From 90957832acdfa937939467d9715c166050f46327 Mon Sep 17 00:00:00 2001 From: lnxtree Date: Thu, 26 Mar 2026 12:17:14 +0800 Subject: [PATCH 1/4] feat(embedder): add use_bf16 support and unify inference dtype behavior add use_bf16 to auto embedder and all embedder constructors --- FlagEmbedding/abc/inference/AbsEmbedder.py | 9 +++++++++ FlagEmbedding/inference/auto_embedder.py | 2 ++ .../inference/embedder/decoder_only/base.py | 9 ++++++--- .../inference/embedder/decoder_only/icl.py | 8 +++++--- .../inference/embedder/encoder_only/base.py | 9 ++++++--- FlagEmbedding/inference/embedder/encoder_only/m3.py | 13 ++++++++----- 6 files changed, 36 insertions(+), 14 deletions(-) diff --git a/FlagEmbedding/abc/inference/AbsEmbedder.py b/FlagEmbedding/abc/inference/AbsEmbedder.py index fe0da04e..1b61c0a1 100644 --- a/FlagEmbedding/abc/inference/AbsEmbedder.py +++ b/FlagEmbedding/abc/inference/AbsEmbedder.py @@ -49,6 +49,7 @@ def __init__( model_name_or_path: str, normalize_embeddings: bool = True, use_fp16: bool = True, + use_bf16: bool = False, query_instruction_for_retrieval: Optional[str] = None, query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval devices: Optional[Union[str, int, List[str], List[int]]] = None, @@ -62,6 +63,7 @@ def __init__( self.model_name_or_path = model_name_or_path self.normalize_embeddings = normalize_embeddings self.use_fp16 = use_fp16 + self.use_bf16 = use_bf16 self.query_instruction_for_retrieval = query_instruction_for_retrieval self.query_instruction_format = query_instruction_format self.target_devices = self.get_target_devices(devices) @@ -81,6 +83,13 @@ def __init__( self.model = None self.pool = None + def get_model_torch_dtype(self) -> torch.dtype: + if self.use_bf16: + return torch.bfloat16 + if self.use_fp16: + return torch.float16 + return torch.float32 + def stop_self_pool(self): if self.pool is not None: self.stop_multi_process_pool(self.pool) diff --git a/FlagEmbedding/inference/auto_embedder.py b/FlagEmbedding/inference/auto_embedder.py index 18f1dab8..fc646ea8 100644 --- a/FlagEmbedding/inference/auto_embedder.py +++ b/FlagEmbedding/inference/auto_embedder.py @@ -26,6 +26,7 @@ def from_finetuned( model_class: Optional[Union[str, EmbedderModelClass]] = None, normalize_embeddings: bool = True, use_fp16: bool = True, + use_bf16: bool = False, query_instruction_for_retrieval: Optional[str] = None, devices: Optional[Union[str, List[str]]] = None, pooling_method: Optional[str] = None, @@ -102,6 +103,7 @@ def from_finetuned( model_name_or_path, normalize_embeddings=normalize_embeddings, use_fp16=use_fp16, + use_bf16=use_bf16, query_instruction_for_retrieval=query_instruction_for_retrieval, query_instruction_format=query_instruction_format, devices=devices, diff --git a/FlagEmbedding/inference/embedder/decoder_only/base.py b/FlagEmbedding/inference/embedder/decoder_only/base.py index 3765cf0f..736e58a5 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/base.py +++ b/FlagEmbedding/inference/embedder/decoder_only/base.py @@ -60,6 +60,7 @@ def __init__( model_name_or_path: str, normalize_embeddings: bool = True, use_fp16: bool = True, + use_bf16: bool = False, query_instruction_for_retrieval: Optional[str] = None, query_instruction_format: str = "Instruct: {}\nQuery: {}", # specify the format of query_instruction_for_retrieval devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"] @@ -77,6 +78,7 @@ def __init__( model_name_or_path, normalize_embeddings=normalize_embeddings, use_fp16=use_fp16, + use_bf16=use_bf16, query_instruction_for_retrieval=query_instruction_for_retrieval, query_instruction_format=query_instruction_format, devices=devices, @@ -95,7 +97,8 @@ def __init__( self.model = AutoModel.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code, - cache_dir=cache_dir + cache_dir=cache_dir, + dtype=self.get_model_torch_dtype(), ) if self.kwargs.get("pooling_method", "last_token") != "last_token": @@ -211,8 +214,8 @@ def encode_single_device( if device is None: device = self.target_devices[0] - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if device == "cpu": + self.model.float() self.model.to(device) self.model.eval() diff --git a/FlagEmbedding/inference/embedder/decoder_only/icl.py b/FlagEmbedding/inference/embedder/decoder_only/icl.py index b8ea6466..5ca64b00 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/icl.py +++ b/FlagEmbedding/inference/embedder/decoder_only/icl.py @@ -68,6 +68,7 @@ def __init__( model_name_or_path: str, normalize_embeddings: bool = True, use_fp16: bool = True, + use_bf16: bool = False, query_instruction_for_retrieval: Optional[str] = None, query_instruction_format: str = "{}\n{}", # specify the format of query_instruction_for_retrieval suffix: str = '\n', @@ -90,6 +91,7 @@ def __init__( model_name_or_path, normalize_embeddings=normalize_embeddings, use_fp16=use_fp16, + use_bf16=use_bf16, query_instruction_for_retrieval=query_instruction_for_retrieval, query_instruction_format=query_instruction_format, devices=devices, @@ -108,7 +110,8 @@ def __init__( self.model = AutoModel.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code, - cache_dir=cache_dir + cache_dir=cache_dir, + torch_dtype=self.get_model_torch_dtype(), ) self.examples_for_task = examples_for_task self.examples_instruction_format = examples_instruction_format @@ -340,8 +343,7 @@ def encode_queries_single_device( if device is None: device = self.target_devices[0] - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if device == "cpu": self.model.float() self.model.to(device) self.model.eval() diff --git a/FlagEmbedding/inference/embedder/encoder_only/base.py b/FlagEmbedding/inference/embedder/encoder_only/base.py index 71a23057..e3c2abce 100644 --- a/FlagEmbedding/inference/embedder/encoder_only/base.py +++ b/FlagEmbedding/inference/embedder/encoder_only/base.py @@ -42,6 +42,7 @@ def __init__( model_name_or_path: str, normalize_embeddings: bool = True, use_fp16: bool = True, + use_bf16: bool = False, query_instruction_for_retrieval: Optional[str] = None, query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"] @@ -60,6 +61,7 @@ def __init__( model_name_or_path, normalize_embeddings=normalize_embeddings, use_fp16=use_fp16, + use_bf16=use_bf16, query_instruction_for_retrieval=query_instruction_for_retrieval, query_instruction_format=query_instruction_format, devices=devices, @@ -79,7 +81,8 @@ def __init__( self.model = AutoModel.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code, - cache_dir=cache_dir + cache_dir=cache_dir, + dtype=self.get_model_torch_dtype(), ) def encode_queries( @@ -192,8 +195,8 @@ def encode_single_device( if device is None: device = self.target_devices[0] - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if device == "cpu": + self.model.float() self.model.to(device) self.model.eval() diff --git a/FlagEmbedding/inference/embedder/encoder_only/m3.py b/FlagEmbedding/inference/embedder/encoder_only/m3.py index 2082003a..1f2c24c9 100644 --- a/FlagEmbedding/inference/embedder/encoder_only/m3.py +++ b/FlagEmbedding/inference/embedder/encoder_only/m3.py @@ -52,6 +52,7 @@ def __init__( model_name_or_path: str, normalize_embeddings: bool = True, use_fp16: bool = True, + use_bf16: bool = False, query_instruction_for_retrieval: Optional[str] = None, query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"] @@ -73,6 +74,7 @@ def __init__( model_name_or_path, normalize_embeddings=normalize_embeddings, use_fp16=use_fp16, + use_bf16=use_bf16, query_instruction_for_retrieval=query_instruction_for_retrieval, query_instruction_format=query_instruction_format, devices=devices, @@ -96,7 +98,8 @@ def __init__( model_name_or_path, trust_remote_code=trust_remote_code, colbert_dim=colbert_dim, - cache_dir=cache_dir + cache_dir=cache_dir, + torch_dtype=self.get_model_torch_dtype(), ), tokenizer=self.tokenizer, sentence_pooling_method=pooling_method, @@ -334,8 +337,8 @@ def encode_single_device( if device is None: device = self.target_devices[0] - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if device == "cpu": + self.model.float() self.model.to(device) self.model.eval() @@ -630,8 +633,8 @@ def _tokenize(texts: list, max_length: int): if device is None: device = self.target_devices[0] - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if device == "cpu": + self.model.float() self.model.to(device) self.model.eval() From 5d6fa9a69ef086ef277200aa677f23f75d5bb9f8 Mon Sep 17 00:00:00 2001 From: lnxtree Date: Thu, 26 Mar 2026 13:37:05 +0800 Subject: [PATCH 2/4] fix: add bf16 interface to EncoderOnlyEmbedderM3Runner.get_model --- .../finetune/embedder/encoder_only/m3/runner.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py b/FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py index bd3de30d..aa7136af 100644 --- a/FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py +++ b/FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py @@ -1,7 +1,7 @@ import os import torch import logging -from typing import Tuple +from typing import Tuple, Optional from transformers import ( AutoModel, AutoConfig, AutoTokenizer, PreTrainedTokenizer @@ -44,7 +44,8 @@ def get_model( model_name_or_path: str, trust_remote_code: bool = False, colbert_dim: int = -1, - cache_dir: str = None + cache_dir: str = None, + torch_dtype: Optional[torch.dtype] = None, ): """Get the model. @@ -54,6 +55,7 @@ def get_model( trust_remote_code (bool, optional): trust_remote_code to use when loading models from HF. Defaults to ``False``. colbert_dim (int, optional): Colbert dim to set. Defaults to ``-1``. cache_dir (str, optional): HF cache dir to store the model. Defaults to ``None``. + torch_dtype (Optional[torch.dtype], optional): Torch dtype used when loading model weights. Defaults to ``None``. Returns: dict: A dictionary containing the model, colbert linear and sparse linear. @@ -69,7 +71,8 @@ def get_model( model = AutoModel.from_pretrained( model_name_or_path, cache_dir=cache_folder, - trust_remote_code=trust_remote_code + trust_remote_code=trust_remote_code, + dtype=torch_dtype, ) colbert_linear = torch.nn.Linear( in_features=model.config.hidden_size, From 6679caaf654d02b7d12c921bd80f6699a7a1f6a5 Mon Sep 17 00:00:00 2001 From: lnxtree Date: Thu, 26 Mar 2026 14:00:41 +0800 Subject: [PATCH 3/4] fix: code format for inference.embedder.decoder_only.icl --- FlagEmbedding/inference/embedder/decoder_only/icl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/FlagEmbedding/inference/embedder/decoder_only/icl.py b/FlagEmbedding/inference/embedder/decoder_only/icl.py index 5ca64b00..c7495b6e 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/icl.py +++ b/FlagEmbedding/inference/embedder/decoder_only/icl.py @@ -343,7 +343,8 @@ def encode_queries_single_device( if device is None: device = self.target_devices[0] - if device == "cpu": self.model.float() + if device == "cpu": + self.model.float() self.model.to(device) self.model.eval() From 19edba7b08d9f56827d2611185823591b0b743ff Mon Sep 17 00:00:00 2001 From: lnxtree Date: Thu, 26 Mar 2026 14:17:14 +0800 Subject: [PATCH 4/4] feat: fix the interface of attn_implementation in embedder.decode_only.*.load_model and reranker.decode_only.*.load_model --- .../finetune/embedder/decoder_only/base/load_model.py | 4 ++-- .../finetune/embedder/decoder_only/icl/load_model.py | 4 ++-- .../finetune/reranker/decoder_only/base/load_model.py | 4 ++-- .../finetune/reranker/decoder_only/layerwise/load_model.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py b/FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py index 517ea0a0..866aafd2 100644 --- a/FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py +++ b/FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py @@ -71,7 +71,7 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re model = AutoModel.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -152,7 +152,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: model = AutoModel.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), diff --git a/FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py b/FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py index f2fce2b4..5a5c790a 100644 --- a/FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py +++ b/FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py @@ -71,7 +71,7 @@ def get_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_dir: str, model = AutoModel.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -150,7 +150,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_d model = AutoModel.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), diff --git a/FlagEmbedding/finetune/reranker/decoder_only/base/load_model.py b/FlagEmbedding/finetune/reranker/decoder_only/base/load_model.py index 82173bdf..6caa047c 100644 --- a/FlagEmbedding/finetune/reranker/decoder_only/base/load_model.py +++ b/FlagEmbedding/finetune/reranker/decoder_only/base/load_model.py @@ -67,7 +67,7 @@ def get_model(model_args: RerankerModelArguments): model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -135,7 +135,7 @@ def save_merged_model(model_args: RerankerModelArguments, output_dir: str): model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), diff --git a/FlagEmbedding/finetune/reranker/decoder_only/layerwise/load_model.py b/FlagEmbedding/finetune/reranker/decoder_only/layerwise/load_model.py index 7ae0586a..eff0d3b3 100644 --- a/FlagEmbedding/finetune/reranker/decoder_only/layerwise/load_model.py +++ b/FlagEmbedding/finetune/reranker/decoder_only/layerwise/load_model.py @@ -77,7 +77,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit): model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, # torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -131,7 +131,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit): model = LayerWiseMiniCPMForCausalLM.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path),