Skip to content

lvae 2d vanilla implementation#57

Open
mkeeler43 wants to merge 4 commits intomainfrom
lvae_2d
Open

lvae 2d vanilla implementation#57
mkeeler43 wants to merge 4 commits intomainfrom
lvae_2d

Conversation

@mkeeler43
Copy link
Contributor

Basic lvae_2d and performance MLP plvae_2d implementations

Copy link
Contributor

@SoheylM SoheylM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are the points I identified:

  1. aes.py: log(s+eta) that can lead to NaN
  2. aes.py: the .to() override
  3. lvae_2d.py: plvae_2d.py: hard-coded [:25] that can fail on small datasets for visualization
  4. lvae_2d.py: plvae_2d.py: possible risk of OOM when encoding full training set at once
  5. lvae_2d.py: plvae_2d.py: Encoder, Decoder class being duplicated. I understand this is done because of the CleanRL mindset, so to be decided if we consider lvae and plvae to be close enough to share these.
  6. lvae_2d.py: plvae_2d.py: th.cuda.empty_cache() called even if I run CPU/MPS
  7. lvae_2d.py: small typo (double period) in url docstring
  8. plvae_2d.py: comment says won't be used in predictor but tensors are actually passed through and the hack works becasue zero-width concat is a no-op
  9. lvae_2d.py: plvae_2d.py: epoch_report called with bactch None and pbar None but I think this is a wanted behavior
  10. Future work: add evaluate scripts for lvae_2d.py and plvae_2d.py

Scalar volume loss.
"""
s = z.std(0)
return torch.exp(torch.log(s + self.eta).mean())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can eta be zero? in this case any latent dimension with zero standard deviation (if possible) can lead to log(0) -> -inf, causing NaN losses. I suggest safeguarding against this by either 1) caatching this case and returning whatever value is appropriate or 2) adding a default non-zero for eta like 1e-8

self._zstd: torch.Tensor | None = None
self._zmean: torch.Tensor | None = None

def to(self, device: torch.device | str) -> LeastVolumeAE_DynamicPruning:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am not mistaken (and can trust my powerful coding LLM), register_buffer calls already ensure these tensors move with the model. This override reassigns them as plain tensors (not buffers), breaking state dict saving/loading.
Unless, I missed why this needs to be specifically defined, I would suggest removing this method since register buffers auto-move with nn.Module.to().

# Generate interpolated designs
x_ints = []
for alpha in [0, 0.25, 0.5, 0.75, 1]:
z_ = (1 - alpha) * z[:25] + alpha * th.roll(z, -1, 0)[:25]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard-coded sample count may fail (and this is not the only place wherer it occurs in code).

If the training set has fewer than 25 samples, this silently produces fewer visualizations - which one may be fine with, if only for visualization. th.roll(z, -1, 0)[:25] will wrap around incorrectly for small datasets.

I would suggest using min(25, len(z)) or parametrize the sample count.

# Generate interpolated designs
x_ints = []
for alpha in [0, 0.25, 0.5, 0.75, 1]:
z_ = (1 - alpha) * z[:25] + alpha * th.roll(z, -1, 0)[:25]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment in lvae_2d.py about this

with th.no_grad():
# Encode training designs
xs = x_train.to(device)
z = lvae.encode(xs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Full dataset encoding may lead to OOM

xs = x_train.to(device)
z = lvae.encode(xs)

I think there is a risk of OOM if encoding the entire dataset at once during visualization.

Would it be worth encoding in batches to be on the safe side?

}
wandb.log(val_log_dict, commit=True)

th.cuda.empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called every epoch regardless whether CUDA is being used (as opposed to CPU/MPS). Harmless but wasteful.

What about doing this instead:
if th.cuda.is_available(): th.cuda.empty_cache()

}
wandb.log(val_log_dict, commit=True)

th.cuda.empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called every epoch regardless whether CUDA is being used (as opposed to CPU/MPS). Harmless but wasteful.

What about doing this instead:
if th.cuda.is_available(): th.cuda.empty_cache()

@@ -0,0 +1,455 @@
"""LVAE for 2D designs with plummet-based dynamic pruning. Adapted from https://github.com/IDEALLab/Least_Volume_ICLR2024..
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: Double period at the end of URL

val_vol /= n

# Trigger pruning check at end of epoch
lvae.epoch_report(epoch=epoch, callbacks=[], batch=None, loss=losses, pbar=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works because callbacks list is empty (otherwise pbar and batch would have to be non-None values. Checking at the rest of the code, I understand this is done on purpose.

c_train_scaled = th.from_numpy(c_scaler.fit_transform(c_train.numpy())).to(c_train.dtype)
c_val_scaled = th.from_numpy(c_scaler.transform(c_val.numpy())).to(c_val.dtype)
else:
# Dummy tensors when not using conditions (won't be used in predictor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would clarify:
zero-width tensors: concatenating with pz is a no-op, so predictor sees only latent

@SoheylM
Copy link
Contributor

SoheylM commented Feb 27, 2026

Hey Matt, just checking, how is the progress with this PR? SOH-13

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants