Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
module claude-usage

go 1.24.0
go 1.25.0

toolchain go1.24.12
toolchain go1.25.5

require fyne.io/systray v1.12.0

require (
github.com/godbus/dbus/v5 v5.1.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/sys v0.42.0 // indirect
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ fyne.io/systray v1.12.0 h1:CA1Kk0e2zwFlxtc02L3QFSiIbxJ/P0n582YrZHT7aTM=
fyne.io/systray v1.12.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
70 changes: 17 additions & 53 deletions internal/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"time"

"claude-usage/internal/config"
Expand Down Expand Up @@ -127,17 +126,17 @@ type tokenRefreshResponse struct {

// FetchRateLimits fetches usage data from the OAuth usage endpoint.
// This is a free endpoint that doesn't consume any tokens.
func (c *Client) FetchRateLimits() (*RateLimitData, error) {
return c.fetchRateLimitsWithRetry(0)
func (c *Client) FetchRateLimits(ctx context.Context) (*RateLimitData, error) {
return c.fetchRateLimitsWithRetry(ctx, 0)
}

// fetchRateLimitsWithRetry implements retry logic with automatic token refresh on 401
func (c *Client) fetchRateLimitsWithRetry(attempt int) (*RateLimitData, error) {
func (c *Client) fetchRateLimitsWithRetry(ctx context.Context, attempt int) (*RateLimitData, error) {
if c.token == "" {
return nil, fmt.Errorf("no OAuth token configured")
}

req, err := http.NewRequest("GET", usageEndpoint, nil)
req, err := http.NewRequestWithContext(ctx, "GET", usageEndpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
Expand All @@ -154,7 +153,10 @@ func (c *Client) fetchRateLimitsWithRetry(attempt int) (*RateLimitData, error) {
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
defer func() {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}()

// Handle 401 Unauthorized with token refresh
if resp.StatusCode == http.StatusUnauthorized {
Expand All @@ -170,7 +172,7 @@ func (c *Client) fetchRateLimitsWithRetry(attempt int) (*RateLimitData, error) {
log.Printf("Token expired (attempt %d/%d), refreshing...", attempt+1, maxRetries)

// Attempt to refresh the token
newToken, err := c.RefreshAccessToken()
newToken, err := c.RefreshAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
Expand All @@ -180,7 +182,7 @@ func (c *Client) fetchRateLimitsWithRetry(attempt int) (*RateLimitData, error) {
log.Printf("Token refreshed successfully, retrying request")

// Retry the request with the new token
return c.fetchRateLimitsWithRetry(attempt + 1)
return c.fetchRateLimitsWithRetry(ctx, attempt+1)
}

// Check for other errors
Expand All @@ -201,7 +203,7 @@ func (c *Client) fetchRateLimitsWithRetry(attempt int) (*RateLimitData, error) {

// RefreshAccessToken uses the refresh token to obtain a new access token.
// Returns the new access token on success.
func (c *Client) RefreshAccessToken() (string, error) {
func (c *Client) RefreshAccessToken(ctx context.Context) (string, error) {
if c.refreshToken == "" {
return "", fmt.Errorf("no refresh token available")
}
Expand All @@ -219,7 +221,7 @@ func (c *Client) RefreshAccessToken() (string, error) {
return "", fmt.Errorf("failed to marshal refresh request: %w", err)
}

req, err := http.NewRequest("POST", tokenEndpoint, bytes.NewBuffer(jsonBody))
req, err := http.NewRequestWithContext(ctx, "POST", tokenEndpoint, bytes.NewBuffer(jsonBody))
if err != nil {
return "", fmt.Errorf("failed to create refresh request: %w", err)
}
Expand All @@ -234,7 +236,10 @@ func (c *Client) RefreshAccessToken() (string, error) {
if err != nil {
return "", fmt.Errorf("failed to make refresh request: %w", err)
}
defer resp.Body.Close()
defer func() {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}()

// Check response status
if resp.StatusCode != http.StatusOK {
Expand All @@ -253,7 +258,6 @@ func (c *Client) RefreshAccessToken() (string, error) {

// Check if we got a new refresh token (some OAuth servers rotate them)
if refreshResp.RefreshToken != "" && refreshResp.RefreshToken != c.refreshToken {
oldRefreshToken := c.refreshToken
c.refreshToken = refreshResp.RefreshToken
log.Printf("Received a new refresh token from the server (token rotated)")

Expand All @@ -262,51 +266,11 @@ func (c *Client) RefreshAccessToken() (string, error) {
c.onRefreshTokenUpdate(refreshResp.RefreshToken)
}

// Also write a debug warning file next to the binary (for troubleshooting)
if err := c.writeRefreshTokenWarning(refreshResp.RefreshToken, oldRefreshToken); err != nil {
log.Printf("Failed to write refresh token warning file: %v", err)
}
}

return refreshResp.AccessToken, nil
}

// writeRefreshTokenWarning creates a debug file next to the binary when a new refresh token is received.
// This is kept for debugging purposes even though we now persist tokens to credentials file.
func (c *Client) writeRefreshTokenWarning(newRefreshToken, oldRefreshToken string) error {
exe, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}

warningPath := filepath.Join(filepath.Dir(exe), "NEW_REFRESH_TOKEN_WARNING.txt")

content := fmt.Sprintf(`DEBUG: Refresh Token Rotation Detected
========================================

A new refresh token was received at: %s

The credentials file should have been updated automatically.
This file is kept for debugging purposes.

Old refresh token: %s
New refresh token: %s

If you see authentication errors after restart, check the credentials file.
`,
time.Now().Format(time.RFC3339),
oldRefreshToken,
newRefreshToken,
)

if err := os.WriteFile(warningPath, []byte(content), 0644); err != nil {
return fmt.Errorf("failed to write warning file: %w", err)
}

log.Printf("Refresh token rotation debug info written to: %s", warningPath)
return nil
}

// parseUsageResponse converts the API response to RateLimitData
func parseUsageResponse(usage *usageResponse) *RateLimitData {
data := &RateLimitData{
Expand Down
31 changes: 19 additions & 12 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package app

import (
"context"
"log"
"runtime"
"sync"
Expand Down Expand Up @@ -29,7 +30,7 @@ type App struct {
iconGen *icon.Generator
apiClient *api.Client
stats *stats.WeeklyStats
statsMu sync.RWMutex
mu sync.RWMutex // guards stats and apiClient
credOrigin string // "file" or "keychain" — tracks where credentials were loaded from
stopCh chan struct{}
refreshCh chan struct{}
Expand All @@ -50,9 +51,10 @@ func New(version string) (*App, error) {
version: version,
tray: tray.New(version, cfg.GetSourceDisplayName()),
iconGen: icon.DefaultGenerator(),
apiClient: nil, // Will be initialized when we have a token
stopCh: make(chan struct{}),
refreshCh: make(chan struct{}, 1),
apiClient: nil, // Will be initialized when we have a token
credOrigin: credOriginFile,
stopCh: make(chan struct{}),
refreshCh: make(chan struct{}, 1),
}, nil
}

Expand Down Expand Up @@ -182,9 +184,9 @@ func (a *App) refresh() {
a.fetchAndApplyRateLimits(weeklyStats, creds.ClaudeAiOauth.AccessToken, creds.ClaudeAiOauth.RefreshToken)

// Store stats
a.statsMu.Lock()
a.mu.Lock()
a.stats = weeklyStats
a.statsMu.Unlock()
a.mu.Unlock()

// Update tray
a.updateTray(weeklyStats)
Expand All @@ -198,7 +200,8 @@ func (a *App) refresh() {

// fetchAndApplyRateLimits fetches rate limits from the API and applies them to weeklyStats.
func (a *App) fetchAndApplyRateLimits(weeklyStats *stats.WeeklyStats, token string, refreshToken string) {
// Initialize or update API client
// Initialize or update API client (guarded by mutex since toggleSource can nil it)
a.mu.Lock()
if a.apiClient == nil {
a.apiClient = api.NewClient(token)

Expand All @@ -210,9 +213,11 @@ func (a *App) fetchAndApplyRateLimits(weeklyStats *stats.WeeklyStats, token stri

// Always set the refresh token so the client can auto-refresh on 401
a.apiClient.SetRefreshToken(refreshToken)
client := a.apiClient
a.mu.Unlock()

// Fetch rate limits
rateLimits, err := a.apiClient.FetchRateLimits()
rateLimits, err := client.FetchRateLimits(context.Background())
if err != nil {
log.Printf("Warning: could not fetch rate limits from API: %v", err)
return
Expand Down Expand Up @@ -293,7 +298,9 @@ func (a *App) toggleSource() {
a.tray.UpdateSourceToggle(newSource)

// Reset the API client so it gets re-initialized with the new credentials
a.mu.Lock()
a.apiClient = nil
a.mu.Unlock()

// Trigger a refresh to load the new credentials
a.triggerRefresh()
Expand Down Expand Up @@ -336,8 +343,8 @@ func (a *App) createRefreshTokenCallback() func(string) {

// GetStats returns the current weekly stats (thread-safe).
func (a *App) GetStats() *stats.WeeklyStats {
a.statsMu.RLock()
defer a.statsMu.RUnlock()
a.mu.RLock()
defer a.mu.RUnlock()
return a.stats
}

Expand All @@ -356,9 +363,9 @@ func (a *App) performUpdate() {
// Restore normal tooltip after a delay
go func() {
time.Sleep(5 * time.Second)
a.statsMu.RLock()
a.mu.RLock()
stats := a.stats
a.statsMu.RUnlock()
a.mu.RUnlock()
if stats != nil {
a.updateTray(stats)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (c *Config) Save() error {
return err
}

return os.WriteFile(GetConfigPath(), data, 0644)
return os.WriteFile(GetConfigPath(), data, 0600)
}

// GetStatsPath returns the effective stats path (config or default).
Expand Down
6 changes: 0 additions & 6 deletions internal/stats/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,6 @@ func ParseKeychainCredentials() (*Credentials, error) {
return &creds, nil
}

// FileExists checks if a file exists at the given path.
func FileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}

// UpdateKeychainRefreshToken updates the refresh token in the macOS Keychain entry.
// It reads the current keychain entry, updates the refreshToken field, and writes it back.
// The -U flag to security add-generic-password updates an existing entry.
Expand Down
7 changes: 1 addition & 6 deletions internal/stats/weekly.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package stats

import (
"math/rand"
"time"
)

Expand Down Expand Up @@ -169,11 +168,7 @@ func (w *WeeklyStats) GetPercentage() int {
limit := GetWeeklyLimit(w.SubscriptionType, w.RateLimitTier)

if limit == 0 {
// Can't calculate without a limit, return placeholder
// Use a random value that changes daily (seeded by day)
seed := time.Now().UTC().YearDay()
r := rand.New(rand.NewSource(int64(seed)))
return r.Intn(100)
return 0
}

// Calculate percentage
Expand Down
11 changes: 8 additions & 3 deletions internal/update/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ func Update() (*Result, error) {

// Verify the download is a valid executable (basic check: has content)
info, err := os.Stat(newBinaryPath)
if err != nil || info.Size() < 1000 {
if err != nil {
return nil, fmt.Errorf("failed to stat downloaded file: %w", err)
}
if info.Size() < 1000 {
return nil, fmt.Errorf("downloaded file appears invalid (size: %d)", info.Size())
}

Expand Down Expand Up @@ -245,8 +248,10 @@ func copyFile(src, dst string) error {
}
defer destFile.Close()

_, err = io.Copy(destFile, sourceFile)
return err
if _, err = io.Copy(destFile, sourceFile); err != nil {
return err
}
return destFile.Sync()
}

// downloadBinary downloads the binary from the given URL to a temporary file.
Expand Down