diff --git a/CLAUDE.md b/CLAUDE.md index 468ada8..157a6c3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -185,37 +185,99 @@ cargo run --release --example demo ### Model Loading -The MiniLM model can be loaded from: -1. **Local files** (preferred): `packages/agent-state-rs/models/minilm/` directory -2. **HuggingFace Hub**: Downloads automatically (~90MB, cached in `~/.cache/huggingface`) +Models can be loaded from: +1. **Local files** (preferred): `packages/agent-state-rs/models//` directory +2. **HuggingFace Hub**: Downloads automatically (~90MB per model, cached in `~/.cache/huggingface`) -To use local model (faster startup, works offline): +**No API key required** - all models are public and downloaded directly. + +To use local models (faster startup, works offline): ```bash cd packages/agent-state-rs -mkdir -p models/minilm -curl -L -o models/minilm/config.json "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/config.json" -curl -L -o models/minilm/tokenizer.json "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json" -curl -L -o models/minilm/model.safetensors "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/model.safetensors" + +# BGE-Small (default, recommended) +mkdir -p models/bge-small +curl -L -o models/bge-small/config.json "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/config.json" +curl -L -o models/bge-small/tokenizer.json "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/tokenizer.json" +curl -L -o models/bge-small/model.safetensors "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/model.safetensors" + +# MiniLM-L6 (fast mode) +mkdir -p models/minilm-l6 +curl -L -o models/minilm-l6/config.json "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/config.json" +curl -L -o models/minilm-l6/tokenizer.json "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json" +curl -L -o models/minilm-l6/model.safetensors "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/model.safetensors" + +# BGE-Base (most accurate, 768 dims) +mkdir -p models/bge-base +curl -L -o models/bge-base/config.json "https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/config.json" +curl -L -o models/bge-base/tokenizer.json "https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/tokenizer.json" +curl -L -o models/bge-base/model.safetensors "https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/model.safetensors" ``` ## Key Components (Rust Core) - **Brain** (`packages/agent-state-rs/src/brain.rs`): - - Loads MiniLM model (384-dim embeddings) + - Supports multiple embedding models (see below) + - Supports multiple backends (Candle, ONNX) - Classifies **Action** (Store vs Query) - - Classifies **DataType** (Task vs Memory) + - Classifies **DataType** (Task vs Memory vs Preference vs Relationship vs Event) - Uses zero-shot classification via anchor vectors - **Storage** (`packages/agent-state-rs/src/storage.rs`): - SQLite with blob vector storage - Cosine similarity search - Category-filtered queries + - Supports dynamic embedding dimensions - **AgentEngine** (`packages/agent-state-rs/src/lib.rs`): - `process()` - The unified API (auto-routes by intent) - `store()` / `search()` - Explicit methods when needed - Returns `AgentResponse` for structured handling +### Available Embedding Models + +| Model | Dimensions | MTEB Score | Notes | +|-------|------------|------------|-------| +| `MiniLmL6` | 384 | 56.3 | Fast, resource-constrained | +| `MiniLmL12` | 384 | 59.8 | Better accuracy, same dims | +| `BgeSmall` | 384 | 62.2 | **Best small model (default)** | +| `BgeBase` | 768 | 64.2 | **Best accuracy overall** | +| `E5Small` | 384 | 61.5 | Good alternative to BGE | +| `GteSmall` | 384 | 61.4 | Competitive small model | + +### Backends + +- **Candle** (default): Pure Rust, no external dependencies +- **ONNX Runtime**: Faster inference, better CPU optimization (requires `--features onnx-backend`) + +### Selecting a Model + +```rust +use agent_brain::{AgentEngine, EmbeddingModel, Backend}; + +// Default: BGE-Small with Candle backend +let engine = AgentEngine::new("agent.db")?; + +// Builder pattern for customization +let engine = AgentEngine::builder() + .db_path("agent.db") + .model(EmbeddingModel::BgeBase) // Most accurate (768 dims) + .backend(Backend::Candle) + .build()?; + +// In-memory database for testing +let engine = AgentEngine::builder() + .in_memory() + .model(EmbeddingModel::MiniLmL6) // Fast model + .build()?; + +// Mock mode (no ML model needed) +let engine = AgentEngine::builder() + .in_memory() + .mock() + .build()?; +``` + ### Intent Classification The Brain uses anchor vectors to classify: diff --git a/packages/agent-state-py/Cargo.toml b/packages/agent-state-py/Cargo.toml index 7886372..2b97ba3 100644 --- a/packages/agent-state-py/Cargo.toml +++ b/packages/agent-state-py/Cargo.toml @@ -9,12 +9,18 @@ license = "MIT" name = "_core" crate-type = ["cdylib"] +[features] +default = ["candle-backend"] +candle-backend = ["agent-brain/candle-backend"] +onnx-backend = ["agent-brain/onnx-backend"] +all-backends = ["candle-backend", "onnx-backend"] + [dependencies] pyo3 = { version = "0.22", features = ["extension-module"] } anyhow = "1.0" -# Link to the core Rust library -agent-brain = { path = "../agent-state-rs" } +# Link to the core Rust library with feature pass-through +agent-brain = { path = "../agent-state-rs", default-features = false } [profile.release] opt-level = 3 diff --git a/packages/agent-state-py/python/agent_state/__init__.py b/packages/agent-state-py/python/agent_state/__init__.py index e48ac9c..8fb438f 100644 --- a/packages/agent-state-py/python/agent_state/__init__.py +++ b/packages/agent-state-py/python/agent_state/__init__.py @@ -3,6 +3,25 @@ A semantic state engine for AI agents with a unified intent-based API. All AI processing (embeddings, intent classification) is powered by the Rust core. + +Supports multiple embedding models and backends: +- Models: MiniLM-L6, MiniLM-L12, BGE-Small (default), BGE-Base, E5-Small, GTE-Small +- Backends: Candle (pure Rust, default), ONNX Runtime (optimized inference) + +Example usage: + from agent_state import AgentEngine, EmbeddingModel, Backend + + # Default configuration (BGE-Small with Candle) + engine = AgentEngine() + + # With specific model for higher accuracy + engine = AgentEngine(model=EmbeddingModel.BgeBase) + + # With ONNX backend for faster inference + engine = AgentEngine(model=EmbeddingModel.BgeSmall, backend=Backend.Onnx) + + # Mock mode for testing (no model download) + engine = AgentEngine(mock=True) """ __version__ = "0.1.0" @@ -16,6 +35,9 @@ DataType, Action, TimeFilter, + # Model configuration + EmbeddingModel, + Backend, # Intent classification Intent, # Response types @@ -37,6 +59,9 @@ "DataType", "Action", "TimeFilter", + # Model configuration + "EmbeddingModel", + "Backend", # Intent "Intent", # Response types diff --git a/packages/agent-state-py/src/lib.rs b/packages/agent-state-py/src/lib.rs index f320190..e21c87f 100644 --- a/packages/agent-state-py/src/lib.rs +++ b/packages/agent-state-py/src/lib.rs @@ -12,6 +12,8 @@ use agent_brain::{ Action as RustAction, Intent as RustIntent, TimeFilter as RustTimeFilter, + EmbeddingModel as RustEmbeddingModel, + Backend as RustBackend, }; /// Convert Rust errors to Python exceptions @@ -27,6 +29,8 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -35,8 +39,8 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { } /// Type of data being stored or queried -#[pyclass] -#[derive(Clone, Debug)] +#[pyclass(eq, eq_int)] +#[derive(Clone, Debug, PartialEq)] pub enum DataType { /// Action items, reminders, todos Task, @@ -97,8 +101,8 @@ impl DataType { } /// The action the agent wants to perform -#[pyclass] -#[derive(Clone, Debug)] +#[pyclass(eq, eq_int)] +#[derive(Clone, Debug, PartialEq)] pub enum Action { /// Store information Store, @@ -135,8 +139,8 @@ impl Action { } /// Time filter for queries -#[pyclass] -#[derive(Clone, Debug)] +#[pyclass(eq, eq_int)] +#[derive(Clone, Debug, PartialEq)] pub enum TimeFilter { /// No time filtering All, @@ -171,6 +175,115 @@ impl TimeFilter { } } +/// Embedding model to use for text encoding +#[pyclass(eq, eq_int)] +#[derive(Clone, Debug, PartialEq)] +pub enum EmbeddingModel { + /// all-MiniLM-L6-v2 - Fast, 384 dims, MTEB score 56.3 + MiniLmL6, + /// all-MiniLM-L12-v2 - Better accuracy, 384 dims, MTEB score 59.8 + MiniLmL12, + /// BGE-Small-en-v1.5 - Best small model (default), 384 dims, MTEB score 62.2 + BgeSmall, + /// BGE-Base-en-v1.5 - Best accuracy, 768 dims, MTEB score 64.2 + BgeBase, + /// E5-Small-v2 - Good alternative, 384 dims, MTEB score 61.5 + E5Small, + /// GTE-Small - Competitive model, 384 dims, MTEB score 61.4 + GteSmall, +} + +impl From for RustEmbeddingModel { + fn from(m: EmbeddingModel) -> Self { + match m { + EmbeddingModel::MiniLmL6 => RustEmbeddingModel::MiniLmL6, + EmbeddingModel::MiniLmL12 => RustEmbeddingModel::MiniLmL12, + EmbeddingModel::BgeSmall => RustEmbeddingModel::BgeSmall, + EmbeddingModel::BgeBase => RustEmbeddingModel::BgeBase, + EmbeddingModel::E5Small => RustEmbeddingModel::E5Small, + EmbeddingModel::GteSmall => RustEmbeddingModel::GteSmall, + } + } +} + +impl From for EmbeddingModel { + fn from(m: RustEmbeddingModel) -> Self { + match m { + RustEmbeddingModel::MiniLmL6 => EmbeddingModel::MiniLmL6, + RustEmbeddingModel::MiniLmL12 => EmbeddingModel::MiniLmL12, + RustEmbeddingModel::BgeSmall => EmbeddingModel::BgeSmall, + RustEmbeddingModel::BgeBase => EmbeddingModel::BgeBase, + RustEmbeddingModel::E5Small => EmbeddingModel::E5Small, + RustEmbeddingModel::GteSmall => EmbeddingModel::GteSmall, + } + } +} + +#[pymethods] +impl EmbeddingModel { + /// Get the embedding dimension for this model + fn embedding_dim(&self) -> usize { + RustEmbeddingModel::from(self.clone()).embedding_dim() + } + + /// Get the MTEB benchmark score for this model + fn mteb_score(&self) -> f32 { + RustEmbeddingModel::from(self.clone()).mteb_score() + } + + /// Get the HuggingFace repository name + fn hf_repo(&self) -> &'static str { + RustEmbeddingModel::from(self.clone()).hf_repo() + } + + fn __repr__(&self) -> String { + format!("EmbeddingModel.{:?}", self) + } +} + +/// Backend for model inference +#[pyclass(eq, eq_int)] +#[derive(Clone, Debug, PartialEq)] +pub enum Backend { + /// Candle - Pure Rust, no external dependencies (default) + Candle, + /// ONNX Runtime - Optimized inference, requires onnx-backend feature + Onnx, + /// Mock - Hash-based embeddings for testing (no model required) + Mock, +} + +impl From for RustBackend { + fn from(b: Backend) -> Self { + match b { + Backend::Candle => RustBackend::Candle, + Backend::Onnx => RustBackend::Onnx, + Backend::Mock => RustBackend::Mock, + } + } +} + +impl From for Backend { + fn from(b: RustBackend) -> Self { + match b { + RustBackend::Candle => Backend::Candle, + RustBackend::Onnx => Backend::Onnx, + RustBackend::Mock => Backend::Mock, + } + } +} + +#[pymethods] +impl Backend { + fn __repr__(&self) -> &'static str { + match self { + Backend::Candle => "Backend.CANDLE", + Backend::Onnx => "Backend.ONNX", + Backend::Mock => "Backend.MOCK", + } + } +} + /// Full intent classification result #[pyclass] #[derive(Clone)] @@ -423,29 +536,78 @@ impl AgentEngine { /// Args: /// db_path: Path to SQLite database. Uses in-memory if None or ":memory:". /// mock: If True, uses mock embeddings for testing (no model download required). + /// model: Embedding model to use (default: BgeSmall). + /// backend: Backend for inference (default: Candle, or Onnx if available). + /// + /// Example: + /// # Default configuration + /// engine = AgentEngine() + /// + /// # With specific model and backend + /// engine = AgentEngine(model=EmbeddingModel.BgeBase, backend=Backend.Onnx) + /// + /// # Mock mode for testing + /// engine = AgentEngine(mock=True) #[new] - #[pyo3(signature = (db_path=None, mock=false))] - fn new(db_path: Option, mock: bool) -> PyResult { - let path = db_path.as_deref().unwrap_or(":memory:"); + #[pyo3(signature = (db_path=None, mock=false, model=None, backend=None))] + fn new( + db_path: Option, + mock: bool, + model: Option, + backend: Option, + ) -> PyResult { + let mut builder = RustEngine::builder(); + + // Set database path or in-memory + if let Some(path) = db_path { + if path == ":memory:" { + builder = builder.in_memory(); + } else { + builder = builder.db_path(path); + } + } else { + builder = builder.in_memory(); + } - let engine = if mock { - RustEngine::new_mock(path).map_err(to_py_err)? + // Set mock mode + if mock { + builder = builder.mock(); } else { - RustEngine::new(path).map_err(to_py_err)? - }; + // Set model if specified + if let Some(m) = model { + builder = builder.model(m.into()); + } + + // Set backend if specified + if let Some(b) = backend { + builder = builder.backend(b.into()); + } + } + let engine = builder.build().map_err(to_py_err)?; Ok(AgentEngine { engine }) } /// Create an in-memory engine (shorthand for AgentEngine(":memory:")) #[staticmethod] - fn in_memory() -> PyResult { - let engine = RustEngine::new_in_memory().map_err(to_py_err)?; + #[pyo3(signature = (model=None, backend=None))] + fn in_memory(model: Option, backend: Option) -> PyResult { + let mut builder = RustEngine::builder().in_memory(); + + if let Some(m) = model { + builder = builder.model(m.into()); + } + if let Some(b) = backend { + builder = builder.backend(b.into()); + } + + let engine = builder.build().map_err(to_py_err)?; Ok(AgentEngine { engine }) } /// Create a mock engine for testing (no model download required) #[staticmethod] + #[pyo3(signature = (db_path=None))] fn mock(db_path: Option) -> PyResult { let path = db_path.as_deref().unwrap_or(":memory:"); let engine = RustEngine::new_mock(path).map_err(to_py_err)?; @@ -457,6 +619,21 @@ impl AgentEngine { self.engine.is_mock() } + /// Returns the configured embedding model + fn model(&self) -> EmbeddingModel { + self.engine.model().into() + } + + /// Returns the configured backend + fn backend(&self) -> Backend { + self.engine.backend().into() + } + + /// Returns the embedding dimension for the configured model + fn embedding_dim(&self) -> usize { + self.engine.embedding_dim() + } + /// Process natural language input /// /// Automatically detects whether the input is a store or query operation. diff --git a/packages/agent-state-py/tests/test_engine.py b/packages/agent-state-py/tests/test_engine.py index 54c745a..f2d17f5 100644 --- a/packages/agent-state-py/tests/test_engine.py +++ b/packages/agent-state-py/tests/test_engine.py @@ -14,6 +14,8 @@ QueryResultResponse, NotFoundResponse, NeedsClarificationResponse, + EmbeddingModel, + Backend, ) @@ -215,3 +217,70 @@ def test_intent_overall_confidence(self): confidence = intent.overall_confidence() assert isinstance(confidence, float) assert 0.0 <= confidence <= 1.0 + + +class TestEmbeddingModel: + """Tests for the EmbeddingModel enum.""" + + def test_embedding_model_variants(self): + """Test all EmbeddingModel variants exist.""" + assert EmbeddingModel.MiniLmL6 is not None + assert EmbeddingModel.MiniLmL12 is not None + assert EmbeddingModel.BgeSmall is not None + assert EmbeddingModel.BgeBase is not None + assert EmbeddingModel.E5Small is not None + assert EmbeddingModel.GteSmall is not None + + def test_embedding_dim(self): + """Test embedding_dim() method returns expected dimensions.""" + assert EmbeddingModel.MiniLmL6.embedding_dim() == 384 + assert EmbeddingModel.BgeSmall.embedding_dim() == 384 + assert EmbeddingModel.BgeBase.embedding_dim() == 768 + + def test_mteb_score(self): + """Test mteb_score() method returns valid scores.""" + score = EmbeddingModel.BgeSmall.mteb_score() + assert isinstance(score, float) + assert 50.0 < score < 70.0 # Reasonable MTEB score range + + def test_hf_repo(self): + """Test hf_repo() method returns valid HuggingFace repo names.""" + repo = EmbeddingModel.BgeSmall.hf_repo() + assert isinstance(repo, str) + assert "bge-small" in repo.lower() + + +class TestBackend: + """Tests for the Backend enum.""" + + def test_backend_variants(self): + """Test all Backend variants exist.""" + assert Backend.Candle is not None + assert Backend.Onnx is not None + assert Backend.Mock is not None + + +class TestModelConfiguration: + """Tests for model/backend configuration in AgentEngine.""" + + def test_create_engine_with_model_param(self): + """Test creating engine with model parameter.""" + # Note: Using mock=True still, as we're testing the parameter passing + engine = AgentEngine(mock=True) + assert engine.model() == EmbeddingModel.BgeSmall # Default model + + def test_mock_engine_reports_mock_backend(self): + """Test that mock engine reports Mock backend.""" + engine = AgentEngine(mock=True) + assert engine.backend() == Backend.Mock + + def test_engine_embedding_dim(self): + """Test that engine reports correct embedding dimension.""" + engine = AgentEngine(mock=True) + assert engine.embedding_dim() == 384 # BgeSmall default + + def test_in_memory_factory_method(self): + """Test the in_memory factory method works with mock.""" + engine = AgentEngine.mock() + assert engine.is_mock() + assert engine.count() == 0 diff --git a/packages/agent-state-rs/.cargo/config.toml b/packages/agent-state-rs/.cargo/config.toml new file mode 100644 index 0000000..39fa800 --- /dev/null +++ b/packages/agent-state-rs/.cargo/config.toml @@ -0,0 +1,14 @@ +# CPU Optimization Configuration +# Enables native CPU features (AVX, AVX2, SSE4.2, etc.) for maximum performance +# This can provide 2-3x speedup for embedding generation + +[build] +# Use native CPU features for optimal SIMD performance +rustflags = ["-C", "target-cpu=native"] + +# Environment variables for optimal performance +[env] +# Enable OpenMP parallelism for ONNX Runtime +OMP_NUM_THREADS = "4" +# Disable tokenizer parallelism (can cause issues with ONNX) +TOKENIZERS_PARALLELISM = "false" diff --git a/packages/agent-state-rs/Cargo.lock b/packages/agent-state-rs/Cargo.lock index e9db5d4..cf9fbeb 100644 --- a/packages/agent-state-rs/Cargo.lock +++ b/packages/agent-state-rs/Cargo.lock @@ -10,7 +10,7 @@ checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "agent-brain" -version = "0.2.0" +version = "0.3.0" dependencies = [ "anyhow", "bytemuck", @@ -18,6 +18,8 @@ dependencies = [ "candle-nn", "candle-transformers", "hf-hub", + "ndarray 0.16.1", + "ort", "rusqlite", "serde", "serde_json", @@ -78,6 +80,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" + [[package]] name = "bit-set" version = "0.5.3" @@ -137,6 +145,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" + [[package]] name = "candle-core" version = "0.8.4" @@ -313,6 +327,16 @@ dependencies = [ "syn", ] +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + [[package]] name = "derive_arbitrary" version = "1.4.2" @@ -862,9 +886,31 @@ dependencies = [ "serde", "serde_json", "thiserror", - "ureq", + "ureq 2.12.1", +] + +[[package]] +name = "hmac-sha256" +version = "1.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad6880c8d4a9ebf39c6e8b77007ce223f646a4d21ce29d99f70cb16420545425" + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", ] +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + [[package]] name = "icu_collections" version = "2.1.1" @@ -1097,6 +1143,12 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lzma-rust2" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1670343e58806300d87950e3401e820b519b9384281bbabfb15e3636689ffd69" + [[package]] name = "macro_rules_attribute" version = "0.2.2" @@ -1113,6 +1165,16 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.6" @@ -1184,6 +1246,36 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "nom" version = "7.1.3" @@ -1385,12 +1477,45 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ort" +version = "2.0.0-rc.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5df903c0d2c07b56950f1058104ab0c8557159f2741782223704de9be73c3c" +dependencies = [ + "ndarray 0.17.2", + "ort-sys", + "smallvec", + "tracing", + "ureq 3.1.4", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06503bb33f294c5f1ba484011e053bfa6ae227074bdb841e9863492dc5960d4b" +dependencies = [ + "hmac-sha256", + "lzma-rust2", + "ureq 3.1.4", +] + [[package]] name = "paste" version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -1415,6 +1540,15 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.4" @@ -1579,6 +1713,12 @@ dependencies = [ "bitflags 2.10.0", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.11.0" @@ -1865,6 +2005,17 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "spm_precompiled" version = "0.1.4" @@ -2160,6 +2311,36 @@ dependencies = [ "webpki-roots 0.26.11", ] +[[package]] +name = "ureq" +version = "3.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a" +dependencies = [ + "base64 0.22.1", + "der", + "log", + "native-tls", + "percent-encoding", + "rustls-pki-types", + "socks", + "ureq-proto", + "utf-8", + "webpki-root-certs", +] + +[[package]] +name = "ureq-proto" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" +dependencies = [ + "base64 0.22.1", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.8" @@ -2172,6 +2353,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -2270,6 +2457,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-root-certs" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36a29fc0408b113f68cf32637857ab740edfafdf460c326cd2afaa2d84cc05dc" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webpki-roots" version = "0.26.11" @@ -2288,6 +2484,22 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.11" @@ -2297,6 +2509,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-link" version = "0.2.1" diff --git a/packages/agent-state-rs/Cargo.toml b/packages/agent-state-rs/Cargo.toml index 2245bab..55765ee 100644 --- a/packages/agent-state-rs/Cargo.toml +++ b/packages/agent-state-rs/Cargo.toml @@ -1,14 +1,28 @@ [package] name = "agent-brain" -version = "0.2.0" +version = "0.3.0" edition = "2021" description = "AgentState - A semantic state engine for AI agents with unified intent-based API" +[features] +default = ["candle-backend"] +# Candle backend (default) - uses HuggingFace Candle for inference +candle-backend = ["candle-core", "candle-nn", "candle-transformers"] +# ONNX Runtime backend - faster inference, better CPU optimization +onnx-backend = ["ort"] +# Enable both backends for runtime selection +all-backends = ["candle-backend", "onnx-backend"] + [dependencies] -# The AI Stack (HuggingFace Candle) -candle-core = "0.8" -candle-nn = "0.8" -candle-transformers = "0.8" +# The AI Stack - Candle (optional, default) +candle-core = { version = "0.8", optional = true } +candle-nn = { version = "0.8", optional = true } +candle-transformers = { version = "0.8", optional = true } + +# The AI Stack - ONNX Runtime (optional, for better performance) +ort = { version = "2.0.0-rc.11", optional = true } + +# Shared AI dependencies tokenizers = "0.20" hf-hub = "0.3" @@ -20,6 +34,7 @@ anyhow = "1.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" bytemuck = { version = "1.14", features = ["derive"] } +ndarray = "0.16" [lib] name = "agent_brain" diff --git a/packages/agent-state-rs/examples/benchmark.rs b/packages/agent-state-rs/examples/benchmark.rs new file mode 100644 index 0000000..33c30c8 --- /dev/null +++ b/packages/agent-state-rs/examples/benchmark.rs @@ -0,0 +1,301 @@ +//! Model Benchmark - Compare performance and accuracy of different embedding models +//! +//! This benchmark helps you choose the right model for your use case by measuring: +//! 1. Embedding generation speed +//! 2. Classification accuracy +//! 3. Search relevance +//! +//! Run with: cargo run --release --example benchmark + +use agent_brain::{AgentEngine, AgentResponse, Backend, EmbeddingModel, Action, DataType}; +use std::time::{Duration, Instant}; + +/// Test cases for classification accuracy +const CLASSIFICATION_TESTS: &[(&str, Action, DataType)] = &[ + // Store - Tasks + ("Remind me to call John tomorrow", Action::Store, DataType::Task), + ("Buy groceries from the store", Action::Store, DataType::Task), + ("Schedule dentist appointment", Action::Store, DataType::Task), + ("Don't forget to pay the rent", Action::Store, DataType::Task), + ("Pick up dry cleaning", Action::Store, DataType::Task), + + // Store - Memories + ("My favorite color is blue", Action::Store, DataType::Memory), + ("The capital of France is Paris", Action::Store, DataType::Memory), + ("John's phone number is 555-1234", Action::Store, DataType::Memory), + ("My birthday is March 15th", Action::Store, DataType::Memory), + ("The project deadline is next Friday", Action::Store, DataType::Memory), + + // Store - Preferences + ("I prefer dark mode in all apps", Action::Store, DataType::Preference), + ("I like my coffee black", Action::Store, DataType::Preference), + ("I enjoy hiking on weekends", Action::Store, DataType::Preference), + ("I hate early morning meetings", Action::Store, DataType::Preference), + + // Store - Relationships + ("John is my colleague", Action::Store, DataType::Relationship), + ("Sarah works at Google", Action::Store, DataType::Relationship), + ("Alice is my sister", Action::Store, DataType::Relationship), + + // Store - Events + ("Team meeting tomorrow at 3pm", Action::Store, DataType::Event), + ("Birthday party next Saturday", Action::Store, DataType::Event), + ("Conference call on Monday morning", Action::Store, DataType::Event), + + // Query - various + ("What is my name?", Action::Query, DataType::Memory), + ("What's my favorite color?", Action::Query, DataType::Preference), + ("Who should I call?", Action::Query, DataType::Task), + ("What do I need to do today?", Action::Query, DataType::Task), + ("Find my scheduled meetings", Action::Query, DataType::Event), + ("Who works at Google?", Action::Query, DataType::Relationship), + ("What are my preferences?", Action::Query, DataType::Preference), + ("Tell me about John", Action::Query, DataType::Memory), + ("Search for tasks", Action::Query, DataType::Task), + ("List my reminders", Action::Query, DataType::Task), +]; + +/// Test cases for semantic search relevance +const SEARCH_TESTS: &[(&str, &[&str], &str)] = &[ + ( + "What color do I like?", + &["My favorite color is blue", "I like green apples"], + "My favorite color is blue", + ), + ( + "Who should I contact?", + &["Call John tomorrow", "Send email to Sarah", "The phone is ringing"], + "Call John tomorrow", + ), + ( + "What meetings do I have?", + &["Team meeting at 3pm", "I met John yesterday", "The conference room is booked"], + "Team meeting at 3pm", + ), +]; + +fn main() { + println!("╔══════════════════════════════════════════════════════════════════╗"); + println!("║ AgentState Model Benchmark - Performance & Accuracy ║"); + println!("╚══════════════════════════════════════════════════════════════════╝"); + println!(); + + // Run mock mode benchmark (always available, fast) + println!("═══════════════════════════════════════════════════════════════════"); + println!("MOCK MODE BENCHMARK (baseline, no ML model)"); + println!("═══════════════════════════════════════════════════════════════════"); + let _ = run_benchmark_mock("Mock (hash-based)"); + + // Print model comparison table + println!(); + println!("═══════════════════════════════════════════════════════════════════"); + println!("AVAILABLE MODELS COMPARISON"); + println!("═══════════════════════════════════════════════════════════════════"); + println!(); + println!("| {:20} | {:6} | {:10} | {:30} |", "Model", "Dims", "MTEB Score", "Use Case"); + println!("|{:-<22}|{:-<8}|{:-<12}|{:-<32}|", "", "", "", ""); + + for model in &[ + EmbeddingModel::MiniLmL6, + EmbeddingModel::MiniLmL12, + EmbeddingModel::BgeSmall, + EmbeddingModel::BgeBase, + EmbeddingModel::E5Small, + EmbeddingModel::GteSmall, + ] { + let use_case = match model { + EmbeddingModel::MiniLmL6 => "Fast, resource-constrained", + EmbeddingModel::MiniLmL12 => "Better accuracy, same size", + EmbeddingModel::BgeSmall => "Best small model (default)", + EmbeddingModel::BgeBase => "Best accuracy overall", + EmbeddingModel::E5Small => "Good alternative to BGE", + EmbeddingModel::GteSmall => "Competitive small model", + }; + println!( + "| {:20} | {:6} | {:10.1} | {:30} |", + format!("{:?}", model), + model.embedding_dim(), + model.mteb_score(), + use_case + ); + } + println!(); + + // Try to run real model benchmarks if models are available + println!("═══════════════════════════════════════════════════════════════════"); + println!("REAL MODEL BENCHMARKS"); + println!("═══════════════════════════════════════════════════════════════════"); + println!(); + println!("Attempting to load models from HuggingFace Hub cache..."); + println!("(First run may download ~90MB per model)"); + println!(); + + // Try each model config using the builder pattern + let configs = [ + (EmbeddingModel::MiniLmL6, Backend::Candle, "MiniLM-L6 (Candle)"), + (EmbeddingModel::BgeSmall, Backend::Candle, "BGE-Small (Candle)"), + ]; + + for (model, backend, name) in configs { + println!("-------------------------------------------------------------------"); + if let Err(e) = run_benchmark(model, backend, name) { + println!(" Skipped: {} ({})", name, e); + } + } + + println!(); + println!("═══════════════════════════════════════════════════════════════════"); + println!("BENCHMARK COMPLETE"); + println!("═══════════════════════════════════════════════════════════════════"); + println!(); + println!("Tips for better performance:"); + println!(" 1. Use --release mode: cargo run --release --example benchmark"); + println!(" 2. Download models locally to avoid network latency"); + println!(" 3. Use BGE-Small for best accuracy with 384 dimensions"); + println!(" 4. Use BGE-Base for maximum accuracy (768 dimensions)"); + println!(); +} + +fn run_benchmark_mock(name: &str) -> Result<(), String> { + println!(); + println!("Benchmarking: {}", name); + println!(" Model: Mock, Backend: Mock, Dims: 384"); + println!(); + + // Initialize engine using builder pattern + let init_start = Instant::now(); + let mut engine = AgentEngine::builder() + .in_memory() + .mock() + .build() + .map_err(|e| e.to_string())?; + let init_time = init_start.elapsed(); + println!(" Initialization: {:?}", init_time); + + // Run benchmarks + run_benchmark_suite(&mut engine) +} + +fn run_benchmark(model: EmbeddingModel, backend: Backend, name: &str) -> Result<(), String> { + println!(); + println!("Benchmarking: {}", name); + println!(" Model: {:?}, Backend: {:?}, Dims: {}", + model, backend, model.embedding_dim()); + println!(); + + // Initialize engine using builder pattern + let init_start = Instant::now(); + let mut engine = AgentEngine::builder() + .in_memory() + .model(model) + .backend(backend) + .build() + .map_err(|e| e.to_string())?; + let init_time = init_start.elapsed(); + println!(" Initialization: {:?}", init_time); + + // Run benchmarks + run_benchmark_suite(&mut engine) +} + +fn run_benchmark_suite(engine: &mut AgentEngine) -> Result<(), String> { + // Benchmark embedding generation + let embed_times = benchmark_embeddings(engine); + println!(" Embedding (first): {:?}", embed_times.0); + println!(" Embedding (cached): {:?}", embed_times.1); + println!(" Embedding (avg 10): {:?}", embed_times.2); + + // Benchmark classification accuracy + let (action_acc, type_acc) = benchmark_classification(engine); + println!(" Action accuracy: {:.1}%", action_acc * 100.0); + println!(" DataType accuracy: {:.1}%", type_acc * 100.0); + + // Benchmark search relevance + let search_acc = benchmark_search(engine); + println!(" Search relevance: {:.1}%", search_acc * 100.0); + + Ok(()) +} + +fn benchmark_embeddings(engine: &mut AgentEngine) -> (Duration, Duration, Duration) { + let test_text = "This is a test sentence for embedding generation."; + + // First embedding (cold) + engine.classify(test_text).ok(); // Warm up (loads anchors) + let start = Instant::now(); + engine.store(test_text).ok(); + let first = start.elapsed(); + + // Cached embedding + let start = Instant::now(); + engine.search(test_text).ok(); + let cached = start.elapsed(); + + // Average over 10 different texts + let texts = [ + "Hello world", + "How are you today?", + "The quick brown fox", + "Machine learning is great", + "Rust programming language", + "Database optimization techniques", + "Natural language processing", + "Semantic search engine", + "Vector embeddings work well", + "Classification algorithms", + ]; + + let start = Instant::now(); + for text in texts { + engine.classify(text).ok(); + } + let avg = start.elapsed() / 10; + + (first, cached, avg) +} + +fn benchmark_classification(engine: &mut AgentEngine) -> (f32, f32) { + let mut action_correct = 0; + let mut type_correct = 0; + let total = CLASSIFICATION_TESTS.len(); + + for (text, expected_action, expected_type) in CLASSIFICATION_TESTS { + if let Ok(intent) = engine.classify(text) { + if intent.action == *expected_action { + action_correct += 1; + } + if intent.data_type == *expected_type { + type_correct += 1; + } + } + } + + ( + action_correct as f32 / total as f32, + type_correct as f32 / total as f32, + ) +} + +fn benchmark_search(engine: &mut AgentEngine) -> f32 { + let mut correct = 0; + let total = SEARCH_TESTS.len(); + + for (query, corpus, expected_top) in SEARCH_TESTS { + // Store all corpus items + for item in *corpus { + engine.store(item).ok(); + } + + // Search and check if expected is top result + if let Ok(AgentResponse::QueryResult { results, .. }) = engine.search(query) { + if !results.is_empty() && results[0] == *expected_top { + correct += 1; + } + } + + // Clear for next test + engine.clear().ok(); + } + + correct as f32 / total as f32 +} diff --git a/packages/agent-state-rs/src/brain.rs b/packages/agent-state-rs/src/brain.rs index 481ea37..ad112d3 100644 --- a/packages/agent-state-rs/src/brain.rs +++ b/packages/agent-state-rs/src/brain.rs @@ -1,46 +1,248 @@ //! Intelligence Layer - Handles AI model loading, embedding generation, and intent classification //! -//! This module implements zero-shot intent routing using a MiniLM sentence transformer model. +//! This module implements zero-shot intent routing using sentence transformer models. //! It classifies input text by action (Store vs Query) and by data type (Task vs Memory) //! by comparing embeddings against pre-computed anchor vectors. +//! +//! ## Supported Models +//! +//! | Model | Dimensions | MTEB Score | Notes | +//! |-------|------------|------------|-------| +//! | `MiniLM-L6` | 384 | 56.3 | Fast, decent accuracy (default) | +//! | `MiniLM-L12` | 384 | 59.8 | Better accuracy, same dims | +//! | `BGE-Small` | 384 | 62.2 | Best small model | +//! | `BGE-Base` | 768 | 64.2 | Best accuracy/size ratio | +//! +//! ## Backends +//! +//! - **Candle** (default): Pure Rust, no external dependencies +//! - **ONNX Runtime**: Faster inference, better CPU optimization use anyhow::{Context, Result}; +use std::collections::{HashMap, VecDeque}; +use tokenizers::Tokenizer; + +#[cfg(feature = "candle-backend")] use candle_core::{DType, Device, Tensor}; +#[cfg(feature = "candle-backend")] use candle_nn::VarBuilder; +#[cfg(feature = "candle-backend")] use candle_transformers::models::bert::{BertModel, Config}; + +#[cfg(feature = "onnx-backend")] +use ort::{GraphOptimizationLevel, Session}; + use hf_hub::{api::sync::Api, Repo, RepoType}; -use std::collections::{HashMap, VecDeque}; -use tokenizers::Tokenizer; /// Maximum number of embeddings to cache (LRU eviction when exceeded) const EMBEDDING_CACHE_SIZE: usize = 1000; +/// Available embedding models with their HuggingFace repo names +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EmbeddingModel { + /// MiniLM-L6-v2: Fast, 384 dimensions, MTEB 56.3 + MiniLmL6, + /// MiniLM-L12-v2: Better accuracy, 384 dimensions, MTEB 59.8 + MiniLmL12, + /// BGE-Small-EN-v1.5: Best small model, 384 dimensions, MTEB 62.2 + BgeSmall, + /// BGE-Base-EN-v1.5: Best accuracy/size, 768 dimensions, MTEB 64.2 + BgeBase, + /// E5-Small-v2: Microsoft model, 384 dimensions, MTEB 61.5 + E5Small, + /// GTE-Small: Alibaba model, 384 dimensions, MTEB 61.4 + GteSmall, +} + +impl EmbeddingModel { + /// Returns the HuggingFace repository name for this model + pub fn hf_repo(&self) -> &'static str { + match self { + EmbeddingModel::MiniLmL6 => "sentence-transformers/all-MiniLM-L6-v2", + EmbeddingModel::MiniLmL12 => "sentence-transformers/all-MiniLM-L12-v2", + EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5", + EmbeddingModel::BgeBase => "BAAI/bge-base-en-v1.5", + EmbeddingModel::E5Small => "intfloat/e5-small-v2", + EmbeddingModel::GteSmall => "thenlper/gte-small", + } + } + + /// Returns the embedding dimension for this model + pub fn embedding_dim(&self) -> usize { + match self { + EmbeddingModel::MiniLmL6 => 384, + EmbeddingModel::MiniLmL12 => 384, + EmbeddingModel::BgeSmall => 384, + EmbeddingModel::BgeBase => 768, + EmbeddingModel::E5Small => 384, + EmbeddingModel::GteSmall => 384, + } + } + + /// Returns the ONNX model filename if available + pub fn onnx_filename(&self) -> &'static str { + "model.onnx" + } + + /// Returns true if this model requires query prefixing (BGE, E5 models) + pub fn needs_query_prefix(&self) -> bool { + matches!(self, EmbeddingModel::BgeSmall | EmbeddingModel::BgeBase | EmbeddingModel::E5Small) + } + + /// Returns the query prefix for models that need it + pub fn query_prefix(&self) -> &'static str { + match self { + EmbeddingModel::BgeSmall | EmbeddingModel::BgeBase => "Represent this sentence for searching relevant passages: ", + EmbeddingModel::E5Small => "query: ", + _ => "", + } + } + + /// Returns the MTEB score for this model (higher is better) + pub fn mteb_score(&self) -> f32 { + match self { + EmbeddingModel::MiniLmL6 => 56.3, + EmbeddingModel::MiniLmL12 => 59.8, + EmbeddingModel::BgeSmall => 62.2, + EmbeddingModel::BgeBase => 64.2, + EmbeddingModel::E5Small => 61.5, + EmbeddingModel::GteSmall => 61.4, + } + } +} + +impl Default for EmbeddingModel { + fn default() -> Self { + // Default to BGE-Small for better accuracy with same dimensions + EmbeddingModel::BgeSmall + } +} + +/// Backend for embedding computation +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Backend { + /// Candle (pure Rust, default) + Candle, + /// ONNX Runtime (faster, better CPU optimization) + Onnx, + /// Mock mode (for testing without models) + Mock, +} + +impl Default for Backend { + fn default() -> Self { + #[cfg(feature = "onnx-backend")] + return Backend::Onnx; + #[cfg(not(feature = "onnx-backend"))] + return Backend::Candle; + } +} + +/// Configuration for Brain initialization - use builder pattern +#[derive(Debug, Clone)] +pub struct BrainConfig { + /// Which embedding model to use + pub model: EmbeddingModel, + /// Which backend to use for inference + pub backend: Backend, + /// Optional local model directory (bypasses HuggingFace download) + pub local_model_dir: Option, +} + +impl Default for BrainConfig { + fn default() -> Self { + Self { + model: EmbeddingModel::default(), + backend: Backend::default(), + local_model_dir: None, + } + } +} + +impl BrainConfig { + /// Create a new configuration builder + pub fn builder() -> BrainConfigBuilder { + BrainConfigBuilder::default() + } +} + +/// Builder for BrainConfig - provides fluent API for configuration +#[derive(Debug, Clone, Default)] +pub struct BrainConfigBuilder { + model: Option, + backend: Option, + local_model_dir: Option, + mock_mode: bool, +} + +impl BrainConfigBuilder { + /// Set the embedding model to use + pub fn model(mut self, model: EmbeddingModel) -> Self { + self.model = Some(model); + self + } + + /// Set the backend for inference + pub fn backend(mut self, backend: Backend) -> Self { + self.backend = Some(backend); + self + } + + /// Set a local model directory (bypasses HuggingFace download) + pub fn local_model_dir>(mut self, path: P) -> Self { + self.local_model_dir = Some(path.into()); + self + } + + /// Enable mock mode for testing (no real model needed) + pub fn mock(mut self) -> Self { + self.mock_mode = true; + self + } + + /// Build the configuration + pub fn build(self) -> BrainConfig { + if self.mock_mode { + BrainConfig { + model: self.model.unwrap_or_default(), + backend: Backend::Mock, + local_model_dir: None, + } + } else { + BrainConfig { + model: self.model.unwrap_or_default(), + backend: self.backend.unwrap_or_default(), + local_model_dir: self.local_model_dir, + } + } + } +} + /// The Brain handles all AI-related operations including embedding generation and intent classification pub struct Brain { - /// Model for real mode (None in mock mode) - model: Option, - /// Tokenizer for real mode (None in mock mode) + /// Configuration used to create this brain + config: BrainConfig, + /// Tokenizer (shared between backends) tokenizer: Option, - device: Device, - /// Whether running in mock mode (hash-based embeddings for testing) - mock_mode: bool, - /// Anchor vector representing "Task" data type - anchor_task: Tensor, - /// Anchor vector representing "Memory" data type - anchor_memory: Tensor, - /// Anchor vector representing "Preference" data type - anchor_preference: Tensor, - /// Anchor vector representing "Relationship" data type - anchor_relationship: Tensor, - /// Anchor vector representing "Event" data type - anchor_event: Tensor, - /// Anchor vector representing "Store" action - anchor_store: Tensor, - /// Anchor vector representing "Query" action - anchor_query: Tensor, + /// Candle model (if using Candle backend) + #[cfg(feature = "candle-backend")] + candle_model: Option, + #[cfg(feature = "candle-backend")] + candle_device: Device, + /// ONNX session (if using ONNX backend) + #[cfg(feature = "onnx-backend")] + onnx_session: Option, + /// Anchor vectors for classification (stored as Vec for flexibility) + anchor_task: Vec, + anchor_memory: Vec, + anchor_preference: Vec, + anchor_relationship: Vec, + anchor_event: Vec, + anchor_store: Vec, + anchor_query: Vec, /// Cache for computed embeddings (text -> embedding vector as f32 array) embedding_cache: HashMap>, - /// Order of cache insertions for LRU eviction (VecDeque for O(1) removal from front) + /// Order of cache insertions for LRU eviction cache_order: VecDeque, } @@ -54,25 +256,17 @@ pub enum Action { } /// The type of data being stored or queried -/// -/// Extended data types enable richer semantic classification for AI agents. -/// The engine auto-detects the most appropriate type based on content. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum DataType { - /// Action items, reminders, todos - things that need to be done - /// Examples: "Remind me to call John", "I need to buy groceries" + /// Action items, reminders, todos Task, - /// General facts and information to remember - /// Examples: "The capital of France is Paris", "John's phone number is 555-1234" + /// General facts and information Memory, /// User preferences and likes/dislikes - /// Examples: "I prefer dark mode", "My favorite color is blue" Preference, - /// Relationships between entities (people, things, concepts) - /// Examples: "John is my colleague", "Alice works at Acme Corp" + /// Relationships between entities Relationship, /// Time-based events and appointments - /// Examples: "Meeting tomorrow at 3pm", "Birthday party on Saturday" Event, } @@ -127,9 +321,6 @@ pub struct Intent { impl Intent { /// Returns true if the classification is ambiguous (low confidence) - /// - /// An ambiguous intent means the engine isn't sure what the agent wants. - /// In a cloud/stateless environment, this should trigger a clarification request. pub fn is_ambiguous(&self) -> bool { self.action_confidence < 0.15 || self.data_type_confidence < 0.15 } @@ -140,9 +331,6 @@ impl Intent { } /// Returns a clarification message if the intent is ambiguous - /// - /// This is the key for cloud-ready stateless design: instead of maintaining - /// context, we return a message telling the agent what we need. pub fn clarification_message(&self) -> Option { if !self.is_ambiguous() { return None; @@ -173,40 +361,46 @@ impl Intent { } impl Brain { - /// Creates a new Brain instance by loading the MiniLM model + /// Creates a new Brain instance with default configuration /// - /// First tries to load from local `models/minilm` directory, then falls back - /// to downloading from HuggingFace Hub (cached in ~/.cache/huggingface). + /// Uses BGE-Small model with the best available backend. pub fn new() -> Result { - // Try local models directory first - let local_model_dir = std::path::Path::new("models/minilm"); - if local_model_dir.exists() { - return Self::new_from_local(local_model_dir); - } + Self::with_config(BrainConfig::default()) + } - // Fall back to HuggingFace Hub download - Self::new_from_huggingface() + /// Creates a new Brain instance with custom configuration + pub fn with_config(config: BrainConfig) -> Result { + match config.backend { + Backend::Mock => Self::new_mock_internal(config), + #[cfg(feature = "candle-backend")] + Backend::Candle => Self::new_candle(config), + #[cfg(feature = "onnx-backend")] + Backend::Onnx => Self::new_onnx(config), + #[cfg(not(feature = "candle-backend"))] + Backend::Candle => anyhow::bail!("Candle backend not enabled. Compile with --features candle-backend"), + #[cfg(not(feature = "onnx-backend"))] + Backend::Onnx => anyhow::bail!("ONNX backend not enabled. Compile with --features onnx-backend"), + } } /// Creates a new Brain in mock mode for testing - /// - /// Uses hash-based deterministic embeddings instead of the ML model. - /// This allows running all the same logic (storage, retrieval, classification) - /// without requiring the actual model files. - /// - /// The mock embeddings are semantically-aware based on keyword detection, - /// providing reasonable classification behavior for testing. pub fn new_mock() -> Result { - let device = Device::Cpu; + Self::with_config(BrainConfig::builder().mock().build()) + } - // Create placeholder tensors for anchors - let placeholder = Tensor::zeros((1, 384), DType::F32, &device)?; + fn new_mock_internal(config: BrainConfig) -> Result { + let dim = config.model.embedding_dim(); + let placeholder = vec![0.0f32; dim]; let mut brain = Self { - model: None, + config, tokenizer: None, - device, - mock_mode: true, + #[cfg(feature = "candle-backend")] + candle_model: None, + #[cfg(feature = "candle-backend")] + candle_device: Device::Cpu, + #[cfg(feature = "onnx-backend")] + onnx_session: None, anchor_task: placeholder.clone(), anchor_memory: placeholder.clone(), anchor_preference: placeholder.clone(), @@ -218,120 +412,79 @@ impl Brain { cache_order: VecDeque::new(), }; - // Initialize anchor vectors using mock embeddings - brain.anchor_task = brain.embed( - "action item, todo, remind me to, deadline, schedule, need to do, task, must complete", - )?; - brain.anchor_memory = brain.embed( - "fact, information, context, background, remember that, note, detail about, knowledge", - )?; - brain.anchor_preference = brain.embed( - "I prefer, I like, my favorite, I enjoy, I want, preference, likes, dislikes, choose", - )?; - brain.anchor_relationship = brain.embed( - "is my, works at, knows, colleague, friend, family, partner, relationship, connected to", - )?; - brain.anchor_event = brain.embed( - "meeting, appointment, event, happening, tomorrow, next week, on date, scheduled for, calendar", - )?; - brain.anchor_store = brain.embed( - "save this, remember, store, add, note that, record, keep track of, my name is, I like", - )?; - brain.anchor_query = brain.embed( - "what is, who is, find, search, look up, tell me about, retrieve, get, show me, list", - )?; - + // Initialize anchors with mock embeddings + brain.initialize_anchors()?; Ok(brain) } - /// Returns whether this Brain is running in mock mode - pub fn is_mock(&self) -> bool { - self.mock_mode - } - - /// Creates a new Brain instance from local model files - /// - /// Expects the directory to contain: config.json, tokenizer.json, model.safetensors - pub fn new_from_local(model_dir: &std::path::Path) -> Result { - let device = Device::Cpu; - - let config_filename = model_dir.join("config.json"); - let tokenizer_filename = model_dir.join("tokenizer.json"); - let weights_filename = model_dir.join("model.safetensors"); - - // Verify files exist - if !config_filename.exists() { - anyhow::bail!("config.json not found in {:?}", model_dir); - } - if !tokenizer_filename.exists() { - anyhow::bail!("tokenizer.json not found in {:?}", model_dir); - } - if !weights_filename.exists() { - anyhow::bail!("model.safetensors not found in {:?}", model_dir); - } - - Self::load_model(device, &config_filename, &tokenizer_filename, &weights_filename) - } - - /// Creates a new Brain instance by downloading from HuggingFace Hub - /// - /// Downloads are cached automatically in ~/.cache/huggingface - pub fn new_from_huggingface() -> Result { + #[cfg(feature = "candle-backend")] + fn new_candle(config: BrainConfig) -> Result { let device = Device::Cpu; + let dim = config.model.embedding_dim(); + + // Get model files + let (config_path, tokenizer_path, weights_path) = if let Some(ref local_dir) = config.local_model_dir { + ( + local_dir.join("config.json"), + local_dir.join("tokenizer.json"), + local_dir.join("model.safetensors"), + ) + } else { + // Try local models directory first + let local_model_dir = std::path::Path::new("models").join(match config.model { + EmbeddingModel::MiniLmL6 => "minilm-l6", + EmbeddingModel::MiniLmL12 => "minilm-l12", + EmbeddingModel::BgeSmall => "bge-small", + EmbeddingModel::BgeBase => "bge-base", + EmbeddingModel::E5Small => "e5-small", + EmbeddingModel::GteSmall => "gte-small", + }); + + if local_model_dir.exists() { + ( + local_model_dir.join("config.json"), + local_model_dir.join("tokenizer.json"), + local_model_dir.join("model.safetensors"), + ) + } else { + // Download from HuggingFace + let api = Api::new().context("Failed to initialize HuggingFace API")?; + let repo = api.repo(Repo::new(config.model.hf_repo().to_string(), RepoType::Model)); + + ( + repo.get("config.json").context("Failed to download config.json")?, + repo.get("tokenizer.json").context("Failed to download tokenizer.json")?, + repo.get("model.safetensors").context("Failed to download model.safetensors")?, + ) + } + }; - let api = Api::new().context("Failed to initialize HuggingFace API")?; - let repo = api.repo(Repo::new( - "sentence-transformers/all-MiniLM-L6-v2".to_string(), - RepoType::Model, - )); - - let config_filename = repo - .get("config.json") - .context("Failed to download config.json")?; - let tokenizer_filename = repo - .get("tokenizer.json") - .context("Failed to download tokenizer.json")?; - let weights_filename = repo - .get("model.safetensors") - .context("Failed to download model.safetensors")?; - - Self::load_model(device, &config_filename, &tokenizer_filename, &weights_filename) - } - - /// Internal method to load model from file paths - fn load_model( - device: Device, - config_filename: &std::path::Path, - tokenizer_filename: &std::path::Path, - weights_filename: &std::path::Path, - ) -> Result { - - // Parse model configuration - let config: Config = serde_json::from_str( - &std::fs::read_to_string(&config_filename) - .context("Failed to read config.json")?, + // Parse config + let bert_config: Config = serde_json::from_str( + &std::fs::read_to_string(&config_path).context("Failed to read config.json")?, ) .context("Failed to parse config.json")?; // Load tokenizer - let tokenizer = Tokenizer::from_file(&tokenizer_filename) + let tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?; - // Load model weights using memory-mapped safetensors for efficiency + // Load model weights let vb = unsafe { - VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device) + VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device) .context("Failed to load model weights")? }; - let model = BertModel::load(vb, &config).context("Failed to initialize BERT model")?; + let model = BertModel::load(vb, &bert_config).context("Failed to initialize BERT model")?; - // Create placeholder tensors for anchors (will be initialized below) - let placeholder = Tensor::zeros((1, 384), DType::F32, &device)?; + let placeholder = vec![0.0f32; dim]; let mut brain = Self { - model: Some(model), + config, tokenizer: Some(tokenizer), - device, - mock_mode: false, + candle_model: Some(model), + candle_device: device, + #[cfg(feature = "onnx-backend")] + onnx_session: None, anchor_task: placeholder.clone(), anchor_memory: placeholder.clone(), anchor_preference: placeholder.clone(), @@ -343,73 +496,196 @@ impl Brain { cache_order: VecDeque::new(), }; - // 2. Initialize Anchor Vectors (Zero-Shot Classification Logic) - // Data type anchors: what kind of information is this? - brain.anchor_task = brain.embed( + brain.initialize_anchors()?; + Ok(brain) + } + + #[cfg(feature = "onnx-backend")] + fn new_onnx(config: BrainConfig) -> Result { + let dim = config.model.embedding_dim(); + + // Get model files + let (tokenizer_path, onnx_path) = if let Some(ref local_dir) = config.local_model_dir { + ( + local_dir.join("tokenizer.json"), + local_dir.join("model.onnx"), + ) + } else { + // Try local models directory first + let local_model_dir = std::path::Path::new("models").join(match config.model { + EmbeddingModel::MiniLmL6 => "minilm-l6", + EmbeddingModel::MiniLmL12 => "minilm-l12", + EmbeddingModel::BgeSmall => "bge-small", + EmbeddingModel::BgeBase => "bge-base", + EmbeddingModel::E5Small => "e5-small", + EmbeddingModel::GteSmall => "gte-small", + }); + + if local_model_dir.join("model.onnx").exists() { + ( + local_model_dir.join("tokenizer.json"), + local_model_dir.join("model.onnx"), + ) + } else { + // Download from HuggingFace + let api = Api::new().context("Failed to initialize HuggingFace API")?; + let repo = api.repo(Repo::new(config.model.hf_repo().to_string(), RepoType::Model)); + + // Try to get ONNX model, fall back to regular model + let onnx_path = repo.get("onnx/model.onnx") + .or_else(|_| repo.get("model.onnx")) + .context("Failed to download ONNX model. Model may not have ONNX export.")?; + + ( + repo.get("tokenizer.json").context("Failed to download tokenizer.json")?, + onnx_path, + ) + } + }; + + // Load tokenizer + let tokenizer = Tokenizer::from_file(&tokenizer_path) + .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?; + + // Initialize ONNX Runtime session with optimizations + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(4)? + .commit_from_file(&onnx_path) + .context("Failed to load ONNX model")?; + + let placeholder = vec![0.0f32; dim]; + + let mut brain = Self { + config, + tokenizer: Some(tokenizer), + #[cfg(feature = "candle-backend")] + candle_model: None, + #[cfg(feature = "candle-backend")] + candle_device: Device::Cpu, + onnx_session: Some(session), + anchor_task: placeholder.clone(), + anchor_memory: placeholder.clone(), + anchor_preference: placeholder.clone(), + anchor_relationship: placeholder.clone(), + anchor_event: placeholder.clone(), + anchor_store: placeholder.clone(), + anchor_query: placeholder, + embedding_cache: HashMap::new(), + cache_order: VecDeque::new(), + }; + + brain.initialize_anchors()?; + Ok(brain) + } + + /// Initialize anchor vectors for classification + fn initialize_anchors(&mut self) -> Result<()> { + // Data type anchors + self.anchor_task = self.embed_to_vec( "action item, todo, remind me to, deadline, schedule, need to do, task, must complete", )?; - brain.anchor_memory = brain.embed( + self.anchor_memory = self.embed_to_vec( "fact, information, context, background, remember that, note, detail about, knowledge", )?; - brain.anchor_preference = brain.embed( + self.anchor_preference = self.embed_to_vec( "I prefer, I like, my favorite, I enjoy, I want, preference, likes, dislikes, choose", )?; - brain.anchor_relationship = brain.embed( + self.anchor_relationship = self.embed_to_vec( "is my, works at, knows, colleague, friend, family, partner, relationship, connected to", )?; - brain.anchor_event = brain.embed( + self.anchor_event = self.embed_to_vec( "meeting, appointment, event, happening, tomorrow, next week, on date, scheduled for, calendar", )?; - // Action anchors: what does the agent want to do? - brain.anchor_store = brain.embed( + // Action anchors + self.anchor_store = self.embed_to_vec( "save this, remember, store, add, note that, record, keep track of, my name is, I like", )?; - brain.anchor_query = brain.embed( + self.anchor_query = self.embed_to_vec( "what is, who is, find, search, look up, tell me about, retrieve, get, show me, list", )?; - Ok(brain) + Ok(()) } - /// Converts text into a normalized 384-dimensional embedding vector - /// - /// Uses mean pooling over all token embeddings followed by L2 normalization. - /// The resulting vector is suitable for cosine similarity comparisons. - /// Results are cached for performance. - pub fn embed(&mut self, text: &str) -> Result { + /// Returns whether this Brain is running in mock mode + pub fn is_mock(&self) -> bool { + self.config.backend == Backend::Mock + } + + /// Returns the configured model + pub fn model(&self) -> EmbeddingModel { + self.config.model + } + + /// Returns the configured backend + pub fn backend(&self) -> Backend { + self.config.backend + } + + /// Returns the embedding dimension for the configured model + pub fn embedding_dim(&self) -> usize { + self.config.model.embedding_dim() + } + + /// Converts text into a normalized embedding vector + pub fn embed_to_vec(&mut self, text: &str) -> Result> { // Check cache first if let Some(cached) = self.embedding_cache.get(text) { - return Ok(Tensor::new(cached.as_slice(), &self.device)?.unsqueeze(0)?); + return Ok(cached.clone()); } - // Compute embedding + // Compute embedding based on backend let embedding = self.compute_embedding(text)?; - // Extract values for caching - let values: Vec = embedding.squeeze(0)?.to_vec1()?; - - // Cache with LRU eviction (O(1) removal from front with VecDeque) + // Cache with LRU eviction if self.cache_order.len() >= EMBEDDING_CACHE_SIZE { if let Some(oldest) = self.cache_order.pop_front() { self.embedding_cache.remove(&oldest); } } - self.embedding_cache.insert(text.to_string(), values); + self.embedding_cache.insert(text.to_string(), embedding.clone()); self.cache_order.push_back(text.to_string()); Ok(embedding) } - /// Internal: compute embedding without caching (used by embed()) - fn compute_embedding(&self, text: &str) -> Result { - if self.mock_mode { - return self.compute_mock_embedding(text); + /// Converts text into a Candle tensor (for backward compatibility) + #[cfg(feature = "candle-backend")] + pub fn embed(&mut self, text: &str) -> Result { + let vec = self.embed_to_vec(text)?; + let tensor = Tensor::new(&vec[..], &self.candle_device)?.unsqueeze(0)?; + Ok(tensor) + } + + #[cfg(not(feature = "candle-backend"))] + pub fn embed(&mut self, text: &str) -> Result> { + self.embed_to_vec(text) + } + + fn compute_embedding(&self, text: &str) -> Result> { + match self.config.backend { + Backend::Mock => self.compute_mock_embedding(text), + #[cfg(feature = "candle-backend")] + Backend::Candle => self.compute_candle_embedding(text), + #[cfg(feature = "onnx-backend")] + Backend::Onnx => self.compute_onnx_embedding(text), + #[cfg(not(feature = "candle-backend"))] + Backend::Candle => anyhow::bail!("Candle backend not enabled"), + #[cfg(not(feature = "onnx-backend"))] + Backend::Onnx => anyhow::bail!("ONNX backend not enabled"), } + } - // Tokenize input text + #[cfg(feature = "candle-backend")] + fn compute_candle_embedding(&self, text: &str) -> Result> { let tokenizer = self.tokenizer.as_ref() .ok_or_else(|| anyhow::anyhow!("Tokenizer not available"))?; + let model = self.candle_model.as_ref() + .ok_or_else(|| anyhow::anyhow!("Model not available"))?; + + // Tokenize input text let tokens = tokenizer .encode(text, true) .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?; @@ -418,32 +694,26 @@ impl Brain { let token_type_ids = vec![0u32; token_ids.len()]; let attention_mask = tokens.get_attention_mask(); - // Convert to tensors with batch dimension - let token_ids_tensor = Tensor::new(token_ids, &self.device)?.unsqueeze(0)?; - let token_type_ids_tensor = Tensor::new(token_type_ids.as_slice(), &self.device)?.unsqueeze(0)?; + // Convert to tensors + let token_ids_tensor = Tensor::new(token_ids, &self.candle_device)?.unsqueeze(0)?; + let token_type_ids_tensor = Tensor::new(token_type_ids.as_slice(), &self.candle_device)?.unsqueeze(0)?; - // Run model inference - let model = self.model.as_ref() - .ok_or_else(|| anyhow::anyhow!("Model not available"))?; + // Run inference let embeddings = model .forward(&token_ids_tensor, &token_type_ids_tensor, None) .context("Model forward pass failed")?; - // Mean pooling: average all token embeddings (excluding padding) - let (_batch_size, _n_tokens, _hidden_size) = embeddings.dims3()?; - - // Create attention mask tensor for proper pooling + // Mean pooling with attention mask let attention_mask_tensor = - Tensor::new(attention_mask, &self.device)?.unsqueeze(0)?.unsqueeze(2)?; + Tensor::new(attention_mask, &self.candle_device)?.unsqueeze(0)?.unsqueeze(2)?; let attention_mask_f32 = attention_mask_tensor.to_dtype(DType::F32)?; - // Masked mean pooling let masked_embeddings = embeddings.broadcast_mul(&attention_mask_f32)?; let sum_embeddings = masked_embeddings.sum(1)?; let mask_sum = attention_mask_f32.sum(1)?.clamp(1e-9, f64::MAX)?; let mean_embeddings = sum_embeddings.broadcast_div(&mask_sum)?; - // L2 normalization (critical for cosine similarity) + // L2 normalize let norm = mean_embeddings .sqr()? .sum_keepdim(1)? @@ -451,105 +721,144 @@ impl Brain { .clamp(1e-12, f64::MAX)?; let normalized = mean_embeddings.broadcast_div(&norm)?; - Ok(normalized) + // Convert to Vec + let values: Vec = normalized.squeeze(0)?.to_vec1()?; + Ok(values) + } + + #[cfg(feature = "onnx-backend")] + fn compute_onnx_embedding(&self, text: &str) -> Result> { + let tokenizer = self.tokenizer.as_ref() + .ok_or_else(|| anyhow::anyhow!("Tokenizer not available"))?; + let session = self.onnx_session.as_ref() + .ok_or_else(|| anyhow::anyhow!("ONNX session not available"))?; + + // Tokenize + let tokens = tokenizer + .encode(text, true) + .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?; + + let token_ids: Vec = tokens.get_ids().iter().map(|&x| x as i64).collect(); + let attention_mask: Vec = tokens.get_attention_mask().iter().map(|&x| x as i64).collect(); + let token_type_ids: Vec = vec![0i64; token_ids.len()]; + + let seq_len = token_ids.len(); + + // Create ONNX input arrays + let input_ids = ndarray::Array2::from_shape_vec((1, seq_len), token_ids)?; + let attention = ndarray::Array2::from_shape_vec((1, seq_len), attention_mask.clone())?; + let type_ids = ndarray::Array2::from_shape_vec((1, seq_len), token_type_ids)?; + + // Run inference + let outputs = session.run(ort::inputs![ + "input_ids" => input_ids, + "attention_mask" => attention, + "token_type_ids" => type_ids, + ]?)?; + + // Get output tensor - try different common output names + let output = outputs.get("last_hidden_state") + .or_else(|| outputs.get("output")) + .or_else(|| outputs.get("sentence_embedding")) + .ok_or_else(|| anyhow::anyhow!("Could not find output tensor"))?; + + let output_tensor = output.try_extract_tensor::()?; + let output_view = output_tensor.view(); + + // Mean pooling + let dim = self.config.model.embedding_dim(); + let mut pooled = vec![0.0f32; dim]; + let mut count = 0.0f32; + + for i in 0..seq_len { + if attention_mask[i] == 1 { + for j in 0..dim { + pooled[j] += output_view[[0, i, j]]; + } + count += 1.0; + } + } + + if count > 0.0 { + for v in &mut pooled { + *v /= count; + } + } + + // L2 normalize + let norm: f32 = pooled.iter().map(|x| x * x).sum::().sqrt().max(1e-12); + for v in &mut pooled { + *v /= norm; + } + + Ok(pooled) } /// Compute mock embedding using semantic-aware hash-based approach - /// - /// Creates deterministic 384-dimensional vectors that preserve semantic - /// properties needed for classification and similarity search. - fn compute_mock_embedding(&self, text: &str) -> Result { + fn compute_mock_embedding(&self, text: &str) -> Result> { use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; + let dim = self.config.model.embedding_dim(); let text_lower = text.to_lowercase(); - let mut embedding = vec![0.0f32; 384]; + let mut embedding = vec![0.0f32; dim]; // Base hash for random-like distribution let mut hasher = DefaultHasher::new(); text.hash(&mut hasher); let base_hash = hasher.finish(); - // Fill with pseudo-random values based on text hash - for i in 0..384 { + // Fill with pseudo-random values + for i in 0..dim { let mut h = DefaultHasher::new(); (base_hash, i).hash(&mut h); let val = h.finish(); embedding[i] = ((val % 10000) as f32 / 10000.0) * 2.0 - 1.0; } - // Semantic boosting: adjust specific dimensions based on keywords - // This makes similar texts have similar embeddings - // Use strong boosting (1.5) to create clear semantic clusters for mock mode + // Semantic boosting based on dimension ranges (scaled for different dim sizes) + let scale = dim as f32 / 384.0; + let range_size = (50.0 * scale) as usize; - // Task-related keywords (boost dimensions 0-50) - let task_keywords = ["remind", "todo", "task", "need to", "schedule", "deadline", "call", "buy", "complete", "don't forget", "pick up", "book", "renew"]; - let task_score: f32 = task_keywords.iter() - .filter(|kw| text_lower.contains(*kw)) - .count() as f32 * 1.5; - for i in 0..50 { + // Task keywords + let task_keywords = ["remind", "todo", "task", "need to", "schedule", "deadline", "call", "buy", "complete"]; + let task_score: f32 = task_keywords.iter().filter(|kw| text_lower.contains(*kw)).count() as f32 * 1.5; + for i in 0..range_size.min(dim) { embedding[i] += task_score; } - // Memory/fact keywords (boost dimensions 50-100) - let memory_keywords = ["is", "name", "number", "fact", "information", "capital", "phone", "password", "email", "birthday", "founded", "runs on", "api", "key", "deadline"]; - let memory_score: f32 = memory_keywords.iter() - .filter(|kw| text_lower.contains(*kw)) - .count() as f32 * 1.5; - for i in 50..100 { + // Memory keywords + let memory_keywords = ["is", "name", "number", "fact", "information", "capital", "phone", "birthday"]; + let memory_score: f32 = memory_keywords.iter().filter(|kw| text_lower.contains(*kw)).count() as f32 * 1.5; + for i in range_size..(2 * range_size).min(dim) { embedding[i] += memory_score; } - // Preference keywords (boost dimensions 100-150) - let pref_keywords = ["prefer", "favorite", "like", "enjoy", "love", "hate", "want", "dark mode", "early", "video call"]; - let pref_score: f32 = pref_keywords.iter() - .filter(|kw| text_lower.contains(*kw)) - .count() as f32 * 1.5; - for i in 100..150 { + // Preference keywords + let pref_keywords = ["prefer", "favorite", "like", "enjoy", "love", "hate", "want"]; + let pref_score: f32 = pref_keywords.iter().filter(|kw| text_lower.contains(*kw)).count() as f32 * 1.5; + for i in (2 * range_size)..(3 * range_size).min(dim) { embedding[i] += pref_score; } - // Relationship keywords (boost dimensions 150-200) - let rel_keywords = ["colleague", "friend", "family", "works at", "knows", "is my", "partner", "team lead", "manager", "mentor", "cto", "department"]; - let rel_score: f32 = rel_keywords.iter() - .filter(|kw| text_lower.contains(*kw)) - .count() as f32 * 1.5; - for i in 150..200 { - embedding[i] += rel_score; - } - - // Event keywords (boost dimensions 200-250) - let event_keywords = ["meeting", "appointment", "event", "tomorrow", "next week", "calendar", "party", "every monday", "all-hands", "review", "launch", "holiday"]; - let event_score: f32 = event_keywords.iter() - .filter(|kw| text_lower.contains(*kw)) - .count() as f32 * 1.5; - for i in 200..250 { - embedding[i] += event_score; - } - - // Query keywords (boost dimensions 250-300) - strong boost for clear queries - let query_keywords = ["what", "who", "where", "when", "how", "find", "search", "show", "list", "?", "tell me", "about"]; - let query_score: f32 = query_keywords.iter() - .filter(|kw| text_lower.contains(*kw)) - .count() as f32 * 2.0; // Extra strong for queries - for i in 250..300 { + // Query keywords + let query_keywords = ["what", "who", "where", "when", "how", "find", "search", "show", "list", "?"]; + let query_score: f32 = query_keywords.iter().filter(|kw| text_lower.contains(*kw)).count() as f32 * 2.0; + for i in (5 * range_size)..(6 * range_size).min(dim) { embedding[i] += query_score; } - // Store keywords (boost dimensions 300-350) - declarative statements default to store - let store_keywords = ["remember", "save", "store", "add", "note", "record", "my name is", "i am", "my", "i prefer", "i like", "'s"]; - let store_score: f32 = store_keywords.iter() - .filter(|kw| text_lower.contains(*kw)) - .count() as f32 * 2.0; // Extra strong for stores - for i in 300..350 { + // Store keywords + let store_keywords = ["remember", "save", "store", "add", "note", "record", "my name is", "i am", "my"]; + let store_score: f32 = store_keywords.iter().filter(|kw| text_lower.contains(*kw)).count() as f32 * 2.0; + for i in (6 * range_size)..(7 * range_size).min(dim) { embedding[i] += store_score; } - // Additional heuristic: statements without query markers default to store + // Declarative statements default to store let has_query_marker = query_keywords.iter().any(|kw| text_lower.contains(*kw)); if !has_query_marker && text_lower.len() > 10 { - // Boost store dimensions for declarative statements - for i in 300..350 { + for i in (6 * range_size)..(7 * range_size).min(dim) { embedding[i] += 1.0; } } @@ -560,10 +869,10 @@ impl Brain { *v /= norm; } - Ok(Tensor::new(&embedding[..], &self.device)?.unsqueeze(0)?) + Ok(embedding) } - /// Returns cache statistics (hits can be inferred by caller) + /// Returns cache statistics pub fn cache_size(&self) -> usize { self.embedding_cache.len() } @@ -574,21 +883,31 @@ impl Brain { self.cache_order.clear(); } - /// Classifies the full intent of an input embedding vector - /// - /// Returns both the action (Store/Query) and data type (Task/Memory/Preference/Relationship/Event) - /// by comparing against pre-computed anchor vectors. - /// Also returns confidence scores for cloud-ready stateless operation. - pub fn classify(&self, input_vec: &Tensor) -> Result { - // Determine action: Store or Query? - let score_store = input_vec - .mul(&self.anchor_store)? - .sum_all()? - .to_scalar::()?; - let score_query = input_vec - .mul(&self.anchor_query)? - .sum_all()? - .to_scalar::()?; + /// Cosine similarity between two vectors + fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { + if a.len() != b.len() { + return 0.0; + } + let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); + let mag_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); + let mag_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); + if mag_a < 1e-10 || mag_b < 1e-10 { + return 0.0; + } + dot / (mag_a * mag_b) + } + + /// Classifies the full intent of an input text + pub fn classify_text(&mut self, text: &str) -> Result { + let input_vec = self.embed_to_vec(text)?; + self.classify_vec(&input_vec) + } + + /// Classifies the full intent from a pre-computed embedding vector + pub fn classify_vec(&self, input_vec: &[f32]) -> Result { + // Determine action + let score_store = Self::cosine_similarity(input_vec, &self.anchor_store); + let score_query = Self::cosine_similarity(input_vec, &self.anchor_query); let (action, action_confidence) = if score_query > score_store { (Action::Query, (score_query - score_store).abs()) @@ -596,23 +915,21 @@ impl Brain { (Action::Store, (score_store - score_query).abs()) }; - // Determine data type by comparing against all category anchors + // Determine data type let scores = [ - (DataType::Task, input_vec.mul(&self.anchor_task)?.sum_all()?.to_scalar::()?), - (DataType::Memory, input_vec.mul(&self.anchor_memory)?.sum_all()?.to_scalar::()?), - (DataType::Preference, input_vec.mul(&self.anchor_preference)?.sum_all()?.to_scalar::()?), - (DataType::Relationship, input_vec.mul(&self.anchor_relationship)?.sum_all()?.to_scalar::()?), - (DataType::Event, input_vec.mul(&self.anchor_event)?.sum_all()?.to_scalar::()?), + (DataType::Task, Self::cosine_similarity(input_vec, &self.anchor_task)), + (DataType::Memory, Self::cosine_similarity(input_vec, &self.anchor_memory)), + (DataType::Preference, Self::cosine_similarity(input_vec, &self.anchor_preference)), + (DataType::Relationship, Self::cosine_similarity(input_vec, &self.anchor_relationship)), + (DataType::Event, Self::cosine_similarity(input_vec, &self.anchor_event)), ]; - // Find the highest scoring data type let (best_type, best_score) = scores .iter() .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)) .map(|(dt, s)| (*dt, *s)) .unwrap_or((DataType::Memory, 0.0)); - // Find the second highest score for confidence calculation let second_best_score = scores .iter() .filter(|(dt, _)| *dt != best_type) @@ -630,14 +947,32 @@ impl Brain { }) } - /// Classifies only the data type - returns the best matching category - pub fn classify_data_type(&self, input_vec: &Tensor) -> Result { + /// Backward compatibility: classify from tensor + #[cfg(feature = "candle-backend")] + pub fn classify(&self, input_tensor: &Tensor) -> Result { + let values: Vec = input_tensor.squeeze(0)?.to_vec1()?; + self.classify_vec(&values) + } + + #[cfg(not(feature = "candle-backend"))] + pub fn classify(&self, input_vec: &[f32]) -> Result { + self.classify_vec(input_vec) + } + + /// Classifies only the data type from text + pub fn classify_data_type_text(&mut self, text: &str) -> Result { + let vec = self.embed_to_vec(text)?; + self.classify_data_type_vec(&vec) + } + + /// Classifies only the data type from a vector + pub fn classify_data_type_vec(&self, input_vec: &[f32]) -> Result { let scores = [ - (DataType::Task, input_vec.mul(&self.anchor_task)?.sum_all()?.to_scalar::()?), - (DataType::Memory, input_vec.mul(&self.anchor_memory)?.sum_all()?.to_scalar::()?), - (DataType::Preference, input_vec.mul(&self.anchor_preference)?.sum_all()?.to_scalar::()?), - (DataType::Relationship, input_vec.mul(&self.anchor_relationship)?.sum_all()?.to_scalar::()?), - (DataType::Event, input_vec.mul(&self.anchor_event)?.sum_all()?.to_scalar::()?), + (DataType::Task, Self::cosine_similarity(input_vec, &self.anchor_task)), + (DataType::Memory, Self::cosine_similarity(input_vec, &self.anchor_memory)), + (DataType::Preference, Self::cosine_similarity(input_vec, &self.anchor_preference)), + (DataType::Relationship, Self::cosine_similarity(input_vec, &self.anchor_relationship)), + (DataType::Event, Self::cosine_similarity(input_vec, &self.anchor_event)), ]; let (best_type, _) = scores @@ -649,16 +984,28 @@ impl Brain { Ok(best_type) } - /// Classifies only the action (Store/Query) - pub fn classify_action(&self, input_vec: &Tensor) -> Result { - let score_store = input_vec - .mul(&self.anchor_store)? - .sum_all()? - .to_scalar::()?; - let score_query = input_vec - .mul(&self.anchor_query)? - .sum_all()? - .to_scalar::()?; + /// Backward compatibility + #[cfg(feature = "candle-backend")] + pub fn classify_data_type(&self, input_tensor: &Tensor) -> Result { + let values: Vec = input_tensor.squeeze(0)?.to_vec1()?; + self.classify_data_type_vec(&values) + } + + #[cfg(not(feature = "candle-backend"))] + pub fn classify_data_type(&self, input_vec: &[f32]) -> Result { + self.classify_data_type_vec(input_vec) + } + + /// Classifies only the action from text + pub fn classify_action_text(&mut self, text: &str) -> Result { + let vec = self.embed_to_vec(text)?; + self.classify_action_vec(&vec) + } + + /// Classifies only the action from a vector + pub fn classify_action_vec(&self, input_vec: &[f32]) -> Result { + let score_store = Self::cosine_similarity(input_vec, &self.anchor_store); + let score_query = Self::cosine_similarity(input_vec, &self.anchor_query); if score_query > score_store { Ok(Action::Query) @@ -667,33 +1014,26 @@ impl Brain { } } - /// Returns the embedding dimension (384 for MiniLM) - pub fn embedding_dim(&self) -> usize { - 384 + /// Backward compatibility + #[cfg(feature = "candle-backend")] + pub fn classify_action(&self, input_tensor: &Tensor) -> Result { + let values: Vec = input_tensor.squeeze(0)?.to_vec1()?; + self.classify_action_vec(&values) + } + + #[cfg(not(feature = "candle-backend"))] + pub fn classify_action(&self, input_vec: &[f32]) -> Result { + self.classify_action_vec(input_vec) } /// Generates embeddings for multiple texts in batch - /// - /// More efficient than calling embed() multiple times for bulk operations. - /// Returns vectors in the same order as input texts. pub fn embed_batch(&mut self, texts: &[&str]) -> Result>> { let mut results = Vec::with_capacity(texts.len()); - for text in texts { - let embedding = self.embed(text)?; - let values: Vec = embedding.squeeze(0)?.to_vec1()?; - results.push(values); + results.push(self.embed_to_vec(text)?); } - Ok(results) } - - /// Generates embeddings and returns as raw f32 vectors (skips tensor creation for results) - pub fn embed_to_vec(&mut self, text: &str) -> Result> { - let embedding = self.embed(text)?; - let values: Vec = embedding.squeeze(0)?.to_vec1()?; - Ok(values) - } } #[cfg(test)] @@ -701,64 +1041,49 @@ mod tests { use super::*; #[test] - #[ignore] // Requires model download - fn test_brain_initialization() { - let brain = Brain::new().expect("Brain should initialize"); + fn test_mock_brain() { + let mut brain = Brain::new_mock().expect("Mock brain should initialize"); + assert!(brain.is_mock()); assert_eq!(brain.embedding_dim(), 384); } #[test] - #[ignore] // Requires model download - fn test_embedding_shape() { - let mut brain = Brain::new().expect("Brain should initialize"); - let embedding = brain.embed("Hello world").expect("Embedding should succeed"); - let dims = embedding.dims(); - assert_eq!(dims, &[1, 384]); + fn test_mock_embedding() { + let mut brain = Brain::new_mock().expect("Mock brain should initialize"); + let embedding = brain.embed_to_vec("Hello world").expect("Embedding should succeed"); + assert_eq!(embedding.len(), 384); } #[test] - #[ignore] // Requires model download - fn test_store_task_classification() { - let mut brain = Brain::new().expect("Brain should initialize"); - let embedding = brain - .embed("Remind me to buy groceries tomorrow") - .expect("Embedding should succeed"); - let intent = brain.classify(&embedding).expect("Classification should succeed"); - assert_eq!(intent.action, Action::Store); - assert_eq!(intent.data_type, DataType::Task); - } + fn test_mock_classification() { + let mut brain = Brain::new_mock().expect("Mock brain should initialize"); - #[test] - #[ignore] // Requires model download - fn test_store_memory_classification() { - let mut brain = Brain::new().expect("Brain should initialize"); - let embedding = brain - .embed("My favorite programming language is Rust") - .expect("Embedding should succeed"); - let intent = brain.classify(&embedding).expect("Classification should succeed"); + // Test query detection + let intent = brain.classify_text("What is my name?").expect("Classification should succeed"); + assert_eq!(intent.action, Action::Query); + + // Test store detection + let intent = brain.classify_text("My name is Alice").expect("Classification should succeed"); assert_eq!(intent.action, Action::Store); - assert_eq!(intent.data_type, DataType::Memory); } #[test] - #[ignore] // Requires model download - fn test_query_classification() { - let mut brain = Brain::new().expect("Brain should initialize"); - let embedding = brain - .embed("What is my favorite color?") - .expect("Embedding should succeed"); - let intent = brain.classify(&embedding).expect("Classification should succeed"); - assert_eq!(intent.action, Action::Query); + fn test_embedding_cache() { + let mut brain = Brain::new_mock().expect("Mock brain should initialize"); + + // First call computes + let _ = brain.embed_to_vec("test text").expect("Should work"); + assert_eq!(brain.cache_size(), 8); // 7 anchors + 1 new + + // Second call uses cache + let _ = brain.embed_to_vec("test text").expect("Should work"); + assert_eq!(brain.cache_size(), 8); // Still 8, cache hit } #[test] - #[ignore] // Requires model download - fn test_search_classification() { - let mut brain = Brain::new().expect("Brain should initialize"); - let embedding = brain - .embed("Find all my tasks for today") - .expect("Embedding should succeed"); - let intent = brain.classify(&embedding).expect("Classification should succeed"); - assert_eq!(intent.action, Action::Query); + fn test_model_configs() { + assert_eq!(EmbeddingModel::MiniLmL6.embedding_dim(), 384); + assert_eq!(EmbeddingModel::BgeBase.embedding_dim(), 768); + assert!(EmbeddingModel::BgeBase.mteb_score() > EmbeddingModel::MiniLmL6.mteb_score()); } } diff --git a/packages/agent-state-rs/src/lib.rs b/packages/agent-state-rs/src/lib.rs index 8202fc3..c984da2 100644 --- a/packages/agent-state-rs/src/lib.rs +++ b/packages/agent-state-rs/src/lib.rs @@ -38,7 +38,7 @@ pub mod federation; pub mod metrics; mod storage; -pub use brain::{Action, Brain, DataType, Intent}; +pub use brain::{Action, Backend, Brain, BrainConfig, BrainConfigBuilder, DataType, EmbeddingModel, Intent}; pub use federation::{FederatedEngine, FederationConfig}; pub use metrics::{Metrics, Operation, OperationStats}; pub use storage::{KnowledgeItem, Storage, TimeFilter}; @@ -154,60 +154,194 @@ impl AgentResponse { /// /// This is the primary interface. Agents interact through a single `process()` /// method that automatically understands intent and routes accordingly. +/// +/// # Examples +/// +/// ```no_run +/// use agent_brain::{AgentEngine, EmbeddingModel}; +/// +/// // Default configuration (BGE-Small model) +/// let engine = AgentEngine::new("agent.db")?; +/// +/// // With builder pattern for customization +/// let engine = AgentEngine::builder() +/// .db_path("agent.db") +/// .model(EmbeddingModel::BgeBase) +/// .build()?; +/// +/// // In-memory database for testing +/// let engine = AgentEngine::builder() +/// .in_memory() +/// .mock() // Use mock embeddings (no model download) +/// .build()?; +/// # Ok::<(), anyhow::Error>(()) +/// ``` pub struct AgentEngine { brain: Brain, storage: Storage, metrics: Metrics, } +/// Builder for AgentEngine - provides fluent API for configuration +pub struct AgentEngineBuilder { + db_path: Option, + in_memory: bool, + brain_config: BrainConfigBuilder, + metrics_enabled: bool, +} + +impl Default for AgentEngineBuilder { + fn default() -> Self { + Self { + db_path: None, + in_memory: false, + brain_config: BrainConfigBuilder::default(), + metrics_enabled: true, + } + } +} + +impl AgentEngineBuilder { + /// Set the database file path + pub fn db_path>(mut self, path: S) -> Self { + self.db_path = Some(path.into()); + self.in_memory = false; + self + } + + /// Use an in-memory database (useful for testing) + pub fn in_memory(mut self) -> Self { + self.in_memory = true; + self.db_path = None; + self + } + + /// Set the embedding model to use + pub fn model(mut self, model: EmbeddingModel) -> Self { + self.brain_config = self.brain_config.model(model); + self + } + + /// Set the backend for inference + pub fn backend(mut self, backend: Backend) -> Self { + self.brain_config = self.brain_config.backend(backend); + self + } + + /// Set a local model directory (bypasses HuggingFace download) + pub fn local_model_dir>(mut self, path: P) -> Self { + self.brain_config = self.brain_config.local_model_dir(path); + self + } + + /// Enable mock mode for testing (no real model needed) + pub fn mock(mut self) -> Self { + self.brain_config = self.brain_config.mock(); + self + } + + /// Disable metrics collection + pub fn without_metrics(mut self) -> Self { + self.metrics_enabled = false; + self + } + + /// Build the AgentEngine + pub fn build(self) -> Result { + let db_path = if self.in_memory { + ":memory:".to_string() + } else { + self.db_path.ok_or_else(|| anyhow::anyhow!( + "Database path required. Use .db_path(\"path\") or .in_memory()" + ))? + }; + + let brain_config = self.brain_config.build(); + let brain = Brain::with_config(brain_config).context("Failed to initialize Brain")?; + let storage = Storage::new(&db_path).context("Failed to initialize Storage")?; + let metrics = if self.metrics_enabled { + Metrics::new() + } else { + Metrics::disabled() + }; + + Ok(AgentEngine { brain, storage, metrics }) + } +} + impl AgentEngine { - /// Creates a new AgentEngine instance + /// Creates a new AgentEngine with default configuration /// - /// Downloads and caches the MiniLM model on first run (~90MB). + /// Uses BGE-Small model (best accuracy for 384 dimensions) with the best available backend. + /// Downloads and caches the model on first run (~90MB). /// /// # Arguments /// * `db_path` - Path to the SQLite database file pub fn new(db_path: &str) -> Result { - let brain = Brain::new().context("Failed to initialize Brain")?; - let storage = Storage::new(db_path).context("Failed to initialize Storage")?; - let metrics = Metrics::new(); - Ok(Self { brain, storage, metrics }) + Self::builder().db_path(db_path).build() } - /// Creates an AgentEngine with an in-memory database (useful for testing) - pub fn new_in_memory() -> Result { - Self::new(":memory:") + /// Creates a builder for customizing the AgentEngine + /// + /// # Examples + /// + /// ```no_run + /// use agent_brain::{AgentEngine, EmbeddingModel, Backend}; + /// + /// let engine = AgentEngine::builder() + /// .db_path("agent.db") + /// .model(EmbeddingModel::BgeBase) // Most accurate + /// .backend(Backend::Candle) + /// .build()?; + /// # Ok::<(), anyhow::Error>(()) + /// ``` + pub fn builder() -> AgentEngineBuilder { + AgentEngineBuilder::default() + } + + /// Returns whether this engine is running in mock mode + pub fn is_mock(&self) -> bool { + self.brain.is_mock() + } + + /// Returns the configured embedding model + pub fn model(&self) -> EmbeddingModel { + self.brain.model() + } + + /// Returns the configured backend + pub fn backend(&self) -> Backend { + self.brain.backend() } - /// Creates an AgentEngine with metrics disabled - pub fn new_without_metrics(db_path: &str) -> Result { - let brain = Brain::new().context("Failed to initialize Brain")?; - let storage = Storage::new(db_path).context("Failed to initialize Storage")?; - let metrics = Metrics::disabled(); - Ok(Self { brain, storage, metrics }) + /// Returns the embedding dimension for the configured model + pub fn embedding_dim(&self) -> usize { + self.brain.embedding_dim() } - /// Creates an AgentEngine in mock mode for testing + // ========================================================================= + // CONVENIENCE CONSTRUCTORS (use builder() for full customization) + // ========================================================================= + + /// Creates an in-memory AgentEngine (useful for testing) /// - /// Uses hash-based deterministic embeddings instead of the ML model. - /// This allows testing all functionality without requiring the actual model files. - pub fn new_mock(db_path: &str) -> Result { - let brain = Brain::new_mock().context("Failed to initialize mock Brain")?; - let storage = Storage::new(db_path).context("Failed to initialize Storage")?; - let metrics = Metrics::new(); - Ok(Self { brain, storage, metrics }) + /// Uses default configuration. For customization, use `builder().in_memory()...`. + pub fn new_in_memory() -> Result { + Self::builder().in_memory().build() } - /// Creates an in-memory AgentEngine in mock mode for testing + /// Creates an in-memory AgentEngine in mock mode (fast testing) /// - /// Combines in-memory database with mock embeddings for fast, isolated testing. + /// Uses hash-based embeddings instead of real ML model. pub fn new_mock_in_memory() -> Result { - Self::new_mock(":memory:") + Self::builder().in_memory().mock().build() } - /// Returns whether this engine is running in mock mode - pub fn is_mock(&self) -> bool { - self.brain.is_mock() + /// Creates an AgentEngine in mock mode with specified database path + /// + /// Uses hash-based embeddings instead of real ML model. + pub fn new_mock(db_path: &str) -> Result { + Self::builder().db_path(db_path).mock().build() } // ========================================================================= @@ -268,9 +402,9 @@ impl AgentEngine { // 1. Generate embedding for the input let embed_start = Instant::now(); - let vector_tensor = self + let vector_flat = self .brain - .embed(input) + .embed_to_vec(input) .context("Failed to generate embedding")?; self.metrics.record(Operation::Embed, embed_start.elapsed()); @@ -278,7 +412,7 @@ impl AgentEngine { let classify_start = Instant::now(); let intent = self .brain - .classify(&vector_tensor) + .classify_vec(&vector_flat) .context("Failed to classify intent")?; self.metrics.record(Operation::Classify, classify_start.elapsed()); @@ -298,13 +432,7 @@ impl AgentEngine { }); } - // 4. Convert tensor to vector for storage/search - let vector_flat: Vec = vector_tensor - .flatten_all()? - .to_vec1() - .context("Failed to flatten embedding")?; - - // 5. Route based on detected action + // 4. Route based on detected action let response = match intent.action { Action::Store => { let category = intent.data_type.as_category(); @@ -373,15 +501,13 @@ impl AgentEngine { let total_start = Instant::now(); let embed_start = Instant::now(); - let vector_tensor = self.brain.embed(content)?; + let vector_flat = self.brain.embed_to_vec(content)?; self.metrics.record(Operation::Embed, embed_start.elapsed()); let classify_start = Instant::now(); - let data_type = self.brain.classify_data_type(&vector_tensor)?; + let data_type = self.brain.classify_data_type_vec(&vector_flat)?; self.metrics.record(Operation::Classify, classify_start.elapsed()); - let vector_flat: Vec = vector_tensor.flatten_all()?.to_vec1()?; - let category = data_type.as_category(); let db_start = Instant::now(); @@ -404,11 +530,9 @@ impl AgentEngine { let total_start = Instant::now(); let embed_start = Instant::now(); - let vector_tensor = self.brain.embed(content)?; + let vector_flat = self.brain.embed_to_vec(content)?; self.metrics.record(Operation::Embed, embed_start.elapsed()); - let vector_flat: Vec = vector_tensor.flatten_all()?.to_vec1()?; - let category = data_type.as_category(); let db_start = Instant::now(); @@ -454,9 +578,7 @@ impl AgentEngine { // Classify and prepare batch items let mut items = Vec::with_capacity(contents.len()); for (content, embedding) in contents.iter().zip(embeddings.iter()) { - let tensor = candle_core::Tensor::new(embedding.as_slice(), &candle_core::Device::Cpu)? - .unsqueeze(0)?; - let data_type = self.brain.classify_data_type(&tensor)?; + let data_type = self.brain.classify_data_type_vec(embedding)?; items.push((*content, data_type.as_category(), embedding.as_slice())); } @@ -507,11 +629,9 @@ impl AgentEngine { let total_start = Instant::now(); let embed_start = Instant::now(); - let vector_tensor = self.brain.embed(query)?; + let vector_flat = self.brain.embed_to_vec(query)?; self.metrics.record(Operation::Embed, embed_start.elapsed()); - let vector_flat: Vec = vector_tensor.flatten_all()?.to_vec1()?; - let db_start = Instant::now(); let results = self.storage.search(&vector_flat, limit)?; self.metrics.record(Operation::DbSearch, db_start.elapsed()); @@ -584,11 +704,9 @@ impl AgentEngine { let total_start = Instant::now(); let embed_start = Instant::now(); - let vector_tensor = self.brain.embed(query)?; + let vector_flat = self.brain.embed_to_vec(query)?; self.metrics.record(Operation::Embed, embed_start.elapsed()); - let vector_flat: Vec = vector_tensor.flatten_all()?.to_vec1()?; - let db_start = Instant::now(); let category = data_type.map(|dt| dt.as_category()); let results = self.storage.search_filtered(&vector_flat, category, time_filter, limit)?; @@ -680,8 +798,7 @@ impl AgentEngine { /// Classify intent without storing (useful for debugging/preview) pub fn classify(&mut self, text: &str) -> Result { - let vector = self.brain.embed(text)?; - self.brain.classify(&vector) + self.brain.classify_text(text) } // ========================================================================= diff --git a/packages/agent-state-rs/src/storage.rs b/packages/agent-state-rs/src/storage.rs index 1bb3d9e..7410223 100644 --- a/packages/agent-state-rs/src/storage.rs +++ b/packages/agent-state-rs/src/storage.rs @@ -230,7 +230,7 @@ impl Storage { /// # Arguments /// * `text` - The content to store /// * `category` - One of: "task", "memory", "preference", "relationship", "event" - /// * `vector` - The 384-dimensional embedding vector + /// * `vector` - The embedding vector (dimension depends on model: 384 for small models, 768 for base models) pub fn save(&mut self, text: &str, category: &str, vector: &[f32]) -> Result { let tx = self.conn.transaction()?; let row_id = Self::save_in_transaction(&tx, text, category, vector)?;