From 5d225e6e65a7aded3142d07549e722e3c7420691 Mon Sep 17 00:00:00 2001 From: Lucas Saavedra Vaz <32426024+lucasssvaz@users.noreply.github.com> Date: Fri, 1 May 2026 13:16:32 -0300 Subject: [PATCH] fix: Add lock for simultaneous printing --- .../pytest_embedded/dut_factory.py | 31 ++- pytest-embedded/pytest_embedded/plugin.py | 25 ++- pytest-embedded/tests/test_base.py | 189 ++++++++++++++++++ 3 files changed, 238 insertions(+), 7 deletions(-) diff --git a/pytest-embedded/pytest_embedded/dut_factory.py b/pytest-embedded/pytest_embedded/dut_factory.py index 0ca939b1..1b11e301 100644 --- a/pytest-embedded/pytest_embedded/dut_factory.py +++ b/pytest-embedded/pytest_embedded/dut_factory.py @@ -1,3 +1,4 @@ +import contextlib import datetime import gc import io @@ -42,8 +43,19 @@ def _drop_none_kwargs(kwargs: dict[t.Any, t.Any]): PARAMETRIZED_FIXTURES_CACHE = {} -def _listen(q: MessageQueue, filepath: str, with_timestamp: bool = True, count: int = 1, total: int = 1) -> None: +_STDOUT_LOCK = None + + +def set_stdout_lock(lock) -> None: + global _STDOUT_LOCK + _STDOUT_LOCK = lock + + +def _listen( + q: MessageQueue, filepath: str, with_timestamp: bool = True, count: int = 1, total: int = 1, _stdout_lock=None +) -> None: shall_add_prefix = True + _pending = '' while True: msg = q.get() if not msg: @@ -71,20 +83,25 @@ def _listen(q: MessageQueue, filepath: str, with_timestamp: bool = True, count: if _s.endswith('\n'): # complete line shall_add_prefix = True _s = _s[:-1].replace('\n', '\n' + prefix) + '\n' + with _stdout_lock if _stdout_lock else contextlib.nullcontext(): + _stdout.write(_pending + _s) + _stdout.flush() + _pending = '' else: shall_add_prefix = False _s = _s.replace('\n', '\n' + prefix) - - _stdout.write(_s) - _stdout.flush() + _pending += _s -def _listener_gn(msg_queue, _pexpect_logfile, with_timestamp, dut_index, dut_total) -> multiprocessing.Process: +def _listener_gn( + msg_queue, _pexpect_logfile, with_timestamp, dut_index, dut_total, _stdout_lock=None +) -> multiprocessing.Process: os.makedirs(os.path.dirname(_pexpect_logfile), exist_ok=True) kwargs = { 'with_timestamp': with_timestamp, 'count': dut_index, 'total': dut_total, + '_stdout_lock': _stdout_lock, } return _ctx.Process( @@ -753,7 +770,9 @@ def create( ) logging.debug('You can get your custom DUT log file at the following path: %s.', _pexpect_logfile) - _listener = _listener_gn(msg_queue, _pexpect_logfile, True, DUT_GLOBAL_INDEX, DUT_GLOBAL_INDEX + 1) + _listener = _listener_gn( + msg_queue, _pexpect_logfile, True, DUT_GLOBAL_INDEX, DUT_GLOBAL_INDEX + 1, _stdout_lock=_STDOUT_LOCK + ) layout.append(_listener) _pexpect_fr = _pexpect_fr_gn(_pexpect_logfile, _listener) diff --git a/pytest-embedded/pytest_embedded/plugin.py b/pytest-embedded/pytest_embedded/plugin.py index 3bd79139..38d54503 100644 --- a/pytest-embedded/pytest_embedded/plugin.py +++ b/pytest-embedded/pytest_embedded/plugin.py @@ -30,6 +30,7 @@ from .dut import Dut from .dut_factory import ( DutFactory, + _ctx, _fixture_classes_and_options_fn, _listener_gn, _pexpect_fr_gn, @@ -41,6 +42,7 @@ qemu_gn, serial_gn, set_parametrized_fixtures_cache, + set_stdout_lock, wokwi_gn, ) from .log import MessageQueue, MessageQueueManager, PexpectProcess @@ -695,6 +697,23 @@ def _mp_manager(): manager.shutdown() +@pytest.fixture(scope='session', autouse=True) +def _stdout_lock(): + """ + A session-scoped multiprocessing lock used to serialize stdout writes across + all DUT listener processes, preventing garbled output when multiple DUTs + print to stdout simultaneously. + + It is marked ``autouse=True`` so that the lock is created and registered + globally (via ``set_stdout_lock``) before any DUT fixture is instantiated, + ensuring every listener process receives a valid lock reference regardless + of test ordering. + """ + lock = _ctx.Lock() + set_stdout_lock(lock) + yield lock + + @pytest.fixture def test_case_tempdir(test_case_name: str, session_tempdir: str) -> str: """Function scoped temp dir for pytest-embedded""" @@ -746,13 +765,17 @@ def with_timestamp(request: FixtureRequest) -> bool: @pytest.fixture @multi_dut_generator_fixture -def _listener(msg_queue, _pexpect_logfile, with_timestamp, dut_index, dut_total) -> multiprocessing.Process: +def _listener( + msg_queue, _pexpect_logfile, with_timestamp, dut_index, dut_total, _stdout_lock +) -> multiprocessing.Process: """ The listener would create a `_listen` process. The `_listen` process would get the string from the message queue, and do two things together: 1. print the string to `sys.stdout` 2. write the string to `_pexpect_logfile` + + A shared lock (_stdout_lock) is used to prevent interleaved output when multiple DUTs print simultaneously. """ return _listener_gn(**locals()) diff --git a/pytest-embedded/tests/test_base.py b/pytest-embedded/tests/test_base.py index 9df1c51f..89b1cab2 100644 --- a/pytest-embedded/tests/test_base.py +++ b/pytest-embedded/tests/test_base.py @@ -821,3 +821,192 @@ def test_metric_no_path(log_metric): result = pytester.runpytest() result.assert_outcomes(passed=1) + + +# --------------------------------------------------------------------------- +# Tests for the stdout-lock feature (_stdout_lock / set_stdout_lock / _listen) +# --------------------------------------------------------------------------- + + +def test_set_stdout_lock(): + """set_stdout_lock updates the module-level _STDOUT_LOCK variable.""" + import pytest_embedded.dut_factory as m + from pytest_embedded.dut_factory import set_stdout_lock + + original = m._STDOUT_LOCK + try: + sentinel = object() + set_stdout_lock(sentinel) + assert m._STDOUT_LOCK is sentinel + + set_stdout_lock(None) + assert m._STDOUT_LOCK is None + finally: + set_stdout_lock(original) + + +def test_listen_no_data_loss_without_lock(tmp_path): + """_listen writes every queued message to the logfile when no lock is used.""" + import time + + from pytest_embedded.dut_factory import _ctx, _listen + from pytest_embedded.log import MessageQueue + + logfile = str(tmp_path / 'test.log') + q = MessageQueue() + messages = [f'line_{i}\n'.encode() for i in range(20)] + + p = _ctx.Process(target=_listen, args=(q, logfile), kwargs={'with_timestamp': False}) + p.start() + try: + for msg in messages: + q.put(msg) + + deadline = time.monotonic() + 10 + while time.monotonic() < deadline: + try: + content = open(logfile, 'rb').read() + if all(msg in content for msg in messages): + break + except OSError: + pass + time.sleep(0.05) + finally: + p.terminate() + p.join(timeout=5) + assert p.exitcode is not None, 'listener process did not terminate' + + content = open(logfile, 'rb').read() + for msg in messages: + assert msg in content, f'{msg!r} missing from logfile' + + +def test_listen_no_data_loss_with_lock(tmp_path): + """_listen writes every queued message to the logfile when a Manager lock is used.""" + import multiprocessing + import time + + from pytest_embedded.dut_factory import _ctx, _listen + from pytest_embedded.log import MessageQueue + + logfile = str(tmp_path / 'test.log') + q = MessageQueue() + messages = [f'line_{i}\n'.encode() for i in range(20)] + + manager = multiprocessing.Manager() + try: + lock = manager.Lock() + p = _ctx.Process( + target=_listen, + args=(q, logfile), + kwargs={'with_timestamp': False, '_stdout_lock': lock}, + ) + p.start() + try: + for msg in messages: + q.put(msg) + + deadline = time.monotonic() + 10 + while time.monotonic() < deadline: + try: + content = open(logfile, 'rb').read() + if all(msg in content for msg in messages): + break + except OSError: + pass + time.sleep(0.05) + finally: + p.terminate() + p.join(timeout=5) + assert p.exitcode is not None, 'listener process did not terminate' + finally: + manager.shutdown() + + content = open(logfile, 'rb').read() + for msg in messages: + assert msg in content, f'{msg!r} missing from logfile' + + +def test_stdout_lock_concurrent_no_data_loss(tmp_path): + """Two concurrent _listen processes sharing a Manager lock both preserve all data.""" + import multiprocessing + import time + + from pytest_embedded.dut_factory import _ctx, _listen + from pytest_embedded.log import MessageQueue + + logfile0 = str(tmp_path / 'dut0.log') + logfile1 = str(tmp_path / 'dut1.log') + q0 = MessageQueue() + q1 = MessageQueue() + messages0 = [f'dut0_line_{i}\n'.encode() for i in range(20)] + messages1 = [f'dut1_line_{i}\n'.encode() for i in range(20)] + + manager = multiprocessing.Manager() + try: + lock = manager.Lock() + p0 = _ctx.Process( + target=_listen, + args=(q0, logfile0), + kwargs={'with_timestamp': False, 'count': 1, 'total': 2, '_stdout_lock': lock}, + ) + p1 = _ctx.Process( + target=_listen, + args=(q1, logfile1), + kwargs={'with_timestamp': False, 'count': 2, 'total': 2, '_stdout_lock': lock}, + ) + p0.start() + p1.start() + try: + # interleave writes from both DUTs to maximize lock contention + for msg0, msg1 in zip(messages0, messages1): + q0.put(msg0) + q1.put(msg1) + + deadline = time.monotonic() + 15 + while time.monotonic() < deadline: + try: + c0 = open(logfile0, 'rb').read() + c1 = open(logfile1, 'rb').read() + if all(m in c0 for m in messages0) and all(m in c1 for m in messages1): + break + except OSError: + pass + time.sleep(0.05) + finally: + p0.terminate() + p1.terminate() + p0.join(timeout=5) + p1.join(timeout=5) + assert p0.exitcode is not None, 'dut0 listener process did not terminate' + assert p1.exitcode is not None, 'dut1 listener process did not terminate' + finally: + manager.shutdown() + + c0 = open(logfile0, 'rb').read() + c1 = open(logfile1, 'rb').read() + for msg in messages0: + assert msg in c0, f'{msg!r} missing from dut0 logfile' + for msg in messages1: + assert msg in c1, f'{msg!r} missing from dut1 logfile' + + +def test_multi_dut_no_data_loss(testdir): + """In a 2-DUT test, all messages written by each DUT can be expected - nothing is dropped.""" + testdir.makepyfile(r""" + import pytest + + @pytest.mark.parametrize('count', [2], indirect=True) + def test_concurrent_dut_writes(dut): + n = 15 + for i in range(n): + dut[0].write(f'dut0_msg_{i}') + dut[1].write(f'dut1_msg_{i}') + + for i in range(n): + dut[0].expect_exact(f'dut0_msg_{i}') + dut[1].expect_exact(f'dut1_msg_{i}') + """) + + result = testdir.runpytest() + result.assert_outcomes(passed=1)