Skip to content

ha466/HRM_RDT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

6 Commits
Β 
Β 
Β 
Β 

Repository files navigation

🧬 HybridMythos: Hierarchical Recurrent-Depth Transformer

HybridMythos is an open-source, theoretical architecture that merges the infinite-depth looping of OpenMythos with the brain-inspired, dual-timescale cognitive processing of the Hierarchical Reasoning Model (HRM).

By replacing the standard "flat" recurrent block of a looped transformer with a structured, nested-loop hierarchy, HybridMythos achieves the parameter efficiency of a continuous latent reasoner while maintaining the strategic oversight necessary for complex, multi-hop problem-solving.


πŸ’‘ The Core Concept

Modern large language models struggle with deep reasoning because they are fundamentally shallow, relying on explicit, token-heavy Chain-of-Thought (CoT) to solve algorithmic problems.

HybridMythos solves this by combining two breakthrough paradigms:

  1. OpenMythos (Recurrent-Depth Transformer): Reuses the same transformer weights in a continuous loop, allowing the model to "think" deeper in latent space without expanding its parameter count. It maintains stability across infinite loops using a strictly constrained LTI (Linear Time-Invariant) injection matrix.
  2. HRM (Hierarchical Reasoning Model): Abandons flat loops in favor of two coupled recurrent modules operating at different timescales: a high-level module for slow, abstract planning and a low-level module for rapid, detailed execution.

The Synergy: HybridMythos uses HRM's "Manager/Worker" architecture inside the OpenMythos stability loop. It gives the looped transformer a structured brain that can dynamically course-correct during long reasoning trajectories.


πŸ—οΈ Architecture

The model forward pass is divided into three functional stages:

1. Prelude [Run Once] Standard transformer layers process the input tokens to generate the initial hidden state ($h_0$) and the encoded injection vector ($e$).

2. Hierarchical Recurrent Block [The Core Engine] Instead of a single loop, the recurrent block maintains two hidden states ($h_{high}$ and $h_{low}$) through a nested execution:

  • The Outer Loop (H-Module / "Manager"): Updates occasionally to set the global reasoning direction and establish a new local equilibrium for the worker.
  • The Inner Loop (L-Module / "Worker"): Runs multiple fast iterations per H-Module step, rapidly refining details and grinding through computations.
  • Stability Injection: At every step, the original input $e$ is injected using learned matrices $A$ and $B$, where the spectral radius $\rho(A) < 1$. This prevents the deep recurrent loops from exploding into NaN values.

3. Coda [Run Once] Standard transformer layers process the final $h_{low}$ state into the output logits.


πŸš€ Quick Start & PyTorch Implementation

Installation

Clone the required base repositories (from our conversation history context):

git clone https://github.com/kyegomez/OpenMythos.git
git clone https://github.com/sapientinc/HRM.git

The PyTorch Wrapper

Here is the structural PyTorch pseudo-code to instantiate the hybrid model:

Please note: While the core architectural logic, loop structures, and training hyperparameters are drawn directly from the provided source texts, the standard PyTorch boilerplate (e.g., the standard training loop mechanics, optimizer setup, and dummy dataloaders) is filled in using external general programming knowledge to give you fully runnable scripts. You may want to independently verify and adapt the standard PyTorch boilerplate to your specific data environment.

1. hybrid_model.py (The Architecture)

This file creates the HybridMythos class, merging the OpenMythos Prelude/Coda with the HRM dual-timescale modules and the LTI stability injection.

import torch
import torch.nn as nn
from open_mythos.main import MythosConfig, OpenMythos
# Assuming you have cloned sapientinc/HRM and adapted their modules
from hrm_modules import HighLevelModule, LowLevelModule 

class HybridMythos(nn.Module):
    def __init__(self, cfg: MythosConfig):
        super().__init__()
        # 1. Prelude and Coda from OpenMythos (1-2 layers)
        mythos_base = OpenMythos(cfg)
        self.prelude = mythos_base.prelude
        self.coda = mythos_base.coda
        
        # 2. Dual-timescale reasoning modules from HRM
        self.h_module = HighLevelModule(cfg.dim)
        self.l_module = LowLevelModule(cfg.dim)
        
        # 3. LTI Stability matrices for input injection
        # A is constrained to a negative diagonal to ensure spectral radius < 1
        self.injection_A = nn.Parameter(torch.randn(cfg.dim)) 
        self.injection_B = nn.Parameter(torch.randn(cfg.dim, cfg.dim))

    def forward(self, input_ids, n_outer=4, n_inner=4):
        # Prelude: Generate initial hidden state and continuous injection vector e
        h_low, e = self.prelude(input_ids)
        h_high = h_low.clone()
        
        # Hierarchical Recurrent Block
        for t in range(n_outer):
            # Slow "CEO" Manager Update (runs once per outer loop)
            h_high = self.h_module(h_high, h_low, e)
            
            for s in range(n_inner):
                # Fast "Worker" Update (runs multiple times per outer loop)
                h_low = self.l_module(h_low, h_high, e)
                
                # Stability Injection (Ensuring Spectral Radius < 1)
                # h = A*h + B*e + Transformer(h,e)
                negative_A = -torch.exp(self.injection_A) # Enforce negative diagonal
                h_low = (negative_A * h_low) + (e @ self.injection_B) + h_low
                
        # Coda: Generate final output logits
        logits = self.coda(h_low)
        return logits

    def generate(self, input_ids, max_new_tokens=512, n_outer=8, n_inner=8):
        # Auto-regressive generation loop for inference
        generated_ids = input_ids
        for _ in range(max_new_tokens):
            logits = self.forward(generated_ids, n_outer=n_outer, n_inner=n_inner)
            next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
        return generated_ids

