Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions crates/state/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,11 @@ impl StateStore {
}

fn conn(&self) -> Result<Connection> {
Connection::open(&self.db_path)
.with_context(|| format!("failed to open state db {}", self.db_path.display()))
let conn = Connection::open(&self.db_path)
.with_context(|| format!("failed to open state db {}", self.db_path.display()))?;
conn.execute_batch("PRAGMA foreign_keys = ON;")
.context("failed to enable state db foreign keys")?;
Comment on lines +310 to +311

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using execute_batch for a single, simple PRAGMA statement is less efficient because it parses the input as a batch of multiple SQL statements. It is more idiomatic and efficient to use execute with empty parameters.

Suggested change
conn.execute_batch("PRAGMA foreign_keys = ON;")
.context("failed to enable state db foreign keys")?;
conn.execute("PRAGMA foreign_keys = ON;", [])
.context("failed to enable state db foreign keys")?;

Ok(conn)
}

fn init_schema(&self) -> Result<()> {
Expand Down Expand Up @@ -1818,6 +1821,13 @@ mod tests {
}
}

fn table_count(store: &StateStore, table: &str, thread_id: &str) -> i64 {
let conn = store.conn().expect("open test connection");
let query = format!("SELECT COUNT(*) FROM {table} WHERE thread_id = ?1");
conn.query_row(&query, params![thread_id], |row| row.get(0))
.expect("count rows")
}
Comment on lines +1824 to +1829

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Constructing SQL queries via string interpolation (format!) is a dangerous pattern that can lead to SQL injection vulnerabilities. Although this is a test helper and currently only called with static string literals, it is highly recommended to validate the table parameter against an allowlist of expected table names to prevent unsafe patterns from spreading or being copied into production code.

Suggested change
fn table_count(store: &StateStore, table: &str, thread_id: &str) -> i64 {
let conn = store.conn().expect("open test connection");
let query = format!("SELECT COUNT(*) FROM {table} WHERE thread_id = ?1");
conn.query_row(&query, params![thread_id], |row| row.get(0))
.expect("count rows")
}
fn table_count(store: &StateStore, table: &str, thread_id: &str) -> i64 {
let allowed_tables = ["messages", "thread_dynamic_tools", "checkpoints", "thread_goals"];
assert!(allowed_tables.contains(&table), "invalid table name: {table}");
let conn = store.conn().expect("open test connection");
let query = format!("SELECT COUNT(*) FROM {table} WHERE thread_id = ?1");
conn.query_row(&query, params![thread_id], |row| row.get(0))
.expect("count rows")
}


#[test]
fn thread_goal_crud_round_trips_and_replaces() {
let store = temp_state_store("thread-goal-crud");
Expand Down Expand Up @@ -1867,6 +1877,58 @@ mod tests {
assert!(err.to_string().contains("thread missing-thread not found"));
}

#[test]
fn delete_thread_cascades_child_rows() {
let store = temp_state_store("delete-thread-cascade");
store
.upsert_thread(&test_thread("thread-1"))
.expect("upsert thread");
store
.append_message("thread-1", "user", "hello", None)
.expect("append message");
store
.persist_dynamic_tools(
"thread-1",
&[DynamicToolRecord {
position: 0,
name: "lookup".to_string(),
description: Some("Look something up".to_string()),
input_schema: serde_json::json!({"type": "object"}),
}],
)
.expect("persist dynamic tools");
store
.save_checkpoint("thread-1", "checkpoint-1", &serde_json::json!({"ok": true}))
.expect("save checkpoint");
store
.upsert_thread_goal(&test_goal("thread-1", "finish the thread"))
.expect("upsert goal");

assert_eq!(table_count(&store, "messages", "thread-1"), 1);
assert_eq!(table_count(&store, "thread_dynamic_tools", "thread-1"), 1);
assert_eq!(table_count(&store, "checkpoints", "thread-1"), 1);
assert_eq!(table_count(&store, "thread_goals", "thread-1"), 1);

store.delete_thread("thread-1").expect("delete thread");

assert!(
store
.get_thread("thread-1")
.expect("read deleted thread")
.is_none()
);
assert_eq!(table_count(&store, "messages", "thread-1"), 0);
assert_eq!(table_count(&store, "thread_dynamic_tools", "thread-1"), 0);
assert_eq!(table_count(&store, "checkpoints", "thread-1"), 0);
assert_eq!(table_count(&store, "thread_goals", "thread-1"), 0);
assert!(
store
.get_thread_goal("thread-1")
.expect("read deleted goal")
.is_none()
);
}

#[test]
fn record_thread_goal_usage_accumulates_tokens_and_time() {
let store = temp_state_store("thread-goal-usage");
Expand Down
Loading