Add non-blocking read utility

Add a function that can read data from non-blocking fds, which we will
used later on to read from standard streams with a timeout.
This commit is contained in:
Alex Pyrgiotis 2023-09-13 13:19:37 +03:00
parent 344d6f7bfa
commit fea193e935
No known key found for this signature in database
GPG key ID: B6C15EBA0357C9AA
2 changed files with 153 additions and 1 deletions

View file

@ -1,10 +1,12 @@
import os
import pathlib import pathlib
import platform import platform
import selectors
import string import string
import subprocess import subprocess
import sys import sys
import time import time
from typing import Optional, Self from typing import IO, Optional, Union
import appdirs import appdirs
@ -131,3 +133,66 @@ class Stopwatch:
def stop(self) -> None: def stop(self) -> None:
self.__exit__() self.__exit__()
def nonblocking_read(fd: Union[IO[bytes], int], size: int, timeout: float) -> bytes:
"""Opinionated read function for non-blocking fds.
This function offers a blocking interface for reading non-blocking fds. Unlike
the common `os.read()` function, this function accepts a timeout argument as well.
If the file descriptor has not reached EOF and this function has not read all the
requested bytes before the timeout expiration, it will raise a TimeoutError. Note
that partial reads do not affect the timeout duration, and thus this function may
return a TimeoutError, even if it has read some bytes.
If the file descriptor has reached EOF, this function may return less than the
requested number of bytes, which is the same behavior as `os.read()`.
"""
if not isinstance(fd, int):
fd = fd.fileno()
# Validate the provided arguments.
if os.get_blocking(fd):
raise ValueError("Expected a non-blocking file descriptor")
if size <= 0:
raise ValueError(f"Expected a positive size value (got {size})")
if timeout <= 0:
raise ValueError(f"Expected a positive timeout value (got {timeout})")
# Register this file descriptor only for read. Also, start the timer for the
# timeout.
sel = selectors.DefaultSelector()
sel.register(fd, selectors.EVENT_READ)
buf = b""
sw = Stopwatch(timeout)
sw.start()
# Wait on `select()` until:
#
# 1. The timeout expired. In that case, `select()` will return an empty event ([]).
# 2. The file descriptor returns EOF. In that case, `os.read()` will return an empty
# buffer ("").
# 3. We have read all the bytes we requested.
while True:
events = sel.select(sw.remaining)
if not events:
raise TimeoutError(f"Timeout expired while reading {len(buf)}/{size} bytes")
chunk = os.read(fd, size)
buf += chunk
if chunk == b"":
# EOF
break
# Recalculate the remaining timeout and size arguments.
size -= len(chunk)
assert size >= 0
if size == 0:
# We have read everything
break
sel.close()
return buf

View file

@ -1,8 +1,13 @@
import os
import platform import platform
import selectors
import subprocess import subprocess
import threading
import time
from pathlib import Path from pathlib import Path
import pytest import pytest
from pytest_mock import MockerFixture
from dangerzone import util from dangerzone import util
@ -30,3 +35,85 @@ def test_replace_control_chars(uncommon_text: str, sanitized_text: str) -> None:
assert util.replace_control_chars(uncommon_text) == sanitized_text assert util.replace_control_chars(uncommon_text) == sanitized_text
assert util.replace_control_chars("normal text") == "normal text" assert util.replace_control_chars("normal text") == "normal text"
assert util.replace_control_chars("") == "" assert util.replace_control_chars("") == ""
@pytest.mark.skipif(
platform.system() == "Windows", reason="Cannot test non-blocking read on Windows"
)
def test_nonblocking_read(mocker: MockerFixture) -> None:
"""Test that the nonblocking_read() function works properly."""
size = 9
timeout = 1
r, w = os.pipe()
# Test 1 - Check that invalid arguments (blocking fd, negative size/timeout ) raise
# an exception.
with pytest.raises(ValueError, match="Expected a non-blocking file descriptor"):
util.nonblocking_read(r, size, timeout)
os.set_blocking(r, False)
with pytest.raises(ValueError, match="Expected a positive size value"):
util.nonblocking_read(r, 0, timeout)
with pytest.raises(ValueError, match="Expected a positive timeout value"):
util.nonblocking_read(r, size, 0)
# Test 2 - Check that partial reads are retried, for the timeout's duration,
# and we never read more than we want.
select_spy = mocker.spy(selectors.DefaultSelector, "select")
read_spy = mocker.spy(os, "read")
# Write "1234567890", with a delay of 0.3 seconds.
os.write(w, b"12345")
def write_rest() -> None:
time.sleep(0.3)
os.write(w, b"67890")
threading.Thread(target=write_rest).start()
# Ensure that we receive all the characters, except for the last one ("0"), since it
# exceeds the requested size.
assert util.nonblocking_read(r, size, timeout) == b"123456789"
# Ensure that the read/select calls were retried.
# FIXME: The following asserts are racy, and assume that a 0.3 second delay will
# trigger a re-read. If our tests fail due to it, we should find a smarter way to
# test it.
assert read_spy.call_count == 2
assert read_spy.call_args_list[0].args[1] == 9
assert read_spy.call_args_list[1].args[1] == 4
assert read_spy.spy_return == b"6789"
assert select_spy.call_count == 2
timeout1 = select_spy.call_args_list[0].args[1]
timeout2 = select_spy.call_args_list[1].args[1]
assert 1 > timeout1 > timeout2
# Test 3 - Check that timeouts work, even when we partially read something.
select_spy.reset_mock()
read_spy.reset_mock()
# Ensure that the function raises a timeout error.
with pytest.raises(TimeoutError):
util.nonblocking_read(r, size, 0.1)
# Ensure that the function has read a single character from the previous write
# operation.
assert read_spy.call_count == 1
assert read_spy.spy_return == b"0"
# Ensure that the select() method has been called twice, and that the second time it
# returned an empty list (meaning that timeout expired).
assert select_spy.call_count == 2
assert select_spy.spy_return == []
timeout1 = select_spy.call_args_list[0].args[1]
timeout2 = select_spy.call_args_list[1].args[1]
assert 0.1 > timeout1 > timeout2
# Test 4 - Check that EOF is detected.
buf = b"Bye!"
os.write(w, buf)
os.close(w)
assert util.nonblocking_read(r, size, timeout) == buf