feat(bitnet): native-ternary 1.58 inference end-to-end (microsoft/bitnet-b1.58-2B-4T)#159
feat(bitnet): native-ternary 1.58 inference end-to-end (microsoft/bitnet-b1.58-2B-4T)#159gburd wants to merge 2 commits into
Conversation
|
This is impressive — native I2_S inference running end-to-end on the real 2B model is a great result, and the pieces I'd call lowest-risk are genuinely solid: the The thing I'd want to resolve before merge is architecture, in It's a ~2359-line parallel inference stack — it reimplements GQA attention, RMSNorm (a straight dup of A few more, smaller:
On CI: the |
Adds CPU inference for microsoft/bitnet-b1.58-2B-4T directly against
the native I2_S ternary weights, verified end-to-end (Paris 94.5%).
Pipeline:
larql convert gguf-to-vindex --keep-quant [--dense-only] [--f16] \
--output <vindex> ggml-model-i2_s.gguf
larql-server <vindex> # eager-loads the BitNet model
POST /v1/infer {prompt, top, mode:dense}
larql-compute
ternary_matvec.rs: ternary {-1,0,+1} x f32 matvec kernel
(BitLinearWeight: packed I2_S bytes + per-tensor scale; integer
trit accumulation, scale applied once per row).
larql-models
quant/ggml/tq.rs: correct strided I2_S decode (see the companion
decode-fix PR; included here so the feature is self-contained).
loading/gguf: load_gguf_keep_quant retains raw I2_S bytes +
captures the per-tensor trailing scale (I2S_SCALE_SUFFIX).
larql-vindex
extract/bitnet_writer.rs: writes bitnet/ artifacts (I2_S bytes
re-packed to the kernel's contiguous layout + concatenated
per-row scales) and stamps bitnet_layout into index.json.
extract/bitnet_loader.rs: reads them back into BitLinearWeight.
extract/build: build_vindex_dense_only — skips gate-vector +
HNSW clustering + dense projection duplication (native-ternary
/v1/infer at ~3 GB on-disk, ~1.5 min build vs ~30 min).
config: VindexConfig.bitnet_layout (Option, skip-if-None; dense
vindexes round-trip unchanged).
format/weights/write_f32.rs: WriteWeightsOptions.skip_attn/
skip_ffn; norm writing decoupled from the attention-projection
gate (BitNet needs every RMSNorm incl. attn_sub_norm/ffn_sub_norm
even when the I2_S projections are skipped).
larql-inference
ternary.rs: BitnetModel + BitnetFfn + predict_bitnet (single-shot
prefill forward over 30 layers, GQA attention, RoPE), KV cache +
decode_step + generate, sampling (temperature/top-k/top-p),
load_bitnet_model (assembles dense norms/embed/lm_head + I2_S
BitLinears from a --keep-quant vindex).
larql-cli
convert gguf-to-vindex --keep-quant / --dense-only flags;
--keep-quant rejected at --level browse (needs the dense
norms/embed/lm_head).
Qualified on the real microsoft/bitnet-b1.58-2B-4T GGUF (24-core):
'The capital of France is' -> Paris 94.5%
'Two plus two equals' -> four 30.7%
dense-only build 1.5 min, vindex 3.0 GB, server RSS ~4.8 GB.
cargo test -p larql-inference -p larql-vindex -p larql-models
--lib: 2668 passed, 0 failed.
Server wiring (/v1/infer dispatch, SSE streaming, walk-mode) lives
in larql-server and depends on the server-reliability PR; this PR
is the engine/codec/vindex layer that stands alone and builds
clean on main.
…tatus Addresses PR chrishayuk#159 review (architecture concern + docs), without a risky refactor of the verified forward pass: - ternary.rs module header gains a 'Relationship to the engine machinery' section spelling out, component by component, what is reused vs deliberately forked and why: * RoPE -> reused (attention::rope) * RMSNorm -> NOT reused: residual::rms_norm* allocate an Array2 per call; the hot path needs the alloc-free &[f32]->&mut[f32] rmsnorm_into. Same numerics; collapses to a one-liner if residual gains an _into variant. * GQA attn -> inline: BitNet inserts an extra attn_sub_norm between QK and O, and the projections are ternary (BitLinearWeight), so the dense attention kernels don't apply. * FFN -> ReLU^2 over ternary projections with a mid-FFN sub-norm; matches no existing FFN forward. * KV/sampling -> the genuinely-forkable part; reimplemented for standalone verifiability, to be folded into the engine once the quantized-activation kernel path (FormatRoute/quant-registry) exists. - Documented the dual I2_S layout (strided in tq.rs for GGUF decode, contiguous in the kernel/writer) at the module level \u2014 the latent footgun the reviewer flagged. - ternary_matvec.rs header now states plainly that it is a CORRECT REFERENCE, not yet the optimized fast path: the 'no multiplies' claim describes the algorithm; the current code is still scalar f32 sign-select. The vectorised int8-activation kernel is future work that slots into the backend dispatch. Process note (for the PR thread, not code): the bundled tq.rs decode change is the same fix as the standalone decode PR; once that lands this branch will be rebased to drop the duplicate. Tests unchanged + green: ternary 11 (models) + 27 (compute).
9647313 to
0556661
Compare
|
Thanks — this is the most useful review of the three. Rebased onto main (clears the inherited Architecture — I took the "write down explicitly why and what minimal divergence it needs" option rather than forcing the refactor, because the reuse isn't free in the hot path and I didn't want to regress a forward pass that's verified correct against the real 2B model. The module header now has a component-by-component 'Relationship to the engine machinery' section:
Other items:
|
|
awesome, thanks.. will merge tonight.. apprecitate the contibution. i've been doing lots of refactors, so apologies for the rebase request, so will get this in, before i mess with anything else.. and will cleanup anything from there. thanks again |
Builds on the I2_S quantisation already merged (#148). Adds CPU inference directly against the native I2_S ternary weights — no dequant-to-f16 at convert time — verified end-to-end on the real
microsoft/bitnet-b1.58-2B-4TGGUF.Result
Pipeline
What's here (engine / codec / vindex layer)
BitLinearWeight: packed I2_S + per-tensor scale, integer trit accumulation).load_gguf_keep_quantretains raw I2_S bytes + the per-tensor trailing scale; the strided I2_S decode fix (also filed standalone as the decode-fix PR, included here so this is self-contained).bitnet_writer/bitnet_loader(bitnet/ artifacts +bitnet_layoutin index.json);build_vindex_dense_only(skips gate-vectors + HNSW clustering + dense-projection duplication: ~3 GB on-disk, ~1.5 min build vs ~30 min);WriteWeightsOptions.skip_attn/skip_ffnwith norm-writing decoupled from the attention-projection gate (BitNet needsattn_sub_norm/ffn_sub_normeven when projections are skipped).VindexConfig.bitnet_layoutisOption/skip-if-None so dense vindexes round-trip unchanged.BitnetModel+predict_bitnet(single-shot prefill, GQA + RoPE), KV cache +decode_step+generate, sampling (temp/top-k/top-p),load_bitnet_model.--keep-quant/--dense-onlyflags.cargo test -p larql-inference -p larql-vindex -p larql-models --lib: 2668 passed.Notes
/v1/inferdispatch, SSE streaming, and walk-mode live inlarql-serverand ride on the server-reliability PR (concurrent-load single-flight, timeout); happy to send that server wiring as a follow-up once the reliability + this land.Re: the closed cloud-proxy PR (#149) — fully agree it's out of scope; this is purely local-weights inference.