2. train.py (The Training Script)

This script is configured for your RTX Pro 6000 to pre-train the ~700M parameter model on a chunk of your 4B token dataset. It keeps the loop count small and fixed (max_loop=4) during training for stability.

import torch
import torch.optim as optim
import argparse
from hybrid_model import HybridMythos
from open_mythos.main import MythosConfig

def train():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--grad_accum", type=int, default=8)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--max_loop", type=int, default=4)  # Small fixed loops for training
    args = parser.parse_args()

    # 700M Parameter Configuration
    cfg = MythosConfig(
        vocab_size=128256,
        dim=2048,
        n_heads=32,
        n_layers_prelude=2,
        n_layers_coda=2,
        max_loop_iters=args.max_loop, 
        n_experts=8,
        n_shared_experts=2,
        n_experts_per_tok=2,
        expert_dim=512,
        attn_type="mla"
    )

    model = HybridMythos(cfg).cuda()
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    loss_fn = torch.nn.CrossEntropyLoss()

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # --- External Data Loading Boilerplate ---
    # Replace this block with your actual Dataloader mapped to your 'data/pretrain/' path
    dummy_data = torch.randint(0, cfg.vocab_size, (args.batch_size, 512)).cuda()
    dummy_labels = torch.randint(0, cfg.vocab_size, (args.batch_size, 512)).cuda()
    # -----------------------------------------

    model.train()
    for epoch in range(args.epochs):
        optimizer.zero_grad()
        
        # Forward pass with fixed, small training loops
        logits = model(dummy_data, n_outer=args.max_loop, n_inner=args.max_loop)
        
        loss = loss_fn(logits.view(-1, cfg.vocab_size), dummy_labels.view(-1))
        loss.backward()
        
        # Step optimizer
        optimizer.step()
        print(f"Epoch {epoch} | Loss: {loss.item()}")

    torch.save(model.state_dict(), "checkpoint_700m.pt")
    print("Pre-training complete. Model saved.")

if __name__ == "__main__":
    train()

3. runner.py (The Inference Script)

This is where the magic happens at deployment. You load your trained 700M parameter model and crank the outer and inner loops to 8.

import torch
from hybrid_model import HybridMythos
from open_mythos.main import MythosConfig

def run_inference():
    # Must match the exact configuration from training
    cfg = MythosConfig(
        vocab_size=128256,
        dim=2048,
        n_heads=32,
        n_layers_prelude=2,
        n_layers_coda=2,
        n_experts=8,
        n_shared_experts=2,
        n_experts_per_tok=2,
        expert_dim=512,
        attn_type="mla"
    )

    # Initialize and load weights
    model = HybridMythos(cfg).cuda()
    try:
        model.load_state_dict(torch.load("checkpoint_700m.pt"))
        print("Loaded checkpoint_700m.pt successfully.")
    except FileNotFoundError:
        print("Checkpoint not found. Running with random weights for demonstration.")

    model.eval()

    # --- External Tokenizer Boilerplate ---
    # You would load your actual tokenizer here (e.g., Qwen/Qwen2.5-7B)
    prompt_ids = torch.randint(0, cfg.vocab_size, (1, 16)).cuda() 
    # --------------------------------------

    print("Generating response with deep latent reasoning...")
    with torch.no_grad():
        # Crank loops for god-tier reasoning (1.6B effective depth)
        output_ids = model.generate(
            prompt_ids, 
            max_new_tokens=128, 
            n_outer=8,  # Slow planning loops
            n_inner=8   # Fast detail loops
        )

    print("Generation complete.")
    print("Output shape:", output_ids.shape)

if __name__ == "__main__":
    run_inference()

🧠 Training & Scaling Strategy

As discussed in our configuration strategies:

  1. Efficient Pre-training: Train your ~700M parameter base model on your dataset with a fixed, small loop count (e.g., n_outer=4, n_inner=4).
  2. Instruction Tuning: Freeze the hierarchical recurrent block and only fine-tune the Prelude, Coda, and injection matrices to prevent catastrophic forgetting of reasoning pathways.
  3. Inference-Time Scaling: At inference time, dynamically increase the loops (e.g., n_outer=8, n_inner=8). The 700M parameter model will execute a deeper reasoning chain.

πŸ”— Alternative Open-Source Implementations

If you are looking for a pre-built datacenter-ready alternative that leverages similar mechanics, check out Hierarchos (necat101/Hierarchos). Hierarchos natively synergizes HRM-style reasoning with an RWKV v8 linear-attention backbone for $O(1)$ inference cost, and integrates Titans neural memory for lifelong context retention.

πŸ“œ Acknowledgements & References

  • OpenMythos: Recurrent-Depth Transformer architecture by Kye Gomez.
  • HRM: Hierarchical Reasoning Model by Sapient Intelligence.
  • Hierarchos: Hybrid HRM/Titans/RWKV framework by Makhi Burroughs.

About

HybridMythos: Hierarchical Recurrent-Depth Transformer HybridMythos is an open-source, theoretical architecture that merges the infinite-depth looping of OpenMythos with the brain-inspired, dual-timescale cognitive processing of the Hierarchical Reasoning Model (HRM).

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages