Conversation
SoheylM
left a comment
There was a problem hiding this comment.
Here are the points I identified:
- aes.py: log(s+eta) that can lead to NaN
- aes.py: the .to() override
- lvae_2d.py: plvae_2d.py: hard-coded [:25] that can fail on small datasets for visualization
- lvae_2d.py: plvae_2d.py: possible risk of OOM when encoding full training set at once
- lvae_2d.py: plvae_2d.py: Encoder, Decoder class being duplicated. I understand this is done because of the CleanRL mindset, so to be decided if we consider lvae and plvae to be close enough to share these.
- lvae_2d.py: plvae_2d.py: th.cuda.empty_cache() called even if I run CPU/MPS
- lvae_2d.py: small typo (double period) in url docstring
- plvae_2d.py: comment says won't be used in predictor but tensors are actually passed through and the hack works becasue zero-width concat is a no-op
- lvae_2d.py: plvae_2d.py: epoch_report called with bactch None and pbar None but I think this is a wanted behavior
- Future work: add evaluate scripts for lvae_2d.py and plvae_2d.py
| Scalar volume loss. | ||
| """ | ||
| s = z.std(0) | ||
| return torch.exp(torch.log(s + self.eta).mean()) |
There was a problem hiding this comment.
Can eta be zero? in this case any latent dimension with zero standard deviation (if possible) can lead to log(0) -> -inf, causing NaN losses. I suggest safeguarding against this by either 1) caatching this case and returning whatever value is appropriate or 2) adding a default non-zero for eta like 1e-8
| self._zstd: torch.Tensor | None = None | ||
| self._zmean: torch.Tensor | None = None | ||
|
|
||
| def to(self, device: torch.device | str) -> LeastVolumeAE_DynamicPruning: |
There was a problem hiding this comment.
If I am not mistaken (and can trust my powerful coding LLM), register_buffer calls already ensure these tensors move with the model. This override reassigns them as plain tensors (not buffers), breaking state dict saving/loading.
Unless, I missed why this needs to be specifically defined, I would suggest removing this method since register buffers auto-move with nn.Module.to().
| # Generate interpolated designs | ||
| x_ints = [] | ||
| for alpha in [0, 0.25, 0.5, 0.75, 1]: | ||
| z_ = (1 - alpha) * z[:25] + alpha * th.roll(z, -1, 0)[:25] |
There was a problem hiding this comment.
Hard-coded sample count may fail (and this is not the only place wherer it occurs in code).
If the training set has fewer than 25 samples, this silently produces fewer visualizations - which one may be fine with, if only for visualization. th.roll(z, -1, 0)[:25] will wrap around incorrectly for small datasets.
I would suggest using min(25, len(z)) or parametrize the sample count.
| # Generate interpolated designs | ||
| x_ints = [] | ||
| for alpha in [0, 0.25, 0.5, 0.75, 1]: | ||
| z_ = (1 - alpha) * z[:25] + alpha * th.roll(z, -1, 0)[:25] |
There was a problem hiding this comment.
See comment in lvae_2d.py about this
| with th.no_grad(): | ||
| # Encode training designs | ||
| xs = x_train.to(device) | ||
| z = lvae.encode(xs) |
There was a problem hiding this comment.
Full dataset encoding may lead to OOM
xs = x_train.to(device)
z = lvae.encode(xs)
I think there is a risk of OOM if encoding the entire dataset at once during visualization.
Would it be worth encoding in batches to be on the safe side?
| } | ||
| wandb.log(val_log_dict, commit=True) | ||
|
|
||
| th.cuda.empty_cache() |
There was a problem hiding this comment.
This is called every epoch regardless whether CUDA is being used (as opposed to CPU/MPS). Harmless but wasteful.
What about doing this instead:
if th.cuda.is_available(): th.cuda.empty_cache()
| } | ||
| wandb.log(val_log_dict, commit=True) | ||
|
|
||
| th.cuda.empty_cache() |
There was a problem hiding this comment.
This is called every epoch regardless whether CUDA is being used (as opposed to CPU/MPS). Harmless but wasteful.
What about doing this instead:
if th.cuda.is_available(): th.cuda.empty_cache()
| @@ -0,0 +1,455 @@ | |||
| """LVAE for 2D designs with plummet-based dynamic pruning. Adapted from https://github.com/IDEALLab/Least_Volume_ICLR2024.. | |||
There was a problem hiding this comment.
Typo: Double period at the end of URL
| val_vol /= n | ||
|
|
||
| # Trigger pruning check at end of epoch | ||
| lvae.epoch_report(epoch=epoch, callbacks=[], batch=None, loss=losses, pbar=None) |
There was a problem hiding this comment.
This works because callbacks list is empty (otherwise pbar and batch would have to be non-None values. Checking at the rest of the code, I understand this is done on purpose.
| c_train_scaled = th.from_numpy(c_scaler.fit_transform(c_train.numpy())).to(c_train.dtype) | ||
| c_val_scaled = th.from_numpy(c_scaler.transform(c_val.numpy())).to(c_val.dtype) | ||
| else: | ||
| # Dummy tensors when not using conditions (won't be used in predictor) |
There was a problem hiding this comment.
I would clarify:
zero-width tensors: concatenating with pz is a no-op, so predictor sees only latent
|
Hey Matt, just checking, how is the progress with this PR? SOH-13 |
Basic lvae_2d and performance MLP plvae_2d implementations