Hybrid LLM combining Multi-Head Latent Attention · Sparse MoE · Selective SSM · Titans Neural Memory — with built-in anti-hallucination and Triton/Numba acceleration.
VishwamAI is a decoder-only language model that combines four complementary subsystems in a single forward pass:
| Subsystem | Role | Efficiency gain |
|---|---|---|
| MLA (Multi-Head Latent Attention) | Token-to-token reasoning | 14× KV-cache reduction |
| Sparse MoE (SwiGLU experts) | Specialised FFN processing | Only 3% params fire per token |
| Selective SSM (Mamba-2 style) | Long-context O(n) sequence modelling | Linear vs O(n²) attention |
| Titans LTM (Neural Long-Term Memory) | Persistent memory beyond context window | Infinite effective context |
50% MoE · 43% SSM · 7% MLA
Attention every 6th layer; SSM everywhere else keeps the cost curve linear.
Each block follows Pre-LN:
x <- x + sublayer(RMSNorm(x))
At attention blocks, the Titans LTM is queried first (MAC — Memory As Context) and added residually before the attention sublayer.
Standard transformers cache full K and V tensors — expensive at long contexts. MLA caches a compressed latent c_KV of shape (B, T, kv_latent_dim) and a decoupled RoPE key k_rope of shape (B, T, H, rope_head_dim). K, V are decompressed at attention time via learned up-projections.
Cache: kv_latent_dim + H×rope_head_dim vs 2 × num_heads × head_dim → ~14× reduction
The RoPE component is cached at absolute positions, giving correct positional encoding during incremental decoding.
Top-k routing with SwiGLU expert FFNs. The load-balance auxiliary loss uses the Switch Transformer formulation:
aux_loss = E × Σ_i (f_i × P_i)
where f_i is the hard dispatch fraction (non-differentiable) and P_i is the soft router probability (differentiable). This ensures gradients flow only through the soft path. Weight is applied once by VishwamAIConfig.moe_aux_loss_weight.
A, B, C, and Δ are all input-dependent (selective), giving the model fine-grained control over what to remember across a sequence. Sequential scan replaceable with a parallel associative scan kernel (Triton on GPU).
An online-learned memory matrix M that persists across forward passes. During generation, a surprise-driven momentum SGD step writes new information to M:
error = v - M @ k # prediction error
grad = -outer(error, k) # negative gradient of ‖error‖²
momentum = β·momentum + (1-β)·grad # EMA on gradient
M -= lr · momentum # net: M += lr·EMA(outer(v - M@k, k))
Read: q @ M.T — associative retrieval with no positional limit.
Three-layer human-in-the-loop correction system:
- ConfidenceHead — per-token confidence score in [0, 1]
- SelfCorrector — flags low-confidence spans, runs verifier, proposes retraction — never applied without human approval
- OnlineReinforcer — human-triggered only: writes correction to Titans LTM + contrastive gradient step
Retraction format:
I was wrong about: "...". The correct answer is: "...". [reason]
git clone https://github.com/kasinadhsarma/VishwamAI.git
cd VishwamAI
pip install -r requirements.txtGPU users (recommended):
pip install triton>=2.3.0 # Triton GPU kernels (optional, auto-detected)
pip install numba>=0.59.0 # Numba CPU JIT (optional, auto-detected)import torch
from vishwamai import VishwamAI, small_config
from vishwamai.tokenizer import ByteTokenizer
cfg = small_config()
model = VishwamAI(cfg).eval()
tok = ByteTokenizer()
ids = torch.tensor([tok.encode("The meaning of life is")])
out_ids = model.generate(ids, max_new_tokens=128, temperature=0.8, top_p=0.9)
print(tok.decode(out_ids[0].tolist()))from vishwamai import (
AntiHallucinationVishwamAI, ConfidenceHead,
VishwamAI, small_config,
)
from vishwamai.tokenizer import ByteTokenizer
cfg = small_config()
model = VishwamAI(cfg).eval()
confidence = ConfidenceHead(cfg)
tok = ByteTokenizer()
ah = AntiHallucinationVishwamAI(model, confidence, tok)
result = ah.ask("What is the capital of Australia?", max_new_tokens=64)
print(result["answer"]) # final answer
print(result["confidence"]) # mean token confidence in [0, 1]
print(result["needs_review"]) # True when a correction was flagged
# Human-triggered correction (updates LTM + model weights):
if result["needs_review"]:
ah.correct(
prompt = "What is the capital of Australia?",
wrong_answer = result["raw_answer"],
correction = "Canberra",
)torchrun --nproc_per_node=8 scripts/pretrain.py \
--config small \
--data_path /data/tokens.bin \
--batch_size 32 \
--max_steps 100000 \
--lr 3e-4 \
--output_dir checkpoints/pretrainfrom vishwamai import VishwamAI, medium_config, GRPOConfig, GRPOTrainer, MathVerifier
from vishwamai.tokenizer import SentencePieceTokenizer
cfg = medium_config()
model = VishwamAI(cfg)
ref_model = VishwamAI(cfg) # frozen reference copy
tok = SentencePieceTokenizer("tokenizer.model")
verifier = MathVerifier()
trainer = GRPOTrainer(
model, ref_model, tok, verifier,
GRPOConfig(group_size=8, max_new_tokens=512, lr=1e-5),
)
for prompts, answers in dataloader:
stats = trainer.step(prompts, answers)
print(stats)GRPO generates group_size responses per prompt, scores them with a verifier (CodeVerifier / MathVerifier / FormatVerifier), and computes group-relative advantage:
A_i = (r_i - mean(r)) / (std(r) + eps)
python scripts/train_tokenizer.py \
--corpus /data/text.txt \
--vocab_size 32000 \
--output tokenizer.modelVishwamAI saves weights in SafeTensors format — no pickle, memory-mapped, cross-framework safe.
from vishwamai import save_model, load_model
from vishwamai.checkpoint import export_embeddings
# Save
save_model(model, "checkpoints/step_10000", step=10000)
# Load
model = load_model("checkpoints/step_10000", device="cuda")
# Export embeddings for downstream use (RAG / vector search)
export_embeddings(model, "embeddings/", fmt="safetensors")
# -> embeddings/embeddings.safetensorsVishwamAI exposes all inference and memory operations via the Model Context Protocol so any MCP-compatible host (Claude Desktop, custom agents) can use it as a first-class AI provider.
# stdio transport (Claude Desktop)
python -m vishwamai.mcp_server --checkpoint checkpoints/small
# HTTP/SSE transport
python -m vishwamai.mcp_server --checkpoint checkpoints/medium \
--transport http --host 0.0.0.0 --port 8765Available tools: vishwamai_generate, vishwamai_generate_verified, vishwamai_load_model, vishwamai_save_model, vishwamai_ltm_read, vishwamai_ltm_reset, vishwamai_tokenize, vishwamai_detokenize, vishwamai_correct.
VishwamAI automatically selects the fastest available backend for each operation:
| Operation | Triton (CUDA) | PyTorch fallback | Numba (CPU) |
|---|---|---|---|
| RMSNorm | FusedRMSNorm kernel | Manual rsqrt | — |
| SwiGLU expert | Fused silu(gate)*up |
F.silu inline |
— |
| Flash Attention | Triton kernel | F.scaled_dot_product_attention |
— |
| SSM scan | Parallel assoc. scan | Log-sum-exp PyTorch | Python loop |
| MoE dispatch | Triton scatter-add | Python for e loop |
— |
| BPE tokenise | — | — | @njit merge loop |
No Triton? No problem — every kernel degrades gracefully to pure PyTorch.
| Config | Layers | Hidden | Heads | Experts | PyTorch dtype |
|---|---|---|---|---|---|
small_config() |
24 | 1024 | 8 | 32 (top-2) | bfloat16 |
medium_config() |
32 | 2048 | 16 | 64 (top-4) | bfloat16 |
large_config() |
48 | 4096 | 32 | 256 (top-8) | bfloat16 |
tpu_v3_config() |
24 | 1536 | 12 | 32 (top-2) | bfloat16 / JAX |
tpu_v4_config() |
32 | 4096 | 32 | 128 (top-4) | bfloat16 / JAX |
xpu_config() |
32 | 2048 | 16 | 64 (top-4) | bfloat16 / IPEX |
from vishwamai import small_config, VishwamAI
cfg = small_config()
print(cfg.layer_type(5)) # 'A' — attention layer (every 6th)
print(cfg.layer_type(6)) # 'S' — SSM layer
print(cfg.layer_type(7)) # 'M' — MoE layerVishwamAI/
├── vishwamai/
│ ├── config.py # VishwamAIConfig + layer schedule + preset configs
│ ├── model.py # VishwamAI, VishwamAIBlock, KVCache
│ ├── attention.py # MultiHeadLatentAttention + decoupled RoPE
│ ├── moe.py # SparseMoELayer, TopKRouter, Expert
│ ├── ssm.py # SelectiveSSMLayer (Mamba-2 style)
│ ├── memory.py # TitansNLTM (neural long-term memory)
│ ├── tokenizer.py # SentencePiece / HF / Byte tokenizer backends
│ ├── checkpoint.py # SafeTensors save/load, embedding export
│ ├── training.py # PretrainTrainer + GRPOTrainer + Verifiers
│ ├── anti_hallucination.py # ConfidenceHead + SelfCorrector + OnlineReinforcer
│ ├── mcp_server.py # MCP server (stdio + HTTP/SSE)
│ ├── jax_backend.py # Unified hardware backend switcher
│ └── kernels/
│ ├── rms_norm.py # Triton fused RMSNorm
│ ├── swiglu.py # Triton fused SwiGLU expert
│ ├── flash_attention.py # Triton flash attention
│ ├── ssm_scan.py # Triton/PyTorch/Python SSM scan
│ ├── moe_router.py # Triton MoE scatter-add dispatch
│ ├── tpu_ops.py # JAX/XLA TPU kernels
│ └── cpu_ops.py # Numba BPE + corpus tokenisation
├── scripts/
│ ├── pretrain.py # DDP pretraining (torchrun)
│ └── train_tokenizer.py # HF BPE / SentencePiece tokenizer training
├── tests/
│ └── test_architecture.py # Smoke tests (forward, generate, param count)
├── demo/
│ └── demo_anti_hallucination.py
├── docs/
│ └── vishwamai.tex # Architecture specification (LaTeX)
├── requirements.txt
├── AGENTS.md # AI agent / contributor instructions
└── README.md
| Tier | Hardware | Notes |
|---|---|---|
| Development | RTX 3080 (10 GB) | small_config(), batch 4, bf16 |
| Training | A100 40 GB × 4 | medium_config(), batch 32, bf16 |
| Full scale | H100 × 8+ | large_config(), torchrun, bf16 |
| TPU | v3-8 / v4-32 | tpu_v3_config() / tpu_v4_config() |
| CPU-only | 32 GB RAM | Numba kernels, no Triton, slow |
- Fork the repository
- Read
AGENTS.md— it explains code conventions and PR guidelines - Create a feature branch
- Add tests (
pytest tests/) - Ensure all tests pass:
pytest -x - Submit a pull request
@software{vishwamai2025,
title = {VishwamAI: Hybrid LLM with MLA, Sparse MoE, Selective SSM, and Titans Memory},
author = {Kasinadh Sarma},
year = {2025},
url = {https://github.com/kasinadhsarma/VishwamAI}
}MIT License — see LICENSE for details.
VishwamAI — efficient, honest, and self-correcting language intelligence.