From c72e8ff550461ee55f626e65068773405a22fe2b Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 06:04:40 -0400 Subject: [PATCH 01/14] docs: add Genkit provider migration design MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace all hand-rolled provider implementations with Google Genkit Go SDK adapters. Keeps provider.Provider interface unchanged — Genkit is an internal implementation detail. All consumers (executor, ratchet-cli, mesh) unaffected. Co-Authored-By: Claude Opus 4.6 (1M context) --- ...-04-05-genkit-provider-migration-design.md | 213 ++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 docs/plans/2026-04-05-genkit-provider-migration-design.md diff --git a/docs/plans/2026-04-05-genkit-provider-migration-design.md b/docs/plans/2026-04-05-genkit-provider-migration-design.md new file mode 100644 index 0000000..bab52dd --- /dev/null +++ b/docs/plans/2026-04-05-genkit-provider-migration-design.md @@ -0,0 +1,213 @@ +# Genkit Provider Migration Design + +**Date:** 2026-04-05 +**Repo:** workflow-plugin-agent +**Goal:** Replace all hand-rolled provider implementations with Google Genkit Go SDK adapters, keeping the `provider.Provider` interface unchanged. + +## Scope + +- **Replace:** All 12+ provider implementation files with Genkit plugin adapters +- **Keep:** `provider.Provider` interface, `executor/`, `tools/`, `orchestrator/`, mesh support, test providers +- **Gain:** Unified SDK, structured output, built-in tracing, MCP support, fewer direct dependencies + +## Architecture + +```mermaid +graph TD + E[executor.Execute] --> P[provider.Provider interface] + P --> GA[genkit adapter] + GA --> GG[genkit.Generate / GenerateStream] + GG --> PA[Anthropic plugin] + GG --> PO[OpenAI plugin] + GG --> PG[Google AI plugin] + GG --> POL[Ollama plugin] +``` + +Genkit is an **internal implementation detail**. The `provider.Provider` interface is unchanged. All consumers (executor, ratchet-cli, mesh, orchestrator) work without modification. + +## What Gets Deleted + +- `provider/anthropic.go`, `anthropic_bedrock.go`, `anthropic_foundry.go`, `anthropic_vertex.go` +- `provider/anthropic_convert.go`, `anthropic_convert_test.go` +- `provider/openai.go`, `openai_azure.go`, `openai_azure_test.go` +- `provider/gemini.go`, `gemini_test.go` +- `provider/ollama.go`, `ollama_convert.go`, `ollama_test.go`, `ollama_convert_test.go` +- `provider/copilot.go`, `copilot_test.go`, `copilot_models.go`, `copilot_models_test.go` +- `provider/openrouter.go`, `openrouter_test.go` +- `provider/cohere.go` +- `provider/huggingface.go`, `huggingface_test.go` +- `provider/llama_cpp.go`, `llama_cpp_test.go`, `llama_cpp_download.go`, `llama_cpp_download_test.go` +- `provider/local.go`, `local_test.go` +- `provider/ssrf.go`, `provider/models_ssrf_test.go` +- `provider/auth_modes.go` (simplified into adapter) + +## What Stays + +- `provider/provider.go` — interface + types (`Message`, `ToolDef`, `ToolCall`, `StreamEvent`, `Response`, `Usage`, `AuthModeInfo`) +- `provider/models.go` — model registry/definitions (if still needed for model metadata) +- `provider/test_provider.go`, `test_provider_channel.go`, `test_provider_http.go`, `test_provider_scripted.go` — test infrastructure +- `provider/thinking_field_test.go` — test for thinking trace parsing +- `executor/` — entire package unchanged +- `tools/` — entire package unchanged +- `orchestrator/` — entire package unchanged (ProviderRegistry updated to use new factories) + +## New Package: `genkit/` + +``` +genkit/ + genkit.go — Init(), singleton Genkit instance, plugin registration + adapter.go — genkitProvider implementing provider.Provider + convert.go — Bidirectional type conversion (our types ↔ Genkit ai types) + providers.go — Factory functions per provider type + providers_test.go — Tests with mock Genkit model +``` + +### genkit.go — Initialization + +```go +package genkit + +// Init creates and returns a Genkit instance with all available plugins registered. +// Call once at startup. Thread-safe after initialization. +func Init(ctx context.Context, opts ...Option) (*genkit.Genkit, error) + +type Option func(*initConfig) +func WithAnthropicKey(key string) Option +func WithOpenAIKey(key string) Option +func WithGoogleAIKey(key string) Option +func WithOllamaHost(host string) Option +// etc. +``` + +### adapter.go — Provider Adapter + +```go +// genkitProvider adapts a Genkit model to provider.Provider. +type genkitProvider struct { + g *genkit.Genkit + model ai.Model // resolved Genkit model + name string // provider identifier + authInfo provider.AuthModeInfo + config map[string]any // model-specific generation config +} + +func (p *genkitProvider) Name() string +func (p *genkitProvider) AuthModeInfo() provider.AuthModeInfo +func (p *genkitProvider) Chat(ctx, messages, tools) (*provider.Response, error) +func (p *genkitProvider) Stream(ctx, messages, tools) (<-chan provider.StreamEvent, error) +``` + +`Chat()` calls `genkit.Generate()` and converts the response. +`Stream()` calls `genkit.Generate()` with streaming and converts chunks to `StreamEvent` on a channel. + +### convert.go — Type Conversion + +```go +// toGenkitMessages converts []provider.Message → []ai.Message +func toGenkitMessages(msgs []provider.Message) []*ai.Message + +// toGenkitTools converts []provider.ToolDef → []ai.Tool (via genkit.DefineTool or inline) +func toGenkitTools(g *genkit.Genkit, tools []provider.ToolDef) []*ai.Tool + +// fromGenkitResponse converts *ai.ModelResponse → *provider.Response +func fromGenkitResponse(resp *ai.ModelResponse) *provider.Response + +// fromGenkitChunk converts ai.ModelResponseChunk → provider.StreamEvent +func fromGenkitChunk(chunk *ai.ModelResponseChunk) provider.StreamEvent +``` + +### providers.go — Factory Functions + +```go +// NewAnthropicProvider creates a provider backed by Genkit's Anthropic plugin. +func NewAnthropicProvider(cfg AnthropicConfig) (provider.Provider, error) + +// NewOpenAIProvider creates a provider backed by Genkit's OpenAI plugin. +func NewOpenAIProvider(cfg OpenAIConfig) (provider.Provider, error) + +// NewGoogleAIProvider creates a provider backed by Genkit's Google AI plugin. +func NewGoogleAIProvider(cfg GoogleAIConfig) (provider.Provider, error) + +// NewOllamaProvider creates a provider backed by Genkit's Ollama plugin. +func NewOllamaProvider(cfg OllamaConfig) (provider.Provider, error) + +// NewOpenAICompatibleProvider handles OpenRouter, Copilot, Cohere, HuggingFace, etc. +func NewOpenAICompatibleProvider(cfg OpenAICompatibleConfig) (provider.Provider, error) + +// NewBedrockProvider creates a provider via Genkit's AWS Bedrock plugin. +func NewBedrockProvider(cfg BedrockConfig) (provider.Provider, error) + +// NewVertexAIProvider creates a provider via Genkit's Vertex AI plugin. +func NewVertexAIProvider(cfg VertexAIConfig) (provider.Provider, error) + +// NewAzureOpenAIProvider creates a provider via Genkit's OpenAI plugin with Azure endpoint. +func NewAzureOpenAIProvider(cfg AzureOpenAIConfig) (provider.Provider, error) +``` + +Config structs mirror existing provider configs (API key, model, base URL, max tokens, etc.). + +## Provider Mapping + +| Current Implementation | Genkit Plugin | Factory | +|---|---|---| +| `anthropic.go` (direct API) | `github.com/anthropics/anthropic-sdk-go` via Genkit Anthropic plugin | `NewAnthropicProvider` | +| `anthropic_bedrock.go` | Genkit AWS Bedrock plugin | `NewBedrockProvider` | +| `anthropic_vertex.go` | Genkit Vertex AI plugin | `NewVertexAIProvider` | +| `anthropic_foundry.go` | Genkit Anthropic plugin with custom base URL | `NewAnthropicProvider` | +| `openai.go` | Genkit OpenAI plugin | `NewOpenAIProvider` | +| `openai_azure.go` | Genkit OpenAI plugin with Azure endpoint | `NewAzureOpenAIProvider` | +| `gemini.go` | Genkit Google AI plugin | `NewGoogleAIProvider` | +| `ollama.go` + `ollama_convert.go` | Genkit Ollama plugin | `NewOllamaProvider` | +| `copilot.go` | OpenAI-compatible via Genkit OpenAI plugin | `NewOpenAICompatibleProvider` | +| `openrouter.go` | OpenAI-compatible | `NewOpenAICompatibleProvider` | +| `cohere.go` | OpenAI-compatible | `NewOpenAICompatibleProvider` | +| `huggingface.go` | OpenAI-compatible | `NewOpenAICompatibleProvider` | +| `llama_cpp.go` | Ollama (llama.cpp serves OpenAI-compatible) or Genkit Ollama | `NewOllamaProvider` or `NewOpenAICompatibleProvider` | + +## Thinking/Reasoning Trace Support + +Models like Claude and Qwen emit thinking traces. Genkit streams these as part of `ModelResponseChunk`. The `fromGenkitChunk()` converter maps: +- Text content → `StreamEvent{Type: "text", Text: ...}` +- Thinking content → `StreamEvent{Type: "thinking", Thinking: ...}` +- Tool calls → `StreamEvent{Type: "tool_call", Tool: ...}` +- Done → `StreamEvent{Type: "done"}` +- Errors → `StreamEvent{Type: "error", Error: ...}` + +## ProviderRegistry Update + +`orchestrator/provider_registry.go` currently has a `createProvider()` method that switches on provider type and calls constructors like `provider.NewAnthropicProvider()`. This gets updated to call `genkit.NewAnthropicProvider()` etc. The registry interface doesn't change. + +## llama.cpp Download Support + +The current `llama_cpp_download.go` handles binary downloads and HuggingFace model pulls. This is orthogonal to the provider layer — it can stay as a utility in `provider/` or move to a `local/` package. The llama.cpp *provider* itself gets replaced by Genkit's Ollama plugin (llama.cpp serves an OpenAI-compatible API). + +## Dependencies + +**Added:** +- `github.com/firebase/genkit/go` (core) +- Genkit provider plugins (anthropic, openai, googleai, ollama, etc.) + +**Removed (become transitive via Genkit):** +- `github.com/anthropics/anthropic-sdk-go` (direct) +- `github.com/openai/openai-go` (direct) +- `github.com/google/generative-ai-go` (direct) +- `github.com/ollama/ollama` (direct) + +## Testing Strategy + +1. **Unit tests:** Mock Genkit model via `ai.DefineModel()` with canned responses. Test adapter conversion, streaming, tool call mapping. +2. **Existing test providers** (`test_provider.go` etc.) remain unchanged — they implement `provider.Provider` directly and don't use Genkit. +3. **Integration tests** (tagged `//go:build integration`): Call real provider APIs via Genkit to verify end-to-end. +4. **Regression:** All existing executor tests must pass unchanged since `provider.Provider` interface is stable. + +## Migration Order + +1. Add Genkit dependency, create `genkit/` package skeleton +2. Implement `convert.go` (type conversion) +3. Implement `adapter.go` (genkitProvider) +4. Implement `providers.go` (all factory functions) +5. Update `orchestrator/provider_registry.go` to use new factories +6. Delete old provider implementation files +7. Clean up go.mod (remove direct SDK deps that are now transitive) +8. Run full test suite, fix any regressions +9. Tag release From d9d271c68741c70a780fb08a491ebd9da112130b Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 06:09:48 -0400 Subject: [PATCH 02/14] docs: add Genkit provider migration implementation plan 8-task plan covering: dependency setup, type conversion, adapter, factory functions, registry update, file deletion, dep cleanup, tests. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../2026-04-05-genkit-provider-migration.md | 573 ++++++++++++++++++ 1 file changed, 573 insertions(+) create mode 100644 docs/plans/2026-04-05-genkit-provider-migration.md diff --git a/docs/plans/2026-04-05-genkit-provider-migration.md b/docs/plans/2026-04-05-genkit-provider-migration.md new file mode 100644 index 0000000..1fd0bc9 --- /dev/null +++ b/docs/plans/2026-04-05-genkit-provider-migration.md @@ -0,0 +1,573 @@ +# Genkit Provider Migration — Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Replace all hand-rolled provider implementations in workflow-plugin-agent with Google Genkit Go SDK adapters, keeping the `provider.Provider` interface unchanged. + +**Architecture:** A new `genkit/` package wraps Genkit plugins behind the existing `provider.Provider` interface. The `ProviderRegistry` factory functions are updated to call genkit factories. All provider implementation files (~25 files) are deleted. + +**Tech Stack:** Go 1.26, Genkit Go v1.6.0, Genkit plugins (anthropic, openai, googlegenai, ollama, compat_oai), community plugins (aws-bedrock, azure-openai) + +--- + +## Task 1: Add Genkit dependency and create package skeleton + +**Files:** +- Modify: `go.mod` +- Create: `genkit/genkit.go` + +**Step 1:** Add Genkit core dependency: +```bash +cd /Users/jon/workspace/workflow-plugin-agent +go get github.com/firebase/genkit/go@v1.6.0 +go get github.com/firebase/genkit/go/plugins/anthropic +go get github.com/firebase/genkit/go/plugins/googlegenai +go get github.com/firebase/genkit/go/plugins/ollama +go get github.com/firebase/genkit/go/plugins/compat_oai +go mod tidy +``` + +**Step 2:** Create `genkit/genkit.go` — package with Genkit initialization: +```go +// Package genkit provides Genkit-backed implementations of provider.Provider. +package genkit + +import ( + "context" + "sync" + + gk "github.com/firebase/genkit/go/genkit" +) + +var ( + instance *gk.Genkit + once sync.Once +) + +// Instance returns the shared Genkit instance, initializing it lazily on first call. +// Plugins are registered dynamically when providers are created, not at init time. +func Instance(ctx context.Context) *gk.Genkit { + once.Do(func() { + instance = gk.Init(ctx) + }) + return instance +} +``` + +**Step 3:** Build to verify: `go build ./...` + +**Step 4:** Commit: +```bash +git add go.mod go.sum genkit/genkit.go +git commit -m "chore: add Genkit Go SDK dependency and package skeleton" +``` + +--- + +## Task 2: Implement type conversion layer + +**Files:** +- Create: `genkit/convert.go` +- Create: `genkit/convert_test.go` + +**Step 1:** Create `genkit/convert.go` — bidirectional type conversion between our `provider.*` types and Genkit `ai.*` types: +```go +package genkit + +import ( + "github.com/firebase/genkit/go/ai" + "github.com/GoCodeAlone/workflow-plugin-agent/provider" +) + +// toGenkitMessages converts our messages to Genkit messages. +func toGenkitMessages(msgs []provider.Message) []*ai.Message { + out := make([]*ai.Message, 0, len(msgs)) + for _, m := range msgs { + var role ai.Role + switch m.Role { + case provider.RoleSystem: + role = ai.RoleSystem + case provider.RoleUser: + role = ai.RoleUser + case provider.RoleAssistant: + role = ai.RoleModel + case provider.RoleTool: + role = ai.RoleTool + default: + role = ai.RoleUser + } + + parts := []*ai.Part{ai.NewTextPart(m.Content)} + + // Tool call results: add as ToolResponsePart + if m.ToolCallID != "" { + parts = []*ai.Part{ai.NewToolResponsePart(&ai.ToolResponse{ + Name: m.ToolCallID, + Output: map[string]any{"result": m.Content}, + })} + } + + out = append(out, ai.NewMessage(role, nil, parts...)) + } + return out +} + +// toGenkitToolDefs converts our tool definitions to Genkit tool defs. +// Returns the tool names for WithModelName-style passing. +func toGenkitToolDefs(tools []provider.ToolDef) []ai.ToolDef { + out := make([]ai.ToolDef, 0, len(tools)) + for _, t := range tools { + out = append(out, ai.ToolDef{ + Name: t.Name, + Description: t.Description, + InputSchema: t.Parameters, + }) + } + return out +} + +// fromGenkitResponse converts a Genkit response to our Response type. +func fromGenkitResponse(resp *ai.ModelResponse) *provider.Response { + if resp == nil { + return &provider.Response{} + } + + out := &provider.Response{ + Content: resp.Text(), + } + + // Extract tool calls + if msg := resp.Message; msg != nil { + for _, part := range msg.Content { + if part.ToolRequest != nil { + tc := provider.ToolCall{ + ID: part.ToolRequest.Name, + Name: part.ToolRequest.Name, + Arguments: make(map[string]any), + } + if input, ok := part.ToolRequest.Input.(map[string]any); ok { + tc.Arguments = input + } + out.ToolCalls = append(out.ToolCalls, tc) + } + } + } + + // Extract usage + if resp.Usage != nil { + out.Usage = provider.Usage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + } + } + + return out +} + +// fromGenkitChunk converts a Genkit stream chunk to our StreamEvent. +func fromGenkitChunk(chunk *ai.ModelResponseChunk) provider.StreamEvent { + if chunk == nil { + return provider.StreamEvent{Type: "done"} + } + text := chunk.Text() + if text != "" { + return provider.StreamEvent{Type: "text", Text: text} + } + // Check for tool calls in chunk + for _, part := range chunk.Content { + if part.ToolRequest != nil { + return provider.StreamEvent{ + Type: "tool_call", + Tool: &provider.ToolCall{ + ID: part.ToolRequest.Name, + Name: part.ToolRequest.Name, + Arguments: func() map[string]any { + if m, ok := part.ToolRequest.Input.(map[string]any); ok { + return m + } + return nil + }(), + }, + } + } + } + return provider.StreamEvent{Type: "text", Text: ""} +} +``` + +**Step 2:** Create `genkit/convert_test.go`: +```go +package genkit + +import ( + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/GoCodeAlone/workflow-plugin-agent/provider" +) + +func TestToGenkitMessages(t *testing.T) { + msgs := []provider.Message{ + {Role: provider.RoleSystem, Content: "You are helpful."}, + {Role: provider.RoleUser, Content: "Hello"}, + {Role: provider.RoleAssistant, Content: "Hi there"}, + } + result := toGenkitMessages(msgs) + if len(result) != 3 { + t.Fatalf("expected 3 messages, got %d", len(result)) + } + if result[0].Role != ai.RoleSystem { + t.Errorf("expected system role, got %s", result[0].Role) + } + if result[2].Role != ai.RoleModel { + t.Errorf("expected model role for assistant, got %s", result[2].Role) + } +} + +func TestFromGenkitResponse(t *testing.T) { + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("Hello world"), + Usage: &ai.GenerationUsage{InputTokens: 10, OutputTokens: 5}, + } + result := fromGenkitResponse(resp) + if result.Content != "Hello world" { + t.Errorf("expected 'Hello world', got %q", result.Content) + } + if result.Usage.InputTokens != 10 { + t.Errorf("expected 10 input tokens, got %d", result.Usage.InputTokens) + } +} + +func TestFromGenkitResponseNil(t *testing.T) { + result := fromGenkitResponse(nil) + if result.Content != "" { + t.Errorf("expected empty content, got %q", result.Content) + } +} +``` + +**Step 3:** Run tests: `go test ./genkit/ -v -count=1` + +**Step 4:** Commit: +```bash +git add genkit/convert.go genkit/convert_test.go +git commit -m "feat: add Genkit ↔ provider type conversion layer" +``` + +--- + +## Task 3: Implement the Genkit provider adapter + +**Files:** +- Create: `genkit/adapter.go` +- Create: `genkit/adapter_test.go` + +**Step 1:** Create `genkit/adapter.go` — the `genkitProvider` struct that implements `provider.Provider`: +```go +package genkit + +import ( + "context" + "fmt" + + "github.com/firebase/genkit/go/ai" + gk "github.com/firebase/genkit/go/genkit" + "github.com/GoCodeAlone/workflow-plugin-agent/provider" +) + +// genkitProvider adapts a Genkit model to provider.Provider. +type genkitProvider struct { + g *gk.Genkit + modelName string // "provider/model" format + name string + authInfo provider.AuthModeInfo +} + +func (p *genkitProvider) Name() string { return p.name } +func (p *genkitProvider) AuthModeInfo() provider.AuthModeInfo { return p.authInfo } + +func (p *genkitProvider) Chat(ctx context.Context, messages []provider.Message, tools []provider.ToolDef) (*provider.Response, error) { + opts := []ai.GenerateOption{ + ai.WithModelName(p.modelName), + ai.WithMessages(toGenkitMessages(messages)...), + } + + // Pass tool definitions if provided — use WithReturnToolRequests so we handle tool + // execution ourselves (the executor loop does this, not Genkit). + if len(tools) > 0 { + opts = append(opts, ai.WithReturnToolRequests(true)) + // Register tools dynamically for this call + for _, t := range tools { + tool := gk.DefineTool(p.g, t.Name, t.Description, + func(ctx *ai.ToolContext, input map[string]any) (map[string]any, error) { + // Placeholder — tools are executed by the executor, not here. + return nil, fmt.Errorf("tool %s should not be called via Genkit", t.Name) + }, + ) + opts = append(opts, ai.WithTools(tool)) + } + } + + resp, err := gk.Generate(ctx, p.g, opts...) + if err != nil { + return nil, fmt.Errorf("genkit generate: %w", err) + } + + return fromGenkitResponse(resp), nil +} + +func (p *genkitProvider) Stream(ctx context.Context, messages []provider.Message, tools []provider.ToolDef) (<-chan provider.StreamEvent, error) { + ch := make(chan provider.StreamEvent, 64) + + opts := []ai.GenerateOption{ + ai.WithModelName(p.modelName), + ai.WithMessages(toGenkitMessages(messages)...), + } + + if len(tools) > 0 { + opts = append(opts, ai.WithReturnToolRequests(true)) + for _, t := range tools { + tool := gk.DefineTool(p.g, t.Name, t.Description, + func(ctx *ai.ToolContext, input map[string]any) (map[string]any, error) { + return nil, fmt.Errorf("tool %s should not be called via Genkit", t.Name) + }, + ) + opts = append(opts, ai.WithTools(tool)) + } + } + + go func() { + defer close(ch) + + stream := gk.GenerateStream(ctx, p.g, opts...) + for result, err := range stream { + if err != nil { + ch <- provider.StreamEvent{Type: "error", Error: err.Error()} + return + } + if result.Done { + // Extract final response for tool calls and usage + if result.Response != nil { + final := fromGenkitResponse(result.Response) + for _, tc := range final.ToolCalls { + ch <- provider.StreamEvent{Type: "tool_call", Tool: &tc} + } + if final.Usage.InputTokens > 0 || final.Usage.OutputTokens > 0 { + ch <- provider.StreamEvent{Type: "done", Usage: &final.Usage} + return + } + } + ch <- provider.StreamEvent{Type: "done"} + return + } + if result.Chunk != nil { + ev := fromGenkitChunk(result.Chunk) + if ev.Type != "" && (ev.Text != "" || ev.Tool != nil) { + ch <- ev + } + } + } + // Iterator exhausted without Done — send done anyway + ch <- provider.StreamEvent{Type: "done"} + }() + + return ch, nil +} +``` + +**Step 2:** Create `genkit/adapter_test.go` with mock model tests. Tests should verify Chat/Stream call paths and conversion. + +**Step 3:** Run tests: `go test ./genkit/ -v -count=1` + +**Step 4:** Commit: +```bash +git add genkit/adapter.go genkit/adapter_test.go +git commit -m "feat: implement Genkit provider adapter (provider.Provider interface)" +``` + +--- + +## Task 4: Implement provider factory functions + +**Files:** +- Create: `genkit/providers.go` +- Create: `genkit/providers_test.go` + +**Step 1:** Create `genkit/providers.go` with factory functions for each provider type. Each factory: +1. Initializes the appropriate Genkit plugin +2. Creates a `genkitProvider` with the correct model name format +3. Returns `provider.Provider` + +Factory functions needed (one per ProviderRegistry factory): +- `NewAnthropicProvider(apiKey, model, baseURL string, maxTokens int) provider.Provider` +- `NewOpenAIProvider(apiKey, model, baseURL string, maxTokens int) provider.Provider` +- `NewGoogleAIProvider(apiKey, model string, maxTokens int) (provider.Provider, error)` +- `NewOllamaProvider(model, baseURL string, maxTokens int) provider.Provider` +- `NewOpenAICompatibleProvider(name, apiKey, model, baseURL string, maxTokens int) provider.Provider` — for OpenRouter, Copilot, Cohere, HuggingFace +- `NewBedrockProvider(region, model, accessKeyID, secretAccessKey, sessionToken string, maxTokens int) provider.Provider` — via community plugin or OpenAI-compatible +- `NewVertexAIProvider(projectID, region, model, credentialsJSON string, maxTokens int) (provider.Provider, error)` +- `NewAzureOpenAIProvider(resource, deploymentName, apiVersion, apiKey string, maxTokens int) provider.Provider` — via community plugin or OpenAI-compatible +- `NewAnthropicFoundryProvider(resource, model, apiKey, entraToken string, maxTokens int) provider.Provider` + +Each factory calls `Instance(ctx)` to get the shared Genkit instance, registers the plugin if not already registered, and returns a `genkitProvider`. + +**Step 2:** Create `genkit/providers_test.go` — test factory instantiation (not live API calls). + +**Step 3:** Build: `go build ./...` + +**Step 4:** Commit: +```bash +git add genkit/providers.go genkit/providers_test.go +git commit -m "feat: add Genkit provider factory functions for all provider types" +``` + +--- + +## Task 5: Update ProviderRegistry to use Genkit factories + +**Files:** +- Modify: `orchestrator/provider_registry.go` + +**Step 1:** Update all factory functions in `provider_registry.go` to call `genkit.New*Provider()` instead of `provider.New*Provider()`: + +```go +import gkprov "github.com/GoCodeAlone/workflow-plugin-agent/genkit" + +func anthropicProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewAnthropicProvider(apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens), nil +} + +func openaiProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewOpenAIProvider(apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens), nil +} +// ... repeat for all 13 factory functions +``` + +**Step 2:** Build: `go build ./...` + +**Step 3:** Run registry tests: `go test ./orchestrator/ -run TestProviderRegistry -v -count=1` + +**Step 4:** Run full suite: `go test ./... -count=1` + +**Step 5:** Commit: +```bash +git add orchestrator/provider_registry.go +git commit -m "feat: update ProviderRegistry to use Genkit-backed factories" +``` + +--- + +## Task 6: Delete old provider implementation files + +**Files:** +- Delete: All provider implementation files listed in design doc + +**Step 1:** Delete the old provider files (keep `provider.go`, `models.go`, `auth_modes.go`, and all `test_provider*.go` files): +```bash +cd /Users/jon/workspace/workflow-plugin-agent +# Delete implementation files +rm provider/anthropic.go provider/anthropic_bedrock.go provider/anthropic_foundry.go provider/anthropic_vertex.go +rm provider/anthropic_convert.go provider/anthropic_convert_test.go +rm provider/openai.go provider/openai_azure.go provider/openai_azure_test.go +rm provider/gemini.go provider/gemini_test.go +rm provider/ollama.go provider/ollama_convert.go provider/ollama_test.go provider/ollama_convert_test.go +rm provider/copilot.go provider/copilot_test.go provider/copilot_models.go provider/copilot_models_test.go +rm provider/openrouter.go provider/openrouter_test.go +rm provider/cohere.go +rm provider/huggingface.go provider/huggingface_test.go +rm provider/llama_cpp.go provider/llama_cpp_test.go provider/llama_cpp_download.go provider/llama_cpp_download_test.go +rm provider/local.go provider/local_test.go +rm provider/ssrf.go provider/models_ssrf_test.go +``` + +**Step 2:** Fix any compilation errors from dangling references. Key areas to check: +- `provider/models.go` — may reference deleted types +- `provider/auth_modes.go` — may reference deleted types +- `provider/thinking_field_test.go` — may import deleted code +- Any file in `orchestrator/` that directly imports deleted provider constructors + +**Step 3:** Build: `go build ./...` + +**Step 4:** Run tests: `go test ./... -count=1` + +**Step 5:** Commit: +```bash +git add -u # stages deletions +git commit -m "refactor: delete hand-rolled provider implementations (replaced by Genkit)" +``` + +--- + +## Task 7: Clean up dependencies and fix remaining issues + +**Files:** +- Modify: `go.mod` +- Modify: any files with compilation errors + +**Step 1:** Remove direct SDK dependencies that are now transitive via Genkit: +```bash +go mod tidy +``` + +Check if `github.com/anthropics/anthropic-sdk-go`, `github.com/openai/openai-go`, `github.com/google/generative-ai-go`, `github.com/ollama/ollama` are still needed directly. If only Genkit imports them, they'll move to `indirect`. + +**Step 2:** Fix any remaining compile errors or test failures. + +**Step 3:** Run full test suite with race detector: +```bash +go test -race ./... -count=1 +``` + +**Step 4:** Run linter: +```bash +golangci-lint run +``` + +**Step 5:** Commit: +```bash +git add go.mod go.sum +git commit -m "chore: clean up dependencies after Genkit migration" +``` + +--- + +## Task 8: Write integration tests and verify + +**Files:** +- Create: `genkit/integration_test.go` + +**Step 1:** Create integration tests that verify the full path: `ProviderRegistry → genkit factory → genkitProvider → Genkit → mock model`. These tests use in-memory DB + mock secrets like existing registry tests. + +**Step 2:** Verify existing executor tests pass (they use `provider.Provider` interface which is unchanged): +```bash +go test ./executor/ -v -count=1 +``` + +**Step 3:** Verify orchestrator tests: +```bash +go test ./orchestrator/ -v -count=1 +``` + +**Step 4:** Full regression: +```bash +go test -race ./... -count=1 +golangci-lint run +``` + +**Step 5:** Commit: +```bash +git add genkit/integration_test.go +git commit -m "test: add Genkit integration tests and verify full regression" +``` + +--- + +## Execution Order + +``` +Task 1 (deps + skeleton) → Task 2 (convert) → Task 3 (adapter) → Task 4 (factories) + ↓ +Task 5 (registry update) → Task 6 (delete old files) → Task 7 (cleanup) → Task 8 (tests) +``` + +All tasks are sequential — each builds on the previous. From c922430d0085c8ceb52c9d56c899daadb306ca35 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 06:20:00 -0400 Subject: [PATCH 03/14] chore: add Genkit Go SDK dependency and package skeleton Add github.com/firebase/genkit/go@v1.6.0 and its plugins (anthropic, googlegenai, ollama, compat_oai) as dependencies. Create genkit/ package with a lazily-initialized shared Genkit instance for test/mock scenarios. Co-Authored-By: Claude Opus 4.6 --- genkit/genkit.go | 24 ++++++++++++++++++++++++ go.mod | 11 ++++++++++- go.sum | 24 +++++++++++++++++++++--- 3 files changed, 55 insertions(+), 4 deletions(-) create mode 100644 genkit/genkit.go diff --git a/genkit/genkit.go b/genkit/genkit.go new file mode 100644 index 0000000..2f670dc --- /dev/null +++ b/genkit/genkit.go @@ -0,0 +1,24 @@ +// Package genkit provides Genkit-backed implementations of provider.Provider. +package genkit + +import ( + "context" + "sync" + + gk "github.com/firebase/genkit/go/genkit" +) + +var ( + instance *gk.Genkit + once sync.Once +) + +// Instance returns the shared Genkit instance, initializing it lazily on first call. +// This instance has no plugins and is used for mock/test model definitions. +// Production providers use per-factory Genkit instances initialized with their specific plugin. +func Instance(ctx context.Context) *gk.Genkit { + once.Do(func() { + instance = gk.Init(ctx) + }) + return instance +} diff --git a/go.mod b/go.mod index 0ca6d67..e10a23d 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/aws/aws-sdk-go-v2 v1.41.4 github.com/aws/aws-sdk-go-v2/credentials v1.19.12 github.com/docker/docker v28.5.2+incompatible + github.com/firebase/genkit/go v1.6.0 github.com/go-rod/rod v0.116.2 github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/generative-ai-go v0.20.1 @@ -119,8 +120,10 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-sql-driver/mysql v1.7.1 // indirect github.com/gobwas/glob v0.2.3 // indirect + github.com/goccy/go-yaml v1.17.1 // indirect github.com/golobby/cast v1.3.3 // indirect github.com/google/btree v1.1.3 // indirect + github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect @@ -145,6 +148,7 @@ require ( github.com/hashicorp/memberlist v0.5.4 // indirect github.com/hashicorp/vault/api v1.22.0 // indirect github.com/hashicorp/yamux v0.1.2 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/itchyny/gojq v0.12.18 // indirect github.com/itchyny/timefmt-go v0.1.7 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -161,9 +165,10 @@ require ( github.com/json-iterator/go v1.1.13-0.20220915233716-71ac16282d12 // indirect github.com/klauspost/compress v1.18.4 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect - github.com/mailru/easyjson v0.7.7 // indirect + github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a // indirect github.com/miekg/dns v1.1.72 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/go-testing-interface v1.14.1 // indirect @@ -211,6 +216,10 @@ require ( github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.2.0 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + github.com/xeipuuv/gojsonschema v1.2.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/ysmood/fetchup v0.2.3 // indirect github.com/ysmood/goob v0.4.0 // indirect github.com/ysmood/got v0.40.0 // indirect diff --git a/go.sum b/go.sum index 87fb36e..dcb76a0 100644 --- a/go.sum +++ b/go.sum @@ -256,6 +256,8 @@ github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/firebase/genkit/go v1.6.0 h1:W+oTo/yndFkSe/imGEmsYdW5Qi5w2USG7obiXElAQzY= +github.com/firebase/genkit/go v1.6.0/go.mod h1:vu8ZAqNU6MU5qDza66bvqTtzJoUrqhO/+z5/6dtouJQ= github.com/flowchartsman/retry v1.2.0 h1:qDhlw6RNufXz6RGr+IiYimFpMMkt77SUSHY5tgFaUCU= github.com/flowchartsman/retry v1.2.0/go.mod h1:+sfx8OgCCiAr3t5jh2Gk+T0fRTI+k52edaYxURQxY64= github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= @@ -321,6 +323,8 @@ github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPE github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY= +github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/gofrs/uuid v4.3.1+incompatible h1:0/KbAdpx3UXAx1kEOWHJeOkpbgRFGHVgv+CFIY7dBJI= github.com/gofrs/uuid v4.3.1+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= @@ -347,6 +351,8 @@ github.com/golobby/cast v1.3.3 h1:s2Lawb9RMz7YyYf8IrfMQY4IFmA1R/lgfmj97Vc6fig= github.com/golobby/cast v1.3.3/go.mod h1:0oDO5IT84HTXcbLDf1YXuk0xtg/cRDrxhbpWKxwtJCY= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254 h1:okN800+zMJOGHLJCgry+OGzhhtH6YrjQh1rluHmOacE= +github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254/go.mod h1:k8cjJAQWc//ac/bMnzItyOFbfT01tgRTZGgxELCuxEQ= github.com/google/generative-ai-go v0.20.1 h1:6dEIujpgN2V0PgLhr6c/M1ynRdc7ARtiIDPFzj45uNQ= github.com/google/generative-ai-go v0.20.1/go.mod h1:TjOnZJmZKzarWbjUJgy+r3Ee7HGBRVLhOIgupnwR4Bg= github.com/google/gnostic-models v0.7.1 h1:SisTfuFKJSKM5CPZkffwi6coztzzeYUhc3v4yxLWH8c= @@ -435,6 +441,8 @@ github.com/hashicorp/vault/api v1.22.0 h1:+HYFquE35/B74fHoIeXlZIP2YADVboaPjaSicH github.com/hashicorp/vault/api v1.22.0/go.mod h1:IUZA2cDvr4Ok3+NtK2Oq/r+lJeXkeCrHRmqdyWfpmGM= github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/itchyny/gojq v0.12.18 h1:gFGHyt/MLbG9n6dqnvlliiya2TaMMh6FFaR2b1H6Drc= github.com/itchyny/gojq v0.12.18/go.mod h1:4hPoZ/3lN9fDL1D+aK7DY1f39XZpY9+1Xpjz8atrEkg= github.com/itchyny/timefmt-go v0.1.7 h1:xyftit9Tbw+Dc/huSSPJaEmX1TVL8lw5vxjJLK4GMMA= @@ -465,7 +473,6 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -498,8 +505,8 @@ github.com/lufia/plan9stats v0.0.0-20260216142805-b3301c5f2a88 h1:PTw+yKnXcOFCR6 github.com/lufia/plan9stats v0.0.0-20260216142805-b3301c5f2a88/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= -github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -509,6 +516,8 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a h1:v2cBA3xWKv2cIOVhnzX/gNgkNXqiHfUgJtA3r61Hf7A= +github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a/go.mod h1:Y6ghKH+ZijXn5d9E7qGGZBmjitx7iitZdQiIW97EpTU= github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 h1:KGuD/pM2JpL9FAYvBrnBBeENKZNh6eNtjqytV6TYjnk= @@ -715,8 +724,17 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= +github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/ysmood/fetchup v0.2.3 h1:ulX+SonA0Vma5zUFXtv52Kzip/xe7aj4vqT5AJwQ+ZQ= github.com/ysmood/fetchup v0.2.3/go.mod h1:xhibcRKziSvol0H1/pj33dnKrYyI2ebIvz5cOOkYGns= github.com/ysmood/goob v0.4.0 h1:HsxXhyLBeGzWXnqVKtmT9qM7EuVs/XOgkX7T6r1o1AQ= From f368e7f0e108ea9043ca71aac137822c4d4678d8 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 06:20:57 -0400 Subject: [PATCH 04/14] feat: add Genkit <-> provider type conversion layer Implement bidirectional conversion between provider.Message/Response/ StreamEvent types and Genkit ai.Message/ModelResponse/ModelResponseChunk. Includes thinking trace support mapping ai.PartReasoning to StreamEvent {Type:"thinking"} per the critical alignment requirement. Co-Authored-By: Claude Opus 4.6 --- genkit/convert.go | 138 +++++++++++++++++++++++++++++++++++++++++ genkit/convert_test.go | 114 ++++++++++++++++++++++++++++++++++ 2 files changed, 252 insertions(+) create mode 100644 genkit/convert.go create mode 100644 genkit/convert_test.go diff --git a/genkit/convert.go b/genkit/convert.go new file mode 100644 index 0000000..ed8c87c --- /dev/null +++ b/genkit/convert.go @@ -0,0 +1,138 @@ +package genkit + +import ( + "github.com/GoCodeAlone/workflow-plugin-agent/provider" + "github.com/firebase/genkit/go/ai" +) + +// toGenkitMessages converts our messages to Genkit messages. +func toGenkitMessages(msgs []provider.Message) []*ai.Message { + out := make([]*ai.Message, 0, len(msgs)) + for _, m := range msgs { + var role ai.Role + switch m.Role { + case provider.RoleSystem: + role = ai.RoleSystem + case provider.RoleUser: + role = ai.RoleUser + case provider.RoleAssistant: + role = ai.RoleModel + case provider.RoleTool: + role = ai.RoleTool + default: + role = ai.RoleUser + } + + var parts []*ai.Part + + // Tool call results: add as ToolResponsePart + if m.ToolCallID != "" { + parts = []*ai.Part{ai.NewToolResponsePart(&ai.ToolResponse{ + Name: m.ToolCallID, + Output: map[string]any{"result": m.Content}, + })} + } else if len(m.ToolCalls) > 0 { + // Assistant message with tool calls + for _, tc := range m.ToolCalls { + parts = append(parts, ai.NewToolRequestPart(&ai.ToolRequest{ + Name: tc.Name, + Input: tc.Arguments, + })) + } + } else { + parts = []*ai.Part{ai.NewTextPart(m.Content)} + } + + out = append(out, ai.NewMessage(role, nil, parts...)) + } + return out +} + +// fromGenkitResponse converts a Genkit response to our Response type. +func fromGenkitResponse(resp *ai.ModelResponse) *provider.Response { + if resp == nil { + return &provider.Response{} + } + + out := &provider.Response{ + Content: resp.Text(), + } + + // Extract thinking/reasoning content + if msg := resp.Message; msg != nil { + for _, part := range msg.Content { + if part.IsReasoning() { + out.Thinking = part.Text + break + } + } + } + + // Extract tool calls + if msg := resp.Message; msg != nil { + for _, part := range msg.Content { + if part.ToolRequest != nil { + tc := provider.ToolCall{ + ID: part.ToolRequest.Name, + Name: part.ToolRequest.Name, + Arguments: make(map[string]any), + } + if input, ok := part.ToolRequest.Input.(map[string]any); ok { + tc.Arguments = input + } + out.ToolCalls = append(out.ToolCalls, tc) + } + } + } + + // Extract usage + if resp.Usage != nil { + out.Usage = provider.Usage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + } + } + + return out +} + +// fromGenkitChunk converts a Genkit stream chunk to our StreamEvent. +func fromGenkitChunk(chunk *ai.ModelResponseChunk) provider.StreamEvent { + if chunk == nil { + return provider.StreamEvent{Type: "done"} + } + + // Check for thinking/reasoning parts first + for _, part := range chunk.Content { + if part.IsReasoning() { + return provider.StreamEvent{Type: "thinking", Thinking: part.Text} + } + } + + // Check for text + text := chunk.Text() + if text != "" { + return provider.StreamEvent{Type: "text", Text: text} + } + + // Check for tool calls in chunk + for _, part := range chunk.Content { + if part.ToolRequest != nil { + return provider.StreamEvent{ + Type: "tool_call", + Tool: &provider.ToolCall{ + ID: part.ToolRequest.Name, + Name: part.ToolRequest.Name, + Arguments: func() map[string]any { + if m, ok := part.ToolRequest.Input.(map[string]any); ok { + return m + } + return nil + }(), + }, + } + } + } + + return provider.StreamEvent{Type: "text", Text: ""} +} diff --git a/genkit/convert_test.go b/genkit/convert_test.go new file mode 100644 index 0000000..379bf8e --- /dev/null +++ b/genkit/convert_test.go @@ -0,0 +1,114 @@ +package genkit + +import ( + "testing" + + "github.com/GoCodeAlone/workflow-plugin-agent/provider" + "github.com/firebase/genkit/go/ai" +) + +func TestToGenkitMessages(t *testing.T) { + msgs := []provider.Message{ + {Role: provider.RoleSystem, Content: "You are helpful."}, + {Role: provider.RoleUser, Content: "Hello"}, + {Role: provider.RoleAssistant, Content: "Hi there"}, + } + result := toGenkitMessages(msgs) + if len(result) != 3 { + t.Fatalf("expected 3 messages, got %d", len(result)) + } + if result[0].Role != ai.RoleSystem { + t.Errorf("expected system role, got %s", result[0].Role) + } + if result[2].Role != ai.RoleModel { + t.Errorf("expected model role for assistant, got %s", result[2].Role) + } +} + +func TestToGenkitMessagesToolResult(t *testing.T) { + msgs := []provider.Message{ + {Role: provider.RoleTool, Content: "42", ToolCallID: "add"}, + } + result := toGenkitMessages(msgs) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d", len(result)) + } + if result[0].Role != ai.RoleTool { + t.Errorf("expected tool role, got %s", result[0].Role) + } + if len(result[0].Content) == 0 || result[0].Content[0].ToolResponse == nil { + t.Error("expected ToolResponsePart in message content") + } +} + +func TestFromGenkitResponse(t *testing.T) { + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("Hello world"), + Usage: &ai.GenerationUsage{InputTokens: 10, OutputTokens: 5}, + } + result := fromGenkitResponse(resp) + if result.Content != "Hello world" { + t.Errorf("expected 'Hello world', got %q", result.Content) + } + if result.Usage.InputTokens != 10 { + t.Errorf("expected 10 input tokens, got %d", result.Usage.InputTokens) + } + if result.Usage.OutputTokens != 5 { + t.Errorf("expected 5 output tokens, got %d", result.Usage.OutputTokens) + } +} + +func TestFromGenkitResponseNil(t *testing.T) { + result := fromGenkitResponse(nil) + if result.Content != "" { + t.Errorf("expected empty content, got %q", result.Content) + } +} + +func TestFromGenkitResponseThinking(t *testing.T) { + msg := ai.NewMessage(ai.RoleModel, nil, + ai.NewReasoningPart("I think therefore I am", nil), + ai.NewTextPart("Result text"), + ) + resp := &ai.ModelResponse{Message: msg} + result := fromGenkitResponse(resp) + if result.Thinking != "I think therefore I am" { + t.Errorf("expected thinking trace, got %q", result.Thinking) + } + if result.Content != "Result text" { + t.Errorf("expected 'Result text', got %q", result.Content) + } +} + +func TestFromGenkitChunkText(t *testing.T) { + chunk := &ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewTextPart("hello")}, + } + ev := fromGenkitChunk(chunk) + if ev.Type != "text" { + t.Errorf("expected type 'text', got %q", ev.Type) + } + if ev.Text != "hello" { + t.Errorf("expected 'hello', got %q", ev.Text) + } +} + +func TestFromGenkitChunkThinking(t *testing.T) { + chunk := &ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewReasoningPart("thinking...", nil)}, + } + ev := fromGenkitChunk(chunk) + if ev.Type != "thinking" { + t.Errorf("expected type 'thinking', got %q", ev.Type) + } + if ev.Thinking != "thinking..." { + t.Errorf("expected 'thinking...', got %q", ev.Thinking) + } +} + +func TestFromGenkitChunkNil(t *testing.T) { + ev := fromGenkitChunk(nil) + if ev.Type != "done" { + t.Errorf("expected type 'done', got %q", ev.Type) + } +} From 33599ec4ba9797d5a126344b1b3b32e7c366695c Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 06:23:36 -0400 Subject: [PATCH 05/14] feat: implement Genkit provider adapter (provider.Provider interface) Add genkitProvider struct implementing provider.Provider via Genkit's Generate and GenerateStream APIs. Tools are registered lazily per unique name to prevent duplicate registration panics. Streaming extracts thinking/tool calls from the final response. Co-Authored-By: Claude Opus 4.6 --- genkit/adapter.go | 133 ++++++++++++++++++++++++++++++++ genkit/adapter_test.go | 167 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 300 insertions(+) create mode 100644 genkit/adapter.go create mode 100644 genkit/adapter_test.go diff --git a/genkit/adapter.go b/genkit/adapter.go new file mode 100644 index 0000000..3a40158 --- /dev/null +++ b/genkit/adapter.go @@ -0,0 +1,133 @@ +package genkit + +import ( + "context" + "fmt" + "sync" + + "github.com/GoCodeAlone/workflow-plugin-agent/provider" + "github.com/firebase/genkit/go/ai" + gk "github.com/firebase/genkit/go/genkit" +) + +// genkitProvider adapts a Genkit model to provider.Provider. +type genkitProvider struct { + g *gk.Genkit + modelName string // "provider/model" format e.g. "anthropic/claude-sonnet-4-6" + name string + authInfo provider.AuthModeInfo + + mu sync.Mutex + definedTools map[string]bool // tracks which tool names are registered +} + +func (p *genkitProvider) Name() string { return p.name } +func (p *genkitProvider) AuthModeInfo() provider.AuthModeInfo { return p.authInfo } + +// resolveToolRefs ensures each tool is registered exactly once and returns +// their ToolRef representations for use with WithTools. +func (p *genkitProvider) resolveToolRefs(tools []provider.ToolDef) []ai.ToolRef { + if len(tools) == 0 { + return nil + } + p.mu.Lock() + defer p.mu.Unlock() + + if p.definedTools == nil { + p.definedTools = make(map[string]bool) + } + + refs := make([]ai.ToolRef, 0, len(tools)) + for _, t := range tools { + if !p.definedTools[t.Name] { + tool := gk.DefineTool(p.g, t.Name, t.Description, + func(ctx *ai.ToolContext, input map[string]any) (map[string]any, error) { + // Tools are executed by the executor, not Genkit. + return nil, fmt.Errorf("tool %s should not be called via Genkit", t.Name) + }, + ) + refs = append(refs, tool) + p.definedTools[t.Name] = true + } else { + refs = append(refs, ai.ToolName(t.Name)) + } + } + return refs +} + +// Chat sends a non-streaming request and returns the complete response. +func (p *genkitProvider) Chat(ctx context.Context, messages []provider.Message, tools []provider.ToolDef) (*provider.Response, error) { + opts := []ai.GenerateOption{ + ai.WithModelName(p.modelName), + ai.WithMessages(toGenkitMessages(messages)...), + } + + if len(tools) > 0 { + opts = append(opts, ai.WithReturnToolRequests(true)) + for _, ref := range p.resolveToolRefs(tools) { + opts = append(opts, ai.WithTools(ref)) + } + } + + resp, err := gk.Generate(ctx, p.g, opts...) + if err != nil { + return nil, fmt.Errorf("genkit generate: %w", err) + } + + return fromGenkitResponse(resp), nil +} + +// Stream sends a streaming request. Events are delivered on the returned channel. +func (p *genkitProvider) Stream(ctx context.Context, messages []provider.Message, tools []provider.ToolDef) (<-chan provider.StreamEvent, error) { + ch := make(chan provider.StreamEvent, 64) + + opts := []ai.GenerateOption{ + ai.WithModelName(p.modelName), + ai.WithMessages(toGenkitMessages(messages)...), + } + + if len(tools) > 0 { + opts = append(opts, ai.WithReturnToolRequests(true)) + for _, ref := range p.resolveToolRefs(tools) { + opts = append(opts, ai.WithTools(ref)) + } + } + + go func() { + defer close(ch) + + stream := gk.GenerateStream(ctx, p.g, opts...) + for result, err := range stream { + if err != nil { + ch <- provider.StreamEvent{Type: "error", Error: err.Error()} + return + } + if result.Done { + // Extract final response for tool calls and usage + if result.Response != nil { + final := fromGenkitResponse(result.Response) + for i := range final.ToolCalls { + tc := final.ToolCalls[i] + ch <- provider.StreamEvent{Type: "tool_call", Tool: &tc} + } + if final.Usage.InputTokens > 0 || final.Usage.OutputTokens > 0 { + ch <- provider.StreamEvent{Type: "done", Usage: &final.Usage} + return + } + } + ch <- provider.StreamEvent{Type: "done"} + return + } + if result.Chunk != nil { + ev := fromGenkitChunk(result.Chunk) + if ev.Text != "" || ev.Thinking != "" || ev.Tool != nil { + ch <- ev + } + } + } + // Iterator exhausted without Done — send done anyway + ch <- provider.StreamEvent{Type: "done"} + }() + + return ch, nil +} diff --git a/genkit/adapter_test.go b/genkit/adapter_test.go new file mode 100644 index 0000000..a2ded24 --- /dev/null +++ b/genkit/adapter_test.go @@ -0,0 +1,167 @@ +package genkit + +import ( + "context" + "testing" + + "github.com/GoCodeAlone/workflow-plugin-agent/provider" + "github.com/firebase/genkit/go/ai" + gk "github.com/firebase/genkit/go/genkit" +) + +// mockModel defines a canned in-memory model on a Genkit instance for testing. +func mockModel(g *gk.Genkit, name string, resp *ai.ModelResponse) { + opts := &ai.ModelOptions{ + Supports: &ai.ModelSupports{Tools: true, SystemRole: true, Multiturn: true}, + } + gk.DefineModel(g, name, opts, func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + if cb != nil { + // Stream text chunks + if resp.Message != nil { + for _, part := range resp.Message.Content { + if !part.IsReasoning() { + _ = cb(ctx, &ai.ModelResponseChunk{ + Content: []*ai.Part{part}, + }) + } + } + } + } + return resp, nil + }) +} + +func newTestProvider(t *testing.T, modelName string, resp *ai.ModelResponse) *genkitProvider { + t.Helper() + g := gk.Init(context.Background()) + mockModel(g, modelName, resp) + return &genkitProvider{ + g: g, + modelName: modelName, + name: "mock", + } +} + +func TestGenkitProviderChat(t *testing.T) { + const modelName = "mock/test-model" + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("Hello from mock"), + Usage: &ai.GenerationUsage{InputTokens: 5, OutputTokens: 3}, + } + + p := newTestProvider(t, modelName, resp) + + msgs := []provider.Message{{Role: provider.RoleUser, Content: "Hi"}} + got, err := p.Chat(context.Background(), msgs, nil) + if err != nil { + t.Fatalf("Chat returned error: %v", err) + } + if got.Content != "Hello from mock" { + t.Errorf("expected 'Hello from mock', got %q", got.Content) + } + if got.Usage.InputTokens != 5 { + t.Errorf("expected 5 input tokens, got %d", got.Usage.InputTokens) + } +} + +func TestGenkitProviderName(t *testing.T) { + p := &genkitProvider{name: "test-provider"} + if p.Name() != "test-provider" { + t.Errorf("expected 'test-provider', got %q", p.Name()) + } +} + +func TestGenkitProviderStream(t *testing.T) { + const modelName = "mock/stream-model" + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("streamed"), + Usage: &ai.GenerationUsage{InputTokens: 2, OutputTokens: 1}, + } + + p := newTestProvider(t, modelName, resp) + + msgs := []provider.Message{{Role: provider.RoleUser, Content: "Hi"}} + ch, err := p.Stream(context.Background(), msgs, nil) + if err != nil { + t.Fatalf("Stream returned error: %v", err) + } + + var events []provider.StreamEvent + for ev := range ch { + events = append(events, ev) + } + + if len(events) == 0 { + t.Fatal("expected at least one event") + } + last := events[len(events)-1] + if last.Type != "done" { + t.Errorf("expected last event type 'done', got %q", last.Type) + } +} + +func TestGenkitProviderChatWithTools(t *testing.T) { + const modelName = "mock/tool-model" + + // Mock model that returns a tool request + resp := &ai.ModelResponse{ + Message: ai.NewMessage(ai.RoleModel, nil, + ai.NewToolRequestPart(&ai.ToolRequest{ + Name: "calculator", + Input: map[string]any{"a": 1, "b": 2}, + }), + ), + } + + p := newTestProvider(t, modelName, resp) + + msgs := []provider.Message{{Role: provider.RoleUser, Content: "Add 1+2"}} + tools := []provider.ToolDef{{ + Name: "calculator", + Description: "Adds numbers", + Parameters: map[string]any{"type": "object"}, + }} + + got, err := p.Chat(context.Background(), msgs, tools) + if err != nil { + t.Fatalf("Chat returned error: %v", err) + } + + if len(got.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(got.ToolCalls)) + } + if got.ToolCalls[0].Name != "calculator" { + t.Errorf("expected tool 'calculator', got %q", got.ToolCalls[0].Name) + } +} + +func TestGenkitProviderChatToolDeduplication(t *testing.T) { + const modelName = "mock/dedup-model" + + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("ok"), + } + + p := newTestProvider(t, modelName, resp) + + tools := []provider.ToolDef{{ + Name: "my-tool", + Description: "A tool", + Parameters: map[string]any{"type": "object"}, + }} + + msgs := []provider.Message{{Role: provider.RoleUser, Content: "call 1"}} + + // First call - defines the tool + _, err := p.Chat(context.Background(), msgs, tools) + if err != nil { + t.Fatalf("first Chat error: %v", err) + } + + // Second call - must not panic due to duplicate tool registration + msgs2 := []provider.Message{{Role: provider.RoleUser, Content: "call 2"}} + _, err = p.Chat(context.Background(), msgs2, tools) + if err != nil { + t.Fatalf("second Chat error: %v", err) + } +} From affcbc51b18c38621fce8ea539a80e93591b7d38 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 06:30:34 -0400 Subject: [PATCH 06/14] feat: add Genkit provider factory functions for all provider types Implement factory functions for Anthropic, OpenAI, Google AI, Ollama, OpenAI-compatible (OpenRouter/Copilot/Cohere/llama.cpp), Azure OpenAI, Anthropic Foundry, Vertex AI, and AWS Bedrock. Each factory creates a Genkit instance with the appropriate plugin. Co-Authored-By: Claude Opus 4.6 --- genkit/providers.go | 241 +++++++++++++++++++++++++++++++++++++++ genkit/providers_test.go | 92 +++++++++++++++ go.mod | 2 + go.sum | 2 + 4 files changed, 337 insertions(+) create mode 100644 genkit/providers.go create mode 100644 genkit/providers_test.go diff --git a/genkit/providers.go b/genkit/providers.go new file mode 100644 index 0000000..2272293 --- /dev/null +++ b/genkit/providers.go @@ -0,0 +1,241 @@ +package genkit + +import ( + "context" + "fmt" + + "github.com/GoCodeAlone/workflow-plugin-agent/provider" + gk "github.com/firebase/genkit/go/genkit" + anthropicPlugin "github.com/firebase/genkit/go/plugins/anthropic" + "github.com/firebase/genkit/go/plugins/compat_oai" + openaiPlugin "github.com/firebase/genkit/go/plugins/compat_oai/openai" + "github.com/firebase/genkit/go/plugins/googlegenai" + ollamaPlugin "github.com/firebase/genkit/go/plugins/ollama" + "github.com/openai/openai-go/option" +) + + +// initGenkitWithPlugin creates a Genkit instance with a single plugin registered. +func initGenkitWithPlugin(ctx context.Context, plugin gk.GenkitOption) *gk.Genkit { + return gk.Init(ctx, plugin) +} + +// NewAnthropicProvider creates a provider backed by Genkit's Anthropic plugin. +func NewAnthropicProvider(ctx context.Context, apiKey, model, baseURL string, maxTokens int) (provider.Provider, error) { + if apiKey == "" { + return nil, fmt.Errorf("anthropic: APIKey is required") + } + p := &anthropicPlugin.Anthropic{APIKey: apiKey, BaseURL: baseURL} + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "anthropic/" + model, + name: "anthropic", + authInfo: provider.AuthModeInfo{ + Mode: "api_key", + DisplayName: "Anthropic", + ServerSafe: true, + }, + }, nil +} + +// NewOpenAIProvider creates a provider backed by Genkit's OpenAI plugin. +func NewOpenAIProvider(ctx context.Context, apiKey, model, baseURL string, maxTokens int) (provider.Provider, error) { + if apiKey == "" { + return nil, fmt.Errorf("openai: APIKey is required") + } + var extraOpts []option.RequestOption + if baseURL != "" { + extraOpts = append(extraOpts, option.WithBaseURL(baseURL)) + } + p := &openaiPlugin.OpenAI{APIKey: apiKey, Opts: extraOpts} + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "openai/" + model, + name: "openai", + authInfo: provider.AuthModeInfo{ + Mode: "api_key", + DisplayName: "OpenAI", + ServerSafe: true, + }, + }, nil +} + +// NewGoogleAIProvider creates a provider backed by Genkit's Google AI plugin (Gemini API). +func NewGoogleAIProvider(ctx context.Context, apiKey, model string, maxTokens int) (provider.Provider, error) { + if apiKey == "" { + return nil, fmt.Errorf("googleai: APIKey is required") + } + p := &googlegenai.GoogleAI{APIKey: apiKey} + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "googleai/" + model, + name: "googleai", + authInfo: provider.AuthModeInfo{ + Mode: "api_key", + DisplayName: "Google AI (Gemini)", + ServerSafe: true, + }, + }, nil +} + +// NewOllamaProvider creates a provider backed by Genkit's Ollama plugin. +func NewOllamaProvider(ctx context.Context, model, serverAddress string, maxTokens int) (provider.Provider, error) { + if serverAddress == "" { + serverAddress = "http://localhost:11434" + } + p := &ollamaPlugin.Ollama{ServerAddress: serverAddress} + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "ollama/" + model, + name: "ollama", + authInfo: provider.AuthModeInfo{ + Mode: "none", + DisplayName: "Ollama (local)", + ServerSafe: true, + }, + }, nil +} + +// NewOpenAICompatibleProvider creates a provider for OpenAI-compatible endpoints. +// Used for OpenRouter, Copilot, Cohere, HuggingFace, llama.cpp, etc. +func NewOpenAICompatibleProvider(ctx context.Context, providerName, apiKey, model, baseURL string, maxTokens int) (provider.Provider, error) { + effectiveKey := apiKey + if effectiveKey == "" { + // Use a placeholder to avoid errors when no key is needed (e.g., local endpoints) + effectiveKey = "no-key" + } + p := &compat_oai.OpenAICompatible{ + Provider: providerName, + APIKey: effectiveKey, + BaseURL: baseURL, + } + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: providerName + "/" + model, + name: providerName, + authInfo: provider.AuthModeInfo{ + Mode: "api_key", + DisplayName: providerName, + ServerSafe: true, + }, + }, nil +} + +// NewAzureOpenAIProvider creates a provider for Azure OpenAI Service. +func NewAzureOpenAIProvider(ctx context.Context, resource, deploymentName, apiVersion, apiKey, entraToken string, maxTokens int) (provider.Provider, error) { + if resource == "" { + return nil, fmt.Errorf("openai_azure: resource is required") + } + if apiKey == "" && entraToken == "" { + return nil, fmt.Errorf("openai_azure: apiKey or entraToken is required") + } + if apiVersion == "" { + apiVersion = "2024-10-21" + } + + baseURL := fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s", resource, deploymentName) + + opts := []option.RequestOption{ + option.WithBaseURL(baseURL), + option.WithHeaderDel("authorization"), + option.WithQuery("api-version", apiVersion), + } + if apiKey != "" { + opts = append(opts, option.WithHeader("api-key", apiKey)) + } else { + opts = append(opts, option.WithHeader("authorization", "Bearer "+entraToken)) + } + + // Use a placeholder API key to avoid requiring OPENAI_API_KEY env var + p := &openaiPlugin.OpenAI{APIKey: "azure-placeholder", Opts: opts} + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "openai/" + deploymentName, + name: "openai_azure", + authInfo: provider.AuthModeInfo{ + Mode: "azure", + DisplayName: "OpenAI (Azure OpenAI Service)", + ServerSafe: true, + }, + }, nil +} + +// NewAnthropicFoundryProvider creates a provider for Anthropic on Azure AI Foundry. +func NewAnthropicFoundryProvider(ctx context.Context, resource, model, apiKey, entraToken string, maxTokens int) (provider.Provider, error) { + if resource == "" { + return nil, fmt.Errorf("anthropic_foundry: resource is required") + } + effectiveKey := apiKey + if effectiveKey == "" { + effectiveKey = entraToken + } + if effectiveKey == "" { + return nil, fmt.Errorf("anthropic_foundry: apiKey or entraToken is required") + } + + // Azure AI Foundry Anthropic endpoints + baseURL := fmt.Sprintf("https://%s.services.ai.azure.com/models", resource) + p := &anthropicPlugin.Anthropic{APIKey: effectiveKey, BaseURL: baseURL} + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "anthropic/" + model, + name: "anthropic_foundry", + authInfo: provider.AuthModeInfo{ + Mode: "azure", + DisplayName: "Anthropic (Azure AI Foundry)", + ServerSafe: true, + }, + }, nil +} + +// NewVertexAIProvider creates a provider backed by Genkit's Vertex AI plugin. +func NewVertexAIProvider(ctx context.Context, projectID, region, model, credentialsJSON string, maxTokens int) (provider.Provider, error) { + if projectID == "" { + return nil, fmt.Errorf("vertexai: projectID is required") + } + if region == "" { + region = "us-central1" + } + p := &googlegenai.VertexAI{ + ProjectID: projectID, + Location: region, + } + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "vertexai/" + model, + name: "vertexai", + authInfo: provider.AuthModeInfo{ + Mode: "gcp", + DisplayName: "Vertex AI", + ServerSafe: true, + }, + }, nil +} + +// NewBedrockProvider creates a provider for AWS Bedrock using an OpenAI-compatible endpoint. +// Wraps existing Bedrock implementation as a provider.Provider until a native Genkit plugin is available. +func NewBedrockProvider(ctx context.Context, region, model, accessKeyID, secretAccessKey, sessionToken, baseURL string, maxTokens int) (provider.Provider, error) { + if secretAccessKey == "" { + return nil, fmt.Errorf("anthropic_bedrock: secretAccessKey is required") + } + if region == "" { + region = "us-east-1" + } + return provider.NewAnthropicBedrockProvider(provider.AnthropicBedrockConfig{ + Region: region, + Model: model, + MaxTokens: maxTokens, + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + SessionToken: sessionToken, + BaseURL: baseURL, + }) +} diff --git a/genkit/providers_test.go b/genkit/providers_test.go new file mode 100644 index 0000000..42008f4 --- /dev/null +++ b/genkit/providers_test.go @@ -0,0 +1,92 @@ +package genkit + +import ( + "context" + "testing" + + "github.com/GoCodeAlone/workflow-plugin-agent/provider" +) + +func TestNewAnthropicProvider_MissingKey(t *testing.T) { + _, err := NewAnthropicProvider(context.Background(), "", "claude-sonnet-4-6", "", 4096) + if err == nil { + t.Error("expected error for missing API key") + } +} + +func TestNewOpenAIProvider_MissingKey(t *testing.T) { + _, err := NewOpenAIProvider(context.Background(), "", "gpt-4o", "", 4096) + if err == nil { + t.Error("expected error for missing API key") + } +} + +func TestNewGoogleAIProvider_MissingKey(t *testing.T) { + _, err := NewGoogleAIProvider(context.Background(), "", "gemini-2.0-flash", 4096) + if err == nil { + t.Error("expected error for missing API key") + } +} + +func TestNewOllamaProvider_DefaultAddress(t *testing.T) { + // Ollama doesn't require an API key; verify factory instantiation works + // with default address (no real server needed for creation) + p, err := NewOllamaProvider(context.Background(), "qwen3:8b", "", 4096) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if p == nil { + t.Fatal("expected non-nil provider") + } + if p.Name() != "ollama" { + t.Errorf("expected name 'ollama', got %q", p.Name()) + } +} + +func TestNewOpenAICompatibleProvider_NoKey(t *testing.T) { + // Local providers may not need a key + p, err := NewOpenAICompatibleProvider(context.Background(), "local", "", "model", "http://localhost:8080", 4096) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if p == nil { + t.Fatal("expected non-nil provider") + } +} + +func TestNewAzureOpenAIProvider_MissingResource(t *testing.T) { + _, err := NewAzureOpenAIProvider(context.Background(), "", "gpt-4o", "2024-10-21", "key", "", 4096) + if err == nil { + t.Error("expected error for missing resource") + } +} + +func TestNewAzureOpenAIProvider_MissingCredentials(t *testing.T) { + _, err := NewAzureOpenAIProvider(context.Background(), "myresource", "gpt-4o", "2024-10-21", "", "", 4096) + if err == nil { + t.Error("expected error for missing credentials") + } +} + +func TestNewBedrockProvider_MissingKey(t *testing.T) { + _, err := NewBedrockProvider(context.Background(), "us-east-1", "anthropic.claude-sonnet-4", "", "", "", "", 4096) + if err == nil { + t.Error("expected error for missing secret key") + } +} + +func TestNewVertexAIProvider_MissingProject(t *testing.T) { + _, err := NewVertexAIProvider(context.Background(), "", "us-central1", "gemini-2.0-flash", "", 4096) + if err == nil { + t.Error("expected error for missing project ID") + } +} + +func TestProviderImplementsInterface(t *testing.T) { + // Ensure all returned providers implement provider.Provider + p, err := NewOllamaProvider(context.Background(), "test", "http://localhost:11434", 4096) + if err != nil { + t.Skip("factory failed, skipping interface check") + } + var _ provider.Provider = p +} diff --git a/go.mod b/go.mod index e10a23d..088ad56 100644 --- a/go.mod +++ b/go.mod @@ -124,6 +124,7 @@ require ( github.com/golobby/cast v1.3.3 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect @@ -250,6 +251,7 @@ require ( golang.org/x/sys v0.41.0 // indirect golang.org/x/time v0.15.0 // indirect golang.org/x/tools v0.42.0 // indirect + google.golang.org/genai v1.41.0 // indirect google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect diff --git a/go.sum b/go.sum index dcb76a0..f6f5552 100644 --- a/go.sum +++ b/go.sum @@ -914,6 +914,8 @@ gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E google.golang.org/api v0.271.0 h1:cIPN4qcUc61jlh7oXu6pwOQqbJW2GqYh5PS6rB2C/JY= google.golang.org/api v0.271.0/go.mod h1:CGT29bhwkbF+i11qkRUJb2KMKqcJ1hdFceEIRd9u64Q= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genai v1.41.0 h1:ayXl75LjTmqTu0y94yr96d17gIb4zF8gWVzX2TgioEY= +google.golang.org/genai v1.41.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 h1:VQZ/yAbAtjkHgH80teYd2em3xtIkkHd7ZhqfH2N9CsM= google.golang.org/genproto v0.0.0-20260128011058-8636f8732409/go.mod h1:rxKD3IEILWEu3P44seeNOAwZN4SaoKaQ/2eTg4mM6EM= google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171 h1:tu/dtnW1o3wfaxCOjSLn5IRX4YDcJrtlpzYkhHhGaC4= From 718eef5c3768ae2021e52bff0d7e4bd61829f9ac Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 06:32:12 -0400 Subject: [PATCH 07/14] feat: update ProviderRegistry to use Genkit-backed factories Replace all provider.New*Provider() calls in ProviderRegistry factory functions with gkprov.New*Provider() calls that create Genkit-backed providers. OpenAI-compatible factories (openrouter, copilot, cohere, copilot_models, llama_cpp) route through NewOpenAICompatibleProvider. Co-Authored-By: Claude Opus 4.6 --- orchestrator/provider_registry.go | 119 +++++++++--------------------- 1 file changed, 36 insertions(+), 83 deletions(-) diff --git a/orchestrator/provider_registry.go b/orchestrator/provider_registry.go index bc44c9e..5e2b100 100644 --- a/orchestrator/provider_registry.go +++ b/orchestrator/provider_registry.go @@ -8,6 +8,7 @@ import ( "sync" "time" + gkprov "github.com/GoCodeAlone/workflow-plugin-agent/genkit" "github.com/GoCodeAlone/workflow-plugin-agent/provider" "github.com/GoCodeAlone/workflow/secrets" ) @@ -243,80 +244,56 @@ func mockProviderFactory(_ string, _ LLMProviderConfig) (provider.Provider, erro } func anthropicProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewAnthropicProvider(provider.AnthropicConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + return gkprov.NewAnthropicProvider(context.Background(), apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } func openaiProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewOpenAIProvider(provider.OpenAIConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + return gkprov.NewOpenAIProvider(context.Background(), apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } func openrouterProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewOpenRouterProvider(provider.OpenRouterConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://openrouter.ai/api/v1" + } + return gkprov.NewOpenAICompatibleProvider(context.Background(), "openrouter", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } func copilotProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewCopilotProvider(provider.CopilotConfig{ - Token: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://api.githubcopilot.com" + } + return gkprov.NewOpenAICompatibleProvider(context.Background(), "copilot", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } func cohereProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewCohereProvider(provider.CohereConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://api.cohere.ai/v1" + } + return gkprov.NewOpenAICompatibleProvider(context.Background(), "cohere", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } func copilotModelsProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewCopilotModelsProvider(provider.CopilotModelsConfig{ - Token: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://models.inference.ai.azure.com" + } + return gkprov.NewOpenAICompatibleProvider(context.Background(), "copilot_models", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } func openaiAzureProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { s := cfg.settings() - return provider.NewOpenAIAzureProvider(provider.OpenAIAzureConfig{ - Resource: s["resource"], - DeploymentName: s["deployment_name"], - APIVersion: s["api_version"], - APIKey: apiKey, - EntraToken: s["entra_token"], - MaxTokens: cfg.MaxTokens, - }) + return gkprov.NewAzureOpenAIProvider(context.Background(), + s["resource"], s["deployment_name"], s["api_version"], + apiKey, s["entra_token"], cfg.MaxTokens) } func anthropicFoundryProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { s := cfg.settings() - return provider.NewAnthropicFoundryProvider(provider.AnthropicFoundryConfig{ - Resource: s["resource"], - Model: cfg.Model, - MaxTokens: cfg.MaxTokens, - APIKey: apiKey, - EntraToken: s["entra_token"], - }) + return gkprov.NewAnthropicFoundryProvider(context.Background(), + s["resource"], cfg.Model, apiKey, s["entra_token"], cfg.MaxTokens) } func anthropicVertexProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { @@ -325,49 +302,25 @@ func anthropicVertexProviderFactory(apiKey string, cfg LLMProviderConfig) (provi if credJSON == "" { credJSON = apiKey // fallback: secret may contain the full GCP credentials JSON } - return provider.NewAnthropicVertexProvider(provider.AnthropicVertexConfig{ - ProjectID: s["project_id"], - Region: s["region"], - Model: cfg.Model, - MaxTokens: cfg.MaxTokens, - CredentialsJSON: credJSON, - }) + return gkprov.NewVertexAIProvider(context.Background(), + s["project_id"], s["region"], cfg.Model, credJSON, cfg.MaxTokens) } func geminiProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewGeminiProvider(provider.GeminiConfig{ - APIKey: apiKey, - Model: cfg.Model, - MaxTokens: cfg.MaxTokens, - }) + return gkprov.NewGoogleAIProvider(context.Background(), apiKey, cfg.Model, cfg.MaxTokens) } func ollamaProviderFactory(_ string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewOllamaProvider(provider.OllamaConfig{ - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + return gkprov.NewOllamaProvider(context.Background(), cfg.Model, cfg.BaseURL, cfg.MaxTokens) } func llamaCppProviderFactory(_ string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewLlamaCppProvider(provider.LlamaCppConfig{ - BaseURL: cfg.BaseURL, - ModelPath: cfg.Model, - ModelName: cfg.Model, - MaxTokens: cfg.MaxTokens, - }), nil + // llama.cpp serves an OpenAI-compatible API + return gkprov.NewOpenAICompatibleProvider(context.Background(), "llama_cpp", "", cfg.Model, cfg.BaseURL, cfg.MaxTokens) } func anthropicBedrockProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { s := cfg.settings() - return provider.NewAnthropicBedrockProvider(provider.AnthropicBedrockConfig{ - Region: s["region"], - Model: cfg.Model, - MaxTokens: cfg.MaxTokens, - AccessKeyID: s["access_key_id"], - SecretAccessKey: apiKey, - SessionToken: s["session_token"], - BaseURL: cfg.BaseURL, - }) + return gkprov.NewBedrockProvider(context.Background(), + s["region"], cfg.Model, s["access_key_id"], apiKey, s["session_token"], cfg.BaseURL, cfg.MaxTokens) } From 6b3a70f8986a92e52e49126b9b96f7a664f64fbf Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 06:42:37 -0400 Subject: [PATCH 08/14] feat: replace old provider impls with Genkit adapters (Task 6) Delete 26 hand-rolled provider files and replace their usage in provider_registry.go, module_provider.go, and step_model_pull.go with Genkit factory calls. Preserve anthropic_bedrock.go (still used by genkit/providers.go) alongside new anthropic_bedrock_convert.go. Add OllamaClient for Pull/ListModels utility access in step_model_pull. Move models.go constants/types/helpers from deleted files. Co-Authored-By: Claude Sonnet 4.6 --- module_provider.go | 58 +-- provider/anthropic.go | 150 ------ ...onvert.go => anthropic_bedrock_convert.go} | 2 + provider/anthropic_convert_test.go | 110 ----- provider/anthropic_foundry.go | 344 -------------- provider/anthropic_foundry_test.go | 311 ------------ provider/anthropic_vertex.go | 188 -------- provider/anthropic_vertex_test.go | 271 ----------- provider/cohere.go | 446 ------------------ provider/copilot.go | 178 ------- provider/copilot_models.go | 55 --- provider/copilot_models_test.go | 139 ------ provider/copilot_test.go | 297 ------------ provider/gemini.go | 325 ------------- provider/gemini_test.go | 110 ----- provider/llama_cpp.go | 231 --------- provider/llama_cpp_test.go | 170 ------- provider/models.go | 64 ++- provider/ollama.go | 163 ------- provider/ollama_client.go | 71 +++ provider/ollama_convert.go | 96 ---- provider/ollama_convert_test.go | 233 --------- provider/ollama_test.go | 217 --------- provider/openai.go | 97 ---- provider/openai_azure.go | 132 ------ provider/openai_azure_test.go | 218 --------- provider/openai_convert.go | 196 -------- provider/openrouter.go | 44 -- provider/openrouter_test.go | 139 ------ provider_registry.go | 51 +- step_model_pull.go | 9 +- 31 files changed, 181 insertions(+), 4934 deletions(-) delete mode 100644 provider/anthropic.go rename provider/{anthropic_convert.go => anthropic_bedrock_convert.go} (99%) delete mode 100644 provider/anthropic_convert_test.go delete mode 100644 provider/anthropic_foundry.go delete mode 100644 provider/anthropic_foundry_test.go delete mode 100644 provider/anthropic_vertex.go delete mode 100644 provider/anthropic_vertex_test.go delete mode 100644 provider/cohere.go delete mode 100644 provider/copilot.go delete mode 100644 provider/copilot_models.go delete mode 100644 provider/copilot_models_test.go delete mode 100644 provider/copilot_test.go delete mode 100644 provider/gemini.go delete mode 100644 provider/gemini_test.go delete mode 100644 provider/llama_cpp.go delete mode 100644 provider/llama_cpp_test.go delete mode 100644 provider/ollama.go create mode 100644 provider/ollama_client.go delete mode 100644 provider/ollama_convert.go delete mode 100644 provider/ollama_convert_test.go delete mode 100644 provider/ollama_test.go delete mode 100644 provider/openai.go delete mode 100644 provider/openai_azure.go delete mode 100644 provider/openai_azure_test.go delete mode 100644 provider/openai_convert.go delete mode 100644 provider/openrouter.go delete mode 100644 provider/openrouter_test.go diff --git a/module_provider.go b/module_provider.go index 9a78270..9f6becf 100644 --- a/module_provider.go +++ b/module_provider.go @@ -7,6 +7,7 @@ import ( "time" "github.com/GoCodeAlone/modular" + gkprov "github.com/GoCodeAlone/workflow-plugin-agent/genkit" "github.com/GoCodeAlone/workflow-plugin-agent/provider" "github.com/GoCodeAlone/workflow/plugin" ) @@ -183,43 +184,42 @@ func newProviderModuleFactory() plugin.ModuleFactory { } case "anthropic": - p = provider.NewAnthropicProvider(provider.AnthropicConfig{ - APIKey: apiKey, - Model: model, - BaseURL: baseURL, - MaxTokens: maxTokens, - }) + if prov, err := gkprov.NewAnthropicProvider(context.Background(), apiKey, model, baseURL, maxTokens); err != nil { + p = &errProvider{err: err} + } else { + p = prov + } case "openai": - p = provider.NewOpenAIProvider(provider.OpenAIConfig{ - APIKey: apiKey, - Model: model, - BaseURL: baseURL, - MaxTokens: maxTokens, - }) + if prov, err := gkprov.NewOpenAIProvider(context.Background(), apiKey, model, baseURL, maxTokens); err != nil { + p = &errProvider{err: err} + } else { + p = prov + } case "copilot": - p = provider.NewCopilotProvider(provider.CopilotConfig{ - Token: apiKey, - Model: model, - BaseURL: baseURL, - MaxTokens: maxTokens, - }) + if baseURL == "" { + baseURL = "https://api.githubcopilot.com" + } + if prov, err := gkprov.NewOpenAICompatibleProvider(context.Background(), "copilot", apiKey, model, baseURL, maxTokens); err != nil { + p = &errProvider{err: err} + } else { + p = prov + } case "ollama": - p = provider.NewOllamaProvider(provider.OllamaConfig{ - Model: model, - BaseURL: baseURL, - MaxTokens: maxTokens, - }) + if prov, err := gkprov.NewOllamaProvider(context.Background(), model, baseURL, maxTokens); err != nil { + p = &errProvider{err: err} + } else { + p = prov + } case "llama_cpp": - p = provider.NewLlamaCppProvider(provider.LlamaCppConfig{ - BaseURL: baseURL, - ModelPath: model, - ModelName: model, - MaxTokens: maxTokens, - }) + if prov, err := gkprov.NewOpenAICompatibleProvider(context.Background(), "llama_cpp", "", model, baseURL, maxTokens); err != nil { + p = &errProvider{err: err} + } else { + p = prov + } default: p = &errProvider{err: fmt.Errorf("agent.provider %q: unrecognized provider type %q (supported: mock, test, anthropic, openai, copilot, ollama, llama_cpp)", name, providerType)} diff --git a/provider/anthropic.go b/provider/anthropic.go deleted file mode 100644 index ae8cd5d..0000000 --- a/provider/anthropic.go +++ /dev/null @@ -1,150 +0,0 @@ -package provider - -import ( - "context" - "fmt" - "net/http" - - anthropic "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/option" -) - -const ( - defaultAnthropicBaseURL = "https://api.anthropic.com" - defaultAnthropicModel = "claude-sonnet-4-20250514" - defaultAnthropicMaxTokens = 4096 - anthropicAPIVersion = "2023-06-01" -) - -// AnthropicConfig holds configuration for the Anthropic provider. -type AnthropicConfig struct { - APIKey string - Model string - BaseURL string - MaxTokens int - HTTPClient *http.Client -} - -// AnthropicProvider implements Provider using the Anthropic Messages API. -type AnthropicProvider struct { - client anthropic.Client - config AnthropicConfig -} - -// NewAnthropicProvider creates a new Anthropic provider with the given config. -func NewAnthropicProvider(cfg AnthropicConfig) *AnthropicProvider { - if cfg.Model == "" { - cfg.Model = defaultAnthropicModel - } - if cfg.BaseURL == "" { - cfg.BaseURL = defaultAnthropicBaseURL - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultAnthropicMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - client := anthropic.NewClient( - option.WithAPIKey(cfg.APIKey), - option.WithBaseURL(cfg.BaseURL), - option.WithHTTPClient(cfg.HTTPClient), - ) - return &AnthropicProvider{client: client, config: cfg} -} - -func (p *AnthropicProvider) Name() string { return "anthropic" } - -func (p *AnthropicProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "direct", - DisplayName: "Anthropic (Direct API)", - Description: "Direct access to Anthropic's Claude models via API key.", - DocsURL: "https://platform.claude.com/docs/en/api/getting-started", - ServerSafe: true, - } -} - -func (p *AnthropicProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - params := toAnthropicParams(p.config.Model, p.config.MaxTokens, messages, tools) - msg, err := p.client.Messages.New(ctx, params) - if err != nil { - return nil, fmt.Errorf("anthropic: %w", err) - } - return fromAnthropicMessage(msg) -} - -func (p *AnthropicProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - params := toAnthropicParams(p.config.Model, p.config.MaxTokens, messages, tools) - stream := p.client.Messages.NewStreaming(ctx, params) - if stream.Err() != nil { - return nil, fmt.Errorf("anthropic: %w", stream.Err()) - } - ch := make(chan StreamEvent, 16) - go streamAnthropicEvents(stream, ch) - return ch, nil -} - -// JSON types used by anthropic_foundry.go (hand-rolled HTTP, no SDK support). - -type anthropicRequest struct { - Model string `json:"model"` - MaxTokens int `json:"max_tokens"` - System string `json:"system,omitempty"` - Messages []anthropicMessage `json:"messages"` - Tools []anthropicTool `json:"tools,omitempty"` - Stream bool `json:"stream,omitempty"` -} - -type anthropicMessage struct { - Role string `json:"role"` - Content any `json:"content"` -} - -type anthropicContent struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input map[string]any `json:"input,omitempty"` - ToolUseID string `json:"tool_use_id,omitempty"` - Content string `json:"content,omitempty"` - IsError bool `json:"is_error,omitempty"` - CacheCtrl *cacheControl `json:"cache_control,omitempty"` -} - -type cacheControl struct { - Type string `json:"type"` -} - -type anthropicTool struct { - Name string `json:"name"` - Description string `json:"description"` - InputSchema map[string]any `json:"input_schema"` -} - -type anthropicResponse struct { - ID string `json:"id"` - Type string `json:"type"` - Content []anthropicRespItem `json:"content"` - Usage anthropicUsage `json:"usage"` - Error *anthropicError `json:"error,omitempty"` -} - -type anthropicRespItem struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input map[string]any `json:"input,omitempty"` -} - -type anthropicUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` -} - -type anthropicError struct { - Type string `json:"type"` - Message string `json:"message"` -} diff --git a/provider/anthropic_convert.go b/provider/anthropic_bedrock_convert.go similarity index 99% rename from provider/anthropic_convert.go rename to provider/anthropic_bedrock_convert.go index bd5450f..429a54a 100644 --- a/provider/anthropic_convert.go +++ b/provider/anthropic_bedrock_convert.go @@ -10,6 +10,8 @@ import ( "github.com/anthropics/anthropic-sdk-go/packages/ssestream" ) +const defaultAnthropicMaxTokens = 4096 + // toAnthropicParams converts provider types to SDK MessageNewParams. func toAnthropicParams(model string, maxTokens int, messages []Message, tools []ToolDef) anthropic.MessageNewParams { params := anthropic.MessageNewParams{ diff --git a/provider/anthropic_convert_test.go b/provider/anthropic_convert_test.go deleted file mode 100644 index ae395c5..0000000 --- a/provider/anthropic_convert_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package provider - -import ( - "encoding/json" - "testing" - - anthropic "github.com/anthropics/anthropic-sdk-go" -) - -// TestToAnthropicParams_AssistantToolCalls verifies that assistant messages with -// tool calls are correctly serialized as tool_use content blocks (multi-turn tool use). -func TestToAnthropicParams_AssistantToolCalls(t *testing.T) { - messages := []Message{ - {Role: RoleUser, Content: "What's the weather?"}, - { - Role: RoleAssistant, - Content: "", - ToolCalls: []ToolCall{ - {ID: "call_1", Name: "get_weather", Arguments: map[string]any{"location": "NYC"}}, - }, - }, - {Role: RoleTool, ToolCallID: "call_1", Content: "72°F, sunny"}, - {Role: RoleAssistant, Content: "It's 72°F and sunny in NYC."}, - } - - params := toAnthropicParams("claude-sonnet-4-20250514", 1024, messages, nil) - - if len(params.Messages) != 4 { - t.Fatalf("expected 4 messages, got %d", len(params.Messages)) - } - - // Message[1] should be an assistant message with a tool_use block. - assistantMsg := params.Messages[1] - if assistantMsg.Role != anthropic.MessageParamRoleAssistant { - t.Fatalf("expected assistant role at index 1, got %q", assistantMsg.Role) - } - blocks := assistantMsg.Content - if len(blocks) != 1 { - t.Fatalf("expected 1 content block in assistant message, got %d", len(blocks)) - } - toolBlock := blocks[0] - if toolBlock.OfToolUse == nil { - t.Fatal("expected tool_use block in assistant message") - } - if toolBlock.OfToolUse.ID != "call_1" { - t.Errorf("tool_use ID = %q, want %q", toolBlock.OfToolUse.ID, "call_1") - } - if toolBlock.OfToolUse.Name != "get_weather" { - t.Errorf("tool_use Name = %q, want %q", toolBlock.OfToolUse.Name, "get_weather") - } - inputBytes, ok := toolBlock.OfToolUse.Input.([]byte) - if !ok { - t.Fatalf("expected []byte input, got %T", toolBlock.OfToolUse.Input) - } - var args map[string]any - if err := json.Unmarshal(inputBytes, &args); err != nil { - t.Fatalf("unmarshal tool_use input: %v", err) - } - if args["location"] != "NYC" { - t.Errorf("tool_use input location = %v, want NYC", args["location"]) - } - - // Message[2] should be a user message with a tool_result block. - userMsg := params.Messages[2] - if userMsg.Role != anthropic.MessageParamRoleUser { - t.Fatalf("expected user role at index 2, got %q", userMsg.Role) - } -} - -// TestToAnthropicParams_AssistantTextAndToolCalls verifies mixed text + tool_use blocks. -func TestToAnthropicParams_AssistantTextAndToolCalls(t *testing.T) { - messages := []Message{ - { - Role: RoleAssistant, - Content: "Let me check that.", - ToolCalls: []ToolCall{ - {ID: "call_2", Name: "search", Arguments: map[string]any{"q": "Go generics"}}, - }, - }, - } - - params := toAnthropicParams("claude-sonnet-4-20250514", 1024, messages, nil) - if len(params.Messages) != 1 { - t.Fatalf("expected 1 message, got %d", len(params.Messages)) - } - blocks := params.Messages[0].Content - if len(blocks) != 2 { - t.Fatalf("expected 2 blocks (text + tool_use), got %d", len(blocks)) - } - if blocks[0].OfText == nil { - t.Error("expected first block to be text") - } - if blocks[1].OfToolUse == nil { - t.Error("expected second block to be tool_use") - } -} - -// TestFromAnthropicMessage_ToolCallUnmarshalError verifies that malformed tool -// call arguments produce an error instead of silent data loss. -func TestFromAnthropicMessage_ToolCallUnmarshalError(t *testing.T) { - msg := &anthropic.Message{ - Content: []anthropic.ContentBlockUnion{ - {Type: "tool_use", ID: "call_bad", Name: "broken_tool", Input: json.RawMessage(`{invalid json}`)}, - }, - } - _, err := fromAnthropicMessage(msg) - if err == nil { - t.Fatal("expected error for malformed tool call input, got nil") - } -} diff --git a/provider/anthropic_foundry.go b/provider/anthropic_foundry.go deleted file mode 100644 index 4ede8db..0000000 --- a/provider/anthropic_foundry.go +++ /dev/null @@ -1,344 +0,0 @@ -package provider - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" -) - -const defaultFoundryModel = "claude-sonnet-4-20250514" - -// AnthropicFoundryConfig configures the Anthropic provider for Microsoft Azure AI Foundry. -// Uses Azure API keys or Entra ID (formerly Azure AD) tokens. -type AnthropicFoundryConfig struct { - // Resource is the Azure AI Services resource name (forms the URL: {resource}.services.ai.azure.com). - Resource string - // Model is the model deployment name. - Model string - // MaxTokens limits the response length. - MaxTokens int - // APIKey is the Azure API key (use this OR Entra ID token, not both). - APIKey string - // EntraToken is a Microsoft Entra ID bearer token (optional, alternative to APIKey). - EntraToken string - // HTTPClient is the HTTP client to use (defaults to http.DefaultClient). - HTTPClient *http.Client -} - -// anthropicFoundryProvider accesses Anthropic models via Azure AI Foundry. -type anthropicFoundryProvider struct { - config AnthropicFoundryConfig - url string -} - -// NewAnthropicFoundryProvider creates a provider that accesses Claude via Azure AI Foundry. -// -// Docs: https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry -func NewAnthropicFoundryProvider(cfg AnthropicFoundryConfig) (*anthropicFoundryProvider, error) { - if cfg.Resource == "" { - return nil, fmt.Errorf("anthropic_foundry: resource name is required") - } - if cfg.APIKey == "" && cfg.EntraToken == "" { - return nil, fmt.Errorf("anthropic_foundry: either APIKey or EntraToken is required") - } - if cfg.Model == "" { - cfg.Model = defaultFoundryModel - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultAnthropicMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - return &anthropicFoundryProvider{ - config: cfg, - url: fmt.Sprintf("https://%s.services.ai.azure.com/anthropic/v1/messages", cfg.Resource), - }, nil -} - -func (p *anthropicFoundryProvider) Name() string { return "anthropic_foundry" } - -func (p *anthropicFoundryProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "foundry", - DisplayName: "Anthropic (Azure AI Foundry)", - Description: "Access Claude models via Microsoft Azure AI Foundry using Azure API keys or Microsoft Entra ID tokens.", - DocsURL: "https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry", - ServerSafe: true, - } -} - -func (p *anthropicFoundryProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - reqBody := p.buildRequest(messages, tools, false) - - data, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(data)) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: create request: %w", err) - } - p.setHeaders(req) - - resp, err := p.config.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: send request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("anthropic_foundry: API error (status %d): %s", resp.StatusCode, string(body)) - } - - var apiResp anthropicResponse - if err := json.Unmarshal(body, &apiResp); err != nil { - return nil, fmt.Errorf("anthropic_foundry: unmarshal response: %w", err) - } - - if apiResp.Error != nil { - return nil, fmt.Errorf("anthropic_foundry: %s: %s", apiResp.Error.Type, apiResp.Error.Message) - } - - return foundryParseResponse(&apiResp), nil -} - -func (p *anthropicFoundryProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - reqBody := p.buildRequest(messages, tools, true) - - data, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(data)) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: create request: %w", err) - } - p.setHeaders(req) - - resp, err := p.config.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: send request: %w", err) - } - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - _ = resp.Body.Close() - return nil, fmt.Errorf("anthropic_foundry: API error (status %d): %s", resp.StatusCode, string(body)) - } - - ch := make(chan StreamEvent, 16) - go foundryReadSSE(resp.Body, ch) - return ch, nil -} - -func (p *anthropicFoundryProvider) setHeaders(req *http.Request) { - req.Header.Set("Content-Type", "application/json") - req.Header.Set("anthropic-version", anthropicAPIVersion) - if p.config.EntraToken != "" { - req.Header.Set("Authorization", "Bearer "+p.config.EntraToken) - } else { - req.Header.Set("api-key", p.config.APIKey) - } -} - -func (p *anthropicFoundryProvider) buildRequest(messages []Message, tools []ToolDef, stream bool) *anthropicRequest { - req := &anthropicRequest{ - Model: p.config.Model, - MaxTokens: p.config.MaxTokens, - Stream: stream, - } - - var apiMessages []anthropicMessage - for _, msg := range messages { - if msg.Role == RoleSystem { - req.System = msg.Content - continue - } - if msg.Role == RoleTool { - apiMessages = append(apiMessages, anthropicMessage{ - Role: "user", - Content: []anthropicContent{ - { - Type: "tool_result", - ToolUseID: msg.ToolCallID, - Content: msg.Content, - }, - }, - }) - continue - } - apiMessages = append(apiMessages, anthropicMessage{ - Role: string(msg.Role), - Content: msg.Content, - }) - } - req.Messages = apiMessages - - for _, t := range tools { - schema := t.Parameters - if schema == nil { - schema = map[string]any{"type": "object", "properties": map[string]any{}} - } - req.Tools = append(req.Tools, anthropicTool{ - Name: t.Name, - Description: t.Description, - InputSchema: schema, - }) - } - - return req -} - -func foundryParseResponse(apiResp *anthropicResponse) *Response { - resp := &Response{ - Usage: Usage{ - InputTokens: apiResp.Usage.InputTokens, - OutputTokens: apiResp.Usage.OutputTokens, - }, - } - - var textParts []string - for _, item := range apiResp.Content { - switch item.Type { - case "text": - textParts = append(textParts, item.Text) - case "tool_use": - resp.ToolCalls = append(resp.ToolCalls, ToolCall{ - ID: item.ID, - Name: item.Name, - Arguments: item.Input, - }) - } - } - resp.Content = strings.Join(textParts, "") - - return resp -} - -func foundryReadSSE(body io.ReadCloser, ch chan<- StreamEvent) { - defer func() { _ = body.Close() }() - defer close(ch) - - scanner := bufio.NewScanner(body) - - var currentToolID, currentToolName string - var toolInputBuf bytes.Buffer - var usage *Usage - - for scanner.Scan() { - line := scanner.Text() - - if !strings.HasPrefix(line, "data: ") { - continue - } - - data := strings.TrimPrefix(line, "data: ") - if data == "[DONE]" { - break - } - - var event struct { - Type string `json:"type"` - Index int `json:"index"` - ContentBlock *struct { - Type string `json:"type"` - ID string `json:"id"` - Name string `json:"name"` - Text string `json:"text"` - } `json:"content_block"` - Delta *struct { - Type string `json:"type"` - Text string `json:"text"` - PartialJSON string `json:"partial_json"` - StopReason string `json:"stop_reason"` - } `json:"delta"` - Message *struct { - Usage anthropicUsage `json:"usage"` - } `json:"message"` - Usage *anthropicUsage `json:"usage"` - } - - if err := json.Unmarshal([]byte(data), &event); err != nil { - continue - } - - switch event.Type { - case "message_start": - if event.Message != nil { - usage = &Usage{ - InputTokens: event.Message.Usage.InputTokens, - OutputTokens: event.Message.Usage.OutputTokens, - } - } - - case "content_block_start": - if event.ContentBlock != nil && event.ContentBlock.Type == "tool_use" { - currentToolID = event.ContentBlock.ID - currentToolName = event.ContentBlock.Name - toolInputBuf.Reset() - } - - case "content_block_delta": - if event.Delta == nil { - continue - } - switch event.Delta.Type { - case "text_delta": - ch <- StreamEvent{Type: "text", Text: event.Delta.Text} - case "input_json_delta": - toolInputBuf.WriteString(event.Delta.PartialJSON) - } - - case "content_block_stop": - if currentToolID != "" { - var args map[string]any - if toolInputBuf.Len() > 0 { - if err := json.Unmarshal(toolInputBuf.Bytes(), &args); err != nil { - ch <- StreamEvent{Type: "error", Error: fmt.Sprintf("tool %s: malformed arguments: %v", currentToolName, err)} - } - } - ch <- StreamEvent{ - Type: "tool_call", - Tool: &ToolCall{ - ID: currentToolID, - Name: currentToolName, - Arguments: args, - }, - } - currentToolID = "" - currentToolName = "" - toolInputBuf.Reset() - } - - case "message_delta": - if event.Usage != nil && usage != nil { - usage.OutputTokens = event.Usage.OutputTokens - } - - case "message_stop": - ch <- StreamEvent{Type: "done", Usage: usage} - return - - case "error": - ch <- StreamEvent{Type: "error", Error: data} - return - } - } - - if err := scanner.Err(); err != nil { - ch <- StreamEvent{Type: "error", Error: fmt.Sprintf("foundry stream read: %v", err)} - } -} diff --git a/provider/anthropic_foundry_test.go b/provider/anthropic_foundry_test.go deleted file mode 100644 index 72a9de4..0000000 --- a/provider/anthropic_foundry_test.go +++ /dev/null @@ -1,311 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestNewAnthropicFoundryProvider_Validation(t *testing.T) { - tests := []struct { - name string - cfg AnthropicFoundryConfig - wantErr string - }{ - { - name: "missing resource", - cfg: AnthropicFoundryConfig{APIKey: "key"}, - wantErr: "resource name is required", - }, - { - name: "missing auth", - cfg: AnthropicFoundryConfig{Resource: "myresource"}, - wantErr: "either APIKey or EntraToken is required", - }, - { - name: "valid with api key", - cfg: AnthropicFoundryConfig{Resource: "myresource", APIKey: "key"}, - }, - { - name: "valid with entra token", - cfg: AnthropicFoundryConfig{Resource: "myresource", EntraToken: "token"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p, err := NewAnthropicFoundryProvider(tt.cfg) - if tt.wantErr != "" { - if err == nil { - t.Fatalf("expected error containing %q, got nil", tt.wantErr) - } - if got := err.Error(); !strings.Contains(got, tt.wantErr) { - t.Fatalf("error %q does not contain %q", got, tt.wantErr) - } - return - } - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if p == nil { - t.Fatal("expected non-nil provider") - } - }) - } -} - -func TestAnthropicFoundryProvider_Defaults(t *testing.T) { - p, err := NewAnthropicFoundryProvider(AnthropicFoundryConfig{ - Resource: "myresource", - APIKey: "key", - }) - if err != nil { - t.Fatal(err) - } - - if p.config.Model != defaultFoundryModel { - t.Errorf("default model = %q, want %q", p.config.Model, defaultFoundryModel) - } - if p.config.MaxTokens != defaultAnthropicMaxTokens { - t.Errorf("default max tokens = %d, want %d", p.config.MaxTokens, defaultAnthropicMaxTokens) - } -} - -func TestAnthropicFoundryProvider_Name(t *testing.T) { - p := &anthropicFoundryProvider{} - if got := p.Name(); got != "anthropic_foundry" { - t.Errorf("Name() = %q, want %q", got, "anthropic_foundry") - } -} - -func TestAnthropicFoundryProvider_AuthModeInfo(t *testing.T) { - p := &anthropicFoundryProvider{} - info := p.AuthModeInfo() - if info.Mode != "foundry" { - t.Errorf("Mode = %q, want %q", info.Mode, "foundry") - } - if info.DisplayName != "Anthropic (Azure AI Foundry)" { - t.Errorf("DisplayName = %q", info.DisplayName) - } -} - -func TestAnthropicFoundryProvider_URLConstruction(t *testing.T) { - p, err := NewAnthropicFoundryProvider(AnthropicFoundryConfig{ - Resource: "my-ai-resource", - APIKey: "key", - }) - if err != nil { - t.Fatal(err) - } - - want := "https://my-ai-resource.services.ai.azure.com/anthropic/v1/messages" - if p.url != want { - t.Errorf("url = %q, want %q", p.url, want) - } -} - -func TestAnthropicFoundryProvider_APIKeyAuth(t *testing.T) { - var gotHeaders http.Header - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotHeaders = r.Header.Clone() - json.NewEncoder(w).Encode(anthropicResponse{ - ID: "msg_123", - Type: "message", - Content: []anthropicRespItem{ - {Type: "text", Text: "hello"}, - }, - Usage: anthropicUsage{InputTokens: 10, OutputTokens: 5}, - }) - })) - defer srv.Close() - - p := &anthropicFoundryProvider{ - config: AnthropicFoundryConfig{ - Resource: "test", - Model: "claude-sonnet-4-20250514", - MaxTokens: 1024, - APIKey: "my-azure-key", - HTTPClient: srv.Client(), - }, - url: srv.URL, - } - - _, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - if got := gotHeaders.Get("api-key"); got != "my-azure-key" { - t.Errorf("api-key header = %q, want %q", got, "my-azure-key") - } - if got := gotHeaders.Get("anthropic-version"); got != anthropicAPIVersion { - t.Errorf("anthropic-version header = %q, want %q", got, anthropicAPIVersion) - } - if got := gotHeaders.Get("Authorization"); got != "" { - t.Errorf("Authorization header should be empty with API key auth, got %q", got) - } -} - -func TestAnthropicFoundryProvider_EntraIDAuth(t *testing.T) { - var gotHeaders http.Header - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotHeaders = r.Header.Clone() - json.NewEncoder(w).Encode(anthropicResponse{ - ID: "msg_123", - Type: "message", - Content: []anthropicRespItem{{Type: "text", Text: "hello"}}, - Usage: anthropicUsage{InputTokens: 10, OutputTokens: 5}, - }) - })) - defer srv.Close() - - p := &anthropicFoundryProvider{ - config: AnthropicFoundryConfig{ - Resource: "test", - Model: "claude-sonnet-4-20250514", - MaxTokens: 1024, - EntraToken: "my-entra-token", - HTTPClient: srv.Client(), - }, - url: srv.URL, - } - - _, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - if got := gotHeaders.Get("Authorization"); got != "Bearer my-entra-token" { - t.Errorf("Authorization header = %q, want %q", got, "Bearer my-entra-token") - } - if got := gotHeaders.Get("api-key"); got != "" { - t.Errorf("api-key header should be empty with Entra auth, got %q", got) - } -} - -func TestAnthropicFoundryProvider_Chat(t *testing.T) { - var gotBody anthropicRequest - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewDecoder(r.Body).Decode(&gotBody) - json.NewEncoder(w).Encode(anthropicResponse{ - ID: "msg_123", - Type: "message", - Content: []anthropicRespItem{ - {Type: "text", Text: "Hello from Foundry!"}, - }, - Usage: anthropicUsage{InputTokens: 15, OutputTokens: 8}, - }) - })) - defer srv.Close() - - p := &anthropicFoundryProvider{ - config: AnthropicFoundryConfig{ - Resource: "test", - Model: "claude-sonnet-4-20250514", - MaxTokens: 1024, - APIKey: "key", - HTTPClient: srv.Client(), - }, - url: srv.URL, - } - - resp, err := p.Chat(t.Context(), []Message{ - {Role: RoleSystem, Content: "You are helpful."}, - {Role: RoleUser, Content: "Hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - if resp.Content != "Hello from Foundry!" { - t.Errorf("Content = %q, want %q", resp.Content, "Hello from Foundry!") - } - if resp.Usage.InputTokens != 15 { - t.Errorf("InputTokens = %d, want 15", resp.Usage.InputTokens) - } - if resp.Usage.OutputTokens != 8 { - t.Errorf("OutputTokens = %d, want 8", resp.Usage.OutputTokens) - } - - // Verify request body matches Anthropic Messages format - if gotBody.Model != "claude-sonnet-4-20250514" { - t.Errorf("request model = %q", gotBody.Model) - } - if gotBody.System != "You are helpful." { - t.Errorf("request system = %q", gotBody.System) - } - if len(gotBody.Messages) != 1 { - t.Fatalf("request messages len = %d, want 1", len(gotBody.Messages)) - } -} - -func TestAnthropicFoundryProvider_Stream(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - flusher := w.(http.Flusher) - - events := []string{ - `{"type":"message_start","message":{"usage":{"input_tokens":20,"output_tokens":0}}}`, - `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`, - `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`, - `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" world"}}`, - `{"type":"content_block_stop","index":0}`, - `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":10}}`, - `{"type":"message_stop"}`, - } - - for _, e := range events { - fmt.Fprintf(w, "data: %s\n\n", e) - flusher.Flush() - } - })) - defer srv.Close() - - p := &anthropicFoundryProvider{ - config: AnthropicFoundryConfig{ - Resource: "test", - Model: "claude-sonnet-4-20250514", - MaxTokens: 1024, - APIKey: "key", - HTTPClient: srv.Client(), - }, - url: srv.URL, - } - - ch, err := p.Stream(t.Context(), []Message{ - {Role: RoleUser, Content: "Hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - var texts []string - var done bool - for ev := range ch { - switch ev.Type { - case "text": - texts = append(texts, ev.Text) - case "done": - done = true - if ev.Usage == nil { - t.Error("expected usage in done event") - } else if ev.Usage.OutputTokens != 10 { - t.Errorf("OutputTokens = %d, want 10", ev.Usage.OutputTokens) - } - } - } - - if !done { - t.Error("expected done event") - } - if got := strings.Join(texts, ""); got != "Hello world" { - t.Errorf("streamed text = %q, want %q", got, "Hello world") - } -} diff --git a/provider/anthropic_vertex.go b/provider/anthropic_vertex.go deleted file mode 100644 index fff35bf..0000000 --- a/provider/anthropic_vertex.go +++ /dev/null @@ -1,188 +0,0 @@ -package provider - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - - anthropic "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/option" - "github.com/anthropics/anthropic-sdk-go/vertex" - - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" -) - -const ( - defaultVertexModel = "claude-sonnet-4@20250514" - defaultVertexRegion = "us-east5" -) - -// AnthropicVertexConfig configures the Anthropic provider for Google Vertex AI. -// Uses GCP Application Default Credentials (ADC) or explicit OAuth2 tokens. -type AnthropicVertexConfig struct { - // ProjectID is the GCP project ID. - ProjectID string - // Region is the GCP region (e.g. "us-east5", "europe-west1"). - Region string - // Model is the Vertex model ID (e.g. "claude-sonnet-4@20250514"). - Model string - // MaxTokens limits the response length. - MaxTokens int - // CredentialsJSON is the GCP service account JSON (optional if using ADC). - CredentialsJSON string - // TokenSource provides OAuth2 tokens (optional, for testing or custom auth). - // If set, CredentialsJSON and ADC are ignored. - TokenSource oauth2.TokenSource - // HTTPClient is the HTTP client to use (defaults to http.DefaultClient). - HTTPClient *http.Client -} - -// anthropicVertexProvider accesses Anthropic models via Google Vertex AI. -type anthropicVertexProvider struct { - client anthropic.Client - config AnthropicVertexConfig -} - -// NewAnthropicVertexProvider creates a provider that accesses Claude via Google Vertex AI. -// -// Docs: https://platform.claude.com/docs/en/build-with-claude/claude-on-vertex-ai -func NewAnthropicVertexProvider(cfg AnthropicVertexConfig) (*anthropicVertexProvider, error) { - if cfg.ProjectID == "" { - return nil, fmt.Errorf("anthropic_vertex: project ID is required") - } - if cfg.Region == "" { - cfg.Region = defaultVertexRegion - } - if cfg.Model == "" { - cfg.Model = defaultVertexModel - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultAnthropicMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - - var client anthropic.Client - - baseURL := fmt.Sprintf("https://%s-aiplatform.googleapis.com/", cfg.Region) - - if cfg.TokenSource != nil { - // Testing / custom-auth path: use middleware for token injection and path rewriting. - client = anthropic.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(cfg.HTTPClient), - option.WithMiddleware(vertexPathRewriteMiddleware(cfg.Region, cfg.ProjectID)), - option.WithMiddleware(vertexBearerTokenMiddleware(cfg.TokenSource)), - ) - } else if cfg.CredentialsJSON != "" { - creds, err := google.CredentialsFromJSON(context.Background(), []byte(cfg.CredentialsJSON), - "https://www.googleapis.com/auth/cloud-platform") - if err != nil { - return nil, fmt.Errorf("anthropic_vertex: parse credentials JSON: %w", err) - } - client = anthropic.NewClient( - vertex.WithCredentials(context.Background(), cfg.Region, cfg.ProjectID, creds), - ) - } else { - client = anthropic.NewClient( - vertex.WithGoogleAuth(context.Background(), cfg.Region, cfg.ProjectID, - "https://www.googleapis.com/auth/cloud-platform"), - ) - } - - return &anthropicVertexProvider{client: client, config: cfg}, nil -} - -func (p *anthropicVertexProvider) Name() string { return "anthropic_vertex" } - -func (p *anthropicVertexProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "vertex", - DisplayName: "Anthropic (Google Vertex AI)", - Description: "Access Claude models via Google Cloud Vertex AI using Application Default Credentials (ADC) or service account JSON.", - DocsURL: "https://platform.claude.com/docs/en/build-with-claude/claude-on-vertex-ai", - ServerSafe: true, - } -} - -func (p *anthropicVertexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - params := toAnthropicParams(p.config.Model, p.config.MaxTokens, messages, tools) - msg, err := p.client.Messages.New(ctx, params) - if err != nil { - return nil, fmt.Errorf("anthropic_vertex: %w", err) - } - return fromAnthropicMessage(msg) -} - -func (p *anthropicVertexProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - params := toAnthropicParams(p.config.Model, p.config.MaxTokens, messages, tools) - stream := p.client.Messages.NewStreaming(ctx, params) - if stream.Err() != nil { - return nil, fmt.Errorf("anthropic_vertex: %w", stream.Err()) - } - ch := make(chan StreamEvent, 16) - go streamAnthropicEvents(stream, ch) - return ch, nil -} - -// vertexPathRewriteMiddleware rewrites /v1/messages to the Vertex AI model path. -func vertexPathRewriteMiddleware(region, projectID string) option.Middleware { - return func(r *http.Request, next option.MiddlewareNext) (*http.Response, error) { - if r.Body == nil || r.URL.Path != "/v1/messages" || r.Method != http.MethodPost { - return next(r) - } - body, err := io.ReadAll(r.Body) - if err != nil { - return nil, err - } - r.Body.Close() - - var params struct { - Model string `json:"model"` - Stream bool `json:"stream"` - } - _ = json.Unmarshal(body, ¶ms) - - // Remove model and stream from body (Vertex uses URL path instead) - var m map[string]any - _ = json.Unmarshal(body, &m) - delete(m, "model") - delete(m, "stream") - body, _ = json.Marshal(m) - - specifier := "rawPredict" - if params.Stream { - specifier = "streamRawPredict" - } - r.URL.Path = fmt.Sprintf("/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", - projectID, region, params.Model, specifier) - - reader := bytes.NewReader(body) - r.Body = io.NopCloser(reader) - r.GetBody = func() (io.ReadCloser, error) { - _, _ = reader.Seek(0, 0) - return io.NopCloser(reader), nil - } - r.ContentLength = int64(len(body)) - - return next(r) - } -} - -// vertexBearerTokenMiddleware adds a GCP Bearer token to each request. -func vertexBearerTokenMiddleware(ts oauth2.TokenSource) option.Middleware { - return func(r *http.Request, next option.MiddlewareNext) (*http.Response, error) { - token, err := ts.Token() - if err != nil { - return nil, fmt.Errorf("anthropic_vertex: get token: %w", err) - } - r.Header.Set("Authorization", "Bearer "+token.AccessToken) - r.Header.Set("anthropic-version", anthropicAPIVersion) - return next(r) - } -} diff --git a/provider/anthropic_vertex_test.go b/provider/anthropic_vertex_test.go deleted file mode 100644 index 196ca96..0000000 --- a/provider/anthropic_vertex_test.go +++ /dev/null @@ -1,271 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "golang.org/x/oauth2" -) - -// mockTokenSource returns a fixed token for testing. -type mockTokenSource struct { - token string -} - -func (m *mockTokenSource) Token() (*oauth2.Token, error) { - return &oauth2.Token{AccessToken: m.token}, nil -} - -func TestNewAnthropicVertexProvider_Validation(t *testing.T) { - tests := []struct { - name string - cfg AnthropicVertexConfig - wantErr string - }{ - { - name: "missing project ID", - cfg: AnthropicVertexConfig{TokenSource: &mockTokenSource{token: "tok"}}, - wantErr: "project ID is required", - }, - { - name: "valid with token source", - cfg: AnthropicVertexConfig{ - ProjectID: "my-project", - TokenSource: &mockTokenSource{token: "tok"}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p, err := NewAnthropicVertexProvider(tt.cfg) - if tt.wantErr != "" { - if err == nil { - t.Fatalf("expected error containing %q, got nil", tt.wantErr) - } - if got := err.Error(); !strings.Contains(got, tt.wantErr) { - t.Fatalf("error %q does not contain %q", got, tt.wantErr) - } - return - } - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if p == nil { - t.Fatal("expected non-nil provider") - } - }) - } -} - -func TestAnthropicVertexProvider_Defaults(t *testing.T) { - p, err := NewAnthropicVertexProvider(AnthropicVertexConfig{ - ProjectID: "my-project", - TokenSource: &mockTokenSource{token: "tok"}, - }) - if err != nil { - t.Fatal(err) - } - - if p.config.Model != defaultVertexModel { - t.Errorf("default model = %q, want %q", p.config.Model, defaultVertexModel) - } - if p.config.Region != defaultVertexRegion { - t.Errorf("default region = %q, want %q", p.config.Region, defaultVertexRegion) - } - if p.config.MaxTokens != defaultAnthropicMaxTokens { - t.Errorf("default max tokens = %d, want %d", p.config.MaxTokens, defaultAnthropicMaxTokens) - } -} - -func TestAnthropicVertexProvider_Name(t *testing.T) { - p := &anthropicVertexProvider{} - if got := p.Name(); got != "anthropic_vertex" { - t.Errorf("Name() = %q, want %q", got, "anthropic_vertex") - } -} - -func TestAnthropicVertexProvider_AuthModeInfo(t *testing.T) { - p := &anthropicVertexProvider{} - info := p.AuthModeInfo() - if info.Mode != "vertex" { - t.Errorf("Mode = %q, want %q", info.Mode, "vertex") - } - if info.DisplayName != "Anthropic (Google Vertex AI)" { - t.Errorf("DisplayName = %q", info.DisplayName) - } -} - -func TestAnthropicVertexProvider_BearerAuth(t *testing.T) { - var gotHeaders http.Header - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotHeaders = r.Header.Clone() - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "id": "msg_123", - "type": "message", - "content": []any{map[string]any{"type": "text", "text": "hello"}}, - "usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, - }) - })) - defer srv.Close() - - p, err := NewAnthropicVertexProvider(AnthropicVertexConfig{ - ProjectID: "proj", - Region: "us-east5", - Model: "claude-sonnet-4@20250514", - MaxTokens: 1024, - TokenSource: &mockTokenSource{token: "my-gcp-token"}, - HTTPClient: &http.Client{Transport: &urlRewriteTransport{target: srv.URL}}, - }) - if err != nil { - t.Fatal(err) - } - - _, err = p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - if got := gotHeaders.Get("Authorization"); got != "Bearer my-gcp-token" { - t.Errorf("Authorization header = %q, want %q", got, "Bearer my-gcp-token") - } - if got := gotHeaders.Get("anthropic-version"); got != anthropicAPIVersion { - t.Errorf("anthropic-version header = %q, want %q", got, anthropicAPIVersion) - } - // Vertex should NOT set x-api-key - if got := gotHeaders.Get("x-api-key"); got != "" { - t.Errorf("x-api-key should be empty, got %q", got) - } -} - -// urlRewriteTransport rewrites all request URLs to target for testing. -type urlRewriteTransport struct { - target string - transport http.RoundTripper -} - -func (t *urlRewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - req.URL.Scheme = "http" - req.URL.Host = t.target[len("http://"):] - if t.transport != nil { - return t.transport.RoundTrip(req) - } - return http.DefaultTransport.RoundTrip(req) -} - -func TestAnthropicVertexProvider_Chat(t *testing.T) { - var gotBody map[string]any - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewDecoder(r.Body).Decode(&gotBody) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "id": "msg_123", - "type": "message", - "content": []any{ - map[string]any{"type": "text", "text": "Hello from Vertex!"}, - }, - "usage": map[string]any{"input_tokens": 15, "output_tokens": 8}, - }) - })) - defer srv.Close() - - p, err := NewAnthropicVertexProvider(AnthropicVertexConfig{ - ProjectID: "proj", - Region: "us-east5", - Model: "claude-sonnet-4@20250514", - MaxTokens: 1024, - TokenSource: &mockTokenSource{token: "tok"}, - HTTPClient: &http.Client{Transport: &urlRewriteTransport{target: srv.URL}}, - }) - if err != nil { - t.Fatal(err) - } - - resp, err := p.Chat(t.Context(), []Message{ - {Role: RoleSystem, Content: "You are helpful."}, - {Role: RoleUser, Content: "Hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - if resp.Content != "Hello from Vertex!" { - t.Errorf("Content = %q, want %q", resp.Content, "Hello from Vertex!") - } - if resp.Usage.InputTokens != 15 { - t.Errorf("InputTokens = %d, want 15", resp.Usage.InputTokens) - } -} - -func TestAnthropicVertexProvider_Stream(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - flusher := w.(http.Flusher) - - events := []struct{ typ, data string }{ - {"message_start", `{"type":"message_start","message":{"usage":{"input_tokens":20,"output_tokens":0}}}`}, - {"content_block_start", `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`}, - {"content_block_delta", `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`}, - {"content_block_delta", `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" Vertex"}}`}, - {"content_block_stop", `{"type":"content_block_stop","index":0}`}, - {"message_delta", `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":10}}`}, - {"message_stop", `{"type":"message_stop"}`}, - } - - for _, e := range events { - fmt.Fprintf(w, "event: %s\ndata: %s\n\n", e.typ, e.data) - flusher.Flush() - } - })) - defer srv.Close() - - p, err := NewAnthropicVertexProvider(AnthropicVertexConfig{ - ProjectID: "proj", - Region: "us-east5", - Model: "claude-sonnet-4@20250514", - MaxTokens: 1024, - TokenSource: &mockTokenSource{token: "tok"}, - HTTPClient: &http.Client{Transport: &urlRewriteTransport{target: srv.URL}}, - }) - if err != nil { - t.Fatal(err) - } - - ch, err := p.Stream(t.Context(), []Message{ - {Role: RoleUser, Content: "Hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - var texts []string - var done bool - for ev := range ch { - switch ev.Type { - case "text": - texts = append(texts, ev.Text) - case "done": - done = true - if ev.Usage == nil { - t.Error("expected usage in done event") - } else if ev.Usage.OutputTokens != 10 { - t.Errorf("OutputTokens = %d, want 10", ev.Usage.OutputTokens) - } - } - } - - if !done { - t.Error("expected done event") - } - if got := strings.Join(texts, ""); got != "Hello Vertex" { - t.Errorf("streamed text = %q, want %q", got, "Hello Vertex") - } -} diff --git a/provider/cohere.go b/provider/cohere.go deleted file mode 100644 index e079af2..0000000 --- a/provider/cohere.go +++ /dev/null @@ -1,446 +0,0 @@ -package provider - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" -) - -const ( - defaultCohereBaseURL = "https://api.cohere.com" - defaultCohereModel = "command-r-plus" - defaultCohereMaxTokens = 4096 -) - -// CohereConfig holds configuration for the Cohere provider. -type CohereConfig struct { - APIKey string - Model string - BaseURL string - MaxTokens int - HTTPClient *http.Client -} - -// CohereProvider implements Provider using the Cohere Chat API v2. -type CohereProvider struct { - config CohereConfig -} - -// NewCohereProvider creates a new Cohere provider with the given config. -func NewCohereProvider(cfg CohereConfig) *CohereProvider { - if cfg.Model == "" { - cfg.Model = defaultCohereModel - } - if cfg.BaseURL == "" { - cfg.BaseURL = defaultCohereBaseURL - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultCohereMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - return &CohereProvider{config: cfg} -} - -func (p *CohereProvider) Name() string { return "cohere" } - -func (p *CohereProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "direct", - DisplayName: "Cohere (Direct API)", - Description: "Direct access to Cohere's Command models via API key.", - DocsURL: "https://docs.cohere.com/reference/chat", - ServerSafe: true, - } -} - -// Cohere Chat API v2 request types - -type cohereRequest struct { - Model string `json:"model"` - Messages []cohereMessage `json:"messages"` - Tools []cohereTool `json:"tools,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Stream bool `json:"stream,omitempty"` -} - -type cohereMessage struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` - ToolCalls []cohereToolCall `json:"tool_calls,omitempty"` -} - -type cohereToolCall struct { - ID string `json:"id"` - Type string `json:"type"` - Function cohereToolFunc `json:"function"` -} - -type cohereToolFunc struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type cohereTool struct { - Type string `json:"type"` - Function cohereToolDef `json:"function"` -} - -type cohereToolDef struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]any `json:"parameters"` -} - -// Cohere Chat API v2 response types - -type cohereResponse struct { - ID string `json:"id"` - Message cohereRespMsg `json:"message"` - FinishReason string `json:"finish_reason"` - Usage cohereUsage `json:"usage"` -} - -type cohereRespMsg struct { - Role string `json:"role"` - Content []cohereContent `json:"content"` - ToolCalls []cohereToolCall `json:"tool_calls,omitempty"` -} - -type cohereContent struct { - Type string `json:"type"` - Text string `json:"text"` -} - -type cohereUsage struct { - BilledUnits cohereBilledUnits `json:"billed_units"` - Tokens cohereTokens `json:"tokens"` -} - -type cohereBilledUnits struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` -} - -type cohereTokens struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` -} - -func (p *CohereProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - reqBody := p.buildRequest(messages, tools, false) - - data, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("cohere: marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.config.BaseURL+"/v2/chat", bytes.NewReader(data)) - if err != nil { - return nil, fmt.Errorf("cohere: create request: %w", err) - } - p.setHeaders(req) - - resp, err := p.config.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("cohere: send request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("cohere: read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("cohere: API error (status %d): %s", resp.StatusCode, string(body)) - } - - var apiResp cohereResponse - if err := json.Unmarshal(body, &apiResp); err != nil { - return nil, fmt.Errorf("cohere: unmarshal response: %w", err) - } - - return p.parseResponse(&apiResp), nil -} - -func (p *CohereProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - reqBody := p.buildRequest(messages, tools, true) - - data, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("cohere: marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.config.BaseURL+"/v2/chat", bytes.NewReader(data)) - if err != nil { - return nil, fmt.Errorf("cohere: create request: %w", err) - } - p.setHeaders(req) - - resp, err := p.config.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("cohere: send request: %w", err) - } - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - _ = resp.Body.Close() - return nil, fmt.Errorf("cohere: API error (status %d): %s", resp.StatusCode, string(body)) - } - - ch := make(chan StreamEvent, 16) - go p.readSSE(resp.Body, ch) - return ch, nil -} - -func (p *CohereProvider) buildRequest(messages []Message, tools []ToolDef, stream bool) *cohereRequest { - req := &cohereRequest{ - Model: p.config.Model, - MaxTokens: p.config.MaxTokens, - Stream: stream, - } - - // Cohere v2 uses the same role names as OpenAI: system, user, assistant, tool - for _, msg := range messages { - switch msg.Role { - case RoleTool: - req.Messages = append(req.Messages, cohereMessage{ - Role: "tool", - Content: msg.Content, - ToolCallID: msg.ToolCallID, - }) - case RoleAssistant: - cm := cohereMessage{ - Role: "assistant", - Content: msg.Content, - } - for _, tc := range msg.ToolCalls { - args := "{}" - if tc.Arguments != nil { - if b, err := json.Marshal(tc.Arguments); err == nil { - args = string(b) - } - } - cm.ToolCalls = append(cm.ToolCalls, cohereToolCall{ - ID: tc.ID, - Type: "function", - Function: cohereToolFunc{Name: tc.Name, Arguments: args}, - }) - } - req.Messages = append(req.Messages, cm) - default: - req.Messages = append(req.Messages, cohereMessage{ - Role: string(msg.Role), - Content: msg.Content, - }) - } - } - - // Convert tools to Cohere format (same as OpenAI) - for _, t := range tools { - schema := t.Parameters - if schema == nil { - schema = map[string]any{"type": "object", "properties": map[string]any{}} - } - req.Tools = append(req.Tools, cohereTool{ - Type: "function", - Function: cohereToolDef{ - Name: t.Name, - Description: t.Description, - Parameters: schema, - }, - }) - } - - return req -} - -func (p *CohereProvider) setHeaders(req *http.Request) { - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+p.config.APIKey) -} - -func (p *CohereProvider) parseResponse(apiResp *cohereResponse) *Response { - resp := &Response{ - Usage: Usage{ - InputTokens: apiResp.Usage.Tokens.InputTokens, - OutputTokens: apiResp.Usage.Tokens.OutputTokens, - }, - } - - // Extract text content - var textParts []string - for _, c := range apiResp.Message.Content { - if c.Type == "text" { - textParts = append(textParts, c.Text) - } - } - resp.Content = strings.Join(textParts, "") - - // Extract tool calls - for _, tc := range apiResp.Message.ToolCalls { - var args map[string]any - if tc.Function.Arguments != "" { - _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) - } - resp.ToolCalls = append(resp.ToolCalls, ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: args, - }) - } - - return resp -} - -// Cohere v2 streaming types - -type cohereStreamEvent struct { - Type string `json:"type"` - Index int `json:"index"` - Delta *cohereStreamDelta `json:"delta,omitempty"` - Response *cohereResponse `json:"response,omitempty"` -} - -type cohereStreamDelta struct { - Message *cohereStreamDeltaMsg `json:"message,omitempty"` -} - -type cohereStreamDeltaMsg struct { - Content *cohereStreamContent `json:"content,omitempty"` - ToolCalls *cohereStreamToolDelta `json:"tool_calls,omitempty"` -} - -type cohereStreamContent struct { - Text string `json:"text"` -} - -type cohereStreamToolDelta struct { - Function *cohereStreamFuncDelta `json:"function,omitempty"` -} - -type cohereStreamFuncDelta struct { - Name string `json:"name,omitempty"` - Arguments string `json:"arguments,omitempty"` -} - -// readSSE parses the SSE stream from the Cohere Chat API v2. -func (p *CohereProvider) readSSE(body io.ReadCloser, ch chan<- StreamEvent) { - defer func() { _ = body.Close() }() - defer close(ch) - - scanner := bufio.NewScanner(body) - - // Track tool calls being assembled by index - type pendingToolCall struct { - id string - name string - argsBuf strings.Builder - } - pending := make(map[int]*pendingToolCall) - - var usage *Usage - - for scanner.Scan() { - line := scanner.Text() - - if !strings.HasPrefix(line, "data: ") { - continue - } - - data := strings.TrimPrefix(line, "data: ") - if data == "[DONE]" { - break - } - - var event cohereStreamEvent - if err := json.Unmarshal([]byte(data), &event); err != nil { - continue - } - - switch event.Type { - case "content-delta": - if event.Delta != nil && event.Delta.Message != nil && event.Delta.Message.Content != nil { - ch <- StreamEvent{Type: "text", Text: event.Delta.Message.Content.Text} - } - - case "tool-call-start": - ptc := &pendingToolCall{} - if event.Delta != nil && event.Delta.Message != nil && event.Delta.Message.ToolCalls != nil { - if event.Delta.Message.ToolCalls.Function != nil { - ptc.name = event.Delta.Message.ToolCalls.Function.Name - } - } - pending[event.Index] = ptc - - case "tool-call-delta": - ptc, exists := pending[event.Index] - if !exists { - ptc = &pendingToolCall{} - pending[event.Index] = ptc - } - if event.Delta != nil && event.Delta.Message != nil && event.Delta.Message.ToolCalls != nil { - if event.Delta.Message.ToolCalls.Function != nil { - ptc.argsBuf.WriteString(event.Delta.Message.ToolCalls.Function.Arguments) - } - } - - case "tool-call-end": - ptc, exists := pending[event.Index] - if !exists { - continue - } - var args map[string]any - if ptc.argsBuf.Len() > 0 { - _ = json.Unmarshal([]byte(ptc.argsBuf.String()), &args) - } - ch <- StreamEvent{ - Type: "tool_call", - Tool: &ToolCall{ - ID: ptc.id, - Name: ptc.name, - Arguments: args, - }, - } - delete(pending, event.Index) - - case "message-end": - if event.Response != nil { - usage = &Usage{ - InputTokens: event.Response.Usage.Tokens.InputTokens, - OutputTokens: event.Response.Usage.Tokens.OutputTokens, - } - } - ch <- StreamEvent{Type: "done", Usage: usage} - return - } - } - - // Flush any pending tool calls - for _, ptc := range pending { - var args map[string]any - if ptc.argsBuf.Len() > 0 { - _ = json.Unmarshal([]byte(ptc.argsBuf.String()), &args) - } - ch <- StreamEvent{ - Type: "tool_call", - Tool: &ToolCall{ - ID: ptc.id, - Name: ptc.name, - Arguments: args, - }, - } - } - - if err := scanner.Err(); err != nil { - ch <- StreamEvent{Type: "error", Error: err.Error()} - } -} diff --git a/provider/copilot.go b/provider/copilot.go deleted file mode 100644 index 28db0ab..0000000 --- a/provider/copilot.go +++ /dev/null @@ -1,178 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "sync" - "time" - - openaisdk "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/shared" -) - -const ( - defaultCopilotBaseURL = "https://api.githubcopilot.com" - defaultCopilotModel = "gpt-4o" - defaultCopilotMaxTokens = 4096 - copilotTokenExchangeURL = "https://api.github.com/copilot_internal/v2/token" - copilotEditorVersion = "ratchet/0.1.0" -) - -// CopilotConfig holds configuration for the GitHub Copilot provider. -type CopilotConfig struct { - Token string - Model string - BaseURL string - MaxTokens int - HTTPClient *http.Client -} - -// CopilotProvider implements Provider using the GitHub Copilot Chat API. -// The API follows the OpenAI Chat Completions format. -type CopilotProvider struct { - config CopilotConfig - mu sync.Mutex - bearerToken string - expiresAt time.Time -} - -// copilotTokenResponse is the response from the Copilot token exchange endpoint. -type copilotTokenResponse struct { - Token string `json:"token"` - ExpiresAt int64 `json:"expires_at"` -} - -// NewCopilotProvider creates a new Copilot provider with the given config. -func NewCopilotProvider(cfg CopilotConfig) *CopilotProvider { - if cfg.Model == "" { - cfg.Model = defaultCopilotModel - } - if cfg.BaseURL == "" { - cfg.BaseURL = defaultCopilotBaseURL - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultCopilotMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - return &CopilotProvider{config: cfg} -} - -func (p *CopilotProvider) Name() string { return "copilot" } - -func (p *CopilotProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "personal", - DisplayName: "GitHub Copilot (Personal/IDE)", - Description: "Uses GitHub Copilot's chat completions API via OAuth token exchange. Intended for IDE/CLI use under a Copilot Individual or Business subscription.", - Warning: "This mode uses Copilot's internal API intended for IDE integrations. Using it in server/service contexts may violate GitHub Copilot Terms of Service (https://docs.github.com/en/site-policy/github-terms/github-terms-for-additional-products-and-features).", - DocsURL: "https://docs.github.com/en/copilot", - ServerSafe: false, - } -} - -func (p *CopilotProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - if err := p.ensureBearerToken(ctx); err != nil { - return nil, err - } - p.mu.Lock() - token := p.bearerToken - p.mu.Unlock() - - client := p.newSDKClient(token) - params := openaisdk.ChatCompletionNewParams{ - Model: shared.ChatModel(p.config.Model), - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - resp, err := client.Chat.Completions.New(ctx, params) - if err != nil { - return nil, fmt.Errorf("copilot: %w", err) - } - return fromOpenAIResponse(resp) -} - -func (p *CopilotProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - if err := p.ensureBearerToken(ctx); err != nil { - return nil, err - } - p.mu.Lock() - token := p.bearerToken - p.mu.Unlock() - - client := p.newSDKClient(token) - params := openaisdk.ChatCompletionNewParams{ - Model: shared.ChatModel(p.config.Model), - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - stream := client.Chat.Completions.NewStreaming(ctx, params) - ch := make(chan StreamEvent, 16) - go streamOpenAIEvents(stream, ch) - return ch, nil -} - -// newSDKClient creates a per-request OpenAI SDK client with Copilot auth headers. -func (p *CopilotProvider) newSDKClient(bearerToken string) openaisdk.Client { - return openaisdk.NewClient( - option.WithAPIKey(bearerToken), - option.WithBaseURL(p.config.BaseURL), - option.WithHTTPClient(p.config.HTTPClient), - option.WithHeader("Copilot-Integration-Id", "vscode-chat"), - option.WithHeader("Editor-Version", "vscode/1.100.0"), - option.WithHeader("Editor-Plugin-Version", copilotEditorVersion), - ) -} - -// ensureBearerToken exchanges the GitHub OAuth token for a short-lived Copilot -// bearer token, caching it until 60 seconds before expiry. -func (p *CopilotProvider) ensureBearerToken(ctx context.Context) error { - p.mu.Lock() - defer p.mu.Unlock() - - if p.bearerToken != "" && time.Now().Before(p.expiresAt.Add(-60*time.Second)) { - return nil - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotTokenExchangeURL, nil) - if err != nil { - return fmt.Errorf("copilot: create token exchange request: %w", err) - } - req.Header.Set("Authorization", "Token "+p.config.Token) - req.Header.Set("Accept", "application/json") - - resp, err := p.config.HTTPClient.Do(req) - if err != nil { - return fmt.Errorf("copilot: token exchange request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("copilot: read token exchange response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("copilot: token exchange failed (status %d): %s", resp.StatusCode, truncate(string(body), 200)) - } - - var tokenResp copilotTokenResponse - if err := json.Unmarshal(body, &tokenResp); err != nil { - return fmt.Errorf("copilot: parse token exchange response: %w", err) - } - - p.bearerToken = tokenResp.Token - p.expiresAt = time.Unix(tokenResp.ExpiresAt, 0) - return nil -} diff --git a/provider/copilot_models.go b/provider/copilot_models.go deleted file mode 100644 index 7ee7387..0000000 --- a/provider/copilot_models.go +++ /dev/null @@ -1,55 +0,0 @@ -package provider - -const defaultCopilotModelsBaseURL = "https://models.github.ai/inference" - -// CopilotModelsConfig configures the GitHub Models provider. -// GitHub Models is a separate product from GitHub Copilot, available at models.github.ai. -// It uses fine-grained Personal Access Tokens with the models:read scope. -type CopilotModelsConfig struct { - // Token is a GitHub fine-grained PAT with models:read permission. - Token string - // Model is the model identifier (e.g. "openai/gpt-4o", "anthropic/claude-sonnet-4"). - Model string - // BaseURL overrides the default endpoint. Default: "https://models.github.ai/inference". - BaseURL string - // MaxTokens limits the response length. - MaxTokens int -} - -// CopilotModelsProvider uses GitHub Models (models.github.ai) for inference. -// It wraps OpenAIProvider since GitHub Models uses an OpenAI-compatible API. -type CopilotModelsProvider struct { - *OpenAIProvider -} - -func (p *CopilotModelsProvider) Name() string { return "copilot_models" } - -func (p *CopilotModelsProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "github_models", - DisplayName: "GitHub Models", - Description: "Access AI models via GitHub's Models marketplace using a fine-grained PAT with models:read scope.", - DocsURL: "https://docs.github.com/en/rest/models/inference", - ServerSafe: true, - } -} - -// NewCopilotModelsProvider creates a provider that uses GitHub Models for inference. -// GitHub Models provides access to various AI models via a fine-grained PAT. -// -// Docs: https://docs.github.com/en/rest/models/inference -// Billing: https://docs.github.com/billing/managing-billing-for-your-products/about-billing-for-github-models -func NewCopilotModelsProvider(cfg CopilotModelsConfig) *CopilotModelsProvider { - baseURL := cfg.BaseURL - if baseURL == "" { - baseURL = defaultCopilotModelsBaseURL - } - return &CopilotModelsProvider{ - OpenAIProvider: NewOpenAIProvider(OpenAIConfig{ - APIKey: cfg.Token, - Model: cfg.Model, - BaseURL: baseURL, - MaxTokens: cfg.MaxTokens, - }), - } -} diff --git a/provider/copilot_models_test.go b/provider/copilot_models_test.go deleted file mode 100644 index 66fc1de..0000000 --- a/provider/copilot_models_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" -) - -func TestCopilotModelsProvider_Name(t *testing.T) { - p := NewCopilotModelsProvider(CopilotModelsConfig{Token: "ghp_test"}) - if got := p.Name(); got != "copilot_models" { - t.Errorf("Name() = %q, want %q", got, "copilot_models") - } -} - -func TestCopilotModelsProvider_AuthModeInfo(t *testing.T) { - p := NewCopilotModelsProvider(CopilotModelsConfig{Token: "ghp_test"}) - info := p.AuthModeInfo() - - if info.Mode != "github_models" { - t.Errorf("AuthModeInfo().Mode = %q, want %q", info.Mode, "github_models") - } - if info.DisplayName != "GitHub Models" { - t.Errorf("AuthModeInfo().DisplayName = %q, want %q", info.DisplayName, "GitHub Models") - } - if info.DocsURL != "https://docs.github.com/en/rest/models/inference" { - t.Errorf("AuthModeInfo().DocsURL = %q, want GitHub Models docs URL", info.DocsURL) - } - if !info.ServerSafe { - t.Error("AuthModeInfo().ServerSafe = false, want true") - } -} - -func TestCopilotModelsProvider_ImplementsProvider(t *testing.T) { - var _ Provider = (*CopilotModelsProvider)(nil) -} - -func TestCopilotModelsProvider_DefaultBaseURL(t *testing.T) { - p := NewCopilotModelsProvider(CopilotModelsConfig{Token: "ghp_test"}) - if p.config.BaseURL != defaultCopilotModelsBaseURL { - t.Errorf("default BaseURL = %q, want %q", p.config.BaseURL, defaultCopilotModelsBaseURL) - } -} - -func TestCopilotModelsProvider_CustomBaseURL(t *testing.T) { - custom := "https://custom.models.github.ai/inference" - p := NewCopilotModelsProvider(CopilotModelsConfig{Token: "ghp_test", BaseURL: custom}) - if p.config.BaseURL != custom { - t.Errorf("BaseURL = %q, want %q", p.config.BaseURL, custom) - } -} - -func TestCopilotModelsProvider_Chat(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("Authorization"); got != "Bearer ghp_test123" { - t.Errorf("Authorization header = %q, want %q", got, "Bearer ghp_test123") - } - - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"chatcmpl-gh","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"hello from github models"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":8,"completion_tokens":4,"total_tokens":12},"created":1704067200}`) - })) - defer srv.Close() - - p := NewCopilotModelsProvider(CopilotModelsConfig{ - Token: "ghp_test123", - Model: "openai/gpt-4o", - BaseURL: srv.URL, - }) - - got, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if got.Content != "hello from github models" { - t.Errorf("Chat() content = %q, want %q", got.Content, "hello from github models") - } - if got.Usage.InputTokens != 8 || got.Usage.OutputTokens != 4 { - t.Errorf("Chat() usage = %+v, want input=8 output=4", got.Usage) - } -} - -func TestCopilotModelsProvider_Stream(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - flusher, ok := w.(http.Flusher) - if !ok { - t.Fatal("expected http.Flusher") - } - - data, _ := json.Marshal(map[string]any{ - "id": "chatcmpl-gh-stream", - "object": "chat.completion.chunk", - "model": "gpt-4o", - "created": 1704067200, - "choices": []map[string]any{ - {"index": 0, "delta": map[string]any{"role": "assistant", "content": "gh-streamed"}, "finish_reason": nil}, - }, - }) - fmt.Fprintf(w, "data: %s\n\n", string(data)) - flusher.Flush() - - fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() - })) - defer srv.Close() - - p := NewCopilotModelsProvider(CopilotModelsConfig{ - Token: "ghp_test123", - Model: "openai/gpt-4o", - BaseURL: srv.URL, - }) - - ch, err := p.Stream(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Stream() error: %v", err) - } - - var texts []string - for ev := range ch { - switch ev.Type { - case "text": - texts = append(texts, ev.Text) - case "error": - t.Fatalf("stream error: %s", ev.Error) - } - } - if len(texts) == 0 { - t.Fatal("expected at least one text event") - } - if texts[0] != "gh-streamed" { - t.Errorf("first text event = %q, want %q", texts[0], "gh-streamed") - } -} diff --git a/provider/copilot_test.go b/provider/copilot_test.go deleted file mode 100644 index 7cca38d..0000000 --- a/provider/copilot_test.go +++ /dev/null @@ -1,297 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "sync/atomic" - "testing" - "time" -) - -// setupCopilotServers creates two test servers: one for token exchange and one for chat. -// Returns (tokenSrv, chatSrv, cleanup). -func setupCopilotServers(t *testing.T, tokenHandler, chatHandler http.HandlerFunc) (*httptest.Server, *httptest.Server) { - t.Helper() - tokenSrv := httptest.NewServer(tokenHandler) - chatSrv := httptest.NewServer(chatHandler) - t.Cleanup(func() { - tokenSrv.Close() - chatSrv.Close() - }) - return tokenSrv, chatSrv -} - -func validTokenHandler(expiresAt int64) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") == "" { - http.Error(w, "missing auth", http.StatusUnauthorized) - return - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(copilotTokenResponse{ - Token: "copilot-bearer-token", - ExpiresAt: expiresAt, - }) - } -} - -func validChatHandler(t *testing.T) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("Authorization"); got != "Bearer copilot-bearer-token" { - t.Errorf("chat Authorization = %q, want Bearer copilot-bearer-token", got) - } - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"chatcmpl-cop","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"hello from copilot"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8},"created":1704067200}`) - } -} - -func TestCopilotProvider_Name(t *testing.T) { - p := NewCopilotProvider(CopilotConfig{Token: "ghp_test"}) - if got := p.Name(); got != "copilot" { - t.Errorf("Name() = %q, want %q", got, "copilot") - } -} - -func TestCopilotProvider_AuthModeInfo(t *testing.T) { - p := NewCopilotProvider(CopilotConfig{Token: "ghp_test"}) - info := p.AuthModeInfo() - if info.Mode != "personal" { - t.Errorf("AuthModeInfo().Mode = %q, want %q", info.Mode, "personal") - } - if info.ServerSafe { - t.Error("AuthModeInfo().ServerSafe = true, want false") - } -} - -func TestCopilotProvider_ImplementsProvider(t *testing.T) { - var _ Provider = (*CopilotProvider)(nil) -} - -func TestCopilotProvider_Chat_HappyPath(t *testing.T) { - expiresAt := time.Now().Add(30 * time.Minute).Unix() - tokenSrv, chatSrv := setupCopilotServers(t, validTokenHandler(expiresAt), validChatHandler(t)) - - p := NewCopilotProvider(CopilotConfig{ - Token: "ghp_oauth_token", - Model: "gpt-4o", - BaseURL: chatSrv.URL, - }) - // Override the token exchange URL by patching the provider's internal call via an http.Client - // that redirects token exchange requests to our test server. - p.config.HTTPClient = redirectingClient(tokenSrv.URL) - - got, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hello"}, - }, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if got.Content != "hello from copilot" { - t.Errorf("Chat() content = %q, want %q", got.Content, "hello from copilot") - } - if got.Usage.InputTokens != 5 || got.Usage.OutputTokens != 3 { - t.Errorf("Chat() usage = %+v, want input=5 output=3", got.Usage) - } -} - -func TestCopilotProvider_TokenExchange_Error(t *testing.T) { - tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - })) - defer tokenSrv.Close() - - chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Error("chat server should not be called when token exchange fails") - })) - defer chatSrv.Close() - - p := NewCopilotProvider(CopilotConfig{ - Token: "bad_token", - Model: "gpt-4o", - BaseURL: chatSrv.URL, - HTTPClient: redirectingClient(tokenSrv.URL), - }) - - _, err := p.Chat(t.Context(), []Message{{Role: RoleUser, Content: "hi"}}, nil) - if err == nil { - t.Fatal("expected error when token exchange returns 401") - } -} - -func TestCopilotProvider_TokenCacheHit(t *testing.T) { - var tokenCallCount atomic.Int32 - expiresAt := time.Now().Add(30 * time.Minute).Unix() - - tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tokenCallCount.Add(1) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(copilotTokenResponse{ - Token: "cached-bearer-token", - ExpiresAt: expiresAt, - }) - })) - defer tokenSrv.Close() - - chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"c1","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2},"created":1704067200}`) - })) - defer chatSrv.Close() - - p := NewCopilotProvider(CopilotConfig{ - Token: "ghp_test", - Model: "gpt-4o", - BaseURL: chatSrv.URL, - HTTPClient: redirectingClient(tokenSrv.URL), - }) - - // Two sequential calls — token exchange should only happen once. - for range 2 { - if _, err := p.Chat(t.Context(), []Message{{Role: RoleUser, Content: "hi"}}, nil); err != nil { - t.Fatalf("Chat() error: %v", err) - } - } - - if n := tokenCallCount.Load(); n != 1 { - t.Errorf("token exchange called %d times, want 1 (cache should be hit on second call)", n) - } -} - -func TestCopilotProvider_TokenRefresh_AfterExpiry(t *testing.T) { - var tokenCallCount atomic.Int32 - // Already-expired token. - pastExpiry := time.Now().Add(-5 * time.Minute).Unix() - - tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tokenCallCount.Add(1) - w.Header().Set("Content-Type", "application/json") - // Return a fresh token each time. - json.NewEncoder(w).Encode(copilotTokenResponse{ - Token: fmt.Sprintf("token-%d", tokenCallCount.Load()), - ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), - }) - })) - defer tokenSrv.Close() - - chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"c1","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2},"created":1704067200}`) - })) - defer chatSrv.Close() - - p := NewCopilotProvider(CopilotConfig{ - Token: "ghp_test", - Model: "gpt-4o", - BaseURL: chatSrv.URL, - HTTPClient: redirectingClient(tokenSrv.URL), - }) - - // Manually set an expired bearer token. - p.bearerToken = "expired-token" - p.expiresAt = time.Unix(pastExpiry, 0) - - if _, err := p.Chat(t.Context(), []Message{{Role: RoleUser, Content: "hi"}}, nil); err != nil { - t.Fatalf("Chat() error: %v", err) - } - - if n := tokenCallCount.Load(); n != 1 { - t.Errorf("token exchange called %d times, want 1 (should refresh expired token)", n) - } -} - -func TestCopilotProvider_Stream(t *testing.T) { - expiresAt := time.Now().Add(30 * time.Minute).Unix() - - tokenSrv := httptest.NewServer(validTokenHandler(expiresAt)) - defer tokenSrv.Close() - - chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - flusher, ok := w.(http.Flusher) - if !ok { - t.Fatal("expected http.Flusher") - } - - data, _ := json.Marshal(map[string]any{ - "id": "chatcmpl-cop-stream", - "object": "chat.completion.chunk", - "model": "gpt-4o", - "created": 1704067200, - "choices": []map[string]any{ - {"index": 0, "delta": map[string]any{"role": "assistant", "content": "copilot-streamed"}, "finish_reason": nil}, - }, - }) - fmt.Fprintf(w, "data: %s\n\n", string(data)) - flusher.Flush() - - fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() - })) - defer chatSrv.Close() - - p := NewCopilotProvider(CopilotConfig{ - Token: "ghp_test", - Model: "gpt-4o", - BaseURL: chatSrv.URL, - HTTPClient: redirectingClient(tokenSrv.URL), - }) - - ch, err := p.Stream(t.Context(), []Message{{Role: RoleUser, Content: "hi"}}, nil) - if err != nil { - t.Fatalf("Stream() error: %v", err) - } - - var texts []string - for ev := range ch { - switch ev.Type { - case "text": - texts = append(texts, ev.Text) - case "error": - t.Fatalf("stream error: %s", ev.Error) - } - } - if len(texts) == 0 { - t.Fatal("expected at least one text event") - } - if texts[0] != "copilot-streamed" { - t.Errorf("first text event = %q, want %q", texts[0], "copilot-streamed") - } -} - -// redirectingClient returns an http.Client whose transport redirects all requests -// with the copilotTokenExchangeURL path to tokenBaseURL, leaving others unchanged. -func redirectingClient(tokenBaseURL string) *http.Client { - return &http.Client{ - Transport: &tokenRedirectTransport{ - tokenBaseURL: tokenBaseURL, - inner: http.DefaultTransport, - }, - } -} - -type tokenRedirectTransport struct { - tokenBaseURL string - inner http.RoundTripper -} - -func (t *tokenRedirectTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Redirect token exchange requests to our test server. - if req.URL.String() == copilotTokenExchangeURL { - redirected := req.Clone(req.Context()) - redirected.URL.Scheme = "http" - redirected.URL.Host = req.URL.Host - // Parse tokenBaseURL and override host. - srv := t.tokenBaseURL - // Strip scheme. - if len(srv) > 7 && srv[:7] == "http://" { - srv = srv[7:] - } - redirected.URL.Host = srv - redirected.URL.Path = req.URL.Path - redirected.Host = srv - return t.inner.RoundTrip(redirected) - } - return t.inner.RoundTrip(req) -} diff --git a/provider/gemini.go b/provider/gemini.go deleted file mode 100644 index 08f562c..0000000 --- a/provider/gemini.go +++ /dev/null @@ -1,325 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - - "github.com/google/generative-ai-go/genai" - "google.golang.org/api/iterator" - googleoption "google.golang.org/api/option" -) - -const ( - defaultGeminiModel = "gemini-2.0-flash" - defaultGeminiMaxTokens = 4096 -) - -// GeminiConfig holds configuration for the Google Gemini provider. -type GeminiConfig struct { - APIKey string - Model string - MaxTokens int - HTTPClient *http.Client -} - -// GeminiProvider implements Provider using the Google Gemini API. -type GeminiProvider struct { - config GeminiConfig -} - -// NewGeminiProvider creates a new Gemini provider. Returns an error if no API key is provided. -func NewGeminiProvider(cfg GeminiConfig) (*GeminiProvider, error) { - if cfg.APIKey == "" { - return nil, fmt.Errorf("gemini: APIKey is required") - } - if cfg.Model == "" { - cfg.Model = defaultGeminiModel - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultGeminiMaxTokens - } - return &GeminiProvider{config: cfg}, nil -} - -func (p *GeminiProvider) Name() string { return "gemini" } - -func (p *GeminiProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "gemini", - DisplayName: "Google Gemini", - Description: "Uses the Google Gemini API with an API key from Google AI Studio.", - DocsURL: "https://ai.google.dev/gemini-api/docs/api-key", - ServerSafe: true, - } -} - -func (p *GeminiProvider) newGenaiClient(ctx context.Context) (*genai.Client, error) { - opts := []googleoption.ClientOption{ - googleoption.WithAPIKey(p.config.APIKey), - } - if p.config.HTTPClient != nil { - opts = append(opts, googleoption.WithHTTPClient(p.config.HTTPClient)) - } - return genai.NewClient(ctx, opts...) -} - -func (p *GeminiProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - client, err := p.newGenaiClient(ctx) - if err != nil { - return nil, fmt.Errorf("gemini: create client: %w", err) - } - defer client.Close() - - model := client.GenerativeModel(p.config.Model) - maxOut := int32(p.config.MaxTokens) - model.MaxOutputTokens = &maxOut - if len(tools) > 0 { - model.Tools = toGeminiTools(tools) - } - - contents, systemInstruction := toGeminiContents(messages) - if systemInstruction != nil { - model.SystemInstruction = systemInstruction - } - - resp, err := model.GenerateContent(ctx, contents...) - if err != nil { - return nil, fmt.Errorf("gemini: %w", err) - } - return fromGeminiResponse(resp), nil -} - -func (p *GeminiProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - client, err := p.newGenaiClient(ctx) - if err != nil { - return nil, fmt.Errorf("gemini: create client: %w", err) - } - - model := client.GenerativeModel(p.config.Model) - maxOut := int32(p.config.MaxTokens) - model.MaxOutputTokens = &maxOut - if len(tools) > 0 { - model.Tools = toGeminiTools(tools) - } - - contents, systemInstruction := toGeminiContents(messages) - if systemInstruction != nil { - model.SystemInstruction = systemInstruction - } - - iter := model.GenerateContentStream(ctx, contents...) - ch := make(chan StreamEvent, 16) - go func() { - defer close(ch) - defer client.Close() - send := func(ev StreamEvent) bool { - select { - case ch <- ev: - return true - case <-ctx.Done(): - return false - } - } - for { - resp, err := iter.Next() - if err == iterator.Done { - send(StreamEvent{Type: "done"}) - return - } - if err != nil { - send(StreamEvent{Type: "error", Error: err.Error()}) - return - } - for _, cand := range resp.Candidates { - if cand.Content == nil { - continue - } - for _, part := range cand.Content.Parts { - switch v := part.(type) { - case genai.Text: - if v != "" { - if !send(StreamEvent{Type: "text", Text: string(v)}) { - return - } - } - case genai.FunctionCall: - if !send(StreamEvent{Type: "tool_call", Tool: &ToolCall{ - ID: v.Name, - Name: v.Name, - Arguments: v.Args, - }}) { - return - } - } - } - } - } - }() - return ch, nil -} - -// toGeminiContents converts provider messages to Gemini content parts. -// System messages are returned separately as they set model.SystemInstruction. -func toGeminiContents(messages []Message) ([]genai.Part, *genai.Content) { - var parts []genai.Part - var systemInstruction *genai.Content - - for _, msg := range messages { - switch msg.Role { - case RoleSystem: - systemInstruction = &genai.Content{ - Parts: []genai.Part{genai.Text(msg.Content)}, - Role: "user", - } - case RoleUser: - parts = append(parts, genai.Text(msg.Content)) - case RoleAssistant: - parts = append(parts, genai.Text(msg.Content)) - case RoleTool: - // Tool results are passed as FunctionResponse parts. - var args map[string]any - _ = json.Unmarshal([]byte(msg.Content), &args) - if args == nil { - args = map[string]any{"result": msg.Content} - } - parts = append(parts, genai.FunctionResponse{ - Name: msg.ToolCallID, - Response: args, - }) - } - } - return parts, systemInstruction -} - -// toGeminiTools converts provider tool definitions to Gemini Tool structs. -func toGeminiTools(tools []ToolDef) []*genai.Tool { - decls := make([]*genai.FunctionDeclaration, 0, len(tools)) - for _, t := range tools { - decls = append(decls, &genai.FunctionDeclaration{ - Name: t.Name, - Description: t.Description, - Parameters: toGeminiSchema(t.Parameters), - }) - } - return []*genai.Tool{{FunctionDeclarations: decls}} -} - -// toGeminiSchema converts a JSON Schema map to a genai.Schema. -func toGeminiSchema(params map[string]any) *genai.Schema { - if params == nil { - return nil - } - schema := &genai.Schema{Type: genai.TypeObject} - if props, ok := params["properties"].(map[string]any); ok { - schema.Properties = make(map[string]*genai.Schema, len(props)) - for name, val := range props { - if propMap, ok := val.(map[string]any); ok { - schema.Properties[name] = toGeminiSchemaProp(propMap) - } - } - } - if req, ok := params["required"].([]any); ok { - for _, r := range req { - if s, ok := r.(string); ok { - schema.Required = append(schema.Required, s) - } - } - } - return schema -} - -func toGeminiSchemaProp(m map[string]any) *genai.Schema { - s := &genai.Schema{} - if t, ok := m["type"].(string); ok { - switch t { - case "string": - s.Type = genai.TypeString - case "number": - s.Type = genai.TypeNumber - case "integer": - s.Type = genai.TypeInteger - case "boolean": - s.Type = genai.TypeBoolean - case "array": - s.Type = genai.TypeArray - case "object": - s.Type = genai.TypeObject - } - } - if desc, ok := m["description"].(string); ok { - s.Description = desc - } - return s -} - -// fromGeminiResponse extracts a provider Response from a Gemini GenerateContentResponse. -func fromGeminiResponse(resp *genai.GenerateContentResponse) *Response { - result := &Response{} - if resp.UsageMetadata != nil { - result.Usage = Usage{ - InputTokens: int(resp.UsageMetadata.PromptTokenCount), - OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount), - } - } - if len(resp.Candidates) == 0 { - return result - } - cand := resp.Candidates[0] - if cand.Content == nil { - return result - } - for _, part := range cand.Content.Parts { - switch v := part.(type) { - case genai.Text: - result.Content += string(v) - case genai.FunctionCall: - result.ToolCalls = append(result.ToolCalls, ToolCall{ - ID: v.Name, - Name: v.Name, - Arguments: v.Args, - }) - } - } - return result -} - -// listGeminiModels lists available Gemini models using the genai SDK. -func listGeminiModels(ctx context.Context, apiKey string) ([]ModelInfo, error) { - client, err := genai.NewClient(ctx, googleoption.WithAPIKey(apiKey)) - if err != nil { - return geminiFallbackModels(), nil - } - defer client.Close() - - iter := client.ListModels(ctx) - var models []ModelInfo - for { - m, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return geminiFallbackModels(), nil - } - models = append(models, ModelInfo{ - ID: m.Name, - Name: m.DisplayName, - }) - } - if len(models) == 0 { - return geminiFallbackModels(), nil - } - return models, nil -} - -func geminiFallbackModels() []ModelInfo { - return []ModelInfo{ - {ID: "gemini-2.5-pro-preview-03-25", Name: "Gemini 2.5 Pro Preview"}, - {ID: "gemini-2.0-flash", Name: "Gemini 2.0 Flash"}, - {ID: "gemini-2.0-flash-lite", Name: "Gemini 2.0 Flash-Lite"}, - {ID: "gemini-1.5-pro", Name: "Gemini 1.5 Pro"}, - {ID: "gemini-1.5-flash", Name: "Gemini 1.5 Flash"}, - } -} diff --git a/provider/gemini_test.go b/provider/gemini_test.go deleted file mode 100644 index dff04ff..0000000 --- a/provider/gemini_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package provider - -import ( - "testing" -) - -func TestGeminiProvider_Name(t *testing.T) { - p, err := NewGeminiProvider(GeminiConfig{APIKey: "test-key"}) - if err != nil { - t.Fatal(err) - } - if got := p.Name(); got != "gemini" { - t.Errorf("Name() = %q, want %q", got, "gemini") - } -} - -func TestGeminiProvider_AuthModeInfo(t *testing.T) { - p, err := NewGeminiProvider(GeminiConfig{APIKey: "test-key"}) - if err != nil { - t.Fatal(err) - } - info := p.AuthModeInfo() - - if info.Mode != "gemini" { - t.Errorf("AuthModeInfo().Mode = %q, want %q", info.Mode, "gemini") - } - if info.DisplayName != "Google Gemini" { - t.Errorf("AuthModeInfo().DisplayName = %q, want %q", info.DisplayName, "Google Gemini") - } - if info.DocsURL != "https://ai.google.dev/gemini-api/docs/api-key" { - t.Errorf("AuthModeInfo().DocsURL = %q, want Gemini docs URL", info.DocsURL) - } - if !info.ServerSafe { - t.Error("AuthModeInfo().ServerSafe = false, want true") - } -} - -func TestGeminiProvider_ImplementsProvider(t *testing.T) { - var _ Provider = (*GeminiProvider)(nil) -} - -func TestGeminiProvider_RequiresAPIKey(t *testing.T) { - _, err := NewGeminiProvider(GeminiConfig{}) - if err == nil { - t.Fatal("expected error when no API key provided") - } -} - -func TestGeminiProvider_DefaultModel(t *testing.T) { - p, err := NewGeminiProvider(GeminiConfig{APIKey: "test-key"}) - if err != nil { - t.Fatal(err) - } - if p.config.Model != defaultGeminiModel { - t.Errorf("default Model = %q, want %q", p.config.Model, defaultGeminiModel) - } -} - -func TestGeminiProvider_DefaultMaxTokens(t *testing.T) { - p, err := NewGeminiProvider(GeminiConfig{APIKey: "test-key"}) - if err != nil { - t.Fatal(err) - } - if p.config.MaxTokens != defaultGeminiMaxTokens { - t.Errorf("default MaxTokens = %d, want %d", p.config.MaxTokens, defaultGeminiMaxTokens) - } -} - -func TestToGeminiSchema_NilParams(t *testing.T) { - s := toGeminiSchema(nil) - if s != nil { - t.Errorf("toGeminiSchema(nil) = %v, want nil", s) - } -} - -func TestToGeminiSchema_WithProperties(t *testing.T) { - params := map[string]any{ - "type": "object", - "properties": map[string]any{ - "name": map[string]any{"type": "string", "description": "A name"}, - "age": map[string]any{"type": "integer"}, - }, - "required": []any{"name"}, - } - s := toGeminiSchema(params) - if s == nil { - t.Fatal("toGeminiSchema returned nil for valid params") - } - if len(s.Properties) != 2 { - t.Errorf("Properties count = %d, want 2", len(s.Properties)) - } - if len(s.Required) != 1 || s.Required[0] != "name" { - t.Errorf("Required = %v, want [name]", s.Required) - } -} - -func TestGeminiFallbackModels(t *testing.T) { - models := geminiFallbackModels() - if len(models) == 0 { - t.Error("geminiFallbackModels() returned empty slice") - } - for _, m := range models { - if m.ID == "" { - t.Error("fallback model has empty ID") - } - if m.Name == "" { - t.Error("fallback model has empty Name") - } - } -} diff --git a/provider/llama_cpp.go b/provider/llama_cpp.go deleted file mode 100644 index 357d7ed..0000000 --- a/provider/llama_cpp.go +++ /dev/null @@ -1,231 +0,0 @@ -package provider - -import ( - "context" - "fmt" - "net/http" - "os/exec" - "runtime" - "time" - - openaisdk "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/shared" -) - -const ( - defaultLlamaCppPort = 8081 - defaultLlamaCppGPULayers = -1 - defaultLlamaCppContextSize = 8192 - defaultLlamaCppMaxTokens = 8192 -) - -// LlamaCppConfig holds configuration for the LlamaCpp provider. -// Set BaseURL for external mode (any OpenAI-compatible server). -// Set ModelPath for managed mode (provider starts llama-server). -type LlamaCppConfig struct { - BaseURL string // external mode: OpenAI-compatible server URL - ModelPath string // managed mode: path to .gguf model file - ModelName string // model name sent to server (external mode); defaults to "local" - BinaryPath string // override llama-server binary location - GPULayers int // -ngl flag; 0 → default -1 (all layers) - ContextSize int // -c flag; default 8192 - Threads int // -t flag; default runtime.NumCPU() - Port int // server port; default 8081 - MaxTokens int // default 8192 - DisableThinking bool // send chat_template_kwargs.enable_thinking=false for reasoning models - HTTPClient *http.Client -} - -// LlamaCppProvider implements Provider using an OpenAI-compatible llama-server. -type LlamaCppProvider struct { - client openaisdk.Client - config LlamaCppConfig - cmd *exec.Cmd // non-nil in managed mode after server start -} - -// NewLlamaCppProvider creates a LlamaCppProvider with the given config. -// In external mode (BaseURL set), it points at the given URL. -// In managed mode (ModelPath set), call ensureServer before use. -func NewLlamaCppProvider(cfg LlamaCppConfig) *LlamaCppProvider { - if cfg.GPULayers == 0 { - cfg.GPULayers = defaultLlamaCppGPULayers - } - if cfg.ContextSize <= 0 { - cfg.ContextSize = defaultLlamaCppContextSize - } - if cfg.Threads <= 0 { - cfg.Threads = runtime.NumCPU() - } - if cfg.Port <= 0 { - cfg.Port = defaultLlamaCppPort - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultLlamaCppMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - - baseURL := cfg.BaseURL - if baseURL == "" { - baseURL = fmt.Sprintf("http://localhost:%d/v1", cfg.Port) - } - - client := openaisdk.NewClient( - option.WithAPIKey("no-key"), - option.WithBaseURL(baseURL), - option.WithHTTPClient(cfg.HTTPClient), - ) - return &LlamaCppProvider{client: client, config: cfg} -} - -func (p *LlamaCppProvider) Name() string { return "llama_cpp" } - -// modelName returns the model identifier to send to the server. -func (p *LlamaCppProvider) modelName() string { - if p.config.ModelName != "" { - return p.config.ModelName - } - return "local" -} - -func (p *LlamaCppProvider) AuthModeInfo() AuthModeInfo { - return LocalAuthMode("llama_cpp", "llama.cpp (Local)") -} - -// Chat sends a non-streaming request and applies ParseThinking to the response. -func (p *LlamaCppProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - params := openaisdk.ChatCompletionNewParams{ - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - params.Model = shared.ChatModel(p.modelName()) - - var opts []option.RequestOption - if p.config.DisableThinking { - opts = append(opts, option.WithJSONSet("chat_template_kwargs", map[string]any{"enable_thinking": false})) - } - - resp, err := p.client.Chat.Completions.New(ctx, params, opts...) - if err != nil { - return nil, fmt.Errorf("llama_cpp: %w", err) - } - result, err := fromOpenAIResponse(resp) - if err != nil { - return nil, err - } - result.Thinking, result.Content = ParseThinking(result.Content) - return result, nil -} - -// Stream sends a streaming request, applying ThinkingStreamParser to text events. -func (p *LlamaCppProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - params := openaisdk.ChatCompletionNewParams{ - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - params.Model = shared.ChatModel(p.modelName()) - - stream := p.client.Chat.Completions.NewStreaming(ctx, params) - rawCh := make(chan StreamEvent, 16) - go streamOpenAIEvents(stream, rawCh) - - outCh := make(chan StreamEvent, 16) - go func() { - defer close(outCh) - parser := &ThinkingStreamParser{} - for event := range rawCh { - if event.Type == "text" { - for _, e := range parser.Feed(event.Text) { - outCh <- e - } - } else { - outCh <- event - } - } - }() - return outCh, nil -} - -// EnsureServer starts the managed llama-server if ModelPath is configured. -// No-op in external mode. Must be called before Chat/Stream in managed mode. -func (p *LlamaCppProvider) EnsureServer(ctx context.Context) error { - if p.config.ModelPath == "" { - return nil // external mode, nothing to start - } - return p.ensureServer(ctx) -} - -func (p *LlamaCppProvider) ensureServer(ctx context.Context) error { - binPath := p.config.BinaryPath - if binPath == "" { - if path, err := exec.LookPath("llama-server"); err == nil { - binPath = path - } else { - var dlErr error - binPath, dlErr = EnsureLlamaServer(ctx) - if dlErr != nil { - return fmt.Errorf("llama_cpp: find llama-server: %w", dlErr) - } - } - } - - args := []string{ - "--model", p.config.ModelPath, - "--port", fmt.Sprintf("%d", p.config.Port), - "-ngl", fmt.Sprintf("%d", p.config.GPULayers), - "-c", fmt.Sprintf("%d", p.config.ContextSize), - "-t", fmt.Sprintf("%d", p.config.Threads), - } - - // #nosec G204 -- BinaryPath/binPath comes from config or PATH lookup, not user input. - cmd := exec.Command(binPath, args...) - if err := cmd.Start(); err != nil { - return fmt.Errorf("llama_cpp: start llama-server: %w", err) - } - p.cmd = cmd - - healthURL := fmt.Sprintf("http://localhost:%d/health", p.config.Port) - deadline := time.Now().Add(2 * time.Minute) - ticker := time.NewTicker(500 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - _ = p.cmd.Process.Kill() - _ = p.cmd.Wait() - return ctx.Err() - case <-ticker.C: - if time.Now().After(deadline) { - _ = p.cmd.Process.Kill() - _ = p.cmd.Wait() - return fmt.Errorf("llama_cpp: llama-server did not become healthy within timeout") - } - hReq, _ := http.NewRequestWithContext(ctx, http.MethodGet, healthURL, nil) - resp, err := p.config.HTTPClient.Do(hReq) - if err == nil { - _ = resp.Body.Close() - if resp.StatusCode == http.StatusOK { - return nil - } - } - } - } -} - -// Close kills the managed llama-server process if one was started -// and waits for it to exit to avoid zombie processes. -func (p *LlamaCppProvider) Close() error { - if p.cmd != nil && p.cmd.Process != nil { - _ = p.cmd.Process.Kill() - _ = p.cmd.Wait() // reap the child process - } - return nil -} diff --git a/provider/llama_cpp_test.go b/provider/llama_cpp_test.go deleted file mode 100644 index 7d50889..0000000 --- a/provider/llama_cpp_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" -) - -// newLlamaCppTestServer starts an httptest server returning OpenAI-compatible responses. -// If thinking is non-empty, the content wraps it in tags. -func newLlamaCppTestServer(t *testing.T, thinking, content string) *httptest.Server { - t.Helper() - rawContent := content - if thinking != "" { - rawContent = fmt.Sprintf("%s%s", thinking, content) - } - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/v1/chat/completions" { - w.Header().Set("Content-Type", "application/json") - resp := map[string]any{ - "id": "chatcmpl-test", - "object": "chat.completion", - "choices": []any{map[string]any{"message": map[string]any{"role": "assistant", "content": rawContent}, "finish_reason": "stop"}}, - "usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 5}, - } - _ = json.NewEncoder(w).Encode(resp) - return - } - http.NotFound(w, r) - })) -} - -// newLlamaCppStreamServer starts an httptest server returning SSE streaming responses. -func newLlamaCppStreamServer(t *testing.T, chunks []string) *httptest.Server { - t.Helper() - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/v1/chat/completions" { - w.Header().Set("Content-Type", "text/event-stream") - for _, chunk := range chunks { - delta := map[string]any{ - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "choices": []any{map[string]any{"delta": map[string]any{"content": chunk}, "finish_reason": nil}}, - } - b, _ := json.Marshal(delta) - fmt.Fprintf(w, "data: %s\n\n", b) - } - fmt.Fprintf(w, "data: [DONE]\n\n") - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - return - } - http.NotFound(w, r) - })) -} - -func TestLlamaCppProvider_Defaults(t *testing.T) { - p := NewLlamaCppProvider(LlamaCppConfig{BaseURL: "http://localhost:8081/v1"}) - if p.config.GPULayers != defaultLlamaCppGPULayers { - t.Errorf("GPULayers: want %d, got %d", defaultLlamaCppGPULayers, p.config.GPULayers) - } - if p.config.ContextSize != defaultLlamaCppContextSize { - t.Errorf("ContextSize: want %d, got %d", defaultLlamaCppContextSize, p.config.ContextSize) - } - if p.config.MaxTokens != defaultLlamaCppMaxTokens { - t.Errorf("MaxTokens: want %d, got %d", defaultLlamaCppMaxTokens, p.config.MaxTokens) - } - if p.config.Port != defaultLlamaCppPort { - t.Errorf("Port: want %d, got %d", defaultLlamaCppPort, p.config.Port) - } - if p.Name() != "llama_cpp" { - t.Errorf("Name: want %q, got %q", "llama_cpp", p.Name()) - } -} - -func TestLlamaCppProvider_Chat_NoThinking(t *testing.T) { - srv := newLlamaCppTestServer(t, "", "The sky is blue.") - defer srv.Close() - - p := NewLlamaCppProvider(LlamaCppConfig{ - BaseURL: srv.URL + "/v1", - HTTPClient: srv.Client(), - }) - - resp, err := p.Chat(context.Background(), []Message{{Role: RoleUser, Content: "What color is the sky?"}}, nil) - if err != nil { - t.Fatalf("Chat: %v", err) - } - if resp.Content != "The sky is blue." { - t.Errorf("Content: want %q, got %q", "The sky is blue.", resp.Content) - } - if resp.Thinking != "" { - t.Errorf("Thinking: want empty, got %q", resp.Thinking) - } -} - -func TestLlamaCppProvider_Chat_WithThinking(t *testing.T) { - srv := newLlamaCppTestServer(t, "Let me reason step by step.", "The answer is 42.") - defer srv.Close() - - p := NewLlamaCppProvider(LlamaCppConfig{ - BaseURL: srv.URL + "/v1", - HTTPClient: srv.Client(), - }) - - resp, err := p.Chat(context.Background(), []Message{{Role: RoleUser, Content: "What is the answer?"}}, nil) - if err != nil { - t.Fatalf("Chat: %v", err) - } - if resp.Thinking != "Let me reason step by step." { - t.Errorf("Thinking: want %q, got %q", "Let me reason step by step.", resp.Thinking) - } - if resp.Content != "The answer is 42." { - t.Errorf("Content: want %q, got %q", "The answer is 42.", resp.Content) - } -} - -func TestLlamaCppProvider_Stream_WithThinking(t *testing.T) { - // Stream chunks that spell out reasoninganswer - chunks := []string{"", "reasoning", "", "answer"} - srv := newLlamaCppStreamServer(t, chunks) - defer srv.Close() - - p := NewLlamaCppProvider(LlamaCppConfig{ - BaseURL: srv.URL + "/v1", - HTTPClient: srv.Client(), - }) - - ch, err := p.Stream(context.Background(), []Message{{Role: RoleUser, Content: "question"}}, nil) - if err != nil { - t.Fatalf("Stream: %v", err) - } - - var thinkingText, contentText string - for event := range ch { - switch event.Type { - case "thinking": - thinkingText += event.Thinking - case "text": - contentText += event.Text - case "error": - t.Fatalf("stream error: %s", event.Error) - } - } - if thinkingText != "reasoning" { - t.Errorf("thinking: want %q, got %q", "reasoning", thinkingText) - } - if contentText != "answer" { - t.Errorf("content: want %q, got %q", "answer", contentText) - } -} - -func TestLlamaCppProvider_Close_NoProcess(t *testing.T) { - p := NewLlamaCppProvider(LlamaCppConfig{BaseURL: "http://localhost:8081/v1"}) - if err := p.Close(); err != nil { - t.Errorf("Close with no process: want nil, got %v", err) - } -} - -func TestLlamaCppProvider_EnsureServer_ExternalMode(t *testing.T) { - p := NewLlamaCppProvider(LlamaCppConfig{BaseURL: "http://localhost:8081/v1"}) - // External mode: EnsureServer should be a no-op. - if err := p.EnsureServer(context.Background()); err != nil { - t.Errorf("EnsureServer external mode: want nil, got %v", err) - } -} diff --git a/provider/models.go b/provider/models.go index a5106f8..14bf5cb 100644 --- a/provider/models.go +++ b/provider/models.go @@ -8,6 +8,10 @@ import ( "net/http" "sort" "strings" + + "github.com/google/generative-ai-go/genai" + "google.golang.org/api/iterator" + googleoption "google.golang.org/api/option" ) // ModelInfo describes an available model from a provider. @@ -17,6 +21,23 @@ type ModelInfo struct { ContextWindow int `json:"context_window,omitempty"` } +// Constants used by model-listing functions, sourced from former provider files. +const ( + defaultAnthropicBaseURL = "https://api.anthropic.com" + anthropicAPIVersion = "2023-06-01" + defaultOpenAIBaseURL = "https://api.openai.com" + defaultCopilotBaseURL = "https://api.githubcopilot.com" + copilotTokenExchangeURL = "https://api.github.com/copilot_internal/v2/token" + copilotEditorVersion = "ratchet/0.1.0" + defaultCohereBaseURL = "https://api.cohere.com" +) + +// copilotTokenResponse is the response from the Copilot token exchange endpoint. +type copilotTokenResponse struct { + Token string `json:"token"` + ExpiresAt int64 `json:"expires_at"` +} + // ListModels fetches available models from the given provider type. // Only requires an API key and optional base URL — no saved provider needed. func ListModels(ctx context.Context, providerType, apiKey, baseURL string) ([]ModelInfo, error) { @@ -429,10 +450,49 @@ func foundryFallbackModels() []ModelInfo { } } +// listGeminiModels lists available Gemini models using the genai SDK. +func listGeminiModels(ctx context.Context, apiKey string) ([]ModelInfo, error) { + client, err := genai.NewClient(ctx, googleoption.WithAPIKey(apiKey)) + if err != nil { + return geminiFallbackModels(), nil + } + defer client.Close() + + iter := client.ListModels(ctx) + var models []ModelInfo + for { + m, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return geminiFallbackModels(), nil + } + models = append(models, ModelInfo{ + ID: m.Name, + Name: m.DisplayName, + }) + } + if len(models) == 0 { + return geminiFallbackModels(), nil + } + return models, nil +} + +func geminiFallbackModels() []ModelInfo { + return []ModelInfo{ + {ID: "gemini-2.5-pro-preview-03-25", Name: "Gemini 2.5 Pro Preview"}, + {ID: "gemini-2.0-flash", Name: "Gemini 2.0 Flash"}, + {ID: "gemini-2.0-flash-lite", Name: "Gemini 2.0 Flash-Lite"}, + {ID: "gemini-1.5-pro", Name: "Gemini 1.5 Pro"}, + {ID: "gemini-1.5-flash", Name: "Gemini 1.5 Flash"}, + } +} + // listOllamaModels lists models available on a local Ollama server. func listOllamaModels(ctx context.Context, baseURL string) ([]ModelInfo, error) { - p := NewOllamaProvider(OllamaConfig{BaseURL: baseURL}) - return p.ListModels(ctx) + c := NewOllamaClient(baseURL) + return c.ListModels(ctx) } func truncate(s string, maxLen int) string { diff --git a/provider/ollama.go b/provider/ollama.go deleted file mode 100644 index 633b5bb..0000000 --- a/provider/ollama.go +++ /dev/null @@ -1,163 +0,0 @@ -package provider - -import ( - "context" - "fmt" - "net/http" - "net/url" - - ollamaapi "github.com/ollama/ollama/api" -) - -const defaultOllamaBaseURL = "http://localhost:11434" - -// OllamaConfig holds configuration for the Ollama provider. -type OllamaConfig struct { - Model string - BaseURL string - MaxTokens int - HTTPClient *http.Client -} - -// OllamaProvider implements Provider using a local Ollama server. -type OllamaProvider struct { - client *ollamaapi.Client - config OllamaConfig -} - -// NewOllamaProvider creates a new OllamaProvider with the given config. -func NewOllamaProvider(cfg OllamaConfig) *OllamaProvider { - if cfg.BaseURL == "" { - cfg.BaseURL = defaultOllamaBaseURL - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - base, err := url.Parse(cfg.BaseURL) - if err != nil { - // Fall back to default URL if configured one is invalid. - base, _ = url.Parse(defaultOllamaBaseURL) - cfg.BaseURL = defaultOllamaBaseURL - } - return &OllamaProvider{ - client: ollamaapi.NewClient(base, cfg.HTTPClient), - config: cfg, - } -} - -func (p *OllamaProvider) Name() string { return "ollama" } - -func (p *OllamaProvider) AuthModeInfo() AuthModeInfo { - return LocalAuthMode("ollama", "Ollama (Local)") -} - -// Chat sends a non-streaming request and returns the complete response. -func (p *OllamaProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - streamFalse := false - req := &ollamaapi.ChatRequest{ - Model: p.config.Model, - Messages: toOllamaMessages(messages), - Stream: &streamFalse, - Options: map[string]any{}, - } - if len(tools) > 0 { - req.Tools = toOllamaTools(tools) - } - if p.config.MaxTokens > 0 { - req.Options["num_predict"] = p.config.MaxTokens - } - - var final ollamaapi.ChatResponse - if err := p.client.Chat(ctx, req, func(resp ollamaapi.ChatResponse) error { - final = resp - return nil - }); err != nil { - return nil, fmt.Errorf("ollama: %w", err) - } - return fromOllamaResponse(final), nil -} - -// Stream sends a streaming request and emits events on the returned channel. -// Thinking tokens extracted via ThinkingStreamParser are emitted as "thinking" events. -func (p *OllamaProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - req := &ollamaapi.ChatRequest{ - Model: p.config.Model, - Messages: toOllamaMessages(messages), - Options: map[string]any{}, - } - if len(tools) > 0 { - req.Tools = toOllamaTools(tools) - } - if p.config.MaxTokens > 0 { - req.Options["num_predict"] = p.config.MaxTokens - } - - ch := make(chan StreamEvent, 16) - go func() { - defer close(ch) - var parser ThinkingStreamParser - var usage *Usage - - err := p.client.Chat(ctx, req, func(resp ollamaapi.ChatResponse) error { - text, toolCalls, done := fromOllamaStreamChunk(resp) - if text != "" { - for _, ev := range parser.Feed(text) { - ch <- ev - } - } - for i := range toolCalls { - ch <- StreamEvent{Type: "tool_call", Tool: &toolCalls[i]} - } - if done { - usage = &Usage{ - InputTokens: resp.PromptEvalCount, - OutputTokens: resp.EvalCount, - } - } - return nil - }) - if err != nil { - ch <- StreamEvent{Type: "error", Error: err.Error()} - return - } - ch <- StreamEvent{Type: "done", Usage: usage} - }() - return ch, nil -} - -// Pull downloads a model via the Ollama server. -// progressFn is called with percent completion (0–100); may be nil. -func (p *OllamaProvider) Pull(ctx context.Context, model string, progressFn func(pct float64)) error { - req := &ollamaapi.PullRequest{Model: model} - return p.client.Pull(ctx, req, func(resp ollamaapi.ProgressResponse) error { - if progressFn != nil && resp.Total > 0 { - progressFn(float64(resp.Completed) / float64(resp.Total) * 100) - } - return nil - }) -} - -// ListModels returns the models available on the Ollama server. -func (p *OllamaProvider) ListModels(ctx context.Context) ([]ModelInfo, error) { - resp, err := p.client.List(ctx) - if err != nil { - return nil, fmt.Errorf("ollama: list models: %w", err) - } - models := make([]ModelInfo, 0, len(resp.Models)) - for _, m := range resp.Models { - name := m.Name - if name == "" { - name = m.Model - } - models = append(models, ModelInfo{ID: name, Name: name}) - } - return models, nil -} - -// Health checks whether the Ollama server is reachable. -func (p *OllamaProvider) Health(ctx context.Context) error { - if err := p.client.Heartbeat(ctx); err != nil { - return fmt.Errorf("ollama: health check: %w", err) - } - return nil -} diff --git a/provider/ollama_client.go b/provider/ollama_client.go new file mode 100644 index 0000000..f04b75e --- /dev/null +++ b/provider/ollama_client.go @@ -0,0 +1,71 @@ +package provider + +import ( + "context" + "fmt" + "net/http" + "net/url" + + ollamaapi "github.com/ollama/ollama/api" +) + +const defaultOllamaClientURL = "http://localhost:11434" + +// OllamaClient provides utility operations against a local Ollama server: +// pulling models and listing available models. For chat/stream, use the +// Genkit-backed provider returned by genkit.NewOllamaProvider. +type OllamaClient struct { + client *ollamaapi.Client +} + +// NewOllamaClient creates an OllamaClient pointing at the given server address. +// If serverAddress is empty, http://localhost:11434 is used. +func NewOllamaClient(serverAddress string) *OllamaClient { + if serverAddress == "" { + serverAddress = defaultOllamaClientURL + } + base, err := url.Parse(serverAddress) + if err != nil { + base, _ = url.Parse(defaultOllamaClientURL) + } + return &OllamaClient{ + client: ollamaapi.NewClient(base, http.DefaultClient), + } +} + +// Pull downloads a model via the Ollama server. +// progressFn is called with percent completion (0–100); may be nil. +func (c *OllamaClient) Pull(ctx context.Context, model string, progressFn func(pct float64)) error { + req := &ollamaapi.PullRequest{Model: model} + return c.client.Pull(ctx, req, func(resp ollamaapi.ProgressResponse) error { + if progressFn != nil && resp.Total > 0 { + progressFn(float64(resp.Completed) / float64(resp.Total) * 100) + } + return nil + }) +} + +// ListModels returns the models available on the Ollama server. +func (c *OllamaClient) ListModels(ctx context.Context) ([]ModelInfo, error) { + resp, err := c.client.List(ctx) + if err != nil { + return nil, fmt.Errorf("ollama: list models: %w", err) + } + models := make([]ModelInfo, 0, len(resp.Models)) + for _, m := range resp.Models { + name := m.Name + if name == "" { + name = m.Model + } + models = append(models, ModelInfo{ID: name, Name: name}) + } + return models, nil +} + +// Health checks whether the Ollama server is reachable. +func (c *OllamaClient) Health(ctx context.Context) error { + if err := c.client.Heartbeat(ctx); err != nil { + return fmt.Errorf("ollama: health check: %w", err) + } + return nil +} diff --git a/provider/ollama_convert.go b/provider/ollama_convert.go deleted file mode 100644 index 2e90568..0000000 --- a/provider/ollama_convert.go +++ /dev/null @@ -1,96 +0,0 @@ -package provider - -import ( - "encoding/json" - - ollamaapi "github.com/ollama/ollama/api" -) - -// toOllamaMessages converts provider messages to Ollama API messages. -func toOllamaMessages(msgs []Message) []ollamaapi.Message { - result := make([]ollamaapi.Message, 0, len(msgs)) - for _, msg := range msgs { - m := ollamaapi.Message{ - Role: string(msg.Role), - Content: msg.Content, - ToolCallID: msg.ToolCallID, - } - for _, tc := range msg.ToolCalls { - args := ollamaapi.NewToolCallFunctionArguments() - for k, v := range tc.Arguments { - args.Set(k, v) - } - m.ToolCalls = append(m.ToolCalls, ollamaapi.ToolCall{ - ID: tc.ID, - Function: ollamaapi.ToolCallFunction{ - Name: tc.Name, - Arguments: args, - }, - }) - } - result = append(result, m) - } - return result -} - -// toOllamaTools converts provider tool definitions to Ollama API tools. -// The JSON Schema parameters are marshaled and unmarshaled into the Ollama type. -func toOllamaTools(tools []ToolDef) []ollamaapi.Tool { - result := make([]ollamaapi.Tool, 0, len(tools)) - for _, t := range tools { - var params ollamaapi.ToolFunctionParameters - if b, err := json.Marshal(t.Parameters); err == nil { - _ = json.Unmarshal(b, ¶ms) - } - result = append(result, ollamaapi.Tool{ - Type: "function", - Function: ollamaapi.ToolFunction{ - Name: t.Name, - Description: t.Description, - Parameters: params, - }, - }) - } - return result -} - -// fromOllamaResponse converts a completed Ollama ChatResponse to a provider Response. -// ParseThinking is applied to extract any ... block from the content. -func fromOllamaResponse(resp ollamaapi.ChatResponse) *Response { - thinking, content := ParseThinking(resp.Message.Content) - // Prefer native thinking field if set (Ollama Think mode) - if resp.Message.Thinking != "" { - thinking = resp.Message.Thinking - content = resp.Message.Content - } - r := &Response{ - Content: content, - Thinking: thinking, - Usage: Usage{ - InputTokens: resp.PromptEvalCount, - OutputTokens: resp.EvalCount, - }, - } - for _, tc := range resp.Message.ToolCalls { - r.ToolCalls = append(r.ToolCalls, ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: tc.Function.Arguments.ToMap(), - }) - } - return r -} - -// fromOllamaStreamChunk extracts text content, tool calls, and done flag from a -// streaming Ollama ChatResponse chunk. -func fromOllamaStreamChunk(resp ollamaapi.ChatResponse) (text string, toolCalls []ToolCall, done bool) { - text = resp.Message.Content - for _, tc := range resp.Message.ToolCalls { - toolCalls = append(toolCalls, ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: tc.Function.Arguments.ToMap(), - }) - } - return text, toolCalls, resp.Done -} diff --git a/provider/ollama_convert_test.go b/provider/ollama_convert_test.go deleted file mode 100644 index 1e3d51c..0000000 --- a/provider/ollama_convert_test.go +++ /dev/null @@ -1,233 +0,0 @@ -package provider - -import ( - "testing" - "time" - - ollamaapi "github.com/ollama/ollama/api" -) - -// --- toOllamaMessages --- - -func TestToOllamaMessages_Roles(t *testing.T) { - msgs := []Message{ - {Role: RoleSystem, Content: "sys"}, - {Role: RoleUser, Content: "hi"}, - {Role: RoleAssistant, Content: "hello"}, - {Role: RoleTool, Content: "tool result", ToolCallID: "tc1"}, - } - got := toOllamaMessages(msgs) - if len(got) != 4 { - t.Fatalf("len=%d, want 4", len(got)) - } - if got[0].Role != "system" || got[0].Content != "sys" { - t.Errorf("system msg wrong: %+v", got[0]) - } - if got[1].Role != "user" || got[1].Content != "hi" { - t.Errorf("user msg wrong: %+v", got[1]) - } - if got[2].Role != "assistant" || got[2].Content != "hello" { - t.Errorf("assistant msg wrong: %+v", got[2]) - } - if got[3].Role != "tool" || got[3].ToolCallID != "tc1" { - t.Errorf("tool msg wrong: %+v", got[3]) - } -} - -func TestToOllamaMessages_ToolCalls(t *testing.T) { - msgs := []Message{ - { - Role: RoleAssistant, - ToolCalls: []ToolCall{ - {ID: "id1", Name: "search", Arguments: map[string]any{"q": "hello"}}, - }, - }, - } - got := toOllamaMessages(msgs) - if len(got[0].ToolCalls) != 1 { - t.Fatalf("tool calls len=%d, want 1", len(got[0].ToolCalls)) - } - tc := got[0].ToolCalls[0] - if tc.ID != "id1" { - t.Errorf("ID=%q, want %q", tc.ID, "id1") - } - if tc.Function.Name != "search" { - t.Errorf("Name=%q, want %q", tc.Function.Name, "search") - } - v, ok := tc.Function.Arguments.Get("q") - if !ok || v != "hello" { - t.Errorf("arg q=%v ok=%v, want 'hello'", v, ok) - } -} - -// --- toOllamaTools --- - -func TestToOllamaTools(t *testing.T) { - tools := []ToolDef{ - { - Name: "calculator", - Description: "does math", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "expr": map[string]any{"type": "string"}, - }, - "required": []any{"expr"}, - }, - }, - } - got := toOllamaTools(tools) - if len(got) != 1 { - t.Fatalf("len=%d, want 1", len(got)) - } - if got[0].Type != "function" { - t.Errorf("Type=%q, want 'function'", got[0].Type) - } - if got[0].Function.Name != "calculator" { - t.Errorf("Name=%q", got[0].Function.Name) - } - if got[0].Function.Description != "does math" { - t.Errorf("Description=%q", got[0].Function.Description) - } -} - -func TestToOllamaTools_Empty(t *testing.T) { - got := toOllamaTools(nil) - if len(got) != 0 { - t.Errorf("expected empty slice, got %d", len(got)) - } -} - -// --- fromOllamaResponse --- - -func TestFromOllamaResponse_Basic(t *testing.T) { - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{Role: "assistant", Content: "hello world"}, - Metrics: ollamaapi.Metrics{PromptEvalCount: 10, EvalCount: 5}, - } - got := fromOllamaResponse(resp) - if got.Content != "hello world" { - t.Errorf("Content=%q, want %q", got.Content, "hello world") - } - if got.Thinking != "" { - t.Errorf("Thinking=%q, want empty", got.Thinking) - } - if got.Usage.InputTokens != 10 { - t.Errorf("InputTokens=%d, want 10", got.Usage.InputTokens) - } - if got.Usage.OutputTokens != 5 { - t.Errorf("OutputTokens=%d, want 5", got.Usage.OutputTokens) - } -} - -func TestFromOllamaResponse_ThinkTagExtraction(t *testing.T) { - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{ - Role: "assistant", - Content: "reasoninganswer", - }, - } - got := fromOllamaResponse(resp) - if got.Thinking != "reasoning" { - t.Errorf("Thinking=%q, want %q", got.Thinking, "reasoning") - } - if got.Content != "answer" { - t.Errorf("Content=%q, want %q", got.Content, "answer") - } -} - -func TestFromOllamaResponse_NativeThinkingField(t *testing.T) { - // When Thinking is set natively (Ollama Think mode), prefer it. - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{ - Role: "assistant", - Content: "answer", - Thinking: "native thinking", - }, - } - got := fromOllamaResponse(resp) - if got.Thinking != "native thinking" { - t.Errorf("Thinking=%q, want %q", got.Thinking, "native thinking") - } - if got.Content != "answer" { - t.Errorf("Content=%q, want %q", got.Content, "answer") - } -} - -func TestFromOllamaResponse_ToolCalls(t *testing.T) { - args := ollamaapi.NewToolCallFunctionArguments() - args.Set("query", "test") - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{ - Role: "assistant", - ToolCalls: []ollamaapi.ToolCall{ - {ID: "tc1", Function: ollamaapi.ToolCallFunction{Name: "search", Arguments: args}}, - }, - }, - } - got := fromOllamaResponse(resp) - if len(got.ToolCalls) != 1 { - t.Fatalf("ToolCalls len=%d, want 1", len(got.ToolCalls)) - } - tc := got.ToolCalls[0] - if tc.ID != "tc1" || tc.Name != "search" { - t.Errorf("tc=%+v", tc) - } - if tc.Arguments["query"] != "test" { - t.Errorf("arg query=%v, want 'test'", tc.Arguments["query"]) - } -} - -// --- fromOllamaStreamChunk --- - -func TestFromOllamaStreamChunk_Text(t *testing.T) { - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{Role: "assistant", Content: "chunk text"}, - Done: false, - CreatedAt: time.Now(), - } - text, toolCalls, done := fromOllamaStreamChunk(resp) - if text != "chunk text" { - t.Errorf("text=%q, want %q", text, "chunk text") - } - if len(toolCalls) != 0 { - t.Errorf("toolCalls=%v, want empty", toolCalls) - } - if done { - t.Error("done should be false") - } -} - -func TestFromOllamaStreamChunk_Done(t *testing.T) { - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{Role: "assistant", Content: ""}, - Done: true, - CreatedAt: time.Now(), - } - _, _, done := fromOllamaStreamChunk(resp) - if !done { - t.Error("done should be true") - } -} - -func TestFromOllamaStreamChunk_ToolCalls(t *testing.T) { - args := ollamaapi.NewToolCallFunctionArguments() - args.Set("x", 42) - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{ - Role: "assistant", - ToolCalls: []ollamaapi.ToolCall{ - {ID: "tc2", Function: ollamaapi.ToolCallFunction{Name: "calc", Arguments: args}}, - }, - }, - Done: true, - CreatedAt: time.Now(), - } - _, toolCalls, _ := fromOllamaStreamChunk(resp) - if len(toolCalls) != 1 { - t.Fatalf("len=%d, want 1", len(toolCalls)) - } - if toolCalls[0].Name != "calc" { - t.Errorf("Name=%q, want %q", toolCalls[0].Name, "calc") - } -} diff --git a/provider/ollama_test.go b/provider/ollama_test.go deleted file mode 100644 index b0b31ea..0000000 --- a/provider/ollama_test.go +++ /dev/null @@ -1,217 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" -) - -// ollamaChatNDJSON builds a minimal Ollama chat NDJSON response line. -func ollamaChatNDJSON(content string, done bool) string { - resp := map[string]any{ - "model": "test", - "created_at": time.Now().Format(time.RFC3339), - "message": map[string]any{ - "role": "assistant", - "content": content, - }, - "done": done, - "prompt_eval_count": 3, - "eval_count": 7, - } - b, _ := json.Marshal(resp) - return string(b) + "\n" -} - -// ollamaMockServer sets up a minimal Ollama HTTP server for unit tests. -func ollamaMockServer(t *testing.T, chatContent string) *httptest.Server { - t.Helper() - mux := http.NewServeMux() - - // HEAD / — heartbeat - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodHead { - w.WriteHeader(http.StatusOK) - } - }) - - // POST /api/chat — chat (streaming NDJSON) - mux.HandleFunc("/api/chat", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/x-ndjson") - w.Write([]byte(ollamaChatNDJSON(chatContent, true))) - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - }) - - // GET /api/tags — list models - mux.HandleFunc("/api/tags", func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(map[string]any{ - "models": []map[string]any{ - {"name": "llama3.2:latest", "model": "llama3.2:latest", "size": int64(4000000000)}, - }, - }) - }) - - return httptest.NewServer(mux) -} - -func TestOllamaProvider_ImplementsProvider(t *testing.T) { - var _ Provider = (*OllamaProvider)(nil) -} - -func TestOllamaProvider_Name(t *testing.T) { - p := NewOllamaProvider(OllamaConfig{Model: "qwen3.5:7b"}) - if got := p.Name(); got != "ollama" { - t.Errorf("Name()=%q, want %q", got, "ollama") - } -} - -func TestOllamaProvider_AuthModeInfo(t *testing.T) { - p := NewOllamaProvider(OllamaConfig{}) - info := p.AuthModeInfo() - if info.Mode != "ollama" { - t.Errorf("Mode=%q, want %q", info.Mode, "ollama") - } - if !info.ServerSafe { - t.Error("ServerSafe should be true") - } - if info.Warning != "" { - t.Errorf("Warning should be empty, got %q", info.Warning) - } -} - -func TestOllamaProvider_DefaultBaseURL(t *testing.T) { - p := NewOllamaProvider(OllamaConfig{}) - if p.config.BaseURL != defaultOllamaBaseURL { - t.Errorf("BaseURL=%q, want %q", p.config.BaseURL, defaultOllamaBaseURL) - } -} - -func TestOllamaProvider_Health(t *testing.T) { - srv := ollamaMockServer(t, "") - defer srv.Close() - - p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "test"}) - if err := p.Health(context.Background()); err != nil { - t.Errorf("Health() error: %v", err) - } -} - -func TestOllamaProvider_Chat(t *testing.T) { - srv := ollamaMockServer(t, "hello world") - defer srv.Close() - - p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "test"}) - resp, err := p.Chat(context.Background(), []Message{{Role: RoleUser, Content: "hi"}}, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if resp.Content != "hello world" { - t.Errorf("Content=%q, want %q", resp.Content, "hello world") - } - if resp.Usage.InputTokens != 3 { - t.Errorf("InputTokens=%d, want 3", resp.Usage.InputTokens) - } - if resp.Usage.OutputTokens != 7 { - t.Errorf("OutputTokens=%d, want 7", resp.Usage.OutputTokens) - } -} - -func TestOllamaProvider_ChatWithThinkTags(t *testing.T) { - srv := ollamaMockServer(t, "reasoninganswer") - defer srv.Close() - - p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "test"}) - resp, err := p.Chat(context.Background(), []Message{{Role: RoleUser, Content: "q"}}, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if resp.Thinking != "reasoning" { - t.Errorf("Thinking=%q, want %q", resp.Thinking, "reasoning") - } - if resp.Content != "answer" { - t.Errorf("Content=%q, want %q", resp.Content, "answer") - } -} - -func TestOllamaProvider_Stream(t *testing.T) { - srv := ollamaMockServer(t, "stream result") - defer srv.Close() - - p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "test"}) - ch, err := p.Stream(context.Background(), []Message{{Role: RoleUser, Content: "hi"}}, nil) - if err != nil { - t.Fatalf("Stream() error: %v", err) - } - - var textParts []string - var sawDone bool - for ev := range ch { - switch ev.Type { - case "text": - textParts = append(textParts, ev.Text) - case "done": - sawDone = true - case "error": - t.Fatalf("stream error: %s", ev.Error) - } - } - if !sawDone { - t.Error("expected done event") - } - full := strings.Join(textParts, "") - if full != "stream result" { - t.Errorf("text=%q, want %q", full, "stream result") - } -} - -func TestOllamaProvider_StreamThinkTags(t *testing.T) { - srv := ollamaMockServer(t, "thinkingcontent") - defer srv.Close() - - p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "test"}) - ch, err := p.Stream(context.Background(), []Message{{Role: RoleUser, Content: "q"}}, nil) - if err != nil { - t.Fatalf("Stream() error: %v", err) - } - - var thinkParts, textParts []string - for ev := range ch { - switch ev.Type { - case "thinking": - thinkParts = append(thinkParts, ev.Thinking) - case "text": - textParts = append(textParts, ev.Text) - case "error": - t.Fatalf("stream error: %s", ev.Error) - } - } - if strings.Join(thinkParts, "") != "thinking" { - t.Errorf("thinking=%q, want %q", strings.Join(thinkParts, ""), "thinking") - } - if strings.Join(textParts, "") != "content" { - t.Errorf("text=%q, want %q", strings.Join(textParts, ""), "content") - } -} - -func TestOllamaProvider_ListModels(t *testing.T) { - srv := ollamaMockServer(t, "") - defer srv.Close() - - p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "test"}) - models, err := p.ListModels(context.Background()) - if err != nil { - t.Fatalf("ListModels() error: %v", err) - } - if len(models) != 1 { - t.Fatalf("len=%d, want 1", len(models)) - } - if models[0].ID != "llama3.2:latest" { - t.Errorf("ID=%q, want %q", models[0].ID, "llama3.2:latest") - } -} diff --git a/provider/openai.go b/provider/openai.go deleted file mode 100644 index ffb9b68..0000000 --- a/provider/openai.go +++ /dev/null @@ -1,97 +0,0 @@ -package provider - -import ( - "context" - "fmt" - "net/http" - - openaisdk "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/shared" -) - -const ( - defaultOpenAIBaseURL = "https://api.openai.com" - defaultOpenAIModel = "gpt-4o" - defaultOpenAIMaxTokens = 4096 -) - -// OpenAIConfig holds configuration for the OpenAI provider. -type OpenAIConfig struct { - APIKey string - Model string - BaseURL string - MaxTokens int - HTTPClient *http.Client -} - -// OpenAIProvider implements Provider using the OpenAI Chat Completions API. -type OpenAIProvider struct { - client openaisdk.Client - config OpenAIConfig -} - -// NewOpenAIProvider creates a new OpenAI provider with the given config. -func NewOpenAIProvider(cfg OpenAIConfig) *OpenAIProvider { - if cfg.Model == "" { - cfg.Model = defaultOpenAIModel - } - if cfg.BaseURL == "" { - cfg.BaseURL = defaultOpenAIBaseURL - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultOpenAIMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - client := openaisdk.NewClient( - option.WithAPIKey(cfg.APIKey), - option.WithBaseURL(cfg.BaseURL), - option.WithHTTPClient(cfg.HTTPClient), - ) - return &OpenAIProvider{client: client, config: cfg} -} - -func (p *OpenAIProvider) Name() string { return "openai" } - -func (p *OpenAIProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "direct", - DisplayName: "OpenAI (Direct API)", - Description: "Direct access to OpenAI models via API key.", - DocsURL: "https://platform.openai.com/docs/api-reference/introduction", - ServerSafe: true, - } -} - -func (p *OpenAIProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - params := openaisdk.ChatCompletionNewParams{ - Model: shared.ChatModel(p.config.Model), - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - resp, err := p.client.Chat.Completions.New(ctx, params) - if err != nil { - return nil, fmt.Errorf("openai: %w", err) - } - return fromOpenAIResponse(resp) -} - -func (p *OpenAIProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - params := openaisdk.ChatCompletionNewParams{ - Model: shared.ChatModel(p.config.Model), - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - stream := p.client.Chat.Completions.NewStreaming(ctx, params) - ch := make(chan StreamEvent, 16) - go streamOpenAIEvents(stream, ch) - return ch, nil -} diff --git a/provider/openai_azure.go b/provider/openai_azure.go deleted file mode 100644 index 68b70b5..0000000 --- a/provider/openai_azure.go +++ /dev/null @@ -1,132 +0,0 @@ -package provider - -import ( - "context" - "fmt" - "net/http" - - openaisdk "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/shared" -) - -const defaultAzureOpenAIAPIVersion = "2024-10-21" - -// OpenAIAzureConfig configures the OpenAI provider for Azure OpenAI Service. -// Uses Azure API keys or Entra ID tokens. URLs follow the pattern: -// {resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version} -type OpenAIAzureConfig struct { - // Resource is the Azure OpenAI resource name. - Resource string - // DeploymentName is the model deployment name in Azure. - DeploymentName string - // APIVersion is the Azure API version (e.g. "2024-10-21"). - APIVersion string - // MaxTokens limits the response length. - MaxTokens int - // APIKey is the Azure API key (use this OR Entra ID token, not both). - APIKey string - // EntraToken is a Microsoft Entra ID bearer token (optional, alternative to APIKey). - EntraToken string - // HTTPClient overrides the default HTTP client. - HTTPClient *http.Client - // BaseURL overrides the computed Azure endpoint (used in tests). - BaseURL string -} - -// OpenAIAzureProvider accesses OpenAI models via Azure OpenAI Service. -type OpenAIAzureProvider struct { - client openaisdk.Client - config OpenAIAzureConfig -} - -func (p *OpenAIAzureProvider) Name() string { return "openai_azure" } - -func (p *OpenAIAzureProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "azure", - DisplayName: "OpenAI (Azure OpenAI Service)", - Description: "Access OpenAI models via Azure OpenAI Service using Azure API keys or Microsoft Entra ID tokens. Uses deployment-specific URLs.", - DocsURL: "https://learn.microsoft.com/en-us/azure/ai-services/openai/reference", - ServerSafe: true, - } -} - -// NewOpenAIAzureProvider creates a provider that accesses OpenAI models via Azure. -// -// Docs: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference -func NewOpenAIAzureProvider(cfg OpenAIAzureConfig) (*OpenAIAzureProvider, error) { - if cfg.Resource == "" && cfg.BaseURL == "" { - return nil, fmt.Errorf("openai_azure: Resource is required") - } - if cfg.DeploymentName == "" && cfg.BaseURL == "" { - return nil, fmt.Errorf("openai_azure: DeploymentName is required") - } - if cfg.APIKey == "" && cfg.EntraToken == "" { - return nil, fmt.Errorf("openai_azure: APIKey or EntraToken is required") - } - if cfg.APIVersion == "" { - cfg.APIVersion = defaultAzureOpenAIAPIVersion - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultOpenAIMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - - baseURL := cfg.BaseURL - if baseURL == "" { - baseURL = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s", - cfg.Resource, cfg.DeploymentName) - } - - opts := []option.RequestOption{ - // Use a placeholder API key to prevent OPENAI_API_KEY env var from being used, - // then remove the resulting Authorization header for Azure auth. - option.WithAPIKey("azure-placeholder"), - option.WithHeaderDel("authorization"), - option.WithBaseURL(baseURL), - option.WithQuery("api-version", cfg.APIVersion), - option.WithHTTPClient(cfg.HTTPClient), - } - if cfg.APIKey != "" { - opts = append(opts, option.WithHeader("api-key", cfg.APIKey)) - } else if cfg.EntraToken != "" { - opts = append(opts, option.WithHeader("Authorization", "Bearer "+cfg.EntraToken)) - } - - client := openaisdk.NewClient(opts...) - return &OpenAIAzureProvider{client: client, config: cfg}, nil -} - -func (p *OpenAIAzureProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - params := openaisdk.ChatCompletionNewParams{ - Model: shared.ChatModel(p.config.DeploymentName), - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - resp, err := p.client.Chat.Completions.New(ctx, params) - if err != nil { - return nil, fmt.Errorf("openai_azure: %w", err) - } - return fromOpenAIResponse(resp) -} - -func (p *OpenAIAzureProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - params := openaisdk.ChatCompletionNewParams{ - Model: shared.ChatModel(p.config.DeploymentName), - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - stream := p.client.Chat.Completions.NewStreaming(ctx, params) - ch := make(chan StreamEvent, 16) - go streamOpenAIEvents(stream, ch) - return ch, nil -} diff --git a/provider/openai_azure_test.go b/provider/openai_azure_test.go deleted file mode 100644 index 7b7ae9c..0000000 --- a/provider/openai_azure_test.go +++ /dev/null @@ -1,218 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestOpenAIAzureProvider_Name(t *testing.T) { - p, err := NewOpenAIAzureProvider(OpenAIAzureConfig{ - Resource: "myresource", DeploymentName: "gpt-4o", APIKey: "key123", - }) - if err != nil { - t.Fatal(err) - } - if got := p.Name(); got != "openai_azure" { - t.Errorf("Name() = %q, want %q", got, "openai_azure") - } -} - -func TestOpenAIAzureProvider_AuthModeInfo(t *testing.T) { - p, err := NewOpenAIAzureProvider(OpenAIAzureConfig{ - Resource: "myresource", DeploymentName: "gpt-4o", APIKey: "key123", - }) - if err != nil { - t.Fatal(err) - } - info := p.AuthModeInfo() - if info.Mode != "azure" { - t.Errorf("AuthModeInfo().Mode = %q, want %q", info.Mode, "azure") - } - if info.DisplayName != "OpenAI (Azure OpenAI Service)" { - t.Errorf("AuthModeInfo().DisplayName = %q, want %q", info.DisplayName, "OpenAI (Azure OpenAI Service)") - } - if !info.ServerSafe { - t.Error("AuthModeInfo().ServerSafe = false, want true") - } -} - -func TestOpenAIAzureProvider_ImplementsProvider(t *testing.T) { - var _ Provider = (*OpenAIAzureProvider)(nil) -} - -func TestOpenAIAzureProvider_ValidationErrors(t *testing.T) { - tests := []struct { - name string - cfg OpenAIAzureConfig - want string - }{ - {"missing resource", OpenAIAzureConfig{DeploymentName: "d", APIKey: "k"}, "Resource is required"}, - {"missing deployment", OpenAIAzureConfig{Resource: "r", APIKey: "k"}, "DeploymentName is required"}, - {"missing auth", OpenAIAzureConfig{Resource: "r", DeploymentName: "d"}, "APIKey or EntraToken is required"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := NewOpenAIAzureProvider(tt.cfg) - if err == nil { - t.Fatal("expected error") - } - if got := err.Error(); !strings.Contains(got, tt.want) { - t.Errorf("error = %q, want to contain %q", got, tt.want) - } - }) - } -} - -func TestOpenAIAzureProvider_DefaultAPIVersion(t *testing.T) { - p, err := NewOpenAIAzureProvider(OpenAIAzureConfig{ - Resource: "res", DeploymentName: "dep", APIKey: "key", - }) - if err != nil { - t.Fatal(err) - } - if p.config.APIVersion != defaultAzureOpenAIAPIVersion { - t.Errorf("APIVersion = %q, want %q", p.config.APIVersion, defaultAzureOpenAIAPIVersion) - } -} - -func TestOpenAIAzureProvider_APIKeyAuth(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("api-key"); got != "azure-key-123" { - t.Errorf("api-key header = %q, want %q", got, "azure-key-123") - } - if got := r.Header.Get("Authorization"); got != "" { - t.Errorf("Authorization header should be empty with API key auth, got %q", got) - } - - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"chatcmpl-azure","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"hello from azure"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8},"created":1704067200}`) - })) - defer srv.Close() - - p, err := NewOpenAIAzureProvider(OpenAIAzureConfig{ - Resource: "test", - DeploymentName: "gpt-4o", - APIKey: "azure-key-123", - APIVersion: "2024-10-21", - MaxTokens: 4096, - HTTPClient: http.DefaultClient, - BaseURL: srv.URL, - }) - if err != nil { - t.Fatal(err) - } - - got, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if got.Content != "hello from azure" { - t.Errorf("Chat() content = %q, want %q", got.Content, "hello from azure") - } -} - -func TestOpenAIAzureProvider_EntraTokenAuth(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("Authorization"); got != "Bearer entra-token-xyz" { - t.Errorf("Authorization header = %q, want %q", got, "Bearer entra-token-xyz") - } - if got := r.Header.Get("api-key"); got != "" { - t.Errorf("api-key header should be empty with Entra auth, got %q", got) - } - - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"chatcmpl-entra","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"hello from entra"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8},"created":1704067200}`) - })) - defer srv.Close() - - p, err := NewOpenAIAzureProvider(OpenAIAzureConfig{ - Resource: "test", - DeploymentName: "gpt-4o", - EntraToken: "entra-token-xyz", - APIVersion: "2024-10-21", - MaxTokens: 4096, - HTTPClient: http.DefaultClient, - BaseURL: srv.URL, - }) - if err != nil { - t.Fatal(err) - } - - got, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if got.Content != "hello from entra" { - t.Errorf("Chat() content = %q, want %q", got.Content, "hello from entra") - } -} - -func TestOpenAIAzureProvider_Stream(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - flusher, ok := w.(http.Flusher) - if !ok { - t.Fatal("expected http.Flusher") - } - - data, _ := json.Marshal(map[string]any{ - "id": "chatcmpl-azure-stream", - "object": "chat.completion.chunk", - "model": "gpt-4o", - "created": 1704067200, - "choices": []map[string]any{ - {"index": 0, "delta": map[string]any{"role": "assistant", "content": "azure-streamed"}, "finish_reason": nil}, - }, - }) - fmt.Fprintf(w, "data: %s\n\n", string(data)) - flusher.Flush() - - fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() - })) - defer srv.Close() - - p, err := NewOpenAIAzureProvider(OpenAIAzureConfig{ - Resource: "test", - DeploymentName: "gpt-4o", - APIKey: "azure-key", - APIVersion: "2024-10-21", - MaxTokens: 4096, - HTTPClient: http.DefaultClient, - BaseURL: srv.URL, - }) - if err != nil { - t.Fatal(err) - } - - ch, err := p.Stream(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Stream() error: %v", err) - } - - var texts []string - for ev := range ch { - switch ev.Type { - case "text": - texts = append(texts, ev.Text) - case "error": - t.Fatalf("stream error: %s", ev.Error) - } - } - if len(texts) == 0 { - t.Fatal("expected at least one text event") - } - if texts[0] != "azure-streamed" { - t.Errorf("first text event = %q, want %q", texts[0], "azure-streamed") - } -} diff --git a/provider/openai_convert.go b/provider/openai_convert.go deleted file mode 100644 index ce9f0ef..0000000 --- a/provider/openai_convert.go +++ /dev/null @@ -1,196 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "sort" - "strings" - - openaisdk "github.com/openai/openai-go" - "github.com/openai/openai-go/shared" -) - -// toOpenAIMessages converts provider messages to SDK message params. -func toOpenAIMessages(msgs []Message) []openaisdk.ChatCompletionMessageParamUnion { - result := make([]openaisdk.ChatCompletionMessageParamUnion, 0, len(msgs)) - for _, msg := range msgs { - switch msg.Role { - case RoleTool: - result = append(result, openaisdk.ToolMessage(msg.Content, msg.ToolCallID)) - case RoleAssistant: - asst := openaisdk.ChatCompletionAssistantMessageParam{} - if msg.Content != "" { - asst.Content = openaisdk.ChatCompletionAssistantMessageParamContentUnion{ - OfString: openaisdk.String(msg.Content), - } - } - for _, tc := range msg.ToolCalls { - args := "{}" - if tc.Arguments != nil { - if b, err := json.Marshal(tc.Arguments); err == nil { - args = string(b) - } - } - asst.ToolCalls = append(asst.ToolCalls, openaisdk.ChatCompletionMessageToolCallParam{ - ID: tc.ID, - Function: openaisdk.ChatCompletionMessageToolCallFunctionParam{ - Name: tc.Name, - Arguments: args, - }, - }) - } - result = append(result, openaisdk.ChatCompletionMessageParamUnion{OfAssistant: &asst}) - case RoleSystem: - result = append(result, openaisdk.SystemMessage(msg.Content)) - default: // RoleUser and others - result = append(result, openaisdk.UserMessage(msg.Content)) - } - } - return result -} - -// toOpenAITools converts provider tool definitions to SDK tool params. -func toOpenAITools(tools []ToolDef) []openaisdk.ChatCompletionToolParam { - result := make([]openaisdk.ChatCompletionToolParam, 0, len(tools)) - for _, t := range tools { - schema := shared.FunctionParameters(t.Parameters) - if schema == nil { - schema = shared.FunctionParameters{"type": "object", "properties": map[string]any{}} - } - result = append(result, openaisdk.ChatCompletionToolParam{ - Function: shared.FunctionDefinitionParam{ - Name: t.Name, - Description: openaisdk.String(t.Description), - Parameters: schema, - }, - }) - } - return result -} - -// fromOpenAIResponse converts an SDK ChatCompletion to the provider Response type. -func fromOpenAIResponse(resp *openaisdk.ChatCompletion) (*Response, error) { - result := &Response{ - Usage: Usage{ - InputTokens: int(resp.Usage.PromptTokens), - OutputTokens: int(resp.Usage.CompletionTokens), - }, - } - if len(resp.Choices) == 0 { - return result, nil - } - msg := resp.Choices[0].Message - result.Content = msg.Content - // For reasoning models (Qwen3.5, etc.) that put output in reasoning_content - // instead of content, extract it from the raw JSON extra fields. - if result.Content == "" { - if rc, ok := msg.JSON.ExtraFields["reasoning_content"]; ok && rc.Valid() { - var reasoning string - if err := json.Unmarshal([]byte(rc.Raw()), &reasoning); err == nil && reasoning != "" { - result.Content = reasoning - } - } - } - for _, tc := range msg.ToolCalls { - var args map[string]any - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { - return nil, fmt.Errorf("openai: unmarshal tool call arguments for %q: %w", tc.Function.Name, err) - } - } - result.ToolCalls = append(result.ToolCalls, ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: args, - }) - } - return result, nil -} - -// openaiChunkStream is satisfied by *ssestream.Stream[openaisdk.ChatCompletionChunk]. -type openaiChunkStream interface { - Next() bool - Current() openaisdk.ChatCompletionChunk - Err() error - Close() error -} - -// streamOpenAIEvents drains stream and sends StreamEvents to ch, then closes ch. -func streamOpenAIEvents(stream openaiChunkStream, ch chan<- StreamEvent) { - defer close(ch) - defer stream.Close() - - type pendingToolCall struct { - id string - name string - argsBuf strings.Builder - } - pending := make(map[int64]*pendingToolCall) - var usage *Usage - - for stream.Next() { - chunk := stream.Current() - - if chunk.Usage.PromptTokens > 0 || chunk.Usage.CompletionTokens > 0 { - usage = &Usage{ - InputTokens: int(chunk.Usage.PromptTokens), - OutputTokens: int(chunk.Usage.CompletionTokens), - } - } - - if len(chunk.Choices) == 0 { - continue - } - - delta := chunk.Choices[0].Delta - - if delta.Content != "" { - ch <- StreamEvent{Type: "text", Text: delta.Content} - } - - for _, tc := range delta.ToolCalls { - ptc, exists := pending[tc.Index] - if !exists { - ptc = &pendingToolCall{} - pending[tc.Index] = ptc - } - if tc.ID != "" { - ptc.id = tc.ID - } - if tc.Function.Name != "" { - ptc.name = tc.Function.Name - } - if tc.Function.Arguments != "" { - ptc.argsBuf.WriteString(tc.Function.Arguments) - } - } - } - - if err := stream.Err(); err != nil { - ch <- StreamEvent{Type: "error", Error: err.Error()} - return - } - - indices := make([]int64, 0, len(pending)) - for idx := range pending { - indices = append(indices, idx) - } - sort.Slice(indices, func(i, j int) bool { return indices[i] < indices[j] }) - - for _, idx := range indices { - ptc := pending[idx] - var args map[string]any - if ptc.argsBuf.Len() > 0 { - if err := json.Unmarshal([]byte(ptc.argsBuf.String()), &args); err != nil { - ch <- StreamEvent{Type: "error", Error: fmt.Sprintf("openai: unmarshal tool call arguments for %q: %v", ptc.name, err)} - return - } - } - ch <- StreamEvent{ - Type: "tool_call", - Tool: &ToolCall{ID: ptc.id, Name: ptc.name, Arguments: args}, - } - } - - ch <- StreamEvent{Type: "done", Usage: usage} -} diff --git a/provider/openrouter.go b/provider/openrouter.go deleted file mode 100644 index 6e75bb5..0000000 --- a/provider/openrouter.go +++ /dev/null @@ -1,44 +0,0 @@ -package provider - -const defaultOpenRouterBaseURL = "https://openrouter.ai/api/v1" - -// OpenRouterConfig configures the OpenRouter provider. -type OpenRouterConfig struct { - APIKey string - Model string - BaseURL string - MaxTokens int -} - -// OpenRouterProvider wraps OpenAIProvider with OpenRouter-specific identity and auth info. -type OpenRouterProvider struct { - *OpenAIProvider -} - -func (p *OpenRouterProvider) Name() string { return "openrouter" } - -func (p *OpenRouterProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "openrouter", - DisplayName: "OpenRouter", - Description: "Access multiple AI models via OpenRouter's unified API.", - DocsURL: "https://openrouter.ai/docs/api/reference/authentication", - ServerSafe: true, - } -} - -// NewOpenRouterProvider creates a provider that uses OpenRouter's OpenAI-compatible API. -func NewOpenRouterProvider(cfg OpenRouterConfig) *OpenRouterProvider { - baseURL := cfg.BaseURL - if baseURL == "" { - baseURL = defaultOpenRouterBaseURL - } - return &OpenRouterProvider{ - OpenAIProvider: NewOpenAIProvider(OpenAIConfig{ - APIKey: cfg.APIKey, - Model: cfg.Model, - BaseURL: baseURL, - MaxTokens: cfg.MaxTokens, - }), - } -} diff --git a/provider/openrouter_test.go b/provider/openrouter_test.go deleted file mode 100644 index 0825cd3..0000000 --- a/provider/openrouter_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" -) - -func TestOpenRouterProvider_Name(t *testing.T) { - p := NewOpenRouterProvider(OpenRouterConfig{APIKey: "test-key"}) - if got := p.Name(); got != "openrouter" { - t.Errorf("Name() = %q, want %q", got, "openrouter") - } -} - -func TestOpenRouterProvider_AuthModeInfo(t *testing.T) { - p := NewOpenRouterProvider(OpenRouterConfig{APIKey: "test-key"}) - info := p.AuthModeInfo() - - if info.Mode != "openrouter" { - t.Errorf("AuthModeInfo().Mode = %q, want %q", info.Mode, "openrouter") - } - if info.DisplayName != "OpenRouter" { - t.Errorf("AuthModeInfo().DisplayName = %q, want %q", info.DisplayName, "OpenRouter") - } - if info.DocsURL != "https://openrouter.ai/docs/api/reference/authentication" { - t.Errorf("AuthModeInfo().DocsURL = %q, want openrouter docs URL", info.DocsURL) - } - if !info.ServerSafe { - t.Error("AuthModeInfo().ServerSafe = false, want true") - } -} - -func TestOpenRouterProvider_ImplementsProvider(t *testing.T) { - var _ Provider = (*OpenRouterProvider)(nil) -} - -func TestOpenRouterProvider_DefaultBaseURL(t *testing.T) { - p := NewOpenRouterProvider(OpenRouterConfig{APIKey: "test-key"}) - if p.config.BaseURL != defaultOpenRouterBaseURL { - t.Errorf("default BaseURL = %q, want %q", p.config.BaseURL, defaultOpenRouterBaseURL) - } -} - -func TestOpenRouterProvider_CustomBaseURL(t *testing.T) { - custom := "https://custom.openrouter.ai/api/v1" - p := NewOpenRouterProvider(OpenRouterConfig{APIKey: "test-key", BaseURL: custom}) - if p.config.BaseURL != custom { - t.Errorf("BaseURL = %q, want %q", p.config.BaseURL, custom) - } -} - -func TestOpenRouterProvider_Chat(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("Authorization"); got != "Bearer test-key" { - t.Errorf("Authorization header = %q, want %q", got, "Bearer test-key") - } - - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"chatcmpl-test","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"hello from openrouter"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15},"created":1704067200}`) - })) - defer srv.Close() - - p := NewOpenRouterProvider(OpenRouterConfig{ - APIKey: "test-key", - Model: "meta-llama/llama-3-70b", - BaseURL: srv.URL, - }) - - got, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if got.Content != "hello from openrouter" { - t.Errorf("Chat() content = %q, want %q", got.Content, "hello from openrouter") - } - if got.Usage.InputTokens != 10 || got.Usage.OutputTokens != 5 { - t.Errorf("Chat() usage = %+v, want input=10 output=5", got.Usage) - } -} - -func TestOpenRouterProvider_Stream(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - flusher, ok := w.(http.Flusher) - if !ok { - t.Fatal("expected http.Flusher") - } - - data, _ := json.Marshal(map[string]any{ - "id": "chatcmpl-stream", - "object": "chat.completion.chunk", - "model": "gpt-4o", - "created": 1704067200, - "choices": []map[string]any{ - {"index": 0, "delta": map[string]any{"role": "assistant", "content": "streamed"}, "finish_reason": nil}, - }, - }) - fmt.Fprintf(w, "data: %s\n\n", string(data)) - flusher.Flush() - - fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() - })) - defer srv.Close() - - p := NewOpenRouterProvider(OpenRouterConfig{ - APIKey: "test-key", - Model: "meta-llama/llama-3-70b", - BaseURL: srv.URL, - }) - - ch, err := p.Stream(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Stream() error: %v", err) - } - - var texts []string - for ev := range ch { - switch ev.Type { - case "text": - texts = append(texts, ev.Text) - case "error": - t.Fatalf("stream error: %s", ev.Error) - } - } - if len(texts) == 0 { - t.Fatal("expected at least one text event") - } - if texts[0] != "streamed" { - t.Errorf("first text event = %q, want %q", texts[0], "streamed") - } -} diff --git a/provider_registry.go b/provider_registry.go index 1985b14..658d9da 100644 --- a/provider_registry.go +++ b/provider_registry.go @@ -8,6 +8,7 @@ import ( "time" "github.com/GoCodeAlone/modular" + gkprov "github.com/GoCodeAlone/workflow-plugin-agent/genkit" "github.com/GoCodeAlone/workflow-plugin-agent/provider" "github.com/GoCodeAlone/workflow/config" "github.com/GoCodeAlone/workflow/module" @@ -52,54 +53,30 @@ func NewProviderRegistry(db *sql.DB, secretsProvider secrets.Provider) *Provider return &mockProvider{responses: []string{"I have completed the task."}}, nil } r.Factories["anthropic"] = func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewAnthropicProvider(provider.AnthropicConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + return gkprov.NewAnthropicProvider(context.Background(), apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } r.Factories["openai"] = func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewOpenAIProvider(provider.OpenAIConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + return gkprov.NewOpenAIProvider(context.Background(), apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } r.Factories["openrouter"] = func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - if cfg.BaseURL == "" { - cfg.BaseURL = "https://openrouter.ai/api/v1" + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://openrouter.ai/api/v1" } - return provider.NewOpenAIProvider(provider.OpenAIConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + return gkprov.NewOpenAICompatibleProvider(context.Background(), "openrouter", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } r.Factories["copilot"] = func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewCopilotProvider(provider.CopilotConfig{ - Token: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://api.githubcopilot.com" + } + return gkprov.NewOpenAICompatibleProvider(context.Background(), "copilot", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } r.Factories["ollama"] = func(_ string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewOllamaProvider(provider.OllamaConfig{ - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + return gkprov.NewOllamaProvider(context.Background(), cfg.Model, cfg.BaseURL, cfg.MaxTokens) } r.Factories["llama_cpp"] = func(_ string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewLlamaCppProvider(provider.LlamaCppConfig{ - BaseURL: cfg.BaseURL, - ModelPath: cfg.Model, - ModelName: cfg.Model, - MaxTokens: cfg.MaxTokens, - }), nil + return gkprov.NewOpenAICompatibleProvider(context.Background(), "llama_cpp", "", cfg.Model, cfg.BaseURL, cfg.MaxTokens) } return r diff --git a/step_model_pull.go b/step_model_pull.go index 614e325..7d02936 100644 --- a/step_model_pull.go +++ b/step_model_pull.go @@ -44,13 +44,10 @@ func (s *ModelPullStep) Execute(ctx context.Context, _ *module.PipelineContext) } func (s *ModelPullStep) pullOllama(ctx context.Context) (*module.StepResult, error) { - p := provider.NewOllamaProvider(provider.OllamaConfig{ - Model: s.model, - BaseURL: s.ollamaBase, - }) + client := provider.NewOllamaClient(s.ollamaBase) // Check if model is already available. - models, err := p.ListModels(ctx) + models, err := client.ListModels(ctx) if err == nil { for _, m := range models { if m.ID == s.model || m.Name == s.model { @@ -66,7 +63,7 @@ func (s *ModelPullStep) pullOllama(ctx context.Context) (*module.StepResult, err } // Pull (download) the model. - if pullErr := p.Pull(ctx, s.model, nil); pullErr != nil { + if pullErr := client.Pull(ctx, s.model, nil); pullErr != nil { return &module.StepResult{ Output: map[string]any{ "status": "error", From 771cf9f33e9b825cbe9fe42d276d37a54e9e7f2e Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 06:43:20 -0400 Subject: [PATCH 09/14] chore: run go mod tidy after provider cleanup (Task 7) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Demotes golang.org/x/oauth2 to indirect — no longer directly used after removing old hand-rolled provider implementations. Co-Authored-By: Claude Sonnet 4.6 --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 088ad56..3e42e28 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,6 @@ require ( github.com/ollama/ollama v0.18.3 github.com/openai/openai-go v1.12.0 golang.org/x/crypto v0.48.0 - golang.org/x/oauth2 v0.36.0 golang.org/x/text v0.34.0 google.golang.org/api v0.271.0 gopkg.in/yaml.v3 v3.0.1 @@ -247,6 +246,7 @@ require ( golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa // indirect golang.org/x/mod v0.33.0 // indirect golang.org/x/net v0.51.0 // indirect + golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/time v0.15.0 // indirect From a099a2969ad3819727803446f0be4a5f8d027d1b Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 06:44:42 -0400 Subject: [PATCH 10/14] test: add integration tests for all Genkit providers (Task 8) Integration tests behind //go:build integration tag. Each test skips if the required API key env var is unset, covering: - Anthropic, OpenAI, Google AI, Ollama, OpenRouter - AWS Bedrock, Vertex AI, Azure OpenAI, Anthropic Foundry - provider.Provider interface satisfaction - streaming thinking trace propagation Run with: go test -tags integration ./genkit/... Co-Authored-By: Claude Sonnet 4.6 --- genkit/integration_test.go | 263 +++++++++++++++++++++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 genkit/integration_test.go diff --git a/genkit/integration_test.go b/genkit/integration_test.go new file mode 100644 index 0000000..3513d25 --- /dev/null +++ b/genkit/integration_test.go @@ -0,0 +1,263 @@ +//go:build integration + +package genkit_test + +import ( + "context" + "os" + "strings" + "testing" + + "github.com/GoCodeAlone/workflow-plugin-agent/genkit" + "github.com/GoCodeAlone/workflow-plugin-agent/provider" +) + +// envOrSkip returns the value of an environment variable, or skips the test if unset. +func envOrSkip(t *testing.T, key string) string { + t.Helper() + v := os.Getenv(key) + if v == "" { + t.Skipf("skipping: %s not set", key) + } + return v +} + +// chatRoundTrip sends a simple message and asserts a non-empty text response. +func chatRoundTrip(t *testing.T, p provider.Provider) { + t.Helper() + ctx := context.Background() + msgs := []provider.Message{{Role: provider.RoleUser, Content: "Say exactly: pong"}} + resp, err := p.Chat(ctx, msgs, nil) + if err != nil { + t.Fatalf("Chat error: %v", err) + } + if resp.Content == "" { + t.Error("expected non-empty content") + } +} + +// streamRoundTrip streams a simple message and asserts at least one text event and a done event. +func streamRoundTrip(t *testing.T, p provider.Provider) { + t.Helper() + ctx := context.Background() + msgs := []provider.Message{{Role: provider.RoleUser, Content: "Say exactly: pong"}} + ch, err := p.Stream(ctx, msgs, nil) + if err != nil { + t.Fatalf("Stream error: %v", err) + } + var gotText, gotDone bool + for ev := range ch { + switch ev.Type { + case "text": + if ev.Text != "" { + gotText = true + } + case "done": + gotDone = true + case "error": + t.Fatalf("stream error event: %s", ev.Error) + } + } + if !gotText { + t.Error("expected at least one text event") + } + if !gotDone { + t.Error("expected a done event") + } +} + +// TestIntegration_Anthropic tests the Anthropic (direct API) provider. +func TestIntegration_Anthropic(t *testing.T) { + apiKey := envOrSkip(t, "ANTHROPIC_API_KEY") + p, err := genkit.NewAnthropicProvider(context.Background(), apiKey, "claude-haiku-4-5-20251001", "", 256) + if err != nil { + t.Fatalf("NewAnthropicProvider: %v", err) + } + if p.Name() != "anthropic" { + t.Errorf("expected name 'anthropic', got %q", p.Name()) + } + chatRoundTrip(t, p) + streamRoundTrip(t, p) +} + +// TestIntegration_OpenAI tests the OpenAI provider. +func TestIntegration_OpenAI(t *testing.T) { + apiKey := envOrSkip(t, "OPENAI_API_KEY") + p, err := genkit.NewOpenAIProvider(context.Background(), apiKey, "gpt-4o-mini", "", 256) + if err != nil { + t.Fatalf("NewOpenAIProvider: %v", err) + } + if p.Name() != "openai" { + t.Errorf("expected name 'openai', got %q", p.Name()) + } + chatRoundTrip(t, p) + streamRoundTrip(t, p) +} + +// TestIntegration_GoogleAI tests the Google AI (Gemini API) provider. +func TestIntegration_GoogleAI(t *testing.T) { + apiKey := envOrSkip(t, "GOOGLE_AI_API_KEY") + p, err := genkit.NewGoogleAIProvider(context.Background(), apiKey, "gemini-2.0-flash", 256) + if err != nil { + t.Fatalf("NewGoogleAIProvider: %v", err) + } + if p.Name() != "googleai" { + t.Errorf("expected name 'googleai', got %q", p.Name()) + } + chatRoundTrip(t, p) + streamRoundTrip(t, p) +} + +// TestIntegration_Ollama tests the Ollama local provider. +func TestIntegration_Ollama(t *testing.T) { + serverAddr := os.Getenv("OLLAMA_SERVER") + if serverAddr == "" { + serverAddr = "http://localhost:11434" + } + model := os.Getenv("OLLAMA_MODEL") + if model == "" { + t.Skip("skipping: OLLAMA_MODEL not set") + } + p, err := genkit.NewOllamaProvider(context.Background(), model, serverAddr, 256) + if err != nil { + t.Fatalf("NewOllamaProvider: %v", err) + } + if p.Name() != "ollama" { + t.Errorf("expected name 'ollama', got %q", p.Name()) + } + chatRoundTrip(t, p) + streamRoundTrip(t, p) +} + +// TestIntegration_OpenRouter tests the OpenRouter provider (OpenAI-compatible). +func TestIntegration_OpenRouter(t *testing.T) { + apiKey := envOrSkip(t, "OPENROUTER_API_KEY") + model := os.Getenv("OPENROUTER_MODEL") + if model == "" { + model = "openai/gpt-4o-mini" + } + p, err := genkit.NewOpenAICompatibleProvider( + context.Background(), "openrouter", apiKey, model, + "https://openrouter.ai/api/v1", 256, + ) + if err != nil { + t.Fatalf("NewOpenAICompatibleProvider: %v", err) + } + chatRoundTrip(t, p) +} + +// TestIntegration_Bedrock tests the AWS Bedrock Anthropic provider. +func TestIntegration_Bedrock(t *testing.T) { + accessKey := envOrSkip(t, "AWS_ACCESS_KEY_ID") + secretKey := envOrSkip(t, "AWS_SECRET_ACCESS_KEY") + region := os.Getenv("AWS_REGION") + if region == "" { + region = "us-east-1" + } + p, err := genkit.NewBedrockProvider( + context.Background(), + region, "anthropic.claude-haiku-4-20250514-v1:0", + accessKey, secretKey, "", "", 256, + ) + if err != nil { + t.Fatalf("NewBedrockProvider: %v", err) + } + chatRoundTrip(t, p) +} + +// TestIntegration_VertexAI tests the Google Vertex AI provider. +func TestIntegration_VertexAI(t *testing.T) { + projectID := envOrSkip(t, "VERTEX_PROJECT_ID") + region := os.Getenv("VERTEX_REGION") + if region == "" { + region = "us-central1" + } + p, err := genkit.NewVertexAIProvider( + context.Background(), + projectID, region, "gemini-2.0-flash", "", 256, + ) + if err != nil { + t.Fatalf("NewVertexAIProvider: %v", err) + } + chatRoundTrip(t, p) +} + +// TestIntegration_AzureOpenAI tests the Azure OpenAI provider. +func TestIntegration_AzureOpenAI(t *testing.T) { + resource := envOrSkip(t, "AZURE_OPENAI_RESOURCE") + deploymentName := envOrSkip(t, "AZURE_OPENAI_DEPLOYMENT") + apiKey := envOrSkip(t, "AZURE_OPENAI_API_KEY") + p, err := genkit.NewAzureOpenAIProvider( + context.Background(), + resource, deploymentName, "2024-10-21", apiKey, "", 256, + ) + if err != nil { + t.Fatalf("NewAzureOpenAIProvider: %v", err) + } + chatRoundTrip(t, p) +} + +// TestIntegration_AnthropicFoundry tests the Anthropic on Azure AI Foundry provider. +func TestIntegration_AnthropicFoundry(t *testing.T) { + resource := envOrSkip(t, "FOUNDRY_RESOURCE") + apiKey := envOrSkip(t, "FOUNDRY_API_KEY") + model := os.Getenv("FOUNDRY_MODEL") + if model == "" { + model = "claude-haiku-4-20250514" + } + p, err := genkit.NewAnthropicFoundryProvider( + context.Background(), resource, model, apiKey, "", 256, + ) + if err != nil { + t.Fatalf("NewAnthropicFoundryProvider: %v", err) + } + chatRoundTrip(t, p) +} + +// TestIntegration_ProviderInterface verifies that all concrete providers satisfy provider.Provider. +func TestIntegration_ProviderInterface(t *testing.T) { + apiKey := envOrSkip(t, "ANTHROPIC_API_KEY") + p, err := genkit.NewAnthropicProvider(context.Background(), apiKey, "claude-haiku-4-5-20251001", "", 256) + if err != nil { + t.Fatalf("NewAnthropicProvider: %v", err) + } + var _ provider.Provider = p + if p.AuthModeInfo().Mode == "" { + t.Error("expected non-empty AuthModeInfo.Mode") + } +} + +// TestIntegration_StreamThinkingTrace verifies thinking traces propagate via streaming. +// Uses a model that supports extended thinking (claude-sonnet or claude-3-7-sonnet). +func TestIntegration_StreamThinkingTrace(t *testing.T) { + apiKey := envOrSkip(t, "ANTHROPIC_API_KEY") + // claude-3-7-sonnet-20250219 supports extended thinking + p, err := genkit.NewAnthropicProvider(context.Background(), apiKey, "claude-sonnet-4-20250514", "", 2048) + if err != nil { + t.Fatalf("NewAnthropicProvider: %v", err) + } + ctx := context.Background() + msgs := []provider.Message{{Role: provider.RoleUser, Content: "What is 2+2? Think step by step."}} + ch, err := p.Stream(ctx, msgs, nil) + if err != nil { + t.Fatalf("Stream error: %v", err) + } + var textBuf strings.Builder + var gotDone bool + for ev := range ch { + switch ev.Type { + case "text": + textBuf.WriteString(ev.Text) + case "done": + gotDone = true + case "error": + t.Fatalf("stream error: %s", ev.Error) + } + } + if !gotDone { + t.Error("expected done event") + } + if textBuf.Len() == 0 { + t.Error("expected non-empty text response") + } +} From b9df7a9bbccce70fe407e9c068f3121ba154e34a Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 06:47:20 -0400 Subject: [PATCH 11/14] fix: apply maxTokens via ai.GenerationCommonConfig in genkitProvider Store maxTokens in genkitProvider struct and apply it to every Chat/Stream call via ai.WithConfig(&ai.GenerationCommonConfig{ MaxOutputTokens: maxTokens}) when non-zero. All provider factory functions now populate the field. Fixes spec-reviewer Issue 2: maxTokens was silently dropped. Co-Authored-By: Claude Sonnet 4.6 --- genkit/adapter.go | 17 +++++++++++++++++ genkit/providers.go | 8 ++++++++ 2 files changed, 25 insertions(+) diff --git a/genkit/adapter.go b/genkit/adapter.go index 3a40158..d79e496 100644 --- a/genkit/adapter.go +++ b/genkit/adapter.go @@ -16,6 +16,7 @@ type genkitProvider struct { modelName string // "provider/model" format e.g. "anthropic/claude-sonnet-4-6" name string authInfo provider.AuthModeInfo + maxTokens int // 0 means use model default mu sync.Mutex definedTools map[string]bool // tracks which tool names are registered @@ -55,6 +56,14 @@ func (p *genkitProvider) resolveToolRefs(tools []provider.ToolDef) []ai.ToolRef return refs } +// generationConfig returns a WithConfig option when maxTokens is configured. +func (p *genkitProvider) generationConfig() ai.GenerateOption { + if p.maxTokens > 0 { + return ai.WithConfig(&ai.GenerationCommonConfig{MaxOutputTokens: p.maxTokens}) + } + return nil +} + // Chat sends a non-streaming request and returns the complete response. func (p *genkitProvider) Chat(ctx context.Context, messages []provider.Message, tools []provider.ToolDef) (*provider.Response, error) { opts := []ai.GenerateOption{ @@ -62,6 +71,10 @@ func (p *genkitProvider) Chat(ctx context.Context, messages []provider.Message, ai.WithMessages(toGenkitMessages(messages)...), } + if cfg := p.generationConfig(); cfg != nil { + opts = append(opts, cfg) + } + if len(tools) > 0 { opts = append(opts, ai.WithReturnToolRequests(true)) for _, ref := range p.resolveToolRefs(tools) { @@ -86,6 +99,10 @@ func (p *genkitProvider) Stream(ctx context.Context, messages []provider.Message ai.WithMessages(toGenkitMessages(messages)...), } + if cfg := p.generationConfig(); cfg != nil { + opts = append(opts, cfg) + } + if len(tools) > 0 { opts = append(opts, ai.WithReturnToolRequests(true)) for _, ref := range p.resolveToolRefs(tools) { diff --git a/genkit/providers.go b/genkit/providers.go index 2272293..95cde7b 100644 --- a/genkit/providers.go +++ b/genkit/providers.go @@ -31,6 +31,7 @@ func NewAnthropicProvider(ctx context.Context, apiKey, model, baseURL string, ma g: g, modelName: "anthropic/" + model, name: "anthropic", + maxTokens: maxTokens, authInfo: provider.AuthModeInfo{ Mode: "api_key", DisplayName: "Anthropic", @@ -54,6 +55,7 @@ func NewOpenAIProvider(ctx context.Context, apiKey, model, baseURL string, maxTo g: g, modelName: "openai/" + model, name: "openai", + maxTokens: maxTokens, authInfo: provider.AuthModeInfo{ Mode: "api_key", DisplayName: "OpenAI", @@ -73,6 +75,7 @@ func NewGoogleAIProvider(ctx context.Context, apiKey, model string, maxTokens in g: g, modelName: "googleai/" + model, name: "googleai", + maxTokens: maxTokens, authInfo: provider.AuthModeInfo{ Mode: "api_key", DisplayName: "Google AI (Gemini)", @@ -92,6 +95,7 @@ func NewOllamaProvider(ctx context.Context, model, serverAddress string, maxToke g: g, modelName: "ollama/" + model, name: "ollama", + maxTokens: maxTokens, authInfo: provider.AuthModeInfo{ Mode: "none", DisplayName: "Ollama (local)", @@ -118,6 +122,7 @@ func NewOpenAICompatibleProvider(ctx context.Context, providerName, apiKey, mode g: g, modelName: providerName + "/" + model, name: providerName, + maxTokens: maxTokens, authInfo: provider.AuthModeInfo{ Mode: "api_key", DisplayName: providerName, @@ -158,6 +163,7 @@ func NewAzureOpenAIProvider(ctx context.Context, resource, deploymentName, apiVe g: g, modelName: "openai/" + deploymentName, name: "openai_azure", + maxTokens: maxTokens, authInfo: provider.AuthModeInfo{ Mode: "azure", DisplayName: "OpenAI (Azure OpenAI Service)", @@ -187,6 +193,7 @@ func NewAnthropicFoundryProvider(ctx context.Context, resource, model, apiKey, e g: g, modelName: "anthropic/" + model, name: "anthropic_foundry", + maxTokens: maxTokens, authInfo: provider.AuthModeInfo{ Mode: "azure", DisplayName: "Anthropic (Azure AI Foundry)", @@ -212,6 +219,7 @@ func NewVertexAIProvider(ctx context.Context, projectID, region, model, credenti g: g, modelName: "vertexai/" + model, name: "vertexai", + maxTokens: maxTokens, authInfo: provider.AuthModeInfo{ Mode: "gcp", DisplayName: "Vertex AI", From ce7e4c6312bfc9d712b88a22d9e18be2256eeab1 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 06:53:34 -0400 Subject: [PATCH 12/14] fix: apply credentialsJSON in VertexAI and eliminate TOCTOU race in provider cache - genkit/providers.go: when credentialsJSON is non-empty, write it to a temp file and set GOOGLE_APPLICATION_CREDENTIALS before calling gk.Init so Genkit's VertexAI plugin picks it up via credentials.DetectDefault(). Previously the parameter was accepted but silently ignored. - orchestrator/provider_registry.go: re-check the cache under the write lock in createAndCache before inserting, eliminating the TOCTOU race where two concurrent callers for the same uncached alias both created a provider. Co-Authored-By: Claude Sonnet 4.6 --- genkit/providers.go | 19 +++++++++++++++++++ orchestrator/provider_registry.go | 6 +++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/genkit/providers.go b/genkit/providers.go index 95cde7b..f8eb304 100644 --- a/genkit/providers.go +++ b/genkit/providers.go @@ -3,6 +3,7 @@ package genkit import ( "context" "fmt" + "os" "github.com/GoCodeAlone/workflow-plugin-agent/provider" gk "github.com/firebase/genkit/go/genkit" @@ -210,6 +211,24 @@ func NewVertexAIProvider(ctx context.Context, projectID, region, model, credenti if region == "" { region = "us-central1" } + + // Genkit's VertexAI plugin uses credentials.DetectDefault() which reads + // GOOGLE_APPLICATION_CREDENTIALS. When inline JSON is provided, write it + // to a temp file and point the env var at it. + if credentialsJSON != "" { + f, err := os.CreateTemp("", "vertexai-creds-*.json") + if err != nil { + return nil, fmt.Errorf("vertexai: create temp credentials file: %w", err) + } + if _, err := f.WriteString(credentialsJSON); err != nil { + _ = f.Close() + _ = os.Remove(f.Name()) + return nil, fmt.Errorf("vertexai: write credentials: %w", err) + } + _ = f.Close() + os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", f.Name()) //nolint:errcheck + } + p := &googlegenai.VertexAI{ ProjectID: projectID, Location: region, diff --git a/orchestrator/provider_registry.go b/orchestrator/provider_registry.go index 5e2b100..d168533 100644 --- a/orchestrator/provider_registry.go +++ b/orchestrator/provider_registry.go @@ -229,8 +229,12 @@ func (r *ProviderRegistry) createAndCache(ctx context.Context, alias string, cfg return nil, fmt.Errorf("provider registry: create %q: %w", alias, err) } - // Cache + // Cache — re-check under write lock to avoid TOCTOU race r.mu.Lock() + if existing, ok := r.cache[alias]; ok { + r.mu.Unlock() + return existing, nil + } r.cache[alias] = p r.mu.Unlock() From 4b7ba62ede476677492f965d4c032b2e30eddb87 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 14:21:43 -0400 Subject: [PATCH 13/14] fix: address all 7 Copilot PR review comments 1. VertexAI credentials: guard env var with mutex, restore previous value, and delete temp file after Genkit init completes. 2. Empty model defaults: add default model constants per provider (claude-sonnet-4-6, gpt-4o, gemini-2.5-flash, qwen3:8b) applied when model param is empty. 3-5. Context propagation: change ProviderFactory signature to accept context.Context, thread caller's ctx through all factory functions in both ProviderRegistry files. Module factories use context.TODO() with comment explaining the ModuleFactory limitation. 6. ToolCall.ID uniqueness: generate uuid.New().String() per tool call instead of reusing the tool name as ID. 7. Tool schema passthrough: use ai.WithInputSchema(t.Parameters) in DefineTool so the LLM receives accurate JSON Schema for each tool. Co-Authored-By: Claude Opus 4.6 (1M context) --- genkit/adapter.go | 6 ++- genkit/convert.go | 5 ++- genkit/providers.go | 47 +++++++++++++++++++++-- genkit/providers_test.go | 4 +- module_provider.go | 10 ++--- orchestrator/plugin.go | 4 +- orchestrator/provider_registry.go | 62 ++++++++++++++++--------------- provider_registry.go | 32 ++++++++-------- 8 files changed, 108 insertions(+), 62 deletions(-) diff --git a/genkit/adapter.go b/genkit/adapter.go index d79e496..c4e9fa3 100644 --- a/genkit/adapter.go +++ b/genkit/adapter.go @@ -41,11 +41,15 @@ func (p *genkitProvider) resolveToolRefs(tools []provider.ToolDef) []ai.ToolRef refs := make([]ai.ToolRef, 0, len(tools)) for _, t := range tools { if !p.definedTools[t.Name] { + // Pass the exact JSON Schema from provider.ToolDef.Parameters + // so the LLM gets accurate parameter definitions for tool calling. + // WithInputSchema requires In=any (not map[string]any). tool := gk.DefineTool(p.g, t.Name, t.Description, - func(ctx *ai.ToolContext, input map[string]any) (map[string]any, error) { + func(ctx *ai.ToolContext, input any) (any, error) { // Tools are executed by the executor, not Genkit. return nil, fmt.Errorf("tool %s should not be called via Genkit", t.Name) }, + ai.WithInputSchema(t.Parameters), ) refs = append(refs, tool) p.definedTools[t.Name] = true diff --git a/genkit/convert.go b/genkit/convert.go index ed8c87c..045399a 100644 --- a/genkit/convert.go +++ b/genkit/convert.go @@ -3,6 +3,7 @@ package genkit import ( "github.com/GoCodeAlone/workflow-plugin-agent/provider" "github.com/firebase/genkit/go/ai" + "github.com/google/uuid" ) // toGenkitMessages converts our messages to Genkit messages. @@ -73,7 +74,7 @@ func fromGenkitResponse(resp *ai.ModelResponse) *provider.Response { for _, part := range msg.Content { if part.ToolRequest != nil { tc := provider.ToolCall{ - ID: part.ToolRequest.Name, + ID: uuid.New().String(), Name: part.ToolRequest.Name, Arguments: make(map[string]any), } @@ -121,7 +122,7 @@ func fromGenkitChunk(chunk *ai.ModelResponseChunk) provider.StreamEvent { return provider.StreamEvent{ Type: "tool_call", Tool: &provider.ToolCall{ - ID: part.ToolRequest.Name, + ID: uuid.New().String(), Name: part.ToolRequest.Name, Arguments: func() map[string]any { if m, ok := part.ToolRequest.Input.(map[string]any); ok { diff --git a/genkit/providers.go b/genkit/providers.go index f8eb304..41594a6 100644 --- a/genkit/providers.go +++ b/genkit/providers.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "sync" "github.com/GoCodeAlone/workflow-plugin-agent/provider" gk "github.com/firebase/genkit/go/genkit" @@ -15,6 +16,18 @@ import ( "github.com/openai/openai-go/option" ) +// Default models per provider when none specified. +const ( + defaultAnthropicModel = "claude-sonnet-4-6" + defaultOpenAIModel = "gpt-4o" + defaultGeminiModel = "gemini-2.5-flash" + defaultOllamaModel = "qwen3:8b" +) + +// vertexCredsMu guards the GOOGLE_APPLICATION_CREDENTIALS env var +// to prevent races when multiple VertexAI providers initialize concurrently. +var vertexCredsMu sync.Mutex + // initGenkitWithPlugin creates a Genkit instance with a single plugin registered. func initGenkitWithPlugin(ctx context.Context, plugin gk.GenkitOption) *gk.Genkit { @@ -26,6 +39,9 @@ func NewAnthropicProvider(ctx context.Context, apiKey, model, baseURL string, ma if apiKey == "" { return nil, fmt.Errorf("anthropic: APIKey is required") } + if model == "" { + model = defaultAnthropicModel + } p := &anthropicPlugin.Anthropic{APIKey: apiKey, BaseURL: baseURL} g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) return &genkitProvider{ @@ -46,6 +62,9 @@ func NewOpenAIProvider(ctx context.Context, apiKey, model, baseURL string, maxTo if apiKey == "" { return nil, fmt.Errorf("openai: APIKey is required") } + if model == "" { + model = defaultOpenAIModel + } var extraOpts []option.RequestOption if baseURL != "" { extraOpts = append(extraOpts, option.WithBaseURL(baseURL)) @@ -70,6 +89,9 @@ func NewGoogleAIProvider(ctx context.Context, apiKey, model string, maxTokens in if apiKey == "" { return nil, fmt.Errorf("googleai: APIKey is required") } + if model == "" { + model = defaultGeminiModel + } p := &googlegenai.GoogleAI{APIKey: apiKey} g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) return &genkitProvider{ @@ -90,6 +112,9 @@ func NewOllamaProvider(ctx context.Context, model, serverAddress string, maxToke if serverAddress == "" { serverAddress = "http://localhost:11434" } + if model == "" { + model = defaultOllamaModel + } p := &ollamaPlugin.Ollama{ServerAddress: serverAddress} g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) return &genkitProvider{ @@ -214,19 +239,35 @@ func NewVertexAIProvider(ctx context.Context, projectID, region, model, credenti // Genkit's VertexAI plugin uses credentials.DetectDefault() which reads // GOOGLE_APPLICATION_CREDENTIALS. When inline JSON is provided, write it - // to a temp file and point the env var at it. + // to a temp file, set the env var, init Genkit, then clean up. + var tempCredFile string if credentialsJSON != "" { + vertexCredsMu.Lock() + defer vertexCredsMu.Unlock() + + prevCreds := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") + f, err := os.CreateTemp("", "vertexai-creds-*.json") if err != nil { return nil, fmt.Errorf("vertexai: create temp credentials file: %w", err) } + tempCredFile = f.Name() if _, err := f.WriteString(credentialsJSON); err != nil { _ = f.Close() - _ = os.Remove(f.Name()) + _ = os.Remove(tempCredFile) return nil, fmt.Errorf("vertexai: write credentials: %w", err) } _ = f.Close() - os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", f.Name()) //nolint:errcheck + _ = os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", tempCredFile) + defer func() { + // Restore previous env var and remove temp file. + if prevCreds == "" { + _ = os.Unsetenv("GOOGLE_APPLICATION_CREDENTIALS") + } else { + _ = os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", prevCreds) + } + _ = os.Remove(tempCredFile) + }() } p := &googlegenai.VertexAI{ diff --git a/genkit/providers_test.go b/genkit/providers_test.go index 42008f4..ea82f23 100644 --- a/genkit/providers_test.go +++ b/genkit/providers_test.go @@ -3,8 +3,6 @@ package genkit import ( "context" "testing" - - "github.com/GoCodeAlone/workflow-plugin-agent/provider" ) func TestNewAnthropicProvider_MissingKey(t *testing.T) { @@ -88,5 +86,5 @@ func TestProviderImplementsInterface(t *testing.T) { if err != nil { t.Skip("factory failed, skipping interface check") } - var _ provider.Provider = p + _ = p // already provider.Provider; compile verifies interface } diff --git a/module_provider.go b/module_provider.go index 9f6becf..641696c 100644 --- a/module_provider.go +++ b/module_provider.go @@ -184,14 +184,14 @@ func newProviderModuleFactory() plugin.ModuleFactory { } case "anthropic": - if prov, err := gkprov.NewAnthropicProvider(context.Background(), apiKey, model, baseURL, maxTokens); err != nil { + if prov, err := gkprov.NewAnthropicProvider(context.TODO() /* ModuleFactory doesn't receive ctx; TODO: thread ctx via Start() */, apiKey, model, baseURL, maxTokens); err != nil { p = &errProvider{err: err} } else { p = prov } case "openai": - if prov, err := gkprov.NewOpenAIProvider(context.Background(), apiKey, model, baseURL, maxTokens); err != nil { + if prov, err := gkprov.NewOpenAIProvider(context.TODO() /* ModuleFactory doesn't receive ctx; TODO: thread ctx via Start() */, apiKey, model, baseURL, maxTokens); err != nil { p = &errProvider{err: err} } else { p = prov @@ -201,21 +201,21 @@ func newProviderModuleFactory() plugin.ModuleFactory { if baseURL == "" { baseURL = "https://api.githubcopilot.com" } - if prov, err := gkprov.NewOpenAICompatibleProvider(context.Background(), "copilot", apiKey, model, baseURL, maxTokens); err != nil { + if prov, err := gkprov.NewOpenAICompatibleProvider(context.TODO() /* ModuleFactory doesn't receive ctx; TODO: thread ctx via Start() */, "copilot", apiKey, model, baseURL, maxTokens); err != nil { p = &errProvider{err: err} } else { p = prov } case "ollama": - if prov, err := gkprov.NewOllamaProvider(context.Background(), model, baseURL, maxTokens); err != nil { + if prov, err := gkprov.NewOllamaProvider(context.TODO() /* ModuleFactory doesn't receive ctx; TODO: thread ctx via Start() */, model, baseURL, maxTokens); err != nil { p = &errProvider{err: err} } else { p = prov } case "llama_cpp": - if prov, err := gkprov.NewOpenAICompatibleProvider(context.Background(), "llama_cpp", "", model, baseURL, maxTokens); err != nil { + if prov, err := gkprov.NewOpenAICompatibleProvider(context.TODO() /* ModuleFactory doesn't receive ctx; TODO: thread ctx via Start() */, "llama_cpp", "", model, baseURL, maxTokens); err != nil { p = &errProvider{err: err} } else { p = prov diff --git a/orchestrator/plugin.go b/orchestrator/plugin.go index c6b4a3b..09d22a3 100644 --- a/orchestrator/plugin.go +++ b/orchestrator/plugin.go @@ -664,7 +664,7 @@ func testInteractionHook() plugin.WiringHook { } if regSvc, ok := app.SvcRegistry()["ratchet-provider-registry"]; ok { if registry, ok := regSvc.(*ProviderRegistry); ok { - registry.factories["test"] = func(_ string, _ LLMProviderConfig) (provider.Provider, error) { + registry.factories["test"] = func(_ context.Context, _ string, _ LLMProviderConfig) (provider.Provider, error) { return testProvider, nil } if registry.db != nil { @@ -709,7 +709,7 @@ func testInteractionHook() plugin.WiringHook { if regSvc, ok := app.SvcRegistry()["ratchet-provider-registry"]; ok { if registry, ok := regSvc.(*ProviderRegistry); ok { // Register a "test" factory that returns our pre-built test provider - registry.factories["test"] = func(_ string, _ LLMProviderConfig) (provider.Provider, error) { + registry.factories["test"] = func(_ context.Context, _ string, _ LLMProviderConfig) (provider.Provider, error) { return testProvider, nil } // Update the default provider row in the DB from "mock" to "test" diff --git a/orchestrator/provider_registry.go b/orchestrator/provider_registry.go index d168533..f383d40 100644 --- a/orchestrator/provider_registry.go +++ b/orchestrator/provider_registry.go @@ -35,8 +35,10 @@ func (c *LLMProviderConfig) settings() map[string]string { return m } -// ProviderFactory creates a provider.Provider from an API key and config. -type ProviderFactory func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) +// ProviderFactory creates a provider.Provider from a context, API key, and config. +// The context is propagated from the caller (e.g., GetByAlias/GetDefault) to allow +// cancellation and timeout during provider initialization. +type ProviderFactory func(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) // ProviderRegistry manages AI provider lifecycle: factory creation, caching, and DB lookup. type ProviderRegistry struct { @@ -224,7 +226,7 @@ func (r *ProviderRegistry) createAndCache(ctx context.Context, alias string, cfg } // Create provider - p, err := factory(apiKey, *cfg) + p, err := factory(ctx, apiKey, *cfg) if err != nil { return nil, fmt.Errorf("provider registry: create %q: %w", alias, err) } @@ -243,88 +245,88 @@ func (r *ProviderRegistry) createAndCache(ctx context.Context, alias string, cfg // Built-in factory functions -func mockProviderFactory(_ string, _ LLMProviderConfig) (provider.Provider, error) { +func mockProviderFactory(_ context.Context, _ string, _ LLMProviderConfig) (provider.Provider, error) { return &mockProvider{responses: []string{"I have completed the task."}}, nil } -func anthropicProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return gkprov.NewAnthropicProvider(context.Background(), apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) +func anthropicProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewAnthropicProvider(ctx, apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } -func openaiProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return gkprov.NewOpenAIProvider(context.Background(), apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) +func openaiProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewOpenAIProvider(ctx, apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } -func openrouterProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { +func openrouterProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { baseURL := cfg.BaseURL if baseURL == "" { baseURL = "https://openrouter.ai/api/v1" } - return gkprov.NewOpenAICompatibleProvider(context.Background(), "openrouter", apiKey, cfg.Model, baseURL, cfg.MaxTokens) + return gkprov.NewOpenAICompatibleProvider(ctx, "openrouter", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } -func copilotProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { +func copilotProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { baseURL := cfg.BaseURL if baseURL == "" { baseURL = "https://api.githubcopilot.com" } - return gkprov.NewOpenAICompatibleProvider(context.Background(), "copilot", apiKey, cfg.Model, baseURL, cfg.MaxTokens) + return gkprov.NewOpenAICompatibleProvider(ctx, "copilot", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } -func cohereProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { +func cohereProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { baseURL := cfg.BaseURL if baseURL == "" { baseURL = "https://api.cohere.ai/v1" } - return gkprov.NewOpenAICompatibleProvider(context.Background(), "cohere", apiKey, cfg.Model, baseURL, cfg.MaxTokens) + return gkprov.NewOpenAICompatibleProvider(ctx, "cohere", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } -func copilotModelsProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { +func copilotModelsProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { baseURL := cfg.BaseURL if baseURL == "" { baseURL = "https://models.inference.ai.azure.com" } - return gkprov.NewOpenAICompatibleProvider(context.Background(), "copilot_models", apiKey, cfg.Model, baseURL, cfg.MaxTokens) + return gkprov.NewOpenAICompatibleProvider(ctx, "copilot_models", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } -func openaiAzureProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { +func openaiAzureProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { s := cfg.settings() - return gkprov.NewAzureOpenAIProvider(context.Background(), + return gkprov.NewAzureOpenAIProvider(ctx, s["resource"], s["deployment_name"], s["api_version"], apiKey, s["entra_token"], cfg.MaxTokens) } -func anthropicFoundryProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { +func anthropicFoundryProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { s := cfg.settings() - return gkprov.NewAnthropicFoundryProvider(context.Background(), + return gkprov.NewAnthropicFoundryProvider(ctx, s["resource"], cfg.Model, apiKey, s["entra_token"], cfg.MaxTokens) } -func anthropicVertexProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { +func anthropicVertexProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { s := cfg.settings() credJSON := s["credentials_json"] if credJSON == "" { credJSON = apiKey // fallback: secret may contain the full GCP credentials JSON } - return gkprov.NewVertexAIProvider(context.Background(), + return gkprov.NewVertexAIProvider(ctx, s["project_id"], s["region"], cfg.Model, credJSON, cfg.MaxTokens) } -func geminiProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return gkprov.NewGoogleAIProvider(context.Background(), apiKey, cfg.Model, cfg.MaxTokens) +func geminiProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewGoogleAIProvider(ctx, apiKey, cfg.Model, cfg.MaxTokens) } -func ollamaProviderFactory(_ string, cfg LLMProviderConfig) (provider.Provider, error) { - return gkprov.NewOllamaProvider(context.Background(), cfg.Model, cfg.BaseURL, cfg.MaxTokens) +func ollamaProviderFactory(ctx context.Context, _ string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewOllamaProvider(ctx, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } -func llamaCppProviderFactory(_ string, cfg LLMProviderConfig) (provider.Provider, error) { +func llamaCppProviderFactory(ctx context.Context, _ string, cfg LLMProviderConfig) (provider.Provider, error) { // llama.cpp serves an OpenAI-compatible API - return gkprov.NewOpenAICompatibleProvider(context.Background(), "llama_cpp", "", cfg.Model, cfg.BaseURL, cfg.MaxTokens) + return gkprov.NewOpenAICompatibleProvider(ctx, "llama_cpp", "", cfg.Model, cfg.BaseURL, cfg.MaxTokens) } -func anthropicBedrockProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { +func anthropicBedrockProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { s := cfg.settings() - return gkprov.NewBedrockProvider(context.Background(), + return gkprov.NewBedrockProvider(ctx, s["region"], cfg.Model, s["access_key_id"], apiKey, s["session_token"], cfg.BaseURL, cfg.MaxTokens) } diff --git a/provider_registry.go b/provider_registry.go index 658d9da..dbd280c 100644 --- a/provider_registry.go +++ b/provider_registry.go @@ -28,8 +28,8 @@ type LLMProviderConfig struct { IsDefault int `json:"is_default"` } -// ProviderFactory creates a provider.Provider from an API key and config. -type ProviderFactory func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) +// ProviderFactory creates a provider.Provider from a context, API key, and config. +type ProviderFactory func(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) // ProviderRegistry manages AI provider lifecycle: factory creation, caching, and DB lookup. type ProviderRegistry struct { @@ -49,34 +49,34 @@ func NewProviderRegistry(db *sql.DB, secretsProvider secrets.Provider) *Provider Factories: make(map[string]ProviderFactory), } - r.Factories["mock"] = func(_ string, _ LLMProviderConfig) (provider.Provider, error) { + r.Factories["mock"] = func(_ context.Context, _ string, _ LLMProviderConfig) (provider.Provider, error) { return &mockProvider{responses: []string{"I have completed the task."}}, nil } - r.Factories["anthropic"] = func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return gkprov.NewAnthropicProvider(context.Background(), apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) + r.Factories["anthropic"] = func(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewAnthropicProvider(ctx, apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } - r.Factories["openai"] = func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return gkprov.NewOpenAIProvider(context.Background(), apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) + r.Factories["openai"] = func(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewOpenAIProvider(ctx, apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } - r.Factories["openrouter"] = func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + r.Factories["openrouter"] = func(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { baseURL := cfg.BaseURL if baseURL == "" { baseURL = "https://openrouter.ai/api/v1" } - return gkprov.NewOpenAICompatibleProvider(context.Background(), "openrouter", apiKey, cfg.Model, baseURL, cfg.MaxTokens) + return gkprov.NewOpenAICompatibleProvider(ctx, "openrouter", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } - r.Factories["copilot"] = func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + r.Factories["copilot"] = func(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { baseURL := cfg.BaseURL if baseURL == "" { baseURL = "https://api.githubcopilot.com" } - return gkprov.NewOpenAICompatibleProvider(context.Background(), "copilot", apiKey, cfg.Model, baseURL, cfg.MaxTokens) + return gkprov.NewOpenAICompatibleProvider(ctx, "copilot", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } - r.Factories["ollama"] = func(_ string, cfg LLMProviderConfig) (provider.Provider, error) { - return gkprov.NewOllamaProvider(context.Background(), cfg.Model, cfg.BaseURL, cfg.MaxTokens) + r.Factories["ollama"] = func(ctx context.Context, _ string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewOllamaProvider(ctx, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } - r.Factories["llama_cpp"] = func(_ string, cfg LLMProviderConfig) (provider.Provider, error) { - return gkprov.NewOpenAICompatibleProvider(context.Background(), "llama_cpp", "", cfg.Model, cfg.BaseURL, cfg.MaxTokens) + r.Factories["llama_cpp"] = func(ctx context.Context, _ string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewOpenAICompatibleProvider(ctx, "llama_cpp", "", cfg.Model, cfg.BaseURL, cfg.MaxTokens) } return r @@ -192,7 +192,7 @@ func (r *ProviderRegistry) createAndCache(ctx context.Context, alias string, cfg return nil, fmt.Errorf("provider registry: unknown provider type %q", cfg.Type) } - p, err := factory(apiKey, *cfg) + p, err := factory(ctx, apiKey, *cfg) if err != nil { return nil, fmt.Errorf("provider registry: create %q: %w", alias, err) } From 02556f2ce5fb77b6545d2f8a21f73f506b607e78 Mon Sep 17 00:00:00 2001 From: Jon Langevin Date: Sun, 5 Apr 2026 15:12:55 -0400 Subject: [PATCH 14/14] fix: address 11 new Copilot PR review comments (round 2) convert.go: - Tool-call identity: use tc.ID in ToolRequest.Name for correlation - Preserve assistant text alongside ToolCalls (don't drop Content) - Try JSON-decode tool results before wrapping (avoid double-wrap) - Remove chunk-based tool_call emission (emit only from final response) adapter.go: - Stream() emits tool_call events only from final Done response, avoiding duplicates with unstable IDs from incremental chunks providers.go: - Update default models to current: claude-sonnet-4-6, gpt-4.1, gemini-2.5-flash - Add SSRF validation (provider.ValidateBaseURL) for remote providers (Anthropic, OpenAI, OpenAI-compatible), skip for local providers - Add model default for OpenAICompatible (falls back to gpt-4.1) orchestrator/provider_registry.go: - Fix copilot_models default URL: models.github.ai/inference - Add llama_cpp default baseURL: http://127.0.0.1:8080/v1 - Use singleflight.Group to deduplicate concurrent cold-start provider creation per alias Co-Authored-By: Claude Opus 4.6 (1M context) --- genkit/adapter.go | 3 +- genkit/convert.go | 49 +++++++++++--------- genkit/providers.go | 20 ++++++++- go.mod | 2 +- orchestrator/provider_registry.go | 74 +++++++++++++++++++------------ 5 files changed, 95 insertions(+), 53 deletions(-) diff --git a/genkit/adapter.go b/genkit/adapter.go index c4e9fa3..aa2d540 100644 --- a/genkit/adapter.go +++ b/genkit/adapter.go @@ -124,7 +124,8 @@ func (p *genkitProvider) Stream(ctx context.Context, messages []provider.Message return } if result.Done { - // Extract final response for tool calls and usage + // Tool calls are emitted only from the final response to avoid + // duplicates with unstable IDs from incremental chunks. if result.Response != nil { final := fromGenkitResponse(result.Response) for i := range final.ToolCalls { diff --git a/genkit/convert.go b/genkit/convert.go index 045399a..59e8a8c 100644 --- a/genkit/convert.go +++ b/genkit/convert.go @@ -1,6 +1,8 @@ package genkit import ( + "encoding/json" + "github.com/GoCodeAlone/workflow-plugin-agent/provider" "github.com/firebase/genkit/go/ai" "github.com/google/uuid" @@ -26,17 +28,34 @@ func toGenkitMessages(msgs []provider.Message) []*ai.Message { var parts []*ai.Part - // Tool call results: add as ToolResponsePart + // Tool call results: add as ToolResponsePart. + // Try to JSON-decode the content to avoid double-wrapping structured results. if m.ToolCallID != "" { + var output any + if err := json.Unmarshal([]byte(m.Content), &output); err != nil { + // Not valid JSON — wrap as string. + output = map[string]any{"result": m.Content} + } parts = []*ai.Part{ai.NewToolResponsePart(&ai.ToolResponse{ Name: m.ToolCallID, - Output: map[string]any{"result": m.Content}, + Output: output, })} } else if len(m.ToolCalls) > 0 { - // Assistant message with tool calls + // Assistant message with tool calls. + // Preserve text content alongside tool requests (the executor records + // assistant text + tool calls together). + if m.Content != "" { + parts = append(parts, ai.NewTextPart(m.Content)) + } for _, tc := range m.ToolCalls { + // Use tc.ID as the ToolRequest name so tool responses can be + // correlated back to the correct request. + reqName := tc.Name + if tc.ID != "" { + reqName = tc.ID + } parts = append(parts, ai.NewToolRequestPart(&ai.ToolRequest{ - Name: tc.Name, + Name: reqName, Input: tc.Arguments, })) } @@ -116,24 +135,10 @@ func fromGenkitChunk(chunk *ai.ModelResponseChunk) provider.StreamEvent { return provider.StreamEvent{Type: "text", Text: text} } - // Check for tool calls in chunk - for _, part := range chunk.Content { - if part.ToolRequest != nil { - return provider.StreamEvent{ - Type: "tool_call", - Tool: &provider.ToolCall{ - ID: uuid.New().String(), - Name: part.ToolRequest.Name, - Arguments: func() map[string]any { - if m, ok := part.ToolRequest.Input.(map[string]any); ok { - return m - } - return nil - }(), - }, - } - } - } + // Note: tool_call events from chunks are NOT emitted here because Genkit + // provides the complete tool call list in the final Done response. Emitting + // from both chunks and Done would produce duplicate events with unstable IDs. + // The adapter's Stream() method emits tool_call events from the final response only. return provider.StreamEvent{Type: "text", Text: ""} } diff --git a/genkit/providers.go b/genkit/providers.go index 41594a6..f7da8a6 100644 --- a/genkit/providers.go +++ b/genkit/providers.go @@ -19,7 +19,7 @@ import ( // Default models per provider when none specified. const ( defaultAnthropicModel = "claude-sonnet-4-6" - defaultOpenAIModel = "gpt-4o" + defaultOpenAIModel = "gpt-4.1" defaultGeminiModel = "gemini-2.5-flash" defaultOllamaModel = "qwen3:8b" ) @@ -42,6 +42,9 @@ func NewAnthropicProvider(ctx context.Context, apiKey, model, baseURL string, ma if model == "" { model = defaultAnthropicModel } + if err := provider.ValidateBaseURL(baseURL); err != nil { + return nil, fmt.Errorf("anthropic: %w", err) + } p := &anthropicPlugin.Anthropic{APIKey: apiKey, BaseURL: baseURL} g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) return &genkitProvider{ @@ -65,6 +68,9 @@ func NewOpenAIProvider(ctx context.Context, apiKey, model, baseURL string, maxTo if model == "" { model = defaultOpenAIModel } + if err := provider.ValidateBaseURL(baseURL); err != nil { + return nil, fmt.Errorf("openai: %w", err) + } var extraOpts []option.RequestOption if baseURL != "" { extraOpts = append(extraOpts, option.WithBaseURL(baseURL)) @@ -133,6 +139,18 @@ func NewOllamaProvider(ctx context.Context, model, serverAddress string, maxToke // NewOpenAICompatibleProvider creates a provider for OpenAI-compatible endpoints. // Used for OpenRouter, Copilot, Cohere, HuggingFace, llama.cpp, etc. func NewOpenAICompatibleProvider(ctx context.Context, providerName, apiKey, model, baseURL string, maxTokens int) (provider.Provider, error) { + if model == "" { + model = defaultOpenAIModel + } + // Skip SSRF validation for local providers that use localhost endpoints. + switch providerName { + case "llama_cpp", "local", "test": + // Local providers intentionally use http://localhost — skip SSRF checks. + default: + if err := provider.ValidateBaseURL(baseURL); err != nil { + return nil, fmt.Errorf("%s: %w", providerName, err) + } + } effectiveKey := apiKey if effectiveKey == "" { // Use a placeholder to avoid errors when no key is needed (e.g., local endpoints) diff --git a/go.mod b/go.mod index 3e42e28..a9705be 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/ollama/ollama v0.18.3 github.com/openai/openai-go v1.12.0 golang.org/x/crypto v0.48.0 + golang.org/x/sync v0.20.0 golang.org/x/text v0.34.0 google.golang.org/api v0.271.0 gopkg.in/yaml.v3 v3.0.1 @@ -247,7 +248,6 @@ require ( golang.org/x/mod v0.33.0 // indirect golang.org/x/net v0.51.0 // indirect golang.org/x/oauth2 v0.36.0 // indirect - golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/time v0.15.0 // indirect golang.org/x/tools v0.42.0 // indirect diff --git a/orchestrator/provider_registry.go b/orchestrator/provider_registry.go index f383d40..a6df1df 100644 --- a/orchestrator/provider_registry.go +++ b/orchestrator/provider_registry.go @@ -11,6 +11,7 @@ import ( gkprov "github.com/GoCodeAlone/workflow-plugin-agent/genkit" "github.com/GoCodeAlone/workflow-plugin-agent/provider" "github.com/GoCodeAlone/workflow/secrets" + "golang.org/x/sync/singleflight" ) // LLMProviderConfig represents a configured LLM provider stored in the database. @@ -47,6 +48,7 @@ type ProviderRegistry struct { secrets secrets.Provider cache map[string]provider.Provider factories map[string]ProviderFactory + sflight singleflight.Group // deduplicates concurrent cold-start creation per alias } // NewProviderRegistry creates a new ProviderRegistry with built-in factories registered. @@ -208,39 +210,51 @@ func (r *ProviderRegistry) loadConfig(ctx context.Context, alias string) (*LLMPr } // createAndCache resolves the secret, creates the provider via factory, and caches it. +// Uses singleflight to ensure only one goroutine creates a provider per alias, +// avoiding duplicate expensive Genkit init on concurrent cold starts. func (r *ProviderRegistry) createAndCache(ctx context.Context, alias string, cfg *LLMProviderConfig) (provider.Provider, error) { - // Resolve API key from secrets - var apiKey string - if cfg.SecretName != "" && r.secrets != nil { - var err error - apiKey, err = r.secrets.Get(ctx, cfg.SecretName) - if err != nil { - return nil, fmt.Errorf("provider registry: resolve secret %q: %w", cfg.SecretName, err) + result, err, _ := r.sflight.Do(alias, func() (any, error) { + // Re-check cache under singleflight — another caller may have populated it. + r.mu.RLock() + if p, ok := r.cache[alias]; ok { + r.mu.RUnlock() + return p, nil } - } + r.mu.RUnlock() - // Find factory - factory, ok := r.factories[cfg.Type] - if !ok { - return nil, fmt.Errorf("provider registry: unknown provider type %q", cfg.Type) - } + // Resolve API key from secrets + var apiKey string + if cfg.SecretName != "" && r.secrets != nil { + var err error + apiKey, err = r.secrets.Get(ctx, cfg.SecretName) + if err != nil { + return nil, fmt.Errorf("provider registry: resolve secret %q: %w", cfg.SecretName, err) + } + } - // Create provider - p, err := factory(ctx, apiKey, *cfg) - if err != nil { - return nil, fmt.Errorf("provider registry: create %q: %w", alias, err) - } + // Find factory + factory, ok := r.factories[cfg.Type] + if !ok { + return nil, fmt.Errorf("provider registry: unknown provider type %q", cfg.Type) + } - // Cache — re-check under write lock to avoid TOCTOU race - r.mu.Lock() - if existing, ok := r.cache[alias]; ok { + // Create provider + p, err := factory(ctx, apiKey, *cfg) + if err != nil { + return nil, fmt.Errorf("provider registry: create %q: %w", alias, err) + } + + // Cache + r.mu.Lock() + r.cache[alias] = p r.mu.Unlock() - return existing, nil - } - r.cache[alias] = p - r.mu.Unlock() - return p, nil + return p, nil + }) + if err != nil { + return nil, err + } + return result.(provider.Provider), nil } // Built-in factory functions @@ -284,7 +298,7 @@ func cohereProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderCo func copilotModelsProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { baseURL := cfg.BaseURL if baseURL == "" { - baseURL = "https://models.inference.ai.azure.com" + baseURL = "https://models.github.ai/inference" } return gkprov.NewOpenAICompatibleProvider(ctx, "copilot_models", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } @@ -322,7 +336,11 @@ func ollamaProviderFactory(ctx context.Context, _ string, cfg LLMProviderConfig) func llamaCppProviderFactory(ctx context.Context, _ string, cfg LLMProviderConfig) (provider.Provider, error) { // llama.cpp serves an OpenAI-compatible API - return gkprov.NewOpenAICompatibleProvider(ctx, "llama_cpp", "", cfg.Model, cfg.BaseURL, cfg.MaxTokens) + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "http://127.0.0.1:8080/v1" + } + return gkprov.NewOpenAICompatibleProvider(ctx, "llama_cpp", "", cfg.Model, baseURL, cfg.MaxTokens) } func anthropicBedrockProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) {