Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/eval-callback/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,6 @@ func main() {
ctxParams.NBatch = 512
ctxParams.NSeqMax = 1
ctxParams.NThreads = int32(*threads)
ctxParams.Logits = 1

// NOTE: In a real implementation, we would set eval callbacks here:
// ctxParams.CbEval = callbackFunctionPointer
Expand Down
1 change: 0 additions & 1 deletion examples/simple-chat-with-loader/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ func main() {
// ctxParams.NUbatch = 512 // Keep default value
ctxParams.NSeqMax = 1 // Set max sequences to 1 for simple use case
ctxParams.NThreads = int32(*threads)
ctxParams.Logits = 1 // true as uint8

fmt.Printf("Setting context size to: %d\n", *ctx)
fmt.Printf("Context params NCtx: %d\n", ctxParams.NCtx)
Expand Down
1 change: 0 additions & 1 deletion examples/simple-chat/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ func main() {
// ctxParams.NUbatch = 512 // Keep default value
ctxParams.NSeqMax = 1 // Set max sequences to 1 for simple use case
ctxParams.NThreads = int32(*threads)
ctxParams.Logits = 1 // true as uint8

fmt.Printf("Setting context size to: %d\n", *ctx)
fmt.Printf("Context params NCtx: %d\n", ctxParams.NCtx)
Expand Down
2 changes: 0 additions & 2 deletions examples/speculative/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ func main() {
ctxParamsTgt.NCtx = uint32(*ctx)
ctxParamsTgt.NThreads = int32(*threads)
ctxParamsTgt.NThreadsBatch = int32(*threads)
ctxParamsTgt.Logits = 1

ctxTgt, err := gollama.Init_from_model(modelTgt, ctxParamsTgt)
if err != nil {
Expand All @@ -156,7 +155,6 @@ func main() {
ctxParamsDft.NCtx = uint32(*ctx)
ctxParamsDft.NThreads = int32(*threads)
ctxParamsDft.NThreadsBatch = int32(*threads)
ctxParamsDft.Logits = 1

ctxDft, err := gollama.Init_from_model(modelDft, ctxParamsDft)
if err != nil {
Expand Down
9 changes: 6 additions & 3 deletions ffi.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@ var (
&ffi.TypeUint8, // use_mlock
&ffi.TypeUint8, // check_tensors
&ffi.TypeUint8, // use_extra_bufts
&ffi.TypeUint8, // no_host
nil,
}[0],
}

// LlamaContextParams FFI type
// Layout MUST match struct llama_context_params in llama.h (b6862).
ffiTypeLlamaContextParams = ffi.Type{
Type: ffi.Struct,
Elements: &[]*ffi.Type{
&ffi.TypeUint32, // seed
&ffi.TypeUint32, // n_ctx
&ffi.TypeUint32, // n_batch
&ffi.TypeUint32, // n_ubatch
Expand All @@ -45,6 +46,7 @@ var (
&ffi.TypeSint32, // rope_scaling_type
&ffi.TypeSint32, // pooling_type
&ffi.TypeSint32, // attention_type
&ffi.TypeSint32, // flash_attn_type
&ffi.TypeFloat, // rope_freq_base
&ffi.TypeFloat, // rope_freq_scale
&ffi.TypeFloat, // yarn_ext_factor
Expand All @@ -59,11 +61,12 @@ var (
&ffi.TypeSint32, // type_v
&ffi.TypePointer, // abort_callback
&ffi.TypePointer, // abort_callback_data
&ffi.TypeUint8, // logits
&ffi.TypeUint8, // embeddings
&ffi.TypeUint8, // offload_kqv
&ffi.TypeUint8, // flash_attn
&ffi.TypeUint8, // no_perf
&ffi.TypeUint8, // op_offload
&ffi.TypeUint8, // swa_full
&ffi.TypeUint8, // kv_unified
nil,
}[0],
}
Expand Down
6 changes: 3 additions & 3 deletions ffi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ func (s *FFISuite) TestFFIContextDefaultParams() {
return
}

s.Assert().NotZero(params.Seed, "Seed should not be zero in default params")
s.Assert().NotZero(params.NBatch, "NBatch should not be zero in default params")
s.T().Logf("FFI Context default params: Seed=%d, NCtx=%d, NBatch=%d, NThreads=%d",
params.Seed, params.NCtx, params.NBatch, params.NThreads)
s.Assert().NotZero(params.NUbatch, "NUbatch should not be zero in default params")
s.T().Logf("FFI Context default params: NCtx=%d, NBatch=%d, NUbatch=%d, NSeqMax=%d, NThreads=%d, FlashAttnType=%d",
params.NCtx, params.NBatch, params.NUbatch, params.NSeqMax, params.NThreads, params.FlashAttnType)
}

// Tests FFI-based sampler chain parameter retrieval
Expand Down
42 changes: 29 additions & 13 deletions gollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ const (
LLAMA_ATTENTION_TYPE_NON_CAUSAL LlamaAttentionType = 1
)

type LlamaFlashAttnType int32

const (
LLAMA_FLASH_ATTN_TYPE_AUTO LlamaFlashAttnType = -1
LLAMA_FLASH_ATTN_TYPE_DISABLED LlamaFlashAttnType = 0
LLAMA_FLASH_ATTN_TYPE_ENABLED LlamaFlashAttnType = 1
)

type LlamaSplitMode int32

const (
Expand Down Expand Up @@ -282,11 +290,15 @@ type LlamaModelParams struct {
UseMlock uint8 // force system to keep model in RAM (bool as uint8)
CheckTensors uint8 // validate model tensor data (bool as uint8)
UseExtraBufts uint8 // use extra buffer types (bool as uint8)
NoHost uint8 // bypass host buffer allowing extra buffers to be used (bool as uint8)
}

// Context parameters
//
// Layout MUST match struct llama_context_params in llama.h for the bundled
// llama.cpp build (b6862). The struct is passed/returned BY VALUE across the
// FFI boundary, so any drift silently lands fields on the wrong C offsets.
type LlamaContextParams struct {
Seed uint32 // RNG seed, -1 for random
NCtx uint32 // text context, 0 = from model
NBatch uint32 // logical maximum batch size
NUbatch uint32 // physical maximum batch size
Expand All @@ -296,25 +308,27 @@ type LlamaContextParams struct {
RopeScalingType LlamaRopeScalingType // RoPE scaling type
PoolingType LlamaPoolingType // pooling type for embeddings
AttentionType LlamaAttentionType // attention type
FlashAttnType LlamaFlashAttnType // when to enable Flash Attention
RopeFreqBase float32 // RoPE base frequency
RopeFreqScale float32 // RoPE frequency scaling factor
YarnExtFactor float32 // YaRN extrapolation mix factor
YarnAttnFactor float32 // YaRN magnitude scaling factor
YarnBetaFast float32 // YaRN low correction dim
YarnBetaSlow float32 // YaRN high correction dim
YarnOrigCtx uint32 // YaRN original context size
DefragThold float32 // defragment the KV cache if holes/size > thold
DefragThold float32 // [DEPRECATED] defragment the KV cache if holes/size > thold
CbEval uintptr // evaluation callback
CbEvalUserData uintptr // user data for evaluation callback
TypeK int32 // data type for K cache
TypeV int32 // data type for V cache
AbortCallback uintptr // abort callback
AbortCallbackData uintptr // user data for abort callback
Logits uint8 // whether to compute and return logits (bool as uint8)
Embeddings uint8 // whether to compute and return embeddings (bool as uint8)
Offload_kqv uint8 // whether to offload K, Q, V to GPU (bool as uint8)
FlashAttn uint8 // whether to use flash attention (bool as uint8)
NoPerf uint8 // whether to measure performance (bool as uint8)
Embeddings uint8 // whether to extract embeddings, together with logits (bool as uint8)
Offload_kqv uint8 // whether to offload KQV ops (incl. KV cache) to GPU (bool as uint8)
NoPerf uint8 // whether to skip performance timings (bool as uint8)
OpOffload uint8 // offload host tensor operations to device (bool as uint8)
SwaFull uint8 // use full-size SWA cache (bool as uint8)
KvUnified uint8 // use a unified KV buffer across input sequences (bool as uint8)
}

// Model quantize parameters
Expand Down Expand Up @@ -929,7 +943,6 @@ func Context_default_params() LlamaContextParams {

// Last resort: return hardcoded defaults
return LlamaContextParams{
Seed: LLAMA_DEFAULT_SEED,
NCtx: 0, // Auto-detect from model
NBatch: 2048,
NUbatch: 512,
Expand All @@ -939,12 +952,14 @@ func Context_default_params() LlamaContextParams {
RopeScalingType: LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
PoolingType: LLAMA_POOLING_TYPE_UNSPECIFIED,
AttentionType: LLAMA_ATTENTION_TYPE_CAUSAL,
FlashAttnType: LLAMA_FLASH_ATTN_TYPE_AUTO,
DefragThold: -1.0, // Disabled by default
Logits: 0, // Disabled by default
Embeddings: 0, // Disabled by default
Offload_kqv: 1, // Enable by default
FlashAttn: 0, // Disabled by default
NoPerf: 0, // Enable performance measurement by default
OpOffload: 1, // Enable by default (matches llama.cpp)
SwaFull: 1, // Enable by default (matches llama.cpp)
KvUnified: 0, // Disabled by default (matches llama.cpp)
}
}

Expand Down Expand Up @@ -1507,7 +1522,6 @@ func ContextDefaultParams() LlamaContextParams {
}
// Return default values for non-Darwin platforms - blocks ROADMAP "wait for purego struct support"
return LlamaContextParams{
Seed: LLAMA_DEFAULT_SEED,
NCtx: 0, // 0 = from model
NBatch: 2048,
NUbatch: 512,
Expand All @@ -1517,6 +1531,7 @@ func ContextDefaultParams() LlamaContextParams {
RopeScalingType: LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
PoolingType: LLAMA_POOLING_TYPE_UNSPECIFIED,
AttentionType: LLAMA_ATTENTION_TYPE_CAUSAL,
FlashAttnType: LLAMA_FLASH_ATTN_TYPE_AUTO,
RopeFreqBase: 0.0, // 0.0 = from model
RopeFreqScale: 0.0, // 0.0 = from model
YarnExtFactor: -1.0,
Expand All @@ -1527,11 +1542,12 @@ func ContextDefaultParams() LlamaContextParams {
DefragThold: -1.0,
TypeK: -1,
TypeV: -1,
Logits: 0,
Embeddings: 0,
Offload_kqv: 1,
FlashAttn: 0,
NoPerf: 0,
OpOffload: 1,
SwaFull: 1,
KvUnified: 0,
}
}

Expand Down