diff --git a/README.md b/README.md index 5cd7ede..cc377ca 100644 --- a/README.md +++ b/README.md @@ -81,98 +81,59 @@ calc:000/ Construction: Multiply(Inverse(Named("1", Int(1))), Multiply(Int(22), Sessions can be saved and loaded with `.save` and `.load`. +Functions are called with C-style syntax, e.g. `sin(0)`, `log(8, 2)`, or +`atan2(1, 1)`. The library, also listed in the REPL with `.show functions`: + +| Group | Functions | +|-------|-----------| +| Basic | `abs(x)` | +| Trigonometric | `sin(x)`, `cos(x)`, `tan(x)`, `asin(x)`, `acos(x)`, `atan(x)`, `atan2(y, x)` | +| Exponential and logarithmic | `exp(x)`, `ln(x)`, `log10(x)`, `log2(x)`, `log(x, base)` | +| Roots | `sqrt(x)`, `cbrt(x)` | +| Rounding | `floor(x)`, `ceil(x)`, `round(x)` | +| Comparison | `min(x, ...)`, `max(x, ...)` | +| Hyperbolic | `sinh(x)`, `cosh(x)`, `tanh(x)` | +| Combinatorial | `factorial(n)`, `gamma(x)` | + +Trig functions take and return radians; work in degrees by converting +explicitly, e.g. `sin(45 * PI / 180)`. + `cg` ---- -Run a command and annotate each line of its stdout and stderr with a stream -indicator (`O` for stdout, `E` for stderr, `I` for cg's own lifecycle -messages). At the end of the run, print a one-line summary with the exit -code, wall duration, and per-stream line counts. - -Acts like the `annotate-output` script; `cg` is short for command guard. +Run a command and annotate each output line with a stream indicator: `O` for +stdout, `E` for stderr, `I` for cg's own lifecycle messages. At the end of the +run, a one-line summary reports the exit code, wall duration, and per-stream +line counts. ``` go install github.com/ripta/rt/cmd/cg@latest ``` -Basic usage: - ``` ❯ cg -- echo hello O: hello I: Finished exitcode=0 in 2ms (out=1 err=0) -``` -Stdout and stderr are distinguished: - -``` ❯ cg -- sh -c 'echo out; echo err >&2' O: out E: err I: Finished exitcode=0 in 3ms (out=1 err=1) ``` -The child's exit code is propagated: - -``` -❯ cg -- sh -c 'exit 42'; echo $? -I: Finished exitcode=42 in 2ms (out=0 err=0) -42 -``` - -If the child is killed by a signal, the summary reports the signal number -instead of an exit code: - -``` -❯ cg -- sh -c 'kill -TERM $$' -I: Finished signal=15 in 2ms (out=0 err=0) -``` - -SIGINT and SIGTERM are forwarded to the child process. +The child's exit code propagates to the shell. If the child is killed by a +signal, the summary reports the signal number instead. SIGINT and SIGTERM are +forwarded to the child. -Verbose mode (`-v` / `--verbose`) restores the older preamble — version line, -prefix echo, `Started` line — and prefixes every output line with a -timestamp: - -``` -❯ cg -v -- echo hello -19:02:59 I: cg v0.1.0 -19:02:59 I: prefix="15:04:05 " -19:02:59 I: Started echo hello -19:02:59 O: hello -19:02:59 I: Finished exitcode=0 in 2ms (out=1 err=0) -``` - -The verbose timestamp format follows the Go `time.Format` layout and is -customised with `--format`: - -``` -❯ cg -v --format '2006-01-02T15:04:05 ' -- echo hello -2026-02-22T19:05:00 I: cg v0.1.0 -2026-02-22T19:05:00 I: prefix="2006-01-02T15:04:05 " -2026-02-22T19:05:00 I: Started echo hello -2026-02-22T19:05:00 O: hello -2026-02-22T19:05:00 I: Finished exitcode=0 in 2ms (out=1 err=0) -``` - -### Capturing output - -`-c` / `--capture` writes the child's stdout and stderr to files under -`$TMPDIR/cg//` and appends a short run ID to the summary line. The ID is -6 characters of Crockford base-32 (no `I`, `L`, `O`, or `U`), regenerated on -collision. +`-c` / `--capture` writes the child's stdout and stderr to `$TMPDIR/cg//` +and appends a short run ID to the summary line. Resolution subcommands thread +the ID through follow-up calls: ``` ❯ cg -c -- sh -c 'echo out; echo err >&2' I: Finished exitcode=0 in 3ms (out=1 err=1) id=Q3F9K2 -``` - -Each run directory contains `stdout`, `stderr`, and a `meta.json` written -atomically at end-of-run. Resolution subcommands let downstream tooling -thread the ID through follow-up calls without scraping paths: -``` ❯ cg out Q3F9K2 /tmp/cg/Q3F9K2/stdout @@ -183,334 +144,87 @@ thread the ID through follow-up calls without scraping paths: ❯ rg -i FOO $(cg out Q3F9K2) ``` -`cg ls` lists recent runs, most-recent-first by mtime, one row per run: +`cg ls` lists recent runs, most-recent-first; `cg ls -n N` overrides the +default cap of 20. Capture never deletes anything; `cg prune` is the explicit +cleanup hook: ``` ❯ cg ls Q3F9K2 exit=0 3ms sh -c 'echo out; echo err >&2' M7P4QX exit=42 2ms sh -c 'exit 42' -``` - -`cg ls -n N` overrides the default cap of 20. - -Capture itself never deletes anything. `cg prune` is the explicit cleanup -hook: -``` ❯ cg prune # keep the 50 most recent by mtime -❯ cg prune --keep 10 # keep the 10 most recent -❯ cg prune --older-than 7d # evict runs older than seven days -❯ cg prune --dry-run # print what would be removed, change nothing +❯ cg prune --keep 10 +❯ cg prune --older-than 7d +❯ cg prune --dry-run ``` -`--keep` and `--older-than` are mutually exclusive. `--older-than` accepts -the Go `time.ParseDuration` grammar (`90m`, `1h30m`, `2h`) plus convenience -suffixes `Nd` (days) and `Nw` (weeks). Stray non-run entries and incomplete -runs (no `meta.json`) under `$TMPDIR/cg/` are skipped. - -### Other flags +`-v` / `--verbose` prefixes every line with a timestamp and adds a started/finished +preamble; `--format` controls the layout using Go's `time.Format` syntax. +`--buffered` defers child output until the command finishes, grouping by stream. +`--log-parse json|logfmt` reformats structured log lines inline. -`--buffered` defers the child's output until the command finishes, grouping -by stream instead of streaming in real time. - -`--log-parse json|logfmt` reformats structured child log lines; see -`cg --help` for the message-key, timestamp-key, timestamp-format, and field -selectors. - -### MCP server - -`cg mcp` starts a stdio MCP server that exposes the capture-run model as -native tools. Coding agents that speak MCP (Claude Code, the Anthropic SDK, -others) can call `cg` with structured JSON input and output rather than -constructing shell argv and parsing printed paths. The server is a thin -wrapper over the same on-disk capture model the shell subcommands use, so a -run started with `cg -c -- cmd` is visible to the MCP tools and a run -started by `cg_run` is visible to `cg ls`. MCP is additive; the shell -subcommands continue to work unchanged. - -Register the server with Claude Code using its CLI: +`cg mcp` starts a stdio MCP server that exposes the capture-run model as native +tools, using the same on-disk storage the shell subcommands use — a run started +with `cg -c` is visible to `cg_list`, and a run started by `cg_run` is visible +to `cg ls`. Register with Claude Code: ``` claude mcp add cg cg mcp ``` -For agents that support CLI-based registration but use a different command, -check their docs — the pattern is the same: server name `cg`, command `cg`, -argument `mcp`. - -For MCP hosts that require editing config by hand, add a `cg` entry under -`mcpServers`: +Or by hand in the MCP host config: ```json { "mcpServers": { - "cg": { - "command": "cg", - "args": ["mcp"] - } + "cg": { "command": "cg", "args": ["mcp"] } } } ``` -Any MCP host that speaks the stdio transport launches the server the same -way: spawn `cg mcp` and exchange MCP messages over its stdin and stdout. - The server registers ten tools: | Tool | Purpose | |------|---------| -| `cg_run` | Run a command with capture and return metadata plus head- or tail-window excerpts. | -| `cg_list` | List recent capture runs, most-recent-first by mtime. | -| `cg_meta` | Return the run state and `meta.json` fields for a run. | +| `cg_run` | Run a command with capture; returns metadata and head/tail excerpts. | +| `cg_list` | List recent runs, most-recent-first. | +| `cg_meta` | Return run state and metadata. | | `cg_wait` | Block until a run finishes or a timeout elapses. | | `cg_cancel` | Signal a run's process group, with optional escalation. | -| `cg_paths` | Return absolute paths for a run's `stdout`, `stderr`, `meta.json`. | -| `cg_stdout` | Fetch captured stdout for a run, with byte limits and head/tail windowing. | -| `cg_stderr` | Fetch captured stderr for a run, with byte limits and head/tail windowing. | -| `cg_grep` | Search a run's captured output and return matching lines. | -| `cg_prune` | Evict capture runs by count (`keep`) or age (`older_than`). | - -Unknown IDs and malformed inputs surface as MCP tool errors. A child -command exiting non-zero is data, not an error: `cg_run` returns -successfully with `exit_code: N` and the caller decides how to react. - -#### `cg_run` - -Run a command with capture. Blocks until the child exits or -`wait_timeout_ms` elapses; on timeout, the child keeps running and the -capture continues on disk. - -**Inputs** - -| Field | Type | Default | Notes | -|-------|------|---------|-------| -| `command` | `string[]` | required | argv; index 0 is the program. | -| `cwd` | `string` | server cwd | working directory. | -| `env` | `object` | server env | environment overrides, merged onto the server's env. | -| `wait` | `bool` | `true` | block until exit or timeout. | -| `wait_timeout_ms` | `int` | `60000` | how long to wait before returning `timed_out: true`. | -| `excerpt_bytes` | `int` | `4096` | per-stream excerpt cap; max `16384`. | -| `excerpt_from` | `string` | `auto` | excerpt window: `auto` picks head on success, tail on non-zero exit / signal / timeout; `head` or `tail` forces the window. | +| `cg_paths` | Return absolute paths for a run's stdout, stderr, and meta.json. | +| `cg_stdout` | Fetch captured stdout with byte limits and head/tail windowing. | +| `cg_stderr` | Fetch captured stderr with byte limits and head/tail windowing. | +| `cg_grep` | Search captured output and return matching lines. | +| `cg_prune` | Evict runs by count or age. | -**Outputs** - -| Field | Type | Notes | -|-------|------|-------| -| `id` | `string` | Capture run ID. | -| `started` | `bool` | Set when `wait: false`. | -| `timed_out` | `bool` | Set when the wait timeout fired. | -| `exit_code` | `int?` | Child exit code; absent if timed out. | -| `signal` | `int?` | Signal that killed the child, if any. | -| `duration_ms` | `int?` | Wall-clock run duration; absent if timed out. | -| `stdout_lines` | `int?` | Total stdout lines; absent if timed out. | -| `stderr_lines` | `int?` | Total stderr lines; absent if timed out. | -| `stdout_excerpt` | `string` | `excerpt_bytes` from stdout; window per `excerpt_from`. | -| `stderr_excerpt` | `string` | `excerpt_bytes` from stderr; window per `excerpt_from`. | -| `excerpt_from` | `string` | Window that was used: `head` or `tail`. Omitted when no excerpts (e.g., `wait: false`). | -| `truncated` | `bool` | Either stream had more than `excerpt_bytes`. | - -#### `cg_list` - -List recent capture runs, most-recent-first by directory mtime. The default -surfaces only finished runs; pass `state` to include in-flight runs (started, -no `meta.json` yet) or to ask for them on their own. - -**Inputs** - -| Field | Type | Default | Notes | -|-------|------|---------|-------| -| `limit` | `int` | `20` | maximum runs to return; max `1000`. | -| `state` | `string` | `finished` | filter: `all`, `finished`, or `running`. | - -**Outputs** - -| Field | Type | Notes | -|-------|------|-------| -| `runs` | `object[]` | One entry per matching run; see fields below. | - -Every `runs[]` entry has `id` and `state` (`"finished"` or `"running"`). -Finished entries also carry `command`, `started_at`, `finished_at`, -`duration_ms`, `exit_code`, `signal?`, `stdout_lines`, `stderr_lines`. -In-flight entries are sparse: only `id`, `state`, and `started_at` -synthesized from the run directory's mtime. - -#### `cg_meta` - -Return a run's state and `meta.json` fields. An in-flight run (no -`meta.json` yet) returns `{id, state: "running"}` with no error; a finished -run returns `state: "finished"` plus all meta fields. An unknown ID is a -tool error. - -**Inputs** - -| Field | Type | Default | Notes | -|-------|------|---------|-------| -| `id` | `string` | required | capture run ID. | - -**Outputs** - -| Field | Type | Notes | -|-------|------|-------| -| `id` | `string` | Run ID. | -| `state` | `string` | `"running"` or `"finished"`. | -| `command` | `string[]` | argv that was executed; finished runs only. | -| `started_at` | `string` | RFC 3339 timestamp; finished runs only. | -| `finished_at` | `string` | RFC 3339 timestamp; finished runs only. | -| `duration_ms` | `int` | Wall-clock duration; finished runs only. | -| `exit_code` | `int` | Child exit code; finished runs only. | -| `signal` | `int?` | Signal that killed the child, if any. | -| `stdout_lines` | `int` | Total stdout lines; finished runs only. | -| `stderr_lines` | `int` | Total stderr lines; finished runs only. | - -#### `cg_wait` - -Block until a run finishes or `timeout_ms` elapses. Uses the in-process -`Done` channel for runs this server started and falls back to filesystem -polling otherwise. An unknown ID is a tool error. - -**Inputs** - -| Field | Type | Default | Notes | -|-------|------|---------|-------| -| `id` | `string` | required | capture run ID. | -| `timeout_ms` | `int` | `60000` | how long to block before returning `finished: false`. | - -**Outputs** - -| Field | Type | Notes | -|-------|------|-------| -| `id` | `string` | Run ID. | -| `finished` | `bool` | `true` if the run completed before the timeout. | -| meta fields | — | When `finished`, the same fields as `cg_meta`. | - -#### `cg_cancel` - -Send a signal to a run's process group. An already-finished run returns -`{signaled: false}` without error; an unknown ID is a tool error. With -`escalate_after_ms > 0`, the server sends the initial signal, waits up to -the deadline, and sends `escalate_signal` if the child is still running. - -**Inputs** - -| Field | Type | Default | Notes | -|-------|------|---------|-------| -| `id` | `string` | required | capture run ID. | -| `signal` | `string`/`int` | `SIGTERM` | initial signal; `SIGTERM`, `SIGINT`, `SIGKILL`, or numeric. | -| `escalate_after_ms` | `int` | `0` | wait this long, then escalate; `0` disables escalation. | -| `escalate_signal` | `string`/`int` | `SIGKILL` | signal sent on escalation. | - -**Outputs** - -| Field | Type | Notes | -|-------|------|-------| -| `id` | `string` | Run ID. | -| `signaled` | `bool` | Whether the initial signal was sent. | -| `signal` | `int` | Numeric value of the initial signal. | -| `escalated` | `bool` | Whether `escalate_signal` was sent. | -| `escalate_signal` | `int?` | Numeric escalation signal; present only when escalated. | -| `finished` | `bool` | Whether the child had exited by the time the call returned. | - -#### `cg_paths` - -Return absolute paths for a run's `stdout`, `stderr`, and `meta.json` -files. Works for in-flight runs; the `meta` path is returned even when the -file does not yet exist, so callers can poll the same path. - -**Inputs** - -| Field | Type | Default | Notes | -|-------|------|---------|-------| -| `id` | `string` | required | capture run ID. | - -**Outputs** - -| Field | Type | Notes | -|-------|------|-------| -| `stdout` | `string` | Absolute path to the stdout file. | -| `stderr` | `string` | Absolute path to the stderr file. | -| `meta` | `string` | Absolute path to `meta.json` (may not exist yet). | - -#### `cg_stdout` and `cg_stderr` - -Fetch captured stdout or stderr for a run. Defaults to the first 16 KiB; -`from: "tail"` reads the last `max_bytes` instead. Works for in-flight -runs. The default encoding validates bytes as UTF-8 and falls back to -base64 automatically on invalid input (binary streams or a tail read -that lands mid-codepoint); set `content_encoding: "base64"` to force -base64 for known binary streams. - -**Inputs** - -| Field | Type | Default | Notes | -|-------|------|---------|-------| -| `id` | `string` | required | capture run ID. | -| `max_bytes` | `int` | `16384` | response cap; max `1048576` (1 MiB), clamped if higher. | -| `from` | `string` | `"head"` | `"head"` reads from `offset`; `"tail"` reads the last `max_bytes`. | -| `offset` | `int` | `0` | byte offset for head reads; ignored when `from: "tail"`. | -| `content_encoding` | `string` | `"utf8"` | `"utf8"` validates UTF-8 and falls back to base64 on invalid bytes; `"base64"` always base64-encodes. | - -**Outputs** - -| Field | Type | Notes | -|-------|------|-------| -| `content` | `string` | Bytes read from the stream, encoded per `content_encoding`. | -| `content_encoding` | `string` | `"utf8"` or `"base64"`; describes how to decode `content`. | -| `total_bytes` | `int` | Total size of the stream file. | -| `returned_bytes` | `int` | Length of `content` in bytes. | -| `truncated` | `bool` | More data exists beyond the returned window. | -| `clamped` | `bool` | `max_bytes` was reduced to the 1 MiB ceiling. | - -#### `cg_grep` +A non-zero child exit code is data, not an MCP error: `cg_run` returns +successfully with `exit_code: N` and the caller decides how to react. -Search a run's captured output line by line and return matching lines. -Supply exactly one of `text` (fixed substring) or `pattern` (RE2 regex). -Searches both streams by default. Works for in-flight runs. A line with -invalid UTF-8 is base64-encoded and tagged `content_encoding: "base64"`. - -**Inputs** - -| Field | Type | Default | Notes | -|-------|------|---------|-------| -| `id` | `string` | required | capture run ID. | -| `text` | `string` | — | fixed-string substring; mutually exclusive with `pattern`. | -| `pattern` | `string` | — | RE2 regex; mutually exclusive with `text`. | -| `streams` | `string` | `all` | which streams to search: `all`, `stdout`, or `stderr`. | -| `case_insensitive` | `bool` | `false` | fold case when matching. | -| `invert_match` | `bool` | `false` | return lines that do NOT match. | -| `max_matches` | `int` | `1000` | cap on returned matches; max `10000`. | - -**Outputs** - -| Field | Type | Notes | -|-------|------|-------| -| `matches` | `object[]` | One entry per matching line; see fields below. | -| `match_count` | `int` | Number of returned matches. | -| `truncated` | `bool` | `max_matches` was hit before the streams were fully scanned. | - -Each `matches[]` entry has `stream` (`"stdout"` or `"stderr"`), -`line_number` (1-based, per stream), `line`, and `content_encoding` -(omitted for UTF-8 lines, `"base64"` when the line is base64-encoded). - -#### `cg_prune` - -Evict capture runs from `$TMPDIR/cg/`. Either keep the `N` most recent -runs by mtime or remove runs older than a duration. `keep` and -`older_than` are mutually exclusive. - -**Inputs** - -| Field | Type | Default | Notes | -|-------|------|---------|-------| -| `keep` | `int` | `50` | keep N most recent runs by mtime. | -| `older_than` | `string` | unset | evict runs older than the given duration, e.g. `7d`, `2h`, `90m`. | -| `dry_run` | `bool` | `false` | report what would be removed without removing. | - -**Outputs** - -| Field | Type | Notes | -|-------|------|-------| -| `removed` | `string[]` | Run IDs that were or would be removed. | -| `dry_run` | `bool` | Echoes the input flag. | +`cg_run` checks each command against an approval matcher before running it. The +default mode prompts for unmatched commands when the client supports elicitation, +and otherwise fails closed; `cg mcp --blindly-allow` skips the gate entirely. +Rules live in `~/.config/cg/approve.yaml` (global) and `.cg.yaml` / +`.cg/approve.yaml` / `.claude/cg.yaml` (project), merged at startup: + +```yaml +version: 1 +mode: enforce # enforce (default), allow-all, or deny-all +deny: + - regex: '^/tmp/' + message: do not run executables from temporary directories +allow: + - prefix: [go, test] + as_basename: true + - regex: '^/opt/foo/bin/[^ ]+(\s|$)' +``` + +In enforce mode the matcher checks deny rules, then allow rules, then prompts; +deny always wins. Each rule matches by `exact` argv, `prefix` tokens, `glob`, +or `regex`. `argv[0]` is resolved to an absolute path before matching, so +path-based rules work however the command was spelled. `as_basename: true` +matches the program's basename instead, regardless of install path; shells and +inline-code interpreters are denied by default and cannot be re-allowed. `enc` diff --git a/go.mod b/go.mod index eceba20..bcdefd3 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 github.com/r3labs/diff/v3 v3.0.2 github.com/ripta/hypercmd v0.6.0 - github.com/ripta/reals v0.0.0-20251220032726-c99f163d5c5c + github.com/ripta/reals v0.0.0-20260614185130-b214700ec783 github.com/ripta/unihan v0.0.0-20250404091138-c307c698a880 github.com/rogpeppe/go-internal v1.15.0 github.com/spf13/cobra v1.10.2 @@ -58,12 +58,12 @@ require ( github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/exp v0.0.0-20260603202125-055de637280b // indirect + golang.org/x/exp v0.0.0-20260611194520-c48552f49976 // indirect golang.org/x/mod v0.37.0 // indirect golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sync v0.21.0 // indirect golang.org/x/sys v0.46.0 // indirect - golang.org/x/tools v0.45.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20260608224507-4308a22a1bab // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260608224507-4308a22a1bab // indirect + golang.org/x/tools v0.46.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260610212136-7ab31c22f7ad // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260610212136-7ab31c22f7ad // indirect ) diff --git a/go.sum b/go.sum index fce4b45..9386cee 100644 --- a/go.sum +++ b/go.sum @@ -73,8 +73,8 @@ github.com/r3labs/diff/v3 v3.0.2 h1:yVuxAY1V6MeM4+HNur92xkS39kB/N+cFi2hMkY06BbA= github.com/r3labs/diff/v3 v3.0.2/go.mod h1:Cy542hv0BAEmhDYWtGxXRQ4kqRsVIcEjG9gChUlTmkw= github.com/ripta/hypercmd v0.6.0 h1:lUxdqhP/cR/Spu8Yi0Ve60Tw8+LYPkBdkhaplDf7fbo= github.com/ripta/hypercmd v0.6.0/go.mod h1:8QnmkN5AFLtPDl5LnGExQdSG/CMEhAk9GqBpYJWScrw= -github.com/ripta/reals v0.0.0-20251220032726-c99f163d5c5c h1:4bBR+jNoWIs1roinlXrVDUtmSvqjtNbrJ3cuQtFci5g= -github.com/ripta/reals v0.0.0-20251220032726-c99f163d5c5c/go.mod h1:WErCt40puDDQdpVq8Hg1DzjB0svufA8WboSYG4BI2+E= +github.com/ripta/reals v0.0.0-20260614185130-b214700ec783 h1:C9fMjkM7wA0VJFYwPf8BCmIxBnqLSqv7KNo88ugJYOo= +github.com/ripta/reals v0.0.0-20260614185130-b214700ec783/go.mod h1:HJD35VfuiXMIOHm9kFC3c637WX/sznn9gES6ajLTUKk= github.com/ripta/unihan v0.0.0-20250404091138-c307c698a880 h1:ZzDUYlZP/LHJmkh+PtgRZHEKa+eNVefq6YR8BnUCQ2I= github.com/ripta/unihan v0.0.0-20250404091138-c307c698a880/go.mod h1:ZLBfCas48lym/27GOsyFjRo7OGejoGHzOTdUdoRtDqU= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= @@ -111,12 +111,12 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto= golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio= -golang.org/x/exp v0.0.0-20260603202125-055de637280b h1:v1uXiEBHo8QA0LiGCo7UgHMzHT4Kdfpl2zmtH5vaP1Q= -golang.org/x/exp v0.0.0-20260603202125-055de637280b/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw= +golang.org/x/exp v0.0.0-20260611194520-c48552f49976 h1:X8Hz2ImujgbmetVuW+w2YkyZChE3cBpZi2P158rTG9M= +golang.org/x/exp v0.0.0-20260611194520-c48552f49976/go.mod h1:vnf4pv9iKZXY58sQE1L86zmNWJ4159e1RkcWiLCkeEY= golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ= golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0= -golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= -golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= +golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o= +golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM= @@ -131,12 +131,12 @@ golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE= golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= -golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8= -golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0= -google.golang.org/genproto/googleapis/api v0.0.0-20260608224507-4308a22a1bab h1:Foefixyu0l973HSYkX8Etw/fPxAmKRhyMGwuqXFiVI0= -google.golang.org/genproto/googleapis/api v0.0.0-20260608224507-4308a22a1bab/go.mod h1:KdNqO+rCIWgFumrNBSEDlDNrkrQnpkax7Tv1WxNY8V4= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260608224507-4308a22a1bab h1:cY0oV1VnAqvaim8VsR8ZyEKAudzbRJMRGwD3W/L7yOw= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260608224507-4308a22a1bab/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +golang.org/x/tools v0.46.0 h1:7jTurBkPZu4moS/Uy4OQT1M+QBlsj3wejyZwsT8Z7rk= +golang.org/x/tools v0.46.0/go.mod h1:FrD85F8l+NWL+9XWBSyVSHO6Ne4jutsfIFba7AWQ5Ys= +google.golang.org/genproto/googleapis/api v0.0.0-20260610212136-7ab31c22f7ad h1:3iLyITS/sySRwbUKoC7ogfj2Yr1Cjs0pfaRKj5U5HEw= +google.golang.org/genproto/googleapis/api v0.0.0-20260610212136-7ab31c22f7ad/go.mod h1:KdNqO+rCIWgFumrNBSEDlDNrkrQnpkax7Tv1WxNY8V4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260610212136-7ab31c22f7ad h1:45WmJvIV6C2+O/jjLkPUH+F3aOj/1miDoU2DD0+NWbg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260610212136-7ab31c22f7ad/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/pkg/calc/calculator.go b/pkg/calc/calculator.go index 5f24ff1..946bf93 100644 --- a/pkg/calc/calculator.go +++ b/pkg/calc/calculator.go @@ -198,8 +198,7 @@ func init() { return c.handleSet(args) }, ".show": func(c *Calculator, args []string) error { - c.handleShow() - return nil + return c.handleShow(args) }, ".toggle": func(c *Calculator, args []string) error { return c.handleToggle(args) @@ -424,8 +423,33 @@ func (c *Calculator) handleLoad(args []string) error { return nil } -// handleShow displays current settings -func (c *Calculator) handleShow() { +// showTopics maps each .show topic to its handler. Topics accept any +// unambiguous prefix, like every other meta-command. +var showTopics = map[string]func(*Calculator){ + "settings": (*Calculator).showSettings, + "functions": (*Calculator).showFunctions, +} + +// handleShow displays settings or the function library, selected by an optional +// topic. No topic shows settings; a topic selects its listing by unambiguous +// prefix. +func (c *Calculator) handleShow(args []string) error { + topic := "settings" + if len(args) > 0 { + topic = args[0] + } + + show, err := findByPrefix(topic, showTopics) + if err != nil { + return err + } + + show(c) + return nil +} + +// showSettings prints each setting's current value. +func (c *Calculator) showSettings() { fmt.Println("settings:") for name, setting := range settingsRegistry { switch setting.Type { @@ -437,11 +461,38 @@ func (c *Calculator) handleShow() { } } +// showFunctions lists the registered functions grouped by category. Trig +// functions take and return radians; convert degrees explicitly, e.g. +// sin(45 * PI / 180). +func (c *Calculator) showFunctions() { + fns := parser.Functions() + + width := 0 + for _, f := range fns { + if len(f.Signature) > width { + width = len(f.Signature) + } + } + + group := "" + for _, f := range fns { + if f.Group != group { + if group != "" { + fmt.Println() + } + group = f.Group + fmt.Printf("%s:\n", group) + } + fmt.Printf(" %-*s %s\n", width, f.Signature, f.Summary) + } +} + // handleHelp displays available meta-commands func (c *Calculator) handleHelp() { fmt.Println("Available commands:") fmt.Println(" .set - Change a setting") fmt.Println(" .show - Show current settings") + fmt.Println(" .show functions - List available functions") fmt.Println(" .toggle - Toggle a boolean setting") fmt.Println(" .save [path] - Save session (default: ~/.local/state/rt/calc/session.txt)") fmt.Println(" .load [path] - Load session (default: ~/.local/state/rt/calc/session.txt)") diff --git a/pkg/calc/calculator_test.go b/pkg/calc/calculator_test.go index d6d7cb1..11643e0 100644 --- a/pkg/calc/calculator_test.go +++ b/pkg/calc/calculator_test.go @@ -67,6 +67,26 @@ var handleMetaCommandTests = []handleMetaCommandTest{ cmd: ".sho", wantErr: false, }, + { + name: ".show functions", + cmd: ".show functions", + wantErr: false, + }, + { + name: ".sh functions prefix", + cmd: ".sh functions", + wantErr: false, + }, + { + name: ".sh f resolves functions topic by prefix", + cmd: ".sh f", + wantErr: false, + }, + { + name: ".show unknown topic", + cmd: ".show bogus", + wantErr: true, + }, } func TestHandleMetaCommand(t *testing.T) { @@ -86,6 +106,25 @@ func TestHandleMetaCommand(t *testing.T) { } } +func TestHandleShowTopics(t *testing.T) { + t.Parallel() + + c := &Calculator{DecimalPlaces: 30} + + if err := c.handleShow(nil); err != nil { + t.Errorf("handleShow(nil) = %v, want nil", err) + } + if err := c.handleShow([]string{"settings"}); err != nil { + t.Errorf("handleShow(settings) = %v, want nil", err) + } + if err := c.handleShow([]string{"functions"}); err != nil { + t.Errorf("handleShow(functions) = %v, want nil", err) + } + if err := c.handleShow([]string{"bogus"}); err == nil { + t.Error("handleShow(bogus) = nil, want error") + } +} + type parseBoolTest struct { name string input string diff --git a/pkg/calc/lexer/lex_expression.go b/pkg/calc/lexer/lex_expression.go index 552c96d..ca3dd0f 100644 --- a/pkg/calc/lexer/lex_expression.go +++ b/pkg/calc/lexer/lex_expression.go @@ -88,6 +88,10 @@ func lexExpression(l *L) lexingState { l.Emit(tokens.RPAREN) return lexExpression + case r == ',': + l.Emit(tokens.COMMA) + return lexExpression + case r == '$': return lexLineIdent diff --git a/pkg/calc/lexer/lexer_test.go b/pkg/calc/lexer/lexer_test.go index d36bfe1..063349d 100644 --- a/pkg/calc/lexer/lexer_test.go +++ b/pkg/calc/lexer/lexer_test.go @@ -85,6 +85,19 @@ var tokenTests = []tokenTest{ }, wantErr: "too many decimal points", }, + { + name: "function call with comma-separated arguments", + input: "max(1, 2)", + want: []tokenExpectation{ + {Type: tokens.IDENT, Value: "max", Col: 1}, + {Type: tokens.LPAREN, Value: "(", Col: 4}, + {Type: tokens.LIT_INT, Value: "1", Col: 5}, + {Type: tokens.COMMA, Value: ",", Col: 6}, + {Type: tokens.WHITESPACE, Value: " ", Col: 7}, + {Type: tokens.LIT_INT, Value: "2", Col: 8}, + {Type: tokens.RPAREN, Value: ")", Col: 9}, + }, + }, { name: "minus operator and identifier", input: "-foo", diff --git a/pkg/calc/parser/functions.go b/pkg/calc/parser/functions.go new file mode 100644 index 0000000..1e39969 --- /dev/null +++ b/pkg/calc/parser/functions.go @@ -0,0 +1,435 @@ +package parser + +import ( + "errors" + "fmt" + "math/big" + "strings" + + "github.com/ripta/reals/pkg/constructive" + "github.com/ripta/reals/pkg/rational" + "github.com/ripta/reals/pkg/unified" +) + +// function is a registry entry. minArgs and maxArgs bound the accepted argument +// count; maxArgs of -1 means unbounded variadic. fn receives the environment so +// implementations that decide irrational operands at a binary precision can read +// env.precision. name, group, signature, and summary drive the discoverability +// listing; group orders the catalog and signature/summary are human-facing. +type function struct { + name string + group string + signature string + summary string + minArgs int + maxArgs int + fn func(env *Env, args []*unified.Real) (*unified.Real, error) +} + +// functionCatalog is the source of truth for the function registry. Entries are +// ordered by group so the discoverability listing iterates it directly and emits +// a header whenever the group changes, with no separate ordering table. The +// dispatch map below is derived from it. +var functionCatalog = []function{ + {name: "abs", group: "Basic", signature: "abs(x)", summary: "absolute value", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Abs(), nil }}, + {name: "signum", group: "Basic", signature: "signum(x)", summary: "sign of x as -1, 0, or 1", minArgs: 1, maxArgs: 1, fn: func(e *Env, a []*unified.Real) (*unified.Real, error) { return signum(a[0], e.precision) }}, + + {name: "sin", group: "Trigonometric (radians)", signature: "sin(x)", summary: "sine", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Sin(), nil }}, + {name: "cos", group: "Trigonometric (radians)", signature: "cos(x)", summary: "cosine", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Cos(), nil }}, + {name: "tan", group: "Trigonometric (radians)", signature: "tan(x)", summary: "tangent", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Tan(), nil }}, + {name: "asin", group: "Trigonometric (radians)", signature: "asin(x)", summary: "inverse sine, returns radians", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Asin() }}, + {name: "acos", group: "Trigonometric (radians)", signature: "acos(x)", summary: "inverse cosine, returns radians", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Acos() }}, + {name: "atan", group: "Trigonometric (radians)", signature: "atan(x)", summary: "inverse tangent, returns radians", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Atan(), nil }}, + {name: "atan2", group: "Trigonometric (radians)", signature: "atan2(y, x)", summary: "angle of the point (x, y), returns radians", minArgs: 2, maxArgs: 2, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Atan2(a[1]) }}, + {name: "deg2rad", group: "Trigonometric (radians)", signature: "deg2rad(x)", summary: "convert degrees to radians", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return deg2rad(a[0]), nil }}, + {name: "rad2deg", group: "Trigonometric (radians)", signature: "rad2deg(x)", summary: "convert radians to degrees", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return rad2deg(a[0]), nil }}, + + {name: "exp", group: "Exponential and logarithmic", signature: "exp(x)", summary: "e raised to x", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Exp(), nil }}, + {name: "ln", group: "Exponential and logarithmic", signature: "ln(x)", summary: "natural logarithm", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Ln() }}, + {name: "log10", group: "Exponential and logarithmic", signature: "log10(x)", summary: "base-10 logarithm", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Log10() }}, + {name: "log2", group: "Exponential and logarithmic", signature: "log2(x)", summary: "base-2 logarithm", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Log2() }}, + {name: "log", group: "Exponential and logarithmic", signature: "log(x, base)", summary: "logarithm of x in the given base", minArgs: 2, maxArgs: 2, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Log(a[1]) }}, + + {name: "sqrt", group: "Roots", signature: "sqrt(x)", summary: "square root", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Sqrt() }}, + {name: "cbrt", group: "Roots", signature: "cbrt(x)", summary: "cube root, defined for negatives", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Cbrt(), nil }}, + + {name: "hypot", group: "Geometry", signature: "hypot(x, y)", summary: "length of the hypotenuse, sqrt(x^2 + y^2)", minArgs: 2, maxArgs: 2, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return hypot(a[0], a[1]) }}, + {name: "dist", group: "Geometry", signature: "dist(x1, y1, x2, y2)", summary: "Euclidean distance between two points", minArgs: 4, maxArgs: 4, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return dist(a[0], a[1], a[2], a[3]) }}, + {name: "norm", group: "Geometry", signature: "norm(x, ...)", summary: "Euclidean norm, sqrt of the sum of squares", minArgs: 1, maxArgs: -1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return norm(a) }}, + + {name: "floor", group: "Rounding", signature: "floor(x)", summary: "round down to an integer", minArgs: 1, maxArgs: 1, fn: func(e *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Floor(e.precision), nil }}, + {name: "ceil", group: "Rounding", signature: "ceil(x)", summary: "round up to an integer", minArgs: 1, maxArgs: 1, fn: func(e *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Ceil(e.precision), nil }}, + {name: "round", group: "Rounding", signature: "round(x)", summary: "round half away from zero", minArgs: 1, maxArgs: 1, fn: func(e *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Round(e.precision), nil }}, + {name: "trunc", group: "Rounding", signature: "trunc(x)", summary: "round toward zero", minArgs: 1, maxArgs: 1, fn: func(e *Env, a []*unified.Real) (*unified.Real, error) { return trunc(a[0], e.precision) }}, + + {name: "min", group: "Comparison", signature: "min(x, ...)", summary: "smallest argument", minArgs: 1, maxArgs: -1, fn: func(e *Env, a []*unified.Real) (*unified.Real, error) { return foldMinMax(a, e.precision, false), nil }}, + {name: "max", group: "Comparison", signature: "max(x, ...)", summary: "largest argument", minArgs: 1, maxArgs: -1, fn: func(e *Env, a []*unified.Real) (*unified.Real, error) { return foldMinMax(a, e.precision, true), nil }}, + + {name: "sinh", group: "Hyperbolic", signature: "sinh(x)", summary: "hyperbolic sine", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Sinh(), nil }}, + {name: "cosh", group: "Hyperbolic", signature: "cosh(x)", summary: "hyperbolic cosine", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Cosh(), nil }}, + {name: "tanh", group: "Hyperbolic", signature: "tanh(x)", summary: "hyperbolic tangent", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Tanh(), nil }}, + {name: "asinh", group: "Hyperbolic", signature: "asinh(x)", summary: "inverse hyperbolic sine", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return asinh(a[0]) }}, + {name: "acosh", group: "Hyperbolic", signature: "acosh(x)", summary: "inverse hyperbolic cosine", minArgs: 1, maxArgs: 1, fn: func(e *Env, a []*unified.Real) (*unified.Real, error) { return acosh(a[0], e.precision) }}, + {name: "atanh", group: "Hyperbolic", signature: "atanh(x)", summary: "inverse hyperbolic tangent", minArgs: 1, maxArgs: 1, fn: func(e *Env, a []*unified.Real) (*unified.Real, error) { return atanh(a[0], e.precision) }}, + + {name: "factorial", group: "Combinatorial", signature: "factorial(n)", summary: "factorial of a non-negative integer", minArgs: 1, maxArgs: 1, fn: func(e *Env, a []*unified.Real) (*unified.Real, error) { return factorial(a[0], e.precision) }}, + {name: "gamma", group: "Combinatorial", signature: "gamma(x)", summary: "gamma function, the continuous factorial", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return a[0].Gamma() }}, + {name: "lgamma", group: "Combinatorial", signature: "lgamma(x)", summary: "natural log of the absolute value of gamma", minArgs: 1, maxArgs: 1, fn: func(_ *Env, a []*unified.Real) (*unified.Real, error) { return lgamma(a[0]) }}, + {name: "choose", group: "Combinatorial", signature: "choose(n, k)", summary: "binomial coefficient, n choose k", minArgs: 2, maxArgs: 2, fn: func(e *Env, a []*unified.Real) (*unified.Real, error) { return choose(a[0], a[1], e.precision) }}, + {name: "perm", group: "Combinatorial", signature: "perm(n, k)", summary: "number of k-permutations of n", minArgs: 2, maxArgs: 2, fn: func(e *Env, a []*unified.Real) (*unified.Real, error) { return perm(a[0], a[1], e.precision) }}, +} + +// functions is the registry consulted by CallNode, derived from functionCatalog. +// Names share the identifier namespace with variables and constants but are only +// looked up in call position, so a variable named sin and the sin function coexist. +var functions = func() map[string]function { + m := make(map[string]function, len(functionCatalog)) + for _, f := range functionCatalog { + m[f.name] = f + } + return m +}() + +// FunctionInfo describes a registered function for discoverability output. +type FunctionInfo struct { + Name string + Group string + Signature string + Summary string +} + +// Functions returns the registered functions in catalog order, grouped by +// category, for the calculator's discoverability listing. +func Functions() []FunctionInfo { + infos := make([]FunctionInfo, len(functionCatalog)) + for i, f := range functionCatalog { + infos[i] = FunctionInfo{ + Name: f.name, + Group: f.group, + Signature: f.signature, + Summary: f.summary, + } + } + return infos +} + +// errFactorialDomain reports an argument to factorial that is negative or not an +// integer. It carries no sentinel match in domainError, so its own message +// becomes the reason text: factorial(-1): argument must be a non-negative integer. +var errFactorialDomain = errors.New("argument must be a non-negative integer") + +// foldMinMax reduces args to their minimum or maximum with a left-to-right +// pairwise fold. Operands equal within precision resolve to the leftmost, which +// is the library's own tie behavior. +func foldMinMax(args []*unified.Real, precision int, max bool) *unified.Real { + acc := args[0] + for _, a := range args[1:] { + if max { + acc = acc.Max(a, precision) + } else { + acc = acc.Min(a, precision) + } + } + return acc +} + +// errIndeterminate reports a value that cannot be decided at the active +// precision. It is reachable only when the precision itself is invalid, so it +// stands in for an internal failure rather than a user domain error. +var errIndeterminate = errors.New("cannot decide value at the current precision") + +// errCombinatoricDomain reports an argument to choose or perm that is negative +// or not an integer. Like errFactorialDomain it carries no sentinel match, so +// its own message becomes the reason text. +var errCombinatoricDomain = errors.New("arguments must be non-negative integers") + +// errAcoshDomain and errAtanhDomain report arguments outside the real domains of +// the inverse hyperbolic cosine and tangent. +var ( + errAcoshDomain = errors.New("argument must be at least 1") + errAtanhDomain = errors.New("argument must be in (-1, 1)") +) + +// approxRat decides r to a big.Rat at the given binary precision, returning nil +// only when the precision is invalid and the value cannot be approximated. +func approxRat(r *unified.Real, precision int) *big.Rat { + approx := constructive.Approximate(r.Constructive(), precision) + if approx == nil { + return nil + } + + scale := new(big.Int).Exp(big.NewInt(2), big.NewInt(int64(-precision)), nil) + return new(big.Rat).SetFrac(approx, scale) +} + +// intReal wraps an exact integer as a rational Real. +func intReal(n *big.Int) *unified.Real { + return unified.New(constructive.One(), rational.FromRational(new(big.Rat).SetInt(n))) +} + +// factorial computes the exact factorial of a non-negative integer. The argument +// is decided at precision; a non-integer or negative value is a domain error. +// The product is exact over big.Int, so the result is a rational Real. +func factorial(r *unified.Real, precision int) (*unified.Real, error) { + rat := approxRat(r, precision) + if rat == nil || !rat.IsInt() || rat.Sign() < 0 { + return nil, errFactorialDomain + } + + n := rat.Num() + result := big.NewInt(1) + for i := big.NewInt(2); i.Cmp(n) <= 0; i.Add(i, big.NewInt(1)) { + result.Mul(result, i) + } + + return intReal(result), nil +} + +// signum returns the sign of r as -1, 0, or 1, decided at precision. +func signum(r *unified.Real, precision int) (*unified.Real, error) { + rat := approxRat(r, precision) + if rat == nil { + return nil, errIndeterminate + } + return intReal(big.NewInt(int64(rat.Sign()))), nil +} + +// trunc rounds r toward zero, decided at precision: floor for non-negative +// values, ceil for negative ones. +func trunc(r *unified.Real, precision int) (*unified.Real, error) { + rat := approxRat(r, precision) + if rat == nil { + return nil, errIndeterminate + } + if rat.Sign() < 0 { + return r.Ceil(precision), nil + } + return r.Floor(precision), nil +} + +// deg2rad converts an angle in degrees to radians: x * pi / 180. +func deg2rad(r *unified.Real) *unified.Real { + return r.Multiply(unified.Pi()).Divide(intReal(big.NewInt(180))) +} + +// rad2deg converts an angle in radians to degrees: x * 180 / pi. +func rad2deg(r *unified.Real) *unified.Real { + return r.Multiply(intReal(big.NewInt(180))).Divide(unified.Pi()) +} + +// hypot returns sqrt(x^2 + y^2). The radicand is never negative, so the only +// error path is an upstream failure to take the root. +func hypot(x, y *unified.Real) (*unified.Real, error) { + return x.Multiply(x).Add(y.Multiply(y)).Sqrt() +} + +// dist returns the Euclidean distance between the points (x1, y1) and (x2, y2). +func dist(x1, y1, x2, y2 *unified.Real) (*unified.Real, error) { + return hypot(x2.Subtract(x1), y2.Subtract(y1)) +} + +// norm returns the Euclidean norm of a vector, sqrt of the sum of squares. It +// accepts any number of components, generalizing hypot beyond two dimensions. +func norm(components []*unified.Real) (*unified.Real, error) { + sum := components[0].Multiply(components[0]) + for _, c := range components[1:] { + sum = sum.Add(c.Multiply(c)) + } + return sum.Sqrt() +} + +// asinh returns the inverse hyperbolic sine, ln(x + sqrt(x^2 + 1)). The +// arguments to both sqrt and ln stay positive for every real x. +func asinh(x *unified.Real) (*unified.Real, error) { + root, err := x.Multiply(x).Add(intReal(big.NewInt(1))).Sqrt() + if err != nil { + return nil, err + } + return x.Add(root).Ln() +} + +// acosh returns the inverse hyperbolic cosine, ln(x + sqrt(x^2 - 1)), defined +// for x >= 1. +func acosh(x *unified.Real, precision int) (*unified.Real, error) { + rat := approxRat(x, precision) + if rat == nil { + return nil, errIndeterminate + } + if rat.Cmp(big.NewRat(1, 1)) < 0 { + return nil, errAcoshDomain + } + root, err := x.Multiply(x).Subtract(intReal(big.NewInt(1))).Sqrt() + if err != nil { + return nil, err + } + return x.Add(root).Ln() +} + +// atanh returns the inverse hyperbolic tangent, ln((1 + x) / (1 - x)) / 2, +// defined for -1 < x < 1. +func atanh(x *unified.Real, precision int) (*unified.Real, error) { + rat := approxRat(x, precision) + if rat == nil { + return nil, errIndeterminate + } + if rat.Cmp(big.NewRat(1, 1)) >= 0 || rat.Cmp(big.NewRat(-1, 1)) <= 0 { + return nil, errAtanhDomain + } + one := intReal(big.NewInt(1)) + l, err := one.Add(x).Divide(one.Subtract(x)).Ln() + if err != nil { + return nil, err + } + return l.Divide(intReal(big.NewInt(2))), nil +} + +// lgamma returns the natural logarithm of the absolute value of the gamma +// function. The pole error from gamma propagates; elsewhere |gamma| is positive +// so the logarithm is defined. +func lgamma(x *unified.Real) (*unified.Real, error) { + g, err := x.Gamma() + if err != nil { + return nil, err + } + return g.Abs().Ln() +} + +// combInt decides r to a non-negative integer at precision, returning +// errCombinatoricDomain when r is negative or not an integer. +func combInt(r *unified.Real, precision int) (*big.Int, error) { + rat := approxRat(r, precision) + if rat == nil || !rat.IsInt() || rat.Sign() < 0 { + return nil, errCombinatoricDomain + } + return rat.Num(), nil +} + +// choose returns the binomial coefficient C(n, k) for non-negative integers, +// zero when k > n. The running product C(n, i) is integral at every step, so the +// division stays exact. +func choose(nr, kr *unified.Real, precision int) (*unified.Real, error) { + n, err := combInt(nr, precision) + if err != nil { + return nil, err + } + k, err := combInt(kr, precision) + if err != nil { + return nil, err + } + if k.Cmp(n) > 0 { + return intReal(big.NewInt(0)), nil + } + + one := big.NewInt(1) + result := big.NewInt(1) + num := new(big.Int).Set(n) + den := big.NewInt(1) + for i := new(big.Int); i.Cmp(k) < 0; i.Add(i, one) { + result.Mul(result, num) + result.Quo(result, den) + num.Sub(num, one) + den.Add(den, one) + } + return intReal(result), nil +} + +// perm returns the number of k-permutations of n for non-negative integers, +// zero when k > n. +func perm(nr, kr *unified.Real, precision int) (*unified.Real, error) { + n, err := combInt(nr, precision) + if err != nil { + return nil, err + } + k, err := combInt(kr, precision) + if err != nil { + return nil, err + } + if k.Cmp(n) > 0 { + return intReal(big.NewInt(0)), nil + } + + one := big.NewInt(1) + result := big.NewInt(1) + num := new(big.Int).Set(n) + for i := new(big.Int); i.Cmp(k) < 0; i.Add(i, one) { + result.Mul(result, num) + num.Sub(num, one) + } + return intReal(result), nil +} + +// arityError formats a message describing the accepted argument count against +// the count actually supplied. +func arityError(f function, got int) string { + switch { + case f.maxArgs < 0: + return fmt.Sprintf("expects at least %s, got %d", plural(f.minArgs), got) + case f.minArgs == f.maxArgs: + return fmt.Sprintf("expects %s, got %d", plural(f.minArgs), got) + default: + return fmt.Sprintf("expects %d to %d arguments, got %d", f.minArgs, f.maxArgs, got) + } +} + +func plural(n int) string { + if n == 1 { + return "1 argument" + } + return fmt.Sprintf("%d arguments", n) +} + +// domainError formats a function error in call form, naming the function and +// its rendered arguments followed by the reason: sqrt(-1): argument must be +// non-negative. Recognized upstream sentinels map to calc-owned reason text; +// any other error falls back to its own message so non-domain failures still +// surface. +func domainError(name string, args []*unified.Real, precision int, err error) error { + rendered := make([]string, len(args)) + for i, a := range args { + rendered[i] = formatArg(a, precision) + } + call := fmt.Sprintf("%s(%s)", name, strings.Join(rendered, ", ")) + + var reason string + switch { + case errors.Is(err, unified.ErrNonPositive): + reason = "argument must be positive" + case errors.Is(err, unified.ErrNegative): + reason = "argument must be non-negative" + case errors.Is(err, unified.ErrOutsideUnitInterval): + reason = "argument must be in [-1, 1]" + case errors.Is(err, unified.ErrUndefinedAtOrigin): + reason = "undefined at the origin" + case errors.Is(err, unified.ErrInvalidBase): + reason = "base must not be equal to one" + case errors.Is(err, unified.ErrGammaPole): + reason = "argument must not be a non-positive integer" + default: + reason = err.Error() + } + + return fmt.Errorf("%s: %s", call, reason) +} + +// formatArg renders a Real for an error message, approximating it to the active +// precision. Integers print as plain digits; other values print as a decimal +// with trailing zeros trimmed. +func formatArg(r *unified.Real, precision int) string { + approx := constructive.Approximate(r.Constructive(), precision) + if approx == nil { + return "?" + } + + denom := new(big.Int).Exp(big.NewInt(2), big.NewInt(int64(-precision)), nil) + rat := new(big.Rat).SetFrac(approx, denom) + if rat.IsInt() { + return rat.Num().String() + } + + s := rat.FloatString(-precision) + if strings.Contains(s, ".") { + s = strings.TrimRight(s, "0") + s = strings.TrimRight(s, ".") + } + return s +} diff --git a/pkg/calc/parser/functions_test.go b/pkg/calc/parser/functions_test.go new file mode 100644 index 0000000..d6aaac1 --- /dev/null +++ b/pkg/calc/parser/functions_test.go @@ -0,0 +1,46 @@ +package parser + +import "testing" + +// TestFunctionsMatchRegistry guards against drift between the dispatch map and +// the catalog the discoverability listing reads from. +func TestFunctionsMatchRegistry(t *testing.T) { + t.Parallel() + + infos := Functions() + if len(infos) != len(functions) { + t.Fatalf("Functions() has %d entries, registry has %d", len(infos), len(functions)) + } + + seen := make(map[string]bool, len(infos)) + for _, info := range infos { + if _, ok := functions[info.Name]; !ok { + t.Errorf("Functions() lists %q, absent from the dispatch registry", info.Name) + } + seen[info.Name] = true + } + + for name := range functions { + if !seen[name] { + t.Errorf("registry has %q, absent from Functions()", name) + } + } +} + +// TestFunctionMetadataPopulated verifies every catalog entry carries the +// display fields the listing depends on. +func TestFunctionMetadataPopulated(t *testing.T) { + t.Parallel() + + for _, info := range Functions() { + if info.Group == "" { + t.Errorf("%q: empty Group", info.Name) + } + if info.Signature == "" { + t.Errorf("%q: empty Signature", info.Name) + } + if info.Summary == "" { + t.Errorf("%q: empty Summary", info.Name) + } + } +} diff --git a/pkg/calc/parser/parser.go b/pkg/calc/parser/parser.go index 000c442..815f513 100644 --- a/pkg/calc/parser/parser.go +++ b/pkg/calc/parser/parser.go @@ -299,7 +299,15 @@ func (p *P) parsePrimary() (Node, error) { node = &NumberNode{Value: val} case tokens.IDENT: - node = &IdentNode{Name: tok} + if p.peek().Type == tokens.LPAREN { + var err error + node, err = p.parseCall(tok) + if err != nil { + return nil, err + } + } else { + node = &IdentNode{Name: tok} + } case tokens.LPAREN: var err error @@ -331,6 +339,37 @@ func (p *P) parsePrimary() (Node, error) { return node, nil } +func (p *P) parseCall(name tokens.Token) (Node, error) { + p.next() // consume LPAREN + + var args []Node + if p.peek().Type == tokens.RPAREN { + p.next() + return &CallNode{Func: name, Args: args}, nil + } + + for { + arg, err := p.parseAssignment() + if err != nil { + return nil, err + } + args = append(args, arg) + + tok := p.next() + if p.err != nil { + return nil, p.err + } + switch tok.Type { + case tokens.COMMA: + continue + case tokens.RPAREN: + return &CallNode{Func: name, Args: args}, nil + default: + return nil, p.errorf(tok, "expected ',' or ')', got %s", tok.Type) + } + } +} + func (p *P) parseNumber(tok tokens.Token) (*unified.Real, error) { cleaned := strings.ReplaceAll(tok.Value, "_", "") rat := new(big.Rat) diff --git a/pkg/calc/parser/parser_test.go b/pkg/calc/parser/parser_test.go index 0953b85..c4f4adc 100644 --- a/pkg/calc/parser/parser_test.go +++ b/pkg/calc/parser/parser_test.go @@ -359,6 +359,114 @@ func TestParserExpressions(t *testing.T) { exprs: []string{`a "assign" = 5`, `a * 2`}, want: 10, }, + { + name: "sqrt call", + exprs: []string{"sqrt(4)"}, + want: 2, + }, + { + name: "abs call", + exprs: []string{"abs(-3)"}, + want: 3, + }, + { + name: "sin call", + exprs: []string{"sin(0)"}, + want: 0, + }, + { + name: "call with grouped argument expression", + exprs: []string{"sqrt(2 + 2)"}, + want: 2, + }, + { + name: "call nested in expression", + exprs: []string{"sqrt(9) + abs(-1)"}, + want: 4, + }, + { + name: "call with whitespace around argument", + exprs: []string{"sqrt( 16 )"}, + want: 4, + }, + { + name: "variable shares name with function", + exprs: []string{"sin = 5", "sin"}, + want: 5, + }, + { + name: "variable and function coexist", + exprs: []string{"sin = 5", "sin(0) + sin"}, + want: 5, + }, + {name: "cos call", exprs: []string{"cos(0)"}, want: 1}, + {name: "tan call", exprs: []string{"tan(0)"}, want: 0}, + {name: "exp call", exprs: []string{"exp(1)"}, want: math.E}, + {name: "ln call", exprs: []string{"ln(1)"}, want: 0}, + {name: "ln of e", exprs: []string{"ln(E)"}, want: 1}, + {name: "log10 call", exprs: []string{"log10(1000)"}, want: 3}, + {name: "log2 call", exprs: []string{"log2(8)"}, want: 3}, + {name: "log base call", exprs: []string{"log(8, 2)"}, want: 3}, + {name: "atan call", exprs: []string{"atan(0)"}, want: 0}, + {name: "atan2 call", exprs: []string{"atan2(1, 1)"}, want: math.Pi / 4}, + {name: "asin call", exprs: []string{"asin(1)"}, want: math.Pi / 2}, + {name: "acos call", exprs: []string{"acos(1)"}, want: 0}, + {name: "cbrt of negative", exprs: []string{"cbrt(-8)"}, want: -2}, + {name: "cbrt of positive", exprs: []string{"cbrt(27)"}, want: 3}, + {name: "floor call", exprs: []string{"floor(3.7)"}, want: 3}, + {name: "floor of integer", exprs: []string{"floor(3.0)"}, want: 3}, + {name: "ceil call", exprs: []string{"ceil(3.2)"}, want: 4}, + {name: "round half away from zero", exprs: []string{"round(2.5)"}, want: 3}, + {name: "round negative half away from zero", exprs: []string{"round(-2.5)"}, want: -3}, + {name: "min call", exprs: []string{"min(3, 1, 2)"}, want: 1}, + {name: "max call", exprs: []string{"max(3, 1, 2)"}, want: 3}, + {name: "min single argument", exprs: []string{"min(5)"}, want: 5}, + {name: "max of equal irrationals", exprs: []string{"max(PI, PI)"}, want: math.Pi}, + {name: "sinh call", exprs: []string{"sinh(0)"}, want: 0}, + {name: "cosh call", exprs: []string{"cosh(0)"}, want: 1}, + {name: "tanh call", exprs: []string{"tanh(0)"}, want: 0}, + {name: "factorial of zero", exprs: []string{"factorial(0)"}, want: 1}, + {name: "factorial call", exprs: []string{"factorial(5)"}, want: 120}, + {name: "factorial of ten", exprs: []string{"factorial(10)"}, want: 3628800}, + {name: "gamma of one", exprs: []string{"gamma(1)"}, want: 1}, + {name: "gamma matches factorial", exprs: []string{"gamma(5)"}, want: 24}, + {name: "gamma of one half", exprs: []string{"gamma(0.5)"}, want: math.Sqrt(math.Pi)}, + {name: "lgamma of one", exprs: []string{"lgamma(1)"}, want: 0}, + {name: "lgamma matches log factorial", exprs: []string{"lgamma(5)"}, want: math.Log(24)}, + {name: "lgamma of one half", exprs: []string{"lgamma(0.5)"}, want: math.Log(math.Sqrt(math.Pi))}, + {name: "signum of negative", exprs: []string{"signum(-3)"}, want: -1}, + {name: "signum of zero", exprs: []string{"signum(0)"}, want: 0}, + {name: "signum of positive", exprs: []string{"signum(2.5)"}, want: 1}, + {name: "trunc of positive", exprs: []string{"trunc(3.9)"}, want: 3}, + {name: "trunc of negative", exprs: []string{"trunc(-3.9)"}, want: -3}, + {name: "trunc of integer", exprs: []string{"trunc(5)"}, want: 5}, + {name: "deg2rad of straight angle", exprs: []string{"deg2rad(180)"}, want: math.Pi}, + {name: "deg2rad of right angle", exprs: []string{"deg2rad(90)"}, want: math.Pi / 2}, + {name: "rad2deg of pi", exprs: []string{"rad2deg(PI)"}, want: 180}, + {name: "deg2rad and rad2deg round trip", exprs: []string{"rad2deg(deg2rad(57))"}, want: 57}, + {name: "hypot of 3 and 4", exprs: []string{"hypot(3, 4)"}, want: 5}, + {name: "hypot of zeroes", exprs: []string{"hypot(0, 0)"}, want: 0}, + {name: "dist of two points", exprs: []string{"dist(0, 0, 3, 4)"}, want: 5}, + {name: "dist with negative coordinates", exprs: []string{"dist(-1, -1, 2, 3)"}, want: 5}, + {name: "dist of coincident points", exprs: []string{"dist(7, 2, 7, 2)"}, want: 0}, + {name: "norm of one component", exprs: []string{"norm(-5)"}, want: 5}, + {name: "norm of two components", exprs: []string{"norm(3, 4)"}, want: 5}, + {name: "norm of three components", exprs: []string{"norm(2, 3, 6)"}, want: 7}, + {name: "asinh of zero", exprs: []string{"asinh(0)"}, want: 0}, + {name: "asinh of one", exprs: []string{"asinh(1)"}, want: math.Asinh(1)}, + {name: "asinh of negative", exprs: []string{"asinh(-2)"}, want: math.Asinh(-2)}, + {name: "acosh of one", exprs: []string{"acosh(1)"}, want: 0}, + {name: "acosh of two", exprs: []string{"acosh(2)"}, want: math.Acosh(2)}, + {name: "atanh of zero", exprs: []string{"atanh(0)"}, want: 0}, + {name: "atanh of one half", exprs: []string{"atanh(0.5)"}, want: math.Atanh(0.5)}, + {name: "atanh of negative", exprs: []string{"atanh(-0.5)"}, want: math.Atanh(-0.5)}, + {name: "choose call", exprs: []string{"choose(5, 2)"}, want: 10}, + {name: "choose with zero k", exprs: []string{"choose(10, 0)"}, want: 1}, + {name: "choose with k greater than n", exprs: []string{"choose(4, 5)"}, want: 0}, + {name: "choose of a poker hand", exprs: []string{"choose(52, 5)"}, want: 2598960}, + {name: "perm call", exprs: []string{"perm(5, 2)"}, want: 20}, + {name: "perm with zero k", exprs: []string{"perm(5, 0)"}, want: 1}, + {name: "perm with k greater than n", exprs: []string{"perm(4, 5)"}, want: 0}, } for _, tt := range tests { @@ -406,6 +514,16 @@ func TestParserErrors(t *testing.T) { expr: "$", wantErr: "expected digits after '$'", }, + { + name: "missing comma between arguments", + expr: "max(1 2)", + wantErr: "expected ',' or ')'", + }, + { + name: "unterminated argument list", + expr: "sqrt(4", + wantErr: "expected ',' or ')', got EOF", + }, } for _, tt := range tests { @@ -424,6 +542,146 @@ func TestParserErrors(t *testing.T) { } } +func TestCallEvalErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + expr string + wantErr string + }{ + { + name: "too few arguments", + expr: "sin()", + wantErr: "sin expects 1 argument, got 0", + }, + { + name: "too many arguments", + expr: "sin(1, 2)", + wantErr: "sin expects 1 argument, got 2", + }, + { + name: "unknown function", + expr: "nope(1)", + wantErr: `unknown function "nope"`, + }, + { + name: "domain error", + expr: "sqrt(-1)", + wantErr: "sqrt(-1): argument must be non-negative", + }, + { + name: "ln of non-positive", + expr: "ln(-1)", + wantErr: "ln(-1): argument must be positive", + }, + { + name: "log10 of zero", + expr: "log10(0)", + wantErr: "log10(0): argument must be positive", + }, + { + name: "asin outside unit interval", + expr: "asin(2)", + wantErr: "asin(2): argument must be in [-1, 1]", + }, + { + name: "acos outside unit interval", + expr: "acos(2)", + wantErr: "acos(2): argument must be in [-1, 1]", + }, + { + name: "atan2 at origin", + expr: "atan2(0, 0)", + wantErr: "atan2(0, 0): undefined at the origin", + }, + { + name: "log base one", + expr: "log(8, 1)", + wantErr: "log(8, 1): base must not be equal to one", + }, + { + name: "log with one argument", + expr: "log(8)", + wantErr: "log expects 2 arguments, got 1", + }, + { + name: "factorial of negative", + expr: "factorial(-1)", + wantErr: "factorial(-1): argument must be a non-negative integer", + }, + { + name: "factorial of non-integer", + expr: "factorial(1.5)", + wantErr: "factorial(1.5): argument must be a non-negative integer", + }, + { + name: "gamma at a pole", + expr: "gamma(0)", + wantErr: "gamma(0): argument must not be a non-positive integer", + }, + { + name: "gamma at a negative pole", + expr: "gamma(-2)", + wantErr: "gamma(-2): argument must not be a non-positive integer", + }, + { + name: "min with no arguments", + expr: "min()", + wantErr: "min expects at least 1 argument, got 0", + }, + { + name: "acosh below one", + expr: "acosh(0.5)", + wantErr: "acosh(0.5): argument must be at least 1", + }, + { + name: "atanh at the boundary", + expr: "atanh(1)", + wantErr: "atanh(1): argument must be in (-1, 1)", + }, + { + name: "atanh below the boundary", + expr: "atanh(-2)", + wantErr: "atanh(-2): argument must be in (-1, 1)", + }, + { + name: "lgamma at a pole", + expr: "lgamma(0)", + wantErr: "lgamma(0): argument must not be a non-positive integer", + }, + { + name: "choose of negative", + expr: "choose(-1, 2)", + wantErr: "choose(-1, 2): arguments must be non-negative integers", + }, + { + name: "choose of non-integer", + expr: "choose(5, 1.5)", + wantErr: "choose(5, 1.5): arguments must be non-negative integers", + }, + { + name: "perm of negative", + expr: "perm(-1, 2)", + wantErr: "perm(-1, 2): arguments must be non-negative integers", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := parseAndEval(t, tt.expr, NewEnv()) + if err == nil { + t.Fatalf("expected error containing %q", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("error mismatch: got %v want substring %q", err, tt.wantErr) + } + }) + } +} + func TestEvalUndefinedIdentifier(t *testing.T) { t.Parallel() diff --git a/pkg/calc/parser/tree.go b/pkg/calc/parser/tree.go index 984dd58..5545e25 100644 --- a/pkg/calc/parser/tree.go +++ b/pkg/calc/parser/tree.go @@ -237,6 +237,37 @@ func (n *IdentNode) Eval(env *Env) (*unified.Real, error) { return nil, fmt.Errorf("%s: undefined identifier %q", n.Name.Pos, n.Name.Value) } +type CallNode struct { + Func tokens.Token + Args []Node +} + +func (n *CallNode) Eval(env *Env) (*unified.Real, error) { + fn, ok := functions[n.Func.Value] + if !ok { + return nil, fmt.Errorf("%s: unknown function %q", n.Func.Pos, n.Func.Value) + } + + if len(n.Args) < fn.minArgs || (fn.maxArgs >= 0 && len(n.Args) > fn.maxArgs) { + return nil, fmt.Errorf("%s: %s %s", n.Func.Pos, n.Func.Value, arityError(fn, len(n.Args))) + } + + args := make([]*unified.Real, len(n.Args)) + for i, a := range n.Args { + val, err := a.Eval(env) + if err != nil { + return nil, err + } + args[i] = val + } + + result, err := fn.fn(env, args) + if err != nil { + return nil, fmt.Errorf("%s: %w", n.Func.Pos, domainError(n.Func.Value, args, env.precision, err)) + } + return result, nil +} + type AssignNode struct { Name tokens.Token Value Node diff --git a/pkg/calc/tokens/tokens.go b/pkg/calc/tokens/tokens.go index 91c0a2b..65f3140 100644 --- a/pkg/calc/tokens/tokens.go +++ b/pkg/calc/tokens/tokens.go @@ -52,6 +52,7 @@ const ( LPAREN // ( RPAREN // ) + COMMA // , ) var tokenNames = map[TokenType]string{ @@ -79,6 +80,7 @@ var tokenNames = map[TokenType]string{ LPAREN: "LPAREN", RPAREN: "RPAREN", + COMMA: "COMMA", } func (t TokenType) String() string { diff --git a/pkg/cg/approve/approve.go b/pkg/cg/approve/approve.go index 44edd97..3f1d8a5 100644 --- a/pkg/cg/approve/approve.go +++ b/pkg/cg/approve/approve.go @@ -47,13 +47,15 @@ const ( // Rule is one allow or deny entry. Exactly one of Exact/Prefix/Glob/Regex is // populated after validation; kind records which. Message is valid only on deny -// rules and PermitUnsafeEnvs only on allow rules. +// rules and PermitUnsafeEnvs only on allow rules. AsBasename, valid on both, +// matches the basename form of the subject instead of the canonical form. type Rule struct { Exact []string `yaml:"exact,omitempty"` Prefix []string `yaml:"prefix,omitempty"` Glob string `yaml:"glob,omitempty"` Regex string `yaml:"regex,omitempty"` + AsBasename bool `yaml:"as_basename,omitempty"` Message string `yaml:"message,omitempty"` PermitUnsafeEnvs []string `yaml:"permit_unsafe_envs,omitempty"` @@ -111,6 +113,20 @@ type Store struct { // Ruleset returns the current ruleset the matcher evaluates. func (s *Store) Ruleset() *Ruleset { return s.rules.Load() } +// Subject is the command representation the matcher evaluates. Argv is the +// original argv as invoked. Canonical is the resolved, symlink-evaluated argv +// whose first element is the absolute executable path and whose tail mirrors +// Argv[1:]; it is nil when the executable could not be resolved or canonicalized. +// +// Rules match Canonical by default, so a non-basename rule cannot match when +// Canonical is nil. A rule with AsBasename set instead matches a form derived +// from filepath.Base(Argv[0]), the invoked token, so name-based rules still +// evaluate even when canonicalization fails. +type Subject struct { + Argv []string + Canonical []string +} + // Decision is the matcher's verdict for a command. type Decision int diff --git a/pkg/cg/approve/builtin.go b/pkg/cg/approve/builtin.go index bbc484b..57f8f83 100644 --- a/pkg/cg/approve/builtin.go +++ b/pkg/cg/approve/builtin.go @@ -6,9 +6,10 @@ package approve // would otherwise wave through arbitrary code in an argument the rules do not // introspect. Allowlisting an interpreter allowlists everything it can run. // -// Rules are prefix kind, so the deny-side basename normalization catches the -// program however it is spelled (sh, /bin/sh, ./sh). The two-token interpreter -// rules additionally pin argv[1] to the eval flag. +// Rules are prefix kind with as_basename set, so they match the invoked token's +// basename and catch the program however it is spelled (sh, /bin/sh, ./sh), even +// when /bin/sh is a symlink to dash or busybox. The two-token interpreter rules +// additionally pin argv[1] to the eval flag. func builtinDenyRules() []Rule { tokens := [][]string{ {"sh"}, @@ -26,7 +27,7 @@ func builtinDenyRules() []Rule { rules := make([]Rule, len(tokens)) for i, t := range tokens { - rules[i] = Rule{Prefix: t, kind: KindPrefix} + rules[i] = Rule{Prefix: t, AsBasename: true, kind: KindPrefix} } return rules diff --git a/pkg/cg/approve/builtin_test.go b/pkg/cg/approve/builtin_test.go index df96e52..19faa35 100644 --- a/pkg/cg/approve/builtin_test.go +++ b/pkg/cg/approve/builtin_test.go @@ -41,7 +41,7 @@ func TestBuiltinDeny(t *testing.T) { for _, tt := range builtinTests { t.Run(tt.name, func(t *testing.T) { - if got := rs.Match(tt.argv); got.Decision != tt.want { + if got := rs.Match(identitySubject(tt.argv)); got.Decision != tt.want { t.Errorf("Match(%v) = %v, want %v", tt.argv, got.Decision, tt.want) } }) @@ -58,7 +58,7 @@ func TestBuiltinDenyNotOverridable(t *testing.T) { Deny: builtinDenyRules(), Allow: []Rule{prefixRule("sh")}, } - if got := rs.Match([]string{"sh", "-c", "x"}); got.Decision != DecisionRefuse { + if got := rs.Match(identitySubject([]string{"sh", "-c", "x"})); got.Decision != DecisionRefuse { t.Errorf("Match(sh -c x) = %v, want refuse (builtin deny not overridable)", got.Decision) } } diff --git a/pkg/cg/approve/matcher.go b/pkg/cg/approve/matcher.go index d1fe93c..43475ef 100644 --- a/pkg/cg/approve/matcher.go +++ b/pkg/cg/approve/matcher.go @@ -2,18 +2,31 @@ package approve import ( "path/filepath" - "strings" "github.com/ripta/rt/pkg/cg" ) -// Match evaluates argv against the frozen ruleset and returns the verdict. mode -// allow-all and deny-all short-circuit; otherwise the first matching deny rule -// refuses, the first matching allow rule runs, and no match prompts. Deny is +// matchForm is one argv representation a rule can match against: the token slice +// and its precomputed quoted join. ok is false for the canonical form when the +// subject could not be canonicalized, which makes non-basename rules skip it. +type matchForm struct { + argv []string + quoted string + ok bool +} + +// Match evaluates a subject against the frozen ruleset and returns the verdict. +// mode allow-all and deny-all short-circuit; otherwise the first matching deny +// rule refuses, the first matching allow rule runs, and no match prompts. Deny is // evaluated in full before allow, so a deny match always wins, including across // layers. -func (rs *Ruleset) Match(argv []string) MatchResult { - if len(argv) == 0 { +// +// Rules match the canonical form by default; a rule with AsBasename set matches +// the basename form instead. When the subject has no canonical form, non-basename +// rules cannot match, so a command with an unknown executable identity is never +// allowed by canonical policy and falls through to prompt or fail-closed. +func (rs *Ruleset) Match(subj Subject) MatchResult { + if len(subj.Argv) == 0 { return MatchResult{Decision: DecisionRefuse} } @@ -24,15 +37,15 @@ func (rs *Ruleset) Match(argv []string) MatchResult { return MatchResult{Decision: DecisionRefuse} } - quoted := cg.EscapeArgs(argv) + canonical, basename := subj.forms() for i := range rs.Deny { - if ruleMatches(&rs.Deny[i], argv, quoted, true) { + if ruleMatches(&rs.Deny[i], canonical, basename) { return MatchResult{Decision: DecisionRefuse, Rule: &rs.Deny[i]} } } for i := range rs.Allow { - if ruleMatches(&rs.Allow[i], argv, quoted, false) { + if ruleMatches(&rs.Allow[i], canonical, basename) { return MatchResult{Decision: DecisionRun, Rule: &rs.Allow[i]} } } @@ -40,25 +53,51 @@ func (rs *Ruleset) Match(argv []string) MatchResult { return MatchResult{Decision: DecisionPrompt} } -// ruleMatches reports whether a single rule matches the command. exact and -// prefix compare argv tokens; glob and regex match the precomputed quoted join. -func ruleMatches(rule *Rule, argv []string, quoted string, isDeny bool) bool { +// forms builds the canonical and basename match forms once per Match call. The +// canonical form is unavailable when Canonical is nil. The basename form replaces +// only Argv[0] with its basename, the invoked token, and leaves the tail intact. +func (s Subject) forms() (canonical, basename matchForm) { + if s.Canonical != nil { + canonical = matchForm{argv: s.Canonical, quoted: cg.EscapeArgs(s.Canonical), ok: true} + } + + base := make([]string, len(s.Argv)) + copy(base, s.Argv) + base[0] = filepath.Base(s.Argv[0]) + basename = matchForm{argv: base, quoted: cg.EscapeArgs(base), ok: true} + + return canonical, basename +} + +// ruleMatches reports whether a single rule matches the subject. AsBasename +// selects the basename form; otherwise the rule matches the canonical form, which +// it cannot do when that form is unavailable. exact and prefix compare argv +// tokens; glob and regex match the precomputed quoted join. +func ruleMatches(rule *Rule, canonical, basename matchForm) bool { + form := canonical + if rule.AsBasename { + form = basename + } + if !form.ok { + return false + } + switch rule.kind { case KindExact: - return matchTokens(rule.Exact, argv, true, isDeny) + return matchTokens(rule.Exact, form.argv, true) case KindPrefix: - return matchTokens(rule.Prefix, argv, false, isDeny) + return matchTokens(rule.Prefix, form.argv, false) case KindGlob, KindRegex: - return rule.compiled != nil && rule.compiled.MatchString(quoted) + return rule.compiled != nil && rule.compiled.MatchString(form.quoted) } return false } // matchTokens compares rule tokens against argv element-wise. exact requires -// equal length; prefix requires argv to be at least as long as the rule. argv[0] -// uses the asymmetric program-token normalization; argv[1:] compare byte-exact. -func matchTokens(tokens, argv []string, exact, isDeny bool) bool { +// equal length; prefix requires argv to be at least as long as the rule. Every +// token, including argv[0], compares byte-exact. +func matchTokens(tokens, argv []string, exact bool) bool { if len(tokens) == 0 { return false } @@ -69,10 +108,7 @@ func matchTokens(tokens, argv []string, exact, isDeny bool) bool { return false } - if !programTokenMatches(tokens[0], argv[0], isDeny) { - return false - } - for i := 1; i < len(tokens); i++ { + for i := range tokens { if tokens[i] != argv[i] { return false } @@ -80,23 +116,3 @@ func matchTokens(tokens, argv []string, exact, isDeny bool) bool { return true } - -// programTokenMatches applies the asymmetric program-token normalization to -// argv[0]. A slash-bearing rule token matches argv[0] literally. A no-slash rule -// token broadens a deny to basename(argv[0]) so [sh] catches sh, /bin/sh, and -// ./sh; on an allow it matches only when argv[0] itself has no slash, so a -// path-qualified program falls through to a prompt rather than being -// rubber-stamped. -func programTokenMatches(token, argv0 string, isDeny bool) bool { - if strings.ContainsRune(token, '/') { - return token == argv0 - } - if isDeny { - return filepath.Base(argv0) == token - } - if strings.ContainsRune(argv0, '/') { - return false - } - - return token == argv0 -} diff --git a/pkg/cg/approve/matcher_test.go b/pkg/cg/approve/matcher_test.go index f914088..9e45982 100644 --- a/pkg/cg/approve/matcher_test.go +++ b/pkg/cg/approve/matcher_test.go @@ -9,6 +9,12 @@ func exactRule(tokens ...string) Rule { return Rule{Exact: tokens, kind: KindExa func prefixRule(tokens ...string) Rule { return Rule{Prefix: tokens, kind: KindPrefix} } +// asBase marks a rule as basename-matching, the as_basename: true form. +func asBase(r Rule) Rule { + r.AsBasename = true + return r +} + func globRule(t *testing.T, pattern string) Rule { t.Helper() r := Rule{Glob: pattern, kind: KindGlob} @@ -29,12 +35,23 @@ func regexRule(t *testing.T, pattern string) Rule { return r } +// identitySubject builds a subject whose canonical form equals its argv, for +// tests that do not exercise resolution. Used across the approve test files. +func identitySubject(argv []string) Subject { + return Subject{Argv: argv, Canonical: argv} +} + type matchTest struct { - name string - mode Mode - deny []Rule - allow []Rule - argv []string + name string + mode Mode + deny []Rule + allow []Rule + argv []string + // canonical overrides the canonical form; when nil it defaults to argv. + canonical []string + // unresolved leaves the canonical form unavailable, as when canonicalization + // fails. It takes precedence over canonical. + unresolved bool want Decision wantMessage string } @@ -49,43 +66,53 @@ func TestMatch(t *testing.T) { {name: "allow-all overrides deny", mode: ModeAllowAll, deny: []Rule{prefixRule("rm")}, argv: []string{"rm", "-rf", "/"}, want: DecisionRun}, {name: "deny-all overrides allow", mode: ModeDenyAll, allow: []Rule{prefixRule("git")}, argv: []string{"git", "status"}, want: DecisionRefuse}, - // exact + // exact against the canonical form {name: "exact match", allow: []Rule{exactRule("git", "status")}, argv: []string{"git", "status"}, want: DecisionRun}, {name: "exact longer argv no match", allow: []Rule{exactRule("git", "status")}, argv: []string{"git", "status", "-s"}, want: DecisionPrompt}, {name: "exact different arg no match", allow: []Rule{exactRule("git", "status")}, argv: []string{"git", "log"}, want: DecisionPrompt}, - // prefix + // prefix against the canonical form {name: "prefix match with extra args", allow: []Rule{prefixRule("go", "test")}, argv: []string{"go", "test", "./..."}, want: DecisionRun}, {name: "prefix argv shorter no match", allow: []Rule{prefixRule("go", "test")}, argv: []string{"go"}, want: DecisionPrompt}, {name: "prefix differing token no match", allow: []Rule{prefixRule("go", "test")}, argv: []string{"go", "vet"}, want: DecisionPrompt}, - // deny program-token normalization (basename broadens) - {name: "deny sh plain", deny: []Rule{prefixRule("sh")}, argv: []string{"sh", "-c", "x"}, want: DecisionRefuse}, - {name: "deny sh absolute path", deny: []Rule{prefixRule("sh")}, argv: []string{"/bin/sh", "-c", "x"}, want: DecisionRefuse}, - {name: "deny sh relative path", deny: []Rule{prefixRule("sh")}, argv: []string{"./sh", "-c", "x"}, want: DecisionRefuse}, - - // allow program-token normalization (slash in argv0 falls through) - {name: "allow make plain", allow: []Rule{prefixRule("make")}, argv: []string{"make"}, want: DecisionRun}, - {name: "allow make planted absolute path", allow: []Rule{prefixRule("make")}, argv: []string{"/tmp/evil/make"}, want: DecisionPrompt}, - {name: "allow make relative path", allow: []Rule{prefixRule("make")}, argv: []string{"./make"}, want: DecisionPrompt}, + // canonical path policy: rules match the resolved executable path + {name: "canonical path allow", allow: []Rule{prefixRule("/opt/foo/bin/foo")}, argv: []string{"foo"}, canonical: []string{"/opt/foo/bin/foo"}, want: DecisionRun}, + {name: "canonical path allow with tail", allow: []Rule{prefixRule("/opt/foo/bin/foo")}, argv: []string{"foo", "--bar"}, canonical: []string{"/opt/foo/bin/foo", "--bar"}, want: DecisionRun}, + {name: "bare token does not match canonical path", allow: []Rule{prefixRule("foo")}, argv: []string{"foo"}, canonical: []string{"/opt/foo/bin/foo"}, want: DecisionPrompt}, + {name: "canonical exact full path", allow: []Rule{exactRule("/usr/bin/git", "status")}, argv: []string{"git", "status"}, canonical: []string{"/usr/bin/git", "status"}, want: DecisionRun}, - // rule token with slash matches literally - {name: "deny literal path match", deny: []Rule{prefixRule("/bin/sh")}, argv: []string{"/bin/sh"}, want: DecisionRefuse}, - {name: "deny literal path no bare match", deny: []Rule{prefixRule("/bin/sh")}, argv: []string{"sh"}, want: DecisionPrompt}, - {name: "deny literal path other path no match", deny: []Rule{prefixRule("/bin/sh")}, argv: []string{"/usr/bin/sh"}, want: DecisionPrompt}, - {name: "allow literal path exact", allow: []Rule{exactRule("/usr/bin/git", "status")}, argv: []string{"/usr/bin/git", "status"}, want: DecisionRun}, - {name: "allow literal path bare no match", allow: []Rule{exactRule("/usr/bin/git", "status")}, argv: []string{"git", "status"}, want: DecisionPrompt}, + // element-wise comparison includes argv[0]; no implicit normalization + {name: "literal path prefix match", deny: []Rule{prefixRule("/bin/sh")}, argv: []string{"/bin/sh"}, want: DecisionRefuse}, + {name: "literal path no bare match", deny: []Rule{prefixRule("/bin/sh")}, argv: []string{"sh"}, want: DecisionPrompt}, + {name: "literal path other path no match", deny: []Rule{prefixRule("/bin/sh")}, argv: []string{"/usr/bin/sh"}, want: DecisionPrompt}, - // argv[1:] compares byte-exact, no basename normalization + // argv[1:] compares byte-exact {name: "deny rm -rf exact tail", deny: []Rule{prefixRule("rm", "-rf")}, argv: []string{"rm", "-rf"}, want: DecisionRefuse}, {name: "deny rm -rf with target", deny: []Rule{prefixRule("rm", "-rf")}, argv: []string{"rm", "-rf", "/tmp"}, want: DecisionRefuse}, {name: "deny rm -rf differing flag", deny: []Rule{prefixRule("rm", "-rf")}, argv: []string{"rm", "-r"}, want: DecisionPrompt}, + // as_basename matches the invoked token's basename, however it is spelled + {name: "basename deny plain", deny: []Rule{asBase(prefixRule("sh"))}, argv: []string{"sh", "-c", "x"}, want: DecisionRefuse}, + {name: "basename deny absolute path", deny: []Rule{asBase(prefixRule("sh"))}, argv: []string{"/bin/sh", "-c", "x"}, canonical: []string{"/bin/dash", "-c", "x"}, want: DecisionRefuse}, + {name: "basename deny relative path", deny: []Rule{asBase(prefixRule("sh"))}, argv: []string{"./sh", "-c", "x"}, want: DecisionRefuse}, + {name: "basename allow ignores install path", allow: []Rule{asBase(prefixRule("make"))}, argv: []string{"/tmp/evil/make"}, canonical: []string{"/tmp/evil/make"}, want: DecisionRun}, + {name: "basename allow exact", allow: []Rule{asBase(exactRule("go", "version"))}, argv: []string{"/usr/local/go/bin/go", "version"}, canonical: []string{"/usr/local/go/bin/go", "version"}, want: DecisionRun}, + + // canonical unavailable: non-basename rules cannot match, basename can + {name: "unresolved non-basename allow falls through", allow: []Rule{prefixRule("/opt/foo")}, argv: []string{"foo"}, unresolved: true, want: DecisionPrompt}, + {name: "unresolved non-basename deny does not fire", deny: []Rule{prefixRule("/tmp/x")}, allow: []Rule{asBase(prefixRule("foo"))}, argv: []string{"foo"}, unresolved: true, want: DecisionRun}, + {name: "unresolved basename deny still fires", deny: []Rule{asBase(prefixRule("sh"))}, argv: []string{"sh", "-c", "x"}, unresolved: true, want: DecisionRefuse}, + // deny precedence and layering {name: "deny wins over allow", deny: []Rule{prefixRule("git", "push", "--force")}, allow: []Rule{prefixRule("git")}, argv: []string{"git", "push", "--force"}, want: DecisionRefuse}, {name: "allow when no deny matches", deny: []Rule{prefixRule("git", "push", "--force")}, allow: []Rule{prefixRule("git")}, argv: []string{"git", "status"}, want: DecisionRun}, {name: "no match prompts", allow: []Rule{prefixRule("go")}, argv: []string{"cargo", "build"}, want: DecisionPrompt}, + // deny wins across canonical and basename forms + {name: "basename deny beats canonical allow", deny: []Rule{asBase(prefixRule("sh"))}, allow: []Rule{prefixRule("/bin/sh")}, argv: []string{"/bin/sh", "-c", "x"}, canonical: []string{"/bin/sh", "-c", "x"}, want: DecisionRefuse}, + {name: "canonical path deny beats basename allow", deny: []Rule{prefixRule("/tmp/make")}, allow: []Rule{asBase(prefixRule("make"))}, argv: []string{"make"}, canonical: []string{"/tmp/make"}, want: DecisionRefuse}, + // deny message propagation {name: "deny message surfaced", deny: []Rule{{Prefix: []string{"rm", "-rf"}, Message: "delete specific paths", kind: KindPrefix}}, argv: []string{"rm", "-rf", "/"}, want: DecisionRefuse, wantMessage: "delete specific paths"}, } @@ -101,20 +128,30 @@ func TestMatchPatterns(t *testing.T) { t.Parallel() tests := []matchTest{ - // glob is fully anchored + // glob is fully anchored, over the canonical join {name: "glob trailing star matches tail", allow: []Rule{globRule(t, "kubectl get *")}, argv: []string{"kubectl", "get", "pods", "-n", "x"}, want: DecisionRun}, {name: "glob different verb no match", allow: []Rule{globRule(t, "kubectl get *")}, argv: []string{"kubectl", "describe", "pods"}, want: DecisionPrompt}, {name: "glob no wildcard exact", allow: []Rule{globRule(t, "make")}, argv: []string{"make"}, want: DecisionRun}, {name: "glob no wildcard rejects extra", allow: []Rule{globRule(t, "make")}, argv: []string{"make", "build"}, want: DecisionPrompt}, {name: "glob over quoted join", allow: []Rule{globRule(t, "git commit -m *")}, argv: []string{"git", "commit", "-m", "hello world"}, want: DecisionRun}, - // regex is unanchored search + // regex is unanchored search, over the canonical join {name: "regex sudo leading", deny: []Rule{regexRule(t, `(^|\s)sudo(\s|$)`)}, argv: []string{"sudo", "rm"}, want: DecisionRefuse}, {name: "regex sudo substring no match", deny: []Rule{regexRule(t, `(^|\s)sudo(\s|$)`)}, argv: []string{"mysudo", "foo"}, want: DecisionPrompt}, {name: "regex sudo mid line", deny: []Rule{regexRule(t, `(^|\s)sudo(\s|$)`)}, argv: []string{"echo", "sudo", "hi"}, want: DecisionRefuse}, {name: "regex npm test", allow: []Rule{regexRule(t, `^npm (run )?(test|lint)$`)}, argv: []string{"npm", "test"}, want: DecisionRun}, {name: "regex npm run lint", allow: []Rule{regexRule(t, `^npm (run )?(test|lint)$`)}, argv: []string{"npm", "run", "lint"}, want: DecisionRun}, {name: "regex npm install no match", allow: []Rule{regexRule(t, `^npm (run )?(test|lint)$`)}, argv: []string{"npm", "install"}, want: DecisionPrompt}, + + // regex over the canonical path: directory policies + {name: "regex deny tmp directory", deny: []Rule{regexRule(t, `^/tmp/`)}, argv: []string{"x"}, canonical: []string{"/tmp/x"}, want: DecisionRefuse}, + {name: "regex allow opt directory", allow: []Rule{regexRule(t, `^/opt/foo/bin/`)}, argv: []string{"foo"}, canonical: []string{"/opt/foo/bin/foo"}, want: DecisionRun}, + {name: "regex path bare token no match", allow: []Rule{regexRule(t, `^go test`)}, argv: []string{"go", "test"}, canonical: []string{"/usr/bin/go", "test"}, want: DecisionPrompt}, + + // as_basename glob/regex match the basename join + {name: "basename regex matches by name", allow: []Rule{asBase(regexRule(t, `^go test`))}, argv: []string{"go", "test"}, canonical: []string{"/usr/bin/go", "test"}, want: DecisionRun}, + {name: "basename regex deny sudo by name", deny: []Rule{asBase(regexRule(t, `^sudo(\s|$)`))}, argv: []string{"/usr/bin/sudo", "rm"}, canonical: []string{"/usr/bin/sudo", "rm"}, want: DecisionRefuse}, + {name: "basename glob by name", allow: []Rule{asBase(globRule(t, "kubectl get *"))}, argv: []string{"/usr/local/bin/kubectl", "get", "pods"}, canonical: []string{"/usr/local/bin/kubectl", "get", "pods"}, want: DecisionRun}, } for _, tt := range tests { @@ -124,10 +161,24 @@ func TestMatchPatterns(t *testing.T) { } } +// subjectFor builds the match subject for a test case. unresolved leaves the +// canonical form nil; otherwise canonical defaults to argv. +func subjectFor(tt matchTest) Subject { + if tt.unresolved { + return Subject{Argv: tt.argv} + } + canonical := tt.canonical + if canonical == nil { + canonical = tt.argv + } + + return Subject{Argv: tt.argv, Canonical: canonical} +} + func runMatchCase(t *testing.T, tt matchTest) { t.Helper() rs := &Ruleset{Mode: tt.mode, Deny: tt.deny, Allow: tt.allow} - got := rs.Match(tt.argv) + got := rs.Match(subjectFor(tt)) if got.Decision != tt.want { t.Fatalf("Match(%v) decision = %v, want %v", tt.argv, got.Decision, tt.want) } @@ -144,7 +195,7 @@ func runMatchCase(t *testing.T, tt matchTest) { func TestMatchEmptyArgv(t *testing.T) { t.Parallel() rs := &Ruleset{Mode: ModeEnforce} - if got := rs.Match(nil); got.Decision != DecisionRefuse { - t.Errorf("Match(nil) = %v, want refuse", got.Decision) + if got := rs.Match(Subject{}); got.Decision != DecisionRefuse { + t.Errorf("Match(empty) = %v, want refuse", got.Decision) } } diff --git a/pkg/cg/approve/persist.go b/pkg/cg/approve/persist.go index 9afac36..d7d1f35 100644 --- a/pkg/cg/approve/persist.go +++ b/pkg/cg/approve/persist.go @@ -95,8 +95,9 @@ func (s *Store) AppendProjectAllowPrefix(tokens []string, strategy WriteStrategy if root == nil { return fmt.Errorf("project document has no root mapping") } + asBasename := prefixWantsBasename(tokens) allowSeq := ensureSeq(root, "allow") - allowSeq.Content = append(allowSeq.Content, buildPrefixEntry(tokens)) + allowSeq.Content = append(allowSeq.Content, buildPrefixEntry(tokens, asBasename)) canonicalizeRuleSeqs(root) data, err := renderDocument(doc) @@ -116,11 +117,24 @@ func (s *Store) AppendProjectAllowPrefix(tokens []string, strategy WriteStrategy s.Project.Snapshot = data s.Project.Present = true - s.appendLiveAllow(tokens) + s.appendLiveAllow(tokens, asBasename) return nil } +// prefixWantsBasename reports whether a remembered prefix rule should match by +// basename. A name-based program token, one with no slash, is matched by the +// invoked basename so the rule fires however the program is spelled or resolved; +// a path-bearing token describes an install location and matches the canonical +// path instead. +func prefixWantsBasename(tokens []string) bool { + if len(tokens) == 0 { + return false + } + + return !strings.ContainsRune(tokens[0], filepath.Separator) && !strings.ContainsRune(tokens[0], '/') +} + // baseDocument returns the document node the new rule is appended to. The direct // and overwrite strategies build on cg's in-memory view; reload-merge rebases on // the current on-disk file and unions in-memory rules back in. @@ -186,8 +200,10 @@ func mergeAbsent(seq *yaml.Node, mem, disk []Rule, isDeny bool) { // appendLiveAllow swaps a new ruleset with tokens appended to Allow into the // atomic pointer. The deny slice is shared because it is never mutated; the // allow slice is copied so the live snapshot the matcher reads stays immutable. -func (s *Store) appendLiveAllow(tokens []string) { - rule := Rule{Prefix: slices.Clone(tokens), kind: KindPrefix} +// asBasename mirrors what was written to disk so the live rule matches the same +// commands the reloaded file would. +func (s *Store) appendLiveAllow(tokens []string, asBasename bool) { + rule := Rule{Prefix: slices.Clone(tokens), AsBasename: asBasename, kind: KindPrefix} cur := s.rules.Load() next := &Ruleset{ Mode: cur.Mode, @@ -230,15 +246,16 @@ func ruleEqual(a, b *Rule) bool { return false } -// buildPrefixEntry builds the mapping node for a prefix allow rule. -func buildPrefixEntry(tokens []string) *yaml.Node { - return buildRuleEntry(&Rule{Prefix: tokens, kind: KindPrefix}, false) +// buildPrefixEntry builds the mapping node for a prefix allow rule, marking it +// as_basename when the rule should match by basename. +func buildPrefixEntry(tokens []string, asBasename bool) *yaml.Node { + return buildRuleEntry(&Rule{Prefix: tokens, AsBasename: asBasename, kind: KindPrefix}, false) } -// buildRuleEntry builds the YAML mapping node for a rule: its single kind key -// plus any message (deny) or permit_unsafe_envs (allow). Every scalar carries an -// explicit !!str tag so a token like yes or 123 round-trips as a string rather -// than re-parsing as a bool or int. +// buildRuleEntry builds the YAML mapping node for a rule: its single kind key, +// an as_basename flag when set, plus any message (deny) or permit_unsafe_envs +// (allow). Every scalar carries an explicit !!str tag so a token like yes or 123 +// round-trips as a string rather than re-parsing as a bool or int. func buildRuleEntry(rule *Rule, isDeny bool) *yaml.Node { entry := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} @@ -253,6 +270,9 @@ func buildRuleEntry(rule *Rule, isDeny bool) *yaml.Node { entry.Content = append(entry.Content, strScalar("regex"), strScalar(rule.Regex)) } + if rule.AsBasename { + entry.Content = append(entry.Content, strScalar("as_basename"), boolScalar(true)) + } if isDeny && rule.Message != "" { entry.Content = append(entry.Content, strScalar("message"), strScalar(rule.Message)) } @@ -397,6 +417,16 @@ func strScalar(v string) *yaml.Node { return &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: v} } +// boolScalar builds a boolean scalar node with an explicit tag. +func boolScalar(v bool) *yaml.Node { + val := "false" + if v { + val = "true" + } + + return &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: val} +} + // flowSeq builds a flow-style string sequence node ([a, b]) so token lists // render compactly, matching the canonical format the proposal documents. func flowSeq(tokens []string) *yaml.Node { diff --git a/pkg/cg/approve/persist_test.go b/pkg/cg/approve/persist_test.go index 810c043..b9589ac 100644 --- a/pkg/cg/approve/persist_test.go +++ b/pkg/cg/approve/persist_test.go @@ -85,7 +85,7 @@ func TestAppendCreatesProjectFile(t *testing.T) { // The file must reload cleanly and the rule must be live. reloaded := loadProject(t, root) - if got := reloaded.Ruleset().Match([]string{"make", "build"}); got.Decision != DecisionRun { + if got := reloaded.Ruleset().Match(identitySubject([]string{"make", "build"})); got.Decision != DecisionRun { t.Errorf("reloaded match = %v, want run", got.Decision) } } @@ -96,13 +96,13 @@ func TestAppendUpdatesLiveMatcher(t *testing.T) { root := t.TempDir() s := loadProject(t, root) - if got := s.Ruleset().Match([]string{"make"}); got.Decision != DecisionPrompt { + if got := s.Ruleset().Match(identitySubject([]string{"make"})); got.Decision != DecisionPrompt { t.Fatalf("before append, match = %v, want prompt", got.Decision) } if err := s.AppendProjectAllowPrefix([]string{"make"}, WriteDirect); err != nil { t.Fatalf("AppendProjectAllowPrefix: %v", err) } - if got := s.Ruleset().Match([]string{"make"}); got.Decision != DecisionRun { + if got := s.Ruleset().Match(identitySubject([]string{"make"})); got.Decision != DecisionRun { t.Errorf("after append, match = %v, want run (live swap)", got.Decision) } } @@ -126,8 +126,9 @@ func TestAppendPreservesCommentsAndQuotesTokens(t *testing.T) { if !strings.Contains(got, "# keep this comment") { t.Errorf("comment not preserved:\n%s", got) } - // yes must round-trip as a string, not the boolean true. - if !strings.Contains(got, "yes") || strings.Contains(got, "true") { + // yes must round-trip as a string, not the boolean true. The rule itself + // carries as_basename: true, so check the token rendering specifically. + if !strings.Contains(got, "[weird, yes]") || strings.Contains(got, "[weird, true]") { t.Errorf("token yes mis-rendered:\n%s", got) } @@ -136,10 +137,10 @@ func TestAppendPreservesCommentsAndQuotesTokens(t *testing.T) { if len(reloaded.Project.Doc.Allow) != 2 { t.Errorf("allow count = %d, want 2", len(reloaded.Project.Doc.Allow)) } - if got := reloaded.Ruleset().Match([]string{"go", "test"}); got.Decision != DecisionRun { + if got := reloaded.Ruleset().Match(identitySubject([]string{"go", "test"})); got.Decision != DecisionRun { t.Errorf("existing rule lost: match = %v, want run", got.Decision) } - if got := reloaded.Ruleset().Match([]string{"weird", "yes"}); got.Decision != DecisionRun { + if got := reloaded.Ruleset().Match(identitySubject([]string{"weird", "yes"})); got.Decision != DecisionRun { t.Errorf("new rule missing: match = %v, want run", got.Decision) } } @@ -259,7 +260,7 @@ func TestReloadMergeUnionsRules(t *testing.T) { {"cargo", "build", "--release"}, } for _, argv := range wants { - if got := reloaded.Ruleset().Match(argv); got.Decision != DecisionRun { + if got := reloaded.Ruleset().Match(identitySubject(argv)); got.Decision != DecisionRun { t.Errorf("after reload-merge, match %v = %v, want run", argv, got.Decision) } } @@ -281,13 +282,13 @@ func TestOverwriteDropsDiskChanges(t *testing.T) { } reloaded := loadProject(t, root) - if got := reloaded.Ruleset().Match([]string{"npm", "ci"}); got.Decision == DecisionRun { + if got := reloaded.Ruleset().Match(identitySubject([]string{"npm", "ci"})); got.Decision == DecisionRun { t.Errorf("overwrite kept the dropped on-disk rule") } - if got := reloaded.Ruleset().Match([]string{"go", "vet"}); got.Decision != DecisionRun { + if got := reloaded.Ruleset().Match(identitySubject([]string{"go", "vet"})); got.Decision != DecisionRun { t.Errorf("overwrite did not add the new rule") } - if got := reloaded.Ruleset().Match([]string{"make"}); got.Decision != DecisionRun { + if got := reloaded.Ruleset().Match(identitySubject([]string{"make"})); got.Decision != DecisionRun { t.Errorf("overwrite lost the in-memory make rule") } } diff --git a/pkg/cg/execresolve.go b/pkg/cg/execresolve.go index 4d96f95..868b341 100644 --- a/pkg/cg/execresolve.go +++ b/pkg/cg/execresolve.go @@ -82,6 +82,21 @@ func resolveExecPath(argv0, cwd string) (string, error) { return filepath.Join(base, argv0), nil } +// CanonicalArgv is the argv the approval matcher evaluates against canonical +// rules: the symlink-evaluated executable path followed by the original argument +// tail. It is nil when canonicalization did not succeed, so a command with an +// unknown executable identity cannot be allowed by canonical policy. +func (r *Resolution) CanonicalArgv() []string { + if r == nil || r.Canonical == "" || len(r.Argv) == 0 { + return nil + } + + out := make([]string, len(r.Argv)) + out[0] = r.Canonical + copy(out[1:], r.Argv[1:]) + return out +} + // ExecPath is the path RunCapture execs: the canonical path when available, then // the resolved path, falling back to the original argv[0] so a command that // could not be resolved still surfaces its start failure through exec. diff --git a/pkg/cg/execresolve_test.go b/pkg/cg/execresolve_test.go new file mode 100644 index 0000000..7f0072f --- /dev/null +++ b/pkg/cg/execresolve_test.go @@ -0,0 +1,232 @@ +package cg + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +// plantExec writes an executable file named name under dir and returns its path. +func plantExec(t *testing.T, dir, name string) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte("#!/bin/sh\necho hi\n"), 0o755); err != nil { + t.Fatalf("planting %s: %v", path, err) + } + return path +} + +// evalSymlinks resolves any symlinks in path so comparisons survive macOS, where +// the temp root (/var/folders/...) is itself a symlink into /private. +func evalSymlinks(t *testing.T, path string) string { + t.Helper() + got, err := filepath.EvalSymlinks(path) + if err != nil { + t.Fatalf("EvalSymlinks(%s): %v", path, err) + } + return got +} + +func TestResolveCommandBarePath(t *testing.T) { + dir := t.TempDir() + planted := plantExec(t, dir, "tool") + t.Setenv("PATH", dir) + + r, err := ResolveCommand([]string{"tool", "--flag"}, "") + if err != nil { + t.Fatalf("ResolveCommand: %v", err) + } + if r.Resolved == "" || filepath.Base(r.Resolved) != "tool" { + t.Errorf("Resolved = %q, want an absolute path ending in tool", r.Resolved) + } + if want := evalSymlinks(t, planted); r.Canonical != want { + t.Errorf("Canonical = %q, want %q", r.Canonical, want) + } + if got := r.CanonicalArgv(); len(got) != 2 || got[1] != "--flag" { + t.Errorf("CanonicalArgv tail = %v, want [.. --flag]", got) + } +} + +func TestResolveCommandBareNotFound(t *testing.T) { + t.Setenv("PATH", t.TempDir()) + + r, err := ResolveCommand([]string{"definitely-not-on-path-zzz"}, "") + if err == nil { + t.Fatalf("expected error for unresolvable command") + } + if r.Resolved != "" || r.Canonical != "" { + t.Errorf("Resolved=%q Canonical=%q, want both empty", r.Resolved, r.Canonical) + } + if len(r.Argv) != 1 { + t.Errorf("Argv = %v, want the original command preserved", r.Argv) + } +} + +func TestResolveCommandAbsolutePath(t *testing.T) { + dir := t.TempDir() + planted := plantExec(t, dir, "tool") + + r, err := ResolveCommand([]string{planted, "x"}, "") + if err != nil { + t.Fatalf("ResolveCommand: %v", err) + } + if r.Resolved != filepath.Clean(planted) { + t.Errorf("Resolved = %q, want %q", r.Resolved, filepath.Clean(planted)) + } + if want := evalSymlinks(t, planted); r.Canonical != want { + t.Errorf("Canonical = %q, want %q", r.Canonical, want) + } +} + +func TestResolveCommandRelativeUsesCwd(t *testing.T) { + dir := t.TempDir() + planted := plantExec(t, dir, "tool") + + r, err := ResolveCommand([]string{"./tool"}, dir) + if err != nil { + t.Fatalf("ResolveCommand: %v", err) + } + if r.Resolved != filepath.Join(dir, "tool") { + t.Errorf("Resolved = %q, want %q", r.Resolved, filepath.Join(dir, "tool")) + } + if want := evalSymlinks(t, planted); r.Canonical != want { + t.Errorf("Canonical = %q, want %q", r.Canonical, want) + } +} + +func TestResolveExecPathRelativeUsesServerCwd(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + + got, err := resolveExecPath("./tool", "") + if err != nil { + t.Fatalf("resolveExecPath: %v", err) + } + if want := filepath.Join(wd, "tool"); got != want { + t.Errorf("resolveExecPath(./tool, \"\") = %q, want %q", got, want) + } +} + +func TestResolveCommandSymlinkCanonicalizes(t *testing.T) { + realDir := t.TempDir() + linkDir := t.TempDir() + real := plantExec(t, realDir, "tool") + link := filepath.Join(linkDir, "tool") + if err := os.Symlink(real, link); err != nil { + t.Fatalf("symlink: %v", err) + } + + r, err := ResolveCommand([]string{link}, "") + if err != nil { + t.Fatalf("ResolveCommand: %v", err) + } + if r.Resolved != link { + t.Errorf("Resolved = %q, want the invoked link %q", r.Resolved, link) + } + if want := evalSymlinks(t, real); r.Canonical != want { + t.Errorf("Canonical = %q, want the symlink target %q", r.Canonical, want) + } +} + +func TestResolveCommandCanonicalizeFailure(t *testing.T) { + missing := filepath.Join(t.TempDir(), "nope", "tool") + + r, err := ResolveCommand([]string{missing}, "") + if err == nil { + t.Fatalf("expected canonicalization error for nonexistent path") + } + if r.Resolved != filepath.Clean(missing) { + t.Errorf("Resolved = %q, want %q", r.Resolved, filepath.Clean(missing)) + } + if r.Canonical != "" { + t.Errorf("Canonical = %q, want empty on canonicalization failure", r.Canonical) + } +} + +func TestResolveCommandEmptyArgv(t *testing.T) { + if _, err := ResolveCommand(nil, ""); err == nil { + t.Fatalf("expected error for empty argv") + } +} + +// TestResolveCommandUsesServerPathOnly plants the same program name in two +// directories and confirms resolution follows the server PATH. ResolveCommand +// takes no env argument, so a caller-supplied env.PATH cannot redirect the +// top-level executable that gets approved and execed. +func TestResolveCommandUsesServerPathOnly(t *testing.T) { + serverDir := t.TempDir() + otherDir := t.TempDir() + serverTool := plantExec(t, serverDir, "tool") + plantExec(t, otherDir, "tool") + t.Setenv("PATH", serverDir) + + r, err := ResolveCommand([]string{"tool"}, "") + if err != nil { + t.Fatalf("ResolveCommand: %v", err) + } + if want := evalSymlinks(t, serverTool); r.Canonical != want { + t.Errorf("Canonical = %q, want the server-PATH copy %q", r.Canonical, want) + } + if strings.HasPrefix(r.Canonical, evalSymlinks(t, otherDir)) { + t.Errorf("Canonical = %q resolved into the off-PATH dir", r.Canonical) + } +} + +type canonicalArgvTest struct { + name string + res Resolution + want []string +} + +var canonicalArgvTests = []canonicalArgvTest{ + {name: "canonical with tail", res: Resolution{Argv: []string{"foo", "-x"}, Canonical: "/opt/foo"}, want: []string{"/opt/foo", "-x"}}, + {name: "canonical only", res: Resolution{Argv: []string{"foo"}, Canonical: "/opt/foo"}, want: []string{"/opt/foo"}}, + {name: "no canonical is nil", res: Resolution{Argv: []string{"foo"}}, want: nil}, + {name: "no argv is nil", res: Resolution{Canonical: "/opt/foo"}, want: nil}, +} + +func TestCanonicalArgv(t *testing.T) { + t.Parallel() + + for _, tt := range canonicalArgvTests { + t.Run(tt.name, func(t *testing.T) { + got := tt.res.CanonicalArgv() + if len(got) != len(tt.want) { + t.Fatalf("CanonicalArgv() = %v, want %v", got, tt.want) + } + for i := range tt.want { + if got[i] != tt.want[i] { + t.Fatalf("CanonicalArgv() = %v, want %v", got, tt.want) + } + } + }) + } +} + +type execPathTest struct { + name string + res Resolution + want string +} + +var execPathTests = []execPathTest{ + {name: "prefers canonical", res: Resolution{Argv: []string{"foo"}, Resolved: "/r/foo", Canonical: "/c/foo"}, want: "/c/foo"}, + {name: "falls back to resolved", res: Resolution{Argv: []string{"foo"}, Resolved: "/r/foo"}, want: "/r/foo"}, + {name: "falls back to argv0", res: Resolution{Argv: []string{"foo"}}, want: "foo"}, + {name: "empty resolution", res: Resolution{}, want: ""}, +} + +func TestExecPath(t *testing.T) { + t.Parallel() + + for _, tt := range execPathTests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.res.ExecPath(); got != tt.want { + t.Errorf("ExecPath() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/pkg/cg/mcp/gate.go b/pkg/cg/mcp/gate.go index fa50fd1..3210af5 100644 --- a/pkg/cg/mcp/gate.go +++ b/pkg/cg/mcp/gate.go @@ -7,6 +7,7 @@ import ( mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/ripta/rt/pkg/cg" "github.com/ripta/rt/pkg/cg/approve" ) @@ -29,17 +30,20 @@ type gate struct { } // check evaluates the command against the gate and returns nil to permit -// execution or a refusal error. blindlyAllow (and a nil gate) bypasses matching -// and lets the env override pass through untouched, matching allow-all. A real -// allow rule additionally gates dangerous env overrides; allow-all (whose rule -// is nil) does not. A command that matches nothing prompts the user when el is -// available, and otherwise fails closed. -func (g *gate) check(ctx context.Context, in runInput, el elicitor) error { +// execution or a refusal error. The matcher sees the canonical executable +// resolved from argv[0], so policy and execution agree on which file runs. +// blindlyAllow (and a nil gate) bypasses matching and lets the env override pass +// through untouched, matching allow-all. A real allow rule additionally gates +// dangerous env overrides; allow-all (whose rule is nil) does not. A command that +// matches nothing prompts the user when el is available, and otherwise fails +// closed. +func (g *gate) check(ctx context.Context, in runInput, resolved *cg.Resolution, el elicitor) error { if g == nil || g.blindlyAllow { return nil } - res := g.store.Ruleset().Match(in.Command) + subject := approve.Subject{Argv: in.Command, Canonical: resolved.CanonicalArgv()} + res := g.store.Ruleset().Match(subject) switch res.Decision { case approve.DecisionRun: if res.Rule != nil { diff --git a/pkg/cg/mcp/gate_test.go b/pkg/cg/mcp/gate_test.go index b0b23e3..487216f 100644 --- a/pkg/cg/mcp/gate_test.go +++ b/pkg/cg/mcp/gate_test.go @@ -4,12 +4,36 @@ import ( "context" "os" "path/filepath" + "regexp" "strings" "testing" "github.com/ripta/rt/pkg/cg/approve" ) +// plantScript writes an executable shell script under dir that echoes marker, +// and returns its path. Used to drive resolution and exec end-to-end. +func plantScript(t *testing.T, dir, name, marker string) string { + t.Helper() + path := filepath.Join(dir, name) + body := "#!/bin/sh\necho " + marker + "\n" + if err := os.WriteFile(path, []byte(body), 0o755); err != nil { + t.Fatalf("planting %s: %v", path, err) + } + return path +} + +// canonicalPath resolves symlinks so an approval rule can name the executable's +// canonical path, which is what the gate matches after resolution. +func canonicalPath(t *testing.T, path string) string { + t.Helper() + got, err := filepath.EvalSymlinks(path) + if err != nil { + t.Fatalf("EvalSymlinks(%s): %v", path, err) + } + return got +} + // newTestGate builds a gate from project YAML written under a fresh temp project // root. Use newTestGateAt when the test needs the root path to inspect the file. func newTestGate(t *testing.T, projectYAML string, blindly bool) *gate { @@ -42,7 +66,7 @@ func newTestGateAt(t *testing.T, root, projectYAML string, blindly bool) *gate { func TestGateAllowRuns(t *testing.T) { t.Setenv("TMPDIR", t.TempDir()) - g := newTestGate(t, "version: 1\nallow:\n - prefix: [echo]\n", false) + g := newTestGate(t, "version: 1\nallow:\n - prefix: [echo]\n as_basename: true\n", false) _, out, err := handleRun(context.Background(), nil, g, nil, runInput{ Command: []string{"echo", "hi"}, }) @@ -57,7 +81,7 @@ func TestGateAllowRuns(t *testing.T) { func TestGateDenyRefusesWithMessage(t *testing.T) { t.Setenv("TMPDIR", t.TempDir()) - g := newTestGate(t, "version: 1\ndeny:\n - prefix: [rm, -rf]\n message: delete specific paths instead\n", false) + g := newTestGate(t, "version: 1\ndeny:\n - prefix: [rm, -rf]\n as_basename: true\n message: delete specific paths instead\n", false) _, _, err := handleRun(context.Background(), nil, g, nil, runInput{ Command: []string{"rm", "-rf", "x"}, }) @@ -115,7 +139,7 @@ func TestGateFailsClosedOnUnmatched(t *testing.T) { func TestGateEnvGateRefuses(t *testing.T) { t.Setenv("TMPDIR", t.TempDir()) - g := newTestGate(t, "version: 1\nallow:\n - prefix: [echo]\n", false) + g := newTestGate(t, "version: 1\nallow:\n - prefix: [echo]\n as_basename: true\n", false) _, _, err := handleRun(context.Background(), nil, g, nil, runInput{ Command: []string{"echo", "hi"}, Env: map[string]string{"LD_PRELOAD": "evil.so"}, @@ -131,7 +155,7 @@ func TestGateEnvGateRefuses(t *testing.T) { func TestGateEnvGatePermitted(t *testing.T) { t.Setenv("TMPDIR", t.TempDir()) - g := newTestGate(t, "version: 1\nallow:\n - prefix: [echo]\n permit_unsafe_envs: [PATH]\n", false) + g := newTestGate(t, "version: 1\nallow:\n - prefix: [echo]\n as_basename: true\n permit_unsafe_envs: [PATH]\n", false) _, out, err := handleRun(context.Background(), nil, g, nil, runInput{ Command: []string{"echo", "hi"}, Env: map[string]string{"PATH": os.Getenv("PATH")}, @@ -160,3 +184,89 @@ func TestGateAllowAllPassesEnv(t *testing.T) { t.Errorf("ExitCode = %v, want 0", out.ExitCode) } } + +// TestGateSymlinkCanonicalAllow allows the executable by its canonical path and +// invokes it through a symlink in another directory. Resolution canonicalizes +// the link to its target, so the path-based allow rule matches and the program +// runs. +func TestGateSymlinkCanonicalAllow(t *testing.T) { + t.Setenv("TMPDIR", t.TempDir()) + + realDir := t.TempDir() + linkDir := t.TempDir() + real := plantScript(t, realDir, "tool", "ran-real") + link := filepath.Join(linkDir, "tool") + if err := os.Symlink(real, link); err != nil { + t.Fatalf("symlink: %v", err) + } + + yaml := "version: 1\nallow:\n - prefix: ['" + canonicalPath(t, real) + "']\n" + g := newTestGate(t, yaml, false) + _, out, err := handleRun(context.Background(), nil, g, nil, runInput{ + Command: []string{link}, + }) + if err != nil { + t.Fatalf("handleRun: %v", err) + } + if !strings.Contains(out.StdoutExcerpt, "ran-real") { + t.Errorf("StdoutExcerpt = %q, want it to contain %q", out.StdoutExcerpt, "ran-real") + } +} + +// TestGateSymlinkCanonicalDeny denies a directory by its canonical path and +// invokes an executable inside it through a symlink elsewhere. The deny fires on +// the canonicalized target even though the invoked path is the link. +func TestGateSymlinkCanonicalDeny(t *testing.T) { + t.Setenv("TMPDIR", t.TempDir()) + + realDir := t.TempDir() + linkDir := t.TempDir() + real := plantScript(t, realDir, "tool", "should-not-run") + link := filepath.Join(linkDir, "tool") + if err := os.Symlink(real, link); err != nil { + t.Fatalf("symlink: %v", err) + } + + denyDir := regexp.QuoteMeta(canonicalPath(t, realDir)) + yaml := "version: 1\ndeny:\n - regex: '^" + denyDir + "/'\n message: no executables from that directory\n" + g := newTestGate(t, yaml, false) + _, _, err := handleRun(context.Background(), nil, g, nil, runInput{ + Command: []string{link}, + }) + if err == nil { + t.Fatalf("expected refusal for canonical-path deny") + } + if !strings.Contains(err.Error(), "no executables from that directory") { + t.Errorf("err = %v, want the rule message surfaced", err) + } +} + +// TestGateEnvPathDoesNotRedirectExec plants the same program name on the server +// PATH and in an off-PATH directory, then runs it with env.PATH pointed at the +// off-PATH copy. The server-PATH binary resolves and execs, so env.PATH cannot +// redirect the approved top-level executable; the dangerous-env gate still +// requires the explicit permit_unsafe_envs exemption to run at all. +func TestGateEnvPathDoesNotRedirectExec(t *testing.T) { + t.Setenv("TMPDIR", t.TempDir()) + + serverDir := t.TempDir() + evilDir := t.TempDir() + plantScript(t, serverDir, "tool", "server") + plantScript(t, evilDir, "tool", "evil") + t.Setenv("PATH", serverDir) + + g := newTestGate(t, "version: 1\nallow:\n - prefix: [tool]\n as_basename: true\n permit_unsafe_envs: [PATH]\n", false) + _, out, err := handleRun(context.Background(), nil, g, nil, runInput{ + Command: []string{"tool"}, + Env: map[string]string{"PATH": evilDir}, + }) + if err != nil { + t.Fatalf("handleRun: %v", err) + } + if !strings.Contains(out.StdoutExcerpt, "server") { + t.Errorf("StdoutExcerpt = %q, want the server-PATH copy to run", out.StdoutExcerpt) + } + if strings.Contains(out.StdoutExcerpt, "evil") { + t.Errorf("StdoutExcerpt = %q, env.PATH redirected the exec", out.StdoutExcerpt) + } +} diff --git a/pkg/cg/mcp/run.go b/pkg/cg/mcp/run.go index fd0bdd0..21c7704 100644 --- a/pkg/cg/mcp/run.go +++ b/pkg/cg/mcp/run.go @@ -72,7 +72,7 @@ func handleRun(ctx context.Context, reg *runRegistry, g *gate, el elicitor, in r resolved, _ := cg.ResolveCommand(in.Command, in.Cwd) - if err := g.check(ctx, in, el); err != nil { + if err := g.check(ctx, in, resolved, el); err != nil { return nil, runOutput{}, err } diff --git a/pkg/cg/runner_test.go b/pkg/cg/runner_test.go index e7b5785..4113b3a 100644 --- a/pkg/cg/runner_test.go +++ b/pkg/cg/runner_test.go @@ -3,7 +3,9 @@ package cg import ( "bytes" "errors" + "os" "os/exec" + "os/signal" "regexp" "strings" "syscall" @@ -381,6 +383,38 @@ func TestCommandCustomFormat(t *testing.T) { } } +// sendSignalAndWait sends SIGTERM to the test process so cg's forwarding goroutine relays it to +// the child, then waits for cmd.Execute to return. +// +// Delivery of a signal sent to our own process is asynchronous, and under scheduling pressure the +// forwarding goroutine occasionally does not observe a single SIGTERM before the deadline, so we +// resend periodically until the command finishes. +func sendSignalAndWait(t *testing.T, done <-chan error) { + t.Helper() + + // Register a handler so a stray or late SIGTERM cannot apply the default + // action and kill the test binary once the runner has stopped forwarding. + guard := make(chan os.Signal, 1) + signal.Notify(guard, syscall.SIGTERM) + defer signal.Stop(guard) + + syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + deadline := time.After(30 * time.Second) + resend := time.NewTicker(250 * time.Millisecond) + defer resend.Stop() + for { + select { + case <-done: + return + case <-deadline: + t.Fatal("timed out waiting for command to finish after signal") + case <-resend.C: + syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + } + } +} + func TestCommandSignalForwarding(t *testing.T) { if testing.Short() { t.Skip("skipping integration test") @@ -416,15 +450,7 @@ func TestCommandSignalForwarding(t *testing.T) { time.Sleep(10 * time.Millisecond) } - // Send SIGTERM to our own process group; the child should receive it via - // the signal forwarding goroutine - syscall.Kill(syscall.Getpid(), syscall.SIGTERM) - - select { - case <-done: - case <-time.After(5 * time.Second): - t.Fatal("timed out waiting for command to finish after signal") - } + sendSignalAndWait(t, done) out := buf.String() if !strings.Contains(out, "O: got_sigterm") { @@ -477,13 +503,7 @@ func TestCommandFinishLineSignaled(t *testing.T) { time.Sleep(10 * time.Millisecond) } - syscall.Kill(syscall.Getpid(), syscall.SIGTERM) - - select { - case <-done: - case <-time.After(5 * time.Second): - t.Fatal("timed out waiting for command to finish after signal") - } + sendSignalAndWait(t, done) out := buf.String() m := finishLineRE.FindStringSubmatch(out)