diff --git a/src/ai/llm.rs b/src/ai/llm.rs index c1c6b50..f8c58eb 100644 --- a/src/ai/llm.rs +++ b/src/ai/llm.rs @@ -2,14 +2,28 @@ //! //! 支持 OpenAI-compatible API,用于自动生成 commit message。 //! -//! 环境变量: -//! - `AGIT_LLM_API_KEY` — API 密钥(必需) -//! - `AGIT_LLM_API_URL` — API 端点(默认 `https://api.openai.com/v1`) -//! - `AGIT_LLM_MODEL` — 模型名(默认 `gpt-4o-mini`) +//! ## 配置方式(按优先级) +//! +//! 1. 环境变量 `AGIT_LLM_API_KEY` / `AGIT_LLM_PROVIDER` / `AGIT_LLM_MODEL` +//! 2. 仓库级 `.agit/config.toml` 的 `[llm]` 段 +//! 3. 全局 `~/.agitconfig.toml` 的 `[llm]` 段 +//! 4. 默认:provider=openai, model=gpt-4o-mini +//! +//! ## 配置文件示例 +//! +//! ```toml +//! [llm] +//! api_key = "sk-xxx" +//! provider = "deepseek" # openai / deepseek / anthropic / moonshot / zhipu / ollama +//! model = "deepseek-chat" # 可选,不填自动匹配 provider 默认模型 +//! ``` +//! +//! `AGIT_LLM_API_URL` 环境变量可直接覆盖 API 端点(优先级最高)。 +use crate::config; use serde::{Deserialize, Serialize}; -/// LLM API 配置。 +/// LLM API 运行时配置(已解析)。 #[derive(Clone)] pub struct LlmConfig { pub api_key: String, @@ -18,13 +32,33 @@ pub struct LlmConfig { } impl LlmConfig { - /// 从环境变量加载配置。 - /// 如果 `AGIT_LLM_API_KEY` 未设置则返回 None。 - pub fn from_env() -> Option { - let api_key = std::env::var("AGIT_LLM_API_KEY").ok()?; - let api_url = std::env::var("AGIT_LLM_API_URL") - .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()); - let model = std::env::var("AGIT_LLM_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()); + /// 从 Config 系统加载 LLM 配置。 + /// 优先级:AGIT_LLM_API_URL 环境变量 > provider 预设 > 默认 OpenAI + pub fn from_config(cfg: &config::Config) -> Option { + let api_key = cfg.llm.api_key.clone()?; + + // 解析 API URL + let api_url = std::env::var("AGIT_LLM_API_URL").ok().unwrap_or_else(|| { + // 根据 provider 查预设 + if let Some(ref provider) = cfg.llm.provider { + if let Some((url, _)) = config::resolve_llm_provider(provider) { + return url.to_string(); + } + } + // 默认 OpenAI + "https://api.openai.com/v1".to_string() + }); + + let model = cfg.llm.model.clone().unwrap_or_else(|| { + // 根据 provider 查默认 model + if let Some(ref provider) = cfg.llm.provider { + if let Some((_, default_model)) = config::resolve_llm_provider(provider) { + return default_model.to_string(); + } + } + "gpt-4o-mini".to_string() + }); + Some(LlmConfig { api_key, api_url, @@ -63,14 +97,14 @@ struct MessageContent { } /// 调用 LLM API 获取 chat completion。 -/// -/// 返回生成的文本内容,失败时返回错误。 pub fn chat_completion( config: &LlmConfig, system_prompt: &str, user_prompt: &str, ) -> Result> { - let client = reqwest::blocking::Client::new(); + let client = reqwest::blocking::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build()?; let request = ChatRequest { model: config.model.clone(), @@ -98,7 +132,13 @@ pub fn chat_completion( if !response.status().is_success() { let status = response.status(); let body = response.text().unwrap_or_default(); - return Err(format!("LLM API error ({}): {}", status, body).into()); + return Err(format!( + "LLM API error ({}): {}\n\ + Hint: check AGIT_LLM_API_KEY and AGIT_LLM_PROVIDER.\n\ + Supported providers: openai, deepseek, anthropic, moonshot, zhipu, ollama", + status, body + ) + .into()); } let chat_response: ChatResponse = response.json()?; diff --git a/src/commands/commit.rs b/src/commands/commit.rs index 11914d1..7117d4d 100644 --- a/src/commands/commit.rs +++ b/src/commands/commit.rs @@ -1,6 +1,7 @@ use crate::ai; use crate::config; use crate::core::index::Index; +#[cfg(feature = "ai")] use crate::core::objects::blob::Blob; use crate::core::objects::commit::Commit; use crate::core::objects::tree::Tree; @@ -174,9 +175,14 @@ pub fn run(message: Option, ai_flag: bool) -> Result<(), Box String { let summary = build_staged_summary(repo_root, index); - if let Some(config) = ai::llm::LlmConfig::from_env() { - println!("[AI] Generating commit message via {}...", config.model); - match ai::llm::generate_commit_message(&config, &summary, None) { + let cfg = crate::config::load(); + if let Some(llm_cfg) = ai::llm::LlmConfig::from_config(&cfg) { + let provider = cfg.llm.provider.as_deref().unwrap_or("openai"); + println!( + "[AI] Generating commit message via {} ({})...", + provider, llm_cfg.model + ); + match ai::llm::generate_commit_message(&llm_cfg, &summary, None) { Ok(msg) => { println!("[AI] Generated: {}", msg); return msg; @@ -186,7 +192,13 @@ fn generate_ai_message(repo_root: &std::path::Path, index: &Index) -> String { } } } else { - println!("[AI] AGIT_LLM_API_KEY not set, using basic template."); + println!( + "[AI] No API key configured.\n\ + Configure via:\n \ + 1. env (one-shot): AGIT_LLM_API_KEY=sk-xxx AGIT_LLM_PROVIDER=deepseek\n \ + 2. file (persist): ~/.agitconfig.toml\n [llm]\n api_key = \"sk-xxx\"\n provider = \"deepseek\"\n\ + Supported providers: openai, deepseek, anthropic, moonshot, zhipu, ollama" + ); } // 回退:基于文件列表生成简单消息 let paths: Vec<&str> = index.entries.keys().map(|s| s.as_str()).collect(); @@ -204,6 +216,7 @@ fn generate_ai_message(_repo_root: &std::path::Path, _index: &Index) -> String { } /// 构建暂存区文件变更摘要(供 AI 使用)。 +#[cfg(feature = "ai")] fn build_staged_summary(repo_root: &std::path::Path, index: &Index) -> String { let mut lines = vec!["Staged changes:".to_string()]; diff --git a/src/commands/init.rs b/src/commands/init.rs index 732df64..a8f72e0 100644 --- a/src/commands/init.rs +++ b/src/commands/init.rs @@ -1001,12 +1001,15 @@ pub fn run( "# git ls-files --others --exclude-standard\n", )?; - // ── .gitignore 模板 ── + // ── .gitignore ── + // 始终包含 .agit/ 防止 agit 配置(含密钥)被意外提交 + let mut gitignore = String::from("# agit config directory (may contain secrets)\n.agit/\n\n"); + let has_pattern = pattern.is_some(); if let Some(pattern_name) = pattern { match get_pattern_text(pattern_name) { - Some(content) => { - fs::write(target.join(".gitignore"), content)?; - println!(" Created .gitignore ({})", pattern_name); + Some(template) => { + gitignore.push_str(template); + println!(" Created .gitignore (.agit/ + {})", pattern_name); } None => { eprintln!( @@ -1021,6 +1024,10 @@ pub fn run( } } } + if !has_pattern { + println!(" Created .gitignore (.agit/)"); + } + fs::write(target.join(".gitignore"), gitignore)?; // ── 许可证模板 ── if let Some(licence_name) = licence { diff --git a/src/config/mod.rs b/src/config/mod.rs index 1cc84f4..23a9688 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -9,6 +9,18 @@ pub struct Config { pub user_name: String, pub user_email: String, pub aliases: HashMap, + /// LLM API 配置(api_key / provider / model) + #[allow(dead_code)] + pub llm: LlmConfig, +} + +/// LLM 配置段。 +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct LlmConfig { + pub api_key: Option, + pub provider: Option, + pub model: Option, } /// TOML 配置文件的反序列化结构。 @@ -18,6 +30,8 @@ struct ConfigFile { user: Option, #[serde(default)] alias: Option>, + #[serde(default)] + llm: Option, } #[derive(Debug, Clone, Default, Deserialize)] @@ -28,12 +42,53 @@ struct UserSection { email: Option, } +#[derive(Debug, Clone, Default, Deserialize)] +struct LlmSection { + #[serde(default)] + api_key: Option, + #[serde(default)] + provider: Option, + #[serde(default)] + model: Option, +} + +/// 内置 LLM 厂商预设:(provider, api_url, default_model) +#[allow(dead_code)] +pub const LLM_PROVIDERS: &[(&str, &str, &str)] = &[ + ("openai", "https://api.openai.com/v1", "gpt-4o-mini"), + ("deepseek", "https://api.deepseek.com/v1", "deepseek-chat"), + ( + "anthropic", + "https://api.anthropic.com/v1", + "claude-haiku-3-5", + ), + ("moonshot", "https://api.moonshot.cn/v1", "moonshot-v1-8k"), + ( + "zhipu", + "https://open.bigmodel.cn/api/paas/v4", + "glm-4-flash", + ), + ("ollama", "http://localhost:11434/v1", "llama3"), +]; + +/// 根据 provider 名查找预设的 API URL 和 model。 +#[allow(dead_code)] +pub fn resolve_llm_provider(provider: &str) -> Option<(&'static str, &'static str)> { + LLM_PROVIDERS + .iter() + .find(|(name, _, _)| *name == provider.to_lowercase()) + .map(|(_, url, model)| (*url, *model)) +} + impl Config { /// 按优先级加载配置:环境变量 > 仓库级 .agit/config.toml > 全局 ~/.agitconfig.toml > 默认值。 pub fn load(repo_path: Option<&Path>) -> Self { let mut file_user_name: Option = None; let mut file_user_email: Option = None; let mut aliases = HashMap::new(); + let mut llm_api_key: Option = None; + let mut llm_provider: Option = None; + let mut llm_model: Option = None; // 1. 读取全局配置文件 ~/.agitconfig.toml if let Some(home) = dirs_fallback() { @@ -45,10 +100,19 @@ impl Config { &mut file_user_email, &mut aliases, ); + // 全局配置允许 api_key + merge_llm_config( + &cfg, + &mut llm_api_key, + &mut llm_provider, + &mut llm_model, + true, + ); } } // 2. 读取仓库级配置文件 .agit/config.toml(覆盖全局) + // 注意:api_key 不从仓库级配置读取,防止泄漏到仓库中 if let Some(repo) = repo_path { let repo_config = repo.join(".agit").join("config.toml"); if let Some(cfg) = read_config_file(&repo_config) { @@ -58,6 +122,14 @@ impl Config { &mut file_user_email, &mut aliases, ); + // api_key 不从仓库级配置读取(include_api_key=false) + merge_llm_config( + &cfg, + &mut llm_api_key, + &mut llm_provider, + &mut llm_model, + false, + ); } } @@ -74,10 +146,26 @@ impl Config { .or(file_user_email) .unwrap_or_else(|| "agit@localhost".to_string()); + // LLM 环境变量覆盖配置文件 + if let Ok(key) = env::var("AGIT_LLM_API_KEY") { + llm_api_key = Some(key); + } + if let Ok(provider) = env::var("AGIT_LLM_PROVIDER") { + llm_provider = Some(provider); + } + if let Ok(model) = env::var("AGIT_LLM_MODEL") { + llm_model = Some(model); + } + Config { user_name, user_email, aliases, + llm: LlmConfig { + api_key: llm_api_key, + provider: llm_provider, + model: llm_model, + }, } } } @@ -108,6 +196,29 @@ fn merge_config( } } +fn merge_llm_config( + cfg: &ConfigFile, + api_key: &mut Option, + provider: &mut Option, + model: &mut Option, + include_api_key: bool, +) { + if let Some(ref llm) = cfg.llm { + // api_key 只从全局配置读取,不从仓库级配置读取(防泄漏) + if include_api_key { + if let Some(ref key) = llm.api_key { + *api_key = Some(key.clone()); + } + } + if let Some(ref p) = llm.provider { + *provider = Some(p.clone()); + } + if let Some(ref m) = llm.model { + *model = Some(m.clone()); + } + } +} + /// 不依赖 `dirs` crate 的 home 目录获取。 fn dirs_fallback() -> Option { env::var("HOME") diff --git a/tests/p1_integration_test.rs b/tests/p1_integration_test.rs index 29a83ee..4d8c4d5 100644 --- a/tests/p1_integration_test.rs +++ b/tests/p1_integration_test.rs @@ -428,7 +428,8 @@ fn test_init_pattern_unknown_warns() { let stderr = String::from_utf8_lossy(&output.stderr); assert!(stderr.contains("unknown gitignore pattern")); assert!(stderr.contains("Available:")); - assert!(!repo.join(".gitignore").exists()); + // .gitignore 始终创建(含 .agit/ 安全守卫),即使 pattern 无效 + assert!(repo.join(".gitignore").exists()); let _ = fs::remove_dir_all(&repo); }