diff --git a/README.md b/README.md index 4fe0e73..5c1ec45 100644 --- a/README.md +++ b/README.md @@ -2,28 +2,11 @@ # MS2Query 2.0 -more to come... - -## Basic workflow (so far): - -### Library generation -```python -from ms2query.create_new_library import create_new_library - -ms2query_lib = create_new_library( - spectra_files=["spectra.mgf"], - annotation_files=[], - output_folder="my_ms2query_folder/", - model_path="models/ms2deepscore.pt" -) -``` - -### Loading already generated library -```python -from ms2query.create_new_library import load_created_library - -lib = load_created_library("my_ms2query_folder/") -``` +A first basic implementation is out now, more to follow soon... +The new MS2Query appraoch has a higher accuracy and has a much simpler and faster underlying algorithm. We will hopefully soon share a first preprint as well, showing all the benchmarking. +The current runably version still requires to create the library files, which takes some time for the first run. +Soon this will be much easier and faster. We will add downloadable precomputed files, make MS2Query pip installable, add a database and allow faster MS2DeepScore searching. +The tutorial for the current prototype can be found in notebooks/tutorial. diff --git a/ms2query/benchmarking/reference_methods/readme.md b/ms2query/benchmarking/reference_methods/readme.md deleted file mode 100644 index 18eab2f..0000000 --- a/ms2query/benchmarking/reference_methods/readme.md +++ /dev/null @@ -1 +0,0 @@ -All files here are not important to core functionality and in fact not used, but might be nice to use for future testing. They are currently not yet used in the notebooks. Once they are used there they can probably be removed here. diff --git a/ms2query/benchmarking/AnnotatedSpectrumSet.py b/ms2query/ms2query_development/AnnotatedSpectrumSet.py similarity index 99% rename from ms2query/benchmarking/AnnotatedSpectrumSet.py rename to ms2query/ms2query_development/AnnotatedSpectrumSet.py index 93ff130..bda0601 100644 --- a/ms2query/benchmarking/AnnotatedSpectrumSet.py +++ b/ms2query/ms2query_development/AnnotatedSpectrumSet.py @@ -6,7 +6,7 @@ from matchms.importing import load_spectra from ms2deepscore.models import SiameseSpectralModel from tqdm import tqdm -from ms2query.benchmarking.Embeddings import Embeddings +from ms2query.ms2query_development.Embeddings import Embeddings class AnnotatedSpectrumSet: diff --git a/ms2query/benchmarking/Embeddings.py b/ms2query/ms2query_development/Embeddings.py similarity index 100% rename from ms2query/benchmarking/Embeddings.py rename to ms2query/ms2query_development/Embeddings.py diff --git a/ms2query/benchmarking/Fingerprints.py b/ms2query/ms2query_development/Fingerprints.py similarity index 98% rename from ms2query/benchmarking/Fingerprints.py rename to ms2query/ms2query_development/Fingerprints.py index 4d7a9c1..4e20cce 100644 --- a/ms2query/benchmarking/Fingerprints.py +++ b/ms2query/ms2query_development/Fingerprints.py @@ -5,8 +5,8 @@ from matchms.filtering.metadata_processing.add_fingerprint import _derive_fingerprint_from_inchi from numpy.typing import NDArray from tqdm import tqdm -from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet from ms2query.metrics import generalized_tanimoto_similarity_matrix +from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet class Fingerprints: diff --git a/ms2query/ms2query_development/ReferenceLibrary.py b/ms2query/ms2query_development/ReferenceLibrary.py new file mode 100644 index 0000000..d9fa9b4 --- /dev/null +++ b/ms2query/ms2query_development/ReferenceLibrary.py @@ -0,0 +1,220 @@ +from collections import defaultdict +from pathlib import Path +from typing import Sequence +import numpy as np +import pandas as pd +from matchms.importing import load_spectra +from matchms.Spectrum import Spectrum +from ms2deepscore.models import SiameseSpectralModel, load_model +from ms2deepscore.vector_operations import cosine_similarity_matrix +from tqdm import tqdm +from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet +from ms2query.ms2query_development.Embeddings import Embeddings, _to_json_serializable +from ms2query.ms2query_development.Fingerprints import Fingerprints +from ms2query.ms2query_development.TopKTanimotoScores import TopKTanimotoScores + + +class ReferenceLibrary: + # Set default file names to enable save and load per library + embedding_file_name = "embeddings.npz" + top_k_tanimoto_scores_file_name = "top_k_tanimoto_scores.parquet" + reference_metadata_file_name = "library_metadata.parquet" + ms2deepscore_model_file_name = "ms2deepscore_model.pt" + metadata_to_store = [ + "precursor_mz", + "retention_time", + "collision_energy", + "compound_name", + "smiles", + "inchikey", + ] + fingerprint_type = "daylight" + fingerprint_nbits = 4096 + top_k_inchikeys = 8 + + def __init__( + self, + ms2deepscore_model: SiameseSpectralModel, + reference_embeddings: Embeddings, + top_k_tanimoto_scores: TopKTanimotoScores, + reference_metadata: pd.DataFrame, + ): + self.ms2deepscore_model = ms2deepscore_model + self.reference_embeddings = reference_embeddings + self.top_k_tanimoto_scores = top_k_tanimoto_scores + self.reference_metadata = reference_metadata + + # Check that the loaded files match + if _to_json_serializable(ms2deepscore_model.model_settings.get_dict()) != reference_embeddings.model_settings: + raise ValueError( + "The settings of the ms2deepscore model do not match the model used for creating the library embeddings" + ) + if list(self.reference_metadata["spectrum_hashes"]) != [ + str(spectrum_hash) for spectrum_hash in reference_embeddings.index_to_spectrum_hash + ]: + raise ValueError("The loaded metadata does not match the used embeddings") + if {inchikey[:14] for inchikey in reference_metadata["inchikey"]} != set( + top_k_tanimoto_scores.top_k_inchikeys_and_scores.index + ): + raise ValueError("The inchikeys in the metadata and in the top_k_tanimoto_scores do not match") + + # Get the spectrum_indices_per_inchikey + self.spectrum_indices_per_inchikey = defaultdict(list) + for lib_spec_index, inchikey in enumerate(reference_metadata["inchikey"]): + self.spectrum_indices_per_inchikey[inchikey[:14]].append(lib_spec_index) + + @classmethod + def load_from_directory(cls, library_file_directory) -> "ReferenceLibrary": + reference_embeddings_file = library_file_directory / cls.embedding_file_name + top_k_tanimoto_scores_file = library_file_directory / cls.top_k_tanimoto_scores_file_name + reference_metadata_file = library_file_directory / cls.reference_metadata_file_name + ms2deepscore_model_file_name = library_file_directory / cls.ms2deepscore_model_file_name + return cls.load_from_files( + ms2deepscore_model_file_name, reference_embeddings_file, top_k_tanimoto_scores_file, reference_metadata_file + ) + + @classmethod + def load_from_files( + cls, + ms2deepscore_model_file_name, + reference_embeddings_file, + top_k_tanimoto_scores_file, + reference_metadata_file, + ) -> "ReferenceLibrary": + return cls( + load_model(ms2deepscore_model_file_name), + Embeddings.load(reference_embeddings_file), + TopKTanimotoScores.load(top_k_tanimoto_scores_file), + pd.read_parquet(reference_metadata_file), + ) + + @classmethod + def create_from_spectra( + cls, + library_spectra: Sequence[Spectrum], + ms2deepscore_model_file_name: str, + store_file_directory=None, + store_files=True, + ) -> "ReferenceLibrary": + """Creates all the files needed for MS2Query and stores them""" + if store_file_directory is None: + store_file_directory = Path(ms2deepscore_model_file_name).parent + else: + store_file_directory = Path(store_file_directory) + if store_files: + # Check the files don't exist yet + for file in ( + store_file_directory / cls.embedding_file_name, + store_file_directory / cls.top_k_tanimoto_scores_file_name, + store_file_directory / cls.reference_metadata_file_name, + ): + if file.exists(): + raise FileExistsError(f"There is already a file stored with the name {file}") + + # library_spectra = list(tqdm(load_spectra(library_spectra_file), "Loading library spectra")) + library_spectrum_set = AnnotatedSpectrumSet.create_spectrum_set(library_spectra) + ms2deepscore_model = load_model(ms2deepscore_model_file_name) + library_spectrum_set.add_embeddings(ms2deepscore_model) + + fingerprints = Fingerprints.from_spectrum_set(library_spectrum_set, cls.fingerprint_type, cls.fingerprint_nbits) + top_k_tanimoto_scores = TopKTanimotoScores.calculate_from_fingerprints( + fingerprints, fingerprints, cls.top_k_inchikeys + ) + reference_metadata = extract_metadata_from_library( + library_spectrum_set, + cls.metadata_to_store, + ) + + if store_files: + reference_metadata.to_parquet(store_file_directory / cls.reference_metadata_file_name) + top_k_tanimoto_scores.save(store_file_directory / cls.top_k_tanimoto_scores_file_name) + library_spectrum_set.embeddings.save(store_file_directory / cls.embedding_file_name) + return cls(ms2deepscore_model, library_spectrum_set.embeddings, top_k_tanimoto_scores, reference_metadata) + + def run_ms2query( + self, + query_spectra: Sequence[Spectrum], + batch_size: int = 1000, + ) -> pd.DataFrame: + + query_embeddings = Embeddings.create_from_spectra(query_spectra, self.ms2deepscore_model) + + num_of_query_embeddings = query_embeddings.embeddings.shape[0] + + library_index_highest_ms2deepscore = np.zeros((num_of_query_embeddings), dtype=int) + ms2query_scores = [] + for start_idx in tqdm( + range(0, num_of_query_embeddings, batch_size), + desc="Predicting highest ms2deepscore per batch of " + + str(min(batch_size, num_of_query_embeddings)) + + " embeddings", + ): + # Do MS2DeepScore predictions for batch + end_idx = min(start_idx + batch_size, num_of_query_embeddings) + selected_query_embeddings = query_embeddings.embeddings[start_idx:end_idx] + score_matrix = cosine_similarity_matrix(selected_query_embeddings, self.reference_embeddings.embeddings) + highest_score_idx = np.argmax(score_matrix, axis=1) + library_index_highest_ms2deepscore[start_idx:end_idx] = highest_score_idx + + # get predicted inchikeys + predicted_inchikeys = self.reference_metadata.iloc[highest_score_idx]["inchikey"] + # Compute MS2Query reliability score + ms2query_scores.extend( + get_ms2query_reliability_prediction( + predicted_inchikeys, self.spectrum_indices_per_inchikey, self.top_k_tanimoto_scores, score_matrix + ) + ) + + # construct results df + results = self.reference_metadata.iloc[library_index_highest_ms2deepscore] + results["ms2query_reliability_prediction"] = ms2query_scores + return results + + +def run_ms2query_from_files( + query_spectrum_file, + ms2deepscore_model_file_name, + reference_embeddings_file, + top_k_tanimoto_scores_file, + reference_metadata_file, + save_file_location, +): + ms2query_library = ReferenceLibrary.load_from_files( + ms2deepscore_model_file_name, + reference_embeddings_file, + top_k_tanimoto_scores_file, + reference_metadata_file, + ) + + query_spectra = list(tqdm(load_spectra(query_spectrum_file), desc="loading_in_query_spectra")) + results_df = ms2query_library.run_ms2query(query_spectra) + results_df.to_csv(save_file_location) + + +def get_ms2query_reliability_prediction( + predicted_inchikeys: list[str], + spectrum_indices_per_inchikey, + top_k_tanimoto_scores: TopKTanimotoScores, + ms2deepscore_score_matrix, +) -> list[float]: + ms2query_scores = [] + for query_spectrum_index, library_inchikey in enumerate(predicted_inchikeys): + top_k_inchikeys = top_k_tanimoto_scores.select_top_k_inchikeys(library_inchikey[:14]) + maximum_ms2deepscores = np.zeros(top_k_tanimoto_scores.k, dtype=float) + for i, inchikey in enumerate(top_k_inchikeys): + spectrum_indexes = spectrum_indices_per_inchikey[inchikey] + highest_ms2deepscore = np.max(ms2deepscore_score_matrix[query_spectrum_index, spectrum_indexes]) + maximum_ms2deepscores[i] = highest_ms2deepscore + ms2query_scores.append(np.mean(maximum_ms2deepscores)) + # todo get the spectrum hashes instead of the indexes for lookup later. + return ms2query_scores + + +def extract_metadata_from_library(spectra: AnnotatedSpectrumSet, metadata_to_collect: list): + collected_metadata = {key: [] for key in metadata_to_collect} + collected_metadata["spectrum_hashes"] = [] + for spectrum in tqdm(spectra.spectra, desc="Extracting metadata df from spectra"): + for metadata_key in metadata_to_collect: + collected_metadata[metadata_key].append(spectrum.get(metadata_key)) + collected_metadata["spectrum_hashes"].append(str(spectrum.__hash__())) + return pd.DataFrame(collected_metadata) diff --git a/ms2query/benchmarking/TopKTanimotoScores.py b/ms2query/ms2query_development/TopKTanimotoScores.py similarity index 69% rename from ms2query/benchmarking/TopKTanimotoScores.py rename to ms2query/ms2query_development/TopKTanimotoScores.py index 07c0912..c930ecc 100644 --- a/ms2query/benchmarking/TopKTanimotoScores.py +++ b/ms2query/ms2query_development/TopKTanimotoScores.py @@ -1,7 +1,8 @@ +from pathlib import Path import numpy as np import pandas as pd -from ms2query.benchmarking.Fingerprints import Fingerprints from ms2query.metrics import generalized_tanimoto_similarity_matrix +from ms2query.ms2query_development.Fingerprints import Fingerprints class TopKTanimotoScores: @@ -27,13 +28,21 @@ def _create_multi_index( combined_data = np.empty((len(inchikey_indexes), self.k * 2), dtype=object) combined_data[:, 0::2] = top_k_inchikeys combined_data[:, 1::2] = tanimoto_scores_for_top_k - return pd.DataFrame(combined_data, index=inchikey_indexes, columns=columns) + df = pd.DataFrame(combined_data, index=inchikey_indexes, columns=columns) + + # Cast score columns to float64 + score_cols = [(rank, "score") for rank in [f"Rank_{i + 1}" for i in range(self.k)]] + df[score_cols] = df[score_cols].astype(float) + + return df @classmethod def calculate_from_fingerprints(cls, query_fingerprints: Fingerprints, target_fingerprints: Fingerprints, k): """ Gets the top k highest inchikeys and scores for each inchikey in query_fingerprints from target_fingerprints """ + if target_fingerprints.fingerprints.shape[0] < k: + raise ValueError("K cannot be larger than the number of fingerprints") similarity_scores = generalized_tanimoto_similarity_matrix( query_fingerprints.fingerprints, target_fingerprints.fingerprints ) @@ -67,3 +76,30 @@ def get_all_average_tanimoto_scores(self) -> dict[str, float]: average_per_inchikey_df = scores_df.mean(axis=1) return average_per_inchikey_df.to_dict() + + def save(self, path: str | Path) -> None: + """Save the TopKTanimotoScores to disk as a parquet file. + + Args: + path: File path without extension, e.g. "/data/top_k_scores". + """ + Path(path).with_suffix(".parquet").parent.mkdir(parents=True, exist_ok=True) + self.top_k_inchikeys_and_scores.to_parquet(Path(path).with_suffix(".parquet")) + + @classmethod + def load(cls, path: str | Path) -> "TopKTanimotoScores": + """Load a previously saved TopKTanimotoScores from disk. + + Args: + path: File path without extension, e.g. "/data/top_k_scores". + + Returns: + A fully reconstructed TopKTanimotoScores instance. + """ + df = pd.read_parquet(Path(path).with_suffix(".parquet")) + df.columns.names = ["result_rank", "attribute"] + + instance = cls.__new__(cls) + instance.k = len(df.columns.get_level_values("result_rank").unique()) + instance.top_k_inchikeys_and_scores = df + return instance diff --git a/ms2query/benchmarking/__init__.py b/ms2query/ms2query_development/__init__.py similarity index 100% rename from ms2query/benchmarking/__init__.py rename to ms2query/ms2query_development/__init__.py diff --git a/ms2query/benchmarking/reference_methods/EvaluateExactMatchSearch.py b/ms2query/ms2query_development/reference_methods/EvaluateExactMatchSearch.py similarity index 98% rename from ms2query/benchmarking/reference_methods/EvaluateExactMatchSearch.py rename to ms2query/ms2query_development/reference_methods/EvaluateExactMatchSearch.py index eced66b..a57368e 100644 --- a/ms2query/benchmarking/reference_methods/EvaluateExactMatchSearch.py +++ b/ms2query/ms2query_development/reference_methods/EvaluateExactMatchSearch.py @@ -1,7 +1,7 @@ import random from typing import Callable, List, Tuple from tqdm import tqdm -from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet +from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet class EvaluateExactMatchSearchAcrossIonmodes: diff --git a/ms2query/benchmarking/MS2DeepScoresForTopInChikeys.py b/ms2query/ms2query_development/reference_methods/MS2DeepScoresForTopInChikeys.py similarity index 92% rename from ms2query/benchmarking/MS2DeepScoresForTopInChikeys.py rename to ms2query/ms2query_development/reference_methods/MS2DeepScoresForTopInChikeys.py index d2519c7..883ab35 100644 --- a/ms2query/benchmarking/MS2DeepScoresForTopInChikeys.py +++ b/ms2query/ms2query_development/reference_methods/MS2DeepScoresForTopInChikeys.py @@ -1,12 +1,12 @@ import numpy as np from ms2deepscore.vector_operations import cosine_similarity_matrix from tqdm import tqdm -from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet -from ms2query.benchmarking.Fingerprints import Fingerprints -from ms2query.benchmarking.predict_top_k_ms2deepscore import ( +from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet +from ms2query.ms2query_development.Fingerprints import Fingerprints +from ms2query.ms2query_development.reference_methods.predict_top_k_ms2deepscore import ( select_inchikeys_with_highest_ms2deepscore, ) -from ms2query.benchmarking.TopKTanimotoScores import TopKTanimotoScores +from ms2query.ms2query_development.TopKTanimotoScores import TopKTanimotoScores def calculate_MS2DeepScoresForTopKInChikeys_from_spectra( @@ -77,6 +77,9 @@ def calculate_MS2DeepScoresForTopKInChikeys( class MS2DeepScoresForTopKInChikeys: """Stores the MS2DeepScores and Tanimoto scores for the top k closest lib spectra + This is only needed for the benchmarking and development (in the notebooks) + and is not used for running the final verison of MS2Query + This allows for quick testing of different reranking strategies. E.g. get_mean is similar to the original MS2Query, but it can also be used to make matrixes with both MS2DeepScore and tanimoto scores to train small reranking models. diff --git a/ms2query/benchmarking/reference_methods/__init__.py b/ms2query/ms2query_development/reference_methods/__init__.py similarity index 100% rename from ms2query/benchmarking/reference_methods/__init__.py rename to ms2query/ms2query_development/reference_methods/__init__.py diff --git a/ms2query/benchmarking/reference_methods/predict_best_possible_match.py b/ms2query/ms2query_development/reference_methods/predict_best_possible_match.py similarity index 93% rename from ms2query/benchmarking/reference_methods/predict_best_possible_match.py rename to ms2query/ms2query_development/reference_methods/predict_best_possible_match.py index 1b82e04..e27d7ec 100644 --- a/ms2query/benchmarking/reference_methods/predict_best_possible_match.py +++ b/ms2query/ms2query_development/reference_methods/predict_best_possible_match.py @@ -1,7 +1,7 @@ from typing import Dict from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix -from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet -from ms2query.benchmarking.Fingerprints import Fingerprints +from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet +from ms2query.ms2query_development.Fingerprints import Fingerprints def predict_best_possible_match( diff --git a/ms2query/benchmarking/predict_top_k_ms2deepscore.py b/ms2query/ms2query_development/reference_methods/predict_top_k_ms2deepscore.py similarity index 96% rename from ms2query/benchmarking/predict_top_k_ms2deepscore.py rename to ms2query/ms2query_development/reference_methods/predict_top_k_ms2deepscore.py index 1da3b2e..149807c 100644 --- a/ms2query/benchmarking/predict_top_k_ms2deepscore.py +++ b/ms2query/ms2query_development/reference_methods/predict_top_k_ms2deepscore.py @@ -2,8 +2,8 @@ import numpy as np from ms2deepscore.vector_operations import cosine_similarity_matrix from tqdm import tqdm -from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet -from ms2query.benchmarking.Embeddings import Embeddings +from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet +from ms2query.ms2query_development.Embeddings import Embeddings def predict_top_k_ms2deepscores( diff --git a/ms2query/ms2query_development/reference_methods/readme.md b/ms2query/ms2query_development/reference_methods/readme.md new file mode 100644 index 0000000..289ff6a --- /dev/null +++ b/ms2query/ms2query_development/reference_methods/readme.md @@ -0,0 +1 @@ +All files here are not important to core functionality and in fact not used for running the final version of MS2Query. However MS2DeepSCoresForTopInChikeys and predict_top-k_ms2deepscore are core to the benchmarking and experimenting in the notebooks. predict_best_possible_match and EvaluateExactMatchSearch are not yet used for the benchmarking, but both should still be done, so this could be used for that. diff --git a/ms2query/notebooks/develop_and_compare_ms2query_2.ipynb b/ms2query/notebooks/develop_and_compare_ms2query_2.ipynb index 4b244b2..80a3ba6 100644 --- a/ms2query/notebooks/develop_and_compare_ms2query_2.ipynb +++ b/ms2query/notebooks/develop_and_compare_ms2query_2.ipynb @@ -252,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "6fecef86-84af-40e9-bf48-d81d62f8bcf4", "metadata": {}, "outputs": [ @@ -270,7 +270,7 @@ } ], "source": [ - "from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet\n", + "from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet\n", "\n", "pos_test_spectra = AnnotatedSpectrumSet.create_spectrum_set(pos_test_spectra)\n", "pos_train_spectra = AnnotatedSpectrumSet.create_spectrum_set(pos_train_spectra)\n", @@ -279,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "92a0e90c-d435-4a0e-b9ba-bc7ab2139ca6", "metadata": {}, "outputs": [ @@ -294,7 +294,7 @@ } ], "source": [ - "from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet\n", + "from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet\n", "\n", "neg_val_spectra = AnnotatedSpectrumSet.create_spectrum_set(neg_val_spectra)\n", "neg_test_spectra = AnnotatedSpectrumSet.create_spectrum_set(neg_test_spectra)\n", @@ -428,7 +428,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "3453717f-b50d-4623-bd87-b1287c0fe916", "metadata": {}, "outputs": [ @@ -444,7 +444,7 @@ } ], "source": [ - "from ms2query.benchmarking.Fingerprints import Fingerprints\n", + "from ms2query.ms2query_development.Fingerprints import Fingerprints\n", "val_fingerprints = Fingerprints.from_spectrum_set(pos_val_spectra, \"daylight\", 4096)\n", "train_fingerprints = Fingerprints.from_spectrum_set(pos_train_spectra, \"daylight\", 4096)\n", "val_and_train_fingerprints = Fingerprints.combine_fingerprints(val_fingerprints, train_fingerprints)" @@ -461,12 +461,12 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": null, "id": "cc395ecc-9473-4c9d-979f-ff8c087293d5", "metadata": {}, "outputs": [], "source": [ - "from ms2query.benchmarking.TopKTanimotoScores import TopKTanimotoScores\n", + "from ms2query.ms2query_development.TopKTanimotoScores import TopKTanimotoScores\n", "top_k_tanimoto_scores = TopKTanimotoScores.calculate_from_fingerprints(\n", " train_fingerprints,\n", " train_fingerprints, k=100,)" @@ -488,7 +488,7 @@ "metadata": {}, "outputs": [], "source": [ - "from ms2query.benchmarking.predict_top_k_ms2deepscore import select_inchikeys_with_highest_ms2deepscore\n", + "from ms2query.ms2query_development.reference_methods.predict_top_k_ms2deepscore import select_inchikeys_with_highest_ms2deepscore\n", "\n", "inchikeys_with_highest_ms2deepscores = select_inchikeys_with_highest_ms2deepscore(\n", " pos_val_spectra, pos_train_spectra, nr_of_inchikeys_to_select=1, batch_size=1000,)\n" @@ -509,7 +509,7 @@ } ], "source": [ - "from ms2query.benchmarking.MS2DeepScoresForTopInChikeys import calculate_MS2DeepScoresForTopKInChikeys\n", + "from ms2query.ms2query_development.reference_methods.MS2DeepScoresForTopInChikeys import calculate_MS2DeepScoresForTopKInChikeys\n", "close_tanimoto_scores = calculate_MS2DeepScoresForTopKInChikeys(pos_train_spectra, pos_val_spectra, top_k_tanimoto_scores, inchikeys_with_highest_ms2deepscores)" ] }, @@ -2630,10 +2630,10 @@ } ], "source": [ - "from ms2query.benchmarking.Fingerprints import Fingerprints\n", - "from ms2query.benchmarking.TopKTanimotoScores import TopKTanimotoScores\n", - "from ms2query.benchmarking.predict_top_k_ms2deepscore import select_inchikeys_with_highest_ms2deepscore\n", - "from ms2query.benchmarking.MS2DeepScoresForTopInChikeys import calculate_MS2DeepScoresForTopKInChikeys\n", + "from ms2query.ms2query_development.Fingerprints import Fingerprints\n", + "from ms2query.ms2query_development.TopKTanimotoScores import TopKTanimotoScores\n", + "from ms2query.ms2query_development.reference_methods.predict_top_k_ms2deepscore import select_inchikeys_with_highest_ms2deepscore\n", + "from ms2query.ms2query_development.reference_methods.MS2DeepScoresForTopInChikeys import calculate_MS2DeepScoresForTopKInChikeys\n", "\n", "neg_val_fingerprints = Fingerprints.from_spectrum_set(neg_val_spectra, \"daylight\", 4096)\n", "neg_train_fingerprints = Fingerprints.from_spectrum_set(neg_train_spectra, \"daylight\", 4096)\n", @@ -2714,7 +2714,7 @@ "metadata": {}, "outputs": [], "source": [ - "from ms2query.benchmarking.predict_top_k_ms2deepscore import predict_top_k_ms2deepscores\n", + "from ms2query.ms2query_development.reference_methods.predict_top_k_ms2deepscore import predict_top_k_ms2deepscores\n", "\n", "indexes, ms2deepscores = predict_top_k_ms2deepscores(neg_train_spectra.embeddings, neg_val_spectra.embeddings)\n", "neg_predicted_inchikeys_ms2deepscore = [neg_train_spectra.spectra[int(index)].get(\"inchikey\")[:14] for index in indexes]\n", diff --git a/ms2query/notebooks/tutorial.ipynb b/ms2query/notebooks/tutorial.ipynb new file mode 100644 index 0000000..411561f --- /dev/null +++ b/ms2query/notebooks/tutorial.ipynb @@ -0,0 +1,196 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "70e60490", + "metadata": {}, + "source": [ + "# Tutorial for running MS2Query and creating your own libraries" + ] + }, + { + "cell_type": "markdown", + "id": "4d564033", + "metadata": {}, + "source": [ + "# optional: download a matchms cleaned library\n", + "The code below downloads an already matchms cleaned library and an MS2DeepScore model. You can also use your own library, but make sure you know what you are doing and clean the library first. If you just have a few reference spectra, it is probably best to combine your spectra with the reference spectra below to make sure the MS2Query search works properly. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7df418e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The file ./zenodo_files\\data_split_inchikeys.json already exists, the file won't be downloaded\n", + "The file ./zenodo_files\\merged_and_cleaned_libraries_1.mgf already exists, the file won't be downloaded\n", + "The file ./zenodo_files\\ms2deepscore_model.pt already exists, the file won't be downloaded\n" + ] + } + ], + "source": [ + "import requests\n", + "import os\n", + "from tqdm import tqdm\n", + "\n", + "def download_file(link, file_name):\n", + " response = requests.get(link, stream=True)\n", + " if os.path.exists(file_name):\n", + " print(f\"The file {file_name} already exists, the file won't be downloaded\")\n", + " return\n", + " total_size = int(response.headers.get('content-length', 0))\n", + "\n", + " with open(file_name, \"wb\") as f, tqdm(desc=\"Downloading file\", total=total_size, unit='B', unit_scale=True, unit_divisor=1024,) as bar:\n", + " for chunk in response.iter_content(chunk_size=1024):\n", + " if chunk:\n", + " f.write(chunk)\n", + " bar.update(len(chunk)) # Update progress bar by the chunk size\n", + "folder_to_store_zenodo_files = \"./zenodo_files\"\n", + "os.makedirs(folder_to_store_zenodo_files, exist_ok=True)\n", + "\n", + "download_file(\"https://zenodo.org/records/16882111/files/merged_and_cleaned_libraries_1.mgf?download=1\", \n", + " os.path.join(folder_to_store_zenodo_files, \"merged_and_cleaned_libraries_1.mgf\"))\n", + "download_file(\"https://zenodo.org/records/17826815/files/ms2deepscore_model.pt?download=1\", \n", + " os.path.join(folder_to_store_zenodo_files, \"ms2deepscore_model.pt\"))" + ] + }, + { + "cell_type": "markdown", + "id": "fa99d2dd", + "metadata": {}, + "source": [ + "# Specify file location \n", + "Replace with your file names" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6db579a9", + "metadata": {}, + "outputs": [], + "source": [ + "library_spectra_file = os.path.join(folder_to_store_zenodo_files, \"merged_and_cleaned_libraries_1.mgf\")\n", + "ms2deepscore_model_file_name = os.path.join(folder_to_store_zenodo_files, \"ms2deepscore_model.pt\")\n", + "query_spectrum_file = \"replace_with_your_lib_spectra.mgf\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f521c1c4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1017531it [09:51, 1720.67it/s]\n" + ] + } + ], + "source": [ + "from matchms.importing import load_from_mgf\n", + "from tqdm import tqdm\n", + "\n", + "library_spectra = list(tqdm(load_from_mgf(library_spectra_file)))\n", + "query_spectra = list(tqdm(load_from_mgf(query_spectrum_file)))" + ] + }, + { + "cell_type": "markdown", + "id": "64534dd6", + "metadata": {}, + "source": [ + "# Create the reference library files\n", + "The code below will precompute everything needed to run MS2Query. It will save this in the same folder as your ms2deepscore model. \n", + "The files created are \"embeddings.npz\", \"top_k_tanimoto_scores.parquet\", \"library_metadata.parquet\". " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "669a489a", + "metadata": {}, + "outputs": [], + "source": [ + "from ms2query.ms2query_development.ReferenceLibrary import ReferenceLibrary\n", + "reference_library = ReferenceLibrary.create_from_spectra(library_spectra, ms2deepscore_model_file_name)" + ] + }, + { + "cell_type": "markdown", + "id": "dfd0a633", + "metadata": {}, + "source": [ + "# Run MS2Query\n", + "The code above only has to be run once after that you can load the library faster from the saved files. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6cf44e2", + "metadata": {}, + "outputs": [], + "source": [ + "# no need to run if you just created the libary above\n", + "reference_library = ReferenceLibrary.load_from_directory(folder_to_store_zenodo_files)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "566b2b0d", + "metadata": {}, + "outputs": [], + "source": [ + "results = reference_library.run_ms2query(query_spectra)" + ] + }, + { + "cell_type": "markdown", + "id": "98ce9660", + "metadata": {}, + "source": [ + "print(results)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7e3b533", + "metadata": {}, + "outputs": [], + "source": [ + "results.to_csv(\"ms2query_results.csv\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ms2query_2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ms2query/readme.md b/ms2query/readme.md new file mode 100644 index 0000000..b2b1267 --- /dev/null +++ b/ms2query/readme.md @@ -0,0 +1,21 @@ +## Basic workflow (so far): +This is for creating the database. This is not yet fully functional, so for now please use the notebooks/tutorial.ipynb if you already want to try out the prototype. + +### Library generation +```python +from ms2query.create_new_library import create_new_library + +ms2query_lib = create_new_library( + spectra_files=["spectra.mgf"], + annotation_files=[], + output_folder="my_ms2query_folder/", + model_path="models/ms2deepscore.pt" +) +``` + +### Loading already generated library +```python +from ms2query.create_new_library import load_created_library + +lib = load_created_library("my_ms2query_folder/") +``` diff --git a/pyproject.toml b/pyproject.toml index fbaac97..b416b53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ ms2deepscore= ">=2.6.0" rdkit= ">2024.3.4" nmslib= ">=2.0.0" umap-learn= ">=0.5.7" +pyarrow= ">=14.0.1" [tool.poetry.group.dev.dependencies] decorator = "^5.1.1" diff --git a/tests/helper_functions.py b/tests/helper_functions.py index 82d1709..5c234b4 100644 --- a/tests/helper_functions.py +++ b/tests/helper_functions.py @@ -4,8 +4,8 @@ import numpy as np from matchms.Spectrum import Spectrum from ms2deepscore.models import load_model -from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet -from ms2query.benchmarking.Fingerprints import Fingerprints +from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet +from ms2query.ms2query_development.Fingerprints import Fingerprints TEST_RESOURCES_PATH = Path(__file__).parent / "test_data" @@ -87,6 +87,36 @@ def get_inchikey_inchi_pairs(number_of_pairs): "CC(C)C[C@@H](C(=O)O)N", "L-Leucine", ), + ( + "BSYNRYMUTXBXSQ-UHFFFAOYSA-N", + "InChI=1S/C9H8O4/c1-6(10)13-8-5-3-2-4-7(8)9(11)12/h2-5H,1H3,(H,11,12)", + "CC(=O)OC1=CC=CC=C1C(=O)O", + "Aspirin", + ), + ( + "WHUUTDBJXJRKMK-VKHMYHEASA-N", + "InChI=1S/C5H9NO4/c6-3(5(9)10)1-2-4(7)8/h3H,1-2,6H2,(H,7,8)(H,9,10)/t3-/m0/s1", + "C(CC(=O)O)[C@@H](C(=O)O)N", + "L-Glutamic acid", + ), + ( + "ZKHQWZAMYRWXGA-KQYNXXCUSA-N", + "InChI=1S/C10H14N5O7P/c11-8-5-9(13-2-12-8)15(3-14-5)10-7(17)6(16)4(22-10)1-21-23(18,19)20/h2-4,6-7,10,16-17H,1H2,(H2,11,12,13)(H2,18,19,20)/t4-,6-,7-,10-/m1/s1", + "C1=NC(=C2C(=N1)N(C=N2)[C@H]3[C@@H]([C@@H]([C@H](O3)COP(=O)(O)O)O)O)N", + "AMP", + ), + ( + "GVJHHUAWPYXKBD-UHFFFAOYSA-N", + "InChI=1S/C2H6O/c1-2-3/h3H,2H2,1H3", + "CCO", + "Ethanol", + ), + ( + "IKHGUXGNUITLKF-XPULMUKRSA-N", + "InChI=1S/C9H13NO3/c1-6(11)13-8-4-2-7(3-5-8)9(10)12/h2-6,11H,1H3,(H2,10,12)/t6-/m0/s1", + "C[C@@H](C1=CC=C(C=C1)C(=O)N)O", + "Salbutamol", + ), ) if number_of_pairs > len(inchikey_inchi_pairs): raise ValueError("Not enough example compounds, add some in conftest") diff --git a/tests/test_benchmarking/test_top_k_tanimoto_scores.py b/tests/test_benchmarking/test_top_k_tanimoto_scores.py deleted file mode 100644 index 2c8ecd2..0000000 --- a/tests/test_benchmarking/test_top_k_tanimoto_scores.py +++ /dev/null @@ -1,28 +0,0 @@ -import numpy as np -import pytest -from ms2query.benchmarking.TopKTanimotoScores import TopKTanimotoScores -from tests.helper_functions import make_test_fingerprints - - -def test_methods_top_k_tanimoto_scores(): - scores = np.array([[1.0, 0.8], [0.9, 0.7], [0.8, 0.8]]) - inchikeys_in_top_2 = np.array([["A", "B"], ["B", "C"], ["C", "A"]]) - inchikeys = np.array(["A", "B", "C"]) - top_scores = TopKTanimotoScores(scores, inchikeys_in_top_2, inchikeys) - - assert top_scores.select_top_k_inchikeys_and_scores("A") == {"A": 1.0, "B": 0.8} - - assert top_scores.select_top_k_inchikeys("A") == ["A", "B"] - - assert top_scores.select_average_score("A") == pytest.approx(0.9) - - assert top_scores.get_all_average_tanimoto_scores() == {"A": 0.9, "B": 0.8, "C": 0.8} - - -def test_calculate_from_fingerprints(): - fingerprints = make_test_fingerprints(nbits=5, nr_of_inchikeys=5) - top_scores = TopKTanimotoScores.calculate_from_fingerprints(fingerprints, fingerprints, 2) - assert top_scores.select_top_k_inchikeys_and_scores("AAAAAAAAAAAAAE") == { - "AAAAAAAAAAAAAD": 0.75, - "AAAAAAAAAAAAAE": 1.0, - } diff --git a/tests/test_benchmarking/test_Fingerprints.py b/tests/test_ms2query_development/test_Fingerprints.py similarity index 96% rename from tests/test_benchmarking/test_Fingerprints.py rename to tests/test_ms2query_development/test_Fingerprints.py index 6ab8083..6bffc74 100644 --- a/tests/test_benchmarking/test_Fingerprints.py +++ b/tests/test_ms2query_development/test_Fingerprints.py @@ -1,8 +1,8 @@ import numpy as np import pytest from matchms.filtering.metadata_processing.add_fingerprint import _derive_fingerprint_from_inchi -from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet -from ms2query.benchmarking.Fingerprints import Fingerprints, get_similarity_matrix +from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet +from ms2query.ms2query_development.Fingerprints import Fingerprints, get_similarity_matrix from tests.helper_functions import create_test_spectra, get_inchikey_inchi_pairs diff --git a/tests/test_benchmarking/test_MS2DeepScoresForTopInChikeys.py b/tests/test_ms2query_development/test_MS2DeepScoresForTopInChikeys.py similarity index 78% rename from tests/test_benchmarking/test_MS2DeepScoresForTopInChikeys.py rename to tests/test_ms2query_development/test_MS2DeepScoresForTopInChikeys.py index f0ed9df..a065fed 100644 --- a/tests/test_benchmarking/test_MS2DeepScoresForTopInChikeys.py +++ b/tests/test_ms2query_development/test_MS2DeepScoresForTopInChikeys.py @@ -1,5 +1,5 @@ -from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet -from ms2query.benchmarking.MS2DeepScoresForTopInChikeys import ( +from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet +from ms2query.ms2query_development.reference_methods.MS2DeepScoresForTopInChikeys import ( calculate_MS2DeepScoresForTopKInChikeys_from_spectra, ) from tests.helper_functions import create_test_spectra, ms2deepscore_model diff --git a/tests/test_ms2query_development/test_ReferenceLibrary.py b/tests/test_ms2query_development/test_ReferenceLibrary.py new file mode 100644 index 0000000..77b5bf9 --- /dev/null +++ b/tests/test_ms2query_development/test_ReferenceLibrary.py @@ -0,0 +1,62 @@ +import os +import pandas as pd +from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet +from ms2query.ms2query_development.Fingerprints import Fingerprints +from ms2query.ms2query_development.ReferenceLibrary import ( + ReferenceLibrary, + extract_metadata_from_library, +) +from ms2query.ms2query_development.TopKTanimotoScores import TopKTanimotoScores +from tests.helper_functions import TEST_RESOURCES_PATH, create_test_spectra, ms2deepscore_model + + +def test_run_ms2query(): + model = ms2deepscore_model() + library_spectra = AnnotatedSpectrumSet.create_spectrum_set(create_test_spectra(nr_of_inchikeys=7)) + test_spectra = create_test_spectra(1, nr_of_inchikeys=3) + library_spectra.add_embeddings(model) + fingerprints = Fingerprints.from_spectrum_set(library_spectra, "daylight", 100) + top_k_tanimoto_scores = TopKTanimotoScores.calculate_from_fingerprints(fingerprints, fingerprints, 3) + metadata_library = extract_metadata_from_library( + library_spectra, + [ + "precursor_mz", + "collision_energy", + "compound_name", + "smiles", + "inchikey", + ], + ) + + results = ReferenceLibrary(model, library_spectra.embeddings, top_k_tanimoto_scores, metadata_library).run_ms2query( + test_spectra + ) + print(results) + + +def test_create_library(tmp_path): + lib_spectra = create_test_spectra(nr_of_inchikeys=10, number_of_spectra_per_inchikey=3) + # save_as_mgf(lib_spectra, os.path.join(tmp_path, "library_spectra.mgf")) + ms2deepscore_model_file = os.path.join(TEST_RESOURCES_PATH, "ms2deepscore_testmodel_v1.pt") + ReferenceLibrary.create_from_spectra(lib_spectra, ms2deepscore_model_file, tmp_path) + assert (tmp_path / ReferenceLibrary.embedding_file_name).exists() + assert (tmp_path / ReferenceLibrary.top_k_tanimoto_scores_file_name).exists() + assert (tmp_path / ReferenceLibrary.reference_metadata_file_name).exists() + + +def test_create_and_use_library(tmp_path): + lib_spectra = create_test_spectra(nr_of_inchikeys=10, number_of_spectra_per_inchikey=3) + ms2deepscore_model_file = os.path.join(TEST_RESOURCES_PATH, "ms2deepscore_testmodel_v1.pt") + ms2query_library = ReferenceLibrary.create_from_spectra(lib_spectra, ms2deepscore_model_file, tmp_path) + test_spectra = create_test_spectra(1, nr_of_inchikeys=3) + results = ms2query_library.run_ms2query(test_spectra) + + ms2query_library_2 = ReferenceLibrary.load_from_files( + ms2deepscore_model_file, + tmp_path / ReferenceLibrary.embedding_file_name, + tmp_path / ReferenceLibrary.top_k_tanimoto_scores_file_name, + tmp_path / ReferenceLibrary.reference_metadata_file_name, + ) + + results_2 = ms2query_library_2.run_ms2query(test_spectra) + pd.testing.assert_frame_equal(results, results_2) diff --git a/tests/test_benchmarking/test_SpectrumDataSet.py b/tests/test_ms2query_development/test_SpectrumDataSet.py similarity index 98% rename from tests/test_benchmarking/test_SpectrumDataSet.py rename to tests/test_ms2query_development/test_SpectrumDataSet.py index 3e252eb..94baa36 100644 --- a/tests/test_benchmarking/test_SpectrumDataSet.py +++ b/tests/test_ms2query_development/test_SpectrumDataSet.py @@ -1,6 +1,6 @@ import os import pytest -from ms2query.benchmarking.AnnotatedSpectrumSet import ( +from ms2query.ms2query_development.AnnotatedSpectrumSet import ( AnnotatedSpectrumSet, ) from tests.helper_functions import create_test_spectra, ms2deepscore_model diff --git a/tests/test_benchmarking/test_embeddings.py b/tests/test_ms2query_development/test_embeddings.py similarity index 95% rename from tests/test_benchmarking/test_embeddings.py rename to tests/test_ms2query_development/test_embeddings.py index 86fbbbd..5ad42c9 100644 --- a/tests/test_benchmarking/test_embeddings.py +++ b/tests/test_ms2query_development/test_embeddings.py @@ -1,6 +1,6 @@ import os import pytest -from ms2query.benchmarking.Embeddings import Embeddings, calculate_ms2deepscore_df +from ms2query.ms2query_development.Embeddings import Embeddings, calculate_ms2deepscore_df from tests.helper_functions import create_test_spectra, get_library_and_test_spectra_not_identical, ms2deepscore_model diff --git a/tests/test_benchmarking/test_methods.py b/tests/test_ms2query_development/test_methods.py similarity index 78% rename from tests/test_benchmarking/test_methods.py rename to tests/test_ms2query_development/test_methods.py index b584df4..240c11b 100644 --- a/tests/test_benchmarking/test_methods.py +++ b/tests/test_ms2query_development/test_methods.py @@ -1,6 +1,6 @@ import numpy as np -from ms2query.benchmarking.Fingerprints import Fingerprints -from ms2query.benchmarking.reference_methods.predict_best_possible_match import predict_best_possible_match +from ms2query.ms2query_development.Fingerprints import Fingerprints +from ms2query.ms2query_development.reference_methods.predict_best_possible_match import predict_best_possible_match from tests.helper_functions import ( get_library_and_test_spectra_not_identical, ) diff --git a/tests/test_benchmarking/test_predict_top_ms2deepscores.py b/tests/test_ms2query_development/test_predict_top_ms2deepscores.py similarity index 90% rename from tests/test_benchmarking/test_predict_top_ms2deepscores.py rename to tests/test_ms2query_development/test_predict_top_ms2deepscores.py index e830017..7cd00ce 100644 --- a/tests/test_benchmarking/test_predict_top_ms2deepscores.py +++ b/tests/test_ms2query_development/test_predict_top_ms2deepscores.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet -from ms2query.benchmarking.predict_top_k_ms2deepscore import ( +from ms2query.ms2query_development.AnnotatedSpectrumSet import AnnotatedSpectrumSet +from ms2query.ms2query_development.reference_methods.predict_top_k_ms2deepscore import ( predict_top_k_ms2deepscores, select_inchikeys_with_highest_ms2deepscore, ) diff --git a/tests/test_ms2query_development/test_top_k_tanimoto_scores.py b/tests/test_ms2query_development/test_top_k_tanimoto_scores.py new file mode 100644 index 0000000..e15bb15 --- /dev/null +++ b/tests/test_ms2query_development/test_top_k_tanimoto_scores.py @@ -0,0 +1,80 @@ +import numpy as np +import pandas as pd +import pytest +from ms2query.ms2query_development.TopKTanimotoScores import TopKTanimotoScores +from tests.helper_functions import make_test_fingerprints + + +def test_methods_top_k_tanimoto_scores(): + scores = np.array([[1.0, 0.8], [0.9, 0.7], [0.8, 0.8]]) + inchikeys_in_top_2 = np.array([["A", "B"], ["B", "C"], ["C", "A"]]) + inchikeys = np.array(["A", "B", "C"]) + top_scores = TopKTanimotoScores(scores, inchikeys_in_top_2, inchikeys) + + assert top_scores.select_top_k_inchikeys_and_scores("A") == {"A": 1.0, "B": 0.8} + + assert top_scores.select_top_k_inchikeys("A") == ["A", "B"] + + assert top_scores.select_average_score("A") == pytest.approx(0.9) + + assert top_scores.get_all_average_tanimoto_scores() == {"A": 0.9, "B": 0.8, "C": 0.8} + + +def test_calculate_from_fingerprints(): + fingerprints = make_test_fingerprints(nbits=5, nr_of_inchikeys=5) + top_scores = TopKTanimotoScores.calculate_from_fingerprints(fingerprints, fingerprints, 2) + assert top_scores.select_top_k_inchikeys_and_scores("AAAAAAAAAAAAAE") == { + "AAAAAAAAAAAAAD": 0.75, + "AAAAAAAAAAAAAE": 1.0, + } + + +@pytest.fixture +def sample_scores(): + """Creates a simple TopKTanimotoScores instance for testing.""" + tanimoto_scores = np.array( + [ + [0.9, 0.7, 0.5], + [0.8, 0.6, 0.4], + [0.95, 0.85, 0.75], + ] + ) + top_k_inchikeys = np.array( + [ + ["INCHI_A", "INCHI_B", "INCHI_C"], + ["INCHI_B", "INCHI_C", "INCHI_A"], + ["INCHI_C", "INCHI_A", "INCHI_B"], + ] + ) + inchikey_indexes = np.array(["QUERY_1", "QUERY_2", "QUERY_3"]) + return TopKTanimotoScores(tanimoto_scores, top_k_inchikeys, inchikey_indexes) + + +# ----- save and load tests ----- +def test_save_creates_parquet_file(sample_scores, tmp_path): + sample_scores.save(tmp_path / "test_scores") + assert (tmp_path / "test_scores.parquet").exists() + + +def test_save_creates_parent_directories(sample_scores, tmp_path): + sample_scores.save(tmp_path / "nested" / "dir" / "test_scores") + assert (tmp_path / "nested" / "dir" / "test_scores.parquet").exists() + + +def test_roundtrip_produces_identical_object(sample_scores, tmp_path): + sample_scores.save(tmp_path / "test_scores") + loaded = TopKTanimotoScores.load(tmp_path / "test_scores") + + assert loaded.k == sample_scores.k + pd.testing.assert_frame_equal(loaded.top_k_inchikeys_and_scores, sample_scores.top_k_inchikeys_and_scores) + assert sample_scores.select_top_k_inchikeys_and_scores("QUERY_1") == loaded.select_top_k_inchikeys_and_scores( + "QUERY_1" + ) + assert sample_scores.select_top_k_inchikeys("QUERY_2") == loaded.select_top_k_inchikeys("QUERY_2") + assert sample_scores.select_average_score("QUERY_3") == pytest.approx(loaded.select_average_score("QUERY_3")) + + +def test_roundtrip_accepts_string_path(sample_scores, tmp_path): + sample_scores.save(str(tmp_path / "test_scores")) + loaded = TopKTanimotoScores.load(str(tmp_path / "test_scores")) + pd.testing.assert_frame_equal(loaded.top_k_inchikeys_and_scores, sample_scores.top_k_inchikeys_and_scores)