diff --git a/native/Cargo.lock b/native/Cargo.lock index 0b8f89c..c9d9ffd 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -2185,7 +2185,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -3385,7 +3385,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -3433,7 +3433,7 @@ dependencies = [ "portable-atomic", "portable-atomic-util", "serde_core", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -4848,8 +4848,8 @@ version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" dependencies = [ - "heck 0.4.1", - "itertools 0.10.5", + "heck 0.5.0", + "itertools 0.14.0", "log", "multimap", "once_cell", @@ -4869,7 +4869,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.14.0", "proc-macro2", "quote", "syn 2.0.114", @@ -5525,7 +5525,7 @@ dependencies = [ "once_cell", "socket2 0.6.2", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.60.2", ] [[package]] @@ -6014,7 +6014,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.11.0", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -6920,7 +6920,7 @@ dependencies = [ "getrandom 0.4.1", "once_cell", "rustix 1.1.3", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -7881,7 +7881,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/native/src/delta_reader/distributed.rs b/native/src/delta_reader/distributed.rs index 68b0719..7626924 100644 --- a/native/src/delta_reader/distributed.rs +++ b/native/src/delta_reader/distributed.rs @@ -1015,7 +1015,18 @@ pub fn read_checkpoint_part_arrow_ffi( } } - // 5. Build FFI structs in a temporary vec first, then write all at once. + // 5. Memory pool reservation for Arrow FFI data + let estimated_size: usize = arrays.iter().map(|(a, _)| a.get_buffer_memory_size()).sum(); + let _reservation = crate::memory_pool::MemoryReservation::try_new( + &crate::memory_pool::global_pool(), + estimated_size, + "arrow_ffi", + ) + .unwrap_or_else(|_| { + crate::memory_pool::MemoryReservation::empty(&crate::memory_pool::global_pool(), "arrow_ffi") + }); + + // 6. Build FFI structs in a temporary vec first, then write all at once. // If any schema conversion fails, nothing is written. let mut ffi_pairs: Vec<(FFI_ArrowArray, FFI_ArrowSchema)> = Vec::with_capacity(NUM_COLS); for (i, (array, field)) in arrays.iter().enumerate() { @@ -1026,7 +1037,7 @@ pub fn read_checkpoint_part_arrow_ffi( ffi_pairs.push((ffi_array, ffi_schema)); } - // 6. All FFI structs built successfully — write them all out. + // 7. All FFI structs built successfully — write them all out. for (i, (ffi_array, ffi_schema)) in ffi_pairs.into_iter().enumerate() { let array_ptr = array_addrs[i] as *mut FFI_ArrowArray; let schema_ptr = schema_addrs[i] as *mut FFI_ArrowSchema; diff --git a/native/src/disk_cache/background.rs b/native/src/disk_cache/background.rs index 48ae6b0..9a8290c 100644 --- a/native/src/disk_cache/background.rs +++ b/native/src/disk_cache/background.rs @@ -117,9 +117,16 @@ impl L2DiskCache { // For size-based mode: drain bytes and notify waiting senders if let Some((queued_bytes, backpressure)) = sb_state { - queued_bytes.fetch_sub(data_len, Ordering::Release); + let remaining = queued_bytes.fetch_sub(data_len, Ordering::Release) - data_len; let (_lock, cvar) = &*backpressure; cvar.notify_all(); + + // When queue fully drains, release overflow memory back to pool + if remaining == 0 { + if let Some(budget) = &cache.memory_budget { + budget.on_queue_drained(); + } + } } drop(permit); // Release permit when write completes diff --git a/native/src/disk_cache/mod.rs b/native/src/disk_cache/mod.rs index a392972..b701ba4 100644 --- a/native/src/disk_cache/mod.rs +++ b/native/src/disk_cache/mod.rs @@ -52,6 +52,7 @@ use std::sync::{Arc, Condvar, Mutex, RwLock}; use std::time::Duration; use crate::debug_println; +use crate::memory_pool::{self, DiskCacheMemoryBudget}; use tantivy::directory::OwnedBytes; use lru::SplitLruTable; @@ -183,6 +184,8 @@ pub struct L2DiskCache { thread_handles: Mutex>>, /// Dirty flag - set when manifest has uncommitted changes manifest_dirty: Arc, + /// Memory budget for write queue (staircase-up/cliff-down pattern) + memory_budget: Option, } #[allow(dead_code)] @@ -240,6 +243,39 @@ impl L2DiskCache { let manifest_dirty = Arc::new(std::sync::atomic::AtomicBool::new(false)); + // Create memory budget for the write queue — only when a JVM pool is + // explicitly configured. With UnlimitedMemoryPool (default), the write + // queue operates with its static max size and no expand/contract behavior. + let memory_budget = if memory_pool::is_pool_configured() { + let max_budget = config.max_write_queue_budget as usize; // 0 = default (8x) + match &config.write_queue_mode { + WriteQueueMode::SizeBased { max_bytes } => { + Some(DiskCacheMemoryBudget::with_config( + &memory_pool::global_pool(), + *max_bytes as usize, + 500 * 1024 * 1024, // 500MB grow increment + max_budget, + )) + } + WriteQueueMode::Fragment { capacity } => { + // Estimate: each fragment slot can hold ~1MB of data + let estimated_bytes = (*capacity as usize) * 1024 * 1024; + if estimated_bytes > 0 { + Some(DiskCacheMemoryBudget::with_config( + &memory_pool::global_pool(), + estimated_bytes, + 500 * 1024 * 1024, + max_budget, + )) + } else { + None + } + } + } + } else { + None // UnlimitedMemoryPool — no budget tracking, static queue size + }; + let cache = Arc::new(Self { config: config.clone(), manifest: RwLock::new(manifest), @@ -252,6 +288,7 @@ impl L2DiskCache { shutdown_flag: Arc::clone(&shutdown_flag), thread_handles: Mutex::new(Vec::new()), manifest_dirty: Arc::clone(&manifest_dirty), + memory_budget, }); // Start background writer (uses Weak reference - doesn't prevent Drop) @@ -411,9 +448,22 @@ impl L2DiskCache { ) } + /// Try to expand the memory budget to accommodate new data. + /// Returns true if the budget has room (or was successfully expanded), false if denied. + fn try_expand_budget(&self, needed: usize) -> bool { + match &self.memory_budget { + Some(budget) => budget.ensure_capacity(needed), + None => true, // No budget configured — always allow + } + } + /// Cache data (async write via background thread). /// Blocks if the write queue is full (backpressure). /// Use this for prewarm operations where data must be written. + /// + /// Tries to expand the memory budget first. If the pool denies expansion, + /// proceeds with a blocking write anyway — the background writer will drain + /// the queue naturally, and the user explicitly requested this data be cached. pub fn put( &self, storage_loc: &str, @@ -431,7 +481,16 @@ impl L2DiskCache { self.trigger_eviction((self.max_bytes * 90) / 100); } - // Send to background writer with backpressure. + // Try to expand budget. If denied, proceed anyway — prewarm blocks on + // the channel send until the background writer drains and frees queue space. + if !self.try_expand_budget(data.len()) { + debug_println!( + "⚠️ L2DiskCache::put (prewarm): Memory budget denied expansion, \ + will block on queue backpressure until background writer drains" + ); + } + + // Send to background writer with backpressure (blocks if queue is full). let _ = self.write_tx.send(WriteRequest::Put { storage_loc: storage_loc.to_string(), split_id: split_id.to_string(), @@ -444,6 +503,9 @@ impl L2DiskCache { /// Cache data if the write queue has capacity, otherwise drop silently. /// Returns `true` if the write was enqueued, `false` if dropped. /// Use this for query-path opportunistic caching where dropping is acceptable. + /// + /// Tries to expand the memory budget first. If the pool denies expansion, + /// drops the entry — a cache miss is preferable to over-allocating. pub fn put_if_ready( &self, storage_loc: &str, @@ -459,6 +521,12 @@ impl L2DiskCache { self.trigger_eviction((self.max_bytes * 90) / 100); } + // Try to expand budget. If denied, drop — query path should not over-allocate. + if !self.try_expand_budget(data.len()) { + debug_println!("⚠️ L2DiskCache::put_if_ready: Memory budget denied expansion, dropping cache entry"); + return false; + } + self.write_tx.send_or_drop(WriteRequest::Put { storage_loc: storage_loc.to_string(), split_id: split_id.to_string(), @@ -474,8 +542,8 @@ impl L2DiskCache { } /// Cache data for the query path — blocks or drops depending on config. - /// When `drop_writes_when_full` is enabled, silently drops writes if the queue is full. - /// When disabled, behaves identically to `put()` (blocks until enqueued). + /// Tries to expand the memory budget first. If the pool denies expansion, + /// drops the entry — a cache miss is preferable to over-allocating. pub fn put_query_path( &self, storage_loc: &str, @@ -484,6 +552,12 @@ impl L2DiskCache { byte_range: Option>, data: &[u8], ) { + // Try to expand budget. If denied, drop — query path should not over-allocate. + if !self.try_expand_budget(data.len()) { + debug_println!("⚠️ L2DiskCache::put_query_path: Memory budget denied expansion, dropping cache entry"); + return; + } + if self.config.drop_writes_when_full { self.put_if_ready(storage_loc, split_id, component, byte_range, data); } else { diff --git a/native/src/disk_cache/types.rs b/native/src/disk_cache/types.rs index a001c98..b7ef288 100644 --- a/native/src/disk_cache/types.rs +++ b/native/src/disk_cache/types.rs @@ -73,6 +73,10 @@ pub struct DiskCacheConfig { /// When true, non-prewarm (query-path) writes are dropped if the write queue is full /// instead of blocking. Prewarm writes always block. Default: false (all writes block). pub drop_writes_when_full: bool, + /// Maximum memory budget for the write queue (bytes). Controls the hard cap for + /// staircase-up growth. 0 = default (8x the initial write queue size). + /// Only effective when a JVM memory pool is configured. + pub max_write_queue_budget: u64, } impl Default for DiskCacheConfig { @@ -86,6 +90,7 @@ impl Default for DiskCacheConfig { mmap_cache_size: DEFAULT_MMAP_CACHE_SIZE, write_queue_mode: WriteQueueMode::default(), drop_writes_when_full: false, + max_write_queue_budget: 0, // Default: 8x initial } } } diff --git a/native/src/global_cache/l1_cache.rs b/native/src/global_cache/l1_cache.rs index 506461d..99c4b48 100644 --- a/native/src/global_cache/l1_cache.rs +++ b/native/src/global_cache/l1_cache.rs @@ -6,6 +6,7 @@ use std::sync::OnceLock; use quickwit_storage::ByteRangeCache; use crate::debug_println; +use crate::memory_pool::{self, MemoryReservation}; /// Global L1 ByteRangeCache shared across all SplitSearcher instances /// This provides memory-efficient caching - one bounded cache instead of per-split caches @@ -19,6 +20,13 @@ static CONFIGURED_L1_CACHE_CAPACITY: OnceLock>> = /// When true, all storage requests bypass L1 memory cache and go to L2 disk cache / L3 storage static DISABLE_L1_CACHE: OnceLock> = OnceLock::new(); +/// Memory reservation for the L1 cache capacity. Released when cache is reset or process exits. +static L1_CACHE_RESERVATION: OnceLock>> = OnceLock::new(); + +fn get_l1_reservation_holder() -> &'static std::sync::Mutex> { + L1_CACHE_RESERVATION.get_or_init(|| std::sync::Mutex::new(None)) +} + fn get_l1_cache_holder() -> &'static std::sync::RwLock> { GLOBAL_L1_CACHE.get_or_init(|| std::sync::RwLock::new(None)) } @@ -53,6 +61,11 @@ pub fn reset_global_l1_cache() { *guard = None; } + // Release the memory reservation + { + *get_l1_reservation_holder().lock().unwrap() = None; + } + // Then clear the cache itself (will be recreated on next access) { let holder = get_l1_cache_holder(); @@ -93,6 +106,24 @@ pub fn get_or_create_global_l1_cache() -> Option { // Create bounded L1 cache with configurable capacity let capacity = get_l1_cache_capacity_bytes(); + + // Reserve memory from the global pool for L1 cache — fail-fast if denied + let reservation = match MemoryReservation::try_new( + &memory_pool::global_pool(), + capacity as usize, + "l1_cache", + ) { + Ok(r) => r, + Err(e) => { + debug_println!( + "❌ GLOBAL_L1_CACHE: Memory pool denied L1 reservation of {} MB: {}. L1 cache will not be created.", + capacity / 1024 / 1024, e + ); + return None; + } + }; + *get_l1_reservation_holder().lock().unwrap() = Some(reservation); + let cache = ByteRangeCache::with_capacity( capacity, &quickwit_storage::STORAGE_METRICS.shortlived_cache, diff --git a/native/src/iceberg_reader/distributed.rs b/native/src/iceberg_reader/distributed.rs index 33fd650..e25ae88 100644 --- a/native/src/iceberg_reader/distributed.rs +++ b/native/src/iceberg_reader/distributed.rs @@ -360,7 +360,18 @@ pub fn read_iceberg_manifest_arrow_ffi( } } - // 5. Build FFI structs in a temporary vec first, then write all at once. + // 5. Memory pool reservation for Arrow FFI data + let estimated_size: usize = arrays.iter().map(|(a, _)| a.get_buffer_memory_size()).sum(); + let _reservation = crate::memory_pool::MemoryReservation::try_new( + &crate::memory_pool::global_pool(), + estimated_size, + "arrow_ffi", + ) + .unwrap_or_else(|_| { + crate::memory_pool::MemoryReservation::empty(&crate::memory_pool::global_pool(), "arrow_ffi") + }); + + // 6. Build FFI structs in a temporary vec first, then write all at once. // If any schema conversion fails, nothing is written. let mut ffi_pairs: Vec<(FFI_ArrowArray, FFI_ArrowSchema)> = Vec::with_capacity(NUM_COLS); for (i, (array, field)) in arrays.iter().enumerate() { @@ -371,7 +382,7 @@ pub fn read_iceberg_manifest_arrow_ffi( ffi_pairs.push((ffi_array, ffi_schema)); } - // 6. All FFI structs built successfully — write them all out. + // 7. All FFI structs built successfully — write them all out. for (i, (ffi_array, ffi_schema)) in ffi_pairs.into_iter().enumerate() { let array_ptr = array_addrs[i] as *mut FFI_ArrowArray; let schema_ptr = schema_addrs[i] as *mut FFI_ArrowSchema; diff --git a/native/src/index.rs b/native/src/index.rs index b62545a..6238b65 100644 --- a/native/src/index.rs +++ b/native/src/index.rs @@ -27,8 +27,17 @@ use tantivy::directory::MmapDirectory; use tantivy::tokenizer::{SimpleTokenizer, WhitespaceTokenizer, LowerCaser, RemoveLongFilter, TextAnalyzer as TantivyAnalyzer}; use crate::utils::{handle_error, with_arc_safe, arc_to_jlong, release_arc}; use crate::text_analyzer::DEFAULT_MAX_TOKEN_LENGTH; +use crate::memory_pool::{self, MemoryReservation}; +use std::collections::HashMap; use std::sync::{Arc, Mutex}; +use once_cell::sync::Lazy; + +/// Tracks MemoryReservations for IndexWriter instances, keyed by their registry ID. +/// When a writer is closed, its reservation is removed and dropped, releasing memory. +pub(crate) static WRITER_RESERVATIONS: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + #[no_mangle] pub extern "system" fn Java_io_indextables_tantivy4java_core_Index_nativeNew( mut env: JNIEnv, @@ -172,21 +181,38 @@ pub extern "system" fn Java_io_indextables_tantivy4java_core_Index_nativeWriter( heap_size: jint, num_threads: jint, ) -> jlong { + let heap_size_bytes = if heap_size > 0 { heap_size as usize } else { 50_000_000 }; // 50MB default + let num_threads_val = if num_threads > 0 { num_threads as usize } else { 1 }; + + // Reserve memory from the global pool for this writer's heap + let reservation = match MemoryReservation::try_new( + &memory_pool::global_pool(), + heap_size_bytes, + "index_writer", + ) { + Ok(r) => r, + Err(e) => { + handle_error(&mut env, &format!("Memory pool denied IndexWriter allocation: {}", e)); + return 0; + } + }; + let result = with_arc_safe::, Result>(ptr, |index_mutex| { let index = index_mutex.lock().unwrap(); - let heap_size_bytes = if heap_size > 0 { heap_size as usize } else { 50_000_000 }; // 50MB default - let num_threads_val = if num_threads > 0 { num_threads as usize } else { 1 }; - index.writer_with_num_threads(num_threads_val, heap_size_bytes) .map_err(|e| e.to_string()) }); - + match result { Some(Ok(writer)) => { let writer_arc = Arc::new(Mutex::new(writer)); - arc_to_jlong(writer_arc) + let writer_id = arc_to_jlong(writer_arc); + // Store reservation keyed by writer ID — released on nativeClose + WRITER_RESERVATIONS.lock().unwrap().insert(writer_id, reservation); + writer_id }, Some(Err(err)) => { + // reservation is dropped here, releasing the memory handle_error(&mut env, &err); 0 }, diff --git a/native/src/lib.rs b/native/src/lib.rs index 295da3d..dcccb73 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -22,6 +22,7 @@ use jni::sys::jstring; use jni::JNIEnv; mod debug; // Debug utilities and conditional logging +pub mod memory_pool; // Unified JVM-coordinated memory management mod runtime_manager; // Global Quickwit runtime manager for async-first architecture mod schema; mod document; @@ -68,6 +69,7 @@ pub extern "system" fn Java_io_indextables_tantivy4java_core_Tantivy_getVersion( _class: JClass, ) -> jstring { utils::install_panic_hook(); + utils::set_jvm(&env); let version = env.new_string("0.24.0").unwrap(); version.into_raw() } diff --git a/native/src/memory_pool/disk_cache_budget.rs b/native/src/memory_pool/disk_cache_budget.rs new file mode 100644 index 0000000..d0b117a --- /dev/null +++ b/native/src/memory_pool/disk_cache_budget.rs @@ -0,0 +1,468 @@ +// memory_pool/disk_cache_budget.rs - Staircase-up/cliff-down memory budget for L2 disk cache +// +// Pattern: +// - Start: acquire `base_grant` (configured max) from pool +// - Grow: when queue needs more, acquire `grow_increment` (500MB) chunks +// - Cliff-down: when queue drains to 0, release all overflow above base_grant +// - Drop: release everything including base_grant +// +// This minimizes JNI round-trips by keeping the base grant permanently and only +// making JNI calls when growing or when the queue empties completely. + +use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use crate::debug_println; +use super::pool::MemoryPool; + +/// Default grow increment: 500MB +const DEFAULT_GROW_INCREMENT: usize = 500 * 1024 * 1024; + +/// Default max grant multiplier: 8x the base grant +const DEFAULT_MAX_GRANT_MULTIPLIER: usize = 8; + +/// Memory budget for the L2 disk cache write queue. +/// +/// Manages a staircase-up/cliff-down grant pattern: +/// - Acquires `base_grant` upfront at creation +/// - Grows in `grow_increment` steps when needed (up to `max_grant` cap) +/// - Releases overflow (above base) when queue drains completely +/// - Releases base grant only on Drop +pub struct DiskCacheMemoryBudget { + pool: Arc, + /// Base grant that is retained permanently (configured max write queue size) + base_grant: usize, + /// Current total granted from pool (base + overflow increments) + total_granted: AtomicUsize, + /// Size of each growth increment + grow_increment: usize, + /// Maximum total grant (hard cap). Growth beyond this is denied. + max_grant: usize, + /// Serializes ensure_capacity and on_queue_drained to prevent races + /// between concurrent grow + drain operations on total_granted. + op_lock: Mutex<()>, +} + +impl DiskCacheMemoryBudget { + /// Create a new budget, acquiring `base_grant` bytes from the pool. + /// Max grant defaults to 8x the base grant. + /// + /// If the pool denies the base grant, returns a budget with 0 base + /// (best-effort operation continues without memory tracking). + pub fn new(pool: &Arc, base_grant: usize) -> Self { + Self::with_config(pool, base_grant, DEFAULT_GROW_INCREMENT, 0) + } + + /// Create with a custom grow increment. Max grant defaults to 8x the base grant. + pub fn with_increment( + pool: &Arc, + base_grant: usize, + grow_increment: usize, + ) -> Self { + Self::with_config(pool, base_grant, grow_increment, 0) + } + + /// Create with custom grow increment and max grant cap. + /// If `max_grant` is 0, defaults to `DEFAULT_MAX_GRANT_MULTIPLIER * base_grant`. + pub fn with_config( + pool: &Arc, + base_grant: usize, + grow_increment: usize, + max_grant: usize, + ) -> Self { + let actual_base = if base_grant > 0 { + match pool.try_acquire(base_grant, "l2_write_queue") { + Ok(()) => { + debug_println!( + "📊 DiskCacheMemoryBudget: Acquired base grant of {} MB", + base_grant / 1024 / 1024 + ); + base_grant + } + Err(e) => { + debug_println!( + "⚠️ DiskCacheMemoryBudget: Pool denied base grant of {} MB: {}. Operating untracked.", + base_grant / 1024 / 1024, e + ); + 0 + } + } + } else { + 0 + }; + + let effective_max = if max_grant > 0 { + max_grant.max(actual_base) // max_grant must be >= base + } else { + actual_base.saturating_mul(DEFAULT_MAX_GRANT_MULTIPLIER) + }; + + debug_println!( + "📊 DiskCacheMemoryBudget: base={} MB, max={} MB, increment={} MB", + actual_base / 1024 / 1024, + effective_max / 1024 / 1024, + grow_increment / 1024 / 1024 + ); + + Self { + pool: Arc::clone(pool), + base_grant: actual_base, + total_granted: AtomicUsize::new(actual_base), + grow_increment, + max_grant: effective_max, + op_lock: Mutex::new(()), + } + } + + /// Ensure there is enough grant for `needed_bytes`. + /// If current grant is insufficient, acquires more in `grow_increment` chunks, + /// up to the `max_grant` cap. Returns true if sufficient grant is available, false if denied. + pub fn ensure_capacity(&self, needed_bytes: usize) -> bool { + let _guard = self.op_lock.lock().unwrap(); + + let current = self.total_granted.load(Ordering::Acquire); + if current >= needed_bytes { + return true; + } + + // Check if we've already hit the cap + if current >= self.max_grant { + debug_println!( + "⚠️ DiskCacheMemoryBudget: At max grant cap ({} MB), cannot grow", + self.max_grant / 1024 / 1024 + ); + return false; + } + + // Need to grow — calculate how many increments, clamped to max_grant + let deficit = needed_bytes - current; + let increments = (deficit + self.grow_increment - 1) / self.grow_increment; + let mut grow_amount = increments * self.grow_increment; + + // Clamp so total_granted + grow_amount <= max_grant + let headroom = self.max_grant - current; + if grow_amount > headroom { + grow_amount = headroom; + } + + // If clamped growth won't satisfy the need, don't bother growing + if grow_amount == 0 || current + grow_amount < needed_bytes { + debug_println!( + "⚠️ DiskCacheMemoryBudget: Cannot satisfy {} MB (max cap {} MB)", + needed_bytes / 1024 / 1024, + self.max_grant / 1024 / 1024 + ); + return false; + } + + match self.pool.try_acquire(grow_amount, "l2_write_queue") { + Ok(()) => { + self.total_granted.fetch_add(grow_amount, Ordering::Release); + debug_println!( + "📊 DiskCacheMemoryBudget: Grew by {} MB (total now {} MB, max {} MB)", + grow_amount / 1024 / 1024, + (current + grow_amount) / 1024 / 1024, + self.max_grant / 1024 / 1024 + ); + true + } + Err(_) => { + debug_println!( + "⚠️ DiskCacheMemoryBudget: Pool denied growth of {} MB", + grow_amount / 1024 / 1024 + ); + false + } + } + } + + /// Called when the write queue drains completely (queued_bytes == 0). + /// Releases all overflow above the base grant back to the pool. + pub fn on_queue_drained(&self) { + let _guard = self.op_lock.lock().unwrap(); + + let current = self.total_granted.load(Ordering::Acquire); + if current > self.base_grant { + let overflow = current - self.base_grant; + self.pool.release(overflow, "l2_write_queue"); + self.total_granted.store(self.base_grant, Ordering::Release); + debug_println!( + "📊 DiskCacheMemoryBudget: Queue drained, released {} MB overflow (retained {} MB base)", + overflow / 1024 / 1024, + self.base_grant / 1024 / 1024 + ); + } + } + + /// Current total granted bytes. + pub fn total_granted(&self) -> usize { + self.total_granted.load(Ordering::Relaxed) + } + + /// Base grant bytes (permanent floor). + pub fn base_grant(&self) -> usize { + self.base_grant + } + + /// Maximum grant cap (hard ceiling for growth). + pub fn max_grant(&self) -> usize { + self.max_grant + } +} + +impl Drop for DiskCacheMemoryBudget { + fn drop(&mut self) { + let total = self.total_granted.load(Ordering::Acquire); + if total > 0 { + self.pool.release(total, "l2_write_queue"); + debug_println!( + "📊 DiskCacheMemoryBudget: Released all {} MB on shutdown", + total / 1024 / 1024 + ); + } + } +} + +impl std::fmt::Debug for DiskCacheMemoryBudget { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DiskCacheMemoryBudget") + .field("base_grant", &self.base_grant) + .field("total_granted", &self.total_granted.load(Ordering::Relaxed)) + .field("max_grant", &self.max_grant) + .field("grow_increment", &self.grow_increment) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory_pool::UnlimitedMemoryPool; + + #[test] + fn test_budget_basic_lifecycle() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let base = 100 * 1024 * 1024; // 100MB + + { + let budget = DiskCacheMemoryBudget::new(&pool, base); + assert_eq!(budget.base_grant(), base); + assert_eq!(budget.total_granted(), base); + assert_eq!(pool.used(), base); + } + // Dropped — all released + assert_eq!(pool.used(), 0); + } + + #[test] + fn test_budget_grow_and_drain() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let base = 100 * 1024 * 1024; + let increment = 50 * 1024 * 1024; + + let budget = DiskCacheMemoryBudget::with_increment(&pool, base, increment); + assert_eq!(pool.used(), base); + + // Need more than base + let needed = base + 30 * 1024 * 1024; + assert!(budget.ensure_capacity(needed)); + assert_eq!(budget.total_granted(), base + increment); // Grew by one increment + assert_eq!(pool.used(), base + increment); + + // Queue drains — release overflow back to base + budget.on_queue_drained(); + assert_eq!(budget.total_granted(), base); + assert_eq!(pool.used(), base); + } + + #[test] + fn test_budget_multiple_grows() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let base = 50 * 1024 * 1024; + let increment = 25 * 1024 * 1024; + + let budget = DiskCacheMemoryBudget::with_increment(&pool, base, increment); + + // Need 120MB total — should grow by 3 increments (75MB) + assert!(budget.ensure_capacity(120 * 1024 * 1024)); + assert_eq!(budget.total_granted(), base + 3 * increment); + + // Drain — back to base + budget.on_queue_drained(); + assert_eq!(budget.total_granted(), base); + assert_eq!(pool.used(), base); + } + + #[test] + fn test_budget_drain_at_base_is_noop() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let base = 100 * 1024 * 1024; + + let budget = DiskCacheMemoryBudget::new(&pool, base); + budget.on_queue_drained(); // No overflow to release + assert_eq!(budget.total_granted(), base); + assert_eq!(pool.used(), base); + } + + #[test] + fn test_budget_zero_base() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let budget = DiskCacheMemoryBudget::new(&pool, 0); + assert_eq!(budget.base_grant(), 0); + assert_eq!(budget.total_granted(), 0); + assert_eq!(pool.used(), 0); + } + + // ======================================================================== + // Fail-fast denial tests + // ======================================================================== + + #[test] + fn test_budget_base_denied_when_pool_full() { + use crate::memory_pool::LimitedMemoryPool; + let pool: Arc = Arc::new(LimitedMemoryPool::new(50 * 1024 * 1024)); + + // Request base grant larger than pool capacity + let budget = DiskCacheMemoryBudget::new(&pool, 100 * 1024 * 1024); + // Budget should fall back to 0 base when pool denies + assert_eq!(budget.base_grant(), 0); + assert_eq!(budget.total_granted(), 0); + assert_eq!(pool.used(), 0); + } + + #[test] + fn test_budget_ensure_capacity_denied_returns_false() { + use crate::memory_pool::LimitedMemoryPool; + // Pool has 200MB capacity + let pool: Arc = Arc::new(LimitedMemoryPool::new(200 * 1024 * 1024)); + let base = 100 * 1024 * 1024; + let increment = 50 * 1024 * 1024; + + let budget = DiskCacheMemoryBudget::with_increment(&pool, base, increment); + assert_eq!(budget.base_grant(), base); + + // Need 250MB — pool only has 200MB total (100MB already used for base) + // Growth would need 150MB (3 increments) but only 100MB available + let result = budget.ensure_capacity(250 * 1024 * 1024); + assert!(!result, "Should return false when pool denies growth"); + + // total_granted should not have changed (no partial growth) + assert_eq!(budget.total_granted(), base); + assert_eq!(pool.used(), base); + } + + // ======================================================================== + // Max grant cap tests + // ======================================================================== + + #[test] + fn test_budget_default_max_grant_is_8x_base() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let base = 100 * 1024 * 1024; // 100MB + + let budget = DiskCacheMemoryBudget::new(&pool, base); + assert_eq!(budget.max_grant(), base * 8); // Default: 8x + } + + #[test] + fn test_budget_custom_max_grant() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let base = 100 * 1024 * 1024; + let max = 300 * 1024 * 1024; // 300MB cap + + let budget = DiskCacheMemoryBudget::with_config( + &pool, base, 50 * 1024 * 1024, max, + ); + assert_eq!(budget.max_grant(), max); + } + + #[test] + fn test_budget_growth_clamped_by_max_grant() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let base = 100 * 1024 * 1024; // 100MB + let increment = 50 * 1024 * 1024; // 50MB + let max = 200 * 1024 * 1024; // 200MB cap + + let budget = DiskCacheMemoryBudget::with_config(&pool, base, increment, max); + assert_eq!(budget.base_grant(), base); + assert_eq!(budget.max_grant(), max); + + // Request 250MB — exceeds max_grant (200MB). Cannot satisfy, no growth occurs. + let result = budget.ensure_capacity(250 * 1024 * 1024); + assert!(!result); + // No partial growth — budget stays at base + assert_eq!(budget.total_granted(), base); + assert_eq!(pool.used(), base); + + // Request exactly 200MB (at the cap) — succeeds with 100MB growth + assert!(budget.ensure_capacity(200 * 1024 * 1024)); + assert_eq!(budget.total_granted(), max); + assert_eq!(pool.used(), max); + } + + #[test] + fn test_budget_growth_stops_at_max_grant() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let base = 50 * 1024 * 1024; // 50MB + let increment = 25 * 1024 * 1024; // 25MB + let max = 100 * 1024 * 1024; // 100MB cap + + let budget = DiskCacheMemoryBudget::with_config(&pool, base, increment, max); + + // Grow to 75MB (one increment) + assert!(budget.ensure_capacity(75 * 1024 * 1024)); + assert_eq!(budget.total_granted(), 75 * 1024 * 1024); + + // Grow to 100MB (another increment, hits cap exactly) + assert!(budget.ensure_capacity(100 * 1024 * 1024)); + assert_eq!(budget.total_granted(), 100 * 1024 * 1024); + + // Try to grow beyond cap — denied + assert!(!budget.ensure_capacity(101 * 1024 * 1024)); + assert_eq!(budget.total_granted(), 100 * 1024 * 1024); + } + + #[test] + fn test_budget_drain_resets_then_can_grow_again() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let base = 50 * 1024 * 1024; + let increment = 25 * 1024 * 1024; + let max = 100 * 1024 * 1024; + + let budget = DiskCacheMemoryBudget::with_config(&pool, base, increment, max); + + // Grow to 100MB (cap) + assert!(budget.ensure_capacity(100 * 1024 * 1024)); + assert_eq!(budget.total_granted(), max); + + // Drain — back to base + budget.on_queue_drained(); + assert_eq!(budget.total_granted(), base); + assert_eq!(pool.used(), base); + + // Can grow again after drain + assert!(budget.ensure_capacity(75 * 1024 * 1024)); + assert_eq!(budget.total_granted(), 75 * 1024 * 1024); + } + + #[test] + fn test_budget_max_grant_at_least_base() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let base = 100 * 1024 * 1024; + // Set max_grant smaller than base — should be raised to base + let budget = DiskCacheMemoryBudget::with_config( + &pool, base, 50 * 1024 * 1024, 50 * 1024 * 1024, + ); + assert_eq!(budget.max_grant(), base); // Raised to base + } + + #[test] + fn test_budget_zero_base_zero_max() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let budget = DiskCacheMemoryBudget::with_config(&pool, 0, 50 * 1024 * 1024, 0); + assert_eq!(budget.base_grant(), 0); + assert_eq!(budget.max_grant(), 0); // 8 * 0 = 0 + // Can't grow at all + assert!(!budget.ensure_capacity(1)); + } +} diff --git a/native/src/memory_pool/jni_bridge.rs b/native/src/memory_pool/jni_bridge.rs new file mode 100644 index 0000000..b31e524 --- /dev/null +++ b/native/src/memory_pool/jni_bridge.rs @@ -0,0 +1,256 @@ +// memory_pool/jni_bridge.rs - JNI functions for NativeMemoryManager Java class + +use std::sync::Arc; + +use jni::objects::{JClass, JObject}; +use jni::sys::{jboolean, jlong, jobject, JNI_FALSE, JNI_TRUE}; +use jni::JNIEnv; + +use super::jvm_pool::{JvmMemoryPool, JvmPoolConfig}; +use super::{set_global_pool, global_pool, is_pool_configured}; +use crate::debug_println; + +/// Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeSetAccountant +/// +/// Sets the global memory pool to a JVM-backed pool using the provided NativeMemoryAccountant. +/// Must be called before any native operations that use memory tracking. +/// +/// Returns true if set successfully, false if already set. +#[no_mangle] +pub extern "system" fn Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeSetAccountant( + mut env: JNIEnv, + _class: JClass, + accountant: JObject, + high_watermark: jni::sys::jdouble, + low_watermark: jni::sys::jdouble, + acquire_increment_bytes: jlong, + min_release_bytes: jlong, +) -> jboolean { + // Ensure JavaVM is captured for later JNI callbacks from JvmMemoryPool + crate::utils::set_jvm(&env); + + if accountant.is_null() { + debug_println!("MEMORY_POOL: nativeSetAccountant called with null accountant"); + return JNI_FALSE; + } + + if is_pool_configured() { + debug_println!("MEMORY_POOL: Pool already configured, ignoring set request"); + return JNI_FALSE; + } + + // Create a global reference so the accountant isn't GC'd + let global_ref = match env.new_global_ref(accountant) { + Ok(r) => r, + Err(e) => { + debug_println!("MEMORY_POOL: Failed to create global ref: {}", e); + return JNI_FALSE; + } + }; + + let config = JvmPoolConfig { + high_watermark: if high_watermark > 0.0 { high_watermark } else { 0.90 }, + low_watermark: if low_watermark > 0.0 { low_watermark } else { 0.25 }, + acquire_increment: if acquire_increment_bytes > 0 { + acquire_increment_bytes as usize + } else { + 64 * 1024 * 1024 + }, + min_release_amount: if min_release_bytes > 0 { + min_release_bytes as usize + } else { + 64 * 1024 * 1024 + }, + }; + + debug_println!( + "MEMORY_POOL: Creating JvmMemoryPool with config: {:?}", + config + ); + + match JvmMemoryPool::new(&mut env, global_ref, config) { + Ok(pool) => { + match set_global_pool(Arc::new(pool)) { + Ok(()) => { + debug_println!("MEMORY_POOL: Global JVM memory pool set successfully"); + JNI_TRUE + } + Err(_) => { + debug_println!("MEMORY_POOL: Failed to set pool (already set)"); + JNI_FALSE + } + } + } + Err(e) => { + debug_println!("MEMORY_POOL: Failed to create JvmMemoryPool: {}", e); + JNI_FALSE + } + } +} + +/// Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeGetUsedBytes +#[no_mangle] +pub extern "system" fn Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeGetUsedBytes( + _env: JNIEnv, + _class: JClass, +) -> jlong { + global_pool().used() as jlong +} + +/// Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeGetPeakBytes +#[no_mangle] +pub extern "system" fn Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeGetPeakBytes( + _env: JNIEnv, + _class: JClass, +) -> jlong { + global_pool().peak() as jlong +} + +/// Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeGetGrantedBytes +#[no_mangle] +pub extern "system" fn Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeGetGrantedBytes( + _env: JNIEnv, + _class: JClass, +) -> jlong { + let granted = global_pool().granted(); + if granted == usize::MAX { + -1 // Signal unlimited + } else { + granted as jlong + } +} + +/// Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeIsConfigured +#[no_mangle] +pub extern "system" fn Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeIsConfigured( + _env: JNIEnv, + _class: JClass, +) -> jboolean { + if is_pool_configured() { + JNI_TRUE + } else { + JNI_FALSE + } +} + +/// Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeResetPeak +/// +/// Resets the peak usage counter to current usage. Returns the old peak value. +/// Useful for monitoring windows — call at the start of each window to track +/// per-window peak usage. +#[no_mangle] +pub extern "system" fn Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeResetPeak( + _env: JNIEnv, + _class: JClass, +) -> jlong { + let old_peak = global_pool().reset_peak(); + debug_println!( + "📊 MEMORY_POOL: Peak reset (old peak: {} bytes, current used: {} bytes)", + old_peak, global_pool().used() + ); + old_peak as jlong +} + +/// Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeGetCategoryBreakdown +/// +/// Returns a Java HashMap with per-category memory usage. +#[no_mangle] +pub extern "system" fn Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeGetCategoryBreakdown( + mut env: JNIEnv, + _class: JClass, +) -> jobject { + let breakdown = global_pool().category_breakdown(); + + // Create Java HashMap + let hashmap_class = match env.find_class("java/util/HashMap") { + Ok(c) => c, + Err(_) => return std::ptr::null_mut(), + }; + let hashmap = match env.new_object(&hashmap_class, "()V", &[]) { + Ok(m) => m, + Err(_) => return std::ptr::null_mut(), + }; + + for (category, bytes) in &breakdown { + let key = match env.new_string(category) { + Ok(k) => k, + Err(_) => continue, + }; + let value = match env.new_object( + "java/lang/Long", + "(J)V", + &[jni::objects::JValue::Long(*bytes as i64)], + ) { + Ok(v) => v, + Err(_) => continue, + }; + + let _ = env.call_method( + &hashmap, + "put", + "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", + &[(&key).into(), (&value).into()], + ); + } + + hashmap.into_raw() +} + +/// Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeGetCategoryPeakBreakdown +/// +/// Returns a Java HashMap with per-category peak memory usage. +/// Unlike getCategoryBreakdown() which only shows currently-held memory, +/// this returns the maximum each category has ever held, even if now zero. +#[no_mangle] +pub extern "system" fn Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeGetCategoryPeakBreakdown( + mut env: JNIEnv, + _class: JClass, +) -> jobject { + let breakdown = global_pool().category_peak_breakdown(); + + let hashmap_class = match env.find_class("java/util/HashMap") { + Ok(c) => c, + Err(_) => return std::ptr::null_mut(), + }; + let hashmap = match env.new_object(&hashmap_class, "()V", &[]) { + Ok(m) => m, + Err(_) => return std::ptr::null_mut(), + }; + + for (category, bytes) in &breakdown { + let key = match env.new_string(category) { + Ok(k) => k, + Err(_) => continue, + }; + let value = match env.new_object( + "java/lang/Long", + "(J)V", + &[jni::objects::JValue::Long(*bytes as i64)], + ) { + Ok(v) => v, + Err(_) => continue, + }; + let _ = env.call_method( + &hashmap, + "put", + "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", + &[(&key).into(), (&value).into()], + ); + } + + hashmap.into_raw() +} + +/// Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeShutdown +/// +/// Signals that the JVM is shutting down. After this call, the pool will skip +/// JNI release callbacks to avoid calling releaseMemory() outside of task context. +/// Should be called before any shutdown hooks that trigger native resource cleanup. +#[no_mangle] +pub extern "system" fn Java_io_indextables_tantivy4java_memory_NativeMemoryManager_nativeShutdown( + _env: JNIEnv, + _class: JClass, +) { + debug_println!("📊 MEMORY_POOL: Shutdown signaled — skipping future JNI release callbacks"); + global_pool().shutdown(); +} diff --git a/native/src/memory_pool/jvm_pool.rs b/native/src/memory_pool/jvm_pool.rs new file mode 100644 index 0000000..bc75141 --- /dev/null +++ b/native/src/memory_pool/jvm_pool.rs @@ -0,0 +1,436 @@ +// memory_pool/jvm_pool.rs - JVM-backed memory pool with high/low watermark batching + +use std::collections::HashMap; +use std::fmt; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering::Relaxed}; +use std::sync::Mutex; + +use jni::objects::{GlobalRef, JMethodID, JValue}; +use jni::signature::ReturnType; +use jni::JNIEnv; + +use super::pool::{MemoryError, MemoryPool}; +use crate::debug_println; + +/// Default high watermark: acquire more from JVM when usage exceeds 90% of grant. +const DEFAULT_HIGH_WATERMARK: f64 = 0.90; +/// Default low watermark: release excess to JVM when usage drops below 25% of grant. +const DEFAULT_LOW_WATERMARK: f64 = 0.25; +/// Default minimum JNI acquire chunk: 64MB. +const DEFAULT_ACQUIRE_INCREMENT: usize = 64 * 1024 * 1024; +/// Default minimum amount to release back: 64MB. +const DEFAULT_MIN_RELEASE_AMOUNT: usize = 64 * 1024 * 1024; + +/// Configuration for JvmMemoryPool watermark behavior. +#[derive(Debug, Clone)] +pub struct JvmPoolConfig { + pub high_watermark: f64, + pub low_watermark: f64, + pub acquire_increment: usize, + pub min_release_amount: usize, +} + +impl Default for JvmPoolConfig { + fn default() -> Self { + Self { + high_watermark: DEFAULT_HIGH_WATERMARK, + low_watermark: DEFAULT_LOW_WATERMARK, + acquire_increment: DEFAULT_ACQUIRE_INCREMENT, + min_release_amount: DEFAULT_MIN_RELEASE_AMOUNT, + } + } +} + +/// A memory pool that coordinates with a Java NativeMemoryAccountant via JNI. +/// +/// Uses high/low watermark batching to minimize JNI round-trips: +/// - Most reserve/release calls are pure atomic operations (zero JNI). +/// - JNI calls happen only when usage crosses watermark thresholds. +/// +/// Thread-safe: all state is atomic or behind Mutex. +pub struct JvmMemoryPool { + /// Reference to the Java NativeMemoryAccountant object. + jvm_ref: GlobalRef, + /// Cached JNI method ID for acquireMemory(long) -> long. + acquire_mid: JMethodID, + /// Cached JNI method ID for releaseMemory(long) -> void. + release_mid: JMethodID, + + // Authoritative state + /// Total bytes the JVM has granted us. + jvm_granted: AtomicUsize, + /// Total bytes Rust code has reserved from us. + rust_used: AtomicUsize, + /// Peak usage observed. + peak: AtomicUsize, + + // Per-category tracking (current + peak) + categories: Mutex>, + + // Watermark configuration + config: JvmPoolConfig, + + // Mutex to serialize JNI calls (JNI method IDs are not Send in some impls) + jni_lock: Mutex<()>, + + /// When true, skip JNI release callbacks. Set during JVM shutdown to avoid + /// calling releaseMemory() outside of a task context (e.g., on shutdown hook + /// threads where Spark's TaskContext is unavailable). + shutting_down: AtomicBool, +} + +// Safety: JMethodID is a pointer that is valid for the lifetime of the JVM. +// GlobalRef prevents the Java object from being garbage collected. +// We serialize JNI calls through jni_lock. +unsafe impl Send for JvmMemoryPool {} +unsafe impl Sync for JvmMemoryPool {} + +impl fmt::Debug for JvmMemoryPool { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JvmMemoryPool") + .field("jvm_granted", &self.jvm_granted.load(Relaxed)) + .field("rust_used", &self.rust_used.load(Relaxed)) + .field("peak", &self.peak.load(Relaxed)) + .field("config", &self.config) + .finish() + } +} + +impl JvmMemoryPool { + /// Create a new JvmMemoryPool from a Java NativeMemoryAccountant object. + /// + /// # Arguments + /// * `env` - JNI environment (used only during construction to cache method IDs) + /// * `accountant` - GlobalRef to Java NativeMemoryAccountant object + /// * `config` - Watermark configuration + pub fn new( + env: &mut JNIEnv, + accountant: GlobalRef, + config: JvmPoolConfig, + ) -> Result { + // Cache method IDs for acquireMemory and releaseMemory + let class = env + .get_object_class(&accountant) + .map_err(|e| MemoryError::JniError(format!("Failed to get accountant class: {}", e)))?; + + let acquire_mid = env + .get_method_id(&class, "acquireMemory", "(J)J") + .map_err(|e| { + MemoryError::JniError(format!("Failed to find acquireMemory method: {}", e)) + })?; + + let release_mid = env + .get_method_id(&class, "releaseMemory", "(J)V") + .map_err(|e| { + MemoryError::JniError(format!("Failed to find releaseMemory method: {}", e)) + })?; + + Ok(Self { + jvm_ref: accountant, + acquire_mid, + release_mid, + jvm_granted: AtomicUsize::new(0), + rust_used: AtomicUsize::new(0), + peak: AtomicUsize::new(0), + categories: Mutex::new(HashMap::new()), + config, + jni_lock: Mutex::new(()), + shutting_down: AtomicBool::new(false), + }) + } + + /// Execute a closure with the JNI environment for the current thread. + /// + /// This keeps the `JNIEnv` borrow scoped to the closure, avoiding the need + /// for an unsafe lifetime transmute. The `jni_lock` is held for the duration, + /// serializing all JNI calls. + fn with_jni_env(&self, f: impl FnOnce(&mut JNIEnv) -> R) -> Result { + let _lock = self.jni_lock.lock().unwrap(); + + let jvm = crate::utils::get_jvm().ok_or_else(|| { + MemoryError::JniError("JavaVM not available".to_string()) + })?; + + let mut env = jvm.attach_current_thread_permanently().map_err(|e| { + MemoryError::JniError(format!("Failed to attach thread: {}", e)) + })?; + + Ok(f(&mut env)) + } + + /// Call Java acquireMemory(bytes) → returns actual bytes granted. + fn jni_acquire(&self, bytes: usize) -> Result { + self.with_jni_env(|env| { + let result = unsafe { + env.call_method_unchecked( + &self.jvm_ref, + self.acquire_mid, + ReturnType::Primitive(jni::signature::Primitive::Long), + &[JValue::Long(bytes as i64).as_jni()], + ) + }; + + match result { + Ok(val) => { + if env.exception_check().unwrap_or(false) { + env.exception_clear().ok(); + return Err(MemoryError::JniError( + "Java exception during acquireMemory".to_string(), + )); + } + let acquired = val.j().map_err(|e| { + MemoryError::JniError(format!("Failed to extract long result: {}", e)) + })?; + Ok(acquired as usize) + } + Err(e) => { + env.exception_clear().ok(); + Err(MemoryError::JniError(format!( + "JNI acquireMemory call failed: {}", + e + ))) + } + } + })? + } + + /// Call Java releaseMemory(bytes). + fn jni_release(&self, bytes: usize) -> Result<(), MemoryError> { + self.with_jni_env(|env| { + let result = unsafe { + env.call_method_unchecked( + &self.jvm_ref, + self.release_mid, + ReturnType::Primitive(jni::signature::Primitive::Void), + &[JValue::Long(bytes as i64).as_jni()], + ) + }; + + match result { + Ok(_) => { + if env.exception_check().unwrap_or(false) { + env.exception_clear().ok(); + return Err(MemoryError::JniError( + "Java exception during releaseMemory".to_string(), + )); + } + Ok(()) + } + Err(e) => { + env.exception_clear().ok(); + Err(MemoryError::JniError(format!( + "JNI releaseMemory call failed: {}", + e + ))) + } + } + })? + } + + fn update_category(&self, category: &'static str, delta: isize) { + let mut cats = self.categories.lock().unwrap(); + let tracker = cats + .entry(category) + .or_insert_with(super::pool::CategoryTracker::new); + if delta > 0 { + let new_val = tracker.current.fetch_add(delta as usize, Relaxed) + delta as usize; + // Update per-category peak + let mut old_peak = tracker.peak.load(Relaxed); + while new_val > old_peak { + match tracker.peak.compare_exchange_weak(old_peak, new_val, Relaxed, Relaxed) { + Ok(_) => break, + Err(actual) => old_peak = actual, + } + } + } else { + tracker.current.fetch_sub((-delta) as usize, Relaxed); + } + } + + fn update_peak(&self) { + let current = self.rust_used.load(Relaxed); + let mut old_peak = self.peak.load(Relaxed); + // CAS loop converges quickly: each retry means another thread updated + // peak to a higher value, and once old_peak >= current the loop exits. + while current > old_peak { + match self.peak.compare_exchange_weak(old_peak, current, Relaxed, Relaxed) { + Ok(_) => break, + Err(actual) => old_peak = actual, + } + } + } + + /// Check if we need to acquire more from JVM (usage crossed high watermark). + fn needs_jvm_acquire(&self, new_used: usize) -> bool { + let granted = self.jvm_granted.load(Relaxed); + if granted == 0 { + return true; // No grant yet, must acquire + } + new_used as f64 > granted as f64 * self.config.high_watermark + } + + /// Check if we should release excess to JVM (usage crossed low watermark). + fn should_jvm_release(&self) -> Option { + let granted = self.jvm_granted.load(Relaxed); + let used = self.rust_used.load(Relaxed); + + if granted == 0 { + return None; + } + + // When usage drops to zero, release the entire grant back to JVM + if used == 0 { + if granted >= self.config.min_release_amount { + return Some(granted); + } + return None; + } + + if (used as f64) < (granted as f64 * self.config.low_watermark) { + // Calculate how much to keep: enough headroom above current usage + let target_grant = if self.config.low_watermark > 0.0 { + (used as f64 / self.config.low_watermark) as usize + } else { + used + }; + let excess = granted.saturating_sub(target_grant); + if excess >= self.config.min_release_amount { + return Some(excess); + } + } + None + } +} + +impl MemoryPool for JvmMemoryPool { + fn try_acquire(&self, size: usize, category: &'static str) -> Result<(), MemoryError> { + if size == 0 { + return Ok(()); + } + + // Optimistic update: increment used + let new_used = self.rust_used.fetch_add(size, Relaxed) + size; + + // Check if we need more from JVM + if self.needs_jvm_acquire(new_used) { + let want = std::cmp::max(size, self.config.acquire_increment); + + match self.jni_acquire(want) { + Ok(acquired) if acquired >= size => { + self.jvm_granted.fetch_add(acquired, Relaxed); + } + Ok(acquired) => { + // JVM gave us less than we need — release what we got, undo, fail + if acquired > 0 { + let _ = self.jni_release(acquired); + } + self.rust_used.fetch_sub(size, Relaxed); + return Err(MemoryError::Denied { + requested: size, + available: acquired, + category: category.to_string(), + }); + } + Err(e) => { + self.rust_used.fetch_sub(size, Relaxed); + return Err(e); + } + } + } + + self.update_category(category, size as isize); + self.update_peak(); + Ok(()) + } + + fn release(&self, size: usize, category: &'static str) { + if size == 0 { + return; + } + + self.rust_used.fetch_sub(size, Relaxed); + self.update_category(category, -(size as isize)); + + // Check if we should release excess to JVM. + // Skip JNI release during shutdown — the JVM is exiting and the + // accountant's task context may no longer be available. + if self.shutting_down.load(Relaxed) { + return; + } + + // Cap the release to `size` (the amount being freed by this call) to + // prevent releasing more to the JVM accountant than what the current + // thread's reservation held. The global pool batches acquisitions across + // threads, so the excess can be larger than any single thread's total. + // Without capping, a per-task accountant (e.g., Spark's ExecutionMemoryPool) + // would see releaseMemory(X) where X exceeds what that task acquired. + // Any remaining excess stays in jvm_granted and will be released by + // subsequent operations or on pool shutdown. + if let Some(excess) = self.should_jvm_release() { + let to_release = excess.min(size); + if to_release > 0 { + if let Ok(()) = self.jni_release(to_release) { + self.jvm_granted.fetch_sub(to_release, Relaxed); + } + } + // If JNI release fails, we just keep the grant — no harm done + } + } + + fn used(&self) -> usize { + self.rust_used.load(Relaxed) + } + + fn peak(&self) -> usize { + self.peak.load(Relaxed) + } + + fn reset_peak(&self) -> usize { + let current = self.rust_used.load(Relaxed); + self.peak.swap(current, Relaxed) + } + + fn granted(&self) -> usize { + self.jvm_granted.load(Relaxed) + } + + fn category_breakdown(&self) -> HashMap { + let cats = self.categories.lock().unwrap(); + cats.iter() + .map(|(k, v)| (k.to_string(), v.current.load(Relaxed))) + .filter(|(_, v)| *v > 0) + .collect() + } + + fn category_peak_breakdown(&self) -> HashMap { + let cats = self.categories.lock().unwrap(); + cats.iter() + .map(|(k, v)| (k.to_string(), v.peak.load(Relaxed))) + .filter(|(_, v)| *v > 0) + .collect() + } + + fn shutdown(&self) { + self.shutting_down.store(true, Relaxed); + } +} + +impl Drop for JvmMemoryPool { + fn drop(&mut self) { + // Skip JNI release during shutdown — the JVM is exiting and the + // accountant's task context may no longer be available. + if self.shutting_down.load(Relaxed) { + return; + } + // Release all remaining grant back to JVM + let granted = self.jvm_granted.swap(0, Relaxed); + if granted > 0 { + if let Err(e) = self.jni_release(granted) { + debug_println!( + "MEMORY_POOL: Failed to release {} bytes back to JVM during drop: {}", + granted, e + ); + } + } + } +} diff --git a/native/src/memory_pool/mod.rs b/native/src/memory_pool/mod.rs new file mode 100644 index 0000000..bee066b --- /dev/null +++ b/native/src/memory_pool/mod.rs @@ -0,0 +1,101 @@ +// memory_pool/mod.rs - Unified memory management for JVM-coordinated native allocations +// +// This module provides a memory pool that coordinates Rust-side memory allocations +// with an external JVM memory manager (e.g., Spark's TaskMemoryManager), following +// the pattern established by DataFusion Comet's CometUnifiedMemoryPool. +// +// Key design: +// - MemoryPool trait with try_acquire/release + category tracking +// - JvmMemoryPool uses high/low watermark batching to minimize JNI round-trips +// - UnlimitedMemoryPool for backward compatibility (no JVM coordination) +// - MemoryReservation as RAII guard for automatic cleanup +// - Global pool: lazy default (UnlimitedMemoryPool) that can be replaced once +// by set_global_pool before or after first use + +mod pool; +mod reservation; +mod jvm_pool; +mod jni_bridge; +mod disk_cache_budget; + +pub use pool::{MemoryPool, UnlimitedMemoryPool, MemoryError}; +#[cfg(test)] +pub use pool::LimitedMemoryPool; +pub use reservation::MemoryReservation; +pub use jvm_pool::JvmMemoryPool; +pub use jni_bridge::*; +pub use disk_cache_budget::DiskCacheMemoryBudget; + +use std::sync::{Arc, RwLock}; +use std::sync::atomic::{AtomicBool, Ordering}; + +use once_cell::sync::Lazy; + +/// Global memory pool state. +/// - Starts as UnlimitedMemoryPool (lazy default). +/// - Can be replaced exactly once via set_global_pool() (before or after first use). +/// - After explicit set, further set calls are rejected. +static GLOBAL_MEMORY_POOL: Lazy>> = + Lazy::new(|| RwLock::new(Arc::new(UnlimitedMemoryPool::default()))); + +/// Tracks whether set_global_pool has been called (prevents double-set). +static EXPLICITLY_CONFIGURED: AtomicBool = AtomicBool::new(false); + +/// Get the global memory pool. +pub fn global_pool() -> Arc { + GLOBAL_MEMORY_POOL.read().unwrap_or_else(|e| e.into_inner()).clone() +} + +/// Set the global memory pool. Can be called once to replace the default. +/// Returns Err if already explicitly set via a prior call to set_global_pool. +pub fn set_global_pool(pool: Arc) -> Result<(), Arc> { + // Atomically check-and-set the configured flag + if EXPLICITLY_CONFIGURED.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err() { + return Err(pool); + } + *GLOBAL_MEMORY_POOL.write().unwrap() = pool; + Ok(()) +} + +/// Check if a custom (non-default) memory pool has been explicitly configured. +pub fn is_pool_configured() -> bool { + EXPLICITLY_CONFIGURED.load(Ordering::SeqCst) +} + +/// Global search arena reservation. +/// +/// Instead of reserving 16MB per SplitSearcher (which could mean 16GB for 1000 +/// cached searchers), we reserve `max_concurrency × 16MB` once. This correctly +/// reflects that only `max_concurrency` searchers execute simultaneously, even +/// if many more are cached. +/// +/// Initialized lazily on first searcher creation. Fails fast if the pool denies. +static SEARCH_ARENA_RESERVATION: std::sync::OnceLock>> = + std::sync::OnceLock::new(); + +/// Size of each search arena slot (16MB). +pub const SEARCH_ARENA_SLOT_SIZE: usize = 16 * 1024 * 1024; + +/// Initialize the global search arena reservation if not already done. +/// Reserves `max_concurrency × 16MB` from the pool. +/// Returns Err if the pool denies the reservation. +pub fn init_search_arena() -> Result<(), MemoryError> { + let holder = SEARCH_ARENA_RESERVATION.get_or_init(|| std::sync::Mutex::new(None)); + let mut guard = holder.lock().unwrap(); + if guard.is_some() { + return Ok(()); // Already initialized + } + let max_threads = crate::split_searcher::cache_config::get_max_java_threads(); + let total_arena = max_threads * SEARCH_ARENA_SLOT_SIZE; + let reservation = MemoryReservation::try_new( + &global_pool(), + total_arena, + "search_results", + )?; + crate::debug_println!( + "📊 SEARCH_ARENA: Reserved {} MB for {} concurrent search slots ({} MB each)", + total_arena / 1024 / 1024, max_threads, SEARCH_ARENA_SLOT_SIZE / 1024 / 1024 + ); + *guard = Some(reservation); + Ok(()) +} diff --git a/native/src/memory_pool/pool.rs b/native/src/memory_pool/pool.rs new file mode 100644 index 0000000..a3ae25b --- /dev/null +++ b/native/src/memory_pool/pool.rs @@ -0,0 +1,457 @@ +// memory_pool/pool.rs - Core MemoryPool trait and UnlimitedMemoryPool implementation + +use std::collections::HashMap; +use std::fmt; +use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; +use std::sync::Mutex; + +/// Error type for memory pool operations. +#[derive(Debug, Clone)] +pub enum MemoryError { + /// Memory request was denied by the external manager. + Denied { + requested: usize, + available: usize, + category: String, + }, + /// JNI call to the memory manager failed. + JniError(String), +} + +impl fmt::Display for MemoryError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MemoryError::Denied { requested, available, category } => { + write!( + f, + "Memory request denied: requested {} bytes for '{}', available {} bytes", + requested, category, available + ) + } + MemoryError::JniError(msg) => write!(f, "JNI memory pool error: {}", msg), + } + } +} + +impl std::error::Error for MemoryError {} + +/// Trait for memory pools that coordinate native memory allocations. +/// +/// Implementations must be thread-safe (Send + Sync). +pub trait MemoryPool: Send + Sync + fmt::Debug { + /// Try to acquire `size` bytes in the given category. + /// Returns Ok(()) if granted, Err(MemoryError) if denied. + fn try_acquire(&self, size: usize, category: &'static str) -> Result<(), MemoryError>; + + /// Release `size` bytes back to the pool in the given category. + fn release(&self, size: usize, category: &'static str); + + /// Current total memory held by this pool across all categories. + fn used(&self) -> usize; + + /// Peak memory usage observed since creation or last reset. + fn peak(&self) -> usize; + + /// Reset peak usage counter to current usage. Returns the old peak value. + fn reset_peak(&self) -> usize; + + /// Total bytes granted by the external manager (for JVM pools) or usize::MAX (for unlimited). + fn granted(&self) -> usize; + + /// Per-category memory breakdown. + fn category_breakdown(&self) -> HashMap; + + /// Per-category peak memory breakdown. Returns the maximum bytes each category + /// has held since creation or last reset, even if currently zero. + fn category_peak_breakdown(&self) -> HashMap; + + /// Signal that the JVM is shutting down. Subsequent release() calls skip + /// JNI callbacks to avoid calling releaseMemory() outside of task context. + /// Default implementation is a no-op (for non-JVM pools). + fn shutdown(&self) {} +} + +/// Per-category tracking with current and peak usage. +#[derive(Debug)] +pub(super) struct CategoryTracker { + pub(super) current: AtomicUsize, + pub(super) peak: AtomicUsize, +} + +impl CategoryTracker { + pub(super) fn new() -> Self { + Self { + current: AtomicUsize::new(0), + peak: AtomicUsize::new(0), + } + } +} + +/// An unlimited memory pool that always grants requests. +/// Used as the default when no external memory manager is configured. +/// Provides local tracking for statistics without any JNI overhead. +#[derive(Debug)] +pub struct UnlimitedMemoryPool { + used: AtomicUsize, + peak: AtomicUsize, + categories: Mutex>, +} + +impl Default for UnlimitedMemoryPool { + fn default() -> Self { + Self { + used: AtomicUsize::new(0), + peak: AtomicUsize::new(0), + categories: Mutex::new(HashMap::new()), + } + } +} + +impl UnlimitedMemoryPool { + fn update_category(&self, category: &'static str, delta: isize) { + let mut cats = self.categories.lock().unwrap(); + let tracker = cats + .entry(category) + .or_insert_with(CategoryTracker::new); + if delta > 0 { + let new_val = tracker.current.fetch_add(delta as usize, Relaxed) + delta as usize; + // Update per-category peak + let mut old_peak = tracker.peak.load(Relaxed); + while new_val > old_peak { + match tracker.peak.compare_exchange_weak(old_peak, new_val, Relaxed, Relaxed) { + Ok(_) => break, + Err(actual) => old_peak = actual, + } + } + } else { + tracker.current.fetch_sub((-delta) as usize, Relaxed); + } + } + + fn update_peak(&self) { + let current = self.used.load(Relaxed); + let mut old_peak = self.peak.load(Relaxed); + while current > old_peak { + match self.peak.compare_exchange_weak(old_peak, current, Relaxed, Relaxed) { + Ok(_) => break, + Err(actual) => old_peak = actual, + } + } + } +} + +impl MemoryPool for UnlimitedMemoryPool { + fn try_acquire(&self, size: usize, category: &'static str) -> Result<(), MemoryError> { + self.used.fetch_add(size, Relaxed); + self.update_category(category, size as isize); + self.update_peak(); + Ok(()) + } + + fn release(&self, size: usize, category: &'static str) { + self.used.fetch_sub(size, Relaxed); + self.update_category(category, -(size as isize)); + } + + fn used(&self) -> usize { + self.used.load(Relaxed) + } + + fn peak(&self) -> usize { + self.peak.load(Relaxed) + } + + fn reset_peak(&self) -> usize { + let current = self.used.load(Relaxed); + self.peak.swap(current, Relaxed) + } + + fn granted(&self) -> usize { + usize::MAX + } + + fn category_breakdown(&self) -> HashMap { + let cats = self.categories.lock().unwrap(); + cats.iter() + .map(|(k, v)| (k.to_string(), v.current.load(Relaxed))) + .filter(|(_, v)| *v > 0) + .collect() + } + + fn category_peak_breakdown(&self) -> HashMap { + let cats = self.categories.lock().unwrap(); + cats.iter() + .map(|(k, v)| (k.to_string(), v.peak.load(Relaxed))) + .filter(|(_, v)| *v > 0) + .collect() + } +} + +/// A memory pool with a hard capacity limit. Used for testing fail-fast behavior. +/// Returns `MemoryError::Denied` when a request would exceed the capacity. +#[cfg(test)] +#[derive(Debug)] +pub struct LimitedMemoryPool { + capacity: usize, + used: AtomicUsize, + peak: AtomicUsize, + categories: Mutex>, +} + +#[cfg(test)] +impl LimitedMemoryPool { + pub fn new(capacity: usize) -> Self { + Self { + capacity, + used: AtomicUsize::new(0), + peak: AtomicUsize::new(0), + categories: Mutex::new(HashMap::new()), + } + } + + fn update_category(&self, category: &'static str, delta: isize) { + let mut cats = self.categories.lock().unwrap(); + let tracker = cats.entry(category).or_insert_with(CategoryTracker::new); + if delta > 0 { + let new_val = tracker.current.fetch_add(delta as usize, Relaxed) + delta as usize; + let mut old_peak = tracker.peak.load(Relaxed); + while new_val > old_peak { + match tracker.peak.compare_exchange_weak(old_peak, new_val, Relaxed, Relaxed) { + Ok(_) => break, + Err(actual) => old_peak = actual, + } + } + } else { + tracker.current.fetch_sub((-delta) as usize, Relaxed); + } + } + + fn update_peak(&self) { + let current = self.used.load(Relaxed); + let mut old_peak = self.peak.load(Relaxed); + while current > old_peak { + match self.peak.compare_exchange_weak(old_peak, current, Relaxed, Relaxed) { + Ok(_) => break, + Err(actual) => old_peak = actual, + } + } + } +} + +#[cfg(test)] +impl MemoryPool for LimitedMemoryPool { + fn try_acquire(&self, size: usize, category: &'static str) -> Result<(), MemoryError> { + if size == 0 { + return Ok(()); + } + let current = self.used.load(Relaxed); + if current + size > self.capacity { + return Err(MemoryError::Denied { + requested: size, + available: self.capacity.saturating_sub(current), + category: category.to_string(), + }); + } + self.used.fetch_add(size, Relaxed); + self.update_category(category, size as isize); + self.update_peak(); + Ok(()) + } + + fn release(&self, size: usize, category: &'static str) { + self.used.fetch_sub(size, Relaxed); + self.update_category(category, -(size as isize)); + } + + fn used(&self) -> usize { self.used.load(Relaxed) } + fn peak(&self) -> usize { self.peak.load(Relaxed) } + fn reset_peak(&self) -> usize { + let current = self.used.load(Relaxed); + self.peak.swap(current, Relaxed) + } + fn granted(&self) -> usize { self.capacity } + fn category_breakdown(&self) -> HashMap { + let cats = self.categories.lock().unwrap(); + cats.iter() + .map(|(k, v)| (k.to_string(), v.current.load(Relaxed))) + .filter(|(_, v)| *v > 0) + .collect() + } + fn category_peak_breakdown(&self) -> HashMap { + let cats = self.categories.lock().unwrap(); + cats.iter() + .map(|(k, v)| (k.to_string(), v.peak.load(Relaxed))) + .filter(|(_, v)| *v > 0) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_unlimited_pool_basic() { + let pool = UnlimitedMemoryPool::default(); + assert_eq!(pool.used(), 0); + assert_eq!(pool.peak(), 0); + + pool.try_acquire(1000, "test").unwrap(); + assert_eq!(pool.used(), 1000); + assert_eq!(pool.peak(), 1000); + + pool.try_acquire(2000, "test").unwrap(); + assert_eq!(pool.used(), 3000); + assert_eq!(pool.peak(), 3000); + + pool.release(1500, "test"); + assert_eq!(pool.used(), 1500); + assert_eq!(pool.peak(), 3000); // peak unchanged + + pool.release(1500, "test"); + assert_eq!(pool.used(), 0); + } + + #[test] + fn test_unlimited_pool_categories() { + let pool = UnlimitedMemoryPool::default(); + + pool.try_acquire(100, "index_writer").unwrap(); + pool.try_acquire(200, "l1_cache").unwrap(); + pool.try_acquire(50, "index_writer").unwrap(); + + let breakdown = pool.category_breakdown(); + assert_eq!(*breakdown.get("index_writer").unwrap(), 150); + assert_eq!(*breakdown.get("l1_cache").unwrap(), 200); + assert_eq!(pool.used(), 350); + + pool.release(100, "index_writer"); + let breakdown = pool.category_breakdown(); + assert_eq!(*breakdown.get("index_writer").unwrap(), 50); + + // Release everything — current breakdown should be empty + pool.release(50, "index_writer"); + pool.release(200, "l1_cache"); + let breakdown = pool.category_breakdown(); + assert!(breakdown.is_empty(), "Current breakdown should be empty after full release"); + + // But peak breakdown should still show historical maximums + let peak_breakdown = pool.category_peak_breakdown(); + assert_eq!(*peak_breakdown.get("index_writer").unwrap(), 150); + assert_eq!(*peak_breakdown.get("l1_cache").unwrap(), 200); + } + + #[test] + fn test_unlimited_pool_reset_peak() { + let pool = UnlimitedMemoryPool::default(); + + pool.try_acquire(5000, "test").unwrap(); + assert_eq!(pool.peak(), 5000); + + pool.release(3000, "test"); + assert_eq!(pool.used(), 2000); + assert_eq!(pool.peak(), 5000); + + // Reset peak — returns old peak, sets to current used + let old_peak = pool.reset_peak(); + assert_eq!(old_peak, 5000); + assert_eq!(pool.peak(), 2000); + + // New acquire below old peak — peak tracks correctly from reset point + pool.try_acquire(1000, "test").unwrap(); + assert_eq!(pool.peak(), 3000); + } + + #[test] + fn test_unlimited_pool_always_grants() { + let pool = UnlimitedMemoryPool::default(); + // Even very large requests succeed + assert!(pool.try_acquire(1_000_000_000_000, "huge").is_ok()); + assert_eq!(pool.granted(), usize::MAX); + } + + #[test] + fn test_unlimited_pool_thread_safety() { + use std::sync::Arc; + use std::thread; + + static THREAD_CATS: [&str; 10] = [ + "thread_0", "thread_1", "thread_2", "thread_3", "thread_4", + "thread_5", "thread_6", "thread_7", "thread_8", "thread_9", + ]; + + let pool = Arc::new(UnlimitedMemoryPool::default()); + let mut handles = vec![]; + + for i in 0..10 { + let pool = pool.clone(); + handles.push(thread::spawn(move || { + let cat = THREAD_CATS[i]; + for _ in 0..100 { + pool.try_acquire(1000, cat).unwrap(); + } + for _ in 0..100 { + pool.release(1000, cat); + } + })); + } + + for h in handles { + h.join().unwrap(); + } + + assert_eq!(pool.used(), 0); + } + + // ======================================================================== + // LimitedMemoryPool tests — validate fail-fast denial behavior + // ======================================================================== + + #[test] + fn test_limited_pool_grants_within_capacity() { + let pool = LimitedMemoryPool::new(1000); + assert!(pool.try_acquire(500, "test").is_ok()); + assert!(pool.try_acquire(500, "test").is_ok()); + assert_eq!(pool.used(), 1000); + assert_eq!(pool.granted(), 1000); + } + + #[test] + fn test_limited_pool_denies_over_capacity() { + let pool = LimitedMemoryPool::new(1000); + pool.try_acquire(800, "test").unwrap(); + + // Request that would exceed capacity + let result = pool.try_acquire(300, "test"); + assert!(result.is_err()); + match result.unwrap_err() { + MemoryError::Denied { requested, available, category } => { + assert_eq!(requested, 300); + assert_eq!(available, 200); + assert_eq!(category, "test"); + } + other => panic!("Expected Denied, got: {:?}", other), + } + // Used should not have changed (no partial allocation) + assert_eq!(pool.used(), 800); + } + + #[test] + fn test_limited_pool_release_frees_capacity() { + let pool = LimitedMemoryPool::new(1000); + pool.try_acquire(1000, "test").unwrap(); + assert!(pool.try_acquire(1, "test").is_err()); // Full + + pool.release(500, "test"); + assert!(pool.try_acquire(500, "test").is_ok()); // Now fits + assert_eq!(pool.used(), 1000); + } + + #[test] + fn test_limited_pool_zero_capacity_denies_everything() { + let pool = LimitedMemoryPool::new(0); + assert!(pool.try_acquire(1, "test").is_err()); + // Zero-size requests still succeed + assert!(pool.try_acquire(0, "test").is_ok()); + } +} diff --git a/native/src/memory_pool/reservation.rs b/native/src/memory_pool/reservation.rs new file mode 100644 index 0000000..dbd7930 --- /dev/null +++ b/native/src/memory_pool/reservation.rs @@ -0,0 +1,290 @@ +// memory_pool/reservation.rs - RAII memory reservation guard + +use std::sync::Arc; + +use crate::debug_println; +use super::pool::{MemoryError, MemoryPool}; + +/// RAII guard that holds a memory reservation from a MemoryPool. +/// Automatically releases the reserved memory when dropped. +/// +/// # Example +/// ``` +/// let reservation = MemoryReservation::try_new(&pool, 1024, "index_writer")?; +/// // ... use the memory ... +/// // reservation is automatically released when it goes out of scope +/// ``` +pub struct MemoryReservation { + pool: Arc, + size: usize, + category: &'static str, +} + +impl MemoryReservation { + /// Try to create a new reservation, acquiring `size` bytes from the pool. + /// Returns Err if the pool denies the request. + pub fn try_new( + pool: &Arc, + size: usize, + category: &'static str, + ) -> Result { + if size == 0 { + return Ok(Self { + pool: Arc::clone(pool), + size: 0, + category, + }); + } + pool.try_acquire(size, category)?; + debug_println!( + "📊 MEMORY_POOL: Reserved {} bytes for '{}' (pool total: {} bytes)", + size, category, pool.used() + ); + Ok(Self { + pool: Arc::clone(pool), + size, + category, + }) + } + + /// Create a reservation that doesn't track any memory (zero-cost no-op). + pub fn empty(pool: &Arc, category: &'static str) -> Self { + Self { + pool: Arc::clone(pool), + size: 0, + category, + } + } + + /// Resize this reservation. Acquires more if growing, releases if shrinking. + pub fn resize(&mut self, new_size: usize) -> Result<(), MemoryError> { + if new_size == self.size { + return Ok(()); + } + if new_size > self.size { + let additional = new_size - self.size; + self.pool.try_acquire(additional, self.category)?; + } else { + let decrease = self.size - new_size; + self.pool.release(decrease, self.category); + } + self.size = new_size; + Ok(()) + } + + /// Grow this reservation by `additional` bytes. + pub fn grow(&mut self, additional: usize) -> Result<(), MemoryError> { + self.resize(self.size + additional) + } + + /// Shrink this reservation by `amount` bytes. + pub fn shrink(&mut self, amount: usize) { + let new_size = self.size.saturating_sub(amount); + let _ = self.resize(new_size); + } + + /// Current size of this reservation in bytes. + pub fn size(&self) -> usize { + self.size + } + + /// Category this reservation belongs to. + pub fn category(&self) -> &'static str { + self.category + } + + /// Release all memory and reset to zero. + pub fn release_all(&mut self) { + if self.size > 0 { + self.pool.release(self.size, self.category); + self.size = 0; + } + } +} + +impl Drop for MemoryReservation { + fn drop(&mut self) { + if self.size > 0 { + self.pool.release(self.size, self.category); + debug_println!( + "📊 MEMORY_POOL: Released {} bytes for '{}' (pool total: {} bytes)", + self.size, self.category, self.pool.used() + ); + } + } +} + +impl std::fmt::Debug for MemoryReservation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MemoryReservation") + .field("size", &self.size) + .field("category", &self.category) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory_pool::UnlimitedMemoryPool; + + #[test] + fn test_reservation_basic_lifecycle() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + + { + let reservation = MemoryReservation::try_new(&pool, 1024, "test").unwrap(); + assert_eq!(reservation.size(), 1024); + assert_eq!(pool.used(), 1024); + } + // Dropped — memory released + assert_eq!(pool.used(), 0); + } + + #[test] + fn test_reservation_resize_grow() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let mut reservation = MemoryReservation::try_new(&pool, 1000, "test").unwrap(); + assert_eq!(pool.used(), 1000); + + reservation.resize(2000).unwrap(); + assert_eq!(reservation.size(), 2000); + assert_eq!(pool.used(), 2000); + } + + #[test] + fn test_reservation_resize_shrink() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let mut reservation = MemoryReservation::try_new(&pool, 2000, "test").unwrap(); + + reservation.resize(500).unwrap(); + assert_eq!(reservation.size(), 500); + assert_eq!(pool.used(), 500); + } + + #[test] + fn test_reservation_grow_and_shrink() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let mut reservation = MemoryReservation::try_new(&pool, 1000, "test").unwrap(); + + reservation.grow(500).unwrap(); + assert_eq!(reservation.size(), 1500); + assert_eq!(pool.used(), 1500); + + reservation.shrink(300); + assert_eq!(reservation.size(), 1200); + assert_eq!(pool.used(), 1200); + } + + #[test] + fn test_reservation_release_all() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let mut reservation = MemoryReservation::try_new(&pool, 5000, "test").unwrap(); + assert_eq!(pool.used(), 5000); + + reservation.release_all(); + assert_eq!(reservation.size(), 0); + assert_eq!(pool.used(), 0); + + // Drop should be no-op since already released + drop(reservation); + assert_eq!(pool.used(), 0); + } + + #[test] + fn test_reservation_zero_size() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let reservation = MemoryReservation::try_new(&pool, 0, "test").unwrap(); + assert_eq!(reservation.size(), 0); + assert_eq!(pool.used(), 0); + drop(reservation); + assert_eq!(pool.used(), 0); + } + + #[test] + fn test_reservation_empty() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + let reservation = MemoryReservation::empty(&pool, "test"); + assert_eq!(reservation.size(), 0); + assert_eq!(pool.used(), 0); + } + + #[test] + fn test_multiple_reservations_same_pool() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + + let r1 = MemoryReservation::try_new(&pool, 100, "index_writer").unwrap(); + let r2 = MemoryReservation::try_new(&pool, 200, "l1_cache").unwrap(); + let r3 = MemoryReservation::try_new(&pool, 300, "merge").unwrap(); + + assert_eq!(pool.used(), 600); + + drop(r2); + assert_eq!(pool.used(), 400); + + drop(r1); + drop(r3); + assert_eq!(pool.used(), 0); + } + + #[test] + fn test_reservation_category_tracking() { + let pool: Arc = Arc::new(UnlimitedMemoryPool::default()); + + let _r1 = MemoryReservation::try_new(&pool, 100, "index_writer").unwrap(); + let _r2 = MemoryReservation::try_new(&pool, 200, "l1_cache").unwrap(); + + let breakdown = pool.category_breakdown(); + assert_eq!(*breakdown.get("index_writer").unwrap(), 100); + assert_eq!(*breakdown.get("l1_cache").unwrap(), 200); + } + + // ======================================================================== + // Fail-fast denial tests — validates that MemoryReservation propagates errors + // ======================================================================== + + #[test] + fn test_reservation_denied_when_pool_full() { + use crate::memory_pool::LimitedMemoryPool; + let pool: Arc = Arc::new(LimitedMemoryPool::new(1000)); + + // First reservation succeeds + let _r1 = MemoryReservation::try_new(&pool, 800, "test").unwrap(); + + // Second reservation should be denied (would exceed capacity) + let result = MemoryReservation::try_new(&pool, 300, "test"); + assert!(result.is_err(), "Should fail when pool is nearly full"); + + // Pool used should still be only from r1 (no partial allocation) + assert_eq!(pool.used(), 800); + } + + #[test] + fn test_reservation_denial_does_not_leak_memory() { + use crate::memory_pool::LimitedMemoryPool; + let pool: Arc = Arc::new(LimitedMemoryPool::new(500)); + + // Denied reservation should not leak any memory + let result = MemoryReservation::try_new(&pool, 1000, "big"); + assert!(result.is_err()); + assert_eq!(pool.used(), 0, "No memory should be reserved after denial"); + + // Pool should still be fully usable after denial + let _r = MemoryReservation::try_new(&pool, 500, "fits").unwrap(); + assert_eq!(pool.used(), 500); + } + + #[test] + fn test_reservation_grow_denied_when_pool_full() { + use crate::memory_pool::LimitedMemoryPool; + let pool: Arc = Arc::new(LimitedMemoryPool::new(1000)); + + let mut r = MemoryReservation::try_new(&pool, 800, "test").unwrap(); + + // Grow should fail if it would exceed capacity + let result = r.grow(300); + assert!(result.is_err()); + assert_eq!(r.size(), 800, "Reservation size unchanged after failed grow"); + assert_eq!(pool.used(), 800, "Pool used unchanged after failed grow"); + } +} diff --git a/native/src/parquet_companion/arrow_ffi_export.rs b/native/src/parquet_companion/arrow_ffi_export.rs index 143cdf1..0f649f5 100644 --- a/native/src/parquet_companion/arrow_ffi_export.rs +++ b/native/src/parquet_companion/arrow_ffi_export.rs @@ -13,7 +13,7 @@ use std::sync::Arc; use anyhow::{Context, Result}; use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema}; -use arrow_array::{RecordBatch, UInt32Array}; +use arrow_array::{Array, RecordBatch, UInt32Array}; use arrow_schema::{DataType, Field, Schema, TimeUnit}; use quickwit_storage::Storage; @@ -162,7 +162,20 @@ pub async fn batch_parquet_to_arrow_ffi( t_rename.elapsed().as_millis() ); - // Step 6: Export each column via Arrow FFI + // Step 6: Memory pool reservation for Arrow FFI data + let estimated_size: usize = renamed.columns().iter() + .map(|col| col.get_buffer_memory_size()) + .sum(); + let _ffi_reservation = crate::memory_pool::MemoryReservation::try_new( + &crate::memory_pool::global_pool(), + estimated_size, + "arrow_ffi", + ) + .unwrap_or_else(|_| { + crate::memory_pool::MemoryReservation::empty(&crate::memory_pool::global_pool(), "arrow_ffi") + }); + + // Step 7: Export each column via Arrow FFI let num_cols = renamed.num_columns(); if array_addrs.len() < num_cols || schema_addrs.len() < num_cols { anyhow::bail!( diff --git a/native/src/parquet_companion/arrow_ffi_import.rs b/native/src/parquet_companion/arrow_ffi_import.rs index b98e791..f6af4de 100644 --- a/native/src/parquet_companion/arrow_ffi_import.rs +++ b/native/src/parquet_companion/arrow_ffi_import.rs @@ -22,6 +22,7 @@ use tantivy::schema::{Schema as TantivySchema, Field}; use uuid::Uuid; use crate::debug_println; +use crate::memory_pool::{self, MemoryReservation}; use super::indexing::{arrow_row_to_tantivy_doc, add_arrow_value_to_doc, add_string_value_to_doc, convert_arrow_to_owned_value, is_complex_arrow_type}; use super::manifest::FastFieldMode; use super::name_mapping::NameMapping; @@ -83,6 +84,8 @@ struct PartitionWriter { partition_values: HashMap, /// Per-column statistics accumulators (populated when stats_columns is non-empty) accumulators: HashMap, + /// Memory reservation for the writer's heap — released on Drop/finalization. + _memory_reservation: MemoryReservation, } struct FieldMapping { @@ -329,6 +332,20 @@ fn create_partition_writer( stats_columns: &std::collections::HashSet, field_mapping: &[FieldMapping], ) -> Result { + // Reserve memory from the global pool for this writer's heap. + // Fail-fast: if the pool denies, propagate the error to Java. + let reservation = MemoryReservation::try_new( + &memory_pool::global_pool(), + heap_size, + "index_writer", + ).map_err(|e| { + anyhow::anyhow!( + "Memory pool denied Arrow FFI writer allocation of {} MB: {}. \ + Reduce heap size or increase pool capacity.", + heap_size / 1_000_000, e + ) + })?; + let index_dir = tempfile::tempdir() .context("Failed to create temp directory for partition writer")?; @@ -375,6 +392,7 @@ fn create_partition_writer( doc_count: 0, partition_values, accumulators, + _memory_reservation: reservation, }) } diff --git a/native/src/parquet_companion/augmented_directory.rs b/native/src/parquet_companion/augmented_directory.rs index 33d1084..244349a 100644 --- a/native/src/parquet_companion/augmented_directory.rs +++ b/native/src/parquet_companion/augmented_directory.rs @@ -240,6 +240,17 @@ impl ParquetAugmentedDirectory { let num_docs: u32 = self.manifest.total_rows.try_into() .map_err(|_| anyhow::anyhow!("total_rows {} exceeds u32::MAX", self.manifest.total_rows))?; + // Estimate transcode output size: ~8 bytes per doc per column (columnar overhead) + let estimated_transcode_bytes = num_docs as usize * columns.len() * 8; + let _transcode_reservation = crate::memory_pool::MemoryReservation::try_new( + &crate::memory_pool::global_pool(), + estimated_transcode_bytes, + "parquet_transcode", + ).map_err(|e| { + anyhow::anyhow!("Memory pool denied parquet transcode reservation of {} bytes: {}", + estimated_transcode_bytes, e) + })?; + // Transcode parquet columns to columnar bytes let parquet_columnar_bytes = transcode_columns_from_parquet( &columns, diff --git a/native/src/parquet_companion/cached_reader.rs b/native/src/parquet_companion/cached_reader.rs index cdbb8e7..5f147e6 100644 --- a/native/src/parquet_companion/cached_reader.rs +++ b/native/src/parquet_companion/cached_reader.rs @@ -20,6 +20,7 @@ use parquet::file::page_index::offset_index::{OffsetIndexMetaData, PageLocation} use quickwit_storage::Storage; use crate::debug_println; +use crate::memory_pool::{self, MemoryReservation}; use crate::perf_println; /// Cache key: (file_path, start_byte, end_byte) @@ -72,18 +73,54 @@ impl CoalesceConfig { } } -/// Shared byte-range cache across multiple CachedParquetReader instances. +/// Byte-range cache that tracks its memory usage through the global memory pool. +/// +/// Wraps an LRU cache of parquet byte ranges (dictionary pages, data pages) and +/// grows/shrinks a `MemoryReservation` on every insert/eviction. Because the +/// `JvmMemoryPool` uses high/low watermark batching, most grow/shrink calls are +/// just an atomic add — JNI round-trips only happen at watermark crossings. +pub struct TrackedByteRangeCache { + cache: lru::LruCache, + reservation: MemoryReservation, +} + +impl TrackedByteRangeCache { + /// Look up a cached byte range. Returns a clone of the cached bytes. + pub fn get(&mut self, key: &ByteRangeCacheKey) -> Option { + self.cache.get(key).cloned() + } + + /// Insert a byte range into the cache, tracking memory. + /// Uses `push()` to capture LRU evictions and shrink the reservation. + /// Growth is best-effort — if the pool denies, we still cache (bounded & transient). + pub fn put(&mut self, key: ByteRangeCacheKey, value: Bytes) { + let new_bytes = value.len(); + if let Some((_evicted_key, evicted_value)) = self.cache.push(key, value) { + self.reservation.shrink(evicted_value.len()); + } + let _ = self.reservation.grow(new_bytes); // best-effort + } +} + +/// Shared tracked byte-range cache across multiple CachedParquetReader instances. /// Caches fetched byte ranges (e.g. dictionary pages, data pages) to avoid /// redundant S3/Azure downloads when retrieving multiple docs from the same file. /// Uses LRU eviction to bound memory — least-recently-used entries are evicted /// when the cache exceeds MAX_BYTE_CACHE_ENTRIES. -pub type ByteRangeCache = Arc>>; +pub type ByteRangeCache = Arc>; -/// Create a new empty byte-range cache. +/// Create a new empty byte-range cache with memory pool tracking. pub fn new_byte_range_cache() -> ByteRangeCache { - Arc::new(Mutex::new(lru::LruCache::new( - std::num::NonZeroUsize::new(MAX_BYTE_CACHE_ENTRIES).unwrap(), - ))) + let reservation = MemoryReservation::empty( + &memory_pool::global_pool(), + "parquet_byte_range_cache", + ); + Arc::new(Mutex::new(TrackedByteRangeCache { + cache: lru::LruCache::new( + std::num::NonZeroUsize::new(MAX_BYTE_CACHE_ENTRIES).unwrap(), + ), + reservation, + })) } /// An AsyncFileReader that delegates to Quickwit's Storage trait. @@ -172,7 +209,7 @@ impl CachedParquetReader { fn cache_get(&self, range: &Range) -> Option { let cache = self.byte_cache.as_ref()?; let key = (self.path.clone(), range.start, range.end); - cache.lock().ok()?.get(&key).cloned() + cache.lock().ok()?.get(&key) } /// Store bytes in the cache @@ -237,7 +274,7 @@ impl AsyncFileReader for CachedParquetReader { // avoids the allocation+copy of .to_vec() let bytes = Bytes::from_owner(owned_bytes); - // Cache in L1 in-memory cache + // Cache in L1 in-memory cache (tracked by memory pool) if let Some(cache) = byte_cache { let key = (path_for_cache, range.start, range.end); if let Ok(mut guard) = cache.lock() { @@ -316,7 +353,7 @@ impl AsyncFileReader for CachedParquetReader { let bytes = fetched[fetch_idx].clone(); let range = &uncached_ranges[fetch_idx]; - // Cache in L1 in-memory cache + // Cache in L1 in-memory cache (tracked by memory pool) if let Some(ref cache) = byte_cache { let key = ( path_for_cache.clone(), diff --git a/native/src/parquet_companion/indexing.rs b/native/src/parquet_companion/indexing.rs index e182aaa..e7a8860 100644 --- a/native/src/parquet_companion/indexing.rs +++ b/native/src/parquet_companion/indexing.rs @@ -18,6 +18,7 @@ use tantivy::TantivyDocument; use quickwit_storage::Storage; +use crate::memory_pool::{self, MemoryReservation}; use super::manifest::*; use super::page_index::compute_page_locations_from_column_chunk; use super::schema_derivation::{ @@ -377,6 +378,19 @@ pub async fn create_split_from_parquet( let index = tantivy::Index::create_in_dir(&index_dir, tantivy_schema.clone()) .context("Failed to create tantivy index")?; + // Reserve memory from the global pool for this writer's heap. + let _writer_reservation = MemoryReservation::try_new( + &memory_pool::global_pool(), + parquet_config.writer_heap_size, + "index_writer", + ).map_err(|e| { + anyhow::anyhow!( + "Memory pool denied parquet companion writer allocation of {} MB: {}. \ + Reduce writer_heap_size or increase pool capacity.", + parquet_config.writer_heap_size / 1_000_000, e + ) + })?; + // Single-threaded writer ensures docs within each segment are in insertion order. // Merges are allowed (default LogMergePolicy) — the __pq_file_hash and // __pq_row_in_file fast fields make doc→parquet resolution merge-safe. diff --git a/native/src/quickwit_split/merge_impl.rs b/native/src/quickwit_split/merge_impl.rs index 61d713d..f37ffed 100644 --- a/native/src/quickwit_split/merge_impl.rs +++ b/native/src/quickwit_split/merge_impl.rs @@ -26,6 +26,7 @@ use quickwit_directories::UnionDirectory; use quickwit_query::get_quickwit_fastfield_normalizer_manager; use crate::debug_println; +use crate::memory_pool::{self, MemoryReservation}; use crate::runtime_manager::QuickwitRuntimeManager; use super::QuickwitSplitMetadata; use super::merge_config::InternalMergeConfig; @@ -625,6 +626,17 @@ pub async fn merge_split_directories_with_optimization( optimization.num_threads, if optimization.use_random_io { "Random" } else { "Sequential" }); + // Reserve 3x heap from memory pool: writer heap + mmap copies + Vec temporaries + let merge_memory = optimization.heap_size_bytes as usize * 3; + let _merge_reservation = MemoryReservation::try_new( + &memory_pool::global_pool(), + merge_memory, + "merge", + ).map_err(|e| { + anyhow::anyhow!("Memory pool denied merge reservation of {}MB: {}. \ + Reduce merge parallelism or increase pool capacity.", merge_memory / 1_000_000, e) + })?; + // 1. Create output directory with controlled I/O std::fs::create_dir_all(output_path)?; let output_directory = MmapDirectory::open(output_path)?; diff --git a/native/src/searcher/jni_index_writer.rs b/native/src/searcher/jni_index_writer.rs index c3f52fb..5ba826d 100644 --- a/native/src/searcher/jni_index_writer.rs +++ b/native/src/searcher/jni_index_writer.rs @@ -377,6 +377,8 @@ pub extern "system" fn Java_io_indextables_tantivy4java_core_IndexWriter_nativeW ptr: jlong, ) { // wait_merging_threads consumes the IndexWriter, so we need to remove it from the Arc registry + // Also release the memory reservation + crate::index::WRITER_RESERVATIONS.lock().unwrap().remove(&ptr); let writer_arc = { let mut registry = crate::utils::ARC_REGISTRY.lock().unwrap(); registry.remove(&ptr).and_then(|boxed| boxed.downcast::>>().ok().map(|b| *b)) @@ -474,5 +476,7 @@ pub extern "system" fn Java_io_indextables_tantivy4java_core_IndexWriter_nativeC _class: JClass, ptr: jlong, ) { + // Release the memory reservation for this writer (if any) + crate::index::WRITER_RESERVATIONS.lock().unwrap().remove(&ptr); release_arc(ptr); } diff --git a/native/src/split_cache_manager/jni_lifecycle.rs b/native/src/split_cache_manager/jni_lifecycle.rs index 0012285..eb39f22 100644 --- a/native/src/split_cache_manager/jni_lifecycle.rs +++ b/native/src/split_cache_manager/jni_lifecycle.rs @@ -237,7 +237,15 @@ pub extern "system" fn Java_io_indextables_tantivy4java_split_SplitCacheManager_ false } }; - debug_println!("RUST DEBUG: write_queue_mode={:?}, drop_writes_when_full={}", write_queue_mode, drop_writes_when_full); + let max_write_queue_budget = + match env.call_method(&tiered_config_obj, "getMaxWriteQueueBudget", "()J", &[]) { + Ok(result) => result.j().unwrap_or(0) as u64, + Err(e) => { + debug_println!("RUST DEBUG: Failed to get maxWriteQueueBudget: {:?}, defaulting to 0", e); + 0 + } + }; + debug_println!("RUST DEBUG: write_queue_mode={:?}, drop_writes_when_full={}, max_write_queue_budget={}", write_queue_mode, drop_writes_when_full, max_write_queue_budget); // Create disk cache config if we have a path debug_println!("RUST DEBUG: disk_path is_some={}", disk_path.is_some()); @@ -260,6 +268,7 @@ pub extern "system" fn Java_io_indextables_tantivy4java_split_SplitCacheManager_ mmap_cache_size: 0, // Use default (1024) write_queue_mode, drop_writes_when_full, + max_write_queue_budget, }; debug_println!("RUST DEBUG: Calling set_disk_cache with path: {}", path); diff --git a/native/src/split_searcher/aggregation_arrow_ffi.rs b/native/src/split_searcher/aggregation_arrow_ffi.rs index c1f9e70..c9b66de 100644 --- a/native/src/split_searcher/aggregation_arrow_ffi.rs +++ b/native/src/split_searcher/aggregation_arrow_ffi.rs @@ -10,9 +10,11 @@ use std::sync::Arc; use anyhow::{Context, Result}; use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema}; use arrow_array::{ - Float64Array, Int64Array, RecordBatch, StringArray, TimestampMicrosecondArray, + Array, Float64Array, Int64Array, RecordBatch, StringArray, TimestampMicrosecondArray, }; use arrow_schema::{DataType, Field, Schema, TimeUnit}; + +use crate::memory_pool::{self, MemoryReservation}; use tantivy::aggregation::agg_result::{ AggregationResult, BucketEntries, BucketEntry, BucketResult, MetricResult, RangeBucketEntry, @@ -114,6 +116,42 @@ pub fn export_record_batch_ffi( Ok(batch.num_rows()) } +/// Estimate the memory footprint of a RecordBatch in bytes. +/// +/// Sums the buffer sizes of all columns. This is a lower bound since it doesn't +/// account for FFI struct overhead, but captures the bulk of the memory. +fn estimate_record_batch_size(batch: &RecordBatch) -> usize { + batch + .columns() + .iter() + .map(|col| col.get_buffer_memory_size()) + .sum() +} + +/// Export a RecordBatch via FFI with memory pool tracking. +/// +/// Creates a MemoryReservation for the estimated batch size before export, +/// returning both the row count and the reservation. The caller must hold +/// the reservation alive until the Java side has consumed the FFI data. +pub fn export_record_batch_ffi_tracked( + batch: &RecordBatch, + array_addrs: &[i64], + schema_addrs: &[i64], +) -> Result<(usize, MemoryReservation)> { + let estimated_size = estimate_record_batch_size(batch); + + // Best-effort: if pool denies, proceed with empty reservation (data still exported) + let reservation = MemoryReservation::try_new( + &memory_pool::global_pool(), + estimated_size, + "arrow_ffi", + ) + .unwrap_or_else(|_| MemoryReservation::empty(&memory_pool::global_pool(), "arrow_ffi")); + + let row_count = export_record_batch_ffi(batch, array_addrs, schema_addrs)?; + Ok((row_count, reservation)) +} + /// Return a JSON string describing the Arrow schema for the given aggregation result. /// Format: {"columns": [{"name": "key", "type": "Utf8"}, ...], "row_count": N} pub fn aggregation_result_arrow_schema_json( diff --git a/native/src/split_searcher/jni_agg_arrow.rs b/native/src/split_searcher/jni_agg_arrow.rs index 0827b26..e63edec 100644 --- a/native/src/split_searcher/jni_agg_arrow.rs +++ b/native/src/split_searcher/jni_agg_arrow.rs @@ -22,9 +22,9 @@ use crate::runtime_manager::block_on_operation; use crate::searcher::aggregation::json_helpers::is_date_histogram_aggregation; use super::aggregation_arrow_ffi::{ - aggregation_result_arrow_schema_json, aggregation_result_to_record_batch, + aggregation_result_arrow_schema_json, aggregation_result_to_record_batch_with_hash_resolution, - export_record_batch_ffi, + export_record_batch_ffi_tracked, }; use super::jni_search::perform_search_async_impl_leaf_response_with_aggregations; @@ -127,7 +127,7 @@ pub extern "system" fn Java_io_indextables_tantivy4java_split_SplitSearcher_nati ctx.hash_resolution_map.as_ref(), )?; - let row_count = export_record_batch_ffi(&batch, arr_slice, sch_slice)?; + let (row_count, _reservation) = export_record_batch_ffi_tracked(&batch, arr_slice, sch_slice)?; perf_println!( "⏱️ AGG_FFI: nativeAggregateArrowFfi DONE — {} rows, {} cols, {}ms", @@ -321,7 +321,7 @@ pub extern "system" fn Java_io_indextables_tantivy4java_split_SplitCacheManager_ hash_resolution_map.as_ref(), )?; - let row_count = export_record_batch_ffi(&batch, arr_slice, sch_slice)?; + let (row_count, _reservation) = export_record_batch_ffi_tracked(&batch, arr_slice, sch_slice)?; perf_println!( "⏱️ AGG_FFI: nativeMultiSplitAggregateArrowFfi DONE — {} splits, {} rows, {}ms", @@ -607,7 +607,7 @@ pub extern "system" fn Java_io_indextables_tantivy4java_split_SplitCacheManager_ is_date_hist, hash_resolution_map.as_ref(), )?; - let row_count = export_record_batch_ffi(&batch, arr_slice, sch_slice)?; + let (row_count, _reservation) = export_record_batch_ffi_tracked(&batch, arr_slice, sch_slice)?; row_counts.insert(agg_name.clone(), serde_json::json!(row_count)); } diff --git a/native/src/split_searcher/jni_lifecycle.rs b/native/src/split_searcher/jni_lifecycle.rs index 2d17f23..67a88ac 100644 --- a/native/src/split_searcher/jni_lifecycle.rs +++ b/native/src/split_searcher/jni_lifecycle.rs @@ -833,6 +833,16 @@ pub extern "system" fn Java_io_indextables_tantivy4java_split_SplitSearcher_crea .map(|m| crate::parquet_companion::docid_mapping::build_file_hash_index(m)) .unwrap_or_default(); + // Initialize global search arena reservation (max_concurrency × 16MB). + // Fail-fast: if pool denies, searcher creation fails. + if let Err(e) = crate::memory_pool::init_search_arena() { + runtime.unregister_searcher(); + to_java_exception(&mut env, &anyhow::anyhow!( + "Memory pool denied search arena reservation: {}. \ + Reduce max concurrency or increase pool capacity.", e)); + return 0; + } + // Create clean struct-based context instead of complex tuple let cached_context = CachedSearcherContext { standalone_searcher: std::sync::Arc::new(searcher), @@ -867,6 +877,7 @@ pub extern "system" fn Java_io_indextables_tantivy4java_split_SplitSearcher_crea parquet_file_hash_index, has_merge_safe_tracking, pq_columns: std::sync::Arc::new(std::sync::RwLock::new(Vec::new())), + search_arena_reservation: None, // Global arena covers all concurrent slots }; let searcher_context = std::sync::Arc::new(cached_context); diff --git a/native/src/split_searcher/types.rs b/native/src/split_searcher/types.rs index 7595279..271b1e4 100644 --- a/native/src/split_searcher/types.rs +++ b/native/src/split_searcher/types.rs @@ -7,6 +7,7 @@ use std::ops::Range; use quickwit_storage::{Storage, ByteRangeCache}; use crate::perf_println; use crate::debug_println; +use crate::memory_pool::MemoryReservation; use crate::standalone_searcher::StandaloneSearcher; use crate::parquet_companion::manifest::ParquetManifest; @@ -108,6 +109,9 @@ pub(crate) struct CachedSearcherContext { // When true, doc retrieval uses merge-safe fast-field resolution. // When false, falls back to legacy segment-based positional resolution. pub(crate) has_merge_safe_tracking: bool, + // Memory pool reservation for search result arena (16MB pre-acquired budget). + // Held for the lifetime of the SplitSearcher. Released on close via Drop. + pub(crate) search_arena_reservation: Option, // Parquet companion mode: lazily-loaded __pq_file_hash and __pq_row_in_file Column handles. // Indexed as pq_columns[segment_ord] = Some((file_hash_col, row_in_file_col)) // Column supports O(1) random access via values_for_doc(doc_id) — no need to diff --git a/native/src/utils.rs b/native/src/utils.rs index 9952512..127adb6 100644 --- a/native/src/utils.rs +++ b/native/src/utils.rs @@ -18,13 +18,30 @@ */ use jni::JNIEnv; +use jni::JavaVM; use jni::sys::jlong; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, OnceLock}; use std::sync::atomic::{AtomicU64, Ordering}; use once_cell::sync::Lazy; use crate::debug_println; +/// Global JavaVM reference, captured on first JNI call. +/// Required by JvmMemoryPool to attach threads and make JNI callbacks. +static GLOBAL_JVM: OnceLock = OnceLock::new(); + +/// Store the JavaVM reference. Should be called once during initialization. +pub fn set_jvm(env: &JNIEnv) { + GLOBAL_JVM.get_or_init(|| { + env.get_java_vm().expect("Failed to get JavaVM from JNIEnv") + }); +} + +/// Get the stored JavaVM reference. +pub fn get_jvm() -> Option<&'static JavaVM> { + GLOBAL_JVM.get() +} + /// Handle errors by throwing Java exceptions pub fn handle_error(env: &mut JNIEnv, error: &str) { let _ = env.throw_new("java/lang/RuntimeException", error); diff --git a/pom.xml b/pom.xml index be8748d..4be4398 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.indextables tantivy4java - 0.32.4 + 0.32.5 jar Tantivy4Java Experimental diff --git a/src/main/java/io/indextables/tantivy4java/config/GlobalCacheConfig.java b/src/main/java/io/indextables/tantivy4java/config/GlobalCacheConfig.java index b568d34..00ad3fc 100644 --- a/src/main/java/io/indextables/tantivy4java/config/GlobalCacheConfig.java +++ b/src/main/java/io/indextables/tantivy4java/config/GlobalCacheConfig.java @@ -178,6 +178,14 @@ private static native boolean initializeGlobalCache( ); static { - System.loadLibrary("tantivy4java"); + try { + Class.forName("io.indextables.tantivy4java.core.Tantivy"); + } catch (ClassNotFoundException e) { + try { + System.loadLibrary("tantivy4java"); + } catch (UnsatisfiedLinkError ule) { + // Library may already be loaded + } + } } } \ No newline at end of file diff --git a/src/main/java/io/indextables/tantivy4java/config/RuntimeManager.java b/src/main/java/io/indextables/tantivy4java/config/RuntimeManager.java index 7c2e68f..f36daab 100644 --- a/src/main/java/io/indextables/tantivy4java/config/RuntimeManager.java +++ b/src/main/java/io/indextables/tantivy4java/config/RuntimeManager.java @@ -10,8 +10,15 @@ public class RuntimeManager { static { - // Ensure native library is loaded - System.loadLibrary("tantivy4java"); + try { + Class.forName("io.indextables.tantivy4java.core.Tantivy"); + } catch (ClassNotFoundException e) { + try { + System.loadLibrary("tantivy4java"); + } catch (UnsatisfiedLinkError ule) { + // Library may already be loaded + } + } } /** diff --git a/src/main/java/io/indextables/tantivy4java/memory/NativeMemoryAccountant.java b/src/main/java/io/indextables/tantivy4java/memory/NativeMemoryAccountant.java new file mode 100644 index 0000000..e5d9849 --- /dev/null +++ b/src/main/java/io/indextables/tantivy4java/memory/NativeMemoryAccountant.java @@ -0,0 +1,41 @@ +package io.indextables.tantivy4java.memory; + +/** + * Interface for external memory managers to coordinate with native Rust allocations. + * + *

Implementations must be thread-safe. The native layer may call these methods + * from multiple threads concurrently. + * + *

Example: integrate with Spark's TaskMemoryManager: + *

{@code
+ * public class SparkMemoryAccountant implements NativeMemoryAccountant {
+ *     private final TaskMemoryManager taskMemoryManager;
+ *     private final MemoryConsumer consumer;
+ *
+ *     public long acquireMemory(long bytes) {
+ *         return taskMemoryManager.acquireExecutionMemory(bytes, consumer);
+ *     }
+ *
+ *     public void releaseMemory(long bytes) {
+ *         taskMemoryManager.releaseExecutionMemory(bytes, consumer);
+ *     }
+ * }
+ * }
+ */ +public interface NativeMemoryAccountant { + + /** + * Request memory from the external manager. + * + * @param bytes requested number of bytes + * @return actual bytes granted (may be less than requested; 0 = denied) + */ + long acquireMemory(long bytes); + + /** + * Release previously acquired memory back to the external manager. + * + * @param bytes number of bytes to release + */ + void releaseMemory(long bytes); +} diff --git a/src/main/java/io/indextables/tantivy4java/memory/NativeMemoryManager.java b/src/main/java/io/indextables/tantivy4java/memory/NativeMemoryManager.java new file mode 100644 index 0000000..22e7eb9 --- /dev/null +++ b/src/main/java/io/indextables/tantivy4java/memory/NativeMemoryManager.java @@ -0,0 +1,147 @@ +package io.indextables.tantivy4java.memory; + +import java.util.Collections; +import java.util.Map; + +/** + * Global configuration and monitoring for native memory management. + * + *

Call {@link #setAccountant(NativeMemoryAccountant)} once before any native + * operations to enable JVM-coordinated memory tracking. If not called, the native + * layer defaults to unlimited (untracked) mode. + * + *

Example usage: + *

{@code
+ * // At application startup (before any Index/SplitSearcher use)
+ * NativeMemoryManager.setAccountant(new SparkMemoryAccountant(taskMemoryManager));
+ *
+ * // Monitor memory usage
+ * NativeMemoryStats stats = NativeMemoryManager.getStats();
+ * System.out.println("Native memory used: " + stats.getUsedBytes());
+ * System.out.println("Peak usage: " + stats.getPeakBytes());
+ * stats.getCategoryBreakdown().forEach((cat, bytes) ->
+ *     System.out.println("  " + cat + ": " + bytes + " bytes"));
+ * }
+ */ +public class NativeMemoryManager { + + static { + // Trigger Tantivy's static initializer which properly extracts and loads + // the native library from the jar. System.loadLibrary() alone doesn't work + // because the native library is jar-embedded, not on java.library.path. + try { + Class.forName("io.indextables.tantivy4java.core.Tantivy"); + } catch (ClassNotFoundException e) { + // Fallback for environments where Tantivy class is not available + try { + System.loadLibrary("tantivy4java"); + } catch (UnsatisfiedLinkError ule) { + // Library may already be loaded by another class + } + } + } + + private NativeMemoryManager() { + // Static utility class + } + + /** + * Set the global memory accountant for native allocations. + * + *

Must be called before any Index, SplitSearcher, or SplitCacheManager use. + * Can only be called once; subsequent calls return false. + * + * @param accountant the memory accountant to use + * @return true if set successfully, false if already configured + */ + public static boolean setAccountant(NativeMemoryAccountant accountant) { + return setAccountant(accountant, 0.90, 0.25, 64 * 1024 * 1024, 64 * 1024 * 1024); + } + + /** + * Set the global memory accountant with custom watermark configuration. + * + * @param accountant the memory accountant to use + * @param highWatermark acquire more from JVM when usage exceeds this fraction of grant (default 0.90) + * @param lowWatermark release excess to JVM when usage drops below this fraction (default 0.25) + * @param acquireIncrementBytes minimum JNI acquire chunk size in bytes (default 64MB) + * @param minReleaseBytes minimum amount to release back in bytes (default 64MB) + * @return true if set successfully, false if already configured + */ + public static boolean setAccountant( + NativeMemoryAccountant accountant, + double highWatermark, + double lowWatermark, + long acquireIncrementBytes, + long minReleaseBytes) { + if (accountant == null) { + throw new IllegalArgumentException("accountant must not be null"); + } + return nativeSetAccountant(accountant, highWatermark, lowWatermark, + acquireIncrementBytes, minReleaseBytes); + } + + /** + * Check if a custom memory accountant has been configured. + */ + public static boolean isConfigured() { + return nativeIsConfigured(); + } + + /** + * Reset the peak usage counter to current usage. + * + *

Useful for monitoring windows — call at the start of each window to + * track per-window peak usage. + * + * @return the old peak value in bytes + */ + public static long resetPeak() { + return nativeResetPeak(); + } + + /** + * Get current native memory statistics. + */ + public static NativeMemoryStats getStats() { + long used = nativeGetUsedBytes(); + long peak = nativeGetPeakBytes(); + long granted = nativeGetGrantedBytes(); + Map breakdown = nativeGetCategoryBreakdown(); + if (breakdown == null) { + breakdown = Collections.emptyMap(); + } + Map peakBreakdown = nativeGetCategoryPeakBreakdown(); + if (peakBreakdown == null) { + peakBreakdown = Collections.emptyMap(); + } + return new NativeMemoryStats(used, peak, granted, breakdown, peakBreakdown); + } + + /** + * Signal that the JVM is shutting down. + * + *

After this call, the native pool skips JNI release callbacks to avoid + * calling {@code releaseMemory()} outside of a task context (e.g., on shutdown + * hook threads where Spark's TaskContext is unavailable). + * + *

Call this before any shutdown hooks that trigger native resource cleanup. + */ + public static void shutdown() { + nativeShutdown(); + } + + // Native methods + private static native boolean nativeSetAccountant( + Object accountant, double highWatermark, double lowWatermark, + long acquireIncrementBytes, long minReleaseBytes); + + private static native long nativeGetUsedBytes(); + private static native long nativeGetPeakBytes(); + private static native long nativeGetGrantedBytes(); + private static native long nativeResetPeak(); + private static native boolean nativeIsConfigured(); + private static native Map nativeGetCategoryBreakdown(); + private static native Map nativeGetCategoryPeakBreakdown(); + private static native void nativeShutdown(); +} diff --git a/src/main/java/io/indextables/tantivy4java/memory/NativeMemoryStats.java b/src/main/java/io/indextables/tantivy4java/memory/NativeMemoryStats.java new file mode 100644 index 0000000..67ffbcd --- /dev/null +++ b/src/main/java/io/indextables/tantivy4java/memory/NativeMemoryStats.java @@ -0,0 +1,105 @@ +package io.indextables.tantivy4java.memory; + +import java.util.Collections; +import java.util.Map; + +/** + * Snapshot of native memory pool statistics. + * + *

Categories tracked: + *

    + *
  • {@code index_writer} — IndexWriter heap budget
  • + *
  • {@code merge} — In-process merge operations (3x heap for copies + mmaps)
  • + *
  • {@code l1_cache} — L1 ByteRangeCache memory
  • + *
  • {@code l2_write_queue} — L2 disk cache write queue buffer
  • + *
  • {@code arrow_ffi} — Arrow RecordBatch FFI exports
  • + *
  • {@code search_results} — Search result pre-allocated arenas
  • + *
  • {@code parquet_transcode} — Parquet fast field transcoding buffers
  • + *
+ */ +public class NativeMemoryStats { + + private final long usedBytes; + private final long peakBytes; + private final long grantedBytes; + private final Map categoryBreakdown; + private final Map categoryPeakBreakdown; + + NativeMemoryStats(long usedBytes, long peakBytes, long grantedBytes, + Map categoryBreakdown, + Map categoryPeakBreakdown) { + this.usedBytes = usedBytes; + this.peakBytes = peakBytes; + this.grantedBytes = grantedBytes; + this.categoryBreakdown = Collections.unmodifiableMap(categoryBreakdown); + this.categoryPeakBreakdown = Collections.unmodifiableMap(categoryPeakBreakdown); + } + + // Backward-compatible constructor + NativeMemoryStats(long usedBytes, long peakBytes, long grantedBytes, + Map categoryBreakdown) { + this(usedBytes, peakBytes, grantedBytes, categoryBreakdown, Collections.emptyMap()); + } + + /** Current total bytes reserved by native code. */ + public long getUsedBytes() { + return usedBytes; + } + + /** Peak bytes observed since pool creation. */ + public long getPeakBytes() { + return peakBytes; + } + + /** + * Total bytes granted by the external memory manager. + * Returns -1 if using unlimited (untracked) mode. + */ + public long getGrantedBytes() { + return grantedBytes; + } + + /** Per-category current memory breakdown (only non-zero categories). */ + public Map getCategoryBreakdown() { + return categoryBreakdown; + } + + /** + * Per-category peak memory breakdown. + * + *

Returns the maximum bytes each category has ever held, even if + * currently zero. Useful for post-hoc analysis when all reservations + * have been released by the time this is called. + */ + public Map getCategoryPeakBreakdown() { + return categoryPeakBreakdown; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("NativeMemoryStats{"); + sb.append("used=").append(formatBytes(usedBytes)); + sb.append(", peak=").append(formatBytes(peakBytes)); + sb.append(", granted=").append(grantedBytes < 0 ? "unlimited" : formatBytes(grantedBytes)); + if (!categoryBreakdown.isEmpty()) { + sb.append(", categories={"); + boolean first = true; + for (Map.Entry entry : categoryBreakdown.entrySet()) { + if (!first) sb.append(", "); + sb.append(entry.getKey()).append("=").append(formatBytes(entry.getValue())); + first = false; + } + sb.append("}"); + } + sb.append("}"); + return sb.toString(); + } + + private static String formatBytes(long bytes) { + if (bytes < 1024) return bytes + "B"; + if (bytes < 1024 * 1024) return String.format("%.1fKB", bytes / 1024.0); + if (bytes < 1024L * 1024 * 1024) return String.format("%.1fMB", bytes / (1024.0 * 1024)); + return String.format("%.1fGB", bytes / (1024.0 * 1024 * 1024)); + } +} diff --git a/src/main/java/io/indextables/tantivy4java/memory/UnlimitedMemoryAccountant.java b/src/main/java/io/indextables/tantivy4java/memory/UnlimitedMemoryAccountant.java new file mode 100644 index 0000000..80277ec --- /dev/null +++ b/src/main/java/io/indextables/tantivy4java/memory/UnlimitedMemoryAccountant.java @@ -0,0 +1,18 @@ +package io.indextables.tantivy4java.memory; + +/** + * Default memory accountant that always grants the full requested amount. + * Used when no external memory manager is configured (backward-compatible default). + */ +public class UnlimitedMemoryAccountant implements NativeMemoryAccountant { + + @Override + public long acquireMemory(long bytes) { + return bytes; + } + + @Override + public void releaseMemory(long bytes) { + // No-op: unlimited pool doesn't track externally + } +} diff --git a/src/main/java/io/indextables/tantivy4java/split/SplitCacheManager.java b/src/main/java/io/indextables/tantivy4java/split/SplitCacheManager.java index ceb9b31..5866766 100644 --- a/src/main/java/io/indextables/tantivy4java/split/SplitCacheManager.java +++ b/src/main/java/io/indextables/tantivy4java/split/SplitCacheManager.java @@ -75,6 +75,13 @@ public class SplitCacheManager implements AutoCloseable { Tantivy.initialize(); // Add shutdown hook to gracefully cleanup all cache instances Runtime.getRuntime().addShutdownHook(new Thread(() -> { + // Signal the native memory pool to skip JNI release callbacks + // during shutdown — there's no Spark TaskContext on this thread. + try { + io.indextables.tantivy4java.memory.NativeMemoryManager.shutdown(); + } catch (UnsatisfiedLinkError | NoClassDefFoundError e) { + // Memory pool may not be configured + } synchronized (instances) { for (SplitCacheManager manager : instances.values()) { try { @@ -519,6 +526,7 @@ public static class TieredCacheConfig { private int writeQueueCapacity = 16; // used ONLY when mode=FRAGMENT private long writeQueueMaxBytes = 2_147_483_648L; // used ONLY when mode=SIZE_BASED private boolean dropWritesWhenFull = false; // query-path writes drop instead of block + private long maxWriteQueueBudget = 0; // 0 = default (8x initial write queue size) /** * Set the disk cache directory path. @@ -663,6 +671,28 @@ public TieredCacheConfig withWriteQueueSizeLimit(long maxBytes) { return this; } + /** + * Set the maximum memory budget for the write queue. + * + *

This caps how much memory the write queue can acquire from the JVM memory pool + * via staircase-up growth. When the queue needs more memory than its initial allocation, + * it grows in 500MB increments up to this cap. When the queue drains, overflow is released. + * + *

Default is 0, which means 8x the initial write queue size. For example, with a 2GB + * write queue, the default cap is 16GB. Set this lower to bound memory growth in + * memory-constrained environments. + * + *

Only effective when a JVM memory pool is configured via + * {@link io.indextables.tantivy4java.core.NativeMemoryManager}. + * + * @param bytes maximum budget in bytes (0 = default 8x initial) + * @return this TieredCacheConfig for method chaining + */ + public TieredCacheConfig withMaxWriteQueueBudget(long bytes) { + this.maxWriteQueueBudget = bytes; + return this; + } + /** * Enable dropping query-path writes when the write queue is full. * @@ -697,6 +727,8 @@ public TieredCacheConfig withDropWritesWhenFull(boolean drop) { public long getWriteQueueMaxBytes() { return writeQueueMaxBytes; } /** @return whether query-path writes are dropped when the queue is full */ public boolean isDropWritesWhenFull() { return dropWritesWhenFull; } + /** @return maximum write queue memory budget in bytes (0 = default 8x initial) */ + public long getMaxWriteQueueBudget() { return maxWriteQueueBudget; } /** * Convert compression algorithm to ordinal for native layer. diff --git a/src/test/java/io/indextables/tantivy4java/memory/MemoryDenialTest.java b/src/test/java/io/indextables/tantivy4java/memory/MemoryDenialTest.java new file mode 100644 index 0000000..41d6f0e --- /dev/null +++ b/src/test/java/io/indextables/tantivy4java/memory/MemoryDenialTest.java @@ -0,0 +1,139 @@ +package io.indextables.tantivy4java.memory; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.TestMethodOrder; +import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; +import org.junit.jupiter.api.BeforeAll; + +import java.util.concurrent.atomic.AtomicLong; + +import static org.junit.jupiter.api.Assertions.*; + +import io.indextables.tantivy4java.core.*; + +/** + * Tests that memory pool denial properly prevents operations (fail-fast behavior). + * + * IMPORTANT: This test class MUST run in its own forked JVM because it calls + * NativeMemoryManager.setAccountant() with a LimitedAccountant. Since the global + * pool can only be set once per JVM, this test class must not share a JVM with + * NativeMemoryManagerTest (which sets a TrackingAccountant). + * + * Configure in Maven surefire: + * -Dtest="MemoryDenialTest" with forkCount=1, reuseForks=false + * Or run separately: mvn test -Dtest="MemoryDenialTest" -DforkCount=1 -DreuseForks=false + */ +@TestMethodOrder(OrderAnnotation.class) +public class MemoryDenialTest { + + @BeforeAll + static void ensureNativeLoaded() { + try { + Class.forName("io.indextables.tantivy4java.core.Index"); + } catch (ClassNotFoundException e) { + fail("Native library not available"); + } + } + + // ======================================================================== + // Step 1: Configure a very tight memory limit + // ======================================================================== + + @Test + @Order(1) + void testSetLimitedAccountant() { + // Set a 10MB limit — enough for small operations but not for DEFAULT_HEAP_SIZE (50MB) + LimitedAccountant accountant = new LimitedAccountant(10_000_000); + boolean result = NativeMemoryManager.setAccountant(accountant); + assertTrue(result, "setAccountant should succeed on first call"); + assertTrue(NativeMemoryManager.isConfigured()); + } + + // ======================================================================== + // Step 2: Verify IndexWriter creation fails fast when pool denies + // ======================================================================== + + @Test + @Order(2) + void testIndexWriterDeniedWhenPoolTight() throws Exception { + try (SchemaBuilder builder = new SchemaBuilder()) { + builder.addTextField("title", true, false, "default", "position"); + try (Schema schema = builder.build()) { + try (Index index = new Index(schema, "", true)) { + // DEFAULT_HEAP_SIZE is 50MB, pool limit is 10MB — should be denied + Exception exception = assertThrows(Exception.class, () -> { + index.writer(Index.Memory.DEFAULT_HEAP_SIZE, 1); + }); + String msg = exception.getMessage(); + assertTrue(msg != null && msg.toLowerCase().contains("denied"), + "Exception should mention denial, got: " + msg); + } + } + } + } + + @Test + @Order(3) + void testSmallWriterSucceedsUnderLimit() throws Exception { + // MIN_HEAP_SIZE (15MB) might fit depending on watermark acquire increment. + // The JvmMemoryPool acquires in 64MB chunks by default, so even MIN_HEAP_SIZE + // will trigger a 64MB JNI acquire which exceeds our 10MB limit. + // This test verifies that the limit is properly enforced. + try (SchemaBuilder builder = new SchemaBuilder()) { + builder.addTextField("title", true, false, "default", "position"); + try (Schema schema = builder.build()) { + try (Index index = new Index(schema, "", true)) { + // Even MIN_HEAP_SIZE should be denied because the JvmMemoryPool + // will try to acquire max(15MB, 64MB) = 64MB from the JVM accountant, + // which exceeds our 10MB limit. + Exception exception = assertThrows(Exception.class, () -> { + index.writer(Index.Memory.MIN_HEAP_SIZE, 1); + }); + assertNotNull(exception.getMessage(), + "Should get a meaningful error message on denial"); + } + } + } + } + + @Test + @Order(4) + void testPoolStatsAfterDenial() { + // After denial, pool should still be functional for queries + NativeMemoryStats stats = NativeMemoryManager.getStats(); + assertNotNull(stats); + // Used should be 0 or very small (no successful reservations) + assertTrue(stats.getUsedBytes() >= 0); + // Granted should reflect what the accountant actually gave + assertTrue(stats.getGrantedBytes() >= 0); + } + + // ======================================================================== + // Helper: Limited Accountant (same as in NativeMemoryManagerTest) + // ======================================================================== + + static class LimitedAccountant implements NativeMemoryAccountant { + private final long maxBytes; + private final AtomicLong currentUsage = new AtomicLong(0); + + LimitedAccountant(long maxBytes) { + this.maxBytes = maxBytes; + } + + @Override + public synchronized long acquireMemory(long bytes) { + long current = currentUsage.get(); + if (current + bytes > maxBytes) { + return 0; // Denied + } + currentUsage.addAndGet(bytes); + return bytes; + } + + @Override + public void releaseMemory(long bytes) { + currentUsage.addAndGet(-bytes); + } + } +} diff --git a/src/test/java/io/indextables/tantivy4java/memory/NativeMemoryManagerTest.java b/src/test/java/io/indextables/tantivy4java/memory/NativeMemoryManagerTest.java new file mode 100644 index 0000000..44b48fd --- /dev/null +++ b/src/test/java/io/indextables/tantivy4java/memory/NativeMemoryManagerTest.java @@ -0,0 +1,570 @@ +package io.indextables.tantivy4java.memory; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.TestMethodOrder; +import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; +import org.junit.jupiter.api.BeforeAll; + +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.*; + +import io.indextables.tantivy4java.core.*; + +/** + * Integration tests for the unified memory management system. + * + * Tests are ordered: default-pool tests run first (@Order 1-49), + * then setAccountant tests run last (@Order 50+), since OnceLock + * means setAccountant can only be called once per JVM. + */ +@TestMethodOrder(OrderAnnotation.class) +public class NativeMemoryManagerTest { + + @BeforeAll + static void ensureNativeLoaded() { + // Force native library load via Tantivy version check + try { + Class.forName("io.indextables.tantivy4java.core.Index"); + } catch (ClassNotFoundException e) { + fail("Native library not available"); + } + } + + // ======================================================================== + // Basic API Tests (default unlimited pool — must run before setAccountant) + // ======================================================================== + + @Test + @Order(1) + void testDefaultUnlimitedMode() { + // Without calling setAccountant, the pool should be in unlimited mode + NativeMemoryStats stats = NativeMemoryManager.getStats(); + assertNotNull(stats, "Stats should never be null"); + assertTrue(stats.getUsedBytes() >= 0, "Used bytes should be non-negative"); + assertTrue(stats.getPeakBytes() >= 0, "Peak bytes should be non-negative"); + // Default pool reports granted as -1 (unlimited) + assertEquals(-1, stats.getGrantedBytes(), "Default pool should report unlimited (-1)"); + } + + @Test + @Order(2) + void testStatsToString() { + NativeMemoryStats stats = NativeMemoryManager.getStats(); + String str = stats.toString(); + assertNotNull(str); + assertTrue(str.contains("NativeMemoryStats{"), "Should contain class name"); + assertTrue(str.contains("used="), "Should contain used field"); + assertTrue(str.contains("peak="), "Should contain peak field"); + assertTrue(str.contains("granted="), "Should contain granted field"); + } + + @Test + @Order(3) + void testUnlimitedAccountantAlwaysGrants() { + UnlimitedMemoryAccountant accountant = new UnlimitedMemoryAccountant(); + assertEquals(1000, accountant.acquireMemory(1000)); + assertEquals(1_000_000_000L, accountant.acquireMemory(1_000_000_000L)); + // releaseMemory is a no-op — just verify it doesn't throw + accountant.releaseMemory(1000); + } + + @Test + @Order(4) + void testSetAccountantRejectsNull() { + assertThrows(IllegalArgumentException.class, () -> { + NativeMemoryManager.setAccountant(null); + }); + } + + // ======================================================================== + // Mock Accountant Tests (pure Java, no native calls) + // ======================================================================== + + @Test + @Order(5) + void testTrackingAccountantRecordsOperations() { + TrackingAccountant accountant = new TrackingAccountant(); + + // Simulate what the native layer does + long acquired = accountant.acquireMemory(1024); + assertEquals(1024, acquired); + assertEquals(1024, accountant.getTotalAcquired()); + assertEquals(1, accountant.getAcquireCallCount()); + + accountant.releaseMemory(512); + assertEquals(512, accountant.getTotalReleased()); + assertEquals(1, accountant.getReleaseCallCount()); + } + + @Test + @Order(6) + void testLimitedAccountantDeniesExcessiveRequests() { + LimitedAccountant accountant = new LimitedAccountant(1_000_000); // 1MB limit + + assertEquals(500_000, accountant.acquireMemory(500_000)); // OK + assertEquals(500_000, accountant.acquireMemory(500_000)); // OK, at limit + assertEquals(0, accountant.acquireMemory(1)); // Denied, over limit + + accountant.releaseMemory(100_000); + assertEquals(100_000, accountant.acquireMemory(100_000)); // OK, freed some + } + + @Test + @Order(7) + void testLimitedAccountantPartialGrant() { + LimitedAccountant accountant = new LimitedAccountant(1_000_000); + + accountant.acquireMemory(800_000); + // Only 200KB left, requesting 500KB should get 0 (denied) + long acquired = accountant.acquireMemory(500_000); + assertEquals(0, acquired, "Should deny when insufficient memory"); + } + + // ======================================================================== + // Thread Safety Tests (pure Java accountant tests) + // ======================================================================== + + @Test + @Order(8) + void testTrackingAccountantThreadSafety() throws Exception { + TrackingAccountant accountant = new TrackingAccountant(); + int numThreads = 10; + int opsPerThread = 1000; + CountDownLatch latch = new CountDownLatch(numThreads); + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + + for (int i = 0; i < numThreads; i++) { + executor.submit(() -> { + try { + for (int j = 0; j < opsPerThread; j++) { + accountant.acquireMemory(100); + accountant.releaseMemory(100); + } + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + executor.shutdown(); + + assertEquals(numThreads * opsPerThread, accountant.getAcquireCallCount()); + assertEquals(numThreads * opsPerThread, accountant.getReleaseCallCount()); + assertEquals( + accountant.getTotalAcquired(), + accountant.getTotalReleased(), + "All acquired memory should be released" + ); + } + + @Test + @Order(9) + void testLimitedAccountantThreadSafety() throws Exception { + long limit = 10_000_000; // 10MB + LimitedAccountant accountant = new LimitedAccountant(limit); + int numThreads = 8; + int opsPerThread = 500; + CountDownLatch latch = new CountDownLatch(numThreads); + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + AtomicLong totalAcquired = new AtomicLong(0); + AtomicLong totalReleased = new AtomicLong(0); + AtomicLong deniedCount = new AtomicLong(0); + + for (int i = 0; i < numThreads; i++) { + executor.submit(() -> { + try { + for (int j = 0; j < opsPerThread; j++) { + long acquired = accountant.acquireMemory(10_000); + if (acquired > 0) { + totalAcquired.addAndGet(acquired); + accountant.releaseMemory(acquired); + totalReleased.addAndGet(acquired); + } else { + deniedCount.incrementAndGet(); + } + } + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + executor.shutdown(); + + assertEquals(totalAcquired.get(), totalReleased.get(), + "All acquired memory should be released"); + assertTrue(accountant.getCurrentUsage() == 0, + "Final usage should be zero"); + } + + // ======================================================================== + // NativeMemoryStats Tests + // ======================================================================== + + @Test + @Order(10) + void testStatsFormatBytes() { + // Test the toString formatting with different magnitudes + NativeMemoryStats small = new NativeMemoryStats(100, 200, -1, + Map.of("test", 100L)); + assertTrue(small.toString().contains("100B")); + + NativeMemoryStats medium = new NativeMemoryStats(1_500_000, 2_000_000, 5_000_000, + Map.of("l1_cache", 1_000_000L, "index_writer", 500_000L)); + String str = medium.toString(); + assertTrue(str.contains("MB"), "Should format as MB"); + assertTrue(str.contains("l1_cache"), "Should include category"); + assertTrue(str.contains("index_writer"), "Should include category"); + } + + @Test + @Order(11) + void testStatsCategoryBreakdownIsUnmodifiable() { + NativeMemoryStats stats = new NativeMemoryStats(0, 0, 0, + Map.of("test", 100L)); + assertThrows(UnsupportedOperationException.class, () -> { + stats.getCategoryBreakdown().put("hack", 999L); + }); + } + + // ======================================================================== + // Integration: IndexWriter category tracking (default pool) + // ======================================================================== + + @Test + @Order(20) + void testIndexWriterCategoryTracking() throws Exception { + // Record baseline before creating writer + long baselineUsed = NativeMemoryManager.getStats().getUsedBytes(); + + try (SchemaBuilder builder = new SchemaBuilder()) { + builder.addTextField("title", true, false, "default", "position"); + builder.addIntegerField("score", true, true, false); + + try (Schema schema = builder.build()) { + try (Index index = new Index(schema, "", true)) { + try (IndexWriter writer = index.writer(Index.Memory.DEFAULT_HEAP_SIZE, 1)) { + // While writer is open, index_writer category should be tracked + NativeMemoryStats duringWrite = NativeMemoryManager.getStats(); + long usedDuringWrite = duringWrite.getUsedBytes(); + Map breakdown = duringWrite.getCategoryBreakdown(); + + assertTrue(usedDuringWrite > baselineUsed, + "Used bytes should increase when writer is open: was " + baselineUsed + ", now " + usedDuringWrite); + assertTrue(breakdown.containsKey("index_writer"), + "Category breakdown should contain 'index_writer', got: " + breakdown); + long writerBytes = breakdown.get("index_writer"); + assertTrue(writerBytes >= Index.Memory.DEFAULT_HEAP_SIZE, + "index_writer should reserve at least DEFAULT_HEAP_SIZE (" + + Index.Memory.DEFAULT_HEAP_SIZE + "), got: " + writerBytes); + + // Add a document and commit (verifies operations work with tracking) + try (Document doc = new Document()) { + doc.addText("title", "Memory management test"); + doc.addInteger("score", 42); + writer.addDocument(doc); + } + writer.commit(); + } + + // After writer.close(), reservation should be released + NativeMemoryStats afterClose = NativeMemoryManager.getStats(); + Map afterBreakdown = afterClose.getCategoryBreakdown(); + long afterWriterBytes = afterBreakdown.getOrDefault("index_writer", 0L); + assertTrue(afterWriterBytes == 0 || !afterBreakdown.containsKey("index_writer"), + "index_writer category should be 0 or absent after close, got: " + afterWriterBytes); + } + } + } + } + + @Test + @Order(21) + void testResetPeak() throws Exception { + // Reset peak and verify it returns something reasonable + long oldPeak = NativeMemoryManager.resetPeak(); + assertTrue(oldPeak >= 0, "Old peak should be non-negative"); + + // After reset, peak should equal current used + NativeMemoryStats stats = NativeMemoryManager.getStats(); + assertEquals(stats.getUsedBytes(), stats.getPeakBytes(), + "After resetPeak, peak should equal current used"); + + // Create a writer to push peak up + try (SchemaBuilder builder = new SchemaBuilder()) { + builder.addTextField("title", true, false, "default", "position"); + try (Schema schema = builder.build()) { + try (Index index = new Index(schema, "", true)) { + try (IndexWriter writer = index.writer(Index.Memory.DEFAULT_HEAP_SIZE, 1)) { + NativeMemoryStats duringWrite = NativeMemoryManager.getStats(); + assertTrue(duringWrite.getPeakBytes() >= Index.Memory.DEFAULT_HEAP_SIZE, + "Peak should reflect writer allocation"); + } + } + } + } + + // After close, peak should still be high (peak tracks max) + NativeMemoryStats afterClose = NativeMemoryManager.getStats(); + assertTrue(afterClose.getPeakBytes() >= Index.Memory.DEFAULT_HEAP_SIZE, + "Peak should still reflect past writer allocation"); + + // Reset peak again — now it should drop to current (no writer) + NativeMemoryManager.resetPeak(); + NativeMemoryStats afterReset = NativeMemoryManager.getStats(); + assertTrue(afterReset.getPeakBytes() < Index.Memory.DEFAULT_HEAP_SIZE, + "After second reset with no writer, peak should be below DEFAULT_HEAP_SIZE, got: " + + afterReset.getPeakBytes()); + } + + @Test + @Order(22) + void testWriterCloseReleasesReservation() throws Exception { + long beforeUsed = NativeMemoryManager.getStats().getUsedBytes(); + + try (SchemaBuilder builder = new SchemaBuilder()) { + builder.addTextField("body", true, false, "default", "position"); + try (Schema schema = builder.build()) { + try (Index index = new Index(schema, "", true)) { + // Open writer — memory goes up + IndexWriter writer = index.writer(Index.Memory.DEFAULT_HEAP_SIZE, 1); + long duringUsed = NativeMemoryManager.getStats().getUsedBytes(); + assertTrue(duringUsed >= beforeUsed + Index.Memory.DEFAULT_HEAP_SIZE, + "Memory should increase by at least heap size during writer lifetime"); + + // Close writer — memory goes back down + writer.close(); + long afterUsed = NativeMemoryManager.getStats().getUsedBytes(); + assertTrue(afterUsed < duringUsed, + "Memory should decrease after writer close: during=" + duringUsed + ", after=" + afterUsed); + // Should be approximately back to baseline (within some tolerance for other state) + assertTrue(afterUsed - beforeUsed < Index.Memory.DEFAULT_HEAP_SIZE, + "Memory delta after close should be less than heap size: delta=" + (afterUsed - beforeUsed)); + } + } + } + } + + @Test + @Order(23) + void testMultipleWritersConcurrentCategories() throws Exception { + try (SchemaBuilder builder = new SchemaBuilder()) { + builder.addTextField("text", true, false, "default", "position"); + try (Schema schema = builder.build()) { + try (Index index1 = new Index(schema, "", true); + Index index2 = new Index(schema, "", true)) { + + try (IndexWriter writer1 = index1.writer(Index.Memory.DEFAULT_HEAP_SIZE, 1); + IndexWriter writer2 = index2.writer(Index.Memory.DEFAULT_HEAP_SIZE, 1)) { + + NativeMemoryStats stats = NativeMemoryManager.getStats(); + Map breakdown = stats.getCategoryBreakdown(); + + assertTrue(breakdown.containsKey("index_writer"), + "Should track index_writer category"); + long writerBytes = breakdown.get("index_writer"); + // Two writers, each at DEFAULT_HEAP_SIZE + assertTrue(writerBytes >= 2 * Index.Memory.DEFAULT_HEAP_SIZE, + "Two writers should reserve at least 2x DEFAULT_HEAP_SIZE (" + + (2 * Index.Memory.DEFAULT_HEAP_SIZE) + "), got: " + writerBytes); + } + + // Both closed — index_writer should be 0 + NativeMemoryStats afterClose = NativeMemoryManager.getStats(); + long afterWriterBytes = afterClose.getCategoryBreakdown().getOrDefault("index_writer", 0L); + assertEquals(0, afterWriterBytes, + "index_writer should be 0 after both writers closed"); + } + } + } + } + + @Test + @Order(24) + void testIndexWriterWithDefaultMemoryPool() throws Exception { + // Verify that basic indexing works with the default (unlimited) memory pool + try (SchemaBuilder builder = new SchemaBuilder()) { + builder.addTextField("title", true, false, "default", "position"); + builder.addIntegerField("score", true, true, false); + + try (Schema schema = builder.build()) { + try (Index index = new Index(schema, "", true)) { + try (IndexWriter writer = index.writer(Index.Memory.DEFAULT_HEAP_SIZE, 1)) { + try (Document doc = new Document()) { + doc.addText("title", "Memory management test"); + doc.addInteger("score", 42); + writer.addDocument(doc); + } + writer.commit(); + } + + // Verify stats are accessible + NativeMemoryStats stats = NativeMemoryManager.getStats(); + assertNotNull(stats); + assertTrue(stats.getUsedBytes() >= 0); + } + } + } + } + + // ======================================================================== + // End-to-End: setAccountant with JNI bridge (MUST RUN LAST — OnceLock) + // ======================================================================== + + @Test + @Order(50) + void testSetAccountantEndToEnd() throws Exception { + // This test calls setAccountant() which sets the OnceLock. + // All subsequent tests in this JVM will use this accountant. + TrackingAccountant accountant = new TrackingAccountant(); + boolean result = NativeMemoryManager.setAccountant(accountant); + assertTrue(result, "First setAccountant call should succeed"); + assertTrue(NativeMemoryManager.isConfigured(), "Pool should be configured after setAccountant"); + + // Second call should fail (OnceLock already set) + TrackingAccountant accountant2 = new TrackingAccountant(); + boolean result2 = NativeMemoryManager.setAccountant(accountant2); + assertFalse(result2, "Second setAccountant call should fail"); + + // Now do a real native operation — create an IndexWriter + try (SchemaBuilder builder = new SchemaBuilder()) { + builder.addTextField("title", true, false, "default", "position"); + try (Schema schema = builder.build()) { + try (Index index = new Index(schema, "", true)) { + try (IndexWriter writer = index.writer(Index.Memory.DEFAULT_HEAP_SIZE, 1)) { + // The JVM accountant should have been called via JNI + assertTrue(accountant.getAcquireCallCount() > 0, + "JVM accountant should have received acquire calls via JNI, got: " + + accountant.getAcquireCallCount()); + assertTrue(accountant.getTotalAcquired() >= Index.Memory.DEFAULT_HEAP_SIZE, + "JVM accountant should have acquired at least DEFAULT_HEAP_SIZE (" + + Index.Memory.DEFAULT_HEAP_SIZE + "), got: " + accountant.getTotalAcquired()); + + // Stats should reflect the JVM-backed pool + NativeMemoryStats stats = NativeMemoryManager.getStats(); + assertTrue(stats.getGrantedBytes() >= 0, + "JVM pool should report non-negative granted, got: " + stats.getGrantedBytes()); + assertTrue(stats.getUsedBytes() > 0, + "JVM pool should report positive used during writer lifetime"); + + // Add document and commit + try (Document doc = new Document()) { + doc.addText("title", "JVM accountant test"); + writer.addDocument(doc); + } + writer.commit(); + } + } + } + } + + // After writer close, memory should be released from the pool + NativeMemoryStats afterCloseStats = NativeMemoryManager.getStats(); + assertEquals(0, afterCloseStats.getUsedBytes(), + "Pool used bytes should be 0 after writer close"); + + // The JVM accountant may or may not have received release callbacks + // depending on watermark batching — the important thing is that the + // pool's internal tracking shows 0 used bytes. + } + + @Test + @Order(51) + void testCategoryBreakdownWithJvmPool() throws Exception { + // Now running with JVM-backed pool (set in testSetAccountantEndToEnd) + // Verify categories still work correctly + try (SchemaBuilder builder = new SchemaBuilder()) { + builder.addTextField("content", true, false, "default", "position"); + try (Schema schema = builder.build()) { + try (Index index = new Index(schema, "", true)) { + try (IndexWriter writer = index.writer(Index.Memory.LARGE_HEAP_SIZE, 1)) { + NativeMemoryStats stats = NativeMemoryManager.getStats(); + Map breakdown = stats.getCategoryBreakdown(); + + assertTrue(breakdown.containsKey("index_writer"), + "JVM pool should track index_writer category, got: " + breakdown); + assertTrue(breakdown.get("index_writer") >= Index.Memory.LARGE_HEAP_SIZE, + "index_writer should reserve at least LARGE_HEAP_SIZE"); + } + } + } + } + } + + // ======================================================================== + // Helper: Tracking Accountant (records all operations) + // ======================================================================== + + /** + * A test accountant that tracks all acquire/release calls. + * Always grants the full request. + */ + static class TrackingAccountant implements NativeMemoryAccountant { + private final AtomicLong totalAcquired = new AtomicLong(0); + private final AtomicLong totalReleased = new AtomicLong(0); + private final AtomicLong acquireCallCount = new AtomicLong(0); + private final AtomicLong releaseCallCount = new AtomicLong(0); + + @Override + public long acquireMemory(long bytes) { + acquireCallCount.incrementAndGet(); + totalAcquired.addAndGet(bytes); + return bytes; + } + + @Override + public void releaseMemory(long bytes) { + releaseCallCount.incrementAndGet(); + totalReleased.addAndGet(bytes); + } + + public long getTotalAcquired() { return totalAcquired.get(); } + public long getTotalReleased() { return totalReleased.get(); } + public long getAcquireCallCount() { return acquireCallCount.get(); } + public long getReleaseCallCount() { return releaseCallCount.get(); } + } + + // ======================================================================== + // Helper: Limited Accountant (enforces a memory limit) + // ======================================================================== + + /** + * A test accountant that enforces a hard memory limit. + * Returns 0 if the request would exceed the limit. + */ + static class LimitedAccountant implements NativeMemoryAccountant { + private final long maxBytes; + private final AtomicLong currentUsage = new AtomicLong(0); + + LimitedAccountant(long maxBytes) { + this.maxBytes = maxBytes; + } + + @Override + public synchronized long acquireMemory(long bytes) { + long current = currentUsage.get(); + if (current + bytes > maxBytes) { + return 0; // Denied + } + currentUsage.addAndGet(bytes); + return bytes; + } + + @Override + public void releaseMemory(long bytes) { + currentUsage.addAndGet(-bytes); + } + + public long getCurrentUsage() { return currentUsage.get(); } + public long getMaxBytes() { return maxBytes; } + } +}