diff --git a/docs/plans/2026-04-05-genkit-provider-migration-design.md b/docs/plans/2026-04-05-genkit-provider-migration-design.md new file mode 100644 index 0000000..bab52dd --- /dev/null +++ b/docs/plans/2026-04-05-genkit-provider-migration-design.md @@ -0,0 +1,213 @@ +# Genkit Provider Migration Design + +**Date:** 2026-04-05 +**Repo:** workflow-plugin-agent +**Goal:** Replace all hand-rolled provider implementations with Google Genkit Go SDK adapters, keeping the `provider.Provider` interface unchanged. + +## Scope + +- **Replace:** All 12+ provider implementation files with Genkit plugin adapters +- **Keep:** `provider.Provider` interface, `executor/`, `tools/`, `orchestrator/`, mesh support, test providers +- **Gain:** Unified SDK, structured output, built-in tracing, MCP support, fewer direct dependencies + +## Architecture + +```mermaid +graph TD + E[executor.Execute] --> P[provider.Provider interface] + P --> GA[genkit adapter] + GA --> GG[genkit.Generate / GenerateStream] + GG --> PA[Anthropic plugin] + GG --> PO[OpenAI plugin] + GG --> PG[Google AI plugin] + GG --> POL[Ollama plugin] +``` + +Genkit is an **internal implementation detail**. The `provider.Provider` interface is unchanged. All consumers (executor, ratchet-cli, mesh, orchestrator) work without modification. + +## What Gets Deleted + +- `provider/anthropic.go`, `anthropic_bedrock.go`, `anthropic_foundry.go`, `anthropic_vertex.go` +- `provider/anthropic_convert.go`, `anthropic_convert_test.go` +- `provider/openai.go`, `openai_azure.go`, `openai_azure_test.go` +- `provider/gemini.go`, `gemini_test.go` +- `provider/ollama.go`, `ollama_convert.go`, `ollama_test.go`, `ollama_convert_test.go` +- `provider/copilot.go`, `copilot_test.go`, `copilot_models.go`, `copilot_models_test.go` +- `provider/openrouter.go`, `openrouter_test.go` +- `provider/cohere.go` +- `provider/huggingface.go`, `huggingface_test.go` +- `provider/llama_cpp.go`, `llama_cpp_test.go`, `llama_cpp_download.go`, `llama_cpp_download_test.go` +- `provider/local.go`, `local_test.go` +- `provider/ssrf.go`, `provider/models_ssrf_test.go` +- `provider/auth_modes.go` (simplified into adapter) + +## What Stays + +- `provider/provider.go` — interface + types (`Message`, `ToolDef`, `ToolCall`, `StreamEvent`, `Response`, `Usage`, `AuthModeInfo`) +- `provider/models.go` — model registry/definitions (if still needed for model metadata) +- `provider/test_provider.go`, `test_provider_channel.go`, `test_provider_http.go`, `test_provider_scripted.go` — test infrastructure +- `provider/thinking_field_test.go` — test for thinking trace parsing +- `executor/` — entire package unchanged +- `tools/` — entire package unchanged +- `orchestrator/` — entire package unchanged (ProviderRegistry updated to use new factories) + +## New Package: `genkit/` + +``` +genkit/ + genkit.go — Init(), singleton Genkit instance, plugin registration + adapter.go — genkitProvider implementing provider.Provider + convert.go — Bidirectional type conversion (our types ↔ Genkit ai types) + providers.go — Factory functions per provider type + providers_test.go — Tests with mock Genkit model +``` + +### genkit.go — Initialization + +```go +package genkit + +// Init creates and returns a Genkit instance with all available plugins registered. +// Call once at startup. Thread-safe after initialization. +func Init(ctx context.Context, opts ...Option) (*genkit.Genkit, error) + +type Option func(*initConfig) +func WithAnthropicKey(key string) Option +func WithOpenAIKey(key string) Option +func WithGoogleAIKey(key string) Option +func WithOllamaHost(host string) Option +// etc. +``` + +### adapter.go — Provider Adapter + +```go +// genkitProvider adapts a Genkit model to provider.Provider. +type genkitProvider struct { + g *genkit.Genkit + model ai.Model // resolved Genkit model + name string // provider identifier + authInfo provider.AuthModeInfo + config map[string]any // model-specific generation config +} + +func (p *genkitProvider) Name() string +func (p *genkitProvider) AuthModeInfo() provider.AuthModeInfo +func (p *genkitProvider) Chat(ctx, messages, tools) (*provider.Response, error) +func (p *genkitProvider) Stream(ctx, messages, tools) (<-chan provider.StreamEvent, error) +``` + +`Chat()` calls `genkit.Generate()` and converts the response. +`Stream()` calls `genkit.Generate()` with streaming and converts chunks to `StreamEvent` on a channel. + +### convert.go — Type Conversion + +```go +// toGenkitMessages converts []provider.Message → []ai.Message +func toGenkitMessages(msgs []provider.Message) []*ai.Message + +// toGenkitTools converts []provider.ToolDef → []ai.Tool (via genkit.DefineTool or inline) +func toGenkitTools(g *genkit.Genkit, tools []provider.ToolDef) []*ai.Tool + +// fromGenkitResponse converts *ai.ModelResponse → *provider.Response +func fromGenkitResponse(resp *ai.ModelResponse) *provider.Response + +// fromGenkitChunk converts ai.ModelResponseChunk → provider.StreamEvent +func fromGenkitChunk(chunk *ai.ModelResponseChunk) provider.StreamEvent +``` + +### providers.go — Factory Functions + +```go +// NewAnthropicProvider creates a provider backed by Genkit's Anthropic plugin. +func NewAnthropicProvider(cfg AnthropicConfig) (provider.Provider, error) + +// NewOpenAIProvider creates a provider backed by Genkit's OpenAI plugin. +func NewOpenAIProvider(cfg OpenAIConfig) (provider.Provider, error) + +// NewGoogleAIProvider creates a provider backed by Genkit's Google AI plugin. +func NewGoogleAIProvider(cfg GoogleAIConfig) (provider.Provider, error) + +// NewOllamaProvider creates a provider backed by Genkit's Ollama plugin. +func NewOllamaProvider(cfg OllamaConfig) (provider.Provider, error) + +// NewOpenAICompatibleProvider handles OpenRouter, Copilot, Cohere, HuggingFace, etc. +func NewOpenAICompatibleProvider(cfg OpenAICompatibleConfig) (provider.Provider, error) + +// NewBedrockProvider creates a provider via Genkit's AWS Bedrock plugin. +func NewBedrockProvider(cfg BedrockConfig) (provider.Provider, error) + +// NewVertexAIProvider creates a provider via Genkit's Vertex AI plugin. +func NewVertexAIProvider(cfg VertexAIConfig) (provider.Provider, error) + +// NewAzureOpenAIProvider creates a provider via Genkit's OpenAI plugin with Azure endpoint. +func NewAzureOpenAIProvider(cfg AzureOpenAIConfig) (provider.Provider, error) +``` + +Config structs mirror existing provider configs (API key, model, base URL, max tokens, etc.). + +## Provider Mapping + +| Current Implementation | Genkit Plugin | Factory | +|---|---|---| +| `anthropic.go` (direct API) | `github.com/anthropics/anthropic-sdk-go` via Genkit Anthropic plugin | `NewAnthropicProvider` | +| `anthropic_bedrock.go` | Genkit AWS Bedrock plugin | `NewBedrockProvider` | +| `anthropic_vertex.go` | Genkit Vertex AI plugin | `NewVertexAIProvider` | +| `anthropic_foundry.go` | Genkit Anthropic plugin with custom base URL | `NewAnthropicProvider` | +| `openai.go` | Genkit OpenAI plugin | `NewOpenAIProvider` | +| `openai_azure.go` | Genkit OpenAI plugin with Azure endpoint | `NewAzureOpenAIProvider` | +| `gemini.go` | Genkit Google AI plugin | `NewGoogleAIProvider` | +| `ollama.go` + `ollama_convert.go` | Genkit Ollama plugin | `NewOllamaProvider` | +| `copilot.go` | OpenAI-compatible via Genkit OpenAI plugin | `NewOpenAICompatibleProvider` | +| `openrouter.go` | OpenAI-compatible | `NewOpenAICompatibleProvider` | +| `cohere.go` | OpenAI-compatible | `NewOpenAICompatibleProvider` | +| `huggingface.go` | OpenAI-compatible | `NewOpenAICompatibleProvider` | +| `llama_cpp.go` | Ollama (llama.cpp serves OpenAI-compatible) or Genkit Ollama | `NewOllamaProvider` or `NewOpenAICompatibleProvider` | + +## Thinking/Reasoning Trace Support + +Models like Claude and Qwen emit thinking traces. Genkit streams these as part of `ModelResponseChunk`. The `fromGenkitChunk()` converter maps: +- Text content → `StreamEvent{Type: "text", Text: ...}` +- Thinking content → `StreamEvent{Type: "thinking", Thinking: ...}` +- Tool calls → `StreamEvent{Type: "tool_call", Tool: ...}` +- Done → `StreamEvent{Type: "done"}` +- Errors → `StreamEvent{Type: "error", Error: ...}` + +## ProviderRegistry Update + +`orchestrator/provider_registry.go` currently has a `createProvider()` method that switches on provider type and calls constructors like `provider.NewAnthropicProvider()`. This gets updated to call `genkit.NewAnthropicProvider()` etc. The registry interface doesn't change. + +## llama.cpp Download Support + +The current `llama_cpp_download.go` handles binary downloads and HuggingFace model pulls. This is orthogonal to the provider layer — it can stay as a utility in `provider/` or move to a `local/` package. The llama.cpp *provider* itself gets replaced by Genkit's Ollama plugin (llama.cpp serves an OpenAI-compatible API). + +## Dependencies + +**Added:** +- `github.com/firebase/genkit/go` (core) +- Genkit provider plugins (anthropic, openai, googleai, ollama, etc.) + +**Removed (become transitive via Genkit):** +- `github.com/anthropics/anthropic-sdk-go` (direct) +- `github.com/openai/openai-go` (direct) +- `github.com/google/generative-ai-go` (direct) +- `github.com/ollama/ollama` (direct) + +## Testing Strategy + +1. **Unit tests:** Mock Genkit model via `ai.DefineModel()` with canned responses. Test adapter conversion, streaming, tool call mapping. +2. **Existing test providers** (`test_provider.go` etc.) remain unchanged — they implement `provider.Provider` directly and don't use Genkit. +3. **Integration tests** (tagged `//go:build integration`): Call real provider APIs via Genkit to verify end-to-end. +4. **Regression:** All existing executor tests must pass unchanged since `provider.Provider` interface is stable. + +## Migration Order + +1. Add Genkit dependency, create `genkit/` package skeleton +2. Implement `convert.go` (type conversion) +3. Implement `adapter.go` (genkitProvider) +4. Implement `providers.go` (all factory functions) +5. Update `orchestrator/provider_registry.go` to use new factories +6. Delete old provider implementation files +7. Clean up go.mod (remove direct SDK deps that are now transitive) +8. Run full test suite, fix any regressions +9. Tag release diff --git a/docs/plans/2026-04-05-genkit-provider-migration.md b/docs/plans/2026-04-05-genkit-provider-migration.md new file mode 100644 index 0000000..1fd0bc9 --- /dev/null +++ b/docs/plans/2026-04-05-genkit-provider-migration.md @@ -0,0 +1,573 @@ +# Genkit Provider Migration — Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Replace all hand-rolled provider implementations in workflow-plugin-agent with Google Genkit Go SDK adapters, keeping the `provider.Provider` interface unchanged. + +**Architecture:** A new `genkit/` package wraps Genkit plugins behind the existing `provider.Provider` interface. The `ProviderRegistry` factory functions are updated to call genkit factories. All provider implementation files (~25 files) are deleted. + +**Tech Stack:** Go 1.26, Genkit Go v1.6.0, Genkit plugins (anthropic, openai, googlegenai, ollama, compat_oai), community plugins (aws-bedrock, azure-openai) + +--- + +## Task 1: Add Genkit dependency and create package skeleton + +**Files:** +- Modify: `go.mod` +- Create: `genkit/genkit.go` + +**Step 1:** Add Genkit core dependency: +```bash +cd /Users/jon/workspace/workflow-plugin-agent +go get github.com/firebase/genkit/go@v1.6.0 +go get github.com/firebase/genkit/go/plugins/anthropic +go get github.com/firebase/genkit/go/plugins/googlegenai +go get github.com/firebase/genkit/go/plugins/ollama +go get github.com/firebase/genkit/go/plugins/compat_oai +go mod tidy +``` + +**Step 2:** Create `genkit/genkit.go` — package with Genkit initialization: +```go +// Package genkit provides Genkit-backed implementations of provider.Provider. +package genkit + +import ( + "context" + "sync" + + gk "github.com/firebase/genkit/go/genkit" +) + +var ( + instance *gk.Genkit + once sync.Once +) + +// Instance returns the shared Genkit instance, initializing it lazily on first call. +// Plugins are registered dynamically when providers are created, not at init time. +func Instance(ctx context.Context) *gk.Genkit { + once.Do(func() { + instance = gk.Init(ctx) + }) + return instance +} +``` + +**Step 3:** Build to verify: `go build ./...` + +**Step 4:** Commit: +```bash +git add go.mod go.sum genkit/genkit.go +git commit -m "chore: add Genkit Go SDK dependency and package skeleton" +``` + +--- + +## Task 2: Implement type conversion layer + +**Files:** +- Create: `genkit/convert.go` +- Create: `genkit/convert_test.go` + +**Step 1:** Create `genkit/convert.go` — bidirectional type conversion between our `provider.*` types and Genkit `ai.*` types: +```go +package genkit + +import ( + "github.com/firebase/genkit/go/ai" + "github.com/GoCodeAlone/workflow-plugin-agent/provider" +) + +// toGenkitMessages converts our messages to Genkit messages. +func toGenkitMessages(msgs []provider.Message) []*ai.Message { + out := make([]*ai.Message, 0, len(msgs)) + for _, m := range msgs { + var role ai.Role + switch m.Role { + case provider.RoleSystem: + role = ai.RoleSystem + case provider.RoleUser: + role = ai.RoleUser + case provider.RoleAssistant: + role = ai.RoleModel + case provider.RoleTool: + role = ai.RoleTool + default: + role = ai.RoleUser + } + + parts := []*ai.Part{ai.NewTextPart(m.Content)} + + // Tool call results: add as ToolResponsePart + if m.ToolCallID != "" { + parts = []*ai.Part{ai.NewToolResponsePart(&ai.ToolResponse{ + Name: m.ToolCallID, + Output: map[string]any{"result": m.Content}, + })} + } + + out = append(out, ai.NewMessage(role, nil, parts...)) + } + return out +} + +// toGenkitToolDefs converts our tool definitions to Genkit tool defs. +// Returns the tool names for WithModelName-style passing. +func toGenkitToolDefs(tools []provider.ToolDef) []ai.ToolDef { + out := make([]ai.ToolDef, 0, len(tools)) + for _, t := range tools { + out = append(out, ai.ToolDef{ + Name: t.Name, + Description: t.Description, + InputSchema: t.Parameters, + }) + } + return out +} + +// fromGenkitResponse converts a Genkit response to our Response type. +func fromGenkitResponse(resp *ai.ModelResponse) *provider.Response { + if resp == nil { + return &provider.Response{} + } + + out := &provider.Response{ + Content: resp.Text(), + } + + // Extract tool calls + if msg := resp.Message; msg != nil { + for _, part := range msg.Content { + if part.ToolRequest != nil { + tc := provider.ToolCall{ + ID: part.ToolRequest.Name, + Name: part.ToolRequest.Name, + Arguments: make(map[string]any), + } + if input, ok := part.ToolRequest.Input.(map[string]any); ok { + tc.Arguments = input + } + out.ToolCalls = append(out.ToolCalls, tc) + } + } + } + + // Extract usage + if resp.Usage != nil { + out.Usage = provider.Usage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + } + } + + return out +} + +// fromGenkitChunk converts a Genkit stream chunk to our StreamEvent. +func fromGenkitChunk(chunk *ai.ModelResponseChunk) provider.StreamEvent { + if chunk == nil { + return provider.StreamEvent{Type: "done"} + } + text := chunk.Text() + if text != "" { + return provider.StreamEvent{Type: "text", Text: text} + } + // Check for tool calls in chunk + for _, part := range chunk.Content { + if part.ToolRequest != nil { + return provider.StreamEvent{ + Type: "tool_call", + Tool: &provider.ToolCall{ + ID: part.ToolRequest.Name, + Name: part.ToolRequest.Name, + Arguments: func() map[string]any { + if m, ok := part.ToolRequest.Input.(map[string]any); ok { + return m + } + return nil + }(), + }, + } + } + } + return provider.StreamEvent{Type: "text", Text: ""} +} +``` + +**Step 2:** Create `genkit/convert_test.go`: +```go +package genkit + +import ( + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/GoCodeAlone/workflow-plugin-agent/provider" +) + +func TestToGenkitMessages(t *testing.T) { + msgs := []provider.Message{ + {Role: provider.RoleSystem, Content: "You are helpful."}, + {Role: provider.RoleUser, Content: "Hello"}, + {Role: provider.RoleAssistant, Content: "Hi there"}, + } + result := toGenkitMessages(msgs) + if len(result) != 3 { + t.Fatalf("expected 3 messages, got %d", len(result)) + } + if result[0].Role != ai.RoleSystem { + t.Errorf("expected system role, got %s", result[0].Role) + } + if result[2].Role != ai.RoleModel { + t.Errorf("expected model role for assistant, got %s", result[2].Role) + } +} + +func TestFromGenkitResponse(t *testing.T) { + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("Hello world"), + Usage: &ai.GenerationUsage{InputTokens: 10, OutputTokens: 5}, + } + result := fromGenkitResponse(resp) + if result.Content != "Hello world" { + t.Errorf("expected 'Hello world', got %q", result.Content) + } + if result.Usage.InputTokens != 10 { + t.Errorf("expected 10 input tokens, got %d", result.Usage.InputTokens) + } +} + +func TestFromGenkitResponseNil(t *testing.T) { + result := fromGenkitResponse(nil) + if result.Content != "" { + t.Errorf("expected empty content, got %q", result.Content) + } +} +``` + +**Step 3:** Run tests: `go test ./genkit/ -v -count=1` + +**Step 4:** Commit: +```bash +git add genkit/convert.go genkit/convert_test.go +git commit -m "feat: add Genkit ↔ provider type conversion layer" +``` + +--- + +## Task 3: Implement the Genkit provider adapter + +**Files:** +- Create: `genkit/adapter.go` +- Create: `genkit/adapter_test.go` + +**Step 1:** Create `genkit/adapter.go` — the `genkitProvider` struct that implements `provider.Provider`: +```go +package genkit + +import ( + "context" + "fmt" + + "github.com/firebase/genkit/go/ai" + gk "github.com/firebase/genkit/go/genkit" + "github.com/GoCodeAlone/workflow-plugin-agent/provider" +) + +// genkitProvider adapts a Genkit model to provider.Provider. +type genkitProvider struct { + g *gk.Genkit + modelName string // "provider/model" format + name string + authInfo provider.AuthModeInfo +} + +func (p *genkitProvider) Name() string { return p.name } +func (p *genkitProvider) AuthModeInfo() provider.AuthModeInfo { return p.authInfo } + +func (p *genkitProvider) Chat(ctx context.Context, messages []provider.Message, tools []provider.ToolDef) (*provider.Response, error) { + opts := []ai.GenerateOption{ + ai.WithModelName(p.modelName), + ai.WithMessages(toGenkitMessages(messages)...), + } + + // Pass tool definitions if provided — use WithReturnToolRequests so we handle tool + // execution ourselves (the executor loop does this, not Genkit). + if len(tools) > 0 { + opts = append(opts, ai.WithReturnToolRequests(true)) + // Register tools dynamically for this call + for _, t := range tools { + tool := gk.DefineTool(p.g, t.Name, t.Description, + func(ctx *ai.ToolContext, input map[string]any) (map[string]any, error) { + // Placeholder — tools are executed by the executor, not here. + return nil, fmt.Errorf("tool %s should not be called via Genkit", t.Name) + }, + ) + opts = append(opts, ai.WithTools(tool)) + } + } + + resp, err := gk.Generate(ctx, p.g, opts...) + if err != nil { + return nil, fmt.Errorf("genkit generate: %w", err) + } + + return fromGenkitResponse(resp), nil +} + +func (p *genkitProvider) Stream(ctx context.Context, messages []provider.Message, tools []provider.ToolDef) (<-chan provider.StreamEvent, error) { + ch := make(chan provider.StreamEvent, 64) + + opts := []ai.GenerateOption{ + ai.WithModelName(p.modelName), + ai.WithMessages(toGenkitMessages(messages)...), + } + + if len(tools) > 0 { + opts = append(opts, ai.WithReturnToolRequests(true)) + for _, t := range tools { + tool := gk.DefineTool(p.g, t.Name, t.Description, + func(ctx *ai.ToolContext, input map[string]any) (map[string]any, error) { + return nil, fmt.Errorf("tool %s should not be called via Genkit", t.Name) + }, + ) + opts = append(opts, ai.WithTools(tool)) + } + } + + go func() { + defer close(ch) + + stream := gk.GenerateStream(ctx, p.g, opts...) + for result, err := range stream { + if err != nil { + ch <- provider.StreamEvent{Type: "error", Error: err.Error()} + return + } + if result.Done { + // Extract final response for tool calls and usage + if result.Response != nil { + final := fromGenkitResponse(result.Response) + for _, tc := range final.ToolCalls { + ch <- provider.StreamEvent{Type: "tool_call", Tool: &tc} + } + if final.Usage.InputTokens > 0 || final.Usage.OutputTokens > 0 { + ch <- provider.StreamEvent{Type: "done", Usage: &final.Usage} + return + } + } + ch <- provider.StreamEvent{Type: "done"} + return + } + if result.Chunk != nil { + ev := fromGenkitChunk(result.Chunk) + if ev.Type != "" && (ev.Text != "" || ev.Tool != nil) { + ch <- ev + } + } + } + // Iterator exhausted without Done — send done anyway + ch <- provider.StreamEvent{Type: "done"} + }() + + return ch, nil +} +``` + +**Step 2:** Create `genkit/adapter_test.go` with mock model tests. Tests should verify Chat/Stream call paths and conversion. + +**Step 3:** Run tests: `go test ./genkit/ -v -count=1` + +**Step 4:** Commit: +```bash +git add genkit/adapter.go genkit/adapter_test.go +git commit -m "feat: implement Genkit provider adapter (provider.Provider interface)" +``` + +--- + +## Task 4: Implement provider factory functions + +**Files:** +- Create: `genkit/providers.go` +- Create: `genkit/providers_test.go` + +**Step 1:** Create `genkit/providers.go` with factory functions for each provider type. Each factory: +1. Initializes the appropriate Genkit plugin +2. Creates a `genkitProvider` with the correct model name format +3. Returns `provider.Provider` + +Factory functions needed (one per ProviderRegistry factory): +- `NewAnthropicProvider(apiKey, model, baseURL string, maxTokens int) provider.Provider` +- `NewOpenAIProvider(apiKey, model, baseURL string, maxTokens int) provider.Provider` +- `NewGoogleAIProvider(apiKey, model string, maxTokens int) (provider.Provider, error)` +- `NewOllamaProvider(model, baseURL string, maxTokens int) provider.Provider` +- `NewOpenAICompatibleProvider(name, apiKey, model, baseURL string, maxTokens int) provider.Provider` — for OpenRouter, Copilot, Cohere, HuggingFace +- `NewBedrockProvider(region, model, accessKeyID, secretAccessKey, sessionToken string, maxTokens int) provider.Provider` — via community plugin or OpenAI-compatible +- `NewVertexAIProvider(projectID, region, model, credentialsJSON string, maxTokens int) (provider.Provider, error)` +- `NewAzureOpenAIProvider(resource, deploymentName, apiVersion, apiKey string, maxTokens int) provider.Provider` — via community plugin or OpenAI-compatible +- `NewAnthropicFoundryProvider(resource, model, apiKey, entraToken string, maxTokens int) provider.Provider` + +Each factory calls `Instance(ctx)` to get the shared Genkit instance, registers the plugin if not already registered, and returns a `genkitProvider`. + +**Step 2:** Create `genkit/providers_test.go` — test factory instantiation (not live API calls). + +**Step 3:** Build: `go build ./...` + +**Step 4:** Commit: +```bash +git add genkit/providers.go genkit/providers_test.go +git commit -m "feat: add Genkit provider factory functions for all provider types" +``` + +--- + +## Task 5: Update ProviderRegistry to use Genkit factories + +**Files:** +- Modify: `orchestrator/provider_registry.go` + +**Step 1:** Update all factory functions in `provider_registry.go` to call `genkit.New*Provider()` instead of `provider.New*Provider()`: + +```go +import gkprov "github.com/GoCodeAlone/workflow-plugin-agent/genkit" + +func anthropicProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewAnthropicProvider(apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens), nil +} + +func openaiProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewOpenAIProvider(apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens), nil +} +// ... repeat for all 13 factory functions +``` + +**Step 2:** Build: `go build ./...` + +**Step 3:** Run registry tests: `go test ./orchestrator/ -run TestProviderRegistry -v -count=1` + +**Step 4:** Run full suite: `go test ./... -count=1` + +**Step 5:** Commit: +```bash +git add orchestrator/provider_registry.go +git commit -m "feat: update ProviderRegistry to use Genkit-backed factories" +``` + +--- + +## Task 6: Delete old provider implementation files + +**Files:** +- Delete: All provider implementation files listed in design doc + +**Step 1:** Delete the old provider files (keep `provider.go`, `models.go`, `auth_modes.go`, and all `test_provider*.go` files): +```bash +cd /Users/jon/workspace/workflow-plugin-agent +# Delete implementation files +rm provider/anthropic.go provider/anthropic_bedrock.go provider/anthropic_foundry.go provider/anthropic_vertex.go +rm provider/anthropic_convert.go provider/anthropic_convert_test.go +rm provider/openai.go provider/openai_azure.go provider/openai_azure_test.go +rm provider/gemini.go provider/gemini_test.go +rm provider/ollama.go provider/ollama_convert.go provider/ollama_test.go provider/ollama_convert_test.go +rm provider/copilot.go provider/copilot_test.go provider/copilot_models.go provider/copilot_models_test.go +rm provider/openrouter.go provider/openrouter_test.go +rm provider/cohere.go +rm provider/huggingface.go provider/huggingface_test.go +rm provider/llama_cpp.go provider/llama_cpp_test.go provider/llama_cpp_download.go provider/llama_cpp_download_test.go +rm provider/local.go provider/local_test.go +rm provider/ssrf.go provider/models_ssrf_test.go +``` + +**Step 2:** Fix any compilation errors from dangling references. Key areas to check: +- `provider/models.go` — may reference deleted types +- `provider/auth_modes.go` — may reference deleted types +- `provider/thinking_field_test.go` — may import deleted code +- Any file in `orchestrator/` that directly imports deleted provider constructors + +**Step 3:** Build: `go build ./...` + +**Step 4:** Run tests: `go test ./... -count=1` + +**Step 5:** Commit: +```bash +git add -u # stages deletions +git commit -m "refactor: delete hand-rolled provider implementations (replaced by Genkit)" +``` + +--- + +## Task 7: Clean up dependencies and fix remaining issues + +**Files:** +- Modify: `go.mod` +- Modify: any files with compilation errors + +**Step 1:** Remove direct SDK dependencies that are now transitive via Genkit: +```bash +go mod tidy +``` + +Check if `github.com/anthropics/anthropic-sdk-go`, `github.com/openai/openai-go`, `github.com/google/generative-ai-go`, `github.com/ollama/ollama` are still needed directly. If only Genkit imports them, they'll move to `indirect`. + +**Step 2:** Fix any remaining compile errors or test failures. + +**Step 3:** Run full test suite with race detector: +```bash +go test -race ./... -count=1 +``` + +**Step 4:** Run linter: +```bash +golangci-lint run +``` + +**Step 5:** Commit: +```bash +git add go.mod go.sum +git commit -m "chore: clean up dependencies after Genkit migration" +``` + +--- + +## Task 8: Write integration tests and verify + +**Files:** +- Create: `genkit/integration_test.go` + +**Step 1:** Create integration tests that verify the full path: `ProviderRegistry → genkit factory → genkitProvider → Genkit → mock model`. These tests use in-memory DB + mock secrets like existing registry tests. + +**Step 2:** Verify existing executor tests pass (they use `provider.Provider` interface which is unchanged): +```bash +go test ./executor/ -v -count=1 +``` + +**Step 3:** Verify orchestrator tests: +```bash +go test ./orchestrator/ -v -count=1 +``` + +**Step 4:** Full regression: +```bash +go test -race ./... -count=1 +golangci-lint run +``` + +**Step 5:** Commit: +```bash +git add genkit/integration_test.go +git commit -m "test: add Genkit integration tests and verify full regression" +``` + +--- + +## Execution Order + +``` +Task 1 (deps + skeleton) → Task 2 (convert) → Task 3 (adapter) → Task 4 (factories) + ↓ +Task 5 (registry update) → Task 6 (delete old files) → Task 7 (cleanup) → Task 8 (tests) +``` + +All tasks are sequential — each builds on the previous. diff --git a/genkit/adapter.go b/genkit/adapter.go new file mode 100644 index 0000000..aa2d540 --- /dev/null +++ b/genkit/adapter.go @@ -0,0 +1,155 @@ +package genkit + +import ( + "context" + "fmt" + "sync" + + "github.com/GoCodeAlone/workflow-plugin-agent/provider" + "github.com/firebase/genkit/go/ai" + gk "github.com/firebase/genkit/go/genkit" +) + +// genkitProvider adapts a Genkit model to provider.Provider. +type genkitProvider struct { + g *gk.Genkit + modelName string // "provider/model" format e.g. "anthropic/claude-sonnet-4-6" + name string + authInfo provider.AuthModeInfo + maxTokens int // 0 means use model default + + mu sync.Mutex + definedTools map[string]bool // tracks which tool names are registered +} + +func (p *genkitProvider) Name() string { return p.name } +func (p *genkitProvider) AuthModeInfo() provider.AuthModeInfo { return p.authInfo } + +// resolveToolRefs ensures each tool is registered exactly once and returns +// their ToolRef representations for use with WithTools. +func (p *genkitProvider) resolveToolRefs(tools []provider.ToolDef) []ai.ToolRef { + if len(tools) == 0 { + return nil + } + p.mu.Lock() + defer p.mu.Unlock() + + if p.definedTools == nil { + p.definedTools = make(map[string]bool) + } + + refs := make([]ai.ToolRef, 0, len(tools)) + for _, t := range tools { + if !p.definedTools[t.Name] { + // Pass the exact JSON Schema from provider.ToolDef.Parameters + // so the LLM gets accurate parameter definitions for tool calling. + // WithInputSchema requires In=any (not map[string]any). + tool := gk.DefineTool(p.g, t.Name, t.Description, + func(ctx *ai.ToolContext, input any) (any, error) { + // Tools are executed by the executor, not Genkit. + return nil, fmt.Errorf("tool %s should not be called via Genkit", t.Name) + }, + ai.WithInputSchema(t.Parameters), + ) + refs = append(refs, tool) + p.definedTools[t.Name] = true + } else { + refs = append(refs, ai.ToolName(t.Name)) + } + } + return refs +} + +// generationConfig returns a WithConfig option when maxTokens is configured. +func (p *genkitProvider) generationConfig() ai.GenerateOption { + if p.maxTokens > 0 { + return ai.WithConfig(&ai.GenerationCommonConfig{MaxOutputTokens: p.maxTokens}) + } + return nil +} + +// Chat sends a non-streaming request and returns the complete response. +func (p *genkitProvider) Chat(ctx context.Context, messages []provider.Message, tools []provider.ToolDef) (*provider.Response, error) { + opts := []ai.GenerateOption{ + ai.WithModelName(p.modelName), + ai.WithMessages(toGenkitMessages(messages)...), + } + + if cfg := p.generationConfig(); cfg != nil { + opts = append(opts, cfg) + } + + if len(tools) > 0 { + opts = append(opts, ai.WithReturnToolRequests(true)) + for _, ref := range p.resolveToolRefs(tools) { + opts = append(opts, ai.WithTools(ref)) + } + } + + resp, err := gk.Generate(ctx, p.g, opts...) + if err != nil { + return nil, fmt.Errorf("genkit generate: %w", err) + } + + return fromGenkitResponse(resp), nil +} + +// Stream sends a streaming request. Events are delivered on the returned channel. +func (p *genkitProvider) Stream(ctx context.Context, messages []provider.Message, tools []provider.ToolDef) (<-chan provider.StreamEvent, error) { + ch := make(chan provider.StreamEvent, 64) + + opts := []ai.GenerateOption{ + ai.WithModelName(p.modelName), + ai.WithMessages(toGenkitMessages(messages)...), + } + + if cfg := p.generationConfig(); cfg != nil { + opts = append(opts, cfg) + } + + if len(tools) > 0 { + opts = append(opts, ai.WithReturnToolRequests(true)) + for _, ref := range p.resolveToolRefs(tools) { + opts = append(opts, ai.WithTools(ref)) + } + } + + go func() { + defer close(ch) + + stream := gk.GenerateStream(ctx, p.g, opts...) + for result, err := range stream { + if err != nil { + ch <- provider.StreamEvent{Type: "error", Error: err.Error()} + return + } + if result.Done { + // Tool calls are emitted only from the final response to avoid + // duplicates with unstable IDs from incremental chunks. + if result.Response != nil { + final := fromGenkitResponse(result.Response) + for i := range final.ToolCalls { + tc := final.ToolCalls[i] + ch <- provider.StreamEvent{Type: "tool_call", Tool: &tc} + } + if final.Usage.InputTokens > 0 || final.Usage.OutputTokens > 0 { + ch <- provider.StreamEvent{Type: "done", Usage: &final.Usage} + return + } + } + ch <- provider.StreamEvent{Type: "done"} + return + } + if result.Chunk != nil { + ev := fromGenkitChunk(result.Chunk) + if ev.Text != "" || ev.Thinking != "" || ev.Tool != nil { + ch <- ev + } + } + } + // Iterator exhausted without Done — send done anyway + ch <- provider.StreamEvent{Type: "done"} + }() + + return ch, nil +} diff --git a/genkit/adapter_test.go b/genkit/adapter_test.go new file mode 100644 index 0000000..a2ded24 --- /dev/null +++ b/genkit/adapter_test.go @@ -0,0 +1,167 @@ +package genkit + +import ( + "context" + "testing" + + "github.com/GoCodeAlone/workflow-plugin-agent/provider" + "github.com/firebase/genkit/go/ai" + gk "github.com/firebase/genkit/go/genkit" +) + +// mockModel defines a canned in-memory model on a Genkit instance for testing. +func mockModel(g *gk.Genkit, name string, resp *ai.ModelResponse) { + opts := &ai.ModelOptions{ + Supports: &ai.ModelSupports{Tools: true, SystemRole: true, Multiturn: true}, + } + gk.DefineModel(g, name, opts, func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + if cb != nil { + // Stream text chunks + if resp.Message != nil { + for _, part := range resp.Message.Content { + if !part.IsReasoning() { + _ = cb(ctx, &ai.ModelResponseChunk{ + Content: []*ai.Part{part}, + }) + } + } + } + } + return resp, nil + }) +} + +func newTestProvider(t *testing.T, modelName string, resp *ai.ModelResponse) *genkitProvider { + t.Helper() + g := gk.Init(context.Background()) + mockModel(g, modelName, resp) + return &genkitProvider{ + g: g, + modelName: modelName, + name: "mock", + } +} + +func TestGenkitProviderChat(t *testing.T) { + const modelName = "mock/test-model" + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("Hello from mock"), + Usage: &ai.GenerationUsage{InputTokens: 5, OutputTokens: 3}, + } + + p := newTestProvider(t, modelName, resp) + + msgs := []provider.Message{{Role: provider.RoleUser, Content: "Hi"}} + got, err := p.Chat(context.Background(), msgs, nil) + if err != nil { + t.Fatalf("Chat returned error: %v", err) + } + if got.Content != "Hello from mock" { + t.Errorf("expected 'Hello from mock', got %q", got.Content) + } + if got.Usage.InputTokens != 5 { + t.Errorf("expected 5 input tokens, got %d", got.Usage.InputTokens) + } +} + +func TestGenkitProviderName(t *testing.T) { + p := &genkitProvider{name: "test-provider"} + if p.Name() != "test-provider" { + t.Errorf("expected 'test-provider', got %q", p.Name()) + } +} + +func TestGenkitProviderStream(t *testing.T) { + const modelName = "mock/stream-model" + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("streamed"), + Usage: &ai.GenerationUsage{InputTokens: 2, OutputTokens: 1}, + } + + p := newTestProvider(t, modelName, resp) + + msgs := []provider.Message{{Role: provider.RoleUser, Content: "Hi"}} + ch, err := p.Stream(context.Background(), msgs, nil) + if err != nil { + t.Fatalf("Stream returned error: %v", err) + } + + var events []provider.StreamEvent + for ev := range ch { + events = append(events, ev) + } + + if len(events) == 0 { + t.Fatal("expected at least one event") + } + last := events[len(events)-1] + if last.Type != "done" { + t.Errorf("expected last event type 'done', got %q", last.Type) + } +} + +func TestGenkitProviderChatWithTools(t *testing.T) { + const modelName = "mock/tool-model" + + // Mock model that returns a tool request + resp := &ai.ModelResponse{ + Message: ai.NewMessage(ai.RoleModel, nil, + ai.NewToolRequestPart(&ai.ToolRequest{ + Name: "calculator", + Input: map[string]any{"a": 1, "b": 2}, + }), + ), + } + + p := newTestProvider(t, modelName, resp) + + msgs := []provider.Message{{Role: provider.RoleUser, Content: "Add 1+2"}} + tools := []provider.ToolDef{{ + Name: "calculator", + Description: "Adds numbers", + Parameters: map[string]any{"type": "object"}, + }} + + got, err := p.Chat(context.Background(), msgs, tools) + if err != nil { + t.Fatalf("Chat returned error: %v", err) + } + + if len(got.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(got.ToolCalls)) + } + if got.ToolCalls[0].Name != "calculator" { + t.Errorf("expected tool 'calculator', got %q", got.ToolCalls[0].Name) + } +} + +func TestGenkitProviderChatToolDeduplication(t *testing.T) { + const modelName = "mock/dedup-model" + + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("ok"), + } + + p := newTestProvider(t, modelName, resp) + + tools := []provider.ToolDef{{ + Name: "my-tool", + Description: "A tool", + Parameters: map[string]any{"type": "object"}, + }} + + msgs := []provider.Message{{Role: provider.RoleUser, Content: "call 1"}} + + // First call - defines the tool + _, err := p.Chat(context.Background(), msgs, tools) + if err != nil { + t.Fatalf("first Chat error: %v", err) + } + + // Second call - must not panic due to duplicate tool registration + msgs2 := []provider.Message{{Role: provider.RoleUser, Content: "call 2"}} + _, err = p.Chat(context.Background(), msgs2, tools) + if err != nil { + t.Fatalf("second Chat error: %v", err) + } +} diff --git a/genkit/convert.go b/genkit/convert.go new file mode 100644 index 0000000..59e8a8c --- /dev/null +++ b/genkit/convert.go @@ -0,0 +1,144 @@ +package genkit + +import ( + "encoding/json" + + "github.com/GoCodeAlone/workflow-plugin-agent/provider" + "github.com/firebase/genkit/go/ai" + "github.com/google/uuid" +) + +// toGenkitMessages converts our messages to Genkit messages. +func toGenkitMessages(msgs []provider.Message) []*ai.Message { + out := make([]*ai.Message, 0, len(msgs)) + for _, m := range msgs { + var role ai.Role + switch m.Role { + case provider.RoleSystem: + role = ai.RoleSystem + case provider.RoleUser: + role = ai.RoleUser + case provider.RoleAssistant: + role = ai.RoleModel + case provider.RoleTool: + role = ai.RoleTool + default: + role = ai.RoleUser + } + + var parts []*ai.Part + + // Tool call results: add as ToolResponsePart. + // Try to JSON-decode the content to avoid double-wrapping structured results. + if m.ToolCallID != "" { + var output any + if err := json.Unmarshal([]byte(m.Content), &output); err != nil { + // Not valid JSON — wrap as string. + output = map[string]any{"result": m.Content} + } + parts = []*ai.Part{ai.NewToolResponsePart(&ai.ToolResponse{ + Name: m.ToolCallID, + Output: output, + })} + } else if len(m.ToolCalls) > 0 { + // Assistant message with tool calls. + // Preserve text content alongside tool requests (the executor records + // assistant text + tool calls together). + if m.Content != "" { + parts = append(parts, ai.NewTextPart(m.Content)) + } + for _, tc := range m.ToolCalls { + // Use tc.ID as the ToolRequest name so tool responses can be + // correlated back to the correct request. + reqName := tc.Name + if tc.ID != "" { + reqName = tc.ID + } + parts = append(parts, ai.NewToolRequestPart(&ai.ToolRequest{ + Name: reqName, + Input: tc.Arguments, + })) + } + } else { + parts = []*ai.Part{ai.NewTextPart(m.Content)} + } + + out = append(out, ai.NewMessage(role, nil, parts...)) + } + return out +} + +// fromGenkitResponse converts a Genkit response to our Response type. +func fromGenkitResponse(resp *ai.ModelResponse) *provider.Response { + if resp == nil { + return &provider.Response{} + } + + out := &provider.Response{ + Content: resp.Text(), + } + + // Extract thinking/reasoning content + if msg := resp.Message; msg != nil { + for _, part := range msg.Content { + if part.IsReasoning() { + out.Thinking = part.Text + break + } + } + } + + // Extract tool calls + if msg := resp.Message; msg != nil { + for _, part := range msg.Content { + if part.ToolRequest != nil { + tc := provider.ToolCall{ + ID: uuid.New().String(), + Name: part.ToolRequest.Name, + Arguments: make(map[string]any), + } + if input, ok := part.ToolRequest.Input.(map[string]any); ok { + tc.Arguments = input + } + out.ToolCalls = append(out.ToolCalls, tc) + } + } + } + + // Extract usage + if resp.Usage != nil { + out.Usage = provider.Usage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + } + } + + return out +} + +// fromGenkitChunk converts a Genkit stream chunk to our StreamEvent. +func fromGenkitChunk(chunk *ai.ModelResponseChunk) provider.StreamEvent { + if chunk == nil { + return provider.StreamEvent{Type: "done"} + } + + // Check for thinking/reasoning parts first + for _, part := range chunk.Content { + if part.IsReasoning() { + return provider.StreamEvent{Type: "thinking", Thinking: part.Text} + } + } + + // Check for text + text := chunk.Text() + if text != "" { + return provider.StreamEvent{Type: "text", Text: text} + } + + // Note: tool_call events from chunks are NOT emitted here because Genkit + // provides the complete tool call list in the final Done response. Emitting + // from both chunks and Done would produce duplicate events with unstable IDs. + // The adapter's Stream() method emits tool_call events from the final response only. + + return provider.StreamEvent{Type: "text", Text: ""} +} diff --git a/genkit/convert_test.go b/genkit/convert_test.go new file mode 100644 index 0000000..379bf8e --- /dev/null +++ b/genkit/convert_test.go @@ -0,0 +1,114 @@ +package genkit + +import ( + "testing" + + "github.com/GoCodeAlone/workflow-plugin-agent/provider" + "github.com/firebase/genkit/go/ai" +) + +func TestToGenkitMessages(t *testing.T) { + msgs := []provider.Message{ + {Role: provider.RoleSystem, Content: "You are helpful."}, + {Role: provider.RoleUser, Content: "Hello"}, + {Role: provider.RoleAssistant, Content: "Hi there"}, + } + result := toGenkitMessages(msgs) + if len(result) != 3 { + t.Fatalf("expected 3 messages, got %d", len(result)) + } + if result[0].Role != ai.RoleSystem { + t.Errorf("expected system role, got %s", result[0].Role) + } + if result[2].Role != ai.RoleModel { + t.Errorf("expected model role for assistant, got %s", result[2].Role) + } +} + +func TestToGenkitMessagesToolResult(t *testing.T) { + msgs := []provider.Message{ + {Role: provider.RoleTool, Content: "42", ToolCallID: "add"}, + } + result := toGenkitMessages(msgs) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d", len(result)) + } + if result[0].Role != ai.RoleTool { + t.Errorf("expected tool role, got %s", result[0].Role) + } + if len(result[0].Content) == 0 || result[0].Content[0].ToolResponse == nil { + t.Error("expected ToolResponsePart in message content") + } +} + +func TestFromGenkitResponse(t *testing.T) { + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("Hello world"), + Usage: &ai.GenerationUsage{InputTokens: 10, OutputTokens: 5}, + } + result := fromGenkitResponse(resp) + if result.Content != "Hello world" { + t.Errorf("expected 'Hello world', got %q", result.Content) + } + if result.Usage.InputTokens != 10 { + t.Errorf("expected 10 input tokens, got %d", result.Usage.InputTokens) + } + if result.Usage.OutputTokens != 5 { + t.Errorf("expected 5 output tokens, got %d", result.Usage.OutputTokens) + } +} + +func TestFromGenkitResponseNil(t *testing.T) { + result := fromGenkitResponse(nil) + if result.Content != "" { + t.Errorf("expected empty content, got %q", result.Content) + } +} + +func TestFromGenkitResponseThinking(t *testing.T) { + msg := ai.NewMessage(ai.RoleModel, nil, + ai.NewReasoningPart("I think therefore I am", nil), + ai.NewTextPart("Result text"), + ) + resp := &ai.ModelResponse{Message: msg} + result := fromGenkitResponse(resp) + if result.Thinking != "I think therefore I am" { + t.Errorf("expected thinking trace, got %q", result.Thinking) + } + if result.Content != "Result text" { + t.Errorf("expected 'Result text', got %q", result.Content) + } +} + +func TestFromGenkitChunkText(t *testing.T) { + chunk := &ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewTextPart("hello")}, + } + ev := fromGenkitChunk(chunk) + if ev.Type != "text" { + t.Errorf("expected type 'text', got %q", ev.Type) + } + if ev.Text != "hello" { + t.Errorf("expected 'hello', got %q", ev.Text) + } +} + +func TestFromGenkitChunkThinking(t *testing.T) { + chunk := &ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewReasoningPart("thinking...", nil)}, + } + ev := fromGenkitChunk(chunk) + if ev.Type != "thinking" { + t.Errorf("expected type 'thinking', got %q", ev.Type) + } + if ev.Thinking != "thinking..." { + t.Errorf("expected 'thinking...', got %q", ev.Thinking) + } +} + +func TestFromGenkitChunkNil(t *testing.T) { + ev := fromGenkitChunk(nil) + if ev.Type != "done" { + t.Errorf("expected type 'done', got %q", ev.Type) + } +} diff --git a/genkit/genkit.go b/genkit/genkit.go new file mode 100644 index 0000000..2f670dc --- /dev/null +++ b/genkit/genkit.go @@ -0,0 +1,24 @@ +// Package genkit provides Genkit-backed implementations of provider.Provider. +package genkit + +import ( + "context" + "sync" + + gk "github.com/firebase/genkit/go/genkit" +) + +var ( + instance *gk.Genkit + once sync.Once +) + +// Instance returns the shared Genkit instance, initializing it lazily on first call. +// This instance has no plugins and is used for mock/test model definitions. +// Production providers use per-factory Genkit instances initialized with their specific plugin. +func Instance(ctx context.Context) *gk.Genkit { + once.Do(func() { + instance = gk.Init(ctx) + }) + return instance +} diff --git a/genkit/integration_test.go b/genkit/integration_test.go new file mode 100644 index 0000000..3513d25 --- /dev/null +++ b/genkit/integration_test.go @@ -0,0 +1,263 @@ +//go:build integration + +package genkit_test + +import ( + "context" + "os" + "strings" + "testing" + + "github.com/GoCodeAlone/workflow-plugin-agent/genkit" + "github.com/GoCodeAlone/workflow-plugin-agent/provider" +) + +// envOrSkip returns the value of an environment variable, or skips the test if unset. +func envOrSkip(t *testing.T, key string) string { + t.Helper() + v := os.Getenv(key) + if v == "" { + t.Skipf("skipping: %s not set", key) + } + return v +} + +// chatRoundTrip sends a simple message and asserts a non-empty text response. +func chatRoundTrip(t *testing.T, p provider.Provider) { + t.Helper() + ctx := context.Background() + msgs := []provider.Message{{Role: provider.RoleUser, Content: "Say exactly: pong"}} + resp, err := p.Chat(ctx, msgs, nil) + if err != nil { + t.Fatalf("Chat error: %v", err) + } + if resp.Content == "" { + t.Error("expected non-empty content") + } +} + +// streamRoundTrip streams a simple message and asserts at least one text event and a done event. +func streamRoundTrip(t *testing.T, p provider.Provider) { + t.Helper() + ctx := context.Background() + msgs := []provider.Message{{Role: provider.RoleUser, Content: "Say exactly: pong"}} + ch, err := p.Stream(ctx, msgs, nil) + if err != nil { + t.Fatalf("Stream error: %v", err) + } + var gotText, gotDone bool + for ev := range ch { + switch ev.Type { + case "text": + if ev.Text != "" { + gotText = true + } + case "done": + gotDone = true + case "error": + t.Fatalf("stream error event: %s", ev.Error) + } + } + if !gotText { + t.Error("expected at least one text event") + } + if !gotDone { + t.Error("expected a done event") + } +} + +// TestIntegration_Anthropic tests the Anthropic (direct API) provider. +func TestIntegration_Anthropic(t *testing.T) { + apiKey := envOrSkip(t, "ANTHROPIC_API_KEY") + p, err := genkit.NewAnthropicProvider(context.Background(), apiKey, "claude-haiku-4-5-20251001", "", 256) + if err != nil { + t.Fatalf("NewAnthropicProvider: %v", err) + } + if p.Name() != "anthropic" { + t.Errorf("expected name 'anthropic', got %q", p.Name()) + } + chatRoundTrip(t, p) + streamRoundTrip(t, p) +} + +// TestIntegration_OpenAI tests the OpenAI provider. +func TestIntegration_OpenAI(t *testing.T) { + apiKey := envOrSkip(t, "OPENAI_API_KEY") + p, err := genkit.NewOpenAIProvider(context.Background(), apiKey, "gpt-4o-mini", "", 256) + if err != nil { + t.Fatalf("NewOpenAIProvider: %v", err) + } + if p.Name() != "openai" { + t.Errorf("expected name 'openai', got %q", p.Name()) + } + chatRoundTrip(t, p) + streamRoundTrip(t, p) +} + +// TestIntegration_GoogleAI tests the Google AI (Gemini API) provider. +func TestIntegration_GoogleAI(t *testing.T) { + apiKey := envOrSkip(t, "GOOGLE_AI_API_KEY") + p, err := genkit.NewGoogleAIProvider(context.Background(), apiKey, "gemini-2.0-flash", 256) + if err != nil { + t.Fatalf("NewGoogleAIProvider: %v", err) + } + if p.Name() != "googleai" { + t.Errorf("expected name 'googleai', got %q", p.Name()) + } + chatRoundTrip(t, p) + streamRoundTrip(t, p) +} + +// TestIntegration_Ollama tests the Ollama local provider. +func TestIntegration_Ollama(t *testing.T) { + serverAddr := os.Getenv("OLLAMA_SERVER") + if serverAddr == "" { + serverAddr = "http://localhost:11434" + } + model := os.Getenv("OLLAMA_MODEL") + if model == "" { + t.Skip("skipping: OLLAMA_MODEL not set") + } + p, err := genkit.NewOllamaProvider(context.Background(), model, serverAddr, 256) + if err != nil { + t.Fatalf("NewOllamaProvider: %v", err) + } + if p.Name() != "ollama" { + t.Errorf("expected name 'ollama', got %q", p.Name()) + } + chatRoundTrip(t, p) + streamRoundTrip(t, p) +} + +// TestIntegration_OpenRouter tests the OpenRouter provider (OpenAI-compatible). +func TestIntegration_OpenRouter(t *testing.T) { + apiKey := envOrSkip(t, "OPENROUTER_API_KEY") + model := os.Getenv("OPENROUTER_MODEL") + if model == "" { + model = "openai/gpt-4o-mini" + } + p, err := genkit.NewOpenAICompatibleProvider( + context.Background(), "openrouter", apiKey, model, + "https://openrouter.ai/api/v1", 256, + ) + if err != nil { + t.Fatalf("NewOpenAICompatibleProvider: %v", err) + } + chatRoundTrip(t, p) +} + +// TestIntegration_Bedrock tests the AWS Bedrock Anthropic provider. +func TestIntegration_Bedrock(t *testing.T) { + accessKey := envOrSkip(t, "AWS_ACCESS_KEY_ID") + secretKey := envOrSkip(t, "AWS_SECRET_ACCESS_KEY") + region := os.Getenv("AWS_REGION") + if region == "" { + region = "us-east-1" + } + p, err := genkit.NewBedrockProvider( + context.Background(), + region, "anthropic.claude-haiku-4-20250514-v1:0", + accessKey, secretKey, "", "", 256, + ) + if err != nil { + t.Fatalf("NewBedrockProvider: %v", err) + } + chatRoundTrip(t, p) +} + +// TestIntegration_VertexAI tests the Google Vertex AI provider. +func TestIntegration_VertexAI(t *testing.T) { + projectID := envOrSkip(t, "VERTEX_PROJECT_ID") + region := os.Getenv("VERTEX_REGION") + if region == "" { + region = "us-central1" + } + p, err := genkit.NewVertexAIProvider( + context.Background(), + projectID, region, "gemini-2.0-flash", "", 256, + ) + if err != nil { + t.Fatalf("NewVertexAIProvider: %v", err) + } + chatRoundTrip(t, p) +} + +// TestIntegration_AzureOpenAI tests the Azure OpenAI provider. +func TestIntegration_AzureOpenAI(t *testing.T) { + resource := envOrSkip(t, "AZURE_OPENAI_RESOURCE") + deploymentName := envOrSkip(t, "AZURE_OPENAI_DEPLOYMENT") + apiKey := envOrSkip(t, "AZURE_OPENAI_API_KEY") + p, err := genkit.NewAzureOpenAIProvider( + context.Background(), + resource, deploymentName, "2024-10-21", apiKey, "", 256, + ) + if err != nil { + t.Fatalf("NewAzureOpenAIProvider: %v", err) + } + chatRoundTrip(t, p) +} + +// TestIntegration_AnthropicFoundry tests the Anthropic on Azure AI Foundry provider. +func TestIntegration_AnthropicFoundry(t *testing.T) { + resource := envOrSkip(t, "FOUNDRY_RESOURCE") + apiKey := envOrSkip(t, "FOUNDRY_API_KEY") + model := os.Getenv("FOUNDRY_MODEL") + if model == "" { + model = "claude-haiku-4-20250514" + } + p, err := genkit.NewAnthropicFoundryProvider( + context.Background(), resource, model, apiKey, "", 256, + ) + if err != nil { + t.Fatalf("NewAnthropicFoundryProvider: %v", err) + } + chatRoundTrip(t, p) +} + +// TestIntegration_ProviderInterface verifies that all concrete providers satisfy provider.Provider. +func TestIntegration_ProviderInterface(t *testing.T) { + apiKey := envOrSkip(t, "ANTHROPIC_API_KEY") + p, err := genkit.NewAnthropicProvider(context.Background(), apiKey, "claude-haiku-4-5-20251001", "", 256) + if err != nil { + t.Fatalf("NewAnthropicProvider: %v", err) + } + var _ provider.Provider = p + if p.AuthModeInfo().Mode == "" { + t.Error("expected non-empty AuthModeInfo.Mode") + } +} + +// TestIntegration_StreamThinkingTrace verifies thinking traces propagate via streaming. +// Uses a model that supports extended thinking (claude-sonnet or claude-3-7-sonnet). +func TestIntegration_StreamThinkingTrace(t *testing.T) { + apiKey := envOrSkip(t, "ANTHROPIC_API_KEY") + // claude-3-7-sonnet-20250219 supports extended thinking + p, err := genkit.NewAnthropicProvider(context.Background(), apiKey, "claude-sonnet-4-20250514", "", 2048) + if err != nil { + t.Fatalf("NewAnthropicProvider: %v", err) + } + ctx := context.Background() + msgs := []provider.Message{{Role: provider.RoleUser, Content: "What is 2+2? Think step by step."}} + ch, err := p.Stream(ctx, msgs, nil) + if err != nil { + t.Fatalf("Stream error: %v", err) + } + var textBuf strings.Builder + var gotDone bool + for ev := range ch { + switch ev.Type { + case "text": + textBuf.WriteString(ev.Text) + case "done": + gotDone = true + case "error": + t.Fatalf("stream error: %s", ev.Error) + } + } + if !gotDone { + t.Error("expected done event") + } + if textBuf.Len() == 0 { + t.Error("expected non-empty text response") + } +} diff --git a/genkit/providers.go b/genkit/providers.go new file mode 100644 index 0000000..f7da8a6 --- /dev/null +++ b/genkit/providers.go @@ -0,0 +1,327 @@ +package genkit + +import ( + "context" + "fmt" + "os" + "sync" + + "github.com/GoCodeAlone/workflow-plugin-agent/provider" + gk "github.com/firebase/genkit/go/genkit" + anthropicPlugin "github.com/firebase/genkit/go/plugins/anthropic" + "github.com/firebase/genkit/go/plugins/compat_oai" + openaiPlugin "github.com/firebase/genkit/go/plugins/compat_oai/openai" + "github.com/firebase/genkit/go/plugins/googlegenai" + ollamaPlugin "github.com/firebase/genkit/go/plugins/ollama" + "github.com/openai/openai-go/option" +) + +// Default models per provider when none specified. +const ( + defaultAnthropicModel = "claude-sonnet-4-6" + defaultOpenAIModel = "gpt-4.1" + defaultGeminiModel = "gemini-2.5-flash" + defaultOllamaModel = "qwen3:8b" +) + +// vertexCredsMu guards the GOOGLE_APPLICATION_CREDENTIALS env var +// to prevent races when multiple VertexAI providers initialize concurrently. +var vertexCredsMu sync.Mutex + + +// initGenkitWithPlugin creates a Genkit instance with a single plugin registered. +func initGenkitWithPlugin(ctx context.Context, plugin gk.GenkitOption) *gk.Genkit { + return gk.Init(ctx, plugin) +} + +// NewAnthropicProvider creates a provider backed by Genkit's Anthropic plugin. +func NewAnthropicProvider(ctx context.Context, apiKey, model, baseURL string, maxTokens int) (provider.Provider, error) { + if apiKey == "" { + return nil, fmt.Errorf("anthropic: APIKey is required") + } + if model == "" { + model = defaultAnthropicModel + } + if err := provider.ValidateBaseURL(baseURL); err != nil { + return nil, fmt.Errorf("anthropic: %w", err) + } + p := &anthropicPlugin.Anthropic{APIKey: apiKey, BaseURL: baseURL} + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "anthropic/" + model, + name: "anthropic", + maxTokens: maxTokens, + authInfo: provider.AuthModeInfo{ + Mode: "api_key", + DisplayName: "Anthropic", + ServerSafe: true, + }, + }, nil +} + +// NewOpenAIProvider creates a provider backed by Genkit's OpenAI plugin. +func NewOpenAIProvider(ctx context.Context, apiKey, model, baseURL string, maxTokens int) (provider.Provider, error) { + if apiKey == "" { + return nil, fmt.Errorf("openai: APIKey is required") + } + if model == "" { + model = defaultOpenAIModel + } + if err := provider.ValidateBaseURL(baseURL); err != nil { + return nil, fmt.Errorf("openai: %w", err) + } + var extraOpts []option.RequestOption + if baseURL != "" { + extraOpts = append(extraOpts, option.WithBaseURL(baseURL)) + } + p := &openaiPlugin.OpenAI{APIKey: apiKey, Opts: extraOpts} + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "openai/" + model, + name: "openai", + maxTokens: maxTokens, + authInfo: provider.AuthModeInfo{ + Mode: "api_key", + DisplayName: "OpenAI", + ServerSafe: true, + }, + }, nil +} + +// NewGoogleAIProvider creates a provider backed by Genkit's Google AI plugin (Gemini API). +func NewGoogleAIProvider(ctx context.Context, apiKey, model string, maxTokens int) (provider.Provider, error) { + if apiKey == "" { + return nil, fmt.Errorf("googleai: APIKey is required") + } + if model == "" { + model = defaultGeminiModel + } + p := &googlegenai.GoogleAI{APIKey: apiKey} + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "googleai/" + model, + name: "googleai", + maxTokens: maxTokens, + authInfo: provider.AuthModeInfo{ + Mode: "api_key", + DisplayName: "Google AI (Gemini)", + ServerSafe: true, + }, + }, nil +} + +// NewOllamaProvider creates a provider backed by Genkit's Ollama plugin. +func NewOllamaProvider(ctx context.Context, model, serverAddress string, maxTokens int) (provider.Provider, error) { + if serverAddress == "" { + serverAddress = "http://localhost:11434" + } + if model == "" { + model = defaultOllamaModel + } + p := &ollamaPlugin.Ollama{ServerAddress: serverAddress} + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "ollama/" + model, + name: "ollama", + maxTokens: maxTokens, + authInfo: provider.AuthModeInfo{ + Mode: "none", + DisplayName: "Ollama (local)", + ServerSafe: true, + }, + }, nil +} + +// NewOpenAICompatibleProvider creates a provider for OpenAI-compatible endpoints. +// Used for OpenRouter, Copilot, Cohere, HuggingFace, llama.cpp, etc. +func NewOpenAICompatibleProvider(ctx context.Context, providerName, apiKey, model, baseURL string, maxTokens int) (provider.Provider, error) { + if model == "" { + model = defaultOpenAIModel + } + // Skip SSRF validation for local providers that use localhost endpoints. + switch providerName { + case "llama_cpp", "local", "test": + // Local providers intentionally use http://localhost — skip SSRF checks. + default: + if err := provider.ValidateBaseURL(baseURL); err != nil { + return nil, fmt.Errorf("%s: %w", providerName, err) + } + } + effectiveKey := apiKey + if effectiveKey == "" { + // Use a placeholder to avoid errors when no key is needed (e.g., local endpoints) + effectiveKey = "no-key" + } + p := &compat_oai.OpenAICompatible{ + Provider: providerName, + APIKey: effectiveKey, + BaseURL: baseURL, + } + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: providerName + "/" + model, + name: providerName, + maxTokens: maxTokens, + authInfo: provider.AuthModeInfo{ + Mode: "api_key", + DisplayName: providerName, + ServerSafe: true, + }, + }, nil +} + +// NewAzureOpenAIProvider creates a provider for Azure OpenAI Service. +func NewAzureOpenAIProvider(ctx context.Context, resource, deploymentName, apiVersion, apiKey, entraToken string, maxTokens int) (provider.Provider, error) { + if resource == "" { + return nil, fmt.Errorf("openai_azure: resource is required") + } + if apiKey == "" && entraToken == "" { + return nil, fmt.Errorf("openai_azure: apiKey or entraToken is required") + } + if apiVersion == "" { + apiVersion = "2024-10-21" + } + + baseURL := fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s", resource, deploymentName) + + opts := []option.RequestOption{ + option.WithBaseURL(baseURL), + option.WithHeaderDel("authorization"), + option.WithQuery("api-version", apiVersion), + } + if apiKey != "" { + opts = append(opts, option.WithHeader("api-key", apiKey)) + } else { + opts = append(opts, option.WithHeader("authorization", "Bearer "+entraToken)) + } + + // Use a placeholder API key to avoid requiring OPENAI_API_KEY env var + p := &openaiPlugin.OpenAI{APIKey: "azure-placeholder", Opts: opts} + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "openai/" + deploymentName, + name: "openai_azure", + maxTokens: maxTokens, + authInfo: provider.AuthModeInfo{ + Mode: "azure", + DisplayName: "OpenAI (Azure OpenAI Service)", + ServerSafe: true, + }, + }, nil +} + +// NewAnthropicFoundryProvider creates a provider for Anthropic on Azure AI Foundry. +func NewAnthropicFoundryProvider(ctx context.Context, resource, model, apiKey, entraToken string, maxTokens int) (provider.Provider, error) { + if resource == "" { + return nil, fmt.Errorf("anthropic_foundry: resource is required") + } + effectiveKey := apiKey + if effectiveKey == "" { + effectiveKey = entraToken + } + if effectiveKey == "" { + return nil, fmt.Errorf("anthropic_foundry: apiKey or entraToken is required") + } + + // Azure AI Foundry Anthropic endpoints + baseURL := fmt.Sprintf("https://%s.services.ai.azure.com/models", resource) + p := &anthropicPlugin.Anthropic{APIKey: effectiveKey, BaseURL: baseURL} + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "anthropic/" + model, + name: "anthropic_foundry", + maxTokens: maxTokens, + authInfo: provider.AuthModeInfo{ + Mode: "azure", + DisplayName: "Anthropic (Azure AI Foundry)", + ServerSafe: true, + }, + }, nil +} + +// NewVertexAIProvider creates a provider backed by Genkit's Vertex AI plugin. +func NewVertexAIProvider(ctx context.Context, projectID, region, model, credentialsJSON string, maxTokens int) (provider.Provider, error) { + if projectID == "" { + return nil, fmt.Errorf("vertexai: projectID is required") + } + if region == "" { + region = "us-central1" + } + + // Genkit's VertexAI plugin uses credentials.DetectDefault() which reads + // GOOGLE_APPLICATION_CREDENTIALS. When inline JSON is provided, write it + // to a temp file, set the env var, init Genkit, then clean up. + var tempCredFile string + if credentialsJSON != "" { + vertexCredsMu.Lock() + defer vertexCredsMu.Unlock() + + prevCreds := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") + + f, err := os.CreateTemp("", "vertexai-creds-*.json") + if err != nil { + return nil, fmt.Errorf("vertexai: create temp credentials file: %w", err) + } + tempCredFile = f.Name() + if _, err := f.WriteString(credentialsJSON); err != nil { + _ = f.Close() + _ = os.Remove(tempCredFile) + return nil, fmt.Errorf("vertexai: write credentials: %w", err) + } + _ = f.Close() + _ = os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", tempCredFile) + defer func() { + // Restore previous env var and remove temp file. + if prevCreds == "" { + _ = os.Unsetenv("GOOGLE_APPLICATION_CREDENTIALS") + } else { + _ = os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", prevCreds) + } + _ = os.Remove(tempCredFile) + }() + } + + p := &googlegenai.VertexAI{ + ProjectID: projectID, + Location: region, + } + g := initGenkitWithPlugin(ctx, gk.WithPlugins(p)) + return &genkitProvider{ + g: g, + modelName: "vertexai/" + model, + name: "vertexai", + maxTokens: maxTokens, + authInfo: provider.AuthModeInfo{ + Mode: "gcp", + DisplayName: "Vertex AI", + ServerSafe: true, + }, + }, nil +} + +// NewBedrockProvider creates a provider for AWS Bedrock using an OpenAI-compatible endpoint. +// Wraps existing Bedrock implementation as a provider.Provider until a native Genkit plugin is available. +func NewBedrockProvider(ctx context.Context, region, model, accessKeyID, secretAccessKey, sessionToken, baseURL string, maxTokens int) (provider.Provider, error) { + if secretAccessKey == "" { + return nil, fmt.Errorf("anthropic_bedrock: secretAccessKey is required") + } + if region == "" { + region = "us-east-1" + } + return provider.NewAnthropicBedrockProvider(provider.AnthropicBedrockConfig{ + Region: region, + Model: model, + MaxTokens: maxTokens, + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + SessionToken: sessionToken, + BaseURL: baseURL, + }) +} diff --git a/genkit/providers_test.go b/genkit/providers_test.go new file mode 100644 index 0000000..ea82f23 --- /dev/null +++ b/genkit/providers_test.go @@ -0,0 +1,90 @@ +package genkit + +import ( + "context" + "testing" +) + +func TestNewAnthropicProvider_MissingKey(t *testing.T) { + _, err := NewAnthropicProvider(context.Background(), "", "claude-sonnet-4-6", "", 4096) + if err == nil { + t.Error("expected error for missing API key") + } +} + +func TestNewOpenAIProvider_MissingKey(t *testing.T) { + _, err := NewOpenAIProvider(context.Background(), "", "gpt-4o", "", 4096) + if err == nil { + t.Error("expected error for missing API key") + } +} + +func TestNewGoogleAIProvider_MissingKey(t *testing.T) { + _, err := NewGoogleAIProvider(context.Background(), "", "gemini-2.0-flash", 4096) + if err == nil { + t.Error("expected error for missing API key") + } +} + +func TestNewOllamaProvider_DefaultAddress(t *testing.T) { + // Ollama doesn't require an API key; verify factory instantiation works + // with default address (no real server needed for creation) + p, err := NewOllamaProvider(context.Background(), "qwen3:8b", "", 4096) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if p == nil { + t.Fatal("expected non-nil provider") + } + if p.Name() != "ollama" { + t.Errorf("expected name 'ollama', got %q", p.Name()) + } +} + +func TestNewOpenAICompatibleProvider_NoKey(t *testing.T) { + // Local providers may not need a key + p, err := NewOpenAICompatibleProvider(context.Background(), "local", "", "model", "http://localhost:8080", 4096) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if p == nil { + t.Fatal("expected non-nil provider") + } +} + +func TestNewAzureOpenAIProvider_MissingResource(t *testing.T) { + _, err := NewAzureOpenAIProvider(context.Background(), "", "gpt-4o", "2024-10-21", "key", "", 4096) + if err == nil { + t.Error("expected error for missing resource") + } +} + +func TestNewAzureOpenAIProvider_MissingCredentials(t *testing.T) { + _, err := NewAzureOpenAIProvider(context.Background(), "myresource", "gpt-4o", "2024-10-21", "", "", 4096) + if err == nil { + t.Error("expected error for missing credentials") + } +} + +func TestNewBedrockProvider_MissingKey(t *testing.T) { + _, err := NewBedrockProvider(context.Background(), "us-east-1", "anthropic.claude-sonnet-4", "", "", "", "", 4096) + if err == nil { + t.Error("expected error for missing secret key") + } +} + +func TestNewVertexAIProvider_MissingProject(t *testing.T) { + _, err := NewVertexAIProvider(context.Background(), "", "us-central1", "gemini-2.0-flash", "", 4096) + if err == nil { + t.Error("expected error for missing project ID") + } +} + +func TestProviderImplementsInterface(t *testing.T) { + // Ensure all returned providers implement provider.Provider + p, err := NewOllamaProvider(context.Background(), "test", "http://localhost:11434", 4096) + if err != nil { + t.Skip("factory failed, skipping interface check") + } + _ = p // already provider.Provider; compile verifies interface +} diff --git a/go.mod b/go.mod index 0ca6d67..a9705be 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/aws/aws-sdk-go-v2 v1.41.4 github.com/aws/aws-sdk-go-v2/credentials v1.19.12 github.com/docker/docker v28.5.2+incompatible + github.com/firebase/genkit/go v1.6.0 github.com/go-rod/rod v0.116.2 github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/generative-ai-go v0.20.1 @@ -18,7 +19,7 @@ require ( github.com/ollama/ollama v0.18.3 github.com/openai/openai-go v1.12.0 golang.org/x/crypto v0.48.0 - golang.org/x/oauth2 v0.36.0 + golang.org/x/sync v0.20.0 golang.org/x/text v0.34.0 google.golang.org/api v0.271.0 gopkg.in/yaml.v3 v3.0.1 @@ -119,8 +120,11 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-sql-driver/mysql v1.7.1 // indirect github.com/gobwas/glob v0.2.3 // indirect + github.com/goccy/go-yaml v1.17.1 // indirect github.com/golobby/cast v1.3.3 // indirect github.com/google/btree v1.1.3 // indirect + github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect @@ -145,6 +149,7 @@ require ( github.com/hashicorp/memberlist v0.5.4 // indirect github.com/hashicorp/vault/api v1.22.0 // indirect github.com/hashicorp/yamux v0.1.2 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/itchyny/gojq v0.12.18 // indirect github.com/itchyny/timefmt-go v0.1.7 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -161,9 +166,10 @@ require ( github.com/json-iterator/go v1.1.13-0.20220915233716-71ac16282d12 // indirect github.com/klauspost/compress v1.18.4 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect - github.com/mailru/easyjson v0.7.7 // indirect + github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a // indirect github.com/miekg/dns v1.1.72 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/go-testing-interface v1.14.1 // indirect @@ -211,6 +217,10 @@ require ( github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.2.0 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + github.com/xeipuuv/gojsonschema v1.2.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/ysmood/fetchup v0.2.3 // indirect github.com/ysmood/goob v0.4.0 // indirect github.com/ysmood/got v0.40.0 // indirect @@ -237,10 +247,11 @@ require ( golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa // indirect golang.org/x/mod v0.33.0 // indirect golang.org/x/net v0.51.0 // indirect - golang.org/x/sync v0.20.0 // indirect + golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/time v0.15.0 // indirect golang.org/x/tools v0.42.0 // indirect + google.golang.org/genai v1.41.0 // indirect google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect diff --git a/go.sum b/go.sum index 87fb36e..f6f5552 100644 --- a/go.sum +++ b/go.sum @@ -256,6 +256,8 @@ github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/firebase/genkit/go v1.6.0 h1:W+oTo/yndFkSe/imGEmsYdW5Qi5w2USG7obiXElAQzY= +github.com/firebase/genkit/go v1.6.0/go.mod h1:vu8ZAqNU6MU5qDza66bvqTtzJoUrqhO/+z5/6dtouJQ= github.com/flowchartsman/retry v1.2.0 h1:qDhlw6RNufXz6RGr+IiYimFpMMkt77SUSHY5tgFaUCU= github.com/flowchartsman/retry v1.2.0/go.mod h1:+sfx8OgCCiAr3t5jh2Gk+T0fRTI+k52edaYxURQxY64= github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= @@ -321,6 +323,8 @@ github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPE github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY= +github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/gofrs/uuid v4.3.1+incompatible h1:0/KbAdpx3UXAx1kEOWHJeOkpbgRFGHVgv+CFIY7dBJI= github.com/gofrs/uuid v4.3.1+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= @@ -347,6 +351,8 @@ github.com/golobby/cast v1.3.3 h1:s2Lawb9RMz7YyYf8IrfMQY4IFmA1R/lgfmj97Vc6fig= github.com/golobby/cast v1.3.3/go.mod h1:0oDO5IT84HTXcbLDf1YXuk0xtg/cRDrxhbpWKxwtJCY= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254 h1:okN800+zMJOGHLJCgry+OGzhhtH6YrjQh1rluHmOacE= +github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254/go.mod h1:k8cjJAQWc//ac/bMnzItyOFbfT01tgRTZGgxELCuxEQ= github.com/google/generative-ai-go v0.20.1 h1:6dEIujpgN2V0PgLhr6c/M1ynRdc7ARtiIDPFzj45uNQ= github.com/google/generative-ai-go v0.20.1/go.mod h1:TjOnZJmZKzarWbjUJgy+r3Ee7HGBRVLhOIgupnwR4Bg= github.com/google/gnostic-models v0.7.1 h1:SisTfuFKJSKM5CPZkffwi6coztzzeYUhc3v4yxLWH8c= @@ -435,6 +441,8 @@ github.com/hashicorp/vault/api v1.22.0 h1:+HYFquE35/B74fHoIeXlZIP2YADVboaPjaSicH github.com/hashicorp/vault/api v1.22.0/go.mod h1:IUZA2cDvr4Ok3+NtK2Oq/r+lJeXkeCrHRmqdyWfpmGM= github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/itchyny/gojq v0.12.18 h1:gFGHyt/MLbG9n6dqnvlliiya2TaMMh6FFaR2b1H6Drc= github.com/itchyny/gojq v0.12.18/go.mod h1:4hPoZ/3lN9fDL1D+aK7DY1f39XZpY9+1Xpjz8atrEkg= github.com/itchyny/timefmt-go v0.1.7 h1:xyftit9Tbw+Dc/huSSPJaEmX1TVL8lw5vxjJLK4GMMA= @@ -465,7 +473,6 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -498,8 +505,8 @@ github.com/lufia/plan9stats v0.0.0-20260216142805-b3301c5f2a88 h1:PTw+yKnXcOFCR6 github.com/lufia/plan9stats v0.0.0-20260216142805-b3301c5f2a88/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= -github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -509,6 +516,8 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a h1:v2cBA3xWKv2cIOVhnzX/gNgkNXqiHfUgJtA3r61Hf7A= +github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a/go.mod h1:Y6ghKH+ZijXn5d9E7qGGZBmjitx7iitZdQiIW97EpTU= github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 h1:KGuD/pM2JpL9FAYvBrnBBeENKZNh6eNtjqytV6TYjnk= @@ -715,8 +724,17 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= +github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/ysmood/fetchup v0.2.3 h1:ulX+SonA0Vma5zUFXtv52Kzip/xe7aj4vqT5AJwQ+ZQ= github.com/ysmood/fetchup v0.2.3/go.mod h1:xhibcRKziSvol0H1/pj33dnKrYyI2ebIvz5cOOkYGns= github.com/ysmood/goob v0.4.0 h1:HsxXhyLBeGzWXnqVKtmT9qM7EuVs/XOgkX7T6r1o1AQ= @@ -896,6 +914,8 @@ gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E google.golang.org/api v0.271.0 h1:cIPN4qcUc61jlh7oXu6pwOQqbJW2GqYh5PS6rB2C/JY= google.golang.org/api v0.271.0/go.mod h1:CGT29bhwkbF+i11qkRUJb2KMKqcJ1hdFceEIRd9u64Q= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genai v1.41.0 h1:ayXl75LjTmqTu0y94yr96d17gIb4zF8gWVzX2TgioEY= +google.golang.org/genai v1.41.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 h1:VQZ/yAbAtjkHgH80teYd2em3xtIkkHd7ZhqfH2N9CsM= google.golang.org/genproto v0.0.0-20260128011058-8636f8732409/go.mod h1:rxKD3IEILWEu3P44seeNOAwZN4SaoKaQ/2eTg4mM6EM= google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171 h1:tu/dtnW1o3wfaxCOjSLn5IRX4YDcJrtlpzYkhHhGaC4= diff --git a/module_provider.go b/module_provider.go index 9a78270..641696c 100644 --- a/module_provider.go +++ b/module_provider.go @@ -7,6 +7,7 @@ import ( "time" "github.com/GoCodeAlone/modular" + gkprov "github.com/GoCodeAlone/workflow-plugin-agent/genkit" "github.com/GoCodeAlone/workflow-plugin-agent/provider" "github.com/GoCodeAlone/workflow/plugin" ) @@ -183,43 +184,42 @@ func newProviderModuleFactory() plugin.ModuleFactory { } case "anthropic": - p = provider.NewAnthropicProvider(provider.AnthropicConfig{ - APIKey: apiKey, - Model: model, - BaseURL: baseURL, - MaxTokens: maxTokens, - }) + if prov, err := gkprov.NewAnthropicProvider(context.TODO() /* ModuleFactory doesn't receive ctx; TODO: thread ctx via Start() */, apiKey, model, baseURL, maxTokens); err != nil { + p = &errProvider{err: err} + } else { + p = prov + } case "openai": - p = provider.NewOpenAIProvider(provider.OpenAIConfig{ - APIKey: apiKey, - Model: model, - BaseURL: baseURL, - MaxTokens: maxTokens, - }) + if prov, err := gkprov.NewOpenAIProvider(context.TODO() /* ModuleFactory doesn't receive ctx; TODO: thread ctx via Start() */, apiKey, model, baseURL, maxTokens); err != nil { + p = &errProvider{err: err} + } else { + p = prov + } case "copilot": - p = provider.NewCopilotProvider(provider.CopilotConfig{ - Token: apiKey, - Model: model, - BaseURL: baseURL, - MaxTokens: maxTokens, - }) + if baseURL == "" { + baseURL = "https://api.githubcopilot.com" + } + if prov, err := gkprov.NewOpenAICompatibleProvider(context.TODO() /* ModuleFactory doesn't receive ctx; TODO: thread ctx via Start() */, "copilot", apiKey, model, baseURL, maxTokens); err != nil { + p = &errProvider{err: err} + } else { + p = prov + } case "ollama": - p = provider.NewOllamaProvider(provider.OllamaConfig{ - Model: model, - BaseURL: baseURL, - MaxTokens: maxTokens, - }) + if prov, err := gkprov.NewOllamaProvider(context.TODO() /* ModuleFactory doesn't receive ctx; TODO: thread ctx via Start() */, model, baseURL, maxTokens); err != nil { + p = &errProvider{err: err} + } else { + p = prov + } case "llama_cpp": - p = provider.NewLlamaCppProvider(provider.LlamaCppConfig{ - BaseURL: baseURL, - ModelPath: model, - ModelName: model, - MaxTokens: maxTokens, - }) + if prov, err := gkprov.NewOpenAICompatibleProvider(context.TODO() /* ModuleFactory doesn't receive ctx; TODO: thread ctx via Start() */, "llama_cpp", "", model, baseURL, maxTokens); err != nil { + p = &errProvider{err: err} + } else { + p = prov + } default: p = &errProvider{err: fmt.Errorf("agent.provider %q: unrecognized provider type %q (supported: mock, test, anthropic, openai, copilot, ollama, llama_cpp)", name, providerType)} diff --git a/orchestrator/plugin.go b/orchestrator/plugin.go index c6b4a3b..09d22a3 100644 --- a/orchestrator/plugin.go +++ b/orchestrator/plugin.go @@ -664,7 +664,7 @@ func testInteractionHook() plugin.WiringHook { } if regSvc, ok := app.SvcRegistry()["ratchet-provider-registry"]; ok { if registry, ok := regSvc.(*ProviderRegistry); ok { - registry.factories["test"] = func(_ string, _ LLMProviderConfig) (provider.Provider, error) { + registry.factories["test"] = func(_ context.Context, _ string, _ LLMProviderConfig) (provider.Provider, error) { return testProvider, nil } if registry.db != nil { @@ -709,7 +709,7 @@ func testInteractionHook() plugin.WiringHook { if regSvc, ok := app.SvcRegistry()["ratchet-provider-registry"]; ok { if registry, ok := regSvc.(*ProviderRegistry); ok { // Register a "test" factory that returns our pre-built test provider - registry.factories["test"] = func(_ string, _ LLMProviderConfig) (provider.Provider, error) { + registry.factories["test"] = func(_ context.Context, _ string, _ LLMProviderConfig) (provider.Provider, error) { return testProvider, nil } // Update the default provider row in the DB from "mock" to "test" diff --git a/orchestrator/provider_registry.go b/orchestrator/provider_registry.go index bc44c9e..a6df1df 100644 --- a/orchestrator/provider_registry.go +++ b/orchestrator/provider_registry.go @@ -8,8 +8,10 @@ import ( "sync" "time" + gkprov "github.com/GoCodeAlone/workflow-plugin-agent/genkit" "github.com/GoCodeAlone/workflow-plugin-agent/provider" "github.com/GoCodeAlone/workflow/secrets" + "golang.org/x/sync/singleflight" ) // LLMProviderConfig represents a configured LLM provider stored in the database. @@ -34,8 +36,10 @@ func (c *LLMProviderConfig) settings() map[string]string { return m } -// ProviderFactory creates a provider.Provider from an API key and config. -type ProviderFactory func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) +// ProviderFactory creates a provider.Provider from a context, API key, and config. +// The context is propagated from the caller (e.g., GetByAlias/GetDefault) to allow +// cancellation and timeout during provider initialization. +type ProviderFactory func(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) // ProviderRegistry manages AI provider lifecycle: factory creation, caching, and DB lookup. type ProviderRegistry struct { @@ -44,6 +48,7 @@ type ProviderRegistry struct { secrets secrets.Provider cache map[string]provider.Provider factories map[string]ProviderFactory + sflight singleflight.Group // deduplicates concurrent cold-start creation per alias } // NewProviderRegistry creates a new ProviderRegistry with built-in factories registered. @@ -205,169 +210,141 @@ func (r *ProviderRegistry) loadConfig(ctx context.Context, alias string) (*LLMPr } // createAndCache resolves the secret, creates the provider via factory, and caches it. +// Uses singleflight to ensure only one goroutine creates a provider per alias, +// avoiding duplicate expensive Genkit init on concurrent cold starts. func (r *ProviderRegistry) createAndCache(ctx context.Context, alias string, cfg *LLMProviderConfig) (provider.Provider, error) { - // Resolve API key from secrets - var apiKey string - if cfg.SecretName != "" && r.secrets != nil { - var err error - apiKey, err = r.secrets.Get(ctx, cfg.SecretName) + result, err, _ := r.sflight.Do(alias, func() (any, error) { + // Re-check cache under singleflight — another caller may have populated it. + r.mu.RLock() + if p, ok := r.cache[alias]; ok { + r.mu.RUnlock() + return p, nil + } + r.mu.RUnlock() + + // Resolve API key from secrets + var apiKey string + if cfg.SecretName != "" && r.secrets != nil { + var err error + apiKey, err = r.secrets.Get(ctx, cfg.SecretName) + if err != nil { + return nil, fmt.Errorf("provider registry: resolve secret %q: %w", cfg.SecretName, err) + } + } + + // Find factory + factory, ok := r.factories[cfg.Type] + if !ok { + return nil, fmt.Errorf("provider registry: unknown provider type %q", cfg.Type) + } + + // Create provider + p, err := factory(ctx, apiKey, *cfg) if err != nil { - return nil, fmt.Errorf("provider registry: resolve secret %q: %w", cfg.SecretName, err) + return nil, fmt.Errorf("provider registry: create %q: %w", alias, err) } - } - // Find factory - factory, ok := r.factories[cfg.Type] - if !ok { - return nil, fmt.Errorf("provider registry: unknown provider type %q", cfg.Type) - } + // Cache + r.mu.Lock() + r.cache[alias] = p + r.mu.Unlock() - // Create provider - p, err := factory(apiKey, *cfg) + return p, nil + }) if err != nil { - return nil, fmt.Errorf("provider registry: create %q: %w", alias, err) + return nil, err } - - // Cache - r.mu.Lock() - r.cache[alias] = p - r.mu.Unlock() - - return p, nil + return result.(provider.Provider), nil } // Built-in factory functions -func mockProviderFactory(_ string, _ LLMProviderConfig) (provider.Provider, error) { +func mockProviderFactory(_ context.Context, _ string, _ LLMProviderConfig) (provider.Provider, error) { return &mockProvider{responses: []string{"I have completed the task."}}, nil } -func anthropicProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewAnthropicProvider(provider.AnthropicConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil +func anthropicProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewAnthropicProvider(ctx, apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } -func openaiProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewOpenAIProvider(provider.OpenAIConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil +func openaiProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewOpenAIProvider(ctx, apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } -func openrouterProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewOpenRouterProvider(provider.OpenRouterConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil +func openrouterProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://openrouter.ai/api/v1" + } + return gkprov.NewOpenAICompatibleProvider(ctx, "openrouter", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } -func copilotProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewCopilotProvider(provider.CopilotConfig{ - Token: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil +func copilotProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://api.githubcopilot.com" + } + return gkprov.NewOpenAICompatibleProvider(ctx, "copilot", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } -func cohereProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewCohereProvider(provider.CohereConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil +func cohereProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://api.cohere.ai/v1" + } + return gkprov.NewOpenAICompatibleProvider(ctx, "cohere", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } -func copilotModelsProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewCopilotModelsProvider(provider.CopilotModelsConfig{ - Token: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil +func copilotModelsProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://models.github.ai/inference" + } + return gkprov.NewOpenAICompatibleProvider(ctx, "copilot_models", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } -func openaiAzureProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { +func openaiAzureProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { s := cfg.settings() - return provider.NewOpenAIAzureProvider(provider.OpenAIAzureConfig{ - Resource: s["resource"], - DeploymentName: s["deployment_name"], - APIVersion: s["api_version"], - APIKey: apiKey, - EntraToken: s["entra_token"], - MaxTokens: cfg.MaxTokens, - }) + return gkprov.NewAzureOpenAIProvider(ctx, + s["resource"], s["deployment_name"], s["api_version"], + apiKey, s["entra_token"], cfg.MaxTokens) } -func anthropicFoundryProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { +func anthropicFoundryProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { s := cfg.settings() - return provider.NewAnthropicFoundryProvider(provider.AnthropicFoundryConfig{ - Resource: s["resource"], - Model: cfg.Model, - MaxTokens: cfg.MaxTokens, - APIKey: apiKey, - EntraToken: s["entra_token"], - }) + return gkprov.NewAnthropicFoundryProvider(ctx, + s["resource"], cfg.Model, apiKey, s["entra_token"], cfg.MaxTokens) } -func anthropicVertexProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { +func anthropicVertexProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { s := cfg.settings() credJSON := s["credentials_json"] if credJSON == "" { credJSON = apiKey // fallback: secret may contain the full GCP credentials JSON } - return provider.NewAnthropicVertexProvider(provider.AnthropicVertexConfig{ - ProjectID: s["project_id"], - Region: s["region"], - Model: cfg.Model, - MaxTokens: cfg.MaxTokens, - CredentialsJSON: credJSON, - }) + return gkprov.NewVertexAIProvider(ctx, + s["project_id"], s["region"], cfg.Model, credJSON, cfg.MaxTokens) } -func geminiProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewGeminiProvider(provider.GeminiConfig{ - APIKey: apiKey, - Model: cfg.Model, - MaxTokens: cfg.MaxTokens, - }) +func geminiProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewGoogleAIProvider(ctx, apiKey, cfg.Model, cfg.MaxTokens) } -func ollamaProviderFactory(_ string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewOllamaProvider(provider.OllamaConfig{ - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil +func ollamaProviderFactory(ctx context.Context, _ string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewOllamaProvider(ctx, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } -func llamaCppProviderFactory(_ string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewLlamaCppProvider(provider.LlamaCppConfig{ - BaseURL: cfg.BaseURL, - ModelPath: cfg.Model, - ModelName: cfg.Model, - MaxTokens: cfg.MaxTokens, - }), nil +func llamaCppProviderFactory(ctx context.Context, _ string, cfg LLMProviderConfig) (provider.Provider, error) { + // llama.cpp serves an OpenAI-compatible API + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "http://127.0.0.1:8080/v1" + } + return gkprov.NewOpenAICompatibleProvider(ctx, "llama_cpp", "", cfg.Model, baseURL, cfg.MaxTokens) } -func anthropicBedrockProviderFactory(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { +func anthropicBedrockProviderFactory(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { s := cfg.settings() - return provider.NewAnthropicBedrockProvider(provider.AnthropicBedrockConfig{ - Region: s["region"], - Model: cfg.Model, - MaxTokens: cfg.MaxTokens, - AccessKeyID: s["access_key_id"], - SecretAccessKey: apiKey, - SessionToken: s["session_token"], - BaseURL: cfg.BaseURL, - }) + return gkprov.NewBedrockProvider(ctx, + s["region"], cfg.Model, s["access_key_id"], apiKey, s["session_token"], cfg.BaseURL, cfg.MaxTokens) } diff --git a/provider/anthropic.go b/provider/anthropic.go deleted file mode 100644 index ae8cd5d..0000000 --- a/provider/anthropic.go +++ /dev/null @@ -1,150 +0,0 @@ -package provider - -import ( - "context" - "fmt" - "net/http" - - anthropic "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/option" -) - -const ( - defaultAnthropicBaseURL = "https://api.anthropic.com" - defaultAnthropicModel = "claude-sonnet-4-20250514" - defaultAnthropicMaxTokens = 4096 - anthropicAPIVersion = "2023-06-01" -) - -// AnthropicConfig holds configuration for the Anthropic provider. -type AnthropicConfig struct { - APIKey string - Model string - BaseURL string - MaxTokens int - HTTPClient *http.Client -} - -// AnthropicProvider implements Provider using the Anthropic Messages API. -type AnthropicProvider struct { - client anthropic.Client - config AnthropicConfig -} - -// NewAnthropicProvider creates a new Anthropic provider with the given config. -func NewAnthropicProvider(cfg AnthropicConfig) *AnthropicProvider { - if cfg.Model == "" { - cfg.Model = defaultAnthropicModel - } - if cfg.BaseURL == "" { - cfg.BaseURL = defaultAnthropicBaseURL - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultAnthropicMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - client := anthropic.NewClient( - option.WithAPIKey(cfg.APIKey), - option.WithBaseURL(cfg.BaseURL), - option.WithHTTPClient(cfg.HTTPClient), - ) - return &AnthropicProvider{client: client, config: cfg} -} - -func (p *AnthropicProvider) Name() string { return "anthropic" } - -func (p *AnthropicProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "direct", - DisplayName: "Anthropic (Direct API)", - Description: "Direct access to Anthropic's Claude models via API key.", - DocsURL: "https://platform.claude.com/docs/en/api/getting-started", - ServerSafe: true, - } -} - -func (p *AnthropicProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - params := toAnthropicParams(p.config.Model, p.config.MaxTokens, messages, tools) - msg, err := p.client.Messages.New(ctx, params) - if err != nil { - return nil, fmt.Errorf("anthropic: %w", err) - } - return fromAnthropicMessage(msg) -} - -func (p *AnthropicProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - params := toAnthropicParams(p.config.Model, p.config.MaxTokens, messages, tools) - stream := p.client.Messages.NewStreaming(ctx, params) - if stream.Err() != nil { - return nil, fmt.Errorf("anthropic: %w", stream.Err()) - } - ch := make(chan StreamEvent, 16) - go streamAnthropicEvents(stream, ch) - return ch, nil -} - -// JSON types used by anthropic_foundry.go (hand-rolled HTTP, no SDK support). - -type anthropicRequest struct { - Model string `json:"model"` - MaxTokens int `json:"max_tokens"` - System string `json:"system,omitempty"` - Messages []anthropicMessage `json:"messages"` - Tools []anthropicTool `json:"tools,omitempty"` - Stream bool `json:"stream,omitempty"` -} - -type anthropicMessage struct { - Role string `json:"role"` - Content any `json:"content"` -} - -type anthropicContent struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input map[string]any `json:"input,omitempty"` - ToolUseID string `json:"tool_use_id,omitempty"` - Content string `json:"content,omitempty"` - IsError bool `json:"is_error,omitempty"` - CacheCtrl *cacheControl `json:"cache_control,omitempty"` -} - -type cacheControl struct { - Type string `json:"type"` -} - -type anthropicTool struct { - Name string `json:"name"` - Description string `json:"description"` - InputSchema map[string]any `json:"input_schema"` -} - -type anthropicResponse struct { - ID string `json:"id"` - Type string `json:"type"` - Content []anthropicRespItem `json:"content"` - Usage anthropicUsage `json:"usage"` - Error *anthropicError `json:"error,omitempty"` -} - -type anthropicRespItem struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input map[string]any `json:"input,omitempty"` -} - -type anthropicUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` -} - -type anthropicError struct { - Type string `json:"type"` - Message string `json:"message"` -} diff --git a/provider/anthropic_convert.go b/provider/anthropic_bedrock_convert.go similarity index 99% rename from provider/anthropic_convert.go rename to provider/anthropic_bedrock_convert.go index bd5450f..429a54a 100644 --- a/provider/anthropic_convert.go +++ b/provider/anthropic_bedrock_convert.go @@ -10,6 +10,8 @@ import ( "github.com/anthropics/anthropic-sdk-go/packages/ssestream" ) +const defaultAnthropicMaxTokens = 4096 + // toAnthropicParams converts provider types to SDK MessageNewParams. func toAnthropicParams(model string, maxTokens int, messages []Message, tools []ToolDef) anthropic.MessageNewParams { params := anthropic.MessageNewParams{ diff --git a/provider/anthropic_convert_test.go b/provider/anthropic_convert_test.go deleted file mode 100644 index ae395c5..0000000 --- a/provider/anthropic_convert_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package provider - -import ( - "encoding/json" - "testing" - - anthropic "github.com/anthropics/anthropic-sdk-go" -) - -// TestToAnthropicParams_AssistantToolCalls verifies that assistant messages with -// tool calls are correctly serialized as tool_use content blocks (multi-turn tool use). -func TestToAnthropicParams_AssistantToolCalls(t *testing.T) { - messages := []Message{ - {Role: RoleUser, Content: "What's the weather?"}, - { - Role: RoleAssistant, - Content: "", - ToolCalls: []ToolCall{ - {ID: "call_1", Name: "get_weather", Arguments: map[string]any{"location": "NYC"}}, - }, - }, - {Role: RoleTool, ToolCallID: "call_1", Content: "72°F, sunny"}, - {Role: RoleAssistant, Content: "It's 72°F and sunny in NYC."}, - } - - params := toAnthropicParams("claude-sonnet-4-20250514", 1024, messages, nil) - - if len(params.Messages) != 4 { - t.Fatalf("expected 4 messages, got %d", len(params.Messages)) - } - - // Message[1] should be an assistant message with a tool_use block. - assistantMsg := params.Messages[1] - if assistantMsg.Role != anthropic.MessageParamRoleAssistant { - t.Fatalf("expected assistant role at index 1, got %q", assistantMsg.Role) - } - blocks := assistantMsg.Content - if len(blocks) != 1 { - t.Fatalf("expected 1 content block in assistant message, got %d", len(blocks)) - } - toolBlock := blocks[0] - if toolBlock.OfToolUse == nil { - t.Fatal("expected tool_use block in assistant message") - } - if toolBlock.OfToolUse.ID != "call_1" { - t.Errorf("tool_use ID = %q, want %q", toolBlock.OfToolUse.ID, "call_1") - } - if toolBlock.OfToolUse.Name != "get_weather" { - t.Errorf("tool_use Name = %q, want %q", toolBlock.OfToolUse.Name, "get_weather") - } - inputBytes, ok := toolBlock.OfToolUse.Input.([]byte) - if !ok { - t.Fatalf("expected []byte input, got %T", toolBlock.OfToolUse.Input) - } - var args map[string]any - if err := json.Unmarshal(inputBytes, &args); err != nil { - t.Fatalf("unmarshal tool_use input: %v", err) - } - if args["location"] != "NYC" { - t.Errorf("tool_use input location = %v, want NYC", args["location"]) - } - - // Message[2] should be a user message with a tool_result block. - userMsg := params.Messages[2] - if userMsg.Role != anthropic.MessageParamRoleUser { - t.Fatalf("expected user role at index 2, got %q", userMsg.Role) - } -} - -// TestToAnthropicParams_AssistantTextAndToolCalls verifies mixed text + tool_use blocks. -func TestToAnthropicParams_AssistantTextAndToolCalls(t *testing.T) { - messages := []Message{ - { - Role: RoleAssistant, - Content: "Let me check that.", - ToolCalls: []ToolCall{ - {ID: "call_2", Name: "search", Arguments: map[string]any{"q": "Go generics"}}, - }, - }, - } - - params := toAnthropicParams("claude-sonnet-4-20250514", 1024, messages, nil) - if len(params.Messages) != 1 { - t.Fatalf("expected 1 message, got %d", len(params.Messages)) - } - blocks := params.Messages[0].Content - if len(blocks) != 2 { - t.Fatalf("expected 2 blocks (text + tool_use), got %d", len(blocks)) - } - if blocks[0].OfText == nil { - t.Error("expected first block to be text") - } - if blocks[1].OfToolUse == nil { - t.Error("expected second block to be tool_use") - } -} - -// TestFromAnthropicMessage_ToolCallUnmarshalError verifies that malformed tool -// call arguments produce an error instead of silent data loss. -func TestFromAnthropicMessage_ToolCallUnmarshalError(t *testing.T) { - msg := &anthropic.Message{ - Content: []anthropic.ContentBlockUnion{ - {Type: "tool_use", ID: "call_bad", Name: "broken_tool", Input: json.RawMessage(`{invalid json}`)}, - }, - } - _, err := fromAnthropicMessage(msg) - if err == nil { - t.Fatal("expected error for malformed tool call input, got nil") - } -} diff --git a/provider/anthropic_foundry.go b/provider/anthropic_foundry.go deleted file mode 100644 index 4ede8db..0000000 --- a/provider/anthropic_foundry.go +++ /dev/null @@ -1,344 +0,0 @@ -package provider - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" -) - -const defaultFoundryModel = "claude-sonnet-4-20250514" - -// AnthropicFoundryConfig configures the Anthropic provider for Microsoft Azure AI Foundry. -// Uses Azure API keys or Entra ID (formerly Azure AD) tokens. -type AnthropicFoundryConfig struct { - // Resource is the Azure AI Services resource name (forms the URL: {resource}.services.ai.azure.com). - Resource string - // Model is the model deployment name. - Model string - // MaxTokens limits the response length. - MaxTokens int - // APIKey is the Azure API key (use this OR Entra ID token, not both). - APIKey string - // EntraToken is a Microsoft Entra ID bearer token (optional, alternative to APIKey). - EntraToken string - // HTTPClient is the HTTP client to use (defaults to http.DefaultClient). - HTTPClient *http.Client -} - -// anthropicFoundryProvider accesses Anthropic models via Azure AI Foundry. -type anthropicFoundryProvider struct { - config AnthropicFoundryConfig - url string -} - -// NewAnthropicFoundryProvider creates a provider that accesses Claude via Azure AI Foundry. -// -// Docs: https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry -func NewAnthropicFoundryProvider(cfg AnthropicFoundryConfig) (*anthropicFoundryProvider, error) { - if cfg.Resource == "" { - return nil, fmt.Errorf("anthropic_foundry: resource name is required") - } - if cfg.APIKey == "" && cfg.EntraToken == "" { - return nil, fmt.Errorf("anthropic_foundry: either APIKey or EntraToken is required") - } - if cfg.Model == "" { - cfg.Model = defaultFoundryModel - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultAnthropicMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - return &anthropicFoundryProvider{ - config: cfg, - url: fmt.Sprintf("https://%s.services.ai.azure.com/anthropic/v1/messages", cfg.Resource), - }, nil -} - -func (p *anthropicFoundryProvider) Name() string { return "anthropic_foundry" } - -func (p *anthropicFoundryProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "foundry", - DisplayName: "Anthropic (Azure AI Foundry)", - Description: "Access Claude models via Microsoft Azure AI Foundry using Azure API keys or Microsoft Entra ID tokens.", - DocsURL: "https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry", - ServerSafe: true, - } -} - -func (p *anthropicFoundryProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - reqBody := p.buildRequest(messages, tools, false) - - data, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(data)) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: create request: %w", err) - } - p.setHeaders(req) - - resp, err := p.config.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: send request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("anthropic_foundry: API error (status %d): %s", resp.StatusCode, string(body)) - } - - var apiResp anthropicResponse - if err := json.Unmarshal(body, &apiResp); err != nil { - return nil, fmt.Errorf("anthropic_foundry: unmarshal response: %w", err) - } - - if apiResp.Error != nil { - return nil, fmt.Errorf("anthropic_foundry: %s: %s", apiResp.Error.Type, apiResp.Error.Message) - } - - return foundryParseResponse(&apiResp), nil -} - -func (p *anthropicFoundryProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - reqBody := p.buildRequest(messages, tools, true) - - data, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(data)) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: create request: %w", err) - } - p.setHeaders(req) - - resp, err := p.config.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("anthropic_foundry: send request: %w", err) - } - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - _ = resp.Body.Close() - return nil, fmt.Errorf("anthropic_foundry: API error (status %d): %s", resp.StatusCode, string(body)) - } - - ch := make(chan StreamEvent, 16) - go foundryReadSSE(resp.Body, ch) - return ch, nil -} - -func (p *anthropicFoundryProvider) setHeaders(req *http.Request) { - req.Header.Set("Content-Type", "application/json") - req.Header.Set("anthropic-version", anthropicAPIVersion) - if p.config.EntraToken != "" { - req.Header.Set("Authorization", "Bearer "+p.config.EntraToken) - } else { - req.Header.Set("api-key", p.config.APIKey) - } -} - -func (p *anthropicFoundryProvider) buildRequest(messages []Message, tools []ToolDef, stream bool) *anthropicRequest { - req := &anthropicRequest{ - Model: p.config.Model, - MaxTokens: p.config.MaxTokens, - Stream: stream, - } - - var apiMessages []anthropicMessage - for _, msg := range messages { - if msg.Role == RoleSystem { - req.System = msg.Content - continue - } - if msg.Role == RoleTool { - apiMessages = append(apiMessages, anthropicMessage{ - Role: "user", - Content: []anthropicContent{ - { - Type: "tool_result", - ToolUseID: msg.ToolCallID, - Content: msg.Content, - }, - }, - }) - continue - } - apiMessages = append(apiMessages, anthropicMessage{ - Role: string(msg.Role), - Content: msg.Content, - }) - } - req.Messages = apiMessages - - for _, t := range tools { - schema := t.Parameters - if schema == nil { - schema = map[string]any{"type": "object", "properties": map[string]any{}} - } - req.Tools = append(req.Tools, anthropicTool{ - Name: t.Name, - Description: t.Description, - InputSchema: schema, - }) - } - - return req -} - -func foundryParseResponse(apiResp *anthropicResponse) *Response { - resp := &Response{ - Usage: Usage{ - InputTokens: apiResp.Usage.InputTokens, - OutputTokens: apiResp.Usage.OutputTokens, - }, - } - - var textParts []string - for _, item := range apiResp.Content { - switch item.Type { - case "text": - textParts = append(textParts, item.Text) - case "tool_use": - resp.ToolCalls = append(resp.ToolCalls, ToolCall{ - ID: item.ID, - Name: item.Name, - Arguments: item.Input, - }) - } - } - resp.Content = strings.Join(textParts, "") - - return resp -} - -func foundryReadSSE(body io.ReadCloser, ch chan<- StreamEvent) { - defer func() { _ = body.Close() }() - defer close(ch) - - scanner := bufio.NewScanner(body) - - var currentToolID, currentToolName string - var toolInputBuf bytes.Buffer - var usage *Usage - - for scanner.Scan() { - line := scanner.Text() - - if !strings.HasPrefix(line, "data: ") { - continue - } - - data := strings.TrimPrefix(line, "data: ") - if data == "[DONE]" { - break - } - - var event struct { - Type string `json:"type"` - Index int `json:"index"` - ContentBlock *struct { - Type string `json:"type"` - ID string `json:"id"` - Name string `json:"name"` - Text string `json:"text"` - } `json:"content_block"` - Delta *struct { - Type string `json:"type"` - Text string `json:"text"` - PartialJSON string `json:"partial_json"` - StopReason string `json:"stop_reason"` - } `json:"delta"` - Message *struct { - Usage anthropicUsage `json:"usage"` - } `json:"message"` - Usage *anthropicUsage `json:"usage"` - } - - if err := json.Unmarshal([]byte(data), &event); err != nil { - continue - } - - switch event.Type { - case "message_start": - if event.Message != nil { - usage = &Usage{ - InputTokens: event.Message.Usage.InputTokens, - OutputTokens: event.Message.Usage.OutputTokens, - } - } - - case "content_block_start": - if event.ContentBlock != nil && event.ContentBlock.Type == "tool_use" { - currentToolID = event.ContentBlock.ID - currentToolName = event.ContentBlock.Name - toolInputBuf.Reset() - } - - case "content_block_delta": - if event.Delta == nil { - continue - } - switch event.Delta.Type { - case "text_delta": - ch <- StreamEvent{Type: "text", Text: event.Delta.Text} - case "input_json_delta": - toolInputBuf.WriteString(event.Delta.PartialJSON) - } - - case "content_block_stop": - if currentToolID != "" { - var args map[string]any - if toolInputBuf.Len() > 0 { - if err := json.Unmarshal(toolInputBuf.Bytes(), &args); err != nil { - ch <- StreamEvent{Type: "error", Error: fmt.Sprintf("tool %s: malformed arguments: %v", currentToolName, err)} - } - } - ch <- StreamEvent{ - Type: "tool_call", - Tool: &ToolCall{ - ID: currentToolID, - Name: currentToolName, - Arguments: args, - }, - } - currentToolID = "" - currentToolName = "" - toolInputBuf.Reset() - } - - case "message_delta": - if event.Usage != nil && usage != nil { - usage.OutputTokens = event.Usage.OutputTokens - } - - case "message_stop": - ch <- StreamEvent{Type: "done", Usage: usage} - return - - case "error": - ch <- StreamEvent{Type: "error", Error: data} - return - } - } - - if err := scanner.Err(); err != nil { - ch <- StreamEvent{Type: "error", Error: fmt.Sprintf("foundry stream read: %v", err)} - } -} diff --git a/provider/anthropic_foundry_test.go b/provider/anthropic_foundry_test.go deleted file mode 100644 index 72a9de4..0000000 --- a/provider/anthropic_foundry_test.go +++ /dev/null @@ -1,311 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestNewAnthropicFoundryProvider_Validation(t *testing.T) { - tests := []struct { - name string - cfg AnthropicFoundryConfig - wantErr string - }{ - { - name: "missing resource", - cfg: AnthropicFoundryConfig{APIKey: "key"}, - wantErr: "resource name is required", - }, - { - name: "missing auth", - cfg: AnthropicFoundryConfig{Resource: "myresource"}, - wantErr: "either APIKey or EntraToken is required", - }, - { - name: "valid with api key", - cfg: AnthropicFoundryConfig{Resource: "myresource", APIKey: "key"}, - }, - { - name: "valid with entra token", - cfg: AnthropicFoundryConfig{Resource: "myresource", EntraToken: "token"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p, err := NewAnthropicFoundryProvider(tt.cfg) - if tt.wantErr != "" { - if err == nil { - t.Fatalf("expected error containing %q, got nil", tt.wantErr) - } - if got := err.Error(); !strings.Contains(got, tt.wantErr) { - t.Fatalf("error %q does not contain %q", got, tt.wantErr) - } - return - } - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if p == nil { - t.Fatal("expected non-nil provider") - } - }) - } -} - -func TestAnthropicFoundryProvider_Defaults(t *testing.T) { - p, err := NewAnthropicFoundryProvider(AnthropicFoundryConfig{ - Resource: "myresource", - APIKey: "key", - }) - if err != nil { - t.Fatal(err) - } - - if p.config.Model != defaultFoundryModel { - t.Errorf("default model = %q, want %q", p.config.Model, defaultFoundryModel) - } - if p.config.MaxTokens != defaultAnthropicMaxTokens { - t.Errorf("default max tokens = %d, want %d", p.config.MaxTokens, defaultAnthropicMaxTokens) - } -} - -func TestAnthropicFoundryProvider_Name(t *testing.T) { - p := &anthropicFoundryProvider{} - if got := p.Name(); got != "anthropic_foundry" { - t.Errorf("Name() = %q, want %q", got, "anthropic_foundry") - } -} - -func TestAnthropicFoundryProvider_AuthModeInfo(t *testing.T) { - p := &anthropicFoundryProvider{} - info := p.AuthModeInfo() - if info.Mode != "foundry" { - t.Errorf("Mode = %q, want %q", info.Mode, "foundry") - } - if info.DisplayName != "Anthropic (Azure AI Foundry)" { - t.Errorf("DisplayName = %q", info.DisplayName) - } -} - -func TestAnthropicFoundryProvider_URLConstruction(t *testing.T) { - p, err := NewAnthropicFoundryProvider(AnthropicFoundryConfig{ - Resource: "my-ai-resource", - APIKey: "key", - }) - if err != nil { - t.Fatal(err) - } - - want := "https://my-ai-resource.services.ai.azure.com/anthropic/v1/messages" - if p.url != want { - t.Errorf("url = %q, want %q", p.url, want) - } -} - -func TestAnthropicFoundryProvider_APIKeyAuth(t *testing.T) { - var gotHeaders http.Header - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotHeaders = r.Header.Clone() - json.NewEncoder(w).Encode(anthropicResponse{ - ID: "msg_123", - Type: "message", - Content: []anthropicRespItem{ - {Type: "text", Text: "hello"}, - }, - Usage: anthropicUsage{InputTokens: 10, OutputTokens: 5}, - }) - })) - defer srv.Close() - - p := &anthropicFoundryProvider{ - config: AnthropicFoundryConfig{ - Resource: "test", - Model: "claude-sonnet-4-20250514", - MaxTokens: 1024, - APIKey: "my-azure-key", - HTTPClient: srv.Client(), - }, - url: srv.URL, - } - - _, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - if got := gotHeaders.Get("api-key"); got != "my-azure-key" { - t.Errorf("api-key header = %q, want %q", got, "my-azure-key") - } - if got := gotHeaders.Get("anthropic-version"); got != anthropicAPIVersion { - t.Errorf("anthropic-version header = %q, want %q", got, anthropicAPIVersion) - } - if got := gotHeaders.Get("Authorization"); got != "" { - t.Errorf("Authorization header should be empty with API key auth, got %q", got) - } -} - -func TestAnthropicFoundryProvider_EntraIDAuth(t *testing.T) { - var gotHeaders http.Header - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotHeaders = r.Header.Clone() - json.NewEncoder(w).Encode(anthropicResponse{ - ID: "msg_123", - Type: "message", - Content: []anthropicRespItem{{Type: "text", Text: "hello"}}, - Usage: anthropicUsage{InputTokens: 10, OutputTokens: 5}, - }) - })) - defer srv.Close() - - p := &anthropicFoundryProvider{ - config: AnthropicFoundryConfig{ - Resource: "test", - Model: "claude-sonnet-4-20250514", - MaxTokens: 1024, - EntraToken: "my-entra-token", - HTTPClient: srv.Client(), - }, - url: srv.URL, - } - - _, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - if got := gotHeaders.Get("Authorization"); got != "Bearer my-entra-token" { - t.Errorf("Authorization header = %q, want %q", got, "Bearer my-entra-token") - } - if got := gotHeaders.Get("api-key"); got != "" { - t.Errorf("api-key header should be empty with Entra auth, got %q", got) - } -} - -func TestAnthropicFoundryProvider_Chat(t *testing.T) { - var gotBody anthropicRequest - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewDecoder(r.Body).Decode(&gotBody) - json.NewEncoder(w).Encode(anthropicResponse{ - ID: "msg_123", - Type: "message", - Content: []anthropicRespItem{ - {Type: "text", Text: "Hello from Foundry!"}, - }, - Usage: anthropicUsage{InputTokens: 15, OutputTokens: 8}, - }) - })) - defer srv.Close() - - p := &anthropicFoundryProvider{ - config: AnthropicFoundryConfig{ - Resource: "test", - Model: "claude-sonnet-4-20250514", - MaxTokens: 1024, - APIKey: "key", - HTTPClient: srv.Client(), - }, - url: srv.URL, - } - - resp, err := p.Chat(t.Context(), []Message{ - {Role: RoleSystem, Content: "You are helpful."}, - {Role: RoleUser, Content: "Hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - if resp.Content != "Hello from Foundry!" { - t.Errorf("Content = %q, want %q", resp.Content, "Hello from Foundry!") - } - if resp.Usage.InputTokens != 15 { - t.Errorf("InputTokens = %d, want 15", resp.Usage.InputTokens) - } - if resp.Usage.OutputTokens != 8 { - t.Errorf("OutputTokens = %d, want 8", resp.Usage.OutputTokens) - } - - // Verify request body matches Anthropic Messages format - if gotBody.Model != "claude-sonnet-4-20250514" { - t.Errorf("request model = %q", gotBody.Model) - } - if gotBody.System != "You are helpful." { - t.Errorf("request system = %q", gotBody.System) - } - if len(gotBody.Messages) != 1 { - t.Fatalf("request messages len = %d, want 1", len(gotBody.Messages)) - } -} - -func TestAnthropicFoundryProvider_Stream(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - flusher := w.(http.Flusher) - - events := []string{ - `{"type":"message_start","message":{"usage":{"input_tokens":20,"output_tokens":0}}}`, - `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`, - `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`, - `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" world"}}`, - `{"type":"content_block_stop","index":0}`, - `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":10}}`, - `{"type":"message_stop"}`, - } - - for _, e := range events { - fmt.Fprintf(w, "data: %s\n\n", e) - flusher.Flush() - } - })) - defer srv.Close() - - p := &anthropicFoundryProvider{ - config: AnthropicFoundryConfig{ - Resource: "test", - Model: "claude-sonnet-4-20250514", - MaxTokens: 1024, - APIKey: "key", - HTTPClient: srv.Client(), - }, - url: srv.URL, - } - - ch, err := p.Stream(t.Context(), []Message{ - {Role: RoleUser, Content: "Hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - var texts []string - var done bool - for ev := range ch { - switch ev.Type { - case "text": - texts = append(texts, ev.Text) - case "done": - done = true - if ev.Usage == nil { - t.Error("expected usage in done event") - } else if ev.Usage.OutputTokens != 10 { - t.Errorf("OutputTokens = %d, want 10", ev.Usage.OutputTokens) - } - } - } - - if !done { - t.Error("expected done event") - } - if got := strings.Join(texts, ""); got != "Hello world" { - t.Errorf("streamed text = %q, want %q", got, "Hello world") - } -} diff --git a/provider/anthropic_vertex.go b/provider/anthropic_vertex.go deleted file mode 100644 index fff35bf..0000000 --- a/provider/anthropic_vertex.go +++ /dev/null @@ -1,188 +0,0 @@ -package provider - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - - anthropic "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/option" - "github.com/anthropics/anthropic-sdk-go/vertex" - - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" -) - -const ( - defaultVertexModel = "claude-sonnet-4@20250514" - defaultVertexRegion = "us-east5" -) - -// AnthropicVertexConfig configures the Anthropic provider for Google Vertex AI. -// Uses GCP Application Default Credentials (ADC) or explicit OAuth2 tokens. -type AnthropicVertexConfig struct { - // ProjectID is the GCP project ID. - ProjectID string - // Region is the GCP region (e.g. "us-east5", "europe-west1"). - Region string - // Model is the Vertex model ID (e.g. "claude-sonnet-4@20250514"). - Model string - // MaxTokens limits the response length. - MaxTokens int - // CredentialsJSON is the GCP service account JSON (optional if using ADC). - CredentialsJSON string - // TokenSource provides OAuth2 tokens (optional, for testing or custom auth). - // If set, CredentialsJSON and ADC are ignored. - TokenSource oauth2.TokenSource - // HTTPClient is the HTTP client to use (defaults to http.DefaultClient). - HTTPClient *http.Client -} - -// anthropicVertexProvider accesses Anthropic models via Google Vertex AI. -type anthropicVertexProvider struct { - client anthropic.Client - config AnthropicVertexConfig -} - -// NewAnthropicVertexProvider creates a provider that accesses Claude via Google Vertex AI. -// -// Docs: https://platform.claude.com/docs/en/build-with-claude/claude-on-vertex-ai -func NewAnthropicVertexProvider(cfg AnthropicVertexConfig) (*anthropicVertexProvider, error) { - if cfg.ProjectID == "" { - return nil, fmt.Errorf("anthropic_vertex: project ID is required") - } - if cfg.Region == "" { - cfg.Region = defaultVertexRegion - } - if cfg.Model == "" { - cfg.Model = defaultVertexModel - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultAnthropicMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - - var client anthropic.Client - - baseURL := fmt.Sprintf("https://%s-aiplatform.googleapis.com/", cfg.Region) - - if cfg.TokenSource != nil { - // Testing / custom-auth path: use middleware for token injection and path rewriting. - client = anthropic.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(cfg.HTTPClient), - option.WithMiddleware(vertexPathRewriteMiddleware(cfg.Region, cfg.ProjectID)), - option.WithMiddleware(vertexBearerTokenMiddleware(cfg.TokenSource)), - ) - } else if cfg.CredentialsJSON != "" { - creds, err := google.CredentialsFromJSON(context.Background(), []byte(cfg.CredentialsJSON), - "https://www.googleapis.com/auth/cloud-platform") - if err != nil { - return nil, fmt.Errorf("anthropic_vertex: parse credentials JSON: %w", err) - } - client = anthropic.NewClient( - vertex.WithCredentials(context.Background(), cfg.Region, cfg.ProjectID, creds), - ) - } else { - client = anthropic.NewClient( - vertex.WithGoogleAuth(context.Background(), cfg.Region, cfg.ProjectID, - "https://www.googleapis.com/auth/cloud-platform"), - ) - } - - return &anthropicVertexProvider{client: client, config: cfg}, nil -} - -func (p *anthropicVertexProvider) Name() string { return "anthropic_vertex" } - -func (p *anthropicVertexProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "vertex", - DisplayName: "Anthropic (Google Vertex AI)", - Description: "Access Claude models via Google Cloud Vertex AI using Application Default Credentials (ADC) or service account JSON.", - DocsURL: "https://platform.claude.com/docs/en/build-with-claude/claude-on-vertex-ai", - ServerSafe: true, - } -} - -func (p *anthropicVertexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - params := toAnthropicParams(p.config.Model, p.config.MaxTokens, messages, tools) - msg, err := p.client.Messages.New(ctx, params) - if err != nil { - return nil, fmt.Errorf("anthropic_vertex: %w", err) - } - return fromAnthropicMessage(msg) -} - -func (p *anthropicVertexProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - params := toAnthropicParams(p.config.Model, p.config.MaxTokens, messages, tools) - stream := p.client.Messages.NewStreaming(ctx, params) - if stream.Err() != nil { - return nil, fmt.Errorf("anthropic_vertex: %w", stream.Err()) - } - ch := make(chan StreamEvent, 16) - go streamAnthropicEvents(stream, ch) - return ch, nil -} - -// vertexPathRewriteMiddleware rewrites /v1/messages to the Vertex AI model path. -func vertexPathRewriteMiddleware(region, projectID string) option.Middleware { - return func(r *http.Request, next option.MiddlewareNext) (*http.Response, error) { - if r.Body == nil || r.URL.Path != "/v1/messages" || r.Method != http.MethodPost { - return next(r) - } - body, err := io.ReadAll(r.Body) - if err != nil { - return nil, err - } - r.Body.Close() - - var params struct { - Model string `json:"model"` - Stream bool `json:"stream"` - } - _ = json.Unmarshal(body, ¶ms) - - // Remove model and stream from body (Vertex uses URL path instead) - var m map[string]any - _ = json.Unmarshal(body, &m) - delete(m, "model") - delete(m, "stream") - body, _ = json.Marshal(m) - - specifier := "rawPredict" - if params.Stream { - specifier = "streamRawPredict" - } - r.URL.Path = fmt.Sprintf("/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", - projectID, region, params.Model, specifier) - - reader := bytes.NewReader(body) - r.Body = io.NopCloser(reader) - r.GetBody = func() (io.ReadCloser, error) { - _, _ = reader.Seek(0, 0) - return io.NopCloser(reader), nil - } - r.ContentLength = int64(len(body)) - - return next(r) - } -} - -// vertexBearerTokenMiddleware adds a GCP Bearer token to each request. -func vertexBearerTokenMiddleware(ts oauth2.TokenSource) option.Middleware { - return func(r *http.Request, next option.MiddlewareNext) (*http.Response, error) { - token, err := ts.Token() - if err != nil { - return nil, fmt.Errorf("anthropic_vertex: get token: %w", err) - } - r.Header.Set("Authorization", "Bearer "+token.AccessToken) - r.Header.Set("anthropic-version", anthropicAPIVersion) - return next(r) - } -} diff --git a/provider/anthropic_vertex_test.go b/provider/anthropic_vertex_test.go deleted file mode 100644 index 196ca96..0000000 --- a/provider/anthropic_vertex_test.go +++ /dev/null @@ -1,271 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "golang.org/x/oauth2" -) - -// mockTokenSource returns a fixed token for testing. -type mockTokenSource struct { - token string -} - -func (m *mockTokenSource) Token() (*oauth2.Token, error) { - return &oauth2.Token{AccessToken: m.token}, nil -} - -func TestNewAnthropicVertexProvider_Validation(t *testing.T) { - tests := []struct { - name string - cfg AnthropicVertexConfig - wantErr string - }{ - { - name: "missing project ID", - cfg: AnthropicVertexConfig{TokenSource: &mockTokenSource{token: "tok"}}, - wantErr: "project ID is required", - }, - { - name: "valid with token source", - cfg: AnthropicVertexConfig{ - ProjectID: "my-project", - TokenSource: &mockTokenSource{token: "tok"}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p, err := NewAnthropicVertexProvider(tt.cfg) - if tt.wantErr != "" { - if err == nil { - t.Fatalf("expected error containing %q, got nil", tt.wantErr) - } - if got := err.Error(); !strings.Contains(got, tt.wantErr) { - t.Fatalf("error %q does not contain %q", got, tt.wantErr) - } - return - } - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if p == nil { - t.Fatal("expected non-nil provider") - } - }) - } -} - -func TestAnthropicVertexProvider_Defaults(t *testing.T) { - p, err := NewAnthropicVertexProvider(AnthropicVertexConfig{ - ProjectID: "my-project", - TokenSource: &mockTokenSource{token: "tok"}, - }) - if err != nil { - t.Fatal(err) - } - - if p.config.Model != defaultVertexModel { - t.Errorf("default model = %q, want %q", p.config.Model, defaultVertexModel) - } - if p.config.Region != defaultVertexRegion { - t.Errorf("default region = %q, want %q", p.config.Region, defaultVertexRegion) - } - if p.config.MaxTokens != defaultAnthropicMaxTokens { - t.Errorf("default max tokens = %d, want %d", p.config.MaxTokens, defaultAnthropicMaxTokens) - } -} - -func TestAnthropicVertexProvider_Name(t *testing.T) { - p := &anthropicVertexProvider{} - if got := p.Name(); got != "anthropic_vertex" { - t.Errorf("Name() = %q, want %q", got, "anthropic_vertex") - } -} - -func TestAnthropicVertexProvider_AuthModeInfo(t *testing.T) { - p := &anthropicVertexProvider{} - info := p.AuthModeInfo() - if info.Mode != "vertex" { - t.Errorf("Mode = %q, want %q", info.Mode, "vertex") - } - if info.DisplayName != "Anthropic (Google Vertex AI)" { - t.Errorf("DisplayName = %q", info.DisplayName) - } -} - -func TestAnthropicVertexProvider_BearerAuth(t *testing.T) { - var gotHeaders http.Header - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotHeaders = r.Header.Clone() - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "id": "msg_123", - "type": "message", - "content": []any{map[string]any{"type": "text", "text": "hello"}}, - "usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, - }) - })) - defer srv.Close() - - p, err := NewAnthropicVertexProvider(AnthropicVertexConfig{ - ProjectID: "proj", - Region: "us-east5", - Model: "claude-sonnet-4@20250514", - MaxTokens: 1024, - TokenSource: &mockTokenSource{token: "my-gcp-token"}, - HTTPClient: &http.Client{Transport: &urlRewriteTransport{target: srv.URL}}, - }) - if err != nil { - t.Fatal(err) - } - - _, err = p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - if got := gotHeaders.Get("Authorization"); got != "Bearer my-gcp-token" { - t.Errorf("Authorization header = %q, want %q", got, "Bearer my-gcp-token") - } - if got := gotHeaders.Get("anthropic-version"); got != anthropicAPIVersion { - t.Errorf("anthropic-version header = %q, want %q", got, anthropicAPIVersion) - } - // Vertex should NOT set x-api-key - if got := gotHeaders.Get("x-api-key"); got != "" { - t.Errorf("x-api-key should be empty, got %q", got) - } -} - -// urlRewriteTransport rewrites all request URLs to target for testing. -type urlRewriteTransport struct { - target string - transport http.RoundTripper -} - -func (t *urlRewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - req.URL.Scheme = "http" - req.URL.Host = t.target[len("http://"):] - if t.transport != nil { - return t.transport.RoundTrip(req) - } - return http.DefaultTransport.RoundTrip(req) -} - -func TestAnthropicVertexProvider_Chat(t *testing.T) { - var gotBody map[string]any - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewDecoder(r.Body).Decode(&gotBody) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "id": "msg_123", - "type": "message", - "content": []any{ - map[string]any{"type": "text", "text": "Hello from Vertex!"}, - }, - "usage": map[string]any{"input_tokens": 15, "output_tokens": 8}, - }) - })) - defer srv.Close() - - p, err := NewAnthropicVertexProvider(AnthropicVertexConfig{ - ProjectID: "proj", - Region: "us-east5", - Model: "claude-sonnet-4@20250514", - MaxTokens: 1024, - TokenSource: &mockTokenSource{token: "tok"}, - HTTPClient: &http.Client{Transport: &urlRewriteTransport{target: srv.URL}}, - }) - if err != nil { - t.Fatal(err) - } - - resp, err := p.Chat(t.Context(), []Message{ - {Role: RoleSystem, Content: "You are helpful."}, - {Role: RoleUser, Content: "Hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - if resp.Content != "Hello from Vertex!" { - t.Errorf("Content = %q, want %q", resp.Content, "Hello from Vertex!") - } - if resp.Usage.InputTokens != 15 { - t.Errorf("InputTokens = %d, want 15", resp.Usage.InputTokens) - } -} - -func TestAnthropicVertexProvider_Stream(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - flusher := w.(http.Flusher) - - events := []struct{ typ, data string }{ - {"message_start", `{"type":"message_start","message":{"usage":{"input_tokens":20,"output_tokens":0}}}`}, - {"content_block_start", `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`}, - {"content_block_delta", `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`}, - {"content_block_delta", `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" Vertex"}}`}, - {"content_block_stop", `{"type":"content_block_stop","index":0}`}, - {"message_delta", `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":10}}`}, - {"message_stop", `{"type":"message_stop"}`}, - } - - for _, e := range events { - fmt.Fprintf(w, "event: %s\ndata: %s\n\n", e.typ, e.data) - flusher.Flush() - } - })) - defer srv.Close() - - p, err := NewAnthropicVertexProvider(AnthropicVertexConfig{ - ProjectID: "proj", - Region: "us-east5", - Model: "claude-sonnet-4@20250514", - MaxTokens: 1024, - TokenSource: &mockTokenSource{token: "tok"}, - HTTPClient: &http.Client{Transport: &urlRewriteTransport{target: srv.URL}}, - }) - if err != nil { - t.Fatal(err) - } - - ch, err := p.Stream(t.Context(), []Message{ - {Role: RoleUser, Content: "Hi"}, - }, nil) - if err != nil { - t.Fatal(err) - } - - var texts []string - var done bool - for ev := range ch { - switch ev.Type { - case "text": - texts = append(texts, ev.Text) - case "done": - done = true - if ev.Usage == nil { - t.Error("expected usage in done event") - } else if ev.Usage.OutputTokens != 10 { - t.Errorf("OutputTokens = %d, want 10", ev.Usage.OutputTokens) - } - } - } - - if !done { - t.Error("expected done event") - } - if got := strings.Join(texts, ""); got != "Hello Vertex" { - t.Errorf("streamed text = %q, want %q", got, "Hello Vertex") - } -} diff --git a/provider/cohere.go b/provider/cohere.go deleted file mode 100644 index e079af2..0000000 --- a/provider/cohere.go +++ /dev/null @@ -1,446 +0,0 @@ -package provider - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" -) - -const ( - defaultCohereBaseURL = "https://api.cohere.com" - defaultCohereModel = "command-r-plus" - defaultCohereMaxTokens = 4096 -) - -// CohereConfig holds configuration for the Cohere provider. -type CohereConfig struct { - APIKey string - Model string - BaseURL string - MaxTokens int - HTTPClient *http.Client -} - -// CohereProvider implements Provider using the Cohere Chat API v2. -type CohereProvider struct { - config CohereConfig -} - -// NewCohereProvider creates a new Cohere provider with the given config. -func NewCohereProvider(cfg CohereConfig) *CohereProvider { - if cfg.Model == "" { - cfg.Model = defaultCohereModel - } - if cfg.BaseURL == "" { - cfg.BaseURL = defaultCohereBaseURL - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultCohereMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - return &CohereProvider{config: cfg} -} - -func (p *CohereProvider) Name() string { return "cohere" } - -func (p *CohereProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "direct", - DisplayName: "Cohere (Direct API)", - Description: "Direct access to Cohere's Command models via API key.", - DocsURL: "https://docs.cohere.com/reference/chat", - ServerSafe: true, - } -} - -// Cohere Chat API v2 request types - -type cohereRequest struct { - Model string `json:"model"` - Messages []cohereMessage `json:"messages"` - Tools []cohereTool `json:"tools,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Stream bool `json:"stream,omitempty"` -} - -type cohereMessage struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` - ToolCalls []cohereToolCall `json:"tool_calls,omitempty"` -} - -type cohereToolCall struct { - ID string `json:"id"` - Type string `json:"type"` - Function cohereToolFunc `json:"function"` -} - -type cohereToolFunc struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type cohereTool struct { - Type string `json:"type"` - Function cohereToolDef `json:"function"` -} - -type cohereToolDef struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]any `json:"parameters"` -} - -// Cohere Chat API v2 response types - -type cohereResponse struct { - ID string `json:"id"` - Message cohereRespMsg `json:"message"` - FinishReason string `json:"finish_reason"` - Usage cohereUsage `json:"usage"` -} - -type cohereRespMsg struct { - Role string `json:"role"` - Content []cohereContent `json:"content"` - ToolCalls []cohereToolCall `json:"tool_calls,omitempty"` -} - -type cohereContent struct { - Type string `json:"type"` - Text string `json:"text"` -} - -type cohereUsage struct { - BilledUnits cohereBilledUnits `json:"billed_units"` - Tokens cohereTokens `json:"tokens"` -} - -type cohereBilledUnits struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` -} - -type cohereTokens struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` -} - -func (p *CohereProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - reqBody := p.buildRequest(messages, tools, false) - - data, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("cohere: marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.config.BaseURL+"/v2/chat", bytes.NewReader(data)) - if err != nil { - return nil, fmt.Errorf("cohere: create request: %w", err) - } - p.setHeaders(req) - - resp, err := p.config.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("cohere: send request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("cohere: read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("cohere: API error (status %d): %s", resp.StatusCode, string(body)) - } - - var apiResp cohereResponse - if err := json.Unmarshal(body, &apiResp); err != nil { - return nil, fmt.Errorf("cohere: unmarshal response: %w", err) - } - - return p.parseResponse(&apiResp), nil -} - -func (p *CohereProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - reqBody := p.buildRequest(messages, tools, true) - - data, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("cohere: marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.config.BaseURL+"/v2/chat", bytes.NewReader(data)) - if err != nil { - return nil, fmt.Errorf("cohere: create request: %w", err) - } - p.setHeaders(req) - - resp, err := p.config.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("cohere: send request: %w", err) - } - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - _ = resp.Body.Close() - return nil, fmt.Errorf("cohere: API error (status %d): %s", resp.StatusCode, string(body)) - } - - ch := make(chan StreamEvent, 16) - go p.readSSE(resp.Body, ch) - return ch, nil -} - -func (p *CohereProvider) buildRequest(messages []Message, tools []ToolDef, stream bool) *cohereRequest { - req := &cohereRequest{ - Model: p.config.Model, - MaxTokens: p.config.MaxTokens, - Stream: stream, - } - - // Cohere v2 uses the same role names as OpenAI: system, user, assistant, tool - for _, msg := range messages { - switch msg.Role { - case RoleTool: - req.Messages = append(req.Messages, cohereMessage{ - Role: "tool", - Content: msg.Content, - ToolCallID: msg.ToolCallID, - }) - case RoleAssistant: - cm := cohereMessage{ - Role: "assistant", - Content: msg.Content, - } - for _, tc := range msg.ToolCalls { - args := "{}" - if tc.Arguments != nil { - if b, err := json.Marshal(tc.Arguments); err == nil { - args = string(b) - } - } - cm.ToolCalls = append(cm.ToolCalls, cohereToolCall{ - ID: tc.ID, - Type: "function", - Function: cohereToolFunc{Name: tc.Name, Arguments: args}, - }) - } - req.Messages = append(req.Messages, cm) - default: - req.Messages = append(req.Messages, cohereMessage{ - Role: string(msg.Role), - Content: msg.Content, - }) - } - } - - // Convert tools to Cohere format (same as OpenAI) - for _, t := range tools { - schema := t.Parameters - if schema == nil { - schema = map[string]any{"type": "object", "properties": map[string]any{}} - } - req.Tools = append(req.Tools, cohereTool{ - Type: "function", - Function: cohereToolDef{ - Name: t.Name, - Description: t.Description, - Parameters: schema, - }, - }) - } - - return req -} - -func (p *CohereProvider) setHeaders(req *http.Request) { - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+p.config.APIKey) -} - -func (p *CohereProvider) parseResponse(apiResp *cohereResponse) *Response { - resp := &Response{ - Usage: Usage{ - InputTokens: apiResp.Usage.Tokens.InputTokens, - OutputTokens: apiResp.Usage.Tokens.OutputTokens, - }, - } - - // Extract text content - var textParts []string - for _, c := range apiResp.Message.Content { - if c.Type == "text" { - textParts = append(textParts, c.Text) - } - } - resp.Content = strings.Join(textParts, "") - - // Extract tool calls - for _, tc := range apiResp.Message.ToolCalls { - var args map[string]any - if tc.Function.Arguments != "" { - _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) - } - resp.ToolCalls = append(resp.ToolCalls, ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: args, - }) - } - - return resp -} - -// Cohere v2 streaming types - -type cohereStreamEvent struct { - Type string `json:"type"` - Index int `json:"index"` - Delta *cohereStreamDelta `json:"delta,omitempty"` - Response *cohereResponse `json:"response,omitempty"` -} - -type cohereStreamDelta struct { - Message *cohereStreamDeltaMsg `json:"message,omitempty"` -} - -type cohereStreamDeltaMsg struct { - Content *cohereStreamContent `json:"content,omitempty"` - ToolCalls *cohereStreamToolDelta `json:"tool_calls,omitempty"` -} - -type cohereStreamContent struct { - Text string `json:"text"` -} - -type cohereStreamToolDelta struct { - Function *cohereStreamFuncDelta `json:"function,omitempty"` -} - -type cohereStreamFuncDelta struct { - Name string `json:"name,omitempty"` - Arguments string `json:"arguments,omitempty"` -} - -// readSSE parses the SSE stream from the Cohere Chat API v2. -func (p *CohereProvider) readSSE(body io.ReadCloser, ch chan<- StreamEvent) { - defer func() { _ = body.Close() }() - defer close(ch) - - scanner := bufio.NewScanner(body) - - // Track tool calls being assembled by index - type pendingToolCall struct { - id string - name string - argsBuf strings.Builder - } - pending := make(map[int]*pendingToolCall) - - var usage *Usage - - for scanner.Scan() { - line := scanner.Text() - - if !strings.HasPrefix(line, "data: ") { - continue - } - - data := strings.TrimPrefix(line, "data: ") - if data == "[DONE]" { - break - } - - var event cohereStreamEvent - if err := json.Unmarshal([]byte(data), &event); err != nil { - continue - } - - switch event.Type { - case "content-delta": - if event.Delta != nil && event.Delta.Message != nil && event.Delta.Message.Content != nil { - ch <- StreamEvent{Type: "text", Text: event.Delta.Message.Content.Text} - } - - case "tool-call-start": - ptc := &pendingToolCall{} - if event.Delta != nil && event.Delta.Message != nil && event.Delta.Message.ToolCalls != nil { - if event.Delta.Message.ToolCalls.Function != nil { - ptc.name = event.Delta.Message.ToolCalls.Function.Name - } - } - pending[event.Index] = ptc - - case "tool-call-delta": - ptc, exists := pending[event.Index] - if !exists { - ptc = &pendingToolCall{} - pending[event.Index] = ptc - } - if event.Delta != nil && event.Delta.Message != nil && event.Delta.Message.ToolCalls != nil { - if event.Delta.Message.ToolCalls.Function != nil { - ptc.argsBuf.WriteString(event.Delta.Message.ToolCalls.Function.Arguments) - } - } - - case "tool-call-end": - ptc, exists := pending[event.Index] - if !exists { - continue - } - var args map[string]any - if ptc.argsBuf.Len() > 0 { - _ = json.Unmarshal([]byte(ptc.argsBuf.String()), &args) - } - ch <- StreamEvent{ - Type: "tool_call", - Tool: &ToolCall{ - ID: ptc.id, - Name: ptc.name, - Arguments: args, - }, - } - delete(pending, event.Index) - - case "message-end": - if event.Response != nil { - usage = &Usage{ - InputTokens: event.Response.Usage.Tokens.InputTokens, - OutputTokens: event.Response.Usage.Tokens.OutputTokens, - } - } - ch <- StreamEvent{Type: "done", Usage: usage} - return - } - } - - // Flush any pending tool calls - for _, ptc := range pending { - var args map[string]any - if ptc.argsBuf.Len() > 0 { - _ = json.Unmarshal([]byte(ptc.argsBuf.String()), &args) - } - ch <- StreamEvent{ - Type: "tool_call", - Tool: &ToolCall{ - ID: ptc.id, - Name: ptc.name, - Arguments: args, - }, - } - } - - if err := scanner.Err(); err != nil { - ch <- StreamEvent{Type: "error", Error: err.Error()} - } -} diff --git a/provider/copilot.go b/provider/copilot.go deleted file mode 100644 index 28db0ab..0000000 --- a/provider/copilot.go +++ /dev/null @@ -1,178 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "sync" - "time" - - openaisdk "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/shared" -) - -const ( - defaultCopilotBaseURL = "https://api.githubcopilot.com" - defaultCopilotModel = "gpt-4o" - defaultCopilotMaxTokens = 4096 - copilotTokenExchangeURL = "https://api.github.com/copilot_internal/v2/token" - copilotEditorVersion = "ratchet/0.1.0" -) - -// CopilotConfig holds configuration for the GitHub Copilot provider. -type CopilotConfig struct { - Token string - Model string - BaseURL string - MaxTokens int - HTTPClient *http.Client -} - -// CopilotProvider implements Provider using the GitHub Copilot Chat API. -// The API follows the OpenAI Chat Completions format. -type CopilotProvider struct { - config CopilotConfig - mu sync.Mutex - bearerToken string - expiresAt time.Time -} - -// copilotTokenResponse is the response from the Copilot token exchange endpoint. -type copilotTokenResponse struct { - Token string `json:"token"` - ExpiresAt int64 `json:"expires_at"` -} - -// NewCopilotProvider creates a new Copilot provider with the given config. -func NewCopilotProvider(cfg CopilotConfig) *CopilotProvider { - if cfg.Model == "" { - cfg.Model = defaultCopilotModel - } - if cfg.BaseURL == "" { - cfg.BaseURL = defaultCopilotBaseURL - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultCopilotMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - return &CopilotProvider{config: cfg} -} - -func (p *CopilotProvider) Name() string { return "copilot" } - -func (p *CopilotProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "personal", - DisplayName: "GitHub Copilot (Personal/IDE)", - Description: "Uses GitHub Copilot's chat completions API via OAuth token exchange. Intended for IDE/CLI use under a Copilot Individual or Business subscription.", - Warning: "This mode uses Copilot's internal API intended for IDE integrations. Using it in server/service contexts may violate GitHub Copilot Terms of Service (https://docs.github.com/en/site-policy/github-terms/github-terms-for-additional-products-and-features).", - DocsURL: "https://docs.github.com/en/copilot", - ServerSafe: false, - } -} - -func (p *CopilotProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - if err := p.ensureBearerToken(ctx); err != nil { - return nil, err - } - p.mu.Lock() - token := p.bearerToken - p.mu.Unlock() - - client := p.newSDKClient(token) - params := openaisdk.ChatCompletionNewParams{ - Model: shared.ChatModel(p.config.Model), - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - resp, err := client.Chat.Completions.New(ctx, params) - if err != nil { - return nil, fmt.Errorf("copilot: %w", err) - } - return fromOpenAIResponse(resp) -} - -func (p *CopilotProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - if err := p.ensureBearerToken(ctx); err != nil { - return nil, err - } - p.mu.Lock() - token := p.bearerToken - p.mu.Unlock() - - client := p.newSDKClient(token) - params := openaisdk.ChatCompletionNewParams{ - Model: shared.ChatModel(p.config.Model), - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - stream := client.Chat.Completions.NewStreaming(ctx, params) - ch := make(chan StreamEvent, 16) - go streamOpenAIEvents(stream, ch) - return ch, nil -} - -// newSDKClient creates a per-request OpenAI SDK client with Copilot auth headers. -func (p *CopilotProvider) newSDKClient(bearerToken string) openaisdk.Client { - return openaisdk.NewClient( - option.WithAPIKey(bearerToken), - option.WithBaseURL(p.config.BaseURL), - option.WithHTTPClient(p.config.HTTPClient), - option.WithHeader("Copilot-Integration-Id", "vscode-chat"), - option.WithHeader("Editor-Version", "vscode/1.100.0"), - option.WithHeader("Editor-Plugin-Version", copilotEditorVersion), - ) -} - -// ensureBearerToken exchanges the GitHub OAuth token for a short-lived Copilot -// bearer token, caching it until 60 seconds before expiry. -func (p *CopilotProvider) ensureBearerToken(ctx context.Context) error { - p.mu.Lock() - defer p.mu.Unlock() - - if p.bearerToken != "" && time.Now().Before(p.expiresAt.Add(-60*time.Second)) { - return nil - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotTokenExchangeURL, nil) - if err != nil { - return fmt.Errorf("copilot: create token exchange request: %w", err) - } - req.Header.Set("Authorization", "Token "+p.config.Token) - req.Header.Set("Accept", "application/json") - - resp, err := p.config.HTTPClient.Do(req) - if err != nil { - return fmt.Errorf("copilot: token exchange request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("copilot: read token exchange response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("copilot: token exchange failed (status %d): %s", resp.StatusCode, truncate(string(body), 200)) - } - - var tokenResp copilotTokenResponse - if err := json.Unmarshal(body, &tokenResp); err != nil { - return fmt.Errorf("copilot: parse token exchange response: %w", err) - } - - p.bearerToken = tokenResp.Token - p.expiresAt = time.Unix(tokenResp.ExpiresAt, 0) - return nil -} diff --git a/provider/copilot_models.go b/provider/copilot_models.go deleted file mode 100644 index 7ee7387..0000000 --- a/provider/copilot_models.go +++ /dev/null @@ -1,55 +0,0 @@ -package provider - -const defaultCopilotModelsBaseURL = "https://models.github.ai/inference" - -// CopilotModelsConfig configures the GitHub Models provider. -// GitHub Models is a separate product from GitHub Copilot, available at models.github.ai. -// It uses fine-grained Personal Access Tokens with the models:read scope. -type CopilotModelsConfig struct { - // Token is a GitHub fine-grained PAT with models:read permission. - Token string - // Model is the model identifier (e.g. "openai/gpt-4o", "anthropic/claude-sonnet-4"). - Model string - // BaseURL overrides the default endpoint. Default: "https://models.github.ai/inference". - BaseURL string - // MaxTokens limits the response length. - MaxTokens int -} - -// CopilotModelsProvider uses GitHub Models (models.github.ai) for inference. -// It wraps OpenAIProvider since GitHub Models uses an OpenAI-compatible API. -type CopilotModelsProvider struct { - *OpenAIProvider -} - -func (p *CopilotModelsProvider) Name() string { return "copilot_models" } - -func (p *CopilotModelsProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "github_models", - DisplayName: "GitHub Models", - Description: "Access AI models via GitHub's Models marketplace using a fine-grained PAT with models:read scope.", - DocsURL: "https://docs.github.com/en/rest/models/inference", - ServerSafe: true, - } -} - -// NewCopilotModelsProvider creates a provider that uses GitHub Models for inference. -// GitHub Models provides access to various AI models via a fine-grained PAT. -// -// Docs: https://docs.github.com/en/rest/models/inference -// Billing: https://docs.github.com/billing/managing-billing-for-your-products/about-billing-for-github-models -func NewCopilotModelsProvider(cfg CopilotModelsConfig) *CopilotModelsProvider { - baseURL := cfg.BaseURL - if baseURL == "" { - baseURL = defaultCopilotModelsBaseURL - } - return &CopilotModelsProvider{ - OpenAIProvider: NewOpenAIProvider(OpenAIConfig{ - APIKey: cfg.Token, - Model: cfg.Model, - BaseURL: baseURL, - MaxTokens: cfg.MaxTokens, - }), - } -} diff --git a/provider/copilot_models_test.go b/provider/copilot_models_test.go deleted file mode 100644 index 66fc1de..0000000 --- a/provider/copilot_models_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" -) - -func TestCopilotModelsProvider_Name(t *testing.T) { - p := NewCopilotModelsProvider(CopilotModelsConfig{Token: "ghp_test"}) - if got := p.Name(); got != "copilot_models" { - t.Errorf("Name() = %q, want %q", got, "copilot_models") - } -} - -func TestCopilotModelsProvider_AuthModeInfo(t *testing.T) { - p := NewCopilotModelsProvider(CopilotModelsConfig{Token: "ghp_test"}) - info := p.AuthModeInfo() - - if info.Mode != "github_models" { - t.Errorf("AuthModeInfo().Mode = %q, want %q", info.Mode, "github_models") - } - if info.DisplayName != "GitHub Models" { - t.Errorf("AuthModeInfo().DisplayName = %q, want %q", info.DisplayName, "GitHub Models") - } - if info.DocsURL != "https://docs.github.com/en/rest/models/inference" { - t.Errorf("AuthModeInfo().DocsURL = %q, want GitHub Models docs URL", info.DocsURL) - } - if !info.ServerSafe { - t.Error("AuthModeInfo().ServerSafe = false, want true") - } -} - -func TestCopilotModelsProvider_ImplementsProvider(t *testing.T) { - var _ Provider = (*CopilotModelsProvider)(nil) -} - -func TestCopilotModelsProvider_DefaultBaseURL(t *testing.T) { - p := NewCopilotModelsProvider(CopilotModelsConfig{Token: "ghp_test"}) - if p.config.BaseURL != defaultCopilotModelsBaseURL { - t.Errorf("default BaseURL = %q, want %q", p.config.BaseURL, defaultCopilotModelsBaseURL) - } -} - -func TestCopilotModelsProvider_CustomBaseURL(t *testing.T) { - custom := "https://custom.models.github.ai/inference" - p := NewCopilotModelsProvider(CopilotModelsConfig{Token: "ghp_test", BaseURL: custom}) - if p.config.BaseURL != custom { - t.Errorf("BaseURL = %q, want %q", p.config.BaseURL, custom) - } -} - -func TestCopilotModelsProvider_Chat(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("Authorization"); got != "Bearer ghp_test123" { - t.Errorf("Authorization header = %q, want %q", got, "Bearer ghp_test123") - } - - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"chatcmpl-gh","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"hello from github models"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":8,"completion_tokens":4,"total_tokens":12},"created":1704067200}`) - })) - defer srv.Close() - - p := NewCopilotModelsProvider(CopilotModelsConfig{ - Token: "ghp_test123", - Model: "openai/gpt-4o", - BaseURL: srv.URL, - }) - - got, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if got.Content != "hello from github models" { - t.Errorf("Chat() content = %q, want %q", got.Content, "hello from github models") - } - if got.Usage.InputTokens != 8 || got.Usage.OutputTokens != 4 { - t.Errorf("Chat() usage = %+v, want input=8 output=4", got.Usage) - } -} - -func TestCopilotModelsProvider_Stream(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - flusher, ok := w.(http.Flusher) - if !ok { - t.Fatal("expected http.Flusher") - } - - data, _ := json.Marshal(map[string]any{ - "id": "chatcmpl-gh-stream", - "object": "chat.completion.chunk", - "model": "gpt-4o", - "created": 1704067200, - "choices": []map[string]any{ - {"index": 0, "delta": map[string]any{"role": "assistant", "content": "gh-streamed"}, "finish_reason": nil}, - }, - }) - fmt.Fprintf(w, "data: %s\n\n", string(data)) - flusher.Flush() - - fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() - })) - defer srv.Close() - - p := NewCopilotModelsProvider(CopilotModelsConfig{ - Token: "ghp_test123", - Model: "openai/gpt-4o", - BaseURL: srv.URL, - }) - - ch, err := p.Stream(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Stream() error: %v", err) - } - - var texts []string - for ev := range ch { - switch ev.Type { - case "text": - texts = append(texts, ev.Text) - case "error": - t.Fatalf("stream error: %s", ev.Error) - } - } - if len(texts) == 0 { - t.Fatal("expected at least one text event") - } - if texts[0] != "gh-streamed" { - t.Errorf("first text event = %q, want %q", texts[0], "gh-streamed") - } -} diff --git a/provider/copilot_test.go b/provider/copilot_test.go deleted file mode 100644 index 7cca38d..0000000 --- a/provider/copilot_test.go +++ /dev/null @@ -1,297 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "sync/atomic" - "testing" - "time" -) - -// setupCopilotServers creates two test servers: one for token exchange and one for chat. -// Returns (tokenSrv, chatSrv, cleanup). -func setupCopilotServers(t *testing.T, tokenHandler, chatHandler http.HandlerFunc) (*httptest.Server, *httptest.Server) { - t.Helper() - tokenSrv := httptest.NewServer(tokenHandler) - chatSrv := httptest.NewServer(chatHandler) - t.Cleanup(func() { - tokenSrv.Close() - chatSrv.Close() - }) - return tokenSrv, chatSrv -} - -func validTokenHandler(expiresAt int64) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") == "" { - http.Error(w, "missing auth", http.StatusUnauthorized) - return - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(copilotTokenResponse{ - Token: "copilot-bearer-token", - ExpiresAt: expiresAt, - }) - } -} - -func validChatHandler(t *testing.T) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("Authorization"); got != "Bearer copilot-bearer-token" { - t.Errorf("chat Authorization = %q, want Bearer copilot-bearer-token", got) - } - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"chatcmpl-cop","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"hello from copilot"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8},"created":1704067200}`) - } -} - -func TestCopilotProvider_Name(t *testing.T) { - p := NewCopilotProvider(CopilotConfig{Token: "ghp_test"}) - if got := p.Name(); got != "copilot" { - t.Errorf("Name() = %q, want %q", got, "copilot") - } -} - -func TestCopilotProvider_AuthModeInfo(t *testing.T) { - p := NewCopilotProvider(CopilotConfig{Token: "ghp_test"}) - info := p.AuthModeInfo() - if info.Mode != "personal" { - t.Errorf("AuthModeInfo().Mode = %q, want %q", info.Mode, "personal") - } - if info.ServerSafe { - t.Error("AuthModeInfo().ServerSafe = true, want false") - } -} - -func TestCopilotProvider_ImplementsProvider(t *testing.T) { - var _ Provider = (*CopilotProvider)(nil) -} - -func TestCopilotProvider_Chat_HappyPath(t *testing.T) { - expiresAt := time.Now().Add(30 * time.Minute).Unix() - tokenSrv, chatSrv := setupCopilotServers(t, validTokenHandler(expiresAt), validChatHandler(t)) - - p := NewCopilotProvider(CopilotConfig{ - Token: "ghp_oauth_token", - Model: "gpt-4o", - BaseURL: chatSrv.URL, - }) - // Override the token exchange URL by patching the provider's internal call via an http.Client - // that redirects token exchange requests to our test server. - p.config.HTTPClient = redirectingClient(tokenSrv.URL) - - got, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hello"}, - }, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if got.Content != "hello from copilot" { - t.Errorf("Chat() content = %q, want %q", got.Content, "hello from copilot") - } - if got.Usage.InputTokens != 5 || got.Usage.OutputTokens != 3 { - t.Errorf("Chat() usage = %+v, want input=5 output=3", got.Usage) - } -} - -func TestCopilotProvider_TokenExchange_Error(t *testing.T) { - tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "unauthorized", http.StatusUnauthorized) - })) - defer tokenSrv.Close() - - chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Error("chat server should not be called when token exchange fails") - })) - defer chatSrv.Close() - - p := NewCopilotProvider(CopilotConfig{ - Token: "bad_token", - Model: "gpt-4o", - BaseURL: chatSrv.URL, - HTTPClient: redirectingClient(tokenSrv.URL), - }) - - _, err := p.Chat(t.Context(), []Message{{Role: RoleUser, Content: "hi"}}, nil) - if err == nil { - t.Fatal("expected error when token exchange returns 401") - } -} - -func TestCopilotProvider_TokenCacheHit(t *testing.T) { - var tokenCallCount atomic.Int32 - expiresAt := time.Now().Add(30 * time.Minute).Unix() - - tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tokenCallCount.Add(1) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(copilotTokenResponse{ - Token: "cached-bearer-token", - ExpiresAt: expiresAt, - }) - })) - defer tokenSrv.Close() - - chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"c1","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2},"created":1704067200}`) - })) - defer chatSrv.Close() - - p := NewCopilotProvider(CopilotConfig{ - Token: "ghp_test", - Model: "gpt-4o", - BaseURL: chatSrv.URL, - HTTPClient: redirectingClient(tokenSrv.URL), - }) - - // Two sequential calls — token exchange should only happen once. - for range 2 { - if _, err := p.Chat(t.Context(), []Message{{Role: RoleUser, Content: "hi"}}, nil); err != nil { - t.Fatalf("Chat() error: %v", err) - } - } - - if n := tokenCallCount.Load(); n != 1 { - t.Errorf("token exchange called %d times, want 1 (cache should be hit on second call)", n) - } -} - -func TestCopilotProvider_TokenRefresh_AfterExpiry(t *testing.T) { - var tokenCallCount atomic.Int32 - // Already-expired token. - pastExpiry := time.Now().Add(-5 * time.Minute).Unix() - - tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tokenCallCount.Add(1) - w.Header().Set("Content-Type", "application/json") - // Return a fresh token each time. - json.NewEncoder(w).Encode(copilotTokenResponse{ - Token: fmt.Sprintf("token-%d", tokenCallCount.Load()), - ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), - }) - })) - defer tokenSrv.Close() - - chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"c1","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2},"created":1704067200}`) - })) - defer chatSrv.Close() - - p := NewCopilotProvider(CopilotConfig{ - Token: "ghp_test", - Model: "gpt-4o", - BaseURL: chatSrv.URL, - HTTPClient: redirectingClient(tokenSrv.URL), - }) - - // Manually set an expired bearer token. - p.bearerToken = "expired-token" - p.expiresAt = time.Unix(pastExpiry, 0) - - if _, err := p.Chat(t.Context(), []Message{{Role: RoleUser, Content: "hi"}}, nil); err != nil { - t.Fatalf("Chat() error: %v", err) - } - - if n := tokenCallCount.Load(); n != 1 { - t.Errorf("token exchange called %d times, want 1 (should refresh expired token)", n) - } -} - -func TestCopilotProvider_Stream(t *testing.T) { - expiresAt := time.Now().Add(30 * time.Minute).Unix() - - tokenSrv := httptest.NewServer(validTokenHandler(expiresAt)) - defer tokenSrv.Close() - - chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - flusher, ok := w.(http.Flusher) - if !ok { - t.Fatal("expected http.Flusher") - } - - data, _ := json.Marshal(map[string]any{ - "id": "chatcmpl-cop-stream", - "object": "chat.completion.chunk", - "model": "gpt-4o", - "created": 1704067200, - "choices": []map[string]any{ - {"index": 0, "delta": map[string]any{"role": "assistant", "content": "copilot-streamed"}, "finish_reason": nil}, - }, - }) - fmt.Fprintf(w, "data: %s\n\n", string(data)) - flusher.Flush() - - fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() - })) - defer chatSrv.Close() - - p := NewCopilotProvider(CopilotConfig{ - Token: "ghp_test", - Model: "gpt-4o", - BaseURL: chatSrv.URL, - HTTPClient: redirectingClient(tokenSrv.URL), - }) - - ch, err := p.Stream(t.Context(), []Message{{Role: RoleUser, Content: "hi"}}, nil) - if err != nil { - t.Fatalf("Stream() error: %v", err) - } - - var texts []string - for ev := range ch { - switch ev.Type { - case "text": - texts = append(texts, ev.Text) - case "error": - t.Fatalf("stream error: %s", ev.Error) - } - } - if len(texts) == 0 { - t.Fatal("expected at least one text event") - } - if texts[0] != "copilot-streamed" { - t.Errorf("first text event = %q, want %q", texts[0], "copilot-streamed") - } -} - -// redirectingClient returns an http.Client whose transport redirects all requests -// with the copilotTokenExchangeURL path to tokenBaseURL, leaving others unchanged. -func redirectingClient(tokenBaseURL string) *http.Client { - return &http.Client{ - Transport: &tokenRedirectTransport{ - tokenBaseURL: tokenBaseURL, - inner: http.DefaultTransport, - }, - } -} - -type tokenRedirectTransport struct { - tokenBaseURL string - inner http.RoundTripper -} - -func (t *tokenRedirectTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Redirect token exchange requests to our test server. - if req.URL.String() == copilotTokenExchangeURL { - redirected := req.Clone(req.Context()) - redirected.URL.Scheme = "http" - redirected.URL.Host = req.URL.Host - // Parse tokenBaseURL and override host. - srv := t.tokenBaseURL - // Strip scheme. - if len(srv) > 7 && srv[:7] == "http://" { - srv = srv[7:] - } - redirected.URL.Host = srv - redirected.URL.Path = req.URL.Path - redirected.Host = srv - return t.inner.RoundTrip(redirected) - } - return t.inner.RoundTrip(req) -} diff --git a/provider/gemini.go b/provider/gemini.go deleted file mode 100644 index 08f562c..0000000 --- a/provider/gemini.go +++ /dev/null @@ -1,325 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - - "github.com/google/generative-ai-go/genai" - "google.golang.org/api/iterator" - googleoption "google.golang.org/api/option" -) - -const ( - defaultGeminiModel = "gemini-2.0-flash" - defaultGeminiMaxTokens = 4096 -) - -// GeminiConfig holds configuration for the Google Gemini provider. -type GeminiConfig struct { - APIKey string - Model string - MaxTokens int - HTTPClient *http.Client -} - -// GeminiProvider implements Provider using the Google Gemini API. -type GeminiProvider struct { - config GeminiConfig -} - -// NewGeminiProvider creates a new Gemini provider. Returns an error if no API key is provided. -func NewGeminiProvider(cfg GeminiConfig) (*GeminiProvider, error) { - if cfg.APIKey == "" { - return nil, fmt.Errorf("gemini: APIKey is required") - } - if cfg.Model == "" { - cfg.Model = defaultGeminiModel - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultGeminiMaxTokens - } - return &GeminiProvider{config: cfg}, nil -} - -func (p *GeminiProvider) Name() string { return "gemini" } - -func (p *GeminiProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "gemini", - DisplayName: "Google Gemini", - Description: "Uses the Google Gemini API with an API key from Google AI Studio.", - DocsURL: "https://ai.google.dev/gemini-api/docs/api-key", - ServerSafe: true, - } -} - -func (p *GeminiProvider) newGenaiClient(ctx context.Context) (*genai.Client, error) { - opts := []googleoption.ClientOption{ - googleoption.WithAPIKey(p.config.APIKey), - } - if p.config.HTTPClient != nil { - opts = append(opts, googleoption.WithHTTPClient(p.config.HTTPClient)) - } - return genai.NewClient(ctx, opts...) -} - -func (p *GeminiProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - client, err := p.newGenaiClient(ctx) - if err != nil { - return nil, fmt.Errorf("gemini: create client: %w", err) - } - defer client.Close() - - model := client.GenerativeModel(p.config.Model) - maxOut := int32(p.config.MaxTokens) - model.MaxOutputTokens = &maxOut - if len(tools) > 0 { - model.Tools = toGeminiTools(tools) - } - - contents, systemInstruction := toGeminiContents(messages) - if systemInstruction != nil { - model.SystemInstruction = systemInstruction - } - - resp, err := model.GenerateContent(ctx, contents...) - if err != nil { - return nil, fmt.Errorf("gemini: %w", err) - } - return fromGeminiResponse(resp), nil -} - -func (p *GeminiProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - client, err := p.newGenaiClient(ctx) - if err != nil { - return nil, fmt.Errorf("gemini: create client: %w", err) - } - - model := client.GenerativeModel(p.config.Model) - maxOut := int32(p.config.MaxTokens) - model.MaxOutputTokens = &maxOut - if len(tools) > 0 { - model.Tools = toGeminiTools(tools) - } - - contents, systemInstruction := toGeminiContents(messages) - if systemInstruction != nil { - model.SystemInstruction = systemInstruction - } - - iter := model.GenerateContentStream(ctx, contents...) - ch := make(chan StreamEvent, 16) - go func() { - defer close(ch) - defer client.Close() - send := func(ev StreamEvent) bool { - select { - case ch <- ev: - return true - case <-ctx.Done(): - return false - } - } - for { - resp, err := iter.Next() - if err == iterator.Done { - send(StreamEvent{Type: "done"}) - return - } - if err != nil { - send(StreamEvent{Type: "error", Error: err.Error()}) - return - } - for _, cand := range resp.Candidates { - if cand.Content == nil { - continue - } - for _, part := range cand.Content.Parts { - switch v := part.(type) { - case genai.Text: - if v != "" { - if !send(StreamEvent{Type: "text", Text: string(v)}) { - return - } - } - case genai.FunctionCall: - if !send(StreamEvent{Type: "tool_call", Tool: &ToolCall{ - ID: v.Name, - Name: v.Name, - Arguments: v.Args, - }}) { - return - } - } - } - } - } - }() - return ch, nil -} - -// toGeminiContents converts provider messages to Gemini content parts. -// System messages are returned separately as they set model.SystemInstruction. -func toGeminiContents(messages []Message) ([]genai.Part, *genai.Content) { - var parts []genai.Part - var systemInstruction *genai.Content - - for _, msg := range messages { - switch msg.Role { - case RoleSystem: - systemInstruction = &genai.Content{ - Parts: []genai.Part{genai.Text(msg.Content)}, - Role: "user", - } - case RoleUser: - parts = append(parts, genai.Text(msg.Content)) - case RoleAssistant: - parts = append(parts, genai.Text(msg.Content)) - case RoleTool: - // Tool results are passed as FunctionResponse parts. - var args map[string]any - _ = json.Unmarshal([]byte(msg.Content), &args) - if args == nil { - args = map[string]any{"result": msg.Content} - } - parts = append(parts, genai.FunctionResponse{ - Name: msg.ToolCallID, - Response: args, - }) - } - } - return parts, systemInstruction -} - -// toGeminiTools converts provider tool definitions to Gemini Tool structs. -func toGeminiTools(tools []ToolDef) []*genai.Tool { - decls := make([]*genai.FunctionDeclaration, 0, len(tools)) - for _, t := range tools { - decls = append(decls, &genai.FunctionDeclaration{ - Name: t.Name, - Description: t.Description, - Parameters: toGeminiSchema(t.Parameters), - }) - } - return []*genai.Tool{{FunctionDeclarations: decls}} -} - -// toGeminiSchema converts a JSON Schema map to a genai.Schema. -func toGeminiSchema(params map[string]any) *genai.Schema { - if params == nil { - return nil - } - schema := &genai.Schema{Type: genai.TypeObject} - if props, ok := params["properties"].(map[string]any); ok { - schema.Properties = make(map[string]*genai.Schema, len(props)) - for name, val := range props { - if propMap, ok := val.(map[string]any); ok { - schema.Properties[name] = toGeminiSchemaProp(propMap) - } - } - } - if req, ok := params["required"].([]any); ok { - for _, r := range req { - if s, ok := r.(string); ok { - schema.Required = append(schema.Required, s) - } - } - } - return schema -} - -func toGeminiSchemaProp(m map[string]any) *genai.Schema { - s := &genai.Schema{} - if t, ok := m["type"].(string); ok { - switch t { - case "string": - s.Type = genai.TypeString - case "number": - s.Type = genai.TypeNumber - case "integer": - s.Type = genai.TypeInteger - case "boolean": - s.Type = genai.TypeBoolean - case "array": - s.Type = genai.TypeArray - case "object": - s.Type = genai.TypeObject - } - } - if desc, ok := m["description"].(string); ok { - s.Description = desc - } - return s -} - -// fromGeminiResponse extracts a provider Response from a Gemini GenerateContentResponse. -func fromGeminiResponse(resp *genai.GenerateContentResponse) *Response { - result := &Response{} - if resp.UsageMetadata != nil { - result.Usage = Usage{ - InputTokens: int(resp.UsageMetadata.PromptTokenCount), - OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount), - } - } - if len(resp.Candidates) == 0 { - return result - } - cand := resp.Candidates[0] - if cand.Content == nil { - return result - } - for _, part := range cand.Content.Parts { - switch v := part.(type) { - case genai.Text: - result.Content += string(v) - case genai.FunctionCall: - result.ToolCalls = append(result.ToolCalls, ToolCall{ - ID: v.Name, - Name: v.Name, - Arguments: v.Args, - }) - } - } - return result -} - -// listGeminiModels lists available Gemini models using the genai SDK. -func listGeminiModels(ctx context.Context, apiKey string) ([]ModelInfo, error) { - client, err := genai.NewClient(ctx, googleoption.WithAPIKey(apiKey)) - if err != nil { - return geminiFallbackModels(), nil - } - defer client.Close() - - iter := client.ListModels(ctx) - var models []ModelInfo - for { - m, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return geminiFallbackModels(), nil - } - models = append(models, ModelInfo{ - ID: m.Name, - Name: m.DisplayName, - }) - } - if len(models) == 0 { - return geminiFallbackModels(), nil - } - return models, nil -} - -func geminiFallbackModels() []ModelInfo { - return []ModelInfo{ - {ID: "gemini-2.5-pro-preview-03-25", Name: "Gemini 2.5 Pro Preview"}, - {ID: "gemini-2.0-flash", Name: "Gemini 2.0 Flash"}, - {ID: "gemini-2.0-flash-lite", Name: "Gemini 2.0 Flash-Lite"}, - {ID: "gemini-1.5-pro", Name: "Gemini 1.5 Pro"}, - {ID: "gemini-1.5-flash", Name: "Gemini 1.5 Flash"}, - } -} diff --git a/provider/gemini_test.go b/provider/gemini_test.go deleted file mode 100644 index dff04ff..0000000 --- a/provider/gemini_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package provider - -import ( - "testing" -) - -func TestGeminiProvider_Name(t *testing.T) { - p, err := NewGeminiProvider(GeminiConfig{APIKey: "test-key"}) - if err != nil { - t.Fatal(err) - } - if got := p.Name(); got != "gemini" { - t.Errorf("Name() = %q, want %q", got, "gemini") - } -} - -func TestGeminiProvider_AuthModeInfo(t *testing.T) { - p, err := NewGeminiProvider(GeminiConfig{APIKey: "test-key"}) - if err != nil { - t.Fatal(err) - } - info := p.AuthModeInfo() - - if info.Mode != "gemini" { - t.Errorf("AuthModeInfo().Mode = %q, want %q", info.Mode, "gemini") - } - if info.DisplayName != "Google Gemini" { - t.Errorf("AuthModeInfo().DisplayName = %q, want %q", info.DisplayName, "Google Gemini") - } - if info.DocsURL != "https://ai.google.dev/gemini-api/docs/api-key" { - t.Errorf("AuthModeInfo().DocsURL = %q, want Gemini docs URL", info.DocsURL) - } - if !info.ServerSafe { - t.Error("AuthModeInfo().ServerSafe = false, want true") - } -} - -func TestGeminiProvider_ImplementsProvider(t *testing.T) { - var _ Provider = (*GeminiProvider)(nil) -} - -func TestGeminiProvider_RequiresAPIKey(t *testing.T) { - _, err := NewGeminiProvider(GeminiConfig{}) - if err == nil { - t.Fatal("expected error when no API key provided") - } -} - -func TestGeminiProvider_DefaultModel(t *testing.T) { - p, err := NewGeminiProvider(GeminiConfig{APIKey: "test-key"}) - if err != nil { - t.Fatal(err) - } - if p.config.Model != defaultGeminiModel { - t.Errorf("default Model = %q, want %q", p.config.Model, defaultGeminiModel) - } -} - -func TestGeminiProvider_DefaultMaxTokens(t *testing.T) { - p, err := NewGeminiProvider(GeminiConfig{APIKey: "test-key"}) - if err != nil { - t.Fatal(err) - } - if p.config.MaxTokens != defaultGeminiMaxTokens { - t.Errorf("default MaxTokens = %d, want %d", p.config.MaxTokens, defaultGeminiMaxTokens) - } -} - -func TestToGeminiSchema_NilParams(t *testing.T) { - s := toGeminiSchema(nil) - if s != nil { - t.Errorf("toGeminiSchema(nil) = %v, want nil", s) - } -} - -func TestToGeminiSchema_WithProperties(t *testing.T) { - params := map[string]any{ - "type": "object", - "properties": map[string]any{ - "name": map[string]any{"type": "string", "description": "A name"}, - "age": map[string]any{"type": "integer"}, - }, - "required": []any{"name"}, - } - s := toGeminiSchema(params) - if s == nil { - t.Fatal("toGeminiSchema returned nil for valid params") - } - if len(s.Properties) != 2 { - t.Errorf("Properties count = %d, want 2", len(s.Properties)) - } - if len(s.Required) != 1 || s.Required[0] != "name" { - t.Errorf("Required = %v, want [name]", s.Required) - } -} - -func TestGeminiFallbackModels(t *testing.T) { - models := geminiFallbackModels() - if len(models) == 0 { - t.Error("geminiFallbackModels() returned empty slice") - } - for _, m := range models { - if m.ID == "" { - t.Error("fallback model has empty ID") - } - if m.Name == "" { - t.Error("fallback model has empty Name") - } - } -} diff --git a/provider/llama_cpp.go b/provider/llama_cpp.go deleted file mode 100644 index 357d7ed..0000000 --- a/provider/llama_cpp.go +++ /dev/null @@ -1,231 +0,0 @@ -package provider - -import ( - "context" - "fmt" - "net/http" - "os/exec" - "runtime" - "time" - - openaisdk "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/shared" -) - -const ( - defaultLlamaCppPort = 8081 - defaultLlamaCppGPULayers = -1 - defaultLlamaCppContextSize = 8192 - defaultLlamaCppMaxTokens = 8192 -) - -// LlamaCppConfig holds configuration for the LlamaCpp provider. -// Set BaseURL for external mode (any OpenAI-compatible server). -// Set ModelPath for managed mode (provider starts llama-server). -type LlamaCppConfig struct { - BaseURL string // external mode: OpenAI-compatible server URL - ModelPath string // managed mode: path to .gguf model file - ModelName string // model name sent to server (external mode); defaults to "local" - BinaryPath string // override llama-server binary location - GPULayers int // -ngl flag; 0 → default -1 (all layers) - ContextSize int // -c flag; default 8192 - Threads int // -t flag; default runtime.NumCPU() - Port int // server port; default 8081 - MaxTokens int // default 8192 - DisableThinking bool // send chat_template_kwargs.enable_thinking=false for reasoning models - HTTPClient *http.Client -} - -// LlamaCppProvider implements Provider using an OpenAI-compatible llama-server. -type LlamaCppProvider struct { - client openaisdk.Client - config LlamaCppConfig - cmd *exec.Cmd // non-nil in managed mode after server start -} - -// NewLlamaCppProvider creates a LlamaCppProvider with the given config. -// In external mode (BaseURL set), it points at the given URL. -// In managed mode (ModelPath set), call ensureServer before use. -func NewLlamaCppProvider(cfg LlamaCppConfig) *LlamaCppProvider { - if cfg.GPULayers == 0 { - cfg.GPULayers = defaultLlamaCppGPULayers - } - if cfg.ContextSize <= 0 { - cfg.ContextSize = defaultLlamaCppContextSize - } - if cfg.Threads <= 0 { - cfg.Threads = runtime.NumCPU() - } - if cfg.Port <= 0 { - cfg.Port = defaultLlamaCppPort - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultLlamaCppMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - - baseURL := cfg.BaseURL - if baseURL == "" { - baseURL = fmt.Sprintf("http://localhost:%d/v1", cfg.Port) - } - - client := openaisdk.NewClient( - option.WithAPIKey("no-key"), - option.WithBaseURL(baseURL), - option.WithHTTPClient(cfg.HTTPClient), - ) - return &LlamaCppProvider{client: client, config: cfg} -} - -func (p *LlamaCppProvider) Name() string { return "llama_cpp" } - -// modelName returns the model identifier to send to the server. -func (p *LlamaCppProvider) modelName() string { - if p.config.ModelName != "" { - return p.config.ModelName - } - return "local" -} - -func (p *LlamaCppProvider) AuthModeInfo() AuthModeInfo { - return LocalAuthMode("llama_cpp", "llama.cpp (Local)") -} - -// Chat sends a non-streaming request and applies ParseThinking to the response. -func (p *LlamaCppProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - params := openaisdk.ChatCompletionNewParams{ - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - params.Model = shared.ChatModel(p.modelName()) - - var opts []option.RequestOption - if p.config.DisableThinking { - opts = append(opts, option.WithJSONSet("chat_template_kwargs", map[string]any{"enable_thinking": false})) - } - - resp, err := p.client.Chat.Completions.New(ctx, params, opts...) - if err != nil { - return nil, fmt.Errorf("llama_cpp: %w", err) - } - result, err := fromOpenAIResponse(resp) - if err != nil { - return nil, err - } - result.Thinking, result.Content = ParseThinking(result.Content) - return result, nil -} - -// Stream sends a streaming request, applying ThinkingStreamParser to text events. -func (p *LlamaCppProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - params := openaisdk.ChatCompletionNewParams{ - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - params.Model = shared.ChatModel(p.modelName()) - - stream := p.client.Chat.Completions.NewStreaming(ctx, params) - rawCh := make(chan StreamEvent, 16) - go streamOpenAIEvents(stream, rawCh) - - outCh := make(chan StreamEvent, 16) - go func() { - defer close(outCh) - parser := &ThinkingStreamParser{} - for event := range rawCh { - if event.Type == "text" { - for _, e := range parser.Feed(event.Text) { - outCh <- e - } - } else { - outCh <- event - } - } - }() - return outCh, nil -} - -// EnsureServer starts the managed llama-server if ModelPath is configured. -// No-op in external mode. Must be called before Chat/Stream in managed mode. -func (p *LlamaCppProvider) EnsureServer(ctx context.Context) error { - if p.config.ModelPath == "" { - return nil // external mode, nothing to start - } - return p.ensureServer(ctx) -} - -func (p *LlamaCppProvider) ensureServer(ctx context.Context) error { - binPath := p.config.BinaryPath - if binPath == "" { - if path, err := exec.LookPath("llama-server"); err == nil { - binPath = path - } else { - var dlErr error - binPath, dlErr = EnsureLlamaServer(ctx) - if dlErr != nil { - return fmt.Errorf("llama_cpp: find llama-server: %w", dlErr) - } - } - } - - args := []string{ - "--model", p.config.ModelPath, - "--port", fmt.Sprintf("%d", p.config.Port), - "-ngl", fmt.Sprintf("%d", p.config.GPULayers), - "-c", fmt.Sprintf("%d", p.config.ContextSize), - "-t", fmt.Sprintf("%d", p.config.Threads), - } - - // #nosec G204 -- BinaryPath/binPath comes from config or PATH lookup, not user input. - cmd := exec.Command(binPath, args...) - if err := cmd.Start(); err != nil { - return fmt.Errorf("llama_cpp: start llama-server: %w", err) - } - p.cmd = cmd - - healthURL := fmt.Sprintf("http://localhost:%d/health", p.config.Port) - deadline := time.Now().Add(2 * time.Minute) - ticker := time.NewTicker(500 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - _ = p.cmd.Process.Kill() - _ = p.cmd.Wait() - return ctx.Err() - case <-ticker.C: - if time.Now().After(deadline) { - _ = p.cmd.Process.Kill() - _ = p.cmd.Wait() - return fmt.Errorf("llama_cpp: llama-server did not become healthy within timeout") - } - hReq, _ := http.NewRequestWithContext(ctx, http.MethodGet, healthURL, nil) - resp, err := p.config.HTTPClient.Do(hReq) - if err == nil { - _ = resp.Body.Close() - if resp.StatusCode == http.StatusOK { - return nil - } - } - } - } -} - -// Close kills the managed llama-server process if one was started -// and waits for it to exit to avoid zombie processes. -func (p *LlamaCppProvider) Close() error { - if p.cmd != nil && p.cmd.Process != nil { - _ = p.cmd.Process.Kill() - _ = p.cmd.Wait() // reap the child process - } - return nil -} diff --git a/provider/llama_cpp_test.go b/provider/llama_cpp_test.go deleted file mode 100644 index 7d50889..0000000 --- a/provider/llama_cpp_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" -) - -// newLlamaCppTestServer starts an httptest server returning OpenAI-compatible responses. -// If thinking is non-empty, the content wraps it in tags. -func newLlamaCppTestServer(t *testing.T, thinking, content string) *httptest.Server { - t.Helper() - rawContent := content - if thinking != "" { - rawContent = fmt.Sprintf("%s%s", thinking, content) - } - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/v1/chat/completions" { - w.Header().Set("Content-Type", "application/json") - resp := map[string]any{ - "id": "chatcmpl-test", - "object": "chat.completion", - "choices": []any{map[string]any{"message": map[string]any{"role": "assistant", "content": rawContent}, "finish_reason": "stop"}}, - "usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 5}, - } - _ = json.NewEncoder(w).Encode(resp) - return - } - http.NotFound(w, r) - })) -} - -// newLlamaCppStreamServer starts an httptest server returning SSE streaming responses. -func newLlamaCppStreamServer(t *testing.T, chunks []string) *httptest.Server { - t.Helper() - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/v1/chat/completions" { - w.Header().Set("Content-Type", "text/event-stream") - for _, chunk := range chunks { - delta := map[string]any{ - "id": "chatcmpl-test", - "object": "chat.completion.chunk", - "choices": []any{map[string]any{"delta": map[string]any{"content": chunk}, "finish_reason": nil}}, - } - b, _ := json.Marshal(delta) - fmt.Fprintf(w, "data: %s\n\n", b) - } - fmt.Fprintf(w, "data: [DONE]\n\n") - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - return - } - http.NotFound(w, r) - })) -} - -func TestLlamaCppProvider_Defaults(t *testing.T) { - p := NewLlamaCppProvider(LlamaCppConfig{BaseURL: "http://localhost:8081/v1"}) - if p.config.GPULayers != defaultLlamaCppGPULayers { - t.Errorf("GPULayers: want %d, got %d", defaultLlamaCppGPULayers, p.config.GPULayers) - } - if p.config.ContextSize != defaultLlamaCppContextSize { - t.Errorf("ContextSize: want %d, got %d", defaultLlamaCppContextSize, p.config.ContextSize) - } - if p.config.MaxTokens != defaultLlamaCppMaxTokens { - t.Errorf("MaxTokens: want %d, got %d", defaultLlamaCppMaxTokens, p.config.MaxTokens) - } - if p.config.Port != defaultLlamaCppPort { - t.Errorf("Port: want %d, got %d", defaultLlamaCppPort, p.config.Port) - } - if p.Name() != "llama_cpp" { - t.Errorf("Name: want %q, got %q", "llama_cpp", p.Name()) - } -} - -func TestLlamaCppProvider_Chat_NoThinking(t *testing.T) { - srv := newLlamaCppTestServer(t, "", "The sky is blue.") - defer srv.Close() - - p := NewLlamaCppProvider(LlamaCppConfig{ - BaseURL: srv.URL + "/v1", - HTTPClient: srv.Client(), - }) - - resp, err := p.Chat(context.Background(), []Message{{Role: RoleUser, Content: "What color is the sky?"}}, nil) - if err != nil { - t.Fatalf("Chat: %v", err) - } - if resp.Content != "The sky is blue." { - t.Errorf("Content: want %q, got %q", "The sky is blue.", resp.Content) - } - if resp.Thinking != "" { - t.Errorf("Thinking: want empty, got %q", resp.Thinking) - } -} - -func TestLlamaCppProvider_Chat_WithThinking(t *testing.T) { - srv := newLlamaCppTestServer(t, "Let me reason step by step.", "The answer is 42.") - defer srv.Close() - - p := NewLlamaCppProvider(LlamaCppConfig{ - BaseURL: srv.URL + "/v1", - HTTPClient: srv.Client(), - }) - - resp, err := p.Chat(context.Background(), []Message{{Role: RoleUser, Content: "What is the answer?"}}, nil) - if err != nil { - t.Fatalf("Chat: %v", err) - } - if resp.Thinking != "Let me reason step by step." { - t.Errorf("Thinking: want %q, got %q", "Let me reason step by step.", resp.Thinking) - } - if resp.Content != "The answer is 42." { - t.Errorf("Content: want %q, got %q", "The answer is 42.", resp.Content) - } -} - -func TestLlamaCppProvider_Stream_WithThinking(t *testing.T) { - // Stream chunks that spell out reasoninganswer - chunks := []string{"", "reasoning", "", "answer"} - srv := newLlamaCppStreamServer(t, chunks) - defer srv.Close() - - p := NewLlamaCppProvider(LlamaCppConfig{ - BaseURL: srv.URL + "/v1", - HTTPClient: srv.Client(), - }) - - ch, err := p.Stream(context.Background(), []Message{{Role: RoleUser, Content: "question"}}, nil) - if err != nil { - t.Fatalf("Stream: %v", err) - } - - var thinkingText, contentText string - for event := range ch { - switch event.Type { - case "thinking": - thinkingText += event.Thinking - case "text": - contentText += event.Text - case "error": - t.Fatalf("stream error: %s", event.Error) - } - } - if thinkingText != "reasoning" { - t.Errorf("thinking: want %q, got %q", "reasoning", thinkingText) - } - if contentText != "answer" { - t.Errorf("content: want %q, got %q", "answer", contentText) - } -} - -func TestLlamaCppProvider_Close_NoProcess(t *testing.T) { - p := NewLlamaCppProvider(LlamaCppConfig{BaseURL: "http://localhost:8081/v1"}) - if err := p.Close(); err != nil { - t.Errorf("Close with no process: want nil, got %v", err) - } -} - -func TestLlamaCppProvider_EnsureServer_ExternalMode(t *testing.T) { - p := NewLlamaCppProvider(LlamaCppConfig{BaseURL: "http://localhost:8081/v1"}) - // External mode: EnsureServer should be a no-op. - if err := p.EnsureServer(context.Background()); err != nil { - t.Errorf("EnsureServer external mode: want nil, got %v", err) - } -} diff --git a/provider/models.go b/provider/models.go index a5106f8..14bf5cb 100644 --- a/provider/models.go +++ b/provider/models.go @@ -8,6 +8,10 @@ import ( "net/http" "sort" "strings" + + "github.com/google/generative-ai-go/genai" + "google.golang.org/api/iterator" + googleoption "google.golang.org/api/option" ) // ModelInfo describes an available model from a provider. @@ -17,6 +21,23 @@ type ModelInfo struct { ContextWindow int `json:"context_window,omitempty"` } +// Constants used by model-listing functions, sourced from former provider files. +const ( + defaultAnthropicBaseURL = "https://api.anthropic.com" + anthropicAPIVersion = "2023-06-01" + defaultOpenAIBaseURL = "https://api.openai.com" + defaultCopilotBaseURL = "https://api.githubcopilot.com" + copilotTokenExchangeURL = "https://api.github.com/copilot_internal/v2/token" + copilotEditorVersion = "ratchet/0.1.0" + defaultCohereBaseURL = "https://api.cohere.com" +) + +// copilotTokenResponse is the response from the Copilot token exchange endpoint. +type copilotTokenResponse struct { + Token string `json:"token"` + ExpiresAt int64 `json:"expires_at"` +} + // ListModels fetches available models from the given provider type. // Only requires an API key and optional base URL — no saved provider needed. func ListModels(ctx context.Context, providerType, apiKey, baseURL string) ([]ModelInfo, error) { @@ -429,10 +450,49 @@ func foundryFallbackModels() []ModelInfo { } } +// listGeminiModels lists available Gemini models using the genai SDK. +func listGeminiModels(ctx context.Context, apiKey string) ([]ModelInfo, error) { + client, err := genai.NewClient(ctx, googleoption.WithAPIKey(apiKey)) + if err != nil { + return geminiFallbackModels(), nil + } + defer client.Close() + + iter := client.ListModels(ctx) + var models []ModelInfo + for { + m, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return geminiFallbackModels(), nil + } + models = append(models, ModelInfo{ + ID: m.Name, + Name: m.DisplayName, + }) + } + if len(models) == 0 { + return geminiFallbackModels(), nil + } + return models, nil +} + +func geminiFallbackModels() []ModelInfo { + return []ModelInfo{ + {ID: "gemini-2.5-pro-preview-03-25", Name: "Gemini 2.5 Pro Preview"}, + {ID: "gemini-2.0-flash", Name: "Gemini 2.0 Flash"}, + {ID: "gemini-2.0-flash-lite", Name: "Gemini 2.0 Flash-Lite"}, + {ID: "gemini-1.5-pro", Name: "Gemini 1.5 Pro"}, + {ID: "gemini-1.5-flash", Name: "Gemini 1.5 Flash"}, + } +} + // listOllamaModels lists models available on a local Ollama server. func listOllamaModels(ctx context.Context, baseURL string) ([]ModelInfo, error) { - p := NewOllamaProvider(OllamaConfig{BaseURL: baseURL}) - return p.ListModels(ctx) + c := NewOllamaClient(baseURL) + return c.ListModels(ctx) } func truncate(s string, maxLen int) string { diff --git a/provider/ollama.go b/provider/ollama.go deleted file mode 100644 index 633b5bb..0000000 --- a/provider/ollama.go +++ /dev/null @@ -1,163 +0,0 @@ -package provider - -import ( - "context" - "fmt" - "net/http" - "net/url" - - ollamaapi "github.com/ollama/ollama/api" -) - -const defaultOllamaBaseURL = "http://localhost:11434" - -// OllamaConfig holds configuration for the Ollama provider. -type OllamaConfig struct { - Model string - BaseURL string - MaxTokens int - HTTPClient *http.Client -} - -// OllamaProvider implements Provider using a local Ollama server. -type OllamaProvider struct { - client *ollamaapi.Client - config OllamaConfig -} - -// NewOllamaProvider creates a new OllamaProvider with the given config. -func NewOllamaProvider(cfg OllamaConfig) *OllamaProvider { - if cfg.BaseURL == "" { - cfg.BaseURL = defaultOllamaBaseURL - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - base, err := url.Parse(cfg.BaseURL) - if err != nil { - // Fall back to default URL if configured one is invalid. - base, _ = url.Parse(defaultOllamaBaseURL) - cfg.BaseURL = defaultOllamaBaseURL - } - return &OllamaProvider{ - client: ollamaapi.NewClient(base, cfg.HTTPClient), - config: cfg, - } -} - -func (p *OllamaProvider) Name() string { return "ollama" } - -func (p *OllamaProvider) AuthModeInfo() AuthModeInfo { - return LocalAuthMode("ollama", "Ollama (Local)") -} - -// Chat sends a non-streaming request and returns the complete response. -func (p *OllamaProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - streamFalse := false - req := &ollamaapi.ChatRequest{ - Model: p.config.Model, - Messages: toOllamaMessages(messages), - Stream: &streamFalse, - Options: map[string]any{}, - } - if len(tools) > 0 { - req.Tools = toOllamaTools(tools) - } - if p.config.MaxTokens > 0 { - req.Options["num_predict"] = p.config.MaxTokens - } - - var final ollamaapi.ChatResponse - if err := p.client.Chat(ctx, req, func(resp ollamaapi.ChatResponse) error { - final = resp - return nil - }); err != nil { - return nil, fmt.Errorf("ollama: %w", err) - } - return fromOllamaResponse(final), nil -} - -// Stream sends a streaming request and emits events on the returned channel. -// Thinking tokens extracted via ThinkingStreamParser are emitted as "thinking" events. -func (p *OllamaProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - req := &ollamaapi.ChatRequest{ - Model: p.config.Model, - Messages: toOllamaMessages(messages), - Options: map[string]any{}, - } - if len(tools) > 0 { - req.Tools = toOllamaTools(tools) - } - if p.config.MaxTokens > 0 { - req.Options["num_predict"] = p.config.MaxTokens - } - - ch := make(chan StreamEvent, 16) - go func() { - defer close(ch) - var parser ThinkingStreamParser - var usage *Usage - - err := p.client.Chat(ctx, req, func(resp ollamaapi.ChatResponse) error { - text, toolCalls, done := fromOllamaStreamChunk(resp) - if text != "" { - for _, ev := range parser.Feed(text) { - ch <- ev - } - } - for i := range toolCalls { - ch <- StreamEvent{Type: "tool_call", Tool: &toolCalls[i]} - } - if done { - usage = &Usage{ - InputTokens: resp.PromptEvalCount, - OutputTokens: resp.EvalCount, - } - } - return nil - }) - if err != nil { - ch <- StreamEvent{Type: "error", Error: err.Error()} - return - } - ch <- StreamEvent{Type: "done", Usage: usage} - }() - return ch, nil -} - -// Pull downloads a model via the Ollama server. -// progressFn is called with percent completion (0–100); may be nil. -func (p *OllamaProvider) Pull(ctx context.Context, model string, progressFn func(pct float64)) error { - req := &ollamaapi.PullRequest{Model: model} - return p.client.Pull(ctx, req, func(resp ollamaapi.ProgressResponse) error { - if progressFn != nil && resp.Total > 0 { - progressFn(float64(resp.Completed) / float64(resp.Total) * 100) - } - return nil - }) -} - -// ListModels returns the models available on the Ollama server. -func (p *OllamaProvider) ListModels(ctx context.Context) ([]ModelInfo, error) { - resp, err := p.client.List(ctx) - if err != nil { - return nil, fmt.Errorf("ollama: list models: %w", err) - } - models := make([]ModelInfo, 0, len(resp.Models)) - for _, m := range resp.Models { - name := m.Name - if name == "" { - name = m.Model - } - models = append(models, ModelInfo{ID: name, Name: name}) - } - return models, nil -} - -// Health checks whether the Ollama server is reachable. -func (p *OllamaProvider) Health(ctx context.Context) error { - if err := p.client.Heartbeat(ctx); err != nil { - return fmt.Errorf("ollama: health check: %w", err) - } - return nil -} diff --git a/provider/ollama_client.go b/provider/ollama_client.go new file mode 100644 index 0000000..f04b75e --- /dev/null +++ b/provider/ollama_client.go @@ -0,0 +1,71 @@ +package provider + +import ( + "context" + "fmt" + "net/http" + "net/url" + + ollamaapi "github.com/ollama/ollama/api" +) + +const defaultOllamaClientURL = "http://localhost:11434" + +// OllamaClient provides utility operations against a local Ollama server: +// pulling models and listing available models. For chat/stream, use the +// Genkit-backed provider returned by genkit.NewOllamaProvider. +type OllamaClient struct { + client *ollamaapi.Client +} + +// NewOllamaClient creates an OllamaClient pointing at the given server address. +// If serverAddress is empty, http://localhost:11434 is used. +func NewOllamaClient(serverAddress string) *OllamaClient { + if serverAddress == "" { + serverAddress = defaultOllamaClientURL + } + base, err := url.Parse(serverAddress) + if err != nil { + base, _ = url.Parse(defaultOllamaClientURL) + } + return &OllamaClient{ + client: ollamaapi.NewClient(base, http.DefaultClient), + } +} + +// Pull downloads a model via the Ollama server. +// progressFn is called with percent completion (0–100); may be nil. +func (c *OllamaClient) Pull(ctx context.Context, model string, progressFn func(pct float64)) error { + req := &ollamaapi.PullRequest{Model: model} + return c.client.Pull(ctx, req, func(resp ollamaapi.ProgressResponse) error { + if progressFn != nil && resp.Total > 0 { + progressFn(float64(resp.Completed) / float64(resp.Total) * 100) + } + return nil + }) +} + +// ListModels returns the models available on the Ollama server. +func (c *OllamaClient) ListModels(ctx context.Context) ([]ModelInfo, error) { + resp, err := c.client.List(ctx) + if err != nil { + return nil, fmt.Errorf("ollama: list models: %w", err) + } + models := make([]ModelInfo, 0, len(resp.Models)) + for _, m := range resp.Models { + name := m.Name + if name == "" { + name = m.Model + } + models = append(models, ModelInfo{ID: name, Name: name}) + } + return models, nil +} + +// Health checks whether the Ollama server is reachable. +func (c *OllamaClient) Health(ctx context.Context) error { + if err := c.client.Heartbeat(ctx); err != nil { + return fmt.Errorf("ollama: health check: %w", err) + } + return nil +} diff --git a/provider/ollama_convert.go b/provider/ollama_convert.go deleted file mode 100644 index 2e90568..0000000 --- a/provider/ollama_convert.go +++ /dev/null @@ -1,96 +0,0 @@ -package provider - -import ( - "encoding/json" - - ollamaapi "github.com/ollama/ollama/api" -) - -// toOllamaMessages converts provider messages to Ollama API messages. -func toOllamaMessages(msgs []Message) []ollamaapi.Message { - result := make([]ollamaapi.Message, 0, len(msgs)) - for _, msg := range msgs { - m := ollamaapi.Message{ - Role: string(msg.Role), - Content: msg.Content, - ToolCallID: msg.ToolCallID, - } - for _, tc := range msg.ToolCalls { - args := ollamaapi.NewToolCallFunctionArguments() - for k, v := range tc.Arguments { - args.Set(k, v) - } - m.ToolCalls = append(m.ToolCalls, ollamaapi.ToolCall{ - ID: tc.ID, - Function: ollamaapi.ToolCallFunction{ - Name: tc.Name, - Arguments: args, - }, - }) - } - result = append(result, m) - } - return result -} - -// toOllamaTools converts provider tool definitions to Ollama API tools. -// The JSON Schema parameters are marshaled and unmarshaled into the Ollama type. -func toOllamaTools(tools []ToolDef) []ollamaapi.Tool { - result := make([]ollamaapi.Tool, 0, len(tools)) - for _, t := range tools { - var params ollamaapi.ToolFunctionParameters - if b, err := json.Marshal(t.Parameters); err == nil { - _ = json.Unmarshal(b, ¶ms) - } - result = append(result, ollamaapi.Tool{ - Type: "function", - Function: ollamaapi.ToolFunction{ - Name: t.Name, - Description: t.Description, - Parameters: params, - }, - }) - } - return result -} - -// fromOllamaResponse converts a completed Ollama ChatResponse to a provider Response. -// ParseThinking is applied to extract any ... block from the content. -func fromOllamaResponse(resp ollamaapi.ChatResponse) *Response { - thinking, content := ParseThinking(resp.Message.Content) - // Prefer native thinking field if set (Ollama Think mode) - if resp.Message.Thinking != "" { - thinking = resp.Message.Thinking - content = resp.Message.Content - } - r := &Response{ - Content: content, - Thinking: thinking, - Usage: Usage{ - InputTokens: resp.PromptEvalCount, - OutputTokens: resp.EvalCount, - }, - } - for _, tc := range resp.Message.ToolCalls { - r.ToolCalls = append(r.ToolCalls, ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: tc.Function.Arguments.ToMap(), - }) - } - return r -} - -// fromOllamaStreamChunk extracts text content, tool calls, and done flag from a -// streaming Ollama ChatResponse chunk. -func fromOllamaStreamChunk(resp ollamaapi.ChatResponse) (text string, toolCalls []ToolCall, done bool) { - text = resp.Message.Content - for _, tc := range resp.Message.ToolCalls { - toolCalls = append(toolCalls, ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: tc.Function.Arguments.ToMap(), - }) - } - return text, toolCalls, resp.Done -} diff --git a/provider/ollama_convert_test.go b/provider/ollama_convert_test.go deleted file mode 100644 index 1e3d51c..0000000 --- a/provider/ollama_convert_test.go +++ /dev/null @@ -1,233 +0,0 @@ -package provider - -import ( - "testing" - "time" - - ollamaapi "github.com/ollama/ollama/api" -) - -// --- toOllamaMessages --- - -func TestToOllamaMessages_Roles(t *testing.T) { - msgs := []Message{ - {Role: RoleSystem, Content: "sys"}, - {Role: RoleUser, Content: "hi"}, - {Role: RoleAssistant, Content: "hello"}, - {Role: RoleTool, Content: "tool result", ToolCallID: "tc1"}, - } - got := toOllamaMessages(msgs) - if len(got) != 4 { - t.Fatalf("len=%d, want 4", len(got)) - } - if got[0].Role != "system" || got[0].Content != "sys" { - t.Errorf("system msg wrong: %+v", got[0]) - } - if got[1].Role != "user" || got[1].Content != "hi" { - t.Errorf("user msg wrong: %+v", got[1]) - } - if got[2].Role != "assistant" || got[2].Content != "hello" { - t.Errorf("assistant msg wrong: %+v", got[2]) - } - if got[3].Role != "tool" || got[3].ToolCallID != "tc1" { - t.Errorf("tool msg wrong: %+v", got[3]) - } -} - -func TestToOllamaMessages_ToolCalls(t *testing.T) { - msgs := []Message{ - { - Role: RoleAssistant, - ToolCalls: []ToolCall{ - {ID: "id1", Name: "search", Arguments: map[string]any{"q": "hello"}}, - }, - }, - } - got := toOllamaMessages(msgs) - if len(got[0].ToolCalls) != 1 { - t.Fatalf("tool calls len=%d, want 1", len(got[0].ToolCalls)) - } - tc := got[0].ToolCalls[0] - if tc.ID != "id1" { - t.Errorf("ID=%q, want %q", tc.ID, "id1") - } - if tc.Function.Name != "search" { - t.Errorf("Name=%q, want %q", tc.Function.Name, "search") - } - v, ok := tc.Function.Arguments.Get("q") - if !ok || v != "hello" { - t.Errorf("arg q=%v ok=%v, want 'hello'", v, ok) - } -} - -// --- toOllamaTools --- - -func TestToOllamaTools(t *testing.T) { - tools := []ToolDef{ - { - Name: "calculator", - Description: "does math", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "expr": map[string]any{"type": "string"}, - }, - "required": []any{"expr"}, - }, - }, - } - got := toOllamaTools(tools) - if len(got) != 1 { - t.Fatalf("len=%d, want 1", len(got)) - } - if got[0].Type != "function" { - t.Errorf("Type=%q, want 'function'", got[0].Type) - } - if got[0].Function.Name != "calculator" { - t.Errorf("Name=%q", got[0].Function.Name) - } - if got[0].Function.Description != "does math" { - t.Errorf("Description=%q", got[0].Function.Description) - } -} - -func TestToOllamaTools_Empty(t *testing.T) { - got := toOllamaTools(nil) - if len(got) != 0 { - t.Errorf("expected empty slice, got %d", len(got)) - } -} - -// --- fromOllamaResponse --- - -func TestFromOllamaResponse_Basic(t *testing.T) { - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{Role: "assistant", Content: "hello world"}, - Metrics: ollamaapi.Metrics{PromptEvalCount: 10, EvalCount: 5}, - } - got := fromOllamaResponse(resp) - if got.Content != "hello world" { - t.Errorf("Content=%q, want %q", got.Content, "hello world") - } - if got.Thinking != "" { - t.Errorf("Thinking=%q, want empty", got.Thinking) - } - if got.Usage.InputTokens != 10 { - t.Errorf("InputTokens=%d, want 10", got.Usage.InputTokens) - } - if got.Usage.OutputTokens != 5 { - t.Errorf("OutputTokens=%d, want 5", got.Usage.OutputTokens) - } -} - -func TestFromOllamaResponse_ThinkTagExtraction(t *testing.T) { - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{ - Role: "assistant", - Content: "reasoninganswer", - }, - } - got := fromOllamaResponse(resp) - if got.Thinking != "reasoning" { - t.Errorf("Thinking=%q, want %q", got.Thinking, "reasoning") - } - if got.Content != "answer" { - t.Errorf("Content=%q, want %q", got.Content, "answer") - } -} - -func TestFromOllamaResponse_NativeThinkingField(t *testing.T) { - // When Thinking is set natively (Ollama Think mode), prefer it. - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{ - Role: "assistant", - Content: "answer", - Thinking: "native thinking", - }, - } - got := fromOllamaResponse(resp) - if got.Thinking != "native thinking" { - t.Errorf("Thinking=%q, want %q", got.Thinking, "native thinking") - } - if got.Content != "answer" { - t.Errorf("Content=%q, want %q", got.Content, "answer") - } -} - -func TestFromOllamaResponse_ToolCalls(t *testing.T) { - args := ollamaapi.NewToolCallFunctionArguments() - args.Set("query", "test") - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{ - Role: "assistant", - ToolCalls: []ollamaapi.ToolCall{ - {ID: "tc1", Function: ollamaapi.ToolCallFunction{Name: "search", Arguments: args}}, - }, - }, - } - got := fromOllamaResponse(resp) - if len(got.ToolCalls) != 1 { - t.Fatalf("ToolCalls len=%d, want 1", len(got.ToolCalls)) - } - tc := got.ToolCalls[0] - if tc.ID != "tc1" || tc.Name != "search" { - t.Errorf("tc=%+v", tc) - } - if tc.Arguments["query"] != "test" { - t.Errorf("arg query=%v, want 'test'", tc.Arguments["query"]) - } -} - -// --- fromOllamaStreamChunk --- - -func TestFromOllamaStreamChunk_Text(t *testing.T) { - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{Role: "assistant", Content: "chunk text"}, - Done: false, - CreatedAt: time.Now(), - } - text, toolCalls, done := fromOllamaStreamChunk(resp) - if text != "chunk text" { - t.Errorf("text=%q, want %q", text, "chunk text") - } - if len(toolCalls) != 0 { - t.Errorf("toolCalls=%v, want empty", toolCalls) - } - if done { - t.Error("done should be false") - } -} - -func TestFromOllamaStreamChunk_Done(t *testing.T) { - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{Role: "assistant", Content: ""}, - Done: true, - CreatedAt: time.Now(), - } - _, _, done := fromOllamaStreamChunk(resp) - if !done { - t.Error("done should be true") - } -} - -func TestFromOllamaStreamChunk_ToolCalls(t *testing.T) { - args := ollamaapi.NewToolCallFunctionArguments() - args.Set("x", 42) - resp := ollamaapi.ChatResponse{ - Message: ollamaapi.Message{ - Role: "assistant", - ToolCalls: []ollamaapi.ToolCall{ - {ID: "tc2", Function: ollamaapi.ToolCallFunction{Name: "calc", Arguments: args}}, - }, - }, - Done: true, - CreatedAt: time.Now(), - } - _, toolCalls, _ := fromOllamaStreamChunk(resp) - if len(toolCalls) != 1 { - t.Fatalf("len=%d, want 1", len(toolCalls)) - } - if toolCalls[0].Name != "calc" { - t.Errorf("Name=%q, want %q", toolCalls[0].Name, "calc") - } -} diff --git a/provider/ollama_test.go b/provider/ollama_test.go deleted file mode 100644 index b0b31ea..0000000 --- a/provider/ollama_test.go +++ /dev/null @@ -1,217 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" -) - -// ollamaChatNDJSON builds a minimal Ollama chat NDJSON response line. -func ollamaChatNDJSON(content string, done bool) string { - resp := map[string]any{ - "model": "test", - "created_at": time.Now().Format(time.RFC3339), - "message": map[string]any{ - "role": "assistant", - "content": content, - }, - "done": done, - "prompt_eval_count": 3, - "eval_count": 7, - } - b, _ := json.Marshal(resp) - return string(b) + "\n" -} - -// ollamaMockServer sets up a minimal Ollama HTTP server for unit tests. -func ollamaMockServer(t *testing.T, chatContent string) *httptest.Server { - t.Helper() - mux := http.NewServeMux() - - // HEAD / — heartbeat - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodHead { - w.WriteHeader(http.StatusOK) - } - }) - - // POST /api/chat — chat (streaming NDJSON) - mux.HandleFunc("/api/chat", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/x-ndjson") - w.Write([]byte(ollamaChatNDJSON(chatContent, true))) - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - }) - - // GET /api/tags — list models - mux.HandleFunc("/api/tags", func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(map[string]any{ - "models": []map[string]any{ - {"name": "llama3.2:latest", "model": "llama3.2:latest", "size": int64(4000000000)}, - }, - }) - }) - - return httptest.NewServer(mux) -} - -func TestOllamaProvider_ImplementsProvider(t *testing.T) { - var _ Provider = (*OllamaProvider)(nil) -} - -func TestOllamaProvider_Name(t *testing.T) { - p := NewOllamaProvider(OllamaConfig{Model: "qwen3.5:7b"}) - if got := p.Name(); got != "ollama" { - t.Errorf("Name()=%q, want %q", got, "ollama") - } -} - -func TestOllamaProvider_AuthModeInfo(t *testing.T) { - p := NewOllamaProvider(OllamaConfig{}) - info := p.AuthModeInfo() - if info.Mode != "ollama" { - t.Errorf("Mode=%q, want %q", info.Mode, "ollama") - } - if !info.ServerSafe { - t.Error("ServerSafe should be true") - } - if info.Warning != "" { - t.Errorf("Warning should be empty, got %q", info.Warning) - } -} - -func TestOllamaProvider_DefaultBaseURL(t *testing.T) { - p := NewOllamaProvider(OllamaConfig{}) - if p.config.BaseURL != defaultOllamaBaseURL { - t.Errorf("BaseURL=%q, want %q", p.config.BaseURL, defaultOllamaBaseURL) - } -} - -func TestOllamaProvider_Health(t *testing.T) { - srv := ollamaMockServer(t, "") - defer srv.Close() - - p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "test"}) - if err := p.Health(context.Background()); err != nil { - t.Errorf("Health() error: %v", err) - } -} - -func TestOllamaProvider_Chat(t *testing.T) { - srv := ollamaMockServer(t, "hello world") - defer srv.Close() - - p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "test"}) - resp, err := p.Chat(context.Background(), []Message{{Role: RoleUser, Content: "hi"}}, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if resp.Content != "hello world" { - t.Errorf("Content=%q, want %q", resp.Content, "hello world") - } - if resp.Usage.InputTokens != 3 { - t.Errorf("InputTokens=%d, want 3", resp.Usage.InputTokens) - } - if resp.Usage.OutputTokens != 7 { - t.Errorf("OutputTokens=%d, want 7", resp.Usage.OutputTokens) - } -} - -func TestOllamaProvider_ChatWithThinkTags(t *testing.T) { - srv := ollamaMockServer(t, "reasoninganswer") - defer srv.Close() - - p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "test"}) - resp, err := p.Chat(context.Background(), []Message{{Role: RoleUser, Content: "q"}}, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if resp.Thinking != "reasoning" { - t.Errorf("Thinking=%q, want %q", resp.Thinking, "reasoning") - } - if resp.Content != "answer" { - t.Errorf("Content=%q, want %q", resp.Content, "answer") - } -} - -func TestOllamaProvider_Stream(t *testing.T) { - srv := ollamaMockServer(t, "stream result") - defer srv.Close() - - p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "test"}) - ch, err := p.Stream(context.Background(), []Message{{Role: RoleUser, Content: "hi"}}, nil) - if err != nil { - t.Fatalf("Stream() error: %v", err) - } - - var textParts []string - var sawDone bool - for ev := range ch { - switch ev.Type { - case "text": - textParts = append(textParts, ev.Text) - case "done": - sawDone = true - case "error": - t.Fatalf("stream error: %s", ev.Error) - } - } - if !sawDone { - t.Error("expected done event") - } - full := strings.Join(textParts, "") - if full != "stream result" { - t.Errorf("text=%q, want %q", full, "stream result") - } -} - -func TestOllamaProvider_StreamThinkTags(t *testing.T) { - srv := ollamaMockServer(t, "thinkingcontent") - defer srv.Close() - - p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "test"}) - ch, err := p.Stream(context.Background(), []Message{{Role: RoleUser, Content: "q"}}, nil) - if err != nil { - t.Fatalf("Stream() error: %v", err) - } - - var thinkParts, textParts []string - for ev := range ch { - switch ev.Type { - case "thinking": - thinkParts = append(thinkParts, ev.Thinking) - case "text": - textParts = append(textParts, ev.Text) - case "error": - t.Fatalf("stream error: %s", ev.Error) - } - } - if strings.Join(thinkParts, "") != "thinking" { - t.Errorf("thinking=%q, want %q", strings.Join(thinkParts, ""), "thinking") - } - if strings.Join(textParts, "") != "content" { - t.Errorf("text=%q, want %q", strings.Join(textParts, ""), "content") - } -} - -func TestOllamaProvider_ListModels(t *testing.T) { - srv := ollamaMockServer(t, "") - defer srv.Close() - - p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "test"}) - models, err := p.ListModels(context.Background()) - if err != nil { - t.Fatalf("ListModels() error: %v", err) - } - if len(models) != 1 { - t.Fatalf("len=%d, want 1", len(models)) - } - if models[0].ID != "llama3.2:latest" { - t.Errorf("ID=%q, want %q", models[0].ID, "llama3.2:latest") - } -} diff --git a/provider/openai.go b/provider/openai.go deleted file mode 100644 index ffb9b68..0000000 --- a/provider/openai.go +++ /dev/null @@ -1,97 +0,0 @@ -package provider - -import ( - "context" - "fmt" - "net/http" - - openaisdk "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/shared" -) - -const ( - defaultOpenAIBaseURL = "https://api.openai.com" - defaultOpenAIModel = "gpt-4o" - defaultOpenAIMaxTokens = 4096 -) - -// OpenAIConfig holds configuration for the OpenAI provider. -type OpenAIConfig struct { - APIKey string - Model string - BaseURL string - MaxTokens int - HTTPClient *http.Client -} - -// OpenAIProvider implements Provider using the OpenAI Chat Completions API. -type OpenAIProvider struct { - client openaisdk.Client - config OpenAIConfig -} - -// NewOpenAIProvider creates a new OpenAI provider with the given config. -func NewOpenAIProvider(cfg OpenAIConfig) *OpenAIProvider { - if cfg.Model == "" { - cfg.Model = defaultOpenAIModel - } - if cfg.BaseURL == "" { - cfg.BaseURL = defaultOpenAIBaseURL - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultOpenAIMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - client := openaisdk.NewClient( - option.WithAPIKey(cfg.APIKey), - option.WithBaseURL(cfg.BaseURL), - option.WithHTTPClient(cfg.HTTPClient), - ) - return &OpenAIProvider{client: client, config: cfg} -} - -func (p *OpenAIProvider) Name() string { return "openai" } - -func (p *OpenAIProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "direct", - DisplayName: "OpenAI (Direct API)", - Description: "Direct access to OpenAI models via API key.", - DocsURL: "https://platform.openai.com/docs/api-reference/introduction", - ServerSafe: true, - } -} - -func (p *OpenAIProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - params := openaisdk.ChatCompletionNewParams{ - Model: shared.ChatModel(p.config.Model), - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - resp, err := p.client.Chat.Completions.New(ctx, params) - if err != nil { - return nil, fmt.Errorf("openai: %w", err) - } - return fromOpenAIResponse(resp) -} - -func (p *OpenAIProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - params := openaisdk.ChatCompletionNewParams{ - Model: shared.ChatModel(p.config.Model), - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - stream := p.client.Chat.Completions.NewStreaming(ctx, params) - ch := make(chan StreamEvent, 16) - go streamOpenAIEvents(stream, ch) - return ch, nil -} diff --git a/provider/openai_azure.go b/provider/openai_azure.go deleted file mode 100644 index 68b70b5..0000000 --- a/provider/openai_azure.go +++ /dev/null @@ -1,132 +0,0 @@ -package provider - -import ( - "context" - "fmt" - "net/http" - - openaisdk "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/shared" -) - -const defaultAzureOpenAIAPIVersion = "2024-10-21" - -// OpenAIAzureConfig configures the OpenAI provider for Azure OpenAI Service. -// Uses Azure API keys or Entra ID tokens. URLs follow the pattern: -// {resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version} -type OpenAIAzureConfig struct { - // Resource is the Azure OpenAI resource name. - Resource string - // DeploymentName is the model deployment name in Azure. - DeploymentName string - // APIVersion is the Azure API version (e.g. "2024-10-21"). - APIVersion string - // MaxTokens limits the response length. - MaxTokens int - // APIKey is the Azure API key (use this OR Entra ID token, not both). - APIKey string - // EntraToken is a Microsoft Entra ID bearer token (optional, alternative to APIKey). - EntraToken string - // HTTPClient overrides the default HTTP client. - HTTPClient *http.Client - // BaseURL overrides the computed Azure endpoint (used in tests). - BaseURL string -} - -// OpenAIAzureProvider accesses OpenAI models via Azure OpenAI Service. -type OpenAIAzureProvider struct { - client openaisdk.Client - config OpenAIAzureConfig -} - -func (p *OpenAIAzureProvider) Name() string { return "openai_azure" } - -func (p *OpenAIAzureProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "azure", - DisplayName: "OpenAI (Azure OpenAI Service)", - Description: "Access OpenAI models via Azure OpenAI Service using Azure API keys or Microsoft Entra ID tokens. Uses deployment-specific URLs.", - DocsURL: "https://learn.microsoft.com/en-us/azure/ai-services/openai/reference", - ServerSafe: true, - } -} - -// NewOpenAIAzureProvider creates a provider that accesses OpenAI models via Azure. -// -// Docs: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference -func NewOpenAIAzureProvider(cfg OpenAIAzureConfig) (*OpenAIAzureProvider, error) { - if cfg.Resource == "" && cfg.BaseURL == "" { - return nil, fmt.Errorf("openai_azure: Resource is required") - } - if cfg.DeploymentName == "" && cfg.BaseURL == "" { - return nil, fmt.Errorf("openai_azure: DeploymentName is required") - } - if cfg.APIKey == "" && cfg.EntraToken == "" { - return nil, fmt.Errorf("openai_azure: APIKey or EntraToken is required") - } - if cfg.APIVersion == "" { - cfg.APIVersion = defaultAzureOpenAIAPIVersion - } - if cfg.MaxTokens <= 0 { - cfg.MaxTokens = defaultOpenAIMaxTokens - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = http.DefaultClient - } - - baseURL := cfg.BaseURL - if baseURL == "" { - baseURL = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s", - cfg.Resource, cfg.DeploymentName) - } - - opts := []option.RequestOption{ - // Use a placeholder API key to prevent OPENAI_API_KEY env var from being used, - // then remove the resulting Authorization header for Azure auth. - option.WithAPIKey("azure-placeholder"), - option.WithHeaderDel("authorization"), - option.WithBaseURL(baseURL), - option.WithQuery("api-version", cfg.APIVersion), - option.WithHTTPClient(cfg.HTTPClient), - } - if cfg.APIKey != "" { - opts = append(opts, option.WithHeader("api-key", cfg.APIKey)) - } else if cfg.EntraToken != "" { - opts = append(opts, option.WithHeader("Authorization", "Bearer "+cfg.EntraToken)) - } - - client := openaisdk.NewClient(opts...) - return &OpenAIAzureProvider{client: client, config: cfg}, nil -} - -func (p *OpenAIAzureProvider) Chat(ctx context.Context, messages []Message, tools []ToolDef) (*Response, error) { - params := openaisdk.ChatCompletionNewParams{ - Model: shared.ChatModel(p.config.DeploymentName), - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - resp, err := p.client.Chat.Completions.New(ctx, params) - if err != nil { - return nil, fmt.Errorf("openai_azure: %w", err) - } - return fromOpenAIResponse(resp) -} - -func (p *OpenAIAzureProvider) Stream(ctx context.Context, messages []Message, tools []ToolDef) (<-chan StreamEvent, error) { - params := openaisdk.ChatCompletionNewParams{ - Model: shared.ChatModel(p.config.DeploymentName), - Messages: toOpenAIMessages(messages), - MaxTokens: openaisdk.Int(int64(p.config.MaxTokens)), - } - if len(tools) > 0 { - params.Tools = toOpenAITools(tools) - } - stream := p.client.Chat.Completions.NewStreaming(ctx, params) - ch := make(chan StreamEvent, 16) - go streamOpenAIEvents(stream, ch) - return ch, nil -} diff --git a/provider/openai_azure_test.go b/provider/openai_azure_test.go deleted file mode 100644 index 7b7ae9c..0000000 --- a/provider/openai_azure_test.go +++ /dev/null @@ -1,218 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestOpenAIAzureProvider_Name(t *testing.T) { - p, err := NewOpenAIAzureProvider(OpenAIAzureConfig{ - Resource: "myresource", DeploymentName: "gpt-4o", APIKey: "key123", - }) - if err != nil { - t.Fatal(err) - } - if got := p.Name(); got != "openai_azure" { - t.Errorf("Name() = %q, want %q", got, "openai_azure") - } -} - -func TestOpenAIAzureProvider_AuthModeInfo(t *testing.T) { - p, err := NewOpenAIAzureProvider(OpenAIAzureConfig{ - Resource: "myresource", DeploymentName: "gpt-4o", APIKey: "key123", - }) - if err != nil { - t.Fatal(err) - } - info := p.AuthModeInfo() - if info.Mode != "azure" { - t.Errorf("AuthModeInfo().Mode = %q, want %q", info.Mode, "azure") - } - if info.DisplayName != "OpenAI (Azure OpenAI Service)" { - t.Errorf("AuthModeInfo().DisplayName = %q, want %q", info.DisplayName, "OpenAI (Azure OpenAI Service)") - } - if !info.ServerSafe { - t.Error("AuthModeInfo().ServerSafe = false, want true") - } -} - -func TestOpenAIAzureProvider_ImplementsProvider(t *testing.T) { - var _ Provider = (*OpenAIAzureProvider)(nil) -} - -func TestOpenAIAzureProvider_ValidationErrors(t *testing.T) { - tests := []struct { - name string - cfg OpenAIAzureConfig - want string - }{ - {"missing resource", OpenAIAzureConfig{DeploymentName: "d", APIKey: "k"}, "Resource is required"}, - {"missing deployment", OpenAIAzureConfig{Resource: "r", APIKey: "k"}, "DeploymentName is required"}, - {"missing auth", OpenAIAzureConfig{Resource: "r", DeploymentName: "d"}, "APIKey or EntraToken is required"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := NewOpenAIAzureProvider(tt.cfg) - if err == nil { - t.Fatal("expected error") - } - if got := err.Error(); !strings.Contains(got, tt.want) { - t.Errorf("error = %q, want to contain %q", got, tt.want) - } - }) - } -} - -func TestOpenAIAzureProvider_DefaultAPIVersion(t *testing.T) { - p, err := NewOpenAIAzureProvider(OpenAIAzureConfig{ - Resource: "res", DeploymentName: "dep", APIKey: "key", - }) - if err != nil { - t.Fatal(err) - } - if p.config.APIVersion != defaultAzureOpenAIAPIVersion { - t.Errorf("APIVersion = %q, want %q", p.config.APIVersion, defaultAzureOpenAIAPIVersion) - } -} - -func TestOpenAIAzureProvider_APIKeyAuth(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("api-key"); got != "azure-key-123" { - t.Errorf("api-key header = %q, want %q", got, "azure-key-123") - } - if got := r.Header.Get("Authorization"); got != "" { - t.Errorf("Authorization header should be empty with API key auth, got %q", got) - } - - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"chatcmpl-azure","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"hello from azure"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8},"created":1704067200}`) - })) - defer srv.Close() - - p, err := NewOpenAIAzureProvider(OpenAIAzureConfig{ - Resource: "test", - DeploymentName: "gpt-4o", - APIKey: "azure-key-123", - APIVersion: "2024-10-21", - MaxTokens: 4096, - HTTPClient: http.DefaultClient, - BaseURL: srv.URL, - }) - if err != nil { - t.Fatal(err) - } - - got, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if got.Content != "hello from azure" { - t.Errorf("Chat() content = %q, want %q", got.Content, "hello from azure") - } -} - -func TestOpenAIAzureProvider_EntraTokenAuth(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("Authorization"); got != "Bearer entra-token-xyz" { - t.Errorf("Authorization header = %q, want %q", got, "Bearer entra-token-xyz") - } - if got := r.Header.Get("api-key"); got != "" { - t.Errorf("api-key header should be empty with Entra auth, got %q", got) - } - - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"chatcmpl-entra","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"hello from entra"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8},"created":1704067200}`) - })) - defer srv.Close() - - p, err := NewOpenAIAzureProvider(OpenAIAzureConfig{ - Resource: "test", - DeploymentName: "gpt-4o", - EntraToken: "entra-token-xyz", - APIVersion: "2024-10-21", - MaxTokens: 4096, - HTTPClient: http.DefaultClient, - BaseURL: srv.URL, - }) - if err != nil { - t.Fatal(err) - } - - got, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if got.Content != "hello from entra" { - t.Errorf("Chat() content = %q, want %q", got.Content, "hello from entra") - } -} - -func TestOpenAIAzureProvider_Stream(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - flusher, ok := w.(http.Flusher) - if !ok { - t.Fatal("expected http.Flusher") - } - - data, _ := json.Marshal(map[string]any{ - "id": "chatcmpl-azure-stream", - "object": "chat.completion.chunk", - "model": "gpt-4o", - "created": 1704067200, - "choices": []map[string]any{ - {"index": 0, "delta": map[string]any{"role": "assistant", "content": "azure-streamed"}, "finish_reason": nil}, - }, - }) - fmt.Fprintf(w, "data: %s\n\n", string(data)) - flusher.Flush() - - fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() - })) - defer srv.Close() - - p, err := NewOpenAIAzureProvider(OpenAIAzureConfig{ - Resource: "test", - DeploymentName: "gpt-4o", - APIKey: "azure-key", - APIVersion: "2024-10-21", - MaxTokens: 4096, - HTTPClient: http.DefaultClient, - BaseURL: srv.URL, - }) - if err != nil { - t.Fatal(err) - } - - ch, err := p.Stream(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Stream() error: %v", err) - } - - var texts []string - for ev := range ch { - switch ev.Type { - case "text": - texts = append(texts, ev.Text) - case "error": - t.Fatalf("stream error: %s", ev.Error) - } - } - if len(texts) == 0 { - t.Fatal("expected at least one text event") - } - if texts[0] != "azure-streamed" { - t.Errorf("first text event = %q, want %q", texts[0], "azure-streamed") - } -} diff --git a/provider/openai_convert.go b/provider/openai_convert.go deleted file mode 100644 index ce9f0ef..0000000 --- a/provider/openai_convert.go +++ /dev/null @@ -1,196 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "sort" - "strings" - - openaisdk "github.com/openai/openai-go" - "github.com/openai/openai-go/shared" -) - -// toOpenAIMessages converts provider messages to SDK message params. -func toOpenAIMessages(msgs []Message) []openaisdk.ChatCompletionMessageParamUnion { - result := make([]openaisdk.ChatCompletionMessageParamUnion, 0, len(msgs)) - for _, msg := range msgs { - switch msg.Role { - case RoleTool: - result = append(result, openaisdk.ToolMessage(msg.Content, msg.ToolCallID)) - case RoleAssistant: - asst := openaisdk.ChatCompletionAssistantMessageParam{} - if msg.Content != "" { - asst.Content = openaisdk.ChatCompletionAssistantMessageParamContentUnion{ - OfString: openaisdk.String(msg.Content), - } - } - for _, tc := range msg.ToolCalls { - args := "{}" - if tc.Arguments != nil { - if b, err := json.Marshal(tc.Arguments); err == nil { - args = string(b) - } - } - asst.ToolCalls = append(asst.ToolCalls, openaisdk.ChatCompletionMessageToolCallParam{ - ID: tc.ID, - Function: openaisdk.ChatCompletionMessageToolCallFunctionParam{ - Name: tc.Name, - Arguments: args, - }, - }) - } - result = append(result, openaisdk.ChatCompletionMessageParamUnion{OfAssistant: &asst}) - case RoleSystem: - result = append(result, openaisdk.SystemMessage(msg.Content)) - default: // RoleUser and others - result = append(result, openaisdk.UserMessage(msg.Content)) - } - } - return result -} - -// toOpenAITools converts provider tool definitions to SDK tool params. -func toOpenAITools(tools []ToolDef) []openaisdk.ChatCompletionToolParam { - result := make([]openaisdk.ChatCompletionToolParam, 0, len(tools)) - for _, t := range tools { - schema := shared.FunctionParameters(t.Parameters) - if schema == nil { - schema = shared.FunctionParameters{"type": "object", "properties": map[string]any{}} - } - result = append(result, openaisdk.ChatCompletionToolParam{ - Function: shared.FunctionDefinitionParam{ - Name: t.Name, - Description: openaisdk.String(t.Description), - Parameters: schema, - }, - }) - } - return result -} - -// fromOpenAIResponse converts an SDK ChatCompletion to the provider Response type. -func fromOpenAIResponse(resp *openaisdk.ChatCompletion) (*Response, error) { - result := &Response{ - Usage: Usage{ - InputTokens: int(resp.Usage.PromptTokens), - OutputTokens: int(resp.Usage.CompletionTokens), - }, - } - if len(resp.Choices) == 0 { - return result, nil - } - msg := resp.Choices[0].Message - result.Content = msg.Content - // For reasoning models (Qwen3.5, etc.) that put output in reasoning_content - // instead of content, extract it from the raw JSON extra fields. - if result.Content == "" { - if rc, ok := msg.JSON.ExtraFields["reasoning_content"]; ok && rc.Valid() { - var reasoning string - if err := json.Unmarshal([]byte(rc.Raw()), &reasoning); err == nil && reasoning != "" { - result.Content = reasoning - } - } - } - for _, tc := range msg.ToolCalls { - var args map[string]any - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { - return nil, fmt.Errorf("openai: unmarshal tool call arguments for %q: %w", tc.Function.Name, err) - } - } - result.ToolCalls = append(result.ToolCalls, ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: args, - }) - } - return result, nil -} - -// openaiChunkStream is satisfied by *ssestream.Stream[openaisdk.ChatCompletionChunk]. -type openaiChunkStream interface { - Next() bool - Current() openaisdk.ChatCompletionChunk - Err() error - Close() error -} - -// streamOpenAIEvents drains stream and sends StreamEvents to ch, then closes ch. -func streamOpenAIEvents(stream openaiChunkStream, ch chan<- StreamEvent) { - defer close(ch) - defer stream.Close() - - type pendingToolCall struct { - id string - name string - argsBuf strings.Builder - } - pending := make(map[int64]*pendingToolCall) - var usage *Usage - - for stream.Next() { - chunk := stream.Current() - - if chunk.Usage.PromptTokens > 0 || chunk.Usage.CompletionTokens > 0 { - usage = &Usage{ - InputTokens: int(chunk.Usage.PromptTokens), - OutputTokens: int(chunk.Usage.CompletionTokens), - } - } - - if len(chunk.Choices) == 0 { - continue - } - - delta := chunk.Choices[0].Delta - - if delta.Content != "" { - ch <- StreamEvent{Type: "text", Text: delta.Content} - } - - for _, tc := range delta.ToolCalls { - ptc, exists := pending[tc.Index] - if !exists { - ptc = &pendingToolCall{} - pending[tc.Index] = ptc - } - if tc.ID != "" { - ptc.id = tc.ID - } - if tc.Function.Name != "" { - ptc.name = tc.Function.Name - } - if tc.Function.Arguments != "" { - ptc.argsBuf.WriteString(tc.Function.Arguments) - } - } - } - - if err := stream.Err(); err != nil { - ch <- StreamEvent{Type: "error", Error: err.Error()} - return - } - - indices := make([]int64, 0, len(pending)) - for idx := range pending { - indices = append(indices, idx) - } - sort.Slice(indices, func(i, j int) bool { return indices[i] < indices[j] }) - - for _, idx := range indices { - ptc := pending[idx] - var args map[string]any - if ptc.argsBuf.Len() > 0 { - if err := json.Unmarshal([]byte(ptc.argsBuf.String()), &args); err != nil { - ch <- StreamEvent{Type: "error", Error: fmt.Sprintf("openai: unmarshal tool call arguments for %q: %v", ptc.name, err)} - return - } - } - ch <- StreamEvent{ - Type: "tool_call", - Tool: &ToolCall{ID: ptc.id, Name: ptc.name, Arguments: args}, - } - } - - ch <- StreamEvent{Type: "done", Usage: usage} -} diff --git a/provider/openrouter.go b/provider/openrouter.go deleted file mode 100644 index 6e75bb5..0000000 --- a/provider/openrouter.go +++ /dev/null @@ -1,44 +0,0 @@ -package provider - -const defaultOpenRouterBaseURL = "https://openrouter.ai/api/v1" - -// OpenRouterConfig configures the OpenRouter provider. -type OpenRouterConfig struct { - APIKey string - Model string - BaseURL string - MaxTokens int -} - -// OpenRouterProvider wraps OpenAIProvider with OpenRouter-specific identity and auth info. -type OpenRouterProvider struct { - *OpenAIProvider -} - -func (p *OpenRouterProvider) Name() string { return "openrouter" } - -func (p *OpenRouterProvider) AuthModeInfo() AuthModeInfo { - return AuthModeInfo{ - Mode: "openrouter", - DisplayName: "OpenRouter", - Description: "Access multiple AI models via OpenRouter's unified API.", - DocsURL: "https://openrouter.ai/docs/api/reference/authentication", - ServerSafe: true, - } -} - -// NewOpenRouterProvider creates a provider that uses OpenRouter's OpenAI-compatible API. -func NewOpenRouterProvider(cfg OpenRouterConfig) *OpenRouterProvider { - baseURL := cfg.BaseURL - if baseURL == "" { - baseURL = defaultOpenRouterBaseURL - } - return &OpenRouterProvider{ - OpenAIProvider: NewOpenAIProvider(OpenAIConfig{ - APIKey: cfg.APIKey, - Model: cfg.Model, - BaseURL: baseURL, - MaxTokens: cfg.MaxTokens, - }), - } -} diff --git a/provider/openrouter_test.go b/provider/openrouter_test.go deleted file mode 100644 index 0825cd3..0000000 --- a/provider/openrouter_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package provider - -import ( - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" -) - -func TestOpenRouterProvider_Name(t *testing.T) { - p := NewOpenRouterProvider(OpenRouterConfig{APIKey: "test-key"}) - if got := p.Name(); got != "openrouter" { - t.Errorf("Name() = %q, want %q", got, "openrouter") - } -} - -func TestOpenRouterProvider_AuthModeInfo(t *testing.T) { - p := NewOpenRouterProvider(OpenRouterConfig{APIKey: "test-key"}) - info := p.AuthModeInfo() - - if info.Mode != "openrouter" { - t.Errorf("AuthModeInfo().Mode = %q, want %q", info.Mode, "openrouter") - } - if info.DisplayName != "OpenRouter" { - t.Errorf("AuthModeInfo().DisplayName = %q, want %q", info.DisplayName, "OpenRouter") - } - if info.DocsURL != "https://openrouter.ai/docs/api/reference/authentication" { - t.Errorf("AuthModeInfo().DocsURL = %q, want openrouter docs URL", info.DocsURL) - } - if !info.ServerSafe { - t.Error("AuthModeInfo().ServerSafe = false, want true") - } -} - -func TestOpenRouterProvider_ImplementsProvider(t *testing.T) { - var _ Provider = (*OpenRouterProvider)(nil) -} - -func TestOpenRouterProvider_DefaultBaseURL(t *testing.T) { - p := NewOpenRouterProvider(OpenRouterConfig{APIKey: "test-key"}) - if p.config.BaseURL != defaultOpenRouterBaseURL { - t.Errorf("default BaseURL = %q, want %q", p.config.BaseURL, defaultOpenRouterBaseURL) - } -} - -func TestOpenRouterProvider_CustomBaseURL(t *testing.T) { - custom := "https://custom.openrouter.ai/api/v1" - p := NewOpenRouterProvider(OpenRouterConfig{APIKey: "test-key", BaseURL: custom}) - if p.config.BaseURL != custom { - t.Errorf("BaseURL = %q, want %q", p.config.BaseURL, custom) - } -} - -func TestOpenRouterProvider_Chat(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("Authorization"); got != "Bearer test-key" { - t.Errorf("Authorization header = %q, want %q", got, "Bearer test-key") - } - - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"id":"chatcmpl-test","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"hello from openrouter"},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15},"created":1704067200}`) - })) - defer srv.Close() - - p := NewOpenRouterProvider(OpenRouterConfig{ - APIKey: "test-key", - Model: "meta-llama/llama-3-70b", - BaseURL: srv.URL, - }) - - got, err := p.Chat(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Chat() error: %v", err) - } - if got.Content != "hello from openrouter" { - t.Errorf("Chat() content = %q, want %q", got.Content, "hello from openrouter") - } - if got.Usage.InputTokens != 10 || got.Usage.OutputTokens != 5 { - t.Errorf("Chat() usage = %+v, want input=10 output=5", got.Usage) - } -} - -func TestOpenRouterProvider_Stream(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - flusher, ok := w.(http.Flusher) - if !ok { - t.Fatal("expected http.Flusher") - } - - data, _ := json.Marshal(map[string]any{ - "id": "chatcmpl-stream", - "object": "chat.completion.chunk", - "model": "gpt-4o", - "created": 1704067200, - "choices": []map[string]any{ - {"index": 0, "delta": map[string]any{"role": "assistant", "content": "streamed"}, "finish_reason": nil}, - }, - }) - fmt.Fprintf(w, "data: %s\n\n", string(data)) - flusher.Flush() - - fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() - })) - defer srv.Close() - - p := NewOpenRouterProvider(OpenRouterConfig{ - APIKey: "test-key", - Model: "meta-llama/llama-3-70b", - BaseURL: srv.URL, - }) - - ch, err := p.Stream(t.Context(), []Message{ - {Role: RoleUser, Content: "hi"}, - }, nil) - if err != nil { - t.Fatalf("Stream() error: %v", err) - } - - var texts []string - for ev := range ch { - switch ev.Type { - case "text": - texts = append(texts, ev.Text) - case "error": - t.Fatalf("stream error: %s", ev.Error) - } - } - if len(texts) == 0 { - t.Fatal("expected at least one text event") - } - if texts[0] != "streamed" { - t.Errorf("first text event = %q, want %q", texts[0], "streamed") - } -} diff --git a/provider_registry.go b/provider_registry.go index 1985b14..dbd280c 100644 --- a/provider_registry.go +++ b/provider_registry.go @@ -8,6 +8,7 @@ import ( "time" "github.com/GoCodeAlone/modular" + gkprov "github.com/GoCodeAlone/workflow-plugin-agent/genkit" "github.com/GoCodeAlone/workflow-plugin-agent/provider" "github.com/GoCodeAlone/workflow/config" "github.com/GoCodeAlone/workflow/module" @@ -27,8 +28,8 @@ type LLMProviderConfig struct { IsDefault int `json:"is_default"` } -// ProviderFactory creates a provider.Provider from an API key and config. -type ProviderFactory func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) +// ProviderFactory creates a provider.Provider from a context, API key, and config. +type ProviderFactory func(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) // ProviderRegistry manages AI provider lifecycle: factory creation, caching, and DB lookup. type ProviderRegistry struct { @@ -48,58 +49,34 @@ func NewProviderRegistry(db *sql.DB, secretsProvider secrets.Provider) *Provider Factories: make(map[string]ProviderFactory), } - r.Factories["mock"] = func(_ string, _ LLMProviderConfig) (provider.Provider, error) { + r.Factories["mock"] = func(_ context.Context, _ string, _ LLMProviderConfig) (provider.Provider, error) { return &mockProvider{responses: []string{"I have completed the task."}}, nil } - r.Factories["anthropic"] = func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewAnthropicProvider(provider.AnthropicConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + r.Factories["anthropic"] = func(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewAnthropicProvider(ctx, apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } - r.Factories["openai"] = func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewOpenAIProvider(provider.OpenAIConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + r.Factories["openai"] = func(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewOpenAIProvider(ctx, apiKey, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } - r.Factories["openrouter"] = func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - if cfg.BaseURL == "" { - cfg.BaseURL = "https://openrouter.ai/api/v1" + r.Factories["openrouter"] = func(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://openrouter.ai/api/v1" } - return provider.NewOpenAIProvider(provider.OpenAIConfig{ - APIKey: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + return gkprov.NewOpenAICompatibleProvider(ctx, "openrouter", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } - r.Factories["copilot"] = func(apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewCopilotProvider(provider.CopilotConfig{ - Token: apiKey, - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + r.Factories["copilot"] = func(ctx context.Context, apiKey string, cfg LLMProviderConfig) (provider.Provider, error) { + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://api.githubcopilot.com" + } + return gkprov.NewOpenAICompatibleProvider(ctx, "copilot", apiKey, cfg.Model, baseURL, cfg.MaxTokens) } - r.Factories["ollama"] = func(_ string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewOllamaProvider(provider.OllamaConfig{ - Model: cfg.Model, - BaseURL: cfg.BaseURL, - MaxTokens: cfg.MaxTokens, - }), nil + r.Factories["ollama"] = func(ctx context.Context, _ string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewOllamaProvider(ctx, cfg.Model, cfg.BaseURL, cfg.MaxTokens) } - r.Factories["llama_cpp"] = func(_ string, cfg LLMProviderConfig) (provider.Provider, error) { - return provider.NewLlamaCppProvider(provider.LlamaCppConfig{ - BaseURL: cfg.BaseURL, - ModelPath: cfg.Model, - ModelName: cfg.Model, - MaxTokens: cfg.MaxTokens, - }), nil + r.Factories["llama_cpp"] = func(ctx context.Context, _ string, cfg LLMProviderConfig) (provider.Provider, error) { + return gkprov.NewOpenAICompatibleProvider(ctx, "llama_cpp", "", cfg.Model, cfg.BaseURL, cfg.MaxTokens) } return r @@ -215,7 +192,7 @@ func (r *ProviderRegistry) createAndCache(ctx context.Context, alias string, cfg return nil, fmt.Errorf("provider registry: unknown provider type %q", cfg.Type) } - p, err := factory(apiKey, *cfg) + p, err := factory(ctx, apiKey, *cfg) if err != nil { return nil, fmt.Errorf("provider registry: create %q: %w", alias, err) } diff --git a/step_model_pull.go b/step_model_pull.go index 614e325..7d02936 100644 --- a/step_model_pull.go +++ b/step_model_pull.go @@ -44,13 +44,10 @@ func (s *ModelPullStep) Execute(ctx context.Context, _ *module.PipelineContext) } func (s *ModelPullStep) pullOllama(ctx context.Context) (*module.StepResult, error) { - p := provider.NewOllamaProvider(provider.OllamaConfig{ - Model: s.model, - BaseURL: s.ollamaBase, - }) + client := provider.NewOllamaClient(s.ollamaBase) // Check if model is already available. - models, err := p.ListModels(ctx) + models, err := client.ListModels(ctx) if err == nil { for _, m := range models { if m.ID == s.model || m.Name == s.model { @@ -66,7 +63,7 @@ func (s *ModelPullStep) pullOllama(ctx context.Context) (*module.StepResult, err } // Pull (download) the model. - if pullErr := p.Pull(ctx, s.model, nil); pullErr != nil { + if pullErr := client.Pull(ctx, s.model, nil); pullErr != nil { return &module.StepResult{ Output: map[string]any{ "status": "error",