Skip to content

VishwamAI/VishwamAI

VishwamAI

Hybrid LLM combining Multi-Head Latent Attention · Sparse MoE · Selective SSM · Titans Neural Memory — with built-in anti-hallucination and Triton/Numba acceleration.

Python 3.10+ PyTorch 2.3+ License: MIT


Architecture at a Glance

VishwamAI is a decoder-only language model that combines four complementary subsystems in a single forward pass:

Subsystem Role Efficiency gain
MLA (Multi-Head Latent Attention) Token-to-token reasoning 14× KV-cache reduction
Sparse MoE (SwiGLU experts) Specialised FFN processing Only 3% params fire per token
Selective SSM (Mamba-2 style) Long-context O(n) sequence modelling Linear vs O(n²) attention
Titans LTM (Neural Long-Term Memory) Persistent memory beyond context window Infinite effective context

Layer schedule

50% MoE · 43% SSM · 7% MLA
Attention every 6th layer; SSM everywhere else keeps the cost curve linear.

Each block follows Pre-LN:

x <- x + sublayer(RMSNorm(x))

At attention blocks, the Titans LTM is queried first (MAC — Memory As Context) and added residually before the attention sublayer.


Key Innovations

Multi-Head Latent Attention (MLA)

Standard transformers cache full K and V tensors — expensive at long contexts. MLA caches a compressed latent c_KV of shape (B, T, kv_latent_dim) and a decoupled RoPE key k_rope of shape (B, T, H, rope_head_dim). K, V are decompressed at attention time via learned up-projections.

Cache: kv_latent_dim + H×rope_head_dim  vs  2 × num_heads × head_dim  → ~14× reduction

The RoPE component is cached at absolute positions, giving correct positional encoding during incremental decoding.

Sparse Mixture-of-Experts

Top-k routing with SwiGLU expert FFNs. The load-balance auxiliary loss uses the Switch Transformer formulation:

aux_loss = E × Σ_i (f_i × P_i)

where f_i is the hard dispatch fraction (non-differentiable) and P_i is the soft router probability (differentiable). This ensures gradients flow only through the soft path. Weight is applied once by VishwamAIConfig.moe_aux_loss_weight.

Selective SSM (Mamba-2 style)

A, B, C, and Δ are all input-dependent (selective), giving the model fine-grained control over what to remember across a sequence. Sequential scan replaceable with a parallel associative scan kernel (Triton on GPU).

Titans Neural Long-Term Memory

An online-learned memory matrix M that persists across forward passes. During generation, a surprise-driven momentum SGD step writes new information to M:

error      = v - M @ k                    # prediction error
grad       = -outer(error, k)             # negative gradient of ‖error‖²
momentum   = β·momentum + (1-β)·grad     # EMA on gradient
M         -= lr · momentum               # net: M += lr·EMA(outer(v - M@k, k))

Read: q @ M.T — associative retrieval with no positional limit.

Anti-Hallucination System

Three-layer human-in-the-loop correction system:

  1. ConfidenceHead — per-token confidence score in [0, 1]
  2. SelfCorrector — flags low-confidence spans, runs verifier, proposes retraction — never applied without human approval
  3. OnlineReinforcer — human-triggered only: writes correction to Titans LTM + contrastive gradient step

Retraction format:

I was wrong about: "...". The correct answer is: "...". [reason]

Quick Start

Installation

git clone https://github.com/kasinadhsarma/VishwamAI.git
cd VishwamAI
pip install -r requirements.txt

GPU users (recommended):

pip install triton>=2.3.0     # Triton GPU kernels (optional, auto-detected)
pip install numba>=0.59.0     # Numba CPU JIT (optional, auto-detected)

Inference

import torch
from vishwamai import VishwamAI, small_config
from vishwamai.tokenizer import ByteTokenizer

cfg   = small_config()
model = VishwamAI(cfg).eval()
tok   = ByteTokenizer()

ids     = torch.tensor([tok.encode("The meaning of life is")])
out_ids = model.generate(ids, max_new_tokens=128, temperature=0.8, top_p=0.9)
print(tok.decode(out_ids[0].tolist()))

