Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 56 additions & 16 deletions src/ai/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,28 @@
//!
//! 支持 OpenAI-compatible API,用于自动生成 commit message。
//!
//! 环境变量:
//! - `AGIT_LLM_API_KEY` — API 密钥(必需)
//! - `AGIT_LLM_API_URL` — API 端点(默认 `https://api.openai.com/v1`)
//! - `AGIT_LLM_MODEL` — 模型名(默认 `gpt-4o-mini`)
//! ## 配置方式(按优先级)
//!
//! 1. 环境变量 `AGIT_LLM_API_KEY` / `AGIT_LLM_PROVIDER` / `AGIT_LLM_MODEL`
//! 2. 仓库级 `.agit/config.toml` 的 `[llm]` 段
//! 3. 全局 `~/.agitconfig.toml` 的 `[llm]` 段
//! 4. 默认:provider=openai, model=gpt-4o-mini
//!
//! ## 配置文件示例
//!
//! ```toml
//! [llm]
//! api_key = "sk-xxx"
//! provider = "deepseek" # openai / deepseek / anthropic / moonshot / zhipu / ollama
//! model = "deepseek-chat" # 可选,不填自动匹配 provider 默认模型
//! ```
//!
//! `AGIT_LLM_API_URL` 环境变量可直接覆盖 API 端点(优先级最高)。

use crate::config;
use serde::{Deserialize, Serialize};

