Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mcrit/matchers/MatcherInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from mcrit.storage.FunctionEntry import FunctionEntry
from mcrit.storage.SampleEntry import SampleEntry
from mcrit.storage.MatchingCache import MatchingCache
from mcrit.storage.MemoryStorage import MemoryStorage
from mcrit.Worker import Worker


Expand Down Expand Up @@ -245,7 +244,7 @@ def _countPackedTuples(self, candidate_pairs) -> int:
return quotient + int(bool(remainder)) # always round up

def _unrollGroupsAsPackedTuples(
self, cache: Union["MatchingCache", "MemoryStorage"], candidate_pairs
self, cache: "MatchingCache", candidate_pairs
) -> Iterable[List[Tuple[int, int, bytes, int, int, bytes]]]:
# Query, VS, Sample
# All were identical
Expand Down
2 changes: 1 addition & 1 deletion mcrit/matchers/MatcherQuery.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,6 @@ def _createMatchingCache(self, candidate_groups):
for function_id in other_function_ids:
if function_id >= 0:
function_ids_from_storage.add(function_id)
cache = self._storage.createMatchingCache(function_ids_from_storage)
cache = self._storage.createMatchingCache(function_ids_from_storage, allow_self_return=True)
cache.addFunctionEntriesToCache(self._function_entries)
return cache
2 changes: 1 addition & 1 deletion mcrit/matchers/MatcherQueryFunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ def _createMatchingCache(self, candidate_groups):
for function_id in other_function_ids:
if function_id >= 0:
function_ids_from_storage.add(function_id)
cache = self._storage.createMatchingCache(function_ids_from_storage)
cache = self._storage.createMatchingCache(function_ids_from_storage, allow_self_return=True)
cache.addFunctionEntriesToCache(self._function_entries)
return cache
41 changes: 37 additions & 4 deletions mcrit/storage/MatchingCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ def __init__(self, cache_data):
self._func_id_to_sample_id = cache_data["func_id_to_sample_id"]
self._sample_id_to_func_ids = cache_data["sample_id_to_func_ids"]

def _setFunctionEntry(self, function_id, sample_id, minhash):
if function_id in self._func_id_to_sample_id:
old_sample_id = self._func_id_to_sample_id[function_id]
if old_sample_id in self._sample_id_to_func_ids:
self._sample_id_to_func_ids[old_sample_id].discard(function_id)
if not self._sample_id_to_func_ids[old_sample_id]:
del self._sample_id_to_func_ids[old_sample_id]
self._func_id_to_minhash[function_id] = minhash
self._func_id_to_sample_id[function_id] = sample_id
if sample_id not in self._sample_id_to_func_ids:
self._sample_id_to_func_ids[sample_id] = set()
self._sample_id_to_func_ids[sample_id].add(function_id)

def isSampleId(self, sample_id):
return sample_id in self._sample_id_to_func_ids

Expand All @@ -20,9 +33,29 @@ def getFunctionIdsBySampleId(self, sample_id):

def addFunctionEntriesToCache(self, function_entries):
for function_entry in function_entries:
self._func_id_to_minhash[function_entry.function_id] = function_entry.minhash
self._func_id_to_sample_id[function_entry.function_id] = function_entry.sample_id
sample_id = function_entry.sample_id
self._setFunctionEntry(function_entry.function_id, function_entry.sample_id, function_entry.minhash)


class StorageBackedMatchingCache(MatchingCache):
"""A matching cache that reuses storage-backed data without mutating the underlying storage."""

def __init__(self, storage, function_ids):
self._storage = storage
self._func_id_to_minhash = {}
unique_function_ids = set(function_ids)
self._func_id_to_sample_id = dict(self._storage.getSampleIdsByFunctionIds(list(unique_function_ids)))
self._sample_id_to_func_ids = {}
missing_function_ids = unique_function_ids.difference(self._func_id_to_sample_id)
if missing_function_ids:
raise KeyError(missing_function_ids.pop())
for function_id, sample_id in self._func_id_to_sample_id.items():
if sample_id not in self._sample_id_to_func_ids:
self._sample_id_to_func_ids[sample_id] = set()
self._sample_id_to_func_ids[sample_id].add(function_entry.function_id)
self._sample_id_to_func_ids[sample_id].add(function_id)