Anti-Hallucination Inference

from vishwamai import (
    AntiHallucinationVishwamAI, ConfidenceHead,
    VishwamAI, small_config,
)
from vishwamai.tokenizer import ByteTokenizer

cfg        = small_config()
model      = VishwamAI(cfg).eval()
confidence = ConfidenceHead(cfg)
tok        = ByteTokenizer()

ah = AntiHallucinationVishwamAI(model, confidence, tok)
result = ah.ask("What is the capital of Australia?", max_new_tokens=64)

print(result["answer"])       # final answer
print(result["confidence"])   # mean token confidence in [0, 1]
print(result["needs_review"]) # True when a correction was flagged

# Human-triggered correction (updates LTM + model weights):
if result["needs_review"]:
    ah.correct(
        prompt       = "What is the capital of Australia?",
        wrong_answer = result["raw_answer"],
        correction   = "Canberra",
    )

Training

Pretraining (DDP, multi-GPU)

torchrun --nproc_per_node=8 scripts/pretrain.py \
    --config small \
    --data_path /data/tokens.bin \
    --batch_size 32 \
    --max_steps 100000 \
    --lr 3e-4 \
    --output_dir checkpoints/pretrain

GRPO / RLVR Fine-tuning (no human labels)

from vishwamai import VishwamAI, medium_config, GRPOConfig, GRPOTrainer, MathVerifier
from vishwamai.tokenizer import SentencePieceTokenizer

cfg       = medium_config()
model     = VishwamAI(cfg)
ref_model = VishwamAI(cfg)   # frozen reference copy
tok       = SentencePieceTokenizer("tokenizer.model")
verifier  = MathVerifier()

trainer = GRPOTrainer(
    model, ref_model, tok, verifier,
    GRPOConfig(group_size=8, max_new_tokens=512, lr=1e-5),
)
for prompts, answers in dataloader:
    stats = trainer.step(prompts, answers)
    print(stats)

GRPO generates group_size responses per prompt, scores them with a verifier (CodeVerifier / MathVerifier / FormatVerifier), and computes group-relative advantage:

A_i = (r_i - mean(r)) / (std(r) + eps)

Train a tokenizer

python scripts/train_tokenizer.py \
    --corpus /data/text.txt \
    --vocab_size 32000 \
    --output tokenizer.model

Checkpointing with SafeTensors

VishwamAI saves weights in SafeTensors format — no pickle, memory-mapped, cross-framework safe.

from vishwamai import save_model, load_model
from vishwamai.checkpoint import export_embeddings

# Save
save_model(model, "checkpoints/step_10000", step=10000)

# Load
model = load_model("checkpoints/step_10000", device="cuda")

# Export embeddings for downstream use (RAG / vector search)
export_embeddings(model, "embeddings/", fmt="safetensors")
# -> embeddings/embeddings.safetensors

MCP Server

VishwamAI exposes all inference and memory operations via the Model Context Protocol so any MCP-compatible host (Claude Desktop, custom agents) can use it as a first-class AI provider.

# stdio transport (Claude Desktop)
python -m vishwamai.mcp_server --checkpoint checkpoints/small

# HTTP/SSE transport
python -m vishwamai.mcp_server --checkpoint checkpoints/medium \
    --transport http --host 0.0.0.0 --port 8765

Available tools: vishwamai_generate, vishwamai_generate_verified, vishwamai_load_model, vishwamai_save_model, vishwamai_ltm_read, vishwamai_ltm_reset, vishwamai_tokenize, vishwamai_detokenize, vishwamai_correct.


Kernel Acceleration

VishwamAI automatically selects the fastest available backend for each operation:

Operation Triton (CUDA) PyTorch fallback Numba (CPU)
RMSNorm FusedRMSNorm kernel Manual rsqrt
SwiGLU expert Fused silu(gate)*up F.silu inline
Flash Attention Triton kernel F.scaled_dot_product_attention
SSM scan Parallel assoc. scan Log-sum-exp PyTorch Python loop
MoE dispatch Triton scatter-add Python for e loop
BPE tokenise @njit merge loop

No Triton? No problem — every kernel degrades gracefully to pure PyTorch.


