Skip to content

feat(bitnet): native-ternary 1.58 inference end-to-end (microsoft/bitnet-b1.58-2B-4T)#159

Open
gburd wants to merge 2 commits into
chrishayuk:mainfrom
gburd:pr/bitnet-inference
Open

feat(bitnet): native-ternary 1.58 inference end-to-end (microsoft/bitnet-b1.58-2B-4T)#159
gburd wants to merge 2 commits into
chrishayuk:mainfrom
gburd:pr/bitnet-inference

Conversation

@gburd

@gburd gburd commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Builds on the I2_S quantisation already merged (#148). Adds CPU inference directly against the native I2_S ternary weights — no dequant-to-f16 at convert time — verified end-to-end on the real microsoft/bitnet-b1.58-2B-4T GGUF.

Result

prompt output
"The capital of France is" Paris 94.5%
"Two plus two equals" four 30.7%

Pipeline

larql convert gguf-to-vindex --keep-quant [--dense-only] [--f16] -o <vindex> ggml-model-i2_s.gguf
larql-server <vindex>                 # eager-loads the BitNet model
POST /v1/infer {"prompt":..., "top":5, "mode":"dense"}

What's here (engine / codec / vindex layer)

  • larql-compute — ternary {-1,0,+1}×f32 matvec kernel (BitLinearWeight: packed I2_S + per-tensor scale, integer trit accumulation).
  • larql-modelsload_gguf_keep_quant retains raw I2_S bytes + the per-tensor trailing scale; the strided I2_S decode fix (also filed standalone as the decode-fix PR, included here so this is self-contained).
  • larql-vindexbitnet_writer/bitnet_loader (bitnet/ artifacts + bitnet_layout in index.json); build_vindex_dense_only (skips gate-vectors + HNSW clustering + dense-projection duplication: ~3 GB on-disk, ~1.5 min build vs ~30 min); WriteWeightsOptions.skip_attn/skip_ffn with norm-writing decoupled from the attention-projection gate (BitNet needs attn_sub_norm/ffn_sub_norm even when projections are skipped). VindexConfig.bitnet_layout is Option/skip-if-None so dense vindexes round-trip unchanged.
  • larql-inferenceBitnetModel + predict_bitnet (single-shot prefill, GQA + RoPE), KV cache + decode_step + generate, sampling (temp/top-k/top-p), load_bitnet_model.
  • larql-cli--keep-quant / --dense-only flags.

cargo test -p larql-inference -p larql-vindex -p larql-models --lib: 2668 passed.

Notes

  • The /v1/infer dispatch, SSE streaming, and walk-mode live in larql-server and ride on the server-reliability PR (concurrent-load single-flight, timeout); happy to send that server wiring as a follow-up once the reliability + this land.
  • Memory caveat for transparency: server RSS is ~4.8 GB on this 2B model — the f16 embeddings + lm_head expand to f32 in RAM and are held as two copies; tying them (BitNet ties lm_head==embed) + f16-resident is a known follow-up toward ~1.4 GB. The dense-only build removes the ~8 GB of walk-mode/projection overhead.

Re: the closed cloud-proxy PR (#149) — fully agree it's out of scope; this is purely local-weights inference.

@chrishayuk

Copy link
Copy Markdown
Owner

This is impressive — native I2_S inference running end-to-end on the real 2B model is a great result, and the pieces I'd call lowest-risk are genuinely solid: the ternary_matvec kernel is clean and well-tested against a dequant reference, and the vindex-format work (bitnet_layout as Option + skip_serializing_if, with a round-trip test proving existing vindexes are unchanged) is properly non-breaking. The --dense-only build path is a nice win too.

The thing I'd want to resolve before merge is architecture, in ternary.rs:

It's a ~2359-line parallel inference stack — it reimplements GQA attention, RMSNorm (a straight dup of larql_compute::residual::rms_norm), the KV cache, prefill/decode_step/generate, and sampling/argmax, all of which already exist in attention/, kv_dispatch, forward/, and the StatePolicy engine framework. It does correctly reuse attention::rope, which shows the seams are reachable. BitNet's real differences — the attention/FFN sub-norms and the squared-ReLU FFN — are legitimate, but they don't require forking the KV-cache, sampling, or argmax. Could you either reuse residual::rms_norm + attention::{gqa,decode} and ride the existing decode path, or write down explicitly why BitNet can't and what minimal divergence it actually needs? Right now it sits beside the engine machinery rather than in it, which is a maintenance cost as both evolve.

A few more, smaller:

  • Re-validate the bundled tq.rs decode change against existing feat/vindex build by ternary quantisation for BitNet 1.58 #148 I2_S consumers — it changes the shared decoder semantics. Ideally land the decode fix (fix(ggml): correct I2_S decode to microsoft strided block layout #156) standalone first so it bakes before this depends on it.
  • Document the dual I2_S layout (strided in tq.rs for GGUF decode, contiguous in the kernel/writer) at the format-spec level — it's intentional but a latent footgun.
  • The kernel is a correct reference, not the fast path yet (the module header's "integer accumulation / no multiplies" is aspirational — it's still scalar f32 × ±1.0). The optimized version (int8-quantized activations + a NEON/AVX2 sign-select kernel, dispatched through the backend) lines up exactly with the FormatRoute/quant-registry work in the roadmap — happy to point at where it'd slot in.

On CI: the decode.rs ~88% coverage failure is not from this PR — that file is unmodified here; it's env-gated branches inherited from the perf work that just landed on main (now fixed there), so a rebase should clear it. The fmt failures split into inherited (clears on rebase) + your new files (cargo fmt).

gburd added 2 commits June 19, 2026 09:53
Adds CPU inference for microsoft/bitnet-b1.58-2B-4T directly against
the native I2_S ternary weights, verified end-to-end (Paris 94.5%).

Pipeline:
  larql convert gguf-to-vindex --keep-quant [--dense-only] [--f16] \
    --output <vindex> ggml-model-i2_s.gguf
  larql-server <vindex>          # eager-loads the BitNet model
  POST /v1/infer {prompt, top, mode:dense}

larql-compute
  ternary_matvec.rs: ternary {-1,0,+1} x f32 matvec kernel
  (BitLinearWeight: packed I2_S bytes + per-tensor scale; integer
  trit accumulation, scale applied once per row).

larql-models
  quant/ggml/tq.rs: correct strided I2_S decode (see the companion
  decode-fix PR; included here so the feature is self-contained).
  loading/gguf: load_gguf_keep_quant retains raw I2_S bytes +
  captures the per-tensor trailing scale (I2S_SCALE_SUFFIX).

larql-vindex
  extract/bitnet_writer.rs: writes bitnet/ artifacts (I2_S bytes
  re-packed to the kernel's contiguous layout + concatenated
  per-row scales) and stamps bitnet_layout into index.json.
  extract/bitnet_loader.rs: reads them back into BitLinearWeight.
  extract/build: build_vindex_dense_only — skips gate-vector +
  HNSW clustering + dense projection duplication (native-ternary
  /v1/infer at ~3 GB on-disk, ~1.5 min build vs ~30 min).
  config: VindexConfig.bitnet_layout (Option, skip-if-None; dense
  vindexes round-trip unchanged).
  format/weights/write_f32.rs: WriteWeightsOptions.skip_attn/
  skip_ffn; norm writing decoupled from the attention-projection
  gate (BitNet needs every RMSNorm incl. attn_sub_norm/ffn_sub_norm
  even when the I2_S projections are skipped).

larql-inference
  ternary.rs: BitnetModel + BitnetFfn + predict_bitnet (single-shot
  prefill forward over 30 layers, GQA attention, RoPE), KV cache +
  decode_step + generate, sampling (temperature/top-k/top-p),
  load_bitnet_model (assembles dense norms/embed/lm_head + I2_S
  BitLinears from a --keep-quant vindex).

larql-cli
  convert gguf-to-vindex --keep-quant / --dense-only flags;
  --keep-quant rejected at --level browse (needs the dense
  norms/embed/lm_head).

Qualified on the real microsoft/bitnet-b1.58-2B-4T GGUF (24-core):
  'The capital of France is' -> Paris 94.5%
  'Two plus two equals'      -> four  30.7%
  dense-only build 1.5 min, vindex 3.0 GB, server RSS ~4.8 GB.

cargo test -p larql-inference -p larql-vindex -p larql-models
--lib: 2668 passed, 0 failed.

Server wiring (/v1/infer dispatch, SSE streaming, walk-mode) lives
in larql-server and depends on the server-reliability PR; this PR
is the engine/codec/vindex layer that stands alone and builds
clean on main.
…tatus

Addresses PR chrishayuk#159 review (architecture concern + docs), without a
risky refactor of the verified forward pass:

- ternary.rs module header gains a 'Relationship to the engine
  machinery' section spelling out, component by component, what is
  reused vs deliberately forked and why:
    * RoPE      -> reused (attention::rope)
    * RMSNorm   -> NOT reused: residual::rms_norm* allocate an
                   Array2 per call; the hot path needs the alloc-free
                   &[f32]->&mut[f32] rmsnorm_into.  Same numerics;
                   collapses to a one-liner if residual gains an
                   _into variant.
    * GQA attn  -> inline: BitNet inserts an extra attn_sub_norm
                   between QK and O, and the projections are ternary
                   (BitLinearWeight), so the dense attention kernels
                   don't apply.
    * FFN       -> ReLU^2 over ternary projections with a mid-FFN
                   sub-norm; matches no existing FFN forward.
    * KV/sampling -> the genuinely-forkable part; reimplemented for
                   standalone verifiability, to be folded into the
                   engine once the quantized-activation kernel path
                   (FormatRoute/quant-registry) exists.

- Documented the dual I2_S layout (strided in tq.rs for GGUF decode,
  contiguous in the kernel/writer) at the module level \u2014 the latent
  footgun the reviewer flagged.

- ternary_matvec.rs header now states plainly that it is a CORRECT
  REFERENCE, not yet the optimized fast path: the 'no multiplies'
  claim describes the algorithm; the current code is still scalar
  f32 sign-select.  The vectorised int8-activation kernel is future
  work that slots into the backend dispatch.

Process note (for the PR thread, not code): the bundled tq.rs decode
change is the same fix as the standalone decode PR; once that lands
this branch will be rebased to drop the duplicate.

Tests unchanged + green: ternary 11 (models) + 27 (compute).
@gburd gburd force-pushed the pr/bitnet-inference branch from 9647313 to 0556661 Compare June 19, 2026 14:23
@gburd

gburd commented Jun 19, 2026

Copy link
Copy Markdown
Contributor Author

Thanks — this is the most useful review of the three. Rebased onto main (clears the inherited decode.rs coverage failure) and cargo fmt'd the new files.

Architecture — I took the "write down explicitly why and what minimal divergence it needs" option rather than forcing the refactor, because the reuse isn't free in the hot path and I didn't want to regress a forward pass that's verified correct against the real 2B model. The module header now has a component-by-component 'Relationship to the engine machinery' section:

  • RoPE — reused (attention::rope), as you noticed.
  • RMSNormnot reused, and here's the concrete reason: larql_compute::residual::rms_norm* take &Array2<f32> and allocate a new array per call; the BitNet forward runs a per-token, allocation-free &[f32] -> &mut [f32] norm (rmsnorm_into) in the inner loop. Routing through the Array2 form adds an allocation per position per layer. Same numerics (same eps, same x/rms*weight). If residual grows an _into variant, this collapses to a one-line call — happy to add that to residual instead if you'd prefer the shared seam.
  • GQA attention — inline because BitNet inserts an extra attn_sub_norm between the QK product and the O projection (no hook in the standard path), and the Q/K/V/O projections are themselves ternary (BitLinearWeight), so they can't call the dense attention kernels regardless.
  • FFN — ReLU² (not SwiGLU) over ternary projections with a mid-FFN ffn_sub_norm.
  • KV cache / decode_step / generate / sampling — agreed, these are the genuinely-forkable parts. Documented as a known maintenance cost, to be folded into the engine once the int8-activation kernel + shared KV path exist (which lines up with the FormatRoute/quant-registry work you pointed at — yes, please point me at where it'd slot in).

Other items:

  • Dual I2_S layout — documented at the module level (strided in tq.rs for GGUF decode, contiguous in the kernel/writer; the writer re-packs). Flagged as the intentional-but-latent footgun it is.
  • Kernel is a reference, not the fast path — corrected the ternary_matvec header: the 'no multiplies' line describes the algorithm; the current code is still scalar f32 sign-select. Marked the vectorised int8 path as future work.
  • Land fix(ggml): correct I2_S decode to microsoft strided block layout #156 first — agreed. fix(ggml): correct I2_S decode to microsoft strided block layout #156 is the standalone decode fix; once it merges I'll rebase this branch to drop the bundled tq.rs change so it depends on the merged version rather than carrying a duplicate.

cargo test -p larql-inference -p larql-compute --lib ternary: 11 + 27 passed.

@chrishayuk

Copy link
Copy Markdown
Owner

awesome, thanks.. will merge tonight.. apprecitate the contibution. i've been doing lots of refactors, so apologies for the rebase request, so will get this in, before i mess with anything else.. and will cleanup anything from there. thanks again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants