-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathnormalized_semantic_chunker.py
More file actions
1749 lines (1457 loc) · 67.1 KB
/
normalized_semantic_chunker.py
File metadata and controls
1749 lines (1457 loc) · 67.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import hmac
import time
import json
import logging
import re
import psutil
import tiktoken
import torch
import asyncio
import multiprocessing
import numpy as np
from logging.handlers import RotatingFileHandler
from sentence_transformers import SentenceTransformer
from typing import List, Optional, Dict, Union, Any
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, BackgroundTasks
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field, ConfigDict
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from contextlib import asynccontextmanager
RESERVED_SYSTEM_MEMORY_GB = 4.0 # GB of RAM to keep free for the OS and other processes
BASE_MEMORY_PER_WORKER_GB = 1.0 # Base GB of RAM allocated per worker
MEMORY_PER_SENTENCE_GB = 0.0004 # Additional GB of RAM per sentence per worker
DOC_SIZE_FACTOR_SCALER = (
3000 # Scaler used in document size factor for CPU-based worker calculation
)
VERY_LARGE_DOC_SENTENCE_THRESHOLD = (
20000 # Sentences to be considered a very large document
)
LARGE_DOC_SENTENCE_THRESHOLD = 10000 # Sentences to be considered a large document
MIN_SENTENCES_FOR_PARALLEL = (
100 # Minimum sentences to use parallel processing instead of sequential
)
WORKERS_VERY_LARGE_DOC = 1 # Max workers for very large documents
WORKERS_LARGE_DOC = 2 # Max workers for large documents
STEP_SIZE_VERY_LARGE_DOC_THRESHOLD = (
15000 # Sentence count threshold for using smallest step size
)
STEP_SIZE_LARGE_DOC_THRESHOLD = (
5000 # Sentence count threshold for using medium step size
)
STEP_SIZE_DEFAULT = 10 # Default step size for smaller documents
STEP_SIZE_LARGE_DOC = 5 # Step size for large documents
STEP_SIZE_VERY_LARGE_DOC = 3 # Step size for very large documents
ALLOWED_EXTENSIONS = {"txt", "md", "json"}
def _get_int_env(name: str, default: int) -> int:
value = os.environ.get(name)
if value is None:
return default
value = value.strip()
if not value:
return default
return int(value)
EMBEDDER_MODEL = os.environ.get(
"EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
)
MAX_FILE_SIZE = _get_int_env("MAX_FILE_SIZE", 10 * 1024 * 1024) # 10MB default
MAX_CHUNK_TEXT_SIZE = _get_int_env("MAX_CHUNK_TEXT_SIZE", 100_000) # 100k chars per chunk
MAX_WORKERS = max(
1, _get_int_env("MAX_WORKERS", min(multiprocessing.cpu_count() - 1, 4))
)
CACHE_TIMEOUT = _get_int_env("CACHE_TIMEOUT", 3600) # 1 hour in seconds
WORKER_TIMEOUT = _get_int_env("WORKER_TIMEOUT", 300) # 5 minutes default
def _validate_config() -> None:
"""Validate critical configuration values. Raises RuntimeError on invalid config."""
errors = []
if MAX_FILE_SIZE <= 0:
errors.append(f"MAX_FILE_SIZE must be > 0, got {MAX_FILE_SIZE}")
if MAX_CHUNK_TEXT_SIZE <= 0:
errors.append(f"MAX_CHUNK_TEXT_SIZE must be > 0, got {MAX_CHUNK_TEXT_SIZE}")
if WORKER_TIMEOUT <= 0:
errors.append(f"WORKER_TIMEOUT must be > 0, got {WORKER_TIMEOUT}")
if CACHE_TIMEOUT <= 0:
errors.append(f"CACHE_TIMEOUT must be > 0, got {CACHE_TIMEOUT}")
if MAX_WORKERS < 1:
errors.append(f"MAX_WORKERS must be >= 1, got {MAX_WORKERS}")
if errors:
raise RuntimeError("Invalid configuration:\n" + "\n".join(f" - {e}" for e in errors))
@asynccontextmanager
async def lifespan(app: FastAPI):
global _model_lock, _gpu_lock
_validate_config() # Fail fast on invalid env-var configuration
# Inizializza i lock nell'event loop attivo (non a livello di modulo)
_model_lock = asyncio.Lock()
_gpu_lock = asyncio.Lock()
logger.info(
f"Loading embedding model {EMBEDDER_MODEL} during application startup..."
)
try:
await _get_model(EMBEDDER_MODEL)
async with _gpu_lock:
await _load_model_to_gpu(EMBEDDER_MODEL)
logger.info("Embedding model loaded and ready on device.")
except Exception as e:
logger.error(f"Failed to load embedding model: {str(e)}")
raise RuntimeError(f"Cannot start: embedding model failed to load: {e}") from e
yield
# Cleanup
logger.info("Application shutting down, cleaning up resources...")
try:
# Cleanup model cache
async with _model_lock:
for model_name in list(_model_cache.keys()):
if model_name in _model_cache:
del _model_cache[model_name]
logger.info(f"Removed model {model_name} from cache")
# Cleanup GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("GPU memory cleared during shutdown")
except Exception as e:
logger.error(f"Error during cleanup: {str(e)}")
app = FastAPI(
title="Normalized Semantic Chunker",
description="API for processing and chunking text documents into smaller, semantically coherent segments",
version="1.0.0",
lifespan=lifespan,
)
API_TOKEN = os.environ.get("API_TOKEN", "").strip()
_security = HTTPBearer(auto_error=False)
async def verify_token(
credentials: HTTPAuthorizationCredentials = Depends(_security),
):
"""Verify Bearer token if API_TOKEN is configured. No-op when API_TOKEN is unset."""
if not API_TOKEN:
return # Auth disabled
if credentials is None or not hmac.compare_digest(credentials.credentials, API_TOKEN):
raise HTTPException(status_code=403, detail="Invalid or missing API token")
# Create logs directory if it doesn't exist
logs_dir = Path("logs")
logs_dir.mkdir(exist_ok=True)
# Configure logging (this sets up root logger with a console handler)
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# Get the logger
logger = logging.getLogger(__name__)
# Create a file handler for error logs
error_log_path = logs_dir / "errors.log"
file_handler = RotatingFileHandler(
error_log_path,
maxBytes=10485760, # 10 MB
backupCount=5, # Keep 5 backup logs
encoding="utf-8",
)
# Set the file handler to only log errors and critical messages
file_handler.setLevel(logging.ERROR)
# Create a formatter
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(message)s - %(pathname)s:%(lineno)d"
)
file_handler.setFormatter(formatter)
# Add the handler to the logger
logger.addHandler(file_handler)
# Create a singleton for model caching with expiration
_model_cache: dict = {}
_model_last_used: dict = {}
_model_lock: Optional[asyncio.Lock] = None # inizializzato nel lifespan
_gpu_lock: Optional[asyncio.Lock] = None # inizializzato nel lifespan — un solo modello in VRAM alla volta
_current_gpu_model: Optional[str] = None # nome del modello attualmente in VRAM
async def _get_model(model_name: str) -> SentenceTransformer:
"""Get model from cache or load it into RAM with cache expiration.
Args:
model_name (str): Name or path of the model to use.
Returns:
SentenceTransformer: The loaded model instance.
"""
global _model_cache, _model_last_used
current_time = time.time()
async with _model_lock:
# Check for expired models first (solo RAM, non tocca VRAM — gestita da _evict_gpu_model)
expired_models = [
name
for name, last_used in _model_last_used.items()
if current_time - last_used > CACHE_TIMEOUT
]
# Remove expired models from RAM
for name in expired_models:
if name in _model_cache and name != model_name:
logger.info(f"Removing expired model {name} from cache")
del _model_cache[name]
del _model_last_used[name]
# Update or load the requested model
if model_name not in _model_cache:
# Create models directory if it doesn't exist
models_dir = Path("models")
models_dir.mkdir(exist_ok=True)
# Local path for the model
local_model_path = models_dir / model_name.replace("/", "_")
try:
if local_model_path.exists():
# Load from local storage
logger.info(f"Loading model from local storage: {local_model_path}")
_model_cache[model_name] = SentenceTransformer(
str(local_model_path)
)
else:
# Download and save model
logger.info(
f"Downloading model {model_name} and saving to {local_model_path}"
)
_model_cache[model_name] = SentenceTransformer(model_name)
_model_cache[model_name].save(str(local_model_path))
except Exception as e:
logger.error(f"Error loading model {model_name}: {str(e)}")
raise
# Update last used timestamp
_model_last_used[model_name] = current_time
return _model_cache[model_name]
async def _evict_gpu_model() -> None:
"""Rimuove il modello corrente dalla VRAM e lo mantiene in RAM.
DEVE essere chiamata mentre si detiene _gpu_lock.
"""
global _current_gpu_model
if _current_gpu_model is None or not torch.cuda.is_available():
return
model = _model_cache.get(_current_gpu_model)
if model is not None:
_model_cache[_current_gpu_model] = model.cpu() # assegnazione corretta — evita il leak
logger.info(f"Evicted model '{_current_gpu_model}' from VRAM, kept in RAM")
torch.cuda.empty_cache()
_current_gpu_model = None
async def _load_model_to_gpu(model_name: str) -> SentenceTransformer:
"""Carica il modello richiesto in VRAM, evictando quello precedente se necessario.
DEVE essere chiamata mentre si detiene _gpu_lock.
Il modello deve essere già in _model_cache (caricato da _get_model).
"""
global _current_gpu_model
if not torch.cuda.is_available():
# Nessuna GPU: ritorna il modello in RAM così com'è
return _model_cache[model_name]
if _current_gpu_model == model_name:
# Modello già in VRAM — nessun overhead
return _model_cache[model_name]
# Evict modello precedente dalla VRAM
if _current_gpu_model is not None:
await _evict_gpu_model()
# Carica il nuovo modello in VRAM
device = torch.device("cuda")
_model_cache[model_name] = _model_cache[model_name].to(device)
_current_gpu_model = model_name
logger.info(f"Loaded model '{model_name}' to VRAM")
return _model_cache[model_name]
def split_into_sentences(doc: str) -> List[str]:
"""Split a document into sentences using regex pattern matching.
Args:
doc (str): The input document text to be split into sentences.
Returns:
List[str]: A list of sentences, with each sentence stripped of leading/trailing whitespace.
Note:
The function handles common edge cases like:
- Titles (Mr., Mrs., Dr., etc.)
- Common abbreviations (i.e., e.g., etc.)
- Decimal numbers
- Ellipsis
- Quotes and brackets
"""
# Define a pattern that looks for sentence boundaries but doesn't include them in the split
# Instead of splitting directly at punctuation, we'll look for patterns that indicate sentence endings
pattern = r"""
# Match sentence ending punctuation followed by space and capital letter
# Negative lookbehind for common titles and abbreviations
(?<![A-Z][a-z]\.) # Not an abbreviation like U.S.
(?<!Mr\.)(?<!Mrs\.)(?<!Dr\.)(?<!Prof\.)(?<!Sr\.)(?<!Jr\.)(?<!Ms\.) # Not a title
(?<!i\.e\.)(?<!e\.g\.)(?<!vs\.)(?<!etc\.) # Not a common abbreviation
(?<!\d\.)(?<!\.\d) # Not a decimal or numbered list
(?<!\.\.\.) # Not an ellipsis
[\.!\?] # Sentence ending punctuation
\s+ # One or more whitespace
(?=[A-Z]) # Followed by capital letter
"""
# Find all positions where we should split
split_positions = []
for match in re.finditer(pattern, doc, re.VERBOSE):
# Split after the punctuation and space
split_positions.append(match.end())
# Use the positions to extract sentences
sentences = []
start = 0
for pos in split_positions:
if pos > start:
sentences.append(doc[start:pos].strip())
start = pos
# Add the last sentence if there's remaining text
if start < len(doc):
sentences.append(doc[start:].strip())
# Filter out empty sentences
return [s for s in sentences if s]
async def get_embeddings(
doc: List[str],
model: str = EMBEDDER_MODEL,
batch_size: int = 8,
verbosity: bool = False,
convert_to_numpy: bool = True,
normalize_embeddings: bool = True,
) -> dict[str, List[float]]:
"""Generate embeddings for a list of text strings using a Sentence Transformer model.
Args:
doc (List[str]): List of text strings to generate embeddings for.
model (str, optional): Name or path of the model to use.
batch_size (int, optional): Batch size for embedding generation.
verbosity (bool, optional): If True, shows all log messages and progress bars.
convert_to_numpy (bool, optional): Whether to convert output to numpy array.
normalize_embeddings (bool, optional): Whether to normalize embeddings.
Returns:
dict[str, List[float]]: Dictionary mapping input strings to their embeddings.
Raises:
HTTPException: If there's an error during the embedding process.
"""
if not doc:
logger.warning("get_embeddings called with empty document list — returning empty dict")
return {}
try:
# Carica modello in RAM se non presente
await _get_model(model)
# Aggiusta batch size per documenti grandi
effective_batch_size = batch_size
if len(doc) > 1000:
effective_batch_size = max(1, batch_size // 2)
if verbosity:
logger.info(
f"Large document detected, reducing batch size to {effective_batch_size}"
)
# Garantisce che il modello corretto sia in VRAM (serializza lo switch tra modelli)
async with _gpu_lock:
active_model = await _load_model_to_gpu(model)
# Esegui encode in un thread separato per non bloccare l'event loop
embeddings = await asyncio.to_thread(
active_model.encode,
doc,
batch_size=effective_batch_size,
show_progress_bar=verbosity,
convert_to_numpy=convert_to_numpy,
normalize_embeddings=normalize_embeddings,
)
# Costruisce il dizionario fuori dal lock (operazione CPU pura).
# Per frasi duplicate si conserva solo il primo embedding: SentenceTransformer
# è deterministico (stesso testo → stesso vettore), quindi first/last sono
# equivalenti, ma "first" è più prevedibile e coerente con l'ordine originale.
result = {}
for sentence, embedding in zip(doc, embeddings):
if sentence not in result:
result[sentence] = embedding.tolist()
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error generating embeddings: {str(e)}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred: {str(e)}",
)
def calculate_similarity(
embeddings_dict: dict[str, List[float]], sentences: List[str]
) -> List[float]:
"""Calculate similarity scores between consecutive vectors in sentence order.
Args:
embeddings_dict (dict[str, List[float]]): Dictionary mapping sentences to vectors
sentences (List[str]): Sentences in original order
Returns:
List[float]: List of similarity scores between consecutive vector pairs
"""
if len(sentences) <= 1:
return []
try:
# Extract vectors in sentence order
vectors = [embeddings_dict[sentence] for sentence in sentences]
# Process in batches for large documents to prevent OOM errors
batch_size = 5000 # Adjust based on memory constraints
if len(vectors) > batch_size:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
similarities = []
for i in range(0, len(vectors) - 1, batch_size):
# Compute similarities for pairs (i, i+1) through (end_idx-1, end_idx).
# Both slices have equal length (end_idx - i), so element-wise ops are safe.
end_idx = min(i + batch_size, len(vectors) - 1)
batch_vectors1 = torch.tensor(
vectors[i:end_idx], dtype=torch.float32, device=device
)
batch_vectors2 = torch.tensor(
vectors[i + 1 : end_idx + 1], dtype=torch.float32, device=device
)
# Calculate norms
norms1 = torch.linalg.norm(batch_vectors1, dim=1)
norms2 = torch.linalg.norm(batch_vectors2, dim=1)
# Calculate dot products
dot_products = torch.sum(batch_vectors1 * batch_vectors2, dim=1)
# Calculate similarities
batch_similarities = 1 - (dot_products / (norms1 * norms2))
similarities.extend(
[round(float(sim), 5) for sim in batch_similarities.cpu()]
)
# Clean up batch memory
del (
batch_vectors1,
batch_vectors2,
norms1,
norms2,
dot_products,
batch_similarities,
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return similarities
else:
# For smaller documents, process all at once
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vectors_tensor = torch.tensor(vectors, dtype=torch.float32, device=device)
# Get consecutive pairs
vectors1 = vectors_tensor[:-1]
vectors2 = vectors_tensor[1:]
# Calculate norms
norms = torch.linalg.norm(vectors_tensor, dim=1)
norms1 = norms[:-1]
norms2 = norms[1:]
# Calculate dot products
dot_products = torch.sum(vectors1 * vectors2, dim=1)
# Calculate similarities (using 1 - cosine similarity for angular distance)
cosine_similarities = dot_products / (norms1 * norms2)
angular_similarities = 1 - cosine_similarities
# Move to CPU and round
result = [round(float(sim), 5) for sim in angular_similarities.cpu()]
# Cleanup GPU memory
if torch.cuda.is_available():
del vectors_tensor, vectors1, vectors2, norms, norms1, norms2
del dot_products, cosine_similarities, angular_similarities
torch.cuda.empty_cache()
return result
except Exception as e:
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.error(f"Error calculating similarities: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error calculating similarities: {str(e)}",
)
def _count_tokens_for_text(args: tuple[str, str]) -> int:
"""Count tokens in a text string using the specified encoding.
Args:
args (tuple): Tuple containing:
- text (str): Text to count tokens for
- encoding_name (str): Name of the tiktoken encoding to use
Returns:
int: Number of tokens in the text
Raises:
Exception: If tiktoken fails. Caller (ProcessPoolExecutor) handles the exception.
"""
text, encoding_name = args
encoding = tiktoken.get_encoding(encoding_name)
return len(encoding.encode(text))
def _group_chunks_by_similarity(
sentences: List[str], distance: List[float], percentile: int
) -> tuple[dict[str, int], int, float, float]:
"""Group sentences into chunks based on similarity distances and a percentile threshold.
Args:
sentences (List[str]): List of sentences to group into chunks.
distance (List[float]): List of similarity distances between consecutive sentences.
percentile (int): Percentile value to use as threshold for chunk boundaries.
Returns:
tuple[dict[str, int], int, float, float]: A tuple containing:
- Dictionary mapping text chunks to their token counts
- Maximum token count across all chunks
- Average token count across all chunks
- Standard deviation of token counts
"""
try:
if not (1 <= percentile <= 99):
logger.warning(f"percentile={percentile} is outside valid range [1, 99], clamping")
percentile = max(1, min(99, percentile))
breakpoint = np.percentile(distance, percentile)
indices_above_th = [i for i, x in enumerate(distance) if x > breakpoint]
chunks = []
start_index = 0
for index in indices_above_th:
combined_text = " ".join(sentences[start_index : index + 1])
chunks.append(combined_text)
start_index = index + 1
if start_index < len(sentences):
chunks.append(" ".join(sentences[start_index:]))
# Calculate token counts with multiprocessing for large documents
if len(chunks) > 100:
with ProcessPoolExecutor(
max_workers=min(MAX_WORKERS, len(chunks) // 10 + 1)
) as executor:
token_args = [(chunk_text, "cl100k_base") for chunk_text in chunks]
token_counts = list(executor.map(_count_tokens_for_text, token_args))
else:
# For smaller sets, avoid the overhead of multiprocessing
token_counts = [
_count_tokens_for_text((chunk_text, "cl100k_base"))
for chunk_text in chunks
]
# Create dictionary mapping chunks to token counts
chunks_with_tokens = {
chunk: count for chunk, count in zip(chunks, token_counts)
}
max_tokens = max(token_counts) if token_counts else 0
average_tokens = sum(token_counts) / len(token_counts) if token_counts else 0
std_dev = np.std(token_counts) if len(token_counts) > 1 else 0
return chunks_with_tokens, max_tokens, average_tokens, std_dev
except Exception as e:
logger.error(f"Error in _group_chunks_by_similarity: {str(e)}")
# Return safe fallback in case of error
if not sentences:
return {}, 0, 0, 0
# Single chunk fallback
combined_text = " ".join(sentences)
token_count = _count_tokens_for_text((combined_text, "cl100k_base"))
return {combined_text: token_count}, token_count, token_count, 0
def _process_percentile_range(
args: tuple[List[str], List[float], int, int],
) -> tuple[Optional[dict[str, int]], Optional[int], Optional[float]]:
"""Process a single percentile and check if it produces valid chunks.
Args:
args (tuple): Tuple containing:
- sentences (List[str]): List of sentences to process
- distance (List[float]): List of similarity distances
- max_tokens (int): Maximum tokens allowed per chunk
- percentile (int): Percentile to check
Returns:
tuple[Optional[dict[str, int]], Optional[int], Optional[float]]:
- Dictionary mapping chunks to token counts (if valid)
- Percentile used (if valid)
- Average tokens (if valid)
- Or (None, None, None) if invalid
"""
try:
sentences, distance, max_tokens, percentile = args
chunks_with_tokens, threshold_tokens, average_tokens, std_dev = (
_group_chunks_by_similarity(sentences, distance, percentile)
)
# Calculate 95th percentile value using z-score of 1.645
estimated_95th_percentile = average_tokens + (1.645 * std_dev)
if estimated_95th_percentile <= max_tokens:
return chunks_with_tokens, percentile, average_tokens
return None, None, None
except Exception as e:
logger.error(f"Error processing percentile {percentile}: {str(e)}")
return None, None, None
def _find_optimal_chunks(
sentences: List[str], distance: List[float], max_tokens: int
) -> tuple[dict[str, int], int, float]:
"""Find optimal chunk groupings that fit within a token limit using statistical approach.
Args:
sentences (List[str]): List of sentences to group into chunks.
distance (List[float]): List of similarity distances between consecutive sentences.
max_tokens (int): Maximum number of tokens allowed per chunk.
Returns:
tuple[dict[str, int], int, float]: A tuple containing:
- Dictionary mapping text chunks to their token counts
- Percentile value used for grouping (0 if no suitable grouping found)
- Average token count across all chunks
Note:
Uses a statistical approach by calculating the estimated 95th percentile
of token counts to ensure most chunks stay below the token limit.
"""
# Use a targeted percentile range for more efficient exploration
percentile_steps = 5
for percentile in range(95, 0, -percentile_steps):
chunks_with_tokens, max_token_val, average_tokens, std_dev = (
_group_chunks_by_similarity(sentences, distance, percentile)
)
# Calculate 95th percentile value using z-score of 1.645
estimated_95th_percentile = average_tokens + (1.645 * std_dev)
if estimated_95th_percentile <= max_tokens:
logger.info(
f"Sequential fallback found valid percentile {percentile} with 95th percentile estimate {estimated_95th_percentile:.2f} <= {max_tokens}"
)
return chunks_with_tokens, percentile, average_tokens
# If no valid percentile found, return the entire text as a single chunk
logger.warning(
"No valid chunking found using sequential approach - returning single chunk. "
"Consider increasing max_tokens or providing a longer document."
)
fallback_chunks = {
" ".join(sentences): _count_tokens_for_text(
(" ".join(sentences), "cl100k_base")
)
}
return fallback_chunks, 0, 0
def parallel_find_optimal_chunks(
sentences: List[str],
distance: List[float],
max_tokens: int,
start_percentile: int = 99,
verbosity: bool = True,
) -> tuple[dict[str, int], int, float]:
"""Find optimal chunks using parallel batch processing starting from highest percentiles.
Uses available CPU cores to process multiple percentiles simultaneously, collecting all
valid results and selecting the highest percentile that satisfies the constraints.
Args:
sentences: List of sentences to group
distance: List of similarity distances
max_tokens: Maximum tokens per chunk
start_percentile: Starting percentile for search (default: 99)
Returns:
tuple[dict[str, int], int, float]: Dictionary mapping chunks to token counts,
percentile used, and average tokens
"""
# Get available memory in GB and estimate memory needs based on document size
available_memory_gb = psutil.virtual_memory().available / (1024**3)
if verbosity:
logger.info(f"Available system memory: {available_memory_gb:.2f} GB")
# Estimate memory needed per worker (rough heuristic based on sentence count)
estimated_mem_per_worker_gb = max(
BASE_MEMORY_PER_WORKER_GB, len(sentences) * MEMORY_PER_SENTENCE_GB
)
if verbosity:
logger.info(
f"Estimated memory per worker: {estimated_mem_per_worker_gb:.2f} GB"
)
# Calculate max workers based on available memory (leave RESERVED_SYSTEM_MEMORY_GB for system)
# Ensure estimated_mem_per_worker_gb is not zero to prevent DivisionByZeroError
if estimated_mem_per_worker_gb <= 0:
logger.warning(
"Estimated memory per worker is zero or negative, defaulting memory_based_max_workers to 1."
)
memory_based_max_workers = 1
else:
memory_based_max_workers = max(
1,
int(
(available_memory_gb - RESERVED_SYSTEM_MEMORY_GB)
/ estimated_mem_per_worker_gb
),
)
if verbosity:
logger.info(f"Memory-based worker limit: {memory_based_max_workers}")
# Also consider document size for worker scaling
doc_size_factor = min(1.0, DOC_SIZE_FACTOR_SCALER / max(1, len(sentences)))
cpu_based_max_workers = max(
1, min(MAX_WORKERS, int(multiprocessing.cpu_count() * doc_size_factor))
)
if verbosity:
logger.info(f"CPU-based worker limit: {cpu_based_max_workers}")
# Take the minimum of memory-based and CPU-based worker counts
max_workers = min(memory_based_max_workers, cpu_based_max_workers)
# Hard cap based on sentence count for very large documents as extra safety
if len(sentences) > VERY_LARGE_DOC_SENTENCE_THRESHOLD:
max_workers = min(max_workers, WORKERS_VERY_LARGE_DOC)
if verbosity:
logger.info(
f"Very large document (>{VERY_LARGE_DOC_SENTENCE_THRESHOLD} sentences), capping at {WORKERS_VERY_LARGE_DOC} worker(s)"
)
elif len(sentences) > LARGE_DOC_SENTENCE_THRESHOLD:
max_workers = min(max_workers, WORKERS_LARGE_DOC)
if verbosity:
logger.info(
f"Large document (>{LARGE_DOC_SENTENCE_THRESHOLD} sentences), capping at {WORKERS_LARGE_DOC} worker(s)"
)
# For very small documents, skip parallel processing
if len(sentences) < MIN_SENTENCES_FOR_PARALLEL:
if verbosity:
logger.info(
f"Small document (<{MIN_SENTENCES_FOR_PARALLEL} sentences), using sequential processing"
)
return _find_optimal_chunks(sentences, distance, max_tokens)
if verbosity:
logger.info(f"Using {max_workers} workers for parallel processing")
try:
# Scale step size inversely with document size
if len(sentences) > STEP_SIZE_VERY_LARGE_DOC_THRESHOLD:
initial_step_size = STEP_SIZE_VERY_LARGE_DOC
elif len(sentences) > STEP_SIZE_LARGE_DOC_THRESHOLD:
initial_step_size = STEP_SIZE_LARGE_DOC
else:
initial_step_size = STEP_SIZE_DEFAULT
if verbosity:
logger.info(f"Using initial step size of {initial_step_size}")
for batch_start in range(start_percentile, 0, -initial_step_size * max_workers):
batch_end = max(0, batch_start - initial_step_size * max_workers)
current_percentiles = range(batch_start, batch_end, -initial_step_size)
process_args = [
(sentences, distance, max_tokens, p) for p in current_percentiles
]
valid_results = []
# Use a timeout for worker processes to prevent hung processes
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(_process_percentile_range, args)
for args in process_args
]
try:
for future in as_completed(futures, timeout=WORKER_TIMEOUT):
try:
chunks_with_tokens, percentile, average_tokens = future.result()
if chunks_with_tokens is not None:
valid_results.append(
(chunks_with_tokens, percentile, average_tokens)
)
except Exception as e:
logger.error(f"Error processing future: {str(e)}")
continue
except TimeoutError:
logger.error(
f"Percentile search timed out after {WORKER_TIMEOUT}s. "
f"Proceeding with results collected so far ({len(valid_results)} valid)."
)
if valid_results:
# We found at least one valid percentile, now refine with finer grain
valid_results.sort(key=lambda x: x[1], reverse=True)
best_valid_percentile = valid_results[0][1]
# Perform a refined search around the best valid percentile
refined_start = min(99, best_valid_percentile + initial_step_size)
refined_end = max(1, best_valid_percentile - initial_step_size)
# Ensure range is non-empty: stop must be strictly less than start
refined_end = min(refined_end, refined_start - 1)
refined_percentiles = range(refined_start, refined_end, -1)
refined_args = [
(sentences, distance, max_tokens, p) for p in refined_percentiles
]
refined_results = []
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(_process_percentile_range, args)
for args in refined_args
]
try:
for future in as_completed(futures, timeout=WORKER_TIMEOUT):
try:
chunks_with_tokens, percentile, average_tokens = (
future.result()
)
if chunks_with_tokens is not None:
refined_results.append(
(chunks_with_tokens, percentile, average_tokens)
)
except Exception as e:
logger.error(
f"Error processing future in refinement: {str(e)}"
)
continue
except TimeoutError:
logger.error(
f"Refined percentile search timed out after {WORKER_TIMEOUT}s. "
f"Proceeding with results collected so far ({len(refined_results)} valid)."
)
# Combine results and select the best one
all_results = valid_results + refined_results
all_results.sort(key=lambda x: x[1], reverse=True)
best_chunks_with_tokens, best_percentile, best_average_tokens = (
all_results[0]
)
if verbosity:
logger.info(
f"Selected the highest valid percentile: {best_percentile}"
)
return best_chunks_with_tokens, best_percentile, best_average_tokens
if verbosity:
logger.info(
"No valid chunking found in parallel processing, falling back to sequential"
)
return _find_optimal_chunks(sentences, distance, max_tokens)
except Exception as e:
logger.error(
f"Error in parallel processing: {str(e)}, falling back to sequential approach"
)
return _find_optimal_chunks(sentences, distance, max_tokens)
async def merge_undersized_chunks(
chunks: List[dict],
min_token_threshold: float,
max_tokens: int,
model: str = EMBEDDER_MODEL,
verbosity: bool = False,
max_passes: int = 3,
) -> List[dict]:
"""
Merge chunks that are below a minimum token threshold with semantically similar neighbors.
Uses multiple passes to maximize the number of small chunks that get merged.
Args:
chunks (List[dict]): List of chunks with 'text' and 'token_count' keys
min_token_threshold (float): Minimum token threshold (e.g., 5th percentile).
This threshold remains FIXED across all passes to ensure convergence.
max_tokens (int): Maximum allowed tokens for a chunk
model (str): Embedding model to use
verbosity (bool): If True, shows all log messages and progress bars
max_passes (int): Maximum number of merge passes (default: 3, range: 1-5)
Returns:
List[dict]: Updated list of chunks after merging small ones
"""
# Track initial state for final logging
initial_undersized = sum(
1 for chunk in chunks if chunk["token_count"] < min_token_threshold
)
if initial_undersized == 0:
logger.info("Merge skipped: no chunks below threshold")
return chunks
# If most chunks are undersized, adjust the threshold (only once, before passes)
adjusted_threshold = min_token_threshold
if initial_undersized > len(chunks) * 0.5:
logger.info("Too many undersized chunks, adjusting threshold to 80%")
adjusted_threshold = min_token_threshold * 0.8
initial_undersized = sum(
1 for chunk in chunks if chunk["token_count"] < adjusted_threshold
)
# Working copy of chunks that will be modified across passes
current_chunks = list(chunks)
total_merged = 0
# Multi-pass merge loop
for pass_num in range(1, max_passes + 1):
# Identify undersized chunks for this pass (using fixed threshold)
undersized_indices = [
i
for i, chunk in enumerate(current_chunks)
if chunk["token_count"] < adjusted_threshold
]
if not undersized_indices:
if verbosity:
logger.info(
f"Pass {pass_num}: No undersized chunks remaining, stopping early"
)
break
# Sort by token count (ascending) to process smallest first
undersized_indices.sort(key=lambda i: current_chunks[i]["token_count"])
# Calculate embeddings for all current chunks
chunk_texts = [chunk["text"] for chunk in current_chunks]
embeddings_dict = await get_embeddings(
doc=chunk_texts,
model=model,
batch_size=8,
verbosity=False, # Suppress per-pass embedding logs
convert_to_numpy=True,
normalize_embeddings=True,
)