From bad6aae19cd8f20d431080ce9bf28791339184f3 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 12 Jan 2026 13:46:41 -0800 Subject: [PATCH 1/9] Improved interrupts API. --- go/ai/generate.go | 30 +- go/ai/generate_test.go | 53 +-- go/ai/option.go | 7 +- go/ai/option_test.go | 11 +- go/ai/tools.go | 311 +++++++++++--- go/ai/tools_test.go | 451 ++++++++++++++++++++- go/genkit/genkit.go | 13 +- go/internal/base/json.go | 26 ++ go/plugins/compat_oai/generate.go | 14 +- go/plugins/googlegenai/gemini.go | 12 +- go/plugins/googlegenai/imagen.go | 5 +- go/plugins/internal/anthropic/anthropic.go | 17 +- go/samples/basic/main.go | 3 +- 13 files changed, 813 insertions(+), 140 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 0d3d75e7db..73208da40f 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -30,6 +30,7 @@ import ( "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal/base" + "github.com/google/uuid" ) // Model represents a model that can generate content based on a request. @@ -361,6 +362,9 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, err } + // Ensure all tool requests have unique refs for matching during resume. + ensureToolRequestRefs(resp.Message) + // If this is a long-running operation response, return it immediately without further processing if bm != nil && resp.Operation != nil { return resp, nil @@ -552,6 +556,9 @@ func GenerateText(ctx context.Context, r api.Registry, opts ...GenerateOption) ( } // GenerateData runs a generate request and returns strongly-typed output. +// If the response doesn't contain text output (e.g., contains tool requests +// or interrupts instead), the output will be nil and no error is returned. +// Check resp.Interrupts() or resp.ToolRequests() to handle these cases. func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) (*Out, *ModelResponse, error) { var value Out opts = append(opts, WithOutputType(value)) @@ -561,9 +568,16 @@ func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...Generate return nil, nil, err } + // If there's no text content to parse (e.g., the response contains tool + // requests or interrupts), return nil output. The caller should check + // resp.Interrupts() or resp.ToolRequests() to handle these cases. + if resp.Text() == "" { + return nil, resp, nil + } + err = resp.Output(&value) if err != nil { - return nil, nil, err + return nil, resp, err } return &value, resp, nil @@ -711,6 +725,20 @@ func (m *model) supportsConstrained(hasTools bool) bool { return true } +// ensureToolRequestRefs assigns unique refs to tool request parts that don't have one. +// This ensures that when there are multiple calls to the same tool, each can be +// individually matched when resuming with Restart or Respond directives. +func ensureToolRequestRefs(msg *Message) { + if msg == nil { + return + } + for _, part := range msg.Content { + if part.IsToolRequest() && part.ToolRequest.Ref == "" { + part.ToolRequest.Ref = uuid.New().String() + } + } +} + // clone creates a deep copy of the provided object using JSON marshaling and unmarshaling. func clone[T any](obj *T) *T { if obj == nil { diff --git a/go/ai/generate_test.go b/go/ai/generate_test.go index 050a0f3cce..f07cc4ad77 100644 --- a/go/ai/generate_test.go +++ b/go/ai/generate_test.go @@ -993,13 +993,19 @@ func validTestMessage(m *Message, output *ModelOutputConfig) (*Message, error) { return handler.ParseMessage(m) } +type conditionalToolInput struct { + Value string + Interrupt bool +} + +type resumableToolInput struct { + Action string + Data string +} + func TestToolInterruptsAndResume(t *testing.T) { conditionalTool := DefineTool(r, "conditional", "tool that may interrupt based on input", - func(ctx *ToolContext, input struct { - Value string - Interrupt bool - }, - ) (string, error) { + func(ctx *ToolContext, input conditionalToolInput) (string, error) { if input.Interrupt { return "", ctx.Interrupt(&InterruptOptions{ Metadata: map[string]any{ @@ -1014,11 +1020,7 @@ func TestToolInterruptsAndResume(t *testing.T) { ) resumableTool := DefineTool(r, "resumable", "tool that can be resumed", - func(ctx *ToolContext, input struct { - Action string - Data string - }, - ) (string, error) { + func(ctx *ToolContext, input resumableToolInput) (string, error) { if ctx.Resumed != nil { resumedData, ok := ctx.Resumed["data"].(string) if ok { @@ -1156,11 +1158,12 @@ func TestToolInterruptsAndResume(t *testing.T) { interruptedPart := res.Message.Content[1] + newInput := conditionalToolInput{ + Value: "new_test_data", + Interrupt: false, + } restartPart := conditionalTool.Restart(interruptedPart, &RestartOptions{ - ReplaceInput: map[string]any{ - "Value": "new_test_data", - "Interrupt": false, - }, + ReplaceInput: newInput, ResumedMetadata: map[string]any{ "data": "resumed_data", "source": "restart", @@ -1175,17 +1178,17 @@ func TestToolInterruptsAndResume(t *testing.T) { t.Errorf("expected tool request name 'conditional', got %q", restartPart.ToolRequest.Name) } - newInput, ok := restartPart.ToolRequest.Input.(map[string]any) + replacedInput, ok := restartPart.ToolRequest.Input.(conditionalToolInput) if !ok { - t.Fatal("expected input to be map[string]any") + t.Fatalf("expected input to be conditionalInput, got %T", restartPart.ToolRequest.Input) } - if newInput["Value"] != "new_test_data" { - t.Errorf("expected new input value 'new_test_data', got %v", newInput["Value"]) + if replacedInput.Value != "new_test_data" { + t.Errorf("expected new input value 'new_test_data', got %v", replacedInput.Value) } - if newInput["Interrupt"] != false { - t.Errorf("expected interrupt to be false, got %v", newInput["Interrupt"]) + if replacedInput.Interrupt != false { + t.Errorf("expected interrupt to be false, got %v", replacedInput.Interrupt) } if _, hasInterrupt := restartPart.Metadata["interrupt"]; hasInterrupt { @@ -1242,11 +1245,13 @@ func TestToolInterruptsAndResume(t *testing.T) { } interruptedPart := res.Message.Content[1] + + newInput := conditionalToolInput{ + Value: "restarted_data", + Interrupt: false, + } restartPart := conditionalTool.Restart(interruptedPart, &RestartOptions{ - ReplaceInput: map[string]any{ - "Value": "restarted_data", - "Interrupt": false, - }, + ReplaceInput: newInput, ResumedMetadata: map[string]any{ "data": "restart_context", }, diff --git a/go/ai/option.go b/go/ai/option.go index 9e5ce1217f..d28c68e3e9 100644 --- a/go/ai/option.go +++ b/go/ai/option.go @@ -915,12 +915,15 @@ func (o *generateOptions) applyGenerate(genOpts *generateOptions) error { return nil } -// WithToolResponses sets the tool responses to return from interrupted tool calls. +// WithToolResponses provides resolved responses for interrupted tool calls. +// Use this when you already have the result and want to skip re-executing the tool. func WithToolResponses(parts ...*Part) GenerateOption { return &generateOptions{RespondParts: parts} } -// WithToolRestarts sets the tool requests to restart interrupted tools with. +// WithToolRestarts re-executes interrupted tool calls with additional metadata. +// Use this when the original call lacked required context (e.g., auth, user confirmation) +// that should now allow the tool to complete successfully. func WithToolRestarts(parts ...*Part) GenerateOption { return &generateOptions{RestartParts: parts} } diff --git a/go/ai/option_test.go b/go/ai/option_test.go index 04fee69a59..72796e2e9c 100644 --- a/go/ai/option_test.go +++ b/go/ai/option_test.go @@ -20,6 +20,7 @@ import ( "context" "testing" + "github.com/firebase/genkit/go/core/api" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" ) @@ -658,15 +659,7 @@ func (t *mockTool) RunRawMultipart(ctx context.Context, input any) (*MultipartTo return nil, nil } -func (t *mockTool) Respond(toolReq *Part, outputData any, opts *RespondOptions) *Part { - return nil -} - -func (t *mockTool) Restart(toolReq *Part, opts *RestartOptions) *Part { - return nil -} - -func (t *mockTool) Register(r interface{ RegisterValue(string, any) }) { +func (t *mockTool) Register(r api.Registry) { } func TestWithInputSchemaName(t *testing.T) { diff --git a/go/ai/tools.go b/go/ai/tools.go index 05c30bf096..33a2c712ed 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -54,11 +54,11 @@ func (t ToolName) Name() string { return (string)(t) } -// tool is an action with functions specific to tools. +// ToolDef is an action with functions specific to tools. // Internally, all tools use the v2 format (returning MultipartToolResponse). // For regular tools, RunRaw unwraps the Output field for backward compatibility. -type tool struct { - api.Action +type ToolDef[In, Out any] struct { + action api.Action // The underlying action. multipart bool // Whether this is a multipart-only tool. registry api.Registry // Registry for schema resolution. Set when registered. } @@ -88,6 +88,12 @@ type toolInterruptError struct { } func (e *toolInterruptError) Error() string { + if e.Metadata != nil { + data, err := json.MarshalIndent(e.Metadata, "", " ") + if err == nil { + return fmt.Sprintf("tool execution interrupted: \n\n%s", string(data)) + } + } return "tool execution interrupted" } @@ -122,13 +128,62 @@ type RespondOptions struct { Metadata map[string]any } +// RespondWithOption is a functional option for [ToolDef.RespondWith]. +type RespondWithOption[Out any] interface { + applyRespondWith(*RespondOptions) error +} + +// applyRespondWith applies the option to the respond options. +func (o *RespondOptions) applyRespondWith(opts *RespondOptions) error { + if o.Metadata != nil { + if opts.Metadata != nil { + return errors.New("cannot set metadata more than once (WithResponseMetadata)") + } + opts.Metadata = o.Metadata + } + return nil +} + +// WithResponseMetadata sets metadata for the response. +func WithResponseMetadata[Out any](meta map[string]any) RespondWithOption[Out] { + return &RespondOptions{Metadata: meta} +} + +// RestartWithOption is a functional option for [ToolDef.RestartWith]. +type RestartWithOption[In any] interface { + applyRestartWith(*RestartOptions) error +} + +// applyRestartWith applies the option to the restart options. +func (o *RestartOptions) applyRestartWith(opts *RestartOptions) error { + if o.ReplaceInput != nil { + if opts.ReplaceInput != nil { + return errors.New("cannot set replace input more than once (WithReplaceInput)") + } + opts.ReplaceInput = o.ReplaceInput + } + if o.ResumedMetadata != nil { + if opts.ResumedMetadata != nil { + return errors.New("cannot set resumed metadata more than once (WithResumedMetadata)") + } + opts.ResumedMetadata = o.ResumedMetadata + } + return nil +} + +// WithReplaceInput sets a new input value to replace the original tool request input. +func WithReplaceInput[In any](input In) RestartWithOption[In] { + return &RestartOptions{ReplaceInput: input} +} + +// WithResumedMetadata sets metadata to pass to the resumed tool execution. +func WithResumedMetadata[In any](meta any) RestartWithOption[In] { + return &RestartOptions{ResumedMetadata: meta} +} + // ToolContext provides context and utility functions for tool execution. type ToolContext struct { context.Context - // Interrupt is a function that can be used to interrupt the tool execution. - // Interrupting tool execution returns the control to the caller with the - // total model response so far. - Interrupt func(opts *InterruptOptions) error // Resumed is optional metadata that can be used to resume the tool execution. // Map is not nil only if the tool was interrupted. Resumed map[string]any @@ -136,14 +191,85 @@ type ToolContext struct { OriginalInput any } -// DefineTool creates a new [Tool] and registers it. +// Interrupt interrupts the tool execution and returns control to the caller +// with the total model response so far. The provided metadata is preserved +// and passed back via [ToolContext.Resumed] when the tool is restarted. +func (tc *ToolContext) Interrupt(opts *InterruptOptions) error { + if opts == nil { + opts = &InterruptOptions{} + } + return &toolInterruptError{ + Metadata: opts.Metadata, + } +} + +// InterruptWith is a convenience function to interrupt a tool with a strongly-typed metadata value. +// The metadata is converted to map[string]any via JSON marshaling. +func InterruptWith[T any](tc *ToolContext, meta T) error { + m, err := base.StructToMap(meta) + if err != nil { + return fmt.Errorf("InterruptWith: failed to convert metadata: %w", err) + } + return tc.Interrupt(&InterruptOptions{Metadata: m}) +} + +// InterruptAs extracts strongly-typed metadata from an interrupted tool request [Part]. +// Returns the zero value and false if the part is not an interrupt or the type doesn't match. +func InterruptAs[T any](p *Part) (T, bool) { + var zero T + if p == nil || !p.IsInterrupt() { + return zero, false + } + meta, ok := p.Metadata["interrupt"].(map[string]any) + if !ok { + return zero, false + } + result, err := base.MapToStruct[T](meta) + if err != nil { + return zero, false + } + return result, true +} + +// IsResumed returns true if this tool execution is a resumption after an interrupt. +func (tc *ToolContext) IsResumed() bool { + return tc.Resumed != nil +} + +// ResumedValue retrieves a typed value from the Resumed metadata. +// Returns the zero value and false if the key doesn't exist or the type doesn't match. +func ResumedValue[T any](tc *ToolContext, key string) (T, bool) { + var zero T + if tc.Resumed == nil { + return zero, false + } + v, ok := tc.Resumed[key] + if !ok { + return zero, false + } + typed, ok := v.(T) + return typed, ok +} + +// OriginalInputAs returns the original input typed appropriately. +// Returns the zero value and false if not resumed or type doesn't match. +func OriginalInputAs[T any](tc *ToolContext) (T, bool) { + var zero T + if tc.OriginalInput == nil { + return zero, false + } + typed, ok := tc.OriginalInput.(T) + return typed, ok +} + +// DefineTool creates a new [ToolDef] and registers it. // Use [WithInputSchema] to provide a custom JSON schema instead of inferring from the type parameter. func DefineTool[In, Out any]( r api.Registry, name, description string, fn ToolFunc[In, Out], opts ...ToolOption, -) Tool { +) *ToolDef[In, Out] { toolOpts := &toolOptions{} for _, opt := range opts { if err := opt.applyTool(toolOpts); err != nil { @@ -166,10 +292,10 @@ func DefineTool[In, Out any]( provider, id := api.ParseName(name) r.RegisterAction(api.NewKey(api.ActionTypeTool, provider, id), action) - return &tool{Action: action, multipart: false, registry: r} + return &ToolDef[In, Out]{action: action, multipart: false, registry: r} } -// DefineToolWithInputSchema creates a new [Tool] with a custom input schema and registers it. +// DefineToolWithInputSchema creates a new [ToolDef] with a custom input schema and registers it. // // Deprecated: Use [DefineTool] with [WithInputSchema] instead. func DefineToolWithInputSchema[Out any]( @@ -177,13 +303,13 @@ func DefineToolWithInputSchema[Out any]( name, description string, inputSchema map[string]any, fn ToolFunc[any, Out], -) Tool { +) *ToolDef[any, Out] { return DefineTool(r, name, description, fn, WithInputSchema(inputSchema)) } -// NewTool creates a new [Tool]. It can be passed directly to [Generate]. +// NewTool creates a new [ToolDef]. It can be passed directly to [Generate]. // Use [WithInputSchema] to provide a custom JSON schema instead of inferring from the type parameter. -func NewTool[In, Out any](name, description string, fn ToolFunc[In, Out], opts ...ToolOption) Tool { +func NewTool[In, Out any](name, description string, fn ToolFunc[In, Out], opts ...ToolOption) *ToolDef[In, Out] { toolOpts := &toolOptions{} for _, opt := range opts { if err := opt.applyTool(toolOpts); err != nil { @@ -203,17 +329,17 @@ func NewTool[In, Out any](name, description string, fn ToolFunc[In, Out], opts . metadata, wrappedFn := wrapToolFunc(name, description, fn) metadata["dynamic"] = true action := core.NewAction(name, api.ActionTypeToolV2, metadata, toolOpts.InputSchema, wrappedFn) - return &tool{Action: action, multipart: false} + return &ToolDef[In, Out]{action: action, multipart: false} } -// NewToolWithInputSchema creates a new [Tool] with a custom input schema. It can be passed directly to [Generate]. +// NewToolWithInputSchema creates a new [ToolDef] with a custom input schema. It can be passed directly to [Generate]. // // Deprecated: Use [NewTool] with [WithInputSchema] instead. -func NewToolWithInputSchema[Out any](name, description string, inputSchema map[string]any, fn ToolFunc[any, Out]) Tool { +func NewToolWithInputSchema[Out any](name, description string, inputSchema map[string]any, fn ToolFunc[any, Out]) *ToolDef[any, Out] { return NewTool(name, description, fn, WithInputSchema(inputSchema)) } -// DefineMultipartTool creates a new multipart [Tool] and registers it. +// DefineMultipartTool creates a new multipart [ToolDef] and registers it. // Multipart tools can return both output data and additional content parts (like media). // Use [WithInputSchema] to provide a custom JSON schema instead of inferring from the type parameter. func DefineMultipartTool[In any]( @@ -221,7 +347,7 @@ func DefineMultipartTool[In any]( name, description string, fn MultipartToolFunc[In], opts ...ToolOption, -) Tool { +) *ToolDef[In, *MultipartToolResponse] { toolOpts := &toolOptions{} for _, opt := range opts { if err := opt.applyTool(toolOpts); err != nil { @@ -231,13 +357,13 @@ func DefineMultipartTool[In any]( metadata, wrappedFn := wrapMultipartToolFunc(name, description, fn) action := core.DefineAction(r, name, api.ActionTypeToolV2, metadata, toolOpts.InputSchema, wrappedFn) - return &tool{Action: action, multipart: true, registry: r} + return &ToolDef[In, *MultipartToolResponse]{action: action, multipart: true, registry: r} } -// NewMultipartTool creates a new multipart [Tool]. It can be passed directly to [Generate]. +// NewMultipartTool creates a new multipart [ToolDef]. It can be passed directly to [Generate]. // Multipart tools can return both output data and additional content parts (like media). // Use [WithInputSchema] to provide a custom JSON schema instead of inferring from the type parameter. -func NewMultipartTool[In any](name, description string, fn MultipartToolFunc[In], opts ...ToolOption) Tool { +func NewMultipartTool[In any](name, description string, fn MultipartToolFunc[In], opts ...ToolOption) *ToolDef[In, *MultipartToolResponse] { toolOpts := &toolOptions{} for _, opt := range opts { if err := opt.applyTool(toolOpts); err != nil { @@ -248,7 +374,7 @@ func NewMultipartTool[In any](name, description string, fn MultipartToolFunc[In] metadata, wrappedFn := wrapMultipartToolFunc(name, description, fn) metadata["dynamic"] = true action := core.NewAction(name, api.ActionTypeToolV2, metadata, toolOpts.InputSchema, wrappedFn) - return &tool{Action: action, multipart: true} + return &ToolDef[In, *MultipartToolResponse]{action: action, multipart: true} } // wrapToolFunc wraps a regular tool function to return MultipartToolResponse. @@ -271,12 +397,7 @@ func wrapToolFunc[In, Out any](name, description string, fn ToolFunc[In, Out]) ( wrappedFn := func(ctx context.Context, input In) (*MultipartToolResponse, error) { toolCtx := &ToolContext{ - Context: ctx, - Interrupt: func(opts *InterruptOptions) error { - return &toolInterruptError{ - Metadata: opts.Metadata, - } - }, + Context: ctx, Resumed: resumedCtxKey.FromContext(ctx), OriginalInput: origInputCtxKey.FromContext(ctx), } @@ -299,12 +420,7 @@ func wrapMultipartToolFunc[In any](name, description string, fn MultipartToolFun } wrappedFn := func(ctx context.Context, input In) (*MultipartToolResponse, error) { toolCtx := &ToolContext{ - Context: ctx, - Interrupt: func(opts *InterruptOptions) error { - return &toolInterruptError{ - Metadata: opts.Metadata, - } - }, + Context: ctx, Resumed: resumedCtxKey.FromContext(ctx), OriginalInput: origInputCtxKey.FromContext(ctx), } @@ -313,9 +429,14 @@ func wrapMultipartToolFunc[In any](name, description string, fn MultipartToolFun return metadata, wrappedFn } +// Name returns the name of the tool. +func (t *ToolDef[In, Out]) Name() string { + return t.action.Name() +} + // Definition returns [ToolDefinition] for for this tool. -func (t *tool) Definition() *ToolDefinition { - desc := t.Action.Desc() +func (t *ToolDef[In, Out]) Definition() *ToolDefinition { + desc := t.action.Desc() // Resolve the input schema if it contains a $ref. inputSchema := desc.InputSchema @@ -350,18 +471,18 @@ func (t *tool) Definition() *ToolDefinition { } // Register registers the tool with the given registry. -func (t *tool) Register(r api.Registry) { +func (t *ToolDef[In, Out]) Register(r api.Registry) { t.registry = r - t.Action.Register(r) + t.action.Register(r) if !t.multipart { // Also register under the "tool" key for backward compatibility. - provider, id := api.ParseName(t.Action.Name()) - r.RegisterAction(api.NewKey(api.ActionTypeTool, provider, id), t.Action) + provider, id := api.ParseName(t.action.Name()) + r.RegisterAction(api.NewKey(api.ActionTypeTool, provider, id), t.action) } } // RunRaw runs this tool using the provided raw map format data (JSON parsed as map[string]any). -func (t *tool) RunRaw(ctx context.Context, input any) (any, error) { +func (t *ToolDef[In, Out]) RunRaw(ctx context.Context, input any) (any, error) { resp, err := t.RunRawMultipart(ctx, input) if err != nil { return nil, err @@ -371,7 +492,7 @@ func (t *tool) RunRaw(ctx context.Context, input any) (any, error) { // RunRawMultipart runs this tool using the provided raw map format data (JSON parsed as map[string]any). // It returns the full multipart response. -func (t *tool) RunRawMultipart(ctx context.Context, input any) (*MultipartToolResponse, error) { +func (t *ToolDef[In, Out]) RunRawMultipart(ctx context.Context, input any) (*MultipartToolResponse, error) { if t == nil { return nil, core.NewError(core.INVALID_ARGUMENT, "ai.Tool.RunRawMultipart: tool called on a nil tool; check that all tools are defined") } @@ -380,7 +501,7 @@ func (t *tool) RunRawMultipart(ctx context.Context, input any) (*MultipartToolRe if err != nil { return nil, fmt.Errorf("error marshalling tool input for %v: %v", t.Name(), err) } - output, err := t.Action.RunJSON(ctx, mi, nil) + output, err := t.action.RunJSON(ctx, mi, nil) if err != nil { return nil, fmt.Errorf("error calling tool %v: %w", t.Name(), err) } @@ -394,6 +515,7 @@ func (t *tool) RunRawMultipart(ctx context.Context, input any) (*MultipartToolRe // LookupTool looks up the tool in the registry by provided name and returns it. // It checks for "tool.v2" first, then falls back to "tool" for legacy compatibility. +// Since the types are not known at lookup time, it returns a type-erased tool. func LookupTool(r api.Registry, name string) Tool { if name == "" { return nil @@ -423,17 +545,19 @@ func LookupTool(r api.Registry, name string) Tool { } } - return &tool{Action: action, multipart: multipart, registry: r} + return &ToolDef[any, any]{action: action, multipart: multipart, registry: r} } // IsMultipart returns true if the tool is a multipart tool (tool.v2 only). -func (t *tool) IsMultipart() bool { +func (t *ToolDef[In, Out]) IsMultipart() bool { return t.multipart } -// Respond creates a tool response for an interrupted tool call to pass to the [WithToolResponses] option to [Generate]. -// If the part provided is not a tool request, it returns nil. -func (t *tool) Respond(toolReq *Part, output any, opts *RespondOptions) *Part { +// Respond creates a part for [WithToolResponses] to provide a resolved response for an interrupted tool call. +// Returns nil if the part is not a tool request. +// +// Deprecated: Use [ToolDef.RespondWith] instead for strongly-typed options. +func (t *ToolDef[In, Out]) Respond(toolReq *Part, output any, opts *RespondOptions) *Part { if toolReq == nil || !toolReq.IsToolRequest() { return nil } @@ -453,9 +577,11 @@ func (t *tool) Respond(toolReq *Part, output any, opts *RespondOptions) *Part { return newToolResp } -// Restart creates a tool request for an interrupted tool call to pass to the [WithToolRestarts] option to [Generate]. -// If the part provided is not a tool request, it returns nil. -func (t *tool) Restart(p *Part, opts *RestartOptions) *Part { +// Restart creates a part for [WithToolRestarts] to re-execute an interrupted tool call with additional context. +// Returns nil if the part is not a tool request. +// +// Deprecated: Use [ToolDef.RestartWith] instead for strongly-typed options. +func (t *ToolDef[In, Out]) Restart(p *Part, opts *RestartOptions) *Part { if p == nil || !p.IsToolRequest() { return nil } @@ -498,6 +624,87 @@ func (t *tool) Restart(p *Part, opts *RestartOptions) *Part { return newToolReq } +// RespondWith creates a part for [WithToolResponses] to provide a resolved response for an interrupted tool call. +// Returns nil if the part is not a tool request. +// +// Example: +// +// part := myTool.RespondWith(toolReq, output, WithResponseMetadata[MyOutput](meta)) +func (t *ToolDef[In, Out]) RespondWith(toolReq *Part, output Out, opts ...RespondWithOption[Out]) *Part { + if toolReq == nil || !toolReq.IsToolRequest() { + return nil + } + + cfg := &RespondOptions{} + for _, opt := range opts { + if err := opt.applyRespondWith(cfg); err != nil { + panic(fmt.Errorf("ai.ToolDef.RespondWith: %w", err)) + } + } + + newToolResp := NewResponseForToolRequest(toolReq, output) + newToolResp.Metadata = map[string]any{ + "interruptResponse": true, + } + if cfg.Metadata != nil { + newToolResp.Metadata["interruptResponse"] = cfg.Metadata + } + + return newToolResp +} + +// RestartWith creates a part for [WithToolRestarts] to re-execute an interrupted tool call with additional context. +// Returns nil if the part is not a tool request. +// +// Example: +// +// part := myTool.RestartWith(toolReq, WithReplaceInput(newInput), WithResumedMetadata[MyInput](meta)) +func (t *ToolDef[In, Out]) RestartWith(toolReq *Part, opts ...RestartWithOption[In]) *Part { + if toolReq == nil || !toolReq.IsToolRequest() { + return nil + } + + cfg := &RestartOptions{} + for _, opt := range opts { + if err := opt.applyRestartWith(cfg); err != nil { + panic(fmt.Errorf("ai.ToolDef.RestartWith: %w", err)) + } + } + + newInput := toolReq.ToolRequest.Input + var originalInput any + + if cfg.ReplaceInput != nil { + originalInput = newInput + newInput = cfg.ReplaceInput + } + + newMeta := maps.Clone(toolReq.Metadata) + if newMeta == nil { + newMeta = make(map[string]any) + } + + newMeta["resumed"] = true + if cfg.ResumedMetadata != nil { + newMeta["resumed"] = cfg.ResumedMetadata + } + + if originalInput != nil { + newMeta["replacedInput"] = originalInput + } + + delete(newMeta, "interrupt") + + newToolReqPart := NewToolRequestPart(&ToolRequest{ + Name: toolReq.ToolRequest.Name, + Ref: toolReq.ToolRequest.Ref, + Input: newInput, + }) + newToolReqPart.Metadata = newMeta + + return newToolReqPart +} + // resolveUniqueTools resolves the list of tool refs to a list of all tool names and new tools that must be registered. // Returns an error if there are tool refs with duplicate names. func resolveUniqueTools(r api.Registry, toolRefs []ToolRef) (toolNames []string, newTools []Tool, err error) { diff --git a/go/ai/tools_test.go b/go/ai/tools_test.go index 857be8c2b2..5386601732 100644 --- a/go/ai/tools_test.go +++ b/go/ai/tools_test.go @@ -44,9 +44,18 @@ func TestToolName(t *testing.T) { } func TestToolInterruptError(t *testing.T) { - t.Run("Error returns fixed message", func(t *testing.T) { + t.Run("Error includes metadata when present", func(t *testing.T) { err := &toolInterruptError{Metadata: map[string]any{"key": "value"}} got := err.Error() + want := "tool execution interrupted: \n\n{\n \"key\": \"value\"\n}" + if got != want { + t.Errorf("Error() = %q, want %q", got, want) + } + }) + + t.Run("Error returns simple message when no metadata", func(t *testing.T) { + err := &toolInterruptError{} + got := err.Error() want := "tool execution interrupted" if got != want { t.Errorf("Error() = %q, want %q", got, want) @@ -653,14 +662,19 @@ func TestToolRestart(t *testing.T) { }) reqPart.Metadata = map[string]any{"interrupt": true} + newInputVal := struct { + Value int `json:"value"` + }{Value: 20} opts := &RestartOptions{ - ReplaceInput: map[string]any{"value": 20}, + ReplaceInput: newInputVal, } restart := tl.Restart(reqPart, opts) - newInput := restart.ToolRequest.Input.(map[string]any) - if newInput["value"] != 20 { - t.Errorf("new input value = %v, want 20", newInput["value"]) + newInput := restart.ToolRequest.Input.(struct { + Value int `json:"value"` + }) + if newInput.Value != 20 { + t.Errorf("new input value = %v, want 20", newInput.Value) } if restart.Metadata["replacedInput"] == nil { t.Error("replacedInput not set in metadata") @@ -858,9 +872,7 @@ func TestIsMultipart(t *testing.T) { return "result", nil }) - // IsMultipart is on the internal *tool type, so we need to type assert - internalTool := tl.(*tool) - if internalTool.IsMultipart() { + if tl.IsMultipart() { t.Error("IsMultipart() = true for standard tool, want false") } }) @@ -871,8 +883,7 @@ func TestIsMultipart(t *testing.T) { return "result", nil }) - internalTool := tl.(*tool) - if internalTool.IsMultipart() { + if tl.IsMultipart() { t.Error("IsMultipart() = true for NewTool, want false") } }) @@ -886,8 +897,7 @@ func TestIsMultipart(t *testing.T) { }, nil }) - internalTool := tl.(*tool) - if !internalTool.IsMultipart() { + if !tl.IsMultipart() { t.Error("IsMultipart() = false for multipart tool, want true") } }) @@ -900,9 +910,422 @@ func TestIsMultipart(t *testing.T) { }, nil }) - internalTool := tl.(*tool) - if !internalTool.IsMultipart() { + if !tl.IsMultipart() { t.Error("IsMultipart() = false for NewMultipartTool, want true") } }) } + +func TestToolContextIsResumed(t *testing.T) { + t.Run("returns false when Resumed is nil", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + Resumed: nil, + } + + if tc.IsResumed() { + t.Error("IsResumed() = true, want false") + } + }) + + t.Run("returns true when Resumed is set", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + Resumed: map[string]any{"step": "confirm"}, + } + + if !tc.IsResumed() { + t.Error("IsResumed() = false, want true") + } + }) + + t.Run("returns true even for empty map", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + Resumed: map[string]any{}, + } + + if !tc.IsResumed() { + t.Error("IsResumed() = false for empty map, want true") + } + }) +} + +func TestResumedValue(t *testing.T) { + t.Run("returns value when key exists and type matches", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + Resumed: map[string]any{ + "step": "confirmation", + "count": 42, + }, + } + + step, ok := ResumedValue[string](tc, "step") + if !ok { + t.Error("ResumedValue[string] ok = false, want true") + } + if step != "confirmation" { + t.Errorf("step = %q, want %q", step, "confirmation") + } + + count, ok := ResumedValue[int](tc, "count") + if !ok { + t.Error("ResumedValue[int] ok = false, want true") + } + if count != 42 { + t.Errorf("count = %d, want %d", count, 42) + } + }) + + t.Run("returns false when key does not exist", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + Resumed: map[string]any{"other": "value"}, + } + + val, ok := ResumedValue[string](tc, "missing") + if ok { + t.Error("ResumedValue ok = true for missing key, want false") + } + if val != "" { + t.Errorf("val = %q, want zero value", val) + } + }) + + t.Run("returns false when type does not match", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + Resumed: map[string]any{"count": "not a number"}, + } + + val, ok := ResumedValue[int](tc, "count") + if ok { + t.Error("ResumedValue ok = true for wrong type, want false") + } + if val != 0 { + t.Errorf("val = %d, want zero value", val) + } + }) + + t.Run("returns false when Resumed is nil", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + Resumed: nil, + } + + val, ok := ResumedValue[string](tc, "anything") + if ok { + t.Error("ResumedValue ok = true for nil Resumed, want false") + } + if val != "" { + t.Errorf("val = %q, want zero value", val) + } + }) + + t.Run("works with complex types", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + Resumed: map[string]any{ + "options": []string{"a", "b", "c"}, + "nested": map[string]any{"key": "value"}, + }, + } + + options, ok := ResumedValue[[]string](tc, "options") + if !ok { + t.Error("ResumedValue[[]string] ok = false, want true") + } + if len(options) != 3 { + t.Errorf("len(options) = %d, want 3", len(options)) + } + + nested, ok := ResumedValue[map[string]any](tc, "nested") + if !ok { + t.Error("ResumedValue[map[string]any] ok = false, want true") + } + if nested["key"] != "value" { + t.Errorf("nested[key] = %v, want %q", nested["key"], "value") + } + }) +} + +func TestOriginalInputAs(t *testing.T) { + type MyInput struct { + Query string `json:"query"` + Limit int `json:"limit"` + } + + t.Run("returns typed input when type matches", func(t *testing.T) { + original := MyInput{Query: "test", Limit: 10} + tc := &ToolContext{ + Context: context.Background(), + OriginalInput: original, + } + + input, ok := OriginalInputAs[MyInput](tc) + if !ok { + t.Error("OriginalInputAs ok = false, want true") + } + if input.Query != "test" { + t.Errorf("input.Query = %q, want %q", input.Query, "test") + } + if input.Limit != 10 { + t.Errorf("input.Limit = %d, want %d", input.Limit, 10) + } + }) + + t.Run("returns false when OriginalInput is nil", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + OriginalInput: nil, + } + + input, ok := OriginalInputAs[MyInput](tc) + if ok { + t.Error("OriginalInputAs ok = true for nil, want false") + } + if input.Query != "" || input.Limit != 0 { + t.Errorf("input = %+v, want zero value", input) + } + }) + + t.Run("returns false when type does not match", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + OriginalInput: "wrong type", + } + + input, ok := OriginalInputAs[MyInput](tc) + if ok { + t.Error("OriginalInputAs ok = true for wrong type, want false") + } + if input.Query != "" || input.Limit != 0 { + t.Errorf("input = %+v, want zero value", input) + } + }) + + t.Run("works with map type", func(t *testing.T) { + original := map[string]any{"query": "test", "limit": 10} + tc := &ToolContext{ + Context: context.Background(), + OriginalInput: original, + } + + input, ok := OriginalInputAs[map[string]any](tc) + if !ok { + t.Error("OriginalInputAs ok = false, want true") + } + if input["query"] != "test" { + t.Errorf("input[query] = %v, want %q", input["query"], "test") + } + }) + + t.Run("works with pointer types", func(t *testing.T) { + original := &MyInput{Query: "pointer", Limit: 5} + tc := &ToolContext{ + Context: context.Background(), + OriginalInput: original, + } + + input, ok := OriginalInputAs[*MyInput](tc) + if !ok { + t.Error("OriginalInputAs ok = false, want true") + } + if input.Query != "pointer" { + t.Errorf("input.Query = %q, want %q", input.Query, "pointer") + } + }) +} + +func TestToolContextInterruptMethod(t *testing.T) { + t.Run("interrupt with nil options", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + } + + err := tc.Interrupt(nil) + if err == nil { + t.Fatal("Interrupt(nil) = nil, want error") + } + + isInterrupt, meta := IsToolInterruptError(err) + if !isInterrupt { + t.Error("IsToolInterruptError() = false, want true") + } + if meta != nil { + t.Errorf("metadata = %v, want nil", meta) + } + }) + + t.Run("interrupt with empty options", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + } + + err := tc.Interrupt(&InterruptOptions{}) + if err == nil { + t.Fatal("Interrupt() = nil, want error") + } + + isInterrupt, meta := IsToolInterruptError(err) + if !isInterrupt { + t.Error("IsToolInterruptError() = false, want true") + } + if meta != nil { + t.Errorf("metadata = %v, want nil", meta) + } + }) + + t.Run("interrupt with metadata", func(t *testing.T) { + tc := &ToolContext{ + Context: context.Background(), + } + + err := tc.Interrupt(&InterruptOptions{ + Metadata: map[string]any{ + "step": "confirm", + "preview": "deleting files", + }, + }) + if err == nil { + t.Fatal("Interrupt() = nil, want error") + } + + isInterrupt, meta := IsToolInterruptError(err) + if !isInterrupt { + t.Error("IsToolInterruptError() = false, want true") + } + if meta["step"] != "confirm" { + t.Errorf("meta[step] = %v, want %q", meta["step"], "confirm") + } + if meta["preview"] != "deleting files" { + t.Errorf("meta[preview] = %v, want %q", meta["preview"], "deleting files") + } + }) +} + +func TestInterruptFor(t *testing.T) { + type ConfirmMeta struct { + Reason string `json:"reason"` + Amount float64 `json:"amount"` + Recipient string `json:"recipient"` + } + + t.Run("creates interrupt with typed metadata", func(t *testing.T) { + tc := &ToolContext{Context: context.Background()} + + err := InterruptWith(tc, ConfirmMeta{ + Reason: "new recipient", + Amount: 50.0, + Recipient: "Alice", + }) + + if err == nil { + t.Fatal("InterruptFor() = nil, want error") + } + + isInterrupt, meta := IsToolInterruptError(err) + if !isInterrupt { + t.Error("IsToolInterruptError() = false, want true") + } + if meta["reason"] != "new recipient" { + t.Errorf("meta[reason] = %v, want %q", meta["reason"], "new recipient") + } + if meta["amount"] != 50.0 { + t.Errorf("meta[amount] = %v, want %v", meta["amount"], 50.0) + } + if meta["recipient"] != "Alice" { + t.Errorf("meta[recipient] = %v, want %q", meta["recipient"], "Alice") + } + }) + + t.Run("handles nested structs", func(t *testing.T) { + type Nested struct { + Inner struct { + Value string `json:"value"` + } `json:"inner"` + } + + tc := &ToolContext{Context: context.Background()} + err := InterruptWith(tc, Nested{Inner: struct { + Value string `json:"value"` + }{Value: "test"}}) + + isInterrupt, meta := IsToolInterruptError(err) + if !isInterrupt { + t.Error("IsToolInterruptError() = false, want true") + } + inner, ok := meta["inner"].(map[string]any) + if !ok { + t.Fatal("meta[inner] is not a map") + } + if inner["value"] != "test" { + t.Errorf("inner[value] = %v, want %q", inner["value"], "test") + } + }) +} + +func TestInterruptMetadata(t *testing.T) { + type ConfirmMeta struct { + Reason string `json:"reason"` + Amount float64 `json:"amount"` + Recipient string `json:"recipient"` + } + + t.Run("extracts typed metadata from interrupt part", func(t *testing.T) { + part := NewToolRequestPart(&ToolRequest{ + Name: "testTool", + Input: map[string]any{}, + }) + part.Metadata = map[string]any{ + "interrupt": map[string]any{ + "reason": "large amount", + "amount": 200.0, + "recipient": "Bob", + }, + } + + meta, ok := InterruptAs[ConfirmMeta](part) + if !ok { + t.Fatal("InterruptMetadata() ok = false, want true") + } + if meta.Reason != "large amount" { + t.Errorf("meta.Reason = %q, want %q", meta.Reason, "large amount") + } + if meta.Amount != 200.0 { + t.Errorf("meta.Amount = %v, want %v", meta.Amount, 200.0) + } + if meta.Recipient != "Bob" { + t.Errorf("meta.Recipient = %q, want %q", meta.Recipient, "Bob") + } + }) + + t.Run("returns false for non-interrupt part", func(t *testing.T) { + part := NewTextPart("not an interrupt") + + _, ok := InterruptAs[ConfirmMeta](part) + if ok { + t.Error("InterruptMetadata() ok = true for text part, want false") + } + }) + + t.Run("returns false for nil part", func(t *testing.T) { + _, ok := InterruptAs[ConfirmMeta](nil) + if ok { + t.Error("InterruptMetadata() ok = true for nil, want false") + } + }) + + t.Run("returns false when interrupt metadata is not a map", func(t *testing.T) { + part := NewToolRequestPart(&ToolRequest{Name: "test"}) + part.Metadata = map[string]any{ + "interrupt": true, // bool instead of map + } + + _, ok := InterruptAs[ConfirmMeta](part) + if ok { + t.Error("InterruptMetadata() ok = true for bool interrupt, want false") + } + }) +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 8123debf37..24a71cb3b7 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -546,12 +546,12 @@ func LookupBackgroundModel(g *Genkit, name string) ai.BackgroundModel { // } // // fmt.Println(resp.Text()) // Might output something like "The weather in Paris is Sunny, 25°C." -func DefineTool[In, Out any](g *Genkit, name, description string, fn ai.ToolFunc[In, Out], opts ...ai.ToolOption) ai.Tool { +func DefineTool[In, Out any](g *Genkit, name, description string, fn ai.ToolFunc[In, Out], opts ...ai.ToolOption) *ai.ToolDef[In, Out] { return ai.DefineTool(g.reg, name, description, fn, opts...) } // DefineToolWithInputSchema defines a tool with a custom input schema that can be used by models during generation, -// registers it as a [core.Action] of type Tool, and returns an [ai.Tool]. +// registers it as a [core.Action] of type Tool, and returns an [*ai.ToolDef]. // // This variant of [DefineTool] allows specifying a JSON Schema for the tool's input, providing more // control over input validation and model guidance. The input parameter to the tool function will be @@ -595,12 +595,12 @@ func DefineTool[In, Out any](g *Genkit, name, description string, fn ai.ToolFunc // }, // ai.WithToolInputSchema(inputSchema), // ) -func DefineToolWithInputSchema[Out any](g *Genkit, name, description string, inputSchema map[string]any, fn ai.ToolFunc[any, Out]) ai.Tool { +func DefineToolWithInputSchema[Out any](g *Genkit, name, description string, inputSchema map[string]any, fn ai.ToolFunc[any, Out]) *ai.ToolDef[any, Out] { return ai.DefineTool(g.reg, name, description, fn, ai.WithInputSchema(inputSchema)) } // DefineMultipartTool defines a multipart tool that can be used by models during generation, -// registers it as a [core.Action] of type Tool, and returns an [ai.Tool]. +// registers it as a [core.Action] of type Tool, and returns an [*ai.ToolDef]. // Unlike regular tools that return just an output value, multipart tools can return // both an output value and additional content parts (like images or other media). // @@ -649,13 +649,14 @@ func DefineToolWithInputSchema[Out any](g *Genkit, name, description string, inp // } // // fmt.Println(resp.Text()) -func DefineMultipartTool[In any](g *Genkit, name, description string, fn ai.MultipartToolFunc[In], opts ...ai.ToolOption) ai.Tool { +func DefineMultipartTool[In any](g *Genkit, name, description string, fn ai.MultipartToolFunc[In], opts ...ai.ToolOption) *ai.ToolDef[In, *ai.MultipartToolResponse] { return ai.DefineMultipartTool(g.reg, name, description, fn, opts...) } -// LookupTool retrieves a registered [ai.Tool] by its name. +// LookupTool retrieves a registered tool by its name. // It returns the tool instance if found, or `nil` if no tool with the // given name is registered (e.g., via [DefineTool]). +// Since the types are not known at lookup time, it returns a type-erased tool. func LookupTool(g *Genkit, name string) ai.Tool { return ai.LookupTool(g.reg, name) } diff --git a/go/internal/base/json.go b/go/internal/base/json.go index 4f413ab8f7..82fef0f07b 100644 --- a/go/internal/base/json.go +++ b/go/internal/base/json.go @@ -118,6 +118,32 @@ func InferJSONSchema(x any) (s *jsonschema.Schema) { return s } +// MapToStruct converts a map[string]any to a struct of type T via JSON round-trip. +func MapToStruct[T any](m map[string]any) (T, error) { + var result T + data, err := json.Marshal(m) + if err != nil { + return result, err + } + if err := json.Unmarshal(data, &result); err != nil { + return result, err + } + return result, nil +} + +// StructToMap converts a struct to map[string]any via JSON round-trip. +func StructToMap[T any](v T) (map[string]any, error) { + data, err := json.Marshal(v) + if err != nil { + return nil, err + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return nil, err + } + return m, nil +} + // SchemaAsMap converts json schema struct to a map (JSON representation). func SchemaAsMap(s *jsonschema.Schema) map[string]any { jsb, err := s.MarshalJSON() diff --git a/go/plugins/compat_oai/generate.go b/go/plugins/compat_oai/generate.go index fd5050a291..acc7865582 100644 --- a/go/plugins/compat_oai/generate.go +++ b/go/plugins/compat_oai/generate.go @@ -20,20 +20,12 @@ import ( "fmt" "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/internal/base" "github.com/openai/openai-go" "github.com/openai/openai-go/packages/param" "github.com/openai/openai-go/shared" ) -// mapToStruct unmarshals a map[string]any to the expected config api. -func mapToStruct(m map[string]any, v any) error { - jsonData, err := json.Marshal(m) - if err != nil { - return err - } - return json.Unmarshal(jsonData, v) -} - // ModelGenerator handles OpenAI generation requests type ModelGenerator struct { client *openai.Client @@ -163,7 +155,9 @@ func (g *ModelGenerator) WithConfig(config any) *ModelGenerator { case *openai.ChatCompletionNewParams: openaiConfig = *cfg case map[string]any: - if err := mapToStruct(cfg, &openaiConfig); err != nil { + var err error + openaiConfig, err = base.MapToStruct[openai.ChatCompletionNewParams](cfg) + if err != nil { g.err = fmt.Errorf("failed to convert config to openai.ChatCompletionNewParams: %w", err) return g } diff --git a/go/plugins/googlegenai/gemini.go b/go/plugins/googlegenai/gemini.go index c3ca5297b0..9e09b45ae4 100644 --- a/go/plugins/googlegenai/gemini.go +++ b/go/plugins/googlegenai/gemini.go @@ -99,14 +99,6 @@ func configToMap(config any) map[string]any { return result } -// mapToStruct unmarshals a map[string]any to the expected config api. -func mapToStruct(m map[string]any, v any) error { - jsonData, err := json.Marshal(m) - if err != nil { - return err - } - return json.Unmarshal(jsonData, v) -} // configFromRequest converts any supported config type to [genai.GenerateContentConfig]. func configFromRequest(input *ai.ModelRequest) (*genai.GenerateContentConfig, error) { @@ -119,7 +111,9 @@ func configFromRequest(input *ai.ModelRequest) (*genai.GenerateContentConfig, er result = *config case map[string]any: // TODO: Log warnings if unknown parameters are found. - if err := mapToStruct(config, &result); err != nil { + var err error + result, err = base.MapToStruct[genai.GenerateContentConfig](config) + if err != nil { return nil, err } case nil: diff --git a/go/plugins/googlegenai/imagen.go b/go/plugins/googlegenai/imagen.go index aabe7ec490..6003494a30 100644 --- a/go/plugins/googlegenai/imagen.go +++ b/go/plugins/googlegenai/imagen.go @@ -22,6 +22,7 @@ import ( "fmt" "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/internal/base" "google.golang.org/genai" ) @@ -45,7 +46,9 @@ func imagenConfigFromRequest(input *ai.ModelRequest) (*genai.GenerateImagesConfi case *genai.GenerateImagesConfig: result = *config case map[string]any: - if err := mapToStruct(config, &result); err != nil { + var err error + result, err = base.MapToStruct[genai.GenerateImagesConfig](config) + if err != nil { return nil, err } case nil: diff --git a/go/plugins/internal/anthropic/anthropic.go b/go/plugins/internal/anthropic/anthropic.go index 3d68a5219d..b856223d0f 100644 --- a/go/plugins/internal/anthropic/anthropic.go +++ b/go/plugins/internal/anthropic/anthropic.go @@ -26,6 +26,7 @@ import ( "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/internal/base" "github.com/firebase/genkit/go/plugins/internal/uri" "github.com/anthropics/anthropic-sdk-go" @@ -195,14 +196,6 @@ func toAnthropicRequest(i *ai.ModelRequest) (*anthropic.MessageNewParams, error) return req, nil } -// mapToStruct unmarshals a map[string]any to the expected type -func mapToStruct(m map[string]any, v any) error { - jsonData, err := json.Marshal(m) - if err != nil { - return err - } - return json.Unmarshal(jsonData, v) -} // configFromRequest converts any supported config type to [anthropic.MessageNewParams] func configFromRequest(input *ai.ModelRequest) (*anthropic.MessageNewParams, error) { @@ -214,7 +207,9 @@ func configFromRequest(input *ai.ModelRequest) (*anthropic.MessageNewParams, err case *anthropic.MessageNewParams: result = *config case map[string]any: - if err := mapToStruct(config, &result); err != nil { + var err error + result, err = base.MapToStruct[anthropic.MessageNewParams](config) + if err != nil { return nil, err } case nil: @@ -243,7 +238,9 @@ func toAnthropicTools(tools []*ai.ToolDefinition) ([]anthropic.ToolUnionParam, e if len(inputSchema) == 0 { inputSchema = map[string]any{"type": "object", "properties": map[string]any{}} } - if err := mapToStruct(inputSchema, &schema); err != nil { + var err error + schema, err = base.MapToStruct[anthropic.ToolInputSchemaParam](inputSchema) + if err != nil { return nil, fmt.Errorf("unable to parse tool input schema: %w", err) } diff --git a/go/samples/basic/main.go b/go/samples/basic/main.go index 2031340ac5..c76b64ce62 100644 --- a/go/samples/basic/main.go +++ b/go/samples/basic/main.go @@ -50,8 +50,7 @@ func main() { })), ai.WithPrompt("Share a joke about %s.", input), ) - }, - ) + }) // Define a streaming flow that generates jokes about a given topic with passthrough streaming. genkit.DefineStreamingFlow(g, "streamingJokesFlow", From a172696ea24d6d24ebca5cb5c84e76fa6239e34c Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 12 Jan 2026 13:54:08 -0800 Subject: [PATCH 2/9] Update tools.go --- go/ai/tools.go | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/go/ai/tools.go b/go/ai/tools.go index 33a2c712ed..454f4d3f4d 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -625,20 +625,25 @@ func (t *ToolDef[In, Out]) Restart(p *Part, opts *RestartOptions) *Part { } // RespondWith creates a part for [WithToolResponses] to provide a resolved response for an interrupted tool call. -// Returns nil if the part is not a tool request. // // Example: // -// part := myTool.RespondWith(toolReq, output, WithResponseMetadata[MyOutput](meta)) -func (t *ToolDef[In, Out]) RespondWith(toolReq *Part, output Out, opts ...RespondWithOption[Out]) *Part { - if toolReq == nil || !toolReq.IsToolRequest() { - return nil +// part, err := myTool.RespondWith(toolReq, output, WithResponseMetadata[MyOutput](meta)) +func (t *ToolDef[In, Out]) RespondWith(toolReq *Part, output Out, opts ...RespondWithOption[Out]) (*Part, error) { + if toolReq == nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.RespondWith: toolReq is nil") + } + if !toolReq.IsToolRequest() { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.RespondWith: part is not a tool request") + } + if toolReq.ToolRequest.Name != t.Name() { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.RespondWith: tool request is for %q, not %q", toolReq.ToolRequest.Name, t.Name()) } cfg := &RespondOptions{} for _, opt := range opts { if err := opt.applyRespondWith(cfg); err != nil { - panic(fmt.Errorf("ai.ToolDef.RespondWith: %w", err)) + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.RespondWith: %v", err) } } @@ -650,24 +655,29 @@ func (t *ToolDef[In, Out]) RespondWith(toolReq *Part, output Out, opts ...Respon newToolResp.Metadata["interruptResponse"] = cfg.Metadata } - return newToolResp + return newToolResp, nil } // RestartWith creates a part for [WithToolRestarts] to re-execute an interrupted tool call with additional context. -// Returns nil if the part is not a tool request. // // Example: // -// part := myTool.RestartWith(toolReq, WithReplaceInput(newInput), WithResumedMetadata[MyInput](meta)) -func (t *ToolDef[In, Out]) RestartWith(toolReq *Part, opts ...RestartWithOption[In]) *Part { - if toolReq == nil || !toolReq.IsToolRequest() { - return nil +// part, err := myTool.RestartWith(toolReq, WithReplaceInput(newInput), WithResumedMetadata[MyInput](meta)) +func (t *ToolDef[In, Out]) RestartWith(toolReq *Part, opts ...RestartWithOption[In]) (*Part, error) { + if toolReq == nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.RestartWith: toolReq is nil") + } + if !toolReq.IsToolRequest() { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.RestartWith: part is not a tool request") + } + if toolReq.ToolRequest.Name != t.Name() { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.RestartWith: tool request is for %q, not %q", toolReq.ToolRequest.Name, t.Name()) } cfg := &RestartOptions{} for _, opt := range opts { if err := opt.applyRestartWith(cfg); err != nil { - panic(fmt.Errorf("ai.ToolDef.RestartWith: %w", err)) + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.RestartWith: %v", err) } } @@ -702,7 +712,7 @@ func (t *ToolDef[In, Out]) RestartWith(toolReq *Part, opts ...RestartWithOption[ }) newToolReqPart.Metadata = newMeta - return newToolReqPart + return newToolReqPart, nil } // resolveUniqueTools resolves the list of tool refs to a list of all tool names and new tools that must be registered. From f1a9c4d523b173b214b4160d8564f53fe8ee3dee Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 12 Jan 2026 15:23:37 -0800 Subject: [PATCH 3/9] Added intermediate tool interrupts sample. --- go/ai/generate.go | 2 +- go/ai/tools.go | 15 +- go/plugins/googlegenai/googlegenai.go | 2 +- go/samples/basic-prompts/main.go | 6 +- go/samples/basic-structured/main.go | 6 +- go/samples/basic/main.go | 4 +- go/samples/intermediate-interrupts/main.go | 227 +++++++++++++++++++++ 7 files changed, 250 insertions(+), 12 deletions(-) create mode 100644 go/samples/intermediate-interrupts/main.go diff --git a/go/ai/generate.go b/go/ai/generate.go index 73208da40f..0327627b24 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -1190,7 +1190,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene } } } - if originalInputVal, ok := restartPart.Metadata["originalInput"]; ok { + if originalInputVal, ok := restartPart.Metadata["replacedInput"]; ok { resumedCtx = origInputCtxKey.NewContext(resumedCtx, originalInputVal) } diff --git a/go/ai/tools.go b/go/ai/tools.go index 454f4d3f4d..48b907bde6 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -258,8 +258,19 @@ func OriginalInputAs[T any](tc *ToolContext) (T, bool) { if tc.OriginalInput == nil { return zero, false } - typed, ok := tc.OriginalInput.(T) - return typed, ok + // Try direct type assertion first (for when input is already typed) + if typed, ok := tc.OriginalInput.(T); ok { + return typed, ok + } + // Otherwise try to convert from map[string]any (common case from JSON) + if m, ok := tc.OriginalInput.(map[string]any); ok { + result, err := base.MapToStruct[T](m) + if err != nil { + return zero, false + } + return result, true + } + return zero, false } // DefineTool creates a new [ToolDef] and registers it. diff --git a/go/plugins/googlegenai/googlegenai.go b/go/plugins/googlegenai/googlegenai.go index 8ddbdfbe4a..04f361f74b 100644 --- a/go/plugins/googlegenai/googlegenai.go +++ b/go/plugins/googlegenai/googlegenai.go @@ -285,7 +285,7 @@ func (v *VertexAI) IsDefinedEmbedder(g *genkit.Genkit, name string) bool { // ModelRef creates a new ModelRef for a Google Gen AI model with the given name and configuration. func ModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { - return ai.NewModelRef(googleAIProvider+"/"+name, config) + return ai.NewModelRef(name, config) } // GoogleAIModelRef creates a new ModelRef for a Google AI model with the given ID and configuration. diff --git a/go/samples/basic-prompts/main.go b/go/samples/basic-prompts/main.go index ccbc308a13..3147542e0c 100644 --- a/go/samples/basic-prompts/main.go +++ b/go/samples/basic-prompts/main.go @@ -103,7 +103,7 @@ func main() { func DefineSimpleJokeWithInlinePrompt(g *genkit.Genkit) { jokePrompt := genkit.DefinePrompt( g, "joke.code", - ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ai.WithModel(googlegenai.ModelRef("googleai/gemini-2.5-flash", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ ThinkingBudget: genai.Ptr[int32](0), }, @@ -162,7 +162,7 @@ func DefineSimpleJokeWithDotprompt(g *genkit.Genkit) { func DefineStructuredJokeWithInlinePrompt(g *genkit.Genkit) { jokePrompt := genkit.DefineDataPrompt[JokeRequest, *Joke]( g, "structured-joke.code", - ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ai.WithModel(googlegenai.ModelRef("googleai/gemini-2.5-flash", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ ThinkingBudget: genai.Ptr[int32](0), }, @@ -215,7 +215,7 @@ func DefineStructuredJokeWithDotprompt(g *genkit.Genkit) { func DefineRecipeWithInlinePrompt(g *genkit.Genkit) { recipePrompt := genkit.DefineDataPrompt[RecipeRequest, *Recipe]( g, "recipe.code", - ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ai.WithModel(googlegenai.ModelRef("googleai/gemini-2.5-flash", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ ThinkingBudget: genai.Ptr[int32](0), }, diff --git a/go/samples/basic-structured/main.go b/go/samples/basic-structured/main.go index 428636de4d..0c47507d96 100644 --- a/go/samples/basic-structured/main.go +++ b/go/samples/basic-structured/main.go @@ -89,7 +89,7 @@ func DefineSimpleJoke(g *genkit.Genkit) { genkit.DefineStreamingFlow(g, "simpleJokesFlow", func(ctx context.Context, input string, sendChunk core.StreamCallback[string]) (string, error) { stream := genkit.GenerateStream(ctx, g, - ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ai.WithModel(googlegenai.ModelRef("googleai/gemini-2.5-flash", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ ThinkingBudget: genai.Ptr[int32](0), }, @@ -118,7 +118,7 @@ func DefineStructuredJoke(g *genkit.Genkit) { genkit.DefineStreamingFlow(g, "structuredJokesFlow", func(ctx context.Context, input JokeRequest, sendChunk core.StreamCallback[*Joke]) (*Joke, error) { stream := genkit.GenerateDataStream[*Joke](ctx, g, - ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ai.WithModel(googlegenai.ModelRef("googleai/gemini-2.5-flash", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ ThinkingBudget: genai.Ptr[int32](0), }, @@ -146,7 +146,7 @@ func DefineRecipe(g *genkit.Genkit) { genkit.DefineStreamingFlow(g, "recipeFlow", func(ctx context.Context, input RecipeRequest, sendChunk core.StreamCallback[[]*Ingredient]) (*Recipe, error) { stream := genkit.GenerateDataStream[*Recipe](ctx, g, - ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ai.WithModel(googlegenai.ModelRef("googleai/gemini-2.5-flash", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ ThinkingBudget: genai.Ptr[int32](0), }, diff --git a/go/samples/basic/main.go b/go/samples/basic/main.go index c76b64ce62..9cf3952ea9 100644 --- a/go/samples/basic/main.go +++ b/go/samples/basic/main.go @@ -43,7 +43,7 @@ func main() { } return genkit.GenerateText(ctx, g, - ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ai.WithModel(googlegenai.ModelRef("googleai/gemini-2.5-flash", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ ThinkingBudget: genai.Ptr[int32](0), }, @@ -60,7 +60,7 @@ func main() { } resp, err := genkit.Generate(ctx, g, - ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ai.WithModel(googlegenai.ModelRef("googleai/gemini-2.5-flash", &genai.GenerateContentConfig{ ThinkingConfig: &genai.ThinkingConfig{ ThinkingBudget: genai.Ptr[int32](0), }, diff --git a/go/samples/intermediate-interrupts/main.go b/go/samples/intermediate-interrupts/main.go new file mode 100644 index 0000000000..3c90e60799 --- /dev/null +++ b/go/samples/intermediate-interrupts/main.go @@ -0,0 +1,227 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tool-interrupts demonstrates the tool interrupts feature in Genkit. +// It shows how to pause generation for human-in-the-loop interactions +// and resume with user input using RestartWith and RespondWith. +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strconv" + "strings" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "google.golang.org/genai" +) + +// TransferInput is the input schema for the transferMoney tool. +type TransferInput struct { + ToAccount string `json:"toAccount" jsonschema:"description=destination account ID"` + Amount float64 `json:"amount" jsonschema:"description=amount in dollars (e.g. 50.00 for $50)"` +} + +// TransferOutput is the output schema for the transferMoney tool. +type TransferOutput struct { + Status string `json:"status"` + Message string `json:"message,omitempty"` + NewBalance float64 `json:"newBalance,omitempty"` +} + +// TransferInterrupt is the typed interrupt metadata for transfer issues. +type TransferInterrupt struct { + Reason string `json:"reason"` // "insufficient_balance" or "confirm_large" + ToAccount string `json:"toAccount"` + Amount float64 `json:"amount"` + Balance float64 `json:"balance,omitempty"` +} + +var accountBalance = 150.00 + +func main() { + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + reader := bufio.NewReader(os.Stdin) + + // Define the transfer tool with interrupt logic + transferMoney := genkit.DefineTool(g, "transferMoney", + "Transfers money to another account. Use this when the user wants to send money.", + func(ctx *ai.ToolContext, input TransferInput) (TransferOutput, error) { + if input.Amount > accountBalance { + if accountBalance <= 0 { + return TransferOutput{"rejected", "Account balance is 0. Please add funds.", accountBalance}, nil + } + return TransferOutput{}, ai.InterruptWith(ctx, TransferInterrupt{ + Reason: "insufficient_balance", + ToAccount: input.ToAccount, + Amount: input.Amount, + Balance: accountBalance, + }) + } + + if !ctx.IsResumed() && input.Amount > 100 { + return TransferOutput{}, ai.InterruptWith(ctx, TransferInterrupt{ + Reason: "confirm_large", + ToAccount: input.ToAccount, + Amount: input.Amount, + }) + } + + accountBalance -= input.Amount + message := fmt.Sprintf("Transferred $%.2f to %s", input.Amount, input.ToAccount) + if orig, ok := ai.OriginalInputAs[TransferInput](ctx); ok { + message = fmt.Sprintf("Transferred $%.2f to %s (adjusted from $%.2f due to insufficient balance)", input.Amount, input.ToAccount, orig.Amount) + } + + return TransferOutput{"completed", message, accountBalance}, nil + }) + + // Define the payment agent flow + paymentAgent := genkit.DefineFlow(g, "paymentAgent", func(ctx context.Context, request string) (string, error) { + resp, err := genkit.Generate(ctx, g, + ai.WithModel(googlegenai.ModelRef("googleai/gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ThinkingBudget: genai.Ptr[int32](0)}, + })), + ai.WithSystem("You are a helpful payment assistant. When the user wants to transfer money, use the transferMoney tool. Always confirm the result with the user."), + ai.WithPrompt(request), + ai.WithTools(transferMoney), + ) + if err != nil { + return "", err + } + + for resp.FinishReason == ai.FinishReasonInterrupted { + var restarts, responses []*ai.Part + + for _, interrupt := range resp.Interrupts() { + meta, ok := ai.InterruptAs[TransferInterrupt](interrupt) + if !ok { + continue + } + + switch meta.Reason { + case "insufficient_balance": + fmt.Printf("\n[Insufficient Balance] You requested $%.2f but only have $%.2f\n", + meta.Amount, meta.Balance) + fmt.Printf("Options: [1] Transfer $%.2f instead [2] Cancel\n", meta.Balance) + fmt.Print("Choice: ") + + if promptChoice(reader, 1, 2) == 1 { + // RestartWith + WithReplaceInput: Retry with adjusted amount + part, err := transferMoney.RestartWith(interrupt, + ai.WithReplaceInput(TransferInput{meta.ToAccount, meta.Balance})) + if err != nil { + return "", fmt.Errorf("RestartWith: %w", err) + } + restarts = append(restarts, part) + } else { + // RespondWith: Provide cancelled output directly + part, err := transferMoney.RespondWith(interrupt, + TransferOutput{"cancelled", "Transfer cancelled by user.", accountBalance}) + if err != nil { + return "", fmt.Errorf("RespondWith: %w", err) + } + responses = append(responses, part) + } + + case "confirm_large": + fmt.Printf("\n[Confirm Large Transfer] Send $%.2f to %s? (yes/no): ", + meta.Amount, meta.ToAccount) + + if promptYesNo(reader) { + // RestartWith: Re-execute the tool with approval + part, err := transferMoney.RestartWith(interrupt) + if err != nil { + return "", fmt.Errorf("RestartWith: %w", err) + } + restarts = append(restarts, part) + } else { + // RespondWith: Provide cancelled output directly + part, err := transferMoney.RespondWith(interrupt, + TransferOutput{"cancelled", "Transfer cancelled by user.", accountBalance}) + if err != nil { + return "", fmt.Errorf("RespondWith: %w", err) + } + responses = append(responses, part) + } + } + } + + resp, err = genkit.Generate(ctx, g, + ai.WithModel(googlegenai.ModelRef("googleai/gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ThinkingBudget: genai.Ptr[int32](0)}, + })), + ai.WithMessages(resp.History()...), + ai.WithTools(transferMoney), + ai.WithToolRestarts(restarts...), + ai.WithToolResponses(responses...), + ) + if err != nil { + return "", err + } + } + + return resp.Text(), nil + }) + + fmt.Println("Payment Agent - Tool Interrupts Demo") + fmt.Printf("Balance: $%.2f\n", accountBalance) + fmt.Println("Type 'quit' to exit.") + fmt.Println() + + for { + fmt.Print("> ") + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(input) + + if input == "quit" || input == "exit" { + break + } + if input == "" { + continue + } + + result, err := paymentAgent.Run(ctx, input) + if err != nil { + fmt.Printf("Error: %v\n\n", err) + continue + } + fmt.Printf("\n%s\n\n", result) + } +} + +// promptChoice reads a number choice from stdin. +func promptChoice(reader *bufio.Reader, min, max int) int { + for { + text, _ := reader.ReadString('\n') + text = strings.TrimSpace(text) + n, err := strconv.Atoi(text) + if err == nil && n >= min && n <= max { + return n + } + fmt.Printf("Please enter a number between %d and %d: ", min, max) + } +} + +// promptYesNo reads a yes/no response from stdin. +func promptYesNo(reader *bufio.Reader) bool { + text, _ := reader.ReadString('\n') + text = strings.ToLower(strings.TrimSpace(text)) + return text == "yes" || text == "y" +} From 715698121d9bde2201820ab4d3cc70072d0eb84d Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 12 Jan 2026 15:47:41 -0800 Subject: [PATCH 4/9] Update main.go --- go/samples/intermediate-interrupts/main.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/go/samples/intermediate-interrupts/main.go b/go/samples/intermediate-interrupts/main.go index 3c90e60799..bcc89ac3a4 100644 --- a/go/samples/intermediate-interrupts/main.go +++ b/go/samples/intermediate-interrupts/main.go @@ -68,18 +68,13 @@ func main() { return TransferOutput{"rejected", "Account balance is 0. Please add funds.", accountBalance}, nil } return TransferOutput{}, ai.InterruptWith(ctx, TransferInterrupt{ - Reason: "insufficient_balance", - ToAccount: input.ToAccount, - Amount: input.Amount, - Balance: accountBalance, + "insufficient_balance", input.ToAccount, input.Amount, accountBalance, }) } if !ctx.IsResumed() && input.Amount > 100 { return TransferOutput{}, ai.InterruptWith(ctx, TransferInterrupt{ - Reason: "confirm_large", - ToAccount: input.ToAccount, - Amount: input.Amount, + "confirm_large", input.ToAccount, input.Amount, accountBalance, }) } From 6c909cd583daecbcce12970e1e6644ae6d48a080 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 12 Jan 2026 15:49:13 -0800 Subject: [PATCH 5/9] Update main.go --- go/samples/intermediate-interrupts/main.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/go/samples/intermediate-interrupts/main.go b/go/samples/intermediate-interrupts/main.go index bcc89ac3a4..9eb6ab7c79 100644 --- a/go/samples/intermediate-interrupts/main.go +++ b/go/samples/intermediate-interrupts/main.go @@ -112,8 +112,7 @@ func main() { switch meta.Reason { case "insufficient_balance": - fmt.Printf("\n[Insufficient Balance] You requested $%.2f but only have $%.2f\n", - meta.Amount, meta.Balance) + fmt.Printf("\n[Insufficient Balance] You requested $%.2f but only have $%.2f\n", meta.Amount, meta.Balance) fmt.Printf("Options: [1] Transfer $%.2f instead [2] Cancel\n", meta.Balance) fmt.Print("Choice: ") @@ -136,8 +135,7 @@ func main() { } case "confirm_large": - fmt.Printf("\n[Confirm Large Transfer] Send $%.2f to %s? (yes/no): ", - meta.Amount, meta.ToAccount) + fmt.Printf("\n[Confirm Large Transfer] Send $%.2f to %s? (yes/no): ", meta.Amount, meta.ToAccount) if promptYesNo(reader) { // RestartWith: Re-execute the tool with approval From a92864add7f45a99ddab4eb38604267c0fd2a209 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 12 Jan 2026 16:54:58 -0800 Subject: [PATCH 6/9] Changed option name. --- go/ai/tools.go | 8 ++++---- go/samples/intermediate-interrupts/main.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go/ai/tools.go b/go/ai/tools.go index 48b907bde6..7da2618c43 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -158,7 +158,7 @@ type RestartWithOption[In any] interface { func (o *RestartOptions) applyRestartWith(opts *RestartOptions) error { if o.ReplaceInput != nil { if opts.ReplaceInput != nil { - return errors.New("cannot set replace input more than once (WithReplaceInput)") + return errors.New("cannot set new input more than once (WithNewInput)") } opts.ReplaceInput = o.ReplaceInput } @@ -171,8 +171,8 @@ func (o *RestartOptions) applyRestartWith(opts *RestartOptions) error { return nil } -// WithReplaceInput sets a new input value to replace the original tool request input. -func WithReplaceInput[In any](input In) RestartWithOption[In] { +// WithNewInput sets a new input value to replace the original tool request input. +func WithNewInput[In any](input In) RestartWithOption[In] { return &RestartOptions{ReplaceInput: input} } @@ -673,7 +673,7 @@ func (t *ToolDef[In, Out]) RespondWith(toolReq *Part, output Out, opts ...Respon // // Example: // -// part, err := myTool.RestartWith(toolReq, WithReplaceInput(newInput), WithResumedMetadata[MyInput](meta)) +// part, err := myTool.RestartWith(toolReq, WithNewInput(newInput), WithResumedMetadata[MyInput](meta)) func (t *ToolDef[In, Out]) RestartWith(toolReq *Part, opts ...RestartWithOption[In]) (*Part, error) { if toolReq == nil { return nil, core.NewError(core.INVALID_ARGUMENT, "ai.RestartWith: toolReq is nil") diff --git a/go/samples/intermediate-interrupts/main.go b/go/samples/intermediate-interrupts/main.go index 9eb6ab7c79..5dd2a697a0 100644 --- a/go/samples/intermediate-interrupts/main.go +++ b/go/samples/intermediate-interrupts/main.go @@ -117,9 +117,9 @@ func main() { fmt.Print("Choice: ") if promptChoice(reader, 1, 2) == 1 { - // RestartWith + WithReplaceInput: Retry with adjusted amount + // RestartWith + WithNewInput: Retry with adjusted amount part, err := transferMoney.RestartWith(interrupt, - ai.WithReplaceInput(TransferInput{meta.ToAccount, meta.Balance})) + ai.WithNewInput(TransferInput{meta.ToAccount, meta.Balance})) if err != nil { return "", fmt.Errorf("RestartWith: %w", err) } From b66f5b221b9009facca2a49131b2a5fddd480857 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 13 Jan 2026 07:58:10 -0800 Subject: [PATCH 7/9] Update tools.go --- go/ai/tools.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/go/ai/tools.go b/go/ai/tools.go index 7da2618c43..453ad779a3 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -177,7 +177,8 @@ func WithNewInput[In any](input In) RestartWithOption[In] { } // WithResumedMetadata sets metadata to pass to the resumed tool execution. -func WithResumedMetadata[In any](meta any) RestartWithOption[In] { +// The metadata will be available in the tool's [ToolContext.Resumed] field. +func WithResumedMetadata[In any](meta map[string]any) RestartWithOption[In] { return &RestartOptions{ResumedMetadata: meta} } From fe0d53f5306d95f9a194047c233c8a762af8879d Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 13 Jan 2026 09:02:55 -0800 Subject: [PATCH 8/9] Ran go mod tidy. --- go/samples/mcp-git-pr-explainer/go.mod | 2 +- go/samples/mcp-git-pr-explainer/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/samples/mcp-git-pr-explainer/go.mod b/go/samples/mcp-git-pr-explainer/go.mod index 445211f6e0..4a19281a89 100644 --- a/go/samples/mcp-git-pr-explainer/go.mod +++ b/go/samples/mcp-git-pr-explainer/go.mod @@ -43,7 +43,7 @@ require ( golang.org/x/net v0.41.0 // indirect golang.org/x/sys v0.34.0 // indirect golang.org/x/text v0.27.0 // indirect - google.golang.org/genai v1.30.0 // indirect + google.golang.org/genai v1.40.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect google.golang.org/grpc v1.73.0 // indirect google.golang.org/protobuf v1.36.6 // indirect diff --git a/go/samples/mcp-git-pr-explainer/go.sum b/go/samples/mcp-git-pr-explainer/go.sum index 428a4e6eb7..b8b244b3cb 100644 --- a/go/samples/mcp-git-pr-explainer/go.sum +++ b/go/samples/mcp-git-pr-explainer/go.sum @@ -98,8 +98,8 @@ golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= -google.golang.org/genai v1.30.0 h1:7021aneIvl24nEBLbtQFEWleHsMbjzpcQvkT4WcJ1dc= -google.golang.org/genai v1.30.0/go.mod h1:7pAilaICJlQBonjKKJNhftDFv3SREhZcTe9F6nRcjbg= +google.golang.org/genai v1.40.0 h1:kYxyQSH+vsib8dvsgyLJzsVEIv5k3ZmHJyVqdvGncmc= +google.golang.org/genai v1.40.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE= google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= From 76ea6ba0a2abdaccd752452e994f905062c9d458 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 13 Jan 2026 19:50:31 -0800 Subject: [PATCH 9/9] Formatting. --- go/plugins/googlegenai/gemini.go | 1 - go/plugins/internal/anthropic/anthropic.go | 1 - 2 files changed, 2 deletions(-) diff --git a/go/plugins/googlegenai/gemini.go b/go/plugins/googlegenai/gemini.go index 9e09b45ae4..ae61255c3f 100644 --- a/go/plugins/googlegenai/gemini.go +++ b/go/plugins/googlegenai/gemini.go @@ -99,7 +99,6 @@ func configToMap(config any) map[string]any { return result } - // configFromRequest converts any supported config type to [genai.GenerateContentConfig]. func configFromRequest(input *ai.ModelRequest) (*genai.GenerateContentConfig, error) { var result genai.GenerateContentConfig diff --git a/go/plugins/internal/anthropic/anthropic.go b/go/plugins/internal/anthropic/anthropic.go index b856223d0f..d40930cc55 100644 --- a/go/plugins/internal/anthropic/anthropic.go +++ b/go/plugins/internal/anthropic/anthropic.go @@ -196,7 +196,6 @@ func toAnthropicRequest(i *ai.ModelRequest) (*anthropic.MessageNewParams, error) return req, nil } - // configFromRequest converts any supported config type to [anthropic.MessageNewParams] func configFromRequest(input *ai.ModelRequest) (*anthropic.MessageNewParams, error) { var result anthropic.MessageNewParams