diff --git a/Cargo.lock b/Cargo.lock index 71e74fb..3471623 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -135,6 +135,35 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-test" +version = "18.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce2a8627e8d8851f894696b39f2b67807d6375c177361d376173ace306a21e2" +dependencies = [ + "anyhow", + "axum", + "bytes", + "bytesize", + "cookie", + "expect-json", + "http 1.4.0", + "http-body-util", + "hyper", + "hyper-util", + "mime", + "pretty_assertions", + "reserve-port", + "rust-multipart-rfc7578_2", + "serde", + "serde_json", + "serde_urlencoded", + "smallvec", + "tokio", + "tower", + "url", +] + [[package]] name = "backtrace" version = "0.3.76" @@ -237,6 +266,12 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +[[package]] +name = "bytesize" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd91ee7b2422bcb158d90ef4d14f75ef67f340943fc4149891dcce8f8b972a3" + [[package]] name = "bzip2-sys" version = "0.1.13+1.0.8" @@ -354,6 +389,16 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "time", + "version_check", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -449,6 +494,21 @@ dependencies = [ "uuid", ] +[[package]] +name = "deranged" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" +dependencies = [ + "powerfmt", +] + +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + [[package]] name = "digest" version = "0.10.7" @@ -482,6 +542,15 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "email_address" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449" +dependencies = [ + "serde", +] + [[package]] name = "encoding_rs" version = "0.8.35" @@ -497,6 +566,17 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "erased-serde" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2add8a07dd6a8d93ff627029c51de145e12686fbc36ecb298ac22e74cf02dec" +dependencies = [ + "serde", + "serde_core", + "typeid", +] + [[package]] name = "errno" version = "0.3.14" @@ -507,6 +587,35 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "expect-json" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "869f97f4abe8e78fc812a94ad6b721d72c4fb5532877c79610f2c238d7ccf6c4" +dependencies = [ + "chrono", + "email_address", + "expect-json-macros", + "num", + "regex", + "serde", + "serde_json", + "thiserror", + "typetag", + "uuid", +] + +[[package]] +name = "expect-json-macros" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0637949cd816934f3b7aab44ff98e7ec1fb903c379e07dcb9eac943ec33499e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "eyre" version = "0.6.12" @@ -779,11 +888,13 @@ version = "0.1.0" dependencies = [ "api", "axum", + "axum-test", "defs", "index", "serde", "serde_json", "storage", + "tempfile", "tokio", "tracing", ] @@ -1089,6 +1200,15 @@ dependencies = [ "serde_core", ] +[[package]] +name = "inventory" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4f0c30c76f2f4ccee3fe55a2435f691ca00c0e4bd87abe4f4a851b1d4dac39b" +dependencies = [ + "rustversion", +] + [[package]] name = "ipnet" version = "2.12.0" @@ -1388,6 +1508,76 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-conv" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521739c6d2bac4aa25192232afe6841231376b2b26d4d9fae5ecf8ca5772e441" + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -1567,6 +1757,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.21" @@ -1576,6 +1772,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "pretty_assertions" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" +dependencies = [ + "diff", + "yansi", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -1828,6 +2034,15 @@ dependencies = [ "web-sys", ] +[[package]] +name = "reserve-port" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94070964579245eb2f76e62a7668fe87bd9969ed6c41256f3bf614e3323dd3cc" +dependencies = [ + "thiserror", +] + [[package]] name = "ring" version = "0.17.14" @@ -1852,6 +2067,21 @@ dependencies = [ "librocksdb-sys", ] +[[package]] +name = "rust-multipart-rfc7578_2" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c839d037155ebc06a571e305af66ff9fd9063a6e662447051737e1ac75beea41" +dependencies = [ + "bytes", + "futures-core", + "futures-util", + "http 1.4.0", + "mime", + "rand", + "thiserror", +] + [[package]] name = "rustc-demangle" version = "0.1.27" @@ -2329,6 +2559,26 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.9" @@ -2338,6 +2588,37 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "time" +version = "0.3.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde_core", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" + +[[package]] +name = "time-macros" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tinystr" version = "0.8.2" @@ -2630,12 +2911,42 @@ dependencies = [ "uuid", ] +[[package]] +name = "typeid" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc7d623258602320d5c55d1bc22793b57daff0ec7efc270ea7d55ce1d5f5471c" + [[package]] name = "typenum" version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +[[package]] +name = "typetag" +version = "0.2.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5a897b12c6c1151ad0b138b8db50252dc301f93bc3b027db05eec82aeed298c" +dependencies = [ + "erased-serde", + "inventory", + "once_cell", + "serde", + "typetag-impl", +] + +[[package]] +name = "typetag-impl" +version = "0.2.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf808357c6ed7e13ba0f3277ec8d8f21b2d501274895104263985330c726c1c5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicase" version = "2.9.0" @@ -3207,6 +3518,12 @@ dependencies = [ "rustix", ] +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yoke" version = "0.8.1" diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index 2cdad92..5a94828 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -10,8 +10,7 @@ use std::sync::{Arc, RwLock}; use index::flat::index::FlatIndex; use index::{IndexType, VectorIndex}; use snapshot::Snapshot; -use storage::rocks_db::RocksDbStorage; -use storage::{StorageEngine, StorageType, VectorPage}; +use storage::{StorageEngine, StorageType, VectorPage, create_storage_engine}; use uuid::Uuid; @@ -252,10 +251,7 @@ pub fn restore_from_snapshot(config: &DbRestoreConfig) -> Result Result { // Initialize the storage engine - let storage = match config.storage_type { - StorageType::RocksDb => Arc::new(RocksDbStorage::new(config.data_path)?), - _ => Arc::new(RocksDbStorage::new(config.data_path)?), - }; + let storage = create_storage_engine(config.storage_type, config.data_path)?; // Initialize the vector index let index: Arc> = match config.index_type { @@ -290,9 +286,13 @@ mod tests { // Helper function to create a test database fn create_test_db() -> (VectorDb, TempDir) { + create_test_db_with_storage(StorageType::RocksDb) + } + + fn create_test_db_with_storage(storage_type: StorageType) -> (VectorDb, TempDir) { let temp_dir = tempdir().unwrap(); let config = DbConfig { - storage_type: StorageType::RocksDb, + storage_type, index_type: IndexType::Flat, data_path: temp_dir.path().to_path_buf(), dimension: 3, @@ -301,6 +301,13 @@ mod tests { (init_api(config).unwrap(), temp_dir) } + fn test_payload(content: &str) -> Payload { + Payload { + content_type: ContentType::Text, + content: content.to_string(), + } + } + #[test] fn test_insert_and_get() { let (db, _temp_dir) = create_test_db(); @@ -326,6 +333,20 @@ mod tests { assert_eq!(point.payload.as_ref().unwrap().content, "Test content"); } + #[test] + fn test_insert_and_get_with_in_memory_storage() { + let (db, _temp_dir) = create_test_db_with_storage(StorageType::InMemory); + let vector = vec![1.0, 2.0, 3.0]; + let payload = test_payload("Test content"); + + let id = db.insert(vector.clone(), payload.clone()).unwrap(); + let point = db.get(id).unwrap().unwrap(); + + assert_eq!(point.id, id); + assert_eq!(point.vector, Some(vector)); + assert_eq!(point.payload, Some(payload)); + } + #[test] fn test_dimension_mismatch() { let (db, _temp_dir) = create_test_db(); @@ -593,6 +614,34 @@ mod tests { assert!(loaded_db.get(id2).unwrap().unwrap().vector.unwrap() == v2); } + #[test] + fn test_create_and_load_snapshot_with_in_memory_storage() { + let (old_db, temp_dir) = create_test_db_with_storage(StorageType::InMemory); + + let v1 = vec![0.0, 1.0, 2.0]; + let v2 = vec![3.0, 4.0, 5.0]; + let v3 = vec![6.0, 7.0, 8.0]; + + let id1 = old_db.insert(v1.clone(), test_payload("one")).unwrap(); + let id2 = old_db.insert(v2.clone(), test_payload("two")).unwrap(); + + let temp_snapshot_dir = tempdir().unwrap(); + let snapshot_path = old_db.create_snapshot(temp_snapshot_dir.path()).unwrap(); + + let id3 = old_db.insert(v3, test_payload("three")).unwrap(); + + let reload_config = DbRestoreConfig { + data_path: temp_dir.path().to_path_buf(), + snapshot_path, + }; + + let loaded_db = restore_from_snapshot(&reload_config).unwrap(); + + assert_eq!(loaded_db.get(id1).unwrap().unwrap().vector, Some(v1)); + assert_eq!(loaded_db.get(id2).unwrap().unwrap().vector, Some(v2)); + assert!(loaded_db.get(id3).unwrap().is_none()); + } + #[test] fn test_snapshot_engine() { let (_db, _temp_dir) = create_test_db(); diff --git a/crates/grpc/src/error.rs b/crates/grpc/src/error.rs index c8bd1fd..faf02b4 100644 --- a/crates/grpc/src/error.rs +++ b/crates/grpc/src/error.rs @@ -135,6 +135,15 @@ impl From for GrpcError { StorageError::RocksDbFlush { source: _ } => GrpcError::Internal { message: "flush error".to_string(), }, + StorageError::InMemoryLock {} => GrpcError::Internal { + message: "failed to lock in-memory storage".to_string(), + }, + StorageError::InMemoryCheckpoint { msg } => GrpcError::Internal { + message: format!("in-memory checkpoint error: {}", msg), + }, + StorageError::InMemoryCheckpointIo { msg, source: _ } => GrpcError::Internal { + message: format!("in-memory checkpoint io error: {}", msg), + }, } } } diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index 4f3595e..fb25fad 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -16,3 +16,7 @@ serde_json.workspace = true storage.workspace = true tokio.workspace = true tracing.workspace = true + +[dev-dependencies] +axum-test.workspace = true +tempfile.workspace = true diff --git a/crates/http/src/handler.rs b/crates/http/src/handler.rs index 67c726e..387d167 100644 --- a/crates/http/src/handler.rs +++ b/crates/http/src/handler.rs @@ -149,7 +149,10 @@ fn api_error_to_response(err: &ApiError) -> (StatusCode, String) { | StorageError::RocksDbFlush { .. } | StorageError::RocksDbInitialization { .. } | StorageError::RocksDbCheckpointMsg { .. } - | StorageError::RocksDbCheckpointIo { .. } => { + | StorageError::RocksDbCheckpointIo { .. } + | StorageError::InMemoryLock { .. } + | StorageError::InMemoryCheckpoint { .. } + | StorageError::InMemoryCheckpointIo { .. } => { (StatusCode::INTERNAL_SERVER_ERROR, source.to_string()) } }, diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index 544a28a..450137b 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -46,3 +46,67 @@ pub async fn run_http_server(db: Arc, addr: SocketAddr) -> Result<(), axum::serve(listener, app.into_make_service()).await?; Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use api::DbConfig; + use axum::http::StatusCode; + use axum_test::TestServer; + use defs::Similarity; + use index::IndexType; + use serde_json::json; + use storage::StorageType; + + #[tokio::test] + async fn in_memory_storage_http_smoke_test() { + let temp_dir = tempfile::tempdir().unwrap(); + let db = api::init_api(DbConfig { + storage_type: StorageType::InMemory, + index_type: IndexType::Flat, + data_path: temp_dir.path().to_path_buf(), + dimension: 3, + similarity: Similarity::Cosine, + }) + .unwrap(); + let server = TestServer::new(create_router(Arc::new(db))).unwrap(); + + let insert_response = server + .post("/points") + .json(&json!({ + "vector": [1.0, 0.0, 0.0], + "payload": { + "content_type": "Text", + "content": "smoke-test" + } + })) + .await; + insert_response.assert_status(StatusCode::CREATED); + let insert_body: serde_json::Value = insert_response.json(); + let point_id = insert_body["point_id"].as_str().unwrap(); + + let get_response = server.get(&format!("/points/{point_id}")).await; + get_response.assert_status_ok(); + let point_body: serde_json::Value = get_response.json(); + assert_eq!(point_body["payload"]["content"], "smoke-test"); + assert_eq!(point_body["vector"], json!([1.0, 0.0, 0.0])); + + let search_response = server + .post("/points/search") + .json(&json!({ + "vector": [1.0, 0.0, 0.0], + "similarity": "Cosine", + "limit": 1 + })) + .await; + search_response.assert_status_ok(); + let search_body: serde_json::Value = search_response.json(); + assert_eq!(search_body["results"], json!([point_id])); + + let delete_response = server.delete(&format!("/points/{point_id}")).await; + delete_response.assert_status(StatusCode::NO_CONTENT); + + let missing_response = server.get(&format!("/points/{point_id}")).await; + missing_response.assert_status(StatusCode::NOT_FOUND); + } +} diff --git a/crates/snapshot/src/lib.rs b/crates/snapshot/src/lib.rs index 51b26fc..7324a75 100644 --- a/crates/snapshot/src/lib.rs +++ b/crates/snapshot/src/lib.rs @@ -26,7 +26,8 @@ use std::{ time::SystemTime, }; use storage::{ - StorageEngine, StorageType, checkpoint::StorageCheckpoint, rocks_db::RocksDbStorage, + StorageEngine, StorageType, checkpoint::StorageCheckpoint, in_memory::MemoryStorage, + rocks_db::RocksDbStorage, }; use tar::Archive; use tempfile::tempdir; @@ -201,17 +202,12 @@ impl Snapshot { )); } - // only rocksdb is supported for snapshots as of now let mut storage_engine: Box = match manifest.storage_type { + StorageType::InMemory => Box::new(MemoryStorage::new()), StorageType::RocksDb => Box::new( RocksDbStorage::new(storage_data_path) .map_err(|e| DbError::StorageError(format!("Could not open storage: {e}")))?, ), - _ => { - return Err(DbError::SnapshotError( - "Unsupported storage type".to_string(), - )); - } }; let id = manifest.id; diff --git a/crates/storage/src/checkpoint.rs b/crates/storage/src/checkpoint.rs index 09827fd..fc4be08 100644 --- a/crates/storage/src/checkpoint.rs +++ b/crates/storage/src/checkpoint.rs @@ -35,6 +35,7 @@ impl StorageCheckpoint { .0; let storage_type = match marker { + INMEMORY_CHECKPOINT_FILENAME_MARKER => StorageType::InMemory, ROCKSDB_CHECKPOINT_FILENAME_MARKER => StorageType::RocksDb, _ => { return Err(DbError::StorageCheckpointError( diff --git a/crates/storage/src/error.rs b/crates/storage/src/error.rs index d54f67e..70971a7 100644 --- a/crates/storage/src/error.rs +++ b/crates/storage/src/error.rs @@ -37,6 +37,15 @@ pub enum StorageError { #[snafu(display("Failed to iterate over storage: {source}"))] RocksDbIteration { source: rocksdb::Error }, + #[snafu(display("Failed to lock in-memory storage"))] + InMemoryLock {}, + + #[snafu(display("In-memory checkpoint error: {}", msg))] + InMemoryCheckpoint { msg: String }, + + #[snafu(display("{} : {}", msg, source))] + InMemoryCheckpointIo { msg: String, source: std::io::Error }, + #[snafu(display("Failed to serialize point {id}: {source}"))] Serialization { id: PointId, source: bincode::Error }, diff --git a/crates/storage/src/in_memory.rs b/crates/storage/src/in_memory.rs index 096d9ed..a619bb0 100644 --- a/crates/storage/src/in_memory.rs +++ b/crates/storage/src/in_memory.rs @@ -1,18 +1,26 @@ use crate::StorageType; use crate::error::StorageError; use crate::{StorageEngine, VectorPage, checkpoint::StorageCheckpoint}; -use defs::{DenseVector, Payload, PointId}; -use std::path::{Path, PathBuf}; +use bincode::{deserialize_from, serialize_into}; +use defs::{DenseVector, Payload, Point, PointId}; +use std::collections::BTreeMap; +use std::fs::File; +use std::ops::Bound::{Excluded, Unbounded}; +use std::path::Path; +use std::sync::RwLock; pub const INMEMORY_CHECKPOINT_FILENAME_MARKER: &str = "inmemory"; +const INMEMORY_CHECKPOINT_EXTENSION: &str = "bin"; pub struct MemoryStorage { - // define here how MemoryStorage will be defined + points: RwLock>, } impl MemoryStorage { pub fn new() -> Self { - MemoryStorage {} + MemoryStorage { + points: RwLock::new(BTreeMap::new()), + } } } @@ -25,40 +33,288 @@ impl Default for MemoryStorage { impl StorageEngine for MemoryStorage { fn insert_point( &self, - _id: PointId, - _vector: Option, - _payload: Option, + id: PointId, + vector: Option, + payload: Option, ) -> Result<(), StorageError> { + let mut points = self + .points + .write() + .map_err(|_| StorageError::InMemoryLock {})?; + points.insert( + id, + Point { + id, + vector, + payload, + }, + ); Ok(()) } - fn contains_point(&self, _id: PointId) -> Result { - Ok(true) + fn contains_point(&self, id: PointId) -> Result { + let points = self + .points + .read() + .map_err(|_| StorageError::InMemoryLock {})?; + Ok(points.contains_key(&id)) } - fn delete_point(&self, _id: PointId) -> Result<(), StorageError> { + fn delete_point(&self, id: PointId) -> Result<(), StorageError> { + let mut points = self + .points + .write() + .map_err(|_| StorageError::InMemoryLock {})?; + points.remove(&id); Ok(()) } - fn get_payload(&self, _id: PointId) -> Result, StorageError> { - Ok(None) + fn get_payload(&self, id: PointId) -> Result, StorageError> { + let points = self + .points + .read() + .map_err(|_| StorageError::InMemoryLock {})?; + Ok(points.get(&id).and_then(|point| point.payload.clone())) } - fn get_vector(&self, _id: PointId) -> Result, StorageError> { - Ok(None) + fn get_vector(&self, id: PointId) -> Result, StorageError> { + let points = self + .points + .read() + .map_err(|_| StorageError::InMemoryLock {})?; + Ok(points.get(&id).and_then(|point| point.vector.clone())) } fn list_vectors( &self, - _offset: PointId, - _limit: usize, + offset: PointId, + limit: usize, ) -> Result, StorageError> { - Ok(None) + if limit < 1 { + return Ok(None); + } + + let points = self + .points + .read() + .map_err(|_| StorageError::InMemoryLock {})?; + let mut result = Vec::with_capacity(limit); + let mut last_id = offset; + + for (id, point) in points.range((Excluded(offset), Unbounded)) { + if let Some(vector) = &point.vector { + last_id = *id; + result.push((*id, vector.clone())); + if result.len() == limit { + break; + } + } + } + + Ok(Some((result, last_id))) } - fn checkpoint_at(&self, _path: &Path) -> Result { + fn checkpoint_at(&self, path: &Path) -> Result { + let checkpoint_filename = format!( + "{}-{}.{}", + INMEMORY_CHECKPOINT_FILENAME_MARKER, + uuid::Uuid::new_v4(), + INMEMORY_CHECKPOINT_EXTENSION + ); + let checkpoint_path = path.join(checkpoint_filename); + let file = File::create(&checkpoint_path).map_err(|source| { + StorageError::InMemoryCheckpointIo { + msg: "Couldn't create in-memory checkpoint".to_string(), + source, + } + })?; + let points = self + .points + .read() + .map_err(|_| StorageError::InMemoryLock {})?; + serialize_into(file, &*points).map_err(|source| StorageError::Serialization { + id: PointId::nil(), + source, + })?; + Ok(StorageCheckpoint { - path: PathBuf::default(), + path: checkpoint_path, storage_type: StorageType::InMemory, }) } - fn restore_checkpoint(&mut self, _checkpoint: &StorageCheckpoint) -> Result<(), StorageError> { + fn restore_checkpoint(&mut self, checkpoint: &StorageCheckpoint) -> Result<(), StorageError> { + if checkpoint.storage_type != StorageType::InMemory { + return Err(StorageError::InMemoryCheckpoint { + msg: "Invalid storage type".to_string(), + }); + } + + let checkpoint_filename = checkpoint + .path + .file_name() + .ok_or_else(|| StorageError::InMemoryCheckpoint { + msg: "Could not read checkpoint filename".to_string(), + })? + .to_str() + .ok_or_else(|| StorageError::InMemoryCheckpoint { + msg: "Checkpoint filename is not valid UTF-8".to_string(), + })?; + if !checkpoint_filename.starts_with(INMEMORY_CHECKPOINT_FILENAME_MARKER) + || checkpoint.path.extension().and_then(|ext| ext.to_str()) + != Some(INMEMORY_CHECKPOINT_EXTENSION) + { + return Err(StorageError::InMemoryCheckpoint { + msg: "Invalid file name".to_string(), + }); + } + + let file = + File::open(&checkpoint.path).map_err(|source| StorageError::InMemoryCheckpointIo { + msg: "Couldn't open in-memory checkpoint".to_string(), + source, + })?; + let restored_points: BTreeMap = + deserialize_from(file).map_err(|source| StorageError::Deserialization { + id: PointId::nil(), + source, + })?; + let mut points = self + .points + .write() + .map_err(|_| StorageError::InMemoryLock {})?; + *points = restored_points; Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use defs::ContentType; + use tempfile::{TempDir, tempdir}; + use uuid::Uuid; + + fn create_test_storage() -> MemoryStorage { + MemoryStorage::new() + } + + fn test_payload(content: &str) -> Payload { + Payload { + content_type: ContentType::Text, + content: content.to_string(), + } + } + + #[test] + fn test_insert_and_get_vector() { + let storage = create_test_storage(); + let id = Uuid::new_v4(); + let vector = Some(vec![0.1, 0.2, 0.3]); + let payload = Some(test_payload("Test")); + + storage.insert_point(id, vector.clone(), payload).unwrap(); + + assert_eq!(storage.get_vector(id).unwrap(), vector); + } + + #[test] + fn test_insert_and_get_payload() { + let storage = create_test_storage(); + let id = Uuid::new_v4(); + let payload = Some(test_payload("Test")); + + storage.insert_point(id, None, payload.clone()).unwrap(); + + assert_eq!(storage.get_payload(id).unwrap(), payload); + } + + #[test] + fn test_contains_and_delete_point() { + let storage = create_test_storage(); + let id = Uuid::new_v4(); + + assert!(!storage.contains_point(id).unwrap()); + + storage + .insert_point(id, Some(vec![0.4, 0.5, 0.6]), Some(test_payload("Test"))) + .unwrap(); + assert!(storage.contains_point(id).unwrap()); + + storage.delete_point(id).unwrap(); + assert!(!storage.contains_point(id).unwrap()); + assert_eq!(storage.get_vector(id).unwrap(), None); + assert_eq!(storage.get_payload(id).unwrap(), None); + } + + #[test] + fn test_list_vectors_respects_offset_limit_and_skips_payload_only_points() { + let storage = create_test_storage(); + let ids = [ + Uuid::from_u128(1), + Uuid::from_u128(2), + Uuid::from_u128(3), + Uuid::from_u128(4), + ]; + + storage + .insert_point(ids[0], Some(vec![1.0, 1.1]), Some(test_payload("one"))) + .unwrap(); + storage + .insert_point(ids[1], None, Some(test_payload("payload-only"))) + .unwrap(); + storage + .insert_point(ids[2], Some(vec![3.0, 3.1]), Some(test_payload("three"))) + .unwrap(); + storage + .insert_point(ids[3], Some(vec![4.0, 4.1]), Some(test_payload("four"))) + .unwrap(); + + let (first_page, next_offset) = storage.list_vectors(Uuid::nil(), 2).unwrap().unwrap(); + assert_eq!( + first_page, + vec![(ids[0], vec![1.0, 1.1]), (ids[2], vec![3.0, 3.1])] + ); + assert_eq!(next_offset, ids[2]); + + let (second_page, next_offset) = storage.list_vectors(next_offset, 2).unwrap().unwrap(); + assert_eq!(second_page, vec![(ids[3], vec![4.0, 4.1])]); + assert_eq!(next_offset, ids[3]); + } + + #[test] + fn test_list_vectors_with_zero_limit_returns_none() { + let storage = create_test_storage(); + + assert_eq!(storage.list_vectors(Uuid::nil(), 0).unwrap(), None); + } + + #[test] + fn test_create_and_restore_checkpoint() { + let mut storage = create_test_storage(); + let temp_dir: TempDir = tempdir().unwrap(); + let id_before_checkpoint = Uuid::new_v4(); + let id_after_checkpoint = Uuid::new_v4(); + + storage + .insert_point( + id_before_checkpoint, + Some(vec![0.1, 0.2, 0.3]), + Some(test_payload("before")), + ) + .unwrap(); + let checkpoint = storage.checkpoint_at(temp_dir.path()).unwrap(); + + storage + .insert_point( + id_after_checkpoint, + Some(vec![0.4, 0.5, 0.6]), + Some(test_payload("after")), + ) + .unwrap(); + + storage.restore_checkpoint(&checkpoint).unwrap(); + + assert!(storage.contains_point(id_before_checkpoint).unwrap()); + assert!(!storage.contains_point(id_after_checkpoint).unwrap()); + assert_eq!( + storage.get_payload(id_before_checkpoint).unwrap(), + Some(test_payload("before")) + ); + } +}