From 865dca96e231bae2f086fa56947af75f465c8dac Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 17 May 2026 01:01:48 +0000 Subject: [PATCH] feat(commands): add `shelltime update` self-update command Adds a top-level `shelltime update` command that downloads the latest release archive from GitHub, verifies its SHA256 against checksums.txt, extracts the bundled shelltime / shelltime-daemon binaries, and replaces the running install in place. After the swap it runs `shelltime daemon reinstall` so systemd/launchd picks up the new binary (opt-out via --skip-daemon-reinstall). Homebrew installs are detected (via the resolved executable path containing /Cellar/, /opt/homebrew/, or /home/linuxbrew/.linuxbrew/) and redirected to `brew upgrade shelltime/tap/shelltime` instead of being overwritten. Binaries at unrecognized locations are left untouched with a warning. Flags: --check (dry run), --force (override same-version / dev-build guards), --skip-daemon-reinstall. Pure logic (URL/archive-name builder, checksum parser, archive extractor with zip-slip guard, .bak-rename replace) lives in model/updater.go with table-driven tests in model/updater_test.go. --- README.md | 15 ++ cmd/cli/main.go | 1 + commands/update.go | 194 ++++++++++++++++++ model/updater.go | 459 ++++++++++++++++++++++++++++++++++++++++++ model/updater_test.go | 265 ++++++++++++++++++++++++ 5 files changed, 934 insertions(+) create mode 100644 commands/update.go create mode 100644 model/updater.go create mode 100644 model/updater_test.go diff --git a/README.md b/README.md index fb7b74c..6c44f1e 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,20 @@ brew install shelltime/tap/shelltime curl -sSL https://shelltime.xyz/i | bash ``` +### Upgrading + +For curl-installed users, upgrade in place: + +```bash +shelltime update +``` + +Homebrew users should upgrade via brew: + +```bash +brew upgrade shelltime/tap/shelltime +``` + ## Quick Start The fastest setup path is: @@ -58,6 +72,7 @@ shelltime codex install |---------|-------------| | `shelltime init` | Bootstrap auth, hooks, daemon, and AI-code integrations | | `shelltime auth` | Authenticate with `shelltime.xyz` | +| `shelltime update` | Download and install the latest release in place | | `shelltime doctor` | Check installation and environment health | | `shelltime web` | Open the ShellTime dashboard in a browser | diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 74fd4ad..f090c3c 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -104,6 +104,7 @@ func main() { commands.GrepCommand, commands.ConfigCommand, commands.IosCommand, + commands.UpdateCommand, } err = app.Run(os.Args) if err != nil { diff --git a/commands/update.go b/commands/update.go new file mode 100644 index 0000000..4048c31 --- /dev/null +++ b/commands/update.go @@ -0,0 +1,194 @@ +package commands + +import ( + "context" + "fmt" + "log/slog" + "os" + "path/filepath" + "runtime" + + "github.com/gookit/color" + "github.com/malamtime/cli/model" + "github.com/urfave/cli/v2" +) + +var UpdateCommand *cli.Command = &cli.Command{ + Name: "update", + Usage: "Download and install the latest shelltime release in place", + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "check", + Aliases: []string{"c"}, + Usage: "Only report current vs latest version, do not install", + }, + &cli.BoolFlag{ + Name: "force", + Aliases: []string{"f"}, + Usage: "Proceed even if already on the latest version or running a dev build", + }, + &cli.BoolFlag{ + Name: "skip-daemon-reinstall", + Usage: "Skip refreshing the daemon service after replacing binaries", + }, + }, + Action: commandUpdate, +} + +func commandUpdate(c *cli.Context) error { + ctx, span := commandTracer.Start(c.Context, "update") + defer span.End() + + check := c.Bool("check") + force := c.Bool("force") + skipDaemonReinstall := c.Bool("skip-daemon-reinstall") + + color.Yellow.Println("🔍 Checking for updates...") + + cliPath, err := model.ResolveCLIBinaryPath() + if err != nil { + return fmt.Errorf("resolve running binary path: %w", err) + } + + switch model.DetectInstallKind(cliPath) { + case model.InstallKindHomebrew: + color.Yellow.Println("đŸ“Ļ Detected Homebrew installation.") + color.Yellow.Println(" Run: brew upgrade shelltime/tap/shelltime") + return nil + case model.InstallKindUnknown: + color.Yellow.Printf("âš ī¸ Binary at %s is not in a known auto-updatable location.\n", cliPath) + color.Yellow.Println(" Reinstall via the curl installer or Homebrew to enable in-place updates.") + return nil + } + + latest, err := model.FetchLatestVersion(ctx) + if err != nil { + return fmt.Errorf("fetch latest release: %w", err) + } + + current := commitID + if current == "" { + current = "dev" + } + normalizedLatest := model.NormalizeVersion(latest) + normalizedCurrent := model.NormalizeVersion(current) + + color.Cyan.Printf(" Current: %s\n", current) + color.Cyan.Printf(" Latest: %s\n", latest) + + if check { + if normalizedLatest == normalizedCurrent { + color.Green.Println("✅ Already on the latest version.") + } else { + color.Yellow.Println("âŦ†ī¸ An update is available. Run `shelltime update` to install it.") + } + return nil + } + + if current == "dev" && !force { + color.Yellow.Println("âš ī¸ Refusing to overwrite a dev build. Use --force to proceed anyway.") + return nil + } + + if normalizedLatest == normalizedCurrent && !force { + color.Green.Println("✅ Already on the latest version. Use --force to reinstall.") + return nil + } + + archiveName, err := model.BuildArchiveName(runtime.GOOS, runtime.GOARCH) + if err != nil { + return err + } + downloadURL := model.BuildDownloadURL(latest, archiveName) + + expectedSum, ok, err := model.FetchChecksum(ctx, latest, archiveName) + if err != nil { + color.Yellow.Printf("âš ī¸ Could not fetch checksums.txt: %v (proceeding without verification)\n", err) + } else if !ok { + color.Yellow.Println("âš ī¸ No checksum entry for this archive — proceeding without verification.") + } + + tmpDir, err := os.MkdirTemp("", "shelltime-update-*") + if err != nil { + return fmt.Errorf("create temp dir: %w", err) + } + defer os.RemoveAll(tmpDir) + + archivePath := filepath.Join(tmpDir, archiveName) + color.Yellow.Printf("âŦ‡ī¸ Downloading %s ...\n", archiveName) + if err := model.DownloadAndVerify(ctx, downloadURL, expectedSum, archivePath); err != nil { + return fmt.Errorf("download release: %w", err) + } + + extractDir := filepath.Join(tmpDir, "extracted") + if err := os.MkdirAll(extractDir, 0o755); err != nil { + return err + } + binaries, err := model.ExtractBinaries(archivePath, extractDir) + if err != nil { + return fmt.Errorf("extract archive: %w", err) + } + if _, ok := binaries["shelltime"]; !ok { + return fmt.Errorf("archive %s did not contain a shelltime binary", archiveName) + } + + color.Yellow.Println("🔄 Replacing binaries...") + + if err := model.ReplaceBinary(binaries["shelltime"], cliPath); err != nil { + return fmt.Errorf("replace shelltime binary: %w", err) + } + color.Green.Printf(" shelltime -> %s\n", cliPath) + + if daemonSrc, ok := binaries["shelltime-daemon"]; ok { + daemonDest := resolveDaemonDest() + if err := model.ReplaceBinary(daemonSrc, daemonDest); err != nil { + return fmt.Errorf("replace shelltime-daemon binary: %w", err) + } + color.Green.Printf(" shelltime-daemon -> %s\n", daemonDest) + } + + if shouldReinstallDaemon(ctx, skipDaemonReinstall) { + color.Yellow.Println("🔁 Refreshing daemon service...") + if err := commandDaemonReinstall(c); err != nil { + color.Yellow.Printf("âš ī¸ Daemon reinstall reported an error: %v\n", err) + color.Yellow.Println(" You can rerun `shelltime daemon reinstall` manually.") + } + } else { + color.Yellow.Println("â„šī¸ Skipping daemon reinstall. Run `shelltime daemon reinstall` to pick up the new binary.") + } + + color.Green.Printf("✅ Updated to %s. Restart your shell to use the new binary.\n", latest) + return nil +} + +// resolveDaemonDest returns the path the daemon binary should be written to — +// the existing daemon location if installed, otherwise the curl-installer default. +func resolveDaemonDest() string { + if p, err := model.ResolveDaemonBinaryPath(); err == nil { + return p + } + return filepath.Join(model.GetBinFolderPath(), "shelltime-daemon") +} + +// shouldReinstallDaemon decides whether to call commandDaemonReinstall after a +// binary swap. +func shouldReinstallDaemon(_ context.Context, skipFlag bool) bool { + if skipFlag { + return false + } + if runtime.GOOS == "windows" { + return false + } + if _, err := model.ResolveDaemonBinaryPath(); err != nil { + return false + } + installer, err := model.NewDaemonInstaller("", "", "") + if err != nil { + slog.Debug("skip daemon reinstall: installer factory failed", slog.Any("err", err)) + return false + } + if err := installer.Check(); err != nil { + return false + } + return true +} diff --git a/model/updater.go b/model/updater.go new file mode 100644 index 0000000..e935a2d --- /dev/null +++ b/model/updater.go @@ -0,0 +1,459 @@ +package model + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + "time" + + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" +) + +const ( + // GitHub repo for ShellTime CLI releases. + githubReleasesOwner = "shelltime" + githubReleasesRepo = "cli" + + // Goreleaser ProjectName for release archive naming (cli__.). + releaseArchivePrefix = "cli" + + // Max bytes accepted from an archive entry to defeat zip-bombs. + maxArchiveEntrySize = 200 * 1024 * 1024 + + // HTTP timeouts. + updaterAPITimeout = 15 * time.Second + updaterDownloadTimeout = 5 * time.Minute +) + +// Binary names extracted from release archives. +var allowedArchiveBinaries = map[string]bool{ + "shelltime": true, + "shelltime-daemon": true, + "shelltime.exe": true, + "shelltime-daemon.exe": true, +} + +// LatestRelease is the subset of the GitHub releases API response we use. +type LatestRelease struct { + TagName string `json:"tag_name"` +} + +func newUpdaterHTTPClient(timeout time.Duration) *http.Client { + return &http.Client{ + Timeout: timeout, + Transport: otelhttp.NewTransport(http.DefaultTransport), + } +} + +func updaterUserAgent() string { + v := commitID + if v == "" { + v = "dev" + } + return "shelltimeCLI@" + v +} + +// FetchLatestVersion calls the GitHub API for the latest stable release tag. +func FetchLatestVersion(ctx context.Context) (string, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", githubReleasesOwner, githubReleasesRepo) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", err + } + req.Header.Set("User-Agent", updaterUserAgent()) + req.Header.Set("Accept", "application/vnd.github+json") + + client := newUpdaterHTTPClient(updaterAPITimeout) + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("github api request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("github api returned status %d", resp.StatusCode) + } + + var rel LatestRelease + if err := json.NewDecoder(resp.Body).Decode(&rel); err != nil { + return "", fmt.Errorf("decode github api response: %w", err) + } + if rel.TagName == "" { + return "", errors.New("github api returned empty tag_name") + } + return rel.TagName, nil +} + +// BuildArchiveName returns the release archive filename for the given platform, +// matching the goreleaser name_template (e.g. cli_Darwin_x86_64.zip). +func BuildArchiveName(goos, goarch string) (string, error) { + var osPart string + switch goos { + case "darwin": + osPart = "Darwin" + case "linux": + osPart = "Linux" + case "windows": + osPart = "Windows" + default: + return "", fmt.Errorf("unsupported OS: %s", goos) + } + + var archPart string + switch goarch { + case "amd64": + archPart = "x86_64" + case "arm64": + archPart = "arm64" + case "386": + archPart = "i386" + default: + return "", fmt.Errorf("unsupported architecture: %s", goarch) + } + + ext := "tar.gz" + if goos == "darwin" || goos == "windows" { + ext = "zip" + } + + return fmt.Sprintf("%s_%s_%s.%s", releaseArchivePrefix, osPart, archPart, ext), nil +} + +// BuildDownloadURL returns the direct release-asset URL for a specific tag. +func BuildDownloadURL(tag, archiveName string) string { + return fmt.Sprintf( + "https://github.com/%s/%s/releases/download/%s/%s", + githubReleasesOwner, githubReleasesRepo, tag, archiveName, + ) +} + +// BuildChecksumsURL returns the checksums.txt URL for a specific tag. +func BuildChecksumsURL(tag string) string { + return fmt.Sprintf( + "https://github.com/%s/%s/releases/download/%s/checksums.txt", + githubReleasesOwner, githubReleasesRepo, tag, + ) +} + +// FetchChecksum returns the expected SHA256 for archiveName. The bool reports +// whether a checksum was found; callers may proceed without verification if false. +func FetchChecksum(ctx context.Context, tag, archiveName string) (string, bool, error) { + url := BuildChecksumsURL(tag) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", false, err + } + req.Header.Set("User-Agent", updaterUserAgent()) + + client := newUpdaterHTTPClient(updaterAPITimeout) + resp, err := client.Do(req) + if err != nil { + return "", false, fmt.Errorf("download checksums.txt: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return "", false, nil + } + if resp.StatusCode != http.StatusOK { + return "", false, fmt.Errorf("checksums.txt returned status %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) + if err != nil { + return "", false, fmt.Errorf("read checksums.txt: %w", err) + } + return parseChecksumLine(string(body), archiveName) +} + +// parseChecksumLine finds the SHA256 for archiveName in goreleaser's checksums.txt +// format: " ". +func parseChecksumLine(content, archiveName string) (string, bool, error) { + for _, line := range strings.Split(content, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + fields := strings.Fields(line) + if len(fields) != 2 { + continue + } + if fields[1] == archiveName { + if len(fields[0]) != 64 { + return "", false, fmt.Errorf("malformed sha256 for %s: %q", archiveName, fields[0]) + } + return strings.ToLower(fields[0]), true, nil + } + } + return "", false, nil +} + +// DownloadAndVerify streams url to destPath while hashing. If expectedSha256 is +// non-empty, the download fails unless the computed digest matches. +func DownloadAndVerify(ctx context.Context, url, expectedSha256, destPath string) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + req.Header.Set("User-Agent", updaterUserAgent()) + + client := newUpdaterHTTPClient(updaterDownloadTimeout) + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("download %s: %w", url, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("download %s returned status %d", url, resp.StatusCode) + } + + out, err := os.Create(destPath) + if err != nil { + return err + } + defer out.Close() + + hasher := sha256.New() + if _, err := io.Copy(out, io.TeeReader(resp.Body, hasher)); err != nil { + return fmt.Errorf("write archive: %w", err) + } + + if expectedSha256 != "" { + got := hex.EncodeToString(hasher.Sum(nil)) + if !strings.EqualFold(got, expectedSha256) { + return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedSha256, got) + } + } + return nil +} + +// safeExtractPath joins destDir and entryName, rejecting any result that escapes +// destDir (defends against zip-slip / tar path traversal). +func safeExtractPath(destDir, entryName string) (string, error) { + cleanDest := filepath.Clean(destDir) + target := filepath.Join(cleanDest, filepath.Base(entryName)) + if !strings.HasPrefix(target, cleanDest+string(filepath.Separator)) && target != cleanDest { + return "", fmt.Errorf("archive entry escapes destination: %q", entryName) + } + return target, nil +} + +// ExtractBinaries unpacks the archive into tmpDir, returning a map from binary +// basename (without .exe) to the extracted file path. Only entries matching the +// allowed binary names are extracted. +func ExtractBinaries(archivePath, tmpDir string) (map[string]string, error) { + if strings.HasSuffix(archivePath, ".zip") { + return extractZipBinaries(archivePath, tmpDir) + } + if strings.HasSuffix(archivePath, ".tar.gz") || strings.HasSuffix(archivePath, ".tgz") { + return extractTarGzBinaries(archivePath, tmpDir) + } + return nil, fmt.Errorf("unsupported archive format: %s", archivePath) +} + +func extractZipBinaries(archivePath, tmpDir string) (map[string]string, error) { + zr, err := zip.OpenReader(archivePath) + if err != nil { + return nil, err + } + defer zr.Close() + + out := map[string]string{} + for _, f := range zr.File { + base := filepath.Base(f.Name) + if !allowedArchiveBinaries[base] { + continue + } + target, err := safeExtractPath(tmpDir, base) + if err != nil { + return nil, err + } + rc, err := f.Open() + if err != nil { + return nil, err + } + if err := writeBinary(target, rc); err != nil { + rc.Close() + return nil, err + } + rc.Close() + out[stripExe(base)] = target + } + return out, nil +} + +func extractTarGzBinaries(archivePath, tmpDir string) (map[string]string, error) { + f, err := os.Open(archivePath) + if err != nil { + return nil, err + } + defer f.Close() + + gzr, err := gzip.NewReader(f) + if err != nil { + return nil, err + } + defer gzr.Close() + + out := map[string]string{} + tr := tar.NewReader(gzr) + for { + hdr, err := tr.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, err + } + if hdr.Typeflag != tar.TypeReg { + continue + } + base := filepath.Base(hdr.Name) + if !allowedArchiveBinaries[base] { + continue + } + target, err := safeExtractPath(tmpDir, base) + if err != nil { + return nil, err + } + if err := writeBinary(target, tr); err != nil { + return nil, err + } + out[stripExe(base)] = target + } + return out, nil +} + +func writeBinary(target string, src io.Reader) error { + dst, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755) + if err != nil { + return err + } + defer dst.Close() + if _, err := io.Copy(dst, io.LimitReader(src, maxArchiveEntrySize+1)); err != nil { + return err + } + info, err := dst.Stat() + if err != nil { + return err + } + if info.Size() > maxArchiveEntrySize { + return fmt.Errorf("archive entry %s exceeds max size %d", target, maxArchiveEntrySize) + } + return nil +} + +func stripExe(name string) string { + return strings.TrimSuffix(name, ".exe") +} + +// ReplaceBinary swaps a freshly-downloaded binary into destPath, renaming any +// existing destPath to destPath+".bak" (overwriting a previous .bak). On Unix +// this is safe even while the binary is running because the kernel keeps the +// old inode alive for the current process. +func ReplaceBinary(srcPath, destPath string) error { + bak := destPath + ".bak" + _ = os.Remove(bak) + if _, err := os.Stat(destPath); err == nil { + if err := os.Rename(destPath, bak); err != nil { + return fmt.Errorf("rename %s -> %s: %w", destPath, bak, err) + } + } + if err := moveFile(srcPath, destPath); err != nil { + // Try to restore .bak on failure so we don't leave the user without a binary. + _ = os.Rename(bak, destPath) + return err + } + if err := os.Chmod(destPath, 0o755); err != nil { + return err + } + return nil +} + +// moveFile renames src to dst, falling back to copy+remove when crossing +// filesystems (e.g. /tmp to $HOME on Linux with separate mounts). +func moveFile(src, dst string) error { + if err := os.Rename(src, dst); err == nil { + return nil + } + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755) + if err != nil { + return err + } + if _, err := io.Copy(out, in); err != nil { + out.Close() + return err + } + if err := out.Close(); err != nil { + return err + } + return os.Remove(src) +} + +// NormalizeVersion strips a leading "v" so "v0.94.5" and "0.94.5" compare equal. +func NormalizeVersion(v string) string { + return strings.TrimPrefix(strings.TrimSpace(v), "v") +} + +// ResolveCLIBinaryPath returns the real (symlink-resolved) path of the running +// CLI binary. +func ResolveCLIBinaryPath() (string, error) { + exe, err := os.Executable() + if err != nil { + return "", err + } + real, err := filepath.EvalSymlinks(exe) + if err != nil { + return exe, nil + } + return real, nil +} + +// InstallKind describes how the running CLI binary appears to be installed. +type InstallKind int + +const ( + InstallKindUnknown InstallKind = iota + InstallKindHomebrew + InstallKindCurl +) + +// DetectInstallKind classifies binPath as a Homebrew install, a curl-installer +// install ($HOME/.shelltime/bin), or unknown. +func DetectInstallKind(binPath string) InstallKind { + clean := filepath.Clean(binPath) + if strings.Contains(clean, string(filepath.Separator)+"Cellar"+string(filepath.Separator)) || + strings.HasPrefix(clean, "/opt/homebrew/") || + strings.HasPrefix(clean, "/home/linuxbrew/.linuxbrew/") { + return InstallKindHomebrew + } + expected := filepath.Clean(filepath.Join(GetBaseStoragePath(), "bin")) + if strings.HasPrefix(clean, expected+string(filepath.Separator)) { + return InstallKindCurl + } + return InstallKindUnknown +} + +// CurrentPlatform returns the goos/goarch pair, exposed for tests and logging. +func CurrentPlatform() (string, string) { + return runtime.GOOS, runtime.GOARCH +} diff --git a/model/updater_test.go b/model/updater_test.go new file mode 100644 index 0000000..132f26e --- /dev/null +++ b/model/updater_test.go @@ -0,0 +1,265 @@ +package model + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildArchiveName(t *testing.T) { + tests := []struct { + name string + goos string + goarch string + want string + wantErr bool + }{ + {"linux amd64", "linux", "amd64", "cli_Linux_x86_64.tar.gz", false}, + {"linux arm64", "linux", "arm64", "cli_Linux_arm64.tar.gz", false}, + {"darwin amd64", "darwin", "amd64", "cli_Darwin_x86_64.zip", false}, + {"darwin arm64", "darwin", "arm64", "cli_Darwin_arm64.zip", false}, + {"windows amd64", "windows", "amd64", "cli_Windows_x86_64.zip", false}, + {"windows arm64", "windows", "arm64", "cli_Windows_arm64.zip", false}, + {"linux 386", "linux", "386", "cli_Linux_i386.tar.gz", false}, + {"unsupported os", "freebsd", "amd64", "", true}, + {"unsupported arch", "linux", "mips", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := BuildArchiveName(tt.goos, tt.goarch) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestBuildDownloadURL(t *testing.T) { + got := BuildDownloadURL("v1.2.3", "cli_Linux_x86_64.tar.gz") + assert.Equal(t, "https://github.com/shelltime/cli/releases/download/v1.2.3/cli_Linux_x86_64.tar.gz", got) +} + +func TestBuildChecksumsURL(t *testing.T) { + got := BuildChecksumsURL("v1.2.3") + assert.Equal(t, "https://github.com/shelltime/cli/releases/download/v1.2.3/checksums.txt", got) +} + +func TestParseChecksumLine(t *testing.T) { + content := strings.Join([]string{ + "abc123 some_other_file.tar.gz", + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef cli_Linux_x86_64.tar.gz", + "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef cli_Darwin_arm64.zip", + "", + }, "\n") + + t.Run("found", func(t *testing.T) { + sum, ok, err := parseChecksumLine(content, "cli_Linux_x86_64.tar.gz") + require.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", sum) + }) + + t.Run("missing", func(t *testing.T) { + sum, ok, err := parseChecksumLine(content, "cli_Windows_arm64.zip") + require.NoError(t, err) + assert.False(t, ok) + assert.Empty(t, sum) + }) + + t.Run("malformed sha", func(t *testing.T) { + _, _, err := parseChecksumLine("abc123 cli_Linux_x86_64.tar.gz", "cli_Linux_x86_64.tar.gz") + assert.Error(t, err) + }) +} + +func TestSafeExtractPath(t *testing.T) { + tests := []struct { + name string + destDir string + entryName string + wantErr bool + }{ + {"clean basename", "/tmp/foo", "shelltime", false}, + {"path traversal stripped by basename", "/tmp/foo", "../../etc/passwd", false}, + {"absolute stripped by basename", "/tmp/foo", "/etc/passwd", false}, + {"nested name", "/tmp/foo", "bin/shelltime", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := safeExtractPath(tt.destDir, tt.entryName) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.True(t, strings.HasPrefix(got, "/tmp/foo")) + assert.Equal(t, filepath.Base(tt.entryName), filepath.Base(got)) + }) + } +} + +func TestNormalizeVersion(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"v0.94.5", "0.94.5"}, + {"0.94.5", "0.94.5"}, + {" v1.0.0 ", "1.0.0"}, + {"", ""}, + {"v", ""}, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + assert.Equal(t, tt.want, NormalizeVersion(tt.in)) + }) + } +} + +func TestDetectInstallKind(t *testing.T) { + base := GetBaseStoragePath() + tests := []struct { + name string + path string + want InstallKind + }{ + {"curl install", filepath.Join(base, "bin", "shelltime"), InstallKindCurl}, + {"homebrew apple silicon", "/opt/homebrew/bin/shelltime", InstallKindHomebrew}, + {"homebrew cellar", "/usr/local/Cellar/shelltime/0.1.0/bin/shelltime", InstallKindHomebrew}, + {"linuxbrew", "/home/linuxbrew/.linuxbrew/bin/shelltime", InstallKindHomebrew}, + {"random location", "/usr/bin/shelltime", InstallKindUnknown}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, DetectInstallKind(tt.path)) + }) + } +} + +func TestExtractBinariesZip(t *testing.T) { + tmp := t.TempDir() + archivePath := filepath.Join(tmp, "release.zip") + + zf, err := os.Create(archivePath) + require.NoError(t, err) + zw := zip.NewWriter(zf) + + cli, err := zw.Create("shelltime") + require.NoError(t, err) + _, err = cli.Write([]byte("CLI_BINARY")) + require.NoError(t, err) + + daemon, err := zw.Create("shelltime-daemon") + require.NoError(t, err) + _, err = daemon.Write([]byte("DAEMON_BINARY")) + require.NoError(t, err) + + junk, err := zw.Create("README.md") + require.NoError(t, err) + _, err = junk.Write([]byte("ignored")) + require.NoError(t, err) + + require.NoError(t, zw.Close()) + require.NoError(t, zf.Close()) + + dest := filepath.Join(tmp, "out") + require.NoError(t, os.MkdirAll(dest, 0o755)) + + got, err := ExtractBinaries(archivePath, dest) + require.NoError(t, err) + assert.Len(t, got, 2) + assert.Contains(t, got, "shelltime") + assert.Contains(t, got, "shelltime-daemon") + + body, err := os.ReadFile(got["shelltime"]) + require.NoError(t, err) + assert.Equal(t, "CLI_BINARY", string(body)) +} + +func TestExtractBinariesTarGz(t *testing.T) { + tmp := t.TempDir() + archivePath := filepath.Join(tmp, "release.tar.gz") + + tf, err := os.Create(archivePath) + require.NoError(t, err) + gw := gzip.NewWriter(tf) + tw := tar.NewWriter(gw) + + writeEntry := func(name, body string) { + require.NoError(t, tw.WriteHeader(&tar.Header{ + Name: name, + Mode: 0o755, + Size: int64(len(body)), + Typeflag: tar.TypeReg, + })) + _, err := tw.Write([]byte(body)) + require.NoError(t, err) + } + + writeEntry("shelltime", "CLI_TAR") + writeEntry("shelltime-daemon", "DAEMON_TAR") + writeEntry("LICENSE", "ignored") + + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + require.NoError(t, tf.Close()) + + dest := filepath.Join(tmp, "out") + require.NoError(t, os.MkdirAll(dest, 0o755)) + + got, err := ExtractBinaries(archivePath, dest) + require.NoError(t, err) + assert.Len(t, got, 2) + body, err := os.ReadFile(got["shelltime"]) + require.NoError(t, err) + assert.Equal(t, "CLI_TAR", string(body)) +} + +func TestReplaceBinary(t *testing.T) { + tmp := t.TempDir() + dest := filepath.Join(tmp, "shelltime") + require.NoError(t, os.WriteFile(dest, []byte("OLD"), 0o755)) + + src := filepath.Join(tmp, "src", "shelltime") + require.NoError(t, os.MkdirAll(filepath.Dir(src), 0o755)) + require.NoError(t, os.WriteFile(src, []byte("NEW"), 0o755)) + + require.NoError(t, ReplaceBinary(src, dest)) + + body, err := os.ReadFile(dest) + require.NoError(t, err) + assert.Equal(t, "NEW", string(body)) + + body, err = os.ReadFile(dest + ".bak") + require.NoError(t, err) + assert.Equal(t, "OLD", string(body)) +} + +func TestReplaceBinaryNoExisting(t *testing.T) { + tmp := t.TempDir() + dest := filepath.Join(tmp, "shelltime") + + src := filepath.Join(tmp, "src", "shelltime") + require.NoError(t, os.MkdirAll(filepath.Dir(src), 0o755)) + require.NoError(t, os.WriteFile(src, []byte("NEW"), 0o755)) + + require.NoError(t, ReplaceBinary(src, dest)) + + body, err := os.ReadFile(dest) + require.NoError(t, err) + assert.Equal(t, "NEW", string(body)) + + _, err = os.Stat(dest + ".bak") + assert.True(t, os.IsNotExist(err)) +}