Model Configurations

Config Layers Hidden Heads Experts PyTorch dtype
small_config() 24 1024 8 32 (top-2) bfloat16
medium_config() 32 2048 16 64 (top-4) bfloat16
large_config() 48 4096 32 256 (top-8) bfloat16
tpu_v3_config() 24 1536 12 32 (top-2) bfloat16 / JAX
tpu_v4_config() 32 4096 32 128 (top-4) bfloat16 / JAX
xpu_config() 32 2048 16 64 (top-4) bfloat16 / IPEX
from vishwamai import small_config, VishwamAI

cfg = small_config()
print(cfg.layer_type(5))   # 'A' — attention layer (every 6th)
print(cfg.layer_type(6))   # 'S' — SSM layer
print(cfg.layer_type(7))   # 'M' — MoE layer

Project Structure

VishwamAI/
├── vishwamai/
│   ├── config.py              # VishwamAIConfig + layer schedule + preset configs
│   ├── model.py               # VishwamAI, VishwamAIBlock, KVCache
│   ├── attention.py           # MultiHeadLatentAttention + decoupled RoPE
│   ├── moe.py                 # SparseMoELayer, TopKRouter, Expert
│   ├── ssm.py                 # SelectiveSSMLayer (Mamba-2 style)
│   ├── memory.py              # TitansNLTM (neural long-term memory)
│   ├── tokenizer.py           # SentencePiece / HF / Byte tokenizer backends
│   ├── checkpoint.py          # SafeTensors save/load, embedding export
│   ├── training.py            # PretrainTrainer + GRPOTrainer + Verifiers
│   ├── anti_hallucination.py  # ConfidenceHead + SelfCorrector + OnlineReinforcer
│   ├── mcp_server.py          # MCP server (stdio + HTTP/SSE)
│   ├── jax_backend.py         # Unified hardware backend switcher
│   └── kernels/
│       ├── rms_norm.py        # Triton fused RMSNorm
│       ├── swiglu.py          # Triton fused SwiGLU expert
│       ├── flash_attention.py # Triton flash attention
│       ├── ssm_scan.py        # Triton/PyTorch/Python SSM scan
│       ├── moe_router.py      # Triton MoE scatter-add dispatch
│       ├── tpu_ops.py         # JAX/XLA TPU kernels
│       └── cpu_ops.py         # Numba BPE + corpus tokenisation
├── scripts/
│   ├── pretrain.py            # DDP pretraining (torchrun)
│   └── train_tokenizer.py     # HF BPE / SentencePiece tokenizer training
├── tests/
│   └── test_architecture.py   # Smoke tests (forward, generate, param count)
├── demo/
│   └── demo_anti_hallucination.py
├── docs/
│   └── vishwamai.tex          # Architecture specification (LaTeX)
├── requirements.txt
├── AGENTS.md                  # AI agent / contributor instructions
└── README.md

Hardware Requirements

Tier Hardware Notes
Development RTX 3080 (10 GB) small_config(), batch 4, bf16
Training A100 40 GB × 4 medium_config(), batch 32, bf16
Full scale H100 × 8+ large_config(), torchrun, bf16
TPU v3-8 / v4-32 tpu_v3_config() / tpu_v4_config()
CPU-only 32 GB RAM Numba kernels, no Triton, slow

Contributing

  1. Fork the repository
  2. Read AGENTS.md — it explains code conventions and PR guidelines
  3. Create a feature branch
  4. Add tests (pytest tests/)
  5. Ensure all tests pass: pytest -x
  6. Submit a pull request

Citation

@software{vishwamai2025,
  title   = {VishwamAI: Hybrid LLM with MLA, Sparse MoE, Selective SSM, and Titans Memory},
  author  = {Kasinadh Sarma},
  year    = {2025},
  url     = {https://github.com/kasinadhsarma/VishwamAI}
}

License

MIT License — see LICENSE for details.


VishwamAI — efficient, honest, and self-correcting language intelligence.

About

comming soon

Resources

License

Code of conduct

Contributing

Security policy

Stars

Watchers

Forks

Packages

 
 
 

Contributors