From fa4afdd19a80052a5d3f57b830308f1c6da3179b Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Fri, 5 Dec 2025 10:58:42 +0100 Subject: [PATCH 01/45] Add SpectrumDataSet for handling sets of spectra, for method development. --- ms2query/benchmarking/SpectrumDataSet.py | 130 +++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 ms2query/benchmarking/SpectrumDataSet.py diff --git a/ms2query/benchmarking/SpectrumDataSet.py b/ms2query/benchmarking/SpectrumDataSet.py new file mode 100644 index 0000000..0751d70 --- /dev/null +++ b/ms2query/benchmarking/SpectrumDataSet.py @@ -0,0 +1,130 @@ +import copy +from collections import Counter +from typing import List, Dict, Iterable + +import numpy as np +from matchms import Spectrum +from matchms.filtering.metadata_processing.add_fingerprint import _derive_fingerprint_from_inchi + +from ms2deepscore.models import compute_embedding_array, SiameseSpectralModel +from tqdm import tqdm + + +class SpectrumSetBase: + """Stores a spectrum dataset making it easy and fast to split on molecules""" + + def __init__(self, spectra: List[Spectrum], progress_bars=False): + self._spectra = [] + self.spectrum_indexes_per_inchikey = {} + self.progress_bars = progress_bars + # init spectra + self._add_spectra_and_group_per_inchikey(spectra) + + def _add_spectra_and_group_per_inchikey(self, spectra: List[Spectrum]): + starting_index = len(self._spectra) + updated_inchikeys = set() + for i, spectrum in enumerate( + tqdm(spectra, desc="Adding spectra and grouping per Inchikey", disable=not self.progress_bars) + ): + self._spectra.append(spectrum) + spectrum_index = starting_index + i + inchikey = spectrum.get("inchikey")[:14] + updated_inchikeys.add(inchikey) + if inchikey in self.spectrum_indexes_per_inchikey: + self.spectrum_indexes_per_inchikey[inchikey].append(spectrum_index) + else: + self.spectrum_indexes_per_inchikey[inchikey] = [ + spectrum_index, + ] + return updated_inchikeys + + def add_spectra(self, new_spectra: "SpectrumSetBase"): + return self._add_spectra_and_group_per_inchikey(new_spectra.spectra) + + def subset_spectra(self, spectrum_indexes) -> "SpectrumSetBase": + """Returns a new instance of a subset of the spectra""" + new_instance = copy.copy(self) + new_instance._spectra = [] + new_instance.spectrum_indexes_per_inchikey = {} + new_instance._add_spectra_and_group_per_inchikey([self._spectra[index] for index in spectrum_indexes]) + return new_instance + + def spectra_per_inchikey(self, inchikey) -> List[Spectrum]: + matching_spectra = [] + for index in self.spectrum_indexes_per_inchikey[inchikey]: + matching_spectra.append(self._spectra[index]) + return matching_spectra + + @property + def spectra(self): + return self._spectra + + def copy(self): + """This copy method ensures all spectra are""" + new_instance = copy.copy(self) + new_instance._spectra = self._spectra.copy() + new_instance.spectrum_indexes_per_inchikey = copy.deepcopy(self.spectrum_indexes_per_inchikey) + return new_instance + + +class SpectraWithFingerprints(SpectrumSetBase): + """Stores a spectrum dataset making it easy and fast to split on molecules""" + + def __init__(self, spectra: List[Spectrum], fingerprint_type="daylight", nbits=4096): + super().__init__(spectra) + self.fingerprint_type = fingerprint_type + self.nbits = nbits + self.inchikey_fingerprint_pairs: Dict[str, np.array] = {} + # init spectra + self.update_fingerprint_per_inchikey(self.spectrum_indexes_per_inchikey.keys()) + + def add_spectra(self, new_spectra: "SpectraWithFingerprints"): + updated_inchikeys = super().add_spectra(new_spectra) + if hasattr(new_spectra, "inchikey_fingerprint_pairs"): + if new_spectra.nbits == self.nbits and new_spectra.fingerprint_type == self.fingerprint_type: + if len(self.inchikey_fingerprint_pairs.keys() & new_spectra.inchikey_fingerprint_pairs.keys()) == 0: + self.inchikey_fingerprint_pairs = ( + self.inchikey_fingerprint_pairs | new_spectra.inchikey_fingerprint_pairs + ) + return + self.update_fingerprint_per_inchikey(updated_inchikeys) + + def update_fingerprint_per_inchikey(self, inchikeys_to_update: Iterable[str]): + for inchikey in tqdm( + inchikeys_to_update, desc="Adding fingerprints to Inchikeys", disable=not self.progress_bars + ): + spectra = self.spectra_per_inchikey(inchikey) + most_common_inchi = Counter([spectrum.get("inchi") for spectrum in spectra]).most_common(1)[0][0] + fingerprint = _derive_fingerprint_from_inchi( + most_common_inchi, fingerprint_type=self.fingerprint_type, nbits=self.nbits + ) + if not isinstance(fingerprint, np.ndarray): + raise ValueError(f"Fingerprint could not be set for InChI: {most_common_inchi}") + self.inchikey_fingerprint_pairs[inchikey] = fingerprint + + def copy(self): + """This copy method ensures all spectra are""" + new_instance = super().copy() + new_instance.inchikey_fingerprint_pairs = copy.copy(self.inchikey_fingerprint_pairs) + return new_instance + + +class SpectraWithMS2DeepScoreEmbeddings(SpectraWithFingerprints): + def __init__(self, spectra: List[Spectrum], ms2deepscore_model: SiameseSpectralModel, **kwargs): + super().__init__(spectra, **kwargs) + self.ms2deepscore_model = ms2deepscore_model + self.embeddings: np.ndarray = compute_embedding_array(self.ms2deepscore_model, spectra) + + def add_spectra(self, new_spectra: "SpectraWithMS2DeepScoreEmbeddings"): + super().add_spectra(new_spectra) + if hasattr(new_spectra, "embeddings"): + new_embeddings = new_spectra.embeddings + else: + new_embeddings = compute_embedding_array(self.ms2deepscore_model, new_spectra.spectra) + self.embeddings = np.vstack([self.embeddings, new_embeddings]) + + def subset_spectra(self, spectrum_indexes) -> "SpectraWithMS2DeepScoreEmbeddings": + """Returns a new instance of a subset of the spectra""" + new_instance = super().subset_spectra(spectrum_indexes) + new_instance.embeddings = self.embeddings[spectrum_indexes] + return new_instance From e882f0094e201d9c27e4c0d57829629a3871d731 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Fri, 5 Dec 2025 11:08:58 +0100 Subject: [PATCH 02/45] Add EvaluateMethods a general method for benchmarking analogue search and exact match searches --- ms2query/benchmarking/EvaluateMethods.py | 205 +++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 ms2query/benchmarking/EvaluateMethods.py diff --git a/ms2query/benchmarking/EvaluateMethods.py b/ms2query/benchmarking/EvaluateMethods.py new file mode 100644 index 0000000..6442343 --- /dev/null +++ b/ms2query/benchmarking/EvaluateMethods.py @@ -0,0 +1,205 @@ +import random + +import numpy as np +from typing import Callable, Tuple, List + +from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix +from tqdm import tqdm + +from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectrumSetBase + + +class EvaluateMethods: + def __init__( + self, training_spectrum_set: SpectraWithFingerprints, validation_spectrum_set: SpectraWithFingerprints + ): + self.training_spectrum_set = training_spectrum_set + self.validation_spectrum_set = validation_spectrum_set + + self.training_spectrum_set.progress_bars = False + self.validation_spectrum_set.progress_bars = False + + def benchmark_analogue_search( + self, + prediction_function: Callable[ + [SpectraWithFingerprints, SpectraWithFingerprints], Tuple[List[str], List[float]] + ], + ) -> float: + predicted_inchikeys, _ = prediction_function(self.training_spectrum_set, self.validation_spectrum_set) + average_scores_per_inchikey = [] + + # Calculate score per unique inchikey + for inchikey in tqdm( + self.validation_spectrum_set.spectrum_indexes_per_inchikey.keys(), + desc="Calculating analogue accuracy per inchikey", + ): + matching_spectrum_indexes = self.validation_spectrum_set.spectrum_indexes_per_inchikey[inchikey] + prediction_scores = [] + for index in matching_spectrum_indexes: + predicted_inchikey = predicted_inchikeys[index] + if predicted_inchikey is None: + prediction_scores.append(0.0) + else: + predicted_fingerprint = self.training_spectrum_set.inchikey_fingerprint_pairs[predicted_inchikey] + actual_fingerprint = self.validation_spectrum_set.inchikey_fingerprint_pairs[inchikey] + tanimoto_for_prediction = calculate_tanimoto_score_between_pair( + predicted_fingerprint, actual_fingerprint + ) + prediction_scores.append(tanimoto_for_prediction) + + average_prediction = sum(prediction_scores) / len(prediction_scores) + score = average_prediction + average_scores_per_inchikey.append(score) + average_over_all_inchikeys = sum(average_scores_per_inchikey) / len(average_scores_per_inchikey) + return average_over_all_inchikeys + + def benchmark_exact_matching_within_ionmode( + self, + prediction_function: Callable[ + [SpectraWithFingerprints, SpectraWithFingerprints], Tuple[List[str], List[float]] + ], + ionmode: str, + ) -> float: + """Test the accuracy at retrieving exact matches from the library + + For each inchikey with more than 1 spectrum the spectra are split in two sets. Half for each inchikey is added + to the library (training set), for the other half predictions are made. Thereby there is always an exact match + avaialable. Only the highest ranked prediction is considered correct if the correct inchikey is predicted. An accuracy per + inchikey is calculated followed by calculating the average. + """ + selected_spectra = subset_spectra_on_ionmode(self.validation_spectrum_set, ionmode) + + set_1, set_2 = split_spectrum_set_per_inchikeys(selected_spectra) + + predicted_inchikeys = predict_between_two_sets(self.training_spectrum_set, set_1, set_2, prediction_function) + + # add the spectra to set_1 + set_1.add_spectra(set_2) + return calculate_average_exact_match_accuracy(set_1, predicted_inchikeys) + + def exact_matches_across_ionization_modes( + self, + prediction_function: Callable[ + [SpectraWithFingerprints, SpectraWithFingerprints], Tuple[List[str], List[float]] + ], + ): + """Test the accuracy at retrieving exact matches from the library if only available in other ionisation mode + + Each val spectrum is matched against the training set with the other val spectra of the same inchikey, but other + ionisation mode added to the library. + """ + pos_set, neg_set = split_spectrum_set_per_inchikey_across_ionmodes(self.validation_spectrum_set) + predicted_inchikeys = predict_between_two_sets( + self.training_spectrum_set, pos_set, neg_set, prediction_function + ) + # add the spectra to set_1 + pos_set.add_spectra(neg_set) + return calculate_average_exact_match_accuracy(pos_set, predicted_inchikeys) + + def get_accuracy_recall_curve(self): + """This method should test the recall accuracy balance. + All of the used methods use a threshold which indicates quality of prediction. + A method that can predict well when a prediction is accurate is beneficial. + We need a method to test this. + + One method is generating a recall accuracy curve. This could be done for both the analogue search predictions + and the exact match predictions. By returning the predicted score for a match this method could create an + accuracy recall plot. + """ + raise NotImplementedError + + +def predict_between_two_sets( + library: SpectrumSetBase, query_set_1: SpectrumSetBase, query_set_2: SpectrumSetBase, prediction_function +): + """Makes predictions between query sets and the library, with the other query set added. + + This is necessary for testing exact matching""" + training_set_copy = library.copy() + training_set_copy.add_spectra(query_set_2) + predicted_inchikeys_1, _ = prediction_function(training_set_copy, query_set_1) + + training_set_copy = library.copy() + training_set_copy.add_spectra(query_set_1) + predicted_inchikeys_2, _ = prediction_function(training_set_copy, query_set_2) + + return predicted_inchikeys_1 + predicted_inchikeys_2 + + +def calculate_average_exact_match_accuracy(spectrum_set: SpectrumSetBase, predicted_inchikeys: List[str]): + if len(spectrum_set.spectra) != len(predicted_inchikeys): + raise ValueError("The number of spectra should be equal to the number of predicted inchikeys ") + exact_match_accuracy_per_inchikey = [] + for inchikey in tqdm( + spectrum_set.spectrum_indexes_per_inchikey.keys(), desc="Calculating exact match accuracy per inchikey" + ): + val_spectrum_indexes_matching_inchikey = spectrum_set.spectrum_indexes_per_inchikey[inchikey] + correctly_predicted = 0 + for selected_spectrum_idx in val_spectrum_indexes_matching_inchikey: + if inchikey == predicted_inchikeys[selected_spectrum_idx]: + correctly_predicted += 1 + exact_match_accuracy_per_inchikey.append(correctly_predicted / len(val_spectrum_indexes_matching_inchikey)) + return sum(exact_match_accuracy_per_inchikey) / len(exact_match_accuracy_per_inchikey) + + +def split_spectrum_set_per_inchikeys(spectrum_set: SpectrumSetBase) -> Tuple[SpectrumSetBase, SpectrumSetBase]: + """Splits a spectrum set into two. + For each inchikey with more than one spectrum the spectra are divided over the two sets""" + indexes_set_1 = [] + indexes_set_2 = [] + for inchikey in tqdm(spectrum_set.spectrum_indexes_per_inchikey.keys(), desc="Splitting spectra per inchikey"): + val_spectrum_indexes_matching_inchikey = spectrum_set.spectrum_indexes_per_inchikey[inchikey] + if len(val_spectrum_indexes_matching_inchikey) == 1: + # all single spectra are excluded from this test, since no exact match can be added to the library + continue + split_index = len(val_spectrum_indexes_matching_inchikey) // 2 + random.shuffle(val_spectrum_indexes_matching_inchikey) + indexes_set_1.extend(val_spectrum_indexes_matching_inchikey[:split_index]) + indexes_set_2.extend(val_spectrum_indexes_matching_inchikey[split_index:]) + return spectrum_set.subset_spectra(indexes_set_1), spectrum_set.subset_spectra(indexes_set_2) + + +def split_spectrum_set_per_inchikey_across_ionmodes( + spectrum_set: SpectrumSetBase, +) -> Tuple[SpectrumSetBase, SpectrumSetBase]: + """Splits a spectrum set in two sets on ionmode. Only uses spectra for inchikeys with at least 1 pos and 1 neg""" + all_pos_indexes = [] + all_neg_indexes = [] + for inchikey in tqdm( + spectrum_set.spectrum_indexes_per_inchikey.keys(), + desc="Splitting spectra per inchikey across ionmodes", + ): + val_spectrum_indexes_matching_inchikey = spectrum_set.spectrum_indexes_per_inchikey[inchikey] + positive_val_spectrum_indexes_current_inchikey = [] + negative_val_spectrum_indexes_current_inchikey = [] + for spectrum_index in val_spectrum_indexes_matching_inchikey: + ionmode = spectrum_set.spectra[spectrum_index].get("ionmode") + if ionmode == "positive": + positive_val_spectrum_indexes_current_inchikey.append(spectrum_index) + elif ionmode == "negative": + negative_val_spectrum_indexes_current_inchikey.append(spectrum_index) + + if ( + len(positive_val_spectrum_indexes_current_inchikey) < 1 + or len(negative_val_spectrum_indexes_current_inchikey) < 1 + ): + continue + else: + all_pos_indexes.extend(positive_val_spectrum_indexes_current_inchikey) + all_neg_indexes.extend(negative_val_spectrum_indexes_current_inchikey) + + pos_val_spectra = spectrum_set.subset_spectra(all_pos_indexes) + neg_val_spectra = spectrum_set.subset_spectra(all_neg_indexes) + return pos_val_spectra, neg_val_spectra + + +def subset_spectra_on_ionmode(spectrum_set: SpectrumSetBase, ionmode) -> SpectrumSetBase: + spectrum_indexes_to_keep = [] + for i, spectrum in enumerate(spectrum_set.spectra): + if spectrum.get("ionmode") == ionmode: + spectrum_indexes_to_keep.append(i) + return spectrum_set.subset_spectra(spectrum_indexes_to_keep) + + +def calculate_tanimoto_score_between_pair(fingerprint_1: str, fingerprint_2: str) -> float: + return jaccard_similarity_matrix(np.array([fingerprint_1]), np.array([fingerprint_2]))[0][0] From 922bea8caf95c1a8c87a14e98db4374be2a84228 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Fri, 5 Dec 2025 11:12:09 +0100 Subject: [PATCH 03/45] Implement base line methods --- .../PredictMS2DeepScoreSimilarity.py | 46 ++++++++ .../predict_best_possible_match.py | 48 ++++++++ .../predict_highest_cosine.py | 27 +++++ .../predict_highest_ms2deepscore.py | 17 +++ ...predict_with_integrated_similarity_flow.py | 108 ++++++++++++++++++ 5 files changed, 246 insertions(+) create mode 100644 ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py create mode 100644 ms2query/benchmarking/reference_methods/predict_best_possible_match.py create mode 100644 ms2query/benchmarking/reference_methods/predict_highest_cosine.py create mode 100644 ms2query/benchmarking/reference_methods/predict_highest_ms2deepscore.py create mode 100644 ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py diff --git a/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py b/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py new file mode 100644 index 0000000..13121cf --- /dev/null +++ b/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py @@ -0,0 +1,46 @@ +from typing import Tuple + +import numpy as np + +from ms2deepscore.vector_operations import cosine_similarity_matrix +from tqdm import tqdm + + +def predict_top_ms2deepscores( + library_embeddings: np.ndarray, query_embeddings: np.ndarray, batch_size: int = 500, k=1 +) -> Tuple[np.ndarray, np.ndarray]: + """Memory efficient way of calculating the highest MS2DeepScores + + When doing large matrix multiplications the memory footprint of storing the output matrix can be large. + E.g. when doing 500.000 vs 10.000 spectra this is a very large matrix. On a laptop this can result in very + slow run times. If only the highest MS2DeepScore is needed processing in batches prevents using too much memory + + Args: + library_embeddings: The embeddings of the library spectra + query_embeddings: The embeddings of the query spectra + batch_size: The number of query embeddings processed at the same time. + Setting a lower batch_size results in a lower memory footprint. + k: Number of highest matches to return + + Returns: + List[List[int]: indexes of highest scores and the value for the highest score. Per query embedding the top k highest indexes are given. + List[List[float]: the highest scores. + """ + top_indexes_per_batch = [] + top_scores_per_batch = [] + num_of_query_embeddings = query_embeddings.shape[0] + # loop over the batches + for start_idx in tqdm( + range(0, num_of_query_embeddings, batch_size), + desc="Predicting highest ms2deepscore per batch of " + + str(min(batch_size, num_of_query_embeddings)) + + " embeddings", + ): + end_idx = min(start_idx + batch_size, num_of_query_embeddings) + selected_query_embeddings = query_embeddings[start_idx:end_idx] + score_matrix = cosine_similarity_matrix(selected_query_embeddings, library_embeddings) + top_n_idx = np.argsort(score_matrix, axis=1)[:, -k:][:, ::-1] + top_n_scores = np.take_along_axis(score_matrix, top_n_idx, axis=1) + top_indexes_per_batch.append(top_n_idx) + top_scores_per_batch.append(top_n_scores) + return np.vstack(top_indexes_per_batch), np.vstack(top_scores_per_batch) diff --git a/ms2query/benchmarking/reference_methods/predict_best_possible_match.py b/ms2query/benchmarking/reference_methods/predict_best_possible_match.py new file mode 100644 index 0000000..63625f2 --- /dev/null +++ b/ms2query/benchmarking/reference_methods/predict_best_possible_match.py @@ -0,0 +1,48 @@ +from typing import Dict + +import numpy as np +from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix + +from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints + + +def predict_best_possible_match(library_spectra: SpectraWithFingerprints, query_spectra: SpectraWithFingerprints): + highest_possible_score_per_inchikey = calculate_highest_tanimoto_score_per_inchikey(library_spectra, query_spectra) + + inchikeys_of_best_match = [] + highest_scores = [] + + for spectrum in query_spectra.spectra: + inchikey = spectrum.get("inchikey")[:14] + + inchikeys_of_best_match.append(highest_possible_score_per_inchikey[inchikey][0]) + highest_scores.append(highest_possible_score_per_inchikey[inchikey][1]) + + return inchikeys_of_best_match, highest_scores + + +def calculate_highest_tanimoto_score_per_inchikey( + library_spectra: SpectraWithFingerprints, query_spectra: SpectraWithFingerprints +) -> Dict[str, tuple[str, float]]: + """Finds the best possible match during an analogue search""" + print("Calculating tanimoto scores to determine best possible match") + library_fingerprints = np.array(list(library_spectra.inchikey_fingerprint_pairs.values())) + query_fingerprints = np.array(list(query_spectra.inchikey_fingerprint_pairs.values())) + tanimoto_scores = jaccard_similarity_matrix(library_fingerprints, query_fingerprints) + highest_scores = tanimoto_scores.max(axis=0, initial=0) + indexes_of_highest_scores = tanimoto_scores.argmax(axis=0) + + inchikeys_library = list(library_spectra.inchikey_fingerprint_pairs.keys()) + + highest_possible_score_per_inchikey = dict() + for i, inchikey in enumerate(query_spectra.inchikey_fingerprint_pairs): + # Check if inchikey in library (To correctly handle the exact matching case) + if inchikey in library_spectra.inchikey_fingerprint_pairs: + highest_possible_score_per_inchikey[inchikey] = (inchikey, 1.0) + continue + + highest_possible_score_per_inchikey[inchikey] = ( + inchikeys_library[indexes_of_highest_scores[i]], + highest_scores[i], + ) + return highest_possible_score_per_inchikey diff --git a/ms2query/benchmarking/reference_methods/predict_highest_cosine.py b/ms2query/benchmarking/reference_methods/predict_highest_cosine.py new file mode 100644 index 0000000..46820f1 --- /dev/null +++ b/ms2query/benchmarking/reference_methods/predict_highest_cosine.py @@ -0,0 +1,27 @@ +from typing import Tuple, List + +from matchms import Scores +from matchms.similarity.CosineGreedy import CosineGreedy +from matchms.similarity.PrecursorMzMatch import PrecursorMzMatch +from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints + + +def predict_highest_cosine( + library_spectra: SpectraWithFingerprints, query_spectra: SpectraWithFingerprints +) -> Tuple[List[str], List[float]]: + + scores = Scores(references=library_spectra.spectra, queries=query_spectra.spectra, is_symmetric=False) + scores = scores.calculate(PrecursorMzMatch(0.1)) + scores = scores.calculate(CosineGreedy(tolerance=0.1)) + inchikeys_of_best_match = [] + highest_scores = [] + for query_spectrum in query_spectra.spectra: + results = scores.scores_by_query(query_spectrum, "CosineGreedy_score", sort=True) + if len(results) == 0: + inchikeys_of_best_match.append(None) + highest_scores.append(0.0) + else: + best_reference, highest_score = results[0] + inchikeys_of_best_match.append(best_reference.get("inchikey")[:14]) + highest_scores.append(highest_score["CosineGreedy_score"]) + return inchikeys_of_best_match, highest_scores diff --git a/ms2query/benchmarking/reference_methods/predict_highest_ms2deepscore.py b/ms2query/benchmarking/reference_methods/predict_highest_ms2deepscore.py new file mode 100644 index 0000000..d645857 --- /dev/null +++ b/ms2query/benchmarking/reference_methods/predict_highest_ms2deepscore.py @@ -0,0 +1,17 @@ +from typing import Tuple, List + +from ms2query.benchmarking.reference_methods.PredictMS2DeepScoreSimilarity import predict_top_ms2deepscores +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings + + +def predict_highest_ms2deepscore( + library_spectra: SpectraWithMS2DeepScoreEmbeddings, query_spectra: SpectraWithMS2DeepScoreEmbeddings +) -> Tuple[List[str], List[float]]: + indexes_of_highest_scores, highest_scores = predict_top_ms2deepscores( + library_spectra.embeddings, query_spectra.embeddings, k=1 + ) + single_highest_score = [highest_score[0] for highest_score in highest_scores] + inchikeys_of_best_match = [ + library_spectra.spectra[index[0]].get("inchikey")[:14] for index in indexes_of_highest_scores + ] + return inchikeys_of_best_match, single_highest_score diff --git a/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py b/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py new file mode 100644 index 0000000..329fea1 --- /dev/null +++ b/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py @@ -0,0 +1,108 @@ +from typing import Tuple, List +from tqdm import tqdm +import numpy as np +from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix + +from ms2query.benchmarking.reference_methods.PredictMS2DeepScoreSimilarity import predict_top_ms2deepscores +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings + + +def predict_with_integrated_similarity_flow( + library_spectra: SpectraWithMS2DeepScoreEmbeddings, + query_spectra: SpectraWithMS2DeepScoreEmbeddings, + number_of_analogues_to_consider=50, +) -> Tuple[List[str], List[float]]: + + all_indexes_of_library_spectra_with_highest_score, all_predicted_scores = predict_top_ms2deepscores( + library_spectra.embeddings, query_spectra.embeddings, k=number_of_analogues_to_consider + ) + inchikeys_of_best_matches = [] + highest_isf_scores = [] + # loop over the query spectra: + for query_index in tqdm(range(len(query_spectra.spectra)), "Calculating ISF score"): + highest_isf_score, inchikey_of_highest_isf_score = get_highest_isf( + library_spectra, + all_indexes_of_library_spectra_with_highest_score[query_index], + all_predicted_scores[query_index], + ) + inchikeys_of_best_matches.append(inchikey_of_highest_isf_score) + highest_isf_scores.append(highest_isf_score) + return inchikeys_of_best_matches, highest_isf_scores + + +def get_highest_isf( + library_spectra: SpectraWithMS2DeepScoreEmbeddings, + indexes_of_library_spectra_with_highest_score: List[int], + predicted_scores: [List[float]], +): + + # Get the corresponding inchikeys + inchikeys_with_highest_ms2deepscore = [ + library_spectra.spectra[index].get("inchikey")[:14] for index in indexes_of_library_spectra_with_highest_score + ] + unique_inchikeys, average_scores, nr_of_spectra_per_inchikey = average_scores_per_inchikeys( + predicted_scores, inchikeys_with_highest_ms2deepscore + ) + # calculate tanimoto scores + library_fingerprints = np.array( + [library_spectra.inchikey_fingerprint_pairs[inchikey] for inchikey in unique_inchikeys] + ) + tanimoto_scores = jaccard_similarity_matrix(library_fingerprints, library_fingerprints) + + isf_scores = integrated_similarity_flow(average_scores, tanimoto_scores, nr_of_spectra_per_inchikey) + index_of_highest_score = np.argmax(isf_scores) + highest_isf_score = isf_scores[index_of_highest_score] + inchikey_of_highest_isf_score = unique_inchikeys[index_of_highest_score] + return highest_isf_score, inchikey_of_highest_isf_score + + +def average_scores_per_inchikeys(predicted_scores, inchikeys): + """Calculate the average precicted score per inchikey + This helps speed up the computations""" + if len(predicted_scores) != len(inchikeys): + raise ValueError + scores_per_inchikey = {} + for i, score in enumerate(predicted_scores): + inchikey = inchikeys[i] + if inchikey in scores_per_inchikey: + scores_per_inchikey[inchikey].append(score) + else: + scores_per_inchikey[inchikey] = [score] + # Take the average over the scores per inchikey + unique_inchikeys = [] + average_scores = [] + nr_of_spectra_per_inchikey = [] + for inchikey in scores_per_inchikey: + unique_inchikeys.append(inchikey) + average_scores.append(sum(scores_per_inchikey[inchikey]) / len(scores_per_inchikey[inchikey])) + nr_of_spectra_per_inchikey.append(len(scores_per_inchikey[inchikey])) + return unique_inchikeys, average_scores, nr_of_spectra_per_inchikey + + +def integrated_similarity_flow( + predicted_scores: List[float], similarities: np.ndarray, nr_of_spectra_per_inchikey: List[float] +) -> List[float]: + """Compute the confidence of the prediction for each candidate. + Integrated similarity flow (ISF) scores are calculated using the similarity of candidates among each other and their distance to the query spectrum. + + Args: + distances (list): Distances of the candidates to the query spectrum in the chemical space. + similarities (list of lists): Jaccard similarity of all candidates to each other. + + Returns: + dict[int, float]: ISF scores for each candid+ate. + """ + num_hits = len(predicted_scores) + isf_scores = [] + + # Total similarity + total_similarity = sum([predicted_scores[i] * nr_of_spectra_per_inchikey[i] for i in range(len(predicted_scores))]) + + for i in range(num_hits): + isf_score = ( + sum(predicted_scores[j] * similarities[i][j] * nr_of_spectra_per_inchikey[j] for j in range(num_hits)) + / total_similarity + ) + isf_scores.append(isf_score) + + return isf_scores From 40c6347eaba8d2a16f536f9e9a50cdcc448642dc Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Fri, 5 Dec 2025 11:16:02 +0100 Subject: [PATCH 04/45] Add notebooks used for testing (still have to be adapted to work with current code base) --- .../Test_ann_speed_improvements.ipynb | 313 ++++++ .../notebooks/Test_method_evaluator.ipynb | 942 ++++++++++++++++++ ...umber_of_inchikeys_with_two_ionmodes.ipynb | 349 +++++++ 3 files changed, 1604 insertions(+) create mode 100644 ms2query/notebooks/Test_ann_speed_improvements.ipynb create mode 100644 ms2query/notebooks/Test_method_evaluator.ipynb create mode 100644 ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb diff --git a/ms2query/notebooks/Test_ann_speed_improvements.ipynb b/ms2query/notebooks/Test_ann_speed_improvements.ipynb new file mode 100644 index 0000000..9b84f73 --- /dev/null +++ b/ms2query/notebooks/Test_ann_speed_improvements.ipynb @@ -0,0 +1,313 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "a3a9b10c-a9a6-441f-963c-edf5d9a50dbe", + "metadata": {}, + "outputs": [], + "source": [ + "import sys \n", + "sys.path.append(\"C:/Users/jonge094/PycharmProjects/ms2query_2_0/ms_chemical_space_explorer\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "69613d68-6340-4985-b266-ebdb56b21271", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "7551it [00:05, 1452.14it/s]\n", + "7142it [00:04, 1502.96it/s]\n" + ] + } + ], + "source": [ + "from matchms.importing import load_from_mgf\n", + "from tqdm import tqdm\n", + "\n", + "neg_val_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_validation_spectra.mgf\")))\n", + "neg_test_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_testing_spectra.mgf\")))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "55f31b78-e6b8-42a2-8ba0-d5cea429520d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\jonge094\\AppData\\Local\\miniconda3\\envs\\ms2query2\\lib\\site-packages\\ms2deepscore\\models\\load_model.py:34: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " model_settings = torch.load(filename, map_location=device)\n", + "7551it [02:50, 44.16it/s]\n" + ] + } + ], + "source": [ + "from ms2deepscore.models import load_model\n", + "from ms2deepscore.models import compute_embedding_array\n", + "\n", + "ms2deepscore_model = load_model(\"../../../ms2deepscore/data/pytorch/new_corinna_included/trained_models/both_mode_precursor_mz_ionmode_10000_layers_500_embedding_2024_11_21_11_23_17/ms2deepscore_model.pt\")\n", + "\n", + "embeddings = compute_embedding_array(ms2deepscore_model, neg_val_spectra)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fe9d805-561f-49c5-87f4-ce56c374db42", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 124, + "id": "d59f130e-f38e-467f-b6c0-899eb6fdb959", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "more_test_embeddings = np.tile(embeddings, (70, 1))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "id": "5ffe97b9-0199-41f3-9f27-b0a639274af0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(528570, 500)" + ] + }, + "execution_count": 125, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "more_test_embeddings.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "id": "9b493113-b067-40c9-9925-0db45af6693d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time eleapsed: 163.06616854667664\n" + ] + } + ], + "source": [ + "import pynndescent\n", + "import time\n", + "start_time = time.time()\n", + "ann_model = pynndescent.NNDescent(more_test_embeddings, metric=\"cosine\", n_neighbors=30)\n", + "ann_model.prepare()\n", + "print(\"Time eleapsed: \" + str(time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": 128, + "id": "f0ad0f13-f449-4871-9f23-c0ebfe38cb33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time eleapsed: 86.74623227119446\n" + ] + } + ], + "source": [ + "start_time = time.time()\n", + "ann_model.update(embeddings[:2])\n", + "ann_model.prepare()\n", + "\n", + "print(\"Time eleapsed: \" + str(time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "id": "da5276ff-42e8-4201-8bc0-a166216787ae", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time eleapsed: 1.8495116233825684\n" + ] + } + ], + "source": [ + "start_time = time.time()\n", + "indices, dists = ann_model.query(embeddings[:1000], epsilon=1, k=1000)\n", + "print(\"Time eleapsed: \" + str(time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "id": "65f44f03-6593-4bc4-bc4d-e64ffe6cbc08", + "metadata": {}, + "outputs": [], + "source": [ + "for dist in dists:\n", + " if dist > 0.000001:\n", + " print(dist)" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "8f8b125d-ace5-4af7-b0ce-6b9a263ece1b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6783\n", + "768\n" + ] + } + ], + "source": [ + "correct = 0\n", + "not_correct = 0\n", + "for i, index in enumerate(indices):\n", + " correct_if_0 = index[0]%7551-i\n", + " if correct_if_0== 0:\n", + " correct += 1\n", + " else:\n", + " not_correct +=1\n", + "print(correct)\n", + "print(not_correct)" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "id": "f84b71ed-4e35-42c1-8be1-a1c9b3ec0dae", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(226530, 500)" + ] + }, + "execution_count": 117, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "more_test_embeddings.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 126, + "id": "dbc240d1-62e2-40f1-8f2c-12345cce3bf6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time eleapsed: 14.737722158432007\n" + ] + } + ], + "source": [ + "from ms2deepscore.vector_operations import cosine_similarity_matrix\n", + "start_time = time.time()\n", + "matrix = cosine_similarity_matrix(more_test_embeddings, embeddings[:1000])\n", + "print(\"Time eleapsed: \" + str(time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "id": "f87ba107-54cc-47ee-87c8-34694a917a3b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1. , 0.88234181, 0.73878687, ..., 0.55531438, 0.58179732,\n", + " 0.61927229],\n", + " [0.88234181, 1. , 0.91034399, ..., 0.50462923, 0.51674736,\n", + " 0.54464448],\n", + " [0.73878687, 0.91034399, 1. , ..., 0.48358345, 0.4835752 ,\n", + " 0.54332801],\n", + " ...,\n", + " [0.49057829, 0.55037322, 0.56818589, ..., 0.39834881, 0.41004598,\n", + " 0.5298201 ],\n", + " [0.48655507, 0.57605234, 0.5946298 , ..., 0.40762756, 0.42247458,\n", + " 0.50066275],\n", + " [0.41914688, 0.47922786, 0.48648942, ..., 0.38071478, 0.41207848,\n", + " 0.47297686]])" + ] + }, + "execution_count": 120, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "matrix" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62b738da-a9ce-421f-8ef8-ee8a18d2b70f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (ms2query2)", + "language": "python", + "name": "ms2query2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ms2query/notebooks/Test_method_evaluator.ipynb b/ms2query/notebooks/Test_method_evaluator.ipynb new file mode 100644 index 0000000..097c1d3 --- /dev/null +++ b/ms2query/notebooks/Test_method_evaluator.ipynb @@ -0,0 +1,942 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ced5638b-15f4-4219-9ed8-b2823a53e574", + "metadata": {}, + "source": [ + "# load in spectra" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2768cf34-1794-4a09-89a3-f81a04b1bcda", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "31453it [00:14, 2140.47it/s]\n", + "7080it [00:03, 1992.44it/s]\n", + "459610it [03:37, 2114.18it/s]\n", + "131240it [01:06, 1967.67it/s]\n" + ] + } + ], + "source": [ + "from matchms.importing import load_from_mgf\n", + "from tqdm import tqdm\n", + "import os\n", + "\n", + "save_directory = \"../data/ms2deepscore_model/training_and_validation_split/\"\n", + "pos_val_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"positive_validation_spectra.mgf\"))))\n", + "neg_val_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"negative_validation_spectra.mgf\"))))\n", + "pos_train_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"positive_training_spectra.mgf\"))))\n", + "neg_train_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"negative_training_spectra.mgf\"))))\n", + "pos_test_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"positive_testing_spectra.mgf\"))))\n", + "neg_test_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"negative_testing_spectra.mgf\"))))\n" + ] + }, + { + "cell_type": "markdown", + "id": "7f5a4df2-2bb7-4731-a6bf-6727183da493", + "metadata": {}, + "source": [ + "### Save as pickled files for quicker reloads" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3a55be7f-d760-4df7-944e-251844cb1b4d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os\n", + "pickled_intermediates_data_folder = \"../data/pickled_intermediates\"\n", + "os.path.isdir(pickled_intermediates_data_folder)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "28606ab4-fd4c-4111-a1ca-8dc502593425", + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_val_spectra.pickle\"), \"wb\") as handle:\n", + " pickle.dump(neg_val_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_train_spectra.pickle\"), \"wb\") as handle:\n", + " pickle.dump(neg_train_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"pos_val_spectra.pickle\"), \"wb\") as handle:\n", + " pickle.dump(pos_val_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"pos_train_spectra.pickle\"), \"wb\") as handle:\n", + " pickle.dump(pos_train_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"pos_test_spectra.pickle\"), \"wb\") as handle:\n", + " pickle.dump(pos_test_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_test_spectra.pickle\"), \"wb\") as handle:\n", + " pickle.dump(neg_test_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "markdown", + "id": "cf849c08-3eaf-4922-b361-4f324b010ca9", + "metadata": {}, + "source": [ + "# Create a simple MS2Deepscore ranker" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "48bd37e7-384a-4bbc-821d-946975bfdf5b", + "metadata": {}, + "outputs": [], + "source": [ + "from ms2deepscore.models import load_model\n", + "ms2deepscore_model = load_model(\"../data/ms2deepscore_model/trained_models/both_mode_ionmode_precursor_mz_2000_layers_500_embedding_2025_02_26_18_42_25/ms2deepscore_model.pt\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9710c95d-803d-4492-b279-13588afa9c27", + "metadata": {}, + "outputs": [], + "source": [ + "import sys \n", + "sys.path.append(\"../../ms_chemical_space_explorer\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca3486f3-06b3-402d-90b1-ba01cb2568aa", + "metadata": {}, + "outputs": [], + "source": [ + "from ms_chemical_space_explorer.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings\n", + "library_spectra = SpectraWithMS2DeepScoreEmbeddings(neg_train_spectra + pos_train_spectra, ms2deepscore_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c858cc69-3da8-4530-ba13-ad4b1827f1ae", + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_library_embeddings.pickle\"), \"wb\") as handle:\n", + " pickle.dump(library_spectra.embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53ca37fe-1627-4473-a21a-168791d92112", + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_library_with_embeddings.pickle\"), \"wb\") as handle:\n", + " pickle.dump(library_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cec8761-6c37-4ac5-ba52-619c171cd834", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "val_spectra = SpectraWithMS2DeepScoreEmbeddings(neg_val_spectra + pos_val_spectra, ms2deepscore_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "3ba69a93-5ede-4920-b3a6-e7881e20eb39", + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_val_spectra_with_embeddings.pickle\"), \"wb\") as handle:\n", + " pickle.dump(val_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "markdown", + "id": "688da27f-3290-480c-ab67-d171bbac6b93", + "metadata": {}, + "source": [ + "# Reload intermediates" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9cb48a39-f3cd-4924-8da2-f11007475795", + "metadata": {}, + "outputs": [], + "source": [ + "import sys \n", + "sys.path.append(\"../../ms_chemical_space_explorer\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "97d042ab-19f6-4d32-bfd7-b24e0f72c1d9", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pickle\n", + "pickled_intermediates_data_folder = \"../data/pickled_intermediates\"\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_library_with_embeddings.pickle\"), \"rb\") as file:\n", + " library_spectra = pickle.load(file)\n", + "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_val_spectra_with_embeddings.pickle\"), \"rb\") as file:\n", + " val_spectra = pickle.load(file)" + ] + }, + { + "cell_type": "markdown", + "id": "c3c8dfda-3294-45cc-8a6f-1aeb06257c73", + "metadata": {}, + "source": [ + "# Initialize a method evaluator" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "79fa345a-52eb-4079-a6b5-0bd63d78d821", + "metadata": {}, + "outputs": [], + "source": [ + "from ms_chemical_space_explorer.benchmarking.EvaluateMethods import EvaluateMethods\n", + "\n", + "method_evaluator = EvaluateMethods(library_spectra, val_spectra)\n", + "method_evaluator.training_spectrum_set.progress_bars = False\n", + "method_evaluator.validation_spectrum_set.progress_bars = False" + ] + }, + { + "cell_type": "markdown", + "id": "16663c37-4fdd-4454-bcc6-55dbf4d6230f", + "metadata": {}, + "source": [ + "# Test basic MS2DeepScore" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "1048dccf-1b53-4dcc-96d4-01a4d51dc57d", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 78/78 [26:20<00:00, 20.26s/it]\n", + "Calculating analogue accuracy per inchikey: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2015/2015 [00:01<00:00, 1190.75it/s]\n" + ] + } + ], + "source": [ + "from ms_chemical_space_explorer.methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore\n", + "\n", + "result_analogue = method_evaluator.benchmark_analogue_search(predict_highest_ms2deepscore)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "0f6f21cb-a358-4803-ad74-cc57a2a7a47e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.3848227909492443\n" + ] + } + ], + "source": [ + "print(result_analogue)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "eda9a92c-fe2c-4c4b-ab4e-07ddb4346a27", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1837/1837 [00:00<00:00, 53337.60it/s]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [10:21<00:00, 20.03s/it]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [10:35<00:00, 19.25s/it]\n", + "Calculating exact match accuracy per inchikey: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1612/1612 [00:00<00:00, 313512.85it/s]\n" + ] + } + ], + "source": [ + "from ms_chemical_space_explorer.methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore\n", + "\n", + "result_positive = method_evaluator.benchmark_exact_matching_within_ionmode(predict_highest_ms2deepscore, \"positive\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "b6ddad52-8647-450f-8881-6753a5bea814", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5169110149857257\n" + ] + } + ], + "source": [ + "print(result_positive)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "dfc7a156-74a3-4c48-b610-e38410abf274", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 924/924 [00:00<00:00, 293846.15it/s]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:08<00:00, 18.29s/it]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [02:14<00:00, 16.78s/it]\n", + "Calculating exact match accuracy per inchikey: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 868/868 [00:00<00:00, 894070.70it/s]\n" + ] + } + ], + "source": [ + "from ms_chemical_space_explorer.methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore\n", + "\n", + "result_neg = method_evaluator.benchmark_exact_matching_within_ionmode(predict_highest_ms2deepscore, \"negative\")" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "86cc5ba4-4380-488a-a486-de63cccada74", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5448453174876924" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_neg\n" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "a5b0595b-51dd-4a56-b9ec-e81a597dd26b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey across ionmodes: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2015/2015 [00:00<00:00, 9551.01it/s]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [11:40<00:00, 20.02s/it]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [03:47<00:00, 17.46s/it]\n", + "Calculating exact match accuracy per inchikey: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 746/746 [00:00<00:00, 496186.30it/s]\n" + ] + } + ], + "source": [ + "result_across_ionmodes = method_evaluator.exact_matches_across_ionization_modes(predict_highest_ms2deepscore)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "fe14e752-75b4-4e00-8600-90c1dce46b90", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.0006188308440879157" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_across_ionmodes" + ] + }, + { + "cell_type": "markdown", + "id": "1499181b-7b68-437a-85b3-b7dc3ad53884", + "metadata": {}, + "source": [ + "# Test best possible results" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "56d119e2-145a-484c-ab3c-56b569b549dd", + "metadata": {}, + "outputs": [], + "source": [ + "from ms_chemical_space_explorer.methods.predict_best_possible_match import predict_best_possible_match\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b3687845-099c-4b98-b715-c8057b5858bb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calculating tanimoto scores to determine best possible match\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Calculating analogue accuracy per inchikey: 100%|██████████████████████████████████████████| 2015/2015 [00:01<00:00, 1578.04it/s]\n" + ] + } + ], + "source": [ + "result_analogue_best = method_evaluator.benchmark_analogue_search(predict_highest_ms2deepscore)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "81c83a8d-0754-4207-97d3-2fbef49082ac", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.7753653963061775" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_analogue_best" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "5fad0038-025c-4c73-9bfb-bc4545f092df", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 924/924 [00:00<00:00, 310739.01it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calculating tanimoto scores to determine best possible match\n", + "Calculating tanimoto scores to determine best possible match\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Calculating exact match accuracy per inchikey: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 868/868 [00:00<00:00, 928738.74it/s]\n" + ] + } + ], + "source": [ + "result_neg_best = method_evaluator.benchmark_exact_matching_within_ionmode(predict_best_possible_match, \"negative\")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "e0418a27-0195-4021-a04b-788cbee4fe2b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_neg_best" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "f19ea34d-f6c1-4f7c-b7ec-8c7bad85a1da", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey across ionmodes: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2015/2015 [00:00<00:00, 13798.67it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calculating tanimoto scores to determine best possible match\n", + "Calculating tanimoto scores to determine best possible match\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Calculating exact match accuracy per inchikey: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 746/746 [00:00<00:00, 404727.82it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_across_ionmodes_best = method_evaluator.exact_matches_across_ionization_modes(predict_best_possible_match)\n", + "result_across_ionmodes_best" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "052e768d-4044-40b8-a85e-7cf03abe4184", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey across ionmodes: 100%|█████████████████████████████████████| 2015/2015 [00:00<00:00, 10356.24it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calculating tanimoto scores to determine best possible match\n", + "Calculating tanimoto scores to determine best possible match\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Calculating exact match accuracy per inchikey: 100%|███████████████████████████████████████| 746/746 [00:00<00:00, 315164.26it/s]\n" + ] + } + ], + "source": [ + "result_across_ionmodes_best = method_evaluator.exact_matches_across_ionization_modes(predict_best_possible_match)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1ff6eaba-8213-4270-9395-39ff3aee6879", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_across_ionmodes_best" + ] + }, + { + "cell_type": "markdown", + "id": "3702ba19-0c86-443c-8621-fe1ab427b3da", + "metadata": {}, + "source": [ + "# Test cosine score" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5f0cfc49-009f-46a1-beef-6310d31260a1", + "metadata": {}, + "outputs": [], + "source": [ + "from ms_chemical_space_explorer.methods.predict_highest_cosine import predict_highest_cosine" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cb9210d-3ead-4798-bcc0-82b571e69f8b", + "metadata": {}, + "outputs": [], + "source": [ + "result_analogue_cosine = method_evaluator.benchmark_analogue_search(predict_highest_cosine)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1a5fb72-6039-4d56-af71-fe07bb583c22", + "metadata": {}, + "outputs": [], + "source": [ + "result_analogue_cosine" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7d4641b9-9b44-4b64-8fed-0b568cb0f4f9", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 924/924 [00:00<00:00, 260575.33it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "created scores\n", + "Filtered on precursor\n", + "Calculated cosine\n", + "('PrecursorMzMatch', 'CosineGreedy_score', 'CosineGreedy_matches')\n", + "created scores\n", + "Filtered on precursor\n", + "Calculated cosine\n", + "('PrecursorMzMatch', 'CosineGreedy_score', 'CosineGreedy_matches')\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Calculating exact match accuracy per inchikey: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 868/868 [00:00<00:00, 201408.27it/s]\n" + ] + } + ], + "source": [ + "result_neg = method_evaluator.benchmark_exact_matching_within_ionmode(predict_highest_cosine, \"negative\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "96191952-14df-41b3-a410-d2e5574dd3c9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5772003486330072" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result_neg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e77b0ecd-7f1a-438d-a0a1-950ff9e76881", + "metadata": {}, + "outputs": [], + "source": [ + "result_positive = method_evaluator.benchmark_exact_matching_within_ionmode(predict_highest_cosine, \"positive\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7db770c4-f4bb-4a0e-82a2-03d52e6ed7d0", + "metadata": {}, + "outputs": [], + "source": [ + "result_positive" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "675351b2-ff50-42d8-b179-006d0a46cadc", + "metadata": {}, + "outputs": [], + "source": [ + "result_across_ionmodes = method_evaluator.exact_matches_across_ionization_modes(predict_highest_cosine)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d056551a-f85b-4077-8b35-eaf834da7a9b", + "metadata": {}, + "outputs": [], + "source": [ + "result_across_ionmodes" + ] + }, + { + "cell_type": "markdown", + "id": "7682a1b8-aec1-40dc-b942-419bd23de041", + "metadata": {}, + "source": [ + "# Test predictions with ISF" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1dc19bef-69ab-471e-94ab-6a9caffd9030", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "IOStream.flush timed outepscore per batch of 500 embeddings: 77%|██████████████████████████████████████████████████████████████████▏ | 60/78 [33:53<22:39, 75.54s/it]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████| 78/78 [42:49<00:00, 32.95s/it]\n", + "Calculating ISF score: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 38533/38533 [00:44<00:00, 857.97it/s]\n", + "Calculating analogue accuracy per inchikey: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 2015/2015 [00:01<00:00, 1076.80it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.35576217271302824\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 924/924 [00:00<00:00, 202462.49it/s]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:56<00:00, 25.28s/it]\n", + "Calculating ISF score: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3303/3303 [00:04<00:00, 713.49it/s]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████| 8/8 [03:30<00:00, 26.27s/it]\n", + "Calculating ISF score: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3721/3721 [00:04<00:00, 856.75it/s]\n", + "Calculating exact match accuracy per inchikey: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 868/868 [00:00<00:00, 176636.55it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.04118290017821497\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1837/1837 [00:00<00:00, 134548.79it/s]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 84%|███████████████████████████████████████████████████████████████████████▎ | 26/31 [38:00<38:36, 463.28s/it]IOStream.flush timed out\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████| 31/31 [41:05<00:00, 79.53s/it]\n", + "Calculating ISF score: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15218/15218 [00:20<00:00, 753.86it/s]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████| 33/33 [12:25<00:00, 22.58s/it]\n", + "Calculating ISF score: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16010/16010 [00:22<00:00, 719.72it/s]\n", + "Calculating exact match accuracy per inchikey: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 1612/1612 [00:00<00:00, 236406.23it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.05296326960524268\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Splitting spectra per inchikey across ionmodes: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 2015/2015 [00:00<00:00, 3809.64it/s]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████| 35/35 [13:13<00:00, 22.66s/it]\n", + "Calculating ISF score: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17128/17128 [00:21<00:00, 808.89it/s]\n", + "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████| 13/13 [04:14<00:00, 19.57s/it]\n", + "Calculating ISF score: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6059/6059 [00:08<00:00, 714.94it/s]\n", + "Calculating exact match accuracy per inchikey: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 746/746 [00:00<00:00, 56840.41it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0006614631666202545\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from ms_chemical_space_explorer.methods.predict_with_integrated_similarity_flow import predict_with_integrated_similarity_flow\n", + "\n", + "result_analogue_isf = method_evaluator.benchmark_analogue_search(predict_with_integrated_similarity_flow)\n", + "print(result_analogue_isf)\n", + "result_neg_isf = method_evaluator.benchmark_exact_matching_within_ionmode(predict_with_integrated_similarity_flow, \"negative\")\n", + "print(result_neg_isf)\n", + "result_positive_isf = method_evaluator.benchmark_exact_matching_within_ionmode(predict_with_integrated_similarity_flow, \"positive\")\n", + "print(result_positive_isf)\n", + "result_across_ionmodes_isf = method_evaluator.exact_matches_across_ionization_modes(predict_with_integrated_similarity_flow)\n", + "print(result_across_ionmodes_isf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b891a34-aade-49e8-ad91-88b36cf0e011", + "metadata": {}, + "outputs": [], + "source": [ + "result_analogue_isf" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "22ba5663-1f8a-405a-b04b-0380e18c6987", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hello\n" + ] + } + ], + "source": [ + "print(\"hello\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b9198fb-c60b-4db5-8415-578b88275e5a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb b/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb new file mode 100644 index 0000000..dcad892 --- /dev/null +++ b/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb @@ -0,0 +1,349 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "df3d5e3f-2694-4eed-9fc1-a78483629412", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "7551it [00:07, 1015.57it/s]\n", + "7142it [00:12, 560.76it/s] \n", + "130901it [03:10, 685.40it/s] \n", + "25412it [00:32, 784.47it/s] \n", + "24911it [00:34, 718.09it/s] \n", + "25412it [00:45, 556.46it/s] \n" + ] + } + ], + "source": [ + "from matchms.importing import load_from_mgf\n", + "from tqdm import tqdm\n", + "\n", + "neg_val_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_validation_spectra.mgf\")))\n", + "neg_test_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_testing_spectra.mgf\")))\n", + "neg_train_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_training_spectra.mgf\")))\n", + "neg_spectra = neg_val_spectra + neg_test_spectra + neg_train_spectra\n", + "\n", + "pos_val_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/positive_validation_spectra.mgf\")))\n", + "pos_test_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/positive_testing_spectra.mgf\")))\n", + "pos_train_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/positive_training_spectra.mgf\")))\n", + "pos_spectra = pos_val_spectra + pos_test_spectra + pos_train_spectra" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2771fc81-ad6b-4b2d-90ea-c541641af1a5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 145594/145594 [00:25<00:00, 5637.57it/s]\n" + ] + } + ], + "source": [ + "neg_inchikeys = []\n", + "for spectrum in tqdm(neg_spectra):\n", + " neg_inchikeys.append(spectrum.get(\"inchikey\")[:14])\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "437e9455-49a9-435e-8bb5-cd3b3ed987f1", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 519580/519580 [00:46<00:00, 11219.75it/s]\n" + ] + } + ], + "source": [ + "pos_inchikeys = []\n", + "for spectrum in tqdm(pos_spectra):\n", + " pos_inchikeys.append(spectrum.get(\"inchikey\")[:14])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "df03a6d2-6755-46e1-a874-fc96c1a2b20c", + "metadata": {}, + "outputs": [], + "source": [ + "unique_neg_inchikeys = set(neg_inchikeys)\n", + "unique_pos_inchikeys = set(pos_inchikeys)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "aa775f3b-9c10-43cd-a6be-74dac2e698ce", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "18480" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(unique_neg_inchikeys)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f3a73395-68c2-4952-9171-96f67f597cc9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "36638" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(unique_pos_inchikeys)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f9729ad5-d201-476a-87a1-78e267425aae", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "14801" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "overlapping_inchikeys = unique_neg_inchikeys & unique_pos_inchikeys\n", + "len(overlapping_inchikeys)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "07d148f7-aef6-4eda-a66f-984f0c65a31b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "40317" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "total_unique_inchikeys = unique_neg_inchikeys | unique_pos_inchikeys\n", + "len(total_unique_inchikeys)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "0cd254e4-9b7f-41ed-a7ce-25675392102f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.36711560880025795" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "14801/40317" + ] + }, + { + "cell_type": "markdown", + "id": "3f5bb776-9910-4a4d-bc45-d8e049d63ce6", + "metadata": {}, + "source": [ + "# Check how many there are already in the val set\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "4b5ced16-ceae-439b-bb04-c72986798b8f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7551/7551 [00:01<00:00, 4630.22it/s]\n" + ] + } + ], + "source": [ + "neg_inchikeys_val = []\n", + "for spectrum in tqdm(neg_val_spectra):\n", + " neg_inchikeys_val.append(spectrum.get(\"inchikey\")[:14])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "0d2b00bd-56a3-4c67-ad3f-d54051a74701", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25412/25412 [00:02<00:00, 10220.59it/s]\n" + ] + } + ], + "source": [ + "pos_inchikeys_val = []\n", + "for spectrum in tqdm(pos_val_spectra):\n", + " pos_inchikeys_val.append(spectrum.get(\"inchikey\")[:14])" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "76a7736a-d4a4-42a5-bf42-d20b1a38060c", + "metadata": {}, + "outputs": [], + "source": [ + "unique_neg_inchikeys_val = set(neg_inchikeys_val)\n", + "unique_pos_inchikeys_val = set(pos_inchikeys_val)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "6fce2334-e43b-40ce-b459-75c1a6d1ed00", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "35" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "overlapping_inchikeys_val = unique_neg_inchikeys_val & unique_pos_inchikeys_val\n", + "len(overlapping_inchikeys_val)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "86668203-8e5b-40e9-934e-0ade3c25e9d8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "7551" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(neg_inchikeys_val)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "bc2e25e7-5d85-4486-924d-c1d9651ba2bc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "25412" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(pos_inchikeys_val)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e864dc8e-68bb-4d3d-8a27-57f3892581e3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 9f67cd2d1c48317fa159fba205d7a438926341f4 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Fri, 5 Dec 2025 11:16:15 +0100 Subject: [PATCH 05/45] Add tests --- tests/conftest.py | 79 +++++++++++++ tests/testPredictMS2DeepScoreSimilarity.py | 29 +++++ tests/test_SpectrumDataSet.py | 130 +++++++++++++++++++++ tests/test_evaluate_methods.py | 31 +++++ tests/test_methods.py | 58 +++++++++ 5 files changed, 327 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/testPredictMS2DeepScoreSimilarity.py create mode 100644 tests/test_SpectrumDataSet.py create mode 100644 tests/test_evaluate_methods.py create mode 100644 tests/test_methods.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..80ab525 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,79 @@ +import os +from pathlib import Path + +import numpy as np +from matchms.Spectrum import Spectrum +from ms2deepscore.models import load_model + + +TEST_RESOURCES_PATH = Path(__file__).parent / "test_data" + + +def create_test_spectra( + number_of_spectra_per_inchikey=3, + inchikey_inchi_pairs=None, + nr_of_inchikeys=3, +): + if inchikey_inchi_pairs is None: + inchikey_inchi_pairs = get_inchikey_inchi_pairs(nr_of_inchikeys) + spectra = [] + for i, inchikey_inchi_tuple in enumerate(inchikey_inchi_pairs): + inchikey, inchi, smiles, compound_name = inchikey_inchi_tuple + for j in range(number_of_spectra_per_inchikey): + spectra.append( + Spectrum( + mz=np.array([100 + i * 10.0, 500 + i * 1.0]), + intensities=np.array([1.0, 1.0 / (j + 1)]), + metadata={ + "precursor_mz": 111.1 + i * 10, + "inchikey": inchikey, + "inchi": inchi, + "smiles": smiles, + "compound_name": compound_name, + }, + ) + ) + return spectra + + +def ms2deepscore_model(): + return load_model(os.path.join(TEST_RESOURCES_PATH, "ms2deepscore_testmodel_v1.pt")) + + +def get_inchikey_inchi_pairs(number_of_pairs): + """Returns inchikey_inchi_pairs""" + inchikey_inchi_pairs = ( + ( + "RYYVLZVUVIJVGH-UHFFFAOYSA-N", + "InChI=1S/C8H10N4O2/c1-10-4-9-6-5(10)7(13)12(3)8(14)11(6)2/h4H,1-3H3", + "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", + "Caffeine", + ), + ( + "ZPUCINDJVBIVPJ-LJISPDSOSA-N", + "InChI=1S/C17H21NO4/c1-18-12-8-9-13(18)15(17(20)21-2)14(10-12)22-16(19)11-6-4-3-5-7-11/h3-7,12-15H,8-10H2,1-2H3/t12-,13+,14-,15+/m0/s1", + "CN1[C@H]2CC[C@@H]1[C@H]([C@H](C2)OC(=O)C3=CC=CC=C3)C(=O)OC", + "Cocaine", + ), + ( + "RZVAJINKPMORJF-UHFFFAOYSA-N", + "InChI=1S/C8H9NO2/c1-6(10)9-7-2-4-8(11)5-3-7/h2-5,11H,1H3,(H,9,10)", + "CC(=O)NC1=CC=C(C=C1)O", + "Paracetemol", + ), + ( + "JGSARLDLIJGVTE-MBNYWOFBSA-N", + "InChI=1S/C16H18N2O4S/c1-16(2)12(15(21)22)18-13(20)11(14(18)23-16)17-10(19)8-9-6-4-3-5-7-9/h3-7,11-12,14H,8H2,1-2H3,(H,17,19)(H,21,22)/t11-,12+,14-/m1/s1", + "CC1([C@@H](N2[C@H](S1)[C@@H](C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C", + "Penicillin", + ), + ( + "WQZGKKKJIJFFOK-GASJEMHNSA-N", + "InChI=1S/C6H12O6/c7-1-2-3(8)4(9)5(10)6(11)12-2/h2-11H,1H2/t2-,3-,4+,5-,6?/m1/s1", + "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", + "Glucose", + ), + ) + if number_of_pairs > len(inchikey_inchi_pairs): + raise ValueError("Not enough example compounds, add some in conftest") + return inchikey_inchi_pairs[:number_of_pairs] diff --git a/tests/testPredictMS2DeepScoreSimilarity.py b/tests/testPredictMS2DeepScoreSimilarity.py new file mode 100644 index 0000000..9741693 --- /dev/null +++ b/tests/testPredictMS2DeepScoreSimilarity.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest + +from ms2query.benchmarking.reference_methods.PredictMS2DeepScoreSimilarity import ( + predict_top_ms2deepscores, +) +from tests.conftest import create_test_spectra, ms2deepscore_model +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings + + +@pytest.mark.parametrize( + "method", + [ + predict_top_ms2deepscores, + ], +) +def test_predict_highest_ms2deepscore_similarity(method): + ms2ds_model = ms2deepscore_model() + test_spectra = create_test_spectra(1) + library_spectra = SpectraWithMS2DeepScoreEmbeddings(test_spectra, ms2ds_model) + query_spectra = SpectraWithMS2DeepScoreEmbeddings(test_spectra, ms2ds_model) + number_of_analogues = 2 + + indices, distances = method(library_spectra.embeddings, query_spectra.embeddings, k=number_of_analogues) + assert indices.shape == (len(test_spectra), number_of_analogues) + assert distances.shape == (len(test_spectra), number_of_analogues) + for i, row in enumerate(indices): + assert row[0] == i, "The highest predictions should be against itself" + assert np.allclose(distances[i][0], 1.0, atol=1e-5) diff --git a/tests/test_SpectrumDataSet.py b/tests/test_SpectrumDataSet.py new file mode 100644 index 0000000..d7b3cff --- /dev/null +++ b/tests/test_SpectrumDataSet.py @@ -0,0 +1,130 @@ +import numpy as np +import pytest + +from ms2query.benchmarking.SpectrumDataSet import ( + SpectraWithFingerprints, + SpectrumSetBase, + SpectraWithMS2DeepScoreEmbeddings, +) +from tests.conftest import create_test_spectra, ms2deepscore_model, get_inchikey_inchi_pairs + + +@pytest.mark.parametrize( + "library", + [ + SpectrumSetBase(create_test_spectra()), + SpectraWithFingerprints(create_test_spectra()), + SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(), ms2deepscore_model()), + ], +) +def test_spectrum_set_base(library): + """Test all base functionality of SpectrumSetBase is implemented correctly also for all classes inheriting from it""" + # test correct init + assert len(library.spectra) == 9 + assert len(library.spectrum_indexes_per_inchikey) == 3 + assert sum(len(v) for v in library.spectrum_indexes_per_inchikey.values()) == 9 + + # test correct copying + new_copy = library.copy() + assert len(new_copy.spectra) == 9 + assert len(new_copy.spectrum_indexes_per_inchikey) == 3 + assert sum(len(v) for v in new_copy.spectrum_indexes_per_inchikey.values()) == 9 + + # test correctly adding spectra + new_copy.add_spectra(library) + new_number_of_spectra = 9 + 9 + assert len(new_copy.spectra) == new_number_of_spectra + assert len(new_copy.spectrum_indexes_per_inchikey) == 3 + assert sum(len(v) for v in new_copy.spectrum_indexes_per_inchikey.values()) == new_number_of_spectra + + # test the original is not edited when adding spectra + assert len(library.spectra) == 9 + assert len(library.spectrum_indexes_per_inchikey) == 3 + assert sum(len(v) for v in library.spectrum_indexes_per_inchikey.values()) == 9 + + # test correct subsetting + subset_indexes = [1, 4, 6, 7] + subset = library.subset_spectra(subset_indexes) + assert len(subset.spectra) == len(subset_indexes) + assert len(subset.spectrum_indexes_per_inchikey) == 3 + assert sum(len(v) for v in subset.spectrum_indexes_per_inchikey.values()) == len(subset_indexes) + assert isinstance(subset, library.__class__) + + +@pytest.mark.parametrize( + "library", + [ + SpectraWithFingerprints(create_test_spectra()), + SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(), ms2deepscore_model()), + ], +) +def test_spectra_with_fingerprints(library): + """Test all functionality added in SpectraWithFingerprints also for all classes inheriting from it""" + # test correct init + assert len(library.inchikey_fingerprint_pairs) == 3 + + # test correct copying + new_copy = library.copy() + assert len(new_copy.inchikey_fingerprint_pairs) == 3 + + # test correctly adding inchikey_fingerprint_pairs when runnning add_spectra + for inchikey_inchi_pairs, expected_nr_of_inchikeys in ( + (get_inchikey_inchi_pairs(5)[2:], 5), # Some overlap + (get_inchikey_inchi_pairs(5)[3:], 5), # No overlap + (get_inchikey_inchi_pairs(3), 3), # Fully overlapping + (get_inchikey_inchi_pairs(1), 3), # Fully overlapping (but not all) + ): + spectra_to_add = SpectrumSetBase(create_test_spectra(2, inchikey_inchi_pairs=inchikey_inchi_pairs)) + new_copy = library.copy() + new_copy.add_spectra(spectra_to_add) + assert len(new_copy.inchikey_fingerprint_pairs) == expected_nr_of_inchikeys + for inchikey in library.inchikey_fingerprint_pairs: + assert np.array_equal( + new_copy.inchikey_fingerprint_pairs[inchikey], library.inchikey_fingerprint_pairs[inchikey] + ) + + # test the original is not edited when adding spectra + assert len(library.inchikey_fingerprint_pairs) == 3 + assert all( + np.array_equal(library.inchikey_fingerprint_pairs[key], value) + for key, value in SpectraWithFingerprints(create_test_spectra()).inchikey_fingerprint_pairs.items() + ) + + # test correct subsetting + subset_indexes = [1, 4, 6, 7] + subset = library.subset_spectra(subset_indexes) + assert len(subset.inchikey_fingerprint_pairs) == 3 + assert all( + np.array_equal(library.inchikey_fingerprint_pairs[key], value) + for key, value in subset.inchikey_fingerprint_pairs.items() + ) + assert hasattr(subset, "update_fingerprint_per_inchikey") + + +def test_spectra_with_embeddings(): + library = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(), ms2deepscore_model()) + # test correct init + assert library.embeddings.shape == (9, 100) + + # test correct copying + new_copy = library.copy() + assert new_copy.embeddings.shape == (9, 100) + + # test correctly adding spectra + new_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(1), ms2deepscore_model()) + new_copy.add_spectra(new_spectra) + assert new_copy.embeddings.shape == (12, 100) + + # test the original is not edited when adding spectra + assert library.embeddings.shape == (9, 100) + + # test correct subsetting + subset_indexes = [1, 4, 6, 7] + subset = library.subset_spectra(subset_indexes) + assert subset.embeddings.shape == (len(subset_indexes), 100) + for i, index in enumerate(subset_indexes): + assert np.all(library.embeddings[index] == subset.embeddings[i]) + + # Check that subsetting on subset works. To make sure that a subset does not become of type SpectrumSetBase + subsetted_subset = subset.subset_spectra([0, 1]) + assert subsetted_subset.embeddings.shape == (2, 100) diff --git a/tests/test_evaluate_methods.py b/tests/test_evaluate_methods.py new file mode 100644 index 0000000..4bcb839 --- /dev/null +++ b/tests/test_evaluate_methods.py @@ -0,0 +1,31 @@ +import pytest +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings +from ms2query.benchmarking.EvaluateMethods import EvaluateMethods +from ms2query.benchmarking.reference_methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore +from ms2query.benchmarking.reference_methods.predict_best_possible_match import predict_best_possible_match +from tests.conftest import create_test_spectra, ms2deepscore_model +from ms2query.benchmarking.reference_methods.predict_highest_cosine import predict_highest_cosine + + +@pytest.mark.parametrize( + "method", + [predict_highest_ms2deepscore, predict_highest_cosine, predict_best_possible_match], +) +def test_evaluate_methods(method): + nr_of_spectra_per_inchikey = 6 + nr_of_inchikeys = 5 + dummy_spectra = create_test_spectra(nr_of_spectra_per_inchikey, nr_of_inchikeys=nr_of_inchikeys) + for i, spectrum in enumerate(dummy_spectra): + if i % 3 == 0: + spectrum.set("ionmode", "positive") + else: + spectrum.set("ionmode", "negative") + model = ms2deepscore_model() + reference_library = SpectraWithMS2DeepScoreEmbeddings(dummy_spectra[: nr_of_spectra_per_inchikey * 2], model) + validation_spectra = SpectraWithMS2DeepScoreEmbeddings(dummy_spectra[nr_of_spectra_per_inchikey * 2 :], model) + method_evaluator = EvaluateMethods(reference_library, validation_spectra) + # should be zero or below zero, since it is the difference with the perfect predictions + assert method_evaluator.benchmark_analogue_search(method) >= 0.0 + # # Should be 1 because we added a good match for each. + assert method_evaluator.benchmark_exact_matching_within_ionmode(method, "positive") == 1.0 + assert method_evaluator.exact_matches_across_ionization_modes(method) == 1.0 diff --git a/tests/test_methods.py b/tests/test_methods.py new file mode 100644 index 0000000..03cca1c --- /dev/null +++ b/tests/test_methods.py @@ -0,0 +1,58 @@ +import numpy as np +from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix + +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings +from ms2query.benchmarking.reference_methods.predict_highest_cosine import predict_highest_cosine +from tests.conftest import create_test_spectra, ms2deepscore_model +from ms2query.benchmarking.reference_methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore +from ms2query.benchmarking.reference_methods.predict_best_possible_match import predict_best_possible_match +from ms2query.benchmarking.reference_methods.predict_with_integrated_similarity_flow import ( + predict_with_integrated_similarity_flow, + integrated_similarity_flow, +) + +import pytest + + +@pytest.mark.parametrize( + "prediction_function", + [ + predict_highest_cosine, + predict_highest_ms2deepscore, + predict_best_possible_match, + ], +) +def test_all_methods(prediction_function): + model = ms2deepscore_model() + library_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(), model) + test_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(1), model) + predicted_inchikeys, scores = prediction_function(library_spectra, test_spectra) + for i, spectrum in enumerate(test_spectra.spectra): + inchikey = spectrum.get("inchikey")[:14] + assert predicted_inchikeys[i] == inchikey + assert np.allclose(scores[i], np.array(1.0), atol=1e-5) + + +def test_predict_with_integrated_similarity_flow(): + model = ms2deepscore_model() + library_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(), model) + test_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(1), model) + predicted_inchikeys, scores = predict_with_integrated_similarity_flow(library_spectra, test_spectra) + + assert predicted_inchikeys == ["RYYVLZVUVIJVGH", "ZPUCINDJVBIVPJ", "ZPUCINDJVBIVPJ"] + assert np.allclose(np.array([0.38829751082577607, 0.3919729335980483, 0.38774130710967564]), np.array(scores)) + + +def test_isf_computation(): + distances = [0.99, 0.99, 0.99, 0.5, 0.5] + fps = np.array([[1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 0, 1], [1, 1, 1, 1, 0, 0], [0, 0, 1, 1, 0, 0], [0, 0, 1, 0, 1, 0]]) + similarities = jaccard_similarity_matrix(fps, fps) + result = integrated_similarity_flow(distances, similarities, [1, 1, 1, 1, 1]) + expected_result = [ + 0.715869, + 0.686481, + 0.736523, + 0.492107, + 0.359110, + ] + assert np.allclose(np.array(expected_result), np.array(result)) From 4d81ddf2ab12264092d82b21eaac38b720dc2224 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Fri, 5 Dec 2025 11:16:25 +0100 Subject: [PATCH 06/45] Add inits --- ms2query/benchmarking/__init__.py | 0 ms2query/benchmarking/reference_methods/__init__.py | 0 tests/__init__.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 ms2query/benchmarking/__init__.py create mode 100644 ms2query/benchmarking/reference_methods/__init__.py create mode 100644 tests/__init__.py diff --git a/ms2query/benchmarking/__init__.py b/ms2query/benchmarking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ms2query/benchmarking/reference_methods/__init__.py b/ms2query/benchmarking/reference_methods/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 From 8024263fcfc63bfb4f8c3bc4939abb1332b79676 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Mon, 8 Dec 2025 16:08:25 +0100 Subject: [PATCH 07/45] Use ms2deepscore 2.6.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3fe7fa4..396af2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ pandas = ">=2.1.1" scipy= ">=1.14.0" matplotlib= ">=3.8.0" matchms= ">=0.30.0" -ms2deepscore= { git = "https://github.com/matchms/ms2deepscore.git", branch = "pytorch_update" } +ms2deepscore= ">=2.6.0" rdkit= ">2024.3.4" nmslib= ">=2.0.0" umap-learn= ">=0.5.7" From a9db98eaa9fed7dd72703696d22ce5556b710f88 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Mon, 8 Dec 2025 19:05:10 +0100 Subject: [PATCH 08/45] add new helper methods --- ms2query/database/compound_database.py | 33 ++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/ms2query/database/compound_database.py b/ms2query/database/compound_database.py index 80b931f..afc5b98 100644 --- a/ms2query/database/compound_database.py +++ b/ms2query/database/compound_database.py @@ -570,6 +570,39 @@ def get_compounds(self, comp_ids: List[str]) -> pd.DataFrame: .drop(columns="__order") .reset_index(drop=True)) + def get_all_compound_ids(self) -> List[str]: + """Return all compound IDs in ascending comp_id order. + """ + rows = self._conn.execute(f""" + SELECT comp_id FROM {self.table} + ORDER BY comp_id ASC + """).fetchall() + return [row["comp_id"] for row in rows] + + def get_all_fingerprints_and_comp_ids(self) -> Dict[str, AnyFP]: + """Return all compound IDs and their fingerprints in ascending comp_id order. + """ + rows = self._conn.execute(f""" + SELECT comp_id, fingerprint_bits, fingerprint_counts, fingerprint_dense + FROM {self.table} + ORDER BY comp_id ASC + """).fetchall() + comp_ids = [] + fps = [] + for row in rows: + comp_ids.append(row["comp_id"]) + dense_blob = row["fingerprint_dense"] or b"" + bits_blob = row["fingerprint_bits"] or b"" + counts_blob = row["fingerprint_counts"] or b"" + if dense_blob: + fps.append(decode_dense_fp(dense_blob, dtype=self.fingerprint_dtype_dense)) + elif bits_blob or counts_blob: + bits, counts = decode_sparse_fp(bits_blob, counts_blob) + fps.append(bits if counts.size == 0 else (bits, counts)) + else: + fps.append(None) + return {"comp_ids": comp_ids, "fingerprints": fps} + def sql_query(self, query: str) -> pd.DataFrame: return pd.read_sql_query(query, self._conn) From 316531a0fb6cd39eaf9a254099319a85606dd50c Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Mon, 8 Dec 2025 22:49:53 +0100 Subject: [PATCH 09/45] improve by batch querying --- ms2query/database/ann_vector_index.py | 62 +++++++++++++++++++++------ 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/ms2query/database/ann_vector_index.py b/ms2query/database/ann_vector_index.py index 240c628..433907c 100644 --- a/ms2query/database/ann_vector_index.py +++ b/ms2query/database/ann_vector_index.py @@ -323,28 +323,63 @@ def _create_hnsw_index( def query( self, - vector: np.ndarray, + vectors: np.ndarray, k: int = 10, ef: Optional[int] = None, - ) -> List[Tuple[str, float]]: + num_threads: int = 0, + ) -> List[Tuple[str, float]] | List[List[Tuple[str, float]]]: """ Query for k nearest neighbors. - Returns list of (spec_id, similarity) tuples. + Parameters + ---------- + vectors : np.ndarray + Either a single vector of shape (dim,) or a batch of shape (N, dim). + k : int + Number of neighbors. + ef : Optional[int] + Optional per-query ef parameter for HNSW. + num_threads : int + Number of threads to use inside nmslib (0 = library default). + + Returns + ------- + Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]] + - If a single vector is given, returns a list of (spec_id, similarity). + - If a batch is given, returns a list (per query) of such lists. """ if self._index is None: raise RuntimeError("Index not built or loaded.") - v = np.asarray(vector, dtype=np.float32).reshape(1, -1) - if v.shape[1] != self.dim: - raise ValueError(f"Query must have dim={self.dim}") + X = np.asarray(vectors, dtype=np.float32) + + single = False + if X.ndim == 1: + # Single query vector: (dim,) -> (1, dim) + if X.size != self.dim: + raise ValueError(f"Query must have dim={self.dim}") + X = X.reshape(1, -1) + single = True + elif X.ndim == 2: + if X.shape[1] != self.dim: + raise ValueError(f"Expected shape (N, {self.dim}), got {X.shape}") + else: + raise ValueError("vectors must be 1D or 2D array.") if ef is not None: self._index.setQueryTimeParams({"ef": ef}) - idxs, dists = self._index.knnQueryBatch(v, k=k)[0] - sims = 1.0 - np.asarray(dists, dtype=np.float32) # cosine distance -> similarity - return [(str(self._ids[i]), float(sims[j])) for j, i in enumerate(idxs)] + batch_results = self._index.knnQueryBatch(X, k=k, num_threads=num_threads) + + all_out: List[List[Tuple[str, float]]] = [] + for idxs, dists in batch_results: + idxs = np.asarray(idxs, dtype=np.int64) + dists = np.asarray(dists, dtype=np.float32) + sims = 1.0 - dists # cosine distance -> similarity + out = [(str(self._ids[i]), float(s)) for i, s in zip(idxs, sims)] + all_out.append(out) + + return all_out[0] if single else all_out def save_index(self, path_prefix: str) -> None: if self._index is None: @@ -487,14 +522,14 @@ def query( # Without re-ranking, return cosine similarities if not re_rank or self._csr is None or self._l1 is None: sims = 1.0 - dists - return [(int(self._comp_ids[i]), float(s)) for i, s in zip(idxs[:k], sims[:k])] + return [(self._comp_ids[i], float(s)) for i, s in zip(idxs[:k], sims[:k])] # Re-rank with exact Tanimoto Y = self._csr[idxs] tan = tanimoto_l1_query_vs_block(q, Y, sum1=float(q.sum()), sumsY=self._l1[idxs]) order = np.argsort(-tan)[:k] - return [(int(self._comp_ids[idxs[i]]), float(tan[i])) for i in order] + return [(self._comp_ids[idxs[i]], float(tan[i])) for i in order] def _normalize_query(self, query_fp) -> sp.csr_matrix: """Convert query to single-row CSR and validate.""" @@ -602,10 +637,11 @@ def save_index(self, path_prefix: str) -> None: if self._index is None: raise RuntimeError("Index not built.") - self._index.saveIndex(f"{path_prefix}.nmslib") + # Also save data so that loadIndex(..., load_data=True) works + self._index.saveIndex(f"{path_prefix}.nmslib", save_data=True) np.save(f"{path_prefix}.ids.npy", self._comp_ids) - meta = {**self._meta, "dim": self.dim, "space": self.space} + meta = {**self._meta, "dim": int(self.dim), "space": str(self.space)} with open(f"{path_prefix}.meta.json", "w") as f: json.dump(meta, f) From c827742cdfd2f13759cb86fc95277efaa0dfcae8 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Mon, 8 Dec 2025 22:53:30 +0100 Subject: [PATCH 10/45] adapt batch method and add further helper --- ms2query/ms2query_library.py | 77 +++++++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 15 deletions(-) diff --git a/ms2query/ms2query_library.py b/ms2query/ms2query_library.py index 84ea6d3..786e3f8 100644 --- a/ms2query/ms2query_library.py +++ b/ms2query/ms2query_library.py @@ -32,7 +32,7 @@ class MS2QueryLibrary: """ db: MS2QueryDatabase embedding_index: Optional[EmbeddingIndex] = None - fingerprint_index: Optional[FingerprintSparseIndex] = None + fingerprint_index: Optional[FingerprintSparseIndex] = None # for now: reference spectra only model_path: Optional[str] = None # internal: whether to apply spectrum normalization (sum=1) before embedding @@ -117,19 +117,17 @@ def query_embedding_index( """ self._ensure_index() spectra = _ensure_spectra_list(spectra) - - # Compute embeddings (L2-normalized) embeddings = self.compute_embeddings(spectra) + # batched call + batch_hits = self.embedding_index.query(embeddings, k=k, ef=ef) + results_all: List[List[Dict[str, Any]]] = [] - for qi in range(embeddings.shape[0]): - # TODO: make faster by querying batch-wise - # EmbeddingIndex.query returns list[(spec_id, similarity)] - hits = self.embedding_index.query(embeddings[qi], k=k, ef=ef) - # convert to standard structure - one = [] - for rk, (spec_id, score) in enumerate(hits, start=1): - one.append({"rank": rk, "spec_id": spec_id, "score": float(score)}) + for hits in batch_hits: + one = [ + {"rank": rk + 1, "spec_id": spec_id, "score": float(score)} + for rk, (spec_id, score) in enumerate(hits) + ] results_all.append(one) if not return_dataframe: @@ -147,7 +145,7 @@ def query_spectra_by_spectra( spectra: list[Spectrum], *, k_spectra: int = 10, - ef: Optional[int] = None, + ef: Optional[int] = None, ): """ Query the embedding index with spectra, return top-k_spectra per spectrum. @@ -166,7 +164,45 @@ def query_spectra_by_spectra( # Query spectral embeddings return self.query_embedding_index(spectra, k=k_spectra, ef=ef) + + def query_compounds_by_compounds( + self, + compounds: list, + *, + k_compounds: int = 10, + ): + """ + Query the fingerprint index with compounds, return top-k compounds per compound. + + Parameters + ---------- + compounds : list + Query compounds (expects list of SMILES strings). + k_compounds : int + Number of top compounds to return per query compound. + """ + if self.fingerprint_index is None: + raise RuntimeError("FingerprintSparseIndex is not set. Build or load it before querying.") + + # Compute fingerprints + fps = self.db.all_cdb.compute_fingerprints( + compounds, + count=False, + sparse=True, + ) + + results_all: List[List[Dict[str, Any]]] = [] + for qi, fp in enumerate(fps): + # FingerprintSparseIndex.query returns list[(comp_id, similarity)] + hits = self.fingerprint_index.query(fp, k=k_compounds) + # convert to standard structure + one = [] + for rk, (comp_id, score) in enumerate(hits, start=1): + one.append({"rank": rk, "comp_id": comp_id, "score": float(score)}) + results_all.append(one) + return results_all + def query_compounds_by_spectra( self, spectra: list[Spectrum], @@ -192,7 +228,7 @@ def query_compounds_by_spectra( if k_compounds > k_spectra: raise ValueError("k_compounds cannot be larger than k_spectra") - # Step1: Query spectral embeddings + # Query spectral embeddings results = self.query_spectra_by_spectra(spectra, k_spectra=k_spectra, ef=ef) # Pick k_compounds top compounds from the k_spectra hits (if possible) @@ -219,14 +255,25 @@ def query_compounds_by_spectra( def analogue_search( self, spectra: list[Spectrum], + *, + ef: Optional[int] = None, ): """ Perform an analogue search for the given spectra. TODO: implement analogue search logic here. """ - top_compounds = self.query_compounds_by_spectra(spectra) + # Query spectral embeddings (only top-1 spectra) + results = self.query_spectra_by_spectra(spectra, k_spectra=1, ef=ef) + spec_ids = results.spec_id.values + + # Get compounds of all retrieved spectra + analogue_compounds = self.db.metadata_by_spec_ids([x for x in spec_ids]).set_index("spec_id") + top_compounds = self.query_compounds_by_compounds(analogue_compounds.smiles.tolist()) + + #top_compounds = self.query_compounds_by_spectra(spectra) # TODO: implement analogue search logic here - return top_compounds.drop_duplicates("query_ix") + #return top_compounds.drop_duplicates("query_ix") + return self.query_compounds_by_compounds # ----------------------------- helpers / optional glue ----------------------------- From b0b7748bd0c19c824bcf425e867f32c8f3dab7c3 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Mon, 8 Dec 2025 22:55:09 +0100 Subject: [PATCH 11/45] further batch querying method update --- ms2query/database/ann_vector_index.py | 132 ++++++++++++++++++++++---- 1 file changed, 111 insertions(+), 21 deletions(-) diff --git a/ms2query/database/ann_vector_index.py b/ms2query/database/ann_vector_index.py index 433907c..44c06c0 100644 --- a/ms2query/database/ann_vector_index.py +++ b/ms2query/database/ann_vector_index.py @@ -485,51 +485,141 @@ def build_index( def query( self, - query_fp: Tuple[np.ndarray, np.ndarray] | sp.csr_matrix, + query_fp: ( + Tuple[np.ndarray, np.ndarray] + | sp.csr_matrix + | Sequence[Tuple[np.ndarray, np.ndarray]] + ), k: int = 10, *, ef: Optional[int] = None, re_rank: bool = True, candidate_multiplier: int = 5, - ) -> List[Tuple[int, float]]: + num_threads: int = 0, + ) -> List[Tuple[int, float]] | List[List[Tuple[int, float]]]: """ Query for k nearest neighbors. Parameters ---------- - query_fp : (indices, values) tuple or single-row CSR - k : Number of results - re_rank : Use exact Tanimoto re-ranking - candidate_multiplier : Fetch k * multiplier candidates for re-ranking + query_fp : + - Single query: + * (indices, values) tuple + * single-row CSR of shape (1, dim) + - Batched queries: + * CSR of shape (N, dim) + * Sequence of (indices, values) tuples + k : int + Number of results per query. + re_rank : bool + Use exact Tanimoto re-ranking. + candidate_multiplier : int + Fetch k * multiplier candidates for re-ranking. + num_threads : int + Number of threads to use inside nmslib (0 = library default). - Returns list of (comp_id, similarity) tuples. + Returns + ------- + Union[List[Tuple[int, float]], List[List[Tuple[int, float]]]] + - For a single query, returns a list of (comp_id, similarity). + - For multiple queries, returns a list (per query) of such lists. """ if self._index is None: raise RuntimeError("Index not built or loaded.") - q = self._normalize_query(query_fp) - if q.nnz == 0: - return [] + # ------------------------- + # Normalize input to CSR + # ------------------------- + single = False + + if isinstance(query_fp, sp.csr_matrix): + Q = query_fp.astype(np.float32, copy=False) + if Q.shape[1] != self.dim: + raise ValueError(f"CSR query must have shape (N, {self.dim})") + single = Q.shape[0] == 1 + + elif isinstance(query_fp, tuple): + # Single (indices, values) + Q = csr_row_from_tuple(query_fp, dim=self.dim) + single = True + + else: + # Assume sequence of (indices, values) tuples -> batched queries + Q = tuples_to_csr(query_fp, dim=self.dim) + single = Q.shape[0] == 1 + + if (Q.data < 0).any(): + raise ValueError("Query must be non-negative for Tanimoto.") + + # Handle completely empty queries quickly + row_nnz = Q.indptr[1:] - Q.indptr[:-1] + if row_nnz.sum() == 0: + if single: + return [] + return [[] for _ in range(Q.shape[0])] if ef is not None: self._index.setQueryTimeParams({"ef": ef}) fetch = max(k, k * candidate_multiplier) - idxs, dists = self._index.knnQueryBatch(q, k=fetch)[0] - idxs = np.asarray(idxs, dtype=np.int64) - dists = np.asarray(dists, dtype=np.float32) - # Without re-ranking, return cosine similarities + # ------------------------- + # ANN search for all queries + # ------------------------- + batch_results = self._index.knnQueryBatch(Q, k=fetch, num_threads=num_threads) + + # ------------------------- + # No re-ranking: cosine sims only + # ------------------------- if not re_rank or self._csr is None or self._l1 is None: - sims = 1.0 - dists - return [(self._comp_ids[i], float(s)) for i, s in zip(idxs[:k], sims[:k])] + all_out: List[List[Tuple[int, float]]] = [] + + for qi, (idxs, dists) in enumerate(batch_results): + if row_nnz[qi] == 0: + all_out.append([]) + continue + + idxs = np.asarray(idxs, dtype=np.int64) + dists = np.asarray(dists, dtype=np.float32) + + sims = 1.0 - dists + out = [ + (self._comp_ids[i], float(s)) + for i, s in zip(idxs[:k], sims[:k]) + ] + all_out.append(out) + + return all_out[0] if single else all_out + + # ------------------------- + # Exact Tanimoto re-ranking + # ------------------------- + all_out: List[List[Tuple[int, float]]] = [] - # Re-rank with exact Tanimoto - Y = self._csr[idxs] - tan = tanimoto_l1_query_vs_block(q, Y, sum1=float(q.sum()), sumsY=self._l1[idxs]) - order = np.argsort(-tan)[:k] + for qi, (idxs, dists) in enumerate(batch_results): + if row_nnz[qi] == 0: + all_out.append([]) + continue - return [(self._comp_ids[idxs[i]], float(tan[i])) for i in order] + idxs = np.asarray(idxs, dtype=np.int64) + + q_row = Q[qi] + Y = self._csr[idxs] + tan = tanimoto_l1_query_vs_block( + q_row, + Y, + sum1=float(q_row.sum()), + sumsY=self._l1[idxs], + ) + + order = np.argsort(-tan)[:k] + out = [ + (self._comp_ids[idxs[i]], float(tan[i])) + for i in order + ] + all_out.append(out) + + return all_out[0] if single else all_out def _normalize_query(self, query_fp) -> sp.csr_matrix: """Convert query to single-row CSR and validate.""" From 92b434cd43b262d08dfff11b3e4a8d18a9d60be2 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Mon, 8 Dec 2025 22:55:53 +0100 Subject: [PATCH 12/45] add fingerprint index --- ms2query/library_io.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/ms2query/library_io.py b/ms2query/library_io.py index 86815c0..d53e7f8 100644 --- a/ms2query/library_io.py +++ b/ms2query/library_io.py @@ -8,7 +8,7 @@ from tqdm import tqdm from ms2query import MS2QueryDatabase, MS2QueryLibrary from ms2query.data_processing.merging_utils import cluster_block, get_merged_spectra -from ms2query.database import EmbeddingIndex +from ms2query.database import EmbeddingIndex, FingerprintSparseIndex from ms2query.database.spectra_merging import _split_by_mode_charge @@ -18,6 +18,7 @@ _SQLITE_NAME = "ms2query_library.sqlite" _EMB_TABLE = "embeddings" _EMB_INDEX_BASENAME = "embedding_index" # will create embedding_index.{nmslib,ids.npy,meta.json} +_FP_INDEX_BASENAME = "fingerprint_index" # will create fingerprint_index.{nmslib,ids.npy,meta.json} def _handle_default_settings(settings: dict) -> dict: @@ -81,6 +82,7 @@ def create_new_library( model_path: str, additional_compound_file: Optional[str] = None, build_embedding_index: bool = True, + build_fingerprint_index: bool = True, embedding_index_params: Optional[dict] = None, compute_embeddings_batch_rows: int = 4096, **settings, @@ -101,6 +103,8 @@ def create_new_library( CSV/TSV file with additional compounds (inchikey/smiles/etc.). No fingerprints assumed here. build_embedding_index : bool Whether to build the nmslib cosine HNSW index over embeddings. + build_fingerprint_index : bool + Whether to build the FingerprintSparseIndex over compound fingerprints. embedding_index_params : dict Params for HNSW: {'M': int, 'ef_construction': int, 'post_init_ef': int, 'batch_rows': int} compute_embeddings_batch_rows : int @@ -139,6 +143,8 @@ def create_new_library( _print_progress(f"Inserted {creation_stats['n_inserted_spectra']} spectra.") _print_progress(f"Mapped {creation_stats['n_mapped']} spectra to compounds; " f"created {creation_stats['n_new_compounds']} new compounds.") + stats = ms2query_db.ref_cdb.compute_fingerprints_missing() + _print_progress(f"Computed fingerprints for {stats['updated']} compounds.") if additional_compound_file is not None: if not additional_compound_file.lower().endswith((".csv", ".tsv", ".txt")): @@ -199,6 +205,23 @@ def create_new_library( _print_progress(f"Saved EmbeddingIndex files with prefix: {emb_prefix}") lib.set_embedding_index(emb_index) + if build_fingerprint_index: + # TODO: this is not efficient yet; improve later + _print_progress("Building FingerprintSparseIndex ...") + results = lib.db.ref_cdb.get_all_fingerprints_and_comp_ids() + max_bits = [x[0][-1] for x in results["fingerprints"]] + fp_index = FingerprintSparseIndex(dim=int(max(max_bits) + 1)) + + fp_index.build_index( + results["fingerprints"], + results["comp_ids"], + ) + fp_prefix = str(out_dir / _FP_INDEX_BASENAME) + fp_index.save_index(fp_prefix) + _print_progress(f"Saved FingerprintSparseIndex files with prefix: {fp_prefix}") + lib.set_fingerprint_index(fp_index) + + # ----------------------------- # Manifest # ----------------------------- @@ -208,6 +231,7 @@ def create_new_library( "embedding_table": _EMB_TABLE, "model_path": model_path, # stored for convenience; not copied "embedding_index_prefix": _EMB_INDEX_BASENAME if build_embedding_index else None, + "fingerprint_index_prefix": _FP_INDEX_BASENAME if build_fingerprint_index else None, "settings": settings, } with open(out_dir / _MANIFEST_NAME, "w", encoding="utf-8") as f: @@ -262,6 +286,13 @@ def load_created_library(folder: str) -> MS2QueryLibrary: emb_index = EmbeddingIndex() emb_index.load_index(str(out_dir / emb_prefix)) lib.set_embedding_index(emb_index) + + # Load FingerprintSparseIndex if present + fp_prefix = manifest.get("fingerprint_index_prefix") + if fp_prefix: + fp_index = FingerprintSparseIndex() + fp_index.load_index(str(out_dir / fp_prefix)) + lib.set_fingerprint_index(fp_index) # (Optional) Load fingerprint index here if/when you add it later. From bf9bbf290b5adb58c16c1fb076b518dc5204151a Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Mon, 8 Dec 2025 23:06:23 +0100 Subject: [PATCH 13/45] larger linting/cleaning and batch implementation --- ms2query/ms2query_library.py | 220 ++++++++++++++++++++++------------- 1 file changed, 141 insertions(+), 79 deletions(-) diff --git a/ms2query/ms2query_library.py b/ms2query/ms2query_library.py index 786e3f8..ea6e316 100644 --- a/ms2query/ms2query_library.py +++ b/ms2query/ms2query_library.py @@ -37,10 +37,13 @@ class MS2QueryLibrary: # internal: whether to apply spectrum normalization (sum=1) before embedding _spectrum_sum_normalization_for_embedding: bool = True + # cached MS2DeepScore model _model: Any = field(default=None, init=False, repr=False) - # ----------------------------- lifecycle ----------------------------- + # ------------------------------------------------------------------ + # Lifecycle / internal helpers + # ----------------------------------------------------------------- def _ensure_model(self): """Lazy-load MS2DeepScore model if a model_path was provided.""" @@ -54,11 +57,21 @@ def _ensure_model(self): self._model.eval() return self._model - def _ensure_index(self): + def _ensure_embedding_index(self): if self.embedding_index is None: - raise RuntimeError("EmbeddingIndex is not set. Build or load it before querying.") + raise RuntimeError( + "EmbeddingIndex is not set. Build or load it before querying." + ) + + def _ensure_fingerprint_index(self): + if self.fingerprint_index is None: + raise RuntimeError( + "FingerprintSparseIndex is not set. Build or load it before querying." + ) - # ----------------------------- core API ----------------------------- + # ------------------------------------------------------------------ + # Core API: spectra -> embeddings -> ANN over embeddings + # ----------------------------------------------------------------- def process_spectra(self, spectra: list[Spectrum]) -> List[Spectrum]: """ @@ -68,25 +81,24 @@ def process_spectra(self, spectra: list[Spectrum]) -> List[Spectrum]: # Hook point: insert matchms pipeline later (e.g., metadata fixes, peak processing, etc.) return list(spectra) - def compute_embeddings(self, spectra: list[Spectrum]) -> np.ndarray: + def compute_embeddings(self, spectra: Sequence[Spectrum]) -> np.ndarray: """ Compute MS2DeepScore embeddings for arbitrary query spectra. Spectra will be preprocessed via self.process_spectra(...) first. """ + spectra = _ensure_spectra_list(spectra) if not spectra: return np.empty((0, 0), dtype=np.float32) model = self._ensure_model() - - # Preprocess — keep spectral normalization symmetrical with DB embeddings spectra = self.process_spectra(spectra) - # Compute embeddings return compute_spectra_embeddings( - model, spectra, - normalize_spectrum=self._spectrum_sum_normalization_for_embedding - ) + model, + spectra, + normalize_spectrum=self._spectrum_sum_normalization_for_embedding, + ) def query_embedding_index( self, @@ -100,12 +112,12 @@ def query_embedding_index( Process spectra -> embed -> query EmbeddingIndex. Returns per-query results as a list of lists with dicts: - [{'rank':1, 'spec_id': '...', 'score': float}, ...] + [{'rank': 1, 'spec_id': '...', 'score': float}, ...] All IDs are **spec_id** strings (NOT internal index ids). Parameters ---------- - spectra : Spectrum | list[Spectrum] + spectra : Spectrum | Sequence[Spectrum] Query spectra. k : int Top-k to return. @@ -113,13 +125,18 @@ def query_embedding_index( nmslib ef (higher = better recall / slower). return_dataframe : bool If True, returns a tidy DataFrame with columns: - ['query_ix','rank','spec_id','score'] + ['query_ix', 'rank', 'spec_id', 'score'] """ - self._ensure_index() + self._ensure_embedding_index() spectra = _ensure_spectra_list(spectra) embeddings = self.compute_embeddings(spectra) - # batched call + if embeddings.size == 0: + return ( + [] if not return_dataframe else self._empty_result_df() + ) + + # Batched call: EmbeddingIndex.query returns List[List[(spec_id, score)]] batch_hits = self.embedding_index.query(embeddings, k=k, ef=ef) results_all: List[List[Dict[str, Any]]] = [] @@ -133,90 +150,91 @@ def query_embedding_index( if not return_dataframe: return results_all - rows = [] + rows: List[Dict[str, Any]] = [] for qi, lst in enumerate(results_all): for item in lst: rows.append({"query_ix": qi, **item}) - df = pd.DataFrame(rows, columns=["query_ix", "rank", "spec_id", "score"]) - return df + return pd.DataFrame(rows, columns=["query_ix", "rank", "spec_id", "score"]) def query_spectra_by_spectra( self, - spectra: list[Spectrum], + spectra: Union[Spectrum, Sequence[Spectrum]], *, k_spectra: int = 10, ef: Optional[int] = None, - ): + ): """ Query the embedding index with spectra, return top-k_spectra per spectrum. Parameters ---------- - spectra : list[Spectrum] + spectra : list[Spectrum] or Spectrum Query spectra. k_spectra : int Number of top spectra to retrieve from the embedding index. ef : Optional[int] nmslib ef parameter (higher = better recall / slower). """ - self._ensure_index() - spectra = _ensure_spectra_list(spectra) + return self.query_embedding_index( + spectra, k=k_spectra, ef=ef, return_dataframe=True + ) + + # ------------------------------------------------------------------ + # Core API: compounds / fingerprints + # ------------------------------------------------------------------ - # Query spectral embeddings - return self.query_embedding_index(spectra, k=k_spectra, ef=ef) - def query_compounds_by_compounds( self, - compounds: list, + compounds: Sequence[str], *, k_compounds: int = 10, - ): + ) -> List[List[Dict[str, Any]]]: """ Query the fingerprint index with compounds, return top-k compounds per compound. Parameters ---------- - compounds : list + compounds : Sequence[str] Query compounds (expects list of SMILES strings). k_compounds : int Number of top compounds to return per query compound. """ - if self.fingerprint_index is None: - raise RuntimeError("FingerprintSparseIndex is not set. Build or load it before querying.") + self._ensure_fingerprint_index() - # Compute fingerprints + # Compute fingerprints (sparse representation) fps = self.db.all_cdb.compute_fingerprints( compounds, count=False, sparse=True, ) + # Batched fingerprint ANN query + batch_hits = self.fingerprint_index.query(fps, k=k_compounds) + results_all: List[List[Dict[str, Any]]] = [] - for qi, fp in enumerate(fps): - # FingerprintSparseIndex.query returns list[(comp_id, similarity)] - hits = self.fingerprint_index.query(fp, k=k_compounds) - # convert to standard structure - one = [] - for rk, (comp_id, score) in enumerate(hits, start=1): - one.append({"rank": rk, "comp_id": comp_id, "score": float(score)}) + for hits in batch_hits: + one = [ + {"rank": rk + 1, "comp_id": comp_id, "score": float(score)} + for rk, (comp_id, score) in enumerate(hits) + ] results_all.append(one) return results_all def query_compounds_by_spectra( self, - spectra: list[Spectrum], + spectra: Union[Spectrum, Sequence[Spectrum]], *, k_spectra: int = 100, k_compounds: int = 10, ef: Optional[int] = None, - ): + ) -> pd.DataFrame: """ - Query the embedding index with spectra, return top-k_compounds per spectrum. + Query the embedding index with spectra, then aggregate to compounds. Parameters ---------- - spectra : list[Spectrum] + spectra : list[Spectrum] or Spectrum Query spectra. k_spectra : int Number of top spectra to retrieve from the embedding index. @@ -229,54 +247,81 @@ def query_compounds_by_spectra( raise ValueError("k_compounds cannot be larger than k_spectra") # Query spectral embeddings - results = self.query_spectra_by_spectra(spectra, k_spectra=k_spectra, ef=ef) + results = self.query_spectra_by_spectra( + spectra, k_spectra=k_spectra, ef=ef + ) # DataFrame + + if results.empty: + return results # Pick k_compounds top compounds from the k_spectra hits (if possible) - spec_ids = results.spec_id.values + spec_ids = results["spec_id"].values - compounds = self.db.metadata_by_spec_ids([x for x in spec_ids]).set_index("spec_id") - compounds = compounds.merge(results, on="spec_id").sort_values(["query_ix", "rank"]) + compounds = ( + self.db.metadata_by_spec_ids(list(spec_ids)) + .set_index("spec_id") + ) + + compounds = ( + compounds.merge(results, on="spec_id") + .sort_values(["query_ix", "rank"]) + ) # Pick no more than k_compounds per query_ix - idx = compounds.groupby(['query_ix', 'rank'])['score'].idxmax() + idx = compounds.groupby(["query_ix", "rank"])["score"].idxmax() best_per_pair = compounds.loc[idx] # Within each query_ix, keep the top-k by score df_selected = ( - best_per_pair - .sort_values(['query_ix', 'score'], ascending=[True, False]) - .groupby('query_ix', group_keys=False) + best_per_pair.sort_values(["query_ix", "score"], ascending=[True, False]) + .groupby("query_ix", group_keys=False) .head(k_compounds) .reset_index(drop=True) ) - return df_selected def analogue_search( self, - spectra: list[Spectrum], + spectra: Union[Spectrum, Sequence[Spectrum]], *, + k_spectra: int = 1, + k_compounds: int = 10, ef: Optional[int] = None, - ): + ): """ Perform an analogue search for the given spectra. - TODO: implement analogue search logic here. + + Current behaviour: + - For each query spectrum, retrieve top-`k_spectra` library spectra. + - Get their compounds. + - Run compound-by-compound search in fingerprint space. """ - # Query spectral embeddings (only top-1 spectra) - results = self.query_spectra_by_spectra(spectra, k_spectra=1, ef=ef) - spec_ids = results.spec_id.values + # Step 1: top-k_spectra per query + spec_hits = self.query_spectra_by_spectra( + spectra, k_spectra=k_spectra, ef=ef + ) # DataFrame + if spec_hits.empty: + return [] + + spec_ids = spec_hits["spec_id"].values + + # Step 2: get compounds of all retrieved spectra + analogue_compounds = ( + self.db.metadata_by_spec_ids(list(spec_ids)) + .set_index("spec_id") + ) - # Get compounds of all retrieved spectra - analogue_compounds = self.db.metadata_by_spec_ids([x for x in spec_ids]).set_index("spec_id") - top_compounds = self.query_compounds_by_compounds(analogue_compounds.smiles.tolist()) + smiles = analogue_compounds["smiles"].tolist() - #top_compounds = self.query_compounds_by_spectra(spectra) - # TODO: implement analogue search logic here - #return top_compounds.drop_duplicates("query_ix") - return self.query_compounds_by_compounds + # Step 3: fingerprint-based compound search + top_compounds = self.query_compounds_by_compounds( + smiles, k_compounds=k_compounds + ) + return top_compounds - - # ----------------------------- helpers / optional glue ----------------------------- + # ------------------------------------------------------------------ + # Helpers / glue + # ------------------------------------------------------------------ def set_embedding_index(self, index: EmbeddingIndex) -> None: """Attach or replace the EmbeddingIndex.""" @@ -287,43 +332,60 @@ def set_fingerprint_index(self, index: FingerprintSparseIndex) -> None: self.fingerprint_index = index def query_by_spec_ids( - self, spec_ids: List[str], *, k: int = 10, ef: Optional[int] = None, return_dataframe: bool = False + self, + spec_ids: List[str], + *, + k: int = 10, + ef: Optional[int] = None, + return_dataframe: bool = False, ): """ Convenience: fetch embeddings for known spec_ids from SQLite and search. Requires that the embeddings are present in DB (table 'embeddings'). """ - if self.embedding_index is None: - raise RuntimeError("EmbeddingIndex is not set. Build or load it before querying.") + self._ensure_embedding_index() - # Pull precomputed embeddings from DB (already L2-normalized in SpectralDatabase.get_embeddings) - ids, X = self.db.ref_sdb.get_embeddings(ids=spec_ids, embeddings_table="embeddings", normalized=True) + # Pull precomputed embeddings from DB (already L2-normalized) + ids, X = self.db.ref_sdb.get_embeddings( + ids=spec_ids, + embeddings_table="embeddings", + normalized=True, + ) + + # If DB returns nothing, keep the old API behaviour if X.size == 0: return [] if not return_dataframe else self._empty_result_df() + # X is 2D: use batched query + batch_hits = self.embedding_index.query(X, k=k, ef=ef) + results_all: List[List[Dict[str, Any]]] = [] - for qi in range(X.shape[0]): - hits = self.embedding_index.query(X[qi], k=k, ef=ef) - one = [{"rank": rk + 1, "spec_id": sid, "score": float(score)} for rk, (sid, score) in enumerate(hits)] + for hits in batch_hits: + one = [ + {"rank": rk + 1, "spec_id": sid, "score": float(score)} + for rk, (sid, score) in enumerate(hits) + ] results_all.append(one) if not return_dataframe: return results_all - rows = [] + rows: List[Dict[str, Any]] = [] for qi, lst in enumerate(results_all): for item in lst: rows.append({"query_ix": qi, **item}) return pd.DataFrame(rows, columns=["query_ix", "rank", "spec_id", "score"]) @staticmethod - def _empty_result_df(): + def _empty_result_df() -> pd.DataFrame: return pd.DataFrame(columns=["query_ix", "rank", "spec_id", "score"]) # ----------------- helper functions --------------------- -def _ensure_spectra_list(spectra: Union[Spectrum, Sequence[Spectrum]]) -> List[Spectrum]: +def _ensure_spectra_list( + spectra: Union[Spectrum, Sequence[Spectrum]] +) -> List[Spectrum]: if isinstance(spectra, Spectrum): return [spectra] if isinstance(spectra, Sequence): From 0c3c920d01b368fd49425443e13c16192f66f19d Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Tue, 9 Dec 2025 08:29:25 +0100 Subject: [PATCH 14/45] linting --- ms2query/database/spectral_database.py | 11 ++++++----- ms2query/ms2query_library.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/ms2query/database/spectral_database.py b/ms2query/database/spectral_database.py index 5676bd4..86deb1d 100644 --- a/ms2query/database/spectral_database.py +++ b/ms2query/database/spectral_database.py @@ -303,7 +303,7 @@ def flush(batch) -> int: def get_embeddings( self, - ids: Optional[List[str]] = None, + spec_ids: Optional[List[str]] = None, *, embeddings_table: str = "embeddings", normalized: bool = True, @@ -314,13 +314,14 @@ def get_embeddings( If normalized=True, L2-normalize (recommended for cosine). """ cur = self._conn.cursor() - if ids is None: + if spec_ids is None: cur.execute(f"SELECT spec_id, d, vec FROM {embeddings_table} ORDER BY spec_id ASC;") else: - ph = ",".join("?" for _ in ids) + placeholders = ",".join("?" for _ in spec_ids) cur.execute( - f"SELECT spec_id, d, vec FROM {embeddings_table} WHERE spec_id IN ({ph}) ORDER BY spec_id ASC;", - ids) + f"""SELECT spec_id, d, vec FROM {embeddings_table} + WHERE spec_id IN ({placeholders}) ORDER BY spec_id ASC;""", + spec_ids) sids: List[str] = [] vecs: List[np.ndarray] = [] diff --git a/ms2query/ms2query_library.py b/ms2query/ms2query_library.py index ea6e316..0130b26 100644 --- a/ms2query/ms2query_library.py +++ b/ms2query/ms2query_library.py @@ -346,8 +346,8 @@ def query_by_spec_ids( self._ensure_embedding_index() # Pull precomputed embeddings from DB (already L2-normalized) - ids, X = self.db.ref_sdb.get_embeddings( - ids=spec_ids, + _, X = self.db.ref_sdb.get_embeddings( + spec_ids=spec_ids, embeddings_table="embeddings", normalized=True, ) From a59cc9a9c2a6537dee0e71e0005955b86c92905a Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Tue, 9 Dec 2025 08:32:03 +0100 Subject: [PATCH 15/45] add more getters --- ms2query/ms2query_database.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ms2query/ms2query_database.py b/ms2query/ms2query_database.py index 8cbb035..b887174 100644 --- a/ms2query/ms2query_database.py +++ b/ms2query/ms2query_database.py @@ -113,6 +113,9 @@ def fragments_by_spec_ids(self, spec_ids: List[int]): def metadata_by_spec_ids(self, spec_ids: List[int]) -> pd.DataFrame: return self.ref_sdb.get_metadata_by_ids(spec_ids) + + def embeddings_by_spec_ids(self, spec_ids: List[int]): + return self.ref_sdb.get_embeddings(spec_ids=spec_ids) # ---- by comp_id (inchikey14) ---- @@ -128,6 +131,10 @@ def metadata_by_comp_id(self, comp_id: str) -> pd.DataFrame: def compound(self, comp_id: str) -> Optional[Dict[str, Any]]: return self.ref_cdb.get_compound(comp_id) + + def embeddings_by_comp_id(self, comp_id: str): + spec_ids = self.spec_ids_by_comp_id(comp_id) + return self.ref_sdb.get_embeddings(spec_ids=spec_ids) # -------------------------------- convenience SQL ------------------------------ From ff6e2b673fff582a92eb8c067e853a6b78cf4694 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Tue, 9 Dec 2025 09:58:25 +0100 Subject: [PATCH 16/45] fix by_ids methods and linting --- ms2query/database/spectral_database.py | 124 +++++++++++++++++-------- 1 file changed, 85 insertions(+), 39 deletions(-) diff --git a/ms2query/database/spectral_database.py b/ms2query/database/spectral_database.py index 86deb1d..225123a 100644 --- a/ms2query/database/spectral_database.py +++ b/ms2query/database/spectral_database.py @@ -57,10 +57,19 @@ def _normalize_metadata(md: Dict[str, Any], fields: Iterable[str]) -> Dict[str, class SpectralDatabase: sqlite_path: str table: str = "spectra" - metadata_fields: List[str] = field(default_factory=lambda: [ - "precursor_mz", "ionmode", "smiles", "inchikey", "inchi", "name", - "instrument_type", "adduct", "collision_energy" - ]) + metadata_fields: List[str] = field( + default_factory=lambda: [ + "precursor_mz", + "ionmode", + "smiles", + "inchikey", + "inchi", + "name", + "instrument_type", + "adduct", + "collision_energy", + ] + ) spectrum_sum_normalization_for_embedding: bool = True _conn: sqlite3.Connection = field(init=False, repr=False) _ms2ds_model_path: Optional[str] = field(default=None, repr=False) @@ -81,7 +90,8 @@ def add_spectra(self, spectra: List[Spectrum]) -> List[str]: cur = self._conn.cursor() # Bulk-load speed PRAGMAs (safe for single-user/batch ingest) - cur.executescript(""" + cur.executescript( + """ PRAGMA journal_mode=WAL; PRAGMA synchronous=OFF; PRAGMA temp_store=MEMORY; @@ -130,14 +140,14 @@ def ids(self) -> List[str]: rows = cur.execute(f"SELECT spec_id FROM {self.table}").fetchall() return [str(row["spec_id"]) for row in rows] - def get_spectra_by_ids(self, specIDs: List[str]) -> List[Spectrum]: - """Retrieve full Spectrum objects for given specIDs (order preserved, missing IDs skipped).""" + def get_spectra_by_ids(self, spec_ids: List[str]) -> List[Spectrum]: + """Retrieve full Spectrum objects for given spec_ids (order preserved, missing IDs skipped).""" rows = self._fetch_rows_by_ids( - specIDs, cols="spec_id, mz_blob, intensity_blob, n_peaks, " + ", ".join(self.metadata_fields)) + spec_ids, cols="spec_id, mz_blob, intensity_blob, n_peaks, " + ", ".join(self.metadata_fields)) by_id = {row["spec_id"]: row for row in rows} result: List[Spectrum] = [] - for sid in specIDs: + for sid in spec_ids: row = by_id.get(sid) if row is None: continue @@ -149,13 +159,19 @@ def get_spectra_by_ids(self, specIDs: List[str]) -> List[Spectrum]: result.append(Spectrum(mz=mz, intensities=inten, metadata=md)) return result - def get_fragments_by_ids(self, specIDs: List[str]) -> List[Tuple[np.ndarray, np.ndarray]]: - """Retrieve (mz, intensity) arrays for given specIDs (order preserved, missing IDs skipped).""" - rows = self._fetch_rows_by_ids(specIDs, cols="spec_id, mz_blob, intensity_blob, n_peaks") + def get_fragments_by_ids(self, spec_ids: List[str]) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Retrieve (mz, intensity) arrays for given spec_ids. + + Order is preserved with respect to `spec_ids`. + Missing IDs are skipped. + """ + cols = "spec_id, mz_blob, intensity_blob, n_peaks" + rows = self._fetch_rows_by_ids(spec_ids, cols=cols) by_id = {row["spec_id"]: row for row in rows} out: List[Tuple[np.ndarray, np.ndarray]] = [] - for sid in specIDs: + for sid in spec_ids: row = by_id.get(sid) if row is None: continue @@ -165,16 +181,34 @@ def get_fragments_by_ids(self, specIDs: List[str]) -> List[Tuple[np.ndarray, np. out.append((mz, inten)) return out - def get_metadata_by_ids(self, specIDs: List[str]) -> pd.DataFrame: - """Retrieve metadata for given specIDs (order preserved).""" + def get_metadata_by_ids(self, spec_ids: List[str]) -> pd.DataFrame: + """ + Retrieve metadata for given spec_ids. + + Returns a DataFrame with **one row per requested spec_id** in the same + order as `spec_ids`. If a spec_id is not present in the database, a row + with that spec_id and metadata columns set to None/NaN is returned. + """ cols = ["spec_id"] + self.metadata_fields - rows = self._fetch_rows_by_ids(specIDs, cols=", ".join(cols)) - df = pd.DataFrame(rows, columns=cols) - if not df.empty: - order = {sid: i for i, sid in enumerate(specIDs)} - df["__order"] = df["spec_id"].map(order) - df = df.sort_values("__order").drop(columns="__order").reset_index(drop=True) - return df + if not spec_ids: + return pd.DataFrame(columns=cols) + + rows = self._fetch_rows_by_ids(spec_ids, cols=", ".join(cols)) + by_id = {row["spec_id"]: row for row in rows} + + records: List[Dict[str, Any]] = [] + for sid in spec_ids: + row = by_id.get(sid) + if row is None: + rec = {"spec_id": sid} + rec.update({k: None for k in self.metadata_fields}) + else: + rec = {"spec_id": sid} + for k in self.metadata_fields: + rec[k] = row[k] + records.append(rec) + + return pd.DataFrame.from_records(records, columns=cols) def sql_query(self, query: str) -> pd.DataFrame: """Run a raw SQL SELECT and return a DataFrame.""" @@ -225,7 +259,6 @@ def compute_embeddings_to_sqlite( - Uses `matchms.Spectrum` objects reconstructed from the stored peaks & metadata. - Stores raw float32 vectors (no extra header) with their dimension `d`. """ - # TODO: add batch_size to speed up? spectra_table = spectra_table or self.table self._ensure_schema() # spectra schema self.ensure_embeddings_schema(embeddings_table) @@ -254,7 +287,9 @@ def compute_embeddings_to_sqlite( model = self.load_ms2deepscore_model(model_path) inserted = 0 - buf: List[Tuple[str, bytes, bytes, int, float, str, Optional[int]]] = [] + buf: List[ + Tuple[str, bytes, bytes, int, float, str, Optional[int]] + ] = [] done_since_commit = 0 def flush(batch) -> int: @@ -265,24 +300,35 @@ def flush(batch) -> int: for sid, mz_blob, it_blob, n_peaks, prec_mz, ionmode, charge in batch: mz = _from_float32_bytes(mz_blob, int(n_peaks)) it = _from_float32_bytes(it_blob, int(n_peaks)) - spectrum = Spectrum(mz=mz, intensities=it, metadata={ - "precursor_mz": float(prec_mz) if prec_mz is not None else None, - "ionmode": ionmode, - "charge": charge, - "spec_id": sid, - }) + spectrum = Spectrum( + mz=mz, + intensities=it, + metadata={ + "precursor_mz": float(prec_mz) if prec_mz is not None else None, + "ionmode": ionmode, + "charge": charge, + "spec_id": sid, + }, + ) specs.append(spectrum) sids.append(sid) embeddings = compute_spectra_embeddings( - model, specs, - normalize_spectrum=self.spectrum_sum_normalization_for_embedding - ) + model, + specs, + normalize_spectrum=self.spectrum_sum_normalization_for_embedding, + ) dim = int(embeddings.shape[1]) - q = f"INSERT OR REPLACE INTO {embeddings_table} (spec_id, d, vec) VALUES (?, ?, ?);" + q = ( + f"INSERT OR REPLACE INTO {embeddings_table} " + f"(spec_id, d, vec) VALUES (?, ?, ?);" + ) with self._conn: for sid, embedding in zip(sids, embeddings): - self._conn.execute(q, (sid, dim, sqlite3.Binary(_as_float32_bytes(embedding)))) + self._conn.execute( + q, + (sid, dim, sqlite3.Binary(_as_float32_bytes(embedding))), + ) return len(batch) while True: @@ -360,13 +406,13 @@ def connection(self) -> sqlite3.Connection: return self._conn # ---------- internal ---------- - def _fetch_rows_by_ids(self, specIDs: List[str], cols: str) -> List[sqlite3.Row]: - if not specIDs: + def _fetch_rows_by_ids(self, spec_ids: List[str], cols: str) -> List[sqlite3.Row]: + if not spec_ids: return [] - placeholders = ",".join("?" for _ in specIDs) + placeholders = ",".join("?" for _ in spec_ids) sql = f"SELECT {cols} FROM {self.table} WHERE spec_id IN ({placeholders})" cur = self._conn.cursor() - return cur.execute(sql, specIDs).fetchall() + return cur.execute(sql, spec_ids).fetchall() def _ensure_schema(self): cur = self._conn.cursor() From c3d7b62bd0072924288fbccd818e4b41b0d98edf Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Tue, 9 Dec 2025 09:58:39 +0100 Subject: [PATCH 17/45] add test and clean --- ms2query/ms2query_library.py | 2 +- tests/test_spectral_database.py | 63 +++++++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/ms2query/ms2query_library.py b/ms2query/ms2query_library.py index 0130b26..978adae 100644 --- a/ms2query/ms2query_library.py +++ b/ms2query/ms2query_library.py @@ -352,7 +352,7 @@ def query_by_spec_ids( normalized=True, ) - # If DB returns nothing, keep the old API behaviour + # If DB returns nothing if X.size == 0: return [] if not return_dataframe else self._empty_result_df() diff --git a/tests/test_spectral_database.py b/tests/test_spectral_database.py index adfc296..fe7237e 100644 --- a/tests/test_spectral_database.py +++ b/tests/test_spectral_database.py @@ -1,4 +1,3 @@ -# test_spectral_database.py import numpy as np import pandas as pd import pytest @@ -149,10 +148,70 @@ def test_missing_ids_handling(tmp_db, spectra_small): # Implementation skips missing IDs but preserves order of the ones that exist assert [s.metadata["spec_id"] for s in out_spectra] == [ids[1]] - assert list(out_meta["spec_id"]) == [ids[1]] + + # get_metadata_by_ids now returns one row per requested ID + assert out_meta.shape[0] == 2 + assert list(out_meta["spec_id"]) == req + + # First row corresponds to missing ID: all metadata fields should be None/NaN + missing_row = out_meta.iloc[0] + assert missing_row["spec_id"] == req[0] + assert missing_row[tmp_db.metadata_fields].isna().all() + + # Exactly one row has *any* metadata filled (the real spec_id) + non_empty_rows = out_meta.dropna(how="all", subset=tmp_db.metadata_fields) + assert non_empty_rows.shape[0] == 1 + assert non_empty_rows.iloc[0]["spec_id"] == ids[1] + + # get_fragments_by_ids still skips missing IDs assert len(out_frags) == 1 +def test_get_metadata_by_ids_all_ids_included_even_if_same_compound(tmp_db): + # Two spectra with different peaks but same compound-level metadata + s1 = make_spectrum( + [100, 200, 300], + [10, 20, 30], + precursor_mz=250.0, + ionmode="positive", + inchikey="SAME-IK", + smiles="C", + name="compound-1", + ) + s2 = make_spectrum( + [110, 210, 310], + [5, 15, 25], + precursor_mz=260.0, + ionmode="positive", + inchikey="SAME-IK", # same compound + smiles="C", + name="compound-1", + ) + + ids = tmp_db.add_spectra([s1, s2]) + + # Request both spec_ids; we expect two rows, one per ID, same order + df = tmp_db.get_metadata_by_ids(ids) + + expected_cols = ["spec_id"] + tmp_db.metadata_fields + assert list(df.columns) == expected_cols + assert df.shape[0] == 2 + assert list(df["spec_id"]) == ids + + # Both rows should carry the same compound-level metadata (same inchikey/smiles) + assert df.loc[0, "inchikey"] == "SAME-IK" + assert df.loc[1, "inchikey"] == "SAME-IK" + assert df.loc[0, "smiles"] == "C" + assert df.loc[1, "smiles"] == "C" + + # If name is stored, it should be consistent across rows (but may be None) + assert df.loc[0, "name"] == df.loc[1, "name"] + + # The precursor m/z values differ per spectrum and should be preserved per ID + assert df.loc[0, "precursor_mz"] == pytest.approx(250.0) + assert df.loc[1, "precursor_mz"] == pytest.approx(260.0) + + def test_add_duplicates_are_ignored_and_ids_repeat(tmp_db, spectra_small): # First insert ids_first = tmp_db.add_spectra(spectra_small) From 62eef512b262273f4db3b6959d5345761ba8483a Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Tue, 9 Dec 2025 11:02:31 +0100 Subject: [PATCH 18/45] cleaning and additional test --- tests/test_ann_vector_index.py | 79 ++++++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 12 deletions(-) diff --git a/tests/test_ann_vector_index.py b/tests/test_ann_vector_index.py index fe3971a..69f7f42 100644 --- a/tests/test_ann_vector_index.py +++ b/tests/test_ann_vector_index.py @@ -4,7 +4,13 @@ import numpy as np import pytest import scipy.sparse as sp -from ms2query.database.ann_vector_index import EmbeddingIndex, csr_row_from_tuple, l1_norms_csr, tuples_to_csr +from ms2query.database.ann_vector_index import ( + EmbeddingIndex, + FingerprintSparseIndex, # <-- added + csr_row_from_tuple, + l1_norms_csr, + tuples_to_csr, +) def _mk_unit_vecs(*rows): @@ -15,8 +21,8 @@ def _mk_unit_vecs(*rows): def test_build_index_and_query_dense(): - X = _mk_unit_vecs([1,0,0], [0,1,0], [0,0,1]) - ids = ["a","b","c"] + X = _mk_unit_vecs([1, 0, 0], [0, 1, 0], [0, 0, 1]) + ids = ["a", "b", "c"] idx = EmbeddingIndex(dim=3) idx.build_index(X, ids) # Query close to [1,0,0] @@ -27,11 +33,11 @@ def test_build_index_and_query_dense(): def test_build_index_normalizes_when_requested(): - X = np.array([[2.0,0,0],[0,2.0,0]], dtype=np.float32) - ids = ["x","y"] + X = np.array([[2.0, 0, 0], [0, 2.0, 0]], dtype=np.float32) + ids = ["x", "y"] idx = EmbeddingIndex(dim=3) idx.build_index(X, ids) - q = np.array([1.0,0,0], dtype=np.float32) + q = np.array([1.0, 0, 0], dtype=np.float32) out = idx.query(q, k=1) assert out[0][0] == "x" assert 0.99 <= out[0][1] <= 1.0 @@ -41,14 +47,14 @@ def test_query_errors_and_dim_check(): idx = EmbeddingIndex(dim=3) with pytest.raises(RuntimeError): idx.query(np.zeros(3, np.float32)) - idx.build_index(np.eye(3, dtype=np.float32), ["a","b","c"]) + idx.build_index(np.eye(3, dtype=np.float32), ["a", "b", "c"]) with pytest.raises(ValueError, match="dim=3"): idx.query(np.zeros(4, np.float32)) def test_save_and_load_roundtrip_dense(tmp_path): - X = _mk_unit_vecs([1,0,0],[0,1,0],[0,0,1]) - ids = ["a","b","c"] + X = _mk_unit_vecs([1, 0, 0], [0, 1, 0], [0, 0, 1]) + ids = ["a", "b", "c"] idx = EmbeddingIndex(dim=3) idx.build_index(X, ids) prefix = os.path.join(tmp_path, "emb") @@ -59,7 +65,7 @@ def test_save_and_load_roundtrip_dense(tmp_path): idx2.load_index(prefix) # Query should still work - res = idx2.query(np.array([1.0,0,0], dtype=np.float32), k=1) + res = idx2.query(np.array([1.0, 0, 0], dtype=np.float32), k=1) assert res[0][0] == "a" # meta persisted with open(prefix + ".meta.json") as f: @@ -72,7 +78,8 @@ def test_build_index_from_sqlite_streams_and_orders(batch_rows): conn = sqlite3.connect(":memory:") conn.execute("CREATE TABLE embeddings(spec_id TEXT, vec BLOB, d INTEGER)") # Add 3 vectors of dim 3 - conn.executemany( "INSERT INTO embeddings(spec_id, vec, d) VALUES (?,?,?)", + conn.executemany( + "INSERT INTO embeddings(spec_id, vec, d) VALUES (?,?,?)", [ ("id_1", np.array([1.0, 0.0, 0.0], np.float32).tobytes(), 3), ("id_2", np.array([1.0, 1.0, 0.0], np.float32).tobytes(), 3), @@ -106,7 +113,7 @@ def test_build_index_from_sqlite_errors(): # Empty table should error conn2 = sqlite3.connect(":memory:") - conn2.execute("CREATE TABLE embeddings(spec_id TEXT, vec BLOB, d INTEGER)") + conn2.execute("CREATE TABLE embeddings(spec_id TEXT, vec, d INTEGER)") with pytest.raises(ValueError, match="No rows"): idx.build_index_from_sqlite(conn2, embeddings_table="embeddings") @@ -158,3 +165,51 @@ def test_l1_norms_csr(): norms = l1_norms_csr(X) np.testing.assert_allclose(norms, [3.0, 1.0, 4.0]) assert norms.dtype == np.float64 + + +def test_save_and_load_roundtrip_fingerprint_sparse(tmp_path): + # Build a tiny sparse fingerprint matrix with 3 compounds, dim=5 + tuples = [ + (np.array([0, 3], dtype=np.int32), np.array([1.0, 0.5], dtype=np.float32)), + (np.array([1], dtype=np.int32), np.array([1.0], dtype=np.float32)), + (np.array([2, 4], dtype=np.int32), np.array([0.2, 2.0], dtype=np.float32)), + ] + csr = tuples_to_csr(tuples, dim=5) + comp_ids = np.array([10, 11, 12], dtype=int) + + idx = FingerprintSparseIndex(dim=5) + idx.build_index( + csr, + comp_ids, + keep_csr_for_rerank=True, + compute_l1_for_rerank=True, + ) + + # Basic sanity: query with the first fingerprint should hit comp_id=10 first + q = tuples[0] + res = idx.query(q, k=1) + assert res[0][0] == 10 + + # Save to disk + prefix = os.path.join(tmp_path, "fp") + idx.save_index(prefix) + + # New instance loads back + idx2 = FingerprintSparseIndex() + idx2.load_index(prefix) + + # Query should still work and return the same top compound + res2 = idx2.query(q, k=1) + assert res2[0][0] == 10 + + # CSR and L1 data should have been persisted + assert idx2._csr is not None + assert idx2._l1 is not None + assert idx2._csr.shape == csr.shape + assert idx2._l1.shape[0] == csr.shape[0] + + # Meta persisted and type is correct + with open(prefix + ".meta.json") as f: + meta = json.load(f) + assert meta["type"] == "FingerprintSparseIndex" + assert meta["space"] == "cosinesimil_sparse" From b4561c0e7a5e87ae7ed02f6b8d7830bd747b49bb Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Tue, 9 Dec 2025 15:59:20 +0100 Subject: [PATCH 19/45] Do correct subsetting of inchikey sets --- ms2query/benchmarking/SpectrumDataSet.py | 9 +++++++++ tests/test_SpectrumDataSet.py | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/ms2query/benchmarking/SpectrumDataSet.py b/ms2query/benchmarking/SpectrumDataSet.py index 0751d70..2446d30 100644 --- a/ms2query/benchmarking/SpectrumDataSet.py +++ b/ms2query/benchmarking/SpectrumDataSet.py @@ -108,6 +108,15 @@ def copy(self): new_instance.inchikey_fingerprint_pairs = copy.copy(self.inchikey_fingerprint_pairs) return new_instance + def subset_spectra(self, spectrum_indexes) -> "SpectraWithFingerprints": + """Returns a new instance of a subset of the spectra""" + new_instance = super().subset_spectra(spectrum_indexes) + # Only keep the fingerprints for which we have inchikeys. + # Important note: This is not a deep copy! + # And the fingerprint is not reset (so it is not always actually matching the most common inchi) + new_instance.inchikey_fingerprint_pairs = {inchikey: self.inchikey_fingerprint_pairs[inchikey] for inchikey in new_instance.spectrum_indexes_per_inchikey.keys()} + return new_instance + class SpectraWithMS2DeepScoreEmbeddings(SpectraWithFingerprints): def __init__(self, spectra: List[Spectrum], ms2deepscore_model: SiameseSpectralModel, **kwargs): diff --git a/tests/test_SpectrumDataSet.py b/tests/test_SpectrumDataSet.py index d7b3cff..a4c8e2c 100644 --- a/tests/test_SpectrumDataSet.py +++ b/tests/test_SpectrumDataSet.py @@ -91,9 +91,9 @@ def test_spectra_with_fingerprints(library): ) # test correct subsetting - subset_indexes = [1, 4, 6, 7] + subset_indexes = [1, 6, 7] subset = library.subset_spectra(subset_indexes) - assert len(subset.inchikey_fingerprint_pairs) == 3 + assert len(subset.inchikey_fingerprint_pairs) == 2 assert all( np.array_equal(library.inchikey_fingerprint_pairs[key], value) for key, value in subset.inchikey_fingerprint_pairs.items() From 657990d65f0c048b63ddc9e725d677b5030cfa31 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Tue, 9 Dec 2025 20:09:09 +0100 Subject: [PATCH 20/45] fix --- ms2query/ms2query_library.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ms2query/ms2query_library.py b/ms2query/ms2query_library.py index 978adae..044874b 100644 --- a/ms2query/ms2query_library.py +++ b/ms2query/ms2query_library.py @@ -204,8 +204,6 @@ def query_compounds_by_compounds( # Compute fingerprints (sparse representation) fps = self.db.all_cdb.compute_fingerprints( compounds, - count=False, - sparse=True, ) # Batched fingerprint ANN query From 170537f6edae8004558fbc08844fd76f8d1b6839 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Tue, 9 Dec 2025 20:13:02 +0100 Subject: [PATCH 21/45] cleaning, fixing, linting --- ms2query/database/ann_vector_index.py | 2 +- ms2query/database/compound_database.py | 9 +++++++++ ms2query/ms2query_library.py | 12 ++++++++---- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/ms2query/database/ann_vector_index.py b/ms2query/database/ann_vector_index.py index 44c06c0..042c296 100644 --- a/ms2query/database/ann_vector_index.py +++ b/ms2query/database/ann_vector_index.py @@ -727,7 +727,7 @@ def save_index(self, path_prefix: str) -> None: if self._index is None: raise RuntimeError("Index not built.") - # Also save data so that loadIndex(..., load_data=True) works + # Also save data so that load_index(..., load_data=True) works self._index.saveIndex(f"{path_prefix}.nmslib", save_data=True) np.save(f"{path_prefix}.ids.npy", self._comp_ids) diff --git a/ms2query/database/compound_database.py b/ms2query/database/compound_database.py index afc5b98..140aaea 100644 --- a/ms2query/database/compound_database.py +++ b/ms2query/database/compound_database.py @@ -467,6 +467,15 @@ def compute_fingerprints( Does NOT write to the database. Provide exactly one of (smiles, inchis). + Parameters + ---------- + smiles : Optional[List[str]], optional + List of SMILES strings, by default None + inchis : Optional[List[str]], optional + List of InChI strings, by default None + progress_bar : bool, optional + Whether to show a progress bar, by default False + Returns the same shapes/types as compute_morgan_fingerprints: - dense: np.ndarray of shape (N, nbits) - sparse/binary: List[np.ndarray[uint32]] diff --git a/ms2query/ms2query_library.py b/ms2query/ms2query_library.py index 044874b..86596e6 100644 --- a/ms2query/ms2query_library.py +++ b/ms2query/ms2query_library.py @@ -185,8 +185,9 @@ def query_spectra_by_spectra( def query_compounds_by_compounds( self, - compounds: Sequence[str], *, + smiles: Optional[List[str]] = None, + inchis: Optional[List[str]] = None, k_compounds: int = 10, ) -> List[List[Dict[str, Any]]]: """ @@ -194,8 +195,10 @@ def query_compounds_by_compounds( Parameters ---------- - compounds : Sequence[str] - Query compounds (expects list of SMILES strings). + smiles : Optional[List[str]], optional + List of SMILES strings, by default None + inchis : Optional[List[str]], optional + List of InChI strings, by default None k_compounds : int Number of top compounds to return per query compound. """ @@ -203,7 +206,8 @@ def query_compounds_by_compounds( # Compute fingerprints (sparse representation) fps = self.db.all_cdb.compute_fingerprints( - compounds, + smiles=smiles, + inchis=inchis, ) # Batched fingerprint ANN query From 7f769977394a76cba5124f15f15d6f25444cca2a Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Wed, 10 Dec 2025 12:58:05 +0100 Subject: [PATCH 22/45] Add method for predicting using top 10 closest library spectra. --- .../predict_using_closest_tanimoto.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py diff --git a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py new file mode 100644 index 0000000..f844b77 --- /dev/null +++ b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py @@ -0,0 +1,66 @@ +import numpy as np +from ms2deepscore.vector_operations import cosine_similarity_matrix +from typing import Tuple, List + +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings +from ms2query.metrics import generalized_tanimoto_similarity_matrix + + +def predict_using_closest_tanimoto( + library_spectra: SpectraWithMS2DeepScoreEmbeddings, query_spectra: SpectraWithMS2DeepScoreEmbeddings, + nr_of_closest_inchikeys_to_select=10 +) -> Tuple[List[str], List[float]]: + """Predict best inchikey, by taking the average score over all spectra for the 10 closest related library inchikeys. + (simplified version of old MS2Query) + """ + inchikeys_of_best_match = [] + single_highest_score = [] + for spectrum_idx in range(len(query_spectra.spectra)): + inchikey_of_best_match, score = predict_using_closest_tanimoto_single_spectrum( + library_spectra, query_spectra.subset_spectra([spectrum_idx]), nr_of_closest_inchikeys_to_select) + inchikeys_of_best_match.append(inchikey_of_best_match) + single_highest_score.append(score) + return inchikeys_of_best_match, single_highest_score + + +def predict_using_closest_tanimoto_single_spectrum(spectra_with_embeddings, single_spectrum_with_embeddings, + nr_of_closest_inchikeys_to_select) -> Tuple[str, float]: + if len(single_spectrum_with_embeddings.spectra) != 1: + raise ValueError("expected a single spectrum") + ms2deepscores = cosine_similarity_matrix(single_spectrum_with_embeddings.embeddings, + spectra_with_embeddings.embeddings)[0] + average_predicted_scores = {} + for inchikey, spectrum_indexes in spectra_with_embeddings.spectrum_indexes_per_inchikey.items(): + all_ms2deepscores_for_inchikey = ms2deepscores[spectrum_indexes] + if max(all_ms2deepscores_for_inchikey) > 0.7: + average_predicted_score = get_average_predictions_for_closely_related_metabolites( + spectra_with_embeddings, inchikey, ms2deepscores, nr_of_closest_inchikeys_to_select) + average_predicted_scores[inchikey] = average_predicted_score + + inchikey_with_highest_average_prediction, score = max(average_predicted_scores.items(), key=lambda item: item[1]) + return inchikey_with_highest_average_prediction, score + +def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddings, inchikey, + all_ms2deepscores, nr_of_closest_inchikeys_to_select): + """Calculates the average ms2deepscore predictions for top k closest inchikeys""" + top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k( + spectra_with_embeddings, inchikey,nr_of_closest_inchikeys_to_select) + + average_predicted_scores = [] + for top_inchikey in top_k_inchikeys: + matching_spectrum_indexes = spectra_with_embeddings.spectrum_indexes_per_inchikey[top_inchikey] + predicted_scores = all_ms2deepscores[matching_spectrum_indexes] + average_predicted_scores.append(predicted_scores.mean()) + average_predicted_score = sum(average_predicted_scores) / len(average_predicted_scores) + return average_predicted_score + +def get_inchikey_and_tanimoto_scores_for_top_k(spectra: SpectraWithMS2DeepScoreEmbeddings, inchikey, k) -> tuple[list[str], np.ndarray]: + """For an inchikey in a library the top k highest tanimoto scores in the library are predicted (including itself)""" + library_fingerprints = np.vstack(list(spectra.inchikey_fingerprint_pairs.values())) + fingerprint_single_inchikey = np.vstack(list([spectra.inchikey_fingerprint_pairs[inchikey]])) + similarity_scores = generalized_tanimoto_similarity_matrix(fingerprint_single_inchikey, library_fingerprints)[0] + inchikey_indexes_of_top_k = np.argpartition(similarity_scores, -k)[-k:] + tanimoto_scores_for_top_k = similarity_scores[inchikey_indexes_of_top_k] + all_inchikeys = list(spectra.inchikey_fingerprint_pairs.keys()) + top_inchikeys = [all_inchikeys[inchikey_index] for inchikey_index in inchikey_indexes_of_top_k] + return top_inchikeys, tanimoto_scores_for_top_k From 2c2e35ffa059fc65ca15a2c380e2533d5743ee15 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Wed, 10 Dec 2025 13:04:20 +0100 Subject: [PATCH 23/45] Added extra inchikey smiles examples --- tests/conftest.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 80ab525..80ec8e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,6 +73,18 @@ def get_inchikey_inchi_pairs(number_of_pairs): "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", "Glucose", ), + ( + "MWOOGOJBHIARFG-UHFFFAOYSA-N", + "InChI=1S/C8H8O3/c1-11-8-4-6(5-9)2-3-7(8)10/h2-5,10H,1H3", + "COC1=C(C=CC(=C1)C=O)O", + "vanillin" + ), + ( + "ROHFNLRQFUQHCH-YFKPBYRVSA-N", + "InChI=1S/C6H13NO2/c1-4(2)3-5(7)6(8)9/h4-5H,3,7H2,1-2H3,(H,8,9)/t5-/m0/s1", + "CC(C)C[C@@H](C(=O)O)N", + "L-Leucine" + ) ) if number_of_pairs > len(inchikey_inchi_pairs): raise ValueError("Not enough example compounds, add some in conftest") From e646e2239c50460a5e14957400b66c5590b83138 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Wed, 10 Dec 2025 13:52:43 +0100 Subject: [PATCH 24/45] refactoring/linting --- ms2query/database/spec_to_compound_mapper.py | 3 +++ ms2query/database/spectral_database.py | 10 ---------- ms2query/ms2query_library.py | 13 ++++++++++--- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/ms2query/database/spec_to_compound_mapper.py b/ms2query/database/spec_to_compound_mapper.py index 7442437..f8d141a 100644 --- a/ms2query/database/spec_to_compound_mapper.py +++ b/ms2query/database/spec_to_compound_mapper.py @@ -27,6 +27,7 @@ class SpecToCompoundMap: sqlite_path: str table: str = "spec_to_comp" compound_table: str = "compounds" + _conn: sqlite3.Connection = field(init=False, repr=False) def __post_init__(self): @@ -84,6 +85,8 @@ def link_many(self, pairs: Iterable[Tuple[int, str]]): cur.execute("ROLLBACK") raise + # ---- getters: spec_id -> comp_id ---- + def get_comp_id_for_specs(self, spec_ids: List[str]) -> pd.DataFrame: """Return a DataFrame with columns [spec_id, comp_id] for the provided spec_ids.""" if not spec_ids: diff --git a/ms2query/database/spectral_database.py b/ms2query/database/spectral_database.py index 225123a..806a1b8 100644 --- a/ms2query/database/spectral_database.py +++ b/ms2query/database/spectral_database.py @@ -390,16 +390,6 @@ def get_embeddings( X = X / n return np.asarray(sids, dtype=object), X - def get_embedding_for_id( - self, - spec_id: str, - *, - embeddings_table: str = "embeddings", - normalized: bool = True, - ) -> Optional[np.ndarray]: - ids, X = self.get_embeddings([spec_id], embeddings_table=embeddings_table, normalized=normalized) - return X[0] if X.shape[0] else None - # expose raw connection for ANN builders @property def connection(self) -> sqlite3.Connection: diff --git a/ms2query/ms2query_library.py b/ms2query/ms2query_library.py index 86596e6..ab24b0a 100644 --- a/ms2query/ms2query_library.py +++ b/ms2query/ms2query_library.py @@ -26,8 +26,6 @@ class MS2QueryLibrary: Notes ----- - * `model_path` is optional; provide it if you want on-the-fly embedding of ad-hoc spectra. - If omitted, you can still query using precomputed embeddings fetched from SQLite. * EmbeddingIndex must be built/loaded elsewhere (creation handled by setup workflow). """ db: MS2QueryDatabase @@ -221,7 +219,16 @@ def query_compounds_by_compounds( ] results_all.append(one) - return results_all + # Convert to DataFrame + results_df = pd.DataFrame( + [ + {"query_ix": qi, **item} + for qi, lst in enumerate(results_all) + for item in lst + ], + columns=["query_ix", "rank", "comp_id", "score"], + ) + return results_df def query_compounds_by_spectra( self, From fb4cbd80ec728d394f342d5ee401c8dbf453400f Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Wed, 10 Dec 2025 14:17:02 +0100 Subject: [PATCH 25/45] Add test_get_inchikey_and_tanimoto_scores_from_top_k --- tests/test_predict_using_closest_tanimoto.py | 28 ++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 tests/test_predict_using_closest_tanimoto.py diff --git a/tests/test_predict_using_closest_tanimoto.py b/tests/test_predict_using_closest_tanimoto.py new file mode 100644 index 0000000..65933f2 --- /dev/null +++ b/tests/test_predict_using_closest_tanimoto.py @@ -0,0 +1,28 @@ +import numpy as np + +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings, SpectraWithFingerprints +from ms2query.benchmarking.reference_methods.predict_using_closest_tanimoto import ( + predict_using_closest_tanimoto, predict_using_closest_tanimoto_single_spectrum, + get_average_predictions_for_closely_related_metabolites, get_inchikey_and_tanimoto_scores_for_top_k) +from tests.conftest import ms2deepscore_model, create_test_spectra +import pytest + + + +@pytest.mark.parametrize( + "k", + [1, 3, 7], +) +def test_get_inchikey_and_tanimoto_scores_for_top_k(k): + spectra = SpectraWithFingerprints(create_test_spectra(nr_of_inchikeys=7)) + inchikey = list(spectra.inchikey_fingerprint_pairs.keys())[2] + + top_inchikeys, tanimoto_scores_for_top_k = get_inchikey_and_tanimoto_scores_for_top_k( + spectra, inchikey,k) + + assert inchikey in top_inchikeys + assert len(top_inchikeys) == k + assert len(tanimoto_scores_for_top_k) == k + assert len(set(top_inchikeys)) == k + assert tanimoto_scores_for_top_k[top_inchikeys.index(inchikey)] == 1.0, \ + "The exact match is expected to have a score of 1.0" \ No newline at end of file From 1933c8530e9199478b6ed531d0cda94c99702d2d Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Wed, 10 Dec 2025 14:17:53 +0100 Subject: [PATCH 26/45] Move top_k_selection outside average computation for easier testing --- .../predict_using_closest_tanimoto.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py index f844b77..6b5fa25 100644 --- a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py +++ b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py @@ -2,7 +2,7 @@ from ms2deepscore.vector_operations import cosine_similarity_matrix from typing import Tuple, List -from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings, SpectraWithFingerprints from ms2query.metrics import generalized_tanimoto_similarity_matrix @@ -33,19 +33,18 @@ def predict_using_closest_tanimoto_single_spectrum(spectra_with_embeddings, sing for inchikey, spectrum_indexes in spectra_with_embeddings.spectrum_indexes_per_inchikey.items(): all_ms2deepscores_for_inchikey = ms2deepscores[spectrum_indexes] if max(all_ms2deepscores_for_inchikey) > 0.7: + top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k( + spectra_with_embeddings, inchikey, nr_of_closest_inchikeys_to_select) average_predicted_score = get_average_predictions_for_closely_related_metabolites( - spectra_with_embeddings, inchikey, ms2deepscores, nr_of_closest_inchikeys_to_select) + spectra_with_embeddings, top_k_inchikeys, ms2deepscores) average_predicted_scores[inchikey] = average_predicted_score inchikey_with_highest_average_prediction, score = max(average_predicted_scores.items(), key=lambda item: item[1]) return inchikey_with_highest_average_prediction, score -def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddings, inchikey, - all_ms2deepscores, nr_of_closest_inchikeys_to_select): +def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddings, top_k_inchikeys, + all_ms2deepscores): """Calculates the average ms2deepscore predictions for top k closest inchikeys""" - top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k( - spectra_with_embeddings, inchikey,nr_of_closest_inchikeys_to_select) - average_predicted_scores = [] for top_inchikey in top_k_inchikeys: matching_spectrum_indexes = spectra_with_embeddings.spectrum_indexes_per_inchikey[top_inchikey] @@ -54,7 +53,8 @@ def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddi average_predicted_score = sum(average_predicted_scores) / len(average_predicted_scores) return average_predicted_score -def get_inchikey_and_tanimoto_scores_for_top_k(spectra: SpectraWithMS2DeepScoreEmbeddings, inchikey, k) -> tuple[list[str], np.ndarray]: +def get_inchikey_and_tanimoto_scores_for_top_k(spectra: SpectraWithFingerprints, inchikey, k + ) -> tuple[list[str], np.ndarray]: """For an inchikey in a library the top k highest tanimoto scores in the library are predicted (including itself)""" library_fingerprints = np.vstack(list(spectra.inchikey_fingerprint_pairs.values())) fingerprint_single_inchikey = np.vstack(list([spectra.inchikey_fingerprint_pairs[inchikey]])) From fd2bf89def21119b67481333b1c9a2bcaa6ed676 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Wed, 10 Dec 2025 14:18:11 +0100 Subject: [PATCH 27/45] Add test_get_average_predictions_for_closely_related_metabolites --- tests/test_predict_using_closest_tanimoto.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_predict_using_closest_tanimoto.py b/tests/test_predict_using_closest_tanimoto.py index 65933f2..e948187 100644 --- a/tests/test_predict_using_closest_tanimoto.py +++ b/tests/test_predict_using_closest_tanimoto.py @@ -8,6 +8,24 @@ import pytest +def test_get_average_predictions_for_closely_related_metabolites(): + test_spectra = create_test_spectra(nr_of_inchikeys=7) + # Select different number per inchikey (only one for the first) to check that it is correctly weighted. + test_spectra = test_spectra.copy()[2:] + spectra = SpectraWithFingerprints(test_spectra) + + inchikeys = list(spectra.inchikey_fingerprint_pairs.keys())[:3] + ms2deepscores = np.zeros(len(spectra.spectra)) + ms2deepscores[0] = 0.8 + ms2deepscores[[1,2,3]] = 0.6 + ms2deepscores[4] = 0.6 + ms2deepscores[5] = 0.8 + ms2deepscores[6] = 0.7 + # the average per inchikey is 0.8, 0.6, 0.7, so average overall should be 0.7 + average_predicted_score = get_average_predictions_for_closely_related_metabolites(spectra, + inchikeys, + ms2deepscores) + assert np.allclose(average_predicted_score, np.array(0.7), atol=1e-5) @pytest.mark.parametrize( "k", From 59aca8eca42a60682656ae5988f9e431ba2c0d1a Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Wed, 10 Dec 2025 15:33:04 +0100 Subject: [PATCH 28/45] Split select_inchikeys_with_highest_ms2deepscores to make more modular and to select top k highest ms2deepscores --- .../predict_using_closest_tanimoto.py | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py index 6b5fa25..f353aeb 100644 --- a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py +++ b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py @@ -14,34 +14,47 @@ def predict_using_closest_tanimoto( (simplified version of old MS2Query) """ inchikeys_of_best_match = [] - single_highest_score = [] + highest_scores = [] for spectrum_idx in range(len(query_spectra.spectra)): inchikey_of_best_match, score = predict_using_closest_tanimoto_single_spectrum( library_spectra, query_spectra.subset_spectra([spectrum_idx]), nr_of_closest_inchikeys_to_select) inchikeys_of_best_match.append(inchikey_of_best_match) - single_highest_score.append(score) - return inchikeys_of_best_match, single_highest_score + highest_scores.append(score) + return inchikeys_of_best_match, highest_scores def predict_using_closest_tanimoto_single_spectrum(spectra_with_embeddings, single_spectrum_with_embeddings, - nr_of_closest_inchikeys_to_select) -> Tuple[str, float]: + nr_of_closest_inchikeys_to_select, + nr_of_inchikeys_with_highest_ms2deepscore_to_select) -> Tuple[str, float]: if len(single_spectrum_with_embeddings.spectra) != 1: raise ValueError("expected a single spectrum") ms2deepscores = cosine_similarity_matrix(single_spectrum_with_embeddings.embeddings, spectra_with_embeddings.embeddings)[0] + top_inchikeys = select_inchikeys_with_highest_ms2deepscore(spectra_with_embeddings, ms2deepscores, + nr_of_inchikeys_with_highest_ms2deepscore_to_select) average_predicted_scores = {} - for inchikey, spectrum_indexes in spectra_with_embeddings.spectrum_indexes_per_inchikey.items(): - all_ms2deepscores_for_inchikey = ms2deepscores[spectrum_indexes] - if max(all_ms2deepscores_for_inchikey) > 0.7: - top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k( - spectra_with_embeddings, inchikey, nr_of_closest_inchikeys_to_select) - average_predicted_score = get_average_predictions_for_closely_related_metabolites( - spectra_with_embeddings, top_k_inchikeys, ms2deepscores) - average_predicted_scores[inchikey] = average_predicted_score + for inchikey in top_inchikeys: + top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k( + spectra_with_embeddings, inchikey, nr_of_closest_inchikeys_to_select) + average_predicted_score = get_average_predictions_for_closely_related_metabolites( + spectra_with_embeddings, top_k_inchikeys, ms2deepscores) + average_predicted_scores[inchikey] = average_predicted_score inchikey_with_highest_average_prediction, score = max(average_predicted_scores.items(), key=lambda item: item[1]) return inchikey_with_highest_average_prediction, score +def select_inchikeys_with_highest_ms2deepscore(spectra_with_embeddings, ms2deepscores, nr_of_inchikeys_to_select=10): + highest_score_for_inchikey = [] + for inchikey, spectrum_indexes in spectra_with_embeddings.spectrum_indexes_per_inchikey.items(): + all_ms2deepscores_for_inchikey = ms2deepscores[spectrum_indexes] + highest_score_for_inchikey.append(max(all_ms2deepscores_for_inchikey)) + inchikey_indexes_with_highest_ms2deepscore = np.argpartition( + np.array(highest_score_for_inchikey), -nr_of_inchikeys_to_select)[-nr_of_inchikeys_to_select:] + + all_inchikeys = list(spectra_with_embeddings.inchikey_fingerprint_pairs.keys()) + top_inchikeys = [all_inchikeys[inchikey_index] for inchikey_index in inchikey_indexes_with_highest_ms2deepscore] + return top_inchikeys + def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddings, top_k_inchikeys, all_ms2deepscores): """Calculates the average ms2deepscore predictions for top k closest inchikeys""" From f223e4302fa249057077421b398e5797c8f29539 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Wed, 10 Dec 2025 15:33:41 +0100 Subject: [PATCH 29/45] Add test_select_inchikeys_with_highest_ms2deepscore --- tests/test_predict_using_closest_tanimoto.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_predict_using_closest_tanimoto.py b/tests/test_predict_using_closest_tanimoto.py index e948187..21b3a7c 100644 --- a/tests/test_predict_using_closest_tanimoto.py +++ b/tests/test_predict_using_closest_tanimoto.py @@ -3,11 +3,25 @@ from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings, SpectraWithFingerprints from ms2query.benchmarking.reference_methods.predict_using_closest_tanimoto import ( predict_using_closest_tanimoto, predict_using_closest_tanimoto_single_spectrum, - get_average_predictions_for_closely_related_metabolites, get_inchikey_and_tanimoto_scores_for_top_k) + get_average_predictions_for_closely_related_metabolites, get_inchikey_and_tanimoto_scores_for_top_k, + select_inchikeys_with_highest_ms2deepscore) from tests.conftest import ms2deepscore_model, create_test_spectra import pytest +def test_select_inchikeys_with_highest_ms2deepscore(): + test_spectra = create_test_spectra(nr_of_inchikeys=7) + spectra = SpectraWithFingerprints(test_spectra) + + ms2deepscores = np.zeros(len(test_spectra)) + ms2deepscores[2] = 0.4 + ms2deepscores[5] = 0.9 + ms2deepscores[7] = 0.6 + inchikeys_with_highest_ms2deepscore = select_inchikeys_with_highest_ms2deepscore(spectra, ms2deepscores, 3) + expected_inchikeys = list(spectra.spectrum_indexes_per_inchikey.keys())[:3] + assert set(expected_inchikeys) == set(inchikeys_with_highest_ms2deepscore) + print(inchikeys_with_highest_ms2deepscore) + def test_get_average_predictions_for_closely_related_metabolites(): test_spectra = create_test_spectra(nr_of_inchikeys=7) # Select different number per inchikey (only one for the first) to check that it is correctly weighted. From 16058ccb53a15814d79188475caa35e2e24f02b6 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Wed, 10 Dec 2025 15:51:55 +0100 Subject: [PATCH 30/45] Added nr_of_inchikeys_with_highest_ms2deepscore_to_select as parameter --- .../reference_methods/predict_using_closest_tanimoto.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py index f353aeb..f0a40c9 100644 --- a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py +++ b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py @@ -8,7 +8,8 @@ def predict_using_closest_tanimoto( library_spectra: SpectraWithMS2DeepScoreEmbeddings, query_spectra: SpectraWithMS2DeepScoreEmbeddings, - nr_of_closest_inchikeys_to_select=10 + nr_of_closest_inchikeys_to_select=10, + nr_of_inchikeys_with_highest_ms2deepscore_to_select=100 ) -> Tuple[List[str], List[float]]: """Predict best inchikey, by taking the average score over all spectra for the 10 closest related library inchikeys. (simplified version of old MS2Query) @@ -17,7 +18,8 @@ def predict_using_closest_tanimoto( highest_scores = [] for spectrum_idx in range(len(query_spectra.spectra)): inchikey_of_best_match, score = predict_using_closest_tanimoto_single_spectrum( - library_spectra, query_spectra.subset_spectra([spectrum_idx]), nr_of_closest_inchikeys_to_select) + library_spectra, query_spectra.subset_spectra([spectrum_idx]), + nr_of_closest_inchikeys_to_select, nr_of_inchikeys_with_highest_ms2deepscore_to_select) inchikeys_of_best_match.append(inchikey_of_best_match) highest_scores.append(score) return inchikeys_of_best_match, highest_scores From 1c2f1a733453be68ec8a3fc6bd53f4b9704b8149 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Wed, 10 Dec 2025 15:52:16 +0100 Subject: [PATCH 31/45] Added basic tests for predict using closest tanimoto score (checking correct types of output) --- tests/test_predict_using_closest_tanimoto.py | 23 ++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_predict_using_closest_tanimoto.py b/tests/test_predict_using_closest_tanimoto.py index 21b3a7c..bdb9741 100644 --- a/tests/test_predict_using_closest_tanimoto.py +++ b/tests/test_predict_using_closest_tanimoto.py @@ -9,6 +9,29 @@ import pytest +def test_predict_using_closest_tanimoto(): + """Only very basic test that the function runs and that the output is the right format""" + model = ms2deepscore_model() + library_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(nr_of_inchikeys=7), model) + test_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(1, nr_of_inchikeys=3), model) + predicted_inchikeys, scores = predict_using_closest_tanimoto(library_spectra, test_spectra, 3, 3) + + assert isinstance(predicted_inchikeys, list) + assert len(predicted_inchikeys) == 3 + assert isinstance(scores, list) + assert len(scores) == 3 + +def test_predict_using_closest_tanimoto_single_spectrum(): + """Only very basic test that the function runs and that the output is the right format""" + model = ms2deepscore_model() + library_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(nr_of_inchikeys=7), model) + test_spectra = SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(1, nr_of_inchikeys=1), model) + predicted_inchikey, score = predict_using_closest_tanimoto_single_spectrum(library_spectra, test_spectra, 3, 3) + + assert isinstance(predicted_inchikey, str) + assert len(predicted_inchikey) ==14 + assert isinstance(score, float) + def test_select_inchikeys_with_highest_ms2deepscore(): test_spectra = create_test_spectra(nr_of_inchikeys=7) spectra = SpectraWithFingerprints(test_spectra) From 69f30e754400d8e63f62a78202016f7c5178f578 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Wed, 10 Dec 2025 17:21:20 +0100 Subject: [PATCH 32/45] switch to batch/list handling for queries --- ms2query/database/spec_to_compound_mapper.py | 231 ++++++++++++++----- 1 file changed, 171 insertions(+), 60 deletions(-) diff --git a/ms2query/database/spec_to_compound_mapper.py b/ms2query/database/spec_to_compound_mapper.py index f8d141a..e7c5e4c 100644 --- a/ms2query/database/spec_to_compound_mapper.py +++ b/ms2query/database/spec_to_compound_mapper.py @@ -1,7 +1,7 @@ import sqlite3 from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple import pandas as pd from ms2query.data_processing import inchikey14_from_full from ms2query.database import CompoundDatabase @@ -11,10 +11,11 @@ # Mapping: spectrum <-> compound (spec_to_comp) # ================================================== + @dataclass class SpecToCompoundMap: - """This class manages the mapping between spectrum IDs and compound IDs. - + """Manage the mapping between spectrum IDs and compound IDs (inchikey14). + Attributes ---------- sqlite_path : str @@ -36,50 +37,61 @@ def __post_init__(self): self._conn.row_factory = sqlite3.Row self._ensure_schema() - def close(self): + # ---------------- lifecycle ---------------- + + def close(self) -> None: try: self._conn.close() except Exception: pass - def _ensure_schema(self): + # ---------------- schema ---------------- + + def _ensure_schema(self) -> None: cur = self._conn.cursor() - # No strict FK enforcement (SpectralDatabase may have been created without FK pragma), - # here: index both sides for fast lookup. - cur.executescript(f""" + # No strict FK enforcement; index both sides for fast lookup. + cur.executescript( + f""" CREATE TABLE IF NOT EXISTS {self.table}( spec_id TEXT NOT NULL, - comp_id TEXT NOT NULL, + comp_id TEXT NOT NULL, PRIMARY KEY (spec_id), CHECK (length(comp_id) = 14) ); - CREATE INDEX IF NOT EXISTS idx_spec_to_comp_comp ON {self.table}(comp_id); - """) + CREATE INDEX IF NOT EXISTS idx_{self.table}_comp ON {self.table}(comp_id); + """ + ) self._conn.commit() - # ---------- API ---------- + # ---------------- API ---------------- - def link(self, spec_id: str, comp_id: str): - """Insert or replace a single mapping.""" + def link(self, spec_id: str, comp_id: str) -> None: + """Insert or replace a single mapping (spec_id -> comp_id).""" if not comp_id or len(comp_id) != 14: raise ValueError("comp_id must be inchikey14 (14 characters).") - self._conn.execute(f""" + self._conn.execute( + f""" INSERT INTO {self.table} (spec_id, comp_id) VALUES (?, ?) - ON CONFLICT(spec_id) DO UPDATE SET comp_id=excluded.comp_id - """, (spec_id, comp_id)) + ON CONFLICT(spec_id) DO UPDATE SET comp_id = excluded.comp_id + """, + (spec_id, comp_id), + ) self._conn.commit() - def link_many(self, pairs: Iterable[Tuple[int, str]]): - """Bulk link (spec_id, comp_id).""" + def link_many(self, pairs: Iterable[Tuple[str, str]]) -> None: + """Bulk link (spec_id, comp_id) pairs.""" cur = self._conn.cursor() cur.execute("BEGIN") try: - cur.executemany(f""" + cur.executemany( + f""" INSERT INTO {self.table} (spec_id, comp_id) VALUES (?, ?) - ON CONFLICT(spec_id) DO UPDATE SET comp_id=excluded.comp_id - """, list(pairs)) + ON CONFLICT(spec_id) DO UPDATE SET comp_id = excluded.comp_id + """, + list(pairs), + ) cur.execute("COMMIT") except Exception: cur.execute("ROLLBACK") @@ -87,25 +99,87 @@ def link_many(self, pairs: Iterable[Tuple[int, str]]): # ---- getters: spec_id -> comp_id ---- - def get_comp_id_for_specs(self, spec_ids: List[str]) -> pd.DataFrame: - """Return a DataFrame with columns [spec_id, comp_id] for the provided spec_ids.""" + def get_comp_id_for_spec(self, spec_id: str) -> Optional[str]: + """ + Return the comp_id for a single spec_id, or None if not mapped. + """ + df = self.get_comp_id_for_specs([spec_id]) + if df.empty: + return None + val = df.loc[0, "comp_id"] + return None if pd.isna(val) else str(val) + + def get_comp_id_for_specs(self, spec_ids: Sequence[str]) -> pd.DataFrame: + """ + Return a DataFrame with columns ['spec_id', 'comp_id'] for the provided spec_ids. + + Behaviour + --------- + - Only returns rows for spec_ids that actually have a mapping. + - spec_id values come from the database (TEXT → strings). + - The result may have fewer rows than requested. + """ + cols = ["spec_id", "comp_id"] if not spec_ids: - return pd.DataFrame(columns=["spec_id", "comp_id"]) + return pd.DataFrame(columns=cols) + placeholders = ",".join("?" * len(spec_ids)) rows = self._conn.execute( - f"SELECT spec_id, comp_id FROM {self.table} WHERE spec_id IN ({placeholders})", - spec_ids + f""" + SELECT spec_id, comp_id + FROM {self.table} + WHERE spec_id IN ({placeholders}) + """, + list(spec_ids), ).fetchall() - return pd.DataFrame(rows, columns=["spec_id", "comp_id"]) + + # rows already have the correct types from SQLite (TEXT for both columns) + return pd.DataFrame(rows, columns=cols) + + # ---- getters: comp_id -> spec_id(s) ---- def get_specs_for_comp(self, comp_id: str) -> List[str]: - """Return list of spec_ids for a given comp_id.""" - rows = self._conn.execute(f"SELECT spec_id FROM {self.table} WHERE comp_id = ?", (comp_id,)).fetchall() - return [r[0] for r in rows] + """ + Return list of spec_ids (as strings) for a single comp_id. + """ + rows = self._conn.execute( + f"SELECT spec_id FROM {self.table} WHERE comp_id = ?", + (comp_id,), + ).fetchall() + return [str(r[0]) for r in rows] + + def get_specs_for_comps(self, comp_ids: Sequence[str]) -> pd.DataFrame: + """ + Return a DataFrame with columns ['comp_id', 'spec_id'] for the given comp_ids. + + Behaviour + --------- + - For each comp_id, all mapped spec_ids are returned (1:N). + - Order of returned rows is determined by the underlying query. + - If a comp_id has no spectra, it will not appear in the result. + """ + cols = ["comp_id", "spec_id"] + if not comp_ids: + return pd.DataFrame(columns=cols) + + placeholders = ",".join("?" * len(comp_ids)) + rows = self._conn.execute( + f""" + SELECT comp_id, spec_id + FROM {self.table} + WHERE comp_id IN ({placeholders}) + """, + list(comp_ids), + ).fetchall() + return pd.DataFrame(rows, columns=cols) + + # ---- misc ---- def get_all_mappings(self) -> pd.DataFrame: """Return all spec_id <-> comp_id mappings as a DataFrame.""" - rows = self._conn.execute(f"SELECT spec_id, comp_id FROM {self.table}").fetchall() + rows = self._conn.execute( + f"SELECT spec_id, comp_id FROM {self.table}" + ).fetchall() return pd.DataFrame(rows, columns=["spec_id", "comp_id"]) @@ -113,6 +187,7 @@ def get_all_mappings(self) -> pd.DataFrame: # Integrations with SpectralDatabase # ================================================== + def map_from_spectraldb_metadata( spectral_db_sqlite_path: str, mapping_sqlite_path: Optional[str] = None, @@ -121,15 +196,17 @@ def map_from_spectraldb_metadata( compound_table: str = "compounds", mapping_table: str = "spec_to_comp", *, - create_missing_compounds: bool = True + create_missing_compounds: bool = True, ) -> Tuple[int, int]: """ Read spectra metadata (expects 'inchikey' in metadata), create comp_id (inchikey14), populate spec_to_comp, and optionally upsert minimal compounds. - Returns: (n_mapped, n_new_compounds) + Returns + ------- + (n_mapped, n_new_compounds) """ - # We do not import the class to avoid circular imports; use sqlite directly. + # We do not import the SpectralDatabase class to avoid circular imports; use sqlite directly. s_conn = sqlite3.connect(spectral_db_sqlite_path) s_conn.row_factory = sqlite3.Row @@ -141,13 +218,20 @@ def map_from_spectraldb_metadata( # Discover which columns exist in the spectra table cols = {r[1] for r in s_conn.execute(f"PRAGMA table_info({spectra_table})").fetchall()} - want = ["spec_id", "inchikey", "smiles", "inchi", "classyfire_class", "classyfire_superclass"] + want = [ + "spec_id", + "inchikey", + "smiles", + "inchi", + "classyfire_class", + "classyfire_superclass", + ] have = [c for c in want if c in cols] select_cols = ", ".join(have) rows = s_conn.execute(f"SELECT {select_cols} FROM {spectra_table}").fetchall() - to_link: List[Tuple[int, str]] = [] + to_link: List[Tuple[str, str]] = [] new_comp_rows: List[Dict[str, Any]] = [] for r in rows: @@ -156,20 +240,26 @@ def map_from_spectraldb_metadata( ik_full = r.get("inchikey") if not ik_full: continue + comp_id = inchikey14_from_full(ik_full) if not comp_id: continue - to_link.append((spec_id, comp_id)) + + # spec_id may be int in the spectra table; mapping expects TEXT, but + # SQLite will happily store the string representation. + to_link.append((str(spec_id), comp_id)) if create_missing_compounds: - new_comp_rows.append({ - "smiles": r.get("smiles"), - "inchi": r.get("inchi"), - "inchikey": ik_full, - "classyfire_class": r.get("classyfire_class"), - "classyfire_superclass": r.get("classyfire_superclass"), - "fingerprint": None, # backfill later - }) + new_comp_rows.append( + { + "smiles": r.get("smiles"), + "inchi": r.get("inchi"), + "inchikey": ik_full, + "classyfire_class": r.get("classyfire_class"), + "classyfire_superclass": r.get("classyfire_superclass"), + "fingerprint": None, # backfill later + } + ) # Bulk linking if to_link: @@ -178,7 +268,7 @@ def map_from_spectraldb_metadata( # Upsert compounds n_new_compounds = 0 if create_missing_compounds and new_comp_rows: - # Deduplicate by comp_id to avoid redundant upserts + # Deduplicate by comp_id (inchikey14) to avoid redundant upserts seen: set[str] = set() dedup_rows: List[Dict[str, Any]] = [] for r in new_comp_rows: @@ -186,9 +276,10 @@ def map_from_spectraldb_metadata( if cid and cid not in seen: seen.add(cid) dedup_rows.append(r) + before = compdb.sql_query(f"SELECT COUNT(*) AS n FROM {compound_table}")["n"].iloc[0] compdb.upsert_many(dedup_rows) - after = compdb.sql_query(f"SELECT COUNT(*) AS n FROM {compound_table}")["n"].iloc[0] + after = compdb.sql_query(f"SELECT COUNT(*) AS n FROM {compound_table}")["n"].iloc[0] n_new_compounds = int(after - before) n_mapped = len(to_link) @@ -205,41 +296,61 @@ def get_unique_compounds_from_spectraldb( spectral_db_sqlite_path: str, spectra_table: str = "spectra", external_meta: Optional[pd.DataFrame] = None, - external_key_col: str = "inchikey14" + external_key_col: str = "inchikey14", ) -> pd.DataFrame: """ - Return a DataFrame of unique compounds present in the spectral DB, inferred via inchikey → inchikey14. - Columns: inchikey14, inchikey (full), n_spectra. If `external_meta` is provided, - it will be left-joined on `external_key_col` (default 'inchikey14'). + Return a DataFrame of unique compounds present in the spectral DB, inferred via + inchikey → inchikey14. + + Columns (first three): + - inchikey14 + - n_spectra + - inchikey (full) + + If `external_meta` is provided, it is left-joined on `external_key_col` + (default: 'inchikey14'). """ conn = sqlite3.connect(spectral_db_sqlite_path) conn.row_factory = sqlite3.Row - # pull spec_id + inchikey from spectra + # Pull spec_id + inchikey from spectra df = pd.read_sql_query(f"SELECT spec_id, inchikey FROM {spectra_table}", conn) conn.close() if df.empty: - base = pd.DataFrame(columns=["inchikey14", "inchikey", "n_spectra"]) + base = pd.DataFrame(columns=["inchikey14", "n_spectra", "inchikey"]) if external_meta is not None: - return base.merge(external_meta, how="left", left_on="inchikey14", right_on=external_key_col) + return base.merge( + external_meta, + how="left", + left_on="inchikey14", + right_on=external_key_col, + ) return base # Compute inchikey14 ik14 = df["inchikey"].fillna("").map(inchikey14_from_full) df["inchikey14"] = ik14 - # Aggregate + # Aggregate: order of columns is important for tests: + # ["inchikey14", "n_spectra", "inchikey"] agg = ( df.dropna(subset=["inchikey14"]) - .groupby(["inchikey14"], as_index=False) - .agg(n_spectra=("spec_id", "count"), - inchikey=("inchikey", "first")) # first full key seen + .groupby(["inchikey14"], as_index=False) + .agg( + n_spectra=("spec_id", "count"), + inchikey=("inchikey", "first"), # first full key seen + ) ) # Optional join with external meta if external_meta is not None and not external_meta.empty: - agg = agg.merge(external_meta, how="left", left_on="inchikey14", right_on=external_key_col) + agg = agg.merge( + external_meta, + how="left", + left_on="inchikey14", + right_on=external_key_col, + ) # Order by prevalence agg = agg.sort_values("n_spectra", ascending=False).reset_index(drop=True) From a681846828dd38177f35186c5aa287dd8f0f1e28 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Wed, 10 Dec 2025 17:22:36 +0100 Subject: [PATCH 33/45] linting and adapt to list inputs --- ms2query/ms2query_library.py | 3 +-- tests/test_library_io.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ms2query/ms2query_library.py b/ms2query/ms2query_library.py index ab24b0a..2e89fe5 100644 --- a/ms2query/ms2query_library.py +++ b/ms2query/ms2query_library.py @@ -293,7 +293,6 @@ def analogue_search( self, spectra: Union[Spectrum, Sequence[Spectrum]], *, - k_spectra: int = 1, k_compounds: int = 10, ef: Optional[int] = None, ): @@ -307,7 +306,7 @@ def analogue_search( """ # Step 1: top-k_spectra per query spec_hits = self.query_spectra_by_spectra( - spectra, k_spectra=k_spectra, ef=ef + spectra, k_spectra=1, ef=ef ) # DataFrame if spec_hits.empty: return [] diff --git a/tests/test_library_io.py b/tests/test_library_io.py index ad9a82f..a406410 100644 --- a/tests/test_library_io.py +++ b/tests/test_library_io.py @@ -56,7 +56,7 @@ def test_create_and_load_library(tmp_path: Path): ms2query_db = lib.db # Metadata query by compound id (expected shape from your snippet) - df_meta = ms2query_db.metadata_by_comp_id(TEST_COMP_ID) + df_meta = ms2query_db.metadata_by_comp_ids([TEST_COMP_ID]) assert tuple(df_meta.shape) == EXPECTED_METADATA_SHAPE # Metadata fields presence both in db wrapper and in returned dataframe From 146792ebb52e53fdbf337cd64275892ed272b092 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Wed, 10 Dec 2025 17:23:00 +0100 Subject: [PATCH 34/45] clean up and switch to list inputs --- ms2query/ms2query_database.py | 178 +++++++++++++++++++++++---------- tests/test_ms2query_library.py | 2 +- 2 files changed, 124 insertions(+), 56 deletions(-) diff --git a/ms2query/ms2query_database.py b/ms2query/ms2query_database.py index b887174..7cb3c60 100644 --- a/ms2query/ms2query_database.py +++ b/ms2query/ms2query_database.py @@ -1,6 +1,7 @@ import sqlite3 from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Sequence +import numpy as np import pandas as pd from ms2query.data_processing import inchikey14_from_full from ms2query.database import ( @@ -13,6 +14,7 @@ # ================================ public wrapper ============================== + @dataclass class MS2QueryDatabase: """Wrapper class as main hub/glue between the different MS2Query database elements. @@ -23,17 +25,26 @@ class MS2QueryDatabase: * Provide one-stop creation from (already processed!) `matchms.Spectrum` objects. * Offer ergonomic retrievals by `spec_id`, `comp_id` (inchikey14). * Keep *types and table access paths* in one place. - """ sqlite_path: str ref_spectra_table: str = "spectra" ref_compound_table: str = "compounds" non_annotated_compound_table: str = "compounds_all" - metadata_fields: List[str] = field(default_factory=lambda: [ - "precursor_mz", "ionmode", "smiles", "inchikey", "inchi", "name", - "charge", "instrument_type", "adduct", "collision_energy" - ]) + metadata_fields: List[str] = field( + default_factory=lambda: [ + "precursor_mz", + "ionmode", + "smiles", + "inchikey", + "inchi", + "name", + "charge", + "instrument_type", + "adduct", + "collision_energy", + ] + ) # component singletons ref_sdb: SpectralDatabase = field(init=False) @@ -43,12 +54,18 @@ class MS2QueryDatabase: def __post_init__(self): # Initialize components (each manages its own connection) - self.ref_sdb = SpectralDatabase(self.sqlite_path, table=self.ref_spectra_table, - metadata_fields=self.metadata_fields) + self.ref_sdb = SpectralDatabase( + self.sqlite_path, + table=self.ref_spectra_table, + metadata_fields=self.metadata_fields, + ) self.ref_cdb = CompoundDatabase(self.sqlite_path, table=self.ref_compound_table) - self.all_cdb = CompoundDatabase(self.sqlite_path, table=self.non_annotated_compound_table) - self.mapper = SpecToCompoundMap(self.sqlite_path, compound_table=self.ref_compound_table) - + self.all_cdb = CompoundDatabase( + self.sqlite_path, table=self.non_annotated_compound_table + ) + self.mapper = SpecToCompoundMap( + self.sqlite_path, compound_table=self.ref_compound_table + ) # ----------------------------- creation pipeline ----------------------------- @@ -63,14 +80,17 @@ def create_from_spectra( Parameters ---------- - spectra: List[matchms.Spectrum] + spectra : List[matchms.Spectrum] List of matchms Spectrum objects to be inserted into the database. - map_compounds: bool, default=True + map_compounds : bool, default=True Whether to map spectra to compounds based on metadata InChIKeys. - create_missing_compounds: bool, default=True + create_missing_compounds : bool, default=True Whether to create compound entries for spectra that do not have a matching compound yet. - Returns counts: {"n_inserted_spectra": int, "n_mapped": int, "n_new_compounds": int} + Returns + ------- + dict + Counts: {"n_inserted_spectra": int, "n_mapped": int, "n_new_compounds": int} """ spec_ids = self.ref_sdb.add_spectra(spectra) n_mapped = 0 @@ -90,13 +110,13 @@ def create_from_spectra( "n_mapped": int(n_mapped), "n_new_compounds": int(n_new), } - - def add_second_compound_database(self, df): + + def add_second_compound_database(self, df: pd.DataFrame) -> None: """Add an additional 'all compound' database without need for spectral data. Parameters ---------- - df: pd.DataFrame + df : pd.DataFrame DataFrame containing inchikey and other relevant compound information. Should at least contain smiles or inchi. """ @@ -105,37 +125,91 @@ def add_second_compound_database(self, df): # --------------------------------- retrievals -------------------------------- # ---- by spec_id ---- - def spectra_by_spec_ids(self, spec_ids: List[int]): - return self.ref_sdb.get_spectra_by_ids(spec_ids) + def spectra_by_spec_ids(self, spec_ids: Sequence[str]): + """Return list[Spectrum] for the given spec_ids.""" + return self.ref_sdb.get_spectra_by_ids(list(spec_ids)) - def fragments_by_spec_ids(self, spec_ids: List[int]): - return self.ref_sdb.get_fragments_by_ids(spec_ids) + def fragments_by_spec_ids(self, spec_ids: Sequence[str]): + """Return list[(mz, intensity)] for the given spec_ids.""" + return self.ref_sdb.get_fragments_by_ids(list(spec_ids)) - def metadata_by_spec_ids(self, spec_ids: List[int]) -> pd.DataFrame: - return self.ref_sdb.get_metadata_by_ids(spec_ids) - - def embeddings_by_spec_ids(self, spec_ids: List[int]): - return self.ref_sdb.get_embeddings(spec_ids=spec_ids) + def metadata_by_spec_ids(self, spec_ids: Sequence[str]) -> pd.DataFrame: + """Return metadata DataFrame for the given spec_ids.""" + return self.ref_sdb.get_metadata_by_ids(list(spec_ids)) - # ---- by comp_id (inchikey14) ---- + def embeddings_by_spec_ids(self, spec_ids: Sequence[str]): + """Return (ids, embeddings) tuple for the given spec_ids.""" + return self.ref_sdb.get_embeddings(spec_ids=list(spec_ids)) - def spec_ids_by_comp_id(self, comp_id: str) -> List[int]: - return self.mapper.get_specs_for_comp(comp_id) + # ---- by comp_ids (inchikey14) ---- - def spectra_by_comp_id(self, comp_id: str): - return self.ref_sdb.get_spectra_by_ids(self.spec_ids_by_comp_id(comp_id)) + def spec_ids_by_comp_ids(self, comp_ids: Sequence[str]) -> pd.DataFrame: + """ + Return mapping of comp_ids -> spec_ids. - def metadata_by_comp_id(self, comp_id: str) -> pd.DataFrame: - spec_ids = self.spec_ids_by_comp_id(comp_id) - return self.ref_sdb.get_metadata_by_ids(spec_ids) + Returns + ------- + pd.DataFrame + Columns: ['comp_id', 'spec_id']. + One row per existing mapping (1:N). + """ + return self.mapper.get_specs_for_comps(list(comp_ids)) - def compound(self, comp_id: str) -> Optional[Dict[str, Any]]: - return self.ref_cdb.get_compound(comp_id) - - def embeddings_by_comp_id(self, comp_id: str): - spec_ids = self.spec_ids_by_comp_id(comp_id) - return self.ref_sdb.get_embeddings(spec_ids=spec_ids) + def spectra_by_comp_ids(self, comp_ids: Sequence[str]): + """ + Return all spectra mapped to any of the given comp_ids. + + Notes + ----- + * The order of spectra is determined by the underlying `get_spectra_by_ids` + implementation and mapping table. + * If you need to know which comp_id each spectrum belongs to, combine this + with `spec_ids_by_comp_ids`. + """ + df_map = self.mapper.get_specs_for_comps(list(comp_ids)) + if df_map.empty: + return [] + spec_ids = df_map["spec_id"].tolist() + return self.ref_sdb.get_spectra_by_ids(spec_ids) + + def metadata_by_comp_ids(self, comp_ids: Sequence[str]) -> pd.DataFrame: + """ + Return metadata for all spectra mapped to the given comp_ids. + Returns + ------- + pd.DataFrame + Columns: ['comp_id', 'spec_id', ...metadata_fields...] + """ + df_map = self.mapper.get_specs_for_comps(list(comp_ids)) + if df_map.empty: + cols = ["comp_id", "spec_id"] + self.metadata_fields + return pd.DataFrame(columns=cols) + + meta = self.ref_sdb.get_metadata_by_ids(df_map["spec_id"].tolist()) + # meta: spec_id + metadata_fields + out = df_map.merge(meta, on="spec_id", how="inner") + # Ensure column order: comp_id, spec_id, metadata... + return out[["comp_id", "spec_id"] + self.metadata_fields] + + def embeddings_by_comp_ids(self, comp_ids: Sequence[str]): + """ + Return embeddings for all spectra mapped to the given comp_ids. + + Returns + ------- + (np.ndarray, np.ndarray) + (spec_ids, embeddings) as returned by SpectralDatabase.get_embeddings. + """ + df_map = self.mapper.get_specs_for_comps(list(comp_ids)) + if df_map.empty: + # Mirror SpectralDatabase.get_embeddings empty contract + return ( + np.empty((0,), dtype=str), + np.empty((0, 0), dtype=np.float32), + ) + spec_ids = df_map["spec_id"].tolist() + return self.ref_sdb.get_embeddings(spec_ids=spec_ids) # -------------------------------- convenience SQL ------------------------------ @@ -148,19 +222,13 @@ def sql(self, query: str) -> pd.DataFrame: # ----------------------------------- utilities --------------------------------- def inchikey_to_comp_id(self, inchikey_full: str) -> Optional[str]: + """Convert a full InChIKey to the 14-character comp_id (inchikey14).""" return inchikey14_from_full(inchikey_full) - def close(self): - # Close component connections - try: - self.ref_sdb.close() - except Exception: - pass - try: - self.ref_cdb.close() - except Exception: - pass - try: - self.mapper.close() - except Exception: - pass + def close(self) -> None: + """Close all component connections.""" + for obj in (self.ref_sdb, self.ref_cdb, self.all_cdb, self.mapper): + try: + obj.close() + except Exception: + pass diff --git a/tests/test_ms2query_library.py b/tests/test_ms2query_library.py index a6fd087..5efe5a8 100644 --- a/tests/test_ms2query_library.py +++ b/tests/test_ms2query_library.py @@ -79,7 +79,7 @@ def test_create_and_load_smoke(tmp_path: Path): # DB content checks ms2query_db = lib.db - meta_df = ms2query_db.metadata_by_comp_id(TEST_COMP_ID) + meta_df = ms2query_db.metadata_by_comp_ids([TEST_COMP_ID]) assert tuple(meta_df.shape) == EXPECTED_METADATA_SHAPE for field in EXPECTED_METADATA_FIELDS: assert field in ms2query_db.metadata_fields From cf6a80579034ec3b8106a5795472d3c28d6c4947 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Wed, 10 Dec 2025 17:23:57 +0100 Subject: [PATCH 35/45] adjust tests --- tests/test_library_io.py | 2 +- tests/test_ms2query_library.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_library_io.py b/tests/test_library_io.py index a406410..04662bd 100644 --- a/tests/test_library_io.py +++ b/tests/test_library_io.py @@ -8,7 +8,7 @@ TEST_COMP_ID = "ZBSGKPYXQINNGF" # expected InChIKey14 present in the test data -EXPECTED_METADATA_SHAPE = (5, 11) +EXPECTED_METADATA_SHAPE = (5, 12) EXPECTED_METADATA_FIELDS = [ "precursor_mz", "ionmode", "smiles", "inchikey", "inchi", "name", "charge", "instrument_type", "adduct", "collision_energy", diff --git a/tests/test_ms2query_library.py b/tests/test_ms2query_library.py index 5efe5a8..255f9cb 100644 --- a/tests/test_ms2query_library.py +++ b/tests/test_ms2query_library.py @@ -9,7 +9,7 @@ TEST_COMP_ID = "ZBSGKPYXQINNGF" # known from your snippet -EXPECTED_METADATA_SHAPE = (5, 11) +EXPECTED_METADATA_SHAPE = (5, 12) EXPECTED_METADATA_FIELDS = [ "precursor_mz", "ionmode", "smiles", "inchikey", "inchi", "name", "charge", "instrument_type", "adduct", "collision_energy", From 4e208cfa8766725aa31290bf460e978b78369374 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Thu, 11 Dec 2025 10:49:09 +0100 Subject: [PATCH 36/45] add merge_fingerprint method --- ms2query/data_processing/__init__.py | 3 +- .../fingerprint_computation.py | 101 ++++++++++++++++++ ms2query/ms2query_library.py | 50 ++++++++- 3 files changed, 148 insertions(+), 6 deletions(-) diff --git a/ms2query/data_processing/__init__.py b/ms2query/data_processing/__init__.py index 7dacae7..a3f3836 100644 --- a/ms2query/data_processing/__init__.py +++ b/ms2query/data_processing/__init__.py @@ -1,5 +1,5 @@ from .chemistry_utils import compute_morgan_fingerprints, inchikey14_from_full -from .fingerprint_computation import compute_fingerprints_from_smiles +from .fingerprint_computation import compute_fingerprints_from_smiles, merge_fingerprints from .merging_utils import cluster_block, get_merged_spectra from .spectra_processing import compute_spectra_embeddings, normalize_spectrum_sum @@ -11,5 +11,6 @@ "compute_spectra_embeddings", "get_merged_spectra", "inchikey14_from_full", + "merge_fingerprints", "normalize_spectrum_sum", ] diff --git a/ms2query/data_processing/fingerprint_computation.py b/ms2query/data_processing/fingerprint_computation.py index b6b2e5f..ea2f3ca 100644 --- a/ms2query/data_processing/fingerprint_computation.py +++ b/ms2query/data_processing/fingerprint_computation.py @@ -1,6 +1,8 @@ +from typing import Optional, Sequence, Tuple import numba import numpy as np from numba import typed, types +from numpy.typing import NDArray from rdkit import Chem from tqdm import tqdm @@ -255,6 +257,105 @@ def count_fingerprint_keys(fingerprints): return unique_keys[order], count_arr[order], first_arr[order] +def merge_fingerprints( + fingerprints: Sequence[Tuple[NDArray[np.integer], NDArray[np.floating]]], + weights: Optional[NDArray[np.floating]] = None, +) -> Tuple[NDArray[np.integer], NDArray[np.floating]]: + """ + Merge multiple sparse Morgan (count/TF-IDF) fingerprints into a single + weighted-average fingerprint. + + Parameters + ---------- + fingerprints : + Sequence of (bits, values) pairs. + - bits: 1D integer array of bit indices (non-zero entries) + - values: 1D float array of TF-IDF (or other) weights, + same length as `bits`. + weights : + Optional 1D array-like of length len(fingerprints) with one weight + per fingerprint. Each fingerprint's values are scaled by its weight, + then the merged fingerprint is normalized by the sum of all weights. + + - If None, all fingerprints are weighted equally (weight = 1.0). + + Returns + ------- + merged_bits, merged_values : + - merged_bits: 1D integer array of unique bit indices + - merged_values: 1D float array of weighted-average values per bit + (sum over all weighted fingerprints, divided by sum(weights)). + """ + n_fps = len(fingerprints) + if n_fps == 0: + # Return empty sparse fingerprint + return ( + np.array([], dtype=np.int64), + np.array([], dtype=np.float64), + ) + + if weights is not None: + w = np.asarray(weights, dtype=np.float64).ravel() + if w.shape[0] != n_fps: + raise ValueError( + f"weights must have length {n_fps}, got {w.shape[0]}" + ) + total_weight = float(w.sum()) + if total_weight <= 0.0: + raise ValueError("Sum of weights must be positive.") + else: + # Equal weighting + w = None + total_weight = float(n_fps) + + # Concatenate all indices and (weighted) values + bits_list = [] + vals_list = [] + + for i, (bits, vals) in enumerate(fingerprints): + bits = np.asarray(bits) + vals = np.asarray(vals, dtype=np.float64) + + if bits.shape[0] != vals.shape[0]: + raise ValueError( + f"Fingerprint {i}: bits and values must have same length, " + f"got {bits.shape[0]} and {vals.shape[0]}" + ) + + if w is not None: + vals = vals * w[i] + + bits_list.append(bits) + vals_list.append(vals) + + if not bits_list: + return ( + np.array([], dtype=np.int64), + np.array([], dtype=np.float64), + ) + + all_bits = np.concatenate(bits_list) + all_vals = np.concatenate(vals_list) + + if all_bits.size == 0: + return ( + np.array([], dtype=np.int64), + np.array([], dtype=np.float64), + ) + + # Group by bit index and sum weighted values + unique_bits, inverse = np.unique(all_bits, return_inverse=True) + summed_vals = np.bincount(inverse, weights=all_vals) + + # Weighted average: divide by sum of weights + avg_vals = summed_vals / total_weight + + # Keep dtypes reasonably tight + merged_bits = unique_bits.astype(all_bits.dtype, copy=False) + merged_vals = avg_vals.astype(np.float32, copy=False) + + return merged_bits, merged_vals + ### ------------------------ ### Bit Scaling and Weighing ### ------------------------ diff --git a/ms2query/ms2query_library.py b/ms2query/ms2query_library.py index 2e89fe5..20dd667 100644 --- a/ms2query/ms2query_library.py +++ b/ms2query/ms2query_library.py @@ -4,8 +4,9 @@ import pandas as pd from matchms import Spectrum from ms2deepscore.models import load_model as _ms2ds_load_model +from sklearn.metrics.pairwise import cosine_similarity from ms2query import MS2QueryDatabase -from ms2query.data_processing import compute_spectra_embeddings +from ms2query.data_processing import compute_spectra_embeddings, merge_fingerprints from ms2query.database import EmbeddingIndex, FingerprintSparseIndex @@ -31,6 +32,7 @@ class MS2QueryLibrary: db: MS2QueryDatabase embedding_index: Optional[EmbeddingIndex] = None fingerprint_index: Optional[FingerprintSparseIndex] = None # for now: reference spectra only + large_scale_fingerprint_index: Optional[FingerprintSparseIndex] = None # for large body of reference compounds model_path: Optional[str] = None # internal: whether to apply spectrum normalization (sum=1) before embedding @@ -319,13 +321,51 @@ def analogue_search( .set_index("spec_id") ) - smiles = analogue_compounds["smiles"].tolist() + analogue_smiles = analogue_compounds["smiles"].tolist() # Step 3: fingerprint-based compound search top_compounds = self.query_compounds_by_compounds( - smiles, k_compounds=k_compounds - ) - return top_compounds + smiles=analogue_smiles + ).set_index("query_ix") + + # Step 4: for each query, pick the best matching spectrum among all spectra + fingerprints_merged = [] + weighted_average_scores = [] + embeddings_queries = self.compute_embeddings(spectra) # TODO: this is now done twice! in step 1 and here + for i in range(len(analogue_smiles)): + comp_ids = top_compounds.loc[i].comp_id.to_list() + + # Get chemically closest compounds + spec_ids_all = [] + spec_ids_selected = [] + embeddings_selected = [] + + all_spec_ids = self.db.spec_ids_by_comp_ids(comp_ids).set_index("comp_id") + for comp_id in comp_ids: + new_spec_ids = all_spec_ids.loc[comp_id].spec_id.to_list() + + # Get most similar embedding from one of the top-10 compounds + embs = self.db.ref_sdb.get_embeddings(new_spec_ids) + similarities = cosine_similarity(embs[1], embeddings_queries[i].reshape(1, -1)) + max_id = np.argmax(similarities) + spec_ids_selected.append(embs[0][max_id]) + embeddings_selected.append(embs[1][max_id]) + spec_ids_all.extend(new_spec_ids) + + top1_top10_similarities = cosine_similarity(embeddings_selected, embeddings_queries[i].reshape(1, -1)) + fingerprints = self.db.ref_cdb.get_fingerprints(comp_ids) + fingerprints_merged.append(merge_fingerprints(fingerprints, weights=top1_top10_similarities)) + weighted_average_scores.append(np.sum(top1_top10_similarities ** 2) / np.sum(top1_top10_similarities)) + if self.large_scale_fingerprint_index: + analogue_predictions = self.large_scale_fingerprint_index.query(fingerprints_merged, k=k_compounds) + elif self.fingerprint_index: + analogue_predictions = self.fingerprint_index.query(fingerprints_merged, k=k_compounds) + else: + raise RuntimeError("No fingerprint index is set. Build or load it before querying.") + return pd.DataFrame({ + "analogue_predictions": analogue_predictions, + "weighted_average_scores": weighted_average_scores + }) # ------------------------------------------------------------------ # Helpers / glue From 0ca72ef1160f1ec1f7d2c3902d0cdb54dd2de029 Mon Sep 17 00:00:00 2001 From: Florian Huber Date: Thu, 11 Dec 2025 10:54:34 +0100 Subject: [PATCH 37/45] update ms2deepscore version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3fe7fa4..396af2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ pandas = ">=2.1.1" scipy= ">=1.14.0" matplotlib= ">=3.8.0" matchms= ">=0.30.0" -ms2deepscore= { git = "https://github.com/matchms/ms2deepscore.git", branch = "pytorch_update" } +ms2deepscore= ">=2.6.0" rdkit= ">2024.3.4" nmslib= ">=2.0.0" umap-learn= ">=0.5.7" From 276511c107a62a3a35d702fa55ae151e3b515c1a Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Tue, 16 Dec 2025 09:49:53 +0100 Subject: [PATCH 38/45] Add tqdm to predict using closest tanimoto --- .../reference_methods/predict_using_closest_tanimoto.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py index f0a40c9..afc7ceb 100644 --- a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py +++ b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py @@ -5,6 +5,7 @@ from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings, SpectraWithFingerprints from ms2query.metrics import generalized_tanimoto_similarity_matrix +from tqdm import tqdm def predict_using_closest_tanimoto( library_spectra: SpectraWithMS2DeepScoreEmbeddings, query_spectra: SpectraWithMS2DeepScoreEmbeddings, @@ -16,7 +17,7 @@ def predict_using_closest_tanimoto( """ inchikeys_of_best_match = [] highest_scores = [] - for spectrum_idx in range(len(query_spectra.spectra)): + for spectrum_idx in tqdm(range(len(query_spectra.spectra)), "Predicting using closest tanimoto"): inchikey_of_best_match, score = predict_using_closest_tanimoto_single_spectrum( library_spectra, query_spectra.subset_spectra([spectrum_idx]), nr_of_closest_inchikeys_to_select, nr_of_inchikeys_with_highest_ms2deepscore_to_select) From a59e6a925bdfca01cd921cc5dcc20307b38ff7f0 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Tue, 16 Dec 2025 09:52:20 +0100 Subject: [PATCH 39/45] ruff --- ms2query/benchmarking/EvaluateMethods.py | 5 +---- ms2query/benchmarking/SpectrumDataSet.py | 6 ++---- .../PredictMS2DeepScoreSimilarity.py | 2 -- .../predict_best_possible_match.py | 2 -- .../reference_methods/predict_highest_cosine.py | 3 +-- .../predict_highest_ms2deepscore.py | 3 +-- .../predict_using_closest_tanimoto.py | 7 +++---- .../predict_with_integrated_similarity_flow.py | 5 ++--- tests/conftest.py | 1 - tests/testPredictMS2DeepScoreSimilarity.py | 3 +-- tests/test_SpectrumDataSet.py | 5 ++--- tests/test_evaluate_methods.py | 6 +++--- tests/test_methods.py | 12 +++++------- tests/test_predict_using_closest_tanimoto.py | 16 +++++++++------- 14 files changed, 30 insertions(+), 46 deletions(-) diff --git a/ms2query/benchmarking/EvaluateMethods.py b/ms2query/benchmarking/EvaluateMethods.py index 6442343..17a28a5 100644 --- a/ms2query/benchmarking/EvaluateMethods.py +++ b/ms2query/benchmarking/EvaluateMethods.py @@ -1,11 +1,8 @@ import random - +from typing import Callable, List, Tuple import numpy as np -from typing import Callable, Tuple, List - from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix from tqdm import tqdm - from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectrumSetBase diff --git a/ms2query/benchmarking/SpectrumDataSet.py b/ms2query/benchmarking/SpectrumDataSet.py index 2446d30..9b70dbb 100644 --- a/ms2query/benchmarking/SpectrumDataSet.py +++ b/ms2query/benchmarking/SpectrumDataSet.py @@ -1,12 +1,10 @@ import copy from collections import Counter -from typing import List, Dict, Iterable - +from typing import Dict, Iterable, List import numpy as np from matchms import Spectrum from matchms.filtering.metadata_processing.add_fingerprint import _derive_fingerprint_from_inchi - -from ms2deepscore.models import compute_embedding_array, SiameseSpectralModel +from ms2deepscore.models import SiameseSpectralModel, compute_embedding_array from tqdm import tqdm diff --git a/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py b/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py index 13121cf..a8082c1 100644 --- a/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py +++ b/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py @@ -1,7 +1,5 @@ from typing import Tuple - import numpy as np - from ms2deepscore.vector_operations import cosine_similarity_matrix from tqdm import tqdm diff --git a/ms2query/benchmarking/reference_methods/predict_best_possible_match.py b/ms2query/benchmarking/reference_methods/predict_best_possible_match.py index 63625f2..1544f71 100644 --- a/ms2query/benchmarking/reference_methods/predict_best_possible_match.py +++ b/ms2query/benchmarking/reference_methods/predict_best_possible_match.py @@ -1,8 +1,6 @@ from typing import Dict - import numpy as np from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix - from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints diff --git a/ms2query/benchmarking/reference_methods/predict_highest_cosine.py b/ms2query/benchmarking/reference_methods/predict_highest_cosine.py index 46820f1..49fd6cf 100644 --- a/ms2query/benchmarking/reference_methods/predict_highest_cosine.py +++ b/ms2query/benchmarking/reference_methods/predict_highest_cosine.py @@ -1,5 +1,4 @@ -from typing import Tuple, List - +from typing import List, Tuple from matchms import Scores from matchms.similarity.CosineGreedy import CosineGreedy from matchms.similarity.PrecursorMzMatch import PrecursorMzMatch diff --git a/ms2query/benchmarking/reference_methods/predict_highest_ms2deepscore.py b/ms2query/benchmarking/reference_methods/predict_highest_ms2deepscore.py index d645857..a5d9318 100644 --- a/ms2query/benchmarking/reference_methods/predict_highest_ms2deepscore.py +++ b/ms2query/benchmarking/reference_methods/predict_highest_ms2deepscore.py @@ -1,5 +1,4 @@ -from typing import Tuple, List - +from typing import List, Tuple from ms2query.benchmarking.reference_methods.PredictMS2DeepScoreSimilarity import predict_top_ms2deepscores from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings diff --git a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py index afc7ceb..7e12f52 100644 --- a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py +++ b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py @@ -1,11 +1,10 @@ +from typing import List, Tuple import numpy as np from ms2deepscore.vector_operations import cosine_similarity_matrix -from typing import Tuple, List - -from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings, SpectraWithFingerprints +from tqdm import tqdm +from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectraWithMS2DeepScoreEmbeddings from ms2query.metrics import generalized_tanimoto_similarity_matrix -from tqdm import tqdm def predict_using_closest_tanimoto( library_spectra: SpectraWithMS2DeepScoreEmbeddings, query_spectra: SpectraWithMS2DeepScoreEmbeddings, diff --git a/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py b/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py index 329fea1..b89a906 100644 --- a/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py +++ b/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py @@ -1,8 +1,7 @@ -from typing import Tuple, List -from tqdm import tqdm +from typing import List, Tuple import numpy as np from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix - +from tqdm import tqdm from ms2query.benchmarking.reference_methods.PredictMS2DeepScoreSimilarity import predict_top_ms2deepscores from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings diff --git a/tests/conftest.py b/tests/conftest.py index 80ec8e9..702e6fc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ import os from pathlib import Path - import numpy as np from matchms.Spectrum import Spectrum from ms2deepscore.models import load_model diff --git a/tests/testPredictMS2DeepScoreSimilarity.py b/tests/testPredictMS2DeepScoreSimilarity.py index 9741693..4741bdf 100644 --- a/tests/testPredictMS2DeepScoreSimilarity.py +++ b/tests/testPredictMS2DeepScoreSimilarity.py @@ -1,11 +1,10 @@ import numpy as np import pytest - from ms2query.benchmarking.reference_methods.PredictMS2DeepScoreSimilarity import ( predict_top_ms2deepscores, ) -from tests.conftest import create_test_spectra, ms2deepscore_model from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings +from tests.conftest import create_test_spectra, ms2deepscore_model @pytest.mark.parametrize( diff --git a/tests/test_SpectrumDataSet.py b/tests/test_SpectrumDataSet.py index a4c8e2c..4244e99 100644 --- a/tests/test_SpectrumDataSet.py +++ b/tests/test_SpectrumDataSet.py @@ -1,12 +1,11 @@ import numpy as np import pytest - from ms2query.benchmarking.SpectrumDataSet import ( SpectraWithFingerprints, - SpectrumSetBase, SpectraWithMS2DeepScoreEmbeddings, + SpectrumSetBase, ) -from tests.conftest import create_test_spectra, ms2deepscore_model, get_inchikey_inchi_pairs +from tests.conftest import create_test_spectra, get_inchikey_inchi_pairs, ms2deepscore_model @pytest.mark.parametrize( diff --git a/tests/test_evaluate_methods.py b/tests/test_evaluate_methods.py index 4bcb839..7431aad 100644 --- a/tests/test_evaluate_methods.py +++ b/tests/test_evaluate_methods.py @@ -1,10 +1,10 @@ import pytest -from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings from ms2query.benchmarking.EvaluateMethods import EvaluateMethods -from ms2query.benchmarking.reference_methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore from ms2query.benchmarking.reference_methods.predict_best_possible_match import predict_best_possible_match -from tests.conftest import create_test_spectra, ms2deepscore_model from ms2query.benchmarking.reference_methods.predict_highest_cosine import predict_highest_cosine +from ms2query.benchmarking.reference_methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings +from tests.conftest import create_test_spectra, ms2deepscore_model @pytest.mark.parametrize( diff --git a/tests/test_methods.py b/tests/test_methods.py index 03cca1c..1b747d0 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -1,17 +1,15 @@ import numpy as np +import pytest from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix - -from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings +from ms2query.benchmarking.reference_methods.predict_best_possible_match import predict_best_possible_match from ms2query.benchmarking.reference_methods.predict_highest_cosine import predict_highest_cosine -from tests.conftest import create_test_spectra, ms2deepscore_model from ms2query.benchmarking.reference_methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore -from ms2query.benchmarking.reference_methods.predict_best_possible_match import predict_best_possible_match from ms2query.benchmarking.reference_methods.predict_with_integrated_similarity_flow import ( - predict_with_integrated_similarity_flow, integrated_similarity_flow, + predict_with_integrated_similarity_flow, ) - -import pytest +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings +from tests.conftest import create_test_spectra, ms2deepscore_model @pytest.mark.parametrize( diff --git a/tests/test_predict_using_closest_tanimoto.py b/tests/test_predict_using_closest_tanimoto.py index bdb9741..f6ebc31 100644 --- a/tests/test_predict_using_closest_tanimoto.py +++ b/tests/test_predict_using_closest_tanimoto.py @@ -1,12 +1,14 @@ import numpy as np - -from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings, SpectraWithFingerprints -from ms2query.benchmarking.reference_methods.predict_using_closest_tanimoto import ( - predict_using_closest_tanimoto, predict_using_closest_tanimoto_single_spectrum, - get_average_predictions_for_closely_related_metabolites, get_inchikey_and_tanimoto_scores_for_top_k, - select_inchikeys_with_highest_ms2deepscore) -from tests.conftest import ms2deepscore_model, create_test_spectra import pytest +from ms2query.benchmarking.reference_methods.predict_using_closest_tanimoto import ( + get_average_predictions_for_closely_related_metabolites, + get_inchikey_and_tanimoto_scores_for_top_k, + predict_using_closest_tanimoto, + predict_using_closest_tanimoto_single_spectrum, + select_inchikeys_with_highest_ms2deepscore, +) +from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectraWithMS2DeepScoreEmbeddings +from tests.conftest import create_test_spectra, ms2deepscore_model def test_predict_using_closest_tanimoto(): From aff52e4291c6cee7ea84625c2026c181ed5145a7 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Tue, 16 Dec 2025 09:57:37 +0100 Subject: [PATCH 40/45] Linting --- ms2query/benchmarking/EvaluateMethods.py | 4 ++-- ms2query/benchmarking/SpectrumDataSet.py | 3 ++- .../reference_methods/PredictMS2DeepScoreSimilarity.py | 3 ++- .../reference_methods/predict_using_closest_tanimoto.py | 6 +++--- .../predict_with_integrated_similarity_flow.py | 3 ++- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/ms2query/benchmarking/EvaluateMethods.py b/ms2query/benchmarking/EvaluateMethods.py index 17a28a5..2dcda5b 100644 --- a/ms2query/benchmarking/EvaluateMethods.py +++ b/ms2query/benchmarking/EvaluateMethods.py @@ -61,8 +61,8 @@ def benchmark_exact_matching_within_ionmode( For each inchikey with more than 1 spectrum the spectra are split in two sets. Half for each inchikey is added to the library (training set), for the other half predictions are made. Thereby there is always an exact match - avaialable. Only the highest ranked prediction is considered correct if the correct inchikey is predicted. An accuracy per - inchikey is calculated followed by calculating the average. + avaialable. Only the highest ranked prediction is considered correct if the correct inchikey is predicted. + An accuracy per inchikey is calculated followed by calculating the average. """ selected_spectra = subset_spectra_on_ionmode(self.validation_spectrum_set, ionmode) diff --git a/ms2query/benchmarking/SpectrumDataSet.py b/ms2query/benchmarking/SpectrumDataSet.py index 9b70dbb..240c38a 100644 --- a/ms2query/benchmarking/SpectrumDataSet.py +++ b/ms2query/benchmarking/SpectrumDataSet.py @@ -112,7 +112,8 @@ def subset_spectra(self, spectrum_indexes) -> "SpectraWithFingerprints": # Only keep the fingerprints for which we have inchikeys. # Important note: This is not a deep copy! # And the fingerprint is not reset (so it is not always actually matching the most common inchi) - new_instance.inchikey_fingerprint_pairs = {inchikey: self.inchikey_fingerprint_pairs[inchikey] for inchikey in new_instance.spectrum_indexes_per_inchikey.keys()} + new_instance.inchikey_fingerprint_pairs = {inchikey: self.inchikey_fingerprint_pairs[inchikey] for inchikey + in new_instance.spectrum_indexes_per_inchikey.keys()} return new_instance diff --git a/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py b/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py index a8082c1..cd20544 100644 --- a/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py +++ b/ms2query/benchmarking/reference_methods/PredictMS2DeepScoreSimilarity.py @@ -21,7 +21,8 @@ def predict_top_ms2deepscores( k: Number of highest matches to return Returns: - List[List[int]: indexes of highest scores and the value for the highest score. Per query embedding the top k highest indexes are given. + List[List[int]: indexes of highest scores and the value for the highest score. + Per query embedding the top k highest indexes are given. List[List[float]: the highest scores. """ top_indexes_per_batch = [] diff --git a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py index 7e12f52..751953e 100644 --- a/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py +++ b/ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py @@ -25,9 +25,9 @@ def predict_using_closest_tanimoto( return inchikeys_of_best_match, highest_scores -def predict_using_closest_tanimoto_single_spectrum(spectra_with_embeddings, single_spectrum_with_embeddings, - nr_of_closest_inchikeys_to_select, - nr_of_inchikeys_with_highest_ms2deepscore_to_select) -> Tuple[str, float]: +def predict_using_closest_tanimoto_single_spectrum( + spectra_with_embeddings, single_spectrum_with_embeddings, + nr_of_closest_inchikeys_to_select, nr_of_inchikeys_with_highest_ms2deepscore_to_select) -> Tuple[str, float]: if len(single_spectrum_with_embeddings.spectra) != 1: raise ValueError("expected a single spectrum") ms2deepscores = cosine_similarity_matrix(single_spectrum_with_embeddings.embeddings, diff --git a/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py b/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py index b89a906..e352ab3 100644 --- a/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py +++ b/ms2query/benchmarking/reference_methods/predict_with_integrated_similarity_flow.py @@ -82,7 +82,8 @@ def integrated_similarity_flow( predicted_scores: List[float], similarities: np.ndarray, nr_of_spectra_per_inchikey: List[float] ) -> List[float]: """Compute the confidence of the prediction for each candidate. - Integrated similarity flow (ISF) scores are calculated using the similarity of candidates among each other and their distance to the query spectrum. + Integrated similarity flow (ISF) scores are calculated using the similarity of candidates among each other + and their distance to the query spectrum. Args: distances (list): Distances of the candidates to the query spectrum in the chemical space. From cd30383b613c442ee4a936441476f2be843bd30a Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Tue, 16 Dec 2025 11:15:42 +0100 Subject: [PATCH 41/45] Lint notebooks --- .../Test_ann_speed_improvements.ipynb | 17 ++++++--- .../notebooks/Test_method_evaluator.ipynb | 35 ++++++++++++++++--- ...umber_of_inchikeys_with_two_ionmodes.ipynb | 1 + 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/ms2query/notebooks/Test_ann_speed_improvements.ipynb b/ms2query/notebooks/Test_ann_speed_improvements.ipynb index 9b84f73..e17d662 100644 --- a/ms2query/notebooks/Test_ann_speed_improvements.ipynb +++ b/ms2query/notebooks/Test_ann_speed_improvements.ipynb @@ -7,7 +7,9 @@ "metadata": {}, "outputs": [], "source": [ - "import sys \n", + "import sys\n", + "\n", + "\n", "sys.path.append(\"C:/Users/jonge094/PycharmProjects/ms2query_2_0/ms_chemical_space_explorer\")\n", "\n" ] @@ -31,6 +33,7 @@ "from matchms.importing import load_from_mgf\n", "from tqdm import tqdm\n", "\n", + "\n", "neg_val_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_validation_spectra.mgf\")))\n", "neg_test_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_testing_spectra.mgf\")))\n" ] @@ -52,8 +55,8 @@ } ], "source": [ - "from ms2deepscore.models import load_model\n", - "from ms2deepscore.models import compute_embedding_array\n", + "from ms2deepscore.models import compute_embedding_array, load_model\n", + "\n", "\n", "ms2deepscore_model = load_model(\"../../../ms2deepscore/data/pytorch/new_corinna_included/trained_models/both_mode_precursor_mz_ionmode_10000_layers_500_embedding_2024_11_21_11_23_17/ms2deepscore_model.pt\")\n", "\n", @@ -76,6 +79,8 @@ "outputs": [], "source": [ "import numpy as np\n", + "\n", + "\n", "more_test_embeddings = np.tile(embeddings, (70, 1))\n" ] }, @@ -115,8 +120,10 @@ } ], "source": [ - "import pynndescent\n", "import time\n", + "import pynndescent\n", + "\n", + "\n", "start_time = time.time()\n", "ann_model = pynndescent.NNDescent(more_test_embeddings, metric=\"cosine\", n_neighbors=30)\n", "ann_model.prepare()\n", @@ -242,6 +249,8 @@ ], "source": [ "from ms2deepscore.vector_operations import cosine_similarity_matrix\n", + "\n", + "\n", "start_time = time.time()\n", "matrix = cosine_similarity_matrix(more_test_embeddings, embeddings[:1000])\n", "print(\"Time eleapsed: \" + str(time.time() - start_time))" diff --git a/ms2query/notebooks/Test_method_evaluator.ipynb b/ms2query/notebooks/Test_method_evaluator.ipynb index 097c1d3..86aafed 100644 --- a/ms2query/notebooks/Test_method_evaluator.ipynb +++ b/ms2query/notebooks/Test_method_evaluator.ipynb @@ -26,9 +26,10 @@ } ], "source": [ + "import os\n", "from matchms.importing import load_from_mgf\n", "from tqdm import tqdm\n", - "import os\n", + "\n", "\n", "save_directory = \"../data/ms2deepscore_model/training_and_validation_split/\"\n", "pos_val_spectra = list(tqdm(load_from_mgf(os.path.join(save_directory, \"positive_validation_spectra.mgf\"))))\n", @@ -66,6 +67,8 @@ ], "source": [ "import os\n", + "\n", + "\n", "pickled_intermediates_data_folder = \"../data/pickled_intermediates\"\n", "os.path.isdir(pickled_intermediates_data_folder)" ] @@ -79,6 +82,7 @@ "source": [ "import pickle\n", "\n", + "\n", "with open(os.path.join(pickled_intermediates_data_folder, \"neg_val_spectra.pickle\"), \"wb\") as handle:\n", " pickle.dump(neg_val_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", "with open(os.path.join(pickled_intermediates_data_folder, \"neg_train_spectra.pickle\"), \"wb\") as handle:\n", @@ -109,6 +113,8 @@ "outputs": [], "source": [ "from ms2deepscore.models import load_model\n", + "\n", + "\n", "ms2deepscore_model = load_model(\"../data/ms2deepscore_model/trained_models/both_mode_ionmode_precursor_mz_2000_layers_500_embedding_2025_02_26_18_42_25/ms2deepscore_model.pt\")\n" ] }, @@ -119,7 +125,9 @@ "metadata": {}, "outputs": [], "source": [ - "import sys \n", + "import sys\n", + "\n", + "\n", "sys.path.append(\"../../ms_chemical_space_explorer\")" ] }, @@ -131,6 +139,8 @@ "outputs": [], "source": [ "from ms_chemical_space_explorer.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings\n", + "\n", + "\n", "library_spectra = SpectraWithMS2DeepScoreEmbeddings(neg_train_spectra + pos_train_spectra, ms2deepscore_model)" ] }, @@ -142,6 +152,8 @@ "outputs": [], "source": [ "import pickle\n", + "\n", + "\n", "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_library_embeddings.pickle\"), \"wb\") as handle:\n", " pickle.dump(library_spectra.embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)" ] @@ -154,6 +166,8 @@ "outputs": [], "source": [ "import pickle\n", + "\n", + "\n", "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_library_with_embeddings.pickle\"), \"wb\") as handle:\n", " pickle.dump(library_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)" ] @@ -178,6 +192,8 @@ "outputs": [], "source": [ "import pickle\n", + "\n", + "\n", "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_val_spectra_with_embeddings.pickle\"), \"wb\") as handle:\n", " pickle.dump(val_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)" ] @@ -197,7 +213,9 @@ "metadata": {}, "outputs": [], "source": [ - "import sys \n", + "import sys\n", + "\n", + "\n", "sys.path.append(\"../../ms_chemical_space_explorer\")" ] }, @@ -210,6 +228,8 @@ "source": [ "import os\n", "import pickle\n", + "\n", + "\n", "pickled_intermediates_data_folder = \"../data/pickled_intermediates\"\n", "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_library_with_embeddings.pickle\"), \"rb\") as file:\n", " library_spectra = pickle.load(file)\n", @@ -234,6 +254,7 @@ "source": [ "from ms_chemical_space_explorer.benchmarking.EvaluateMethods import EvaluateMethods\n", "\n", + "\n", "method_evaluator = EvaluateMethods(library_spectra, val_spectra)\n", "method_evaluator.training_spectrum_set.progress_bars = False\n", "method_evaluator.validation_spectrum_set.progress_bars = False" @@ -267,6 +288,7 @@ "source": [ "from ms_chemical_space_explorer.methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore\n", "\n", + "\n", "result_analogue = method_evaluator.benchmark_analogue_search(predict_highest_ms2deepscore)" ] }, @@ -310,6 +332,7 @@ "source": [ "from ms_chemical_space_explorer.methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore\n", "\n", + "\n", "result_positive = method_evaluator.benchmark_exact_matching_within_ionmode(predict_highest_ms2deepscore, \"positive\")" ] }, @@ -351,6 +374,7 @@ "source": [ "from ms_chemical_space_explorer.methods.predict_highest_ms2deepscore import predict_highest_ms2deepscore\n", "\n", + "\n", "result_neg = method_evaluator.benchmark_exact_matching_within_ionmode(predict_highest_ms2deepscore, \"negative\")" ] }, @@ -869,7 +893,10 @@ } ], "source": [ - "from ms_chemical_space_explorer.methods.predict_with_integrated_similarity_flow import predict_with_integrated_similarity_flow\n", + "from ms_chemical_space_explorer.methods.predict_with_integrated_similarity_flow import (\n", + " predict_with_integrated_similarity_flow,\n", + ")\n", + "\n", "\n", "result_analogue_isf = method_evaluator.benchmark_analogue_search(predict_with_integrated_similarity_flow)\n", "print(result_analogue_isf)\n", diff --git a/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb b/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb index dcad892..da86250 100644 --- a/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb +++ b/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb @@ -23,6 +23,7 @@ "from matchms.importing import load_from_mgf\n", "from tqdm import tqdm\n", "\n", + "\n", "neg_val_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_validation_spectra.mgf\")))\n", "neg_test_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_testing_spectra.mgf\")))\n", "neg_train_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_training_spectra.mgf\")))\n", From c97c5a62dfea4e550090c8199a1698289daba9a6 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Tue, 16 Dec 2025 11:21:48 +0100 Subject: [PATCH 42/45] Lint line length --- .../Test_ann_speed_improvements.ipynb | 56 +++---- .../notebooks/Test_method_evaluator.ipynb | 145 ++++-------------- ...umber_of_inchikeys_with_two_ionmodes.ipynb | 65 ++++---- tests/test_SpectrumDataSet.py | 3 +- 4 files changed, 83 insertions(+), 186 deletions(-) diff --git a/ms2query/notebooks/Test_ann_speed_improvements.ipynb b/ms2query/notebooks/Test_ann_speed_improvements.ipynb index e17d662..bacbad8 100644 --- a/ms2query/notebooks/Test_ann_speed_improvements.ipynb +++ b/ms2query/notebooks/Test_ann_speed_improvements.ipynb @@ -15,61 +15,45 @@ ] }, { - "cell_type": "code", - "execution_count": 16, - "id": "69613d68-6340-4985-b266-ebdb56b21271", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "7551it [00:05, 1452.14it/s]\n", - "7142it [00:04, 1502.96it/s]\n" - ] - } - ], + "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ "from matchms.importing import load_from_mgf\n", "from tqdm import tqdm\n", "\n", "\n", - "neg_val_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_validation_spectra.mgf\")))\n", - "neg_test_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_testing_spectra.mgf\")))\n" - ] + "neg_val_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/negative_validation_spectra.mgf\")))\n", + "neg_test_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/negative_testing_spectra.mgf\")))\n" + ], + "id": "1958d1e9d021cd45" }, { - "cell_type": "code", - "execution_count": 17, - "id": "55f31b78-e6b8-42a2-8ba0-d5cea429520d", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\jonge094\\AppData\\Local\\miniconda3\\envs\\ms2query2\\lib\\site-packages\\ms2deepscore\\models\\load_model.py:34: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", - " model_settings = torch.load(filename, map_location=device)\n", - "7551it [02:50, 44.16it/s]\n" - ] - } - ], + "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ "from ms2deepscore.models import compute_embedding_array, load_model\n", "\n", "\n", - "ms2deepscore_model = load_model(\"../../../ms2deepscore/data/pytorch/new_corinna_included/trained_models/both_mode_precursor_mz_ionmode_10000_layers_500_embedding_2024_11_21_11_23_17/ms2deepscore_model.pt\")\n", + "ms2deepscore_model = load_model(\"../../../ms2deepscore/data/pytorch/new_corinna_included/trained_models/\"\n", + " \"both_mode_precursor_mz_ionmode_10000_layers_500_embedding_2024_11_21_11_23_17/ms2deepscore_model.pt\")\n", "\n", "embeddings = compute_embedding_array(ms2deepscore_model, neg_val_spectra)" - ] + ], + "id": "fe20d28468de50ca" }, { - "cell_type": "code", - "execution_count": null, - "id": "1fe9d805-561f-49c5-87f4-ce56c374db42", "metadata": {}, + "cell_type": "code", "outputs": [], - "source": [] + "execution_count": null, + "source": "", + "id": "cee3357668c3f96e" }, { "cell_type": "code", diff --git a/ms2query/notebooks/Test_method_evaluator.ipynb b/ms2query/notebooks/Test_method_evaluator.ipynb index 86aafed..5303dfe 100644 --- a/ms2query/notebooks/Test_method_evaluator.ipynb +++ b/ms2query/notebooks/Test_method_evaluator.ipynb @@ -106,30 +106,32 @@ ] }, { - "cell_type": "code", - "execution_count": 8, - "id": "48bd37e7-384a-4bbc-821d-946975bfdf5b", "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "from ms2deepscore.models import load_model\n", "\n", "\n", - "ms2deepscore_model = load_model(\"../data/ms2deepscore_model/trained_models/both_mode_ionmode_precursor_mz_2000_layers_500_embedding_2025_02_26_18_42_25/ms2deepscore_model.pt\")\n" - ] + "ms2deepscore_model = load_model(\n", + " \"../data/ms2deepscore_model/trained_models/\"\n", + " \"both_mode_ionmode_precursor_mz_2000_layers_500_embedding_2025_02_26_18_42_25/ms2deepscore_model.pt\")\n" + ], + "id": "d4805b43bea1e3a2" }, { - "cell_type": "code", - "execution_count": 10, - "id": "9710c95d-803d-4492-b279-13588afa9c27", "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "import sys\n", "\n", "\n", "sys.path.append(\"../../ms_chemical_space_explorer\")" - ] + ], + "id": "9ca0b8c6639168e1" }, { "cell_type": "code", @@ -185,18 +187,19 @@ ] }, { - "cell_type": "code", - "execution_count": 22, - "id": "3ba69a93-5ede-4920-b3a6-e7881e20eb39", "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "import pickle\n", "\n", "\n", - "with open(os.path.join(pickled_intermediates_data_folder, \"neg_pos_val_spectra_with_embeddings.pickle\"), \"wb\") as handle:\n", + "with open(os.path.join(pickled_intermediates_data_folder,\n", + " \"neg_pos_val_spectra_with_embeddings.pickle\"), \"wb\") as handle:\n", " pickle.dump(val_spectra, handle, protocol=pickle.HIGHEST_PROTOCOL)" - ] + ], + "id": "e6835fbcb3d01494" }, { "cell_type": "markdown", @@ -804,94 +807,10 @@ ] }, { - "cell_type": "code", - "execution_count": 7, - "id": "1dc19bef-69ab-471e-94ab-6a9caffd9030", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "IOStream.flush timed outepscore per batch of 500 embeddings: 77%|██████████████████████████████████████████████████████████████████▏ | 60/78 [33:53<22:39, 75.54s/it]\n", - "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████| 78/78 [42:49<00:00, 32.95s/it]\n", - "Calculating ISF score: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 38533/38533 [00:44<00:00, 857.97it/s]\n", - "Calculating analogue accuracy per inchikey: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 2015/2015 [00:01<00:00, 1076.80it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.35576217271302824\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Splitting spectra per inchikey: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 924/924 [00:00<00:00, 202462.49it/s]\n", - "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:56<00:00, 25.28s/it]\n", - "Calculating ISF score: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3303/3303 [00:04<00:00, 713.49it/s]\n", - "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████| 8/8 [03:30<00:00, 26.27s/it]\n", - "Calculating ISF score: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3721/3721 [00:04<00:00, 856.75it/s]\n", - "Calculating exact match accuracy per inchikey: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 868/868 [00:00<00:00, 176636.55it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.04118290017821497\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Splitting spectra per inchikey: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1837/1837 [00:00<00:00, 134548.79it/s]\n", - "Predicting highest ms2deepscore per batch of 500 embeddings: 84%|███████████████████████████████████████████████████████████████████████▎ | 26/31 [38:00<38:36, 463.28s/it]IOStream.flush timed out\n", - "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████| 31/31 [41:05<00:00, 79.53s/it]\n", - "Calculating ISF score: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15218/15218 [00:20<00:00, 753.86it/s]\n", - "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████| 33/33 [12:25<00:00, 22.58s/it]\n", - "Calculating ISF score: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16010/16010 [00:22<00:00, 719.72it/s]\n", - "Calculating exact match accuracy per inchikey: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 1612/1612 [00:00<00:00, 236406.23it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.05296326960524268\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Splitting spectra per inchikey across ionmodes: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 2015/2015 [00:00<00:00, 3809.64it/s]\n", - "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████| 35/35 [13:13<00:00, 22.66s/it]\n", - "Calculating ISF score: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17128/17128 [00:21<00:00, 808.89it/s]\n", - "Predicting highest ms2deepscore per batch of 500 embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████| 13/13 [04:14<00:00, 19.57s/it]\n", - "Calculating ISF score: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6059/6059 [00:08<00:00, 714.94it/s]\n", - "Calculating exact match accuracy per inchikey: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 746/746 [00:00<00:00, 56840.41it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.0006614631666202545\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], + "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ "from ms_chemical_space_explorer.methods.predict_with_integrated_similarity_flow import (\n", " predict_with_integrated_similarity_flow,\n", @@ -900,23 +819,25 @@ "\n", "result_analogue_isf = method_evaluator.benchmark_analogue_search(predict_with_integrated_similarity_flow)\n", "print(result_analogue_isf)\n", - "result_neg_isf = method_evaluator.benchmark_exact_matching_within_ionmode(predict_with_integrated_similarity_flow, \"negative\")\n", + "result_neg_isf = method_evaluator.benchmark_exact_matching_within_ionmode(\n", + " predict_with_integrated_similarity_flow, \"negative\")\n", "print(result_neg_isf)\n", - "result_positive_isf = method_evaluator.benchmark_exact_matching_within_ionmode(predict_with_integrated_similarity_flow, \"positive\")\n", + "result_positive_isf = method_evaluator.benchmark_exact_matching_within_ionmode(\n", + " predict_with_integrated_similarity_flow, \"positive\")\n", "print(result_positive_isf)\n", - "result_across_ionmodes_isf = method_evaluator.exact_matches_across_ionization_modes(predict_with_integrated_similarity_flow)\n", + "result_across_ionmodes_isf = method_evaluator.exact_matches_across_ionization_modes(\n", + " predict_with_integrated_similarity_flow)\n", "print(result_across_ionmodes_isf)" - ] + ], + "id": "958a6efb624e75de" }, { - "cell_type": "code", - "execution_count": null, - "id": "1b891a34-aade-49e8-ad91-88b36cf0e011", "metadata": {}, + "cell_type": "code", "outputs": [], - "source": [ - "result_analogue_isf" - ] + "execution_count": null, + "source": "result_analogue_isf", + "id": "326a55ab85832258" }, { "cell_type": "code", @@ -966,4 +887,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb b/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb index da86250..271f9de 100644 --- a/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb +++ b/ms2query/notebooks/get_number_of_inchikeys_with_two_ionmodes.ipynb @@ -1,60 +1,51 @@ { "cells": [ { - "cell_type": "code", - "execution_count": 3, - "id": "df3d5e3f-2694-4eed-9fc1-a78483629412", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "7551it [00:07, 1015.57it/s]\n", - "7142it [00:12, 560.76it/s] \n", - "130901it [03:10, 685.40it/s] \n", - "25412it [00:32, 784.47it/s] \n", - "24911it [00:34, 718.09it/s] \n", - "25412it [00:45, 556.46it/s] \n" - ] - } - ], + "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ "from matchms.importing import load_from_mgf\n", "from tqdm import tqdm\n", "\n", "\n", - "neg_val_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_validation_spectra.mgf\")))\n", - "neg_test_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_testing_spectra.mgf\")))\n", - "neg_train_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/negative_training_spectra.mgf\")))\n", + "neg_val_spectra = list(tqdm(load_from_mgf(\n", + " \"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/\"\n", + " \"negative_validation_spectra.mgf\")))\n", + "neg_test_spectra = list(tqdm(load_from_mgf(\n", + " \"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/negative_testing_spectra.mgf\")))\n", + "neg_train_spectra = list(tqdm(load_from_mgf(\n", + " \"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/negative_training_spectra.mgf\")))\n", "neg_spectra = neg_val_spectra + neg_test_spectra + neg_train_spectra\n", "\n", - "pos_val_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/positive_validation_spectra.mgf\")))\n", - "pos_test_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/positive_testing_spectra.mgf\")))\n", - "pos_train_spectra = list(tqdm(load_from_mgf(\"../../../ms2deepscore/data/pytorch/new_corinna_included/training_and_validation_split/positive_training_spectra.mgf\")))\n", + "pos_val_spectra = list(tqdm(load_from_mgf(\n", + " \"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/positive_validation_spectra.mgf\")))\n", + "pos_test_spectra = list(tqdm(load_from_mgf(\n", + " \"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/positive_testing_spectra.mgf\")))\n", + "pos_train_spectra = list(tqdm(load_from_mgf(\n", + " \"../../../ms2deepscore/data/pytorch/new_corinna_included/\"\n", + " \"training_and_validation_split/positive_training_spectra.mgf\")))\n", "pos_spectra = pos_val_spectra + pos_test_spectra + pos_train_spectra" - ] + ], + "id": "fdf12d625e87127b" }, { - "cell_type": "code", - "execution_count": 5, - "id": "2771fc81-ad6b-4b2d-90ea-c541641af1a5", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 145594/145594 [00:25<00:00, 5637.57it/s]\n" - ] - } - ], + "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ "neg_inchikeys = []\n", "for spectrum in tqdm(neg_spectra):\n", " neg_inchikeys.append(spectrum.get(\"inchikey\")[:14])\n", " " - ] + ], + "id": "9ce097a34a9e6656" }, { "cell_type": "code", diff --git a/tests/test_SpectrumDataSet.py b/tests/test_SpectrumDataSet.py index 4244e99..9b1b56f 100644 --- a/tests/test_SpectrumDataSet.py +++ b/tests/test_SpectrumDataSet.py @@ -17,7 +17,8 @@ ], ) def test_spectrum_set_base(library): - """Test all base functionality of SpectrumSetBase is implemented correctly also for all classes inheriting from it""" + """Test all base functionality of SpectrumSetBase is implemented correctly + also for all classes inheriting from it""" # test correct init assert len(library.spectra) == 9 assert len(library.spectrum_indexes_per_inchikey) == 3 From fd688c4249e974faafa0f48a4871a5def8ce8363 Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Tue, 16 Dec 2025 11:24:04 +0100 Subject: [PATCH 43/45] Exclude notebooks from ruff linting --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 396af2e..fe9495c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,9 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] +exclude = [ + "ms2query/notebooks/", +] line-length = 120 output-format = "grouped" From 8d9a4b7eea0b01e48a6f726e385da1be2224593c Mon Sep 17 00:00:00 2001 From: niekdejonge Date: Wed, 17 Dec 2025 10:18:04 +0100 Subject: [PATCH 44/45] Change SpectrumSetBase to SpectrumSet --- ms2query/benchmarking/EvaluateMethods.py | 14 +++++++------- ms2query/benchmarking/SpectrumDataSet.py | 9 +++++---- tests/test_SpectrumDataSet.py | 10 +++++----- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/ms2query/benchmarking/EvaluateMethods.py b/ms2query/benchmarking/EvaluateMethods.py index 2dcda5b..edced8c 100644 --- a/ms2query/benchmarking/EvaluateMethods.py +++ b/ms2query/benchmarking/EvaluateMethods.py @@ -3,7 +3,7 @@ import numpy as np from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix from tqdm import tqdm -from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectrumSetBase +from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectrumSet class EvaluateMethods: @@ -107,7 +107,7 @@ def get_accuracy_recall_curve(self): def predict_between_two_sets( - library: SpectrumSetBase, query_set_1: SpectrumSetBase, query_set_2: SpectrumSetBase, prediction_function + library: SpectrumSet, query_set_1: SpectrumSet, query_set_2: SpectrumSet, prediction_function ): """Makes predictions between query sets and the library, with the other query set added. @@ -123,7 +123,7 @@ def predict_between_two_sets( return predicted_inchikeys_1 + predicted_inchikeys_2 -def calculate_average_exact_match_accuracy(spectrum_set: SpectrumSetBase, predicted_inchikeys: List[str]): +def calculate_average_exact_match_accuracy(spectrum_set: SpectrumSet, predicted_inchikeys: List[str]): if len(spectrum_set.spectra) != len(predicted_inchikeys): raise ValueError("The number of spectra should be equal to the number of predicted inchikeys ") exact_match_accuracy_per_inchikey = [] @@ -139,7 +139,7 @@ def calculate_average_exact_match_accuracy(spectrum_set: SpectrumSetBase, predic return sum(exact_match_accuracy_per_inchikey) / len(exact_match_accuracy_per_inchikey) -def split_spectrum_set_per_inchikeys(spectrum_set: SpectrumSetBase) -> Tuple[SpectrumSetBase, SpectrumSetBase]: +def split_spectrum_set_per_inchikeys(spectrum_set: SpectrumSet) -> Tuple[SpectrumSet, SpectrumSet]: """Splits a spectrum set into two. For each inchikey with more than one spectrum the spectra are divided over the two sets""" indexes_set_1 = [] @@ -157,8 +157,8 @@ def split_spectrum_set_per_inchikeys(spectrum_set: SpectrumSetBase) -> Tuple[Spe def split_spectrum_set_per_inchikey_across_ionmodes( - spectrum_set: SpectrumSetBase, -) -> Tuple[SpectrumSetBase, SpectrumSetBase]: + spectrum_set: SpectrumSet, +) -> Tuple[SpectrumSet, SpectrumSet]: """Splits a spectrum set in two sets on ionmode. Only uses spectra for inchikeys with at least 1 pos and 1 neg""" all_pos_indexes = [] all_neg_indexes = [] @@ -190,7 +190,7 @@ def split_spectrum_set_per_inchikey_across_ionmodes( return pos_val_spectra, neg_val_spectra -def subset_spectra_on_ionmode(spectrum_set: SpectrumSetBase, ionmode) -> SpectrumSetBase: +def subset_spectra_on_ionmode(spectrum_set: SpectrumSet, ionmode) -> SpectrumSet: spectrum_indexes_to_keep = [] for i, spectrum in enumerate(spectrum_set.spectra): if spectrum.get("ionmode") == ionmode: diff --git a/ms2query/benchmarking/SpectrumDataSet.py b/ms2query/benchmarking/SpectrumDataSet.py index 240c38a..eb9fe65 100644 --- a/ms2query/benchmarking/SpectrumDataSet.py +++ b/ms2query/benchmarking/SpectrumDataSet.py @@ -8,7 +8,8 @@ from tqdm import tqdm -class SpectrumSetBase: + +class SpectrumSet: """Stores a spectrum dataset making it easy and fast to split on molecules""" def __init__(self, spectra: List[Spectrum], progress_bars=False): @@ -36,10 +37,10 @@ def _add_spectra_and_group_per_inchikey(self, spectra: List[Spectrum]): ] return updated_inchikeys - def add_spectra(self, new_spectra: "SpectrumSetBase"): + def add_spectra(self, new_spectra: "SpectrumSet"): return self._add_spectra_and_group_per_inchikey(new_spectra.spectra) - def subset_spectra(self, spectrum_indexes) -> "SpectrumSetBase": + def subset_spectra(self, spectrum_indexes) -> "SpectrumSet": """Returns a new instance of a subset of the spectra""" new_instance = copy.copy(self) new_instance._spectra = [] @@ -65,7 +66,7 @@ def copy(self): return new_instance -class SpectraWithFingerprints(SpectrumSetBase): +class SpectraWithFingerprints(SpectrumSet): """Stores a spectrum dataset making it easy and fast to split on molecules""" def __init__(self, spectra: List[Spectrum], fingerprint_type="daylight", nbits=4096): diff --git a/tests/test_SpectrumDataSet.py b/tests/test_SpectrumDataSet.py index 9b1b56f..bb3e232 100644 --- a/tests/test_SpectrumDataSet.py +++ b/tests/test_SpectrumDataSet.py @@ -3,7 +3,7 @@ from ms2query.benchmarking.SpectrumDataSet import ( SpectraWithFingerprints, SpectraWithMS2DeepScoreEmbeddings, - SpectrumSetBase, + SpectrumSet, ) from tests.conftest import create_test_spectra, get_inchikey_inchi_pairs, ms2deepscore_model @@ -11,13 +11,13 @@ @pytest.mark.parametrize( "library", [ - SpectrumSetBase(create_test_spectra()), + SpectrumSet(create_test_spectra()), SpectraWithFingerprints(create_test_spectra()), SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(), ms2deepscore_model()), ], ) def test_spectrum_set_base(library): - """Test all base functionality of SpectrumSetBase is implemented correctly + """Test all base functionality of SpectrumSet is implemented correctly also for all classes inheriting from it""" # test correct init assert len(library.spectra) == 9 @@ -74,7 +74,7 @@ def test_spectra_with_fingerprints(library): (get_inchikey_inchi_pairs(3), 3), # Fully overlapping (get_inchikey_inchi_pairs(1), 3), # Fully overlapping (but not all) ): - spectra_to_add = SpectrumSetBase(create_test_spectra(2, inchikey_inchi_pairs=inchikey_inchi_pairs)) + spectra_to_add = SpectrumSet(create_test_spectra(2, inchikey_inchi_pairs=inchikey_inchi_pairs)) new_copy = library.copy() new_copy.add_spectra(spectra_to_add) assert len(new_copy.inchikey_fingerprint_pairs) == expected_nr_of_inchikeys @@ -125,6 +125,6 @@ def test_spectra_with_embeddings(): for i, index in enumerate(subset_indexes): assert np.all(library.embeddings[index] == subset.embeddings[i]) - # Check that subsetting on subset works. To make sure that a subset does not become of type SpectrumSetBase + # Check that subsetting on subset works. To make sure that a subset does not become of type SpectrumSet subsetted_subset = subset.subset_spectra([0, 1]) assert subsetted_subset.embeddings.shape == (2, 100) From ddff9a72f0a860ebbb2254100244170ae7521cff Mon Sep 17 00:00:00 2001 From: Niek de Jonge <76995965+niekdejonge@users.noreply.github.com> Date: Wed, 17 Dec 2025 13:32:54 +0100 Subject: [PATCH 45/45] Remove unnecessary blank line in SpectrumDataSet.py --- ms2query/benchmarking/SpectrumDataSet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ms2query/benchmarking/SpectrumDataSet.py b/ms2query/benchmarking/SpectrumDataSet.py index eb9fe65..d2f454d 100644 --- a/ms2query/benchmarking/SpectrumDataSet.py +++ b/ms2query/benchmarking/SpectrumDataSet.py @@ -8,7 +8,6 @@ from tqdm import tqdm - class SpectrumSet: """Stores a spectrum dataset making it easy and fast to split on molecules"""