diff --git a/go/sdk/variants/dispatch.go b/go/sdk/variants/dispatch.go index ef1440e..8145246 100644 --- a/go/sdk/variants/dispatch.go +++ b/go/sdk/variants/dispatch.go @@ -29,16 +29,8 @@ func (d *dispatcher) handle(ctx context.Context, method string, req mcp.Request, switch method { case "tools/list", "resources/list", "prompts/list", "resources/templates/list": return d.handleList(ctx, method, req) - case "tools/call", "resources/read", "prompts/get": - return d.handleCall(ctx, method, req) - case "resources/subscribe": - return d.handleSubscribe(ctx, req) - case "resources/unsubscribe": - return d.handleUnsubscribe(ctx, req) - case "completion/complete": - return d.handleCompletion(ctx, req) default: - return next(ctx, method, req) + return d.handleReceiveRedirect(ctx, method, req) } } @@ -66,16 +58,23 @@ func (d *dispatcher) createInvalidVariantError(ctx context.Context, requestedVar } } +// isParamsNil checks if params is nil or a typed-nil (a nil pointer wrapped in an interface). +// The SDK can produce typed-nil params for requests with no parameters. +func isParamsNil(params mcp.Params) bool { + if params == nil { + return true + } + v := reflect.ValueOf(params) + return v.Kind() == reflect.Ptr && v.IsNil() +} + // variantIDFromMeta extracts the variant ID from the request's _meta field. // Returns empty string if no variant is specified. Guards against typed-nil // params (e.g. (*ListToolsParams)(nil) wrapped in the mcp.Params interface) // which the SDK can produce for requests with no parameters. func variantIDFromMeta(req mcp.Request) string { params := req.GetParams() - if params == nil { - return "" - } - if v := reflect.ValueOf(params); v.Kind() == reflect.Ptr && v.IsNil() { + if isParamsNil(params) { return "" } meta := params.GetMeta() @@ -160,7 +159,7 @@ func enrichError(err error, variantID string) error { // List methods // --------------------------------------------------------------------------- -// handleList handles list methods by forwarding to the appropriate variant. +// handleList handles list methods using the generic backend session call method. // Implements cursor scoping per SEP-2053: unwraps incoming cursors and wraps outgoing cursors. func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { conn, err := d.getConnection(ctx, req) @@ -171,230 +170,61 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ backendSession := conn.backendSession variantID := backendSession.variantID params := req.GetParams() - extra := req.GetExtra() - switch method { - case "tools/list": - p, _ := params.(*mcp.ListToolsParams) - if p != nil { - injectVariantMeta(p, variantID) - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - } - result, err := backendSession.ListTools(ctx, p, extra) - if err != nil { - return nil, enrichError(err, variantID) - } - if result != nil && result.NextCursor != "" { - result.NextCursor = wrapCursor(result.NextCursor, variantID) - } - return result, nil - - case "resources/list": - p, _ := params.(*mcp.ListResourcesParams) - if p != nil { - injectVariantMeta(p, variantID) - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - } - result, err := backendSession.ListResources(ctx, p, extra) - if err != nil { - return nil, enrichError(err, variantID) - } - if result != nil && result.NextCursor != "" { - result.NextCursor = wrapCursor(result.NextCursor, variantID) - } - return result, nil - - case "prompts/list": - p, _ := params.(*mcp.ListPromptsParams) - if p != nil { - injectVariantMeta(p, variantID) - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - } - result, err := backendSession.ListPrompts(ctx, p, extra) - if err != nil { - return nil, enrichError(err, variantID) - } - if result != nil && result.NextCursor != "" { - result.NextCursor = wrapCursor(result.NextCursor, variantID) - } - return result, nil - - case "resources/templates/list": - p, _ := params.(*mcp.ListResourceTemplatesParams) - if p != nil { - injectVariantMeta(p, variantID) - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor + // Inject variant metadata and handle cursor unwrapping (guard against typed-nil params) + if !isParamsNil(params) { + injectVariantMeta(params, variantID) + + if f := reflect.ValueOf(params).Elem().FieldByName("Cursor"); f.IsValid() && f.String() != "" { + innerCursor, err := unwrapCursor(f.String(), variantID) + if err != nil { + return nil, err } + f.SetString(innerCursor) } - result, err := backendSession.ListResourceTemplates(ctx, p, extra) - if err != nil { - return nil, enrichError(err, variantID) - } - if result != nil && result.NextCursor != "" { - result.NextCursor = wrapCursor(result.NextCursor, variantID) - } - return result, nil - - default: - return nil, errors.New("unsupported list method: " + method) } -} -// --------------------------------------------------------------------------- -// Call methods -// --------------------------------------------------------------------------- - -// handleCall handles call methods (tools/call, resources/read, prompts/get). -func (d *dispatcher) handleCall(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - conn, err := d.getConnection(ctx, req) + // Generic dispatch - pass entire request object + result, err := backendSession.handleReceive(ctx, method, req) if err != nil { - return nil, err + return nil, enrichError(err, variantID) } - backendSession := conn.backendSession - variantID := backendSession.variantID - params := req.GetParams() - extra := req.GetExtra() - var result mcp.Result - - switch method { - case "tools/call": - raw, _ := params.(*mcp.CallToolParamsRaw) - if raw == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid tools/call params", - } - } - injectVariantMeta(raw, variantID) - result, err = backendSession.CallTool(ctx, raw, extra) - case "resources/read": - p, _ := params.(*mcp.ReadResourceParams) - if p == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid resources/read params", - } - } - injectVariantMeta(p, variantID) - result, err = backendSession.ReadResource(ctx, p, extra) - case "prompts/get": - p, _ := params.(*mcp.GetPromptParams) - if p == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid prompts/get params", - } - } - injectVariantMeta(p, variantID) - result, err = backendSession.GetPrompt(ctx, p, extra) - default: - return nil, errors.New("unsupported call method: " + method) + if f := reflect.ValueOf(result).Elem().FieldByName("NextCursor"); f.IsValid() && f.String() != "" { + f.SetString(wrapCursor(f.String(), variantID)) } - if err != nil { - return nil, enrichError(err, variantID) - } return result, nil } // --------------------------------------------------------------------------- -// Subscription methods +// Simple methods (no pagination) // --------------------------------------------------------------------------- -// handleSubscribe handles resources/subscribe. -func (d *dispatcher) handleSubscribe(ctx context.Context, req mcp.Request) (mcp.Result, error) { +// handleReceiveRedirect handles all simple methods (call, subscribe, unsubscribe, completion) +// that don't require special cursor handling. This consolidates what were previously +// separate handlers for handleCall, handleSubscribe, handleUnsubscribe, and handleCompletion. +func (d *dispatcher) handleReceiveRedirect(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { conn, err := d.getConnection(ctx, req) if err != nil { return nil, err } backendSession := conn.backendSession - params, _ := req.GetParams().(*mcp.SubscribeParams) - if params == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid resources/subscribe params", - } - } - injectVariantMeta(params, backendSession.variantID) - if err := backendSession.Subscribe(ctx, params, req.GetExtra()); err != nil { - return nil, enrichError(err, backendSession.variantID) - } - return nil, nil -} + variantID := backendSession.variantID + params := req.GetParams() -// handleUnsubscribe handles resources/unsubscribe. -// Per SEP-2053: "Servers MUST continue to accept resources/unsubscribe for -// existing subscription ids even if the underlying resource is no longer available." -func (d *dispatcher) handleUnsubscribe(ctx context.Context, req mcp.Request) (mcp.Result, error) { - conn, err := d.getConnection(ctx, req) - if err != nil { - return nil, err + // Inject variant metadata (guard against typed-nil params) + if !isParamsNil(params) { + injectVariantMeta(params, variantID) } - backendSession := conn.backendSession - params, _ := req.GetParams().(*mcp.UnsubscribeParams) - if params == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid resources/unsubscribe params", - } - } - injectVariantMeta(params, backendSession.variantID) - if err := backendSession.Unsubscribe(ctx, params, req.GetExtra()); err != nil { - return nil, enrichError(err, backendSession.variantID) - } - return nil, nil -} - -// --------------------------------------------------------------------------- -// Completion -// --------------------------------------------------------------------------- - -// handleCompletion handles completion/complete. -func (d *dispatcher) handleCompletion(ctx context.Context, req mcp.Request) (mcp.Result, error) { - conn, err := d.getConnection(ctx, req) + // Generic dispatch - pass entire request object + result, err := backendSession.handleReceive(ctx, method, req) if err != nil { - return nil, err + return nil, enrichError(err, variantID) } - backendSession := conn.backendSession - params, _ := req.GetParams().(*mcp.CompleteParams) - if params == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid completion/complete params", - } - } - injectVariantMeta(params, backendSession.variantID) - result, err := backendSession.Complete(ctx, params, req.GetExtra()) - if err != nil { - return nil, enrichError(err, backendSession.variantID) - } return result, nil } diff --git a/go/sdk/variants/session.go b/go/sdk/variants/session.go index add4563..3e57368 100644 --- a/go/sdk/variants/session.go +++ b/go/sdk/variants/session.go @@ -6,7 +6,7 @@ package variants import ( "context" - "fmt" + "reflect" "sync" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -42,110 +42,27 @@ type backendSession struct { mcpMethodHandler mcp.MethodHandler } -func (s *backendSession) ListTools(ctx context.Context, p *mcp.ListToolsParams, extra *mcp.RequestExtra) (*mcp.ListToolsResult, error) { - result, err := s.mcpMethodHandler(ctx, "tools/list", &mcp.ListToolsRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ListToolsResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for tools/list", result) - } - return r, nil -} - -func (s *backendSession) ListResources(ctx context.Context, p *mcp.ListResourcesParams, extra *mcp.RequestExtra) (*mcp.ListResourcesResult, error) { - result, err := s.mcpMethodHandler(ctx, "resources/list", &mcp.ListResourcesRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ListResourcesResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for resources/list", result) - } - return r, nil -} - -func (s *backendSession) ListPrompts(ctx context.Context, p *mcp.ListPromptsParams, extra *mcp.RequestExtra) (*mcp.ListPromptsResult, error) { - result, err := s.mcpMethodHandler(ctx, "prompts/list", &mcp.ListPromptsRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ListPromptsResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for prompts/list", result) - } - return r, nil -} - -func (s *backendSession) ListResourceTemplates(ctx context.Context, p *mcp.ListResourceTemplatesParams, extra *mcp.RequestExtra) (*mcp.ListResourceTemplatesResult, error) { - result, err := s.mcpMethodHandler(ctx, "resources/templates/list", &mcp.ListResourceTemplatesRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ListResourceTemplatesResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for resources/templates/list", result) - } - return r, nil -} - -func (s *backendSession) CallTool(ctx context.Context, p *mcp.CallToolParamsRaw, extra *mcp.RequestExtra) (*mcp.CallToolResult, error) { - result, err := s.mcpMethodHandler(ctx, "tools/call", &mcp.CallToolRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.CallToolResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for tools/call", result) +// handleReceive invokes mcpMethodHandler for any MCP method by modifying the request's +// Session field to point to the inner server session. This replaces the explicit +// per-method functions (ListTools, CallTool, etc.) with a single generic handler. +// +// The dispatcher is responsible for modifying params (metadata injection, +// cursor unwrapping) before calling this method. +func (s *backendSession) handleReceive(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + // Use reflection to modify the Session field in place. We can't wrap the + // request because the SDK's receiving handler does type assertions on the + // concrete request type (e.g., *mcp.ServerRequest[*mcp.CallToolParamsRaw]). + reqVal := reflect.ValueOf(req) + if reqVal.Kind() == reflect.Ptr { + reqVal = reqVal.Elem() } - return r, nil -} -func (s *backendSession) ReadResource(ctx context.Context, p *mcp.ReadResourceParams, extra *mcp.RequestExtra) (*mcp.ReadResourceResult, error) { - result, err := s.mcpMethodHandler(ctx, "resources/read", &mcp.ReadResourceRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ReadResourceResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for resources/read", result) + sessionField := reqVal.FieldByName("Session") + if sessionField.IsValid() && sessionField.CanSet() { + sessionField.Set(reflect.ValueOf(s.serverSession)) } - return r, nil -} -func (s *backendSession) GetPrompt(ctx context.Context, p *mcp.GetPromptParams, extra *mcp.RequestExtra) (*mcp.GetPromptResult, error) { - result, err := s.mcpMethodHandler(ctx, "prompts/get", &mcp.GetPromptRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.GetPromptResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for prompts/get", result) - } - return r, nil -} - -func (s *backendSession) Subscribe(ctx context.Context, p *mcp.SubscribeParams, extra *mcp.RequestExtra) error { - _, err := s.mcpMethodHandler(ctx, "resources/subscribe", &mcp.SubscribeRequest{Session: s.serverSession, Params: p, Extra: extra}) - return err -} - -func (s *backendSession) Unsubscribe(ctx context.Context, p *mcp.UnsubscribeParams, extra *mcp.RequestExtra) error { - _, err := s.mcpMethodHandler(ctx, "resources/unsubscribe", &mcp.UnsubscribeRequest{Session: s.serverSession, Params: p, Extra: extra}) - return err -} - -func (s *backendSession) Complete(ctx context.Context, p *mcp.CompleteParams, extra *mcp.RequestExtra) (*mcp.CompleteResult, error) { - result, err := s.mcpMethodHandler(ctx, "completion/complete", &mcp.CompleteRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.CompleteResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for completion/complete", result) - } - return r, nil + return s.mcpMethodHandler(ctx, method, req) } // sessionState holds all per-session state for one front client.