diff --git a/internal/cmd/agents.go b/internal/cmd/agents.go index 18e36db..a8878e9 100644 --- a/internal/cmd/agents.go +++ b/internal/cmd/agents.go @@ -4,8 +4,6 @@ import ( "errors" "fmt" "os" - "path/filepath" - "strings" "github.com/spf13/cobra" @@ -15,7 +13,8 @@ import ( var agentID string -// GetCurrentAgentID returns the agent ID from flag, env var, or file (in priority order) +// GetCurrentAgentID returns the agent ID from flag, env var, or file (in priority order). +// The file is per-API-key (see state_scope.go) with a legacy unscoped fallback. func GetCurrentAgentID() string { if agentID != "" { return agentID @@ -23,55 +22,31 @@ func GetCurrentAgentID() string { if envID := os.Getenv(config.EnvAgentID); envID != "" { return envID } - configDir, err := config.Dir() + id, err := readStateFile(config.CurrentAgentFile) if err != nil { return "" } - data, err := os.ReadFile(filepath.Join(configDir, config.CurrentAgentFile)) - if err != nil { - return "" - } - return strings.TrimSpace(string(data)) + return id } func setCurrentAgent(id string) error { - configDir, err := config.Dir() - if err != nil { - return err - } - if err := os.MkdirAll(configDir, 0o700); err != nil { - return err - } - return os.WriteFile(filepath.Join(configDir, config.CurrentAgentFile), []byte(id), 0o600) + return writeStateFile(config.CurrentAgentFile, id) } func clearCurrentAgent() error { - configDir, err := config.Dir() - if err != nil { - return err - } - path := filepath.Join(configDir, config.CurrentAgentFile) - if err := os.Remove(path); err != nil && !os.IsNotExist(err) { - return err - } - return nil + return clearStateFile(config.CurrentAgentFile) } +// clearCurrentAgentIfMatches removes the per-API-key current_agent file (and any +// legacy unscoped file) only when its contents match expectedID. This avoids +// clobbering a more recent value written by a later command. func clearCurrentAgentIfMatches(expectedID string) error { - configDir, err := config.Dir() + id, err := readStateFile(config.CurrentAgentFile) if err != nil { return err } - path := filepath.Join(configDir, config.CurrentAgentFile) - data, err := os.ReadFile(path) - if err != nil { - if os.IsNotExist(err) { - return nil - } - return err - } - if strings.TrimSpace(string(data)) == expectedID { - return os.Remove(path) + if id != "" && id == expectedID { + return clearStateFile(config.CurrentAgentFile) } return nil } diff --git a/internal/cmd/sessions.go b/internal/cmd/sessions.go index 2b04b97..75a2738 100644 --- a/internal/cmd/sessions.go +++ b/internal/cmd/sessions.go @@ -44,7 +44,8 @@ var ( sessionReplayOutput string ) -// GetCurrentSessionID returns the session ID from flag, env var, or file (in priority order) +// GetCurrentSessionID returns the session ID from flag, env var, or file (in priority order). +// The file is per-API-key (see state_scope.go) with a legacy unscoped fallback. func GetCurrentSessionID() string { // 1. Check --session-id flag (already in sessionID variable if set) if sessionID != "" { @@ -56,118 +57,69 @@ func GetCurrentSessionID() string { return envID } - // 3. Check current_session file - configDir, err := config.Dir() + // 3. Check current_session file (per-API-key, with legacy fallback) + id, err := readStateFile(config.CurrentSessionFile) if err != nil { return "" } - data, err := os.ReadFile(filepath.Join(configDir, config.CurrentSessionFile)) - if err != nil { - return "" - } - return strings.TrimSpace(string(data)) + return id } -// setCurrentSession saves the session ID to the current_session file +// setCurrentSession saves the session ID to the per-API-key current_session file. func setCurrentSession(id string) error { - configDir, err := config.Dir() - if err != nil { - return err - } - // Ensure directory exists - if err := os.MkdirAll(configDir, 0o700); err != nil { - return err - } - return os.WriteFile(filepath.Join(configDir, config.CurrentSessionFile), []byte(id), 0o600) + return writeStateFile(config.CurrentSessionFile, id) } -// clearCurrentSession removes the current_session file +// clearCurrentSession removes the per-API-key current_session file and any +// legacy unscoped file left over from a prior CLI version. func clearCurrentSession() error { - configDir, err := config.Dir() - if err != nil { - return err - } - path := filepath.Join(configDir, config.CurrentSessionFile) - if err := os.Remove(path); err != nil && !os.IsNotExist(err) { - return err - } - return nil + return clearStateFile(config.CurrentSessionFile) } -// setCurrentViewerURL saves the viewer URL to the current_viewer_url file +// setCurrentViewerURL saves the viewer URL to the per-API-key current_viewer_url file. func setCurrentViewerURL(url string) error { - configDir, err := config.Dir() - if err != nil { - return err - } - if err := os.MkdirAll(configDir, 0o700); err != nil { - return err - } - return os.WriteFile(filepath.Join(configDir, config.CurrentViewerURLFile), []byte(url), 0o600) + return writeStateFile(config.CurrentViewerURLFile, url) } -// getCurrentViewerURL reads the viewer URL from the current_viewer_url file +// getCurrentViewerURL reads the viewer URL from the per-API-key current_viewer_url +// file, with a legacy unscoped fallback. func getCurrentViewerURL() string { - configDir, err := config.Dir() + url, err := readStateFile(config.CurrentViewerURLFile) if err != nil { return "" } - data, err := os.ReadFile(filepath.Join(configDir, config.CurrentViewerURLFile)) - if err != nil { - return "" - } - return strings.TrimSpace(string(data)) + return url } -// clearCurrentViewerURL removes the current_viewer_url file +// clearCurrentViewerURL removes the per-API-key current_viewer_url file and any +// legacy unscoped file. func clearCurrentViewerURL() error { - configDir, err := config.Dir() - if err != nil { - return err - } - path := filepath.Join(configDir, config.CurrentViewerURLFile) - if err := os.Remove(path); err != nil && !os.IsNotExist(err) { - return err - } - return nil + return clearStateFile(config.CurrentViewerURLFile) } -// setCurrentSessionExpiry saves the session expiry timestamp to the current_session_expiry file +// setCurrentSessionExpiry saves the session expiry timestamp to the per-API-key +// current_session_expiry file. func setCurrentSessionExpiry(t time.Time) error { - configDir, err := config.Dir() - if err != nil { - return err - } - if err := os.MkdirAll(configDir, 0o700); err != nil { - return err - } - return os.WriteFile(filepath.Join(configDir, config.CurrentSessionExpiryFile), []byte(t.Format(time.RFC3339)), 0o600) + return writeStateFile(config.CurrentSessionExpiryFile, t.Format(time.RFC3339)) } -// getCurrentSessionExpiry reads the session expiry timestamp from the current_session_expiry file +// getCurrentSessionExpiry reads the session expiry timestamp from the per-API-key +// current_session_expiry file, with a legacy unscoped fallback. func getCurrentSessionExpiry() (time.Time, error) { - configDir, err := config.Dir() + raw, err := readStateFile(config.CurrentSessionExpiryFile) if err != nil { return time.Time{}, err } - data, err := os.ReadFile(filepath.Join(configDir, config.CurrentSessionExpiryFile)) - if err != nil { - return time.Time{}, err + if raw == "" { + return time.Time{}, os.ErrNotExist } - return time.Parse(time.RFC3339, strings.TrimSpace(string(data))) + return time.Parse(time.RFC3339, raw) } -// clearCurrentSessionExpiry removes the current_session_expiry file +// clearCurrentSessionExpiry removes the per-API-key current_session_expiry file +// and any legacy unscoped file. func clearCurrentSessionExpiry() error { - configDir, err := config.Dir() - if err != nil { - return err - } - path := filepath.Join(configDir, config.CurrentSessionExpiryFile) - if err := os.Remove(path); err != nil && !os.IsNotExist(err) { - return err - } - return nil + return clearStateFile(config.CurrentSessionExpiryFile) } // RequireSessionID ensures a session ID is available from flag, env, or file diff --git a/internal/cmd/state_scope.go b/internal/cmd/state_scope.go new file mode 100644 index 0000000..c13bcd5 --- /dev/null +++ b/internal/cmd/state_scope.go @@ -0,0 +1,123 @@ +package cmd + +import ( + "crypto/sha256" + "encoding/hex" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/nottelabs/notte-cli/internal/auth" + "github.com/nottelabs/notte-cli/internal/config" +) + +var ( + apiKeyScopeSuffixOnce sync.Once + apiKeyScopeSuffixCache string + + // apiKeyScopeOverride, when non-nil, short-circuits the cached lookup. + // Tests use SetAPIKeyScopeForTesting to deterministically pin the scope. + apiKeyScopeOverride *string +) + +// apiKeyScopeSuffix returns a per-API-key suffix (e.g. ".abc12345") used to +// scope CLI state files such as current_session / current_agent so that two +// different accounts on the same machine never collide on the same shared +// file. When no API key is configured (e.g. during `notte auth login` itself), +// it returns "" - state files keep their legacy unscoped names. +func apiKeyScopeSuffix() string { + if apiKeyScopeOverride != nil { + return *apiKeyScopeOverride + } + apiKeyScopeSuffixOnce.Do(func() { + key, _, err := auth.GetAPIKey("") + if err != nil || key == "" { + apiKeyScopeSuffixCache = "" + return + } + sum := sha256.Sum256([]byte(key)) + apiKeyScopeSuffixCache = "." + hex.EncodeToString(sum[:])[:8] + }) + return apiKeyScopeSuffixCache +} + +// SetAPIKeyScopeForTesting pins the API-key scope suffix to a deterministic +// value for the duration of the caller's test. Restores the previous override +// (typically nil) via t.Cleanup. Tests only. +func SetAPIKeyScopeForTesting(suffix string) func() { + prev := apiKeyScopeOverride + override := suffix + apiKeyScopeOverride = &override + return func() { apiKeyScopeOverride = prev } +} + +// stateFilePath returns the per-API-key path for a CLI state file. If no API +// key is configured the suffix is empty and the path matches the legacy one. +func stateFilePath(name string) (string, error) { + configDir, err := config.Dir() + if err != nil { + return "", err + } + return filepath.Join(configDir, name+apiKeyScopeSuffix()), nil +} + +// readStateFile returns the trimmed contents of a state file, trying the +// scoped path first and falling back to the legacy unscoped path for users +// upgrading from a version that didn't scope by API key. Returns ("", nil) +// when neither file exists, mirroring the previous best-effort read behavior. +func readStateFile(name string) (string, error) { + configDir, err := config.Dir() + if err != nil { + return "", nil + } + scoped := filepath.Join(configDir, name+apiKeyScopeSuffix()) + legacy := filepath.Join(configDir, name) + for _, p := range dedupePaths(scoped, legacy) { + data, err := os.ReadFile(p) + if err == nil { + return strings.TrimSpace(string(data)), nil + } + if !os.IsNotExist(err) { + return "", err + } + } + return "", nil +} + +// writeStateFile writes content to the scoped state file path, creating the +// config directory if needed. +func writeStateFile(name, content string) error { + configDir, err := config.Dir() + if err != nil { + return err + } + if err := os.MkdirAll(configDir, 0o700); err != nil { + return err + } + return os.WriteFile(filepath.Join(configDir, name+apiKeyScopeSuffix()), []byte(content), 0o600) +} + +// clearStateFile removes the scoped state file and, when distinct, the legacy +// unscoped file too. Cleanup is best-effort: a non-existent file is not an error. +func clearStateFile(name string) error { + configDir, err := config.Dir() + if err != nil { + return err + } + scoped := filepath.Join(configDir, name+apiKeyScopeSuffix()) + legacy := filepath.Join(configDir, name) + for _, p := range dedupePaths(scoped, legacy) { + if err := os.Remove(p); err != nil && !os.IsNotExist(err) { + return err + } + } + return nil +} + +func dedupePaths(a, b string) []string { + if a == b { + return []string{a} + } + return []string{a, b} +} diff --git a/internal/cmd/state_scope_test.go b/internal/cmd/state_scope_test.go new file mode 100644 index 0000000..5adb922 --- /dev/null +++ b/internal/cmd/state_scope_test.go @@ -0,0 +1,226 @@ +package cmd + +import ( + "crypto/sha256" + "encoding/hex" + "os" + "path/filepath" + "sync" + "testing" + + "github.com/nottelabs/notte-cli/internal/config" + "github.com/nottelabs/notte-cli/internal/testutil" +) + +// TestMain pins the API-key state-file scope to "" for the entire cmd test +// package so existing tests that hand-roll state-file paths continue to read +// and write the legacy unscoped name. Tests that exercise scoping flip the +// override locally via SetAPIKeyScopeForTesting. +func TestMain(m *testing.M) { + empty := "" + apiKeyScopeOverride = &empty + os.Exit(m.Run()) +} + +func expectedAPIKeyScopeSuffix(apiKey string) string { + if apiKey == "" { + return "" + } + sum := sha256.Sum256([]byte(apiKey)) + return "." + hex.EncodeToString(sum[:])[:8] +} + +// resetAPIKeyScopeCache forces the next apiKeyScopeSuffix() call to re-derive +// the cached value from the current environment. Restores the test-default +// (override pinned to "") via t.Cleanup. Local test helper only. +func resetAPIKeyScopeCache(t *testing.T) { + t.Helper() + apiKeyScopeSuffixOnce = sync.Once{} + apiKeyScopeSuffixCache = "" + apiKeyScopeOverride = nil + t.Cleanup(func() { + empty := "" + apiKeyScopeSuffixOnce = sync.Once{} + apiKeyScopeSuffixCache = "" + apiKeyScopeOverride = &empty + }) +} + +// TestAPIKeyScopeSuffix_FromEnvVar verifies the scope suffix is derived from +// NOTTE_API_KEY and is stable across calls. +func TestAPIKeyScopeSuffix_FromEnvVar(t *testing.T) { + resetAPIKeyScopeCache(t) + + env := testutil.SetupTestEnv(t) + env.SetEnv("NOTTE_API_KEY", "test-key-12345") + + got := apiKeyScopeSuffix() + want := expectedAPIKeyScopeSuffix("test-key-12345") + if got != want { + t.Fatalf("apiKeyScopeSuffix() = %q, want %q", got, want) + } + + // Stability across calls (cached via sync.Once). + if again := apiKeyScopeSuffix(); again != want { + t.Fatalf("apiKeyScopeSuffix() second call = %q, want %q", again, want) + } +} + +// TestStateFilePath_WithAPIKeyScope verifies writes land at the suffixed path. +func TestStateFilePath_WithAPIKeyScope(t *testing.T) { + cleanup := SetAPIKeyScopeForTesting(".abc12345") + t.Cleanup(cleanup) + + tmpDir := t.TempDir() + config.SetTestConfigDir(tmpDir) + t.Cleanup(func() { config.SetTestConfigDir("") }) + + if err := writeStateFile(config.CurrentSessionFile, "sess_abc"); err != nil { + t.Fatalf("writeStateFile: %v", err) + } + + scopedPath := filepath.Join(tmpDir, config.ConfigDirName, config.CurrentSessionFile+".abc12345") + legacyPath := filepath.Join(tmpDir, config.ConfigDirName, config.CurrentSessionFile) + + if data, err := os.ReadFile(scopedPath); err != nil { + t.Fatalf("expected scoped file %s to exist: %v", scopedPath, err) + } else if string(data) != "sess_abc" { + t.Fatalf("scoped file content = %q, want %q", string(data), "sess_abc") + } + + if _, err := os.Stat(legacyPath); !os.IsNotExist(err) { + t.Fatalf("legacy unscoped file must not be written when scope is set; stat err=%v", err) + } +} + +// TestReadStateFile_FallsBackToLegacyUnscopedPath verifies users upgrading +// from a previous CLI version do not lose their existing current_session. +func TestReadStateFile_FallsBackToLegacyUnscopedPath(t *testing.T) { + cleanup := SetAPIKeyScopeForTesting(".abc12345") + t.Cleanup(cleanup) + + tmpDir := t.TempDir() + config.SetTestConfigDir(tmpDir) + t.Cleanup(func() { config.SetTestConfigDir("") }) + + configDir := filepath.Join(tmpDir, config.ConfigDirName) + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatalf("mkdir: %v", err) + } + legacyPath := filepath.Join(configDir, config.CurrentSessionFile) + if err := os.WriteFile(legacyPath, []byte("sess_legacy"), 0o600); err != nil { + t.Fatalf("write legacy: %v", err) + } + + got, err := readStateFile(config.CurrentSessionFile) + if err != nil { + t.Fatalf("readStateFile: %v", err) + } + if got != "sess_legacy" { + t.Fatalf("readStateFile = %q, want %q (legacy fallback)", got, "sess_legacy") + } +} + +// TestReadStateFile_PrefersScopedOverLegacy verifies the per-API-key file +// wins over the legacy unscoped one when both exist (e.g. a previous CLI +// version left the unscoped one behind). +func TestReadStateFile_PrefersScopedOverLegacy(t *testing.T) { + cleanup := SetAPIKeyScopeForTesting(".abc12345") + t.Cleanup(cleanup) + + tmpDir := t.TempDir() + config.SetTestConfigDir(tmpDir) + t.Cleanup(func() { config.SetTestConfigDir("") }) + + configDir := filepath.Join(tmpDir, config.ConfigDirName) + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(configDir, config.CurrentSessionFile), []byte("sess_legacy"), 0o600); err != nil { + t.Fatalf("write legacy: %v", err) + } + if err := os.WriteFile(filepath.Join(configDir, config.CurrentSessionFile+".abc12345"), []byte("sess_scoped"), 0o600); err != nil { + t.Fatalf("write scoped: %v", err) + } + + got, err := readStateFile(config.CurrentSessionFile) + if err != nil { + t.Fatalf("readStateFile: %v", err) + } + if got != "sess_scoped" { + t.Fatalf("readStateFile = %q, want %q (scoped takes precedence)", got, "sess_scoped") + } +} + +// TestClearStateFile_RemovesBothScopedAndLegacy verifies the cleanup path +// drops the legacy file alongside the scoped one so users don't see stale +// state after a clear. +func TestClearStateFile_RemovesBothScopedAndLegacy(t *testing.T) { + cleanup := SetAPIKeyScopeForTesting(".abc12345") + t.Cleanup(cleanup) + + tmpDir := t.TempDir() + config.SetTestConfigDir(tmpDir) + t.Cleanup(func() { config.SetTestConfigDir("") }) + + configDir := filepath.Join(tmpDir, config.ConfigDirName) + if err := os.MkdirAll(configDir, 0o700); err != nil { + t.Fatalf("mkdir: %v", err) + } + legacyPath := filepath.Join(configDir, config.CurrentSessionFile) + scopedPath := filepath.Join(configDir, config.CurrentSessionFile+".abc12345") + if err := os.WriteFile(legacyPath, []byte("sess_legacy"), 0o600); err != nil { + t.Fatalf("write legacy: %v", err) + } + if err := os.WriteFile(scopedPath, []byte("sess_scoped"), 0o600); err != nil { + t.Fatalf("write scoped: %v", err) + } + + if err := clearStateFile(config.CurrentSessionFile); err != nil { + t.Fatalf("clearStateFile: %v", err) + } + + for _, p := range []string{legacyPath, scopedPath} { + if _, err := os.Stat(p); !os.IsNotExist(err) { + t.Fatalf("expected %s to be removed, stat err=%v", p, err) + } + } +} + +// TestDifferentAPIKeysIsolated verifies two API keys land on different paths +// so two accounts on the same machine never see each other's sessions. +func TestDifferentAPIKeysIsolated(t *testing.T) { + tmpDir := t.TempDir() + config.SetTestConfigDir(tmpDir) + t.Cleanup(func() { config.SetTestConfigDir("") }) + + cleanupA := SetAPIKeyScopeForTesting(".aaaaaaaa") + if err := writeStateFile(config.CurrentSessionFile, "sess_for_key_A"); err != nil { + t.Fatalf("writeStateFile A: %v", err) + } + cleanupA() + + cleanupB := SetAPIKeyScopeForTesting(".bbbbbbbb") + if err := writeStateFile(config.CurrentSessionFile, "sess_for_key_B"); err != nil { + t.Fatalf("writeStateFile B: %v", err) + } + + got, err := readStateFile(config.CurrentSessionFile) + if err != nil { + t.Fatalf("readStateFile (B scope): %v", err) + } + if got != "sess_for_key_B" { + t.Fatalf("readStateFile (B scope) = %q, want %q", got, "sess_for_key_B") + } + cleanupB() + + cleanupA2 := SetAPIKeyScopeForTesting(".aaaaaaaa") + t.Cleanup(cleanupA2) + got, err = readStateFile(config.CurrentSessionFile) + if err != nil { + t.Fatalf("readStateFile (A scope): %v", err) + } + if got != "sess_for_key_A" { + t.Fatalf("readStateFile (A scope) = %q, want %q (account isolation)", got, "sess_for_key_A") + } +}