From d65ffb17b77d17cc4d502999903a138b254ebfda Mon Sep 17 00:00:00 2001 From: lnxtree Date: Thu, 26 Mar 2026 15:19:05 +0800 Subject: [PATCH 1/2] fix(embedder): add _convert_to_numpy in base class and guard bf16->numpy on non-cpu --- FlagEmbedding/abc/inference/AbsEmbedder.py | 17 +++++++++++++++++ .../inference/embedder/decoder_only/base.py | 2 +- .../inference/embedder/decoder_only/icl.py | 4 ++-- .../inference/embedder/encoder_only/base.py | 2 +- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/FlagEmbedding/abc/inference/AbsEmbedder.py b/FlagEmbedding/abc/inference/AbsEmbedder.py index 1b61c0a1..1d9438f0 100644 --- a/FlagEmbedding/abc/inference/AbsEmbedder.py +++ b/FlagEmbedding/abc/inference/AbsEmbedder.py @@ -450,3 +450,20 @@ def _concatenate_results_from_multi_process(self, results_list: List[Union[torch return np.concatenate(results_list, axis=0) else: raise NotImplementedError("Unsupported type for results_list") + + def _convert_to_numpy(self, embeddings: torch.Tensor, device: Optional[str] = None) -> np.ndarray: + """Convert tensor embeddings to numpy with bf16-safe handling. + + NumPy does not support bfloat16, so we upcast to float32 only when + bf16 inference is enabled on non-CPU devices. + + Args: + embeddings (torch.Tensor): Embedding tensor. + device (Optional[str], optional): Inference device string. Defaults to ``None``. + + Returns: + np.ndarray: Embeddings in numpy format. + """ + if device != "cpu" and self.use_bf16 and embeddings.dtype == torch.bfloat16: + embeddings = embeddings.float() + return embeddings.cpu().numpy() \ No newline at end of file diff --git a/FlagEmbedding/inference/embedder/decoder_only/base.py b/FlagEmbedding/inference/embedder/decoder_only/base.py index 736e58a5..1fda6243 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/base.py +++ b/FlagEmbedding/inference/embedder/decoder_only/base.py @@ -281,7 +281,7 @@ def encode_single_device( embeddings = cast(torch.Tensor, embeddings) if convert_to_numpy: - embeddings = embeddings.cpu().numpy() + embeddings = self._convert_to_numpy(embeddings, device=device) all_embeddings.append(embeddings) if convert_to_numpy: diff --git a/FlagEmbedding/inference/embedder/decoder_only/icl.py b/FlagEmbedding/inference/embedder/decoder_only/icl.py index c7495b6e..dc606f6c 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/icl.py +++ b/FlagEmbedding/inference/embedder/decoder_only/icl.py @@ -437,7 +437,7 @@ def encode_queries_single_device( embeddings = cast(torch.Tensor, embeddings) if convert_to_numpy: - embeddings = embeddings.cpu().numpy() + embeddings = self._convert_to_numpy(embeddings, device=device) all_embeddings.append(embeddings) if convert_to_numpy: @@ -546,7 +546,7 @@ def encode_single_device( embeddings = cast(torch.Tensor, embeddings) if convert_to_numpy: - embeddings = embeddings.cpu().numpy() + embeddings = self._convert_to_numpy(embeddings, device=device) all_embeddings.append(embeddings) if convert_to_numpy: diff --git a/FlagEmbedding/inference/embedder/encoder_only/base.py b/FlagEmbedding/inference/embedder/encoder_only/base.py index e3c2abce..888c7747 100644 --- a/FlagEmbedding/inference/embedder/encoder_only/base.py +++ b/FlagEmbedding/inference/embedder/encoder_only/base.py @@ -262,7 +262,7 @@ def encode_single_device( embeddings = cast(torch.Tensor, embeddings) if convert_to_numpy: - embeddings = embeddings.cpu().numpy() + embeddings = self._convert_to_numpy(embeddings, device=device) all_embeddings.append(embeddings) if convert_to_numpy: From 2955a5a6e5122a3c45f3e2e81d8d7c2517bff1d5 Mon Sep 17 00:00:00 2001 From: lnxtree Date: Thu, 26 Mar 2026 19:14:12 +0800 Subject: [PATCH 2/2] fix(m3): make bf16 inference/train loading safe and unify numpy conversion - replace m3 embedder .cpu().numpy() paths with base _convert_to_numpy(...) for bf16-safe conversion - add torch_dtype plumbing in m3 runner model loading (AutoModel/colbert_linear/sparse_linear) to keep dtype behavior consistent --- .../embedder/encoder_only/m3/runner.py | 6 ++-- .../inference/embedder/decoder_only/icl.py | 4 +-- .../inference/embedder/encoder_only/m3.py | 29 ++++++++++++------- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py b/FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py index aa7136af..3157a3d5 100644 --- a/FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py +++ b/FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py @@ -76,11 +76,13 @@ def get_model( ) colbert_linear = torch.nn.Linear( in_features=model.config.hidden_size, - out_features=model.config.hidden_size if colbert_dim <= 0 else colbert_dim + out_features=model.config.hidden_size if colbert_dim <= 0 else colbert_dim, + dtype=torch_dtype, ) sparse_linear = torch.nn.Linear( in_features=model.config.hidden_size, - out_features=1 + out_features=1, + dtype=torch_dtype, ) colbert_model_path = os.path.join(model_name_or_path, 'colbert_linear.pt') diff --git a/FlagEmbedding/inference/embedder/decoder_only/icl.py b/FlagEmbedding/inference/embedder/decoder_only/icl.py index dc606f6c..affba718 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/icl.py +++ b/FlagEmbedding/inference/embedder/decoder_only/icl.py @@ -479,8 +479,8 @@ def encode_single_device( if device is None: device = self.target_devices[0] - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if device == "cpu": + self.model.float() self.model.to(device) self.model.eval() diff --git a/FlagEmbedding/inference/embedder/encoder_only/m3.py b/FlagEmbedding/inference/embedder/encoder_only/m3.py index 1f2c24c9..c5678a40 100644 --- a/FlagEmbedding/inference/embedder/encoder_only/m3.py +++ b/FlagEmbedding/inference/embedder/encoder_only/m3.py @@ -431,23 +431,23 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list): ) if return_dense: - all_dense_embeddings.append(outputs['dense_vecs'].cpu().numpy()) + all_dense_embeddings.append(self._convert_to_numpy(outputs['dense_vecs'], device=device)) if return_sparse: token_weights = outputs['sparse_vecs'].squeeze(-1) all_lexical_weights.extend( list(map( _process_token_weights, - token_weights.cpu().numpy(), - inputs_batch['input_ids'].cpu().numpy().tolist() + self._convert_to_numpy(token_weights, device=device), + self._convert_to_numpy(inputs_batch['input_ids'], device=device).tolist() ))) if return_colbert_vecs: all_colbert_vecs.extend( list(map( _process_colbert_vecs, - outputs['colbert_vecs'].cpu().numpy(), - inputs_batch['attention_mask'].cpu().numpy() + self._convert_to_numpy(outputs['colbert_vecs'], device=device), + self._convert_to_numpy(inputs_batch['attention_mask'], device=device) ))) if return_dense: @@ -700,19 +700,28 @@ def _tokenize(texts: list, max_length: int): inx, inx].float(), colbert_scores[inx, inx].float() all_scores['colbert'].extend( - colbert_scores.cpu().numpy().tolist() + self._convert_to_numpy(colbert_scores, device=device).tolist() ) all_scores['sparse'].extend( - sparse_scores.cpu().numpy().tolist() + self._convert_to_numpy(sparse_scores, device=device).tolist() ) all_scores['dense'].extend( - dense_scores.cpu().numpy().tolist() + self._convert_to_numpy(dense_scores, device=device).tolist() ) all_scores['sparse+dense'].extend( - ((sparse_scores * weights_for_different_modes[1] + dense_scores * weights_for_different_modes[0])/(weights_for_different_modes[1]+weights_for_different_modes[0])).cpu().numpy().tolist() + self._convert_to_numpy( + (sparse_scores * weights_for_different_modes[1] + dense_scores * weights_for_different_modes[0]) + / (weights_for_different_modes[1] + weights_for_different_modes[0]), + device=device, + ).tolist() ) all_scores['colbert+sparse+dense'].extend( - ((colbert_scores * weights_for_different_modes[2] + sparse_scores * weights_for_different_modes[1] + dense_scores * weights_for_different_modes[0])/weight_sum).cpu().numpy().tolist() + self._convert_to_numpy( + (colbert_scores * weights_for_different_modes[2] + + sparse_scores * weights_for_different_modes[1] + + dense_scores * weights_for_different_modes[0]) / weight_sum, + device=device, + ).tolist() ) if one_input_pair: