diff --git a/config/molt.yaml b/config/molt.yaml new file mode 100644 index 0000000..e85786e --- /dev/null +++ b/config/molt.yaml @@ -0,0 +1,127 @@ +# MoLT (Mixture of Linear Transforms) training configuration +# This file uses Lightning CLI's automatic class construction + +seed_everything: 42 + +trainer: + max_steps: 20_000 + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false + num_sanity_val_steps: 0 + accelerator: "gpu" + devices: [0] + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "molt" + name: "molt" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + replacement_model: + class_path: crosslayer_transcoder.metrics.replacement_model_accuracy.ReplacementModelAccuracy + init_args: + model_name: "openai-community/gpt2" + device_map: "cuda:0" + loader_batch_size: 2 + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.00001 + c_sparsity: 100 + use_tanh: true + + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + buffer_size: 2_000_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/crosslayer_transcoder/model/__init__.py b/crosslayer_transcoder/model/__init__.py index e4f0aae..c903ed9 100644 --- a/crosslayer_transcoder/model/__init__.py +++ b/crosslayer_transcoder/model/__init__.py @@ -3,12 +3,15 @@ """ from .clt import CrossLayerTranscoder -from .clt_lightning import CrossLayerTranscoderModule +from .clt_lightning import CrossLayerTranscoderModule, MoltModule +from .molt import Molt from .topk import BatchTopK, PerLayerBatchTopK, PerLayerTopK __all__ = [ "CrossLayerTranscoder", "CrossLayerTranscoderModule", + "Molt", + "MoltModule", "BatchTopK", "PerLayerTopK", "PerLayerBatchTopK", diff --git a/crosslayer_transcoder/model/clt_lightning.py b/crosslayer_transcoder/model/clt_lightning.py index 8e4a020..e7bd74e 100644 --- a/crosslayer_transcoder/model/clt_lightning.py +++ b/crosslayer_transcoder/model/clt_lightning.py @@ -2,7 +2,7 @@ import os import subprocess import time -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import lightning as L import psutil @@ -23,6 +23,7 @@ Decoder, ) from crosslayer_transcoder.model.jumprelu import JumpReLU +from crosslayer_transcoder.model.molt import Molt from crosslayer_transcoder.model.topk import BatchTopK @@ -30,7 +31,7 @@ class CrossLayerTranscoderModule(L.LightningModule): def __init__( self, # Pre-constructed modules - model: CrossLayerTranscoder, + model: Union[CrossLayerTranscoder, Molt], replacement_model: Optional[ReplacementModelAccuracy] = None, dead_features: Optional[DeadFeatures] = None, # Training parameters @@ -85,17 +86,23 @@ def __init__( self.beta2 = beta2 self.log_metrics_every = log_metrics_every - assert self.model.encoder.n_layers == self.model.decoder.n_layers, ( - "Encoder and decoder must have the same number of layers" - ) + if isinstance(self.model, Molt): + self.register_buffer( + "last_active", + torch.zeros((self.model.n_features,), dtype=torch.long), + ) + else: + assert self.model.encoder.n_layers == self.model.decoder.n_layers, ( + "Encoder and decoder must have the same number of layers" + ) - self.register_buffer( - "last_active", - torch.zeros( - (self.model.encoder.n_layers, self.model.encoder.d_features), - dtype=torch.long, - ), - ) + self.register_buffer( + "last_active", + torch.zeros( + (self.model.encoder.n_layers, self.model.encoder.d_features), + dtype=torch.long, + ), + ) def configure_model(self): # Apply compilation if requested @@ -565,3 +572,62 @@ def training_step(self, batch, batch_idx): torch.cuda.memory._record_memory_history(enabled=None) exit() return loss + + +class MoltModule(CrossLayerTranscoderModule): + def __init__( + self, + lambda_sparsity: float = 0.0002, + c_sparsity: float = 0.1, + use_tanh: bool = True, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._lambda = lambda_sparsity + self.c = c_sparsity + self.use_tanh = use_tanh + + def current_sparsity_penalty(self): + n_steps = self.trainer.max_steps + current_step = self.global_step + cur_lambda = self._lambda * (current_step / n_steps) + self.log("training/sparsity_penalty", cur_lambda) + return cur_lambda + + def forward(self, batch, layer): + return self.model.forward(batch, layer) + + def training_step(self, batch, batch_idx): + if batch_idx == 0: + self.model.initialize_standardizers(batch) + self.log("model/d_latents", self.model.d_latents) + self.log("model/n_features", self.model.n_features) + + layer = 8 + + resid, mlp_out = batch[:, 0], batch[:, 1] + resid = resid[:, layer] + mlp_out = mlp_out[:, layer] + gate, recons_norm, recons = self.model.forward(resid, layer) + + self.update_dead_features(gate) + mse = ( + recons_norm - self.model.output_standardizer.standardize(mlp_out, layer) + ) ** 2 + + norms = self.model.transform_norm() + weighted_norms = norms * gate + self.log("model/weighted_norms_mean", weighted_norms.detach().mean().cpu()) + + if self.use_tanh: + weighted_norms = torch.tanh(weighted_norms * self.c) + sparsity = self.current_sparsity_penalty() * weighted_norms.sum(dim=-1).mean() + self.log("training/sparsity_loss", sparsity) + self.log("L0", (gate > 0.0).float().sum() / gate.shape[0]) + + loss = mse.mean() + sparsity + self.log("training/mse", mse.mean()) + self.log("training/loss", loss) + + return loss diff --git a/crosslayer_transcoder/model/jumprelu.py b/crosslayer_transcoder/model/jumprelu.py index d21b7ec..c977830 100644 --- a/crosslayer_transcoder/model/jumprelu.py +++ b/crosslayer_transcoder/model/jumprelu.py @@ -52,7 +52,8 @@ def backward(ctx, grad_output): class JumpReLU(SerializableModule): def __init__(self, theta=0.0, bandwidth=1.0, n_layers=12, d_features=768 * 8): super().__init__() - self.theta = nn.Parameter(torch.full((1, n_layers, d_features), theta)) + shape = (1, n_layers, d_features) if n_layers > 1 else (1, d_features) + self.theta = nn.Parameter(torch.full(shape, theta)) self.register_buffer("bandwidth", torch.tensor(bandwidth)) self._init_theta = theta self.n_layers = n_layers diff --git a/crosslayer_transcoder/model/molt.py b/crosslayer_transcoder/model/molt.py new file mode 100644 index 0000000..86d091d --- /dev/null +++ b/crosslayer_transcoder/model/molt.py @@ -0,0 +1,95 @@ +import einops +import torch +import torch.nn as nn +from jaxtyping import Float + + +class Molt(nn.Module): + def __init__( + self, + d_acts: int, + N: int, + nonlinearity: nn.Module, + input_standardizer: nn.Module, + output_standardizer: nn.Module, + ranks: list[int] = [512, 256, 128, 64, 32], + ): + super().__init__() + + self.d_acts = d_acts + self.nonlinearity = nonlinearity + self.input_standardizer = input_standardizer + self.output_standardizer = output_standardizer + Us = [] + Vs = [] + rank_multiplier = 1 + n_features = 0 + d_latents = 0 + for rank in ranks: + Us.append(nn.Parameter(torch.empty(N * rank_multiplier, rank, d_acts))) + Vs.append(nn.Parameter(torch.empty(N * rank_multiplier, d_acts, rank))) + n_features += N * rank_multiplier + d_latents += N * rank_multiplier * rank + rank_multiplier *= 2 + self.n_features = n_features + self.e = nn.Linear(d_acts, n_features) + self.Us = nn.ParameterList(Us) + self.Vs = nn.ParameterList(Vs) + + print(f"d_latents (transcoder equivalent): {d_latents}") + self.d_latents = d_latents + + self.reset_parameters() + + def reset_parameters(self): + for U in self.Us: + nn.init.xavier_uniform_(U) + for V in self.Vs: + nn.init.xavier_uniform_(V) + + def transform_norm(self): + norms = [] + for U, V in zip(self.Us, self.Vs): + uv = einops.einsum( + U, + V, + "n_transforms d_transform d_acts_out, n_transforms d_acts_in d_transform -> n_transforms d_acts_in d_acts_out", + ) + norms.append(torch.norm(uv, dim=(1, 2))) + return torch.cat(norms, dim=0) + + def forward( + self, acts: Float[torch.Tensor, "batch_size d_acts"], layer: int + ) -> Float[torch.Tensor, "batch_size d_acts"]: + acts = self.input_standardizer(acts, layer) + pre_actvs = self.e(acts) + gate = self.nonlinearity(pre_actvs) # (batch, n_transforms) + + raw_recons = [] + for U, V in zip(self.Us, self.Vs): + latents = einops.einsum( + acts, + V, + "batch d_acts, n_transforms d_acts d_transform -> batch n_transforms d_transform", + ) + raw_recons.append( + einops.einsum( + latents, + U, + "batch n_transforms d_transform, n_transforms d_transform d_acts -> batch n_transforms d_acts", + ) + ) + + raw_recons = torch.cat(raw_recons, dim=1) + + weighted_recons = gate.unsqueeze(-1) * raw_recons + recons_norm = weighted_recons.sum(dim=1) + + recons = self.output_standardizer(recons_norm, layer) + return gate, recons_norm, recons + + def initialize_standardizers( + self, batch: Float[torch.Tensor, "batch_size io n_layers d_acts"] + ): + self.input_standardizer.initialize_from_batch(batch) + self.output_standardizer.initialize_from_batch(batch)