/// LLM API 配置
/// LLM API 运行时配置(已解析)
#[derive(Clone)]
pub struct LlmConfig {
pub api_key: String,
Expand All @@ -18,13 +32,33 @@ pub struct LlmConfig {
}

impl LlmConfig {
/// 从环境变量加载配置。
/// 如果 `AGIT_LLM_API_KEY` 未设置则返回 None。
pub fn from_env() -> Option<Self> {
let api_key = std::env::var("AGIT_LLM_API_KEY").ok()?;
let api_url = std::env::var("AGIT_LLM_API_URL")
.unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
let model = std::env::var("AGIT_LLM_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string());
/// 从 Config 系统加载 LLM 配置。
/// 优先级:AGIT_LLM_API_URL 环境变量 > provider 预设 > 默认 OpenAI
pub fn from_config(cfg: &config::Config) -> Option<Self> {
let api_key = cfg.llm.api_key.clone()?;

// 解析 API URL
let api_url = std::env::var("AGIT_LLM_API_URL").ok().unwrap_or_else(|| {
// 根据 provider 查预设
if let Some(ref provider) = cfg.llm.provider {
if let Some((url, _)) = config::resolve_llm_provider(provider) {
return url.to_string();
}
}
// 默认 OpenAI
"https://api.openai.com/v1".to_string()
});

let model = cfg.llm.model.clone().unwrap_or_else(|| {
// 根据 provider 查默认 model
if let Some(ref provider) = cfg.llm.provider {
if let Some((_, default_model)) = config::resolve_llm_provider(provider) {
return default_model.to_string();
}
}
"gpt-4o-mini".to_string()
});

Some(LlmConfig {
api_key,
api_url,
Expand Down Expand Up @@ -63,14 +97,14 @@ struct MessageContent {
}

/// 调用 LLM API 获取 chat completion。
///
/// 返回生成的文本内容,失败时返回错误。
pub fn chat_completion(
config: &LlmConfig,
system_prompt: &str,
user_prompt: &str,
) -> Result<String, Box<dyn std::error::Error>> {
let client = reqwest::blocking::Client::new();
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()?;

let request = ChatRequest {
model: config.model.clone(),
Expand Down Expand Up @@ -98,7 +132,13 @@ pub fn chat_completion(
if !response.status().is_success() {
let status = response.status();
let body = response.text().unwrap_or_default();
return Err(format!("LLM API error ({}): {}", status, body).into());
return Err(format!(
"LLM API error ({}): {}\n\
Hint: check AGIT_LLM_API_KEY and AGIT_LLM_PROVIDER.\n\
Supported providers: openai, deepseek, anthropic, moonshot, zhipu, ollama",
status, body
)
.into());
}

let chat_response: ChatResponse = response.json()?;
Expand Down
21 changes: 17 additions & 4 deletions src/commands/commit.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::ai;
use crate::config;
use crate::core::index::Index;
#[cfg(feature = "ai")]
use crate::core::objects::blob::Blob;
use crate::core::objects::commit::Commit;
use crate::core::objects::tree::Tree;
Expand Down Expand Up @@ -174,9 +175,14 @@ pub fn run(message: Option<String>, ai_flag: bool) -> Result<(), Box<dyn std::er
#[cfg(feature = "ai")]
fn generate_ai_message(repo_root: &std::path::Path, index: &Index) -> String {
let summary = build_staged_summary(repo_root, index);
if let Some(config) = ai::llm::LlmConfig::from_env() {
println!("[AI] Generating commit message via {}...", config.model);
match ai::llm::generate_commit_message(&config, &summary, None) {
let cfg = crate::config::load();
if let Some(llm_cfg) = ai::llm::LlmConfig::from_config(&cfg) {
let provider = cfg.llm.provider.as_deref().unwrap_or("openai");
println!(
"[AI] Generating commit message via {} ({})...",
provider, llm_cfg.model
);
match ai::llm::generate_commit_message(&llm_cfg, &summary, None) {
Ok(msg) => {
println!("[AI] Generated: {}", msg);
return msg;
Expand All @@ -186,7 +192,13 @@ fn generate_ai_message(repo_root: &std::path::Path, index: &Index) -> String {
}
}
} else {
println!("[AI] AGIT_LLM_API_KEY not set, using basic template.");
println!(
"[AI] No API key configured.\n\
Configure via:\n \
1. env (one-shot): AGIT_LLM_API_KEY=sk-xxx AGIT_LLM_PROVIDER=deepseek\n \
2. file (persist): ~/.agitconfig.toml\n [llm]\n api_key = \"sk-xxx\"\n provider = \"deepseek\"\n\
Supported providers: openai, deepseek, anthropic, moonshot, zhipu, ollama"
);
}
// 回退:基于文件列表生成简单消息
let paths: Vec<&str> = index.entries.keys().map(|s| s.as_str()).collect();
Expand All @@ -204,6 +216,7 @@ fn generate_ai_message(_repo_root: &std::path::Path, _index: &Index) -> String {
}

/// 构建暂存区文件变更摘要(供 AI 使用)。
#[cfg(feature = "ai")]
fn build_staged_summary(repo_root: &std::path::Path, index: &Index) -> String {
let mut lines = vec!["Staged changes:".to_string()];

Expand Down
15 changes: 11 additions & 4 deletions src/commands/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1001,12 +1001,15 @@ pub fn run(
"# git ls-files --others --exclude-standard\n",
)?;

// ── .gitignore 模板 ──
// ── .gitignore ──
// 始终包含 .agit/ 防止 agit 配置(含密钥)被意外提交
let mut gitignore = String::from("# agit config directory (may contain secrets)\n.agit/\n\n");
let has_pattern = pattern.is_some();
if let Some(pattern_name) = pattern {
match get_pattern_text(pattern_name) {
Some(content) => {
fs::write(target.join(".gitignore"), content)?;
println!(" Created .gitignore ({})", pattern_name);
Some(template) => {
gitignore.push_str(template);
println!(" Created .gitignore (.agit/ + {})", pattern_name);
}
None => {
eprintln!(
Expand All @@ -1021,6 +1024,10 @@ pub fn run(
}
}
}
if !has_pattern {
println!(" Created .gitignore (.agit/)");
}
fs::write(target.join(".gitignore"), gitignore)?;

// ── 许可证模板 ──
if let Some(licence_name) = licence {
Expand Down
111 changes: 111 additions & 0 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ pub struct Config {
pub user_name: String,
pub user_email: String,
pub aliases: HashMap<String, String>,
/// LLM API 配置(api_key / provider / model)
#[allow(dead_code)]
pub llm: LlmConfig,
}

/// LLM 配置段。
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct LlmConfig {
pub api_key: Option<String>,
pub provider: Option<String>,
pub model: Option<String>,
}

/// TOML 配置文件的反序列化结构。
Expand All @@ -18,6 +30,8 @@ struct ConfigFile {
user: Option<UserSection>,
#[serde(default)]
alias: Option<HashMap<String, String>>,
#[serde(default)]
llm: Option<LlmSection>,
}

#[derive(Debug, Clone, Default, Deserialize)]
Expand All @@ -28,12 +42,53 @@ struct UserSection {
email: Option<String>,
}

#[derive(Debug, Clone, Default, Deserialize)]
struct LlmSection {
#[serde(default)]
api_key: Option<String>,
#[serde(default)]
provider: Option<String>,
#[serde(default)]
model: Option<String>,
}

/// 内置 LLM 厂商预设:(provider, api_url, default_model)
#[allow(dead_code)]
pub const LLM_PROVIDERS: &[(&str, &str, &str)] = &[
("openai", "https://api.openai.com/v1", "gpt-4o-mini"),
("deepseek", "https://api.deepseek.com/v1", "deepseek-chat"),
(
"anthropic",
"https://api.anthropic.com/v1",
"claude-haiku-3-5",
),
("moonshot", "https://api.moonshot.cn/v1", "moonshot-v1-8k"),
(
"zhipu",
"https://open.bigmodel.cn/api/paas/v4",
"glm-4-flash",
),
("ollama", "http://localhost:11434/v1", "llama3"),
];

/// 根据 provider 名查找预设的 API URL 和 model。
#[allow(dead_code)]
pub fn resolve_llm_provider(provider: &str) -> Option<(&'static str, &'static str)> {
LLM_PROVIDERS
.iter()
.find(|(name, _, _)| *name == provider.to_lowercase())
.map(|(_, url, model)| (*url, *model))
}

impl Config {
/// 按优先级加载配置:环境变量 > 仓库级 .agit/config.toml > 全局 ~/.agitconfig.toml > 默认值。
pub fn load(repo_path: Option<&Path>) -> Self {
let mut file_user_name: Option<String> = None;
let mut file_user_email: Option<String> = None;
let mut aliases = HashMap::new();
let mut llm_api_key: Option<String> = None;
let mut llm_provider: Option<String> = None;
let mut llm_model: Option<String> = None;

// 1. 读取全局配置文件 ~/.agitconfig.toml
if let Some(home) = dirs_fallback() {
Expand All @@ -45,10 +100,19 @@ impl Config {
&mut file_user_email,
&mut aliases,
);
// 全局配置允许 api_key
merge_llm_config(
&cfg,
&mut llm_api_key,
&mut llm_provider,
&mut llm_model,
true,
);
}
}

// 2. 读取仓库级配置文件 .agit/config.toml(覆盖全局)
// 注意:api_key 不从仓库级配置读取,防止泄漏到仓库中
if let Some(repo) = repo_path {
let repo_config = repo.join(".agit").join("config.toml");
if let Some(cfg) = read_config_file(&repo_config) {
Expand All @@ -58,6 +122,14 @@ impl Config {
&mut file_user_email,
&mut aliases,
);
// api_key 不从仓库级配置读取(include_api_key=false)
merge_llm_config(
&cfg,
&mut llm_api_key,
&mut llm_provider,
&mut llm_model,
false,
);
}
}

Expand All @@ -74,10 +146,26 @@ impl Config {
.or(file_user_email)
.unwrap_or_else(|| "agit@localhost".to_string());

// LLM 环境变量覆盖配置文件
if let Ok(key) = env::var("AGIT_LLM_API_KEY") {
llm_api_key = Some(key);
}
if let Ok(provider) = env::var("AGIT_LLM_PROVIDER") {
llm_provider = Some(provider);
}
if let Ok(model) = env::var("AGIT_LLM_MODEL") {
llm_model = Some(model);
}

Config {
user_name,
user_email,
aliases,
llm: LlmConfig {
api_key: llm_api_key,
provider: llm_provider,
model: llm_model,
},
}
}
}
Expand Down Expand Up @@ -108,6 +196,29 @@ fn merge_config(
}
}

fn merge_llm_config(
cfg: &ConfigFile,
api_key: &mut Option<String>,
provider: &mut Option<String>,
model: &mut Option<String>,
include_api_key: bool,
) {
if let Some(ref llm) = cfg.llm {
// api_key 只从全局配置读取,不从仓库级配置读取(防泄漏)
if include_api_key {
if let Some(ref key) = llm.api_key {
*api_key = Some(key.clone());
}
}
if let Some(ref p) = llm.provider {
*provider = Some(p.clone());
}
if let Some(ref m) = llm.model {
*model = Some(m.clone());
}
}
}

/// 不依赖 `dirs` crate 的 home 目录获取。
fn dirs_fallback() -> Option<PathBuf> {
env::var("HOME")
Expand Down
3 changes: 2 additions & 1 deletion tests/p1_integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ fn test_init_pattern_unknown_warns() {
let stderr = String::from_utf8_lossy(&output.stderr);
assert!(stderr.contains("unknown gitignore pattern"));
assert!(stderr.contains("Available:"));
assert!(!repo.join(".gitignore").exists());
// .gitignore 始终创建(含 .agit/ 安全守卫),即使 pattern 无效
assert!(repo.join(".gitignore").exists());
let _ = fs::remove_dir_all(&repo);
}

Expand Down
Loading