From c74e92eb59b604ad9c146a7938e0621b3fa2dbea Mon Sep 17 00:00:00 2001 From: ganeshflexiana Date: Mon, 30 Mar 2026 19:03:02 +0530 Subject: [PATCH] BDH development --- DOCUMENTATION.md | 134 +++++++++++++ Dockerfile | 19 ++ Dockerfile.gpu | 14 ++ README.md | 23 --- bdh.py | 440 ++++++++++++++++++++++++++++++++++++++--- bdh_app/.DS_Store | Bin 0 -> 6148 bytes bdh_app/__init__.py | 1 + bdh_app/evaluation.py | 291 +++++++++++++++++++++++++++ bdh_app/memory.py | 272 +++++++++++++++++++++++++ bdh_app/training.py | 272 +++++++++++++++++++++++++ compare_checkpoints.py | 7 + generate.py | 210 ++++++++++++++++++++ main.py | 44 +++++ memory.py | 7 + requirements.txt | 2 + train.py | 125 +----------- 16 files changed, 1684 insertions(+), 177 deletions(-) create mode 100644 DOCUMENTATION.md create mode 100644 Dockerfile create mode 100644 Dockerfile.gpu create mode 100644 bdh_app/.DS_Store create mode 100644 bdh_app/__init__.py create mode 100644 bdh_app/evaluation.py create mode 100644 bdh_app/memory.py create mode 100644 bdh_app/training.py create mode 100755 compare_checkpoints.py create mode 100755 generate.py create mode 100644 main.py create mode 100644 memory.py diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md new file mode 100644 index 0000000..4960a3b --- /dev/null +++ b/DOCUMENTATION.md @@ -0,0 +1,134 @@ +# Baby Dragon Hatchling (BDH) Guide + +This guide explains what the project does, how to run it end-to-end, and how to interpret the outputs. + +## What This Generates + +- Training checkpoints at every 10% of the run. +- A consolidated generation log and a Markdown report comparing checkpoints. +- A loss curve plot (loss vs steps). +- A memory run log and a small loss plot for the model-memory fine-tune. + +## Deliverables Summary + +1. **End-to-End Training**: `train.py` supports configurable steps and automatic checkpointing at every 10% of progress. +2. **Checkpoint Comparison**: `compare_checkpoints.py` evaluates all checkpoints with a fixed prompt and creates a report + consolidated log. +3. **Memory Script**: `memory.py` demonstrates: + - **Fast Memory (Context)**: A lightweight memory store ingests facts from a prompt and answers from memory without keeping the fact in the user prompt. + - **Model Memory (Weights)**: Facts consolidated into model weights via a short fine-tune. +4. **One-Click Run**: `main.py` runs training, evaluation, and memory demos in order. + +## How to Run + +### Using Docker (Recommended for Reproducibility) + +1. **Build the CPU image**: + + ```bash + docker build -t bdh-project -f Dockerfile . + ``` + +2. **Build the GPU image** (CUDA): + + ```bash + docker build -t bdh-project-gpu -f Dockerfile.gpu . + ``` + +3. **Run Training** (saves checkpoints to `outputs/`): + + ```bash + docker run -v $(pwd)/outputs:/app/outputs bdh-project python train.py --max_iters 1000 + ``` + +4. **Compare Checkpoints**: + + ```bash + docker run -v $(pwd)/outputs:/app/outputs bdh-project python compare_checkpoints.py --input_dir outputs/training/checkpoints --output_dir outputs/evaluation + ``` + +5. **Run Memory Script** (requires training checkpoints): + ```bash + docker run -v $(pwd)/outputs:/app/outputs bdh-project python memory.py + ``` +6. **Run All in Order**: + ```bash + docker run -v $(pwd)/outputs:/app/outputs/checkpoints bdh-project python main.py --max_iters 1000 + ``` + +**GPU usage**: + +```bash +docker run --gpus all -v $(pwd)/outputs:/app/outputs bdh-project-gpu python main.py --max_iters 1000 +``` + +### Running Locally + +Ensure you have Python and PyTorch installed. + +1. **Install requirements**: + + ```bash + pip install numpy torch requests psutil pandas matplotlib + ``` + +2. **Train**: + + ```bash + python train.py --max_iters 1000 --out_dir outputs + ``` + + - You can control training length with `--max_iters` (e.g., 100 for a quick run, 3000 for full training). + - CPU training is slower; more iterations mean longer runtime. Use smaller `--max_iters` for quick tests. + +3. **Visuals & Logs**: + + - Training logs (loss vs step) are saved to `outputs/training/logs/training_log.csv`. + - Run `python compare_checkpoints.py --input_dir outputs/training/checkpoints` to see the text generation progress. + - Consolidated generations are written to `outputs/evaluation/checkpoint_generations.log`. + +4. **Run Memory Script** (requires training checkpoints): + ```bash + python memory.py + ``` +5. **Run All in Order**: + ```bash + python main.py --max_iters 1000 + ``` + +## Outputs and Where to Find Them + +- Training outputs: `outputs/training/` + - Checkpoints: `outputs/training/checkpoints/` + - Training log CSV: `outputs/training/logs/training_log.csv` +- Evaluation outputs: `outputs/evaluation/` + - Report: `outputs/evaluation/report.md` + - Consolidated generations log: `outputs/evaluation/checkpoint_generations.log` + - Loss plot: `outputs/evaluation/figs/loss_curve.png` + - Checkpoint samples: `outputs/evaluation/figs/checkpoint_samples.png` + - Output diversity plot: `outputs/evaluation/figs/output_quality.png` + - Multi-prompt runs: `outputs/evaluation/multi_prompt/` +- Memory outputs: `outputs/memory/` + - Memory log: `outputs/memory/memory_log.txt` + - Fine-tune loss plot: `outputs/memory/figs/memory_fine_tuning_loss.png` + +## Results Interpretation + +- **Training**: You will see the loss decrease over time. Around 100-200 steps, the model starts forming recognizable words. By 1000 steps, it should produce Shakespeare-like structure. +- **Checkpoint Comparison**: Early checkpoints produce noise; later checkpoints produce English-like text. +- **Fast Memory**: The memory layer recalls facts without putting them in the user prompt (and the log also shows the model attempt with memory-injected prompt for transparency). +- **Model Memory**: After fine-tuning, the model recalls the fact even without any context provided at runtime. + +## What the Loss Curve Means + +- The loss curve (`outputs/evaluation/figs/loss_curve.png`) shows how prediction error changes over training steps. +- Lower loss means the model is better at predicting the next character/byte. +- A steady downward trend is expected; it indicates the model is learning patterns from the dataset. +- Minor bumps are normal (training is noisy), but the overall trend should go down. +- If validation loss is shown, it should track training loss but stay slightly higher. + +## Additional Visuals Included + +- **Checkpoint samples**: A small panel showing early/mid/final outputs for quick comparison. +- **Output diversity**: Unique character ratio per checkpoint to show output variation over time. +- **Before vs After table**: Inline table in the report with early/mid/late outputs. +- **Multi-prompt checks**: Same checkpoints evaluated on 3 prompts for consistency. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..283c90f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,19 @@ +# CPU-only, minimal base image for smaller size +FROM python:3.11-slim + +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 +ENV PIP_INDEX_URL=https://download.pytorch.org/whl/cpu +ENV PIP_EXTRA_INDEX_URL=https://pypi.org/simple + +WORKDIR /app + +# Copy and install deps first for better layer caching +COPY requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -r /app/requirements.txt + +# Copy project files +COPY . /app + +# Default command: show help +CMD ["python", "main.py", "--help"] diff --git a/Dockerfile.gpu b/Dockerfile.gpu new file mode 100644 index 0000000..cc338a1 --- /dev/null +++ b/Dockerfile.gpu @@ -0,0 +1,14 @@ +# GPU-enabled image (CUDA runtime) +FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime + +WORKDIR /app + +# Install dependencies +COPY requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -r /app/requirements.txt + +# Copy project files +COPY . /app + +# Default command: show help +CMD ["python", "main.py", "--help"] diff --git a/README.md b/README.md index a48d4f0..d548cb3 100644 --- a/README.md +++ b/README.md @@ -50,29 +50,6 @@ BDH follows **Transformer-like scaling laws**, maintaining parameter efficiency *** -## Latest research update: Sudoku Benchmark - -Note: The Sudoku Extreme result refers to Pathway’s internal BDH implementation, not to the current open-source repository. This repository contains the implementation of the baseline variant as described in our [public paper](https://arxiv.org/abs/2509.26507) and does not reproduce the 97.4% benchmark result out of the box. See the dedicated Extreme Sudoku research blog post for additional benchmark context and the reported results. - -On Sudoku Extreme, BDH reaches 97.4% accuracy across roughly 250,000 difficult puzzles, without chain-of-thought, solution backtracking, or external tool use, while leading LLMs struggle to perform on the benchmark at all. - -Language is not enough for intelligence. Transformers process information token by token with limited internal state, which makes search-heavy, non-linguistic reasoning tasks like Sudoku awkward. BDH uses a larger latent reasoning space with intrinsic memory that supports learning and adaptation during use. - -We believe that the future of AI will belong to systems that can reason natively across domains, that can hold multiple possibilities in a rich latent space, and that can converge on solutions without needing to verbalize every step. BDH is our answer to that challenge. It is designed to be a universal reasoning system that can speak our language without being trapped inside it. And yes, it solves Sudoku. - -Read more: [Post-transformers: Sudoku Bench](https://pathway.com/research/beyond-transformers-sudoku-bench) - -### Performance Comparison - -| Model | Sudoku Extreme Accuracy | Relative Cost | -|------|------------------------|--------------| -| Pathway BDH | 97.4% | 10× lower, No chain-of-thought | -| Leading LLMs (O3-mini, DeepSeek R1, Claude 3.7 8K) | ~0% | High (chain-of-thought) | - -*Table 1: Performance comparison on extreme Sudoku benchmarks (~250,000 difficult puzzles).* -*Source: Pathway internal data and https://arxiv.org/pdf/2506.21734 for the Leading LLMs’ accuracy score. Pathway’s approach reflects top-1 accuracy and does not rely on chain-of-thought nor solution backtracking.* - - ## Installation and Training ```bash diff --git a/bdh.py b/bdh.py index 4cfff79..9f614bd 100644 --- a/bdh.py +++ b/bdh.py @@ -1,5 +1,18 @@ # Copyright 2025 Pathway Technology, Inc. +""" +BDH (Binary Data Handler) Model Implementation + +This module implements a neural network architecture designed for processing binary data sequences. +The model uses a transformer-like architecture with attention mechanisms to learn patterns in byte sequences. +This is useful for tasks like text generation, where we treat text as sequences of bytes. + +Key concepts for ML beginners: +- Neural networks learn to predict the next item in a sequence by finding patterns in data +- Attention mechanisms allow the model to focus on relevant parts of the input when making predictions +- Embeddings convert discrete tokens (like bytes) into continuous vectors that the network can process +""" + import dataclasses import math @@ -10,6 +23,38 @@ @dataclasses.dataclass class BDHConfig: + """ + Configuration class for the BDH model architecture. + + This dataclass stores all the hyperparameters (settings) that control the model's structure. + Hyperparameters are values set before training that determine how the model learns. + Changing these values changes the model's capacity (ability to learn) and training behavior. + + Attributes: + n_layer: Number of transformer layers in the model. Each layer processes the input + sequentially, allowing the model to learn increasingly complex patterns. + More layers = more capacity but slower training and more memory usage. + + n_embd: Embedding dimension - the size of vectors used to represent each byte. + Higher values allow richer representations but require more computation. + Think of this as how many "features" we use to describe each byte. + + dropout: Dropout probability - randomly sets some neurons to zero during training. + This prevents overfitting (memorizing training data instead of learning patterns). + Value between 0 (no dropout) and 1 (all neurons dropped). 0.1 means 10% are dropped. + + n_head: Number of attention heads. Attention heads allow the model to focus on different + aspects of the input simultaneously (like looking at syntax and semantics separately). + More heads = more parallel processing but more parameters. + + mlp_internal_dim_multiplier: Controls the size of internal layers in the feed-forward network. + This is multiplied by n_embd to determine the hidden layer size. + Larger values = more capacity in the MLP (Multi-Layer Perceptron). + + vocab_size: Size of the vocabulary - number of unique tokens the model can process. + For byte-level models, this is 256 (one for each possible byte value 0-255). + The model learns to predict which of these 256 values comes next. + """ n_layer: int = 6 n_embd: int = 256 dropout: float = 0.1 @@ -19,9 +64,37 @@ class BDHConfig: def get_freqs(n, theta, dtype): + """ + Generate frequency values for Rotary Position Embedding (RoPE). + + RoPE is a technique that encodes position information directly into attention calculations. + Instead of adding position embeddings, we rotate the query and key vectors based on their positions. + This helps the model understand the order of elements in sequences. + + The frequencies determine how fast the rotation changes as we move through positions. + Higher frequencies = faster rotation = model can distinguish positions that are closer together. + + Args: + n: Number of frequency components to generate. More components = finer-grained position encoding. + theta: Base frequency parameter. Higher values = slower rotation = model focuses on longer-range patterns. + dtype: Data type for the tensor (e.g., float32). Important for numerical precision. + + Returns: + A tensor of frequencies that will be used to rotate attention vectors based on position. + """ def quantize(t, q=2): + """ + Quantize (round down to nearest multiple of q) the input tensor. + + Quantization reduces the number of unique frequency values, which can help with + generalization and reduce overfitting to specific positions. + q=2 means we round down to the nearest even number. + """ return (t / q).floor() * q + # Generate frequencies using a geometric progression (exponential decay) + # The formula creates frequencies that decrease exponentially, which is useful because + # position differences matter more for nearby tokens than distant ones return ( 1.0 / (theta ** (quantize(torch.arange(0, n, 1, dtype=dtype)) / n)) @@ -30,122 +103,362 @@ def quantize(t, q=2): class Attention(torch.nn.Module): + """ + Attention mechanism with Rotary Position Embedding (RoPE). + + Attention is the core mechanism that allows the model to focus on relevant parts of the input + when making predictions. For example, when predicting the next word after "The cat sat on the", + the model should pay attention to "cat" and "sat" to predict "mat" or "floor". + + This implementation uses RoPE, which encodes position information through rotations rather than + addition. This is more mathematically elegant and often performs better than standard position embeddings. + + The attention mechanism works by: + 1. Computing how much each position should attend to every other position (attention scores) + 2. Using these scores to weight and combine the value vectors + 3. This creates a context-aware representation that considers the entire sequence + """ def __init__(self, config): + """ + Initialize the attention module. + + Sets up the frequency buffers needed for RoPE. These frequencies are pre-computed + and stored as buffers (non-trainable parameters) since they don't change during training. + """ super().__init__() self.config = config - nh = config.n_head - D = config.n_embd + nh = config.n_head # Number of attention heads + D = config.n_embd # Embedding dimension + # N is the dimension of the expanded space for attention computation + # This is larger than D to give the model more capacity in the attention mechanism N = config.mlp_internal_dim_multiplier * D // nh + # Store frequencies as a buffer (non-trainable, but part of the model state) + # The view(1, 1, 1, N) reshapes for broadcasting across batches, heads, and sequence positions self.freqs = torch.nn.Buffer( get_freqs(N, theta=2**16, dtype=torch.float32).view(1, 1, 1, N) ) @staticmethod def phases_cos_sin(phases): + """ + Convert phase values to cosine and sine components for rotation. + + In RoPE, we rotate vectors using rotation matrices, which are built from cos and sin values. + The phase represents how much to rotate (based on position and frequency). + + The modulo operation (phases % 1) ensures phases are in [0, 1), then we scale to [0, 2π) + because trigonometric functions are periodic with period 2π. + + Args: + phases: Tensor of phase values (position * frequency for each position-frequency pair) + + Returns: + Tuple of (cosine, sine) tensors that will be used to rotate the attention vectors + """ + # Normalize phases to [0, 1) range, then scale to [0, 2π) for trigonometric functions phases = (phases % 1) * (2 * math.pi) + # Compute cosine and sine components for rotation + # These will be used to rotate vectors: rotated = original * cos + rotated_original * sin phases_cos = torch.cos(phases) phases_sin = torch.sin(phases) return phases_cos, phases_sin @staticmethod def rope(phases, v): + """ + Apply Rotary Position Embedding (RoPE) to a vector. + + RoPE rotates pairs of dimensions in the vector based on position. This encodes position + information directly into the vector representation, allowing the model to understand + where each token is in the sequence. + + The rotation is done in 2D planes: for dimensions [0, 1, 2, 3, ...], we rotate + (0,1), (2,3), (4,5), etc. as pairs. This is more efficient than rotating all dimensions. + + Args: + phases: Phase values (position * frequency) that determine rotation angles + v: The vector to rotate (typically query or key vectors in attention) + + Returns: + The rotated vector, which now contains position information encoded through rotation + """ + # Create rotated version by swapping pairs and negating every other element + # [..., 1::2] gets odd indices (1, 3, 5, ...), [..., ::2] gets even indices (0, 2, 4, ...) + # The negative sign on odd indices creates the rotation effect + # This is equivalent to rotating in 2D planes: (dim0, dim1), (dim2, dim3), etc. v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size()) + # Get rotation coefficients (cos and sin) based on position-dependent phases phases_cos, phases_sin = Attention.phases_cos_sin(phases) + # Apply rotation: combine original vector (scaled by cos) with rotated vector (scaled by sin) + # This is the standard 2D rotation formula: x' = x*cos - y*sin, y' = x*sin + y*cos + # The result maintains the vector's magnitude but changes its direction based on position return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype) def forward(self, Q, K, V): + """ + Compute attention with RoPE. + + Attention computes how much each position should focus on every other position. + The output is a weighted combination of value vectors, where weights come from + how similar query and key vectors are (after rotation with position information). + + Args: + Q: Query vectors - "what am I looking for?" (shape: batch, heads, sequence_length, dim) + K: Key vectors - "what do I contain?" (here, K is the same as Q for self-attention) + V: Value vectors - "what information do I provide?" (shape: batch, heads, sequence_length, dim) + + Returns: + Attention output - context-aware representations weighted by attention scores + """ + # Ensure frequencies are in float32 for numerical stability assert self.freqs.dtype == torch.float32 + # This is self-attention: queries and keys come from the same source assert K is Q + # Extract sequence length T from the input shape + # Shape is (batch, heads, sequence_length, dimension) _, _, T, _ = Q.size() + # Compute rotation phases for each position in the sequence + # We create a tensor of positions [0, 1, 2, ..., T-1] and multiply by frequencies + # This gives us position-dependent rotation angles for each frequency component r_phases = ( torch.arange( 0, T, device=self.freqs.device, dtype=self.freqs.dtype, - ).view(1, 1, -1, 1) - ) * self.freqs + ).view(1, 1, -1, 1) # Reshape for broadcasting: (1, 1, T, 1) + ) * self.freqs # Broadcast multiply: each position gets rotated by all frequencies + # Apply RoPE to queries and keys to encode position information QR = self.rope(r_phases, Q) - KR = QR + KR = QR # Since K == Q, rotated keys are the same as rotated queries - # Current attention + # Compute attention scores: how similar is each query to each key? + # QR @ KR.mT computes dot products between all query-key pairs + # tril(diagonal=-1) creates a lower triangular matrix, ensuring each position only + # attends to previous positions (causal masking - can't see the future when predicting) scores = (QR @ KR.mT).tril(diagonal=-1) + # Weight and combine value vectors using attention scores + # Higher scores = more attention = that value contributes more to the output return scores @ V class BDH(nn.Module): + """ + BDH (Baby Dragon Hatchling) Model - A transformer-like architecture for byte sequences. + + This model processes sequences of bytes (0-255) and learns to predict the next byte. + It uses a sparse attention mechanism where the model learns to encode and decode + information through learned projection matrices. + + Architecture overview: + 1. Embedding: Convert byte tokens to dense vectors + 2. Multiple transformer layers: Each layer refines the representation + 3. Language model head: Converts final representation to probability distribution over bytes + + The "sparse" aspect comes from using ReLU activation, which creates sparse (mostly zero) + representations that can be more efficient and interpretable than dense representations. + """ def __init__(self, config: BDHConfig): + """ + Initialize the BDH model with the given configuration. + + Creates all the learnable parameters (weights) that will be optimized during training. + These parameters start with small random values and are updated via gradient descent + to minimize prediction error on the training data. + """ super().__init__() + # Ensure vocabulary size is specified (needed for embedding and output layers) assert config.vocab_size is not None self.config = config - nh = config.n_head - D = config.n_embd + nh = config.n_head # Number of attention heads + D = config.n_embd # Embedding dimension + # N is the expanded dimension for sparse representations + # Larger than D to allow the model to represent information in a higher-dimensional space N = config.mlp_internal_dim_multiplier * D // nh + + # Decoder: Projects from sparse high-dimensional space back to embedding dimension + # This learns how to combine information from the sparse representation + # Shape: (nh * N, D) - takes expanded representation and compresses to D dimensions self.decoder = nn.Parameter(torch.zeros((nh * N, D)).normal_(std=0.02)) + + # Encoder: Projects from embedding dimension to sparse high-dimensional space + # This learns how to expand the representation to capture more information + # Shape: (nh, D, N) - expands D-dimensional vectors to N-dimensional sparse vectors per head self.encoder = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02)) + # Attention mechanism: Allows the model to focus on relevant parts of the sequence self.attn = Attention(config) + # Layer normalization: Stabilizes training by normalizing activations + # elementwise_affine=False means no learnable scale/shift (just normalization) + # This helps gradients flow better and speeds up training self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False) + + # Embedding layer: Converts byte tokens (0-255) to dense D-dimensional vectors + # Each of the 256 possible bytes gets a learnable vector representation self.embed = nn.Embedding(config.vocab_size, D) + + # Dropout: Randomly zeros some activations during training to prevent overfitting + # This forces the model to learn robust features that don't depend on specific neurons self.drop = nn.Dropout(config.dropout) + + # Encoder for value vectors: Similar to self.encoder but used for value projection + # This allows the model to learn different encodings for attention values vs queries/keys self.encoder_v = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02)) + # Language model head: Final projection from embeddings to vocabulary logits + # Converts the D-dimensional representation to a 256-dimensional vector + # Each dimension represents the model's confidence that the next byte is that value + # Shape: (D, vocab_size) - projects from embedding space to vocabulary space self.lm_head = nn.Parameter( torch.zeros((D, config.vocab_size)).normal_(std=0.02) ) + # Initialize all weights with small random values + # This ensures the model starts with diverse, non-zero gradients self.apply(self._init_weights) def _init_weights(self, module): + """ + Initialize weights for different layer types. + + Proper weight initialization is crucial for training. If weights are too large, + gradients explode; if too small, gradients vanish. The standard deviation of 0.02 + is a common choice that works well for transformer models. + + Args: + module: The neural network module to initialize (Linear layer, Embedding, etc.) + """ if isinstance(module, nn.Linear): + # Linear layers: Initialize weights with small random values from normal distribution + # mean=0.0, std=0.02 ensures weights start small but non-zero + # Small initial weights help prevent exploding gradients early in training nn.init.normal_(module.weight, mean=0.0, std=0.02) + # Biases start at zero - this is standard practice for most layers + # Zero bias means the layer starts neutral, learning bias values during training if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): + # Embedding layers: Also initialized with small random values + # This ensures each token starts with a unique but small vector representation + # The model will learn to make these representations meaningful during training nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): + """ + Forward pass through the model. + + This is where the model processes input sequences and makes predictions. + The input is a sequence of byte indices, and the output is a probability distribution + over possible next bytes for each position. + + Args: + idx: Input tensor of byte indices (shape: batch_size, sequence_length) + Each value is an integer 0-255 representing a byte + targets: Optional target byte indices for computing loss during training + If provided, the model computes how wrong its predictions are + + Returns: + logits: Raw prediction scores for each possible next byte (before softmax) + Shape: (batch_size, sequence_length, vocab_size) + Higher values = model thinks that byte is more likely + loss: Cross-entropy loss if targets provided, None otherwise + Lower loss = better predictions = model is learning correctly + """ C = self.config - B, T = idx.size() - D = C.n_embd - nh = C.n_head + # Extract dimensions from input + B, T = idx.size() # B = batch size, T = sequence length + D = C.n_embd # Embedding dimension + nh = C.n_head # Number of attention heads + # N is the expanded dimension for sparse representations N = D * C.mlp_internal_dim_multiplier // nh + # Step 1: Convert byte indices to dense vector representations (embeddings) + # Each byte (0-255) is mapped to a D-dimensional vector that the model learns + # unsqueeze(1) adds a dimension for heads: (B, T, D) -> (B, 1, T, D) x = self.embed(idx).unsqueeze(1) - # actually helps with training - x = self.ln(x) # B, 1, T, D + # Step 2: Normalize embeddings (helps with training stability) + # Layer normalization centers and scales the activations, preventing extreme values + # This is done early to ensure the rest of the network receives well-scaled inputs + x = self.ln(x) # Shape: (B, 1, T, D) + # Step 3: Process through multiple transformer layers + # Each layer refines the representation, allowing the model to learn increasingly + # complex patterns. More layers = deeper understanding but more computation. for level in range(C.n_layer): - x_latent = x @ self.encoder - - x_sparse = F.relu(x_latent) # B, nh, T, N - + # 3a: Encode input to sparse high-dimensional representation + # This expands the D-dimensional vectors to N-dimensional vectors per head + # The model learns which dimensions to activate for different patterns + x_latent = x @ self.encoder # (B, 1, T, D) @ (nh, D, N) -> (B, nh, T, N) + + # 3b: Apply ReLU to create sparse representation + # ReLU sets negative values to zero, creating a sparse (mostly zeros) representation + # Sparsity can help the model focus on important features and reduce overfitting + # Only positive activations pass through, creating a "rectified" representation + x_sparse = F.relu(x_latent) # Shape: (B, nh, T, N) + + # 3c: Apply attention mechanism + # Attention allows each position to focus on relevant parts of the sequence + # Q and K are both x_sparse (self-attention), V is the original x + # This creates context-aware representations that consider the whole sequence yKV = self.attn( - Q=x_sparse, - K=x_sparse, - V=x, + Q=x_sparse, # Queries: "what am I looking for?" + K=x_sparse, # Keys: "what do I contain?" (same as queries for self-attention) + V=x, # Values: "what information do I provide?" ) + # Normalize attention output for stability yKV = self.ln(yKV) - y_latent = yKV @ self.encoder_v - y_sparse = F.relu(y_latent) - xy_sparse = x_sparse * y_sparse # B, nh, T, N - + # 3d: Encode attention output to sparse representation + # Similar to step 3a, but for the attention output + # This allows the model to learn different sparse patterns for attended information + y_latent = yKV @ self.encoder_v # (B, nh, T, D) @ (nh, D, N) -> (B, nh, T, N) + y_sparse = F.relu(y_latent) # Create sparse representation of attention output + + # 3e: Combine sparse representations through element-wise multiplication + # This is a gating mechanism: x_sparse acts as a gate for y_sparse + # Only features that are active in both representations remain active + # This creates a more selective, focused representation + xy_sparse = x_sparse * y_sparse # Shape: (B, nh, T, N) + + # 3f: Apply dropout for regularization + # Randomly zeros some activations during training to prevent overfitting + # This forces the model to learn redundant, robust features xy_sparse = self.drop(xy_sparse) + # 3g: Decode sparse representation back to embedding dimension + # Transpose and reshape to combine all heads: (B, nh, T, N) -> (B, 1, T, nh*N) + # Then project back to D dimensions using the decoder matrix + # This learns how to combine information from the sparse representation yMLP = ( xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh) @ self.decoder - ) # B, 1, T, D + ) # Result shape: (B, 1, T, D) + + # 3h: Normalize the MLP output y = self.ln(yMLP) - x = self.ln(x + y) - + + # 3i: Residual connection: add the layer's output to its input + # This creates a "highway" for gradients to flow through, enabling deeper networks + # The model can learn to make small adjustments (y) to the input (x) + # Without residual connections, deep networks are hard to train + x = self.ln(x + y) # Normalize the sum for stability + + # Step 4: Project final representation to vocabulary logits + # Reshape to remove head dimension: (B, 1, T, D) -> (B, T, D) + # Then project to vocabulary size: (B, T, D) @ (D, vocab_size) -> (B, T, vocab_size) + # Each of the 256 values represents the model's confidence for that byte logits = x.view(B, T, D) @ self.lm_head + + # Step 5: Compute loss if targets are provided (training mode) loss = None if targets is not None: + # Cross-entropy loss measures how wrong the predictions are + # It compares the predicted probability distribution to the true next byte + # Lower loss = predictions are closer to truth = model is learning + # Reshape to (batch*sequence, vocab_size) and (batch*sequence,) for the loss function loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss @@ -158,14 +471,77 @@ def generate( temperature: float = 1.0, top_k: int | None = None, ) -> torch.Tensor: + """ + Generate new tokens (bytes) by sampling from the model's predictions. + + This is the text generation process: the model predicts the next byte, we sample + from that prediction, add it to the sequence, and repeat. This creates new sequences + that follow patterns the model learned during training. + + The generation is autoregressive: each new token is based on all previous tokens. + This is why language models can create coherent text - they consider the full context. + + Args: + idx: Initial sequence of byte indices to start generation from (the "prompt") + Shape: (batch_size, sequence_length) + max_new_tokens: Maximum number of new bytes to generate + The model will generate exactly this many tokens + temperature: Controls randomness in sampling + - temperature = 1.0: Use model's confidence as-is (default) + - temperature < 1.0: Make predictions more confident (less random) + - temperature > 1.0: Make predictions less confident (more random) + Lower temperature = more conservative, higher = more creative + top_k: If specified, only consider the top-k most likely tokens + This prevents the model from choosing very unlikely tokens + None means consider all 256 possible bytes + + Returns: + Extended sequence with original prompt + generated tokens + Shape: (batch_size, original_length + max_new_tokens) + """ + # @torch.no_grad() disables gradient computation for efficiency + # We don't need gradients during generation (only during training) + # This saves memory and speeds up inference + + # Generate tokens one at a time, building on the previous sequence for _ in range(max_new_tokens): + # Use the current sequence as input (includes original prompt + generated tokens so far) idx_cond = idx - logits, _ = self(idx_cond) - logits = logits[:, -1, :] / temperature + + # Get model's predictions for the next token + # The model outputs logits (raw scores) for all 256 possible bytes + logits, _ = self(idx_cond) # Shape: (batch, sequence_length, vocab_size) + + # Take only the last position's predictions (we only care about the next token) + # Divide by temperature to adjust randomness + # Higher temperature = divide by larger number = smaller differences = more random + # Lower temperature = divide by smaller number = larger differences = more confident + logits = logits[:, -1, :] / temperature # Shape: (batch, vocab_size) + + # Optional: Apply top-k filtering to restrict to most likely tokens if top_k is not None: + # Find the k-th highest logit value values, _ = torch.topk(logits, min(top_k, logits.size(-1))) + # Set all logits below the k-th highest to negative infinity + # This makes their probability zero after softmax + # This prevents the model from choosing very unlikely tokens logits[logits < values[:, [-1]]] = float("-inf") - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1) - idx = torch.cat((idx, idx_next), dim=1) + + # Convert logits to probabilities using softmax + # Softmax ensures probabilities sum to 1.0 and are all positive + # Higher logits become higher probabilities, but the relationship is non-linear + # This creates a probability distribution over the 256 possible bytes + probs = F.softmax(logits, dim=-1) # Shape: (batch, vocab_size) + + # Sample one token from the probability distribution + # multinomial randomly picks a token, but more likely tokens are picked more often + # This adds randomness while still favoring the model's confident predictions + # Without sampling, we'd always pick the most likely token (greedy decoding) + idx_next = torch.multinomial(probs, num_samples=1) # Shape: (batch, 1) + + # Append the new token to the sequence + # This becomes the input for the next iteration + idx = torch.cat((idx, idx_next), dim=1) # Shape: (batch, sequence_length + 1) + + # Return the complete sequence: original prompt + all generated tokens return idx diff --git a/bdh_app/.DS_Store b/bdh_app/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..51da145c762e8bee07e86d9c2f31f0075a8c565b GIT binary patch literal 6148 zcmeHK%Sr<=6g{b3R0g3-m%$Gx_y@yK7q0yPwJK6*J5pMqyZHuxSkK9g(wSy0BKMZ$ zBscfuafaprAT;~uYhVgs!XykT6in`7(Pm5K%V1(ppm}Z~JSOPBaf_FS(g+22fxh;Fp$_LDFHDI1&P41fUTO5jvj{&1% z|67h?bKKyy#2V>)9M6zdH&?(Fa0OfeSKveep4qZVW$3vp;0m|`KMH7nh)lvdV`rGQ z4raP?s$JLUWGw5lEMLr6XY35wLJLME8a3o6Mld@4F~`*zJ42%*q>SokK3h+gr?WOPJut}> p*Qr2bxpxU*Mf=EQYP9yKGUMuuouPV>KhcSP5lDh~<_i3R0$ limit else "") + + +def plot_checkpoint_samples(results, output_path): + if not results: + return False + steps = [results[0], results[len(results) // 2], results[-1]] + fig, axes = plt.subplots(3, 1, figsize=(12, 5.5)) + for ax, r in zip(axes, steps): + preview = sanitize_preview(r["text"], limit=200) + ax.axis("off") + ax.text( + 0, + 0.5, + f"Step {r['step']} | Loss {r.get('loss', 'N/A')}\n{preview}", + fontfamily="monospace", + fontsize=9, + va="center", + ) + fig.tight_layout() + fig.savefig(output_path, dpi=150) + plt.close() + return True + + +def plot_output_quality(results, output_path): + if not results: + return False + steps = [] + ratios = [] + for r in results: + text = r["text"] + if not text: + continue + steps.append(r["step"]) + ratios.append(len(set(text)) / max(1, len(text))) + plt.figure(figsize=(10, 4)) + plt.plot(steps, ratios, marker="o", linewidth=1.5) + plt.xlabel("Step") + plt.ylabel("Unique character ratio") + plt.title("Output Diversity over Checkpoints") + plt.grid(True, alpha=0.3) + plt.savefig(output_path, dpi=150) + plt.close() + return True + + +class Evaluator: + def __init__(self, input_dir, output_dir, prompt, device="auto", log_file=None): + self.input_dir = input_dir + self.output_dir = output_dir + self.prompt = prompt + self.device = device + self.log_file = log_file + + def _resolve_device(self): + device = self.device + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cpu" and torch.backends.mps.is_available(): + device = "mps" + return device + + def run(self): + device = self._resolve_device() + print(f"Using device: {device}") + + os.makedirs(self.output_dir, exist_ok=True) + figs_dir = os.path.join(self.output_dir, "figs") + os.makedirs(figs_dir, exist_ok=True) + + checkpoints = find_checkpoints(self.input_dir) + print(f"Found {len(checkpoints)} checkpoints in {self.input_dir}") + + results = [] + for step, path in checkpoints: + print(f"Eval step {step}...", end="\r") + model, _, meta = load_checkpoint(path) + text = generate_text(model, self.prompt, device=device) + results.append({ + 'step': step, + 'loss': meta['loss'], + 'text': text + }) + + csv_file = None + if os.path.basename(os.path.normpath(self.input_dir)) == "checkpoints": + parent = os.path.dirname(os.path.normpath(self.input_dir)) + log_dir = os.path.join(parent, "logs") + csv_candidate = os.path.join(log_dir, "training_log.csv") + if os.path.exists(csv_candidate): + csv_file = csv_candidate + if not csv_file or not os.path.exists(csv_file): + candidate = os.path.join(self.input_dir, "training_log.csv") + if os.path.exists(candidate): + csv_file = candidate + + loss_plot_path = os.path.join(figs_dir, "loss_curve.png") + has_plot = False + if csv_file: + checkpoint_steps = [r["step"] for r in results] + has_plot = plot_loss(csv_file, loss_plot_path, checkpoint_steps=checkpoint_steps) + + samples_path = os.path.join(figs_dir, "checkpoint_samples.png") + plot_checkpoint_samples(results, samples_path) + quality_path = os.path.join(figs_dir, "output_quality.png") + plot_output_quality(results, quality_path) + + report_path = os.path.join(self.output_dir, "report.md") + with open(report_path, "w", encoding="utf-8") as f: + f.write("# BDH Training Report\n\n") + if has_plot: + f.write("## Training Loss\n\n") + f.write(f"![Loss Curve](figs/loss_curve.png)\n\n") + f.write("## Quick Before vs After\n\n") + if results: + early = results[0] + mid = results[len(results) // 2] + late = results[-1] + f.write("| Stage | Step | Sample Output |\n") + f.write("|-------|------|---------------|\n") + f.write(f"| Early | {early['step']} | {sanitize_preview(early['text'], 160)} |\n") + f.write(f"| Mid | {mid['step']} | {sanitize_preview(mid['text'], 160)} |\n") + f.write(f"| Late | {late['step']} | {sanitize_preview(late['text'], 160)} |\n") + f.write("\n") + f.write("## Checkpoint Samples\n\n") + f.write(f"![Checkpoint Samples](figs/checkpoint_samples.png)\n\n") + f.write("## Output Diversity\n\n") + f.write(f"![Output Diversity](figs/output_quality.png)\n\n") + f.write("## Checkpoint Evaluations\n\n") + f.write(f"**Prompt:** `{self.prompt}`\n\n") + f.write("| Step | Loss | Generated Text Preview |\n") + f.write("|------|------|------------------------|\n") + for r in results: + preview = r['text'][:100].replace('\n', ' ').replace('|', '\\|') + f.write(f"| {r['step']} | {r.get('loss', 'N/A')} | {preview}... |\n") + f.write("\n\n## Full Generation Outputs\n\n") + for r in results: + f.write(f"### Step {r['step']}\n") + f.write(f"**Loss:** {r.get('loss', 'N/A')}\n\n") + f.write("```\n") + f.write(r['text']) + f.write("\n```\n\n") + + print(f"\nReport generated at {report_path}") + + log_path = self.log_file or os.path.join(self.output_dir, "checkpoint_generations.log") + with open(log_path, "w", encoding="utf-8") as f: + f.write(f"Prompt: {self.prompt}\n") + f.write("=" * 80 + "\n") + for r in results: + f.write(f"\nSTEP {r['step']} | Loss: {r.get('loss', 'N/A')}\n") + f.write("-" * 80 + "\n") + f.write(r["text"]) + f.write("\n") + print(f"Consolidated log saved to {log_path}") + + +def run_evaluation(input_dir, output_dir, prompt, device="auto", log_file=None): + Evaluator(input_dir, output_dir, prompt, device, log_file).run() + + +def run_multi_prompt_evaluation(input_dir, output_dir, prompts, device="auto"): + eval_dir = os.path.join(output_dir, "multi_prompt") + os.makedirs(eval_dir, exist_ok=True) + for idx, prompt in enumerate(prompts, start=1): + prompt_tag = f"prompt_{idx}" + eval_out = os.path.join(eval_dir, prompt_tag) + Evaluator( + input_dir=input_dir, + output_dir=eval_out, + prompt=prompt, + device=device, + log_file=None, + ).run() + + +def main(argv=None): + parser = argparse.ArgumentParser(description="Evaluate BDH checkpoints") + parser.add_argument("--input_dir", type=str, required=True, help="Directory containing checkpoints") + parser.add_argument("--output_dir", type=str, default="outputs/evaluation", help="Directory to save report and visuals") + parser.add_argument("--prompt", type=str, default="To be or not to be") + parser.add_argument("--device", type=str, default="auto") + parser.add_argument("--log_file", type=str, default=None, help="Path to write consolidated generations log") + args = parser.parse_args(argv) + run_evaluation(args.input_dir, args.output_dir, args.prompt, args.device, args.log_file) + + +if __name__ == "__main__": + main() diff --git a/bdh_app/memory.py b/bdh_app/memory.py new file mode 100644 index 0000000..233ffa2 --- /dev/null +++ b/bdh_app/memory.py @@ -0,0 +1,272 @@ +""" +Memory demonstration module for BDH. +""" + +import sys +import os +import argparse + +import torch +import matplotlib.pyplot as plt + +import bdh + +# Use Agg backend for headless environments (e.g., Docker) +plt.switch_backend("Agg") + +# Configuration +FACT = "The capital of JiriLand is DragonCity." +QUERY = "The capital of JiriLand is" +EXPECTED = "DragonCity" +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +if DEVICE == "cpu" and torch.backends.mps.is_available(): + DEVICE = "mps" + + +class Tee(object): + def __init__(self, filename): + self.terminal = sys.stdout + self.log = open(filename, "w", encoding="utf-8") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + self.log.flush() + + def flush(self): + self.terminal.flush() + self.log.flush() + + +class FastMemory: + def __init__(self): + self.facts = [] + + def add(self, fact): + self.facts.append(fact) + + def retrieve(self, query): + query_tokens = set(query.lower().split()) + hits = [] + for fact in self.facts: + fact_tokens = set(fact.lower().split()) + if query_tokens.intersection(fact_tokens): + hits.append(fact) + return hits + + +def find_latest_checkpoint(base_dir): + if not os.path.exists(base_dir): + return None + candidates = [] + for name in os.listdir(base_dir): + if name.startswith("bdh_checkpoint_step_") and name.endswith(".pt"): + try: + step = int(name.replace("bdh_checkpoint_step_", "").replace(".pt", "")) + except ValueError: + step = -1 + candidates.append((step, os.path.join(base_dir, name))) + if not candidates: + final_model = os.path.join(base_dir, "bdh_model_final.pt") + return final_model if os.path.exists(final_model) else None + candidates.sort(key=lambda x: x[0]) + return candidates[-1][1] + + +def get_model(train_out_dir): + checkpoint_dir = os.path.join(train_out_dir, "checkpoints") + ckpt_path = find_latest_checkpoint(checkpoint_dir) + if ckpt_path and os.path.exists(ckpt_path): + checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) + config = checkpoint.get("config", bdh.BDHConfig()) + model = bdh.BDH(config) + state_dict = checkpoint.get("model_state_dict", checkpoint) + model.load_state_dict(state_dict) + model = model.to(DEVICE) + print(f"Loaded checkpoint: {ckpt_path}") + return model + raise FileNotFoundError( + "No trained checkpoint found. Run train.py first to generate checkpoints." + ) + + +def generate(model, prompt, max_new=20, top_k=1, temperature=0.8): + model.eval() + prompt_bytes = bytearray(prompt, "utf-8") + prompt_tensor = torch.tensor(prompt_bytes, dtype=torch.long, device=DEVICE).unsqueeze(0) + with torch.no_grad(): + out = model.generate(prompt_tensor, max_new_tokens=max_new, temperature=temperature, top_k=top_k) + + decoded = bytes(out.to(torch.uint8).to("cpu").squeeze(0)).decode(errors="replace") + return decoded[len(prompt):] + + +def sanitize_text(text): + cleaned = "".join(ch for ch in text if ch.isprintable() or ch in "\n\t ") + return cleaned.replace("\r", "").strip() + + +def extract_fact_from_prompt(prompt_text): + marker = "The capital of " + if marker not in prompt_text: + return None + try: + start = prompt_text.index(marker) + snippet = prompt_text[start:] + parts = snippet.split(" is ") + if len(parts) < 2: + return None + city_part = parts[1] + city = city_part.split(".")[0].strip() + if city: + return f"{marker}{parts[0].split(marker)[1]} is {city}." + except Exception: + return None + return None + + +def memory_answer(query, facts): + for fact in facts: + if fact.startswith("The capital of ") and " is " in fact: + return fact.split(" is ")[1].replace(".", "").strip() + return "" + + +def demo_fast_memory(model): + print("\n" + "=" * 50) + print("FAST MEMORY: Runtime Context Injection") + print("=" * 50) + + fast_memory = FastMemory() + + print("1. Baseline: trained model without memory") + print(f" Prompt: '{QUERY}'") + output = sanitize_text(generate(model, QUERY)) + print(f" Output: '{output}'") + + if "DragonCity" in output: + print(" (Unexpected: Model randomly guessed it!)") + else: + print(" (Expected: Model doesn't know the fact)") + + print("\n2. Move fact from a prompt into fast memory") + ingest_prompt = f"Fact: {FACT}\nQuestion: {QUERY}" + print(f" Ingest Prompt: '{ingest_prompt.replace(chr(10), ' ')}'") + extracted = extract_fact_from_prompt(ingest_prompt) + if extracted: + fast_memory.add(extracted) + print(f" Stored in memory: '{extracted}'") + else: + print(" WARNING: No fact extracted from prompt.") + + print("\n3. Answer using fast memory (no fact in prompt)") + retrieved = fast_memory.retrieve(QUERY) + memory_response = memory_answer(QUERY, retrieved) + print(f" Memory Recall: '{memory_response}'") + + memory_context = " ".join([f"Fact: {f}" for f in retrieved]) + context_prompt = f"{memory_context}\nQuestion: {QUERY}\nAnswer:" + model_with_memory = sanitize_text(generate(model, context_prompt)) + print(f" Model + Memory Prompt Output: '{model_with_memory}'") + + if EXPECTED in memory_response: + print(" SUCCESS: Memory layer recalled the fact.") + else: + print(" FAIL: Memory layer did not recall the fact.") + + +def demo_model_memory(model, out_dir): + print("\n" + "=" * 50) + print("MODEL MEMORY: Consolidate Facts into Weights") + print("=" * 50) + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + print("1. Verifying model doesn't know the fact initially...") + output = sanitize_text(generate(model, QUERY)) + print(f" Output: '{output}'") + + print("\n2. Fine-tuning model on the fact (Consolidating to Long-Term Memory)...") + data = bytearray(FACT * 10, "utf-8") + x_train = torch.tensor(data[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0) + y_train = torch.tensor(data[1:], dtype=torch.long, device=DEVICE).unsqueeze(0) + + model.train() + steps = 150 + print(f" Training for {steps} steps...") + + losses = [] + for i in range(steps): + _, loss = model(x_train, y_train) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + current_loss = loss.item() + losses.append(current_loss) + if i % 20 == 0 or i == steps - 1: + print(f" Step {i}: loss {current_loss:.4f}") + + plt.figure(figsize=(10, 6)) + plt.plot(losses, label='Fine-tuning Loss') + plt.xlabel('Step') + plt.ylabel('Loss') + plt.title('Quick Fine-tuning Loss (Model Memory)') + plt.grid(True, alpha=0.3) + plt.legend() + loss_fig_path = os.path.join(out_dir, "figs", "memory_fine_tuning_loss.png") + plt.savefig(loss_fig_path) + plt.close() + print(f" (Saved fine-tuning loss plot to {loss_fig_path})") + + print("\n3. Testing recall WITHOUT context...") + output = sanitize_text(generate(model, QUERY)) + print(f" Prompt: '{QUERY}'") + print(f" Output: '{output}'") + + if "DragonCity" in output: + print(" SUCCESS: Model internalized the fact into weights!") + else: + print(" FAIL: Model failed to memorize.") + + +class MemoryDemo: + def __init__(self, train_out_dir, out_dir): + self.train_out_dir = train_out_dir + self.out_dir = out_dir + + def run(self): + os.makedirs(self.out_dir, exist_ok=True) + figs_dir = os.path.join(self.out_dir, "figs") + os.makedirs(figs_dir, exist_ok=True) + log_path = os.path.join(self.out_dir, "memory_log.txt") + + sys.stdout = Tee(log_path) + print(f"Logging to {log_path}") + print(f"Using device: {DEVICE}") + print("\nCLIENT SUMMARY") + print("- Fast Memory = external memory store, separate from the prompt.") + print("- Model Memory = facts learned into weights after a short fine-tune.") + print("- Success criteria:") + print(" 1) Fast memory returns the fact without placing it in the prompt.") + print(" 2) Model memory recalls the fact even when memory is cleared.") + + model = get_model(self.train_out_dir) + demo_fast_memory(model) + demo_model_memory(model, self.out_dir) + + +def run_memory_demo(train_out_dir, out_dir): + MemoryDemo(train_out_dir, out_dir).run() + + +def main(argv=None): + parser = argparse.ArgumentParser(description="BDH memory demos") + parser.add_argument("--train_out_dir", type=str, default="outputs/training", help="Training output directory") + parser.add_argument("--out_dir", type=str, default="outputs/memory", help="Memory output directory") + args = parser.parse_args(argv) + run_memory_demo(args.train_out_dir, args.out_dir) + + +if __name__ == "__main__": + main() diff --git a/bdh_app/training.py b/bdh_app/training.py new file mode 100644 index 0000000..50104e4 --- /dev/null +++ b/bdh_app/training.py @@ -0,0 +1,272 @@ +# Copyright Pathway Technology, Inc. + +""" +Training module for BDH. + +This module provides a CLI-compatible entrypoint via main(). +""" + +import os +import argparse +import csv +from contextlib import nullcontext + +import bdh +import numpy as np +import requests +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Optional: For monitoring memory usage +try: + import psutil # type: ignore[import-not-found] + HAS_PSUTIL = True +except ImportError: + HAS_PSUTIL = False + print("Install psutil for memory monitoring: pip install psutil") + +# Device selection +if torch.cuda.is_available(): + device = torch.device("cuda") +elif torch.backends.mps.is_available(): + device = torch.device("mps") + os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0' +else: + device = torch.device("cpu") + +# Data type selection +if torch.cuda.is_available(): + if torch.cuda.is_bf16_supported(): + dtype = "bfloat16" + else: + dtype = "float16" +elif torch.backends.mps.is_available(): + dtype = "float32" +else: + dtype = "float32" + +ptdtype = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +}[dtype] + +ctx = ( + torch.amp.autocast(device_type=device.type, dtype=ptdtype) + if device.type == "cuda" + else nullcontext() +) + +scaler = torch.amp.GradScaler(device=device.type, enabled=(dtype == "float16" and device.type == "cuda")) + +torch.manual_seed(1337) +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + +print(f"Using device: {device} with dtype {dtype}") + +BDH_CONFIG = bdh.BDHConfig( + mlp_internal_dim_multiplier=48, + n_layer=5, +) + +BLOCK_SIZE = 384 +BATCH_SIZE = 12 + +if device.type == "cpu": + print("WARNING: Training on CPU will be very slow. Consider using a GPU if available.") + print("Memory optimizations applied: reduced batch size, block size, and model dimensions") +elif device.type == "mps": + print("MPS detected. Using float32 (required for MPS training).") + print(f"Model config: n_layer={BDH_CONFIG.n_layer}, mlp_mult={BDH_CONFIG.mlp_internal_dim_multiplier}") + print(f"Training config: batch_size={BATCH_SIZE}, block_size={BLOCK_SIZE}") + print("Note: float32 uses 2x more memory than float16. Monitor RAM usage.") + +MAX_ITERS = 3000 +LEARNING_RATE = 1e-3 +WEIGHT_DECAY = 0.1 +LOG_FREQ = 10 + +input_file_path = os.path.join(os.path.dirname(__file__), "..", "input.txt") +input_file_path = os.path.abspath(input_file_path) + + +def fetch_data(): + if not os.path.exists(input_file_path): + data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" + with open(input_file_path, "w") as f: + f.write(requests.get(data_url).text) + + +def get_batch(split): + if not os.path.exists(input_file_path): + fetch_data() + assert os.path.exists(input_file_path), f"Dataset missing at {input_file_path}" + + data = np.memmap(input_file_path, dtype=np.uint8, mode="r") + if split == "train": + data = data[: int(0.9 * len(data))] + else: + data = data[int(0.9 * len(data)) :] + + ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,)) + x = torch.stack( + [torch.from_numpy((data[i : i + BLOCK_SIZE]).astype(np.int64)) for i in ix] + ) + y = torch.stack( + [ + torch.from_numpy((data[i + 1 : i + 1 + BLOCK_SIZE]).astype(np.int64)) + for i in ix + ] + ) + if torch.cuda.is_available(): + x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to( + device, non_blocking=True + ) + else: + x, y = x.to(device), y.to(device) + return x, y + + +def estimate_loss(model, split, eval_steps=10): + losses = [] + model.eval() + with torch.no_grad(): + for _ in range(eval_steps): + x, y = get_batch(split) + with ctx: + _, loss = model(x, y) + losses.append(loss.item()) + model.train() + return sum(losses) / len(losses) + + +class Trainer: + def __init__(self, max_iters, batch_size, out_dir): + self.max_iters = max_iters + self.batch_size = batch_size + self.out_dir = out_dir + + def run(self): + global MAX_ITERS, BATCH_SIZE + MAX_ITERS = self.max_iters + BATCH_SIZE = self.batch_size + + checkpoint_dir = os.path.join(self.out_dir, "checkpoints") + log_dir = os.path.join(self.out_dir, "logs") + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(log_dir, exist_ok=True) + + print(f"Training outputs will be saved to: {self.out_dir}") + print(f" Checkpoints: {checkpoint_dir}") + print(f" Logs: {log_dir}") + + fetch_data() + model = bdh.BDH(BDH_CONFIG).to(device) + + if device.type == "cuda": + model = torch.compile(model) + elif device.type == "mps": + print("Skipping torch.compile on MPS") + else: + print("Skipping torch.compile on CPU") + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LEARNING_RATE, + weight_decay=WEIGHT_DECAY + ) + + x, y = get_batch("train") + loss_acc = 0 + loss_steps = 0 + + csv_path = os.path.join(log_dir, "training_log.csv") + csv_file = open(csv_path, "w", newline="") + csv_writer = csv.writer(csv_file) + csv_writer.writerow(["step", "train_loss", "val_loss"]) + + checkpoint_intervals = [int(MAX_ITERS * (i / 10.0)) for i in range(1, 11)] + checkpoint_intervals = sorted(list(set(s for s in checkpoint_intervals if s > 0))) + print(f"Checkpoints will be saved at steps: {checkpoint_intervals}") + + for step in range(MAX_ITERS): + with ctx: + _, loss = model(x, y) + + x, y = get_batch("train") + loss_acc += loss.item() + loss_steps += 1 + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if device.type == "cuda" and step % 50 == 0: + torch.cuda.empty_cache() + + if step % LOG_FREQ == 0: + avg_loss = loss_acc / loss_steps + val_loss = estimate_loss(model, "val", eval_steps=5) + print(f"Step: {step}/{MAX_ITERS} train {avg_loss:.4f} | val {val_loss:.4f}") + csv_writer.writerow([step, avg_loss, val_loss]) + csv_file.flush() + loss_acc = 0 + loss_steps = 0 + + target_step = step + 1 + if target_step in checkpoint_intervals: + ckpt_name = f"bdh_checkpoint_step_{target_step}.pt" + ckpt_path = os.path.join(checkpoint_dir, ckpt_name) + checkpoint = { + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'config': BDH_CONFIG, + 'step': target_step, + 'loss': loss.item(), + } + torch.save(checkpoint, ckpt_path) + print(f" --> Checkpoint saved: {ckpt_path}") + + csv_file.close() + print("Training done.") + + final_model_path = os.path.join(checkpoint_dir, "bdh_model_final.pt") + torch.save(model.state_dict(), final_model_path) + print(f"Final model weights saved to {final_model_path}") + + model.eval() + print("Generating sample...") + prompt = torch.tensor( + bytearray("To be or ", "utf-8"), + dtype=torch.long, + device=device + ).unsqueeze(0) + ret = model.generate(prompt, max_new_tokens=100, top_k=3) + try: + ret_decoded = bytes(ret.to(torch.uint8).to("cpu").squeeze(0)).decode( + errors="backslashreplace" + ) + print(ret_decoded) + except Exception: + print("(Could not decode output bytes to string)") + + +def run_training(max_iters, batch_size, out_dir): + Trainer(max_iters, batch_size, out_dir).run() + + +def main(argv=None): + parser = argparse.ArgumentParser(description="Train BDH model") + parser.add_argument("--max_iters", type=int, default=3000, help="Maximum number of training iterations") + parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="Batch size") + parser.add_argument("--out_dir", type=str, default="outputs/training", help="Base output directory") + + args = parser.parse_args(argv) + run_training(args.max_iters, args.batch_size, args.out_dir) + + +if __name__ == "__main__": + main() diff --git a/compare_checkpoints.py b/compare_checkpoints.py new file mode 100755 index 0000000..fffca23 --- /dev/null +++ b/compare_checkpoints.py @@ -0,0 +1,7 @@ +"""CLI wrapper for BDH checkpoint evaluation.""" + +from bdh_app.evaluation import main + + +if __name__ == "__main__": + main() diff --git a/generate.py b/generate.py new file mode 100755 index 0000000..2922d13 --- /dev/null +++ b/generate.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +""" +Generate text using a trained BDH model. + +This script loads a saved checkpoint and generates text from a prompt. +""" + +import torch +import bdh +import sys +import os + +def load_checkpoint(checkpoint_path): + """ + Load a checkpoint file and return model, config, and metadata. + + Args: + checkpoint_path: Path to the checkpoint file (.pt) + + Returns: + model: Loaded BDH model + config: BDHConfig used for the model + metadata: Dictionary with training info (step, loss, etc.) + """ + print(f"Loading checkpoint from {checkpoint_path}...") + + # Load checkpoint (use CPU for loading, then move to device) + # weights_only=False is needed because checkpoint contains custom BDHConfig object + # This is safe since it's our own checkpoint file + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + # Extract config (if saved in checkpoint) or use default + if 'config' in checkpoint: + config = checkpoint['config'] + print(f"Loaded config from checkpoint") + else: + # Try to infer from checkpoint filename or use default + config = bdh.BDHConfig() + print(f"Using default config (config not found in checkpoint)") + + # Create model with the config + print(f"Creating model: n_layer={config.n_layer}, n_embd={config.n_embd}, " + f"mlp_mult={config.mlp_internal_dim_multiplier}") + model = bdh.BDH(config) + + # Load model weights + if 'model_state_dict' in checkpoint: + model.load_state_dict(checkpoint['model_state_dict']) + print("Model weights loaded successfully") + else: + # Assume the checkpoint IS the state_dict + model.load_state_dict(checkpoint) + print("Loaded checkpoint as state_dict") + + # Extract metadata + metadata = { + 'step': checkpoint.get('step', 'unknown'), + 'loss': checkpoint.get('loss', 'unknown'), + 'epoch': checkpoint.get('epoch', 'unknown'), + } + + return model, config, metadata + +def generate_text(model, prompt_text, max_new_tokens=200, temperature=1.0, top_k=3, device="cpu"): + """ + Generate text from a prompt using the model. + + Args: + model: BDH model (should be in eval mode) + prompt_text: Text prompt to start generation from + max_new_tokens: Maximum number of new tokens to generate + temperature: Sampling temperature (higher = more random) + top_k: Only consider top-k most likely tokens + device: Device to run generation on ("cpu", "mps", or "cuda") + + Returns: + Generated text (prompt + generated continuation) + """ + # Set model to evaluation mode + model.eval() + model = model.to(device) + + # Convert prompt to tensor + prompt_bytes = bytearray(prompt_text, "utf-8") + prompt_tensor = torch.tensor(prompt_bytes, dtype=torch.long, device=device).unsqueeze(0) + + print(f"\nPrompt: '{prompt_text}'") + print(f"Generating {max_new_tokens} tokens...") + print("-" * 60) + + # Generate + with torch.no_grad(): + generated = model.generate( + prompt_tensor, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + ) + + # Decode generated bytes to text + generated_bytes = bytes(generated.to(torch.uint8).to("cpu").squeeze(0)) + generated_text = generated_bytes.decode(errors="backslashreplace") + + return generated_text + +def main(): + """Main function to load model and generate text.""" + import argparse + + parser = argparse.ArgumentParser(description="Generate text using a trained BDH model") + parser.add_argument( + "--checkpoint", + type=str, + default="bdh_checkpoint_step_100.pt", + help="Path to checkpoint file (default: bdh_checkpoint_step_100.pt)" + ) + parser.add_argument( + "--prompt", + type=str, + default="To be or not to be", + help="Text prompt to start generation from (default: 'To be or not to be')" + ) + parser.add_argument( + "--max-tokens", + type=int, + default=200, + help="Maximum number of tokens to generate (default: 200)" + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="Sampling temperature - higher = more random (default: 1.0)" + ) + parser.add_argument( + "--top-k", + type=int, + default=3, + help="Only consider top-k most likely tokens (default: 3)" + ) + parser.add_argument( + "--device", + type=str, + default="auto", + choices=["auto", "cpu", "mps", "cuda"], + help="Device to use (default: auto - detects best available)" + ) + + args = parser.parse_args() + + # Determine device + if args.device == "auto": + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + else: + device = args.device + + print(f"Using device: {device}") + + # Check if checkpoint exists + if not os.path.exists(args.checkpoint): + print(f"Error: Checkpoint file '{args.checkpoint}' not found!") + print(f"Available checkpoint files:") + for f in os.listdir("."): + if f.endswith(".pt"): + print(f" - {f}") + sys.exit(1) + + # Load model + try: + model, config, metadata = load_checkpoint(args.checkpoint) + print(f"\nCheckpoint info:") + print(f" Step: {metadata['step']}") + print(f" Loss: {metadata['loss']}") + except Exception as e: + print(f"Error loading checkpoint: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + + # Generate text + try: + generated = generate_text( + model, + args.prompt, + max_new_tokens=args.max_tokens, + temperature=args.temperature, + top_k=args.top_k, + device=device + ) + + print("\n" + "=" * 60) + print("GENERATED TEXT:") + print("=" * 60) + print(generated) + print("=" * 60) + + except Exception as e: + print(f"Error during generation: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + +if __name__ == "__main__": + main() + diff --git a/main.py b/main.py new file mode 100644 index 0000000..ccf7a33 --- /dev/null +++ b/main.py @@ -0,0 +1,44 @@ +"""Run BDH training, evaluation, and memory demos in order.""" + +import argparse + +from bdh_app.training import run_training +from bdh_app.evaluation import run_evaluation, run_multi_prompt_evaluation +from bdh_app.memory import run_memory_demo + + +def main(): + parser = argparse.ArgumentParser(description="BDH end-to-end runner") + parser.add_argument("--max_iters", type=int, default=3000, help="Training iterations") + parser.add_argument("--batch_size", type=int, default=12, help="Training batch size") + parser.add_argument("--base_out_dir", type=str, default="outputs", help="Base output directory") + parser.add_argument("--prompt", type=str, default="To be or not to be", help="Evaluation prompt") + args = parser.parse_args() + + train_out_dir = f"{args.base_out_dir}/training" + eval_out_dir = f"{args.base_out_dir}/evaluation" + memory_out_dir = f"{args.base_out_dir}/memory" + + run_training(args.max_iters, args.batch_size, train_out_dir) + run_evaluation( + input_dir=f"{train_out_dir}/checkpoints", + output_dir=eval_out_dir, + prompt=args.prompt, + device="auto", + log_file=None, + ) + run_multi_prompt_evaluation( + input_dir=f"{train_out_dir}/checkpoints", + output_dir=eval_out_dir, + prompts=[ + "My lord, I shall", + "The king hath", + "Upon this day", + ], + device="auto", + ) + run_memory_demo(train_out_dir=train_out_dir, out_dir=memory_out_dir) + + +if __name__ == "__main__": + main() diff --git a/memory.py b/memory.py new file mode 100644 index 0000000..cbb45e5 --- /dev/null +++ b/memory.py @@ -0,0 +1,7 @@ +"""CLI wrapper for BDH memory demos.""" + +from bdh_app.memory import main + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 8ad30cc..07200ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ torch numpy requests +pandas +matplotlib diff --git a/train.py b/train.py index 6b982d8..6408a9e 100644 --- a/train.py +++ b/train.py @@ -1,126 +1,7 @@ -# Copyright Pathway Technology, Inc. +"""CLI wrapper for BDH training.""" -import os -from contextlib import nullcontext - -import bdh -import numpy as np -import requests -import torch -import torch.nn as nn -import torch.nn.functional as F - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# On a Mac you can also try -# device=torch.device('mps') - -dtype = ( - "bfloat16" - if torch.cuda.is_available() and torch.cuda.is_bf16_supported() - else "float16" -) # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler -ptdtype = { - "float32": torch.float32, - "bfloat16": torch.bfloat16, - "float16": torch.float16, -}[dtype] -ctx = ( - torch.amp.autocast(device_type=device.type, dtype=ptdtype) - if "cuda" in device.type - else nullcontext() -) -scaler = torch.amp.GradScaler(device=device.type, enabled=(dtype == "float16")) -torch.manual_seed(1337) -torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul -torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn -print(f"Using device: {device} with dtype {dtype}") - - -# Configuration -BDH_CONFIG = bdh.BDHConfig() -BLOCK_SIZE = 512 -BATCH_SIZE = 32 -MAX_ITERS = 3000 -LEARNING_RATE = 1e-3 -WEIGHT_DECAY = 0.1 -LOG_FREQ = 100 - -input_file_path = os.path.join(os.path.dirname(__file__), "input.txt") - - -# Fetch the tiny Shakespeare dataset -def fetch_data(): - if not os.path.exists(input_file_path): - data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" - with open(input_file_path, "w") as f: - f.write(requests.get(data_url).text) - - -def get_batch(split): - # treat the file as bytes - data = np.memmap(input_file_path, dtype=np.uint8, mode="r") - if split == "train": - data = data[: int(0.9 * len(data))] - else: - data = data[int(0.9 * len(data)) :] - ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,)) - x = torch.stack( - [torch.from_numpy((data[i : i + BLOCK_SIZE]).astype(np.int64)) for i in ix] - ) - y = torch.stack( - [ - torch.from_numpy((data[i + 1 : i + 1 + BLOCK_SIZE]).astype(np.int64)) - for i in ix - ] - ) - if torch.cuda.is_available(): - # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) - x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to( - device, non_blocking=True - ) - else: - x, y = x.to(device), y.to(device) - return x, y - - -def eval(model): - model.eval() +from bdh_app.training import main if __name__ == "__main__": - fetch_data() - - model = bdh.BDH(BDH_CONFIG).to(device) - model = torch.compile(model) - optimizer = torch.optim.AdamW( - model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY - ) - - x, y = get_batch("train") - - loss_acc = 0 - loss_steps = 0 - for step in range(MAX_ITERS): - with ctx: - logits, loss = model(x, y) - x, y = get_batch("train") - loss_acc += loss - loss_steps += 1 - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - if step % LOG_FREQ == 0: - print(f"Step: {step}/{MAX_ITERS} loss {loss_acc.item() / loss_steps:.3}") - loss_acc = 0 - loss_steps = 0 - print("Training done, now generating a sample ") - model.eval() - prompt = torch.tensor( - bytearray("To be or ", "utf-8"), dtype=torch.long, device=device - ).unsqueeze(0) - ret = model.generate(prompt, max_new_tokens=100, top_k=3) - ret_decoded = bytes(ret.to(torch.uint8).to("cpu").squeeze(0)).decode( - errors="backslashreplace" - ) - print(ret_decoded) + main()