diff --git a/python/packages/jumpstarter-cli/conftest.py b/python/packages/jumpstarter-cli/conftest.py new file mode 100644 index 00000000..af7e4799 --- /dev/null +++ b/python/packages/jumpstarter-cli/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture +def anyio_backend(): + return "asyncio" diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py b/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py index 764c12e0..d69449c2 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py @@ -13,6 +13,8 @@ from jumpstarter_cli_common.exceptions import find_exception_in_group, handle_exceptions_with_reauthentication from jumpstarter_cli_common.oidc import ( TOKEN_EXPIRY_WARNING_SECONDS, + Config, + decode_jwt_issuer, format_duration, get_token_remaining_seconds, ) @@ -32,6 +34,10 @@ logger = logging.getLogger(__name__) +# Refresh token when less than this many seconds remain +_TOKEN_REFRESH_THRESHOLD_SECONDS = 120 + + def _run_shell_only(lease, config, command, path: str) -> int: """Run just the shell command without log streaming.""" @@ -60,36 +66,192 @@ def _warn_about_expired_token(lease_name: str, selector: str) -> None: click.echo(click.style(f"To reconnect: JMP_LEASE={lease_name} jmp shell", fg="cyan")) -async def _monitor_token_expiry(config, cancel_scope) -> None: - """Monitor token expiry and warn user.""" +async def _update_lease_channel(config, lease) -> None: + """Update the lease's gRPC channel with the current config credentials.""" + if lease is not None: + new_channel = await config.channel() + lease.refresh_channel(new_channel) + + +async def _try_refresh_token(config, lease) -> bool: + """Attempt to refresh the token and update the lease channel. + + Returns True if refresh succeeded, False otherwise. + """ + refresh_token = getattr(config, "refresh_token", None) + if not refresh_token: + return False + + old_token = config.token + old_refresh_token = config.refresh_token + try: + issuer = decode_jwt_issuer(config.token) + oidc = Config( + issuer=issuer, + client_id="jumpstarter-cli", + offline_access=True, + ) + + tokens = await oidc.refresh_token_grant(refresh_token) + config.token = tokens["access_token"] + new_refresh_token = tokens.get("refresh_token") + if new_refresh_token is not None: + config.refresh_token = new_refresh_token + + # Update the lease channel first (critical for the running session) + await _update_lease_channel(config, lease) + + # Persist to disk (best-effort, uses original config path) + try: + ClientConfigV1Alpha1.save(config, path=config.path) + except Exception as e: + logger.warning("Failed to save refreshed token to disk: %s", e) + + return True + except Exception as e: + # Restore old token so the monitor doesn't think we succeeded + config.token = old_token + config.refresh_token = old_refresh_token + logger.debug("Token refresh failed: %s", e) + return False + + +async def _try_reload_token_from_disk(config, lease) -> bool: + """Check if the config on disk has a newer/valid token (e.g. from 'jmp login'). + + If a valid token is found on disk, updates the in-memory config and lease channel. + Returns True if a valid token was loaded, False otherwise. + """ + config_path = getattr(config, "path", None) + if not config_path: + return False + + old_token = config.token + old_refresh_token = config.refresh_token + try: + disk_config = ClientConfigV1Alpha1.from_file(config_path) + disk_token = getattr(disk_config, "token", None) + if not disk_token or disk_token == config.token: + return False + + # Check if the token on disk is actually valid + disk_remaining = get_token_remaining_seconds(disk_token) + if disk_remaining is None or disk_remaining <= 0: + return False + + # Token on disk is valid and different - use it + config.token = disk_token + config.refresh_token = getattr(disk_config, "refresh_token", None) + + # Update the lease channel (critical for the running session) + await _update_lease_channel(config, lease) + + return True + except Exception as e: + config.token = old_token + config.refresh_token = old_refresh_token + logger.debug("Failed to reload token from disk: %s", e) + return False + + +async def _attempt_token_recovery(config, lease) -> str | None: + """Try all available methods to recover a valid token. + + Attempts OIDC refresh first, then falls back to reloading from disk + (e.g. if user ran 'jmp login' from the shell). + + Returns a message describing the recovery method, or None if all failed. + """ + if await _try_refresh_token(config, lease): + return "Token refreshed automatically." + if await _try_reload_token_from_disk(config, lease): + return "Token reloaded from login." + return None + + +def _warn_refresh_failed(remaining: float) -> None: + """Warn the user that token refresh failed.""" + if remaining > 0: + duration = format_duration(remaining) + click.echo( + click.style( + f"\nToken expires in {duration} and auto-refresh failed. " + "Run 'jmp login' from this shell to refresh manually.", + fg="yellow", + bold=True, + ) + ) + else: + click.echo( + click.style( + "\nToken expired and auto-refresh failed. " + "New commands will fail until you run 'jmp login' from this shell.", + fg="red", + bold=True, + ) + ) + + +async def _handle_token_refresh(config, lease, remaining, warn_state, token_state=None) -> None: + """Try to recover the token and update warning state accordingly.""" + recovery_msg = await _attempt_token_recovery(config, lease) + if recovery_msg: + click.echo(click.style(f"\n{recovery_msg}", fg="green")) + warn_state["expiry"] = False + warn_state["refresh_failed"] = False + warn_state["token_expired"] = False + if token_state is not None: + token_state["expired_unrecovered"] = False + elif remaining <= 0 and not warn_state["token_expired"]: + _warn_refresh_failed(remaining) + warn_state["token_expired"] = True + if token_state is not None: + token_state["expired_unrecovered"] = True + elif remaining > 0 and not warn_state["refresh_failed"]: + _warn_refresh_failed(remaining) + warn_state["refresh_failed"] = True + + +async def _monitor_token_expiry(config, lease, cancel_scope, token_state=None) -> None: + """Monitor token expiry, auto-refresh when possible, warn user otherwise. + + this monitor: + 1. Proactively refreshes the token before it expires using the refresh token + 2. Updates the lease's gRPC channel with new credentials + 3. If refresh fails, periodically checks the config on disk for a token + refreshed externally (e.g. via 'jmp login' from within the shell) + 4. Never cancels the scope - the shell stays alive regardless + """ token = getattr(config, "token", None) if not token: return - warned = False + warn_state = {"expiry": False, "refresh_failed": False, "token_expired": False} while not cancel_scope.cancel_called: try: - remaining = get_token_remaining_seconds(token) + # Re-read config.token each iteration since it may have been refreshed + remaining = get_token_remaining_seconds(config.token) if remaining is None: return - if remaining <= 0: - click.echo(click.style("\nToken expired! Exiting shell.", fg="red", bold=True)) - cancel_scope.cancel() - return - - if remaining <= TOKEN_EXPIRY_WARNING_SECONDS and not warned: + if remaining <= _TOKEN_REFRESH_THRESHOLD_SECONDS: + await _handle_token_refresh(config, lease, remaining, warn_state, token_state) + elif remaining <= TOKEN_EXPIRY_WARNING_SECONDS and not warn_state["expiry"]: duration = format_duration(remaining) click.echo( click.style( - f"\nToken expires in {duration}. Session will continue but cleanup may fail on exit.", + f"\nToken expires in {duration}. Will attempt auto-refresh.", fg="yellow", bold=True, ) ) - warned = True + warn_state["expiry"] = True - await anyio.sleep(30) + # Check more frequently as we approach expiry + if remaining <= _TOKEN_REFRESH_THRESHOLD_SECONDS: + await anyio.sleep(5) + else: + await anyio.sleep(30) except Exception: return @@ -255,6 +417,7 @@ async def _shell_with_signal_handling( # noqa: C901 exit_code = 0 cancelled_exc_class = get_cancelled_exc_class() lease_used = None + token_state = {"expired_unrecovered": False} # Check token before starting token = getattr(config, "token", None) @@ -277,11 +440,13 @@ async def _shell_with_signal_handling( # noqa: C901 lease_used = lease # Start token monitoring only once we're in the shell - tg.start_soon(_monitor_token_expiry, config, tg.cancel_scope) + tg.start_soon(_monitor_token_expiry, config, lease, tg.cancel_scope, token_state) exit_code = await _run_shell_with_lease_async( lease, exporter_logs, config, command, tg.cancel_scope ) + if lease.release and lease.name and token_state["expired_unrecovered"]: + _warn_about_expired_token(lease.name, selector) except BaseExceptionGroup as eg: for exc in eg.exceptions: if isinstance(exc, TimeoutError): diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py b/python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py index 132b7b32..fc412c7e 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py @@ -1,6 +1,7 @@ import base64 import inspect import json +import logging import time from contextlib import asynccontextmanager from datetime import datetime, timedelta @@ -11,12 +12,24 @@ import pytest from jumpstarter_cli_common.exceptions import handle_exceptions_with_reauthentication -from jumpstarter_cli.shell import _resolve_lease_from_active_async, _shell_with_signal_handling, shell +from jumpstarter_cli.shell import ( + _attempt_token_recovery, + _monitor_token_expiry, + _resolve_lease_from_active_async, + _shell_with_signal_handling, + _try_refresh_token, + _try_reload_token_from_disk, + _update_lease_channel, + _warn_refresh_failed, + shell, +) from jumpstarter.client.grpc import Lease, LeaseList from jumpstarter.config.client import ClientConfigV1Alpha1 from jumpstarter.config.env import JMP_LEASE +pytestmark = pytest.mark.anyio + def _make_lease(name: str, client: str = "test-client") -> Lease: return Lease( @@ -77,6 +90,49 @@ def test_shell_passes_exporter_name_to_lease_async(): assert config.captured[1] == "laptop-test-exporter" +async def test_shell_warns_when_expired_token_prevents_cleanup_on_normal_exit(): + lease = Mock() + lease.release = True + lease.name = "expired-lease" + lease.lease_ended = False + lease.lease_transferred = False + + config = _DummyConfig() + + @asynccontextmanager + async def lease_async(selector, exporter_name, lease_name, duration, portal, acquisition_timeout): + yield lease + + config.lease_async = lease_async + + async def fake_monitor(_config, _lease, _cancel_scope, token_state=None): + if token_state is not None: + token_state["expired_unrecovered"] = True + + async def fake_run_shell(*_args): + await anyio.sleep(0) + return 0 + + with ( + patch("jumpstarter_cli.shell._monitor_token_expiry", side_effect=fake_monitor), + patch("jumpstarter_cli.shell._run_shell_with_lease_async", side_effect=fake_run_shell), + patch("jumpstarter_cli.shell._warn_about_expired_token") as mock_warn, + ): + exit_code = await _shell_with_signal_handling( + config, + None, + None, + None, + timedelta(minutes=1), + False, + (), + None, + ) + + assert exit_code == 0 + mock_warn.assert_called_once_with("expired-lease", None) + + def test_shell_requires_selector_or_name_when_no_leases(): config = Mock(spec=ClientConfigV1Alpha1) config.metadata = type("Metadata", (), {"name": "test-client"})() @@ -310,3 +366,413 @@ def run_shell(): run_shell() login_mock.assert_called_once_with(config) + + +def _make_config(token="tok", refresh_token="rt", path="/tmp/config.yaml"): + """Create a mock config with sensible defaults.""" + config = Mock() + config.token = token + config.refresh_token = refresh_token + config.path = path + config.channel = AsyncMock(return_value=Mock(name="new_channel")) + return config + + +def _make_mock_lease(): + """Create a mock lease with a refresh_channel method.""" + lease = Mock() + lease.refresh_channel = Mock() + return lease + + +class TestUpdateLeaseChannel: + async def test_updates_channel_on_lease(self): + config = _make_config() + lease = _make_mock_lease() + + await _update_lease_channel(config, lease) + + config.channel.assert_awaited_once() + lease.refresh_channel.assert_called_once_with(config.channel.return_value) + + async def test_noop_when_lease_is_none(self): + config = _make_config() + + await _update_lease_channel(config, None) + + config.channel.assert_not_awaited() + + +class TestTryRefreshToken: + async def test_returns_false_when_no_refresh_token(self): + config = _make_config(refresh_token=None) + assert await _try_refresh_token(config, _make_mock_lease()) is False + + async def test_returns_false_when_refresh_token_is_empty(self): + config = _make_config(refresh_token="") + assert await _try_refresh_token(config, _make_mock_lease()) is False + + @patch("jumpstarter_cli.shell.ClientConfigV1Alpha1") + @patch("jumpstarter_cli.shell.Config") + @patch("jumpstarter_cli.shell.decode_jwt_issuer", return_value="https://issuer") + async def test_successful_refresh(self, _mock_issuer, mock_oidc_cls, mock_save): + config = _make_config() + lease = _make_mock_lease() + + mock_oidc = AsyncMock() + mock_oidc.refresh_token_grant.return_value = { + "access_token": "new_tok", + "refresh_token": "new_rt", + } + mock_oidc_cls.return_value = mock_oidc + + result = await _try_refresh_token(config, lease) + + assert result is True + assert config.token == "new_tok" + assert config.refresh_token == "new_rt" + lease.refresh_channel.assert_called_once() + mock_save.save.assert_called_once() + + @patch("jumpstarter_cli.shell.ClientConfigV1Alpha1") + @patch("jumpstarter_cli.shell.Config") + @patch("jumpstarter_cli.shell.decode_jwt_issuer", return_value="https://issuer") + async def test_successful_refresh_without_new_refresh_token( + self, _mock_issuer, mock_oidc_cls, _mock_save + ): + config = _make_config() + lease = _make_mock_lease() + + mock_oidc = AsyncMock() + mock_oidc.refresh_token_grant.return_value = { + "access_token": "new_tok", + # No refresh_token in response + } + mock_oidc_cls.return_value = mock_oidc + + result = await _try_refresh_token(config, lease) + + assert result is True + assert config.token == "new_tok" + assert config.refresh_token == "rt" # unchanged + + @patch("jumpstarter_cli.shell.decode_jwt_issuer", side_effect=ValueError("bad jwt")) + async def test_rollback_on_failure(self, _mock_issuer): + config = _make_config(token="original_tok", refresh_token="original_rt") + lease = _make_mock_lease() + + result = await _try_refresh_token(config, lease) + + assert result is False + assert config.token == "original_tok" + assert config.refresh_token == "original_rt" + lease.refresh_channel.assert_not_called() + + @patch("jumpstarter_cli.shell.ClientConfigV1Alpha1") + @patch("jumpstarter_cli.shell.Config") + @patch("jumpstarter_cli.shell.decode_jwt_issuer", return_value="https://issuer") + async def test_save_failure_does_not_fail_refresh( + self, _mock_issuer, mock_oidc_cls, mock_save, caplog + ): + """Disk save is best-effort; refresh should still succeed.""" + config = _make_config() + lease = _make_mock_lease() + + mock_oidc = AsyncMock() + mock_oidc.refresh_token_grant.return_value = { + "access_token": "new_tok", + } + mock_oidc_cls.return_value = mock_oidc + mock_save.save.side_effect = OSError("disk full") + + with caplog.at_level(logging.WARNING): + result = await _try_refresh_token(config, lease) + + assert result is True + assert config.token == "new_tok" + assert "Failed to save refreshed token to disk" in caplog.text + + +class TestTryReloadTokenFromDisk: + async def test_returns_false_when_no_path(self): + config = _make_config(path=None) + assert await _try_reload_token_from_disk(config, _make_mock_lease()) is False + + @patch("jumpstarter_cli.shell.ClientConfigV1Alpha1") + @patch("jumpstarter_cli.shell.get_token_remaining_seconds", return_value=3600) + async def test_successful_reload(self, _mock_remaining, mock_client_cfg): + config = _make_config(token="old_tok", refresh_token="old_rt") + lease = _make_mock_lease() + + disk_config = Mock() + disk_config.token = "disk_tok" + disk_config.refresh_token = "disk_rt" + mock_client_cfg.from_file.return_value = disk_config + + result = await _try_reload_token_from_disk(config, lease) + + assert result is True + assert config.token == "disk_tok" + assert config.refresh_token == "disk_rt" + lease.refresh_channel.assert_called_once() + + @patch("jumpstarter_cli.shell.ClientConfigV1Alpha1") + @patch("jumpstarter_cli.shell.get_token_remaining_seconds", return_value=3600) + async def test_clears_refresh_token_when_disk_has_none(self, _mock_remaining, mock_client_cfg): + """If disk config has no refresh token, in-memory refresh token must be cleared.""" + config = _make_config(token="old_tok", refresh_token="stale_rt") + lease = _make_mock_lease() + + disk_config = Mock() + disk_config.token = "disk_tok" + disk_config.refresh_token = None + mock_client_cfg.from_file.return_value = disk_config + + result = await _try_reload_token_from_disk(config, lease) + + assert result is True + assert config.token == "disk_tok" + assert config.refresh_token is None + + @patch("jumpstarter_cli.shell.ClientConfigV1Alpha1") + async def test_returns_false_when_disk_token_is_same(self, mock_client_cfg): + config = _make_config(token="same_tok") + disk_config = Mock() + disk_config.token = "same_tok" + mock_client_cfg.from_file.return_value = disk_config + + result = await _try_reload_token_from_disk(config, _make_mock_lease()) + + assert result is False + + @patch("jumpstarter_cli.shell.ClientConfigV1Alpha1") + @patch("jumpstarter_cli.shell.get_token_remaining_seconds", return_value=-10) + async def test_returns_false_when_disk_token_is_expired( + self, _mock_remaining, mock_client_cfg + ): + config = _make_config(token="old_tok") + disk_config = Mock() + disk_config.token = "disk_tok" + mock_client_cfg.from_file.return_value = disk_config + + result = await _try_reload_token_from_disk(config, _make_mock_lease()) + + assert result is False + assert config.token == "old_tok" # unchanged + + @patch("jumpstarter_cli.shell.ClientConfigV1Alpha1") + async def test_rollback_on_file_error(self, mock_client_cfg): + config = _make_config(token="orig_tok", refresh_token="orig_rt") + mock_client_cfg.from_file.side_effect = FileNotFoundError("gone") + + result = await _try_reload_token_from_disk(config, _make_mock_lease()) + + assert result is False + assert config.token == "orig_tok" + assert config.refresh_token == "orig_rt" + + +class TestAttemptTokenRecovery: + @patch("jumpstarter_cli.shell._try_reload_token_from_disk", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell._try_refresh_token", new_callable=AsyncMock) + async def test_returns_message_on_oidc_success(self, mock_refresh, mock_disk): + mock_refresh.return_value = True + + result = await _attempt_token_recovery(Mock(), Mock()) + + assert result == "Token refreshed automatically." + mock_disk.assert_not_awaited() # should not fall through + + @patch("jumpstarter_cli.shell._try_reload_token_from_disk", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell._try_refresh_token", new_callable=AsyncMock) + async def test_falls_back_to_disk_reload(self, mock_refresh, mock_disk): + mock_refresh.return_value = False + mock_disk.return_value = True + + result = await _attempt_token_recovery(Mock(), Mock()) + + assert result == "Token reloaded from login." + + @patch("jumpstarter_cli.shell._try_reload_token_from_disk", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell._try_refresh_token", new_callable=AsyncMock) + async def test_returns_none_when_all_fail(self, mock_refresh, mock_disk): + mock_refresh.return_value = False + mock_disk.return_value = False + + result = await _attempt_token_recovery(Mock(), Mock()) + + assert result is None + + +class TestWarnRefreshFailed: + @patch("jumpstarter_cli.shell.click") + def test_warns_yellow_when_time_remaining(self, mock_click): + _warn_refresh_failed(300) + mock_click.style.assert_called_once() + _, kwargs = mock_click.style.call_args + assert kwargs["fg"] == "yellow" + + @patch("jumpstarter_cli.shell.click") + def test_warns_red_when_expired(self, mock_click): + _warn_refresh_failed(-10) + mock_click.style.assert_called_once() + _, kwargs = mock_click.style.call_args + assert kwargs["fg"] == "red" + + +class TestMonitorTokenExpiry: + async def test_returns_immediately_when_no_token(self): + config = Mock(spec=[]) # no token attribute + cancel_scope = Mock(cancel_called=False) + + await _monitor_token_expiry(config, None, cancel_scope) + # Should return without error + + @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell.get_token_remaining_seconds", return_value=None) + async def test_returns_when_remaining_is_none(self, _mock_remaining, _mock_sleep): + config = _make_config() + cancel_scope = Mock(cancel_called=False) + + await _monitor_token_expiry(config, None, cancel_scope) + + @patch("jumpstarter_cli.shell.click") + @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell.get_token_remaining_seconds") + async def test_refreshes_when_below_threshold( + self, mock_remaining, mock_recovery, mock_sleep, mock_click + ): + # First call: below threshold; second call: raise to exit + mock_remaining.side_effect = [60, Exception("done")] + mock_recovery.return_value = "Token refreshed automatically." + config = _make_config() + cancel_scope = Mock(cancel_called=False) + + await _monitor_token_expiry(config, _make_mock_lease(), cancel_scope) + + mock_recovery.assert_awaited_once() + # Should print the green success message + mock_click.echo.assert_called() + + @patch("jumpstarter_cli.shell.click") + @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell.get_token_remaining_seconds") + async def test_warns_when_refresh_fails( + self, mock_remaining, mock_recovery, mock_sleep, mock_click + ): + mock_remaining.side_effect = [60, Exception("done")] + mock_recovery.return_value = None # all recovery failed + config = _make_config() + cancel_scope = Mock(cancel_called=False) + + await _monitor_token_expiry(config, _make_mock_lease(), cancel_scope) + + mock_recovery.assert_awaited_once() + + @patch("jumpstarter_cli.shell.click") + @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell.get_token_remaining_seconds") + async def test_warns_within_expiry_window( + self, mock_remaining, mock_sleep, mock_click + ): + from jumpstarter_cli_common.oidc import TOKEN_EXPIRY_WARNING_SECONDS + + # First iteration: within warning window but above refresh threshold + # Second iteration: exit via exception + mock_remaining.side_effect = [ + TOKEN_EXPIRY_WARNING_SECONDS - 10, + Exception("done"), + ] + config = _make_config() + cancel_scope = Mock(cancel_called=False) + + await _monitor_token_expiry(config, _make_mock_lease(), cancel_scope) + + # Verify warning was echoed + mock_click.echo.assert_called() + args = mock_click.style.call_args + assert "auto-refresh" in args[0][0] + + @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell.get_token_remaining_seconds", return_value=500) + async def test_sleeps_30s_when_above_threshold(self, _mock_remaining, mock_sleep): + # Exit after one loop via cancel_called + call_count = 0 + + def check_cancelled(): + nonlocal call_count + call_count += 1 + return call_count > 1 + + config = _make_config() + + class _CancelScope(Mock): + cancel_called = property(lambda self: check_cancelled()) + + cancel_scope = _CancelScope() + + await _monitor_token_expiry(config, _make_mock_lease(), cancel_scope) + + mock_sleep.assert_awaited_with(30) + + @patch("jumpstarter_cli.shell.click") + @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell.get_token_remaining_seconds") + async def test_sleeps_5s_when_below_threshold( + self, mock_remaining, mock_recovery, mock_sleep, _mock_click + ): + mock_remaining.side_effect = [60, Exception("done")] + mock_recovery.return_value = None + config = _make_config() + cancel_scope = Mock(cancel_called=False) + + await _monitor_token_expiry(config, _make_mock_lease(), cancel_scope) + + mock_sleep.assert_awaited_with(5) + + @patch("jumpstarter_cli.shell.click") + @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell.get_token_remaining_seconds") + async def test_does_not_cancel_scope_on_expiry( + self, mock_remaining, mock_recovery, mock_sleep, _mock_click + ): + """The monitor must never cancel the scope — the shell stays alive.""" + mock_remaining.side_effect = [60, Exception("done")] + mock_recovery.return_value = None + config = _make_config() + cancel_scope = Mock(cancel_called=False) + + await _monitor_token_expiry(config, _make_mock_lease(), cancel_scope) + + cancel_scope.cancel.assert_not_called() + + @patch("jumpstarter_cli.shell.click") + @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock) + @patch("jumpstarter_cli.shell.get_token_remaining_seconds") + async def test_warns_red_when_token_transitions_to_expired( + self, mock_remaining, mock_recovery, mock_sleep, mock_click + ): + """After a yellow 'approaching expiry' warning, a red 'expired' warning + must still appear when the token actually crosses zero.""" + mock_remaining.side_effect = [60, -5, Exception("done")] + mock_recovery.return_value = None # all recovery fails + config = _make_config() + cancel_scope = Mock(cancel_called=False) + token_state = {"expired_unrecovered": False} + + await _monitor_token_expiry(config, _make_mock_lease(), cancel_scope, token_state) + + warn_calls = mock_click.style.call_args_list + # Find the yellow warning (remaining > 0) + yellow_calls = [c for c in warn_calls if c[1].get("fg") == "yellow"] + # Find the red warning (remaining <= 0) + + red_calls = [c for c in warn_calls if c[1].get("fg") == "red"] + assert len(yellow_calls) >= 1, "Expected yellow warning for near-expiry" + assert len(red_calls) >= 1, "Expected red warning for actual expiry" + assert token_state["expired_unrecovered"] is True diff --git a/python/packages/jumpstarter/jumpstarter/client/lease.py b/python/packages/jumpstarter/jumpstarter/client/lease.py index 3e89bb65..a51c87db 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease.py @@ -105,6 +105,16 @@ def __post_init__(self): self.controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel) self.svc = ClientService(channel=self.channel, namespace=self.namespace) + def refresh_channel(self, channel: Channel): + """Update the gRPC channel used for controller communication. + + This is used when the auth token is refreshed to update the + underlying gRPC channel with new credentials. + """ + self.channel = channel + self.controller = jumpstarter_pb2_grpc.ControllerServiceStub(channel) + self.svc = ClientService(channel=channel, namespace=self.namespace) + async def _create(self): logger.debug("Creating lease request for selector %s for duration %s", self.selector, self.duration) with translate_grpc_exceptions(): @@ -373,7 +383,7 @@ async def _wait_for_ready_connection(self, path: str): else: raise ConnectionError("Socket not ready at %s" % path) from e - def _notify_lease_ending(self, remaining: timedelta): + def _notify_lease_ending(self, remaining: timedelta) -> None: """Log lease status and invoke the ending callback if set.""" if remaining <= timedelta(0): self.lease_ended = True @@ -381,41 +391,61 @@ def _notify_lease_ending(self, remaining: timedelta): if self.lease_ending_callback is not None: self.lease_ending_callback(self, remaining) + def _get_lease_end_time(self, lease) -> datetime | None: + """Extract the end time from a lease response, or None if not available.""" + if lease.effective_end_time: + return lease.effective_end_time + if not (lease.effective_begin_time and lease.duration): + return None + return lease.effective_begin_time + lease.duration + @asynccontextmanager async def monitor_async(self, threshold: timedelta = timedelta(minutes=5)): async def _monitor(): check_interval = 30 # seconds - check periodically for external lease changes - end_time = None # Track across iterations for error recovery + last_known_end_time = None while True: try: lease = await self.get() - except Exception: - # gRPC channel broken — check if lease was expected to have ended - if end_time and datetime.now().astimezone() >= end_time: - self._notify_lease_ending(timedelta(0)) + except Exception as e: + logger.warning("Failed to check lease %s status: %s", self.name, e) + # If we know when the lease should end, use it to bound the sleep + if last_known_end_time is not None: + remain = (last_known_end_time - datetime.now().astimezone()).total_seconds() + if remain <= 0: + logger.info( + "Lease %s estimated to have ended at %s (unable to confirm with server)", + self.name, + last_known_end_time, + ) + self._notify_lease_ending(timedelta(0)) + break + await sleep(min(check_interval, remain)) else: - logger.debug("Lease monitor: connection lost unexpectedly") + await sleep(check_interval) + continue + + end_time = self._get_lease_end_time(lease) + if end_time is None: + await sleep(1) + continue + + last_known_end_time = end_time + remain = end_time - datetime.now().astimezone() + if remain < timedelta(0): + logger.info("Lease {} ended at {}".format(self.name, end_time)) + self._notify_lease_ending(timedelta(0)) break - if lease.effective_begin_time and lease.effective_duration: - if lease.effective_end_time: # already ended - end_time = lease.effective_end_time - else: - end_time = lease.effective_begin_time + lease.duration - remain = end_time - datetime.now().astimezone() - if remain < timedelta(0): - self._notify_lease_ending(timedelta(0)) - break - # Log once when entering the threshold window - if threshold - timedelta(seconds=check_interval) <= remain < threshold: - logger.info( - "Lease {} ending in {} minutes at {}".format( - self.name, int((remain.total_seconds() + 30) // 60), end_time - ) + + # Log once when entering the threshold window + if threshold - timedelta(seconds=check_interval) <= remain < threshold: + logger.info( + "Lease {} ending in {} minutes at {}".format( + self.name, int((remain.total_seconds() + 30) // 60), end_time ) - self._notify_lease_ending(remain) - await sleep(min(remain.total_seconds(), check_interval)) - else: - await sleep(1) + ) + self._notify_lease_ending(remain) + await sleep(min(remain.total_seconds(), check_interval)) async with create_task_group() as tg: tg.start_soon(_monitor) diff --git a/python/packages/jumpstarter/jumpstarter/client/lease_test.py b/python/packages/jumpstarter/jumpstarter/client/lease_test.py index 4d7dbc85..89b7ab81 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease_test.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease_test.py @@ -1,13 +1,13 @@ import asyncio import logging import sys -from datetime import datetime, timedelta -from unittest.mock import Mock, patch +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, Mock, patch import pytest from rich.console import Console -from jumpstarter.client.lease import LeaseAcquisitionSpinner +from jumpstarter.client.lease import Lease, LeaseAcquisitionSpinner class TestLeaseAcquisitionSpinner: @@ -334,3 +334,191 @@ def test_throttling_not_applied_when_console_available(self): spinner.update_status("Message 4") assert mock_spinner.update.call_count == 4 + + +class TestRefreshChannel: + """Tests for Lease.refresh_channel.""" + + def _make_lease(self): + """Create a Lease with mocked dependencies.""" + channel = Mock(name="original_channel") + lease = object.__new__(Lease) + lease.channel = channel + lease.namespace = "default" + lease.controller = Mock(name="original_controller") + lease.svc = Mock(name="original_svc") + return lease + + @patch("jumpstarter.client.lease.ClientService") + @patch("jumpstarter.client.lease.jumpstarter_pb2_grpc.ControllerServiceStub") + def test_replaces_channel_and_stubs(self, mock_stub_cls, mock_svc_cls): + lease = self._make_lease() + new_channel = Mock(name="new_channel") + + lease.refresh_channel(new_channel) + + assert lease.channel is new_channel + mock_stub_cls.assert_called_once_with(new_channel) + assert lease.controller is mock_stub_cls.return_value + mock_svc_cls.assert_called_once() + + +class TestNotifyLeaseEnding: + """Tests for Lease._notify_lease_ending.""" + + def _make_lease(self): + lease = object.__new__(Lease) + lease.lease_ending_callback = None + return lease + + def test_calls_callback_when_set(self): + lease = self._make_lease() + callback = Mock() + lease.lease_ending_callback = callback + remaining = timedelta(minutes=3) + + lease._notify_lease_ending(remaining) + + callback.assert_called_once_with(lease, remaining) + + def test_noop_when_no_callback(self): + lease = self._make_lease() + + # Should not raise + lease._notify_lease_ending(timedelta(0)) + + +class TestGetLeaseEndTime: + """Tests for Lease._get_lease_end_time.""" + + def _make_lease(self): + return object.__new__(Lease) + + def test_returns_none_when_no_begin_time(self): + lease = self._make_lease() + response = Mock(effective_begin_time=None, duration=timedelta(minutes=30), effective_end_time=None) + + assert lease._get_lease_end_time(response) is None + + def test_returns_none_when_no_duration(self): + lease = self._make_lease() + response = Mock( + effective_begin_time=datetime.now(tz=timezone.utc), + duration=None, + effective_end_time=None, + ) + + assert lease._get_lease_end_time(response) is None + + def test_returns_effective_end_time_when_present(self): + lease = self._make_lease() + end_time = datetime(2025, 6, 1, 12, 0, 0, tzinfo=timezone.utc) + response = Mock( + effective_begin_time=datetime(2025, 6, 1, 11, 0, 0, tzinfo=timezone.utc), + duration=timedelta(hours=1), + effective_end_time=end_time, + ) + + assert lease._get_lease_end_time(response) is end_time + + def test_returns_effective_end_time_even_without_begin_or_duration(self): + lease = self._make_lease() + end_time = datetime(2025, 6, 1, 12, 0, 0, tzinfo=timezone.utc) + response = Mock( + effective_begin_time=None, + duration=None, + effective_end_time=end_time, + ) + + assert lease._get_lease_end_time(response) is end_time + + def test_calculates_end_time_when_no_effective_end(self): + lease = self._make_lease() + begin = datetime(2025, 6, 1, 11, 0, 0, tzinfo=timezone.utc) + duration = timedelta(hours=2) + response = Mock( + effective_begin_time=begin, + effective_duration=timedelta(hours=1), # elapsed time, not used for calculation + effective_end_time=None, + duration=duration, + ) + + result = lease._get_lease_end_time(response) + + assert result == begin + duration + + +class TestMonitorAsyncError: + """Tests for the error handling in monitor_async.""" + + def _make_lease_for_monitor(self): + lease = object.__new__(Lease) + lease.name = "test-lease" + lease.lease_ending_callback = None + lease.get = AsyncMock() + return lease + + @pytest.mark.anyio + async def test_continues_on_get_failure_without_end_time(self): + """When get() fails and we have no end time, monitor retries.""" + lease = self._make_lease_for_monitor() + call_count = 0 + + async def failing_get(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise Exception("transient error") + # Third call: return expired lease to exit the loop + end_time = datetime.now(tz=timezone.utc) - timedelta(seconds=10) + return Mock( + effective_begin_time=end_time - timedelta(hours=1), + effective_duration=timedelta(hours=1), + effective_end_time=end_time, + ) + + lease.get = failing_get + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + async with lease.monitor_async(): + pass + + assert call_count == 3 # two failures + one success + + @pytest.mark.anyio + async def test_estimates_expiry_from_last_known_end_time(self, caplog): + """When get() fails after we've seen an end time, use cached value.""" + lease = self._make_lease_for_monitor() + callback = Mock() + lease.lease_ending_callback = callback + + # End time slightly in the future so the monitor caches it and sleeps + future_end = datetime.now(tz=timezone.utc) + timedelta(milliseconds=50) + call_count = 0 + + async def get_then_fail(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return Mock( + effective_begin_time=future_end - timedelta(hours=1), + effective_duration=timedelta(hours=1), + effective_end_time=None, + duration=timedelta(hours=1), + ) + raise Exception("server unavailable") + + lease.get = get_then_fail + + with caplog.at_level(logging.WARNING): + async with lease.monitor_async(): + # Keep the body alive long enough for the monitor to loop + # through the first get(), sleep, second get() (fails), and + # error handler using the cached end time. + await asyncio.sleep(0.2) + + # Should have gone through the error handler using cached end time + assert call_count >= 2 + callback.assert_called() + _, remain_arg = callback.call_args[0] + assert remain_arg == timedelta(0)