diff --git a/dangerzone/util.py b/dangerzone/util.py index bdd6452..e19a174 100644 --- a/dangerzone/util.py +++ b/dangerzone/util.py @@ -1,10 +1,12 @@ +import os import pathlib import platform +import selectors import string import subprocess import sys import time -from typing import Optional, Self +from typing import IO, Optional, Union import appdirs @@ -131,3 +133,66 @@ class Stopwatch: def stop(self) -> None: self.__exit__() + + +def nonblocking_read(fd: Union[IO[bytes], int], size: int, timeout: float) -> bytes: + """Opinionated read function for non-blocking fds. + + This function offers a blocking interface for reading non-blocking fds. Unlike + the common `os.read()` function, this function accepts a timeout argument as well. + + If the file descriptor has not reached EOF and this function has not read all the + requested bytes before the timeout expiration, it will raise a TimeoutError. Note + that partial reads do not affect the timeout duration, and thus this function may + return a TimeoutError, even if it has read some bytes. + + If the file descriptor has reached EOF, this function may return less than the + requested number of bytes, which is the same behavior as `os.read()`. + """ + if not isinstance(fd, int): + fd = fd.fileno() + + # Validate the provided arguments. + if os.get_blocking(fd): + raise ValueError("Expected a non-blocking file descriptor") + if size <= 0: + raise ValueError(f"Expected a positive size value (got {size})") + if timeout <= 0: + raise ValueError(f"Expected a positive timeout value (got {timeout})") + + # Register this file descriptor only for read. Also, start the timer for the + # timeout. + sel = selectors.DefaultSelector() + sel.register(fd, selectors.EVENT_READ) + + buf = b"" + sw = Stopwatch(timeout) + sw.start() + + # Wait on `select()` until: + # + # 1. The timeout expired. In that case, `select()` will return an empty event ([]). + # 2. The file descriptor returns EOF. In that case, `os.read()` will return an empty + # buffer (""). + # 3. We have read all the bytes we requested. + while True: + events = sel.select(sw.remaining) + if not events: + raise TimeoutError(f"Timeout expired while reading {len(buf)}/{size} bytes") + + chunk = os.read(fd, size) + buf += chunk + if chunk == b"": + # EOF + break + + # Recalculate the remaining timeout and size arguments. + size -= len(chunk) + + assert size >= 0 + if size == 0: + # We have read everything + break + + sel.close() + return buf diff --git a/tests/test_util.py b/tests/test_util.py index 04eecb2..b0beff0 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,8 +1,13 @@ +import os import platform +import selectors import subprocess +import threading +import time from pathlib import Path import pytest +from pytest_mock import MockerFixture from dangerzone import util @@ -30,3 +35,85 @@ def test_replace_control_chars(uncommon_text: str, sanitized_text: str) -> None: assert util.replace_control_chars(uncommon_text) == sanitized_text assert util.replace_control_chars("normal text") == "normal text" assert util.replace_control_chars("") == "" + + +@pytest.mark.skipif( + platform.system() == "Windows", reason="Cannot test non-blocking read on Windows" +) +def test_nonblocking_read(mocker: MockerFixture) -> None: + """Test that the nonblocking_read() function works properly.""" + size = 9 + timeout = 1 + r, w = os.pipe() + + # Test 1 - Check that invalid arguments (blocking fd, negative size/timeout ) raise + # an exception. + with pytest.raises(ValueError, match="Expected a non-blocking file descriptor"): + util.nonblocking_read(r, size, timeout) + + os.set_blocking(r, False) + + with pytest.raises(ValueError, match="Expected a positive size value"): + util.nonblocking_read(r, 0, timeout) + + with pytest.raises(ValueError, match="Expected a positive timeout value"): + util.nonblocking_read(r, size, 0) + + # Test 2 - Check that partial reads are retried, for the timeout's duration, + # and we never read more than we want. + select_spy = mocker.spy(selectors.DefaultSelector, "select") + read_spy = mocker.spy(os, "read") + + # Write "1234567890", with a delay of 0.3 seconds. + os.write(w, b"12345") + + def write_rest() -> None: + time.sleep(0.3) + os.write(w, b"67890") + + threading.Thread(target=write_rest).start() + + # Ensure that we receive all the characters, except for the last one ("0"), since it + # exceeds the requested size. + assert util.nonblocking_read(r, size, timeout) == b"123456789" + + # Ensure that the read/select calls were retried. + # FIXME: The following asserts are racy, and assume that a 0.3 second delay will + # trigger a re-read. If our tests fail due to it, we should find a smarter way to + # test it. + assert read_spy.call_count == 2 + assert read_spy.call_args_list[0].args[1] == 9 + assert read_spy.call_args_list[1].args[1] == 4 + assert read_spy.spy_return == b"6789" + + assert select_spy.call_count == 2 + timeout1 = select_spy.call_args_list[0].args[1] + timeout2 = select_spy.call_args_list[1].args[1] + assert 1 > timeout1 > timeout2 + + # Test 3 - Check that timeouts work, even when we partially read something. + select_spy.reset_mock() + read_spy.reset_mock() + + # Ensure that the function raises a timeout error. + with pytest.raises(TimeoutError): + util.nonblocking_read(r, size, 0.1) + + # Ensure that the function has read a single character from the previous write + # operation. + assert read_spy.call_count == 1 + assert read_spy.spy_return == b"0" + + # Ensure that the select() method has been called twice, and that the second time it + # returned an empty list (meaning that timeout expired). + assert select_spy.call_count == 2 + assert select_spy.spy_return == [] + timeout1 = select_spy.call_args_list[0].args[1] + timeout2 = select_spy.call_args_list[1].args[1] + assert 0.1 > timeout1 > timeout2 + + # Test 4 - Check that EOF is detected. + buf = b"Bye!" + os.write(w, buf) + os.close(w) + assert util.nonblocking_read(r, size, timeout) == buf