diff --git a/crates/config/src/lib.rs b/crates/config/src/lib.rs index b8a982cbc..7d169f037 100644 --- a/crates/config/src/lib.rs +++ b/crates/config/src/lib.rs @@ -3239,16 +3239,24 @@ pub const LEGACY_APP_DIR: &str = ".deepseek"; /// `$CODEWHALE_HOME` takes precedence when set. Otherwise defaults to /// `$HOME/.codewhale`. This is the write target for new product state. pub fn codewhale_home() -> Result { - if let Ok(val) = std::env::var("CODEWHALE_HOME") { - let trimmed = val.trim(); - if !trimmed.is_empty() { - return Ok(PathBuf::from(trimmed)); - } + if let Some(path) = explicit_codewhale_home() { + return Ok(path); } let home = effective_home_dir().context("failed to resolve home directory")?; Ok(home.join(CODEWHALE_APP_DIR)) } +fn explicit_codewhale_home() -> Option { + std::env::var("CODEWHALE_HOME").ok().and_then(|val| { + let trimmed = val.trim(); + if trimmed.is_empty() { + None + } else { + Some(PathBuf::from(trimmed)) + } + }) +} + /// Resolve the legacy DeepSeek home directory (`$HOME/.deepseek`). /// /// Always returns the legacy path regardless of whether it exists. @@ -3308,6 +3316,9 @@ pub fn resolve_state_dir(subdir: &str) -> Result { if primary.exists() { return Ok(primary); } + if explicit_codewhale_home().is_some() { + return Ok(primary); + } let legacy = legacy_deepseek_home()?.join(subdir); if legacy.exists() { return Ok(legacy); @@ -3328,7 +3339,9 @@ pub fn resolve_state_dir(subdir: &str) -> Result { pub fn ensure_state_dir(subdir: &str) -> Result { ensure_safe_state_subdir(subdir)?; let dir = codewhale_home()?.join(subdir); - migrate_legacy_state_dir(&dir, subdir)?; + if explicit_codewhale_home().is_none() { + migrate_legacy_state_dir(&dir, subdir)?; + } std::fs::create_dir_all(&dir) .with_context(|| format!("failed to create {}/", dir.display()))?; Ok(dir) @@ -3599,6 +3612,9 @@ pub fn default_config_path() -> Result { if primary.exists() { return Ok(primary); } + if explicit_codewhale_home().is_some() { + return Ok(primary); + } let legacy = legacy_deepseek_home()?.join(CONFIG_FILE_NAME); if legacy.exists() { return Ok(legacy); @@ -3632,6 +3648,9 @@ pub fn migrate_config_if_needed() -> Result> { if primary.exists() { return Ok(None); } + if explicit_codewhale_home().is_some() { + return Ok(None); + } let legacy = legacy_deepseek_home()?.join(CONFIG_FILE_NAME); if !legacy.exists() { return Ok(None); diff --git a/crates/config/src/tests.rs b/crates/config/src/tests.rs index 5ac83ec74..b3406d6e2 100644 --- a/crates/config/src/tests.rs +++ b/crates/config/src/tests.rs @@ -1911,7 +1911,7 @@ fn migrate_config_reports_copied_legacy_path() { unsafe { env::set_var("HOME", &home); env::set_var("USERPROFILE", &home); - env::set_var("CODEWHALE_HOME", &primary_dir); + env::remove_var("CODEWHALE_HOME"); } let migration = migrate_config_if_needed() @@ -1962,9 +1962,9 @@ impl Drop for StateEnvRestore { } } -/// Points `HOME`/`USERPROFILE`/`CODEWHALE_HOME` at a fresh temp tree so -/// `codewhale_home()` -> `/.codewhale` and `legacy_deepseek_home()` -/// -> `/.deepseek`. Env is restored on drop. +/// Points `HOME`/`USERPROFILE` at a fresh temp tree so `codewhale_home()` -> +/// `/.codewhale` and `legacy_deepseek_home()` -> `/.deepseek`. +/// Env is restored on drop. struct StateDirEnv { home: PathBuf, _restore: StateEnvRestore, @@ -1985,7 +1985,7 @@ impl StateDirEnv { unsafe { env::set_var("HOME", &home); env::set_var("USERPROFILE", &home); - env::set_var("CODEWHALE_HOME", home.join(CODEWHALE_APP_DIR)); + env::remove_var("CODEWHALE_HOME"); } Self { home, @@ -2087,6 +2087,89 @@ fn resolve_state_dir_still_finds_legacy_for_backfill() { let _ = fs::remove_dir_all(&state_env.home); } +#[test] +fn explicit_codewhale_home_does_not_read_or_migrate_legacy_state() { + let _lock = env_lock(); + let unique = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("clock") + .as_nanos(); + let home = std::env::temp_dir().join(format!( + "codewhale-explicit-state-home-{}-{unique}", + std::process::id() + )); + let explicit = home.join("isolated-codewhale"); + let restore = StateEnvRestore { + home: env::var_os("HOME"), + userprofile: env::var_os("USERPROFILE"), + codewhale_home: env::var_os("CODEWHALE_HOME"), + }; + // Safety: test-only environment mutation is serialized by env_lock(). + unsafe { + env::set_var("HOME", &home); + env::set_var("USERPROFILE", &home); + env::set_var("CODEWHALE_HOME", &explicit); + } + let _restore = restore; + + let legacy_catalog = home.join(LEGACY_APP_DIR).join("catalog"); + fs::create_dir_all(&legacy_catalog).expect("legacy catalog"); + fs::write(legacy_catalog.join("state.json"), b"legacy").expect("legacy file"); + + assert_eq!( + resolve_state_dir("catalog").expect("resolve"), + explicit.join("catalog") + ); + assert_eq!( + ensure_state_dir("catalog").expect("ensure"), + explicit.join("catalog") + ); + assert!(!explicit.join("catalog").join("state.json").exists()); + assert!(legacy_catalog.join("state.json").exists()); + + let _ = fs::remove_dir_all(home); +} + +#[test] +fn explicit_codewhale_home_does_not_read_or_migrate_legacy_config() { + let _lock = env_lock(); + let unique = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("clock") + .as_nanos(); + let home = std::env::temp_dir().join(format!( + "codewhale-explicit-config-home-{}-{unique}", + std::process::id() + )); + let explicit = home.join("isolated-codewhale"); + let restore = StateEnvRestore { + home: env::var_os("HOME"), + userprofile: env::var_os("USERPROFILE"), + codewhale_home: env::var_os("CODEWHALE_HOME"), + }; + // Safety: test-only environment mutation is serialized by env_lock(). + unsafe { + env::set_var("HOME", &home); + env::set_var("USERPROFILE", &home); + env::set_var("CODEWHALE_HOME", &explicit); + } + let _restore = restore; + + let legacy_config = home.join(LEGACY_APP_DIR).join(CONFIG_FILE_NAME); + fs::create_dir_all(legacy_config.parent().expect("legacy parent")).expect("legacy dir"); + fs::write(&legacy_config, b"provider = \"deepseek\"\n").expect("legacy config"); + + assert_eq!( + default_config_path().expect("default config path"), + explicit.join(CONFIG_FILE_NAME) + ); + assert_eq!(migrate_config_if_needed().expect("migration"), None); + assert!(!explicit.join(CONFIG_FILE_NAME).exists()); + assert!(legacy_config.exists()); + + let _ = fs::remove_dir_all(home); +} + #[test] fn state_resolvers_reject_path_traversal_subdirs() { // Defense against path injection (#3240 hardening): the public state diff --git a/crates/secrets/src/lib.rs b/crates/secrets/src/lib.rs index 15506df3a..93bdbfa03 100644 --- a/crates/secrets/src/lib.rs +++ b/crates/secrets/src/lib.rs @@ -480,7 +480,22 @@ impl FileKeyringStore { } } let body = serde_json::to_string_pretty(blob)?; - fs::write(&self.path, body)?; + #[cfg(unix)] + { + use std::io::Write; + use std::os::unix::fs::OpenOptionsExt; + let mut file = fs::OpenOptions::new() + .create(true) + .truncate(true) + .write(true) + .mode(0o600) + .open(&self.path)?; + file.write_all(body.as_bytes())?; + } + #[cfg(not(unix))] + { + fs::write(&self.path, body)?; + } #[cfg(unix)] { use std::os::unix::fs::PermissionsExt; diff --git a/crates/state/src/lib.rs b/crates/state/src/lib.rs index 34b7bfe45..103e9526b 100644 --- a/crates/state/src/lib.rs +++ b/crates/state/src/lib.rs @@ -305,8 +305,11 @@ impl StateStore { } fn conn(&self) -> Result { - Connection::open(&self.db_path) - .with_context(|| format!("failed to open state db {}", self.db_path.display())) + let conn = Connection::open(&self.db_path) + .with_context(|| format!("failed to open state db {}", self.db_path.display()))?; + conn.execute_batch("PRAGMA foreign_keys = ON;") + .context("failed to enable sqlite foreign keys")?; + Ok(conn) } fn init_schema(&self) -> Result<()> { @@ -1867,6 +1870,45 @@ mod tests { assert!(err.to_string().contains("thread missing-thread not found")); } + #[test] + fn delete_thread_cascades_child_rows() { + let store = temp_state_store("delete-thread-cascades"); + store + .upsert_thread(&test_thread("thread-1")) + .expect("upsert thread"); + store + .append_message("thread-1", "user", "hello", Some(serde_json::json!({}))) + .expect("append message"); + store + .save_checkpoint( + "thread-1", + "checkpoint-1", + &serde_json::json!({ "ok": true }), + ) + .expect("save checkpoint"); + + store.delete_thread("thread-1").expect("delete thread"); + + let conn = store.conn().expect("conn"); + let message_count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM messages WHERE thread_id = ?1", + params!["thread-1"], + |row| row.get(0), + ) + .expect("count messages"); + let checkpoint_count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM checkpoints WHERE thread_id = ?1", + params!["thread-1"], + |row| row.get(0), + ) + .expect("count checkpoints"); + + assert_eq!(message_count, 0); + assert_eq!(checkpoint_count, 0); + } + #[test] fn record_thread_goal_usage_accumulates_tokens_and_time() { let store = temp_state_store("thread-goal-usage"); diff --git a/crates/tui/src/automation_manager.rs b/crates/tui/src/automation_manager.rs index 6a394e76b..b44cff8a1 100644 --- a/crates/tui/src/automation_manager.rs +++ b/crates/tui/src/automation_manager.rs @@ -319,17 +319,33 @@ impl AutomationManager { Self::open(default_automations_dir()) } - fn automation_path(&self, id: &str) -> PathBuf { - self.automations_dir.join(format!("{id}.json")) + fn validate_storage_id(id: &str, field: &str) -> Result<()> { + let path = Path::new(id); + let mut components = path.components(); + let Some(component) = components.next() else { + bail!("{field} must not be empty"); + }; + if components.next().is_some() || !matches!(component, std::path::Component::Normal(_)) { + bail!("{field} must be a single path component"); + } + Ok(()) } - fn runs_dir_for(&self, automation_id: &str) -> PathBuf { - self.runs_dir.join(automation_id) + fn automation_path(&self, id: &str) -> Result { + Self::validate_storage_id(id, "automation_id")?; + Ok(self.automations_dir.join(format!("{id}.json"))) } - fn run_path(&self, automation_id: &str, run_id: &str) -> PathBuf { - self.runs_dir_for(automation_id) - .join(format!("{run_id}.json")) + fn runs_dir_for(&self, automation_id: &str) -> Result { + Self::validate_storage_id(automation_id, "automation_id")?; + Ok(self.runs_dir.join(automation_id)) + } + + fn run_path(&self, automation_id: &str, run_id: &str) -> Result { + Self::validate_storage_id(run_id, "run_id")?; + Ok(self + .runs_dir_for(automation_id)? + .join(format!("{run_id}.json"))) } pub fn create_automation(&self, req: CreateAutomationRequest) -> Result { @@ -362,7 +378,7 @@ impl AutomationManager { } pub fn get_automation(&self, id: &str) -> Result { - let path = self.automation_path(id); + let path = self.automation_path(id)?; let raw = fs::read_to_string(&path) .with_context(|| format!("Failed to read automation {}", path.display()))?; let record: AutomationRecord = serde_json::from_str(&raw) @@ -378,7 +394,7 @@ impl AutomationManager { } pub fn save_automation(&self, record: &AutomationRecord) -> Result<()> { - write_json_atomic(&self.automation_path(&record.id), record) + write_json_atomic(&self.automation_path(&record.id)?, record) } pub fn list_automations(&self) -> Result> { @@ -476,11 +492,11 @@ impl AutomationManager { pub fn delete_automation(&self, id: &str) -> Result { let existing = self.get_automation(id)?; - let path = self.automation_path(id); + let path = self.automation_path(id)?; fs::remove_file(&path) .with_context(|| format!("Failed to delete automation {}", path.display()))?; - let runs_dir = self.runs_dir_for(id); + let runs_dir = self.runs_dir_for(id)?; if runs_dir.exists() { fs::remove_dir_all(&runs_dir).with_context(|| { format!("Failed to delete automation runs {}", runs_dir.display()) @@ -495,7 +511,7 @@ impl AutomationManager { automation_id: &str, limit: Option, ) -> Result> { - let dir = self.runs_dir_for(automation_id); + let dir = self.runs_dir_for(automation_id)?; if !dir.exists() { return Ok(Vec::new()); } @@ -531,9 +547,9 @@ impl AutomationManager { } fn save_run(&self, run: &AutomationRunRecord) -> Result<()> { - let dir = self.runs_dir_for(&run.automation_id); + let dir = self.runs_dir_for(&run.automation_id)?; fs::create_dir_all(&dir).with_context(|| format!("Failed to create {}", dir.display()))?; - write_json_atomic(&self.run_path(&run.automation_id, &run.id), run) + write_json_atomic(&self.run_path(&run.automation_id, &run.id)?, run) } async fn enqueue_run_task( @@ -941,14 +957,50 @@ mod tests { error: None, }; manager.save_run(&run).expect("save run"); - assert!(manager.runs_dir_for(&created.id).exists()); + assert!( + manager + .runs_dir_for(&created.id) + .expect("runs dir") + .exists() + ); manager .delete_automation(&created.id) .expect("delete automation"); assert!(manager.get_automation(&created.id).is_err()); - assert!(!manager.runs_dir_for(&created.id).exists()); + assert!( + !manager + .runs_dir_for(&created.id) + .expect("runs dir") + .exists() + ); + } + + #[test] + fn rejects_storage_ids_with_path_components() { + let tempdir = tempfile::tempdir().expect("tempdir"); + let manager = AutomationManager::open(tempdir.path().to_path_buf()).expect("manager"); + + assert!(manager.get_automation("../escape").is_err()); + assert!(manager.list_runs("../escape", None).is_err()); + + let run = AutomationRunRecord { + schema_version: CURRENT_RUN_SCHEMA_VERSION, + id: "../escape".to_string(), + automation_id: Uuid::new_v4().to_string(), + scheduled_for: Utc::now(), + status: AutomationRunStatus::Queued, + created_at: Utc::now(), + started_at: None, + ended_at: None, + task_id: None, + thread_id: None, + turn_id: None, + error: None, + }; + assert!(manager.save_run(&run).is_err()); + assert!(!tempdir.path().join("escape.json").exists()); } #[test] diff --git a/crates/tui/src/task_manager.rs b/crates/tui/src/task_manager.rs index a9ee32bc7..049f860fc 100644 --- a/crates/tui/src/task_manager.rs +++ b/crates/tui/src/task_manager.rs @@ -39,6 +39,18 @@ const fn default_task_schema_version() -> u32 { CURRENT_TASK_SCHEMA_VERSION } +fn validate_storage_component(value: &str, field: &str) -> Result<()> { + let path = Path::new(value); + let mut components = path.components(); + let Some(component) = components.next() else { + bail!("{field} must not be empty"); + }; + if components.next().is_some() || !matches!(component, std::path::Component::Normal(_)) { + bail!("{field} must be a single path component"); + } + Ok(()) +} + /// Durable task status. #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] @@ -1361,6 +1373,7 @@ impl TaskManager { } fn write_artifact(&self, task_id: &str, label: &str, content: &str) -> Result { + validate_storage_component(task_id, "task_id")?; let artifact_dir = self.artifacts_dir.join(task_id); fs::create_dir_all(&artifact_dir) .with_context(|| format!("Failed to create artifact dir {}", artifact_dir.display()))?; @@ -1962,6 +1975,22 @@ mod tests { Ok(()) } + #[tokio::test] + async fn artifact_writer_rejects_path_component_task_ids() -> Result<()> { + let root = std::env::temp_dir().join(format!("deepseek-task-test-{}", Uuid::new_v4())); + let manager = + TaskManager::start_with_executor(test_config(root.clone()), Arc::new(MockExecutor)) + .await?; + + let err = manager + .write_task_artifact("../escape", "patch", "diff") + .expect_err("path-shaped task id should be rejected"); + + assert!(err.to_string().contains("task_id")); + assert!(!root.join("escape").exists()); + Ok(()) + } + #[tokio::test] async fn cancel_running_task_marks_canceled() -> Result<()> { let root = std::env::temp_dir().join(format!("deepseek-task-test-{}", Uuid::new_v4())); diff --git a/crates/tui/src/tools/tasks.rs b/crates/tui/src/tools/tasks.rs index 54b86b5a1..1bc6a7309 100644 --- a/crates/tui/src/tools/tasks.rs +++ b/crates/tui/src/tools/tasks.rs @@ -25,6 +25,22 @@ const MAX_SUMMARY_CHARS: usize = 900; const DEFAULT_GATE_TIMEOUT_MS: u64 = 120_000; const MAX_GATE_TIMEOUT_MS: u64 = 600_000; +fn validate_storage_component(value: &str, field: &str) -> Result<(), ToolError> { + let path = Path::new(value); + let mut components = path.components(); + let Some(component) = components.next() else { + return Err(ToolError::invalid_input(format!( + "{field} must not be empty" + ))); + }; + if components.next().is_some() || !matches!(component, std::path::Component::Normal(_)) { + return Err(ToolError::invalid_input(format!( + "{field} must be a single path component" + ))); + } + Ok(()) +} + fn build_gate_command_parts(command: &str) -> (String, Vec) { ( "/bin/sh".to_string(), @@ -593,7 +609,16 @@ impl ToolSpec for PrAttemptRecordTool { } async fn execute(&self, input: Value, context: &ToolContext) -> Result { - let task_id = task_id_from_input_or_context(&input, context)?; + let mut task_id = task_id_from_input_or_context(&input, context)?; + if let Some(manager) = context.runtime.task_manager.as_ref() { + task_id = manager + .get_task(&task_id) + .await + .map_err(|e| ToolError::execution_failed(e.to_string()))? + .id; + } else { + validate_storage_component(&task_id, "task_id")?; + } let base_sha = git_output(&context.workspace, &["rev-parse", "HEAD"]).ok(); let head_sha = base_sha.clone(); let branch = git_output(&context.workspace, &["rev-parse", "--abbrev-ref", "HEAD"]).ok(); @@ -608,7 +633,7 @@ impl ToolSpec for PrAttemptRecordTool { .filter(|line| !line.trim().is_empty()) .map(ToString::to_string) .collect::>(); - let patch_path = write_task_artifact_for(context, &task_id, "attempt_patch", &diff)?; + let patch_path = write_task_artifact_for(context, &task_id, "attempt_patch", &diff).await?; let attempt = TaskAttemptRecord { id: format!("attempt_{}", &Uuid::new_v4().to_string()[..8]), attempt_group_id: optional_str(&input, "attempt_group_id") @@ -836,6 +861,7 @@ fn write_runtime_artifact( let Some(data_dir) = context.runtime.task_data_dir.as_ref() else { return Ok(None); }; + validate_storage_component(task_id, "task_id")?; let artifact_dir = data_dir.join("artifacts").join(task_id); std::fs::create_dir_all(&artifact_dir) .map_err(|e| ToolError::execution_failed(format!("create artifact dir: {e}")))?; @@ -855,21 +881,26 @@ fn write_runtime_artifact( )) } -fn write_task_artifact_for( +async fn write_task_artifact_for( context: &ToolContext, task_id: &str, label: &str, content: &str, ) -> Result, ToolError> { if let Some(manager) = context.runtime.task_manager.as_ref() { + let resolved = manager + .get_task(task_id) + .await + .map_err(|e| ToolError::execution_failed(e.to_string()))?; return manager - .write_task_artifact(task_id, label, content) + .write_task_artifact(&resolved.id, label, content) .map(Some) .map_err(|e| ToolError::execution_failed(e.to_string())); } if context.runtime.active_task_id.as_deref() != Some(task_id) { return Ok(None); } + validate_storage_component(task_id, "task_id")?; write_runtime_artifact(context, label, content) } @@ -1003,7 +1034,47 @@ fn sanitize_filename(input: &str) -> String { #[cfg(test)] mod tests { use super::*; + use std::sync::Arc; + + use crate::task_manager::{ + ExecutionTask, TaskExecutionEvent, TaskExecutionResult, TaskExecutor, TaskManager, + TaskManagerConfig, TaskStatus, + }; + use crate::tools::spec::RuntimeToolServices; use crate::tools::spec::ToolSpec; + use tokio::sync::mpsc; + use tokio_util::sync::CancellationToken; + + struct NoopExecutor; + + #[async_trait] + impl TaskExecutor for NoopExecutor { + async fn execute( + &self, + _task: ExecutionTask, + _events: mpsc::UnboundedSender, + _cancel: CancellationToken, + ) -> TaskExecutionResult { + TaskExecutionResult { + status: TaskStatus::Completed, + result_text: None, + error: None, + } + } + } + + fn test_task_manager_config(root: PathBuf) -> TaskManagerConfig { + TaskManagerConfig { + data_dir: root, + worker_count: 1, + default_workspace: PathBuf::from("."), + default_model: "deepseek-v4-flash".to_string(), + default_mode: "agent".to_string(), + allow_shell: false, + trust_mode: false, + max_subagents: 2, + } + } #[test] fn durable_task_schema_requires_prompt() { @@ -1031,6 +1102,37 @@ mod tests { assert!(wait_schema["properties"]["gate"].is_object()); } + #[tokio::test] + async fn task_artifact_helper_requires_existing_manager_task_before_write() { + let root = std::env::temp_dir().join(format!("codewhale-task-tool-{}", Uuid::new_v4())); + let manager = TaskManager::start_with_executor( + test_task_manager_config(root.clone()), + Arc::new(NoopExecutor), + ) + .await + .expect("task manager"); + let context = + ToolContext::new(std::env::temp_dir()).with_runtime_services(RuntimeToolServices { + task_manager: Some(manager), + task_data_dir: Some(root.clone()), + ..RuntimeToolServices::default() + }); + + let err = write_task_artifact_for(&context, "missing-task", "attempt_patch", "patch") + .await + .expect_err("missing task must be rejected before artifact write"); + + assert!( + err.to_string().contains("Task not found"), + "unexpected error: {err}" + ); + assert!( + !root.join("artifacts").join("missing-task").exists(), + "artifact directory must not be created for missing task" + ); + let _ = std::fs::remove_dir_all(root); + } + #[test] fn gate_command_uses_login_shell_invocation() { let (program, args) = build_gate_command_parts("echo hello");