def getMinHashByFunctionId(self, function_id):
if function_id in self._func_id_to_minhash:
return self._func_id_to_minhash[function_id]
if function_id not in self._func_id_to_sample_id:
raise KeyError(function_id)
return self._storage.getMinHashByFunctionId(function_id)
25 changes: 19 additions & 6 deletions mcrit/storage/MemoryStorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from mcrit.storage.FamilyEntry import FamilyEntry
from mcrit.storage.FunctionEntry import FunctionEntry
from mcrit.storage.FunctionLabelEntry import FunctionLabelEntry
from mcrit.storage.MatchingCache import MatchingCache
from mcrit.storage.MatchingCache import MatchingCache, StorageBackedMatchingCache
from mcrit.storage.SampleEntry import SampleEntry
from mcrit.storage.StorageInterface import StorageInterface

Expand Down Expand Up @@ -483,8 +483,9 @@ def addMinHashes(self, minhashes: List["MinHash"]) -> None:
for minhash in minhashes:
self.addMinHash(minhash)

def createMatchingCache(self, function_ids: List[int]) -> MatchingCache:
# TODO: we might want add a flag to allow/disallow returning self
def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> MatchingCache:
if allow_self_return:
return StorageBackedMatchingCache(self, function_ids)
cache_data = self._getCacheDataForFunctionIds(function_ids)
return MatchingCache(cache_data)

Expand All @@ -493,8 +494,8 @@ def _getCacheDataForFunctionIds(self, function_ids: List[int]) -> Dict:
sample_ids = {}
sample_to_func_ids = {}
minhashes = {}
for function_id in function_ids:
function_entry = self._functions[function_id]
for function_id in set(function_ids):
function_entry = self._query_functions[function_id] if function_id < 0 else self._functions[function_id]
function_id = function_entry.function_id
sample_id = function_entry.sample_id
minhashes[function_id] = function_entry.minhash
Expand Down Expand Up @@ -637,6 +638,15 @@ def getSampleIdByFunctionId(self, function_id: int) -> Optional[int]:
sample_id = self._query_functions[function_id].sample_id
return sample_id

def getSampleIdsByFunctionIds(self, function_ids: List[int]) -> Dict[int, int]:
sample_ids = {}
for function_id in set(function_ids):
if function_id in self._functions:
sample_ids[function_id] = self._functions[function_id].sample_id
elif function_id in self._query_functions:
sample_ids[function_id] = self._query_functions[function_id].sample_id
return sample_ids

def getSamples(self, start_index: int, limit: int) -> Optional["SampleEntry"]:
index = 0
sample_entries = []
Expand Down Expand Up @@ -678,6 +688,9 @@ def getMinHashByFunctionId(self, function_id: int) -> Optional[bytes]:
if function_id in self._functions:
function_entry = self._functions[function_id]
return function_entry.minhash
if function_id in self._query_functions:
function_entry = self._query_functions[function_id]
return function_entry.minhash
return None

