Skip to content

lblommesteyn/rust-transformers

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Rust Transformers

A comprehensive implementation of advanced Transformer architectures in Rust using the Candle framework by Hugging Face.

Overview

This project implements a state-of-the-art Transformer model in Rust with numerous optimizations and advanced features:

  • Multiple attention mechanisms (Multi-head, Flash Attention, ALiBi, Sliding Window)
  • Rotary Position Embeddings (RoPE)
  • Advanced feed-forward networks with various activation functions
  • Layer normalization and RMSNorm
  • Configurable architecture with support for encoder-only, decoder-only, and encoder-decoder models
  • Optimizations for efficient processing of long sequences

The implementation uses the Candle framework, which is a minimalist ML framework for Rust developed by Hugging Face.

Features

Attention Mechanisms

  • Multi-Head Attention: Standard attention with support for rotary embeddings
  • Flash Attention: Memory-efficient attention implementation based on "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness"
  • ALiBi Attention: Attention with Linear Biases for improved extrapolation to longer sequences
  • Sliding Window Attention: Efficient attention for very long sequences with optional global tokens

Embeddings

  • Token Embeddings: With optional token type embeddings
  • Positional Embeddings: Including sinusoidal, learned, and relative position embeddings
  • Rotary Embeddings: Rotation-based position encoding for improved performance

Model Components

  • Transformer Encoder: With configurable layers and attention mechanisms
  • Transformer Decoder: With self-attention, cross-attention, and feed-forward networks
  • Feed-Forward Networks: With various activation functions (GELU, SiLU, Swish, Mish, etc.)

Utilities

  • Activation Functions: Comprehensive set of activation functions
  • Masking Utilities: For causal and padding masks
  • Tensor Operations: Advanced operations for efficient tensor manipulation
  • Configuration System: Flexible configuration for model architecture and hyperparameters

Dependencies

[dependencies]
candle-core = "0.3.3"
candle-nn = "0.3.3"
anyhow = "1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
rand = "0.8"
rayon = "1.7"
thiserror = "1.0"
safetensors = "0.3"
tokenizers = "0.13"
tracing = "0.1"
tracing-subscriber = "0.3"
clap = { version = "4.3", features = ["derive"] }
num-traits = "0.2"
indicatif = "0.17"
byteorder = "1.4"
memmap2 = "0.7"
half = "2.3"

Getting Started

  1. Make sure you have Rust and Cargo installed
  2. Clone this repository
  3. Run the example:
cargo run

Project Structure

src/
├── attention/
│   ├── multi_head.rs    # Multi-head attention implementation
│   ├── flash.rs         # Flash attention for memory efficiency
│   ├── alibi.rs         # ALiBi attention for length extrapolation
│   └── sliding_window.rs # Sliding window attention for long sequences
├── embeddings/
│   ├── token.rs         # Token embeddings
│   ├── positional.rs    # Positional embeddings
│   └── rotary.rs        # Rotary position embeddings
├── models/
│   ├── transformer.rs   # Main transformer model
│   ├── encoder.rs       # Transformer encoder
│   ├── decoder.rs       # Transformer decoder
│   └── layer.rs         # Encoder and decoder layers
├── utils/
│   ├── activations.rs   # Activation functions
│   ├── masking.rs       # Attention masking utilities
│   ├── tensor_ops.rs    # Tensor operations
│   └── config.rs        # Configuration system
└── main.rs              # Example usage

Advanced Usage

Creating a Transformer Model

// Create a transformer configuration
let config = TransformerConfig::builder()
    .hidden_size(768)
    .num_attention_heads(12)
    .num_hidden_layers(12)
    .intermediate_size(3072)
    .hidden_act(ActivationType::Gelu)
    .max_position_embeddings(512)
    .use_rotary_embeddings(true)
    .build();

// Initialize the model
let transformer = Transformer::new(&config, vb)?;

// Forward pass
let output = transformer.forward(&input_ids, &attention_mask, None, None, true)?;

Using Different Attention Mechanisms

// Using ALiBi attention
let config = TransformerConfig::builder()
    .attention_type(AttentionType::ALiBi)
    .alibi_bias_max(8.0)
    .build();

// Using Flash Attention
let config = TransformerConfig::builder()
    .attention_type(AttentionType::Flash)
    .build();

// Using Sliding Window Attention
let config = TransformerConfig::builder()
    .attention_type(AttentionType::SlidingWindow)
    .sliding_window(256)
    .use_global_tokens(true)
    .num_global_tokens(4)
    .build();

Performance Optimizations

This implementation includes several optimizations for efficient processing:

  1. Flash Attention: Reduces memory usage and improves speed for attention computation
  2. Sliding Window Attention: Enables processing of very long sequences by limiting attention to a local window
  3. Rotary Embeddings: Provides better relative position information without additional parameters
  4. ALiBi: Improves extrapolation to sequences longer than those seen during training

Contributing

Contributions are welcome! Here are some ways you can contribute:

  1. Implement additional attention mechanisms
  2. Add support for more model architectures
  3. Optimize performance for specific hardware
  4. Add examples and benchmarks
  5. Improve documentation

License

MIT

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages