diff --git a/.bumpversion.toml b/.bumpversion.toml index 3af8812f..a18da6d6 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -3,7 +3,7 @@ # https://peps.python.org/pep-0440/ [tool.bumpversion] - current_version = "0.4.2" + current_version = "0.4.3" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. diff --git a/pyproject.toml b/pyproject.toml index eb9c40a4..1c54f975 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,14 +28,14 @@ dependencies = [ "ag-ui-protocol>=0.1.14", - "agentic-mesh-protocol==0.2.3", + "agentic-mesh-protocol==0.2.4", "anyio==4.13.0", "grpcio-health-checking==1.78.0", "grpcio-reflection==1.78.0", "grpcio-status==1.78.0", "pydantic==2.12.5", ] - version = "0.4.2" + version = "0.4.3" [project.optional-dependencies] profiling = [ diff --git a/src/digitalkin/__version__.py b/src/digitalkin/__version__.py index 1d93645d..c17c9532 100644 --- a/src/digitalkin/__version__.py +++ b/src/digitalkin/__version__.py @@ -5,4 +5,4 @@ try: __version__ = version("digitalkin") except PackageNotFoundError: - __version__ = "0.4.2" + __version__ = "0.4.3" diff --git a/src/digitalkin/modules/_base_module.py b/src/digitalkin/modules/_base_module.py index c172f48a..64c82754 100644 --- a/src/digitalkin/modules/_base_module.py +++ b/src/digitalkin/modules/_base_module.py @@ -649,6 +649,10 @@ async def _resolve_tools(self, config_setup_data: SetupModelT) -> None: config_setup_data: Setup data containing tool references. """ logger.debug("Starting tool resolution", extra=self.context.session.current_ids()) + # New setup version: discard any inherited resolved_tools so the live + # tool-module schemas are re-fetched. Mission runs reuse the persisted + # resolved_tools (via build_tool_cache in start()) and never reach here. + config_setup_data.resolved_tools = {} tool_cache = await config_setup_data.build_tool_cache(self.context.registry, self.context.communication) self.context.tool_cache = tool_cache logger.debug( diff --git a/src/digitalkin/services/storage/default_storage.py b/src/digitalkin/services/storage/default_storage.py index 5d630a79..35d0d7ca 100644 --- a/src/digitalkin/services/storage/default_storage.py +++ b/src/digitalkin/services/storage/default_storage.py @@ -62,7 +62,7 @@ def _load_from_file(self) -> dict[str, StorageRecord]: continue data_model = model_cls.model_validate(rd["data"]) rec = StorageRecord( - mission_id=rd["mission_id"], + context=rd["context"], collection=rd["collection"], record_id=rd["record_id"], data=data_model, @@ -93,7 +93,7 @@ def _save_to_file(self) -> None: serial: dict[str, dict] = {} for key, record in self.storage.items(): serial[key] = { - "mission_id": record.mission_id, + "context": record.context, "collection": record.collection, "record_id": record.record_id, "data_type": record.data_type.name, @@ -119,7 +119,7 @@ async def _store(self, record: StorageRecord) -> StorageRecord: Raises: ValueError: If the record already exists """ - key = f"{record.collection}:{record.record_id}" + key = self._key(record.context, record.collection, record.record_id) if key in self.storage: msg = f"Document {key!r} already exists" raise ValueError(msg) @@ -131,31 +131,36 @@ async def _store(self, record: StorageRecord) -> StorageRecord: logger.debug("Created %s", key) return record - async def _read(self, collection: str, record_id: str) -> StorageRecord | None: - """Get records from the database. + @staticmethod + def _key(context: str, collection: str, record_id: str) -> str: + return f"{context}|{collection}:{record_id}" + + async def _read(self, collection: str, record_id: str, context: str) -> StorageRecord | None: + """Get a record from the database scoped to a specific context. Args: collection: The unique name to retrieve data for record_id: The unique ID of the record + context: Owner context scoping the lookup. Returns: StorageRecord: The corresponding record """ - key = f"{collection}:{record_id}" - return self.storage.get(key) + return self.storage.get(self._key(context, collection, record_id)) - async def _update(self, collection: str, record_id: str, data: BaseModel) -> StorageRecord | None: - """Update records in the database and persist to file. + async def _update(self, collection: str, record_id: str, data: BaseModel, context: str) -> StorageRecord | None: + """Update a record in the database scoped to a specific context. Args: collection: The unique name to retrieve data for record_id: The unique ID of the record data: The data to modify + context: Owner context scoping the update. Returns: StorageRecord: The modified record """ - key = f"{collection}:{record_id}" + key = self._key(context, collection, record_id) rec = self.storage.get(key) if not rec: return None @@ -165,17 +170,18 @@ async def _update(self, collection: str, record_id: str, data: BaseModel) -> Sto logger.debug("Modified %s", key) return rec - async def _remove(self, collection: str, record_id: str) -> bool: - """Delete records from the database and update file. + async def _remove(self, collection: str, record_id: str, context: str) -> bool: + """Delete a record from the database scoped to a specific context. Args: collection: The unique name to retrieve data for record_id: The unique ID of the record + context: Owner context scoping the deletion. Returns: bool: True if the record was removed, False otherwise """ - key = f"{collection}:{record_id}" + key = self._key(context, collection, record_id) if key not in self.storage: return False del self.storage[key] @@ -183,28 +189,30 @@ async def _remove(self, collection: str, record_id: str) -> bool: logger.debug("Removed %s", key) return True - async def _list(self, collection: str) -> list[StorageRecord]: - """Implements StorageStrategy._list. + async def _list(self, collection: str, context: str) -> list[StorageRecord]: + """List records in a collection scoped to a specific context. Args: collection: The unique name to retrieve data for + context: Owner context scoping the listing. Returns: A list of storage records """ - prefix = f"{collection}:" + prefix = f"{context}|{collection}:" return [r for k, r in self.storage.items() if k.startswith(prefix)] - async def _remove_collection(self, collection: str) -> bool: - """Implements StorageStrategy._remove_collection. + async def _remove_collection(self, collection: str, context: str) -> bool: + """Wipe a collection scoped to a specific context. Args: collection: The unique name to retrieve data for + context: Owner context scoping the wipe. Returns: bool: True if the collection was removed, False otherwise """ - prefix = f"{collection}:" + prefix = f"{context}|{collection}:" to_delete = [k for k in self.storage if k.startswith(prefix)] for k in to_delete: del self.storage[k] diff --git a/src/digitalkin/services/storage/grpc_storage.py b/src/digitalkin/services/storage/grpc_storage.py index 815aff8f..ae9a3595 100644 --- a/src/digitalkin/services/storage/grpc_storage.py +++ b/src/digitalkin/services/storage/grpc_storage.py @@ -34,7 +34,7 @@ def _build_record_from_proto(self, proto: data_pb2.StorageRecord) -> StorageReco A fully validated StorageRecord. """ # Direct field access for scalars (avoids full MessageToDict overhead) - mission = proto.mission_id + ctx = proto.context coll = proto.collection rid = proto.record_id dtype = DataType[data_pb2.DataType.Name(proto.data_type)] @@ -48,7 +48,7 @@ def _build_record_from_proto(self, proto: data_pb2.StorageRecord) -> StorageReco validated = self._validate_data(coll, payload) return StorageRecord( - mission_id=mission, + context=ctx, collection=coll, record_id=rid, data=validated, @@ -75,7 +75,7 @@ async def _store(self, record: StorageRecord) -> StorageRecord: data_struct.update(record.data.model_dump()) req = data_pb2.StoreRecordRequest( data=data_struct, - mission_id=record.mission_id, + context=record.context, collection=record.collection, record_id=record.record_id, data_type=record.data_type.name, @@ -90,16 +90,16 @@ async def _store(self, record: StorageRecord) -> StorageRecord: ) raise StorageServiceError(str(e)) from e - async def _read(self, collection: str, record_id: str) -> StorageRecord | None: - """Fetch a single document by collection + record_id. + async def _read(self, collection: str, record_id: str, context: str) -> StorageRecord | None: + """Fetch a single document scoped to a specific context. Returns: StorageData: The record """ - logger.debug("debug:_read collection=%s id=%s", collection, record_id) + logger.debug("debug:_read context=%s collection=%s id=%s", context, collection, record_id) try: req = data_pb2.ReadRecordRequest( - mission_id=self.mission_id, + context=context, collection=collection, record_id=record_id, ) @@ -114,24 +114,20 @@ async def _update( collection: str, record_id: str, data: BaseModel, + context: str, ) -> StorageRecord | None: - """Overwrite a document via gRPC. - - Args: - collection: The unique name for the record type - record_id: The unique ID for the record - data: The validated data model + """Overwrite a document via gRPC scoped to a specific context. Returns: - StorageRecord: The updated record + StorageRecord: The updated record, or None on failure. """ - logger.debug("debug:_update collection=%s id=%s", collection, record_id) + logger.debug("debug:_update context=%s collection=%s id=%s", context, collection, record_id) try: struct = Struct() struct.update(data.model_dump()) req = data_pb2.UpdateRecordRequest( data=struct, - mission_id=self.mission_id, + context=context, collection=collection, record_id=record_id, ) @@ -141,20 +137,16 @@ async def _update( logger.warning("gRPC UpdateRecord failed for %s:%s", collection, record_id) return None - async def _remove(self, collection: str, record_id: str) -> bool: - """Delete a document via gRPC. - - Args: - collection: The unique name for the record type - record_id: The unique ID for the record + async def _remove(self, collection: str, record_id: str, context: str) -> bool: + """Delete a document via gRPC scoped to a specific context. Returns: - bool: True if the record was deleted, False otherwise + bool: True if the record was deleted, False otherwise. """ - logger.debug("debug:_remove collection=%s id=%s", collection, record_id) + logger.debug("debug:_remove context=%s collection=%s id=%s", context, collection, record_id) try: req = data_pb2.RemoveRecordRequest( - mission_id=self.mission_id, + context=context, collection=collection, record_id=record_id, ) @@ -168,19 +160,16 @@ async def _remove(self, collection: str, record_id: str) -> bool: return False return True - async def _list(self, collection: str) -> list[StorageRecord]: - """List all documents in a collection via gRPC. - - Args: - collection: The unique name for the record type + async def _list(self, collection: str, context: str) -> list[StorageRecord]: + """List all documents in a collection via gRPC scoped to a specific context. Returns: - list[StorageRecord]: A list of storage records + list[StorageRecord]: The records found, or an empty list on failure. """ - logger.debug("debug:_list collection=%s", collection) + logger.debug("debug:_list context=%s collection=%s", context, collection) try: req = data_pb2.ListRecordsRequest( - mission_id=self.mission_id, + context=context, collection=collection, ) resp = await self.exec_grpc_query("ListRecords", req) @@ -189,18 +178,15 @@ async def _list(self, collection: str) -> list[StorageRecord]: logger.warning("gRPC ListRecords failed for %s", collection) return [] - async def _remove_collection(self, collection: str) -> bool: - """Delete an entire collection via gRPC. - - Args: - collection: The unique name for the record type + async def _remove_collection(self, collection: str, context: str) -> bool: + """Delete an entire collection via gRPC scoped to a specific context. Returns: - bool: True if the collection was deleted, False otherwise + bool: True if the collection was removed, False otherwise. """ try: req = data_pb2.RemoveCollectionRequest( - mission_id=self.mission_id, + context=context, collection=collection, ) await self.exec_grpc_query("RemoveCollection", req) diff --git a/src/digitalkin/services/storage/storage_strategy.py b/src/digitalkin/services/storage/storage_strategy.py index 2b704482..7d454e00 100644 --- a/src/digitalkin/services/storage/storage_strategy.py +++ b/src/digitalkin/services/storage/storage_strategy.py @@ -29,7 +29,7 @@ class DataType(Enum): class StorageRecord(BaseModel): """A single record stored in a collection, with metadata.""" - mission_id: str = Field(..., description="ID of the mission (bucket) this doc belongs to") + context: str = Field(..., description="Owner context (`missions:` or `setup_versions:`)") collection: str = Field(..., description="Logical collection name") record_id: str = Field(..., description="Unique ID of this record in its collection") data_type: DataType = Field(default=DataType.OUTPUT, description="Category of the data of this record") @@ -38,8 +38,25 @@ class StorageRecord(BaseModel): update_date: datetime.datetime | None = Field(default=None, description="When this record was last modified") +Scope = Literal["mission", "setup"] + + class StorageStrategy(BaseStrategy, ABC): - """Define CRUD + list/remove-collection against a collection/record store.""" + """Define CRUD + list/remove-collection against a collection/record store. + + Records are scoped by a `context` string (the proto field), which is either + `self.mission_id` (mission scope, the default) or `self.setup_version_id` + (setup-version scope). Both attributes are expected to already contain the + full prefix (`missions:` / `setup_versions:`). + + Public methods accept `scope: Literal["mission", "setup"]` (default + `"mission"`); internally we resolve it to the matching context string and + pass that to the abstract `_store/_read/_update/_remove/_list/_remove_collection`. + """ + + def _resolve_context(self, scope: Scope) -> str: + """Return the context string for the given scope.""" + return self.mission_id if scope == "mission" else self.setup_version_id def _validate_data(self, collection: str, data: dict[str, Any]) -> BaseModel: """Validate data against the model schema for the given key. @@ -65,52 +82,34 @@ def _validate_data(self, collection: str, data: dict[str, Any]) -> BaseModel: msg = f"Validation failed for '{collection}': {e!s}" raise ValueError(msg) from e + @staticmethod def _create_storage_record( - self, collection: str, record_id: str, validated_data: BaseModel, data_type: DataType, + context: str, ) -> StorageRecord: - """Create a storage record with metadata. + """Create a storage record stamped with the given context. Args: collection: The unique name for the record type record_id: The unique ID for the record validated_data: The validated data model data_type: The type of data + context: Owner context to stamp on the record (mission or setup-version). Returns: A complete storage record with metadata """ return StorageRecord( - mission_id=self.mission_id, + context=context, collection=collection, record_id=record_id, data=validated_data, data_type=data_type, ) - def _verify_mission_id(self, record: StorageRecord) -> bool: - """Check that a record belongs to this strategy's mission. - - Args: - record: The record to verify. - - Returns: - True if the record's mission_id matches, False otherwise. - """ - if record.mission_id != self.mission_id: - logger.warning( - "Mission ID mismatch: expected %s, got %s (collection=%s, record_id=%s)", - self.mission_id, - record.mission_id, - record.collection, - record.record_id, - ) - return False - return True - @staticmethod def _is_valid_data_type_name(value: str) -> TypeGuard[str]: return value in DataType.__members__ @@ -120,66 +119,71 @@ async def _store(self, record: StorageRecord) -> StorageRecord: """Store a new record in the storage. Args: - record: The record to store + record: The record to store (context is encoded in record.context) Returns: The ID of the created record """ @abstractmethod - async def _read(self, collection: str, record_id: str) -> StorageRecord | None: - """Get records from storage by key. + async def _read(self, collection: str, record_id: str, context: str) -> StorageRecord | None: + """Get records from storage scoped to a specific context. Args: collection: The unique name to retrieve data for record_id: The unique ID of the record + context: Owner context (e.g. `missions:` or `setup_versions:`). Returns: A storage record with validated data """ @abstractmethod - async def _update(self, collection: str, record_id: str, data: BaseModel) -> StorageRecord | None: - """Overwrite an existing record's payload. + async def _update(self, collection: str, record_id: str, data: BaseModel, context: str) -> StorageRecord | None: + """Overwrite an existing record's payload scoped to a specific context. Args: collection: The unique name for the record type record_id: The unique ID of the record data: The new data to store + context: Owner context for the record being updated. Returns: StorageRecord: The modified record """ @abstractmethod - async def _remove(self, collection: str, record_id: str) -> bool: - """Delete a record from the storage. + async def _remove(self, collection: str, record_id: str, context: str) -> bool: + """Delete a record from the storage scoped to a specific context. Args: collection: The unique name for the record type record_id: The unique ID of the record + context: Owner context for the record being deleted. Returns: True if the deletion was successful, False otherwise """ @abstractmethod - async def _list(self, collection: str) -> list[StorageRecord]: - """List all records in a collection. + async def _list(self, collection: str, context: str) -> list[StorageRecord]: + """List all records in a collection scoped to a specific context. Args: collection: The unique name for the record type + context: Owner context filter. Returns: A list of storage records """ @abstractmethod - async def _remove_collection(self, collection: str) -> bool: - """Delete all records in a collection. + async def _remove_collection(self, collection: str, context: str) -> bool: + """Delete all records in a collection scoped to a specific context. Args: collection: The unique name for the record type + context: Owner context for which to wipe records. Returns: True if the deletion was successful, False otherwise @@ -195,9 +199,9 @@ def __init__( """Initialize the storage strategy. Args: - mission_id: The ID of the mission this strategy is associated with + mission_id: Already-prefixed mission context (`missions:`). setup_id: The ID of the setup - setup_version_id: The ID of the setup version + setup_version_id: Already-prefixed setup-version context (`setup_versions:`). config: A dictionary mapping names to Pydantic model classes """ super().__init__(mission_id, setup_id, setup_version_id) @@ -205,17 +209,18 @@ def __init__( self.config: dict[str, type[BaseModel]] = config self._record_locks: dict[str, asyncio.Lock] = {} - def _record_lock(self, collection: str, record_id: str) -> asyncio.Lock: - """Get or create an asyncio.Lock for a specific record. + def _record_lock(self, context: str, collection: str, record_id: str) -> asyncio.Lock: + """Get or create an asyncio.Lock for a specific record under a given context. Args: + context: Owner context the record lives under collection: The collection name record_id: The record ID Returns: - An asyncio.Lock scoped to the given collection:record_id pair. + An asyncio.Lock scoped to the given context:collection:record_id triple. """ - return self._record_locks.setdefault(f"{collection}:{record_id}", asyncio.Lock()) + return self._record_locks.setdefault(f"{context}|{collection}:{record_id}", asyncio.Lock()) async def store( self, @@ -223,6 +228,7 @@ async def store( record_id: str | None, data: dict[str, Any], data_type: Literal["OUTPUT", "VIEW", "LOGS", "OTHER"] = "OUTPUT", + scope: Scope = "mission", ) -> StorageRecord: """Store a new record in the storage. @@ -231,6 +237,8 @@ async def store( record_id: The unique ID for the record (optional) data: The data to store data_type: The type of data being stored (default: OUTPUT) + scope: "mission" (default) writes under the current mission context; + "setup" writes under the setup-version context. Returns: The ID of the created record @@ -243,84 +251,94 @@ async def store( raise ValueError(msg) record_id = record_id or uuid4().hex data_type_enum = DataType[data_type] - validated_data = self._validate_data(collection, {**data, "mission_id": self.mission_id}) - record = self._create_storage_record(collection, record_id, validated_data, data_type_enum) - async with self._record_lock(collection, record_id): + context = self._resolve_context(scope) + validated_data = self._validate_data(collection, data) + record = self._create_storage_record(collection, record_id, validated_data, data_type_enum, context) + async with self._record_lock(context, collection, record_id): return await self._store(record) - async def read(self, collection: str, record_id: str) -> StorageRecord | None: - """Get records from storage by key. + async def read(self, collection: str, record_id: str, scope: Scope = "mission") -> StorageRecord | None: + """Get a record by key under the given scope. Args: collection: The unique name to retrieve data for record_id: The unique ID of the record + scope: Which context to read from (default: "mission"). Returns: - A storage record with validated data, or None if not found - or if the record belongs to a different mission. + The matching record if it exists, otherwise None. """ - async with self._record_lock(collection, record_id): - record = await self._read(collection, record_id) - if record is not None and not self._verify_mission_id(record): - return None - return record + context = self._resolve_context(scope) + async with self._record_lock(context, collection, record_id): + return await self._read(collection, record_id, context) - async def update(self, collection: str, record_id: str, data: dict[str, Any]) -> StorageRecord | None: - """Validate & overwrite an existing record. + async def update( + self, + collection: str, + record_id: str, + data: dict[str, Any], + scope: Scope = "mission", + ) -> StorageRecord | None: + """Validate & overwrite an existing record under the given scope. Args: collection: The unique name for the record type record_id: The unique ID of the record data: The new data to store + scope: Which context the record lives under (default: "mission"). Returns: StorageRecord: The modified record """ validated_data = self._validate_data(collection, data) - async with self._record_lock(collection, record_id): - return await self._update(collection, record_id, validated_data) + context = self._resolve_context(scope) + async with self._record_lock(context, collection, record_id): + return await self._update(collection, record_id, validated_data, context) - async def remove(self, collection: str, record_id: str) -> bool: - """Delete a record from the storage. + async def remove(self, collection: str, record_id: str, scope: Scope = "mission") -> bool: + """Delete a record from the storage under the given scope. Args: collection: The unique name for the record type record_id: The unique ID of the record + scope: Which context the record lives under (default: "mission"). Returns: True if the deletion was successful, False otherwise """ - key = f"{collection}:{record_id}" - async with self._record_lock(collection, record_id): - result = await self._remove(collection, record_id) + context = self._resolve_context(scope) + async with self._record_lock(context, collection, record_id): + result = await self._remove(collection, record_id, context) if result: - self._record_locks.pop(key, None) + self._record_locks.pop(f"{context}|{collection}:{record_id}", None) return result - async def list(self, collection: str) -> list[StorageRecord]: - """Get all records within a collection scoped to this mission. + async def list(self, collection: str, scope: Scope = "mission") -> list[StorageRecord]: + """Get all records in a collection under the given scope. Args: collection: The unique name for the record type + scope: Which context to list (default: "mission"). Returns: - A list of storage records belonging to this mission. + A list of storage records under the resolved context. """ - records = await self._list(collection) - return [r for r in records if self._verify_mission_id(r)] + return await self._list(collection, self._resolve_context(scope)) - async def remove_collection(self, collection: str) -> bool: - """Wipe a record clean. + async def remove_collection(self, collection: str, scope: Scope = "mission") -> bool: + """Wipe a collection clean under the given scope. Args: collection: The unique name for the record type + scope: Which context the records live under (default: "mission"). Returns: True if the deletion was successful, False otherwise """ - result = await self._remove_collection(collection) + context = self._resolve_context(scope) + result = await self._remove_collection(collection, context) if result: - prefix = f"{collection}:" + prefix = f"{context}|{collection}:" for key in [k for k in self._record_locks if k.startswith(prefix)]: self._record_locks.pop(key, None) return result @@ -331,18 +349,20 @@ async def upsert( record_id: str, data: dict[str, Any], data_type: Literal["OUTPUT", "VIEW", "LOGS", "OTHER"] = "OUTPUT", + scope: Scope = "mission", ) -> StorageRecord: - """Insert or update a record atomically. + """Insert or update a record atomically under the given scope. - If a record with the given collection/record_id exists, it is updated; - otherwise a new record is created. The operation is protected by a - per-record lock to prevent races. + If a record with the given collection/record_id exists under that + context it is updated; otherwise a new record is created. The operation + is protected by a per-record lock to prevent races. Args: collection: The unique name for the record type record_id: The unique ID for the record data: The data to store data_type: The type of data being stored (default: OUTPUT) + scope: Which context to upsert under (default: "mission"). Returns: The created or updated storage record @@ -355,13 +375,14 @@ async def upsert( msg = f"Invalid data type '{data_type}'. Must be one of {list(DataType.__members__.keys())}" raise ValueError(msg) data_type_enum = DataType[data_type] - validated_data = self._validate_data(collection, {**data, "mission_id": self.mission_id}) - async with self._record_lock(collection, record_id): - if await self._read(collection, record_id): - updated = await self._update(collection, record_id, validated_data) + context = self._resolve_context(scope) + validated_data = self._validate_data(collection, data) + async with self._record_lock(context, collection, record_id): + if await self._read(collection, record_id, context): + updated = await self._update(collection, record_id, validated_data, context) if updated is None: msg = f"Update failed for existing record '{collection}:{record_id}'" raise StorageServiceError(msg) return updated - record = self._create_storage_record(collection, record_id, validated_data, data_type_enum) + record = self._create_storage_record(collection, record_id, validated_data, data_type_enum, context) return await self._store(record) diff --git a/tests/services/storage/mock_storage_servicer.py b/tests/services/storage/mock_storage_servicer.py index 26b0dece..a2f53d92 100644 --- a/tests/services/storage/mock_storage_servicer.py +++ b/tests/services/storage/mock_storage_servicer.py @@ -21,7 +21,7 @@ def __init__(self, schema_config: dict[str, type[BaseModel]] | None = None) -> N schema_config: Dictionary mapping collection names to Pydantic model classes """ super().__init__() - # mission_id -> collection -> record_id -> record_data + # context -> collection -> record_id -> record_data self.records: dict[str, dict[str, dict[str, dict[str, Any]]]] = {} # Schema configuration for validation self.schema_config = schema_config or {} @@ -47,7 +47,7 @@ def _validate_schema(self, collection: str, data: dict[str, Any]) -> None: def _create_proto_record( self, - mission_id: str, + ctx: str, collection: str, record_id: str, record_data: dict[str, Any], @@ -55,7 +55,7 @@ def _create_proto_record( """Convert internal record data to proto StorageRecord. Args: - mission_id: Mission ID + ctx: Owner context (`missions:` or `setup_versions:`) collection: Collection name record_id: Record ID record_data: The record data dictionary @@ -88,7 +88,7 @@ def _create_proto_record( update_ts.FromDatetime(update_dt) return data_pb2.StorageRecord( - mission_id=mission_id, + context=ctx, collection=collection, record_id=record_id, data_type=value, @@ -111,9 +111,9 @@ def StoreRecord( """ try: # Validate required fields - if not request.mission_id: + if not request.context: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - context.set_details("Mission ID is required") + context.set_details("Context is required") return data_pb2.StoreRecordResponse() if not request.collection: @@ -153,7 +153,7 @@ def StoreRecord( return data_pb2.StoreRecordResponse() # Check if record already exists - mission_records = self.records.setdefault(request.mission_id, {}) + mission_records = self.records.setdefault(request.context, {}) collection_records = mission_records.setdefault(request.collection, {}) if request.record_id in collection_records: @@ -176,10 +176,10 @@ def StoreRecord( # Create response stored_record = self._create_proto_record( - request.mission_id, request.collection, request.record_id, record_data + request.context, request.collection, request.record_id, record_data ) - logger.info(f"Stored record: {request.record_id} in {request.collection} for mission {request.mission_id}") + logger.info(f"Stored record: {request.record_id} in {request.collection} for context {request.context}") return data_pb2.StoreRecordResponse(stored_data=stored_record) except Exception as e: @@ -194,16 +194,16 @@ def ReadRecord( """Read a record from the mock database. Args: - request: ReadRecordRequest containing mission_id, collection, record_id + request: ReadRecordRequest containing context, collection, record_id context: gRPC context Returns: ReadRecordResponse: Response containing the record or empty if not found """ try: - if not request.mission_id: + if not request.context: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - context.set_details("Mission ID is required") + context.set_details("Context is required") return data_pb2.ReadRecordResponse() if not request.collection: @@ -217,7 +217,7 @@ def ReadRecord( return data_pb2.ReadRecordResponse() # Try to find the record - mission_records = self.records.get(request.mission_id, {}) + mission_records = self.records.get(request.context, {}) collection_records = mission_records.get(request.collection, {}) record_data = collection_records.get(request.record_id) @@ -228,7 +228,7 @@ def ReadRecord( # Create response stored_record = self._create_proto_record( - request.mission_id, request.collection, request.record_id, record_data + request.context, request.collection, request.record_id, record_data ) logger.info(f"Read record: {request.record_id} from {request.collection}") @@ -253,9 +253,9 @@ def UpdateRecord( UpdateRecordResponse: Response containing updated record """ try: - if not request.mission_id: + if not request.context: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - context.set_details("Mission ID is required") + context.set_details("Context is required") return data_pb2.UpdateRecordResponse() if not request.collection: @@ -269,7 +269,7 @@ def UpdateRecord( return data_pb2.UpdateRecordResponse() # Try to find the record - mission_records = self.records.get(request.mission_id, {}) + mission_records = self.records.get(request.context, {}) collection_records = mission_records.get(request.collection, {}) record_data = collection_records.get(request.record_id) @@ -297,7 +297,7 @@ def UpdateRecord( # Create response stored_record = self._create_proto_record( - request.mission_id, request.collection, request.record_id, record_data + request.context, request.collection, request.record_id, record_data ) logger.info(f"Updated record: {request.record_id} in {request.collection}") @@ -315,16 +315,16 @@ def RemoveRecord( """Remove a record from the mock database. Args: - request: RemoveRecordRequest containing mission_id, collection, record_id + request: RemoveRecordRequest containing context, collection, record_id context: gRPC context Returns: RemoveRecordResponse: Empty response """ try: - if not request.mission_id: + if not request.context: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - context.set_details("Mission ID is required") + context.set_details("Context is required") return data_pb2.RemoveRecordResponse() if not request.collection: @@ -338,7 +338,7 @@ def RemoveRecord( return data_pb2.RemoveRecordResponse() # Try to find and remove the record - mission_records = self.records.get(request.mission_id, {}) + mission_records = self.records.get(request.context, {}) collection_records = mission_records.get(request.collection, {}) if request.record_id not in collection_records: @@ -363,16 +363,16 @@ def ListRecords( """List all records in a collection. Args: - request: ListRecordsRequest containing mission_id and collection + request: ListRecordsRequest containing context and collection context: gRPC context Returns: ListRecordsResponse: Response containing list of records """ try: - if not request.mission_id: + if not request.context: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - context.set_details("Mission ID is required") + context.set_details("Context is required") return data_pb2.ListRecordsResponse(records=[]) if not request.collection: @@ -381,13 +381,13 @@ def ListRecords( return data_pb2.ListRecordsResponse(records=[]) # Get all records in the collection - mission_records = self.records.get(request.mission_id, {}) + mission_records = self.records.get(request.context, {}) collection_records = mission_records.get(request.collection, {}) # Convert to proto records proto_records = [] for record_id, record_data in collection_records.items(): - proto_record = self._create_proto_record(request.mission_id, request.collection, record_id, record_data) + proto_record = self._create_proto_record(request.context, request.collection, record_id, record_data) proto_records.append(proto_record) logger.info(f"Listed {len(proto_records)} records from {request.collection}") @@ -405,16 +405,16 @@ def RemoveCollection( """Remove all records in a collection. Args: - request: RemoveCollectionRequest containing mission_id and collection + request: RemoveCollectionRequest containing context and collection context: gRPC context Returns: RemoveCollectionResponse: Empty response """ try: - if not request.mission_id: + if not request.context: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - context.set_details("Mission ID is required") + context.set_details("Context is required") return data_pb2.RemoveCollectionResponse() if not request.collection: @@ -423,7 +423,7 @@ def RemoveCollection( return data_pb2.RemoveCollectionResponse() # Remove the entire collection - mission_records = self.records.get(request.mission_id, {}) + mission_records = self.records.get(request.context, {}) if request.collection in mission_records: del mission_records[request.collection] logger.info(f"Removed collection: {request.collection}") diff --git a/tests/services/storage/test_grpc_storage.py b/tests/services/storage/test_grpc_storage.py index e485dfe1..11f8e8cf 100644 --- a/tests/services/storage/test_grpc_storage.py +++ b/tests/services/storage/test_grpc_storage.py @@ -15,12 +15,12 @@ import pytest from agentic_mesh_protocol.storage.v1 import data_pb2, storage_service_pb2, storage_service_pb2_grpc from pydantic import BaseModel, Field +from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext +from tests.services.storage.mock_storage_servicer import MockStorageServicer from digitalkin.models.grpc_servers.models import ClientConfig from digitalkin.services.storage.grpc_storage import GrpcStorage from digitalkin.services.storage.storage_strategy import DataType, StorageServiceError -from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext -from tests.services.storage.mock_storage_servicer import MockStorageServicer # Set timeout for all tests in this file (20 seconds) pytestmark = pytest.mark.timeout(20) @@ -118,8 +118,7 @@ def dummy_client_config() -> ClientConfig: Returns: ClientConfig instance with test values """ - from digitalkin.models.settings.utils.channel import SecurityMode - from digitalkin.models.settings.utils.channel import ControlFlow + from digitalkin.models.settings.utils.channel import ControlFlow, SecurityMode return ClientConfig( host="localhost", @@ -194,7 +193,7 @@ def test_store_record_success( _, request, rpc = test_channel.take_unary_unary(method_desc) # Verify request - assert request.mission_id == MISSION_ID + assert request.context == MISSION_ID assert request.collection == collection assert request.record_id == record_id # data_type is now a protobuf enum integer value @@ -215,7 +214,7 @@ def test_store_record_success( # Verify result assert result is not None - assert result.mission_id == MISSION_ID + assert result.context == MISSION_ID assert result.collection == collection assert result.record_id == record_id assert result.data_type == DataType.OUTPUT @@ -525,7 +524,7 @@ def test_read_record_success( read_future = thread_pool.submit(asyncio.run, client.read(collection, record_id)) _, read_request, read_rpc = test_channel.take_unary_unary(read_method_desc) - assert read_request.mission_id == MISSION_ID + assert read_request.context == MISSION_ID assert read_request.collection == collection assert read_request.record_id == record_id @@ -693,7 +692,7 @@ def test_update_record_success( update_future = thread_pool.submit(asyncio.run, client.update(collection, record_id, updated_data)) _, update_request, update_rpc = test_channel.take_unary_unary(update_method_desc) - assert update_request.mission_id == MISSION_ID + assert update_request.context == MISSION_ID assert update_request.collection == collection assert update_request.record_id == record_id @@ -819,7 +818,7 @@ def test_remove_record_success( remove_future = thread_pool.submit(asyncio.run, client.remove(collection, record_id)) _, remove_request, remove_rpc = test_channel.take_unary_unary(remove_method_desc) - assert remove_request.mission_id == MISSION_ID + assert remove_request.context == MISSION_ID assert remove_request.collection == collection assert remove_request.record_id == record_id @@ -983,7 +982,7 @@ def test_remove_collection_success( remove_future = thread_pool.submit(asyncio.run, client.remove_collection(collection)) _, remove_request, remove_rpc = test_channel.take_unary_unary(remove_coll_method_desc) - assert remove_request.mission_id == MISSION_ID + assert remove_request.context == MISSION_ID assert remove_request.collection == collection remove_context = FakeContext() @@ -1159,7 +1158,7 @@ def test_list_records_success( list_future = thread_pool.submit(asyncio.run, client.list(collection)) _, list_request, list_rpc = test_channel.take_unary_unary(list_method_desc) - assert list_request.mission_id == MISSION_ID + assert list_request.context == MISSION_ID assert list_request.collection == collection list_context = FakeContext() diff --git a/tests/services/storage/test_storage_strategy_locks.py b/tests/services/storage/test_storage_strategy_locks.py index b3730783..40dac94b 100644 --- a/tests/services/storage/test_storage_strategy_locks.py +++ b/tests/services/storage/test_storage_strategy_locks.py @@ -1,7 +1,5 @@ """Tests for StorageStrategy lock creation and cleanup.""" -import asyncio - import pytest from pydantic import BaseModel, Field @@ -11,7 +9,6 @@ class _SimpleModel(BaseModel): """Minimal model for lock tests.""" - mission_id: str = Field(default="m1") value: str = Field(default="v") @@ -19,38 +16,48 @@ class _InMemoryStorage(StorageStrategy): """Minimal concrete StorageStrategy backed by a dict.""" def __init__(self) -> None: - super().__init__("m1", "s1", "sv1", {"items": _SimpleModel}) + super().__init__("missions:m1", "s1", "setup_versions:sv1", {"items": _SimpleModel}) self._store_data: dict[str, StorageRecord] = {} + @staticmethod + def _key(context: str, collection: str, record_id: str) -> str: + return f"{context}|{collection}:{record_id}" + async def _store(self, record: StorageRecord) -> StorageRecord: - self._store_data[f"{record.collection}:{record.record_id}"] = record + self._store_data[self._key(record.context, record.collection, record.record_id)] = record return record - async def _read(self, collection: str, record_id: str) -> StorageRecord | None: - return self._store_data.get(f"{collection}:{record_id}") + async def _read(self, collection: str, record_id: str, context: str) -> StorageRecord | None: + return self._store_data.get(self._key(context, collection, record_id)) - async def _update(self, collection: str, record_id: str, data: BaseModel) -> StorageRecord | None: - key = f"{collection}:{record_id}" + async def _update( + self, collection: str, record_id: str, data: BaseModel, context: str + ) -> StorageRecord | None: + key = self._key(context, collection, record_id) rec = self._store_data.get(key) if rec is None: return None rec.data = data return rec - async def _remove(self, collection: str, record_id: str) -> bool: - return self._store_data.pop(f"{collection}:{record_id}", None) is not None + async def _remove(self, collection: str, record_id: str, context: str) -> bool: + return self._store_data.pop(self._key(context, collection, record_id), None) is not None - async def _list(self, collection: str) -> list[StorageRecord]: - return [r for k, r in self._store_data.items() if k.startswith(f"{collection}:")] + async def _list(self, collection: str, context: str) -> list[StorageRecord]: + prefix = f"{context}|{collection}:" + return [r for k, r in self._store_data.items() if k.startswith(prefix)] - async def _remove_collection(self, collection: str) -> bool: - prefix = f"{collection}:" + async def _remove_collection(self, collection: str, context: str) -> bool: + prefix = f"{context}|{collection}:" keys = [k for k in self._store_data if k.startswith(prefix)] for k in keys: del self._store_data[k] return bool(keys) +_MISSION_LOCK_PREFIX = "missions:m1|items:" + + class TestRecordLockAtomicity: """Tests for atomic lock creation via setdefault.""" @@ -58,16 +65,24 @@ class TestRecordLockAtomicity: async def test_record_lock_returns_same_instance(self) -> None: """Consecutive calls for same key return the same Lock object.""" storage = _InMemoryStorage() - lock1 = storage._record_lock("items", "r1") - lock2 = storage._record_lock("items", "r1") + lock1 = storage._record_lock("missions:m1", "items", "r1") + lock2 = storage._record_lock("missions:m1", "items", "r1") assert lock1 is lock2 @pytest.mark.asyncio async def test_different_keys_get_different_locks(self) -> None: """Different collection:record_id pairs get independent locks.""" storage = _InMemoryStorage() - lock1 = storage._record_lock("items", "r1") - lock2 = storage._record_lock("items", "r2") + lock1 = storage._record_lock("missions:m1", "items", "r1") + lock2 = storage._record_lock("missions:m1", "items", "r2") + assert lock1 is not lock2 + + @pytest.mark.asyncio + async def test_different_contexts_get_different_locks(self) -> None: + """Same collection:record_id under different contexts get independent locks.""" + storage = _InMemoryStorage() + lock1 = storage._record_lock("missions:m1", "items", "r1") + lock2 = storage._record_lock("setup_versions:sv1", "items", "r1") assert lock1 is not lock2 @@ -79,25 +94,25 @@ async def test_remove_cleans_up_lock(self) -> None: """Removing a record also removes its lock entry.""" storage = _InMemoryStorage() await storage.store("items", "r1", {"value": "x"}) - assert "items:r1" in storage._record_locks + assert f"{_MISSION_LOCK_PREFIX}r1" in storage._record_locks result = await storage.remove("items", "r1") assert result is True - assert "items:r1" not in storage._record_locks + assert f"{_MISSION_LOCK_PREFIX}r1" not in storage._record_locks @pytest.mark.asyncio async def test_remove_nonexistent_keeps_lock(self) -> None: """Removing a nonexistent record does not remove the lock.""" storage = _InMemoryStorage() # Create lock by accessing it - storage._record_lock("items", "r1") - assert "items:r1" in storage._record_locks + storage._record_lock("missions:m1", "items", "r1") + assert f"{_MISSION_LOCK_PREFIX}r1" in storage._record_locks result = await storage.remove("items", "r1") assert result is False - assert "items:r1" in storage._record_locks + assert f"{_MISSION_LOCK_PREFIX}r1" in storage._record_locks @pytest.mark.asyncio async def test_remove_collection_cleans_up_locks(self) -> None: @@ -106,12 +121,12 @@ async def test_remove_collection_cleans_up_locks(self) -> None: await storage.store("items", "r1", {"value": "a"}) await storage.store("items", "r2", {"value": "b"}) await storage.store("items", "r3", {"value": "c"}) - assert len([k for k in storage._record_locks if k.startswith("items:")]) == 3 + assert len([k for k in storage._record_locks if k.startswith(_MISSION_LOCK_PREFIX)]) == 3 result = await storage.remove_collection("items") assert result is True - assert not any(k.startswith("items:") for k in storage._record_locks) + assert not any(k.startswith(_MISSION_LOCK_PREFIX) for k in storage._record_locks) @pytest.mark.asyncio async def test_remove_collection_preserves_other_collection_locks(self) -> None: @@ -123,5 +138,5 @@ async def test_remove_collection_preserves_other_collection_locks(self) -> None: await storage.remove_collection("items") - assert "items:r1" not in storage._record_locks - assert "other:r1" in storage._record_locks + assert f"{_MISSION_LOCK_PREFIX}r1" not in storage._record_locks + assert "missions:m1|other:r1" in storage._record_locks diff --git a/uv.lock b/uv.lock index c9f89002..fc18cbc7 100644 --- a/uv.lock +++ b/uv.lock @@ -23,7 +23,7 @@ wheels = [ [[package]] name = "agentic-mesh-protocol" -version = "0.2.3" +version = "0.2.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "bump-my-version" }, @@ -33,9 +33,9 @@ dependencies = [ { name = "protobuf" }, { name = "protovalidate" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4e/e7/8679acaa44b01bbc858275d3fa262e420d61f2a40598200e728924d1d247/agentic_mesh_protocol-0.2.3.tar.gz", hash = "sha256:a542f476d61b4d5acd3f03e7318cbd837ff016be996c1f80ad120222be2d1d95", size = 78843, upload-time = "2026-03-04T16:26:32.86Z" } +sdist = { url = "https://files.pythonhosted.org/packages/49/cf/35df606a8bdea46441ed8832ee1330860bd7d4e7724b9308a9ba3430babd/agentic_mesh_protocol-0.2.4.tar.gz", hash = "sha256:ee856bd5c891875418162af8251f4dd2df7ce3f50b713d038a77a4369ead1115", size = 79521, upload-time = "2026-05-06T15:50:49.989Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/41/d487e2505531797aba03129187d1b4247499f37a57288a7b2b53d18050c0/agentic_mesh_protocol-0.2.3-py3-none-any.whl", hash = "sha256:ced7c0e4ca2e71ae02cfd8c908c7259165f8d4d98dad85a3e7b86ea0d701b1a9", size = 118882, upload-time = "2026-03-04T16:26:31.258Z" }, + { url = "https://files.pythonhosted.org/packages/86/89/89cc28b35ffa6a6a68f22cc6236125b811aa1813d13abd7a32a76afee4e9/agentic_mesh_protocol-0.2.4-py3-none-any.whl", hash = "sha256:fece3e8293d0b74674453735fa6b13323043cc621fe8208e188d5f6bacd863fa", size = 119660, upload-time = "2026-05-06T15:50:48.18Z" }, ] [[package]] @@ -850,7 +850,7 @@ wheels = [ [[package]] name = "digitalkin" -version = "0.4.1.dev5" +version = "0.4.2" source = { editable = "." } dependencies = [ { name = "ag-ui-protocol" }, @@ -931,7 +931,7 @@ tests = [ [package.metadata] requires-dist = [ { name = "ag-ui-protocol", specifier = ">=0.1.14" }, - { name = "agentic-mesh-protocol", specifier = "==0.2.3" }, + { name = "agentic-mesh-protocol", specifier = "==0.2.4" }, { name = "anyio", specifier = "==4.13.0" }, { name = "asyncio-inspector", marker = "extra == 'profiling'", specifier = "==0.1.0" }, { name = "grpcio-health-checking", specifier = "==1.78.0" },