Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/packages/jumpstarter-cli/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import pytest


@pytest.fixture
def anyio_backend():
return "asyncio"
193 changes: 179 additions & 14 deletions python/packages/jumpstarter-cli/jumpstarter_cli/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from jumpstarter_cli_common.exceptions import find_exception_in_group, handle_exceptions_with_reauthentication
from jumpstarter_cli_common.oidc import (
TOKEN_EXPIRY_WARNING_SECONDS,
Config,
decode_jwt_issuer,
format_duration,
get_token_remaining_seconds,
)
Expand All @@ -32,6 +34,10 @@

logger = logging.getLogger(__name__)

# Refresh token when less than this many seconds remain
_TOKEN_REFRESH_THRESHOLD_SECONDS = 120



def _run_shell_only(lease, config, command, path: str) -> int:
"""Run just the shell command without log streaming."""
Expand Down Expand Up @@ -60,36 +66,192 @@ def _warn_about_expired_token(lease_name: str, selector: str) -> None:
click.echo(click.style(f"To reconnect: JMP_LEASE={lease_name} jmp shell", fg="cyan"))


async def _monitor_token_expiry(config, cancel_scope) -> None:
"""Monitor token expiry and warn user."""
async def _update_lease_channel(config, lease) -> None:
"""Update the lease's gRPC channel with the current config credentials."""
if lease is not None:
new_channel = await config.channel()
lease.refresh_channel(new_channel)


async def _try_refresh_token(config, lease) -> bool:
"""Attempt to refresh the token and update the lease channel.

Returns True if refresh succeeded, False otherwise.
"""
refresh_token = getattr(config, "refresh_token", None)
if not refresh_token:
return False

old_token = config.token
old_refresh_token = config.refresh_token
try:
issuer = decode_jwt_issuer(config.token)
oidc = Config(
issuer=issuer,
client_id="jumpstarter-cli",
offline_access=True,
)

tokens = await oidc.refresh_token_grant(refresh_token)
config.token = tokens["access_token"]
new_refresh_token = tokens.get("refresh_token")
if new_refresh_token is not None:
config.refresh_token = new_refresh_token

# Update the lease channel first (critical for the running session)
await _update_lease_channel(config, lease)

# Persist to disk (best-effort, uses original config path)
try:
ClientConfigV1Alpha1.save(config, path=config.path)
except Exception as e:
logger.warning("Failed to save refreshed token to disk: %s", e)

return True
except Exception as e:
# Restore old token so the monitor doesn't think we succeeded
config.token = old_token
config.refresh_token = old_refresh_token
logger.debug("Token refresh failed: %s", e)
return False


async def _try_reload_token_from_disk(config, lease) -> bool:
"""Check if the config on disk has a newer/valid token (e.g. from 'jmp login').

If a valid token is found on disk, updates the in-memory config and lease channel.
Returns True if a valid token was loaded, False otherwise.
"""
config_path = getattr(config, "path", None)
if not config_path:
return False

old_token = config.token
old_refresh_token = config.refresh_token
try:
disk_config = ClientConfigV1Alpha1.from_file(config_path)
disk_token = getattr(disk_config, "token", None)
if not disk_token or disk_token == config.token:
return False

# Check if the token on disk is actually valid
disk_remaining = get_token_remaining_seconds(disk_token)
if disk_remaining is None or disk_remaining <= 0:
return False

# Token on disk is valid and different - use it
config.token = disk_token
config.refresh_token = getattr(disk_config, "refresh_token", None)

# Update the lease channel (critical for the running session)
await _update_lease_channel(config, lease)

return True
except Exception as e:
config.token = old_token
config.refresh_token = old_refresh_token
logger.debug("Failed to reload token from disk: %s", e)
return False


async def _attempt_token_recovery(config, lease) -> str | None:
"""Try all available methods to recover a valid token.

Attempts OIDC refresh first, then falls back to reloading from disk
(e.g. if user ran 'jmp login' from the shell).

Returns a message describing the recovery method, or None if all failed.
"""
if await _try_refresh_token(config, lease):
return "Token refreshed automatically."
if await _try_reload_token_from_disk(config, lease):
return "Token reloaded from login."
return None


def _warn_refresh_failed(remaining: float) -> None:
"""Warn the user that token refresh failed."""
if remaining > 0:
duration = format_duration(remaining)
click.echo(
click.style(
f"\nToken expires in {duration} and auto-refresh failed. "
"Run 'jmp login' from this shell to refresh manually.",
fg="yellow",
bold=True,
)
)
else:
click.echo(
click.style(
"\nToken expired and auto-refresh failed. "
"New commands will fail until you run 'jmp login' from this shell.",
fg="red",
bold=True,
)
)


