diff --git a/go/README.md b/go/README.md index 351847b199..ef302c82b8 100644 --- a/go/README.md +++ b/go/README.md @@ -474,6 +474,7 @@ Genkit provides a unified interface across all major AI providers. Use whichever | **Google AI** | `googlegenai.GoogleAI` | Gemini 2.5 Flash, Gemini 2.5 Pro, and more | | **Vertex AI** | `vertexai.VertexAI` | Gemini models via Google Cloud | | **Anthropic** | `anthropic.Anthropic` | Claude 3.5, Claude 3 Opus, and more | +| **OpenAI** | `openai.OpenAI` | GPT-5, GPT-5-mini, GPT-5-nano, GPT-4o and more | | **Ollama** | `ollama.Ollama` | Llama, Mistral, and other open models | | **OpenAI Compatible** | `compat_oai` | Any OpenAI-compatible API | @@ -484,6 +485,9 @@ g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) // Anthropic g := genkit.Init(ctx, genkit.WithPlugins(&anthropic.Anthropic{})) +// OpenAI +g := genkit.Init(ctx, genkit.WithPlugins(&openai.OpenAI{})) + // Ollama (local models) g := genkit.Init(ctx, genkit.WithPlugins(&ollama.Ollama{ ServerAddress: "http://localhost:11434", diff --git a/go/go.mod b/go/go.mod index 2750230ab0..bfc0befecf 100644 --- a/go/go.mod +++ b/go/go.mod @@ -28,6 +28,7 @@ require ( github.com/jba/slog v0.2.0 github.com/lib/pq v1.10.9 github.com/mark3labs/mcp-go v0.29.0 + github.com/openai/openai-go/v3 v3.16.0 github.com/pgvector/pgvector-go v0.3.0 github.com/stretchr/testify v1.10.0 github.com/weaviate/weaviate v1.30.0 diff --git a/go/go.sum b/go/go.sum index c832ade7c6..546bc4c947 100644 --- a/go/go.sum +++ b/go/go.sum @@ -308,6 +308,8 @@ github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/openai/openai-go v1.8.2 h1:UqSkJ1vCOPUpz9Ka5tS0324EJFEuOvMc+lA/EarJWP8= github.com/openai/openai-go v1.8.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= +github.com/openai/openai-go/v3 v3.16.0 h1:VdqS+GFZgAvEOBcWNyvLVwPlYEIboW5xwiUCcLrVf8c= +github.com/openai/openai-go/v3 v3.16.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= diff --git a/go/plugins/internal/models.go b/go/plugins/internal/models.go index 220d1461da..c3147d1bf8 100644 --- a/go/plugins/internal/models.go +++ b/go/plugins/internal/models.go @@ -46,4 +46,14 @@ var ( Media: true, Constrained: ai.ConstrainedSupportNone, } + + // Media describes model capabilities for models with media and text input and output + Media = ai.ModelSupports{ + Multiturn: false, + Tools: false, + ToolChoice: false, + SystemRole: false, + Media: true, + Constrained: ai.ConstrainedSupportNone, + } ) diff --git a/go/plugins/openai/generate.go b/go/plugins/openai/generate.go new file mode 100644 index 0000000000..68cedd95dc --- /dev/null +++ b/go/plugins/openai/generate.go @@ -0,0 +1,135 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package openai + +import ( + "context" + "fmt" + + "github.com/firebase/genkit/go/ai" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/responses" +) + +// generate is the entry point function to request content generation to the OpenAI client +func generate(ctx context.Context, client *openai.Client, model string, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error, +) (*ai.ModelResponse, error) { + req, err := toOpenAIResponseParams(model, input) + if err != nil { + return nil, err + } + + // stream mode + if cb != nil { + resp, err := generateStream(ctx, client, req, input, cb) + if err != nil { + return nil, err + } + return resp, nil + + } + + resp, err := generateComplete(ctx, client, req, input) + if err != nil { + return nil, err + } + return resp, nil +} + +// generateStream starts a new streaming response +func generateStream(ctx context.Context, client *openai.Client, req *responses.ResponseNewParams, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { + stream := client.Responses.NewStreaming(ctx, *req) + defer stream.Close() + + var ( + toolRefMap = make(map[string]string) + finalResp *responses.Response + ) + + for stream.Next() { + evt := stream.Current() + chunk := &ai.ModelResponseChunk{} + + switch v := evt.AsAny().(type) { + case responses.ResponseTextDeltaEvent: + chunk.Content = append(chunk.Content, ai.NewTextPart(v.Delta)) + + case responses.ResponseReasoningTextDeltaEvent: + chunk.Content = append(chunk.Content, ai.NewReasoningPart(v.Delta, nil)) + + case responses.ResponseFunctionCallArgumentsDeltaEvent: + name := toolRefMap[v.ItemID] + chunk.Content = append(chunk.Content, ai.NewToolRequestPart(&ai.ToolRequest{ + Ref: v.ItemID, + Name: name, + Input: v.Delta, + })) + + case responses.ResponseOutputItemAddedEvent: + switch item := v.Item.AsAny().(type) { + case responses.ResponseFunctionToolCall: + toolRefMap[item.CallID] = item.Name + chunk.Content = append(chunk.Content, ai.NewToolRequestPart(&ai.ToolRequest{ + Ref: item.CallID, + Name: item.Name, + })) + } + + case responses.ResponseCompletedEvent: + finalResp = &v.Response + } + + if len(chunk.Content) > 0 { + if err := cb(ctx, chunk); err != nil { + return nil, fmt.Errorf("callback error: %w", err) + } + } + } + + if err := stream.Err(); err != nil { + return nil, fmt.Errorf("stream error: %w", err) + } + + if finalResp != nil { + mResp, err := translateResponse(finalResp) + if err != nil { + return nil, err + } + mResp.Request = input + return mResp, nil + } + + // prevent returning an error if stream does not provide [responses.ResponseCompletedEvent] + // user might already have received the chunks throughout the loop + return &ai.ModelResponse{ + Request: input, + Message: &ai.Message{Role: ai.RoleModel}, + }, nil +} + +// generateComplete starts a new completion +func generateComplete(ctx context.Context, client *openai.Client, req *responses.ResponseNewParams, input *ai.ModelRequest) (*ai.ModelResponse, error) { + resp, err := client.Responses.New(ctx, *req) + if err != nil { + return nil, err + } + + modelResp, err := translateResponse(resp) + if err != nil { + return nil, err + } + modelResp.Request = input + return modelResp, nil +} diff --git a/go/plugins/openai/openai.go b/go/plugins/openai/openai.go new file mode 100644 index 0000000000..83b4dbb36a --- /dev/null +++ b/go/plugins/openai/openai.go @@ -0,0 +1,308 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package openai contains the Genkit Plugin implementation for OpenAI provider +package openai + +import ( + "context" + "fmt" + "log/slog" + "os" + "strings" + "sync" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/genkit" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/responses" +) + +const ( + openaiProvider = "openai" + openaiLabelPrefix = "OpenAI" +) + +var defaultOpenAIOpts = ai.ModelOptions{ + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: true, + ToolChoice: true, + SystemRole: true, + Media: true, + Constrained: ai.ConstrainedSupportAll, + }, + Versions: []string{}, + Stage: ai.ModelStageUnstable, +} + +var defaultEmbedOpts = ai.EmbedderOptions{} + +type OpenAI struct { + mu sync.Mutex // protects concurrent access to the client and init state + initted bool // tracks weter the plugin has been initialized + client *openai.Client // openAI client used for making requests + Opts []option.RequestOption // request options for the OpenAI client + APIKey string // API key to use with the desired plugin + BaseURL string // Base URL for custom endpoints +} + +func (o *OpenAI) Name() string { + return openaiProvider +} + +func (o *OpenAI) Init(ctx context.Context) []api.Action { + if o == nil { + o = &OpenAI{} + } + o.mu.Lock() + defer o.mu.Unlock() + if o.initted { + panic("plugin already initialized") + } + + apiKey := o.APIKey + if apiKey == "" { + apiKey = os.Getenv("OPENAI_API_KEY") + } + if apiKey != "" { + o.Opts = append([]option.RequestOption{option.WithAPIKey(apiKey)}, o.Opts...) + } + + baseURL := o.BaseURL + if baseURL == "" { + baseURL = os.Getenv("OPENAI_BASE_URL") + } + if baseURL != "" { + o.Opts = append([]option.RequestOption{option.WithBaseURL(baseURL)}, o.Opts...) + } + + client := openai.NewClient(o.Opts...) + o.client = &client + o.initted = true + + return []api.Action{} +} + +// DefineModel defines an unknown model with the given name. +func (o *OpenAI) DefineModel(g *genkit.Genkit, name string, opts *ai.ModelOptions) (ai.Model, error) { + o.mu.Lock() + defer o.mu.Unlock() + if !o.initted { + panic("OpenAI.Init not called") + } + if name == "" { + return nil, fmt.Errorf("OpenAI.DefineModel: called with empty model name") + } + + if opts == nil { + return nil, fmt.Errorf("OpenAI.DefineModel: called with unknown model options") + } + return newModel(o.client, name, opts), nil +} + +// Model returns the [ai.Model] with the given name. +// It returns nil if the model was not previously defined. +func Model(g *genkit.Genkit, name string) ai.Model { + return genkit.LookupModel(g, api.NewName(openaiProvider, name)) +} + +// ModelRef creates a new ModelRef for an OpenAI model with the given ID and configuration +func ModelRef(name string, config *responses.ResponseNewParams) ai.ModelRef { + return ai.NewModelRef(openaiProvider+"/"+name, config) +} + +// IsDefinedModel reports whether the named [ai.Model] is defined by this plugin +func IsDefinedModel(g *genkit.Genkit, name string) bool { + return genkit.LookupModel(g, name) != nil +} + +// DefineEmbedder defines an embedder with a given name +func (o *OpenAI) DefineEmbedder(g *genkit.Genkit, name string, embedOpts *ai.EmbedderOptions) (ai.Embedder, error) { + o.mu.Lock() + defer o.mu.Unlock() + if !o.initted { + panic("OpenAI.Init not called") + } + return newEmbedder(o.client, name, embedOpts), nil +} + +// Embedder returns the [ai.Embedder] with the given name. +// It returns nil if the embedder was not previously defined. +func Embedder(g *genkit.Genkit, name string) ai.Embedder { + return genkit.LookupEmbedder(g, name) +} + +// IsDefinedEmbedder reports whether the named [ai.Embedder] is defined by this plugin +func IsDefinedEmbedder(g *genkit.Genkit, name string) bool { + return genkit.LookupEmbedder(g, name) != nil +} + +// ListActions lists all the actions supported by the OpenAI plugin. +func (o *OpenAI) ListActions(ctx context.Context) []api.ActionDesc { + actions := []api.ActionDesc{} + models, err := listOpenAIModels(ctx, o.client) + if err != nil { + slog.Error("unable to fetch models from OpenAI API") + return nil + } + + for _, name := range models.chat { + model := newModel(o.client, name, &defaultOpenAIOpts) + if actionDef, ok := model.(api.Action); ok { + actions = append(actions, actionDef.Desc()) + } + } + for _, e := range models.embedders { + embedder := newEmbedder(o.client, e, &defaultEmbedOpts) + if actionDef, ok := embedder.(api.Action); ok { + actions = append(actions, actionDef.Desc()) + } + } + return actions +} + +// ResolveAction resolves an action with the given name. +func (o *OpenAI) ResolveAction(atype api.ActionType, name string) api.Action { + switch atype { + case api.ActionTypeEmbedder: + return newEmbedder(o.client, name, &ai.EmbedderOptions{}).(api.Action) + case api.ActionTypeModel: + var supports *ai.ModelSupports + var config any + + switch { + // TODO: add image and video models + default: + supports = &ai.ModelSupports{ + Multiturn: true, + Tools: true, + ToolChoice: true, + SystemRole: true, + Media: true, + Constrained: ai.ConstrainedSupportAll, + } + config = &responses.ResponseNewParams{} + } + return newModel(o.client, name, &ai.ModelOptions{ + Label: fmt.Sprintf("%s - %s", openaiLabelPrefix, name), + Stage: ai.ModelStageStable, + Versions: []string{}, + Supports: supports, + ConfigSchema: configToMap(config), + }).(api.Action) + } + return nil +} + +// openaiModels contains the collection of supported OpenAI models +type openaiModels struct { + chat []string // gpt, tts, o1, o2, o3... + image []string // gpt-image + video []string // sora + embedders []string // text-embedding... +} + +// listOpenAIModels returns a list of models available in the OpenAI API +// The returned struct is a filtered list of models based on plain string comparisons: +// chat: gpt, tts, o1, o2, o3... +// image: gpt-image +// video: sora +// embedders: text-embedding +// NOTE: the returned list from the SDK is just a plain slice of model names. +// No extra information about the model stage or type is provided. +// See: platform.openai.com/docs/models +func listOpenAIModels(ctx context.Context, client *openai.Client) (openaiModels, error) { + models := openaiModels{} + iter := client.Models.ListAutoPaging(ctx) + for iter.Next() { + m := iter.Current() + if strings.Contains(m.ID, "sora") { + models.video = append(models.video, m.ID) + continue + } + if strings.Contains(m.ID, "image") { + models.image = append(models.image, m.ID) + continue + } + if strings.Contains(m.ID, "embedding") { + models.embedders = append(models.embedders, m.ID) + continue + } + models.chat = append(models.chat, m.ID) + } + if err := iter.Err(); err != nil { + return openaiModels{}, err + } + + return models, nil +} + +// newEmbedder creates a new embedder without registering it +func newEmbedder(client *openai.Client, name string, embedOpts *ai.EmbedderOptions) ai.Embedder { + return ai.NewEmbedder(api.NewName(openaiProvider, name), embedOpts, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { + var data openai.EmbeddingNewParamsInputUnion + for _, doc := range req.Input { + for _, p := range doc.Content { + data.OfArrayOfStrings = append(data.OfArrayOfStrings, p.Text) + } + } + + params := openai.EmbeddingNewParams{ + Input: openai.EmbeddingNewParamsInputUnion(data), + Model: name, + EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat, + } + + embeddingResp, err := client.Embeddings.New(ctx, params) + if err != nil { + return nil, err + } + + resp := &ai.EmbedResponse{} + for _, e := range embeddingResp.Data { + embedding := make([]float32, len(e.Embedding)) + for i, v := range e.Embedding { + embedding[i] = float32(v) + } + resp.Embeddings = append(resp.Embeddings, &ai.Embedding{Embedding: embedding}) + } + return resp, nil + }) +} + +// newModel creates a new model without registering it in the registry +func newModel(client *openai.Client, name string, opts *ai.ModelOptions) ai.Model { + config := &responses.ResponseNewParams{} + meta := &ai.ModelOptions{ + Label: opts.Label, + Supports: opts.Supports, + Versions: opts.Versions, + ConfigSchema: configToMap(config), + Stage: opts.Stage, + } + + fn := func( + ctx context.Context, + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, + ) (*ai.ModelResponse, error) { + // TODO: add support for imagen and video + return generate(ctx, client, name, input, cb) + } + + return ai.NewModel(api.NewName(openaiProvider, name), meta, fn) +} diff --git a/go/plugins/openai/openai_live_test.go b/go/plugins/openai/openai_live_test.go new file mode 100644 index 0000000000..f164ff316c --- /dev/null +++ b/go/plugins/openai/openai_live_test.go @@ -0,0 +1,598 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package openai_test + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "os" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + oai "github.com/firebase/genkit/go/plugins/openai" + openai "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/responses" + "github.com/openai/openai-go/v3/shared" +) + +func TestOpenAILive(t *testing.T) { + if _, ok := requireEnv("OPENAI_API_KEY"); !ok { + t.Skip("OPENAI_API_KEY not found in the environment") + } + + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&oai.OpenAI{})) + + myJokeTool := genkit.DefineTool( + g, + "myJoke", + "When the user asks for a joke, this tool must be used to generate a joke, try to come up with a joke that uses the output of the tool", + func(ctx *ai.ToolContext, input *any) (string, error) { + return "why did the chicken cross the road?", nil + }, + ) + + myStoryTool := genkit.DefineTool( + g, + "myStory", + "When the user asks for a story, create a story about a frog and a fox that are good friends", + func(ctx *ai.ToolContext, input *any) (string, error) { + return "the fox is named Goph and the frog is called Fred", nil + }, + ) + + type WeatherInput struct { + Location string `json:"location"` + } + + weatherTool := genkit.DefineTool( + g, + "weather", + "Returns the weather for the given location", + func(ctx *ai.ToolContext, input *WeatherInput) (string, error) { + report := fmt.Sprintf("The weather in %s is sunny", input.Location) + return report, nil + }, + ) + + gablorkenDefinitionTool := genkit.DefineTool( + g, + "gablorkenDefinitionTool", + "Custom tool that must be used when the user asks for the definition of a gablorken", + func(ctx *ai.ToolContext, input *any) (string, error) { + return "A gablorken is a interstellar currency for the Andromeda Galaxy. It is equivalent to 0.4 USD per Gablorken (GAB)", nil + }, + ) + + t.Run("model version ok", func(t *testing.T) { + m := oai.Model(g, "gpt-4o") + resp, err := genkit.Generate(ctx, g, + ai.WithConfig(&responses.ResponseNewParams{ + Temperature: openai.Float(1), + MaxOutputTokens: openai.Int(1024), + }), + ai.WithModel(m), + ai.WithSystem("talk to me like an evil pirate and say \"Arr\" several times but be very short"), + ai.WithMessages(ai.NewUserMessage(ai.NewTextPart("I'm a fish"))), + ) + if err != nil { + t.Fatal(err) + } + + if !strings.Contains(resp.Text(), "Arr") { + t.Fatalf("not a pirate:%s", resp.Text()) + } + }) + + t.Run("model version not ok", func(t *testing.T) { + m := oai.Model(g, "non-existent-model") + _, err := genkit.Generate(ctx, g, + ai.WithConfig(&responses.ResponseNewParams{ + Temperature: openai.Float(1), + MaxOutputTokens: openai.Int(1024), + }), + ai.WithModel(m), + ) + if err == nil { + t.Fatal("should have failed due wrong model version") + } + }) + + t.Run("model ref", func(t *testing.T) { + config := &responses.ResponseNewParams{ + Temperature: openai.Float(1), + MaxOutputTokens: openai.Int(1024), + // Add built-in tool via config + } + mr := oai.ModelRef("gpt-4o", config) + + resp, err := genkit.Generate(ctx, g, + ai.WithPrompt("tell me a fact about Golang"), + ai.WithModel(mr), + ) + if err != nil { + t.Fatal(err) + } + if resp.Text() == "" { + t.Error("expected to have a response, got empty") + } + if resp.Request.Config == nil { + t.Fatal("expected a not nil configuration, got empty") + } + if cfg, ok := resp.Request.Config.(*responses.ResponseNewParams); ok { + if cfg.MaxOutputTokens != openai.Int(1024) { + t.Errorf("wrong MaxOutputTokens value, got: %d, want: 1024", cfg.MaxOutputTokens.Value) + } + if cfg.Temperature != openai.Float(1) { + t.Errorf("wrongTemperature value, got: %f, want: 1", cfg.Temperature.Value) + } + } else { + t.Fatalf("unexpected config, got: %T, want: %T", cfg, config) + } + }) + + t.Run("media content", func(t *testing.T) { + i, err := fetchImgAsBase64() + if err != nil { + t.Fatal(err) + } + m := oai.Model(g, "gpt-4o") + resp, err := genkit.Generate(ctx, g, + ai.WithSystem("You are a professional image detective that talks like an evil pirate that loves animals, your task is to tell the name of the animal in the image but be very short"), + ai.WithModel(m), + ai.WithConfig(&responses.ResponseNewParams{ + Temperature: openai.Float(1), + MaxOutputTokens: openai.Int(1024), + }), + ai.WithMessages( + ai.NewUserMessage( + ai.NewTextPart("do you know which animal is in the image?"), + ai.NewMediaPart("", "data:image/jpeg;base64,"+i)))) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(strings.ToLower(resp.Text()), "cat") { + t.Fatalf("want: cat, got: %s", resp.Text()) + } + }) + + t.Run("media content stream", func(t *testing.T) { + i, err := fetchImgAsBase64() + if err != nil { + t.Fatal(err) + } + out := "" + m := oai.Model(g, "gpt-4o") + resp, err := genkit.Generate(ctx, g, + ai.WithSystem("You are a professional image detective that talks like an evil pirate that loves animals, your task is to tell the name of the animal in the image but be very short"), + ai.WithModel(m), + ai.WithConfig(&responses.ResponseNewParams{ + Temperature: openai.Float(1), + MaxOutputTokens: openai.Int(1024), + }), + ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error { + out += c.Content[0].Text + return nil + }), + ai.WithMessages( + ai.NewUserMessage( + ai.NewTextPart("do you know which animal is in the image?"), + ai.NewMediaPart("", "data:image/jpeg;base64,"+i)))) + if err != nil { + t.Fatal(err) + } + if out != resp.Text() { + t.Fatalf("want: %s, got: %s", resp.Text(), out) + } + if !strings.Contains(strings.ToLower(resp.Text()), "cat") { + t.Fatalf("want: cat, got: %s", resp.Text()) + } + }) + + t.Run("media content stream with thinking", func(t *testing.T) { + i, err := fetchImgAsBase64() + if err != nil { + t.Fatal(err) + } + out := "" + m := oai.Model(g, "gpt-5") + resp, err := genkit.Generate(ctx, g, + ai.WithSystem(`You are a professional image detective that + talks like an evil pirate that loves animals, your task is to tell the name + of the animal in the image but be very short`), + ai.WithModel(m), + ai.WithConfig(&responses.ResponseNewParams{ + Reasoning: shared.ReasoningParam{ + Effort: shared.ReasoningEffortMedium, + }, + }), + ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error { + for _, p := range c.Content { + if p.IsText() { + out += p.Text + } + } + return nil + }), + ai.WithMessages( + ai.NewUserMessage( + ai.NewTextPart("do you know which animal is in the image?"), + ai.NewMediaPart("", "data:image/jpeg;base64,"+i)))) + if err != nil { + t.Fatal(err) + } + + if out != resp.Text() { + t.Fatalf("want: %s, got: %s", resp.Text(), out) + } + if !strings.Contains(strings.ToLower(resp.Text()), "cat") { + t.Fatalf("want: cat, got: %s", resp.Text()) + } + if resp.Usage.ThoughtsTokens == 0 { + t.Log("No reasoning tokens found in usage (expected for reasoning models)") + } + }) + + t.Run("tools", func(t *testing.T) { + m := oai.Model(g, "gpt-5") + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithConfig(&responses.ResponseNewParams{ + Temperature: openai.Float(1), + MaxOutputTokens: openai.Int(1024), + }), + ai.WithPrompt("tell me the definition of a gablorken"), + ai.WithTools(gablorkenDefinitionTool)) + if err != nil { + t.Fatal(err) + } + + if len(resp.Text()) == 0 { + t.Fatal("expected a response but nothing was returned") + } + }) + + t.Run("tools with schema", func(t *testing.T) { + m := oai.Model(g, "gpt-4o") + type weather struct { + Report string `json:"report"` + } + + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithConfig(&responses.ResponseNewParams{ + Temperature: openai.Float(1), + MaxOutputTokens: openai.Int(1024), + }), + ai.WithPrompt("what is the weather in San Francisco?"), + ai.WithOutputType(weather{}), + ai.WithTools(weatherTool)) + if err != nil { + t.Fatal(err) + } + + var w weather + if err = resp.Output(&w); err != nil { + t.Fatal(err) + } + if w.Report == "" { + t.Fatal("empty weather report, tool should have provided an output") + } + }) + + t.Run("streaming", func(t *testing.T) { + m := oai.Model(g, "gpt-4o") + out := "" + + final, err := genkit.Generate(ctx, g, + ai.WithPrompt("Tell me a short story about a frog and a princess"), + ai.WithConfig(&responses.ResponseNewParams{ + Temperature: openai.Float(1), + MaxOutputTokens: openai.Int(1024), + }), + ai.WithModel(m), + ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error { + for _, p := range c.Content { + if p.IsText() { + out += p.Text + } + } + return nil + }), + ) + if err != nil { + t.Fatal(err) + } + + out2 := "" + for _, p := range final.Message.Content { + out2 += p.Text + } + + if out != out2 { + t.Fatalf("streaming and final should contain the same text.\nstreaming: %s\nfinal:%s\n", out, out2) + } + if final.Usage.InputTokens == 0 || final.Usage.OutputTokens == 0 { + t.Fatalf("empty usage stats: %#v", *final.Usage) + } + }) + + t.Run("streaming with thinking", func(t *testing.T) { + m := oai.Model(g, "gpt-4o") + out := "" + + final, err := genkit.Generate(ctx, g, + ai.WithPrompt("Sing me a song about metaphysics"), + ai.WithConfig(&responses.ResponseNewParams{}), + ai.WithModel(m), + ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error { + for _, p := range c.Content { + if p.IsText() { + out += p.Text + } + } + return nil + }), + ) + if err != nil { + t.Fatal(err) + } + + out2 := "" + for _, p := range final.Message.Content { + if p.IsText() { + out2 += p.Text + } + } + if out != out2 { + t.Fatalf("streaming and final should contain the same text.\n\nstreaming: %s\n\nfinal: %s\n\n", out, out2) + } + + if final.Usage.ThoughtsTokens > 0 { + t.Logf("Reasoning tokens: %d", final.Usage.ThoughtsTokens) + } else { + // this might happen if the model decides not to reason much or if stats are missing. + t.Log("No reasoning tokens reported.") + } + }) + + t.Run("tools streaming", func(t *testing.T) { + m := oai.Model(g, "gpt-4o") + out := "" + + final, err := genkit.Generate(ctx, g, + ai.WithPrompt("Tell me a short story about a frog and a fox, do no mention anything else, only the short story"), + ai.WithModel(m), + ai.WithConfig(&responses.ResponseNewParams{ + Temperature: openai.Float(1), + MaxOutputTokens: openai.Int(1024), + }), + ai.WithTools(myStoryTool), + ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error { + for _, p := range c.Content { + if p.IsText() { + out += p.Text + } + } + return nil + }), + ) + if err != nil { + t.Fatal(err) + } + + out2 := "" + for _, p := range final.Message.Content { + if p.IsText() { + out2 += p.Text + } + } + + if out != out2 { + t.Fatalf("streaming and final should contain the same text\n\nstreaming: %s\n\nfinal: %s\n\n", out, out2) + } + if final.Usage.InputTokens == 0 || final.Usage.OutputTokens == 0 { + t.Fatalf("empty usage stats: %#v", *final.Usage) + } + }) + + t.Run("built-in tools", func(t *testing.T) { + m := oai.Model(g, "gpt-4o") + + webSearchTool := responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch) + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithConfig(&responses.ResponseNewParams{ + Temperature: openai.Float(1), + MaxOutputTokens: openai.Int(1024), + // Add built-in tool via config + Tools: []responses.ToolUnionParam{webSearchTool}, + }), + ai.WithPrompt("What's the current weather in SFO?"), + ) + if err != nil { + t.Fatal(err) + } + + if len(resp.Text()) == 0 { + t.Fatal("expected a response but nothing was returned") + } + }) + + t.Run("mixed tools", func(t *testing.T) { + m := oai.Model(g, "gpt-4o") + + webSearchTool := responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch) + + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithConfig(&responses.ResponseNewParams{ + Temperature: openai.Float(1), + MaxOutputTokens: openai.Int(1024), + ParallelToolCalls: openai.Bool(true), + // Add built-in tool via config + Tools: []responses.ToolUnionParam{webSearchTool}, + }), + ai.WithPrompt("I'd would like to ask you two things: What's the current weather in SFO? What's the meaning of gablorken?. Use the web search tool to get the weather in SFO and use the gablorken definition tool to give me its definition. Make sure to include the response for both questions in your answer"), + ai.WithTools(gablorkenDefinitionTool), + ) + if err != nil { + t.Fatal(err) + } + + if len(resp.Text()) == 0 { + t.Fatal("expected a response but nothing was returned") + } + }) + + t.Run("structured output", func(t *testing.T) { + m := oai.Model(g, "gpt-4o") + + type MovieReview struct { + Title string `json:"title"` + Rating int `json:"rating"` + Reason string `json:"reason"` + } + + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithPrompt("Review the movie 'Inception'"), + ai.WithOutputType(MovieReview{}), + ) + if err != nil { + t.Fatal(err) + } + var out MovieReview + if err := resp.Output(&out); err != nil { + t.Errorf("expected a movie review, got: %v", err) + } + if out.Title == "" || out.Rating == 0 || out.Reason == "" { + t.Fatalf("expected a movie review, got %#v", out) + } + + review, _, err := genkit.GenerateData[MovieReview](ctx, g, + ai.WithModel(m), + ai.WithPrompt("Review the movie 'Signs'"), + ) + if err != nil { + t.Fatal(err) + } + + if review.Title == "" || review.Rating == 0 || review.Reason == "" { + t.Fatalf("expected a movie review, got %#v", review) + } + }) + + t.Run("streaming using GenerateDataStream", func(t *testing.T) { + m := oai.Model(g, "gpt-4o") + + type answerChunk struct { + Text string `json:"text"` + } + + chunksCount := 0 + var finalAnswer answerChunk + for val, err := range genkit.GenerateDataStream[answerChunk](ctx, g, + ai.WithModel(m), + ai.WithPrompt("Tell me how's a black hole created in 2 sentences."), + ) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalAnswer = val.Output + } else { + chunksCount++ + } + } + + if chunksCount == 0 { + t.Errorf("expected to receive some chunks, got 0") + } + if finalAnswer.Text == "" { + t.Errorf("expected final answer, got empty") + } + }) + + t.Run("GenerateDataStream with custom tools", func(t *testing.T) { + m := oai.Model(g, "gpt-4o") + + type JokeResponse struct { + Setup string `json:"setup"` + Punchline string `json:"punchline"` + } + + chunksCount := 0 + var finalJoke JokeResponse + + for val, err := range genkit.GenerateDataStream[JokeResponse](ctx, g, + ai.WithModel(m), + ai.WithPrompt("Tell me a joke about a chicken crossing the road. Use the myJoke tool to get the punchline."), + ai.WithTools(myJokeTool), + ) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalJoke = val.Output + } else { + chunksCount++ + } + } + + if chunksCount == 0 { + t.Errorf("expected to receive some chunks, got 0") + } + if finalJoke.Setup == "" || finalJoke.Punchline == "" { + t.Errorf("expected final joke setup and punchline to be populated, got %+v", finalJoke) + } + }) +} + +func fetchImgAsBase64() (string, error) { + // CC0 license image + imgURL := "https://pd.w.org/2025/07/896686fbbcd9990c9.84605288-2048x1365.jpg" + resp, err := http.Get(imgURL) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", err + } + + imageBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + base64string := base64.StdEncoding.EncodeToString(imageBytes) + return base64string, nil +} + +func requireEnv(key string) (string, bool) { + value, ok := os.LookupEnv(key) + if !ok || value == "" { + return "", false + } + + return value, true +} diff --git a/go/plugins/openai/openai_test.go b/go/plugins/openai/openai_test.go new file mode 100644 index 0000000000..fbf7fd0bd7 --- /dev/null +++ b/go/plugins/openai/openai_test.go @@ -0,0 +1,340 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package openai + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/openai/openai-go/v3/responses" + "github.com/openai/openai-go/v3/shared" +) + +func TestToOpenAIResponseParams_SystemMessage(t *testing.T) { + msg := &ai.Message{ + Role: ai.RoleSystem, + Content: []*ai.Part{ai.NewTextPart("system instruction")}, + } + req := &ai.ModelRequest{ + Messages: []*ai.Message{msg}, + } + + params, err := toOpenAIResponseParams("gpt-4o", req) + if err != nil { + t.Fatalf("toOpenAIResponseParams() error = %v", err) + } + + if params.Instructions.Value != "system instruction" { + t.Errorf("Instructions = %q, want %q", params.Instructions.Value, "system instruction") + } +} + +func TestToOpenAIInputItems_JSON(t *testing.T) { + tests := []struct { + name string + msg *ai.Message + want []string // substrings to match in JSON + }{ + { + name: "user text message", + msg: &ai.Message{ + Role: ai.RoleUser, + Content: []*ai.Part{ai.NewTextPart("user query")}, + }, + want: []string{`"role":"user"`, `"type":"input_text"`, `"text":"user query"`}, + }, + { + name: "model text message", + msg: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewTextPart("model response")}, + }, + want: []string{`"role":"assistant"`, `"type":"output_text"`, `"text":"model response"`, `"status":"completed"`}, + }, + { + name: "tool request", + msg: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewToolRequestPart(&ai.ToolRequest{ + Name: "myTool", + Ref: "call_123", + Input: map[string]string{"arg": "val"}, + })}, + }, + want: []string{`"type":"function_call"`, `"name":"myTool"`, `"call_id":"call_123"`, `"arguments":"{\"arg\":\"val\"}"`}, + }, + { + name: "tool response", + msg: &ai.Message{ + Role: ai.RoleTool, + Content: []*ai.Part{ai.NewToolResponsePart(&ai.ToolResponse{ + Name: "myTool", + Ref: "call_123", + Output: map[string]string{"res": "ok"}, + })}, + }, + want: []string{`"type":"function_call_output"`, `"call_id":"call_123"`, `"output":"{\"res\":\"ok\"}"`}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + items, err := toOpenAIInputItems(tc.msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + b, err := json.Marshal(items) + if err != nil { + t.Fatalf("marshal error: %v", err) + } + jsonStr := string(b) + + for _, w := range tc.want { + if !strings.Contains(jsonStr, w) { + t.Errorf("JSON output missing %q. Got: %s", w, jsonStr) + } + } + }) + } +} + +func TestToOpenAITools(t *testing.T) { + tests := []struct { + name string + tools []*ai.ToolDefinition + wantCount int + check func(*testing.T, []responses.ToolUnionParam) + }{ + { + name: "basic function tool", + tools: []*ai.ToolDefinition{ + { + Name: "myTool", + Description: "does something", + InputSchema: map[string]any{"type": "object"}, + }, + }, + wantCount: 1, + check: func(t *testing.T, got []responses.ToolUnionParam) { + tool := got[0] + // We need to marshal to check fields since they are hidden in UnionParam + b, _ := json.Marshal(tool) + s := string(b) + if !strings.Contains(s, `"name":"myTool"`) { + t.Errorf("missing name: %s", s) + } + if !strings.Contains(s, `"type":"function"`) { + t.Errorf("missing type function: %s", s) + } + }, + }, + { + name: "empty name tool ignored", + tools: []*ai.ToolDefinition{ + {Name: ""}, + }, + wantCount: 0, + check: func(t *testing.T, got []responses.ToolUnionParam) {}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := toOpenAITools(tc.tools) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != tc.wantCount { + t.Errorf("got %d tools, want %d", len(got), tc.wantCount) + } + tc.check(t, got) + }) + } +} + +func TestTranslateResponse(t *testing.T) { + // this is a workaround function to bypass Union types used in the openai-go SDK + createResponse := func(jsonStr string) *responses.Response { + var r responses.Response + if err := json.Unmarshal([]byte(jsonStr), &r); err != nil { + t.Fatalf("failed to create mock response: %v", err) + } + return &r + } + + tests := []struct { + name string + respJSON string + wantReason ai.FinishReason + check func(*testing.T, *ai.ModelResponse) + }{ + { + name: "text response completed", + respJSON: `{ + "id": "resp_1", + "status": "completed", + "usage": {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello world"}] + } + ] + }`, + wantReason: ai.FinishReasonStop, + check: func(t *testing.T, m *ai.ModelResponse) { + if len(m.Message.Content) != 1 { + t.Fatalf("got %d parts, want 1", len(m.Message.Content)) + } + if got := m.Message.Content[0].Text; got != "Hello world" { + t.Errorf("got text %q, want 'Hello world'", got) + } + if m.Usage.TotalTokens != 30 { + t.Errorf("got usage %d, want 30", m.Usage.TotalTokens) + } + }, + }, + { + name: "incomplete response", + respJSON: `{ + "id": "resp_2", + "status": "incomplete", + "output": [] + }`, + wantReason: ai.FinishReasonLength, // mapped from Incomplete + check: func(t *testing.T, m *ai.ModelResponse) {}, + }, + { + name: "refusal response", + respJSON: `{ + "id": "resp_3", + "status": "completed", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "refusal", "refusal": "I cannot do that"}] + } + ] + }`, + wantReason: ai.FinishReasonBlocked, + check: func(t *testing.T, m *ai.ModelResponse) { + if m.FinishMessage != "I cannot do that" { + t.Errorf("got FinishMessage %q, want 'I cannot do that'", m.FinishMessage) + } + }, + }, + { + name: "tool call", + respJSON: `{ + "id": "resp_4", + "status": "completed", + "output": [ + { + "type": "function_call", + "call_id": "call_abc", + "name": "weather", + "arguments": "{\"loc\":\"SFO\"}" + } + ] + }`, + wantReason: ai.FinishReasonStop, + check: func(t *testing.T, m *ai.ModelResponse) { + if len(m.Message.Content) != 1 { + t.Fatalf("got %d parts, want 1", len(m.Message.Content)) + } + p := m.Message.Content[0] + if !p.IsToolRequest() { + t.Fatalf("expected tool request part") + } + if p.ToolRequest.Name != "weather" { + t.Errorf("got tool name %q, want 'weather'", p.ToolRequest.Name) + } + args := p.ToolRequest.Input.(map[string]any) + if args["loc"] != "SFO" { + t.Errorf("got arg loc %v, want 'SFO'", args["loc"]) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := createResponse(tc.respJSON) + got, err := translateResponse(r) + if err != nil { + t.Fatalf("translateResponse() unexpected error: %v", err) + } + if got.FinishReason != tc.wantReason { + t.Errorf("got reason %v, want %v", got.FinishReason, tc.wantReason) + } + tc.check(t, got) + }) + } +} + +func TestConfigFromRequest(t *testing.T) { + tests := []struct { + name string + input any + wantErr bool + check func(*testing.T, *responses.ResponseNewParams) + }{ + { + name: "struct config", + input: responses.ResponseNewParams{ + Model: shared.ResponsesModel("gpt-4o"), + }, + check: func(t *testing.T, got *responses.ResponseNewParams) { + if got.Model != shared.ResponsesModel("gpt-4o") { + t.Errorf("got model %v, want %v", got.Model, shared.ResponsesModel("gpt-4o")) + } + }, + }, + { + name: "map config", + input: map[string]any{ + "model": "gpt-4o", + }, + check: func(t *testing.T, got *responses.ResponseNewParams) { + if got.Model != "gpt-4o" { + t.Errorf("got model %v, want gpt-4o", got.Model) + } + }, + }, + { + name: "invalid type", + input: "some string", + wantErr: true, + check: func(t *testing.T, got *responses.ResponseNewParams) {}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := configFromRequest(tc.input) + if (err != nil) != tc.wantErr { + t.Errorf("configFromRequest() error = %v, wantErr %v", err, tc.wantErr) + return + } + if !tc.wantErr { + tc.check(t, got) + } + }) + } +} diff --git a/go/plugins/openai/translator.go b/go/plugins/openai/translator.go new file mode 100644 index 0000000000..6a1b9deb8f --- /dev/null +++ b/go/plugins/openai/translator.go @@ -0,0 +1,535 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package openai + +import ( + "encoding/json" + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/internal/base" + "github.com/invopop/jsonschema" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" + "github.com/openai/openai-go/v3/shared" + "github.com/openai/openai-go/v3/shared/constant" +) + +// toOpenAIResponseParams translates an [ai.ModelRequest] into [responses.ResponseNewParams] +func toOpenAIResponseParams(model string, input *ai.ModelRequest) (*responses.ResponseNewParams, error) { + params, err := configFromRequest(input.Config) + if err != nil { + return nil, err + } + if params == nil { + params = &responses.ResponseNewParams{} + } + + params.Model = shared.ResponsesModel(model) + + // Handle output format + params.Text = handleOutputFormat(params.Text, input.Output) + + // Handle tools + if len(input.Tools) > 0 { + + tools, err := toOpenAITools(input.Tools) + if err != nil { + return nil, err + } + // Append user tools to any existing tools (e.g. built-in tools provided in config) + params.Tools = append(params.Tools, tools...) + + switch input.ToolChoice { + case ai.ToolChoiceAuto, "": + params.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{ + OfToolChoiceMode: param.NewOpt(responses.ToolChoiceOptions("auto")), + } + case ai.ToolChoiceRequired: + params.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{ + OfToolChoiceMode: param.NewOpt(responses.ToolChoiceOptions("required")), + } + case ai.ToolChoiceNone: + params.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{ + OfToolChoiceMode: param.NewOpt(responses.ToolChoiceOptions("none")), + } + default: + params.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{ + OfFunctionTool: &responses.ToolChoiceFunctionParam{ + Name: string(input.ToolChoice), + }, + } + } + } + + // messages to input items + var inputItems []responses.ResponseInputItemUnionParam + var instructions []string + + for _, m := range input.Messages { + if m.Role == ai.RoleSystem { + instructions = append(instructions, m.Text()) + continue + } + + items, err := toOpenAIInputItems(m) + if err != nil { + return nil, err + } + inputItems = append(inputItems, items...) + } + + if len(instructions) > 0 { + params.Instructions = param.NewOpt(strings.Join(instructions, "\n")) + } + if len(inputItems) > 0 { + params.Input = responses.ResponseNewParamsInputUnion{ + OfInputItemList: inputItems, + } + } + + return params, nil +} + +// translateResponse translates an [responses.Response] into an [ai.ModelResponse] +func translateResponse(r *responses.Response) (*ai.ModelResponse, error) { + resp := &ai.ModelResponse{ + Message: &ai.Message{ + Role: ai.RoleModel, + Content: make([]*ai.Part, 0), + }, + } + + resp.Usage = &ai.GenerationUsage{ + InputTokens: int(r.Usage.InputTokens), + OutputTokens: int(r.Usage.OutputTokens), + CachedContentTokens: int(r.Usage.InputTokensDetails.CachedTokens), + ThoughtsTokens: int(r.Usage.OutputTokensDetails.ReasoningTokens), + TotalTokens: int(r.Usage.TotalTokens), + } + + switch r.Status { + case responses.ResponseStatusCompleted: + resp.FinishReason = ai.FinishReasonStop + case responses.ResponseStatusIncomplete: + resp.FinishReason = ai.FinishReasonLength + case responses.ResponseStatusFailed, responses.ResponseStatusCancelled: + resp.FinishReason = ai.FinishReasonOther + default: + resp.FinishReason = ai.FinishReasonUnknown + } + + for _, item := range r.Output { + if err := handleResponseItem(item, resp); err != nil { + return nil, err + } + } + + return resp, nil +} + +// handleResponseItem is the entry point to translate response items +func handleResponseItem(item responses.ResponseOutputItemUnion, resp *ai.ModelResponse) error { + switch v := item.AsAny().(type) { + case responses.ResponseOutputMessage: + return handleOutputMessage(v, resp) + case responses.ResponseReasoningItem: + return handleReasoningItem(v, resp) + case responses.ResponseFunctionToolCall: + return handleFunctionToolCall(v, resp) + case responses.ResponseFunctionWebSearch: + return handleWebSearchResponse(v, resp) + default: + return fmt.Errorf("unsupported response item type: %T", v) + } +} + +// handleOutputMessage translates a [responses.ResponseOutputMessage] into an [ai.ModelResponse] +// and appends the content into the provided response message. +func handleOutputMessage(msg responses.ResponseOutputMessage, resp *ai.ModelResponse) error { + for _, content := range msg.Content { + switch c := content.AsAny().(type) { + case responses.ResponseOutputText: + resp.Message.Content = append(resp.Message.Content, ai.NewTextPart(c.Text)) + case responses.ResponseOutputRefusal: + resp.FinishMessage = c.Refusal + resp.FinishReason = ai.FinishReasonBlocked + } + } + return nil +} + +// handleReasoningItem translates a [responses.ResponseReasoningItem] into an [ai.ModelResponse] +// and appends the content into the provided response message. +func handleReasoningItem(item responses.ResponseReasoningItem, resp *ai.ModelResponse) error { + for _, content := range item.Content { + resp.Message.Content = append(resp.Message.Content, ai.NewReasoningPart(content.Text, nil)) + } + return nil +} + +// handleFunctionToolCall translates a [responses.ResponseFunctionToolCall] into an [ai.ModelResponse] +// and appends the content into the provided response message. +func handleFunctionToolCall(call responses.ResponseFunctionToolCall, resp *ai.ModelResponse) error { + args, err := jsonStringToMap(call.Arguments) + if err != nil { + return fmt.Errorf("could not parse tool args: %w", err) + } + resp.Message.Content = append(resp.Message.Content, ai.NewToolRequestPart(&ai.ToolRequest{ + Ref: call.CallID, + Name: call.Name, + Input: args, + })) + return nil +} + +// handleWebSearchResponse translates a [responses.ResponseFunctionWebSearch] into an [ai.ModelResponse] +// and appends the content into the provided response message. +func handleWebSearchResponse(webSearch responses.ResponseFunctionWebSearch, resp *ai.ModelResponse) error { + resp.Message.Content = append(resp.Message.Content, ai.NewToolResponsePart(&ai.ToolResponse{ + Ref: webSearch.ID, + Name: string(webSearch.Type), + Output: map[string]any{ + "query": webSearch.Action.Query, + "type": webSearch.Action.Type, + "sources": webSearch.Action.Sources, + }, + })) + return nil +} + +// toOpenAIInputItems converts a Genkit message to OpenAI Input Items +func toOpenAIInputItems(m *ai.Message) ([]responses.ResponseInputItemUnionParam, error) { + var items []responses.ResponseInputItemUnionParam + var partsBuffer []*ai.Part + + // flush() converts a sequence of text and media parts into a single OpenAI Input Item. + // Message roles taken in consideration: + // Model (or Assistant): converted to [responses.ResponseOutputMessageContentUnionParam] + // User/System: converted to [responses.ResponseInputContentUnionParam] + // + // This is needed for the Responses API since it forbids to use Input Items for assistant role messages + flush := func() error { + if len(partsBuffer) == 0 { + return nil + } + + if m.Role == ai.RoleModel { + // conversation-history text messages that the model previously generated + var content []responses.ResponseOutputMessageContentUnionParam + for _, p := range partsBuffer { + if p.IsText() { + content = append(content, responses.ResponseOutputMessageContentUnionParam{ + OfOutputText: &responses.ResponseOutputTextParam{ + Text: p.Text, + Annotations: []responses.ResponseOutputTextAnnotationUnionParam{}, + Type: constant.OutputText("output_text"), + }, + }) + } + } + if len(content) > 0 { + // we need a unique ID for the output message + id := fmt.Sprintf("msg_%p", m) + items = append(items, responses.ResponseInputItemParamOfOutputMessage( + content, + id, + responses.ResponseOutputMessageStatusCompleted, + )) + } + } else { + var content []responses.ResponseInputContentUnionParam + for _, p := range partsBuffer { + if p.IsText() { + content = append(content, responses.ResponseInputContentParamOfInputText(p.Text)) + } else if p.IsImage() || p.IsMedia() { + content = append(content, responses.ResponseInputContentUnionParam{ + OfInputImage: &responses.ResponseInputImageParam{ + ImageURL: param.NewOpt(p.Text), + }, + }) + } + } + if len(content) > 0 { + role := responses.EasyInputMessageRoleUser + // prevent unexpected system messages being sent as User, use Developer role to + // provide new "system" instructions during the conversation + if m.Role == ai.RoleSystem { + role = responses.EasyInputMessageRole("developer") + } + items = append(items, responses.ResponseInputItemParamOfMessage( + responses.ResponseInputMessageContentListParam(content), role), + ) + } + } + + partsBuffer = nil + return nil + } + + for _, p := range m.Content { + if p.IsText() || p.IsImage() || p.IsMedia() { + partsBuffer = append(partsBuffer, p) + } else if p.IsToolRequest() { + if err := flush(); err != nil { + return nil, err + } + args, err := anyToJSONString(p.ToolRequest.Input) + if err != nil { + return nil, err + } + ref := p.ToolRequest.Ref + if ref == "" { + ref = p.ToolRequest.Name + } + items = append(items, responses.ResponseInputItemParamOfFunctionCall(args, ref, p.ToolRequest.Name)) + } else if p.IsReasoning() { + if err := flush(); err != nil { + return nil, err + } + id := fmt.Sprintf("reasoning_%p", p) + summary := []responses.ResponseReasoningItemSummaryParam{ + { + Text: p.Text, + Type: constant.SummaryText("summary_text"), + }, + } + items = append(items, responses.ResponseInputItemParamOfReasoning(id, summary)) + } else if p.IsToolResponse() { + if err := flush(); err != nil { + return nil, err + } + + // handle built-in tools + // TODO: consider adding support for more built-in tools + if p.ToolResponse.Name == "web_search_call" { + item, err := handleWebSearchCall(p.ToolResponse, p.ToolResponse.Ref) + if err != nil { + return nil, err + } + items = append(items, item) + continue + } + + output, err := anyToJSONString(p.ToolResponse.Output) + if err != nil { + return nil, err + } + ref := p.ToolResponse.Ref + items = append(items, responses.ResponseInputItemParamOfFunctionCallOutput(ref, output)) + } + } + if err := flush(); err != nil { + return nil, err + } + + return items, nil +} + +// handleWebSearchCall handles built-in tool responses for the web_search tool +func handleWebSearchCall(toolResponse *ai.ToolResponse, ref string) (responses.ResponseInputItemUnionParam, error) { + output, ok := toolResponse.Output.(map[string]any) + if !ok { + return responses.ResponseInputItemUnionParam{}, fmt.Errorf("invalid output format for web_search_call: expected map[string]any") + } + + actionType, _ := output["type"].(string) + jsonBytes, err := json.Marshal(output) + if err != nil { + return responses.ResponseInputItemUnionParam{}, err + } + + var item responses.ResponseInputItemUnionParam + + switch actionType { + case "open_page": + var openPageAction responses.ResponseFunctionWebSearchActionOpenPageParam + if err := json.Unmarshal(jsonBytes, &openPageAction); err != nil { + return responses.ResponseInputItemUnionParam{}, err + } + item = responses.ResponseInputItemParamOfWebSearchCall( + openPageAction, + ref, + responses.ResponseFunctionWebSearchStatusCompleted, + ) + case "find": + var findAction responses.ResponseFunctionWebSearchActionFindParam + if err := json.Unmarshal(jsonBytes, &findAction); err != nil { + return responses.ResponseInputItemUnionParam{}, err + } + item = responses.ResponseInputItemParamOfWebSearchCall( + findAction, + ref, + responses.ResponseFunctionWebSearchStatusCompleted, + ) + default: + var searchAction responses.ResponseFunctionWebSearchActionSearchParam + if err := json.Unmarshal(jsonBytes, &searchAction); err != nil { + return responses.ResponseInputItemUnionParam{}, err + } + item = responses.ResponseInputItemParamOfWebSearchCall( + searchAction, + ref, + responses.ResponseFunctionWebSearchStatusCompleted, + ) + } + return item, nil +} + +// toOpenAITools converts a slice of [ai.ToolDefinition] to [responses.ToolUnionParam] +func toOpenAITools(tools []*ai.ToolDefinition) ([]responses.ToolUnionParam, error) { + var result []responses.ToolUnionParam + for _, t := range tools { + if t == nil || t.Name == "" { + continue + } + result = append(result, responses.ToolParamOfFunction(t.Name, t.InputSchema, false)) + } + return result, nil +} + +// configFromRequest casts the given configuration into [responses.ResponseNewParams] +func configFromRequest(config any) (*responses.ResponseNewParams, error) { + var openaiConfig responses.ResponseNewParams + + switch cfg := config.(type) { + case responses.ResponseNewParams: + openaiConfig = cfg + case *responses.ResponseNewParams: + openaiConfig = *cfg + case map[string]any: + if err := mapToStruct(cfg, &openaiConfig); err != nil { + return nil, fmt.Errorf("failed to convert config to responses.ResponseNewParams: %w", err) + } + case nil: + // empty but valid config + default: + return nil, fmt.Errorf("unexpected config type: %T", config) + } + return &openaiConfig, nil +} + +// anyToJSONString converts a stream of bytes to a JSON string +func anyToJSONString(data any) (string, error) { + jsonBytes, err := json.Marshal(data) + if err != nil { + return "", fmt.Errorf("failed to marshal any to JSON string: %w", err) + } + return string(jsonBytes), nil +} + +// mapToStruct converts the provided map into a given struct +func mapToStruct(m map[string]any, v any) error { + jsonData, err := json.Marshal(m) + if err != nil { + return err + } + return json.Unmarshal(jsonData, v) +} + +// jsonStringToMap translates a JSON string into a map +func jsonStringToMap(jsonString string) (map[string]any, error) { + var result map[string]any + if err := json.Unmarshal([]byte(jsonString), &result); err != nil { + return nil, fmt.Errorf("unmarshal failed to parse json string %s: %w", jsonString, err) + } + return result, nil +} + +// configToMap converts a config struct to a map[string]any +func configToMap(config any) map[string]any { + r := jsonschema.Reflector{ + DoNotReference: true, + AllowAdditionalProperties: false, + ExpandedStruct: true, + RequiredFromJSONSchemaTags: true, + } + + r.Mapper = func(r reflect.Type) *jsonschema.Schema { + if r.Name() == "Opt[float64]" { + return &jsonschema.Schema{ + Type: "number", + } + } + if r.Name() == "Opt[int64]" { + return &jsonschema.Schema{ + Type: "integer", + } + } + if r.Name() == "Opt[string]" { + return &jsonschema.Schema{ + Type: "string", + } + } + if r.Name() == "Opt[bool]" { + return &jsonschema.Schema{ + Type: "boolean", + } + } + return nil + } + schema := r.Reflect(config) + result := base.SchemaAsMap(schema) + + return result +} + +// handleOutputFormat determines whether to enable structured output or json_mode in the request +func handleOutputFormat(textConfig responses.ResponseTextConfigParam, output *ai.ModelOutputConfig) responses.ResponseTextConfigParam { + if output == nil || output.Format != ai.OutputFormatJSON { + return textConfig + } + + if !output.Constrained || output.Schema == nil { + return textConfig + } + + // strict mode is used for latest gpt models + name := "output_schema" + // openai schemas require a name to be provided + if title, ok := output.Schema["title"].(string); ok { + name = title + } + + textConfig.Format = responses.ResponseFormatTextConfigUnionParam{ + OfJSONSchema: &responses.ResponseFormatTextJSONSchemaConfigParam{ + Type: constant.JSONSchema("json_schema"), + Name: sanitizeSchemaName(name), + Strict: param.NewOpt(true), + Schema: output.Schema, + }, + } + return textConfig +} + +// sanitizeSchemaName ensures the schema name contains only alphanumeric characters, underscores, or dashes, +// replaces invalid characters with underscores (_) and makes sure is no longer than 64 characters. +func sanitizeSchemaName(name string) string { + schemaNameRegex := regexp.MustCompile(`[^a-zA-Z0-9_-]+`) + newName := schemaNameRegex.ReplaceAllString(name, "_") + + // do not return error, cut the string instead + if len(newName) > 64 { + return newName[:64] + } + if newName == "" { + // schema name is a mandatory field + return "output_schema" + } + return newName +}