Skip to content

Commit cc8a882

Browse files
committed
add heavyball
1 parent 489742c commit cc8a882

3 files changed

Lines changed: 30 additions & 2 deletions

File tree

model2vec/train/classifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import lightning as pl
1010
import numpy as np
1111
import torch
12+
from heavyball import AdamW, Muon
1213
from lightning.pytorch.callbacks import Callback, EarlyStopping
1314
from lightning.pytorch.utilities.types import OptimizerLRScheduler
1415
from sklearn.metrics import jaccard_score
@@ -429,7 +430,7 @@ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: i
429430

430431
def configure_optimizers(self) -> OptimizerLRScheduler:
431432
"""Configure optimizer and learning rate scheduler."""
432-
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
433+
optimizer = Muon(self.model.parameters(), lr=self.learning_rate)
433434
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
434435
optimizer,
435436
mode="min",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ dev = [
6363
distill = ["torch", "transformers", "scikit-learn"]
6464
onnx = ["onnx", "torch"]
6565
# train also installs inference
66-
train = ["torch", "lightning", "scikit-learn", "skops"]
66+
train = ["torch", "lightning", "scikit-learn", "skops", "heavyball"]
6767
inference = ["scikit-learn", "skops"]
6868
tokenizer = ["transformers"]
6969

uv.lock

Lines changed: 27 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)