diff --git a/craft_cli/messages.py b/craft_cli/messages.py index 9ff2935c..98177419 100644 --- a/craft_cli/messages.py +++ b/craft_cli/messages.py @@ -40,7 +40,7 @@ import platformdirs -from craft_cli import errors +from craft_cli import errors, printer from craft_cli.printer import Printer if TYPE_CHECKING: @@ -344,6 +344,7 @@ def __exit__( exc_tb: TracebackType | None, ) -> Literal[False]: self.pipe_reader.stop() + printer.reset_terminal_style(self.pipe_reader.stream) return False # do not consume any exception diff --git a/craft_cli/printer.py b/craft_cli/printer.py index 59b8a79b..f313c993 100644 --- a/craft_cli/printer.py +++ b/craft_cli/printer.py @@ -29,6 +29,7 @@ import time import weakref from collections.abc import Callable +from contextlib import suppress from dataclasses import dataclass, field from datetime import datetime from functools import lru_cache @@ -50,6 +51,21 @@ ANSI_CLEAR_LINE_TO_END = "\x1b[K" # ANSI escape code to clear the rest of the line. ANSI_HIDE_CURSOR = "\x1b[?25l" ANSI_SHOW_CURSOR = "\x1b[?25h" +ANSI_RESET = "\x1b[0m" + + +def _safe_print(*args: Any, **kwargs: Any) -> None: + """Print to a stream, ignoring BrokenPipeError from downstream consumers.""" + with suppress(BrokenPipeError): + print(*args, **kwargs) + + +def reset_terminal_style(stream: TextIO | None) -> None: + """Reset ANSI terminal style on the given stream if supported.""" + if stream is None: + return + if _stream_is_terminal(stream) and _supports_ansi_escape_sequences(): + _safe_print(ANSI_RESET, end="", flush=True, file=stream) @dataclass diff --git a/tests/integration/test_messages_integration.py b/tests/integration/test_messages_integration.py index d9074df9..3bb60a7e 100644 --- a/tests/integration/test_messages_integration.py +++ b/tests/integration/test_messages_integration.py @@ -84,6 +84,7 @@ def remove_control_characters(string: str) -> str: string.replace(printer.ANSI_CLEAR_LINE_TO_END, "") .replace(printer.ANSI_HIDE_CURSOR, "") .replace(printer.ANSI_SHOW_CURSOR, "") + .replace(printer.ANSI_RESET, "") ) diff --git a/tests/unit/test_messages_stream_cm.py b/tests/unit/test_messages_stream_cm.py index 52187af7..de0f13b1 100644 --- a/tests/unit/test_messages_stream_cm.py +++ b/tests/unit/test_messages_stream_cm.py @@ -183,6 +183,56 @@ def test_streamcm_dont_consume_exceptions(recording_printer): raise ValueError +@pytest.mark.parametrize("stream", [sys.stdout, sys.stderr]) +def test_streamcm_exit_resets_terminal_style(monkeypatch, recording_printer, stream): + """Closing a stream resets ANSI terminal state.""" + calls = [] + + def fake_reset_terminal_style(target_stream): + calls.append(target_stream) + + monkeypatch.setattr(printer, "reset_terminal_style", fake_reset_terminal_style) + + scm = _StreamContextManager( + recording_printer, + "initial text", + stream=stream, + use_timestamp=False, + ephemeral_mode=False, + ) + + with scm: + pass + + assert calls == [stream] + + +@pytest.mark.parametrize("stream", [sys.stdout, sys.stderr]) +def test_streamcm_exit_always_delegates_terminal_reset( + monkeypatch, recording_printer, stream +): + """Closing a stream delegates terminal reset to the printer helper.""" + calls = [] + + def fake_reset_terminal_style(target_stream): + calls.append(target_stream) + + monkeypatch.setattr(printer, "reset_terminal_style", fake_reset_terminal_style) + + scm = _StreamContextManager( + recording_printer, + "initial text", + stream=stream, + use_timestamp=False, + ephemeral_mode=False, + ) + + with scm: + pass + + assert calls == [stream] + + # -- tests for the pipe reader