From a472c2a967721813f94e7ed4ec21b688574177a3 Mon Sep 17 00:00:00 2001 From: r0ny123 <49360849+r0ny123@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:21:00 +0000 Subject: [PATCH 1/6] Add allow_self_return flag to createMatchingCache - Updated StorageInterface.createMatchingCache signature to include allow_self_return. - Modified MemoryStorage to return self when allow_self_return=True. - Implemented addFunctionEntriesToCache in MemoryStorage for MatchingCache compatibility. - Updated MemoryStorage.getMinHashByFunctionId to handle query functions. - Updated MongoDbStorage.createMatchingCache signature for consistency. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- mcrit/storage/MemoryStorage.py | 17 +++++++++++++++-- mcrit/storage/MongoDbStorage.py | 2 +- mcrit/storage/StorageInterface.py | 3 ++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/mcrit/storage/MemoryStorage.py b/mcrit/storage/MemoryStorage.py index 7de102d..f3d067d 100644 --- a/mcrit/storage/MemoryStorage.py +++ b/mcrit/storage/MemoryStorage.py @@ -483,11 +483,21 @@ def addMinHashes(self, minhashes: List["MinHash"]) -> None: for minhash in minhashes: self.addMinHash(minhash) - def createMatchingCache(self, function_ids: List[int]) -> MatchingCache: - # TODO: we might want add a flag to allow/disallow returning self + def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> Union["MemoryStorage", "MatchingCache"]: + if allow_self_return: + return self cache_data = self._getCacheDataForFunctionIds(function_ids) return MatchingCache(cache_data) + def addFunctionEntriesToCache(self, function_entries: List["FunctionEntry"]) -> None: + for function_entry in function_entries: + if function_entry.function_id < 0: + self._query_functions[function_entry.function_id] = function_entry + else: + self._functions[function_entry.function_id] = function_entry + if function_entry.function_id not in self._sample_id_to_function_ids[function_entry.sample_id]: + self._sample_id_to_function_ids[function_entry.sample_id].append(function_entry.function_id) + def _getCacheDataForFunctionIds(self, function_ids: List[int]) -> Dict: cache_data = {} sample_ids = {} @@ -678,6 +688,9 @@ def getMinHashByFunctionId(self, function_id: int) -> Optional[bytes]: if function_id in self._functions: function_entry = self._functions[function_id] return function_entry.minhash + if function_id in self._query_functions: + function_entry = self._query_functions[function_id] + return function_entry.minhash return None # -> Dict[function_id, Set[function_id]] diff --git a/mcrit/storage/MongoDbStorage.py b/mcrit/storage/MongoDbStorage.py index 976ffb1..82270ef 100644 --- a/mcrit/storage/MongoDbStorage.py +++ b/mcrit/storage/MongoDbStorage.py @@ -1011,7 +1011,7 @@ def _getFunctionDocument( self._encodeFunction(function_dict) return function_dict - def createMatchingCache(self, function_ids: List[int]) -> MatchingCache: + def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> MatchingCache: cache_data = self._getCacheDataForFunctionIds(function_ids) # TODO dont store this as attribute self._matching_cache = MatchingCache(cache_data) diff --git a/mcrit/storage/StorageInterface.py b/mcrit/storage/StorageInterface.py index d90069b..82e4171 100644 --- a/mcrit/storage/StorageInterface.py +++ b/mcrit/storage/StorageInterface.py @@ -367,11 +367,12 @@ def clearMatchingCache(self) -> None: # TODO: make a MatchingCacheInterface, or MemoryStorage a subclass of MatchingCache? # TODO rename -> get? - def createMatchingCache(self, function_ids: List[int]) -> Union["MemoryStorage", "MatchingCache"]: + def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> Union["MemoryStorage", "MatchingCache"]: """Creates a temporary matching cache, for a list of function_ids Args: function_ids: list of function ids + allow_self_return: (optional) if True, allows returning self as cache Returns: a matching cache for the specified list of function ids From f30f3d39bfb94534c661ca8b7dd045f856981c5b Mon Sep 17 00:00:00 2001 From: r0ny123 <49360849+r0ny123@users.noreply.github.com> Date: Mon, 6 Apr 2026 08:23:11 +0000 Subject: [PATCH 2/6] Add allow_self_return flag to createMatchingCache and ensure compatibility - Updated StorageInterface, MemoryStorage, and MongoDbStorage. - Implemented addFunctionEntriesToCache in MemoryStorage. - Handled query functions in MemoryStorage.getMinHashByFunctionId. - Addressed PR feedback regarding MatchingCache compatibility. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- mcrit/storage/MemoryStorage.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mcrit/storage/MemoryStorage.py b/mcrit/storage/MemoryStorage.py index f3d067d..c0d84de 100644 --- a/mcrit/storage/MemoryStorage.py +++ b/mcrit/storage/MemoryStorage.py @@ -537,7 +537,7 @@ def getFamilyIds(self) -> List[int]: return deepcopy(list(self._families.keys())) def isSampleId(self, sample_id: int) -> bool: - return sample_id in self._samples or sample_id in self._query_samples + return sample_id in self._samples or sample_id in self._query_samples or sample_id in self._sample_id_to_function_ids def isFunctionId(self, function_id: int) -> bool: return function_id in self._functions or function_id in self._query_functions @@ -609,10 +609,9 @@ def getFunctionsBySampleId(self, sample_id: int) -> Optional[List["FunctionEntry return None def getFunctionIdsBySampleId(self, sample_id: int) -> Optional[List["int"]]: - function_ids = None - if sample_id in self._samples or sample_id in self._query_samples: - function_ids = self._sample_id_to_function_ids[sample_id] - return function_ids + if sample_id in self._sample_id_to_function_ids: + return self._sample_id_to_function_ids[sample_id] + return None def getFunctions(self, start_index: int, limit: int) -> Optional["FunctionEntry"]: index = 0 From 5d2732d3ab4807a4708dac344454bb71f05fb3ac Mon Sep 17 00:00:00 2001 From: r0ny123 <49360849+r0ny123@users.noreply.github.com> Date: Tue, 7 Apr 2026 00:32:59 +0530 Subject: [PATCH 3/6] Fix matching cache isolation in MemoryStorage --- mcrit/matchers/MatcherInterface.py | 3 +- mcrit/matchers/MatcherQuery.py | 2 +- mcrit/matchers/MatcherQueryFunction.py | 2 +- mcrit/storage/MatchingCache.py | 41 +++++++++++++++++++++++--- mcrit/storage/MemoryStorage.py | 26 ++++++---------- mcrit/storage/StorageInterface.py | 7 ++--- tests/testStorage.py | 17 +++++++++++ 7 files changed, 69 insertions(+), 29 deletions(-) diff --git a/mcrit/matchers/MatcherInterface.py b/mcrit/matchers/MatcherInterface.py index 65316d9..cc9d400 100644 --- a/mcrit/matchers/MatcherInterface.py +++ b/mcrit/matchers/MatcherInterface.py @@ -15,7 +15,6 @@ from mcrit.storage.FunctionEntry import FunctionEntry from mcrit.storage.SampleEntry import SampleEntry from mcrit.storage.MatchingCache import MatchingCache - from mcrit.storage.MemoryStorage import MemoryStorage from mcrit.Worker import Worker @@ -245,7 +244,7 @@ def _countPackedTuples(self, candidate_pairs) -> int: return quotient + int(bool(remainder)) # always round up def _unrollGroupsAsPackedTuples( - self, cache: Union["MatchingCache", "MemoryStorage"], candidate_pairs + self, cache: "MatchingCache", candidate_pairs ) -> Iterable[List[Tuple[int, int, bytes, int, int, bytes]]]: # Query, VS, Sample # All were identical diff --git a/mcrit/matchers/MatcherQuery.py b/mcrit/matchers/MatcherQuery.py index c79e6d5..7460a28 100644 --- a/mcrit/matchers/MatcherQuery.py +++ b/mcrit/matchers/MatcherQuery.py @@ -57,6 +57,6 @@ def _createMatchingCache(self, candidate_groups): for function_id in other_function_ids: if function_id >= 0: function_ids_from_storage.add(function_id) - cache = self._storage.createMatchingCache(function_ids_from_storage) + cache = self._storage.createMatchingCache(function_ids_from_storage, allow_self_return=True) cache.addFunctionEntriesToCache(self._function_entries) return cache diff --git a/mcrit/matchers/MatcherQueryFunction.py b/mcrit/matchers/MatcherQueryFunction.py index 95bf723..e5e5c0a 100644 --- a/mcrit/matchers/MatcherQueryFunction.py +++ b/mcrit/matchers/MatcherQueryFunction.py @@ -59,6 +59,6 @@ def _createMatchingCache(self, candidate_groups): for function_id in other_function_ids: if function_id >= 0: function_ids_from_storage.add(function_id) - cache = self._storage.createMatchingCache(function_ids_from_storage) + cache = self._storage.createMatchingCache(function_ids_from_storage, allow_self_return=True) cache.addFunctionEntriesToCache(self._function_entries) return cache diff --git a/mcrit/storage/MatchingCache.py b/mcrit/storage/MatchingCache.py index e4fa012..7cc4236 100644 --- a/mcrit/storage/MatchingCache.py +++ b/mcrit/storage/MatchingCache.py @@ -6,6 +6,19 @@ def __init__(self, cache_data): self._func_id_to_sample_id = cache_data["func_id_to_sample_id"] self._sample_id_to_func_ids = cache_data["sample_id_to_func_ids"] + def _setFunctionEntry(self, function_id, sample_id, minhash): + if function_id in self._func_id_to_sample_id: + old_sample_id = self._func_id_to_sample_id[function_id] + if old_sample_id in self._sample_id_to_func_ids: + self._sample_id_to_func_ids[old_sample_id].discard(function_id) + if not self._sample_id_to_func_ids[old_sample_id]: + del self._sample_id_to_func_ids[old_sample_id] + self._func_id_to_minhash[function_id] = minhash + self._func_id_to_sample_id[function_id] = sample_id + if sample_id not in self._sample_id_to_func_ids: + self._sample_id_to_func_ids[sample_id] = set() + self._sample_id_to_func_ids[sample_id].add(function_id) + def isSampleId(self, sample_id): return sample_id in self._sample_id_to_func_ids @@ -20,9 +33,29 @@ def getFunctionIdsBySampleId(self, sample_id): def addFunctionEntriesToCache(self, function_entries): for function_entry in function_entries: - self._func_id_to_minhash[function_entry.function_id] = function_entry.minhash - self._func_id_to_sample_id[function_entry.function_id] = function_entry.sample_id - sample_id = function_entry.sample_id + self._setFunctionEntry(function_entry.function_id, function_entry.sample_id, function_entry.minhash) + + +class StorageBackedMatchingCache(MatchingCache): + """A matching cache that reuses storage-backed data without mutating the underlying storage.""" + + def __init__(self, storage, function_ids): + self._storage = storage + self._func_id_to_minhash = {} + self._func_id_to_sample_id = {} + self._sample_id_to_func_ids = {} + for function_id in function_ids: + sample_id = self._storage.getSampleIdByFunctionId(function_id) + if sample_id is None: + raise KeyError(function_id) + self._func_id_to_sample_id[function_id] = sample_id if sample_id not in self._sample_id_to_func_ids: self._sample_id_to_func_ids[sample_id] = set() - self._sample_id_to_func_ids[sample_id].add(function_entry.function_id) + self._sample_id_to_func_ids[sample_id].add(function_id) + + def getMinHashByFunctionId(self, function_id): + if function_id in self._func_id_to_minhash: + return self._func_id_to_minhash[function_id] + if function_id not in self._func_id_to_sample_id: + raise KeyError(function_id) + return self._storage.getMinHashByFunctionId(function_id) diff --git a/mcrit/storage/MemoryStorage.py b/mcrit/storage/MemoryStorage.py index c0d84de..f21319b 100644 --- a/mcrit/storage/MemoryStorage.py +++ b/mcrit/storage/MemoryStorage.py @@ -17,7 +17,7 @@ from mcrit.storage.FamilyEntry import FamilyEntry from mcrit.storage.FunctionEntry import FunctionEntry from mcrit.storage.FunctionLabelEntry import FunctionLabelEntry -from mcrit.storage.MatchingCache import MatchingCache +from mcrit.storage.MatchingCache import MatchingCache, StorageBackedMatchingCache from mcrit.storage.SampleEntry import SampleEntry from mcrit.storage.StorageInterface import StorageInterface @@ -483,21 +483,12 @@ def addMinHashes(self, minhashes: List["MinHash"]) -> None: for minhash in minhashes: self.addMinHash(minhash) - def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> Union["MemoryStorage", "MatchingCache"]: + def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> MatchingCache: if allow_self_return: - return self + return StorageBackedMatchingCache(self, function_ids) cache_data = self._getCacheDataForFunctionIds(function_ids) return MatchingCache(cache_data) - def addFunctionEntriesToCache(self, function_entries: List["FunctionEntry"]) -> None: - for function_entry in function_entries: - if function_entry.function_id < 0: - self._query_functions[function_entry.function_id] = function_entry - else: - self._functions[function_entry.function_id] = function_entry - if function_entry.function_id not in self._sample_id_to_function_ids[function_entry.sample_id]: - self._sample_id_to_function_ids[function_entry.sample_id].append(function_entry.function_id) - def _getCacheDataForFunctionIds(self, function_ids: List[int]) -> Dict: cache_data = {} sample_ids = {} @@ -537,7 +528,7 @@ def getFamilyIds(self) -> List[int]: return deepcopy(list(self._families.keys())) def isSampleId(self, sample_id: int) -> bool: - return sample_id in self._samples or sample_id in self._query_samples or sample_id in self._sample_id_to_function_ids + return sample_id in self._samples or sample_id in self._query_samples def isFunctionId(self, function_id: int) -> bool: return function_id in self._functions or function_id in self._query_functions @@ -609,9 +600,10 @@ def getFunctionsBySampleId(self, sample_id: int) -> Optional[List["FunctionEntry return None def getFunctionIdsBySampleId(self, sample_id: int) -> Optional[List["int"]]: - if sample_id in self._sample_id_to_function_ids: - return self._sample_id_to_function_ids[sample_id] - return None + function_ids = None + if sample_id in self._samples or sample_id in self._query_samples: + function_ids = self._sample_id_to_function_ids[sample_id] + return function_ids def getFunctions(self, start_index: int, limit: int) -> Optional["FunctionEntry"]: index = 0 @@ -962,4 +954,4 @@ def findFunctionByString(self, search_tree: NodeType, cursor: Optional[FullSearc result_dict[entry.function_id] = entry if len(result_dict) >= max_num_results: break - return result_dict \ No newline at end of file + return result_dict diff --git a/mcrit/storage/StorageInterface.py b/mcrit/storage/StorageInterface.py index 82e4171..736f22e 100644 --- a/mcrit/storage/StorageInterface.py +++ b/mcrit/storage/StorageInterface.py @@ -10,7 +10,6 @@ from mcrit.config.MinHashConfig import MinHashConfig from mcrit.storage.FunctionEntry import FunctionEntry from mcrit.storage.MatchingCache import MatchingCache - from mcrit.storage.MemoryStorage import MemoryStorage from mcrit.storage.SampleEntry import SampleEntry from smda.common.SmdaFunction import SmdaFunction from smda.common.SmdaReport import SmdaReport @@ -365,14 +364,14 @@ def clearMatchingCache(self) -> None: """ raise NotImplementedError - # TODO: make a MatchingCacheInterface, or MemoryStorage a subclass of MatchingCache? + # TODO: make a MatchingCacheInterface for all backends. # TODO rename -> get? - def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> Union["MemoryStorage", "MatchingCache"]: + def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> "MatchingCache": """Creates a temporary matching cache, for a list of function_ids Args: function_ids: list of function ids - allow_self_return: (optional) if True, allows returning self as cache + allow_self_return: (optional) if True, allows a backend-specific optimized cache implementation Returns: a matching cache for the specified list of function ids diff --git a/tests/testStorage.py b/tests/testStorage.py index 5f4b51e..cf0f4e8 100644 --- a/tests/testStorage.py +++ b/tests/testStorage.py @@ -302,6 +302,23 @@ def testMatchingCache(self): self.assertTrue(hasattr(cache, "getMinHashByFunctionId")) self.assertTrue(hasattr(cache, "getSampleIdByFunctionId")) + def testMatchingCacheAllowSelfReturnDoesNotMutateStorage(self): + self.storage.clearStorage() + smda_report = SmdaReport.fromFile(self.example_file_path) + self.storage.addSmdaReport(smda_report) + + cache = self.storage.createMatchingCache([0], allow_self_return=True) + cache_only_entry = self.storage.getFunctionById(0, with_xcfg=True) + cache_only_entry.sample_id = 999 + cache_only_entry.minhash = b"cache-only-minhash" + cache.addFunctionEntriesToCache([cache_only_entry]) + + self.assertEqual(0, self.storage.getSampleIdByFunctionId(0)) + self.assertNotEqual(b"cache-only-minhash", self.storage.getMinHashByFunctionId(0)) + self.assertEqual(999, cache.getSampleIdByFunctionId(0)) + self.assertEqual(b"cache-only-minhash", cache.getMinHashByFunctionId(0)) + self.assertEqual(set([0]), cache.getFunctionIdsBySampleId(999)) + ### Added mongo attribute import pytest From 380a6743894edfd8e6be3ae651d2b5ea0a9defd7 Mon Sep 17 00:00:00 2001 From: Rony <49360849+r0ny123@users.noreply.github.com> Date: Tue, 7 Apr 2026 00:43:46 +0530 Subject: [PATCH 4/6] Update mcrit/storage/MatchingCache.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- mcrit/storage/MatchingCache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcrit/storage/MatchingCache.py b/mcrit/storage/MatchingCache.py index 7cc4236..4c2503e 100644 --- a/mcrit/storage/MatchingCache.py +++ b/mcrit/storage/MatchingCache.py @@ -44,7 +44,7 @@ def __init__(self, storage, function_ids): self._func_id_to_minhash = {} self._func_id_to_sample_id = {} self._sample_id_to_func_ids = {} - for function_id in function_ids: + for function_id in set(function_ids): sample_id = self._storage.getSampleIdByFunctionId(function_id) if sample_id is None: raise KeyError(function_id) From cd6499af780ae958b72bc9aa30ce8207aaac8318 Mon Sep 17 00:00:00 2001 From: r0ny123 <49360849+r0ny123@users.noreply.github.com> Date: Tue, 7 Apr 2026 00:51:01 +0530 Subject: [PATCH 5/6] Optimize matching cache lookups --- mcrit/storage/MatchingCache.py | 12 +++---- mcrit/storage/MemoryStorage.py | 13 ++++++-- mcrit/storage/MongoDbStorage.py | 54 +++++++++++++++++++++++-------- mcrit/storage/StorageInterface.py | 11 +++++++ tests/testStorage.py | 17 ++++++++++ 5 files changed, 86 insertions(+), 21 deletions(-) diff --git a/mcrit/storage/MatchingCache.py b/mcrit/storage/MatchingCache.py index 4c2503e..7c6940d 100644 --- a/mcrit/storage/MatchingCache.py +++ b/mcrit/storage/MatchingCache.py @@ -42,13 +42,13 @@ class StorageBackedMatchingCache(MatchingCache): def __init__(self, storage, function_ids): self._storage = storage self._func_id_to_minhash = {} - self._func_id_to_sample_id = {} + unique_function_ids = set(function_ids) + self._func_id_to_sample_id = dict(self._storage.getSampleIdsByFunctionIds(list(unique_function_ids))) self._sample_id_to_func_ids = {} - for function_id in set(function_ids): - sample_id = self._storage.getSampleIdByFunctionId(function_id) - if sample_id is None: - raise KeyError(function_id) - self._func_id_to_sample_id[function_id] = sample_id + missing_function_ids = unique_function_ids.difference(self._func_id_to_sample_id) + if missing_function_ids: + raise KeyError(missing_function_ids.pop()) + for function_id, sample_id in self._func_id_to_sample_id.items(): if sample_id not in self._sample_id_to_func_ids: self._sample_id_to_func_ids[sample_id] = set() self._sample_id_to_func_ids[sample_id].add(function_id) diff --git a/mcrit/storage/MemoryStorage.py b/mcrit/storage/MemoryStorage.py index f21319b..726da7a 100644 --- a/mcrit/storage/MemoryStorage.py +++ b/mcrit/storage/MemoryStorage.py @@ -494,8 +494,8 @@ def _getCacheDataForFunctionIds(self, function_ids: List[int]) -> Dict: sample_ids = {} sample_to_func_ids = {} minhashes = {} - for function_id in function_ids: - function_entry = self._functions[function_id] + for function_id in set(function_ids): + function_entry = self._query_functions[function_id] if function_id < 0 else self._functions[function_id] function_id = function_entry.function_id sample_id = function_entry.sample_id minhashes[function_id] = function_entry.minhash @@ -638,6 +638,15 @@ def getSampleIdByFunctionId(self, function_id: int) -> Optional[int]: sample_id = self._query_functions[function_id].sample_id return sample_id + def getSampleIdsByFunctionIds(self, function_ids: List[int]) -> Dict[int, int]: + sample_ids = {} + for function_id in set(function_ids): + if function_id in self._functions: + sample_ids[function_id] = self._functions[function_id].sample_id + elif function_id in self._query_functions: + sample_ids[function_id] = self._query_functions[function_id].sample_id + return sample_ids + def getSamples(self, start_index: int, limit: int) -> Optional["SampleEntry"]: index = 0 sample_entries = [] diff --git a/mcrit/storage/MongoDbStorage.py b/mcrit/storage/MongoDbStorage.py index 82270ef..494afd6 100644 --- a/mcrit/storage/MongoDbStorage.py +++ b/mcrit/storage/MongoDbStorage.py @@ -401,6 +401,25 @@ def getSampleIdByFunctionId(self, function_id: int) -> Optional[int]: return None return function_document["sample_id"] + def getSampleIdsByFunctionIds(self, function_ids: List[int]) -> Dict[int, int]: + sample_ids = {} + positive_function_ids = [function_id for function_id in set(function_ids) if function_id >= 0] + negative_function_ids = [function_id for function_id in set(function_ids) if function_id < 0] + for collection_name, collection_ids in ( + ("functions", positive_function_ids), + ("query_functions", negative_function_ids), + ): + if not collection_ids: + continue + for sliced_ids in zip_longest(*[iter(collection_ids)] * 500000): + query_function_ids = [function_id for function_id in sliced_ids if function_id is not None] + for function_document in self._getDb()[collection_name].find( + {"function_id": {"$in": query_function_ids}}, + {"_id": 0, "function_id": 1, "sample_id": 1}, + ): + sample_ids[function_document["function_id"]] = function_document["sample_id"] + return sample_ids + def deleteSample(self, sample_id: int) -> bool: sample_entry = self.getSampleById(sample_id) if sample_entry is None: @@ -924,19 +943,28 @@ def _getCacheDataForFunctionIds(self, function_ids: List[int]) -> Dict: sample_to_func_ids = {} minhashes = {} # process this in batches as the number of function_ids can be exceedingly large, pushing beyond Mongo's 16M limit - for sliced_ids in zip_longest(*[iter(function_ids)]*500000): - query_function_ids = [fid for fid in sliced_ids if fid is not None] - for function_document in self._getDb().functions.find( - {"function_id": {"$in": list(query_function_ids)}}, {"_id": 0, "sample_id": 1, "minhash": 1, "function_id": 1} - ): - function_id = function_document["function_id"] - sample_id = function_document["sample_id"] - minhash = bytes.fromhex(function_document["minhash"]) - minhashes[function_id] = minhash - sample_ids[function_id] = sample_id - if sample_id not in sample_to_func_ids: - sample_to_func_ids[sample_id] = set() - sample_to_func_ids[sample_id].add(function_id) + positive_function_ids = [function_id for function_id in set(function_ids) if function_id >= 0] + negative_function_ids = [function_id for function_id in set(function_ids) if function_id < 0] + for collection_name, collection_ids in ( + ("functions", positive_function_ids), + ("query_functions", negative_function_ids), + ): + if not collection_ids: + continue + for sliced_ids in zip_longest(*[iter(collection_ids)] * 500000): + query_function_ids = [function_id for function_id in sliced_ids if function_id is not None] + for function_document in self._getDb()[collection_name].find( + {"function_id": {"$in": query_function_ids}}, + {"_id": 0, "sample_id": 1, "minhash": 1, "function_id": 1}, + ): + function_id = function_document["function_id"] + sample_id = function_document["sample_id"] + minhash = bytes.fromhex(function_document["minhash"]) + minhashes[function_id] = minhash + sample_ids[function_id] = sample_id + if sample_id not in sample_to_func_ids: + sample_to_func_ids[sample_id] = set() + sample_to_func_ids[sample_id].add(function_id) cache_data["func_id_to_minhash"] = minhashes cache_data["func_id_to_sample_id"] = sample_ids cache_data["sample_id_to_func_ids"] = sample_to_func_ids diff --git a/mcrit/storage/StorageInterface.py b/mcrit/storage/StorageInterface.py index 736f22e..9f566d2 100644 --- a/mcrit/storage/StorageInterface.py +++ b/mcrit/storage/StorageInterface.py @@ -256,6 +256,17 @@ def getSampleIdByFunctionId(self, function_id: int) -> Optional[int]: """ raise NotImplementedError + def getSampleIdsByFunctionIds(self, function_ids: List[int]) -> Dict[int, int]: + """For a given list of function_ids, return the corresponding sample_ids. + + Args: + function_ids: a list of function ids + + Returns: + a dict mapping function_id to sample_id for all function_ids found + """ + raise NotImplementedError + def getSampleById(self, sample_id: int) -> Optional["SampleEntry"]: """Given a sample_id, return the respective SampleEntry or None, if sample_id was not found. diff --git a/tests/testStorage.py b/tests/testStorage.py index cf0f4e8..e229803 100644 --- a/tests/testStorage.py +++ b/tests/testStorage.py @@ -319,6 +319,23 @@ def testMatchingCacheAllowSelfReturnDoesNotMutateStorage(self): self.assertEqual(b"cache-only-minhash", cache.getMinHashByFunctionId(0)) self.assertEqual(set([0]), cache.getFunctionIdsBySampleId(999)) + def testMatchingCacheSupportsQueryFunctions(self): + self.storage.clearStorage() + smda_report = SmdaReport.fromFile(self.example_file_path) + query_sample = self.storage.addSmdaReport(smda_report, isQuery=True) + assert query_sample is not None + query_function_ids = self.storage.getFunctionIdsBySampleId(query_sample.sample_id) + assert query_function_ids is not None + query_function_id = query_function_ids[0] + query_function = self.storage.getFunctionById(query_function_id) + assert query_function is not None + + for allow_self_return in (False, True): + cache = self.storage.createMatchingCache([query_function_id], allow_self_return=allow_self_return) + self.assertEqual(query_sample.sample_id, cache.getSampleIdByFunctionId(query_function_id)) + self.assertEqual(query_function.minhash, cache.getMinHashByFunctionId(query_function_id)) + self.assertEqual(set([query_function_id]), set(cache.getFunctionIdsBySampleId(query_sample.sample_id))) + ### Added mongo attribute import pytest From 15d7301ce7c6f3ee86a184ee8cca92b562cc5652 Mon Sep 17 00:00:00 2001 From: r0ny123 <49360849+r0ny123@users.noreply.github.com> Date: Tue, 7 Apr 2026 01:08:24 +0530 Subject: [PATCH 6/6] Fix MongoDB function id lookup --- mcrit/storage/MongoDbStorage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mcrit/storage/MongoDbStorage.py b/mcrit/storage/MongoDbStorage.py index 494afd6..a81fcc9 100644 --- a/mcrit/storage/MongoDbStorage.py +++ b/mcrit/storage/MongoDbStorage.py @@ -724,7 +724,7 @@ def getFunctionsBySampleId(self, sample_id: int) -> Optional[List["FunctionEntry functions.append(FunctionEntry.fromDict(f)) return functions - def getFunctionIdsBySampleId(self, sample_id: int) -> Optional[List["FunctionEntry"]]: + def getFunctionIdsBySampleId(self, sample_id: int) -> Optional[List[int]]: function_ids = None if not self.isSampleId(sample_id): return function_ids @@ -734,7 +734,7 @@ def getFunctionIdsBySampleId(self, sample_id: int) -> Optional[List["FunctionEnt else: function_dicts = list(self._getDb().functions.find({"sample_id": sample_id}, {"_id": 0, "function_id": 1})) for f in function_dicts: - function_ids(f["function_ids"]) + function_ids.append(f["function_id"]) return function_ids def getFunctions(self, start_index: int, limit: int) -> Optional["FunctionEntry"]: