diff --git a/examples/eval-callback/main.go b/examples/eval-callback/main.go index e938a80..87f783b 100644 --- a/examples/eval-callback/main.go +++ b/examples/eval-callback/main.go @@ -436,7 +436,6 @@ func main() { ctxParams.NBatch = 512 ctxParams.NSeqMax = 1 ctxParams.NThreads = int32(*threads) - ctxParams.Logits = 1 // NOTE: In a real implementation, we would set eval callbacks here: // ctxParams.CbEval = callbackFunctionPointer diff --git a/examples/simple-chat-with-loader/main.go b/examples/simple-chat-with-loader/main.go index 5629094..b9c35bd 100644 --- a/examples/simple-chat-with-loader/main.go +++ b/examples/simple-chat-with-loader/main.go @@ -116,7 +116,6 @@ func main() { // ctxParams.NUbatch = 512 // Keep default value ctxParams.NSeqMax = 1 // Set max sequences to 1 for simple use case ctxParams.NThreads = int32(*threads) - ctxParams.Logits = 1 // true as uint8 fmt.Printf("Setting context size to: %d\n", *ctx) fmt.Printf("Context params NCtx: %d\n", ctxParams.NCtx) diff --git a/examples/simple-chat/main.go b/examples/simple-chat/main.go index 2e878d7..e7a3d98 100644 --- a/examples/simple-chat/main.go +++ b/examples/simple-chat/main.go @@ -101,7 +101,6 @@ func main() { // ctxParams.NUbatch = 512 // Keep default value ctxParams.NSeqMax = 1 // Set max sequences to 1 for simple use case ctxParams.NThreads = int32(*threads) - ctxParams.Logits = 1 // true as uint8 fmt.Printf("Setting context size to: %d\n", *ctx) fmt.Printf("Context params NCtx: %d\n", ctxParams.NCtx) diff --git a/examples/speculative/main.go b/examples/speculative/main.go index d96dd81..d7454c9 100644 --- a/examples/speculative/main.go +++ b/examples/speculative/main.go @@ -132,7 +132,6 @@ func main() { ctxParamsTgt.NCtx = uint32(*ctx) ctxParamsTgt.NThreads = int32(*threads) ctxParamsTgt.NThreadsBatch = int32(*threads) - ctxParamsTgt.Logits = 1 ctxTgt, err := gollama.Init_from_model(modelTgt, ctxParamsTgt) if err != nil { @@ -156,7 +155,6 @@ func main() { ctxParamsDft.NCtx = uint32(*ctx) ctxParamsDft.NThreads = int32(*threads) ctxParamsDft.NThreadsBatch = int32(*threads) - ctxParamsDft.Logits = 1 ctxDft, err := gollama.Init_from_model(modelDft, ctxParamsDft) if err != nil { diff --git a/ffi.go b/ffi.go index 7e6990e..a43c6fa 100644 --- a/ffi.go +++ b/ffi.go @@ -27,15 +27,16 @@ var ( &ffi.TypeUint8, // use_mlock &ffi.TypeUint8, // check_tensors &ffi.TypeUint8, // use_extra_bufts + &ffi.TypeUint8, // no_host nil, }[0], } // LlamaContextParams FFI type + // Layout MUST match struct llama_context_params in llama.h (b6862). ffiTypeLlamaContextParams = ffi.Type{ Type: ffi.Struct, Elements: &[]*ffi.Type{ - &ffi.TypeUint32, // seed &ffi.TypeUint32, // n_ctx &ffi.TypeUint32, // n_batch &ffi.TypeUint32, // n_ubatch @@ -45,6 +46,7 @@ var ( &ffi.TypeSint32, // rope_scaling_type &ffi.TypeSint32, // pooling_type &ffi.TypeSint32, // attention_type + &ffi.TypeSint32, // flash_attn_type &ffi.TypeFloat, // rope_freq_base &ffi.TypeFloat, // rope_freq_scale &ffi.TypeFloat, // yarn_ext_factor @@ -59,11 +61,12 @@ var ( &ffi.TypeSint32, // type_v &ffi.TypePointer, // abort_callback &ffi.TypePointer, // abort_callback_data - &ffi.TypeUint8, // logits &ffi.TypeUint8, // embeddings &ffi.TypeUint8, // offload_kqv - &ffi.TypeUint8, // flash_attn &ffi.TypeUint8, // no_perf + &ffi.TypeUint8, // op_offload + &ffi.TypeUint8, // swa_full + &ffi.TypeUint8, // kv_unified nil, }[0], } diff --git a/ffi_test.go b/ffi_test.go index fad4e93..bbd0db0 100644 --- a/ffi_test.go +++ b/ffi_test.go @@ -98,10 +98,10 @@ func (s *FFISuite) TestFFIContextDefaultParams() { return } - s.Assert().NotZero(params.Seed, "Seed should not be zero in default params") s.Assert().NotZero(params.NBatch, "NBatch should not be zero in default params") - s.T().Logf("FFI Context default params: Seed=%d, NCtx=%d, NBatch=%d, NThreads=%d", - params.Seed, params.NCtx, params.NBatch, params.NThreads) + s.Assert().NotZero(params.NUbatch, "NUbatch should not be zero in default params") + s.T().Logf("FFI Context default params: NCtx=%d, NBatch=%d, NUbatch=%d, NSeqMax=%d, NThreads=%d, FlashAttnType=%d", + params.NCtx, params.NBatch, params.NUbatch, params.NSeqMax, params.NThreads, params.FlashAttnType) } // Tests FFI-based sampler chain parameter retrieval diff --git a/gollama.go b/gollama.go index eddb167..dcb1bc1 100644 --- a/gollama.go +++ b/gollama.go @@ -190,6 +190,14 @@ const ( LLAMA_ATTENTION_TYPE_NON_CAUSAL LlamaAttentionType = 1 ) +type LlamaFlashAttnType int32 + +const ( + LLAMA_FLASH_ATTN_TYPE_AUTO LlamaFlashAttnType = -1 + LLAMA_FLASH_ATTN_TYPE_DISABLED LlamaFlashAttnType = 0 + LLAMA_FLASH_ATTN_TYPE_ENABLED LlamaFlashAttnType = 1 +) + type LlamaSplitMode int32 const ( @@ -282,11 +290,15 @@ type LlamaModelParams struct { UseMlock uint8 // force system to keep model in RAM (bool as uint8) CheckTensors uint8 // validate model tensor data (bool as uint8) UseExtraBufts uint8 // use extra buffer types (bool as uint8) + NoHost uint8 // bypass host buffer allowing extra buffers to be used (bool as uint8) } // Context parameters +// +// Layout MUST match struct llama_context_params in llama.h for the bundled +// llama.cpp build (b6862). The struct is passed/returned BY VALUE across the +// FFI boundary, so any drift silently lands fields on the wrong C offsets. type LlamaContextParams struct { - Seed uint32 // RNG seed, -1 for random NCtx uint32 // text context, 0 = from model NBatch uint32 // logical maximum batch size NUbatch uint32 // physical maximum batch size @@ -296,6 +308,7 @@ type LlamaContextParams struct { RopeScalingType LlamaRopeScalingType // RoPE scaling type PoolingType LlamaPoolingType // pooling type for embeddings AttentionType LlamaAttentionType // attention type + FlashAttnType LlamaFlashAttnType // when to enable Flash Attention RopeFreqBase float32 // RoPE base frequency RopeFreqScale float32 // RoPE frequency scaling factor YarnExtFactor float32 // YaRN extrapolation mix factor @@ -303,18 +316,19 @@ type LlamaContextParams struct { YarnBetaFast float32 // YaRN low correction dim YarnBetaSlow float32 // YaRN high correction dim YarnOrigCtx uint32 // YaRN original context size - DefragThold float32 // defragment the KV cache if holes/size > thold + DefragThold float32 // [DEPRECATED] defragment the KV cache if holes/size > thold CbEval uintptr // evaluation callback CbEvalUserData uintptr // user data for evaluation callback TypeK int32 // data type for K cache TypeV int32 // data type for V cache AbortCallback uintptr // abort callback AbortCallbackData uintptr // user data for abort callback - Logits uint8 // whether to compute and return logits (bool as uint8) - Embeddings uint8 // whether to compute and return embeddings (bool as uint8) - Offload_kqv uint8 // whether to offload K, Q, V to GPU (bool as uint8) - FlashAttn uint8 // whether to use flash attention (bool as uint8) - NoPerf uint8 // whether to measure performance (bool as uint8) + Embeddings uint8 // whether to extract embeddings, together with logits (bool as uint8) + Offload_kqv uint8 // whether to offload KQV ops (incl. KV cache) to GPU (bool as uint8) + NoPerf uint8 // whether to skip performance timings (bool as uint8) + OpOffload uint8 // offload host tensor operations to device (bool as uint8) + SwaFull uint8 // use full-size SWA cache (bool as uint8) + KvUnified uint8 // use a unified KV buffer across input sequences (bool as uint8) } // Model quantize parameters @@ -929,7 +943,6 @@ func Context_default_params() LlamaContextParams { // Last resort: return hardcoded defaults return LlamaContextParams{ - Seed: LLAMA_DEFAULT_SEED, NCtx: 0, // Auto-detect from model NBatch: 2048, NUbatch: 512, @@ -939,12 +952,14 @@ func Context_default_params() LlamaContextParams { RopeScalingType: LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, PoolingType: LLAMA_POOLING_TYPE_UNSPECIFIED, AttentionType: LLAMA_ATTENTION_TYPE_CAUSAL, + FlashAttnType: LLAMA_FLASH_ATTN_TYPE_AUTO, DefragThold: -1.0, // Disabled by default - Logits: 0, // Disabled by default Embeddings: 0, // Disabled by default Offload_kqv: 1, // Enable by default - FlashAttn: 0, // Disabled by default NoPerf: 0, // Enable performance measurement by default + OpOffload: 1, // Enable by default (matches llama.cpp) + SwaFull: 1, // Enable by default (matches llama.cpp) + KvUnified: 0, // Disabled by default (matches llama.cpp) } } @@ -1507,7 +1522,6 @@ func ContextDefaultParams() LlamaContextParams { } // Return default values for non-Darwin platforms - blocks ROADMAP "wait for purego struct support" return LlamaContextParams{ - Seed: LLAMA_DEFAULT_SEED, NCtx: 0, // 0 = from model NBatch: 2048, NUbatch: 512, @@ -1517,6 +1531,7 @@ func ContextDefaultParams() LlamaContextParams { RopeScalingType: LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, PoolingType: LLAMA_POOLING_TYPE_UNSPECIFIED, AttentionType: LLAMA_ATTENTION_TYPE_CAUSAL, + FlashAttnType: LLAMA_FLASH_ATTN_TYPE_AUTO, RopeFreqBase: 0.0, // 0.0 = from model RopeFreqScale: 0.0, // 0.0 = from model YarnExtFactor: -1.0, @@ -1527,11 +1542,12 @@ func ContextDefaultParams() LlamaContextParams { DefragThold: -1.0, TypeK: -1, TypeV: -1, - Logits: 0, Embeddings: 0, Offload_kqv: 1, - FlashAttn: 0, NoPerf: 0, + OpOffload: 1, + SwaFull: 1, + KvUnified: 0, } }