From 309a812acb3d8dba080da1bd80abff7d72fe35d2 Mon Sep 17 00:00:00 2001 From: "Sebastian L." Date: Mon, 16 Mar 2026 17:53:19 +0100 Subject: [PATCH] Implement and OpenAI provider --- example.config.toml | 5 +++ go.mod | 1 + go.sum | 2 + llm_provider.go | 95 +++++++++++++++++++++++++++++++++++++++++++++ main.go | 58 +++++++++++++++++++++++++++ 5 files changed, 161 insertions(+) diff --git a/example.config.toml b/example.config.toml index 9b9181a..677647f 100644 --- a/example.config.toml +++ b/example.config.toml @@ -38,6 +38,11 @@ hate_speech_threshold = "none" sexually_explicit_threshold = "none" dangerous_content_threshold = "none" +[openai] +base_url = "your_custom_openai_endpoint" # Replace with your openai compatible endpoint or remove to use OpenAI +api_key = "your_openai_key" +model = "gpt-4o-mini" + [localization] # Default language for the bot default_language = "en" diff --git a/go.mod b/go.mod index 4084067..c24b7ca 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( require ( github.com/google/go-cmp v0.7.0 // indirect + github.com/sashabaranov/go-openai v1.41.2 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect ) diff --git a/go.sum b/go.sum index c64f8fa..dafaaa3 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM= +github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tomnomnom/linkheader v0.0.0-20250811210735-e5fe3b51442e h1:tD38/4xg4nuQCASJ/JxcvCHNb46w0cdAaJfkzQOO1bA= diff --git a/llm_provider.go b/llm_provider.go index db12060..7d4ab07 100644 --- a/llm_provider.go +++ b/llm_provider.go @@ -17,6 +17,7 @@ import ( "time" genai "google.golang.org/genai" + openai "github.com/sashabaranov/go-openai" ) // LLMProvider interface defines the methods that all LLM providers must implement @@ -51,6 +52,13 @@ type TransformersProvider struct { stopMonitor chan bool } +// OpenAIProvider implements LLMProvider for OpenAI and compatibles +type OpenAIProvider struct { + client *openai.Client + model string + baseURL string +} + // NewLLMProvider creates a new LLM provider based on the configuration func NewLLMProvider(config Config) (LLMProvider, error) { switch config.LLM.Provider { @@ -60,6 +68,8 @@ func NewLLMProvider(config Config) (LLMProvider, error) { return setupOllamaProvider(config) case "transformers": return setupTransformersProvider(config) + case "openai": + return setupOpenAIProvider(config) default: return nil, fmt.Errorf("unsupported LLM provider: %s", config.LLM.Provider) } @@ -163,6 +173,38 @@ func setupOllamaProvider(config Config) (*OllamaProvider, error) { return provider, nil } +func setupOpenAIProvider(config Config) (*OpenAIProvider, error) { + // Validate required fields + if config.Openai.APIKey == "" { + return nil, fmt.Errorf("OpenAI API key is required for OpenAI provider") + } + + // Create OpenAI compatible client configuration + openaiConfig := openai.DefaultConfig(config.Openai.APIKey) + + if config.Openai.BaseURL != "" { + openaiConfig.BaseURL = config.Openai.BaseURL + } else { + openaiConfig.BaseURL = "https://api.openai.com/v1" + } + + model := "gpt-4o-mini" + if config.Openai.Model != "" { + model = config.Openai.Model + } + + // Create client + client := openai.NewClientWithConfig(openaiConfig) + + provider := &OpenAIProvider{ + client: client, + model: model, + baseURL: openaiConfig.BaseURL, + } + + return provider, nil +} + // GenerateAltText implementations for each provider func (p *GeminiProvider) GenerateAltText(prompt string, imageData []byte, format string, targetLanguage string) (string, error) { mimeType, err := inferImageMIME(format) @@ -256,6 +298,55 @@ func (p *OllamaProvider) GenerateVideoAltText(prompt string, videoData []byte, f return "", fmt.Errorf("video processing not supported by Ollama provider") } +// GenerateAltText for OpenAI compatible provider +func (p *OpenAIProvider) GenerateAltText(prompt string, imageData []byte, format string, targetLanguage string) (string, error) { + // Convert image to base64 + base64Image := base64.StdEncoding.EncodeToString(imageData) + + // Prepare messages + messages := []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + MultiContent: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: prompt, + }, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: fmt.Sprintf("data:image/%s;base64,%s", format, base64Image), + }, + }, + }, + }, + } + + // Create request + req := openai.ChatCompletionRequest{ + Model: p.model, + Messages: messages, + } + + // Call OpenAI API + resp, err := p.client.CreateChatCompletion(ctx, req) + if err != nil { + return "", fmt.Errorf("error calling OpenAI API: %v", err) + } + + if len(resp.Choices) == 0 { + return "", fmt.Errorf("no choices in response") + } + + return resp.Choices[0].Message.Content, nil +} + +// GenerateVideoAltText for OpenAI compatible provider +func (p *OpenAIProvider) GenerateVideoAltText(prompt string, videoData []byte, format string, targetLanguage string) (string, error) { + // Depending on the backend Open AI comaptible models can support video processing directly but it's not implemented yet + return "", fmt.Errorf("video processing not yet supported by OpenAI compatible provider") +} + func (p *TransformersProvider) GenerateAltText(prompt string, imageData []byte, format string, targetLanguage string) (string, error) { if config.LLM.UseTranslationLayer && targetLanguage != "en" { // Use translation layer @@ -447,6 +538,10 @@ func (p *OllamaProvider) Close() error { return nil // Nothing to close for Ollama } +func (p *OpenAIProvider) Close() error { + return nil // Nothing to close for OpenAI +} + func (p *TransformersProvider) Close() error { if p.monitoring { p.stopMonitor <- true diff --git a/main.go b/main.go index 5213ba5..87d2a8d 100644 --- a/main.go +++ b/main.go @@ -39,6 +39,7 @@ import ( "golang.org/x/text/language" genai "google.golang.org/genai" + openai "github.com/sashabaranov/go-openai" "github.com/mattn/go-mastodon" "github.com/nfnt/resize" @@ -88,6 +89,11 @@ type Config struct { SexuallyExplicitThreshold string `toml:"sexually_explicit_threshold"` DangerousContentThreshold string `toml:"dangerous_content_threshold"` } `toml:"gemini"` + Openai struct { + BaseURL string `toml:"base_url"` + Model string `toml:"model"` + APIKey string `toml:"api_key"` + } `toml:"openai"` Localization struct { DefaultLanguage string `toml:"default_language"` } `toml:"localization"` @@ -187,6 +193,9 @@ var metricsManager *MetricsManager var llmProvider LLMProvider +var openaiClient *openai.Client +var openaiModel string + const ( sourceURL = "https://github.com/micr0-dev/Altbot" donateURL = "https://ko-fi.com/micr0byte" @@ -279,6 +288,11 @@ func main() { videoProcessingCapability = true audioProcessingCapability = true + case "openai": + // Not yet implemented + videoProcessingCapability = false + audioProcessingCapability = false + default: log.Fatalf("Unsupported LLM provider: %s", config.LLM.Provider) } @@ -328,6 +342,12 @@ func main() { log.Fatal(err) } + // Set up Open AI compatible model (needed for dev mode too if using openai) + err = openaiSetup(config.Openai.APIKey) + if err != nil && !devMode { + log.Fatal(err) + } + // In dev mode, skip all Mastodon-related initialization if devMode { fmt.Printf("%s %d Custom settings loaded\n", getStatusSymbol(customSettingsCount > 0), customSettingsCount) @@ -588,6 +608,39 @@ func Setup(apiKey string) error { return nil } +// Setup initializes the Open AI compatible endpoint with the provided API key +func openaiSetup(apiKey string) error { + if ctx == nil { + ctx = context.Background() + } + + if config.LLM.Provider != "openai" { + return nil + } + + // Create OpenAI compatible client configuration + openaiConfig := openai.DefaultConfig(config.Openai.APIKey) + + if config.Openai.BaseURL != "" { + openaiConfig.BaseURL = config.Openai.BaseURL + } else { + openaiConfig.BaseURL = "https://api.openai.com/v1" + } + + if config.Openai.Model != "" { + openaiModel = config.Openai.Model + } else { + openaiModel = "gpt-4o-mini" + } + + // Create client + if openaiClient == nil { + openaiClient = openai.NewClientWithConfig(openaiConfig) + } + + return nil +} + // handleMention processes incoming mentions and generates alt-text descriptions func handleMention(c *mastodon.Client, notification *mastodon.Notification) { if isDNI(¬ification.Account) { @@ -2158,6 +2211,11 @@ func updateBotProfile(client *mastodon.Client, config Config) error { Name: "Model", Value: config.Gemini.Model, }) + } else if config.LLM.Provider == "openai" { + fields = append(fields, mastodon.Field{ + Name: "Model", + Value: openaiModel, + }) } case "source":