Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 1 addition & 38 deletions commands/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,7 @@ func commandQuery(c *cli.Context) error {
// Print newline after streaming
fmt.Println()

newCommand := sanitizeSuggestedCommand(result.String())
if newCommand == "" {
color.Red.Println("AI returned an empty response. Try rephrasing your query.")
return fmt.Errorf("empty AI response")
}
newCommand := strings.TrimSpace(result.String())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore command sanitization before auto-run path

Replacing sanitizeSuggestedCommand with a plain TrimSpace means markdown-formatted model output (for example fenced blocks or inline backticks) is now passed through unchanged. When auto-run is enabled, executeCommand receives those backticks via shell -c, which can trigger shell parsing/command-substitution errors instead of executing the intended command; classification logic also runs on the malformed text.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

By removing the sanitizeSuggestedCommand function, the CLI no longer strips markdown code fences (e.g., ```bash ... ```) from the AI's response. If the AI returns a formatted response, newCommand will contain these backticks, causing the shell execution at line 139 to fail with a syntax error. Since many LLMs default to markdown, it is recommended to at least strip triple-backtick fences.


// Check auto-run configuration
if cfg.AI != nil && (cfg.AI.Agent.View || cfg.AI.Agent.Edit || cfg.AI.Agent.Delete) {
Expand Down Expand Up @@ -192,39 +188,6 @@ func executeCommand(ctx context.Context, command string) error {
return nil
}

// sanitizeSuggestedCommand normalizes raw AI output into an executable command.
// It strips triple-backtick fences (with optional language tag like bash, sh,
// zsh, fish, pwsh, powershell), strips surrounding single backticks when the
// result is a single line, and trims whitespace. Responses that start with `#`
// are treated as refusal comments and preserved verbatim so the caller can
// surface them to the user without attempting execution.
func sanitizeSuggestedCommand(raw string) string {
s := strings.TrimSpace(raw)
if s == "" {
return ""
}

if strings.HasPrefix(s, "```") {
s = strings.TrimPrefix(s, "```")
if nl := strings.IndexByte(s, '\n'); nl >= 0 {
switch strings.ToLower(strings.TrimSpace(s[:nl])) {
case "", "bash", "sh", "shell", "zsh", "fish", "pwsh", "powershell":
s = s[nl+1:]
}
}
s = strings.TrimRight(s, " \t\n")
s = strings.TrimSuffix(s, "```")
s = strings.TrimSpace(s)
}

if !strings.ContainsRune(s, '\n') && len(s) >= 2 &&
strings.HasPrefix(s, "`") && strings.HasSuffix(s, "`") {
s = strings.TrimSpace(s[1 : len(s)-1])
}

return s
}

func getSystemContext(query string, ai *model.AIConfig) (model.CommandSuggestVariables, error) {
// Get shell information
shell := os.Getenv("SHELL")
Expand Down
37 changes: 1 addition & 36 deletions commands/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,42 +458,7 @@ func (s *queryTestSuite) TestQueryCommandEmptyAIResponse() {
}

err := s.app.Run(command)
assert.NotNil(s.T(), err)
assert.Contains(s.T(), err.Error(), "empty AI response")
}

func (s *queryTestSuite) TestSanitizeSuggestedCommand() {
tests := []struct {
name string
in string
want string
}{
{"plain", "ls -la", "ls -la"},
{"trims whitespace", " ls -la \n\t", "ls -la"},
{"fence with bash tag", "```bash\necho hi\n```", "echo hi"},
{"fence with sh tag", "```sh\necho hi\n```", "echo hi"},
{"fence with zsh tag", "```zsh\necho hi\n```", "echo hi"},
{"fence with shell tag", "```shell\necho hi\n```", "echo hi"},
{"fence with fish tag", "```fish\nset -x FOO bar\n```", "set -x FOO bar"},
{"fence with powershell tag", "```powershell\nGet-Process\n```", "Get-Process"},
{"fence with pwsh tag", "```pwsh\nGet-Process\n```", "Get-Process"},
{"fence no language tag", "```\necho hi\n```", "echo hi"},
{"fence with trailing newline before closing", "```bash\nls -la\n\n```", "ls -la"},
{"single backticks around single-line", "`ls -la`", "ls -la"},
{"single backticks with surrounding space", " `ls -la` ", "ls -la"},
{"only whitespace", " \n\t ", ""},
{"empty", "", ""},
{"comment passthrough preserved", "# refusing: unsafe request", "# refusing: unsafe request"},
{"multiline without fences kept", "ls\ncat foo", "ls\ncat foo"},
}
for _, tt := range tests {
s.T().Run(tt.name, func(t *testing.T) {
got := sanitizeSuggestedCommand(tt.in)
if got != tt.want {
t.Errorf("sanitizeSuggestedCommand(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
assert.Nil(s.T(), err)
}

func (s *queryTestSuite) TestQueryCommandDescription() {
Expand Down
36 changes: 10 additions & 26 deletions model/ai_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,24 @@ func (s *sseAIService) QueryCommandStream(
for scanner.Scan() {
line := scanner.Text()

if line == "" {
isError = false
if line == "event: error" {
isError = true
continue
}
Comment on lines +77 to 80
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The string comparison line == "event: error" is too restrictive. The SSE specification allows for an optional space after the colon, and some servers might omit it (e.g., event:error). This hardcoded string comparison will fail to detect errors if the server's formatting varies slightly while remaining spec-compliant.


if v, ok := stripSSEField(line, "event:"); ok {
if v == "error" {
isError = true
}
continue
}
if strings.HasPrefix(line, "data:") {
data := line[len("data:"):]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Trim SSE field value before [DONE] detection

The stream parser takes line[len("data:"):] verbatim, so a valid SSE frame like data: [DONE] yields " [DONE]" and never matches the terminator check. In that case the sentinel is emitted to onToken as part of the suggested command and the loop only ends on EOF, which can break command output and streaming behavior for standards-compliant SSE servers that include the optional space after the colon.

Useful? React with 👍 / 👎.


if v, ok := stripSSEField(line, "data:"); ok {
if isError {
return fmt.Errorf("server error: %s", v)
return fmt.Errorf("server error: %s", data)
}
if v == "[DONE]" {

if data == "[DONE]" {
return nil
}
onToken(v)

onToken(data)
isError = false
}
Comment on lines +82 to 95
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic fails to handle the optional leading space in SSE data fields correctly. Per the SSE spec (§9.2), a single leading space after the colon should be removed. Most importantly, if a server sends data: [DONE] (with the standard space), the data variable will contain " [DONE]". This causes the check at line 89 to fail, and the string [DONE] will be passed to onToken, resulting in it being printed to the user's terminal and appended to the command string.

Suggested change
if strings.HasPrefix(line, "data:") {
data := line[len("data:"):]
if v, ok := stripSSEField(line, "data:"); ok {
if isError {
return fmt.Errorf("server error: %s", v)
return fmt.Errorf("server error: %s", data)
}
if v == "[DONE]" {
if data == "[DONE]" {
return nil
}
onToken(v)
onToken(data)
isError = false
}
if strings.HasPrefix(line, "data:") {
data := strings.TrimPrefix(line, "data:")
if strings.HasPrefix(data, " ") {
data = data[1:]
}
if isError {
return fmt.Errorf("server error: %s", data)
}
if data == "[DONE]" {
return nil
}
onToken(data)
isError = false
}

}

Expand All @@ -103,17 +101,3 @@ func (s *sseAIService) QueryCommandStream(

return nil
}

// stripSSEField returns the value after prefix, stripping one optional leading
// space per the SSE specification (§9.2 "If value starts with a U+0020 SPACE
// character, remove it from value").
func stripSSEField(line, prefix string) (string, bool) {
if !strings.HasPrefix(line, prefix) {
return "", false
}
v := line[len(prefix):]
if strings.HasPrefix(v, " ") {
v = v[1:]
}
return v, true
}
122 changes: 0 additions & 122 deletions model/ai_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

Expand Down Expand Up @@ -91,124 +90,3 @@ func TestQueryCommandStream_ErrorResponseBody(t *testing.T) {
})
}
}
Comment on lines 90 to 92
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The removal of TestQueryCommandStream_SSEParsing and TestStripSSEField significantly reduces the test coverage for the streaming logic. Even if the implementation is being simplified, it's important to maintain tests that verify how the service handles various SSE inputs (like leading spaces, error events, and the [DONE] signal) to prevent regressions.


func TestQueryCommandStream_SSEParsing(t *testing.T) {
tests := []struct {
name string
body string
wantErr bool
wantErrSubstr string
wantTokens []string
}{
{
name: "data with space and [DONE] terminates cleanly",
body: "data: [DONE]\n\n",
wantTokens: nil,
},
{
name: "data without space and [DONE] terminates cleanly",
body: "data:[DONE]\n\n",
wantTokens: nil,
},
{
name: "single data token with leading space is stripped",
body: "data: hello\n\ndata: [DONE]\n\n",
wantTokens: []string{"hello"},
},
{
name: "single data token without leading space passes through",
body: "data:hello\n\ndata:[DONE]\n\n",
wantTokens: []string{"hello"},
},
{
name: "multi-token stream concatenates without spurious spaces",
body: "data: ls\n\ndata: -la\n\ndata: [DONE]\n\n",
wantTokens: []string{"ls", " -la"},
},
{
name: "event error with space",
body: "event: error\ndata: boom\n\n",
wantErr: true,
wantErrSubstr: "boom",
},
{
name: "event error without space",
body: "event:error\ndata:boom\n\n",
wantErr: true,
wantErrSubstr: "boom",
},
{
name: "blank line resets error state between events",
body: "event: error\n\ndata: hello\n\ndata: [DONE]\n\n",
wantTokens: []string{"hello"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(tt.body))
}))
defer server.Close()

var got []string
svc := NewAIService()
err := svc.QueryCommandStream(
context.Background(),
CommandSuggestVariables{Shell: "bash", Os: "linux", Query: "test"},
Endpoint{APIEndpoint: server.URL, Token: "test-token"},
func(token string) { got = append(got, token) },
)

if tt.wantErr {
if err == nil {
t.Fatalf("expected error, got nil (tokens=%v)", got)
}
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
t.Fatalf("expected error to contain %q, got %q", tt.wantErrSubstr, err.Error())
}
return
}

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(got) != len(tt.wantTokens) {
t.Fatalf("token count mismatch: want %d %v, got %d %v", len(tt.wantTokens), tt.wantTokens, len(got), got)
}
for i, tok := range tt.wantTokens {
if got[i] != tok {
t.Errorf("token[%d] = %q, want %q", i, got[i], tok)
}
}
})
}
}

func TestStripSSEField(t *testing.T) {
tests := []struct {
name string
line string
prefix string
wantVal string
wantOk bool
}{
{"no match", "foo:bar", "data:", "", false},
{"match no space", "data:hello", "data:", "hello", true},
{"match one space stripped", "data: hello", "data:", "hello", true},
{"match two spaces preserves second", "data: hello", "data:", " hello", true},
{"empty value no space", "data:", "data:", "", true},
{"empty value one space", "data: ", "data:", "", true},
{"event error with space", "event: error", "event:", "error", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v, ok := stripSSEField(tt.line, tt.prefix)
if ok != tt.wantOk || v != tt.wantVal {
t.Errorf("stripSSEField(%q, %q) = (%q, %v), want (%q, %v)", tt.line, tt.prefix, v, ok, tt.wantVal, tt.wantOk)
}
})
}
}
Loading