diff --git a/ms2query/benchmarking/EvaluateMethods.py b/ms2query/benchmarking/EvaluateMethods.py new file mode 100644 index 0000000..edced8c --- /dev/null +++ b/ms2query/benchmarking/EvaluateMethods.py @@ -0,0 +1,202 @@ +import random +from typing import Callable, List, Tuple +import numpy as np +from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix +from tqdm import tqdm +from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectrumSet + + +class EvaluateMethods: + def __init__( + self, training_spectrum_set: SpectraWithFingerprints, validation_spectrum_set: SpectraWithFingerprints + ): + self.training_spectrum_set = training_spectrum_set + self.validation_spectrum_set = validation_spectrum_set + + self.training_spectrum_set.progress_bars = False + self.validation_spectrum_set.progress_bars = False + + def benchmark_analogue_search( + self, + prediction_function: Callable[ + [SpectraWithFingerprints, SpectraWithFingerprints], Tuple[List[str], List[float]] + ], + ) -> float: + predicted_inchikeys, _ = prediction_function(self.training_spectrum_set, self.validation_spectrum_set) + average_scores_per_inchikey = [] + + # Calculate score per unique inchikey + for inchikey in tqdm( + self.validation_spectrum_set.spectrum_indexes_per_inchikey.keys(), + desc="Calculating analogue accuracy per inchikey", + ): + matching_spectrum_indexes = self.validation_spectrum_set.spectrum_indexes_per_inchikey[inchikey] + prediction_scores = [] + for index in matching_spectrum_indexes: + predicted_inchikey = predicted_inchikeys[index] + if predicted_inchikey is None: + prediction_scores.append(0.0) + else: + predicted_fingerprint = self.training_spectrum_set.inchikey_fingerprint_pairs[predicted_inchikey] + actual_fingerprint = self.validation_spectrum_set.inchikey_fingerprint_pairs[inchikey] + tanimoto_for_prediction = calculate_tanimoto_score_between_pair( + predicted_fingerprint, actual_fingerprint + ) + prediction_scores.append(tanimoto_for_prediction) + + average_prediction = sum(prediction_scores) / len(prediction_scores) + score = average_prediction + average_scores_per_inchikey.append(score) + average_over_all_inchikeys = sum(average_scores_per_inchikey) / len(average_scores_per_inchikey) + return average_over_all_inchikeys + + def benchmark_exact_matching_within_ionmode( + self, + prediction_function: Callable[ + [SpectraWithFingerprints, SpectraWithFingerprints], Tuple[List[str], List[float]] + ], + ionmode: str, + ) -> float: + """Test the accuracy at retrieving exact matches from the library + + For each inchikey with more than 1 spectrum the spectra are split in two sets. Half for each inchikey is added + to the library (training set), for the other half predictions are made. Thereby there is always an exact match + avaialable. Only the highest ranked prediction is considered correct if the correct inchikey is predicted. + An accuracy per inchikey is calculated followed by calculating the average. + """ + selected_spectra = subset_spectra_on_ionmode(self.validation_spectrum_set, ionmode) + + set_1, set_2 = split_spectrum_set_per_inchikeys(selected_spectra) + + predicted_inchikeys = predict_between_two_sets(self.training_spectrum_set, set_1, set_2, prediction_function) + + # add the spectra to set_1 + set_1.add_spectra(set_2) + return calculate_average_exact_match_accuracy(set_1, predicted_inchikeys) + + def exact_matches_across_ionization_modes( + self, + prediction_function: Callable[ + [SpectraWithFingerprints, SpectraWithFingerprints], Tuple[List[str], List[float]] + ], + ): + """Test the accuracy at retrieving exact matches from the library if only available in other ionisation mode + + Each val spectrum is matched against the training set with the other val spectra of the same inchikey, but other + ionisation mode added to the library. + """ + pos_set, neg_set = split_spectrum_set_per_inchikey_across_ionmodes(self.validation_spectrum_set) + predicted_inchikeys = predict_between_two_sets( + self.training_spectrum_set, pos_set, neg_set, prediction_function + ) + # add the spectra to set_1 + pos_set.add_spectra(neg_set) + return calculate_average_exact_match_accuracy(pos_set, predicted_inchikeys) + + def get_accuracy_recall_curve(self): + """This method should test the recall accuracy balance. + All of the used methods use a threshold which indicates quality of prediction. + A method that can predict well when a prediction is accurate is beneficial. + We need a method to test this. + + One method is generating a recall accuracy curve. This could be done for both the analogue search predictions + and the exact match predictions. By returning the predicted score for a match this method could create an + accuracy recall plot. + """ + raise NotImplementedError + + +def predict_between_two_sets( + library: SpectrumSet, query_set_1: SpectrumSet, query_set_2: SpectrumSet, prediction_function +): + """Makes predictions between query sets and the library, with the other query set added. + + This is necessary for testing exact matching""" + training_set_copy = library.copy() + training_set_copy.add_spectra(query_set_2) + predicted_inchikeys_1, _ = prediction_function(training_set_copy, query_set_1) + + training_set_copy = library.copy() + training_set_copy.add_spectra(query_set_1) + predicted_inchikeys_2, _ = prediction_function(training_set_copy, query_set_2) + + return predicted_inchikeys_1 + predicted_inchikeys_2 + + +def calculate_average_exact_match_accuracy(spectrum_set: SpectrumSet, predicted_inchikeys: List[str]): + if len(spectrum_set.spectra) != len(predicted_inchikeys): + raise ValueError("The number of spectra should be equal to the number of predicted inchikeys ") + exact_match_accuracy_per_inchikey = [] + for inchikey in tqdm( + spectrum_set.spectrum_indexes_per_inchikey.keys(), desc="Calculating exact match accuracy per inchikey" + ): + val_spectrum_indexes_matching_inchikey = spectrum_set.spectrum_indexes_per_inchikey[inchikey] + correctly_predicted = 0 + for selected_spectrum_idx in val_spectrum_indexes_matching_inchikey: + if inchikey == predicted_inchikeys[selected_spectrum_idx]: + correctly_predicted += 1 + exact_match_accuracy_per_inchikey.append(correctly_predicted / len(val_spectrum_indexes_matching_inchikey)) + return sum(exact_match_accuracy_per_inchikey) / len(exact_match_accuracy_per_inchikey) + + +def split_spectrum_set_per_inchikeys(spectrum_set: SpectrumSet) -> Tuple[SpectrumSet, SpectrumSet]: + """Splits a spectrum set into two. + For each inchikey with more than one spectrum the spectra are divided over the two sets""" + indexes_set_1 = [] + indexes_set_2 = [] + for inchikey in tqdm(spectrum_set.spectrum_indexes_per_inchikey.keys(), desc="Splitting spectra per inchikey"): + val_spectrum_indexes_matching_inchikey = spectrum_set.spectrum_indexes_per_inchikey[inchikey] + if len(val_spectrum_indexes_matching_inchikey) == 1: + # all single spectra are excluded from this test, since no exact match can be added to the library + continue + split_index = len(val_spectrum_indexes_matching_inchikey) // 2 + random.shuffle(val_spectrum_indexes_matching_inchikey) + indexes_set_1.extend(val_spectrum_indexes_matching_inchikey[:split_index]) + indexes_set_2.extend(val_spectrum_indexes_matching_inchikey[split_index:]) + return spectrum_set.subset_spectra(indexes_set_1), spectrum_set.subset_spectra(indexes_set_2) + + +def split_spectrum_set_per_inchikey_across_ionmodes( + spectrum_set: SpectrumSet, +) -> Tuple[SpectrumSet, SpectrumSet]: + """Splits a spectrum set in two sets on ionmode. Only uses spectra for inchikeys with at least 1 pos and 1 neg""" + all_pos_indexes = [] + all_neg_indexes = [] + for inchikey in tqdm( + spectrum_set.spectrum_indexes_per_inchikey.keys(), + desc="Splitting spectra per inchikey across ionmodes", + ): + val_spectrum_indexes_matching_inchikey = spectrum_set.spectrum_indexes_per_inchikey[inchikey] + positive_val_spectrum_indexes_current_inchikey = [] + negative_val_spectrum_indexes_current_inchikey = [] + for spectrum_index in val_spectrum_indexes_matching_inchikey: + ionmode = spectrum_set.spectra[spectrum_index].get("ionmode") + if ionmode == "positive": + positive_val_spectrum_indexes_current_inchikey.append(spectrum_index) + elif ionmode == "negative": + negative_val_spectrum_indexes_current_inchikey.append(spectrum_index) + + if ( + len(positive_val_spectrum_indexes_current_inchikey) < 1 + or len(negative_val_spectrum_indexes_current_inchikey) < 1 + ): + continue + else: + all_pos_indexes.extend(positive_val_spectrum_indexes_current_inchikey) + all_neg_indexes.extend(negative_val_spectrum_indexes_current_inchikey) + + pos_val_spectra = spectrum_set.subset_spectra(all_pos_indexes) + neg_val_spectra = spectrum_set.subset_spectra(all_neg_indexes) + return pos_val_spectra, neg_val_spectra + + +def subset_spectra_on_ionmode(spectrum_set: SpectrumSet, ionmode) -> SpectrumSet: + spectrum_indexes_to_keep = [] + for i, spectrum in enumerate(spectrum_set.spectra): + if spectrum.get("ionmode") == ionmode: + spectrum_indexes_to_keep.append(i) + return spectrum_set.subset_spectra(spectrum_indexes_to_keep) + + +def calculate_tanimoto_score_between_pair(fingerprint_1: str, fingerprint_2: str) -> float: + return jaccard_similarity_matrix(np.array([fingerprint_1]), np.array([fingerprint_2]))[0][0] diff --git a/ms2query/benchmarking/SpectrumDataSet.py b/ms2query/benchmarking/SpectrumDataSet.py new file mode 100644 index 0000000..d2f454d --- /dev/null +++ b/ms2query/benchmarking/SpectrumDataSet.py @@ -0,0 +1,138 @@ +import copy +from collections import Counter +from typing import Dict, Iterable, List +import numpy as np +from matchms import Spectrum +from matchms.filtering.metadata_processing.add_fingerprint import _derive_fingerprint_from_inchi +from ms2deepscore.models import SiameseSpectralModel, compute_embedding_array +from tqdm import tqdm + + +class SpectrumSet: + """Stores a spectrum dataset making it easy and fast to split on molecules""" + + def __init__(self, spectra: List[Spectrum], progress_bars=False): + self._spectra = [] + self.spectrum_indexes_per_inchikey = {} + self.progress_bars = progress_bars + # init spectra + self._add_spectra_and_group_per_inchikey(spectra) + + def _add_spectra_and_group_per_inchikey(self, spectra: List[Spectrum]): + starting_index = len(self._spectra) + updated_inchikeys = set() + for i, spectrum in enumerate( + tqdm(spectra, desc="Adding spectra and grouping per Inchikey", disable=not self.progress_bars) + ): + self._spectra.append(spectrum) + spectrum_index = starting_index + i + inchikey = spectrum.get("inchikey")[:14] + updated_inchikeys.add(inchikey) + if inchikey in self.spectrum_indexes_per_inchikey: + self.spectrum_indexes_per_inchikey[inchikey].append(spectrum_index) + else: + self.spectrum_indexes_per_inchikey[inchikey] = [ + spectrum_index, + ] + return updated_inchikeys + + def add_spectra(self, new_spectra: "SpectrumSet"): + return self._add_spectra_and_group_per_inchikey(new_spectra.spectra) + + def subset_spectra(self, spectrum_indexes) -> "SpectrumSet": + """Returns a new instance of a subset of the spectra""" + new_instance = copy.copy(self) + new_instance._spectra = [] + new_instance.spectrum_indexes_per_inchikey = {} + new_instance._add_spectra_and_group_per_inchikey([self._spectra[index] for index in spectrum_indexes]) + return new_instance + + def spectra_per_inchikey(self, inchikey) -> List[Spectrum]: + matching_spectra = [] + for index in self.spectrum_indexes_per_inchikey[inchikey]: + matching_spectra.append(self._spectra[index]) + return matching_spectra + + @property + def spectra(self): + return self._spectra + + def copy(self): + """This copy method ensures all spectra are""" + new_instance = copy.copy(self) + new_instance._spectra = self._spectra.copy() + new_instance.spectrum_indexes_per_inchikey = copy.deepcopy(self.spectrum_indexes_per_inchikey) + return new_instance + + +class SpectraWithFingerprints(SpectrumSet): + """Stores a spectrum dataset making it easy and fast to split on molecules""" + + def __init__(self, spectra: List[Spectrum], fingerprint_type="daylight", nbits=4096): + super().__init__(spectra) + self.fingerprint_type = fingerprint_type + self.nbits = nbits + self.inchikey_fingerprint_pairs: Dict[str, np.array] = {} + # init spectra + self.update_fingerprint_per_inchikey(self.spectrum_indexes_per_inchikey.keys()) + + def add_spectra(self, new_spectra: "SpectraWithFingerprints"): + updated_inchikeys = super().add_spectra(new_spectra) + if hasattr(new_spectra, "inchikey_fingerprint_pairs"): + if new_spectra.nbits == self.nbits and new_spectra.fingerprint_type == self.fingerprint_type: + if len(self.inchikey_fingerprint_pairs.keys() & new_spectra.inchikey_fingerprint_pairs.keys()) == 0: + self.inchikey_fingerprint_pairs = ( + self.inchikey_fingerprint_pairs | new_spectra.inchikey_fingerprint_pairs + ) + return + self.update_fingerprint_per_inchikey(updated_inchikeys) + + def update_fingerprint_per_inchikey(self, inchikeys_to_update: Iterable[str]): + for inchikey in tqdm( + inchikeys_to_update, desc="Adding fingerprints to Inchikeys", disable=not self.progress_bars + ): + spectra = self.spectra_per_inchikey(inchikey) + most_common_inchi = Counter([spectrum.get("inchi") for spectrum in spectra]).most_common(1)[0][0] + fingerprint = _derive_fingerprint_from_inchi( + most_common_inchi, fingerprint_type=self.fingerprint_type, nbits=self.nbits + ) + if not isinstance(fingerprint, np.ndarray): + raise ValueError(f"Fingerprint could not be set for InChI: {most_common_inchi}") + self.inchikey_fingerprint_pairs[inchikey] = fingerprint + + def copy(self): + """This copy method ensures all spectra are""" + new_instance = super().copy() + new_instance.inchikey_fingerprint_pairs = copy.copy(self.inchikey_fingerprint_pairs) + return new_instance + + def subset_spectra(self, spectrum_indexes) -> "SpectraWithFingerprints": + """Returns a new instance of a subset of the spectra""" + new_instance = super().subset_spectra(spectrum_indexes) + # Only keep the fingerprints for which we have inchikeys. + # Important note: This is not a deep copy! + # And the fingerprint is not reset (so it is not always actually matching the most common inchi) + new_instance.inchikey_fingerprint_pairs = {inchikey: self.inchikey_fingerprint_pairs[inchikey] for inchikey + in new_instance.spectrum_indexes_per_inchikey.keys()} + return new_instance + + +class SpectraWithMS2DeepScoreEmbeddings(SpectraWithFingerprints): + def __init__(self, spectra: List[Spectrum], ms2deepscore_model: SiameseSpectralModel, **kwargs): + super().__init__(spectra, **kwargs) + self.ms2deepscore_model = ms2deepscore_model + self.embeddings: np.ndarray = compute_embedding_array(self.ms2deepscore_model, spectra) + + def add_spectra(self, new_spectra: "SpectraWithMS2DeepScoreEmbeddings"): + super().add_spectra(new_spectra) + if hasattr(new_spectra, "embeddings"): + new_embeddings = new_spectra.embeddings + else: + new_embeddings = compute_embedding_array(self.ms2deepscore_model, new_spectra.spectra) + self.embeddings = np.vstack([self.embeddings, new_embeddings]) + + def subset_spectra(self, spectrum_indexes) -> "SpectraWithMS2DeepScoreEmbeddings": + """Returns a new instance of a subset of the spectra""" + new_instance = super().subset_spectra(spectrum_indexes) + new_instance.embeddings = self.embeddings[spectrum_indexes] + return new_instance diff --git a/ms2query/benchmarking/__init__.py b/ms2query/benchmarking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py b/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py new file mode 100644 index 0000000..cd20544 --- /dev/null +++ b/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py @@ -0,0 +1,45 @@ +from typing import Tuple +import numpy as np +from ms2deepscore.vector_operations import cosine_similarity_matrix +from tqdm import tqdm + + +def predict_top_ms2deepscores( + library_embeddings: np.ndarray, query_embeddings: np.ndarray, batch_size: int = 500, k=1 +) -> Tuple[np.ndarray, np.ndarray]: + """Memory efficient way of calculating the highest MS2DeepScores + + When doing large matrix multiplications the memory footprint of storing the output matrix can be large. + E.g. when doing 500.000 vs 10.000 spectra this is a very large matrix. On a laptop this can result in very + slow run times. If only the highest MS2DeepScore is needed processing in batches prevents using too much memory + + Args: + library_embeddings: The embeddings of the library spectra + query_embeddings: The embeddings of the query spectra + batch_size: The number of query embeddings processed at the same time. + Setting a lower batch_size results in a lower memory footprint. + k: Number of highest matches to return + + Returns: + List[List[int]: indexes of highest scores and the value for the highest score. + Per query embedding the top k highest indexes are given. + List[List[float]: the highest scores. + """ + top_indexes_per_batch = [] + top_scores_per_batch = [] + num_of_query_embeddings = query_embeddings.shape[0] + # loop over the batches + for start_idx in tqdm( + range(0, num_of_query_embeddings, batch_size), + desc="Predicting highest ms2deepscore per batch of " + + str(min(batch_size, num_of_query_embeddings)) + + " embeddings", + ): + end_idx = min(start_idx + batch_size, num_of_query_embeddings) + selected_query_embeddings = query_embeddings[start_idx:end_idx] + score_matrix = cosine_similarity_matrix(selected_query_embeddings, library_embeddings) + top_n_idx = np.argsort(score_matrix, axis=1)[:, -k:][:, ::-1] + top_n_scores = np.take_along_axis(score_matrix, top_n_idx, axis=1) + top_indexes_per_batch.append(top_n_idx) + top_scores_per_batch.append(top_n_scores) + return np.vstack(top_indexes_per_batch), np.vstack(top_scores_per_batch) diff --git a/ms2query/benchmarking/reference_methods/__init__.py b/ms2query/benchmarking/reference_methods/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ms2query/benchmarking/reference_methods/predict_best_possible_match.py b/ms2query/benchmarking/reference_methods/predict_best_possible_match.py new file mode 100644 index 0000000..1544f71 --- /dev/null +++ b/ms2query/benchmarking/reference_methods/predict_best_possible_match.py @@ -0,0 +1,46 @@ +from typing import Dict +import numpy as np +from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix +from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints + + +def predict_best_possible_match(library_spectra: SpectraWithFingerprints, query_spectra: SpectraWithFingerprints): + highest_possible_score_per_inchikey = calculate_highest_tanimoto_score_per_inchikey(library_spectra, query_spectra) + + inchikeys_of_best_match = [] + highest_scores = [] + + for spectrum in query_spectra.spectra: + inchikey = spectrum.get("inchikey")[:14] + + inchikeys_of_best_match.append(highest_possible_score_per_inchikey[inchikey][0]) + highest_scores.append(highest_possible_score_per_inchikey[inchikey][1]) + + return inchikeys_of_best_match, highest_scores + + +def calculate_highest_tanimoto_score_per_inchikey( + library_spectra: SpectraWithFingerprints, query_spectra: SpectraWithFingerprints +) -> Dict[str, tuple[str, float]]: + """Finds the best possible match during an analogue search""" + print("Calculating tanimoto scores to determine best possible match") + library_fingerprints = np.array(list(library_spectra.inchikey_fingerprint_pairs.values())) + query_fingerprints = np.array(list(query_spectra.inchikey_fingerprint_pairs.values())) + tanimoto_scores = jaccard_similarity_matrix(library_fingerprints, query_fingerprints) + highest_scores = tanimoto_scores.max(axis=0, initial=0) + indexes_of_highest_scores = tanimoto_scores.argmax(axis=0) + + inchikeys_library = list(library_spectra.inchikey_fingerprint_pairs.keys()) + + highest_possible_score_per_inchikey = dict() + for i, inchikey in enumerate(query_spectra.inchikey_fingerprint_pairs): + # Check if inchikey in library (To correctly handle the exact matching case) + if inchikey in library_spectra.inchikey_fingerprint_pairs: + highest_possible_score_per_inchikey[inchikey] = (inchikey, 1.0) + continue + + highest_possible_score_per_inchikey[inchikey] = ( + inchikeys_library[indexes_of_highest_scores[i]], + highest_scores[i], + ) + return highest_possible_score_per_inchikey diff --git a/ms2query/benchmarking/reference_methods/predict_highest_cosine.py b/ms2query/benchmarking/reference_methods/predict_highest_cosine.py new file mode 100644 index 0000000..49fd6cf --- /dev/null +++ b/ms2query/benchmarking/reference_methods/predict_highest_cosine.py @@ -0,0 +1,26 @@ +from typing import List, Tuple +from matchms import Scores +from matchms.similarity.CosineGreedy import CosineGreedy +from matchms.similarity.PrecursorMzMatch import PrecursorMzMatch +from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints + + +def predict_highest_cosine( + library_spectra: SpectraWithFingerprints, query_spectra: SpectraWithFingerprints +) -> Tuple[List[str], List[float]]: + + scores = Scores(references=library_spectra.spectra, queries=query_spectra.spectra, is_symmetric=False) + scores = scores.calculate(PrecursorMzMatch(0.1)) + scores = scores.calculate(CosineGreedy(tolerance=0.1)) + inchikeys_of_best_match = [] + highest_scores = [] + for query_spectrum in query_spectra.spectra: + results = scores.scores_by_query(query_spectrum, "CosineGreedy_score", sort=True) + if len(results) == 0: + inchikeys_of_best_match.append(None) + highest_scores.append(0.0) + else: + best_reference, highest_score = results[0] + inchikeys_of_best_match.append(best_reference.get("inchikey")[:14]) + highest_scores.append(highest_score["CosineGreedy_score"]) + return inchikeys_of_best_match, highest_scores diff --git a/ms2query/benchmarking/reference_methods/predict_highest_ms2deepscore.py b/ms2query/benchmarking/reference_methods/predict_highest_ms2deepscore.py new file mode 100644 index 0000000..a5d9318 --- /dev/null +++ b/ms2query/benchmarking/reference_methods/predict_highest_ms2deepscore.py @@ -0,0 +1,16 @@ +from typing import List, Tuple +from ms2query.benchmarking.reference_methods.PredictMS2DeepScoreSimilarity import predict_top_ms2deepscores +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings + + +def predict_highest_ms2deepscore( + library_spectra: SpectraWithMS2DeepScoreEmbeddings, query_spectra: SpectraWithMS2DeepScoreEmbeddings +) -> Tuple[List[str], List[float]]: + indexes_of_highest_scores, highest_scores = predict_top_ms2deepscores( + library_spectra.embeddings, query_spectra.embeddings, k=1 + ) + single_highest_score = [highest_score[0] for highest_score in highest_scores] + inchikeys_of_best_match = [ + library_spectra.spectra[index[0]].get("inchikey")[:14] for index in indexes_of_highest_scores + ] + return inchikeys_of_best_match, single_highest_score diff --git a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py new file mode 100644 index 0000000..751953e --- /dev/null +++ b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py @@ -0,0 +1,81 @@ +from typing import List, Tuple +import numpy as np +from ms2deepscore.vector_operations import cosine_similarity_matrix +from tqdm import tqdm +from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectraWithMS2DeepScoreEmbeddings +from ms2query.metrics import generalized_tanimoto_similarity_matrix + + +def predict_using_closest_tanimoto( + library_spectra: SpectraWithMS2DeepScoreEmbeddings, query_spectra: SpectraWithMS2DeepScoreEmbeddings, + nr_of_closest_inchikeys_to_select=10, + nr_of_inchikeys_with_highest_ms2deepscore_to_select=100 +) -> Tuple[List[str], List[float]]: + """Predict best inchikey, by taking the average score over all spectra for the 10 closest related library inchikeys. + (simplified version of old MS2Query) + """ + inchikeys_of_best_match = [] + highest_scores = [] + for spectrum_idx in tqdm(range(len(query_spectra.spectra)), "Predicting using closest tanimoto"): + inchikey_of_best_match, score = predict_using_closest_tanimoto_single_spectrum( + library_spectra, query_spectra.subset_spectra([spectrum_idx]), + nr_of_closest_inchikeys_to_select, nr_of_inchikeys_with_highest_ms2deepscore_to_select) + inchikeys_of_best_match.append(inchikey_of_best_match) + highest_scores.append(score) + return inchikeys_of_best_match, highest_scores + + +def predict_using_closest_tanimoto_single_spectrum( + spectra_with_embeddings, single_spectrum_with_embeddings, + nr_of_closest_inchikeys_to_select, nr_of_inchikeys_with_highest_ms2deepscore_to_select) -> Tuple[str, float]: + if len(single_spectrum_with_embeddings.spectra) != 1: + raise ValueError("expected a single spectrum") + ms2deepscores = cosine_similarity_matrix(single_spectrum_with_embeddings.embeddings, + spectra_with_embeddings.embeddings)[0] + top_inchikeys = select_inchikeys_with_highest_ms2deepscore(spectra_with_embeddings, ms2deepscores, + nr_of_inchikeys_with_highest_ms2deepscore_to_select) + average_predicted_scores = {} + for inchikey in top_inchikeys: + top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k( + spectra_with_embeddings, inchikey, nr_of_closest_inchikeys_to_select) + average_predicted_score = get_average_predictions_for_closely_related_metabolites( + spectra_with_embeddings, top_k_inchikeys, ms2deepscores) + average_predicted_scores[inchikey] = average_predicted_score + + inchikey_with_highest_average_prediction, score = max(average_predicted_scores.items(), key=lambda item: item[1]) + return inchikey_with_highest_average_prediction, score + +def select_inchikeys_with_highest_ms2deepscore(spectra_with_embeddings, ms2deepscores, nr_of_inchikeys_to_select=10): + highest_score_for_inchikey = [] + for inchikey, spectrum_indexes in spectra_with_embeddings.spectrum_indexes_per_inchikey.items(): + all_ms2deepscores_for_inchikey = ms2deepscores[spectrum_indexes] + highest_score_for_inchikey.append(max(all_ms2deepscores_for_inchikey)) + inchikey_indexes_with_highest_ms2deepscore = np.argpartition( + np.array(highest_score_for_inchikey), -nr_of_inchikeys_to_select)[-nr_of_inchikeys_to_select:] + + all_inchikeys = list(spectra_with_embeddings.inchikey_fingerprint_pairs.keys()) + top_inchikeys = [all_inchikeys[inchikey_index] for inchikey_index in inchikey_indexes_with_highest_ms2deepscore] + return top_inchikeys + +def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddings, top_k_inchikeys, + all_ms2deepscores): + """Calculates the average ms2deepscore predictions for top k closest inchikeys""" + average_predicted_scores = [] + for top_inchikey in top_k_inchikeys: + matching_spectrum_indexes = spectra_with_embeddings.spectrum_indexes_per_inchikey[top_inchikey] + predicted_scores = all_ms2deepscores[matching_spectrum_indexes] + average_predicted_scores.append(predicted_scores.mean()) + average_predicted_score = sum(average_predicted_scores) / len(average_predicted_scores) + return average_predicted_score + +def get_inchikey_and_tanimoto_scores_for_top_k(spectra: SpectraWithFingerprints, inchikey, k + ) -> tuple[list[str], np.ndarray]: + """For an inchikey in a library the top k highest tanimoto scores in the library are predicted (including itself)""" + library_fingerprints = np.vstack(list(spectra.inchikey_fingerprint_pairs.values())) + fingerprint_single_inchikey = np.vstack(list([spectra.inchikey_fingerprint_pairs[inchikey]])) + similarity_scores = generalized_tanimoto_similarity_matrix(fingerprint_single_inchikey, library_fingerprints)[0] + inchikey_indexes_of_top_k = np.argpartition(similarity_scores, -k)[-k:] + tanimoto_scores_for_top_k = similarity_scores[inchikey_indexes_of_top_k] + all_inchikeys = list(spectra.inchikey_fingerprint_pairs.keys()) + top_inchikeys = [all_inchikeys[inchikey_index] for inchikey_index in inchikey_indexes_of_top_k] + return top_inchikeys, tanimoto_scores_for_top_k diff --git a/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py b/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py new file mode 100644 index 0000000..e352ab3 --- /dev/null +++ b/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py @@ -0,0 +1,108 @@ +from typing import List, Tuple +import numpy as np +from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix +from tqdm import tqdm +from ms2query.benchmarking.reference_methods.PredictMS2DeepScoreSimilarity import predict_top_ms2deepscores +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings + + +def predict_with_integrated_similarity_flow( + library_spectra: SpectraWithMS2DeepScoreEmbeddings, + query_spectra: SpectraWithMS2DeepScoreEmbeddings, + number_of_analogues_to_consider=50, +) -> Tuple[List[str], List[float]]: + + all_indexes_of_library_spectra_with_highest_score, all_predicted_scores = predict_top_ms2deepscores( + library_spectra.embeddings, query_spectra.embeddings, k=number_of_analogues_to_consider + ) + inchikeys_of_best_matches = [] + highest_isf_scores = [] + # loop over the query spectra: + for query_index in tqdm(range(len(query_spectra.spectra)), "Calculating ISF score"): + highest_isf_score, inchikey_of_highest_isf_score = get_highest_isf( + library_spectra, + all_indexes_of_library_spectra_with_highest_score[query_index], + all_predicted_scores[query_index], + ) + inchikeys_of_best_matches.append(inchikey_of_highest_isf_score) + highest_isf_scores.append(highest_isf_score) + return inchikeys_of_best_matches, highest_isf_scores + + +def get_highest_isf( + library_spectra: SpectraWithMS2DeepScoreEmbeddings, + indexes_of_library_spectra_with_highest_score: List[int], + predicted_scores: [List[float]], +): + + # Get the corresponding inchikeys + inchikeys_with_highest_ms2deepscore = [ + library_spectra.spectra[index].get("inchikey")[:14] for index in indexes_of_library_spectra_with_highest_score + ] + unique_inchikeys, average_scores, nr_of_spectra_per_inchikey = average_scores_per_inchikeys( + predicted_scores, inchikeys_with_highest_ms2deepscore + ) + # calculate tanimoto scores + library_fingerprints = np.array( + [library_spectra.inchikey_fingerprint_pairs[inchikey] for inchikey in unique_inchikeys] + ) + tanimoto_scores = jaccard_similarity_matrix(library_fingerprints, library_fingerprints) + + isf_scores = integrated_similarity_flow(average_scores, tanimoto_scores, nr_of_spectra_per_inchikey) + index_of_highest_score = np.argmax(isf_scores) + highest_isf_score = isf_scores[index_of_highest_score] + inchikey_of_highest_isf_score = unique_inchikeys[index_of_highest_score] + return highest_isf_score, inchikey_of_highest_isf_score + + +def average_scores_per_inchikeys(predicted_scores, inchikeys): + """Calculate the average precicted score per inchikey + This helps speed up the computations""" + if len(predicted_scores) != len(inchikeys): + raise ValueError + scores_per_inchikey = {} + for i, score in enumerate(predicted_scores): + inchikey = inchikeys[i] + if inchikey in scores_per_inchikey: + scores_per_inchikey[inchikey].append(score) + else: + scores_per_inchikey[inchikey] = [score] + # Take the average over the scores per inchikey + unique_inchikeys = [] + average_scores = [] + nr_of_spectra_per_inchikey = [] + for inchikey in scores_per_inchikey: + unique_inchikeys.append(inchikey) + average_scores.append(sum(scores_per_inchikey[inchikey]) / len(scores_per_inchikey[inchikey])) + nr_of_spectra_per_inchikey.append(len(scores_per_inchikey[inchikey])) + return unique_inchikeys, average_scores, nr_of_spectra_per_inchikey + + +def integrated_similarity_flow( + predicted_scores: List[float], similarities: np.ndarray, nr_of_spectra_per_inchikey: List[float] +) -> List[float]: + """Compute the confidence of the prediction for each candidate. + Integrated similarity flow (ISF) scores are calculated using the similarity of candidates among each other + and their distance to the query spectrum. + + Args: + distances (list): Distances of the candidates to the query spectrum in the chemical space. + similarities (list of lists): Jaccard similarity of all candidates to each other. + + Returns: + dict[int, float]: ISF scores for each candid+ate. + """ + num_hits = len(predicted_scores) + isf_scores = [] + + # Total similarity + total_similarity = sum([predicted_scores[i] * nr_of_spectra_per_inchikey[i] for i in range(len(predicted_scores))]) + + for i in range(num_hits): + isf_score = ( + sum(predicted_scores[j] * similarities[i][j] * nr_of_spectra_per_inchikey[j] for j in range(num_hits)) + / total_similarity + ) + isf_scores.append(isf_score) + + return isf_scores diff --git a/ms2query/data_processing/__init__.py b/ms2query/data_processing/__init__.py index 7dacae7..a3f3836 100644 --- a/ms2query/data_processing/__init__.py +++ b/ms2query/data_processing/__init__.py @@ -1,5 +1,5 @@ from .chemistry_utils import compute_morgan_fingerprints, inchikey14_from_full -from .fingerprint_computation import compute_fingerprints_from_smiles +from .fingerprint_computation import compute_fingerprints_from_smiles, merge_fingerprints from .merging_utils import cluster_block, get_merged_spectra from .spectra_processing import compute_spectra_embeddings, normalize_spectrum_sum @@ -11,5 +11,6 @@ "compute_spectra_embeddings", "get_merged_spectra", "inchikey14_from_full", + "merge_fingerprints", "normalize_spectrum_sum", ] diff --git a/ms2query/data_processing/fingerprint_computation.py b/ms2query/data_processing/fingerprint_computation.py index b6b2e5f..ea2f3ca 100644 --- a/ms2query/data_processing/fingerprint_computation.py +++ b/ms2query/data_processing/fingerprint_computation.py @@ -1,6 +1,8 @@ +from typing import Optional, Sequence, Tuple import numba import numpy as np from numba import typed, types +from numpy.typing import NDArray from rdkit import Chem from tqdm import tqdm @@ -255,6 +257,105 @@ def count_fingerprint_keys(fingerprints): return unique_keys[order], count_arr[order], first_arr[order] +def merge_fingerprints( + fingerprints: Sequence[Tuple[NDArray[np.integer], NDArray[np.floating]]], + weights: Optional[NDArray[np.floating]] = None, +) -> Tuple[NDArray[np.integer], NDArray[np.floating]]: + """ + Merge multiple sparse Morgan (count/TF-IDF) fingerprints into a single + weighted-average fingerprint. + + Parameters + ---------- + fingerprints : + Sequence of (bits, values) pairs. + - bits: 1D integer array of bit indices (non-zero entries) + - values: 1D float array of TF-IDF (or other) weights, + same length as `bits`. + weights : + Optional 1D array-like of length len(fingerprints) with one weight + per fingerprint. Each fingerprint's values are scaled by its weight, + then the merged fingerprint is normalized by the sum of all weights. + + - If None, all fingerprints are weighted equally (weight = 1.0). + + Returns + ------- + merged_bits, merged_values : + - merged_bits: 1D integer array of unique bit indices + - merged_values: 1D float array of weighted-average values per bit + (sum over all weighted fingerprints, divided by sum(weights)). + """ + n_fps = len(fingerprints) + if n_fps == 0: + # Return empty sparse fingerprint + return ( + np.array([], dtype=np.int64), + np.array([], dtype=np.float64), + ) + + if weights is not None: + w = np.asarray(weights, dtype=np.float64).ravel() + if w.shape[0] != n_fps: + raise ValueError( + f"weights must have length {n_fps}, got {w.shape[0]}" + ) + total_weight = float(w.sum()) + if total_weight <= 0.0: + raise ValueError("Sum of weights must be positive.") + else: + # Equal weighting + w = None + total_weight = float(n_fps) + + # Concatenate all indices and (weighted) values + bits_list = [] + vals_list = [] + + for i, (bits, vals) in enumerate(fingerprints): + bits = np.asarray(bits) + vals = np.asarray(vals, dtype=np.float64) + + if bits.shape[0] != vals.shape[0]: + raise ValueError( + f"Fingerprint {i}: bits and values must have same length, " + f"got {bits.shape[0]} and {vals.shape[0]}" + ) + + if w is not None: + vals = vals * w[i] + + bits_list.append(bits) + vals_list.append(vals) + + if not bits_list: + return ( + np.array([], dtype=np.int64), + np.array([], dtype=np.float64), + ) + + all_bits = np.concatenate(bits_list) + all_vals = np.concatenate(vals_list) + + if all_bits.size == 0: + return ( + np.array([], dtype=np.int64), + np.array([], dtype=np.float64), + ) + + # Group by bit index and sum weighted values + unique_bits, inverse = np.unique(all_bits, return_inverse=True) + summed_vals = np.bincount(inverse, weights=all_vals) + + # Weighted average: divide by sum of weights + avg_vals = summed_vals / total_weight + + # Keep dtypes reasonably tight + merged_bits = unique_bits.astype(all_bits.dtype, copy=False) + merged_vals = avg_vals.astype(np.float32, copy=False) + + return merged_bits, merged_vals + ### ------------------------ ### Bit Scaling and Weighing ### ------------------------ diff --git a/ms2query/database/ann_vector_index.py b/ms2query/database/ann_vector_index.py index 240c628..042c296 100644 --- a/ms2query/database/ann_vector_index.py +++ b/ms2query/database/ann_vector_index.py @@ -323,28 +323,63 @@ def _create_hnsw_index( def query( self, - vector: np.ndarray, + vectors: np.ndarray, k: int = 10, ef: Optional[int] = None, - ) -> List[Tuple[str, float]]: + num_threads: int = 0, + ) -> List[Tuple[str, float]] | List[List[Tuple[str, float]]]: """ Query for k nearest neighbors. - Returns list of (spec_id, similarity) tuples. + Parameters + ---------- + vectors : np.ndarray + Either a single vector of shape (dim,) or a batch of shape (N, dim). + k : int + Number of neighbors. + ef : Optional[int] + Optional per-query ef parameter for HNSW. + num_threads : int + Number of threads to use inside nmslib (0 = library default). + + Returns + ------- + Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]] + - If a single vector is given, returns a list of (spec_id, similarity). + - If a batch is given, returns a list (per query) of such lists. """ if self._index is None: raise RuntimeError("Index not built or loaded.") - v = np.asarray(vector, dtype=np.float32).reshape(1, -1) - if v.shape[1] != self.dim: - raise ValueError(f"Query must have dim={self.dim}") + X = np.asarray(vectors, dtype=np.float32) + + single = False + if X.ndim == 1: + # Single query vector: (dim,) -> (1, dim) + if X.size != self.dim: + raise ValueError(f"Query must have dim={self.dim}") + X = X.reshape(1, -1) + single = True + elif X.ndim == 2: + if X.shape[1] != self.dim: + raise ValueError(f"Expected shape (N, {self.dim}), got {X.shape}") + else: + raise ValueError("vectors must be 1D or 2D array.") if ef is not None: self._index.setQueryTimeParams({"ef": ef}) - idxs, dists = self._index.knnQueryBatch(v, k=k)[0] - sims = 1.0 - np.asarray(dists, dtype=np.float32) # cosine distance -> similarity - return [(str(self._ids[i]), float(sims[j])) for j, i in enumerate(idxs)] + batch_results = self._index.knnQueryBatch(X, k=k, num_threads=num_threads) + + all_out: List[List[Tuple[str, float]]] = [] + for idxs, dists in batch_results: + idxs = np.asarray(idxs, dtype=np.int64) + dists = np.asarray(dists, dtype=np.float32) + sims = 1.0 - dists # cosine distance -> similarity + out = [(str(self._ids[i]), float(s)) for i, s in zip(idxs, sims)] + all_out.append(out) + + return all_out[0] if single else all_out def save_index(self, path_prefix: str) -> None: if self._index is None: @@ -450,51 +485,141 @@ def build_index( def query( self, - query_fp: Tuple[np.ndarray, np.ndarray] | sp.csr_matrix, + query_fp: ( + Tuple[np.ndarray, np.ndarray] + | sp.csr_matrix + | Sequence[Tuple[np.ndarray, np.ndarray]] + ), k: int = 10, *, ef: Optional[int] = None, re_rank: bool = True, candidate_multiplier: int = 5, - ) -> List[Tuple[int, float]]: + num_threads: int = 0, + ) -> List[Tuple[int, float]] | List[List[Tuple[int, float]]]: """ Query for k nearest neighbors. Parameters ---------- - query_fp : (indices, values) tuple or single-row CSR - k : Number of results - re_rank : Use exact Tanimoto re-ranking - candidate_multiplier : Fetch k * multiplier candidates for re-ranking - - Returns list of (comp_id, similarity) tuples. + query_fp : + - Single query: + * (indices, values) tuple + * single-row CSR of shape (1, dim) + - Batched queries: + * CSR of shape (N, dim) + * Sequence of (indices, values) tuples + k : int + Number of results per query. + re_rank : bool + Use exact Tanimoto re-ranking. + candidate_multiplier : int + Fetch k * multiplier candidates for re-ranking. + num_threads : int + Number of threads to use inside nmslib (0 = library default). + + Returns + ------- + Union[List[Tuple[int, float]], List[List[Tuple[int, float]]]] + - For a single query, returns a list of (comp_id, similarity). + - For multiple queries, returns a list (per query) of such lists. """ if self._index is None: raise RuntimeError("Index not built or loaded.") - q = self._normalize_query(query_fp) - if q.nnz == 0: - return [] + # ------------------------- + # Normalize input to CSR + # ------------------------- + single = False + + if isinstance(query_fp, sp.csr_matrix): + Q = query_fp.astype(np.float32, copy=False) + if Q.shape[1] != self.dim: + raise ValueError(f"CSR query must have shape (N, {self.dim})") + single = Q.shape[0] == 1 + + elif isinstance(query_fp, tuple): + # Single (indices, values) + Q = csr_row_from_tuple(query_fp, dim=self.dim) + single = True + + else: + # Assume sequence of (indices, values) tuples -> batched queries + Q = tuples_to_csr(query_fp, dim=self.dim) + single = Q.shape[0] == 1 + + if (Q.data < 0).any(): + raise ValueError("Query must be non-negative for Tanimoto.") + + # Handle completely empty queries quickly + row_nnz = Q.indptr[1:] - Q.indptr[:-1] + if row_nnz.sum() == 0: + if single: + return [] + return [[] for _ in range(Q.shape[0])] if ef is not None: self._index.setQueryTimeParams({"ef": ef}) fetch = max(k, k * candidate_multiplier) - idxs, dists = self._index.knnQueryBatch(q, k=fetch)[0] - idxs = np.asarray(idxs, dtype=np.int64) - dists = np.asarray(dists, dtype=np.float32) - # Without re-ranking, return cosine similarities + # ------------------------- + # ANN search for all queries + # ------------------------- + batch_results = self._index.knnQueryBatch(Q, k=fetch, num_threads=num_threads) + + # ------------------------- + # No re-ranking: cosine sims only + # ------------------------- if not re_rank or self._csr is None or self._l1 is None: - sims = 1.0 - dists - return [(int(self._comp_ids[i]), float(s)) for i, s in zip(idxs[:k], sims[:k])] + all_out: List[List[Tuple[int, float]]] = [] - # Re-rank with exact Tanimoto - Y = self._csr[idxs] - tan = tanimoto_l1_query_vs_block(q, Y, sum1=float(q.sum()), sumsY=self._l1[idxs]) - order = np.argsort(-tan)[:k] + for qi, (idxs, dists) in enumerate(batch_results): + if row_nnz[qi] == 0: + all_out.append([]) + continue - return [(int(self._comp_ids[idxs[i]]), float(tan[i])) for i in order] + idxs = np.asarray(idxs, dtype=np.int64) + dists = np.asarray(dists, dtype=np.float32) + + sims = 1.0 - dists + out = [ + (self._comp_ids[i], float(s)) + for i, s in zip(idxs[:k], sims[:k]) + ] + all_out.append(out) + + return all_out[0] if single else all_out + + # ------------------------- + # Exact Tanimoto re-ranking + # ------------------------- + all_out: List[List[Tuple[int, float]]] = [] + + for qi, (idxs, dists) in enumerate(batch_results): + if row_nnz[qi] == 0: + all_out.append([]) + continue + + idxs = np.asarray(idxs, dtype=np.int64) + + q_row = Q[qi] + Y = self._csr[idxs] + tan = tanimoto_l1_query_vs_block( + q_row, + Y, + sum1=float(q_row.sum()), + sumsY=self._l1[idxs], + ) + + order = np.argsort(-tan)[:k] + out = [ + (self._comp_ids[idxs[i]], float(tan[i])) + for i in order + ] + all_out.append(out) + + return all_out[0] if single else all_out def _normalize_query(self, query_fp) -> sp.csr_matrix: """Convert query to single-row CSR and validate.""" @@ -602,10 +727,11 @@ def save_index(self, path_prefix: str) -> None: if self._index is None: raise RuntimeError("Index not built.") - self._index.saveIndex(f"{path_prefix}.nmslib") + # Also save data so that load_index(..., load_data=True) works + self._index.saveIndex(f"{path_prefix}.nmslib", save_data=True) np.save(f"{path_prefix}.ids.npy", self._comp_ids) - meta = {**self._meta, "dim": self.dim, "space": self.space} + meta = {**self._meta, "dim": int(self.dim), "space": str(self.space)} with open(f"{path_prefix}.meta.json", "w") as f: json.dump(meta, f) diff --git a/ms2query/database/compound_database.py b/ms2query/database/compound_database.py index 80b931f..140aaea 100644 --- a/ms2query/database/compound_database.py +++ b/ms2query/database/compound_database.py @@ -467,6 +467,15 @@ def compute_fingerprints( Does NOT write to the database. Provide exactly one of (smiles, inchis). + Parameters + ---------- + smiles : Optional[List[str]], optional + List of SMILES strings, by default None + inchis : Optional[List[str]], optional + List of InChI strings, by default None + progress_bar : bool, optional + Whether to show a progress bar, by default False + Returns the same shapes/types as compute_morgan_fingerprints: - dense: np.ndarray of shape (N, nbits) - sparse/binary: List[np.ndarray[uint32]] @@ -570,6 +579,39 @@ def get_compounds(self, comp_ids: List[str]) -> pd.DataFrame: .drop(columns="__order") .reset_index(drop=True)) + def get_all_compound_ids(self) -> List[str]: + """Return all compound IDs in ascending comp_id order. + """ + rows = self._conn.execute(f""" + SELECT comp_id FROM {self.table} + ORDER BY comp_id ASC + """).fetchall() + return [row["comp_id"] for row in rows] + + def get_all_fingerprints_and_comp_ids(self) -> Dict[str, AnyFP]: + """Return all compound IDs and their fingerprints in ascending comp_id order. + """ + rows = self._conn.execute(f""" + SELECT comp_id, fingerprint_bits, fingerprint_counts, fingerprint_dense + FROM {self.table} + ORDER BY comp_id ASC + """).fetchall() + comp_ids = [] + fps = [] + for row in rows: + comp_ids.append(row["comp_id"]) + dense_blob = row["fingerprint_dense"] or b"" + bits_blob = row["fingerprint_bits"] or b"" + counts_blob = row["fingerprint_counts"] or b"" + if dense_blob: + fps.append(decode_dense_fp(dense_blob, dtype=self.fingerprint_dtype_dense)) + elif bits_blob or counts_blob: + bits, counts = decode_sparse_fp(bits_blob, counts_blob) + fps.append(bits if counts.size == 0 else (bits, counts)) + else: + fps.append(None) + return {"comp_ids": comp_ids, "fingerprints": fps} + def sql_query(self, query: str) -> pd.DataFrame: return pd.read_sql_query(query, self._conn) diff --git a/ms2query/database/spec_to_compound_mapper.py b/ms2query/database/spec_to_compound_mapper.py index 7442437..e7c5e4c 100644 --- a/ms2query/database/spec_to_compound_mapper.py +++ b/ms2query/database/spec_to_compound_mapper.py @@ -1,7 +1,7 @@ import sqlite3 from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple import pandas as pd from ms2query.data_processing import inchikey14_from_full from ms2query.database import CompoundDatabase @@ -11,10 +11,11 @@ # Mapping: spectrum <-> compound (spec_to_comp) # ================================================== + @dataclass class SpecToCompoundMap: - """This class manages the mapping between spectrum IDs and compound IDs. - + """Manage the mapping between spectrum IDs and compound IDs (inchikey14). + Attributes ---------- sqlite_path : str @@ -27,6 +28,7 @@ class SpecToCompoundMap: sqlite_path: str table: str = "spec_to_comp" compound_table: str = "compounds" + _conn: sqlite3.Connection = field(init=False, repr=False) def __post_init__(self): @@ -35,74 +37,149 @@ def __post_init__(self): self._conn.row_factory = sqlite3.Row self._ensure_schema() - def close(self): + # ---------------- lifecycle ---------------- + + def close(self) -> None: try: self._conn.close() except Exception: pass - def _ensure_schema(self): + # ---------------- schema ---------------- + + def _ensure_schema(self) -> None: cur = self._conn.cursor() - # No strict FK enforcement (SpectralDatabase may have been created without FK pragma), - # here: index both sides for fast lookup. - cur.executescript(f""" + # No strict FK enforcement; index both sides for fast lookup. + cur.executescript( + f""" CREATE TABLE IF NOT EXISTS {self.table}( spec_id TEXT NOT NULL, - comp_id TEXT NOT NULL, + comp_id TEXT NOT NULL, PRIMARY KEY (spec_id), CHECK (length(comp_id) = 14) ); - CREATE INDEX IF NOT EXISTS idx_spec_to_comp_comp ON {self.table}(comp_id); - """) + CREATE INDEX IF NOT EXISTS idx_{self.table}_comp ON {self.table}(comp_id); + """ + ) self._conn.commit() - # ---------- API ---------- + # ---------------- API ---------------- - def link(self, spec_id: str, comp_id: str): - """Insert or replace a single mapping.""" + def link(self, spec_id: str, comp_id: str) -> None: + """Insert or replace a single mapping (spec_id -> comp_id).""" if not comp_id or len(comp_id) != 14: raise ValueError("comp_id must be inchikey14 (14 characters).") - self._conn.execute(f""" + self._conn.execute( + f""" INSERT INTO {self.table} (spec_id, comp_id) VALUES (?, ?) - ON CONFLICT(spec_id) DO UPDATE SET comp_id=excluded.comp_id - """, (spec_id, comp_id)) + ON CONFLICT(spec_id) DO UPDATE SET comp_id = excluded.comp_id + """, + (spec_id, comp_id), + ) self._conn.commit() - def link_many(self, pairs: Iterable[Tuple[int, str]]): - """Bulk link (spec_id, comp_id).""" + def link_many(self, pairs: Iterable[Tuple[str, str]]) -> None: + """Bulk link (spec_id, comp_id) pairs.""" cur = self._conn.cursor() cur.execute("BEGIN") try: - cur.executemany(f""" + cur.executemany( + f""" INSERT INTO {self.table} (spec_id, comp_id) VALUES (?, ?) - ON CONFLICT(spec_id) DO UPDATE SET comp_id=excluded.comp_id - """, list(pairs)) + ON CONFLICT(spec_id) DO UPDATE SET comp_id = excluded.comp_id + """, + list(pairs), + ) cur.execute("COMMIT") except Exception: cur.execute("ROLLBACK") raise - def get_comp_id_for_specs(self, spec_ids: List[str]) -> pd.DataFrame: - """Return a DataFrame with columns [spec_id, comp_id] for the provided spec_ids.""" + # ---- getters: spec_id -> comp_id ---- + + def get_comp_id_for_spec(self, spec_id: str) -> Optional[str]: + """ + Return the comp_id for a single spec_id, or None if not mapped. + """ + df = self.get_comp_id_for_specs([spec_id]) + if df.empty: + return None + val = df.loc[0, "comp_id"] + return None if pd.isna(val) else str(val) + + def get_comp_id_for_specs(self, spec_ids: Sequence[str]) -> pd.DataFrame: + """ + Return a DataFrame with columns ['spec_id', 'comp_id'] for the provided spec_ids. + + Behaviour + --------- + - Only returns rows for spec_ids that actually have a mapping. + - spec_id values come from the database (TEXT → strings). + - The result may have fewer rows than requested. + """ + cols = ["spec_id", "comp_id"] if not spec_ids: - return pd.DataFrame(columns=["spec_id", "comp_id"]) + return pd.DataFrame(columns=cols) + placeholders = ",".join("?" * len(spec_ids)) rows = self._conn.execute( - f"SELECT spec_id, comp_id FROM {self.table} WHERE spec_id IN ({placeholders})", - spec_ids + f""" + SELECT spec_id, comp_id + FROM {self.table} + WHERE spec_id IN ({placeholders}) + """, + list(spec_ids), ).fetchall() - return pd.DataFrame(rows, columns=["spec_id", "comp_id"]) + + # rows already have the correct types from SQLite (TEXT for both columns) + return pd.DataFrame(rows, columns=cols) + + # ---- getters: comp_id -> spec_id(s) ---- def get_specs_for_comp(self, comp_id: str) -> List[str]: - """Return list of spec_ids for a given comp_id.""" - rows = self._conn.execute(f"SELECT spec_id FROM {self.table} WHERE comp_id = ?", (comp_id,)).fetchall() - return [r[0] for r in rows] + """ + Return list of spec_ids (as strings) for a single comp_id. + """ + rows = self._conn.execute( + f"SELECT spec_id FROM {self.table} WHERE comp_id = ?", + (comp_id,), + ).fetchall() + return [str(r[0]) for r in rows] + + def get_specs_for_comps(self, comp_ids: Sequence[str]) -> pd.DataFrame: + """ + Return a DataFrame with columns ['comp_id', 'spec_id'] for the given comp_ids. + + Behaviour + --------- + - For each comp_id, all mapped spec_ids are returned (1:N). + - Order of returned rows is determined by the underlying query. + - If a comp_id has no spectra, it will not appear in the result. + """ + cols = ["comp_id", "spec_id"] + if not comp_ids: + return pd.DataFrame(columns=cols) + + placeholders = ",".join("?" * len(comp_ids)) + rows = self._conn.execute( + f""" + SELECT comp_id, spec_id + FROM {self.table} + WHERE comp_id IN ({placeholders}) + """, + list(comp_ids), + ).fetchall() + return pd.DataFrame(rows, columns=cols) + + # ---- misc ---- def get_all_mappings(self) -> pd.DataFrame: """Return all spec_id <-> comp_id mappings as a DataFrame.""" - rows = self._conn.execute(f"SELECT spec_id, comp_id FROM {self.table}").fetchall() + rows = self._conn.execute( + f"SELECT spec_id, comp_id FROM {self.table}" + ).fetchall() return pd.DataFrame(rows, columns=["spec_id", "comp_id"]) @@ -110,6 +187,7 @@ def get_all_mappings(self) -> pd.DataFrame: # Integrations with SpectralDatabase # ================================================== + def map_from_spectraldb_metadata( spectral_db_sqlite_path: str, mapping_sqlite_path: Optional[str] = None, @@ -118,15 +196,17 @@ def map_from_spectraldb_metadata( compound_table: str = "compounds", mapping_table: str = "spec_to_comp", *, - create_missing_compounds: bool = True + create_missing_compounds: bool = True, ) -> Tuple[int, int]: """ Read spectra metadata (expects 'inchikey' in metadata), create comp_id (inchikey14), populate spec_to_comp, and optionally upsert minimal compounds. - Returns: (n_mapped, n_new_compounds) + Returns + ------- + (n_mapped, n_new_compounds) """ - # We do not import the class to avoid circular imports; use sqlite directly. + # We do not import the SpectralDatabase class to avoid circular imports; use sqlite directly. s_conn = sqlite3.connect(spectral_db_sqlite_path) s_conn.row_factory = sqlite3.Row @@ -138,13 +218,20 @@ def map_from_spectraldb_metadata( # Discover which columns exist in the spectra table cols = {r[1] for r in s_conn.execute(f"PRAGMA table_info({spectra_table})").fetchall()} - want = ["spec_id", "inchikey", "smiles", "inchi", "classyfire_class", "classyfire_superclass"] + want = [ + "spec_id", + "inchikey", + "smiles", + "inchi", + "classyfire_class", + "classyfire_superclass", + ] have = [c for c in want if c in cols] select_cols = ", ".join(have) rows = s_conn.execute(f"SELECT {select_cols} FROM {spectra_table}").fetchall() - to_link: List[Tuple[int, str]] = [] + to_link: List[Tuple[str, str]] = [] new_comp_rows: List[Dict[str, Any]] = [] for r in rows: @@ -153,20 +240,26 @@ def map_from_spectraldb_metadata( ik_full = r.get("inchikey") if not ik_full: continue + comp_id = inchikey14_from_full(ik_full) if not comp_id: continue - to_link.append((spec_id, comp_id)) + + # spec_id may be int in the spectra table; mapping expects TEXT, but + # SQLite will happily store the string representation. + to_link.append((str(spec_id), comp_id)) if create_missing_compounds: - new_comp_rows.append({ - "smiles": r.get("smiles"), - "inchi": r.get("inchi"), - "inchikey": ik_full, - "classyfire_class": r.get("classyfire_class"), - "classyfire_superclass": r.get("classyfire_superclass"), - "fingerprint": None, # backfill later - }) + new_comp_rows.append( + { + "smiles": r.get("smiles"), + "inchi": r.get("inchi"), + "inchikey": ik_full, + "classyfire_class": r.get("classyfire_class"), + "classyfire_superclass": r.get("classyfire_superclass"), + "fingerprint": None, # backfill later + } + ) # Bulk linking if to_link: @@ -175,7 +268,7 @@ def map_from_spectraldb_metadata( # Upsert compounds n_new_compounds = 0 if create_missing_compounds and new_comp_rows: - # Deduplicate by comp_id to avoid redundant upserts + # Deduplicate by comp_id (inchikey14) to avoid redundant upserts seen: set[str] = set() dedup_rows: List[Dict[str, Any]] = [] for r in new_comp_rows: @@ -183,9 +276,10 @@ def map_from_spectraldb_metadata( if cid and cid not in seen: seen.add(cid) dedup_rows.append(r) + before = compdb.sql_query(f"SELECT COUNT(*) AS n FROM {compound_table}")["n"].iloc[0] compdb.upsert_many(dedup_rows) - after = compdb.sql_query(f"SELECT COUNT(*) AS n FROM {compound_table}")["n"].iloc[0] + after = compdb.sql_query(f"SELECT COUNT(*) AS n FROM {compound_table}")["n"].iloc[0] n_new_compounds = int(after - before) n_mapped = len(to_link) @@ -202,41 +296,61 @@ def get_unique_compounds_from_spectraldb( spectral_db_sqlite_path: str, spectra_table: str = "spectra", external_meta: Optional[pd.DataFrame] = None, - external_key_col: str = "inchikey14" + external_key_col: str = "inchikey14", ) -> pd.DataFrame: """ - Return a DataFrame of unique compounds present in the spectral DB, inferred via inchikey → inchikey14. - Columns: inchikey14, inchikey (full), n_spectra. If `external_meta` is provided, - it will be left-joined on `external_key_col` (default 'inchikey14'). + Return a DataFrame of unique compounds present in the spectral DB, inferred via + inchikey → inchikey14. + + Columns (first three): + - inchikey14 + - n_spectra + - inchikey (full) + + If `external_meta` is provided, it is left-joined on `external_key_col` + (default: 'inchikey14'). """ conn = sqlite3.connect(spectral_db_sqlite_path) conn.row_factory = sqlite3.Row - # pull spec_id + inchikey from spectra + # Pull spec_id + inchikey from spectra df = pd.read_sql_query(f"SELECT spec_id, inchikey FROM {spectra_table}", conn) conn.close() if df.empty: - base = pd.DataFrame(columns=["inchikey14", "inchikey", "n_spectra"]) + base = pd.DataFrame(columns=["inchikey14", "n_spectra", "inchikey"]) if external_meta is not None: - return base.merge(external_meta, how="left", left_on="inchikey14", right_on=external_key_col) + return base.merge( + external_meta, + how="left", + left_on="inchikey14", + right_on=external_key_col, + ) return base # Compute inchikey14 ik14 = df["inchikey"].fillna("").map(inchikey14_from_full) df["inchikey14"] = ik14 - # Aggregate + # Aggregate: order of columns is important for tests: + # ["inchikey14", "n_spectra", "inchikey"] agg = ( df.dropna(subset=["inchikey14"]) - .groupby(["inchikey14"], as_index=False) - .agg(n_spectra=("spec_id", "count"), - inchikey=("inchikey", "first")) # first full key seen + .groupby(["inchikey14"], as_index=False) + .agg( + n_spectra=("spec_id", "count"), + inchikey=("inchikey", "first"), # first full key seen + ) ) # Optional join with external meta if external_meta is not None and not external_meta.empty: - agg = agg.merge(external_meta, how="left", left_on="inchikey14", right_on=external_key_col) + agg = agg.merge( + external_meta, + how="left", + left_on="inchikey14", + right_on=external_key_col, + ) # Order by prevalence agg = agg.sort_values("n_spectra", ascending=False).reset_index(drop=True) diff --git a/ms2query/database/spectral_database.py b/ms2query/database/spectral_database.py index 5676bd4..806a1b8 100644 --- a/ms2query/database/spectral_database.py +++ b/ms2query/database/spectral_database.py @@ -57,10 +57,19 @@ def _normalize_metadata(md: Dict[str, Any], fields: Iterable[str]) -> Dict[str, class SpectralDatabase: sqlite_path: str table: str = "spectra" - metadata_fields: List[str] = field(default_factory=lambda: [ - "precursor_mz", "ionmode", "smiles", "inchikey", "inchi", "name", - "instrument_type", "adduct", "collision_energy" - ]) + metadata_fields: List[str] = field( + default_factory=lambda: [ + "precursor_mz", + "ionmode", + "smiles", + "inchikey", + "inchi", + "name", + "instrument_type", + "adduct", + "collision_energy", + ] + ) spectrum_sum_normalization_for_embedding: bool = True _conn: sqlite3.Connection = field(init=False, repr=False) _ms2ds_model_path: Optional[str] = field(default=None, repr=False) @@ -81,7 +90,8 @@ def add_spectra(self, spectra: List[Spectrum]) -> List[str]: cur = self._conn.cursor() # Bulk-load speed PRAGMAs (safe for single-user/batch ingest) - cur.executescript(""" + cur.executescript( + """ PRAGMA journal_mode=WAL; PRAGMA synchronous=OFF; PRAGMA temp_store=MEMORY; @@ -130,14 +140,14 @@ def ids(self) -> List[str]: rows = cur.execute(f"SELECT spec_id FROM {self.table}").fetchall() return [str(row["spec_id"]) for row in rows] - def get_spectra_by_ids(self, specIDs: List[str]) -> List[Spectrum]: - """Retrieve full Spectrum objects for given specIDs (order preserved, missing IDs skipped).""" + def get_spectra_by_ids(self, spec_ids: List[str]) -> List[Spectrum]: + """Retrieve full Spectrum objects for given spec_ids (order preserved, missing IDs skipped).""" rows = self._fetch_rows_by_ids( - specIDs, cols="spec_id, mz_blob, intensity_blob, n_peaks, " + ", ".join(self.metadata_fields)) + spec_ids, cols="spec_id, mz_blob, intensity_blob, n_peaks, " + ", ".join(self.metadata_fields)) by_id = {row["spec_id"]: row for row in rows} result: List[Spectrum] = [] - for sid in specIDs: + for sid in spec_ids: row = by_id.get(sid) if row is None: continue @@ -149,13 +159,19 @@ def get_spectra_by_ids(self, specIDs: List[str]) -> List[Spectrum]: result.append(Spectrum(mz=mz, intensities=inten, metadata=md)) return result - def get_fragments_by_ids(self, specIDs: List[str]) -> List[Tuple[np.ndarray, np.ndarray]]: - """Retrieve (mz, intensity) arrays for given specIDs (order preserved, missing IDs skipped).""" - rows = self._fetch_rows_by_ids(specIDs, cols="spec_id, mz_blob, intensity_blob, n_peaks") + def get_fragments_by_ids(self, spec_ids: List[str]) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Retrieve (mz, intensity) arrays for given spec_ids. + + Order is preserved with respect to `spec_ids`. + Missing IDs are skipped. + """ + cols = "spec_id, mz_blob, intensity_blob, n_peaks" + rows = self._fetch_rows_by_ids(spec_ids, cols=cols) by_id = {row["spec_id"]: row for row in rows} out: List[Tuple[np.ndarray, np.ndarray]] = [] - for sid in specIDs: + for sid in spec_ids: row = by_id.get(sid) if row is None: continue @@ -165,16 +181,34 @@ def get_fragments_by_ids(self, specIDs: List[str]) -> List[Tuple[np.ndarray, np. out.append((mz, inten)) return out - def get_metadata_by_ids(self, specIDs: List[str]) -> pd.DataFrame: - """Retrieve metadata for given specIDs (order preserved).""" + def get_metadata_by_ids(self, spec_ids: List[str]) -> pd.DataFrame: + """ + Retrieve metadata for given spec_ids. + + Returns a DataFrame with **one row per requested spec_id** in the same + order as `spec_ids`. If a spec_id is not present in the database, a row + with that spec_id and metadata columns set to None/NaN is returned. + """ cols = ["spec_id"] + self.metadata_fields - rows = self._fetch_rows_by_ids(specIDs, cols=", ".join(cols)) - df = pd.DataFrame(rows, columns=cols) - if not df.empty: - order = {sid: i for i, sid in enumerate(specIDs)} - df["__order"] = df["spec_id"].map(order) - df = df.sort_values("__order").drop(columns="__order").reset_index(drop=True) - return df + if not spec_ids: + return pd.DataFrame(columns=cols) + + rows = self._fetch_rows_by_ids(spec_ids, cols=", ".join(cols)) + by_id = {row["spec_id"]: row for row in rows} + + records: List[Dict[str, Any]] = [] + for sid in spec_ids: + row = by_id.get(sid) + if row is None: + rec = {"spec_id": sid} + rec.update({k: None for k in self.metadata_fields}) + else: + rec = {"spec_id": sid} + for k in self.metadata_fields: + rec[k] = row[k] + records.append(rec) + + return pd.DataFrame.from_records(records, columns=cols) def sql_query(self, query: str) -> pd.DataFrame: """Run a raw SQL SELECT and return a DataFrame.""" @@ -225,7 +259,6 @@ def compute_embeddings_to_sqlite( - Uses `matchms.Spectrum` objects reconstructed from the stored peaks & metadata. - Stores raw float32 vectors (no extra header) with their dimension `d`. """ - # TODO: add batch_size to speed up? spectra_table = spectra_table or self.table self._ensure_schema() # spectra schema self.ensure_embeddings_schema(embeddings_table) @@ -254,7 +287,9 @@ def compute_embeddings_to_sqlite( model = self.load_ms2deepscore_model(model_path) inserted = 0 - buf: List[Tuple[str, bytes, bytes, int, float, str, Optional[int]]] = [] + buf: List[ + Tuple[str, bytes, bytes, int, float, str, Optional[int]] + ] = [] done_since_commit = 0 def flush(batch) -> int: @@ -265,24 +300,35 @@ def flush(batch) -> int: for sid, mz_blob, it_blob, n_peaks, prec_mz, ionmode, charge in batch: mz = _from_float32_bytes(mz_blob, int(n_peaks)) it = _from_float32_bytes(it_blob, int(n_peaks)) - spectrum = Spectrum(mz=mz, intensities=it, metadata={ - "precursor_mz": float(prec_mz) if prec_mz is not None else None, - "ionmode": ionmode, - "charge": charge, - "spec_id": sid, - }) + spectrum = Spectrum( + mz=mz, + intensities=it, + metadata={ + "precursor_mz": float(prec_mz) if prec_mz is not None else None, + "ionmode": ionmode, + "charge": charge, + "spec_id": sid, + }, + ) specs.append(spectrum) sids.append(sid) embeddings = compute_spectra_embeddings( - model, specs, - normalize_spectrum=self.spectrum_sum_normalization_for_embedding - ) + model, + specs, + normalize_spectrum=self.spectrum_sum_normalization_for_embedding, + ) dim = int(embeddings.shape[1]) - q = f"INSERT OR REPLACE INTO {embeddings_table} (spec_id, d, vec) VALUES (?, ?, ?);" + q = ( + f"INSERT OR REPLACE INTO {embeddings_table} " + f"(spec_id, d, vec) VALUES (?, ?, ?);" + ) with self._conn: for sid, embedding in zip(sids, embeddings): - self._conn.execute(q, (sid, dim, sqlite3.Binary(_as_float32_bytes(embedding)))) + self._conn.execute( + q, + (sid, dim, sqlite3.Binary(_as_float32_bytes(embedding))), + ) return len(batch) while True: @@ -303,7 +349,7 @@ def flush(batch) -> int: def get_embeddings( self, - ids: Optional[List[str]] = None, + spec_ids: Optional[List[str]] = None, *, embeddings_table: str = "embeddings", normalized: bool = True, @@ -314,13 +360,14 @@ def get_embeddings( If normalized=True, L2-normalize (recommended for cosine). """ cur = self._conn.cursor() - if ids is None: + if spec_ids is None: cur.execute(f"SELECT spec_id, d, vec FROM {embeddings_table} ORDER BY spec_id ASC;") else: - ph = ",".join("?" for _ in ids) + placeholders = ",".join("?" for _ in spec_ids) cur.execute( - f"SELECT spec_id, d, vec FROM {embeddings_table} WHERE spec_id IN ({ph}) ORDER BY spec_id ASC;", - ids) + f"""SELECT spec_id, d, vec FROM {embeddings_table} + WHERE spec_id IN ({placeholders}) ORDER BY spec_id ASC;""", + spec_ids) sids: List[str] = [] vecs: List[np.ndarray] = [] @@ -343,29 +390,19 @@ def get_embeddings( X = X / n return np.asarray(sids, dtype=object), X - def get_embedding_for_id( - self, - spec_id: str, - *, - embeddings_table: str = "embeddings", - normalized: bool = True, - ) -> Optional[np.ndarray]: - ids, X = self.get_embeddings([spec_id], embeddings_table=embeddings_table, normalized=normalized) - return X[0] if X.shape[0] else None - # expose raw connection for ANN builders @property def connection(self) -> sqlite3.Connection: return self._conn # ---------- internal ---------- - def _fetch_rows_by_ids(self, specIDs: List[str], cols: str) -> List[sqlite3.Row]: - if not specIDs: + def _fetch_rows_by_ids(self, spec_ids: List[str], cols: str) -> List[sqlite3.Row]: + if not spec_ids: return [] - placeholders = ",".join("?" for _ in specIDs) + placeholders = ",".join("?" for _ in spec_ids) sql = f"SELECT {cols} FROM {self.table} WHERE spec_id IN ({placeholders})" cur = self._conn.cursor() - return cur.execute(sql, specIDs).fetchall() + return cur.execute(sql, spec_ids).fetchall() def _ensure_schema(self): cur = self._conn.cursor() diff --git a/ms2query/library_io.py b/ms2query/library_io.py index 86815c0..d53e7f8 100644 --- a/ms2query/library_io.py +++ b/ms2query/library_io.py @@ -8,7 +8,7 @@ from tqdm import tqdm from ms2query import MS2QueryDatabase, MS2QueryLibrary from ms2query.data_processing.merging_utils import cluster_block, get_merged_spectra -from ms2query.database import EmbeddingIndex +from ms2query.database import EmbeddingIndex, FingerprintSparseIndex from ms2query.database.spectra_merging import _split_by_mode_charge @@ -18,6 +18,7 @@ _SQLITE_NAME = "ms2query_library.sqlite" _EMB_TABLE = "embeddings" _EMB_INDEX_BASENAME = "embedding_index" # will create embedding_index.{nmslib,ids.npy,meta.json} +_FP_INDEX_BASENAME = "fingerprint_index" # will create fingerprint_index.{nmslib,ids.npy,meta.json} def _handle_default_settings(settings: dict) -> dict: @@ -81,6 +82,7 @@ def create_new_library( model_path: str, additional_compound_file: Optional[str] = None, build_embedding_index: bool = True, + build_fingerprint_index: bool = True, embedding_index_params: Optional[dict] = None, compute_embeddings_batch_rows: int = 4096, **settings, @@ -101,6 +103,8 @@ def create_new_library( CSV/TSV file with additional compounds (inchikey/smiles/etc.). No fingerprints assumed here. build_embedding_index : bool Whether to build the nmslib cosine HNSW index over embeddings. + build_fingerprint_index : bool + Whether to build the FingerprintSparseIndex over compound fingerprints. embedding_index_params : dict Params for HNSW: {'M': int, 'ef_construction': int, 'post_init_ef': int, 'batch_rows': int} compute_embeddings_batch_rows : int @@ -139,6 +143,8 @@ def create_new_library( _print_progress(f"Inserted {creation_stats['n_inserted_spectra']} spectra.") _print_progress(f"Mapped {creation_stats['n_mapped']} spectra to compounds; " f"created {creation_stats['n_new_compounds']} new compounds.") + stats = ms2query_db.ref_cdb.compute_fingerprints_missing() + _print_progress(f"Computed fingerprints for {stats['updated']} compounds.") if additional_compound_file is not None: if not additional_compound_file.lower().endswith((".csv", ".tsv", ".txt")): @@ -199,6 +205,23 @@ def create_new_library( _print_progress(f"Saved EmbeddingIndex files with prefix: {emb_prefix}") lib.set_embedding_index(emb_index) + if build_fingerprint_index: + # TODO: this is not efficient yet; improve later + _print_progress("Building FingerprintSparseIndex ...") + results = lib.db.ref_cdb.get_all_fingerprints_and_comp_ids() + max_bits = [x[0][-1] for x in results["fingerprints"]] + fp_index = FingerprintSparseIndex(dim=int(max(max_bits) + 1)) + + fp_index.build_index( + results["fingerprints"], + results["comp_ids"], + ) + fp_prefix = str(out_dir / _FP_INDEX_BASENAME) + fp_index.save_index(fp_prefix) + _print_progress(f"Saved FingerprintSparseIndex files with prefix: {fp_prefix}") + lib.set_fingerprint_index(fp_index) + + # ----------------------------- # Manifest # ----------------------------- @@ -208,6 +231,7 @@ def create_new_library( "embedding_table": _EMB_TABLE, "model_path": model_path, # stored for convenience; not copied "embedding_index_prefix": _EMB_INDEX_BASENAME if build_embedding_index else None, + "fingerprint_index_prefix": _FP_INDEX_BASENAME if build_fingerprint_index else None, "settings": settings, } with open(out_dir / _MANIFEST_NAME, "w", encoding="utf-8") as f: @@ -262,6 +286,13 @@ def load_created_library(folder: str) -> MS2QueryLibrary: emb_index = EmbeddingIndex() emb_index.load_index(str(out_dir / emb_prefix)) lib.set_embedding_index(emb_index) + + # Load FingerprintSparseIndex if present + fp_prefix = manifest.get("fingerprint_index_prefix") + if fp_prefix: + fp_index = FingerprintSparseIndex() + fp_index.load_index(str(out_dir / fp_prefix)) + lib.set_fingerprint_index(fp_index) # (Optional) Load fingerprint index here if/when you add it later. diff --git a/ms2query/ms2query_database.py b/ms2query/ms2query_database.py index 8cbb035..7cb3c60 100644 --- a/ms2query/ms2query_database.py +++ b/ms2query/ms2query_database.py @@ -1,6 +1,7 @@ import sqlite3 from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Sequence +import numpy as np import pandas as pd from ms2query.data_processing import inchikey14_from_full from ms2query.database import ( @@ -13,6 +14,7 @@ # ================================ public wrapper ============================== + @dataclass class MS2QueryDatabase: """Wrapper class as main hub/glue between the different MS2Query database elements. @@ -23,17 +25,26 @@ class MS2QueryDatabase: * Provide one-stop creation from (already processed!) `matchms.Spectrum` objects. * Offer ergonomic retrievals by `spec_id`, `comp_id` (inchikey14). * Keep *types and table access paths* in one place. - """ sqlite_path: str ref_spectra_table: str = "spectra" ref_compound_table: str = "compounds" non_annotated_compound_table: str = "compounds_all" - metadata_fields: List[str] = field(default_factory=lambda: [ - "precursor_mz", "ionmode", "smiles", "inchikey", "inchi", "name", - "charge", "instrument_type", "adduct", "collision_energy" - ]) + metadata_fields: List[str] = field( + default_factory=lambda: [ + "precursor_mz", + "ionmode", + "smiles", + "inchikey", + "inchi", + "name", + "charge", + "instrument_type", + "adduct", + "collision_energy", + ] + ) # component singletons ref_sdb: SpectralDatabase = field(init=False) @@ -43,12 +54,18 @@ class MS2QueryDatabase: def __post_init__(self): # Initialize components (each manages its own connection) - self.ref_sdb = SpectralDatabase(self.sqlite_path, table=self.ref_spectra_table, - metadata_fields=self.metadata_fields) + self.ref_sdb = SpectralDatabase( + self.sqlite_path, + table=self.ref_spectra_table, + metadata_fields=self.metadata_fields, + ) self.ref_cdb = CompoundDatabase(self.sqlite_path, table=self.ref_compound_table) - self.all_cdb = CompoundDatabase(self.sqlite_path, table=self.non_annotated_compound_table) - self.mapper = SpecToCompoundMap(self.sqlite_path, compound_table=self.ref_compound_table) - + self.all_cdb = CompoundDatabase( + self.sqlite_path, table=self.non_annotated_compound_table + ) + self.mapper = SpecToCompoundMap( + self.sqlite_path, compound_table=self.ref_compound_table + ) # ----------------------------- creation pipeline ----------------------------- @@ -63,14 +80,17 @@ def create_from_spectra( Parameters ---------- - spectra: List[matchms.Spectrum] + spectra : List[matchms.Spectrum] List of matchms Spectrum objects to be inserted into the database. - map_compounds: bool, default=True + map_compounds : bool, default=True Whether to map spectra to compounds based on metadata InChIKeys. - create_missing_compounds: bool, default=True + create_missing_compounds : bool, default=True Whether to create compound entries for spectra that do not have a matching compound yet. - Returns counts: {"n_inserted_spectra": int, "n_mapped": int, "n_new_compounds": int} + Returns + ------- + dict + Counts: {"n_inserted_spectra": int, "n_mapped": int, "n_new_compounds": int} """ spec_ids = self.ref_sdb.add_spectra(spectra) n_mapped = 0 @@ -90,13 +110,13 @@ def create_from_spectra( "n_mapped": int(n_mapped), "n_new_compounds": int(n_new), } - - def add_second_compound_database(self, df): + + def add_second_compound_database(self, df: pd.DataFrame) -> None: """Add an additional 'all compound' database without need for spectral data. Parameters ---------- - df: pd.DataFrame + df : pd.DataFrame DataFrame containing inchikey and other relevant compound information. Should at least contain smiles or inchi. """ @@ -105,30 +125,91 @@ def add_second_compound_database(self, df): # --------------------------------- retrievals -------------------------------- # ---- by spec_id ---- - def spectra_by_spec_ids(self, spec_ids: List[int]): - return self.ref_sdb.get_spectra_by_ids(spec_ids) + def spectra_by_spec_ids(self, spec_ids: Sequence[str]): + """Return list[Spectrum] for the given spec_ids.""" + return self.ref_sdb.get_spectra_by_ids(list(spec_ids)) - def fragments_by_spec_ids(self, spec_ids: List[int]): - return self.ref_sdb.get_fragments_by_ids(spec_ids) + def fragments_by_spec_ids(self, spec_ids: Sequence[str]): + """Return list[(mz, intensity)] for the given spec_ids.""" + return self.ref_sdb.get_fragments_by_ids(list(spec_ids)) - def metadata_by_spec_ids(self, spec_ids: List[int]) -> pd.DataFrame: - return self.ref_sdb.get_metadata_by_ids(spec_ids) + def metadata_by_spec_ids(self, spec_ids: Sequence[str]) -> pd.DataFrame: + """Return metadata DataFrame for the given spec_ids.""" + return self.ref_sdb.get_metadata_by_ids(list(spec_ids)) - # ---- by comp_id (inchikey14) ---- + def embeddings_by_spec_ids(self, spec_ids: Sequence[str]): + """Return (ids, embeddings) tuple for the given spec_ids.""" + return self.ref_sdb.get_embeddings(spec_ids=list(spec_ids)) - def spec_ids_by_comp_id(self, comp_id: str) -> List[int]: - return self.mapper.get_specs_for_comp(comp_id) + # ---- by comp_ids (inchikey14) ---- + + def spec_ids_by_comp_ids(self, comp_ids: Sequence[str]) -> pd.DataFrame: + """ + Return mapping of comp_ids -> spec_ids. + + Returns + ------- + pd.DataFrame + Columns: ['comp_id', 'spec_id']. + One row per existing mapping (1:N). + """ + return self.mapper.get_specs_for_comps(list(comp_ids)) - def spectra_by_comp_id(self, comp_id: str): - return self.ref_sdb.get_spectra_by_ids(self.spec_ids_by_comp_id(comp_id)) + def spectra_by_comp_ids(self, comp_ids: Sequence[str]): + """ + Return all spectra mapped to any of the given comp_ids. + + Notes + ----- + * The order of spectra is determined by the underlying `get_spectra_by_ids` + implementation and mapping table. + * If you need to know which comp_id each spectrum belongs to, combine this + with `spec_ids_by_comp_ids`. + """ + df_map = self.mapper.get_specs_for_comps(list(comp_ids)) + if df_map.empty: + return [] + spec_ids = df_map["spec_id"].tolist() + return self.ref_sdb.get_spectra_by_ids(spec_ids) - def metadata_by_comp_id(self, comp_id: str) -> pd.DataFrame: - spec_ids = self.spec_ids_by_comp_id(comp_id) - return self.ref_sdb.get_metadata_by_ids(spec_ids) + def metadata_by_comp_ids(self, comp_ids: Sequence[str]) -> pd.DataFrame: + """ + Return metadata for all spectra mapped to the given comp_ids. - def compound(self, comp_id: str) -> Optional[Dict[str, Any]]: - return self.ref_cdb.get_compound(comp_id) + Returns + ------- + pd.DataFrame + Columns: ['comp_id', 'spec_id', ...metadata_fields...] + """ + df_map = self.mapper.get_specs_for_comps(list(comp_ids)) + if df_map.empty: + cols = ["comp_id", "spec_id"] + self.metadata_fields + return pd.DataFrame(columns=cols) + + meta = self.ref_sdb.get_metadata_by_ids(df_map["spec_id"].tolist()) + # meta: spec_id + metadata_fields + out = df_map.merge(meta, on="spec_id", how="inner") + # Ensure column order: comp_id, spec_id, metadata... + return out[["comp_id", "spec_id"] + self.metadata_fields] + + def embeddings_by_comp_ids(self, comp_ids: Sequence[str]): + """ + Return embeddings for all spectra mapped to the given comp_ids. + Returns + ------- + (np.ndarray, np.ndarray) + (spec_ids, embeddings) as returned by SpectralDatabase.get_embeddings. + """ + df_map = self.mapper.get_specs_for_comps(list(comp_ids)) + if df_map.empty: + # Mirror SpectralDatabase.get_embeddings empty contract + return ( + np.empty((0,), dtype=str), + np.empty((0, 0), dtype=np.float32), + ) + spec_ids = df_map["spec_id"].tolist() + return self.ref_sdb.get_embeddings(spec_ids=spec_ids) # -------------------------------- convenience SQL ------------------------------ @@ -141,19 +222,13 @@ def sql(self, query: str) -> pd.DataFrame: # ----------------------------------- utilities --------------------------------- def inchikey_to_comp_id(self, inchikey_full: str) -> Optional[str]: + """Convert a full InChIKey to the 14-character comp_id (inchikey14).""" return inchikey14_from_full(inchikey_full) - def close(self): - # Close component connections - try: - self.ref_sdb.close() - except Exception: - pass - try: - self.ref_cdb.close() - except Exception: - pass - try: - self.mapper.close() - except Exception: - pass + def close(self) -> None: + """Close all component connections.""" + for obj in (self.ref_sdb, self.ref_cdb, self.all_cdb, self.mapper): + try: + obj.close() + except Exception: + pass diff --git a/ms2query/ms2query_library.py b/ms2query/ms2query_library.py index 84ea6d3..20dd667 100644 --- a/ms2query/ms2query_library.py +++ b/ms2query/ms2query_library.py @@ -4,8 +4,9 @@ import pandas as pd from matchms import Spectrum from ms2deepscore.models import load_model as _ms2ds_load_model +from sklearn.metrics.pairwise import cosine_similarity from ms2query import MS2QueryDatabase -from ms2query.data_processing import compute_spectra_embeddings +from ms2query.data_processing import compute_spectra_embeddings, merge_fingerprints from ms2query.database import EmbeddingIndex, FingerprintSparseIndex @@ -26,21 +27,23 @@ class MS2QueryLibrary: Notes ----- - * `model_path` is optional; provide it if you want on-the-fly embedding of ad-hoc spectra. - If omitted, you can still query using precomputed embeddings fetched from SQLite. * EmbeddingIndex must be built/loaded elsewhere (creation handled by setup workflow). """ db: MS2QueryDatabase embedding_index: Optional[EmbeddingIndex] = None - fingerprint_index: Optional[FingerprintSparseIndex] = None + fingerprint_index: Optional[FingerprintSparseIndex] = None # for now: reference spectra only + large_scale_fingerprint_index: Optional[FingerprintSparseIndex] = None # for large body of reference compounds model_path: Optional[str] = None # internal: whether to apply spectrum normalization (sum=1) before embedding _spectrum_sum_normalization_for_embedding: bool = True + # cached MS2DeepScore model _model: Any = field(default=None, init=False, repr=False) - # ----------------------------- lifecycle ----------------------------- + # ------------------------------------------------------------------ + # Lifecycle / internal helpers + # ----------------------------------------------------------------- def _ensure_model(self): """Lazy-load MS2DeepScore model if a model_path was provided.""" @@ -54,11 +57,21 @@ def _ensure_model(self): self._model.eval() return self._model - def _ensure_index(self): + def _ensure_embedding_index(self): if self.embedding_index is None: - raise RuntimeError("EmbeddingIndex is not set. Build or load it before querying.") + raise RuntimeError( + "EmbeddingIndex is not set. Build or load it before querying." + ) + + def _ensure_fingerprint_index(self): + if self.fingerprint_index is None: + raise RuntimeError( + "FingerprintSparseIndex is not set. Build or load it before querying." + ) - # ----------------------------- core API ----------------------------- + # ------------------------------------------------------------------ + # Core API: spectra -> embeddings -> ANN over embeddings + # ----------------------------------------------------------------- def process_spectra(self, spectra: list[Spectrum]) -> List[Spectrum]: """ @@ -68,25 +81,24 @@ def process_spectra(self, spectra: list[Spectrum]) -> List[Spectrum]: # Hook point: insert matchms pipeline later (e.g., metadata fixes, peak processing, etc.) return list(spectra) - def compute_embeddings(self, spectra: list[Spectrum]) -> np.ndarray: + def compute_embeddings(self, spectra: Sequence[Spectrum]) -> np.ndarray: """ Compute MS2DeepScore embeddings for arbitrary query spectra. Spectra will be preprocessed via self.process_spectra(...) first. """ + spectra = _ensure_spectra_list(spectra) if not spectra: return np.empty((0, 0), dtype=np.float32) model = self._ensure_model() - - # Preprocess — keep spectral normalization symmetrical with DB embeddings spectra = self.process_spectra(spectra) - # Compute embeddings return compute_spectra_embeddings( - model, spectra, - normalize_spectrum=self._spectrum_sum_normalization_for_embedding - ) + model, + spectra, + normalize_spectrum=self._spectrum_sum_normalization_for_embedding, + ) def query_embedding_index( self, @@ -100,12 +112,12 @@ def query_embedding_index( Process spectra -> embed -> query EmbeddingIndex. Returns per-query results as a list of lists with dicts: - [{'rank':1, 'spec_id': '...', 'score': float}, ...] + [{'rank': 1, 'spec_id': '...', 'score': float}, ...] All IDs are **spec_id** strings (NOT internal index ids). Parameters ---------- - spectra : Spectrum | list[Spectrum] + spectra : Spectrum | Sequence[Spectrum] Query spectra. k : int Top-k to return. @@ -113,74 +125,127 @@ def query_embedding_index( nmslib ef (higher = better recall / slower). return_dataframe : bool If True, returns a tidy DataFrame with columns: - ['query_ix','rank','spec_id','score'] + ['query_ix', 'rank', 'spec_id', 'score'] """ - self._ensure_index() + self._ensure_embedding_index() spectra = _ensure_spectra_list(spectra) - - # Compute embeddings (L2-normalized) embeddings = self.compute_embeddings(spectra) + if embeddings.size == 0: + return ( + [] if not return_dataframe else self._empty_result_df() + ) + + # Batched call: EmbeddingIndex.query returns List[List[(spec_id, score)]] + batch_hits = self.embedding_index.query(embeddings, k=k, ef=ef) + results_all: List[List[Dict[str, Any]]] = [] - for qi in range(embeddings.shape[0]): - # TODO: make faster by querying batch-wise - # EmbeddingIndex.query returns list[(spec_id, similarity)] - hits = self.embedding_index.query(embeddings[qi], k=k, ef=ef) - # convert to standard structure - one = [] - for rk, (spec_id, score) in enumerate(hits, start=1): - one.append({"rank": rk, "spec_id": spec_id, "score": float(score)}) + for hits in batch_hits: + one = [ + {"rank": rk + 1, "spec_id": spec_id, "score": float(score)} + for rk, (spec_id, score) in enumerate(hits) + ] results_all.append(one) if not return_dataframe: return results_all - rows = [] + rows: List[Dict[str, Any]] = [] for qi, lst in enumerate(results_all): for item in lst: rows.append({"query_ix": qi, **item}) - df = pd.DataFrame(rows, columns=["query_ix", "rank", "spec_id", "score"]) - return df + return pd.DataFrame(rows, columns=["query_ix", "rank", "spec_id", "score"]) def query_spectra_by_spectra( self, - spectra: list[Spectrum], + spectra: Union[Spectrum, Sequence[Spectrum]], *, k_spectra: int = 10, - ef: Optional[int] = None, - ): + ef: Optional[int] = None, + ): """ Query the embedding index with spectra, return top-k_spectra per spectrum. Parameters ---------- - spectra : list[Spectrum] + spectra : list[Spectrum] or Spectrum Query spectra. k_spectra : int Number of top spectra to retrieve from the embedding index. ef : Optional[int] nmslib ef parameter (higher = better recall / slower). """ - self._ensure_index() - spectra = _ensure_spectra_list(spectra) + return self.query_embedding_index( + spectra, k=k_spectra, ef=ef, return_dataframe=True + ) - # Query spectral embeddings - return self.query_embedding_index(spectra, k=k_spectra, ef=ef) + # ------------------------------------------------------------------ + # Core API: compounds / fingerprints + # ------------------------------------------------------------------ + + def query_compounds_by_compounds( + self, + *, + smiles: Optional[List[str]] = None, + inchis: Optional[List[str]] = None, + k_compounds: int = 10, + ) -> List[List[Dict[str, Any]]]: + """ + Query the fingerprint index with compounds, return top-k compounds per compound. + + Parameters + ---------- + smiles : Optional[List[str]], optional + List of SMILES strings, by default None + inchis : Optional[List[str]], optional + List of InChI strings, by default None + k_compounds : int + Number of top compounds to return per query compound. + """ + self._ensure_fingerprint_index() + # Compute fingerprints (sparse representation) + fps = self.db.all_cdb.compute_fingerprints( + smiles=smiles, + inchis=inchis, + ) + + # Batched fingerprint ANN query + batch_hits = self.fingerprint_index.query(fps, k=k_compounds) + + results_all: List[List[Dict[str, Any]]] = [] + for hits in batch_hits: + one = [ + {"rank": rk + 1, "comp_id": comp_id, "score": float(score)} + for rk, (comp_id, score) in enumerate(hits) + ] + results_all.append(one) + + # Convert to DataFrame + results_df = pd.DataFrame( + [ + {"query_ix": qi, **item} + for qi, lst in enumerate(results_all) + for item in lst + ], + columns=["query_ix", "rank", "comp_id", "score"], + ) + return results_df + def query_compounds_by_spectra( self, - spectra: list[Spectrum], + spectra: Union[Spectrum, Sequence[Spectrum]], *, k_spectra: int = 100, k_compounds: int = 10, ef: Optional[int] = None, - ): + ) -> pd.DataFrame: """ - Query the embedding index with spectra, return top-k_compounds per spectrum. + Query the embedding index with spectra, then aggregate to compounds. Parameters ---------- - spectra : list[Spectrum] + spectra : list[Spectrum] or Spectrum Query spectra. k_spectra : int Number of top spectra to retrieve from the embedding index. @@ -192,44 +257,119 @@ def query_compounds_by_spectra( if k_compounds > k_spectra: raise ValueError("k_compounds cannot be larger than k_spectra") - # Step1: Query spectral embeddings - results = self.query_spectra_by_spectra(spectra, k_spectra=k_spectra, ef=ef) + # Query spectral embeddings + results = self.query_spectra_by_spectra( + spectra, k_spectra=k_spectra, ef=ef + ) # DataFrame + + if results.empty: + return results # Pick k_compounds top compounds from the k_spectra hits (if possible) - spec_ids = results.spec_id.values + spec_ids = results["spec_id"].values - compounds = self.db.metadata_by_spec_ids([x for x in spec_ids]).set_index("spec_id") - compounds = compounds.merge(results, on="spec_id").sort_values(["query_ix", "rank"]) + compounds = ( + self.db.metadata_by_spec_ids(list(spec_ids)) + .set_index("spec_id") + ) + + compounds = ( + compounds.merge(results, on="spec_id") + .sort_values(["query_ix", "rank"]) + ) # Pick no more than k_compounds per query_ix - idx = compounds.groupby(['query_ix', 'rank'])['score'].idxmax() + idx = compounds.groupby(["query_ix", "rank"])["score"].idxmax() best_per_pair = compounds.loc[idx] # Within each query_ix, keep the top-k by score df_selected = ( - best_per_pair - .sort_values(['query_ix', 'score'], ascending=[True, False]) - .groupby('query_ix', group_keys=False) + best_per_pair.sort_values(["query_ix", "score"], ascending=[True, False]) + .groupby("query_ix", group_keys=False) .head(k_compounds) .reset_index(drop=True) ) - return df_selected def analogue_search( self, - spectra: list[Spectrum], - ): + spectra: Union[Spectrum, Sequence[Spectrum]], + *, + k_compounds: int = 10, + ef: Optional[int] = None, + ): """ Perform an analogue search for the given spectra. - TODO: implement analogue search logic here. + + Current behaviour: + - For each query spectrum, retrieve top-`k_spectra` library spectra. + - Get their compounds. + - Run compound-by-compound search in fingerprint space. """ - top_compounds = self.query_compounds_by_spectra(spectra) - # TODO: implement analogue search logic here - return top_compounds.drop_duplicates("query_ix") + # Step 1: top-k_spectra per query + spec_hits = self.query_spectra_by_spectra( + spectra, k_spectra=1, ef=ef + ) # DataFrame + if spec_hits.empty: + return [] + + spec_ids = spec_hits["spec_id"].values + + # Step 2: get compounds of all retrieved spectra + analogue_compounds = ( + self.db.metadata_by_spec_ids(list(spec_ids)) + .set_index("spec_id") + ) - - # ----------------------------- helpers / optional glue ----------------------------- + analogue_smiles = analogue_compounds["smiles"].tolist() + + # Step 3: fingerprint-based compound search + top_compounds = self.query_compounds_by_compounds( + smiles=analogue_smiles + ).set_index("query_ix") + + # Step 4: for each query, pick the best matching spectrum among all spectra + fingerprints_merged = [] + weighted_average_scores = [] + embeddings_queries = self.compute_embeddings(spectra) # TODO: this is now done twice! in step 1 and here + for i in range(len(analogue_smiles)): + comp_ids = top_compounds.loc[i].comp_id.to_list() + + # Get chemically closest compounds + spec_ids_all = [] + spec_ids_selected = [] + embeddings_selected = [] + + all_spec_ids = self.db.spec_ids_by_comp_ids(comp_ids).set_index("comp_id") + for comp_id in comp_ids: + new_spec_ids = all_spec_ids.loc[comp_id].spec_id.to_list() + + # Get most similar embedding from one of the top-10 compounds + embs = self.db.ref_sdb.get_embeddings(new_spec_ids) + similarities = cosine_similarity(embs[1], embeddings_queries[i].reshape(1, -1)) + max_id = np.argmax(similarities) + spec_ids_selected.append(embs[0][max_id]) + embeddings_selected.append(embs[1][max_id]) + spec_ids_all.extend(new_spec_ids) + + top1_top10_similarities = cosine_similarity(embeddings_selected, embeddings_queries[i].reshape(1, -1)) + fingerprints = self.db.ref_cdb.get_fingerprints(comp_ids) + fingerprints_merged.append(merge_fingerprints(fingerprints, weights=top1_top10_similarities)) + weighted_average_scores.append(np.sum(top1_top10_similarities ** 2) / np.sum(top1_top10_similarities)) + if self.large_scale_fingerprint_index: + analogue_predictions = self.large_scale_fingerprint_index.query(fingerprints_merged, k=k_compounds) + elif self.fingerprint_index: + analogue_predictions = self.fingerprint_index.query(fingerprints_merged, k=k_compounds) + else: + raise RuntimeError("No fingerprint index is set. Build or load it before querying.") + return pd.DataFrame({ + "analogue_predictions": analogue_predictions, + "weighted_average_scores": weighted_average_scores + }) + + # ------------------------------------------------------------------ + # Helpers / glue + # ------------------------------------------------------------------ def set_embedding_index(self, index: EmbeddingIndex) -> None: """Attach or replace the EmbeddingIndex.""" @@ -240,43 +380,60 @@ def set_fingerprint_index(self, index: FingerprintSparseIndex) -> None: self.fingerprint_index = index def query_by_spec_ids( - self, spec_ids: List[str], *, k: int = 10, ef: Optional[int] = None, return_dataframe: bool = False + self, + spec_ids: List[str], + *, + k: int = 10, + ef: Optional[int] = None, + return_dataframe: bool = False, ): """ Convenience: fetch embeddings for known spec_ids from SQLite and search. Requires that the embeddings are present in DB (table 'embeddings'). """ - if self.embedding_index is None: - raise RuntimeError("EmbeddingIndex is not set. Build or load it before querying.") + self._ensure_embedding_index() - # Pull precomputed embeddings from DB (already L2-normalized in SpectralDatabase.get_embeddings) - ids, X = self.db.ref_sdb.get_embeddings(ids=spec_ids, embeddings_table="embeddings", normalized=True) + # Pull precomputed embeddings from DB (already L2-normalized) + _, X = self.db.ref_sdb.get_embeddings( + spec_ids=spec_ids, + embeddings_table="embeddings", + normalized=True, + ) + + # If DB returns nothing if X.size == 0: return [] if not return_dataframe else self._empty_result_df() + # X is 2D: use batched query + batch_hits = self.embedding_index.query(X, k=k, ef=ef) + results_all: List[List[Dict[str, Any]]] = [] - for qi in range(X.shape[0]): - hits = self.embedding_index.query(X[qi], k=k, ef=ef) - one = [{"rank": rk + 1, "spec_id": sid, "score": float(score)} for rk, (sid, score) in enumerate(hits)] + for hits in batch_hits: + one = [ + {"rank": rk + 1, "spec_id": sid, "score": float(score)} + for rk, (sid, score) in enumerate(hits) + ] results_all.append(one) if not return_dataframe: return results_all - rows = [] + rows: List[Dict[str, Any]] = [] for qi, lst in enumerate(results_all): for item in lst: rows.append({"query_ix": qi, **item}) return pd.DataFrame(rows, columns=["query_ix", "rank", "spec_id", "score"]) @staticmethod - def _empty_result_df(): + def _empty_result_df() -> pd.DataFrame: return pd.DataFrame(columns=["query_ix", "rank", "spec_id", "score"]) # ----------------- helper functions --------------------- -def _ensure_spectra_list(spectra: Union[Spectrum, Sequence[Spectrum]]) -> List[Spectrum]: +def _ensure_spectra_list( + spectra: Union[Spectrum, Sequence[Spectrum]] +) -> List[Spectrum]: if isinstance(spectra, Spectrum): return [spectra] if isinstance(spectra, Sequence): diff --git a/ms2query/notebooks/Test_ann_speed_improvements.ipynb b/ms2query/notebooks/Test_ann_speed_improvements.ipynb new file mode 100644 index 0000000..bacbad8 --- /dev/null +++ b/ms2query/notebooks/Test_ann_speed_improvements.ipynb @@ -0,0 +1,306 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "a3a9b10c-a9a6-441f-963c-edf5d9a50dbe", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "\n", + "sys.path.append(\"C:/Users/jonge094/PycharmProjects/ms2query_2_0/ms_chemical_space_explorer\")\n", + "\n" + ] + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from matchms.importing import load_from_mgf\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "neg_val_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/negative_validation_spectra.mgf\")))\n", + "neg_test_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/negative_testing_spectra.mgf\")))\n" + ], + "id": "1958d1e9d021cd45" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from ms2deepscore.models import compute_embedding_array, load_model\n", + "\n", + "\n", + "ms2deepscore_model = load_model(\"../../../ms2deepscore/data/pytorch/new_corinna_included/trained_models/\"\n", + " \"both_mode_precursor_mz_ionmode_10000_layers_500_embedding_2024_11_21_11_23_17/ms2deepscore_model.pt\")\n", + "\n", + "embeddings = compute_embedding_array(ms2deepscore_model, neg_val_spectra)" + ], + "id": "fe20d28468de50ca" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "cee3357668c3f96e" + }, + { + "cell_type": "code", + "execution_count": 124, + "id": "d59f130e-f38e-467f-b6c0-899eb6fdb959", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "\n", + "more_test_embeddings = np.tile(embeddings, (70, 1))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "id": "5ffe97b9-0199-41f3-9f27-b0a639274af0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(528570, 500)" + ] + }, + "execution_count": 125, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "more_test_embeddings.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "id": "9b493113-b067-40c9-9925-0db45af6693d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time eleapsed: 163.06616854667664\n" + ] + } + ], + "source": [ + "import time\n", + "import pynndescent\n", + "\n", + "\n", + "start_time = time.time()\n", + "ann_model = pynndescent.NNDescent(more_test_embeddings, metric=\"cosine\", n_neighbors=30)\n", + "ann_model.prepare()\n", + "print(\"Time eleapsed: \" + str(time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": 128, + "id": "f0ad0f13-f449-4871-9f23-c0ebfe38cb33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time eleapsed: 86.74623227119446\n" + ] + } + ], + "source": [ + "start_time = time.time()\n", + "ann_model.update(embeddings[:2])\n", + "ann_model.prepare()\n", + "\n", + "print(\"Time eleapsed: \" + str(time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "id": "da5276ff-42e8-4201-8bc0-a166216787ae", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time eleapsed: 1.8495116233825684\n" + ] + } + ], + "source": [ + "start_time = time.time()\n", + "indices, dists = ann_model.query(embeddings[:1000], epsilon=1, k=1000)\n", + "print(\"Time eleapsed: \" + str(time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "id": "65f44f03-6593-4bc4-bc4d-e64ffe6cbc08", + "metadata": {}, + "outputs": [], + "source": [ + "for dist in dists:\n", + " if dist > 0.000001:\n", + " print(dist)" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "8f8b125d-ace5-4af7-b0ce-6b9a263ece1b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6783\n", + "768\n" + ] + } + ], + "source": [ + "correct = 0\n", + "not_correct = 0\n", + "for i, index in enumerate(indices):\n", + " correct_if_0 = index[0]%7551-i\n", + " if correct_if_0== 0:\n", + " correct += 1\n", + " else:\n", + " not_correct +=1\n", + "print(correct)\n", + "print(not_correct)" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "id": "f84b71ed-4e35-42c1-8be1-a1c9b3ec0dae", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(226530, 500)" + ] + }, + "execution_count": 117, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "more_test_embeddings.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 126, + "id": "dbc240d1-62e2-40f1-8f2c-12345cce3bf6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time eleapsed: 14.737722158432007\n" + ] + } + ], + "source": [ + "from ms2deepscore.vector_operations import cosine_similarity_matrix\n", + "\n", + "\n", + "start_time = time.time()\n", + "matrix = cosine_similarity_matrix(more_test_embeddings, embeddings[:1000])\n", + "print(\"Time eleapsed: \" + str(time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "id": "f87ba107-54cc-47ee-87c8-34694a917a3b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1. , 0.88234181, 0.73878687, ..., 0.55531438, 0.58179732,\n", + " 0.61927229],\n", + " [0.88234181, 1. , 0.91034399, ..., 0.50462923, 0.51674736,\n", + " 0.54464448],\n", + " [0.73878687, 0.91034399, 1. , ..., 0.48358345, 0.4835752 ,\n", + " 0.54332801],\n", + " ...,\n", + " [0.49057829, 0.55037322, 0.56818589, ..., 0.39834881, 0.41004598,\n", + " 0.5298201 ],\n", + " [0.48655507, 0.57605234, 0.5946298 , ..., 0.40762756, 0.42247458,\n", + " 0.50066275],\n", + " [0.41914688, 0.47922786, 0.48648942, ..., 0.38071478, 0.41207848,\n", + " 0.47297686]])" + ] + }, + "execution_count": 120, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "matrix" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62b738da-a9ce-421f-8ef8-ee8a18d2b70f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (ms2query2)", + "language": "python", + "name": "ms2query2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ms2query/notebooks/Test_method_evaluator.ipynb b/ms2query/notebooks/Test_method_evaluator.ipynb new file mode 100644 index 0000000..5303dfe --- /dev/null +++ b/ms2query/notebooks/Test_method_evaluator.ipynb @@ -0,0 +1,890 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ced5638b-15f4-4219-9ed8-b2823a53e574", + "metadata": {}, + "source": [ + "# load in spectra" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2768cf34-1794-4a09-89a3-f81a04b1bcda", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "31453it [00:14, 2140.47it/s]\n", + "7080it [00:03, 1992.44it/s]\n", + "459610it [03:37, 2114.18it/s]\n", + "131240it [01:06, 1967.67it/s]\n" + ] + } + ], + "source": [ + "import os\n", + "from matchms.importing import load_from_mgf\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "save_directory = \"../data/ms2deepscore_model/training_and_validation_split/\"\n", + "pos_val_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"positive_validation_spectra.mgf\"))))\n", + "neg_val_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"negative_validation_spectra.mgf\"))))\n", + "pos_train_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"positive_training_spectra.mgf\"))))\n", + "neg_train_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"negative_training_spectra.mgf\"))))\n", + "pos_test_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"positive_testing_spectra.mgf\"))))\n", + "neg_test_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"negative_testing_spectra.mgf\"))))\n" + ] + }, + { + "cell_type": "markdown", + "id": "7f5a4df2-2bb7-4731-a6bf-6727183da493", + "metadata": {}, + "source": [ + "### Save as pickled files for quicker reloads" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3a55be7f-d760-4df7-944e-251844cb1b4d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os\n", + "\n", + "\n", + "pickled_intermediates_data_folder = \"../data/pickled_intermediates\"\n", + "os.path.isdir(pickled_intermediates_data_folder)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "28606ab4-fd4c-4111-a1ca-8dc502593425", + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "\n", + "\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_val_spectra.pickle\"), \"wb\") as handle:\n", + " pickle.dump(neg_val_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_train_spectra.pickle\"), \"wb\") as handle:\n", + " pickle.dump(neg_train_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"pos_val_spectra.pickle\"), \"wb\") as handle:\n", + " pickle.dump(pos_val_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"pos_train_spectra.pickle\"), \"wb\") as handle:\n", + " pickle.dump(pos_train_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"pos_test_spectra.pickle\"), \"wb\") as handle:\n", + " pickle.dump(pos_test_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_test_spectra.pickle\"), \"wb\") as handle:\n", + " pickle.dump(neg_test_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "markdown", + "id": "cf849c08-3eaf-4922-b361-4f324b010ca9", + "metadata": {}, + "source": [ + "# Create a simple MS2Deepscore ranker" + ] + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from ms2deepscore.models import load_model\n", + "\n", + "\n", + "ms2deepscore_model = load_model(\n", + " \"../data/ms2deepscore_model/trained_models/\"\n", + " \"both_mode_ionmode_precursor_mz_2000_layers_500_embedding_2025_02_26_18_42_25/ms2deepscore_model.pt\")\n" + ], + "id": "d4805b43bea1e3a2" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "import sys\n", + "\n", + "\n", + "sys.path.append(\"../../ms_chemical_space_explorer\")" + ], + "id": "9ca0b8c6639168e1" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca3486f3-06b3-402d-90b1-ba01cb2568aa", + "metadata": {}, + "outputs": [], + "source": [ + "from ms_chemical_space_explorer.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings\n", + "\n", + "\n", + "library_spectra = SpectraWithMS2DeepScoreEmbeddings(neg_train_spectra + pos_train_spectra, ms2deepscore_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c858cc69-3da8-4530-ba13-ad4b1827f1ae", + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "\n", + "\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_library_embeddings.pickle\"), \"wb\") as handle:\n", + " pickle.dump(library_spectra.embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53ca37fe-1627-4473-a21a-168791d92112", + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "\n", + "\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_library_with_embeddings.pickle\"), \"wb\") as handle:\n", + " pickle.dump(library_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cec8761-6c37-4ac5-ba52-619c171cd834", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "val_spectra = SpectraWithMS2DeepScoreEmbeddings(neg_val_spectra + pos_val_spectra, ms2deepscore_model)" + ] + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "import pickle\n", + "\n", + "\n", + "with open(os.path.join(pickled_intermediates_data_folder,\n", + " \"neg_pos_val_spectra_with_embeddings.pickle\"), \"wb\") as handle:\n", + " pickle.dump(val_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)" + ], + "id": "e6835fbcb3d01494" + }, + { + "cell_type": "markdown", + "id": "688da27f-3290-480c-ab67-d171bbac6b93", + "metadata": {}, + "source": [ + "# Reload intermediates" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9cb48a39-f3cd-4924-8da2-f11007475795", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "\n", + "sys.path.append(\"../../ms_chemical_space_explorer\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "97d042ab-19f6-4d32-bfd7-b24e0f72c1d9", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pickle\n", + "\n", + "\n", + "pickled_intermediates_data_folder = \"../data/pickled_intermediates\"\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_library_with_embeddings.pickle\"), \"rb\") as file:\n", + " library_spectra = pickle.load(file)\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_val_spectra_with_embeddings.pickle\"), \"rb\") as file:\n", + " val_spectra = pickle.load(file)" + ] + }, + { + "cell_type": "markdown", + "id": "c3c8dfda-3294-45cc-8a6f-1aeb06257c73", + "metadata": {}, + "source": [ + "# Initialize a method evaluator" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "79fa345a-52eb-4079-a6b5-0bd63d78d821", + "metadata": {}, + "outputs": [], + "source": [ + "from ms_chemical_space_explorer.benchmarking.EvaluateMethods import EvaluateMethods\n", + "\n", + "\n", + "method_evaluator = EvaluateMethods(library_spectra, val_spectra)\n", + "method_evaluator.training_spectrum_set.progress_bars = False\n", + "method_evaluator.validation_spectrum_set.progress_bars = False" + ] + }, + { + "cell_type": "markdown", + "id": "16663c37-4fdd-4454-bcc6-55dbf4d6230f", + "metadata": {}, + "source": [ + "# Test basic MS2DeepScore" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "1048dccf-1b53-4dcc-96d4-01a4d51dc57d", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 78/78 [26:20<00:00, 20.26s/it]\n", + "Calculating analogue accuracy per inchikey: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2015/2015 [00:01<00:00, 1190.75it/s]\n" + ] + } + ], + "source": [ + "from ms_chemical_space_explorer.methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore\n", + "\n", + "\n", + "result_analogue = method_evaluator.benchmark_analogue_search(predict_highest_ms2deepscore)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "0f6f21cb-a358-4803-ad74-cc57a2a7a47e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.3848227909492443\n" + ] + } + ], + "source": [ + "print(result_analogue)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "eda9a92c-fe2c-4c4b-ab4e-07ddb4346a27", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1837/1837 [00:00<00:00, 53337.60it/s]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [10:21<00:00, 20.03s/it]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [10:35<00:00, 19.25s/it]\n", + "Calculating exact match accuracy per inchikey: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1612/1612 [00:00<00:00, 313512.85it/s]\n" + ] + } + ], + "source": [ + "from ms_chemical_space_explorer.methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore\n", + "\n", + "\n", + "result_positive = method_evaluator.benchmark_exact_matching_within_ionmode(predict_highest_ms2deepscore, \"positive\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "b6ddad52-8647-450f-8881-6753a5bea814", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5169110149857257\n" + ] + } + ], + "source": [ + "print(result_positive)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "dfc7a156-74a3-4c48-b610-e38410abf274", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 924/924 [00:00<00:00, 293846.15it/s]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:08<00:00, 18.29s/it]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [02:14<00:00, 16.78s/it]\n", + "Calculating exact match accuracy per inchikey: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 868/868 [00:00<00:00, 894070.70it/s]\n" + ] + } + ], + "source": [ + "from ms_chemical_space_explorer.methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore\n", + "\n", + "\n", + "result_neg = method_evaluator.benchmark_exact_matching_within_ionmode(predict_highest_ms2deepscore, \"negative\")" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "86cc5ba4-4380-488a-a486-de63cccada74", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5448453174876924" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_neg\n" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "a5b0595b-51dd-4a56-b9ec-e81a597dd26b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey across ionmodes: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2015/2015 [00:00<00:00, 9551.01it/s]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [11:40<00:00, 20.02s/it]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [03:47<00:00, 17.46s/it]\n", + "Calculating exact match accuracy per inchikey: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 746/746 [00:00<00:00, 496186.30it/s]\n" + ] + } + ], + "source": [ + "result_across_ionmodes = method_evaluator.exact_matches_across_ionization_modes(predict_highest_ms2deepscore)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "fe14e752-75b4-4e00-8600-90c1dce46b90", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.0006188308440879157" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_across_ionmodes" + ] + }, + { + "cell_type": "markdown", + "id": "1499181b-7b68-437a-85b3-b7dc3ad53884", + "metadata": {}, + "source": [ + "# Test best possible results" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "56d119e2-145a-484c-ab3c-56b569b549dd", + "metadata": {}, + "outputs": [], + "source": [ + "from ms_chemical_space_explorer.methods.predict_best_possible_match import predict_best_possible_match\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b3687845-099c-4b98-b715-c8057b5858bb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calculating tanimoto scores to determine best possible match\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Calculating analogue accuracy per inchikey: 100%|██████████████████████████████████████████| 2015/2015 [00:01<00:00, 1578.04it/s]\n" + ] + } + ], + "source": [ + "result_analogue_best = method_evaluator.benchmark_analogue_search(predict_highest_ms2deepscore)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "81c83a8d-0754-4207-97d3-2fbef49082ac", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.7753653963061775" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_analogue_best" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "5fad0038-025c-4c73-9bfb-bc4545f092df", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 924/924 [00:00<00:00, 310739.01it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calculating tanimoto scores to determine best possible match\n", + "Calculating tanimoto scores to determine best possible match\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Calculating exact match accuracy per inchikey: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 868/868 [00:00<00:00, 928738.74it/s]\n" + ] + } + ], + "source": [ + "result_neg_best = method_evaluator.benchmark_exact_matching_within_ionmode(predict_best_possible_match, \"negative\")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "e0418a27-0195-4021-a04b-788cbee4fe2b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_neg_best" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "f19ea34d-f6c1-4f7c-b7ec-8c7bad85a1da", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey across ionmodes: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2015/2015 [00:00<00:00, 13798.67it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calculating tanimoto scores to determine best possible match\n", + "Calculating tanimoto scores to determine best possible match\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Calculating exact match accuracy per inchikey: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 746/746 [00:00<00:00, 404727.82it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_across_ionmodes_best = method_evaluator.exact_matches_across_ionization_modes(predict_best_possible_match)\n", + "result_across_ionmodes_best" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "052e768d-4044-40b8-a85e-7cf03abe4184", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey across ionmodes: 100%|█████████████████████████████████████| 2015/2015 [00:00<00:00, 10356.24it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calculating tanimoto scores to determine best possible match\n", + "Calculating tanimoto scores to determine best possible match\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Calculating exact match accuracy per inchikey: 100%|███████████████████████████████████████| 746/746 [00:00<00:00, 315164.26it/s]\n" + ] + } + ], + "source": [ + "result_across_ionmodes_best = method_evaluator.exact_matches_across_ionization_modes(predict_best_possible_match)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1ff6eaba-8213-4270-9395-39ff3aee6879", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_across_ionmodes_best" + ] + }, + { + "cell_type": "markdown", + "id": "3702ba19-0c86-443c-8621-fe1ab427b3da", + "metadata": {}, + "source": [ + "# Test cosine score" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5f0cfc49-009f-46a1-beef-6310d31260a1", + "metadata": {}, + "outputs": [], + "source": [ + "from ms_chemical_space_explorer.methods.predict_highest_cosine import predict_highest_cosine" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cb9210d-3ead-4798-bcc0-82b571e69f8b", + "metadata": {}, + "outputs": [], + "source": [ + "result_analogue_cosine = method_evaluator.benchmark_analogue_search(predict_highest_cosine)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1a5fb72-6039-4d56-af71-fe07bb583c22", + "metadata": {}, + "outputs": [], + "source": [ + "result_analogue_cosine" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7d4641b9-9b44-4b64-8fed-0b568cb0f4f9", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 924/924 [00:00<00:00, 260575.33it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "created scores\n", + "Filtered on precursor\n", + "Calculated cosine\n", + "('PrecursorMzMatch', 'CosineGreedy_score', 'CosineGreedy_matches')\n", + "created scores\n", + "Filtered on precursor\n", + "Calculated cosine\n", + "('PrecursorMzMatch', 'CosineGreedy_score', 'CosineGreedy_matches')\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Calculating exact match accuracy per inchikey: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 868/868 [00:00<00:00, 201408.27it/s]\n" + ] + } + ], + "source": [ + "result_neg = method_evaluator.benchmark_exact_matching_within_ionmode(predict_highest_cosine, \"negative\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "96191952-14df-41b3-a410-d2e5574dd3c9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5772003486330072" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_neg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e77b0ecd-7f1a-438d-a0a1-950ff9e76881", + "metadata": {}, + "outputs": [], + "source": [ + "result_positive = method_evaluator.benchmark_exact_matching_within_ionmode(predict_highest_cosine, \"positive\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7db770c4-f4bb-4a0e-82a2-03d52e6ed7d0", + "metadata": {}, + "outputs": [], + "source": [ + "result_positive" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "675351b2-ff50-42d8-b179-006d0a46cadc", + "metadata": {}, + "outputs": [], + "source": [ + "result_across_ionmodes = method_evaluator.exact_matches_across_ionization_modes(predict_highest_cosine)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d056551a-f85b-4077-8b35-eaf834da7a9b", + "metadata": {}, + "outputs": [], + "source": [ + "result_across_ionmodes" + ] + }, + { + "cell_type": "markdown", + "id": "7682a1b8-aec1-40dc-b942-419bd23de041", + "metadata": {}, + "source": [ + "# Test predictions with ISF" + ] + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from ms_chemical_space_explorer.methods.predict_with_integrated_similarity_flow import (\n", + " predict_with_integrated_similarity_flow,\n", + ")\n", + "\n", + "\n", + "result_analogue_isf = method_evaluator.benchmark_analogue_search(predict_with_integrated_similarity_flow)\n", + "print(result_analogue_isf)\n", + "result_neg_isf = method_evaluator.benchmark_exact_matching_within_ionmode(\n", + " predict_with_integrated_similarity_flow, \"negative\")\n", + "print(result_neg_isf)\n", + "result_positive_isf = method_evaluator.benchmark_exact_matching_within_ionmode(\n", + " predict_with_integrated_similarity_flow, \"positive\")\n", + "print(result_positive_isf)\n", + "result_across_ionmodes_isf = method_evaluator.exact_matches_across_ionization_modes(\n", + " predict_with_integrated_similarity_flow)\n", + "print(result_across_ionmodes_isf)" + ], + "id": "958a6efb624e75de" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "result_analogue_isf", + "id": "326a55ab85832258" + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "22ba5663-1f8a-405a-b04b-0380e18c6987", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hello\n" + ] + } + ], + "source": [ + "print(\"hello\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b9198fb-c60b-4db5-8415-578b88275e5a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb b/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb new file mode 100644 index 0000000..271f9de --- /dev/null +++ b/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb @@ -0,0 +1,341 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from matchms.importing import load_from_mgf\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "neg_val_spectra = list(tqdm(load_from_mgf(\n", + " \"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/\"\n", + " \"negative_validation_spectra.mgf\")))\n", + "neg_test_spectra = list(tqdm(load_from_mgf(\n", + " \"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/negative_testing_spectra.mgf\")))\n", + "neg_train_spectra = list(tqdm(load_from_mgf(\n", + " \"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/negative_training_spectra.mgf\")))\n", + "neg_spectra = neg_val_spectra + neg_test_spectra + neg_train_spectra\n", + "\n", + "pos_val_spectra = list(tqdm(load_from_mgf(\n", + " \"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/positive_validation_spectra.mgf\")))\n", + "pos_test_spectra = list(tqdm(load_from_mgf(\n", + " \"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/positive_testing_spectra.mgf\")))\n", + "pos_train_spectra = list(tqdm(load_from_mgf(\n", + " \"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/positive_training_spectra.mgf\")))\n", + "pos_spectra = pos_val_spectra + pos_test_spectra + pos_train_spectra" + ], + "id": "fdf12d625e87127b" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "neg_inchikeys = []\n", + "for spectrum in tqdm(neg_spectra):\n", + " neg_inchikeys.append(spectrum.get(\"inchikey\")[:14])\n", + " " + ], + "id": "9ce097a34a9e6656" + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "437e9455-49a9-435e-8bb5-cd3b3ed987f1", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 519580/519580 [00:46<00:00, 11219.75it/s]\n" + ] + } + ], + "source": [ + "pos_inchikeys = []\n", + "for spectrum in tqdm(pos_spectra):\n", + " pos_inchikeys.append(spectrum.get(\"inchikey\")[:14])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "df03a6d2-6755-46e1-a874-fc96c1a2b20c", + "metadata": {}, + "outputs": [], + "source": [ + "unique_neg_inchikeys = set(neg_inchikeys)\n", + "unique_pos_inchikeys = set(pos_inchikeys)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "aa775f3b-9c10-43cd-a6be-74dac2e698ce", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "18480" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(unique_neg_inchikeys)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f3a73395-68c2-4952-9171-96f67f597cc9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "36638" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(unique_pos_inchikeys)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f9729ad5-d201-476a-87a1-78e267425aae", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "14801" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "overlapping_inchikeys = unique_neg_inchikeys & unique_pos_inchikeys\n", + "len(overlapping_inchikeys)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "07d148f7-aef6-4eda-a66f-984f0c65a31b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "40317" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "total_unique_inchikeys = unique_neg_inchikeys | unique_pos_inchikeys\n", + "len(total_unique_inchikeys)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "0cd254e4-9b7f-41ed-a7ce-25675392102f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.36711560880025795" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "14801/40317" + ] + }, + { + "cell_type": "markdown", + "id": "3f5bb776-9910-4a4d-bc45-d8e049d63ce6", + "metadata": {}, + "source": [ + "# Check how many there are already in the val set\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "4b5ced16-ceae-439b-bb04-c72986798b8f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7551/7551 [00:01<00:00, 4630.22it/s]\n" + ] + } + ], + "source": [ + "neg_inchikeys_val = []\n", + "for spectrum in tqdm(neg_val_spectra):\n", + " neg_inchikeys_val.append(spectrum.get(\"inchikey\")[:14])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "0d2b00bd-56a3-4c67-ad3f-d54051a74701", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25412/25412 [00:02<00:00, 10220.59it/s]\n" + ] + } + ], + "source": [ + "pos_inchikeys_val = []\n", + "for spectrum in tqdm(pos_val_spectra):\n", + " pos_inchikeys_val.append(spectrum.get(\"inchikey\")[:14])" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "76a7736a-d4a4-42a5-bf42-d20b1a38060c", + "metadata": {}, + "outputs": [], + "source": [ + "unique_neg_inchikeys_val = set(neg_inchikeys_val)\n", + "unique_pos_inchikeys_val = set(pos_inchikeys_val)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "6fce2334-e43b-40ce-b459-75c1a6d1ed00", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "35" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "overlapping_inchikeys_val = unique_neg_inchikeys_val & unique_pos_inchikeys_val\n", + "len(overlapping_inchikeys_val)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "86668203-8e5b-40e9-934e-0ade3c25e9d8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "7551" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(neg_inchikeys_val)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "bc2e25e7-5d85-4486-924d-c1d9651ba2bc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "25412" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(pos_inchikeys_val)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e864dc8e-68bb-4d3d-8a27-57f3892581e3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 3fe7fa4..fe9495c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ pandas = ">=2.1.1" scipy= ">=1.14.0" matplotlib= ">=3.8.0" matchms= ">=0.30.0" -ms2deepscore= { git = "https://github.com/matchms/ms2deepscore.git", branch = "pytorch_update" } +ms2deepscore= ">=2.6.0" rdkit= ">2024.3.4" nmslib= ">=2.0.0" umap-learn= ">=0.5.7" @@ -49,6 +49,9 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] +exclude = [ + "ms2query/notebooks/", +] line-length = 120 output-format = "grouped" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..702e6fc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,90 @@ +import os +from pathlib import Path +import numpy as np +from matchms.Spectrum import Spectrum +from ms2deepscore.models import load_model + + +TEST_RESOURCES_PATH = Path(__file__).parent / "test_data" + + +def create_test_spectra( + number_of_spectra_per_inchikey=3, + inchikey_inchi_pairs=None, + nr_of_inchikeys=3, +): + if inchikey_inchi_pairs is None: + inchikey_inchi_pairs = get_inchikey_inchi_pairs(nr_of_inchikeys) + spectra = [] + for i, inchikey_inchi_tuple in enumerate(inchikey_inchi_pairs): + inchikey, inchi, smiles, compound_name = inchikey_inchi_tuple + for j in range(number_of_spectra_per_inchikey): + spectra.append( + Spectrum( + mz=np.array([100 + i * 10.0, 500 + i * 1.0]), + intensities=np.array([1.0, 1.0 / (j + 1)]), + metadata={ + "precursor_mz": 111.1 + i * 10, + "inchikey": inchikey, + "inchi": inchi, + "smiles": smiles, + "compound_name": compound_name, + }, + ) + ) + return spectra + + +def ms2deepscore_model(): + return load_model(os.path.join(TEST_RESOURCES_PATH, "ms2deepscore_testmodel_v1.pt")) + + +def get_inchikey_inchi_pairs(number_of_pairs): + """Returns inchikey_inchi_pairs""" + inchikey_inchi_pairs = ( + ( + "RYYVLZVUVIJVGH-UHFFFAOYSA-N", + "InChI=1S/C8H10N4O2/c1-10-4-9-6-5(10)7(13)12(3)8(14)11(6)2/h4H,1-3H3", + "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", + "Caffeine", + ), + ( + "ZPUCINDJVBIVPJ-LJISPDSOSA-N", + "InChI=1S/C17H21NO4/c1-18-12-8-9-13(18)15(17(20)21-2)14(10-12)22-16(19)11-6-4-3-5-7-11/h3-7,12-15H,8-10H2,1-2H3/t12-,13+,14-,15+/m0/s1", + "CN1[C@H]2CC[C@@H]1[C@H]([C@H](C2)OC(=O)C3=CC=CC=C3)C(=O)OC", + "Cocaine", + ), + ( + "RZVAJINKPMORJF-UHFFFAOYSA-N", + "InChI=1S/C8H9NO2/c1-6(10)9-7-2-4-8(11)5-3-7/h2-5,11H,1H3,(H,9,10)", + "CC(=O)NC1=CC=C(C=C1)O", + "Paracetemol", + ), + ( + "JGSARLDLIJGVTE-MBNYWOFBSA-N", + "InChI=1S/C16H18N2O4S/c1-16(2)12(15(21)22)18-13(20)11(14(18)23-16)17-10(19)8-9-6-4-3-5-7-9/h3-7,11-12,14H,8H2,1-2H3,(H,17,19)(H,21,22)/t11-,12+,14-/m1/s1", + "CC1([C@@H](N2[C@H](S1)[C@@H](C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C", + "Penicillin", + ), + ( + "WQZGKKKJIJFFOK-GASJEMHNSA-N", + "InChI=1S/C6H12O6/c7-1-2-3(8)4(9)5(10)6(11)12-2/h2-11H,1H2/t2-,3-,4+,5-,6?/m1/s1", + "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", + "Glucose", + ), + ( + "MWOOGOJBHIARFG-UHFFFAOYSA-N", + "InChI=1S/C8H8O3/c1-11-8-4-6(5-9)2-3-7(8)10/h2-5,10H,1H3", + "COC1=C(C=CC(=C1)C=O)O", + "vanillin" + ), + ( + "ROHFNLRQFUQHCH-YFKPBYRVSA-N", + "InChI=1S/C6H13NO2/c1-4(2)3-5(7)6(8)9/h4-5H,3,7H2,1-2H3,(H,8,9)/t5-/m0/s1", + "CC(C)C[C@@H](C(=O)O)N", + "L-Leucine" + ) + ) + if number_of_pairs > len(inchikey_inchi_pairs): + raise ValueError("Not enough example compounds, add some in conftest") + return inchikey_inchi_pairs[:number_of_pairs] diff --git a/tests/testPredictMS2DeepScoreSimilarity.py b/tests/testPredictMS2DeepScoreSimilarity.py new file mode 100644 index 0000000..4741bdf --- /dev/null +++ b/tests/testPredictMS2DeepScoreSimilarity.py @@ -0,0 +1,28 @@ +import numpy as np +import pytest +from ms2query.benchmarking.reference_methods.PredictMS2DeepScoreSimilarity import ( + predict_top_ms2deepscores, +) +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings +from tests.conftest import create_test_spectra, ms2deepscore_model + + +@pytest.mark.parametrize( + "method", + [ + predict_top_ms2deepscores, + ], +) +def test_predict_highest_ms2deepscore_similarity(method): + ms2ds_model = ms2deepscore_model() + test_spectra = create_test_spectra(1) + library_spectra = SpectraWithMS2DeepScoreEmbeddings(test_spectra, ms2ds_model) + query_spectra = SpectraWithMS2DeepScoreEmbeddings(test_spectra, ms2ds_model) + number_of_analogues = 2 + + indices, distances = method(library_spectra.embeddings, query_spectra.embeddings, k=number_of_analogues) + assert indices.shape == (len(test_spectra), number_of_analogues) + assert distances.shape == (len(test_spectra), number_of_analogues) + for i, row in enumerate(indices): + assert row[0] == i, "The highest predictions should be against itself" + assert np.allclose(distances[i][0], 1.0, atol=1e-5) diff --git a/tests/test_SpectrumDataSet.py b/tests/test_SpectrumDataSet.py new file mode 100644 index 0000000..bb3e232 --- /dev/null +++ b/tests/test_SpectrumDataSet.py @@ -0,0 +1,130 @@ +import numpy as np +import pytest +from ms2query.benchmarking.SpectrumDataSet import ( + SpectraWithFingerprints, + SpectraWithMS2DeepScoreEmbeddings, + SpectrumSet, +) +from tests.conftest import create_test_spectra, get_inchikey_inchi_pairs, ms2deepscore_model + + +@pytest.mark.parametrize( + "library", + [ + SpectrumSet(create_test_spectra()), + SpectraWithFingerprints(create_test_spectra()), + SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(), ms2deepscore_model()), + ], +) +def test_spectrum_set_base(library): + """Test all base functionality of SpectrumSet is implemented correctly + also for all classes inheriting from it""" + # test correct init + assert len(library.spectra) == 9 + assert len(library.spectrum_indexes_per_inchikey) == 3 + assert sum(len(v) for v in library.spectrum_indexes_per_inchikey.values()) == 9 + + # test correct copying + new_copy = library.copy() + assert len(new_copy.spectra) == 9 + assert len(new_copy.spectrum_indexes_per_inchikey) == 3 + assert sum(len(v) for v in new_copy.spectrum_indexes_per_inchikey.values()) == 9 + + # test correctly adding spectra + new_copy.add_spectra(library) + new_number_of_spectra = 9 + 9 + assert len(new_copy.spectra) == new_number_of_spectra + assert len(new_copy.spectrum_indexes_per_inchikey) == 3 + assert sum(len(v) for v in new_copy.spectrum_indexes_per_inchikey.values()) == new_number_of_spectra + + # test the original is not edited when adding spectra + assert len(library.spectra) == 9 + assert len(library.spectrum_indexes_per_inchikey) == 3 + assert sum(len(v) for v in library.spectrum_indexes_per_inchikey.values()) == 9 + + # test correct subsetting + subset_indexes = [1, 4, 6, 7] + subset = library.subset_spectra(subset_indexes) + assert len(subset.spectra) == len(subset_indexes) + assert len(subset.spectrum_indexes_per_inchikey) == 3 + assert sum(len(v) for v in subset.spectrum_indexes_per_inchikey.values()) == len(subset_indexes) + assert isinstance(subset, library.__class__) + + +@pytest.mark.parametrize( + "library", + [ + SpectraWithFingerprints(create_test_spectra()), + SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(), ms2deepscore_model()), + ], +) +def test_spectra_with_fingerprints(library): + """Test all functionality added in SpectraWithFingerprints also for all classes inheriting from it""" + # test correct init + assert len(library.inchikey_fingerprint_pairs) == 3 + + # test correct copying + new_copy = library.copy() + assert len(new_copy.inchikey_fingerprint_pairs) == 3 + + # test correctly adding inchikey_fingerprint_pairs when runnning add_spectra + for inchikey_inchi_pairs, expected_nr_of_inchikeys in ( + (get_inchikey_inchi_pairs(5)[2:], 5), # Some overlap + (get_inchikey_inchi_pairs(5)[3:], 5), # No overlap + (get_inchikey_inchi_pairs(3), 3), # Fully overlapping + (get_inchikey_inchi_pairs(1), 3), # Fully overlapping (but not all) + ): + spectra_to_add = SpectrumSet(create_test_spectra(2, inchikey_inchi_pairs=inchikey_inchi_pairs)) + new_copy = library.copy() + new_copy.add_spectra(spectra_to_add) + assert len(new_copy.inchikey_fingerprint_pairs) == expected_nr_of_inchikeys + for inchikey in library.inchikey_fingerprint_pairs: + assert np.array_equal( + new_copy.inchikey_fingerprint_pairs[inchikey], library.inchikey_fingerprint_pairs[inchikey] + ) + + # test the original is not edited when adding spectra + assert len(library.inchikey_fingerprint_pairs) == 3 + assert all( + np.array_equal(library.inchikey_fingerprint_pairs[key], value) + for key, value in SpectraWithFingerprints(create_test_spectra()).inchikey_fingerprint_pairs.items() + ) + + # test correct subsetting + subset_indexes = [1, 6, 7] + subset = library.subset_spectra(subset_indexes) + assert len(subset.inchikey_fingerprint_pairs) == 2 + assert all( + np.array_equal(library.inchikey_fingerprint_pairs[key], value) + for key, value in subset.inchikey_fingerprint_pairs.items() + ) + assert hasattr(subset, "update_fingerprint_per_inchikey") + + +def test_spectra_with_embeddings(): + library = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(), ms2deepscore_model()) + # test correct init + assert library.embeddings.shape == (9, 100) + + # test correct copying + new_copy = library.copy() + assert new_copy.embeddings.shape == (9, 100) + + # test correctly adding spectra + new_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(1), ms2deepscore_model()) + new_copy.add_spectra(new_spectra) + assert new_copy.embeddings.shape == (12, 100) + + # test the original is not edited when adding spectra + assert library.embeddings.shape == (9, 100) + + # test correct subsetting + subset_indexes = [1, 4, 6, 7] + subset = library.subset_spectra(subset_indexes) + assert subset.embeddings.shape == (len(subset_indexes), 100) + for i, index in enumerate(subset_indexes): + assert np.all(library.embeddings[index] == subset.embeddings[i]) + + # Check that subsetting on subset works. To make sure that a subset does not become of type SpectrumSet + subsetted_subset = subset.subset_spectra([0, 1]) + assert subsetted_subset.embeddings.shape == (2, 100) diff --git a/tests/test_ann_vector_index.py b/tests/test_ann_vector_index.py index fe3971a..69f7f42 100644 --- a/tests/test_ann_vector_index.py +++ b/tests/test_ann_vector_index.py @@ -4,7 +4,13 @@ import numpy as np import pytest import scipy.sparse as sp -from ms2query.database.ann_vector_index import EmbeddingIndex, csr_row_from_tuple, l1_norms_csr, tuples_to_csr +from ms2query.database.ann_vector_index import ( + EmbeddingIndex, + FingerprintSparseIndex, # <-- added + csr_row_from_tuple, + l1_norms_csr, + tuples_to_csr, +) def _mk_unit_vecs(*rows): @@ -15,8 +21,8 @@ def _mk_unit_vecs(*rows): def test_build_index_and_query_dense(): - X = _mk_unit_vecs([1,0,0], [0,1,0], [0,0,1]) - ids = ["a","b","c"] + X = _mk_unit_vecs([1, 0, 0], [0, 1, 0], [0, 0, 1]) + ids = ["a", "b", "c"] idx = EmbeddingIndex(dim=3) idx.build_index(X, ids) # Query close to [1,0,0] @@ -27,11 +33,11 @@ def test_build_index_and_query_dense(): def test_build_index_normalizes_when_requested(): - X = np.array([[2.0,0,0],[0,2.0,0]], dtype=np.float32) - ids = ["x","y"] + X = np.array([[2.0, 0, 0], [0, 2.0, 0]], dtype=np.float32) + ids = ["x", "y"] idx = EmbeddingIndex(dim=3) idx.build_index(X, ids) - q = np.array([1.0,0,0], dtype=np.float32) + q = np.array([1.0, 0, 0], dtype=np.float32) out = idx.query(q, k=1) assert out[0][0] == "x" assert 0.99 <= out[0][1] <= 1.0 @@ -41,14 +47,14 @@ def test_query_errors_and_dim_check(): idx = EmbeddingIndex(dim=3) with pytest.raises(RuntimeError): idx.query(np.zeros(3, np.float32)) - idx.build_index(np.eye(3, dtype=np.float32), ["a","b","c"]) + idx.build_index(np.eye(3, dtype=np.float32), ["a", "b", "c"]) with pytest.raises(ValueError, match="dim=3"): idx.query(np.zeros(4, np.float32)) def test_save_and_load_roundtrip_dense(tmp_path): - X = _mk_unit_vecs([1,0,0],[0,1,0],[0,0,1]) - ids = ["a","b","c"] + X = _mk_unit_vecs([1, 0, 0], [0, 1, 0], [0, 0, 1]) + ids = ["a", "b", "c"] idx = EmbeddingIndex(dim=3) idx.build_index(X, ids) prefix = os.path.join(tmp_path, "emb") @@ -59,7 +65,7 @@ def test_save_and_load_roundtrip_dense(tmp_path): idx2.load_index(prefix) # Query should still work - res = idx2.query(np.array([1.0,0,0], dtype=np.float32), k=1) + res = idx2.query(np.array([1.0, 0, 0], dtype=np.float32), k=1) assert res[0][0] == "a" # meta persisted with open(prefix + ".meta.json") as f: @@ -72,7 +78,8 @@ def test_build_index_from_sqlite_streams_and_orders(batch_rows): conn = sqlite3.connect(":memory:") conn.execute("CREATE TABLE embeddings(spec_id TEXT, vec BLOB, d INTEGER)") # Add 3 vectors of dim 3 - conn.executemany( "INSERT INTO embeddings(spec_id, vec, d) VALUES (?,?,?)", + conn.executemany( + "INSERT INTO embeddings(spec_id, vec, d) VALUES (?,?,?)", [ ("id_1", np.array([1.0, 0.0, 0.0], np.float32).tobytes(), 3), ("id_2", np.array([1.0, 1.0, 0.0], np.float32).tobytes(), 3), @@ -106,7 +113,7 @@ def test_build_index_from_sqlite_errors(): # Empty table should error conn2 = sqlite3.connect(":memory:") - conn2.execute("CREATE TABLE embeddings(spec_id TEXT, vec BLOB, d INTEGER)") + conn2.execute("CREATE TABLE embeddings(spec_id TEXT, vec, d INTEGER)") with pytest.raises(ValueError, match="No rows"): idx.build_index_from_sqlite(conn2, embeddings_table="embeddings") @@ -158,3 +165,51 @@ def test_l1_norms_csr(): norms = l1_norms_csr(X) np.testing.assert_allclose(norms, [3.0, 1.0, 4.0]) assert norms.dtype == np.float64 + + +def test_save_and_load_roundtrip_fingerprint_sparse(tmp_path): + # Build a tiny sparse fingerprint matrix with 3 compounds, dim=5 + tuples = [ + (np.array([0, 3], dtype=np.int32), np.array([1.0, 0.5], dtype=np.float32)), + (np.array([1], dtype=np.int32), np.array([1.0], dtype=np.float32)), + (np.array([2, 4], dtype=np.int32), np.array([0.2, 2.0], dtype=np.float32)), + ] + csr = tuples_to_csr(tuples, dim=5) + comp_ids = np.array([10, 11, 12], dtype=int) + + idx = FingerprintSparseIndex(dim=5) + idx.build_index( + csr, + comp_ids, + keep_csr_for_rerank=True, + compute_l1_for_rerank=True, + ) + + # Basic sanity: query with the first fingerprint should hit comp_id=10 first + q = tuples[0] + res = idx.query(q, k=1) + assert res[0][0] == 10 + + # Save to disk + prefix = os.path.join(tmp_path, "fp") + idx.save_index(prefix) + + # New instance loads back + idx2 = FingerprintSparseIndex() + idx2.load_index(prefix) + + # Query should still work and return the same top compound + res2 = idx2.query(q, k=1) + assert res2[0][0] == 10 + + # CSR and L1 data should have been persisted + assert idx2._csr is not None + assert idx2._l1 is not None + assert idx2._csr.shape == csr.shape + assert idx2._l1.shape[0] == csr.shape[0] + + # Meta persisted and type is correct + with open(prefix + ".meta.json") as f: + meta = json.load(f) + assert meta["type"] == "FingerprintSparseIndex" + assert meta["space"] == "cosinesimil_sparse" diff --git a/tests/test_evaluate_methods.py b/tests/test_evaluate_methods.py new file mode 100644 index 0000000..7431aad --- /dev/null +++ b/tests/test_evaluate_methods.py @@ -0,0 +1,31 @@ +import pytest +from ms2query.benchmarking.EvaluateMethods import EvaluateMethods +from ms2query.benchmarking.reference_methods.predict_best_possible_match import predict_best_possible_match +from ms2query.benchmarking.reference_methods.predict_highest_cosine import predict_highest_cosine +from ms2query.benchmarking.reference_methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings +from tests.conftest import create_test_spectra, ms2deepscore_model + + +@pytest.mark.parametrize( + "method", + [predict_highest_ms2deepscore, predict_highest_cosine, predict_best_possible_match], +) +def test_evaluate_methods(method): + nr_of_spectra_per_inchikey = 6 + nr_of_inchikeys = 5 + dummy_spectra = create_test_spectra(nr_of_spectra_per_inchikey, nr_of_inchikeys=nr_of_inchikeys) + for i, spectrum in enumerate(dummy_spectra): + if i % 3 == 0: + spectrum.set("ionmode", "positive") + else: + spectrum.set("ionmode", "negative") + model = ms2deepscore_model() + reference_library = SpectraWithMS2DeepScoreEmbeddings(dummy_spectra[: nr_of_spectra_per_inchikey * 2], model) + validation_spectra = SpectraWithMS2DeepScoreEmbeddings(dummy_spectra[nr_of_spectra_per_inchikey * 2 :], model) + method_evaluator = EvaluateMethods(reference_library, validation_spectra) + # should be zero or below zero, since it is the difference with the perfect predictions + assert method_evaluator.benchmark_analogue_search(method) >= 0.0 + # # Should be 1 because we added a good match for each. + assert method_evaluator.benchmark_exact_matching_within_ionmode(method, "positive") == 1.0 + assert method_evaluator.exact_matches_across_ionization_modes(method) == 1.0 diff --git a/tests/test_library_io.py b/tests/test_library_io.py index ad9a82f..04662bd 100644 --- a/tests/test_library_io.py +++ b/tests/test_library_io.py @@ -8,7 +8,7 @@ TEST_COMP_ID = "ZBSGKPYXQINNGF" # expected InChIKey14 present in the test data -EXPECTED_METADATA_SHAPE = (5, 11) +EXPECTED_METADATA_SHAPE = (5, 12) EXPECTED_METADATA_FIELDS = [ "precursor_mz", "ionmode", "smiles", "inchikey", "inchi", "name", "charge", "instrument_type", "adduct", "collision_energy", @@ -56,7 +56,7 @@ def test_create_and_load_library(tmp_path: Path): ms2query_db = lib.db # Metadata query by compound id (expected shape from your snippet) - df_meta = ms2query_db.metadata_by_comp_id(TEST_COMP_ID) + df_meta = ms2query_db.metadata_by_comp_ids([TEST_COMP_ID]) assert tuple(df_meta.shape) == EXPECTED_METADATA_SHAPE # Metadata fields presence both in db wrapper and in returned dataframe diff --git a/tests/test_methods.py b/tests/test_methods.py new file mode 100644 index 0000000..1b747d0 --- /dev/null +++ b/tests/test_methods.py @@ -0,0 +1,56 @@ +import numpy as np +import pytest +from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix +from ms2query.benchmarking.reference_methods.predict_best_possible_match import predict_best_possible_match +from ms2query.benchmarking.reference_methods.predict_highest_cosine import predict_highest_cosine +from ms2query.benchmarking.reference_methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore +from ms2query.benchmarking.reference_methods.predict_with_integrated_similarity_flow import ( + integrated_similarity_flow, + predict_with_integrated_similarity_flow, +) +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings +from tests.conftest import create_test_spectra, ms2deepscore_model + + +@pytest.mark.parametrize( + "prediction_function", + [ + predict_highest_cosine, + predict_highest_ms2deepscore, + predict_best_possible_match, + ], +) +def test_all_methods(prediction_function): + model = ms2deepscore_model() + library_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(), model) + test_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(1), model) + predicted_inchikeys, scores = prediction_function(library_spectra, test_spectra) + for i, spectrum in enumerate(test_spectra.spectra): + inchikey = spectrum.get("inchikey")[:14] + assert predicted_inchikeys[i] == inchikey + assert np.allclose(scores[i], np.array(1.0), atol=1e-5) + + +def test_predict_with_integrated_similarity_flow(): + model = ms2deepscore_model() + library_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(), model) + test_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(1), model) + predicted_inchikeys, scores = predict_with_integrated_similarity_flow(library_spectra, test_spectra) + + assert predicted_inchikeys == ["RYYVLZVUVIJVGH", "ZPUCINDJVBIVPJ", "ZPUCINDJVBIVPJ"] + assert np.allclose(np.array([0.38829751082577607, 0.3919729335980483, 0.38774130710967564]), np.array(scores)) + + +def test_isf_computation(): + distances = [0.99, 0.99, 0.99, 0.5, 0.5] + fps = np.array([[1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 0, 1], [1, 1, 1, 1, 0, 0], [0, 0, 1, 1, 0, 0], [0, 0, 1, 0, 1, 0]]) + similarities = jaccard_similarity_matrix(fps, fps) + result = integrated_similarity_flow(distances, similarities, [1, 1, 1, 1, 1]) + expected_result = [ + 0.715869, + 0.686481, + 0.736523, + 0.492107, + 0.359110, + ] + assert np.allclose(np.array(expected_result), np.array(result)) diff --git a/tests/test_ms2query_library.py b/tests/test_ms2query_library.py index a6fd087..255f9cb 100644 --- a/tests/test_ms2query_library.py +++ b/tests/test_ms2query_library.py @@ -9,7 +9,7 @@ TEST_COMP_ID = "ZBSGKPYXQINNGF" # known from your snippet -EXPECTED_METADATA_SHAPE = (5, 11) +EXPECTED_METADATA_SHAPE = (5, 12) EXPECTED_METADATA_FIELDS = [ "precursor_mz", "ionmode", "smiles", "inchikey", "inchi", "name", "charge", "instrument_type", "adduct", "collision_energy", @@ -79,7 +79,7 @@ def test_create_and_load_smoke(tmp_path: Path): # DB content checks ms2query_db = lib.db - meta_df = ms2query_db.metadata_by_comp_id(TEST_COMP_ID) + meta_df = ms2query_db.metadata_by_comp_ids([TEST_COMP_ID]) assert tuple(meta_df.shape) == EXPECTED_METADATA_SHAPE for field in EXPECTED_METADATA_FIELDS: assert field in ms2query_db.metadata_fields diff --git a/tests/test_predict_using_closest_tanimoto.py b/tests/test_predict_using_closest_tanimoto.py new file mode 100644 index 0000000..f6ebc31 --- /dev/null +++ b/tests/test_predict_using_closest_tanimoto.py @@ -0,0 +1,85 @@ +import numpy as np +import pytest +from ms2query.benchmarking.reference_methods.predict_using_closest_tanimoto import ( + get_average_predictions_for_closely_related_metabolites, + get_inchikey_and_tanimoto_scores_for_top_k, + predict_using_closest_tanimoto, + predict_using_closest_tanimoto_single_spectrum, + select_inchikeys_with_highest_ms2deepscore, +) +from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectraWithMS2DeepScoreEmbeddings +from tests.conftest import create_test_spectra, ms2deepscore_model + + +def test_predict_using_closest_tanimoto(): + """Only very basic test that the function runs and that the output is the right format""" + model = ms2deepscore_model() + library_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(nr_of_inchikeys=7), model) + test_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(1, nr_of_inchikeys=3), model) + predicted_inchikeys, scores = predict_using_closest_tanimoto(library_spectra, test_spectra, 3, 3) + + assert isinstance(predicted_inchikeys, list) + assert len(predicted_inchikeys) == 3 + assert isinstance(scores, list) + assert len(scores) == 3 + +def test_predict_using_closest_tanimoto_single_spectrum(): + """Only very basic test that the function runs and that the output is the right format""" + model = ms2deepscore_model() + library_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(nr_of_inchikeys=7), model) + test_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(1, nr_of_inchikeys=1), model) + predicted_inchikey, score = predict_using_closest_tanimoto_single_spectrum(library_spectra, test_spectra, 3, 3) + + assert isinstance(predicted_inchikey, str) + assert len(predicted_inchikey) ==14 + assert isinstance(score, float) + +def test_select_inchikeys_with_highest_ms2deepscore(): + test_spectra = create_test_spectra(nr_of_inchikeys=7) + spectra = SpectraWithFingerprints(test_spectra) + + ms2deepscores = np.zeros(len(test_spectra)) + ms2deepscores[2] = 0.4 + ms2deepscores[5] = 0.9 + ms2deepscores[7] = 0.6 + inchikeys_with_highest_ms2deepscore = select_inchikeys_with_highest_ms2deepscore(spectra, ms2deepscores, 3) + expected_inchikeys = list(spectra.spectrum_indexes_per_inchikey.keys())[:3] + assert set(expected_inchikeys) == set(inchikeys_with_highest_ms2deepscore) + print(inchikeys_with_highest_ms2deepscore) + +def test_get_average_predictions_for_closely_related_metabolites(): + test_spectra = create_test_spectra(nr_of_inchikeys=7) + # Select different number per inchikey (only one for the first) to check that it is correctly weighted. + test_spectra = test_spectra.copy()[2:] + spectra = SpectraWithFingerprints(test_spectra) + + inchikeys = list(spectra.inchikey_fingerprint_pairs.keys())[:3] + ms2deepscores = np.zeros(len(spectra.spectra)) + ms2deepscores[0] = 0.8 + ms2deepscores[[1,2,3]] = 0.6 + ms2deepscores[4] = 0.6 + ms2deepscores[5] = 0.8 + ms2deepscores[6] = 0.7 + # the average per inchikey is 0.8, 0.6, 0.7, so average overall should be 0.7 + average_predicted_score = get_average_predictions_for_closely_related_metabolites(spectra, + inchikeys, + ms2deepscores) + assert np.allclose(average_predicted_score, np.array(0.7), atol=1e-5) + +@pytest.mark.parametrize( + "k", + [1, 3, 7], +) +def test_get_inchikey_and_tanimoto_scores_for_top_k(k): + spectra = SpectraWithFingerprints(create_test_spectra(nr_of_inchikeys=7)) + inchikey = list(spectra.inchikey_fingerprint_pairs.keys())[2] + + top_inchikeys, tanimoto_scores_for_top_k = get_inchikey_and_tanimoto_scores_for_top_k( + spectra, inchikey,k) + + assert inchikey in top_inchikeys + assert len(top_inchikeys) == k + assert len(tanimoto_scores_for_top_k) == k + assert len(set(top_inchikeys)) == k + assert tanimoto_scores_for_top_k[top_inchikeys.index(inchikey)] == 1.0, \ + "The exact match is expected to have a score of 1.0" \ No newline at end of file diff --git a/tests/test_spectral_database.py b/tests/test_spectral_database.py index adfc296..fe7237e 100644 --- a/tests/test_spectral_database.py +++ b/tests/test_spectral_database.py @@ -1,4 +1,3 @@ -# test_spectral_database.py import numpy as np import pandas as pd import pytest @@ -149,10 +148,70 @@ def test_missing_ids_handling(tmp_db, spectra_small): # Implementation skips missing IDs but preserves order of the ones that exist assert [s.metadata["spec_id"] for s in out_spectra] == [ids[1]] - assert list(out_meta["spec_id"]) == [ids[1]] + + # get_metadata_by_ids now returns one row per requested ID + assert out_meta.shape[0] == 2 + assert list(out_meta["spec_id"]) == req + + # First row corresponds to missing ID: all metadata fields should be None/NaN + missing_row = out_meta.iloc[0] + assert missing_row["spec_id"] == req[0] + assert missing_row[tmp_db.metadata_fields].isna().all() + + # Exactly one row has *any* metadata filled (the real spec_id) + non_empty_rows = out_meta.dropna(how="all", subset=tmp_db.metadata_fields) + assert non_empty_rows.shape[0] == 1 + assert non_empty_rows.iloc[0]["spec_id"] == ids[1] + + # get_fragments_by_ids still skips missing IDs assert len(out_frags) == 1 +def test_get_metadata_by_ids_all_ids_included_even_if_same_compound(tmp_db): + # Two spectra with different peaks but same compound-level metadata + s1 = make_spectrum( + [100, 200, 300], + [10, 20, 30], + precursor_mz=250.0, + ionmode="positive", + inchikey="SAME-IK", + smiles="C", + name="compound-1", + ) + s2 = make_spectrum( + [110, 210, 310], + [5, 15, 25], + precursor_mz=260.0, + ionmode="positive", + inchikey="SAME-IK", # same compound + smiles="C", + name="compound-1", + ) + + ids = tmp_db.add_spectra([s1, s2]) + + # Request both spec_ids; we expect two rows, one per ID, same order + df = tmp_db.get_metadata_by_ids(ids) + + expected_cols = ["spec_id"] + tmp_db.metadata_fields + assert list(df.columns) == expected_cols + assert df.shape[0] == 2 + assert list(df["spec_id"]) == ids + + # Both rows should carry the same compound-level metadata (same inchikey/smiles) + assert df.loc[0, "inchikey"] == "SAME-IK" + assert df.loc[1, "inchikey"] == "SAME-IK" + assert df.loc[0, "smiles"] == "C" + assert df.loc[1, "smiles"] == "C" + + # If name is stored, it should be consistent across rows (but may be None) + assert df.loc[0, "name"] == df.loc[1, "name"] + + # The precursor m/z values differ per spectrum and should be preserved per ID + assert df.loc[0, "precursor_mz"] == pytest.approx(250.0) + assert df.loc[1, "precursor_mz"] == pytest.approx(260.0) + + def test_add_duplicates_are_ignored_and_ids_repeat(tmp_db, spectra_small): # First insert ids_first = tmp_db.add_spectra(spectra_small)