Skip to content

Commit 9ff358b

Browse files
authored
Merge pull request #99 from Muneer320/main
2 parents 2c50506 + 87bd301 commit 9ff358b

3 files changed

Lines changed: 420 additions & 25 deletions

File tree

cmd/cli/createMsg.go

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cmd
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"os"
@@ -10,14 +11,9 @@ import (
1011

1112
"github.com/atotto/clipboard"
1213
"github.com/dfanso/commit-msg/cmd/cli/store"
13-
"github.com/dfanso/commit-msg/internal/chatgpt"
14-
"github.com/dfanso/commit-msg/internal/claude"
1514
"github.com/dfanso/commit-msg/internal/display"
16-
"github.com/dfanso/commit-msg/internal/gemini"
1715
"github.com/dfanso/commit-msg/internal/git"
18-
"github.com/dfanso/commit-msg/internal/grok"
19-
"github.com/dfanso/commit-msg/internal/groq"
20-
"github.com/dfanso/commit-msg/internal/ollama"
16+
"github.com/dfanso/commit-msg/internal/llm"
2117
"github.com/dfanso/commit-msg/internal/stats"
2218
"github.com/dfanso/commit-msg/pkg/types"
2319
"github.com/google/shlex"
@@ -102,6 +98,17 @@ func CreateCommitMsg(dryRun bool, autoCommit bool) {
10298
return
10399
}
104100

101+
ctx := context.Background()
102+
103+
providerInstance, err := llm.NewProvider(commitLLM, llm.ProviderOptions{
104+
Credential: apiKey,
105+
Config: config,
106+
})
107+
if err != nil {
108+
displayProviderError(commitLLM, err)
109+
os.Exit(1)
110+
}
111+
105112
pterm.Println()
106113
spinnerGenerating, err := pterm.DefaultSpinner.
107114
WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏").
@@ -112,7 +119,7 @@ func CreateCommitMsg(dryRun bool, autoCommit bool) {
112119
}
113120

114121
attempt := 1
115-
commitMsg, err := generateMessage(commitLLM, config, changes, apiKey, withAttempt(nil, attempt))
122+
commitMsg, err := generateMessage(ctx, providerInstance, changes, withAttempt(nil, attempt))
116123
if err != nil {
117124
spinnerGenerating.Fail("Failed to generate commit message")
118125
displayProviderError(commitLLM, err)
@@ -174,7 +181,7 @@ interactionLoop:
174181
pterm.Error.Printf("Failed to start spinner: %v\n", err)
175182
continue
176183
}
177-
updatedMessage, genErr := generateMessage(commitLLM, config, changes, apiKey, generationOpts)
184+
updatedMessage, genErr := generateMessage(ctx, providerInstance, changes, generationOpts)
178185
if genErr != nil {
179186
spinner.Fail("Regeneration failed")
180187
displayProviderError(commitLLM, genErr)
@@ -283,22 +290,8 @@ func resolveOllamaConfig(apiKey string) (url, model string) {
283290
return url, model
284291
}
285292

286-
func generateMessage(provider types.LLMProvider, config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) {
287-
switch provider {
288-
case types.ProviderGemini:
289-
return gemini.GenerateCommitMessage(config, changes, apiKey, opts)
290-
case types.ProviderOpenAI:
291-
return chatgpt.GenerateCommitMessage(config, changes, apiKey, opts)
292-
case types.ProviderClaude:
293-
return claude.GenerateCommitMessage(config, changes, apiKey, opts)
294-
case types.ProviderGroq:
295-
return groq.GenerateCommitMessage(config, changes, apiKey, opts)
296-
case types.ProviderOllama:
297-
url, model := resolveOllamaConfig(apiKey)
298-
return ollama.GenerateCommitMessage(config, changes, url, model, opts)
299-
default:
300-
return grok.GenerateCommitMessage(config, changes, apiKey, opts)
301-
}
293+
func generateMessage(ctx context.Context, provider llm.Provider, changes string, opts *types.GenerationOptions) (string, error) {
294+
return provider.Generate(ctx, changes, opts)
302295
}
303296

304297
func promptActionSelection() (string, error) {
@@ -456,6 +449,11 @@ func withAttempt(styleOpts *types.GenerationOptions, attempt int) *types.Generat
456449
}
457450

458451
func displayProviderError(provider types.LLMProvider, err error) {
452+
if errors.Is(err, llm.ErrMissingCredential) {
453+
displayMissingCredentialHint(provider)
454+
return
455+
}
456+
459457
switch provider {
460458
case types.ProviderGemini:
461459
pterm.Error.Printf("Gemini API error: %v. Check your GEMINI_API_KEY environment variable or run: commit llm setup\n", err)
@@ -467,8 +465,29 @@ func displayProviderError(provider types.LLMProvider, err error) {
467465
pterm.Error.Printf("Groq API error: %v. Check your GROQ_API_KEY environment variable or run: commit llm setup\n", err)
468466
case types.ProviderGrok:
469467
pterm.Error.Printf("Grok API error: %v. Check your GROK_API_KEY environment variable or run: commit llm setup\n", err)
468+
case types.ProviderOllama:
469+
pterm.Error.Printf("Ollama error: %v. Verify the Ollama service URL or run: commit llm setup\n", err)
470+
default:
471+
pterm.Error.Printf("LLM error: %v\n", err)
472+
}
473+
}
474+
475+
func displayMissingCredentialHint(provider types.LLMProvider) {
476+
switch provider {
477+
case types.ProviderGemini:
478+
pterm.Error.Println("Gemini requires an API key. Run: commit llm setup or set GEMINI_API_KEY.")
479+
case types.ProviderOpenAI:
480+
pterm.Error.Println("OpenAI requires an API key. Run: commit llm setup or set OPENAI_API_KEY.")
481+
case types.ProviderClaude:
482+
pterm.Error.Println("Claude requires an API key. Run: commit llm setup or set CLAUDE_API_KEY.")
483+
case types.ProviderGroq:
484+
pterm.Error.Println("Groq requires an API key. Run: commit llm setup or set GROQ_API_KEY.")
485+
case types.ProviderGrok:
486+
pterm.Error.Println("Grok requires an API key. Run: commit llm setup or set GROK_API_KEY.")
487+
case types.ProviderOllama:
488+
pterm.Error.Println("Ollama requires a reachable service URL. Run: commit llm setup or set OLLAMA_URL.")
470489
default:
471-
pterm.Error.Printf("LLM API error: %v\n", err)
490+
pterm.Error.Printf("%s is missing credentials. Run: commit llm setup.\n", provider)
472491
}
473492
}
474493

internal/llm/provider.go

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
package llm
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"os"
8+
"strings"
9+
"sync"
10+
11+
"github.com/dfanso/commit-msg/internal/chatgpt"
12+
"github.com/dfanso/commit-msg/internal/claude"
13+
"github.com/dfanso/commit-msg/internal/gemini"
14+
"github.com/dfanso/commit-msg/internal/grok"
15+
"github.com/dfanso/commit-msg/internal/groq"
16+
"github.com/dfanso/commit-msg/internal/ollama"
17+
"github.com/dfanso/commit-msg/pkg/types"
18+
)
19+
20+
// ErrMissingCredential signals that a provider requires a credential such as an API key or URL.
21+
var ErrMissingCredential = errors.New("llm: missing credential")
22+
23+
// Provider declares the behaviour required by commit-msg to talk to an LLM backend.
24+
type Provider interface {
25+
// Name returns the LLM provider identifier this instance represents.
26+
Name() types.LLMProvider
27+
// Generate requests a commit message for the supplied repository changes.
28+
Generate(ctx context.Context, changes string, opts *types.GenerationOptions) (string, error)
29+
}
30+
31+
// ProviderOptions captures the data needed to construct a provider instance.
32+
type ProviderOptions struct {
33+
Credential string
34+
Config *types.Config
35+
}
36+
37+
// Factory describes a function capable of building a Provider.
38+
type Factory func(ProviderOptions) (Provider, error)
39+
40+
var (
41+
factoryMu sync.RWMutex
42+
factories = map[types.LLMProvider]Factory{
43+
types.ProviderOpenAI: newOpenAIProvider,
44+
types.ProviderClaude: newClaudeProvider,
45+
types.ProviderGemini: newGeminiProvider,
46+
types.ProviderGrok: newGrokProvider,
47+
types.ProviderGroq: newGroqProvider,
48+
types.ProviderOllama: newOllamaProvider,
49+
}
50+
)
51+
52+
// RegisterFactory allows callers (primarily tests) to override or extend provider creation logic.
53+
func RegisterFactory(name types.LLMProvider, factory Factory) {
54+
factoryMu.Lock()
55+
defer factoryMu.Unlock()
56+
factories[name] = factory
57+
}
58+
59+
// NewProvider returns a concrete Provider implementation for the requested name.
60+
func NewProvider(name types.LLMProvider, opts ProviderOptions) (Provider, error) {
61+
factoryMu.RLock()
62+
factory, ok := factories[name]
63+
factoryMu.RUnlock()
64+
if !ok {
65+
return nil, fmt.Errorf("llm: unsupported provider %s", name)
66+
}
67+
68+
opts.Config = ensureConfig(opts.Config)
69+
return factory(opts)
70+
}
71+
72+
type missingCredentialError struct {
73+
provider types.LLMProvider
74+
}
75+
76+
func (e *missingCredentialError) Error() string {
77+
return fmt.Sprintf("%s credential is required", e.provider.String())
78+
}
79+
80+
func (e *missingCredentialError) Unwrap() error {
81+
return ErrMissingCredential
82+
}
83+
84+
func newMissingCredentialError(provider types.LLMProvider) error {
85+
return &missingCredentialError{provider: provider}
86+
}
87+
88+
func ensureConfig(cfg *types.Config) *types.Config {
89+
if cfg != nil {
90+
return cfg
91+
}
92+
return &types.Config{}
93+
}
94+
95+
// --- Provider implementations ------------------------------------------------
96+
97+
type openAIProvider struct {
98+
apiKey string
99+
config *types.Config
100+
}
101+
102+
func newOpenAIProvider(opts ProviderOptions) (Provider, error) {
103+
key := strings.TrimSpace(opts.Credential)
104+
if key == "" {
105+
key = strings.TrimSpace(os.Getenv("OPENAI_API_KEY"))
106+
}
107+
if key == "" {
108+
return nil, newMissingCredentialError(types.ProviderOpenAI)
109+
}
110+
return &openAIProvider{apiKey: key, config: opts.Config}, nil
111+
}
112+
113+
func (p *openAIProvider) Name() types.LLMProvider {
114+
return types.ProviderOpenAI
115+
}
116+
117+
func (p *openAIProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) {
118+
return chatgpt.GenerateCommitMessage(p.config, changes, p.apiKey, opts)
119+
}
120+
121+
type claudeProvider struct {
122+
apiKey string
123+
config *types.Config
124+
}
125+
126+
func newClaudeProvider(opts ProviderOptions) (Provider, error) {
127+
key := strings.TrimSpace(opts.Credential)
128+
if key == "" {
129+
key = strings.TrimSpace(os.Getenv("CLAUDE_API_KEY"))
130+
}
131+
if key == "" {
132+
return nil, newMissingCredentialError(types.ProviderClaude)
133+
}
134+
return &claudeProvider{apiKey: key, config: opts.Config}, nil
135+
}
136+
137+
func (p *claudeProvider) Name() types.LLMProvider {
138+
return types.ProviderClaude
139+
}
140+
141+
func (p *claudeProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) {
142+
return claude.GenerateCommitMessage(p.config, changes, p.apiKey, opts)
143+
}
144+
145+
type geminiProvider struct {
146+
apiKey string
147+
config *types.Config
148+
}
149+
150+
func newGeminiProvider(opts ProviderOptions) (Provider, error) {
151+
key := strings.TrimSpace(opts.Credential)
152+
if key == "" {
153+
key = strings.TrimSpace(os.Getenv("GEMINI_API_KEY"))
154+
}
155+
if key == "" {
156+
return nil, newMissingCredentialError(types.ProviderGemini)
157+
}
158+
return &geminiProvider{apiKey: key, config: opts.Config}, nil
159+
}
160+
161+
func (p *geminiProvider) Name() types.LLMProvider {
162+
return types.ProviderGemini
163+
}
164+
165+
func (p *geminiProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) {
166+
return gemini.GenerateCommitMessage(p.config, changes, p.apiKey, opts)
167+
}
168+
169+
type grokProvider struct {
170+
apiKey string
171+
config *types.Config
172+
}
173+
174+
func newGrokProvider(opts ProviderOptions) (Provider, error) {
175+
key := strings.TrimSpace(opts.Credential)
176+
if key == "" {
177+
key = strings.TrimSpace(os.Getenv("GROK_API_KEY"))
178+
}
179+
if key == "" {
180+
return nil, newMissingCredentialError(types.ProviderGrok)
181+
}
182+
return &grokProvider{apiKey: key, config: opts.Config}, nil
183+
}
184+
185+
func (p *grokProvider) Name() types.LLMProvider {
186+
return types.ProviderGrok
187+
}
188+
189+
func (p *grokProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) {
190+
return grok.GenerateCommitMessage(p.config, changes, p.apiKey, opts)
191+
}
192+
193+
type groqProvider struct {
194+
apiKey string
195+
config *types.Config
196+
}
197+
198+
func newGroqProvider(opts ProviderOptions) (Provider, error) {
199+
key := strings.TrimSpace(opts.Credential)
200+
if key == "" {
201+
key = strings.TrimSpace(os.Getenv("GROQ_API_KEY"))
202+
}
203+
if key == "" {
204+
return nil, newMissingCredentialError(types.ProviderGroq)
205+
}
206+
return &groqProvider{apiKey: key, config: opts.Config}, nil
207+
}
208+
209+
func (p *groqProvider) Name() types.LLMProvider {
210+
return types.ProviderGroq
211+
}
212+
213+
func (p *groqProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) {
214+
return groq.GenerateCommitMessage(p.config, changes, p.apiKey, opts)
215+
}
216+
217+
type ollamaProvider struct {
218+
url string
219+
model string
220+
config *types.Config
221+
}
222+
223+
func newOllamaProvider(opts ProviderOptions) (Provider, error) {
224+
url := strings.TrimSpace(opts.Credential)
225+
if url == "" {
226+
url = strings.TrimSpace(os.Getenv("OLLAMA_URL"))
227+
if url == "" {
228+
url = "http://localhost:11434/api/generate"
229+
}
230+
}
231+
232+
model := strings.TrimSpace(os.Getenv("OLLAMA_MODEL"))
233+
if model == "" {
234+
model = "llama3.1"
235+
}
236+
237+
return &ollamaProvider{url: url, model: model, config: opts.Config}, nil
238+
}
239+
240+
func (p *ollamaProvider) Name() types.LLMProvider {
241+
return types.ProviderOllama
242+
}
243+
244+
func (p *ollamaProvider) Generate(_ context.Context, changes string, opts *types.GenerationOptions) (string, error) {
245+
return ollama.GenerateCommitMessage(p.config, changes, p.url, p.model, opts)
246+
}

0 commit comments

Comments
 (0)