Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions config/molt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# MoLT (Mixture of Linear Transforms) training configuration
# This file uses Lightning CLI's automatic class construction

seed_everything: 42

trainer:
max_steps: 20_000
val_check_interval: 1000000000 # replacement model for single layer doesn't make sense
limit_val_batches: 1
check_val_every_n_epoch: null
enable_checkpointing: false
num_sanity_val_steps: 0
accelerator: "gpu"
devices: [0]
accumulate_grad_batches: 1
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: "molt"
name: "molt"
save_dir: "./wandb"
callbacks:
- class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback
init_args:
checkpoint_dir: "checkpoints"

model:
class_path: crosslayer_transcoder.model.clt_lightning.MoltModule
init_args:
model:
class_path: crosslayer_transcoder.model.molt.Molt
init_args:
d_acts: 768
N: 50
nonlinearity:
class_path: crosslayer_transcoder.model.jumprelu.JumpReLU
init_args:
theta: 0.03
bandwidth: 1.0
n_layers: 1
d_features: 1550
input_standardizer:
class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer
init_args:
n_layers: 12
activation_dim: 768
output_standardizer:
class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer
init_args:
n_layers: 12
activation_dim: 768

replacement_model:
class_path: crosslayer_transcoder.metrics.replacement_model_accuracy.ReplacementModelAccuracy
init_args:
model_name: "openai-community/gpt2"
device_map: "cuda:0"
loader_batch_size: 2

dead_features:
class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures
init_args:
n_features: 120
n_layers: 12
return_per_layer: true
return_log_freqs: true
return_neuron_indices: true

learning_rate: 2e-4
compile: true
lr_decay_step: 1000000000
lr_decay_factor: 0.1

lambda_sparsity: 0.00001
c_sparsity: 100
use_tanh: true

compute_dead_features: true
compute_dead_features_every: 500

data:
class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule
init_args:
buffer_size: 2_000_000
n_in_out: 2
n_layers: 12
activation_dim: 768
dtype: "float32"
max_batch_size: 50000

model_name: "openai-community/gpt2"
model_dtype: "float32"

dataset_name: "Skylion007/openwebtext"
dataset_split: "train"
max_sequence_length: 1024

generation_batch_size: 10
refresh_interval: 0.1

shared_memory_name: "activation_buffer"
timeout_seconds: 30

batch_size: 1000
num_workers: 10
prefetch_factor: 2
shuffle: true
persistent_workers: true
pin_memory: true

minimum_fill_threshold: 0.01

use_shared_memory: true

device_map: "cuda:0"
deployment_policy: "gpu_only"

wandb_logging:
enabled: true
project: "molt"
group: null
run_name: "data-generator"
tags: ["data-generation"]
save_dir: "./wandb"
log_interval: 5.0

