Describe the bug
The current PyTorch implementation of the optimizer handles boolean states in a way that leads to two significant issues:
- State Loading Failure: When a model and its optimizer state are saved and then loaded, the boolean state is not correctly recognized. PyTorch saves
torch.bool tensors as torch.uint8. The code explicitly checks for torch.bool, causing it to miss the loaded uint8 state, create a new boolean tensor, and leave the old uint8 tensor orphaned in memory. This effectively creates a memory leak every time a checkpoint is resumed.
- Memory Inefficiency: The paper states, "SMMF circumvents this by storing the binary sign matrix (1-bit)". However, the PyTorch implementation uses
torch.bool, which is an alias for torch.uint8 and consumes 8 bits per element, not 1 bit. This results in an 8x increase in memory usage for this state compared to the paper's claim, negating a key advantage of the method.
To Reproduce
Steps to reproduce the behavior:
- Initialize a model and the optimizer.
- Run a few training steps to populate the optimizer state.
- Save the
state_dict of the optimizer using torch.save().
- Load the
state_dict back into a new optimizer instance using torch.load().
- Observe the optimizer's state dictionary. A new
torch.bool tensor will be created for state['sign'] while the original, now torch.uint8, tensor also remains in memory.
Observed Behavior
The code block responsible for the state loading issue is:
sign = state['sign']
if sign.dtype != torch.bool:
sign = sign.type(torch.bool)
When loading a checkpoint, sign.dtype is torch.uint8, so the condition is met. However, this creates a new tensor instead of correctly re-casting the loaded state, leading to redundant memory usage.
Furthermore, my testing with an SDXL model shows the memory impact of the 8-bit boolean issue. The expected overhead should be approximately 0.5GB, but the actual memory consumption is around 2.5GB, which is half the size of the bf16 model itself (8-bit state).
Expected Behavior
- When loading a checkpoint and states, the optimizer should correctly recognize the
torch.uint8 tensor as the saved boolean state and cast it back to torch.bool in-place, without allocating new memory.
- The memory footprint should be closer to the 1-bit representation described in the paper.
Suggested Fix
- For the state loading bug, the type-checking logic could be modified to handle the
torch.uint8 case explicitly upon loading, ensuring the state is properly restored without duplication.
- For the memory inefficiency, consider implementing a bit-packing solution where 8 boolean values are stored in a 1-bit boolean. This would align the implementation with the paper's claim of 1-bit storage and deliver the expected memory savings.
Despite these issues, the optimizer is lightweight and powerful. Fixing the state management and memory representation would make it significantly more robust and efficient.
Describe the bug
The current PyTorch implementation of the optimizer handles boolean states in a way that leads to two significant issues:
torch.booltensors astorch.uint8. The code explicitly checks fortorch.bool, causing it to miss the loadeduint8state, create a new boolean tensor, and leave the olduint8tensor orphaned in memory. This effectively creates a memory leak every time a checkpoint is resumed.torch.bool, which is an alias fortorch.uint8and consumes 8 bits per element, not 1 bit. This results in an 8x increase in memory usage for this state compared to the paper's claim, negating a key advantage of the method.To Reproduce
Steps to reproduce the behavior:
state_dictof the optimizer usingtorch.save().state_dictback into a new optimizer instance usingtorch.load().torch.booltensor will be created forstate['sign']while the original, nowtorch.uint8, tensor also remains in memory.Observed Behavior
The code block responsible for the state loading issue is:
When loading a checkpoint,
sign.dtypeistorch.uint8, so the condition is met. However, this creates a new tensor instead of correctly re-casting the loaded state, leading to redundant memory usage.Furthermore, my testing with an SDXL model shows the memory impact of the 8-bit boolean issue. The expected overhead should be approximately 0.5GB, but the actual memory consumption is around 2.5GB, which is half the size of the bf16 model itself (8-bit state).
Expected Behavior
torch.uint8tensor as the saved boolean state and cast it back to torch.bool in-place, without allocating new memory.Suggested Fix
torch.uint8case explicitly upon loading, ensuring the state is properly restored without duplication.Despite these issues, the optimizer is lightweight and powerful. Fixing the state management and memory representation would make it significantly more robust and efficient.