async def _handle_token_refresh(config, lease, remaining, warn_state, token_state=None) -> None:
"""Try to recover the token and update warning state accordingly."""
recovery_msg = await _attempt_token_recovery(config, lease)
if recovery_msg:
click.echo(click.style(f"\n{recovery_msg}", fg="green"))
warn_state["expiry"] = False
warn_state["refresh_failed"] = False
warn_state["token_expired"] = False
if token_state is not None:
token_state["expired_unrecovered"] = False
elif remaining <= 0 and not warn_state["token_expired"]:
_warn_refresh_failed(remaining)
warn_state["token_expired"] = True
if token_state is not None:
token_state["expired_unrecovered"] = True
elif remaining > 0 and not warn_state["refresh_failed"]:
_warn_refresh_failed(remaining)
warn_state["refresh_failed"] = True


async def _monitor_token_expiry(config, lease, cancel_scope, token_state=None) -> None:
"""Monitor token expiry, auto-refresh when possible, warn user otherwise.

this monitor:
1. Proactively refreshes the token before it expires using the refresh token
2. Updates the lease's gRPC channel with new credentials
3. If refresh fails, periodically checks the config on disk for a token
refreshed externally (e.g. via 'jmp login' from within the shell)
4. Never cancels the scope - the shell stays alive regardless
"""
token = getattr(config, "token", None)
if not token:
return

warned = False
warn_state = {"expiry": False, "refresh_failed": False, "token_expired": False}
while not cancel_scope.cancel_called:
try:
remaining = get_token_remaining_seconds(token)
# Re-read config.token each iteration since it may have been refreshed
remaining = get_token_remaining_seconds(config.token)
if remaining is None:
return

if remaining <= 0:
click.echo(click.style("\nToken expired! Exiting shell.", fg="red", bold=True))
cancel_scope.cancel()
return

if remaining <= TOKEN_EXPIRY_WARNING_SECONDS and not warned:
if remaining <= _TOKEN_REFRESH_THRESHOLD_SECONDS:
await _handle_token_refresh(config, lease, remaining, warn_state, token_state)
elif remaining <= TOKEN_EXPIRY_WARNING_SECONDS and not warn_state["expiry"]:
duration = format_duration(remaining)
click.echo(
click.style(
f"\nToken expires in {duration}. Session will continue but cleanup may fail on exit.",
f"\nToken expires in {duration}. Will attempt auto-refresh.",
fg="yellow",
bold=True,
)
)
warned = True
warn_state["expiry"] = True

await anyio.sleep(30)
# Check more frequently as we approach expiry
if remaining <= _TOKEN_REFRESH_THRESHOLD_SECONDS:
await anyio.sleep(5)
else:
await anyio.sleep(30)
except Exception:
return

Expand Down Expand Up @@ -255,6 +417,7 @@ async def _shell_with_signal_handling( # noqa: C901
exit_code = 0
cancelled_exc_class = get_cancelled_exc_class()
lease_used = None
token_state = {"expired_unrecovered": False}

# Check token before starting
token = getattr(config, "token", None)
Expand All @@ -277,11 +440,13 @@ async def _shell_with_signal_handling( # noqa: C901
lease_used = lease

# Start token monitoring only once we're in the shell
tg.start_soon(_monitor_token_expiry, config, tg.cancel_scope)
tg.start_soon(_monitor_token_expiry, config, lease, tg.cancel_scope, token_state)

exit_code = await _run_shell_with_lease_async(
lease, exporter_logs, config, command, tg.cancel_scope
)
if lease.release and lease.name and token_state["expired_unrecovered"]:
_warn_about_expired_token(lease.name, selector)
Comment on lines +448 to +449
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don’t warn about an orphaned lease after natural expiry.

token_state["expired_unrecovered"] only says auth-based cleanup cannot run. If lease.monitor_async() has already set lease.lease_ended, this branch still tells the user the lease will remain active and suggests reconnecting to a lease that no longer exists.

💡 Proposed fix
-                        if lease.release and lease.name and token_state["expired_unrecovered"]:
+                        if (
+                            lease.release
+                            and lease.name
+                            and token_state["expired_unrecovered"]
+                            and not lease.lease_ended
+                        ):
                             _warn_about_expired_token(lease.name, selector)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/packages/jumpstarter-cli/jumpstarter_cli/shell.py` around lines 448 -
449, The warning is shown even when the lease has already ended because the
condition only checks token_state["expired_unrecovered"]; update the conditional
in the block that calls _warn_about_expired_token to also verify the lease
hasn't already ended (e.g., require not lease.lease_ended) so you only warn
about orphaned leases that are still active; locate the check around
lease.release, lease.name and token_state in the function/method that handles
lease state and add the extra lease.lease_ended guard before calling
_warn_about_expired_token.

except BaseExceptionGroup as eg:
for exc in eg.exceptions:
if isinstance(exc, TimeoutError):
Expand Down
Loading
Loading