Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions LLMs/torch_examples/batch_agg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch

def standardize(M: torch.Tensor) -> torch.Tensor:
"""
standardize standardizes each row of the input matrix M to have mean 0 and std 1.
Arguments:
M: input tensor of shape (n, p)
Returns:
A tensor of shape (n, p) where each row has mean 0 and std 1.
"""
return (M-M.mean(axis=1).unsqueeze(1))/M.std(axis=1).unsqueeze(1)

def normalize_per_batch(p: int, num_samples: int = 100, batch_size: int = 10):
"""
normalize_per_batch generates num_samples random p-dim vectors,
splits them into batches of size batch_size (last batch may be smaller),
and standardizes each batch to have mean 0 and std 1 per dimension.
Arguments:
p: dimension of the vectors
num_samples: total number of samples to generate
batch_size: size of each batch
Returns:
A tensor of shape (num_samples, p) containing the standardized vectors.
"""
if batch_size > num_samples:
raise ValueError("batch_size must be at most number of samples.")
mat = torch.randn((num_samples, p), dtype=torch.float32)
num_batches = num_samples // batch_size
remainder = num_samples%batch_size
num_batches += int(remainder != 0)
if remainder == 0:
# evenly split
batch_mat = mat.view(num_batches, batch_size, p)
means = batch_mat.mean(dim=1, keepdim=True) # (num_batches, 1, p)
stds = batch_mat.std(dim=1, keepdim=True) # (num_batches, 1, p)
standardized = (batch_mat - means) / (stds + 1e-8)
return standardized.view(num_samples, p)
else:
# scattered
batch_ids = torch.randint(0, num_batches, (num_samples,))
# count per group
counts = torch.bincount(batch_ids, minlength=num_batches).unsqueeze(1)
batch_ids_processed = batch_ids.unsqueeze(1).expand(-1, p)
sum_per_group = torch.zeros((num_batches, p))\
.scatter_add_(0, batch_ids_processed, mat)
mean_per_group = sum_per_group/counts.clamp(min=1)

sum_sq = (mat - mean_per_group[batch_ids]) ** 2
sum_sq_per_group = torch.zeros((num_batches, p)).scatter_add_(0,
batch_ids_processed,
sum_sq)
var_per_group = sum_sq_per_group / counts.clamp(min=1)
std_per_group = torch.sqrt(var_per_group + 1e-8)
return (mat - mean_per_group[batch_ids]) / (std_per_group[batch_ids])

def sample_gaussian_pairs(p: int, num_samples: int = 10000, eps: float = 1e-8):
"""
sanple_gaussian_pairs generates num_samples pairs of p-dimensional vectors
from standard normal distribution, and computes their normalized inner products.
Highlights that high-dimensional random vectors are almost orthogonal.
Arguments:
p: dimension of the vectors
num_samples: number of pairs to sample
eps: small value to avoid division by zero
Returns:
A tensor of shape (num_samples,) containing the normalized inner products.
"""
pair_mat = torch.randn((2*num_samples, p), dtype=torch.float32)
pair_mat = pair_mat.view(num_samples, 2, p)
norm_first = torch.norm(pair_mat[:, 0, :], dim=1)
norm_second = torch.norm(pair_mat[:, 1, :], dim=1)
dot_prod = (pair_mat[:, 0, :] * pair_mat[:, 1, :]).sum(dim=1)
norm_inner_prods = dot_prod / (norm_first*norm_second + eps)
return norm_inner_prods
65 changes: 65 additions & 0 deletions tests/test_batch_agg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@

import torch
import pytest
from LLMs.torch_examples.batch_agg import standardize, normalize_per_batch, sample_gaussian_pairs

def test_standardize_rows_mean_zero_std_one():
torch.manual_seed(0)
M = torch.randn((3, 4), dtype=torch.float32)
S = standardize(M)
# Check that each row has mean close to 0 and std close to 1
row_means = S.mean(dim=1)
row_stds = S.std(dim=1, unbiased=True)
assert torch.allclose(row_means, torch.zeros_like(row_means), atol=1e-6)
assert torch.allclose(row_stds, torch.ones_like(row_stds), atol=1e-6)


def test_normalize_per_batch_even_split():
torch.manual_seed(0)
p = 5
num_samples = 20
batch_size = 4
result = normalize_per_batch(p, num_samples, batch_size)
# result shape: (num_samples, p)
assert result.shape == (num_samples, p)
# Check each batch has mean 0 and std 1 per dimension
num_batches = num_samples // batch_size
for i in range(num_batches):
batch = result[i*batch_size:(i+1)*batch_size, :]
means = batch.mean(dim=0)
stds = batch.std(dim=0, unbiased=True)
assert torch.allclose(means, torch.zeros_like(means), atol=1e-5)
assert torch.allclose(stds, torch.ones_like(stds), atol=1e-5)


def test_normalize_per_batch_uneven_split():
torch.manual_seed(0)
p = 3
num_samples = 11
batch_size = 4
result = normalize_per_batch(p, num_samples, batch_size)
# result shape: (num_samples, p)
assert result.shape == (num_samples, p)
# Check that each batch (by batch id) has mean ~0 and std ~1 per dimension
# Since batch ids are random, just check global mean/std are close to 0/1
means = result.mean(dim=0)
stds = result.std(dim=0, unbiased=True)
assert torch.allclose(means, torch.zeros_like(means), atol=0.2)
assert torch.allclose(stds, torch.ones_like(stds), atol=0.2)


def test_normalize_per_batch_invalid_batch_size():
with pytest.raises(ValueError):
normalize_per_batch(2, num_samples=5, batch_size=10)


def test_sample_gaussian_pairs_shape_and_range():
torch.manual_seed(0)
p = 100
num_samples = 1000
result = sample_gaussian_pairs(p, num_samples)
assert result.shape == (num_samples,)
# Most values should be between -0.3 and 0.3 for high-dim random vectors
assert (result.abs() < 0.5).float().mean() > 0.95