Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/base"
"github.com/google/uuid"
)

// Model represents a model that can generate content based on a request.
Expand Down Expand Up @@ -361,6 +362,9 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi
return nil, err
}

// Ensure all tool requests have unique refs for matching during resume.
ensureToolRequestRefs(resp.Message)

// If this is a long-running operation response, return it immediately without further processing
if bm != nil && resp.Operation != nil {
return resp, nil
Expand Down Expand Up @@ -552,6 +556,9 @@ func GenerateText(ctx context.Context, r api.Registry, opts ...GenerateOption) (
}

// GenerateData runs a generate request and returns strongly-typed output.
// If the response doesn't contain text output (e.g., contains tool requests
// or interrupts instead), the output will be nil and no error is returned.
// Check resp.Interrupts() or resp.ToolRequests() to handle these cases.
func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) (*Out, *ModelResponse, error) {
var value Out
opts = append(opts, WithOutputType(value))
Expand All @@ -561,9 +568,16 @@ func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...Generate
return nil, nil, err
}

// If there's no text content to parse (e.g., the response contains tool
// requests or interrupts), return nil output. The caller should check
// resp.Interrupts() or resp.ToolRequests() to handle these cases.
if resp.Text() == "" {
return nil, resp, nil
}

err = resp.Output(&value)
if err != nil {
return nil, nil, err
return nil, resp, err
}

return &value, resp, nil
Expand Down Expand Up @@ -715,6 +729,20 @@ func (m *model) supportsConstrained(hasTools bool) bool {
return true
}

// ensureToolRequestRefs assigns unique refs to tool request parts that don't have one.
// This ensures that when there are multiple calls to the same tool, each can be
// individually matched when resuming with Restart or Respond directives.
func ensureToolRequestRefs(msg *Message) {
if msg == nil {
return
}
for _, part := range msg.Content {
if part.IsToolRequest() && part.ToolRequest.Ref == "" {
part.ToolRequest.Ref = uuid.New().String()
}
}
}

// clone creates a deep copy of the provided object using JSON marshaling and unmarshaling.
func clone[T any](obj *T) *T {
if obj == nil {
Expand Down Expand Up @@ -1166,7 +1194,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene
}
}
}
if originalInputVal, ok := restartPart.Metadata["originalInput"]; ok {
if originalInputVal, ok := restartPart.Metadata["replacedInput"]; ok {
resumedCtx = origInputCtxKey.NewContext(resumedCtx, originalInputVal)
}

Expand Down
53 changes: 29 additions & 24 deletions go/ai/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -993,13 +993,19 @@ func validTestMessage(m *Message, output *ModelOutputConfig) (*Message, error) {
return handler.ParseMessage(m)
}

type conditionalToolInput struct {
Value string
Interrupt bool
}

type resumableToolInput struct {
Action string
Data string
}

func TestToolInterruptsAndResume(t *testing.T) {
conditionalTool := DefineTool(r, "conditional", "tool that may interrupt based on input",
func(ctx *ToolContext, input struct {
Value string
Interrupt bool
},
) (string, error) {
func(ctx *ToolContext, input conditionalToolInput) (string, error) {
if input.Interrupt {
return "", ctx.Interrupt(&InterruptOptions{
Metadata: map[string]any{
Expand All @@ -1014,11 +1020,7 @@ func TestToolInterruptsAndResume(t *testing.T) {
)

resumableTool := DefineTool(r, "resumable", "tool that can be resumed",
func(ctx *ToolContext, input struct {
Action string
Data string
},
) (string, error) {
func(ctx *ToolContext, input resumableToolInput) (string, error) {
if ctx.Resumed != nil {
resumedData, ok := ctx.Resumed["data"].(string)
if ok {
Expand Down Expand Up @@ -1156,11 +1158,12 @@ func TestToolInterruptsAndResume(t *testing.T) {

interruptedPart := res.Message.Content[1]

newInput := conditionalToolInput{
Value: "new_test_data",
Interrupt: false,
}
restartPart := conditionalTool.Restart(interruptedPart, &RestartOptions{
ReplaceInput: map[string]any{
"Value": "new_test_data",
"Interrupt": false,
},
ReplaceInput: newInput,
ResumedMetadata: map[string]any{
"data": "resumed_data",
"source": "restart",
Expand All @@ -1175,17 +1178,17 @@ func TestToolInterruptsAndResume(t *testing.T) {
t.Errorf("expected tool request name 'conditional', got %q", restartPart.ToolRequest.Name)
}

newInput, ok := restartPart.ToolRequest.Input.(map[string]any)
replacedInput, ok := restartPart.ToolRequest.Input.(conditionalToolInput)
if !ok {
t.Fatal("expected input to be map[string]any")
t.Fatalf("expected input to be conditionalInput, got %T", restartPart.ToolRequest.Input)
}

if newInput["Value"] != "new_test_data" {
t.Errorf("expected new input value 'new_test_data', got %v", newInput["Value"])
if replacedInput.Value != "new_test_data" {
t.Errorf("expected new input value 'new_test_data', got %v", replacedInput.Value)
}

if newInput["Interrupt"] != false {
t.Errorf("expected interrupt to be false, got %v", newInput["Interrupt"])
if replacedInput.Interrupt != false {
t.Errorf("expected interrupt to be false, got %v", replacedInput.Interrupt)
}

if _, hasInterrupt := restartPart.Metadata["interrupt"]; hasInterrupt {
Expand Down Expand Up @@ -1242,11 +1245,13 @@ func TestToolInterruptsAndResume(t *testing.T) {
}

interruptedPart := res.Message.Content[1]

newInput := conditionalToolInput{
Value: "restarted_data",
Interrupt: false,
}
restartPart := conditionalTool.Restart(interruptedPart, &RestartOptions{
ReplaceInput: map[string]any{
"Value": "restarted_data",
"Interrupt": false,
},
ReplaceInput: newInput,
ResumedMetadata: map[string]any{
"data": "restart_context",
},
Expand Down
7 changes: 5 additions & 2 deletions go/ai/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -915,12 +915,15 @@ func (o *generateOptions) applyGenerate(genOpts *generateOptions) error {
return nil
}

// WithToolResponses sets the tool responses to return from interrupted tool calls.
// WithToolResponses provides resolved responses for interrupted tool calls.
// Use this when you already have the result and want to skip re-executing the tool.
func WithToolResponses(parts ...*Part) GenerateOption {
return &generateOptions{RespondParts: parts}
}

// WithToolRestarts sets the tool requests to restart interrupted tools with.
// WithToolRestarts re-executes interrupted tool calls with additional metadata.
// Use this when the original call lacked required context (e.g., auth, user confirmation)
// that should now allow the tool to complete successfully.
func WithToolRestarts(parts ...*Part) GenerateOption {
return &generateOptions{RestartParts: parts}
}
Expand Down
11 changes: 2 additions & 9 deletions go/ai/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"testing"

"github.com/firebase/genkit/go/core/api"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
Expand Down Expand Up @@ -658,15 +659,7 @@ func (t *mockTool) RunRawMultipart(ctx context.Context, input any) (*MultipartTo
return nil, nil
}

func (t *mockTool) Respond(toolReq *Part, outputData any, opts *RespondOptions) *Part {
return nil
}

func (t *mockTool) Restart(toolReq *Part, opts *RestartOptions) *Part {
return nil
}

func (t *mockTool) Register(r interface{ RegisterValue(string, any) }) {
func (t *mockTool) Register(r api.Registry) {
}

func TestWithInputSchemaName(t *testing.T) {
Expand Down
Loading
Loading