diff --git a/go.mod b/go.mod index 3c78582..2bc0593 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,10 @@ module commander go 1.25.4 +// TEMP: local replace for unpublished squadron-wire OAuth proxy messages. +// Revert + publish a new squadron-wire tag before merging. +replace github.com/mlund01/squadron-wire => ../squadron-wire + require ( github.com/gorilla/websocket v1.5.3 github.com/mlund01/squadron-wire v0.0.40 diff --git a/internal/api/oauth.go b/internal/api/oauth.go new file mode 100644 index 0000000..1ff4683 --- /dev/null +++ b/internal/api/oauth.go @@ -0,0 +1,197 @@ +package api + +import ( + "encoding/json" + "fmt" + "html" + "log" + "net/http" + "time" + + "github.com/mlund01/squadron-wire/protocol" + + "commander/internal/hub" +) + +// HandleOAuthCallback serves GET /oauth/callback, the public URL IdPs +// redirect the user's browser to after authorization. The callback is +// routed to the right squadron instance via the cryptographic `state` +// value (which squadron reserved in advance via OAuthRegisterFlow). +// +// This handler is intentionally unauthenticated — IdPs do not carry +// commander session cookies. Security comes from the state value being +// unguessable and single-use. +func HandleOAuthCallback(h *hub.Hub) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + state := q.Get("state") + code := q.Get("code") + idpErr := q.Get("error") + if idpErrDesc := q.Get("error_description"); idpErrDesc != "" && idpErr != "" { + idpErr = idpErr + ": " + idpErrDesc + } + + if state == "" { + writeOAuthErrorPage(w, "callback missing state parameter") + return + } + + flow, ok := h.PendingFlows().Claim(state) + if !ok { + writeOAuthErrorPage(w, "no matching OAuth flow (it may have expired)") + return + } + + // Forward to the originating squadron. + env, err := protocol.NewRequest(protocol.TypeOAuthCallbackDelivery, &protocol.OAuthCallbackDeliveryPayload{ + State: state, + Code: code, + Error: idpErr, + }) + if err != nil { + writeOAuthErrorPage(w, "internal error building delivery: "+err.Error()) + return + } + resp, err := h.SendRequest(flow.InstanceID, env, 30*time.Second) + if err != nil { + writeOAuthErrorPage(w, "failed to deliver callback to squadron: "+err.Error()) + return + } + if resp.Type == protocol.TypeError { + var perr protocol.ErrorPayload + _ = protocol.DecodePayload(resp, &perr) + writeOAuthErrorPage(w, "squadron rejected callback: "+perr.Message) + return + } + + // Notify any open commander tabs for this instance. + success := idpErr == "" && code != "" + noteType := "oauth_completed" + if !success { + noteType = "oauth_failed" + } + h.Notifications().Publish(flow.InstanceID, hub.Notification{ + Type: noteType, + Data: map[string]interface{}{ + "mcpName": flow.McpName, + "error": idpErr, + }, + }) + + if success { + writeOAuthSuccessPage(w, flow.McpName) + } else { + writeOAuthErrorPage(w, idpErr) + } + } +} + +// HandleStartOAuth kicks off a commander-initiated OAuth login for the +// named MCP server on the specified squadron. Returns the authorization URL +// for the browser to open in a new tab. +func HandleStartOAuth(h *hub.Hub) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + instanceID := r.PathValue("id") + mcpName := r.PathValue("name") + + env, err := protocol.NewRequest(protocol.TypeStartMCPLogin, &protocol.StartMCPLoginPayload{ + McpName: mcpName, + }) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + resp, err := h.SendRequest(instanceID, env, 30*time.Second) + if err != nil { + writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()}) + return + } + if resp.Type == protocol.TypeError { + var perr protocol.ErrorPayload + _ = protocol.DecodePayload(resp, &perr) + writeJSON(w, http.StatusBadGateway, map[string]string{"error": perr.Message}) + return + } + var ack protocol.StartMCPLoginAckPayload + if err := protocol.DecodePayload(resp, &ack); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + if !ack.Accepted { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": ack.Reason}) + return + } + writeJSON(w, http.StatusOK, map[string]string{"authUrl": ack.AuthURL}) + } +} + +// HandleNotifications opens an SSE stream of per-instance notifications +// (e.g. oauth_completed). Used by the commander SPA to surface toasts. +func HandleNotifications(h *hub.Hub) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + instanceID := r.PathValue("id") + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming unsupported", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + ch, cleanup := h.Notifications().Subscribe(instanceID) + defer cleanup() + + // Initial comment line so the connection is flushed immediately. + fmt.Fprint(w, ": connected\n\n") + flusher.Flush() + + keepalive := time.NewTicker(30 * time.Second) + defer keepalive.Stop() + + for { + select { + case <-r.Context().Done(): + return + case note, ok := <-ch: + if !ok { + return + } + data, err := json.Marshal(note) + if err != nil { + log.Printf("notification marshal: %v", err) + continue + } + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + case <-keepalive.C: + fmt.Fprint(w, ": keepalive\n\n") + flusher.Flush() + } + } + } +} + +func writeOAuthSuccessPage(w http.ResponseWriter, mcpName string) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = fmt.Fprintf(w, ` +
%s is now connected. You can close this window.
+ +`, html.EscapeString(mcpName)) +} + +func writeOAuthErrorPage(w http.ResponseWriter, msg string) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, ` +%s
+You can close this window and try again from the command center UI.
+`, html.EscapeString(msg)) +} diff --git a/internal/api/routes.go b/internal/api/routes.go index 2daa36b..fe89a19 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -69,6 +69,12 @@ func RegisterRoutes(mux *http.ServeMux, h *hub.Hub, ka *keepalive.KeepAlive) { mux.HandleFunc("GET /api/instances/{id}/agents/{name}/chats", handleChatHistory(h)) mux.HandleFunc("GET /api/instances/{id}/chats/{sessionId}/messages", handleChatMessages(h)) mux.HandleFunc("DELETE /api/instances/{id}/chats/{sessionId}", handleArchiveChat(h)) + + // OAuth proxy: start a login flow, stream completion notifications. + // The public callback endpoint (/oauth/callback) is registered separately + // on the outer mux so IdPs can reach it without auth. + mux.HandleFunc("POST /api/instances/{id}/mcp/{name}/oauth/start", HandleStartOAuth(h)) + mux.HandleFunc("GET /api/instances/{id}/notifications", HandleNotifications(h)) } func handleListInstances(h *hub.Hub) http.HandlerFunc { diff --git a/internal/hub/connection.go b/internal/hub/connection.go index 0049670..b571a06 100644 --- a/internal/hub/connection.go +++ b/internal/hub/connection.go @@ -371,11 +371,44 @@ func (c *Connection) dispatch(env *protocol.Envelope) { c.fanOutChatEvent(env) case protocol.TypeChatComplete: c.fanOutChatComplete(env) + case protocol.TypeOAuthRegisterFlow: + c.handleOAuthRegisterFlow(env) default: log.Printf("Unhandled message type: %s", env.Type) } } +// handleOAuthRegisterFlow records a pending OAuth flow for later callback +// routing. Called when a squadron kicks off an MCP login and asks commander +// to reserve the `state` value. +func (c *Connection) handleOAuthRegisterFlow(env *protocol.Envelope) { + var payload protocol.OAuthRegisterFlowPayload + if err := protocol.DecodePayload(env, &payload); err != nil { + log.Printf("Invalid oauth_register_flow payload: %v", err) + ack, _ := protocol.NewError(env.RequestID, "decode_error", err.Error()) + c.Send(ack) + return + } + if c.instanceID == "" { + ack, _ := protocol.NewError(env.RequestID, "not_registered", "instance not registered yet") + c.Send(ack) + return + } + if payload.State == "" { + ack, _ := protocol.NewResponse(env.RequestID, protocol.TypeOAuthRegisterFlowAck, &protocol.OAuthRegisterFlowAckPayload{ + Accepted: false, + Reason: "state is required", + }) + c.Send(ack) + return + } + c.hub.PendingFlows().Register(payload.State, c.instanceID, payload.McpName) + ack, _ := protocol.NewResponse(env.RequestID, protocol.TypeOAuthRegisterFlowAck, &protocol.OAuthRegisterFlowAckPayload{ + Accepted: true, + }) + c.Send(ack) +} + func (c *Connection) handleRegister(env *protocol.Envelope) { var payload protocol.RegisterPayload if err := protocol.DecodePayload(env, &payload); err != nil { diff --git a/internal/hub/hub.go b/internal/hub/hub.go index 8e01f13..57c4052 100644 --- a/internal/hub/hub.go +++ b/internal/hub/hub.go @@ -8,6 +8,8 @@ import ( "github.com/gorilla/websocket" "github.com/mlund01/squadron-wire/protocol" + + oauthflows "commander/internal/oauth" ) var upgrader = websocket.Upgrader{ @@ -19,6 +21,8 @@ type Hub struct { mu sync.RWMutex connections map[string]*Connection // instanceID → connection registry *Registry + pendingFlows *oauthflows.PendingFlows + notifications *Notifications AllowConfigEdit bool } @@ -27,10 +31,18 @@ func New(allowConfigEdit bool) *Hub { return &Hub{ connections: make(map[string]*Connection), registry: NewRegistry(), + pendingFlows: oauthflows.New(), + notifications: NewNotifications(), AllowConfigEdit: allowConfigEdit, } } +// PendingFlows returns the OAuth flow store. +func (h *Hub) PendingFlows() *oauthflows.PendingFlows { return h.pendingFlows } + +// Notifications returns the per-instance notification fan-out. +func (h *Hub) Notifications() *Notifications { return h.notifications } + // Start initializes background tasks (heartbeat, cleanup, etc.). func (h *Hub) Start() { // TODO: Start heartbeat ticker diff --git a/internal/hub/notifications.go b/internal/hub/notifications.go new file mode 100644 index 0000000..38c2328 --- /dev/null +++ b/internal/hub/notifications.go @@ -0,0 +1,68 @@ +package hub + +import ( + "sync" + "time" +) + +// Notification is a generic per-instance event pushed to any open browser +// tab subscribed to that instance. Initially used to confirm OAuth-proxy +// MCP logins; designed to accept future types without schema churn. +type Notification struct { + Type string `json:"type"` // e.g. "oauth_completed" + Timestamp time.Time `json:"timestamp"` + Data map[string]interface{} `json:"data,omitempty"` +} + +// Notifications fans out per-instance notifications to SSE subscribers. +// Unlike the mission-event fan-out on Connection, notifications are keyed +// by instanceID (not missionID) and have no buffer — they are ephemeral +// hints, not reliable history. Subscribers that aren't listening when an +// event fires will miss it. +type Notifications struct { + mu sync.Mutex + subs map[string][]chan Notification // instanceID → subscribers +} + +// NewNotifications creates an empty fan-out. +func NewNotifications() *Notifications { + return &Notifications{subs: make(map[string][]chan Notification)} +} + +// Subscribe returns a channel for the given instance's notifications and a +// cleanup function to remove the subscription. +func (n *Notifications) Subscribe(instanceID string) (chan Notification, func()) { + ch := make(chan Notification, 16) + n.mu.Lock() + n.subs[instanceID] = append(n.subs[instanceID], ch) + n.mu.Unlock() + return ch, func() { + n.mu.Lock() + defer n.mu.Unlock() + subs := n.subs[instanceID] + for i, s := range subs { + if s == ch { + n.subs[instanceID] = append(subs[:i], subs[i+1:]...) + break + } + } + close(ch) + } +} + +// Publish delivers a notification to all subscribers for the instance. +// Slow subscribers are skipped (no blocking). +func (n *Notifications) Publish(instanceID string, note Notification) { + if note.Timestamp.IsZero() { + note.Timestamp = time.Now() + } + n.mu.Lock() + subs := append([]chan Notification(nil), n.subs[instanceID]...) + n.mu.Unlock() + for _, ch := range subs { + select { + case ch <- note: + default: + } + } +} diff --git a/internal/oauth/flows.go b/internal/oauth/flows.go new file mode 100644 index 0000000..97cf1da --- /dev/null +++ b/internal/oauth/flows.go @@ -0,0 +1,104 @@ +// Package oauth is the command center side of the OAuth proxy. +// +// When a squadron instance wants to authenticate against an MCP server's +// OAuth provider, it asks commander to reserve an entry in the flow store +// keyed by the cryptographic `state` value. When the IdP later redirects +// the user's browser to `