Skip to content

Bug: Incorrect handling of boolean state leads to memory issues on training backups and inefficient memory usage #1

@Koratahiu

Description

@Koratahiu

Describe the bug
The current PyTorch implementation of the optimizer handles boolean states in a way that leads to two significant issues:

  1. State Loading Failure: When a model and its optimizer state are saved and then loaded, the boolean state is not correctly recognized. PyTorch saves torch.bool tensors as torch.uint8. The code explicitly checks for torch.bool, causing it to miss the loaded uint8 state, create a new boolean tensor, and leave the old uint8 tensor orphaned in memory. This effectively creates a memory leak every time a checkpoint is resumed.
  2. Memory Inefficiency: The paper states, "SMMF circumvents this by storing the binary sign matrix (1-bit)". However, the PyTorch implementation uses torch.bool, which is an alias for torch.uint8 and consumes 8 bits per element, not 1 bit. This results in an 8x increase in memory usage for this state compared to the paper's claim, negating a key advantage of the method.

To Reproduce
Steps to reproduce the behavior:

  1. Initialize a model and the optimizer.
  2. Run a few training steps to populate the optimizer state.
  3. Save the state_dict of the optimizer using torch.save().
  4. Load the state_dict back into a new optimizer instance using torch.load().
  5. Observe the optimizer's state dictionary. A new torch.bool tensor will be created for state['sign'] while the original, now torch.uint8, tensor also remains in memory.

Observed Behavior

The code block responsible for the state loading issue is:

              sign = state['sign']
              if sign.dtype != torch.bool:
                  sign = sign.type(torch.bool)

When loading a checkpoint, sign.dtype is torch.uint8, so the condition is met. However, this creates a new tensor instead of correctly re-casting the loaded state, leading to redundant memory usage.

Furthermore, my testing with an SDXL model shows the memory impact of the 8-bit boolean issue. The expected overhead should be approximately 0.5GB, but the actual memory consumption is around 2.5GB, which is half the size of the bf16 model itself (8-bit state).

Expected Behavior

  1. When loading a checkpoint and states, the optimizer should correctly recognize the torch.uint8 tensor as the saved boolean state and cast it back to torch.bool in-place, without allocating new memory.
  2. The memory footprint should be closer to the 1-bit representation described in the paper.

Suggested Fix

  1. For the state loading bug, the type-checking logic could be modified to handle the torch.uint8 case explicitly upon loading, ensuring the state is properly restored without duplication.
  2. For the memory inefficiency, consider implementing a bit-packing solution where 8 boolean values are stored in a 1-bit boolean. This would align the implementation with the paper's claim of 1-bit storage and deliver the expected memory savings.

Despite these issues, the optimizer is lightweight and powerful. Fixing the state management and memory representation would make it significantly more robust and efficient.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions