diff --git a/go.mod b/go.mod index e49ded5..467fcf0 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,12 @@ module claude-usage -go 1.24.0 +go 1.25.0 -toolchain go1.24.12 +toolchain go1.25.5 require fyne.io/systray v1.12.0 require ( github.com/godbus/dbus/v5 v5.1.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/sys v0.42.0 // indirect ) diff --git a/go.sum b/go.sum index 762c518..5809996 100644 --- a/go.sum +++ b/go.sum @@ -2,5 +2,5 @@ fyne.io/systray v1.12.0 h1:CA1Kk0e2zwFlxtc02L3QFSiIbxJ/P0n582YrZHT7aTM= fyne.io/systray v1.12.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= diff --git a/internal/api/client.go b/internal/api/client.go index dc42cc6..15ce998 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -2,13 +2,12 @@ package api import ( "bytes" + "context" "encoding/json" "fmt" "io" "log" "net/http" - "os" - "path/filepath" "time" "claude-usage/internal/config" @@ -127,17 +126,17 @@ type tokenRefreshResponse struct { // FetchRateLimits fetches usage data from the OAuth usage endpoint. // This is a free endpoint that doesn't consume any tokens. -func (c *Client) FetchRateLimits() (*RateLimitData, error) { - return c.fetchRateLimitsWithRetry(0) +func (c *Client) FetchRateLimits(ctx context.Context) (*RateLimitData, error) { + return c.fetchRateLimitsWithRetry(ctx, 0) } // fetchRateLimitsWithRetry implements retry logic with automatic token refresh on 401 -func (c *Client) fetchRateLimitsWithRetry(attempt int) (*RateLimitData, error) { +func (c *Client) fetchRateLimitsWithRetry(ctx context.Context, attempt int) (*RateLimitData, error) { if c.token == "" { return nil, fmt.Errorf("no OAuth token configured") } - req, err := http.NewRequest("GET", usageEndpoint, nil) + req, err := http.NewRequestWithContext(ctx, "GET", usageEndpoint, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -154,7 +153,10 @@ func (c *Client) fetchRateLimitsWithRetry(attempt int) (*RateLimitData, error) { if err != nil { return nil, fmt.Errorf("failed to make request: %w", err) } - defer resp.Body.Close() + defer func() { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() // Handle 401 Unauthorized with token refresh if resp.StatusCode == http.StatusUnauthorized { @@ -170,7 +172,7 @@ func (c *Client) fetchRateLimitsWithRetry(attempt int) (*RateLimitData, error) { log.Printf("Token expired (attempt %d/%d), refreshing...", attempt+1, maxRetries) // Attempt to refresh the token - newToken, err := c.RefreshAccessToken() + newToken, err := c.RefreshAccessToken(ctx) if err != nil { return nil, fmt.Errorf("failed to refresh token: %w", err) } @@ -180,7 +182,7 @@ func (c *Client) fetchRateLimitsWithRetry(attempt int) (*RateLimitData, error) { log.Printf("Token refreshed successfully, retrying request") // Retry the request with the new token - return c.fetchRateLimitsWithRetry(attempt + 1) + return c.fetchRateLimitsWithRetry(ctx, attempt+1) } // Check for other errors @@ -201,7 +203,7 @@ func (c *Client) fetchRateLimitsWithRetry(attempt int) (*RateLimitData, error) { // RefreshAccessToken uses the refresh token to obtain a new access token. // Returns the new access token on success. -func (c *Client) RefreshAccessToken() (string, error) { +func (c *Client) RefreshAccessToken(ctx context.Context) (string, error) { if c.refreshToken == "" { return "", fmt.Errorf("no refresh token available") } @@ -219,7 +221,7 @@ func (c *Client) RefreshAccessToken() (string, error) { return "", fmt.Errorf("failed to marshal refresh request: %w", err) } - req, err := http.NewRequest("POST", tokenEndpoint, bytes.NewBuffer(jsonBody)) + req, err := http.NewRequestWithContext(ctx, "POST", tokenEndpoint, bytes.NewBuffer(jsonBody)) if err != nil { return "", fmt.Errorf("failed to create refresh request: %w", err) } @@ -234,7 +236,10 @@ func (c *Client) RefreshAccessToken() (string, error) { if err != nil { return "", fmt.Errorf("failed to make refresh request: %w", err) } - defer resp.Body.Close() + defer func() { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() // Check response status if resp.StatusCode != http.StatusOK { @@ -253,7 +258,6 @@ func (c *Client) RefreshAccessToken() (string, error) { // Check if we got a new refresh token (some OAuth servers rotate them) if refreshResp.RefreshToken != "" && refreshResp.RefreshToken != c.refreshToken { - oldRefreshToken := c.refreshToken c.refreshToken = refreshResp.RefreshToken log.Printf("Received a new refresh token from the server (token rotated)") @@ -262,51 +266,11 @@ func (c *Client) RefreshAccessToken() (string, error) { c.onRefreshTokenUpdate(refreshResp.RefreshToken) } - // Also write a debug warning file next to the binary (for troubleshooting) - if err := c.writeRefreshTokenWarning(refreshResp.RefreshToken, oldRefreshToken); err != nil { - log.Printf("Failed to write refresh token warning file: %v", err) - } } return refreshResp.AccessToken, nil } -// writeRefreshTokenWarning creates a debug file next to the binary when a new refresh token is received. -// This is kept for debugging purposes even though we now persist tokens to credentials file. -func (c *Client) writeRefreshTokenWarning(newRefreshToken, oldRefreshToken string) error { - exe, err := os.Executable() - if err != nil { - return fmt.Errorf("failed to get executable path: %w", err) - } - - warningPath := filepath.Join(filepath.Dir(exe), "NEW_REFRESH_TOKEN_WARNING.txt") - - content := fmt.Sprintf(`DEBUG: Refresh Token Rotation Detected -======================================== - -A new refresh token was received at: %s - -The credentials file should have been updated automatically. -This file is kept for debugging purposes. - -Old refresh token: %s -New refresh token: %s - -If you see authentication errors after restart, check the credentials file. -`, - time.Now().Format(time.RFC3339), - oldRefreshToken, - newRefreshToken, - ) - - if err := os.WriteFile(warningPath, []byte(content), 0644); err != nil { - return fmt.Errorf("failed to write warning file: %w", err) - } - - log.Printf("Refresh token rotation debug info written to: %s", warningPath) - return nil -} - // parseUsageResponse converts the API response to RateLimitData func parseUsageResponse(usage *usageResponse) *RateLimitData { data := &RateLimitData{ diff --git a/internal/app/app.go b/internal/app/app.go index 07aef86..d7f1454 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -2,6 +2,7 @@ package app import ( + "context" "log" "runtime" "sync" @@ -29,7 +30,7 @@ type App struct { iconGen *icon.Generator apiClient *api.Client stats *stats.WeeklyStats - statsMu sync.RWMutex + mu sync.RWMutex // guards stats and apiClient credOrigin string // "file" or "keychain" — tracks where credentials were loaded from stopCh chan struct{} refreshCh chan struct{} @@ -50,9 +51,10 @@ func New(version string) (*App, error) { version: version, tray: tray.New(version, cfg.GetSourceDisplayName()), iconGen: icon.DefaultGenerator(), - apiClient: nil, // Will be initialized when we have a token - stopCh: make(chan struct{}), - refreshCh: make(chan struct{}, 1), + apiClient: nil, // Will be initialized when we have a token + credOrigin: credOriginFile, + stopCh: make(chan struct{}), + refreshCh: make(chan struct{}, 1), }, nil } @@ -182,9 +184,9 @@ func (a *App) refresh() { a.fetchAndApplyRateLimits(weeklyStats, creds.ClaudeAiOauth.AccessToken, creds.ClaudeAiOauth.RefreshToken) // Store stats - a.statsMu.Lock() + a.mu.Lock() a.stats = weeklyStats - a.statsMu.Unlock() + a.mu.Unlock() // Update tray a.updateTray(weeklyStats) @@ -198,7 +200,8 @@ func (a *App) refresh() { // fetchAndApplyRateLimits fetches rate limits from the API and applies them to weeklyStats. func (a *App) fetchAndApplyRateLimits(weeklyStats *stats.WeeklyStats, token string, refreshToken string) { - // Initialize or update API client + // Initialize or update API client (guarded by mutex since toggleSource can nil it) + a.mu.Lock() if a.apiClient == nil { a.apiClient = api.NewClient(token) @@ -210,9 +213,11 @@ func (a *App) fetchAndApplyRateLimits(weeklyStats *stats.WeeklyStats, token stri // Always set the refresh token so the client can auto-refresh on 401 a.apiClient.SetRefreshToken(refreshToken) + client := a.apiClient + a.mu.Unlock() // Fetch rate limits - rateLimits, err := a.apiClient.FetchRateLimits() + rateLimits, err := client.FetchRateLimits(context.Background()) if err != nil { log.Printf("Warning: could not fetch rate limits from API: %v", err) return @@ -293,7 +298,9 @@ func (a *App) toggleSource() { a.tray.UpdateSourceToggle(newSource) // Reset the API client so it gets re-initialized with the new credentials + a.mu.Lock() a.apiClient = nil + a.mu.Unlock() // Trigger a refresh to load the new credentials a.triggerRefresh() @@ -336,8 +343,8 @@ func (a *App) createRefreshTokenCallback() func(string) { // GetStats returns the current weekly stats (thread-safe). func (a *App) GetStats() *stats.WeeklyStats { - a.statsMu.RLock() - defer a.statsMu.RUnlock() + a.mu.RLock() + defer a.mu.RUnlock() return a.stats } @@ -356,9 +363,9 @@ func (a *App) performUpdate() { // Restore normal tooltip after a delay go func() { time.Sleep(5 * time.Second) - a.statsMu.RLock() + a.mu.RLock() stats := a.stats - a.statsMu.RUnlock() + a.mu.RUnlock() if stats != nil { a.updateTray(stats) } diff --git a/internal/config/config.go b/internal/config/config.go index 549f608..8eb935c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -138,7 +138,7 @@ func (c *Config) Save() error { return err } - return os.WriteFile(GetConfigPath(), data, 0644) + return os.WriteFile(GetConfigPath(), data, 0600) } // GetStatsPath returns the effective stats path (config or default). diff --git a/internal/stats/parser.go b/internal/stats/parser.go index 3e4daf9..bd8cad6 100644 --- a/internal/stats/parser.go +++ b/internal/stats/parser.go @@ -112,12 +112,6 @@ func ParseKeychainCredentials() (*Credentials, error) { return &creds, nil } -// FileExists checks if a file exists at the given path. -func FileExists(path string) bool { - _, err := os.Stat(path) - return err == nil -} - // UpdateKeychainRefreshToken updates the refresh token in the macOS Keychain entry. // It reads the current keychain entry, updates the refreshToken field, and writes it back. // The -U flag to security add-generic-password updates an existing entry. diff --git a/internal/stats/weekly.go b/internal/stats/weekly.go index c1f69b7..fd283e6 100644 --- a/internal/stats/weekly.go +++ b/internal/stats/weekly.go @@ -1,7 +1,6 @@ package stats import ( - "math/rand" "time" ) @@ -169,11 +168,7 @@ func (w *WeeklyStats) GetPercentage() int { limit := GetWeeklyLimit(w.SubscriptionType, w.RateLimitTier) if limit == 0 { - // Can't calculate without a limit, return placeholder - // Use a random value that changes daily (seeded by day) - seed := time.Now().UTC().YearDay() - r := rand.New(rand.NewSource(int64(seed))) - return r.Intn(100) + return 0 } // Calculate percentage diff --git a/internal/update/updater.go b/internal/update/updater.go index 0e38c37..e5296bb 100644 --- a/internal/update/updater.go +++ b/internal/update/updater.go @@ -106,7 +106,10 @@ func Update() (*Result, error) { // Verify the download is a valid executable (basic check: has content) info, err := os.Stat(newBinaryPath) - if err != nil || info.Size() < 1000 { + if err != nil { + return nil, fmt.Errorf("failed to stat downloaded file: %w", err) + } + if info.Size() < 1000 { return nil, fmt.Errorf("downloaded file appears invalid (size: %d)", info.Size()) } @@ -245,8 +248,10 @@ func copyFile(src, dst string) error { } defer destFile.Close() - _, err = io.Copy(destFile, sourceFile) - return err + if _, err = io.Copy(destFile, sourceFile); err != nil { + return err + } + return destFile.Sync() } // downloadBinary downloads the binary from the given URL to a temporary file.