-
Notifications
You must be signed in to change notification settings - Fork 8
feat: Use HF datasets for data logic #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1238bcb
7108909
0c459ba
55ab807
4444f02
0dc8d7a
0e09571
73590a6
3797f60
d474c59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| --- | ||
| {{ card_data }} | ||
| --- | ||
|
|
||
| # {{ repo_id or dataset_name }} Dataset Card | ||
|
|
||
| This dataset was created with [Tokenlearn](https://github.com/MinishLab/tokenlearn) for training [Model2Vec](https://github.com/MinishLab/model2vec) models. It contains mean token embeddings produced by a sentence transformer, used as training targets for static embedding distillation. | ||
|
|
||
| ## Dataset Details | ||
|
|
||
| | Field | Value | | ||
| |---|---| | ||
| | **Source dataset** | [{{ source_dataset }}](https://huggingface.co/datasets/{{ source_dataset }}) | | ||
| | **Source split** | `{{ source_split }}` | | ||
| | **Embedding model** | [{{ model_name }}](https://huggingface.co/{{ model_name }}) | | ||
| | **Embedding dimension** | {{ embedding_dim }} | | ||
| | **Rows** | {{ num_rows }} | | ||
|
|
||
| ## Dataset Structure | ||
|
|
||
| | Column | Type | Description | | ||
| |---|---|---| | ||
| | `text` | `string` | Truncated input text | | ||
| | `embedding` | `list[float32]` | Mean token embedding from `{{ model_name }}`, excluding BOS/EOS tokens | | ||
|
|
||
| ## Usage | ||
|
|
||
| Load with the `datasets` library: | ||
|
|
||
| ```python | ||
| from datasets import load_dataset | ||
|
|
||
| dataset = load_dataset("{{ repo_id or dataset_name }}") | ||
| ``` | ||
|
|
||
| Train a Model2Vec model on this dataset using Tokenlearn: | ||
|
|
||
| ```bash | ||
| python -m tokenlearn.train \ | ||
| --model-name "{{ model_name }}" \ | ||
| --data-path "{{ repo_id or dataset_name }}" \ | ||
| --save-path "<path-to-save-model>" | ||
| ``` | ||
|
|
||
| ## Creation | ||
|
|
||
| This dataset was created using the `tokenlearn-featurize` CLI: | ||
|
|
||
| ```bash | ||
| python -m tokenlearn.featurize \ | ||
| --model-name "{{ model_name }}" \ | ||
| --dataset-path "{{ source_dataset }}" \ | ||
| --dataset-name "{{ source_name }}" \ | ||
| --dataset-split "{{ source_split }}" \ | ||
| --output-dir "<output-dir>" | ||
| ``` | ||
|
|
||
| ## Library Authors | ||
|
|
||
| Tokenlearn was developed by the [Minish](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled). | ||
|
|
||
| ## Citation | ||
|
|
||
| ``` | ||
| @article{minishlab2024model2vec, | ||
| author = {Tulkens, Stephan and {van Dongen}, Thomas}, | ||
| title = {Model2Vec: Fast State-of-the-Art Static Embeddings}, | ||
| year = {2024}, | ||
| url = {https://github.com/MinishLab/model2vec} | ||
| } | ||
| ``` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,22 +1,89 @@ | ||
| import argparse | ||
| import json | ||
| import logging | ||
| import shutil | ||
| from pathlib import Path | ||
| from typing import Iterator | ||
|
|
||
| import numpy as np | ||
| from datasets import load_dataset | ||
| from datasets import Dataset, Features, Sequence, Value, concatenate_datasets, load_dataset, load_from_disk | ||
| from huggingface_hub import DatasetCard, DatasetCardData | ||
| from more_itertools import batched | ||
| from sentence_transformers import SentenceTransformer | ||
| from tqdm import tqdm | ||
| from transformers.tokenization_utils import PreTrainedTokenizer | ||
|
|
||
| _DATASET_CARD_TEMPLATE = Path(__file__).parent / "datacards" / "dataset_card_template.md" | ||
|
|
||
| _SAVE_EVERY = 32 | ||
|
|
||
| _FEATURES = Features({"text": Value("string"), "embedding": Sequence(Value("float32"))}) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _save_checkpoint(checkpoints_dir: Path, texts: list[str], embeddings: list[np.ndarray], part_idx: int) -> None: | ||
| """Save a checkpoint part as a HuggingFace dataset.""" | ||
| part = Dataset.from_dict( | ||
| {"text": texts, "embedding": [e.tolist() for e in embeddings]}, | ||
| features=_FEATURES, | ||
| ) | ||
| part.save_to_disk(str(checkpoints_dir / f"part_{part_idx:08d}")) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. regarding saving. You can directly save this as a parquet file, e.g., To do this, you can apparently first force it into a single shard, and then save it to parquet: single_shard = ds.shard(num_shards=1, index=0)
single_shard.to_parquet(f"shard_{part_idx:08d}.parquet")Note that this only works for datasets, not datasetdicts. |
||
|
|
||
|
|
||
| def _compact_checkpoints(checkpoints_dir: Path, output_dir: Path, keep_checkpoints: bool) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you do the above, you'd just need to write the metadata. But I think you can get away with not writing any metadata tbh. |
||
| """Compact checkpoint parts into a single standard HuggingFace dataset.""" | ||
| part_dirs = sorted(checkpoints_dir.glob("part_*/")) | ||
| if not part_dirs: | ||
| return | ||
|
|
||
| logger.info("Compacting checkpoints into final dataset...") | ||
| # Build the compacted dataset in a sibling temp dir, then replace output_dir. | ||
| tmp_dir = output_dir.parent / f"{output_dir.name}.tmp" | ||
| if tmp_dir.exists(): | ||
| shutil.rmtree(tmp_dir) | ||
| # Load all parts and concatenate them into a single dataset, then save to the temp dir. | ||
| dataset = concatenate_datasets([load_from_disk(str(d)) for d in part_dirs]) | ||
| dataset.save_to_disk(str(tmp_dir)) | ||
| if output_dir.exists(): | ||
| # Remove the old output dir before renaming the temp dir to avoid leaving stale Arrow files from previous runs. | ||
| shutil.rmtree(output_dir) | ||
| tmp_dir.rename(output_dir) | ||
| if not keep_checkpoints: | ||
| shutil.rmtree(checkpoints_dir) | ||
| logger.info(f"Dataset saved to {output_dir}") | ||
|
|
||
|
|
||
| def _create_dataset_card( | ||
| output_dir: Path, | ||
| model_name: str, | ||
| source_dataset: str, | ||
| source_name: str, | ||
| source_split: str, | ||
| num_rows: int, | ||
| embedding_dim: int, | ||
| repo_id: str | None = None, | ||
| ) -> DatasetCard: | ||
| """Create a dataset card, save it to the output directory, and return it.""" | ||
| card_data = DatasetCardData( | ||
| language="en", | ||
| tags=["tokenlearn", "embeddings", "model2vec"], | ||
| ) | ||
| card = DatasetCard.from_template( | ||
| card_data, | ||
| template_path=str(_DATASET_CARD_TEMPLATE), | ||
| repo_id=repo_id, | ||
| dataset_name=output_dir.name, | ||
| model_name=model_name, | ||
| source_dataset=source_dataset, | ||
| source_name=source_name, | ||
| source_split=source_split, | ||
| num_rows=num_rows, | ||
| embedding_dim=embedding_dim, | ||
| ) | ||
| card.save(output_dir / "README.md") | ||
| return card | ||
|
|
||
|
|
||
| def featurize( # noqa C901 | ||
| dataset: Iterator[dict[str, str]], | ||
| model: SentenceTransformer, | ||
|
|
@@ -25,15 +92,19 @@ def featurize( # noqa C901 | |
| batch_size: int, | ||
| text_key: str, | ||
| max_length: int | None = None, | ||
| keep_checkpoints: bool = False, | ||
| ) -> None: | ||
| """Make a directory and dump all kinds of data in it.""" | ||
| output_dir_path = Path(output_dir) | ||
| output_dir_path.mkdir(parents=True, exist_ok=True) | ||
| checkpoints_dir = Path(str(output_dir_path) + ".checkpoints") | ||
| checkpoints_dir.mkdir(exist_ok=True) | ||
|
|
||
| # Ugly hack | ||
| largest_batch = max([int(x.stem.split("_")[1]) for x in list(output_dir_path.glob("*.json"))], default=0) | ||
| if largest_batch: | ||
| logger.info(f"Resuming from batch {largest_batch}, skipping previous batches.") | ||
| part_dirs = sorted(checkpoints_dir.glob("part_*/")) | ||
| part_idx = len(part_dirs) | ||
| rows_done = sum(len(load_from_disk(str(d))) for d in part_dirs) | ||
| if rows_done: | ||
| logger.info(f"Resuming from {rows_done} previously written rows ({part_idx} checkpoint parts).") | ||
|
|
||
| texts = [] | ||
| embeddings = [] | ||
|
|
@@ -53,7 +124,7 @@ def featurize( # noqa C901 | |
| if i * batch_size >= max_means: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe rewrite to max_rows, or call the other variable |
||
| logger.info(f"Reached maximum number of means: {max_means}") | ||
| break | ||
| if largest_batch and i <= largest_batch: | ||
| if i * batch_size < rows_done: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you compute i * batch_size twice. |
||
| continue | ||
| batch = [x[text_key] for x in batch] | ||
|
|
||
|
|
@@ -65,13 +136,14 @@ def featurize( # noqa C901 | |
| texts.append(_truncate_text(tokenizer, text)) | ||
| embeddings.append(embedding[1:-1].float().mean(axis=0).cpu().numpy()) | ||
| if i and i % _SAVE_EVERY == 0: | ||
| json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4) | ||
| np.save(output_dir_path / f"feature_{i}.npy", embeddings) | ||
| _save_checkpoint(checkpoints_dir, texts, embeddings, part_idx) | ||
| part_idx += 1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is kind of random. So if someone were to switch batch size after resuming, the resume logic would still work I guess. But you wouldn't be able to guess So relying on Reinterpret |
||
| texts = [] | ||
| embeddings = [] | ||
| if texts: | ||
| json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4) | ||
| np.save(output_dir_path / f"feature_{i}.npy", embeddings) | ||
| _save_checkpoint(checkpoints_dir, texts, embeddings, part_idx) | ||
|
|
||
| _compact_checkpoints(checkpoints_dir, output_dir_path, keep_checkpoints) | ||
|
|
||
|
|
||
| def _truncate_text(tokenizer: PreTrainedTokenizer, text: str) -> str: | ||
|
|
@@ -141,6 +213,17 @@ def main() -> None: | |
| help="Batch size to use for encoding the texts.", | ||
| ) | ||
| parser.add_argument("--max-length", type=int, default=None, help="Maximum token length for the tokenizer.") | ||
| parser.add_argument( | ||
| "--keep-checkpoints", | ||
| action="store_true", | ||
| help="Keep checkpoint parts after compaction (default: delete them).", | ||
| ) | ||
| parser.add_argument( | ||
| "--push-to-hub", | ||
| type=str, | ||
| default=None, | ||
| help="HuggingFace Hub repo ID to push the dataset to after featurizing (e.g., 'username/my-dataset').", | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
|
|
@@ -159,7 +242,36 @@ def main() -> None: | |
| streaming=args.no_streaming, | ||
| ) | ||
|
|
||
| featurize(iter(dataset), model, output_dir, args.max_means, args.batch_size, args.key, max_length=args.max_length) | ||
| featurize( | ||
| iter(dataset), | ||
| model, | ||
| output_dir, | ||
| args.max_means, | ||
| args.batch_size, | ||
| args.key, | ||
| max_length=args.max_length, | ||
| keep_checkpoints=args.keep_checkpoints, | ||
| ) | ||
|
|
||
| output_dir_path = Path(output_dir) | ||
| if (output_dir_path / "dataset_info.json").exists(): | ||
| ds = load_from_disk(output_dir) | ||
| card = _create_dataset_card( | ||
| output_dir=output_dir_path, | ||
| model_name=args.model_name, | ||
| source_dataset=args.dataset_path, | ||
| source_name=args.dataset_name, | ||
| source_split=args.dataset_split, | ||
| num_rows=len(ds), | ||
| embedding_dim=len(ds[0]["embedding"]), | ||
| repo_id=args.push_to_hub, | ||
| ) | ||
| if args.push_to_hub: | ||
| logger.info(f"Pushing dataset to Hub: {args.push_to_hub}") | ||
| ds.push_to_hub(args.push_to_hub) | ||
| card.push_to_hub(args.push_to_hub) | ||
| else: | ||
| logger.warning("No data was written — skipping dataset card and Hub push.") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can add a zenodo citation for the software as well? If that's possible. (not this PR, just future)