# -> Dict[function_id, Set[function_id]]
Expand Down Expand Up @@ -950,4 +963,4 @@ def findFunctionByString(self, search_tree: NodeType, cursor: Optional[FullSearc
result_dict[entry.function_id] = entry
if len(result_dict) >= max_num_results:
break
return result_dict
return result_dict
60 changes: 44 additions & 16 deletions mcrit/storage/MongoDbStorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,25 @@ def getSampleIdByFunctionId(self, function_id: int) -> Optional[int]:
return None
return function_document["sample_id"]

def getSampleIdsByFunctionIds(self, function_ids: List[int]) -> Dict[int, int]:
sample_ids = {}
positive_function_ids = [function_id for function_id in set(function_ids) if function_id >= 0]
negative_function_ids = [function_id for function_id in set(function_ids) if function_id < 0]
for collection_name, collection_ids in (
("functions", positive_function_ids),
("query_functions", negative_function_ids),
):
if not collection_ids:
continue
for sliced_ids in zip_longest(*[iter(collection_ids)] * 500000):
query_function_ids = [function_id for function_id in sliced_ids if function_id is not None]
for function_document in self._getDb()[collection_name].find(
{"function_id": {"$in": query_function_ids}},
{"_id": 0, "function_id": 1, "sample_id": 1},
):
sample_ids[function_document["function_id"]] = function_document["sample_id"]
return sample_ids

def deleteSample(self, sample_id: int) -> bool:
sample_entry = self.getSampleById(sample_id)
if sample_entry is None:
Expand Down Expand Up @@ -705,7 +724,7 @@ def getFunctionsBySampleId(self, sample_id: int) -> Optional[List["FunctionEntry
functions.append(FunctionEntry.fromDict(f))
return functions

def getFunctionIdsBySampleId(self, sample_id: int) -> Optional[List["FunctionEntry"]]:
def getFunctionIdsBySampleId(self, sample_id: int) -> Optional[List[int]]:
function_ids = None
if not self.isSampleId(sample_id):
return function_ids
Expand All @@ -715,7 +734,7 @@ def getFunctionIdsBySampleId(self, sample_id: int) -> Optional[List["FunctionEnt
else:
function_dicts = list(self._getDb().functions.find({"sample_id": sample_id}, {"_id": 0, "function_id": 1}))
for f in function_dicts:
function_ids(f["function_ids"])
function_ids.append(f["function_id"])
return function_ids

def getFunctions(self, start_index: int, limit: int) -> Optional["FunctionEntry"]:
Expand Down Expand Up @@ -924,19 +943,28 @@ def _getCacheDataForFunctionIds(self, function_ids: List[int]) -> Dict:
sample_to_func_ids = {}
minhashes = {}
# process this in batches as the number of function_ids can be exceedingly large, pushing beyond Mongo's 16M limit
for sliced_ids in zip_longest(*[iter(function_ids)]*500000):
query_function_ids = [fid for fid in sliced_ids if fid is not None]
for function_document in self._getDb().functions.find(
{"function_id": {"$in": list(query_function_ids)}}, {"_id": 0, "sample_id": 1, "minhash": 1, "function_id": 1}
):
function_id = function_document["function_id"]
sample_id = function_document["sample_id"]
minhash = bytes.fromhex(function_document["minhash"])
minhashes[function_id] = minhash
sample_ids[function_id] = sample_id
if sample_id not in sample_to_func_ids:
sample_to_func_ids[sample_id] = set()
sample_to_func_ids[sample_id].add(function_id)
positive_function_ids = [function_id for function_id in set(function_ids) if function_id >= 0]
negative_function_ids = [function_id for function_id in set(function_ids) if function_id < 0]
for collection_name, collection_ids in (
("functions", positive_function_ids),
("query_functions", negative_function_ids),
):
if not collection_ids:
continue
for sliced_ids in zip_longest(*[iter(collection_ids)] * 500000):
query_function_ids = [function_id for function_id in sliced_ids if function_id is not None]
for function_document in self._getDb()[collection_name].find(
{"function_id": {"$in": query_function_ids}},
{"_id": 0, "sample_id": 1, "minhash": 1, "function_id": 1},
):
function_id = function_document["function_id"]
sample_id = function_document["sample_id"]
minhash = bytes.fromhex(function_document["minhash"])
minhashes[function_id] = minhash
sample_ids[function_id] = sample_id
if sample_id not in sample_to_func_ids:
sample_to_func_ids[sample_id] = set()
sample_to_func_ids[sample_id].add(function_id)
cache_data["func_id_to_minhash"] = minhashes
cache_data["func_id_to_sample_id"] = sample_ids
cache_data["sample_id_to_func_ids"] = sample_to_func_ids
Expand Down Expand Up @@ -1011,7 +1039,7 @@ def _getFunctionDocument(
self._encodeFunction(function_dict)
return function_dict

def createMatchingCache(self, function_ids: List[int]) -> MatchingCache:
def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> MatchingCache:
cache_data = self._getCacheDataForFunctionIds(function_ids)
# TODO dont store this as attribute
self._matching_cache = MatchingCache(cache_data)
Expand Down
17 changes: 14 additions & 3 deletions mcrit/storage/StorageInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from mcrit.config.MinHashConfig import MinHashConfig
from mcrit.storage.FunctionEntry import FunctionEntry
from mcrit.storage.MatchingCache import MatchingCache
from mcrit.storage.MemoryStorage import MemoryStorage
from mcrit.storage.SampleEntry import SampleEntry
from smda.common.SmdaFunction import SmdaFunction
from smda.common.SmdaReport import SmdaReport
Expand Down Expand Up @@ -257,6 +256,17 @@ def getSampleIdByFunctionId(self, function_id: int) -> Optional[int]:
"""
raise NotImplementedError

def getSampleIdsByFunctionIds(self, function_ids: List[int]) -> Dict[int, int]:
"""For a given list of function_ids, return the corresponding sample_ids.

Args:
function_ids: a list of function ids

Returns:
a dict mapping function_id to sample_id for all function_ids found
"""
raise NotImplementedError

def getSampleById(self, sample_id: int) -> Optional["SampleEntry"]:
"""Given a sample_id, return the respective SampleEntry or None, if sample_id was not found.

Expand Down Expand Up @@ -365,13 +375,14 @@ def clearMatchingCache(self) -> None:
"""
raise NotImplementedError

# TODO: make a MatchingCacheInterface, or MemoryStorage a subclass of MatchingCache?
# TODO: make a MatchingCacheInterface for all backends.
# TODO rename -> get?
def createMatchingCache(self, function_ids: List[int]) -> Union["MemoryStorage", "MatchingCache"]:
def createMatchingCache(self, function_ids: List[int], allow_self_return: bool = False) -> "MatchingCache":
"""Creates a temporary matching cache, for a list of function_ids

Args:
function_ids: list of function ids
allow_self_return: (optional) if True, allows a backend-specific optimized cache implementation

Returns:
a matching cache for the specified list of function ids
Expand Down
34 changes: 34 additions & 0 deletions tests/testStorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,40 @@ def testMatchingCache(self):
self.assertTrue(hasattr(cache, "getMinHashByFunctionId"))
self.assertTrue(hasattr(cache, "getSampleIdByFunctionId"))

def testMatchingCacheAllowSelfReturnDoesNotMutateStorage(self):
self.storage.clearStorage()
smda_report = SmdaReport.fromFile(self.example_file_path)
self.storage.addSmdaReport(smda_report)

cache = self.storage.createMatchingCache([0], allow_self_return=True)
cache_only_entry = self.storage.getFunctionById(0, with_xcfg=True)
cache_only_entry.sample_id = 999
cache_only_entry.minhash = b"cache-only-minhash"
cache.addFunctionEntriesToCache([cache_only_entry])

self.assertEqual(0, self.storage.getSampleIdByFunctionId(0))
self.assertNotEqual(b"cache-only-minhash", self.storage.getMinHashByFunctionId(0))
self.assertEqual(999, cache.getSampleIdByFunctionId(0))
self.assertEqual(b"cache-only-minhash", cache.getMinHashByFunctionId(0))
self.assertEqual(set([0]), cache.getFunctionIdsBySampleId(999))

def testMatchingCacheSupportsQueryFunctions(self):
self.storage.clearStorage()
smda_report = SmdaReport.fromFile(self.example_file_path)
query_sample = self.storage.addSmdaReport(smda_report, isQuery=True)
assert query_sample is not None
query_function_ids = self.storage.getFunctionIdsBySampleId(query_sample.sample_id)
assert query_function_ids is not None
query_function_id = query_function_ids[0]
query_function = self.storage.getFunctionById(query_function_id)
assert query_function is not None

for allow_self_return in (False, True):
cache = self.storage.createMatchingCache([query_function_id], allow_self_return=allow_self_return)
self.assertEqual(query_sample.sample_id, cache.getSampleIdByFunctionId(query_function_id))
self.assertEqual(query_function.minhash, cache.getMinHashByFunctionId(query_function_id))
self.assertEqual(set([query_function_id]), set(cache.getFunctionIdsBySampleId(query_sample.sample_id)))


### Added mongo attribute
import pytest
Expand Down
Loading