diff --git a/mcrit/matchers/MatcherInterface.py b/mcrit/matchers/MatcherInterface.py index 65316d9..cc9d400 100644 --- a/mcrit/matchers/MatcherInterface.py +++ b/mcrit/matchers/MatcherInterface.py @@ -15,7 +15,6 @@ from mcrit.storage.FunctionEntry import FunctionEntry from mcrit.storage.SampleEntry import SampleEntry from mcrit.storage.MatchingCache import MatchingCache - from mcrit.storage.MemoryStorage import MemoryStorage from mcrit.Worker import Worker @@ -245,7 +244,7 @@ def _countPackedTuples(self, candidate_pairs) -> int: return quotient + int(bool(remainder)) # always round up def _unrollGroupsAsPackedTuples( - self, cache: Union["MatchingCache", "MemoryStorage"], candidate_pairs + self, cache: "MatchingCache", candidate_pairs ) -> Iterable[List[Tuple[int, int, bytes, int, int, bytes]]]: # Query, VS, Sample # All were identical diff --git a/mcrit/matchers/MatcherQuery.py b/mcrit/matchers/MatcherQuery.py index c79e6d5..7460a28 100644 --- a/mcrit/matchers/MatcherQuery.py +++ b/mcrit/matchers/MatcherQuery.py @@ -57,6 +57,6 @@ def _createMatchingCache(self, candidate_groups): for function_id in other_function_ids: if function_id >= 0: function_ids_from_storage.add(function_id) - cache = self._storage.createMatchingCache(function_ids_from_storage) + cache = self._storage.createMatchingCache(function_ids_from_storage, allow_self_return=True) cache.addFunctionEntriesToCache(self._function_entries) return cache diff --git a/mcrit/matchers/MatcherQueryFunction.py b/mcrit/matchers/MatcherQueryFunction.py index 95bf723..e5e5c0a 100644 --- a/mcrit/matchers/MatcherQueryFunction.py +++ b/mcrit/matchers/MatcherQueryFunction.py @@ -59,6 +59,6 @@ def _createMatchingCache(self, candidate_groups): for function_id in other_function_ids: if function_id >= 0: function_ids_from_storage.add(function_id) - cache = self._storage.createMatchingCache(function_ids_from_storage) + cache = self._storage.createMatchingCache(function_ids_from_storage, allow_self_return=True) cache.addFunctionEntriesToCache(self._function_entries) return cache diff --git a/mcrit/storage/MatchingCache.py b/mcrit/storage/MatchingCache.py index e4fa012..7c6940d 100644 --- a/mcrit/storage/MatchingCache.py +++ b/mcrit/storage/MatchingCache.py @@ -6,6 +6,19 @@ def __init__(self, cache_data): self._func_id_to_sample_id = cache_data["func_id_to_sample_id"] self._sample_id_to_func_ids = cache_data["sample_id_to_func_ids"] + def _setFunctionEntry(self, function_id, sample_id, minhash): + if function_id in self._func_id_to_sample_id: + old_sample_id = self._func_id_to_sample_id[function_id] + if old_sample_id in self._sample_id_to_func_ids: + self._sample_id_to_func_ids[old_sample_id].discard(function_id) + if not self._sample_id_to_func_ids[old_sample_id]: + del self._sample_id_to_func_ids[old_sample_id] + self._func_id_to_minhash[function_id] = minhash + self._func_id_to_sample_id[function_id] = sample_id + if sample_id not in self._sample_id_to_func_ids: + self._sample_id_to_func_ids[sample_id] = set() + self._sample_id_to_func_ids[sample_id].add(function_id) + def isSampleId(self, sample_id): return sample_id in self._sample_id_to_func_ids @@ -20,9 +33,29 @@ def getFunctionIdsBySampleId(self, sample_id): def addFunctionEntriesToCache(self, function_entries): for function_entry in function_entries: - self._func_id_to_minhash[function_entry.function_id] = function_entry.minhash - self._func_id_to_sample_id[function_entry.function_id] = function_entry.sample_id - sample_id = function_entry.sample_id + self._setFunctionEntry(function_entry.function_id, function_entry.sample_id, function_entry.minhash) + + +class StorageBackedMatchingCache(MatchingCache): + """A matching cache that reuses storage-backed data without mutating the underlying storage.""" + + def __init__(self, storage, function_ids): + self._storage = storage + self._func_id_to_minhash = {} + unique_function_ids = set(function_ids) + self._func_id_to_sample_id = dict(self._storage.getSampleIdsByFunctionIds(list(unique_function_ids))) + self._sample_id_to_func_ids = {} + missing_function_ids = unique_function_ids.difference(self._func_id_to_sample_id) + if missing_function_ids: + raise KeyError(missing_function_ids.pop()) + for function_id, sample_id in self._func_id_to_sample_id.items(): if sample_id not in self._sample_id_to_func_ids: self._sample_id_to_func_ids[sample_id] = set() - self._sample_id_to_func_ids[sample_id].add(function_entry.function_id) + self._sample_id_to_func_ids[sample_id].add(function_id) + + def getMinHashByFunctionId(self, function_id): + if function_id in self._func_id_to_minhash: + return self._func_id_to_minhash[function_id] + if function_id not in self._func_id_to_sample_id: + raise KeyError(function_id) + return self._storage.getMinHashByFunctionId(function_id) diff --git a/mcrit/storage/MemoryStorage.py b/mcrit/storage/MemoryStorage.py index 7de102d..726da7a 100644 --- a/mcrit/storage/MemoryStorage.py +++ b/mcrit/storage/MemoryStorage.py @@ -17,7 +17,7 @@ from mcrit.storage.FamilyEntry import FamilyEntry from mcrit.storage.FunctionEntry import FunctionEntry from mcrit.storage.FunctionLabelEntry import FunctionLabelEntry -from mcrit.storage.MatchingCache import MatchingCache +from mcrit.storage.MatchingCache import MatchingCache, StorageBackedMatchingCache from mcrit.storage.SampleEntry import SampleEntry from mcrit.storage.StorageInterface import StorageInterface @@ -483,8 +483,9 @@ def addMinHashes(self, minhashes: List["MinHash"]) -> None: for minhash in minhashes: self.addMinHash(minhash) - def createMatchingCache(self, function_ids: List[int]) -> MatchingCache: - # TODO: we might want add a flag to allow/disallow returning self + def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> MatchingCache: + if allow_self_return: + return StorageBackedMatchingCache(self, function_ids) cache_data = self._getCacheDataForFunctionIds(function_ids) return MatchingCache(cache_data) @@ -493,8 +494,8 @@ def _getCacheDataForFunctionIds(self, function_ids: List[int]) -> Dict: sample_ids = {} sample_to_func_ids = {} minhashes = {} - for function_id in function_ids: - function_entry = self._functions[function_id] + for function_id in set(function_ids): + function_entry = self._query_functions[function_id] if function_id < 0 else self._functions[function_id] function_id = function_entry.function_id sample_id = function_entry.sample_id minhashes[function_id] = function_entry.minhash @@ -637,6 +638,15 @@ def getSampleIdByFunctionId(self, function_id: int) -> Optional[int]: sample_id = self._query_functions[function_id].sample_id return sample_id + def getSampleIdsByFunctionIds(self, function_ids: List[int]) -> Dict[int, int]: + sample_ids = {} + for function_id in set(function_ids): + if function_id in self._functions: + sample_ids[function_id] = self._functions[function_id].sample_id + elif function_id in self._query_functions: + sample_ids[function_id] = self._query_functions[function_id].sample_id + return sample_ids + def getSamples(self, start_index: int, limit: int) -> Optional["SampleEntry"]: index = 0 sample_entries = [] @@ -678,6 +688,9 @@ def getMinHashByFunctionId(self, function_id: int) -> Optional[bytes]: if function_id in self._functions: function_entry = self._functions[function_id] return function_entry.minhash + if function_id in self._query_functions: + function_entry = self._query_functions[function_id] + return function_entry.minhash return None # -> Dict[function_id, Set[function_id]] @@ -950,4 +963,4 @@ def findFunctionByString(self, search_tree: NodeType, cursor: Optional[FullSearc result_dict[entry.function_id] = entry if len(result_dict) >= max_num_results: break - return result_dict \ No newline at end of file + return result_dict diff --git a/mcrit/storage/MongoDbStorage.py b/mcrit/storage/MongoDbStorage.py index 976ffb1..a81fcc9 100644 --- a/mcrit/storage/MongoDbStorage.py +++ b/mcrit/storage/MongoDbStorage.py @@ -401,6 +401,25 @@ def getSampleIdByFunctionId(self, function_id: int) -> Optional[int]: return None return function_document["sample_id"] + def getSampleIdsByFunctionIds(self, function_ids: List[int]) -> Dict[int, int]: + sample_ids = {} + positive_function_ids = [function_id for function_id in set(function_ids) if function_id >= 0] + negative_function_ids = [function_id for function_id in set(function_ids) if function_id < 0] + for collection_name, collection_ids in ( + ("functions", positive_function_ids), + ("query_functions", negative_function_ids), + ): + if not collection_ids: + continue + for sliced_ids in zip_longest(*[iter(collection_ids)] * 500000): + query_function_ids = [function_id for function_id in sliced_ids if function_id is not None] + for function_document in self._getDb()[collection_name].find( + {"function_id": {"$in": query_function_ids}}, + {"_id": 0, "function_id": 1, "sample_id": 1}, + ): + sample_ids[function_document["function_id"]] = function_document["sample_id"] + return sample_ids + def deleteSample(self, sample_id: int) -> bool: sample_entry = self.getSampleById(sample_id) if sample_entry is None: @@ -705,7 +724,7 @@ def getFunctionsBySampleId(self, sample_id: int) -> Optional[List["FunctionEntry functions.append(FunctionEntry.fromDict(f)) return functions - def getFunctionIdsBySampleId(self, sample_id: int) -> Optional[List["FunctionEntry"]]: + def getFunctionIdsBySampleId(self, sample_id: int) -> Optional[List[int]]: function_ids = None if not self.isSampleId(sample_id): return function_ids @@ -715,7 +734,7 @@ def getFunctionIdsBySampleId(self, sample_id: int) -> Optional[List["FunctionEnt else: function_dicts = list(self._getDb().functions.find({"sample_id": sample_id}, {"_id": 0, "function_id": 1})) for f in function_dicts: - function_ids(f["function_ids"]) + function_ids.append(f["function_id"]) return function_ids def getFunctions(self, start_index: int, limit: int) -> Optional["FunctionEntry"]: @@ -924,19 +943,28 @@ def _getCacheDataForFunctionIds(self, function_ids: List[int]) -> Dict: sample_to_func_ids = {} minhashes = {} # process this in batches as the number of function_ids can be exceedingly large, pushing beyond Mongo's 16M limit - for sliced_ids in zip_longest(*[iter(function_ids)]*500000): - query_function_ids = [fid for fid in sliced_ids if fid is not None] - for function_document in self._getDb().functions.find( - {"function_id": {"$in": list(query_function_ids)}}, {"_id": 0, "sample_id": 1, "minhash": 1, "function_id": 1} - ): - function_id = function_document["function_id"] - sample_id = function_document["sample_id"] - minhash = bytes.fromhex(function_document["minhash"]) - minhashes[function_id] = minhash - sample_ids[function_id] = sample_id - if sample_id not in sample_to_func_ids: - sample_to_func_ids[sample_id] = set() - sample_to_func_ids[sample_id].add(function_id) + positive_function_ids = [function_id for function_id in set(function_ids) if function_id >= 0] + negative_function_ids = [function_id for function_id in set(function_ids) if function_id < 0] + for collection_name, collection_ids in ( + ("functions", positive_function_ids), + ("query_functions", negative_function_ids), + ): + if not collection_ids: + continue + for sliced_ids in zip_longest(*[iter(collection_ids)] * 500000): + query_function_ids = [function_id for function_id in sliced_ids if function_id is not None] + for function_document in self._getDb()[collection_name].find( + {"function_id": {"$in": query_function_ids}}, + {"_id": 0, "sample_id": 1, "minhash": 1, "function_id": 1}, + ): + function_id = function_document["function_id"] + sample_id = function_document["sample_id"] + minhash = bytes.fromhex(function_document["minhash"]) + minhashes[function_id] = minhash + sample_ids[function_id] = sample_id + if sample_id not in sample_to_func_ids: + sample_to_func_ids[sample_id] = set() + sample_to_func_ids[sample_id].add(function_id) cache_data["func_id_to_minhash"] = minhashes cache_data["func_id_to_sample_id"] = sample_ids cache_data["sample_id_to_func_ids"] = sample_to_func_ids @@ -1011,7 +1039,7 @@ def _getFunctionDocument( self._encodeFunction(function_dict) return function_dict - def createMatchingCache(self, function_ids: List[int]) -> MatchingCache: + def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> MatchingCache: cache_data = self._getCacheDataForFunctionIds(function_ids) # TODO dont store this as attribute self._matching_cache = MatchingCache(cache_data) diff --git a/mcrit/storage/StorageInterface.py b/mcrit/storage/StorageInterface.py index d90069b..9f566d2 100644 --- a/mcrit/storage/StorageInterface.py +++ b/mcrit/storage/StorageInterface.py @@ -10,7 +10,6 @@ from mcrit.config.MinHashConfig import MinHashConfig from mcrit.storage.FunctionEntry import FunctionEntry from mcrit.storage.MatchingCache import MatchingCache - from mcrit.storage.MemoryStorage import MemoryStorage from mcrit.storage.SampleEntry import SampleEntry from smda.common.SmdaFunction import SmdaFunction from smda.common.SmdaReport import SmdaReport @@ -257,6 +256,17 @@ def getSampleIdByFunctionId(self, function_id: int) -> Optional[int]: """ raise NotImplementedError + def getSampleIdsByFunctionIds(self, function_ids: List[int]) -> Dict[int, int]: + """For a given list of function_ids, return the corresponding sample_ids. + + Args: + function_ids: a list of function ids + + Returns: + a dict mapping function_id to sample_id for all function_ids found + """ + raise NotImplementedError + def getSampleById(self, sample_id: int) -> Optional["SampleEntry"]: """Given a sample_id, return the respective SampleEntry or None, if sample_id was not found. @@ -365,13 +375,14 @@ def clearMatchingCache(self) -> None: """ raise NotImplementedError - # TODO: make a MatchingCacheInterface, or MemoryStorage a subclass of MatchingCache? + # TODO: make a MatchingCacheInterface for all backends. # TODO rename -> get? - def createMatchingCache(self, function_ids: List[int]) -> Union["MemoryStorage", "MatchingCache"]: + def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> "MatchingCache": """Creates a temporary matching cache, for a list of function_ids Args: function_ids: list of function ids + allow_self_return: (optional) if True, allows a backend-specific optimized cache implementation Returns: a matching cache for the specified list of function ids diff --git a/tests/testStorage.py b/tests/testStorage.py index 5f4b51e..e229803 100644 --- a/tests/testStorage.py +++ b/tests/testStorage.py @@ -302,6 +302,40 @@ def testMatchingCache(self): self.assertTrue(hasattr(cache, "getMinHashByFunctionId")) self.assertTrue(hasattr(cache, "getSampleIdByFunctionId")) + def testMatchingCacheAllowSelfReturnDoesNotMutateStorage(self): + self.storage.clearStorage() + smda_report = SmdaReport.fromFile(self.example_file_path) + self.storage.addSmdaReport(smda_report) + + cache = self.storage.createMatchingCache([0], allow_self_return=True) + cache_only_entry = self.storage.getFunctionById(0, with_xcfg=True) + cache_only_entry.sample_id = 999 + cache_only_entry.minhash = b"cache-only-minhash" + cache.addFunctionEntriesToCache([cache_only_entry]) + + self.assertEqual(0, self.storage.getSampleIdByFunctionId(0)) + self.assertNotEqual(b"cache-only-minhash", self.storage.getMinHashByFunctionId(0)) + self.assertEqual(999, cache.getSampleIdByFunctionId(0)) + self.assertEqual(b"cache-only-minhash", cache.getMinHashByFunctionId(0)) + self.assertEqual(set([0]), cache.getFunctionIdsBySampleId(999)) + + def testMatchingCacheSupportsQueryFunctions(self): + self.storage.clearStorage() + smda_report = SmdaReport.fromFile(self.example_file_path) + query_sample = self.storage.addSmdaReport(smda_report, isQuery=True) + assert query_sample is not None + query_function_ids = self.storage.getFunctionIdsBySampleId(query_sample.sample_id) + assert query_function_ids is not None + query_function_id = query_function_ids[0] + query_function = self.storage.getFunctionById(query_function_id) + assert query_function is not None + + for allow_self_return in (False, True): + cache = self.storage.createMatchingCache([query_function_id], allow_self_return=allow_self_return) + self.assertEqual(query_sample.sample_id, cache.getSampleIdByFunctionId(query_function_id)) + self.assertEqual(query_function.minhash, cache.getMinHashByFunctionId(query_function_id)) + self.assertEqual(set([query_function_id]), set(cache.getFunctionIdsBySampleId(query_sample.sample_id))) + ### Added mongo attribute import pytest