Wavelet-Mamba Dictionary Compression (WMDC) is a state-of-the-art learned image compression model that combines wavelet transforms, Visual State Space (VSS) modules, and dictionary-based coding with entropic optimal transport. The model achieves superior rate-distortion performance through three key innovations:
- Frequency-Disentangled Mamba (FDM): Efficiently captures long-range dependencies while preserving high-frequency details
- Spatially-Adaptive EOT Dictionary Attention: Dynamically adapts dictionary utilization using unbalanced optimal transport
- Markovian Slice-Based Context Model: Autoregressive entropy modeling with latent residual prediction
- Python 3.11
- CUDA 11.8+ (for GPU support)
- torch, torchvision
- Clone the repository and navigate to the project directory:
cd WMDC- Create and activate a virtual environment:
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate- Install dependencies:
pip install -r requirements.txt- Install vmamba (selective scan operations):
cd vmamba
pip install -e .
cd ..Basic training command:
python train.py -d /path/to/dataset --save_path checkpointsSingle GPU training:
python train.py \
-d /path/to/dataset \
--save_path checkpoints \
--epochs 100 \
--batch-size 8 \
--lr 1e-4 \
--aux-lr 1e-3 \
--lambda 0.0130 \
--patch-size 256 \
--metric mse \
--seed 2026Multi-GPU distributed training:
acceleate launch train.py \
-d /path/to/dataset \
--save_path checkpoints \
--epochs 10 \
--batch-size 8 \
--lambda 0.0018 \
--routing_mode softmaxTraining Arguments:
-d, --dataset: Path to image dataset (required)--save_path: Directory to save checkpoints (default:checkpoints)-e, --epochs: Number of training epochs (default: 100)--batch-size: Batch size (default: 8)--lr: Learning rate (default: 1e-4)--aux-lr: Auxiliary learning rate (default: 1e-3)--lambda: Rate-distortion weight (default: None)--patch-size: Training patch size (default: 256)--clip_max_norm: Gradient clipping (default: 1.0)--checkpoint: Path to resume from checkpoint--metric: Loss metric -mseorms-ssim(default:mse)--routing_mode: Routing mode -softmax,balanced_eot, orunbalanced_eot(default:softmax)--seed: Random seed (default: 2026)
Evaluate on a dataset:
python eval.py \
--dataset /path/to/test/images \
--checkpoint /path/to/checkpoint.pth \
--output results \
--cudaExample with Kodak dataset:
python eval.py \
--dataset /kaggle/input/datasets/kodak-test \
--checkpoint ./checkpoints/lambda_0.0018_mse/checkpoint_best.pth.tar \
--output kodak \
--cuda \
--routing_mode softmaxEvaluation Arguments:
--dataset: Path to test images (required)--checkpoint: Path to model checkpoint (required)--output: Output directory for results (default:output)--cuda: Use GPU for inference--routing_mode: Routing mode -softmax,balanced_eot, orunbalanced_eot(default:softmax)
Visualize attention maps on image directory:
python analyze/visualize_attention.py \
--img_dir /path/to/image \
--checkpoint /path/to/checkpoint.pth \
--slice 3 \
--cuda \
-o attention_maps.pdf \
--routing_mode softmaxVisualize latent features:
python analyze/visualize_latents.py \
--image /path/to/image.png \
--checkpoint /path/to/checkpoint.pth \
--routing_mode softmax \
--output latent_visualizationVisualize patches:
python analyze/visualize_patches.py \
-i /path/to/image.png \
-c /path/to/checkpoint.pth \
--cuda \
--routing_mode softmax \
-o visual_comparisonRate-distortion plot:
python plot_rd.py --checkpoint /path/to/checkpoint.pth --dataset /path/to/datasetProfile model performance:
python analyze/profile_model.py \
--checkpoint /path/to/checkpoint.pth \
--input_size 256 \
--cudaAnalyze bit allocation:
python analyze/analysis_bit_allocation.py \
--checkpoint /path/to/checkpoint.pth \
--dataset /path/to/dataset \
--output bit_analysis_resultsPlot bitrate-distortion latency:
python analyze/plot_bd_latency.py \
--results_dir /path/to/results \
--output bd_latency_plot.pngThe training script expects images in standard image formats (PNG, JPEG). For best results, the dataset should be organized as:
dataset/
├── image1.png
├── image2.png
├── ...
└── imageN.png
The WMDC model combines:
- Wavelet transforms for multi-scale decomposition
- Visual State Space (VSS) modules for feature encoding
- Multi-scale distribution coding for entropy modeling
- Spatial context models (CSM) with Triton backend optimization
See models/WMDC.py for implementation details.
Training outputs:
- Checkpoints saved to
--save_pathdirectory - Logs saved to
logs/directory - TensorBoard logs for training visualization
Evaluation outputs:
- Metrics saved to results directory
- Compressed images saved for visual inspection
- Rate-distortion data
The model outputs:
- PSNR: Peak Signal-to-Noise Ratio
- MS-SSIM: Multi-Scale Structural Similarity
- BPP: Bits Per Pixel (actual compressed size)
- Loss: Rate-distortion loss value
- CompressAI: https://github.com/InterDigitalInc/CompressAI
- Vision State Space: https://github.com/MzeroMiko/VMamba.git
See LICENSE file for details.