Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ from chemap import compute_fingerprints, DatasetLoader, FingerprintConfig


ds_loader = DatasetLoader()
# Load a single dataset from a local file
smiles = ds_loader.load("tests/data/smiles.csv")
# or load a dataset collection from a DOI based registry (e.g., Zenodo)
files = ds_loader.load_collection("10.5281/zenodo.18682050")
# pass one of the absolute file paths from files
smiles = ds_loader.load(files[0])

# ----------------------------
# RDKit: Morgan (folded, dense)
Expand Down
62 changes: 60 additions & 2 deletions chemap/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pathlib
import re
import pandas as pd
import pooch

Expand Down Expand Up @@ -33,6 +34,30 @@ def load(self, source: str, **kwargs) -> list:
else:
raise ValueError(f"Source {source} unknown.")

def load_collection(self, source: str, **kwargs) -> list:
"""
Loads a dataset collection from a DOI-based registry (e.g. Zenodo).

Parameters
-------------
source:
A DOI.

Returns
-------------
list of downloaded filenames from the registry.

Raises
-------------
ValueError if DOI not present.
"""
doi_pattern = r'(10\.\d{4,9}/[-._;()/:a-zA-Z0-9]+)'

if not source.startswith("doi") or not bool(re.search(doi_pattern, source)):
ValueError(f"Could not detect DOI in source {source}.")

return self._from_registry(source, **kwargs)

def _from_local_file(self, path, smiles_column: str = "smiles") -> list:
"""
Loads a dataset from local file.
Expand Down Expand Up @@ -67,10 +92,13 @@ def _from_local_file(self, path, smiles_column: str = "smiles") -> list:
else:
raise ValueError(f"Fileformat {suffix} not supported.")

if smiles_column not in df.columns:
column_map = {col.lower(): col for col in df.columns}
target_col = column_map.get(smiles_column.lower())

if not target_col:
raise ValueError(f"Smiles column {smiles_column} not in dataframe.")

return df[smiles_column].tolist()
return df[target_col].tolist()

def _from_web(self, url: str, **kwargs) -> list:
"""
Expand All @@ -93,3 +121,33 @@ def _from_web(self, url: str, **kwargs) -> list:
)

return self._from_local_file(file_path, **kwargs)

def _from_registry(self, doi: str, **kwargs) -> list:
"""
Loads a dataset collection from DOI-based registry (e.g., Zenodo).

Parameters
-------------
doi:
A valid DOI string.

Returns
-------------
list of strings with absolute path for all downloaded files.

Raises
-------------
ValueError if file type unsupported.
ValueError if smiles column not present.
"""
if not doi.startswith("doi"):
doi = f"doi:{doi}"

client = pooch.create(
path=self.cache_dir,
base_url=f"{doi}/",
registry=None,
)
client.load_registry_from_doi()

return [client.fetch(f, progressbar=True) for f in client.registry]