Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions api/acp_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
type acpBootstrapResponse struct {
Conversation *spritzv1.SpritzConversation `json:"conversation"`
EffectiveSessionID string `json:"effectiveSessionId,omitempty"`
EffectiveCWD string `json:"effectiveCwd,omitempty"`
BindingState string `json:"bindingState,omitempty"`
Loaded bool `json:"loaded,omitempty"`
Replaced bool `json:"replaced,omitempty"`
Expand Down Expand Up @@ -417,10 +418,16 @@ func (s *server) bootstrapACPConversationBinding(ctx context.Context, conversati
return nil, err
}

return s.bootstrapACPConversationBindingWithClient(ctx, conversation, client, initResult)
return s.bootstrapACPConversationBindingWithClient(ctx, conversation, spritz, client, initResult)
}

func (s *server) bootstrapACPConversationBindingWithClient(ctx context.Context, conversation *spritzv1.SpritzConversation, client *acpBootstrapInstanceClient, initResult *acpBootstrapInitializeResult) (*acpBootstrapResponse, error) {
func (s *server) bootstrapACPConversationBindingWithClient(
ctx context.Context,
conversation *spritzv1.SpritzConversation,
spritz *spritzv1.Spritz,
client *acpBootstrapInstanceClient,
initResult *acpBootstrapInitializeResult,
) (*acpBootstrapResponse, error) {
if !initResult.AgentCapabilities.LoadSession {
err := errors.New("agent does not support session/load")
s.recordConversationBindingError(ctx, conversation.Namespace, conversation.Name, "", err)
Expand All @@ -429,6 +436,7 @@ func (s *server) bootstrapACPConversationBindingWithClient(ctx context.Context,

agentInfo := normalizeBootstrapAgentInfo(initResult)
capabilities := normalizeBootstrapCapabilities(initResult)
effectiveCWD := resolveConversationEffectiveCWD(spritz, conversation)
effectiveSessionID := strings.TrimSpace(conversation.Spec.SessionID)
previousSessionID := ""
bindingState := "active"
Expand All @@ -438,12 +446,12 @@ func (s *server) bootstrapACPConversationBindingWithClient(ctx context.Context,
var err error

if effectiveSessionID != "" {
replayMessageCount, err = client.loadSession(ctx, effectiveSessionID, normalizeConversationCWD(conversation.Spec.CWD))
replayMessageCount, err = client.loadSession(ctx, effectiveSessionID, effectiveCWD)
if err != nil {
var rpcErr *acpBootstrapRPCError
if errors.As(err, &rpcErr) && rpcErr.missingSession() {
previousSessionID = effectiveSessionID
effectiveSessionID, err = client.newSession(ctx, normalizeConversationCWD(conversation.Spec.CWD))
effectiveSessionID, err = client.newSession(ctx, effectiveCWD)
if err != nil {
s.recordConversationBindingError(ctx, conversation.Namespace, conversation.Name, previousSessionID, err)
return nil, err
Expand All @@ -458,7 +466,7 @@ func (s *server) bootstrapACPConversationBindingWithClient(ctx context.Context,
loaded = true
}
} else {
effectiveSessionID, err = client.newSession(ctx, normalizeConversationCWD(conversation.Spec.CWD))
effectiveSessionID, err = client.newSession(ctx, effectiveCWD)
if err != nil {
s.recordConversationBindingError(ctx, conversation.Namespace, conversation.Name, "", err)
return nil, err
Expand All @@ -473,11 +481,13 @@ func (s *server) bootstrapACPConversationBindingWithClient(ctx context.Context,

updatedConversation, err := s.updateConversationBinding(ctx, conversation.Namespace, conversation.Name, func(current *spritzv1.SpritzConversation) {
now := metav1.Now()
setConversationCWDOverride(current, normalizeConversationOverrideCWD(spritz, current))
current.Spec.SessionID = effectiveSessionID
current.Spec.AgentInfo = agentInfo
current.Spec.Capabilities = capabilities
current.Status.BoundSessionID = effectiveSessionID
current.Status.BindingState = bindingState
current.Status.EffectiveCWD = effectiveCWD
current.Status.PreviousSessionID = previousSessionID
current.Status.LastBoundAt = &now
current.Status.LastReplayMessageCount = replayMessageCount
Expand All @@ -496,6 +506,7 @@ func (s *server) bootstrapACPConversationBindingWithClient(ctx context.Context,
return &acpBootstrapResponse{
Conversation: updatedConversation,
EffectiveSessionID: effectiveSessionID,
EffectiveCWD: effectiveCWD,
BindingState: bindingState,
Loaded: loaded,
Replaced: replaced,
Expand Down Expand Up @@ -527,18 +538,20 @@ func (s *server) updateConversationBinding(ctx context.Context, namespace, name
}
beforeSpec := current.Spec
beforeStatus := current.Status
beforeAnnotations := cloneStringMap(current.Annotations)
mutate(current)
specChanged := !apiequality.Semantic.DeepEqual(beforeSpec, current.Spec)
statusChanged := !apiequality.Semantic.DeepEqual(beforeStatus, current.Status)
annotationsChanged := !apiequality.Semantic.DeepEqual(beforeAnnotations, current.Annotations)
desiredStatus := current.Status
if specChanged {
if specChanged || annotationsChanged {
if err := s.client.Update(ctx, current); err != nil {
return err
}
}
if statusChanged {
statusTarget := current
if specChanged {
if specChanged || annotationsChanged {
statusTarget = &spritzv1.SpritzConversation{}
if err := s.client.Get(ctx, clientKey(namespace, name), statusTarget); err != nil {
return err
Expand Down
1 change: 1 addition & 0 deletions api/acp_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const (
acpConversationLabelValue = "true"
acpConversationSpritzLabelKey = "spritz.sh/spritz-name"
acpConversationOwnerLabelKey = ownerLabelKey
acpConversationExplicitCWDKey = "spritz.sh/acp-cwd-override-explicit"
)

type acpConfig struct {
Expand Down
57 changes: 42 additions & 15 deletions api/acp_conversations.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/labstack/echo/v4"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client"

spritzv1 "spritz.sh/operator/api/v1"
Expand Down Expand Up @@ -157,24 +158,50 @@ func (s *server) updateACPConversation(c echo.Context) error {
return writeError(c, http.StatusBadRequest, err.Error())
}

changed := false
if body.Title != nil && conversation.Spec.Title != strings.TrimSpace(*body.Title) {
conversation.Spec.Title = strings.TrimSpace(*body.Title)
if conversation.Spec.Title == "" {
conversation.Spec.Title = defaultACPConversationTitle
}
changed = true
}
if body.CWD != nil && conversation.Spec.CWD != normalizeConversationCWD(*body.CWD) {
conversation.Spec.CWD = normalizeConversationCWD(*body.CWD)
changed = true
if body.Title == nil && body.CWD == nil {
return writeJSON(c, http.StatusOK, conversation)
}
if changed {
if err := s.client.Update(c.Request().Context(), conversation); err != nil {
return writeError(c, http.StatusInternalServerError, err.Error())

updatedConversation, err := s.updateConversationBinding(c.Request().Context(), conversation.Namespace, conversation.Name, func(current *spritzv1.SpritzConversation) {
if body.Title != nil && current.Spec.Title != strings.TrimSpace(*body.Title) {
current.Spec.Title = strings.TrimSpace(*body.Title)
if current.Spec.Title == "" {
current.Spec.Title = defaultACPConversationTitle
}
}
if body.CWD == nil {
return
}

nextCWD := normalizeConversationCWD(*body.CWD)
nextExplicit := nextCWD != ""
currentExplicit := conversationHasExplicitCWDOverride(current)
if current.Spec.CWD == nextCWD && currentExplicit == nextExplicit {
return
}

setConversationCWDOverride(current, *body.CWD)

previousSessionID := strings.TrimSpace(current.Status.BoundSessionID)
if previousSessionID == "" {
previousSessionID = strings.TrimSpace(current.Spec.SessionID)
}
current.Spec.SessionID = ""
current.Status.BindingState = "pending"
current.Status.BoundSessionID = ""
current.Status.EffectiveCWD = ""
current.Status.PreviousSessionID = previousSessionID
current.Status.LastBoundAt = nil
current.Status.LastReplayAt = nil
current.Status.LastReplayMessageCount = 0
current.Status.LastError = ""
now := metav1.Now()
current.Status.UpdatedAt = &now
})
if err != nil {
return writeError(c, http.StatusInternalServerError, err.Error())
}
return writeJSON(c, http.StatusOK, conversation)
return writeJSON(c, http.StatusOK, updatedConversation)
}

func decodeACPBody(c echo.Context, target any) error {
Expand Down
195 changes: 195 additions & 0 deletions api/acp_cwd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package main

import (
"net/url"
"path"
"strconv"
"strings"

spritzv1 "spritz.sh/operator/api/v1"
)

// normalizeConversationCWD trims client input and preserves empty values so the
// conversation resource can distinguish "no override" from an explicit cwd.
func normalizeConversationCWD(value string) string {
return strings.TrimSpace(value)
}

// resolveConversationEffectiveCWD resolves the cwd that should be used for ACP
// bootstrap and reconnect flows after accounting for explicit overrides,
// instance defaults, and legacy copied-default values.
func resolveConversationEffectiveCWD(spritz *spritzv1.Spritz, conversation *spritzv1.SpritzConversation) string {
defaultCWD := resolveSpritzDefaultCWD(spritz)
if conversation == nil {
return defaultCWD
}
if override := normalizeConversationOverrideCWD(spritz, conversation); override != "" {
return override
}
return defaultCWD
}

// normalizeConversationOverrideCWD distinguishes an explicit override from an
// inherited instance default without guessing about ambiguous historical values.
func normalizeConversationOverrideCWD(spritz *spritzv1.Spritz, conversation *spritzv1.SpritzConversation) string {
if conversation == nil {
return ""
}
override := normalizeConversationCWD(conversation.Spec.CWD)
if override == "" {
return ""
}

defaultCWD := resolveSpritzDefaultCWD(spritz)
if conversationHasExplicitCWDOverride(conversation) {
return override
}
if override == defaultCWD {
return ""
}
if override == defaultACPCWD {
return ""
}
return override
}

// resolveSpritzDefaultCWD derives the runtime-owned default cwd from explicit
// env overrides first and falls back to the primary repo checkout directory.
func resolveSpritzDefaultCWD(spritz *spritzv1.Spritz) string {
if spritz == nil {
return defaultACPCWD
}

for _, key := range []string{
"SPRITZ_CONVERSATION_DEFAULT_CWD",
"SPRITZ_CODEX_WORKDIR",
"SPRITZ_CLAUDE_CODE_WORKDIR",
"SPRITZ_REPO_DIR",
} {
if value := spritzEnvValue(spritz, key); value != "" {
return value
}
}

if repoDir := resolvePrimaryRepoDir(spritz); repoDir != "" {
return repoDir
}
return defaultACPCWD
}

func spritzEnvValue(spritz *spritzv1.Spritz, key string) string {
if spritz == nil {
return ""
}
for i := len(spritz.Spec.Env) - 1; i >= 0; i-- {
env := spritz.Spec.Env[i]
if strings.TrimSpace(env.Name) != key {
continue
}
if value := strings.TrimSpace(env.Value); value != "" {
return value
}
}
return ""
}

func resolvePrimaryRepoDir(spritz *spritzv1.Spritz) string {
if spritz == nil {
return ""
}

repos := spritz.Spec.Repos
if len(repos) > 0 {
return repoDirForConversationDefault(repos[0], 0, len(repos))
}
if spritz.Spec.Repo != nil && strings.TrimSpace(spritz.Spec.Repo.URL) != "" {
return repoDirForConversationDefault(*spritz.Spec.Repo, 0, 1)
}
return ""
}

func repoDirForConversationDefault(repo spritzv1.SpritzRepo, index int, total int) string {
repoDir := strings.TrimSpace(repo.Dir)
if repoDir == "" {
if total > 1 {
repoDir = "/workspace/repo-" + strconv.Itoa(index+1)
} else if inferred := inferConversationRepoName(repo.URL); inferred != "" {
repoDir = path.Join("/workspace", inferred)
} else {
repoDir = "/workspace/repo"
}
}
if !strings.HasPrefix(repoDir, "/") {
repoDir = path.Join("/workspace", repoDir)
}
return path.Clean(repoDir)
}

func inferConversationRepoName(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return ""
}
pathPart := ""
if strings.Contains(value, "://") {
parsed, err := url.Parse(value)
if err != nil {
return ""
}
pathPart = parsed.Path
} else if strings.Contains(value, ":") {
parts := strings.SplitN(value, ":", 2)
if len(parts) == 2 {
pathPart = parts[1]
} else {
pathPart = value
}
} else {
pathPart = value
}
pathPart = strings.SplitN(pathPart, "?", 2)[0]
pathPart = strings.SplitN(pathPart, "#", 2)[0]
pathPart = strings.TrimSuffix(pathPart, "/")
if pathPart == "" {
return ""
}
base := path.Base(pathPart)
if base == "." || base == "/" {
return ""
}
base = strings.TrimSuffix(base, ".git")
if base == "" || base == "." || base == "/" {
return ""
}
return base
}

func conversationHasExplicitCWDOverride(conversation *spritzv1.SpritzConversation) bool {
if conversation == nil || conversation.Annotations == nil {
return false
}
value := strings.TrimSpace(conversation.Annotations[acpConversationExplicitCWDKey])
switch strings.ToLower(value) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}

func setConversationCWDOverride(conversation *spritzv1.SpritzConversation, value string) {
if conversation == nil {
return
}
conversation.Spec.CWD = normalizeConversationCWD(value)
if conversation.Spec.CWD == "" {
if conversation.Annotations != nil {
delete(conversation.Annotations, acpConversationExplicitCWDKey)
}
return
}
if conversation.Annotations == nil {
conversation.Annotations = map[string]string{}
}
conversation.Annotations[acpConversationExplicitCWDKey] = "true"
}
Loading
Loading