A comprehensive implementation of advanced Transformer architectures in Rust using the Candle framework by Hugging Face.
This project implements a state-of-the-art Transformer model in Rust with numerous optimizations and advanced features:
- Multiple attention mechanisms (Multi-head, Flash Attention, ALiBi, Sliding Window)
- Rotary Position Embeddings (RoPE)
- Advanced feed-forward networks with various activation functions
- Layer normalization and RMSNorm
- Configurable architecture with support for encoder-only, decoder-only, and encoder-decoder models
- Optimizations for efficient processing of long sequences
The implementation uses the Candle framework, which is a minimalist ML framework for Rust developed by Hugging Face.
- Multi-Head Attention: Standard attention with support for rotary embeddings
- Flash Attention: Memory-efficient attention implementation based on "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness"
- ALiBi Attention: Attention with Linear Biases for improved extrapolation to longer sequences
- Sliding Window Attention: Efficient attention for very long sequences with optional global tokens
- Token Embeddings: With optional token type embeddings
- Positional Embeddings: Including sinusoidal, learned, and relative position embeddings
- Rotary Embeddings: Rotation-based position encoding for improved performance
- Transformer Encoder: With configurable layers and attention mechanisms
- Transformer Decoder: With self-attention, cross-attention, and feed-forward networks
- Feed-Forward Networks: With various activation functions (GELU, SiLU, Swish, Mish, etc.)
- Activation Functions: Comprehensive set of activation functions
- Masking Utilities: For causal and padding masks
- Tensor Operations: Advanced operations for efficient tensor manipulation
- Configuration System: Flexible configuration for model architecture and hyperparameters
[dependencies]
candle-core = "0.3.3"
candle-nn = "0.3.3"
anyhow = "1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
rand = "0.8"
rayon = "1.7"
thiserror = "1.0"
safetensors = "0.3"
tokenizers = "0.13"
tracing = "0.1"
tracing-subscriber = "0.3"
clap = { version = "4.3", features = ["derive"] }
num-traits = "0.2"
indicatif = "0.17"
byteorder = "1.4"
memmap2 = "0.7"
half = "2.3"- Make sure you have Rust and Cargo installed
- Clone this repository
- Run the example:
cargo runsrc/
├── attention/
│ ├── multi_head.rs # Multi-head attention implementation
│ ├── flash.rs # Flash attention for memory efficiency
│ ├── alibi.rs # ALiBi attention for length extrapolation
│ └── sliding_window.rs # Sliding window attention for long sequences
├── embeddings/
│ ├── token.rs # Token embeddings
│ ├── positional.rs # Positional embeddings
│ └── rotary.rs # Rotary position embeddings
├── models/
│ ├── transformer.rs # Main transformer model
│ ├── encoder.rs # Transformer encoder
│ ├── decoder.rs # Transformer decoder
│ └── layer.rs # Encoder and decoder layers
├── utils/
│ ├── activations.rs # Activation functions
│ ├── masking.rs # Attention masking utilities
│ ├── tensor_ops.rs # Tensor operations
│ └── config.rs # Configuration system
└── main.rs # Example usage
// Create a transformer configuration
let config = TransformerConfig::builder()
.hidden_size(768)
.num_attention_heads(12)
.num_hidden_layers(12)
.intermediate_size(3072)
.hidden_act(ActivationType::Gelu)
.max_position_embeddings(512)
.use_rotary_embeddings(true)
.build();
// Initialize the model
let transformer = Transformer::new(&config, vb)?;
// Forward pass
let output = transformer.forward(&input_ids, &attention_mask, None, None, true)?;// Using ALiBi attention
let config = TransformerConfig::builder()
.attention_type(AttentionType::ALiBi)
.alibi_bias_max(8.0)
.build();
// Using Flash Attention
let config = TransformerConfig::builder()
.attention_type(AttentionType::Flash)
.build();
// Using Sliding Window Attention
let config = TransformerConfig::builder()
.attention_type(AttentionType::SlidingWindow)
.sliding_window(256)
.use_global_tokens(true)
.num_global_tokens(4)
.build();This implementation includes several optimizations for efficient processing:
- Flash Attention: Reduces memory usage and improves speed for attention computation
- Sliding Window Attention: Enables processing of very long sequences by limiting attention to a local window
- Rotary Embeddings: Provides better relative position information without additional parameters
- ALiBi: Improves extrapolation to sequences longer than those seen during training
Contributions are welcome! Here are some ways you can contribute:
- Implement additional attention mechanisms
- Add support for more model architectures
- Optimize performance for specific hardware
- Add examples and benchmarks
- Improve documentation
MIT