mirror of
https://github.com/freedomofpress/dangerzone.git
synced 2025-04-28 18:02:38 +02:00
Add non-blocking read utility
Add a function that can read data from non-blocking fds, which we will used later on to read from standard streams with a timeout.
This commit is contained in:
parent
344d6f7bfa
commit
fea193e935
2 changed files with 153 additions and 1 deletions
|
@ -1,10 +1,12 @@
|
|||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import selectors
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional, Self
|
||||
from typing import IO, Optional, Union
|
||||
|
||||
import appdirs
|
||||
|
||||
|
@ -131,3 +133,66 @@ class Stopwatch:
|
|||
|
||||
def stop(self) -> None:
|
||||
self.__exit__()
|
||||
|
||||
|
||||
def nonblocking_read(fd: Union[IO[bytes], int], size: int, timeout: float) -> bytes:
|
||||
"""Opinionated read function for non-blocking fds.
|
||||
|
||||
This function offers a blocking interface for reading non-blocking fds. Unlike
|
||||
the common `os.read()` function, this function accepts a timeout argument as well.
|
||||
|
||||
If the file descriptor has not reached EOF and this function has not read all the
|
||||
requested bytes before the timeout expiration, it will raise a TimeoutError. Note
|
||||
that partial reads do not affect the timeout duration, and thus this function may
|
||||
return a TimeoutError, even if it has read some bytes.
|
||||
|
||||
If the file descriptor has reached EOF, this function may return less than the
|
||||
requested number of bytes, which is the same behavior as `os.read()`.
|
||||
"""
|
||||
if not isinstance(fd, int):
|
||||
fd = fd.fileno()
|
||||
|
||||
# Validate the provided arguments.
|
||||
if os.get_blocking(fd):
|
||||
raise ValueError("Expected a non-blocking file descriptor")
|
||||
if size <= 0:
|
||||
raise ValueError(f"Expected a positive size value (got {size})")
|
||||
if timeout <= 0:
|
||||
raise ValueError(f"Expected a positive timeout value (got {timeout})")
|
||||
|
||||
# Register this file descriptor only for read. Also, start the timer for the
|
||||
# timeout.
|
||||
sel = selectors.DefaultSelector()
|
||||
sel.register(fd, selectors.EVENT_READ)
|
||||
|
||||
buf = b""
|
||||
sw = Stopwatch(timeout)
|
||||
sw.start()
|
||||
|
||||
# Wait on `select()` until:
|
||||
#
|
||||
# 1. The timeout expired. In that case, `select()` will return an empty event ([]).
|
||||
# 2. The file descriptor returns EOF. In that case, `os.read()` will return an empty
|
||||
# buffer ("").
|
||||
# 3. We have read all the bytes we requested.
|
||||
while True:
|
||||
events = sel.select(sw.remaining)
|
||||
if not events:
|
||||
raise TimeoutError(f"Timeout expired while reading {len(buf)}/{size} bytes")
|
||||
|
||||
chunk = os.read(fd, size)
|
||||
buf += chunk
|
||||
if chunk == b"":
|
||||
# EOF
|
||||
break
|
||||
|
||||
# Recalculate the remaining timeout and size arguments.
|
||||
size -= len(chunk)
|
||||
|
||||
assert size >= 0
|
||||
if size == 0:
|
||||
# We have read everything
|
||||
break
|
||||
|
||||
sel.close()
|
||||
return buf
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
import os
|
||||
import platform
|
||||
import selectors
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from dangerzone import util
|
||||
|
||||
|
@ -30,3 +35,85 @@ def test_replace_control_chars(uncommon_text: str, sanitized_text: str) -> None:
|
|||
assert util.replace_control_chars(uncommon_text) == sanitized_text
|
||||
assert util.replace_control_chars("normal text") == "normal text"
|
||||
assert util.replace_control_chars("") == ""
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows", reason="Cannot test non-blocking read on Windows"
|
||||
)
|
||||
def test_nonblocking_read(mocker: MockerFixture) -> None:
|
||||
"""Test that the nonblocking_read() function works properly."""
|
||||
size = 9
|
||||
timeout = 1
|
||||
r, w = os.pipe()
|
||||
|
||||
# Test 1 - Check that invalid arguments (blocking fd, negative size/timeout ) raise
|
||||
# an exception.
|
||||
with pytest.raises(ValueError, match="Expected a non-blocking file descriptor"):
|
||||
util.nonblocking_read(r, size, timeout)
|
||||
|
||||
os.set_blocking(r, False)
|
||||
|
||||
with pytest.raises(ValueError, match="Expected a positive size value"):
|
||||
util.nonblocking_read(r, 0, timeout)
|
||||
|
||||
with pytest.raises(ValueError, match="Expected a positive timeout value"):
|
||||
util.nonblocking_read(r, size, 0)
|
||||
|
||||
# Test 2 - Check that partial reads are retried, for the timeout's duration,
|
||||
# and we never read more than we want.
|
||||
select_spy = mocker.spy(selectors.DefaultSelector, "select")
|
||||
read_spy = mocker.spy(os, "read")
|
||||
|
||||
# Write "1234567890", with a delay of 0.3 seconds.
|
||||
os.write(w, b"12345")
|
||||
|
||||
def write_rest() -> None:
|
||||
time.sleep(0.3)
|
||||
os.write(w, b"67890")
|
||||
|
||||
threading.Thread(target=write_rest).start()
|
||||
|
||||
# Ensure that we receive all the characters, except for the last one ("0"), since it
|
||||
# exceeds the requested size.
|
||||
assert util.nonblocking_read(r, size, timeout) == b"123456789"
|
||||
|
||||
# Ensure that the read/select calls were retried.
|
||||
# FIXME: The following asserts are racy, and assume that a 0.3 second delay will
|
||||
# trigger a re-read. If our tests fail due to it, we should find a smarter way to
|
||||
# test it.
|
||||
assert read_spy.call_count == 2
|
||||
assert read_spy.call_args_list[0].args[1] == 9
|
||||
assert read_spy.call_args_list[1].args[1] == 4
|
||||
assert read_spy.spy_return == b"6789"
|
||||
|
||||
assert select_spy.call_count == 2
|
||||
timeout1 = select_spy.call_args_list[0].args[1]
|
||||
timeout2 = select_spy.call_args_list[1].args[1]
|
||||
assert 1 > timeout1 > timeout2
|
||||
|
||||
# Test 3 - Check that timeouts work, even when we partially read something.
|
||||
select_spy.reset_mock()
|
||||
read_spy.reset_mock()
|
||||
|
||||
# Ensure that the function raises a timeout error.
|
||||
with pytest.raises(TimeoutError):
|
||||
util.nonblocking_read(r, size, 0.1)
|
||||
|
||||
# Ensure that the function has read a single character from the previous write
|
||||
# operation.
|
||||
assert read_spy.call_count == 1
|
||||
assert read_spy.spy_return == b"0"
|
||||
|
||||
# Ensure that the select() method has been called twice, and that the second time it
|
||||
# returned an empty list (meaning that timeout expired).
|
||||
assert select_spy.call_count == 2
|
||||
assert select_spy.spy_return == []
|
||||
timeout1 = select_spy.call_args_list[0].args[1]
|
||||
timeout2 = select_spy.call_args_list[1].args[1]
|
||||
assert 0.1 > timeout1 > timeout2
|
||||
|
||||
# Test 4 - Check that EOF is detected.
|
||||
buf = b"Bye!"
|
||||
os.write(w, buf)
|
||||
os.close(w)
|
||||
assert util.nonblocking_read(r, size, timeout) == buf
|
||||
|
|
Loading…
Reference in a new issue