Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions FlagEmbedding/abc/inference/AbsEmbedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,20 @@ def _concatenate_results_from_multi_process(self, results_list: List[Union[torch
return np.concatenate(results_list, axis=0)
else:
raise NotImplementedError("Unsupported type for results_list")

def _convert_to_numpy(self, embeddings: torch.Tensor, device: Optional[str] = None) -> np.ndarray:
"""Convert tensor embeddings to numpy with bf16-safe handling.

NumPy does not support bfloat16, so we upcast to float32 only when
bf16 inference is enabled on non-CPU devices.

Args:
embeddings (torch.Tensor): Embedding tensor.
device (Optional[str], optional): Inference device string. Defaults to ``None``.

Returns:
np.ndarray: Embeddings in numpy format.
"""
if device != "cpu" and self.use_bf16 and embeddings.dtype == torch.bfloat16:
embeddings = embeddings.float()
return embeddings.cpu().numpy()
6 changes: 4 additions & 2 deletions FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,13 @@ def get_model(
)
colbert_linear = torch.nn.Linear(
in_features=model.config.hidden_size,
out_features=model.config.hidden_size if colbert_dim <= 0 else colbert_dim
out_features=model.config.hidden_size if colbert_dim <= 0 else colbert_dim,
dtype=torch_dtype,
)
sparse_linear = torch.nn.Linear(
in_features=model.config.hidden_size,
out_features=1
out_features=1,
dtype=torch_dtype,
)

colbert_model_path = os.path.join(model_name_or_path, 'colbert_linear.pt')
Expand Down
2 changes: 1 addition & 1 deletion FlagEmbedding/inference/embedder/decoder_only/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def encode_single_device(
embeddings = cast(torch.Tensor, embeddings)

if convert_to_numpy:
embeddings = embeddings.cpu().numpy()
embeddings = self._convert_to_numpy(embeddings, device=device)
all_embeddings.append(embeddings)

if convert_to_numpy:
Expand Down
8 changes: 4 additions & 4 deletions FlagEmbedding/inference/embedder/decoder_only/icl.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def encode_queries_single_device(
embeddings = cast(torch.Tensor, embeddings)

if convert_to_numpy:
embeddings = embeddings.cpu().numpy()
embeddings = self._convert_to_numpy(embeddings, device=device)
all_embeddings.append(embeddings)

if convert_to_numpy:
Expand Down Expand Up @@ -479,8 +479,8 @@ def encode_single_device(
if device is None:
device = self.target_devices[0]

if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if device == "cpu":
self.model.float()

self.model.to(device)
self.model.eval()
Expand Down Expand Up @@ -546,7 +546,7 @@ def encode_single_device(
embeddings = cast(torch.Tensor, embeddings)

if convert_to_numpy:
embeddings = embeddings.cpu().numpy()
embeddings = self._convert_to_numpy(embeddings, device=device)
all_embeddings.append(embeddings)

if convert_to_numpy:
Expand Down
2 changes: 1 addition & 1 deletion FlagEmbedding/inference/embedder/encoder_only/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def encode_single_device(
embeddings = cast(torch.Tensor, embeddings)

if convert_to_numpy:
embeddings = embeddings.cpu().numpy()
embeddings = self._convert_to_numpy(embeddings, device=device)
all_embeddings.append(embeddings)

if convert_to_numpy:
Expand Down
29 changes: 19 additions & 10 deletions FlagEmbedding/inference/embedder/encoder_only/m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,23 +431,23 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list):
)

if return_dense:
all_dense_embeddings.append(outputs['dense_vecs'].cpu().numpy())
all_dense_embeddings.append(self._convert_to_numpy(outputs['dense_vecs'], device=device))

if return_sparse:
token_weights = outputs['sparse_vecs'].squeeze(-1)
all_lexical_weights.extend(
list(map(
_process_token_weights,
token_weights.cpu().numpy(),
inputs_batch['input_ids'].cpu().numpy().tolist()
self._convert_to_numpy(token_weights, device=device),
self._convert_to_numpy(inputs_batch['input_ids'], device=device).tolist()
)))

if return_colbert_vecs:
all_colbert_vecs.extend(
list(map(
_process_colbert_vecs,
outputs['colbert_vecs'].cpu().numpy(),
inputs_batch['attention_mask'].cpu().numpy()
self._convert_to_numpy(outputs['colbert_vecs'], device=device),
self._convert_to_numpy(inputs_batch['attention_mask'], device=device)
)))

if return_dense:
Expand Down Expand Up @@ -700,19 +700,28 @@ def _tokenize(texts: list, max_length: int):
inx, inx].float(), colbert_scores[inx, inx].float()

all_scores['colbert'].extend(
colbert_scores.cpu().numpy().tolist()
self._convert_to_numpy(colbert_scores, device=device).tolist()
)
all_scores['sparse'].extend(
sparse_scores.cpu().numpy().tolist()
self._convert_to_numpy(sparse_scores, device=device).tolist()
)
all_scores['dense'].extend(
dense_scores.cpu().numpy().tolist()
self._convert_to_numpy(dense_scores, device=device).tolist()
)
all_scores['sparse+dense'].extend(
((sparse_scores * weights_for_different_modes[1] + dense_scores * weights_for_different_modes[0])/(weights_for_different_modes[1]+weights_for_different_modes[0])).cpu().numpy().tolist()
self._convert_to_numpy(
(sparse_scores * weights_for_different_modes[1] + dense_scores * weights_for_different_modes[0])
/ (weights_for_different_modes[1] + weights_for_different_modes[0]),
device=device,
).tolist()
)
all_scores['colbert+sparse+dense'].extend(
((colbert_scores * weights_for_different_modes[2] + sparse_scores * weights_for_different_modes[1] + dense_scores * weights_for_different_modes[0])/weight_sum).cpu().numpy().tolist()
self._convert_to_numpy(
(colbert_scores * weights_for_different_modes[2]
+ sparse_scores * weights_for_different_modes[1]
+ dense_scores * weights_for_different_modes[0]) / weight_sum,
device=device,
).tolist()
)

if one_input_pair:
Expand Down
Loading