Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions pytest-embedded/pytest_embedded/dut_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import datetime
import gc
import io
Expand Down Expand Up @@ -42,8 +43,19 @@ def _drop_none_kwargs(kwargs: dict[t.Any, t.Any]):
PARAMETRIZED_FIXTURES_CACHE = {}


def _listen(q: MessageQueue, filepath: str, with_timestamp: bool = True, count: int = 1, total: int = 1) -> None:
_STDOUT_LOCK = None


def set_stdout_lock(lock) -> None:
global _STDOUT_LOCK
_STDOUT_LOCK = lock


def _listen(
q: MessageQueue, filepath: str, with_timestamp: bool = True, count: int = 1, total: int = 1, _stdout_lock=None
) -> None:
shall_add_prefix = True
_pending = ''
while True:
msg = q.get()
if not msg:
Expand Down Expand Up @@ -71,20 +83,25 @@ def _listen(q: MessageQueue, filepath: str, with_timestamp: bool = True, count:
if _s.endswith('\n'): # complete line
shall_add_prefix = True
_s = _s[:-1].replace('\n', '\n' + prefix) + '\n'
with _stdout_lock if _stdout_lock else contextlib.nullcontext():
_stdout.write(_pending + _s)
_stdout.flush()
_pending = ''
else:
shall_add_prefix = False
_s = _s.replace('\n', '\n' + prefix)

_stdout.write(_s)
_stdout.flush()
_pending += _s


def _listener_gn(msg_queue, _pexpect_logfile, with_timestamp, dut_index, dut_total) -> multiprocessing.Process:
def _listener_gn(
msg_queue, _pexpect_logfile, with_timestamp, dut_index, dut_total, _stdout_lock=None
) -> multiprocessing.Process:
os.makedirs(os.path.dirname(_pexpect_logfile), exist_ok=True)
kwargs = {
'with_timestamp': with_timestamp,
'count': dut_index,
'total': dut_total,
'_stdout_lock': _stdout_lock,
}

return _ctx.Process(
Expand Down Expand Up @@ -753,7 +770,9 @@ def create(
)
logging.debug('You can get your custom DUT log file at the following path: %s.', _pexpect_logfile)

_listener = _listener_gn(msg_queue, _pexpect_logfile, True, DUT_GLOBAL_INDEX, DUT_GLOBAL_INDEX + 1)
_listener = _listener_gn(
msg_queue, _pexpect_logfile, True, DUT_GLOBAL_INDEX, DUT_GLOBAL_INDEX + 1, _stdout_lock=_STDOUT_LOCK
)
layout.append(_listener)

_pexpect_fr = _pexpect_fr_gn(_pexpect_logfile, _listener)
Expand Down
25 changes: 24 additions & 1 deletion pytest-embedded/pytest_embedded/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .dut import Dut
from .dut_factory import (
DutFactory,
_ctx,
_fixture_classes_and_options_fn,
_listener_gn,
_pexpect_fr_gn,
Expand All @@ -41,6 +42,7 @@
qemu_gn,
serial_gn,
set_parametrized_fixtures_cache,
set_stdout_lock,
wokwi_gn,
)
from .log import MessageQueue, MessageQueueManager, PexpectProcess
Expand Down Expand Up @@ -695,6 +697,23 @@ def _mp_manager():
manager.shutdown()


@pytest.fixture(scope='session', autouse=True)
def _stdout_lock():
"""
A session-scoped multiprocessing lock used to serialize stdout writes across
all DUT listener processes, preventing garbled output when multiple DUTs
print to stdout simultaneously.

It is marked ``autouse=True`` so that the lock is created and registered
globally (via ``set_stdout_lock``) before any DUT fixture is instantiated,
ensuring every listener process receives a valid lock reference regardless
of test ordering.
"""
lock = _ctx.Lock()
set_stdout_lock(lock)
yield lock


@pytest.fixture
def test_case_tempdir(test_case_name: str, session_tempdir: str) -> str:
"""Function scoped temp dir for pytest-embedded"""
Expand Down Expand Up @@ -746,13 +765,17 @@ def with_timestamp(request: FixtureRequest) -> bool:

@pytest.fixture
@multi_dut_generator_fixture
def _listener(msg_queue, _pexpect_logfile, with_timestamp, dut_index, dut_total) -> multiprocessing.Process:
def _listener(
msg_queue, _pexpect_logfile, with_timestamp, dut_index, dut_total, _stdout_lock
) -> multiprocessing.Process:
"""
The listener would create a `_listen` process. The `_listen` process would get the string from the message queue,
and do two things together:

1. print the string to `sys.stdout`
2. write the string to `_pexpect_logfile`

A shared lock (_stdout_lock) is used to prevent interleaved output when multiple DUTs print simultaneously.
"""
return _listener_gn(**locals())

Expand Down
189 changes: 189 additions & 0 deletions pytest-embedded/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,3 +821,192 @@ def test_metric_no_path(log_metric):

result = pytester.runpytest()
result.assert_outcomes(passed=1)


# ---------------------------------------------------------------------------
# Tests for the stdout-lock feature (_stdout_lock / set_stdout_lock / _listen)
# ---------------------------------------------------------------------------


def test_set_stdout_lock():
"""set_stdout_lock updates the module-level _STDOUT_LOCK variable."""
import pytest_embedded.dut_factory as m
from pytest_embedded.dut_factory import set_stdout_lock

original = m._STDOUT_LOCK
try:
sentinel = object()
set_stdout_lock(sentinel)
assert m._STDOUT_LOCK is sentinel

set_stdout_lock(None)
assert m._STDOUT_LOCK is None
finally:
set_stdout_lock(original)


def test_listen_no_data_loss_without_lock(tmp_path):
"""_listen writes every queued message to the logfile when no lock is used."""
import time

from pytest_embedded.dut_factory import _ctx, _listen
from pytest_embedded.log import MessageQueue

logfile = str(tmp_path / 'test.log')
q = MessageQueue()
messages = [f'line_{i}\n'.encode() for i in range(20)]

p = _ctx.Process(target=_listen, args=(q, logfile), kwargs={'with_timestamp': False})
p.start()
try:
for msg in messages:
q.put(msg)

deadline = time.monotonic() + 10
while time.monotonic() < deadline:
try:
content = open(logfile, 'rb').read()
if all(msg in content for msg in messages):
break
except OSError:
pass
time.sleep(0.05)
finally:
p.terminate()
p.join(timeout=5)
assert p.exitcode is not None, 'listener process did not terminate'

content = open(logfile, 'rb').read()
for msg in messages:
assert msg in content, f'{msg!r} missing from logfile'


def test_listen_no_data_loss_with_lock(tmp_path):
"""_listen writes every queued message to the logfile when a Manager lock is used."""
import multiprocessing
import time

from pytest_embedded.dut_factory import _ctx, _listen
from pytest_embedded.log import MessageQueue

logfile = str(tmp_path / 'test.log')
q = MessageQueue()
messages = [f'line_{i}\n'.encode() for i in range(20)]

manager = multiprocessing.Manager()
try:
lock = manager.Lock()
p = _ctx.Process(
target=_listen,
args=(q, logfile),
kwargs={'with_timestamp': False, '_stdout_lock': lock},
)
p.start()
try:
for msg in messages:
q.put(msg)

deadline = time.monotonic() + 10
while time.monotonic() < deadline:
try:
content = open(logfile, 'rb').read()
if all(msg in content for msg in messages):
break
except OSError:
pass
time.sleep(0.05)
finally:
p.terminate()
p.join(timeout=5)
assert p.exitcode is not None, 'listener process did not terminate'
finally:
manager.shutdown()

content = open(logfile, 'rb').read()
for msg in messages:
assert msg in content, f'{msg!r} missing from logfile'


def test_stdout_lock_concurrent_no_data_loss(tmp_path):
"""Two concurrent _listen processes sharing a Manager lock both preserve all data."""
import multiprocessing
import time

from pytest_embedded.dut_factory import _ctx, _listen
from pytest_embedded.log import MessageQueue

logfile0 = str(tmp_path / 'dut0.log')
logfile1 = str(tmp_path / 'dut1.log')
q0 = MessageQueue()
q1 = MessageQueue()
messages0 = [f'dut0_line_{i}\n'.encode() for i in range(20)]
messages1 = [f'dut1_line_{i}\n'.encode() for i in range(20)]

manager = multiprocessing.Manager()
try:
lock = manager.Lock()
p0 = _ctx.Process(
target=_listen,
args=(q0, logfile0),
kwargs={'with_timestamp': False, 'count': 1, 'total': 2, '_stdout_lock': lock},
)
p1 = _ctx.Process(
target=_listen,
args=(q1, logfile1),
kwargs={'with_timestamp': False, 'count': 2, 'total': 2, '_stdout_lock': lock},
)
p0.start()
p1.start()
try:
# interleave writes from both DUTs to maximize lock contention
for msg0, msg1 in zip(messages0, messages1):
q0.put(msg0)
q1.put(msg1)

deadline = time.monotonic() + 15
while time.monotonic() < deadline:
try:
c0 = open(logfile0, 'rb').read()
c1 = open(logfile1, 'rb').read()
if all(m in c0 for m in messages0) and all(m in c1 for m in messages1):
break
except OSError:
pass
time.sleep(0.05)
finally:
p0.terminate()
p1.terminate()
p0.join(timeout=5)
p1.join(timeout=5)
assert p0.exitcode is not None, 'dut0 listener process did not terminate'
assert p1.exitcode is not None, 'dut1 listener process did not terminate'
finally:
manager.shutdown()

c0 = open(logfile0, 'rb').read()
c1 = open(logfile1, 'rb').read()
for msg in messages0:
assert msg in c0, f'{msg!r} missing from dut0 logfile'
for msg in messages1:
assert msg in c1, f'{msg!r} missing from dut1 logfile'


def test_multi_dut_no_data_loss(testdir):
"""In a 2-DUT test, all messages written by each DUT can be expected - nothing is dropped."""
testdir.makepyfile(r"""
import pytest

@pytest.mark.parametrize('count', [2], indirect=True)
def test_concurrent_dut_writes(dut):
n = 15
for i in range(n):
dut[0].write(f'dut0_msg_{i}')
dut[1].write(f'dut1_msg_{i}')

for i in range(n):
dut[0].expect_exact(f'dut0_msg_{i}')
dut[1].expect_exact(f'dut1_msg_{i}')
""")

result = testdir.runpytest()
result.assert_outcomes(passed=1)