-
Notifications
You must be signed in to change notification settings - Fork 0
Refactor spectrum set #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
niekdejonge
wants to merge
151
commits into
main
Choose a base branch
from
refactor_spectrum_set
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
151 commits
Select commit
Hold shift + click to select a range
fa4afdd
Add SpectrumDataSet for handling sets of spectra, for method developm…
niekdejonge e882f00
Add EvaluateMethods a general method for benchmarking analogue search…
niekdejonge 922bea8
Implement base line methods
niekdejonge 40c6347
Add notebooks used for testing (still have to be adapted to work with…
niekdejonge 9f67cd2
Add tests
niekdejonge 4d81ddf
Add inits
niekdejonge 8024263
Use ms2deepscore 2.6.0
niekdejonge b4561c0
Do correct subsetting of inchikey sets
niekdejonge 7f76997
Add method for predicting using top 10 closest library spectra.
niekdejonge 2c2e35f
Added extra inchikey smiles examples
niekdejonge fb4cbd8
Add test_get_inchikey_and_tanimoto_scores_from_top_k
niekdejonge 1933c85
Move top_k_selection outside average computation for easier testing
niekdejonge fd2bf89
Add test_get_average_predictions_for_closely_related_metabolites
niekdejonge 59aca8e
Split select_inchikeys_with_highest_ms2deepscores to make more modula…
niekdejonge f223e43
Add test_select_inchikeys_with_highest_ms2deepscore
niekdejonge 16058cc
Added nr_of_inchikeys_with_highest_ms2deepscore_to_select as parameter
niekdejonge 1c2f1a7
Added basic tests for predict using closest tanimoto score (checking …
niekdejonge 276511c
Add tqdm to predict using closest tanimoto
niekdejonge a59e6a9
ruff
niekdejonge aff52e4
Linting
niekdejonge cd30383
Lint notebooks
niekdejonge c97c5a6
Lint line length
niekdejonge fd688c4
Exclude notebooks from ruff linting
niekdejonge 48e9dea
Merge branch 'main' into add_benchmarking_method
niekdejonge 8d9a4b7
Change SpectrumSetBase to SpectrumSet
niekdejonge e95b5f2
Remove an enter
niekdejonge 682d3cd
Add Fingerprint class
niekdejonge cd41c88
Select most common inchi per inchikey in SpectrumSet
niekdejonge 5d41fd9
Add subset_fingerprints
niekdejonge 5759d1b
Add test_fingerprints
niekdejonge a03d958
Add embeddings and fingerprints to SpectrumSet
niekdejonge e758343
Move and update Fingerprints
niekdejonge 6fd710b
Fix bug for Embeddings when checking unique input
niekdejonge 3f75740
Use hidden properties for embeddings and fingerprints internally
niekdejonge 9205286
Use SpectrumSet in predict_highest_cosine.py
niekdejonge add78e5
Use Embeddings class for predicting top ms2deepscore
niekdejonge 21ede03
Update predict with integrated similarity flow to use SpectrumSet
niekdejonge 4c8a8dc
Make predict_using_closest_tanimoto.py use SpectrumSet
niekdejonge 21ab6fd
Update EvaluateMethods to use SpectrumSet
niekdejonge 8a02acf
Let predict_best_possible_match.py use SpectrumSet
niekdejonge 6b54141
Add some extra typehinting
niekdejonge cf467e7
Make a TopKTanimotoScores class (still need to integrate)
niekdejonge 1f81d35
Change embeddings to add create_from_spectra as class method and add …
niekdejonge 25ac2ec
Move Embeddings to embeddings
niekdejonge 6559fec
Make index_to_spectrum_hash inmutable and replace add_embeddings with…
niekdejonge 33626ec
Add __eq__ method to embeddings
niekdejonge 17f1682
Add test_subset_embeddings
niekdejonge e0d0a16
Add extra tests and fix bug in combine-embeddings
niekdejonge 7fb3347
Add test_combine_embeddings
niekdejonge d9a5499
Refactor Fingerprints, init from fingerprints and class methods do co…
niekdejonge 0191d50
Updated tests for Fingerprints
niekdejonge 0944fdb
change fingerprints.index_to_inchikey to fingerprints.inchikey
niekdejonge ed2ce28
Add method for creating Fingerprints from SpectrumSet
niekdejonge e1bf3b3
Refactor SpectrumSet init takes in attributes classmethods handle com…
niekdejonge 18b4e86
Update type hinting
niekdejonge af1eaec
Add test fingerprint from spectra
niekdejonge 8da343e
Add cloning, copy and __eq__ for spectrumSet
niekdejonge 91f02d4
Rename SpectrumSet to AnnotatedSpectrumSet
niekdejonge 97eb6a2
Add check that inchikey is available
niekdejonge a02f36f
Small bug fixes
niekdejonge f19f9d8
Update test_SpectrumDataSet.py
niekdejonge 7722382
rename to AnnotatedSpectrumSet.py
niekdejonge 009af51
Add inchikeys as property to AnnotatedSpectrumSet.py
niekdejonge 39b5414
Add fingerprints as input to predict_best_possible_match.py
niekdejonge 56e3a8b
Refactor EvaluateMethods and split on analogue search benchmarking an…
niekdejonge 11e3ace
Create EvaluateAnalogueSearch.py
niekdejonge d2244db
Split EvaluateExactMatchSearch in within and across ionmodes
niekdejonge e1ddd1b
Organize methods for within and across ionmodes
niekdejonge 5125b90
Rename file to EvaluateExactMatchSearch.py
niekdejonge 61909b8
Add type hint
niekdejonge 2803df7
First test to evaluate methods
niekdejonge b10d276
Replace old Spectrum calss with new.
niekdejonge 5bea2c3
Use Fingerprints in predict_using_closest_tanimoto.py
niekdejonge 510d4a8
backup because of laptop crash
niekdejonge 7b4e7f0
Merge pull request #5 from matchms/backup
niekdejonge 4816fac
Fixed type hinting
niekdejonge 1836fb6
Fix some type hinting
niekdejonge a726560
Fix type hint
niekdejonge 77cc1c5
Fix type hint
niekdejonge f01e67d
Fix bug getting inchikeys from library spectra
niekdejonge 47c436a
Add type hint
niekdejonge 1c85fdd
fix test_methods
niekdejonge 385c8c3
Add check that list is provided to get_fingerprints
niekdejonge d399aa2
Fix bug in get_inchikey_and_tanimoto_scores_for_top_k
niekdejonge 1a173db
fix test_get_inchikey_and_tanimoto_scores_for_top_k to use Fingerprin…
niekdejonge c3e0fc3
Fix test
niekdejonge 7147523
fix test
niekdejonge 9c98d91
Add type hinting
niekdejonge a8d5e17
Clarify valueerror
niekdejonge 3fd1652
Fix bug in select_inchikeys_with_highest_ms2deepscore if tuple is use…
niekdejonge 7ebd55c
Rename indexes to indices
niekdejonge 8a0fd43
ruff linting
niekdejonge fab8d45
Ruff linting
niekdejonge 9ad8dc7
Move bencharking tests to separate folder
niekdejonge 61edadc
rename conftest to helperfunctions
niekdejonge 72287be
Fix test_predict_using_closest_tanimoto import error
niekdejonge f40f781
Update TopKTanimotoScores
niekdejonge 2e643b4
Add create fingerprint functionality to test helper functions
niekdejonge aff2fd0
Add tests for top_k_tanimoto_scores
niekdejonge 3607bf5
Extended test_calculate from fingerprints
niekdejonge 381d630
Add calculate_ms2deepscore_df
niekdejonge 90f1491
Add selct_inchikeys_with_highest_ms2deepscore
niekdejonge 17ec7c4
rename predict_top_ms2deepscores
niekdejonge e88a4ed
Move get_library_and_test_spectra functions to helper functions
niekdejonge e7ed35e
change reference to predict_top_ms2deepscores
niekdejonge 7dc3d0a
Add optional precomputed embeddings select_inchikeys_with_highest_ms2…
niekdejonge aa7b6d9
change reference to predict_top_ms2deepscores
niekdejonge afdfd1b
Fix bug with numba threading
niekdejonge c809095
Fix but in select_inchikeys_with_highest_ms2deepscore
niekdejonge 2f0e63c
Make predict_using_closest_tanimoto use TopKTanimotoScores
niekdejonge d6418f4
ruff linting
niekdejonge f7b1b8f
add progress bar when creating annotatedspectrumset and remove progre…
niekdejonge 5354887
Add progress bar for hashing of spectra in embeddings
niekdejonge f784303
Remove progress bar setting
niekdejonge 0a750ac
Add len to AnnotatedSpectrumsET
niekdejonge 3e37b4c
fix type hint
niekdejonge 549d1b2
rename predict_top_k_ms2deepscore and make select_inchikeys_with_high…
niekdejonge e130779
Create MS2DeepScoresForTopKInChikeys
niekdejonge fef5a20
remove unused import
niekdejonge 8108df6
First test for MS2DeepScoresForTopinchikeys
niekdejonge 7c921e2
Fix pylance warning
niekdejonge e118177
Fix type hinting issue
niekdejonge 8bb6cda
fix pylance issue
niekdejonge eec6e36
Add has_embedding property to AnnotatedSpectrumSet
niekdejonge 3e2eeb3
Add warning when adding two spectrumsets where one has embeddings and…
niekdejonge 0a87c16
small changes
niekdejonge 7070b1b
fix bug when copying AnnotatedSpectrumSet without embeddings
niekdejonge 1d48ca6
Replace raising error in __eq__ with returning NotImplemented
niekdejonge 328c69d
Added repr ans str for AnnotatedSpectrumSet
niekdejonge a385313
remove integration test from pyproject.toml
niekdejonge 41e4c4b
Add subset spectra on metadata
niekdejonge a11bd46
Replace combine embeddings with __add__
niekdejonge f080a5e
Added an embedding setter with checks to AnnotatedSpectrumSet
niekdejonge 83c6c7d
Added save and load models for embeddings
niekdejonge 73455a1
Added save and load options for AnnotatedSpectrumSet
niekdejonge 11729e2
ruff
niekdejonge 7406739
Add main cleaned notebook
niekdejonge 49cc2ec
Add gradient boosting machine notebook
niekdejonge 604f9de
Add entropy run notebook
niekdejonge f3cda29
Remove old notebook
niekdejonge 3bec761
Add readme with explanation of notebooks
niekdejonge bba9b62
Remove old notebook
niekdejonge 09d545d
spelling error fix
niekdejonge 95bba35
Remove EvaluateAnalogueSearch since now handled in notebooks
niekdejonge cd310cf
Fix some notebook mistakes
niekdejonge 67320f4
Remove unused predict cosine
niekdejonge a5cb3fc
Remove integrated similarity flow
niekdejonge 42a18ae
Remove unused import
niekdejonge 5831a62
Move main methods from reference_methods folder
niekdejonge 6f96d76
Remove predict highest ms2deepscore
niekdejonge 431e802
Add readme describing reference methods folder
niekdejonge File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| import os | ||
| from collections import defaultdict | ||
| from typing import Iterable, List, Mapping, Optional, Sequence | ||
| from matchms import Spectrum | ||
| from matchms.exporting import save_spectra | ||
| from matchms.importing import load_spectra | ||
| from ms2deepscore.models import SiameseSpectralModel | ||
| from tqdm import tqdm | ||
| from ms2query.benchmarking.Embeddings import Embeddings | ||
|
|
||
|
|
||
| class AnnotatedSpectrumSet: | ||
| """Stores a spectrum dataset making it easy and fast to split on molecules""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| spectra: Sequence[Spectrum], | ||
| spectrum_indices_per_inchikey: Mapping[str, Iterable[int]], | ||
| embeddings: Optional[Embeddings] = None, | ||
| ): | ||
| self._spectra = tuple([spectrum.clone() for spectrum in spectra]) | ||
| self.spectrum_indices_per_inchikey: dict[str, tuple[int, ...]] = { | ||
| key: tuple(values) for key, values in spectrum_indices_per_inchikey.items() | ||
| } | ||
| self.embeddings = embeddings | ||
|
|
||
| @classmethod | ||
| def create_spectrum_set(cls, spectra: Sequence[Spectrum]) -> "AnnotatedSpectrumSet": | ||
| spectrum_indices_per_inchikey = defaultdict(list) | ||
| for spectrum_index, spectrum in enumerate(tqdm(spectra, desc="Create mapping from inchikey to spectrum")): | ||
| inchikey = spectrum.get("inchikey") | ||
| if inchikey is None: | ||
| raise ValueError("Annotated Spectrum set expects spectra that all have an inchikey") | ||
| spectrum_indices_per_inchikey[inchikey[:14]].append(spectrum_index) | ||
| return cls(spectra, spectrum_indices_per_inchikey) | ||
|
|
||
| def __add__(self, other) -> "AnnotatedSpectrumSet": | ||
| """Adds two spectrum sets together""" | ||
| if not isinstance(other, AnnotatedSpectrumSet): | ||
| return NotImplemented | ||
| spectra = self.spectra + other.spectra | ||
| # update spectrum_indices_per_inchikey | ||
| starting_index = len(self.spectra) | ||
| reindexed_indices_per_inchikey = {} | ||
| for inchikey, list_of_spectrum_indices in other.spectrum_indices_per_inchikey.items(): | ||
| reindexed_indices_per_inchikey[inchikey] = [v + starting_index for v in list_of_spectrum_indices] | ||
| # combine indices | ||
| spectrum_indices_per_inchikey = defaultdict(list) | ||
| for indices_per_inchikey in (self.spectrum_indices_per_inchikey, reindexed_indices_per_inchikey): | ||
| for inchikey, indices in indices_per_inchikey.items(): | ||
| spectrum_indices_per_inchikey[inchikey].extend(indices) | ||
|
|
||
| # combine embeddings | ||
| if self.has_embeddings != other.has_embeddings: | ||
| print("Only one of the two sets has an embeddings, so embeddings are not added") | ||
| embeddings = None | ||
| if self.has_embeddings and other.has_embeddings: | ||
| embeddings = self.embeddings + other.embeddings | ||
| return AnnotatedSpectrumSet(spectra, spectrum_indices_per_inchikey, embeddings=embeddings) | ||
|
|
||
| def subset_spectra(self, spectrum_indices) -> "AnnotatedSpectrumSet": | ||
| """Returns a new instance of a subset of the spectra""" | ||
| spectra = [self._spectra[index] for index in spectrum_indices] | ||
| new_instance = AnnotatedSpectrumSet.create_spectrum_set(spectra) | ||
| if self.has_embeddings: | ||
| new_instance._embeddings = self.embeddings.subset_embeddings(spectra) | ||
| return new_instance | ||
|
|
||
| def subset_spectra_on_metadata(self, metadata_key: str, values_to_keep: set) -> "AnnotatedSpectrumSet": | ||
| """Creates a subset from the spectra by checking for specific metadata keys | ||
|
|
||
| E.g. subset_spectra_on_metadata("ionmode", set(["positive"])) will return only the spectra in positive ion mode | ||
| """ | ||
| spectrum_indexes_to_keep = [] | ||
| for spectrum_index, spectrum in enumerate(tqdm(self.spectra, desc="Checking spectra for correct metadata")): | ||
| if spectrum.get(metadata_key) in values_to_keep: | ||
| spectrum_indexes_to_keep.append(spectrum_index) | ||
| return self.subset_spectra(spectrum_indexes_to_keep) | ||
|
|
||
| def spectra_per_inchikey(self, inchikey) -> List[Spectrum]: | ||
niekdejonge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| matching_spectra = [] | ||
| for index in self.spectrum_indices_per_inchikey[inchikey]: | ||
| matching_spectra.append(self._spectra[index]) | ||
| return matching_spectra | ||
|
|
||
| def add_embeddings(self, model: SiameseSpectralModel): | ||
| self._embeddings = Embeddings.create_from_spectra(self._spectra, model) | ||
|
|
||
| @property | ||
| def has_embeddings(self) -> bool: | ||
| if self._embeddings is None: | ||
| return False | ||
| return True | ||
|
|
||
| @property | ||
| def spectra(self): | ||
| return self._spectra | ||
|
|
||
| @property | ||
| def embeddings(self) -> Embeddings: | ||
| if self._embeddings is None: | ||
| raise ValueError("First run the 'add_embeddings' method") | ||
| return self._embeddings | ||
|
|
||
| @embeddings.setter | ||
| def embeddings(self, embeddings: Optional[Embeddings]): | ||
| if embeddings is None: | ||
| self._embeddings = embeddings | ||
| return | ||
| if not embeddings.index_to_spectrum_hash == tuple(spectrum.__hash__() for spectrum in self.spectra): | ||
| raise ValueError( | ||
| "The embeddings spectrum hashes don't match the spectrum hashes, make sure you use matching embeddings" | ||
| ) | ||
| self._embeddings = embeddings | ||
|
|
||
| @property | ||
| def inchikeys(self): | ||
| return tuple(self.spectrum_indices_per_inchikey.keys()) | ||
|
|
||
| def __copy__(self): | ||
| return AnnotatedSpectrumSet(self.spectra, self.spectrum_indices_per_inchikey, self._embeddings) | ||
|
|
||
| def __eq__(self, other: object): | ||
| if not isinstance(other, AnnotatedSpectrumSet): | ||
| return NotImplemented("__Eq__ can only be done between two AnnotatedSpectrumSets") | ||
| if self.spectra != other.spectra: | ||
| return False | ||
| if self.spectrum_indices_per_inchikey != other.spectrum_indices_per_inchikey: | ||
| return False | ||
| if self._embeddings != other._embeddings: | ||
| return False | ||
| return True | ||
|
|
||
| def __len__(self): | ||
| return len(self._spectra) | ||
|
|
||
| def __repr__(self): | ||
| return ( | ||
| f"AnnotatedSpectrumSet(nr_of_spectra = {len(self)}," | ||
| f"nr_of_unique_inchikeys = {len(self.inchikeys)}, " | ||
| f"has_embeddings={self.has_embeddings})" | ||
| ) | ||
|
|
||
| def __str__(self): | ||
| with_embeddings = "" | ||
| if self.has_embeddings: | ||
| with_embeddings = "with embeddings" | ||
|
|
||
| return f"{len(self)} spectra for {len(self.inchikeys)} inchikeys {with_embeddings}" | ||
|
|
||
| def save(self, save_file: str) -> None: | ||
| """Save spectra to the specified path""" | ||
| save_spectra(list(self._spectra), save_file) | ||
|
|
||
| if self._embeddings is not None: | ||
| embedding_save_name = os.path.splitext(save_file)[0] + "_embeddings.npz" | ||
| print(f"Saving embeddings at {embedding_save_name}") | ||
| self._embeddings.save(embedding_save_name) | ||
|
|
||
| @classmethod | ||
| def load(cls, spectrum_file: str) -> "AnnotatedSpectrumSet": | ||
| """Load mass spectra into a AnnotatedSpectrmuSet, if embeddings are available they are loaded too""" | ||
| spectra = list(load_spectra(spectrum_file)) | ||
|
|
||
| embedding_file_name = os.path.splitext(spectrum_file)[0] + "_embeddings.npz" | ||
| instance = cls.create_spectrum_set(spectra) | ||
| if os.path.exists(embedding_file_name): | ||
| instance.embeddings = Embeddings.load(embedding_file_name) | ||
| return instance | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| import json | ||
| from pathlib import Path | ||
| from typing import Sequence | ||
| import numpy as np | ||
| import pandas as pd | ||
| from matchms import Spectrum | ||
| from ms2deepscore.models import SiameseSpectralModel, compute_embedding_array | ||
| from ms2deepscore.vector_operations import cosine_similarity_matrix | ||
| from tqdm import tqdm | ||
|
|
||
|
|
||
| class Embeddings: | ||
| """Stores Embeddings for a list of mass spectra""" | ||
|
|
||
| def __init__(self, embeddings: np.ndarray, spectrum_hashes: tuple, model_settings: dict): | ||
| if len(spectrum_hashes) != embeddings.shape[0]: | ||
| raise ValueError("Number of spectra hashes does not match number of embeddings") | ||
| self.index_to_spectrum_hash = spectrum_hashes | ||
| self._spectrum_hash_to_index = { | ||
| spectrum_hash: index for index, spectrum_hash in enumerate(self.index_to_spectrum_hash) | ||
| } | ||
| self.model_settings = model_settings | ||
| self._embeddings = embeddings | ||
|
|
||
| @classmethod | ||
| def create_from_spectra(cls, spectra: Sequence[Spectrum], model: SiameseSpectralModel) -> "Embeddings": | ||
| index_to_spectrum_hash = tuple(spectrum.__hash__() for spectrum in tqdm(spectra, desc="Hashing spectra")) | ||
| if len(set(index_to_spectrum_hash)) != len(spectra): | ||
| raise ValueError("There are duplicated spectra in the spectrum list") | ||
|
|
||
| model_settings = model.model_settings.get_dict() | ||
| embeddings: np.ndarray = compute_embedding_array(model, spectra) # type: ignore | ||
| return cls(embeddings, index_to_spectrum_hash, model_settings) | ||
|
|
||
| def __add__(self, other: "Embeddings") -> "Embeddings": | ||
| if not isinstance(other, Embeddings): | ||
| return NotImplemented | ||
| if self.model_settings != other.model_settings: | ||
| raise ValueError("Model settings of merged embeddings do not match") | ||
| if not set(self.index_to_spectrum_hash).isdisjoint(other.index_to_spectrum_hash): | ||
| raise ValueError("There are repeated spectra in the embeddings that are added together") | ||
| combined_embeddings = np.vstack([self._embeddings, other._embeddings]) | ||
| index_to_spectrum_hash = self.index_to_spectrum_hash + other.index_to_spectrum_hash | ||
| return Embeddings(combined_embeddings, index_to_spectrum_hash, self.model_settings) | ||
|
|
||
| def subset_embeddings(self, spectra): | ||
| spectrum_hashes = tuple(spectrum.__hash__() for spectrum in spectra) | ||
| try: | ||
| embedding_indexes = [self._spectrum_hash_to_index[spectrum_hash] for spectrum_hash in spectrum_hashes] | ||
| except KeyError: | ||
| raise ValueError("The given spectra are not stored in Embeddings") | ||
| embeddings = self._embeddings[embedding_indexes].copy() | ||
| return Embeddings(embeddings, spectrum_hashes, self.model_settings) | ||
|
|
||
| @property | ||
| def embeddings(self): | ||
| return self._embeddings.view() | ||
|
|
||
| @property | ||
| def model_settings(self): | ||
| return self._model_settings.copy() | ||
|
|
||
| @model_settings.setter | ||
| def model_settings(self, model_settings): | ||
| self._model_settings: dict = _to_json_serializable(model_settings) | ||
|
|
||
| def copy(self) -> "Embeddings": | ||
| return Embeddings( | ||
| embeddings=self._embeddings.copy(), | ||
| spectrum_hashes=tuple(self.index_to_spectrum_hash), | ||
| model_settings=dict(self.model_settings), | ||
| ) | ||
|
|
||
| def __eq__(self, other) -> bool: | ||
| if not isinstance(other, Embeddings): | ||
| return NotImplemented | ||
| if self.model_settings != other.model_settings: | ||
| print("Model setting not equal") | ||
| return False | ||
| if self.index_to_spectrum_hash != other.index_to_spectrum_hash: | ||
| print("index to spectrum hash not equal") | ||
| return False | ||
| return np.array_equal(self.embeddings, other.embeddings) | ||
|
|
||
| def save(self, path: str | Path) -> None: | ||
| """Save embeddings to a .npz file with metadata stored alongside. | ||
|
|
||
| Args: | ||
| path: File path. A '.npz' extension will be added if not present. | ||
| """ | ||
| path = Path(path).with_suffix(".npz") | ||
| metadata = { | ||
| "model_settings": self.model_settings, | ||
| "index_to_spectrum_hash": list(self.index_to_spectrum_hash), | ||
| } | ||
| np.savez_compressed( | ||
| path, | ||
| embeddings=self._embeddings, | ||
| metadata=np.array(json.dumps(metadata)), | ||
| ) | ||
|
|
||
| @classmethod | ||
| def load(cls, path: str | Path) -> "Embeddings": | ||
| """Load embeddings from a saved .npz file. | ||
|
|
||
| Args: | ||
| path: Path to the saved .npz file. | ||
| """ | ||
| path = Path(path).with_suffix(".npz") | ||
| with np.load(path, allow_pickle=False) as data: | ||
| embeddings = data["embeddings"] | ||
| metadata = json.loads(data["metadata"].item()) | ||
| return cls( | ||
| embeddings=embeddings, | ||
| spectrum_hashes=tuple(metadata["index_to_spectrum_hash"]), | ||
| model_settings=metadata["model_settings"], | ||
| ) | ||
|
|
||
|
|
||
| def calculate_ms2deepscore_df(query_embeddings: Embeddings, library_embeddings: Embeddings): | ||
| """Returns a DF, where the indexes and column labels are the spectrum hashes""" | ||
| ms2deepscores = cosine_similarity_matrix(query_embeddings.embeddings, library_embeddings.embeddings) | ||
| return pd.DataFrame( | ||
| ms2deepscores, index=query_embeddings.index_to_spectrum_hash, columns=library_embeddings.index_to_spectrum_hash | ||
| ) | ||
|
|
||
|
|
||
| def _to_json_serializable(obj): | ||
| """Changes a dict to be json sericalizable, so it is the same when loaded""" | ||
| if isinstance(obj, dict): | ||
| return {key: _to_json_serializable(value) for key, value in obj.items()} | ||
| if isinstance(obj, (list, tuple)): | ||
| return [_to_json_serializable(item) for item in obj] | ||
| if isinstance(obj, np.integer): | ||
| return int(obj) | ||
| if isinstance(obj, np.floating): | ||
| return float(obj) | ||
| if isinstance(obj, np.ndarray): | ||
| return obj.tolist() | ||
| return obj |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.