-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembeddings.py
More file actions
181 lines (136 loc) · 5.24 KB
/
embeddings.py
File metadata and controls
181 lines (136 loc) · 5.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Multilingual embedding model management.
"""
from sentence_transformers import SentenceTransformer
from typing import List, Union
import numpy as np
import logging
import config
import torch
logger = logging.getLogger(__name__)
# Global model cache
_embedding_model = None
def load_embedding_model(model_name: str = None) -> SentenceTransformer:
"""
Load the multilingual embedding model with caching.
Args:
model_name: Name of the model to load (default from config)
Returns:
Loaded SentenceTransformer model
"""
global _embedding_model
if _embedding_model is not None:
return _embedding_model
if model_name is None:
model_name = config.EMBEDDING_MODEL_NAME
logger.info(f"Loading embedding model: {model_name}")
logger.info("This may take a few minutes on first run...")
# Set cache directory
cache_dir = str(config.MODELS_CACHE_DIR)
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
_embedding_model = SentenceTransformer(
model_name,
cache_folder=cache_dir,
device=device
)
logger.info(f"Model loaded on device: {device}")
logger.info(f"Embedding dimension: {_embedding_model.get_sentence_embedding_dimension()}")
return _embedding_model
def embed_texts(
texts: List[str],
batch_size: int = 32,
show_progress: bool = True,
is_query: bool = False
) -> np.ndarray:
"""
Embed a list of texts using the multilingual model.
Args:
texts: List of texts to embed
batch_size: Batch size for encoding
show_progress: Whether to show progress bar
is_query: If True, add query prefix for E5 models
Returns:
Numpy array of embeddings, shape (len(texts), embedding_dim)
"""
model = load_embedding_model()
# Add E5 prefix if using E5 model (check for 'e5' in model name)
if "e5" in config.EMBEDDING_MODEL_NAME.lower():
prefix = config.E5_QUERY_PREFIX if is_query else config.E5_PASSAGE_PREFIX
texts = [prefix + text for text in texts]
# Encode
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=show_progress,
convert_to_numpy=True,
normalize_embeddings=True # Normalize for cosine similarity
)
return embeddings
def embed_query(query: str) -> np.ndarray:
"""
Embed a single query text.
Args:
query: Query text to embed
Returns:
Numpy array of shape (embedding_dim,)
"""
embeddings = embed_texts([query], batch_size=1, show_progress=False, is_query=True)
return embeddings[0]
def embed_passages(passages: List[str], batch_size: int = 32) -> np.ndarray:
"""
Embed a list of passage texts (documents/chunks).
Args:
passages: List of passage texts to embed
batch_size: Batch size for encoding
Returns:
Numpy array of embeddings, shape (len(passages), embedding_dim)
"""
return embed_texts(passages, batch_size=batch_size, show_progress=True, is_query=False)
def compute_similarity(query_embedding: np.ndarray, passage_embeddings: np.ndarray) -> np.ndarray:
"""
Compute cosine similarity between a query and multiple passages.
Args:
query_embedding: Query embedding, shape (embedding_dim,)
passage_embeddings: Passage embeddings, shape (n_passages, embedding_dim)
Returns:
Similarity scores, shape (n_passages,)
"""
# Cosine similarity (embeddings are already normalized)
similarities = np.dot(passage_embeddings, query_embedding)
return similarities
if __name__ == "__main__":
# Test embedding functionality
print("Testing Multilingual Embedding Model")
print("=" * 60)
# Test texts in different languages
test_texts = [
"What is the treatment for diabetes?", # English
"मधुमेह का इलाज क्या है?", # Hindi
"நீரிழிவு நோய்க்கான சிகிச்சை என்ன?", # Tamil
"డయాబెటిస్ చికిత్స ఏమిటి?", # Telugu
]
print("\n1. Loading model...")
model = load_embedding_model()
print("\n2. Embedding test texts...")
embeddings = embed_passages(test_texts)
print(f"\nEmbeddings shape: {embeddings.shape}")
print(f"Expected shape: ({len(test_texts)}, {config.EMBEDDING_DIMENSION})")
print("\n3. Testing cross-lingual similarity...")
query = "diabetes treatment"
query_emb = embed_query(query)
similarities = compute_similarity(query_emb, embeddings)
print(f"\nQuery: '{query}'")
print("\nSimilarities:")
for text, sim in zip(test_texts, similarities):
print(f" {sim:.4f} - {text}")
print("\n4. Testing multilingual queries...")
queries = [
"diabetes treatment", # English
"मधुमेह उपचार", # Hindi
]
for q in queries:
q_emb = embed_query(q)
sims = compute_similarity(q_emb, embeddings)
print(f"\nQuery: '{q}'")
print(f"Top match: {test_texts[np.argmax(sims)]} (similarity: {np.max(sims):.4f})")