Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
431e0a0
Merge pull request #1 from matchms/fix_pyproject
niekdejonge Dec 5, 2025
fa4afdd
Add SpectrumDataSet for handling sets of spectra, for method developm…
niekdejonge Dec 5, 2025
e882f00
Add EvaluateMethods a general method for benchmarking analogue search…
niekdejonge Dec 5, 2025
922bea8
Implement base line methods
niekdejonge Dec 5, 2025
40c6347
Add notebooks used for testing (still have to be adapted to work with…
niekdejonge Dec 5, 2025
9f67cd2
Add tests
niekdejonge Dec 5, 2025
4d81ddf
Add inits
niekdejonge Dec 5, 2025
8024263
Use ms2deepscore 2.6.0
niekdejonge Dec 8, 2025
a9db98e
add new helper methods
florian-huber Dec 8, 2025
7865da8
Merge branch 'main' of https://github.com/matchms/ms2query_2
florian-huber Dec 8, 2025
316531a
improve by batch querying
florian-huber Dec 8, 2025
c827742
adapt batch method and add further helper
florian-huber Dec 8, 2025
b0b7748
further batch querying method update
florian-huber Dec 8, 2025
92b434c
add fingerprint index
florian-huber Dec 8, 2025
bf9bbf2
larger linting/cleaning and batch implementation
florian-huber Dec 8, 2025
0c3c920
linting
florian-huber Dec 9, 2025
a59cc9a
add more getters
florian-huber Dec 9, 2025
ff6e2b6
fix by_ids methods and linting
florian-huber Dec 9, 2025
c3d7b62
add test and clean
florian-huber Dec 9, 2025
62eef51
cleaning and additional test
florian-huber Dec 9, 2025
b4561c0
Do correct subsetting of inchikey sets
niekdejonge Dec 9, 2025
657990d
fix
florian-huber Dec 9, 2025
170537f
cleaning, fixing, linting
florian-huber Dec 9, 2025
7f76997
Add method for predicting using top 10 closest library spectra.
niekdejonge Dec 10, 2025
2c2e35f
Added extra inchikey smiles examples
niekdejonge Dec 10, 2025
e646e22
refactoring/linting
florian-huber Dec 10, 2025
fb4cbd8
Add test_get_inchikey_and_tanimoto_scores_from_top_k
niekdejonge Dec 10, 2025
1933c85
Move top_k_selection outside average computation for easier testing
niekdejonge Dec 10, 2025
fd2bf89
Add test_get_average_predictions_for_closely_related_metabolites
niekdejonge Dec 10, 2025
59aca8e
Split select_inchikeys_with_highest_ms2deepscores to make more modula…
niekdejonge Dec 10, 2025
f223e43
Add test_select_inchikeys_with_highest_ms2deepscore
niekdejonge Dec 10, 2025
16058cc
Added nr_of_inchikeys_with_highest_ms2deepscore_to_select as parameter
niekdejonge Dec 10, 2025
1c2f1a7
Added basic tests for predict using closest tanimoto score (checking …
niekdejonge Dec 10, 2025
69f30e7
switch to batch/list handling for queries
florian-huber Dec 10, 2025
a681846
linting and adapt to list inputs
florian-huber Dec 10, 2025
146792e
clean up and switch to list inputs
florian-huber Dec 10, 2025
cf6a805
adjust tests
florian-huber Dec 10, 2025
4e208cf
add merge_fingerprint method
florian-huber Dec 11, 2025
0ca72ef
update ms2deepscore version
florian-huber Dec 11, 2025
276511c
Add tqdm to predict using closest tanimoto
niekdejonge Dec 16, 2025
a59e6a9
ruff
niekdejonge Dec 16, 2025
aff52e4
Linting
niekdejonge Dec 16, 2025
cd30383
Lint notebooks
niekdejonge Dec 16, 2025
c97c5a6
Lint line length
niekdejonge Dec 16, 2025
fd688c4
Exclude notebooks from ruff linting
niekdejonge Dec 16, 2025
48e9dea
Merge branch 'main' into add_benchmarking_method
niekdejonge Dec 16, 2025
8d9a4b7
Change SpectrumSetBase to SpectrumSet
niekdejonge Dec 17, 2025
ddff9a7
Remove unnecessary blank line in SpectrumDataSet.py
niekdejonge Dec 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions ms2query/benchmarking/EvaluateMethods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import random
from typing import Callable, List, Tuple
import numpy as np
from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix
from tqdm import tqdm
from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectrumSet


class EvaluateMethods:
def __init__(
self, training_spectrum_set: SpectraWithFingerprints, validation_spectrum_set: SpectraWithFingerprints
):
self.training_spectrum_set = training_spectrum_set
self.validation_spectrum_set = validation_spectrum_set

self.training_spectrum_set.progress_bars = False
self.validation_spectrum_set.progress_bars = False

def benchmark_analogue_search(
self,
prediction_function: Callable[
[SpectraWithFingerprints, SpectraWithFingerprints], Tuple[List[str], List[float]]
],
) -> float:
predicted_inchikeys, _ = prediction_function(self.training_spectrum_set, self.validation_spectrum_set)
average_scores_per_inchikey = []

# Calculate score per unique inchikey
for inchikey in tqdm(
self.validation_spectrum_set.spectrum_indexes_per_inchikey.keys(),
desc="Calculating analogue accuracy per inchikey",
):
matching_spectrum_indexes = self.validation_spectrum_set.spectrum_indexes_per_inchikey[inchikey]
prediction_scores = []
for index in matching_spectrum_indexes:
predicted_inchikey = predicted_inchikeys[index]
if predicted_inchikey is None:
prediction_scores.append(0.0)
else:
predicted_fingerprint = self.training_spectrum_set.inchikey_fingerprint_pairs[predicted_inchikey]
actual_fingerprint = self.validation_spectrum_set.inchikey_fingerprint_pairs[inchikey]
tanimoto_for_prediction = calculate_tanimoto_score_between_pair(
predicted_fingerprint, actual_fingerprint
)
prediction_scores.append(tanimoto_for_prediction)

average_prediction = sum(prediction_scores) / len(prediction_scores)
score = average_prediction
average_scores_per_inchikey.append(score)
average_over_all_inchikeys = sum(average_scores_per_inchikey) / len(average_scores_per_inchikey)
return average_over_all_inchikeys

def benchmark_exact_matching_within_ionmode(
self,
prediction_function: Callable[
[SpectraWithFingerprints, SpectraWithFingerprints], Tuple[List[str], List[float]]
],
ionmode: str,
) -> float:
"""Test the accuracy at retrieving exact matches from the library

For each inchikey with more than 1 spectrum the spectra are split in two sets. Half for each inchikey is added
to the library (training set), for the other half predictions are made. Thereby there is always an exact match
avaialable. Only the highest ranked prediction is considered correct if the correct inchikey is predicted.
An accuracy per inchikey is calculated followed by calculating the average.
"""
selected_spectra = subset_spectra_on_ionmode(self.validation_spectrum_set, ionmode)

set_1, set_2 = split_spectrum_set_per_inchikeys(selected_spectra)

predicted_inchikeys = predict_between_two_sets(self.training_spectrum_set, set_1, set_2, prediction_function)

# add the spectra to set_1
set_1.add_spectra(set_2)
return calculate_average_exact_match_accuracy(set_1, predicted_inchikeys)

def exact_matches_across_ionization_modes(
self,
prediction_function: Callable[
[SpectraWithFingerprints, SpectraWithFingerprints], Tuple[List[str], List[float]]
],
):
"""Test the accuracy at retrieving exact matches from the library if only available in other ionisation mode

Each val spectrum is matched against the training set with the other val spectra of the same inchikey, but other
ionisation mode added to the library.
"""
pos_set, neg_set = split_spectrum_set_per_inchikey_across_ionmodes(self.validation_spectrum_set)
predicted_inchikeys = predict_between_two_sets(
self.training_spectrum_set, pos_set, neg_set, prediction_function
)
# add the spectra to set_1
pos_set.add_spectra(neg_set)
return calculate_average_exact_match_accuracy(pos_set, predicted_inchikeys)

def get_accuracy_recall_curve(self):
"""This method should test the recall accuracy balance.
All of the used methods use a threshold which indicates quality of prediction.
A method that can predict well when a prediction is accurate is beneficial.
We need a method to test this.

One method is generating a recall accuracy curve. This could be done for both the analogue search predictions
and the exact match predictions. By returning the predicted score for a match this method could create an
accuracy recall plot.
"""
raise NotImplementedError


def predict_between_two_sets(
library: SpectrumSet, query_set_1: SpectrumSet, query_set_2: SpectrumSet, prediction_function
):
"""Makes predictions between query sets and the library, with the other query set added.

This is necessary for testing exact matching"""
training_set_copy = library.copy()
training_set_copy.add_spectra(query_set_2)
predicted_inchikeys_1, _ = prediction_function(training_set_copy, query_set_1)

training_set_copy = library.copy()
training_set_copy.add_spectra(query_set_1)
predicted_inchikeys_2, _ = prediction_function(training_set_copy, query_set_2)

return predicted_inchikeys_1 + predicted_inchikeys_2


def calculate_average_exact_match_accuracy(spectrum_set: SpectrumSet, predicted_inchikeys: List[str]):
if len(spectrum_set.spectra) != len(predicted_inchikeys):
raise ValueError("The number of spectra should be equal to the number of predicted inchikeys ")
exact_match_accuracy_per_inchikey = []
for inchikey in tqdm(
spectrum_set.spectrum_indexes_per_inchikey.keys(), desc="Calculating exact match accuracy per inchikey"
):
val_spectrum_indexes_matching_inchikey = spectrum_set.spectrum_indexes_per_inchikey[inchikey]
correctly_predicted = 0
for selected_spectrum_idx in val_spectrum_indexes_matching_inchikey:
if inchikey == predicted_inchikeys[selected_spectrum_idx]:
correctly_predicted += 1
exact_match_accuracy_per_inchikey.append(correctly_predicted / len(val_spectrum_indexes_matching_inchikey))
return sum(exact_match_accuracy_per_inchikey) / len(exact_match_accuracy_per_inchikey)


def split_spectrum_set_per_inchikeys(spectrum_set: SpectrumSet) -> Tuple[SpectrumSet, SpectrumSet]:
"""Splits a spectrum set into two.
For each inchikey with more than one spectrum the spectra are divided over the two sets"""
indexes_set_1 = []
indexes_set_2 = []
for inchikey in tqdm(spectrum_set.spectrum_indexes_per_inchikey.keys(), desc="Splitting spectra per inchikey"):
val_spectrum_indexes_matching_inchikey = spectrum_set.spectrum_indexes_per_inchikey[inchikey]
if len(val_spectrum_indexes_matching_inchikey) == 1:
# all single spectra are excluded from this test, since no exact match can be added to the library
continue
split_index = len(val_spectrum_indexes_matching_inchikey) // 2
random.shuffle(val_spectrum_indexes_matching_inchikey)
indexes_set_1.extend(val_spectrum_indexes_matching_inchikey[:split_index])
indexes_set_2.extend(val_spectrum_indexes_matching_inchikey[split_index:])
return spectrum_set.subset_spectra(indexes_set_1), spectrum_set.subset_spectra(indexes_set_2)


def split_spectrum_set_per_inchikey_across_ionmodes(
spectrum_set: SpectrumSet,
) -> Tuple[SpectrumSet, SpectrumSet]:
"""Splits a spectrum set in two sets on ionmode. Only uses spectra for inchikeys with at least 1 pos and 1 neg"""
all_pos_indexes = []
all_neg_indexes = []
for inchikey in tqdm(
spectrum_set.spectrum_indexes_per_inchikey.keys(),
desc="Splitting spectra per inchikey across ionmodes",
):
val_spectrum_indexes_matching_inchikey = spectrum_set.spectrum_indexes_per_inchikey[inchikey]
positive_val_spectrum_indexes_current_inchikey = []
negative_val_spectrum_indexes_current_inchikey = []
for spectrum_index in val_spectrum_indexes_matching_inchikey:
ionmode = spectrum_set.spectra[spectrum_index].get("ionmode")
if ionmode == "positive":
positive_val_spectrum_indexes_current_inchikey.append(spectrum_index)
elif ionmode == "negative":
negative_val_spectrum_indexes_current_inchikey.append(spectrum_index)

if (
len(positive_val_spectrum_indexes_current_inchikey) < 1
or len(negative_val_spectrum_indexes_current_inchikey) < 1
):
continue
else:
all_pos_indexes.extend(positive_val_spectrum_indexes_current_inchikey)
all_neg_indexes.extend(negative_val_spectrum_indexes_current_inchikey)

pos_val_spectra = spectrum_set.subset_spectra(all_pos_indexes)
neg_val_spectra = spectrum_set.subset_spectra(all_neg_indexes)
return pos_val_spectra, neg_val_spectra


def subset_spectra_on_ionmode(spectrum_set: SpectrumSet, ionmode) -> SpectrumSet:
spectrum_indexes_to_keep = []
for i, spectrum in enumerate(spectrum_set.spectra):
if spectrum.get("ionmode") == ionmode:
spectrum_indexes_to_keep.append(i)
return spectrum_set.subset_spectra(spectrum_indexes_to_keep)


def calculate_tanimoto_score_between_pair(fingerprint_1: str, fingerprint_2: str) -> float:
return jaccard_similarity_matrix(np.array([fingerprint_1]), np.array([fingerprint_2]))[0][0]
138 changes: 138 additions & 0 deletions ms2query/benchmarking/SpectrumDataSet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import copy
from collections import Counter
from typing import Dict, Iterable, List
import numpy as np
from matchms import Spectrum
from matchms.filtering.metadata_processing.add_fingerprint import _derive_fingerprint_from_inchi
from ms2deepscore.models import SiameseSpectralModel, compute_embedding_array
from tqdm import tqdm


class SpectrumSet:
"""Stores a spectrum dataset making it easy and fast to split on molecules"""

def __init__(self, spectra: List[Spectrum], progress_bars=False):
self._spectra = []
self.spectrum_indexes_per_inchikey = {}
self.progress_bars = progress_bars
# init spectra
self._add_spectra_and_group_per_inchikey(spectra)

def _add_spectra_and_group_per_inchikey(self, spectra: List[Spectrum]):
starting_index = len(self._spectra)
updated_inchikeys = set()
for i, spectrum in enumerate(
tqdm(spectra, desc="Adding spectra and grouping per Inchikey", disable=not self.progress_bars)
):
self._spectra.append(spectrum)
spectrum_index = starting_index + i
inchikey = spectrum.get("inchikey")[:14]
updated_inchikeys.add(inchikey)
if inchikey in self.spectrum_indexes_per_inchikey:
self.spectrum_indexes_per_inchikey[inchikey].append(spectrum_index)
else:
self.spectrum_indexes_per_inchikey[inchikey] = [
spectrum_index,
]
return updated_inchikeys

def add_spectra(self, new_spectra: "SpectrumSet"):
return self._add_spectra_and_group_per_inchikey(new_spectra.spectra)

def subset_spectra(self, spectrum_indexes) -> "SpectrumSet":
"""Returns a new instance of a subset of the spectra"""
new_instance = copy.copy(self)
new_instance._spectra = []
new_instance.spectrum_indexes_per_inchikey = {}
new_instance._add_spectra_and_group_per_inchikey([self._spectra[index] for index in spectrum_indexes])
return new_instance

def spectra_per_inchikey(self, inchikey) -> List[Spectrum]:
matching_spectra = []
for index in self.spectrum_indexes_per_inchikey[inchikey]:
matching_spectra.append(self._spectra[index])
return matching_spectra

@property
def spectra(self):
return self._spectra

def copy(self):
"""This copy method ensures all spectra are"""
new_instance = copy.copy(self)
new_instance._spectra = self._spectra.copy()
new_instance.spectrum_indexes_per_inchikey = copy.deepcopy(self.spectrum_indexes_per_inchikey)
return new_instance


class SpectraWithFingerprints(SpectrumSet):
"""Stores a spectrum dataset making it easy and fast to split on molecules"""

def __init__(self, spectra: List[Spectrum], fingerprint_type="daylight", nbits=4096):
super().__init__(spectra)
self.fingerprint_type = fingerprint_type
self.nbits = nbits
self.inchikey_fingerprint_pairs: Dict[str, np.array] = {}
# init spectra
self.update_fingerprint_per_inchikey(self.spectrum_indexes_per_inchikey.keys())

def add_spectra(self, new_spectra: "SpectraWithFingerprints"):
updated_inchikeys = super().add_spectra(new_spectra)
if hasattr(new_spectra, "inchikey_fingerprint_pairs"):
if new_spectra.nbits == self.nbits and new_spectra.fingerprint_type == self.fingerprint_type:
if len(self.inchikey_fingerprint_pairs.keys() & new_spectra.inchikey_fingerprint_pairs.keys()) == 0:
self.inchikey_fingerprint_pairs = (
self.inchikey_fingerprint_pairs | new_spectra.inchikey_fingerprint_pairs
)
return
self.update_fingerprint_per_inchikey(updated_inchikeys)

def update_fingerprint_per_inchikey(self, inchikeys_to_update: Iterable[str]):
for inchikey in tqdm(
inchikeys_to_update, desc="Adding fingerprints to Inchikeys", disable=not self.progress_bars
):
spectra = self.spectra_per_inchikey(inchikey)
most_common_inchi = Counter([spectrum.get("inchi") for spectrum in spectra]).most_common(1)[0][0]
fingerprint = _derive_fingerprint_from_inchi(
most_common_inchi, fingerprint_type=self.fingerprint_type, nbits=self.nbits
)
if not isinstance(fingerprint, np.ndarray):
raise ValueError(f"Fingerprint could not be set for InChI: {most_common_inchi}")
self.inchikey_fingerprint_pairs[inchikey] = fingerprint

def copy(self):
"""This copy method ensures all spectra are"""
new_instance = super().copy()
new_instance.inchikey_fingerprint_pairs = copy.copy(self.inchikey_fingerprint_pairs)
return new_instance

def subset_spectra(self, spectrum_indexes) -> "SpectraWithFingerprints":
"""Returns a new instance of a subset of the spectra"""
new_instance = super().subset_spectra(spectrum_indexes)
# Only keep the fingerprints for which we have inchikeys.
# Important note: This is not a deep copy!
# And the fingerprint is not reset (so it is not always actually matching the most common inchi)
new_instance.inchikey_fingerprint_pairs = {inchikey: self.inchikey_fingerprint_pairs[inchikey] for inchikey
in new_instance.spectrum_indexes_per_inchikey.keys()}
return new_instance


class SpectraWithMS2DeepScoreEmbeddings(SpectraWithFingerprints):
def __init__(self, spectra: List[Spectrum], ms2deepscore_model: SiameseSpectralModel, **kwargs):
super().__init__(spectra, **kwargs)
self.ms2deepscore_model = ms2deepscore_model
self.embeddings: np.ndarray = compute_embedding_array(self.ms2deepscore_model, spectra)

def add_spectra(self, new_spectra: "SpectraWithMS2DeepScoreEmbeddings"):
super().add_spectra(new_spectra)
if hasattr(new_spectra, "embeddings"):
new_embeddings = new_spectra.embeddings
else:
new_embeddings = compute_embedding_array(self.ms2deepscore_model, new_spectra.spectra)
self.embeddings = np.vstack([self.embeddings, new_embeddings])

def subset_spectra(self, spectrum_indexes) -> "SpectraWithMS2DeepScoreEmbeddings":
"""Returns a new instance of a subset of the spectra"""
new_instance = super().subset_spectra(spectrum_indexes)
new_instance.embeddings = self.embeddings[spectrum_indexes]
return new_instance
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Tuple
import numpy as np
from ms2deepscore.vector_operations import cosine_similarity_matrix
from tqdm import tqdm


def predict_top_ms2deepscores(
library_embeddings: np.ndarray, query_embeddings: np.ndarray, batch_size: int = 500, k=1
) -> Tuple[np.ndarray, np.ndarray]:
"""Memory efficient way of calculating the highest MS2DeepScores

When doing large matrix multiplications the memory footprint of storing the output matrix can be large.
E.g. when doing 500.000 vs 10.000 spectra this is a very large matrix. On a laptop this can result in very
slow run times. If only the highest MS2DeepScore is needed processing in batches prevents using too much memory

Args:
library_embeddings: The embeddings of the library spectra
query_embeddings: The embeddings of the query spectra
batch_size: The number of query embeddings processed at the same time.
Setting a lower batch_size results in a lower memory footprint.
k: Number of highest matches to return

Returns:
List[List[int]: indexes of highest scores and the value for the highest score.
Per query embedding the top k highest indexes are given.
List[List[float]: the highest scores.
"""
top_indexes_per_batch = []
top_scores_per_batch = []
num_of_query_embeddings = query_embeddings.shape[0]
# loop over the batches
for start_idx in tqdm(
range(0, num_of_query_embeddings, batch_size),
desc="Predicting highest ms2deepscore per batch of "
+ str(min(batch_size, num_of_query_embeddings))
+ " embeddings",
):
end_idx = min(start_idx + batch_size, num_of_query_embeddings)
selected_query_embeddings = query_embeddings[start_idx:end_idx]
score_matrix = cosine_similarity_matrix(selected_query_embeddings, library_embeddings)
top_n_idx = np.argsort(score_matrix, axis=1)[:, -k:][:, ::-1]
top_n_scores = np.take_along_axis(score_matrix, top_n_idx, axis=1)
top_indexes_per_batch.append(top_n_idx)
top_scores_per_batch.append(top_n_scores)
return np.vstack(top_indexes_per_batch), np.vstack(top_scores_per_batch)
Empty file.
Loading
Loading