ckpt_path: null
5 changes: 4 additions & 1 deletion crosslayer_transcoder/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
"""

from .clt import CrossLayerTranscoder
from .clt_lightning import CrossLayerTranscoderModule
from .clt_lightning import CrossLayerTranscoderModule, MoltModule
from .molt import Molt
from .topk import BatchTopK, PerLayerBatchTopK, PerLayerTopK

__all__ = [
"CrossLayerTranscoder",
"CrossLayerTranscoderModule",
"Molt",
"MoltModule",
"BatchTopK",
"PerLayerTopK",
"PerLayerBatchTopK",
Expand Down
90 changes: 78 additions & 12 deletions crosslayer_transcoder/model/clt_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import subprocess
import time
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import lightning as L
import psutil
Expand All @@ -23,14 +23,15 @@
Decoder,
)
from crosslayer_transcoder.model.jumprelu import JumpReLU
from crosslayer_transcoder.model.molt import Molt
from crosslayer_transcoder.model.topk import BatchTopK


class CrossLayerTranscoderModule(L.LightningModule):
def __init__(
self,
# Pre-constructed modules
model: CrossLayerTranscoder,
model: Union[CrossLayerTranscoder, Molt],
replacement_model: Optional[ReplacementModelAccuracy] = None,
dead_features: Optional[DeadFeatures] = None,
# Training parameters
Expand Down Expand Up @@ -85,17 +86,23 @@ def __init__(
self.beta2 = beta2
self.log_metrics_every = log_metrics_every

assert self.model.encoder.n_layers == self.model.decoder.n_layers, (
"Encoder and decoder must have the same number of layers"
)
if isinstance(self.model, Molt):
self.register_buffer(
"last_active",
torch.zeros((self.model.n_features,), dtype=torch.long),
)
else:
assert self.model.encoder.n_layers == self.model.decoder.n_layers, (
"Encoder and decoder must have the same number of layers"
)

self.register_buffer(
"last_active",
torch.zeros(
(self.model.encoder.n_layers, self.model.encoder.d_features),
dtype=torch.long,
),
)
self.register_buffer(
"last_active",
torch.zeros(
(self.model.encoder.n_layers, self.model.encoder.d_features),
dtype=torch.long,
),
)

def configure_model(self):
# Apply compilation if requested
Expand Down Expand Up @@ -565,3 +572,62 @@ def training_step(self, batch, batch_idx):
torch.cuda.memory._record_memory_history(enabled=None)
exit()
return loss


class MoltModule(CrossLayerTranscoderModule):
def __init__(
self,
lambda_sparsity: float = 0.0002,
c_sparsity: float = 0.1,
use_tanh: bool = True,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._lambda = lambda_sparsity
self.c = c_sparsity
self.use_tanh = use_tanh

def current_sparsity_penalty(self):
n_steps = self.trainer.max_steps
current_step = self.global_step
cur_lambda = self._lambda * (current_step / n_steps)
self.log("training/sparsity_penalty", cur_lambda)
return cur_lambda

def forward(self, batch, layer):
return self.model.forward(batch, layer)

def training_step(self, batch, batch_idx):
if batch_idx == 0:
self.model.initialize_standardizers(batch)
self.log("model/d_latents", self.model.d_latents)
self.log("model/n_features", self.model.n_features)

layer = 8

resid, mlp_out = batch[:, 0], batch[:, 1]
resid = resid[:, layer]
mlp_out = mlp_out[:, layer]
gate, recons_norm, recons = self.model.forward(resid, layer)

self.update_dead_features(gate)
mse = (
recons_norm - self.model.output_standardizer.standardize(mlp_out, layer)
) ** 2

norms = self.model.transform_norm()
weighted_norms = norms * gate
self.log("model/weighted_norms_mean", weighted_norms.detach().mean().cpu())

if self.use_tanh:
weighted_norms = torch.tanh(weighted_norms * self.c)
sparsity = self.current_sparsity_penalty() * weighted_norms.sum(dim=-1).mean()
self.log("training/sparsity_loss", sparsity)
self.log("L0", (gate > 0.0).float().sum() / gate.shape[0])

loss = mse.mean() + sparsity
self.log("training/mse", mse.mean())
self.log("training/loss", loss)

return loss
3 changes: 2 additions & 1 deletion crosslayer_transcoder/model/jumprelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def backward(ctx, grad_output):
class JumpReLU(SerializableModule):
def __init__(self, theta=0.0, bandwidth=1.0, n_layers=12, d_features=768 * 8):
super().__init__()
self.theta = nn.Parameter(torch.full((1, n_layers, d_features), theta))
shape = (1, n_layers, d_features) if n_layers > 1 else (1, d_features)
self.theta = nn.Parameter(torch.full(shape, theta))
self.register_buffer("bandwidth", torch.tensor(bandwidth))
self._init_theta = theta
self.n_layers = n_layers
Expand Down
95 changes: 95 additions & 0 deletions crosslayer_transcoder/model/molt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import einops
import torch
import torch.nn as nn
from jaxtyping import Float


class Molt(nn.Module):
def __init__(
self,
d_acts: int,
N: int,
nonlinearity: nn.Module,
input_standardizer: nn.Module,
output_standardizer: nn.Module,
ranks: list[int] = [512, 256, 128, 64, 32],
):
super().__init__()

self.d_acts = d_acts
self.nonlinearity = nonlinearity
self.input_standardizer = input_standardizer
self.output_standardizer = output_standardizer
Us = []
Vs = []
rank_multiplier = 1
n_features = 0
d_latents = 0
for rank in ranks:
Us.append(nn.Parameter(torch.empty(N * rank_multiplier, rank, d_acts)))
Vs.append(nn.Parameter(torch.empty(N * rank_multiplier, d_acts, rank)))
n_features += N * rank_multiplier
d_latents += N * rank_multiplier * rank
rank_multiplier *= 2
self.n_features = n_features
self.e = nn.Linear(d_acts, n_features)
self.Us = nn.ParameterList(Us)
self.Vs = nn.ParameterList(Vs)

print(f"d_latents (transcoder equivalent): {d_latents}")
self.d_latents = d_latents

self.reset_parameters()

def reset_parameters(self):
for U in self.Us:
nn.init.xavier_uniform_(U)
for V in self.Vs:
nn.init.xavier_uniform_(V)

def transform_norm(self):
norms = []
for U, V in zip(self.Us, self.Vs):
uv = einops.einsum(
U,
V,
"n_transforms d_transform d_acts_out, n_transforms d_acts_in d_transform -> n_transforms d_acts_in d_acts_out",
)
norms.append(torch.norm(uv, dim=(1, 2)))
return torch.cat(norms, dim=0)

def forward(
self, acts: Float[torch.Tensor, "batch_size d_acts"], layer: int
) -> Float[torch.Tensor, "batch_size d_acts"]:
acts = self.input_standardizer(acts, layer)
pre_actvs = self.e(acts)
gate = self.nonlinearity(pre_actvs) # (batch, n_transforms)

raw_recons = []
for U, V in zip(self.Us, self.Vs):
latents = einops.einsum(
acts,
V,
"batch d_acts, n_transforms d_acts d_transform -> batch n_transforms d_transform",
)
raw_recons.append(
einops.einsum(
latents,
U,
"batch n_transforms d_transform, n_transforms d_transform d_acts -> batch n_transforms d_acts",
)
)

raw_recons = torch.cat(raw_recons, dim=1)

weighted_recons = gate.unsqueeze(-1) * raw_recons
recons_norm = weighted_recons.sum(dim=1)

recons = self.output_standardizer(recons_norm, layer)
return gate, recons_norm, recons

def initialize_standardizers(
self, batch: Float[torch.Tensor, "batch_size io n_layers d_acts"]
):
self.input_standardizer.initialize_from_batch(batch)
self.output_standardizer.initialize_from_batch(batch)