Skip to content

Latest commit

 

History

History
381 lines (293 loc) · 12.8 KB

File metadata and controls

381 lines (293 loc) · 12.8 KB

CortexGPT Stability Solutions: Fixing the 4-5x Loss Oscillations

Overview

Based on the mathematical analysis proving instability in CortexGPT's dual memory system, this document provides concrete, implementable solutions to stabilize training and eliminate the 4-5x loss oscillations.

Solution Architecture

1. Stabilized Memory Gating System

Replace hard softmax gating with smooth, temperature-controlled mixing:

class StabilizedMemoryGate(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gate_proj = nn.Linear(dim * 3, 3)
        self.temperature = nn.Parameter(torch.ones(1))
        self.gate_bias = nn.Parameter(torch.tensor([0.5, 0.3, 0.2]))  # Favor current
        
    def forward(self, current, stm_value, ltm_value):
        # Smooth gating with residual connection
        inputs = torch.cat([current, stm_value, ltm_value], dim=-1)
        raw_gates = self.gate_proj(inputs) + self.gate_bias
        
        # Temperature-controlled softmax
        gates = F.softmax(raw_gates / self.temperature.clamp(min=0.1), dim=-1)
        
        # Add residual to prevent gradient vanishing
        output = current + gates[:, 1:2] * (stm_value - current) + \
                           gates[:, 2:3] * (ltm_value - current)
        
        return output, gates

2. Differentiable Sparse Activation

Replace hard top-k selection with Gumbel-Softmax:

class DifferentiableSparseActivation(nn.Module):
    def __init__(self, dim, target_sparsity=0.15):  # Increased from 0.05
        super().__init__()
        self.dim = dim
        self.target_sparsity = target_sparsity
        self.temperature = nn.Parameter(torch.ones(1))
        
    def forward(self, x):
        # Compute importance scores
        scores = self.importance_net(x)
        
        # Gumbel-Softmax for differentiable selection
        if self.training:
            # Add Gumbel noise
            U = torch.rand_like(scores)
            gumbel = -torch.log(-torch.log(U + 1e-8) + 1e-8)
            scores_noisy = (scores + gumbel) / self.temperature
            
            # Soft top-k using continuous relaxation
            k = int(x.size(0) * self.target_sparsity)
            threshold = torch.topk(scores_noisy, k, sorted=False)[0].min()
            mask = torch.sigmoid((scores_noisy - threshold) * 10)
        else:
            # Hard selection during inference
            k = int(x.size(0) * self.target_sparsity)
            _, indices = torch.topk(scores, k)
            mask = torch.zeros_like(scores)
            mask.scatter_(0, indices, 1.0)
        
        # Apply mask with gradient flow
        return x * mask.unsqueeze(-1), mask

3. Stable Memory Consolidation

Implement gradual, differentiable consolidation:

class GradualConsolidation(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.consolidation_gate = nn.Linear(config.dim * 2, 1)
        
    def consolidate(self, stm, ltm, force=False):
        if not force and not self.should_consolidate():
            return
        
        # Gradual transfer based on importance
        candidates = []
        for i, memory in enumerate(stm.memories):
            if stm.access_counts[i] >= self.config.consolidation_threshold:
                # Compute transfer probability
                importance = stm.importance_scores[i]
                age = stm.current_time - stm.creation_times[i]
                
                # Smooth consolidation decision
                features = torch.cat([memory['key'], memory['value']], dim=-1)
                transfer_prob = torch.sigmoid(self.consolidation_gate(features))
                
                if transfer_prob > 0.5 or force:
                    candidates.append((memory, importance * transfer_prob))
        
        # Transfer with decay in STM
        for memory, weight in candidates:
            # Add to LTM
            ltm.store(memory['key'], memory['value'], weight)
            
            # Decay in STM (not remove)
            idx = stm.memories.index(memory)
            stm.importance_scores[idx] *= 0.5  # Gradual decay
    
    def should_consolidate(self):
        # Consolidate at regular intervals, not randomly
        return self.step_count % self.config.consolidation_interval == 0

4. Gradient Flow Improvements

Add residual connections and normalization:

class StabilizedCortexGPT(nn.Module):
    def __init__(self, config, vocab_size, dim):
        super().__init__()
        # ... existing init code ...
        
        # Add layer normalization
        self.memory_norm = nn.LayerNorm(dim)
        self.output_norm = nn.LayerNorm(dim)
        
        # Gradient scaling for memory components
        self.memory_scale = nn.Parameter(torch.ones(3))
        
    def forward(self, input_ids):
        # ... existing code ...
        
        # Normalized memory retrieval
        stm_value = self.memory_norm(self.stm.retrieve(current)[0])
        ltm_value = self.memory_norm(self.ltm.retrieve(current)[0])
        
        # Scaled gradients for different memory types
        stm_value = stm_value * self.memory_scale[0]
        ltm_value = ltm_value * self.memory_scale[1]
        current_scaled = current * self.memory_scale[2]
        
        # Stabilized gating
        output, gates = self.memory_gate(current_scaled, stm_value, ltm_value)
        
        # Add residual connection
        output = self.output_norm(output + current)
        
        return output

5. Loss Function Modifications

Implement smooth loss with regularization:

class StabilizedLoss(nn.Module):
    def __init__(self, vocab_size, smoothing=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.smoothing = smoothing
        self.ce_loss = nn.CrossEntropyLoss(reduction='none')
        
    def forward(self, logits, targets, gates=None, sparsity_mask=None):
        # Label smoothing for stability
        with torch.no_grad():
            true_dist = torch.zeros_like(logits)
            true_dist.scatter_(2, targets.unsqueeze(2), 1.0)
            true_dist = true_dist * (1 - self.smoothing) + \
                       self.smoothing / self.vocab_size
        
        # KL divergence loss (more stable than CE)
        log_probs = F.log_softmax(logits, dim=-1)
        kl_loss = F.kl_div(log_probs, true_dist, reduction='none').sum(dim=-1)
        
        # Regularization terms
        reg_loss = 0
        
        if gates is not None:
            # Encourage smooth gate transitions
            gate_smoothness = torch.mean((gates[1:] - gates[:-1])**2)
            reg_loss += 0.1 * gate_smoothness
            
            # Prevent gate collapse
            gate_entropy = -torch.mean(gates * torch.log(gates + 1e-8))
            reg_loss += 0.05 * (2.0 - gate_entropy)  # Encourage entropy ~2
        
        if sparsity_mask is not None:
            # Encourage target sparsity
            actual_sparsity = sparsity_mask.mean()
            target_sparsity = 0.15
            sparsity_loss = (actual_sparsity - target_sparsity)**2
            reg_loss += 0.1 * sparsity_loss
        
        return kl_loss.mean() + reg_loss

6. Training Procedure Modifications

Implement stable training loop:

class StableTrainer:
    def __init__(self, model, config):
        self.model = model
        self.config = config
        
        # Separate optimizers for different components
        self.optimizer_main = torch.optim.AdamW(
            [p for n, p in model.named_parameters() 
             if 'memory' not in n],
            lr=config.lr,
            weight_decay=0.01
        )
        
        self.optimizer_memory = torch.optim.AdamW(
            [p for n, p in model.named_parameters() 
             if 'memory' in n],
            lr=config.lr * 0.1,  # Slower learning for memory
            weight_decay=0.001
        )
        
        # Gradient clipping values
        self.clip_main = 1.0
        self.clip_memory = 0.5
        
        # EMA for stable training
        self.ema = ExponentialMovingAverage(model.parameters(), decay=0.999)
        
    def train_step(self, batch):
        # Forward pass
        logits, metadata = self.model(batch['input_ids'], return_metadata=True)
        
        # Compute loss with stabilization
        loss = self.criterion(
            logits, 
            batch['labels'],
            gates=metadata.get('gates'),
            sparsity_mask=metadata.get('sparsity_mask')
        )
        
        # Backward with gradient clipping
        loss.backward()
        
        # Clip gradients separately
        torch.nn.utils.clip_grad_norm_(
            [p for n, p in self.model.named_parameters() if 'memory' not in n],
            self.clip_main
        )
        torch.nn.utils.clip_grad_norm_(
            [p for n, p in self.model.named_parameters() if 'memory' in n],
            self.clip_memory
        )
        
        # Optimize
        self.optimizer_main.step()
        self.optimizer_memory.step()
        
        # Update EMA
        self.ema.update()
        
        # Clear gradients
        self.optimizer_main.zero_grad()
        self.optimizer_memory.zero_grad()
        
        return loss.item(), metadata

7. Memory System Improvements

Implement bounded memory growth:

class BoundedMemoryBuffer:
    def __init__(self, capacity, dim, decay_rate=0.95):
        self.capacity = capacity
        self.dim = dim
        self.decay_rate = decay_rate
        
        # Fixed-size buffers
        self.keys = torch.zeros(capacity, dim)
        self.values = torch.zeros(capacity, dim)
        self.importance = torch.zeros(capacity)
        self.age = torch.zeros(capacity)
        self.valid = torch.zeros(capacity, dtype=torch.bool)
        
        self.write_ptr = 0
        
    def store(self, key, value, importance=1.0):
        # Always write to current position (circular buffer)
        self.keys[self.write_ptr] = key.detach()
        self.values[self.write_ptr] = value.detach()
        self.importance[self.write_ptr] = importance
        self.age[self.write_ptr] = 0
        self.valid[self.write_ptr] = True
        
        # Move write pointer
        self.write_ptr = (self.write_ptr + 1) % self.capacity
        
        # Age all memories
        self.age[self.valid] += 1
        
        # Decay importance
        self.importance *= self.decay_rate
        
        # Invalidate very old memories
        self.valid[self.age > self.capacity * 2] = False

Implementation Checklist

  1. Immediate Changes (Can implement now):

    • Replace softmax gates with stabilized version
    • Add gradient clipping for memory components
    • Implement regular consolidation intervals
    • Add residual connections around memory retrieval
  2. Short-term Changes (1-2 days):

    • Implement differentiable sparse activation
    • Add layer normalization throughout
    • Separate optimizers for different components
    • Implement EMA for model parameters
  3. Medium-term Changes (1 week):

    • Replace memory buffers with bounded versions
    • Implement smooth consolidation mechanism
    • Add comprehensive regularization to loss
    • Implement gradient scaling parameters

Expected Improvements

After implementing these solutions:

  1. Loss Stability: Oscillations reduced from 4-5x to <1.2x
  2. Gradient Flow: Consistent gradient magnitudes throughout network
  3. Memory Utilization: Both STM and LTM contribute meaningfully
  4. Sparsity: Increased to 15% without information bottleneck
  5. Convergence: Stable convergence within 10-20 epochs

Validation Metrics

Monitor these metrics to confirm stability:

def compute_stability_metrics(trainer, validation_loader):
    metrics = {
        'loss_variance': [],
        'gate_entropy': [],
        'gradient_norm': [],
        'memory_utilization': [],
        'sparsity_actual': []
    }
    
    for batch in validation_loader:
        loss, metadata = trainer.evaluate_batch(batch)
        
        metrics['loss_variance'].append(loss)
        metrics['gate_entropy'].append(
            -torch.sum(metadata['gates'] * torch.log(metadata['gates'] + 1e-8))
        )
        metrics['gradient_norm'].append(
            compute_gradient_norm(trainer.model)
        )
        metrics['memory_utilization'].append({
            'stm': metadata['memory_confidence']['stm'],
            'ltm': metadata['memory_confidence']['ltm']
        })
        metrics['sparsity_actual'].append(
            metadata['sparsity_mask'].mean()
        )
    
    return metrics

Conclusion

These solutions address the fundamental mathematical instabilities in CortexGPT while preserving its innovative dual-memory architecture. The key insight is replacing discontinuous operations (hard gates, top-k selection, random consolidation) with smooth, differentiable alternatives that maintain stable gradient flow.