diff --git a/FlagEmbedding/abc/inference/AbsEmbedder.py b/FlagEmbedding/abc/inference/AbsEmbedder.py index fe0da04e..1b61c0a1 100644 --- a/FlagEmbedding/abc/inference/AbsEmbedder.py +++ b/FlagEmbedding/abc/inference/AbsEmbedder.py @@ -49,6 +49,7 @@ def __init__( model_name_or_path: str, normalize_embeddings: bool = True, use_fp16: bool = True, + use_bf16: bool = False, query_instruction_for_retrieval: Optional[str] = None, query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval devices: Optional[Union[str, int, List[str], List[int]]] = None, @@ -62,6 +63,7 @@ def __init__( self.model_name_or_path = model_name_or_path self.normalize_embeddings = normalize_embeddings self.use_fp16 = use_fp16 + self.use_bf16 = use_bf16 self.query_instruction_for_retrieval = query_instruction_for_retrieval self.query_instruction_format = query_instruction_format self.target_devices = self.get_target_devices(devices) @@ -81,6 +83,13 @@ def __init__( self.model = None self.pool = None + def get_model_torch_dtype(self) -> torch.dtype: + if self.use_bf16: + return torch.bfloat16 + if self.use_fp16: + return torch.float16 + return torch.float32 + def stop_self_pool(self): if self.pool is not None: self.stop_multi_process_pool(self.pool) diff --git a/FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py b/FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py index 517ea0a0..866aafd2 100644 --- a/FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py +++ b/FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py @@ -71,7 +71,7 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re model = AutoModel.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -152,7 +152,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: model = AutoModel.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), diff --git a/FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py b/FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py index f2fce2b4..5a5c790a 100644 --- a/FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py +++ b/FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py @@ -71,7 +71,7 @@ def get_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_dir: str, model = AutoModel.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -150,7 +150,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_d model = AutoModel.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), diff --git a/FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py b/FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py index bd3de30d..aa7136af 100644 --- a/FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py +++ b/FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py @@ -1,7 +1,7 @@ import os import torch import logging -from typing import Tuple +from typing import Tuple, Optional from transformers import ( AutoModel, AutoConfig, AutoTokenizer, PreTrainedTokenizer @@ -44,7 +44,8 @@ def get_model( model_name_or_path: str, trust_remote_code: bool = False, colbert_dim: int = -1, - cache_dir: str = None + cache_dir: str = None, + torch_dtype: Optional[torch.dtype] = None, ): """Get the model. @@ -54,6 +55,7 @@ def get_model( trust_remote_code (bool, optional): trust_remote_code to use when loading models from HF. Defaults to ``False``. colbert_dim (int, optional): Colbert dim to set. Defaults to ``-1``. cache_dir (str, optional): HF cache dir to store the model. Defaults to ``None``. + torch_dtype (Optional[torch.dtype], optional): Torch dtype used when loading model weights. Defaults to ``None``. Returns: dict: A dictionary containing the model, colbert linear and sparse linear. @@ -69,7 +71,8 @@ def get_model( model = AutoModel.from_pretrained( model_name_or_path, cache_dir=cache_folder, - trust_remote_code=trust_remote_code + trust_remote_code=trust_remote_code, + dtype=torch_dtype, ) colbert_linear = torch.nn.Linear( in_features=model.config.hidden_size, diff --git a/FlagEmbedding/finetune/reranker/decoder_only/base/load_model.py b/FlagEmbedding/finetune/reranker/decoder_only/base/load_model.py index 82173bdf..6caa047c 100644 --- a/FlagEmbedding/finetune/reranker/decoder_only/base/load_model.py +++ b/FlagEmbedding/finetune/reranker/decoder_only/base/load_model.py @@ -67,7 +67,7 @@ def get_model(model_args: RerankerModelArguments): model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -135,7 +135,7 @@ def save_merged_model(model_args: RerankerModelArguments, output_dir: str): model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), diff --git a/FlagEmbedding/finetune/reranker/decoder_only/layerwise/load_model.py b/FlagEmbedding/finetune/reranker/decoder_only/layerwise/load_model.py index 7ae0586a..eff0d3b3 100644 --- a/FlagEmbedding/finetune/reranker/decoder_only/layerwise/load_model.py +++ b/FlagEmbedding/finetune/reranker/decoder_only/layerwise/load_model.py @@ -77,7 +77,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit): model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, # torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -131,7 +131,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit): model = LayerWiseMiniCPMForCausalLM.from_pretrained( model_args.model_name_or_path, # torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16, - use_flash_attention_2=True if model_args.use_flash_attn else False, + attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None, token=model_args.token, cache_dir=model_args.cache_dir, from_tf=bool(".ckpt" in model_args.model_name_or_path), diff --git a/FlagEmbedding/inference/auto_embedder.py b/FlagEmbedding/inference/auto_embedder.py index 18f1dab8..fc646ea8 100644 --- a/FlagEmbedding/inference/auto_embedder.py +++ b/FlagEmbedding/inference/auto_embedder.py @@ -26,6 +26,7 @@ def from_finetuned( model_class: Optional[Union[str, EmbedderModelClass]] = None, normalize_embeddings: bool = True, use_fp16: bool = True, + use_bf16: bool = False, query_instruction_for_retrieval: Optional[str] = None, devices: Optional[Union[str, List[str]]] = None, pooling_method: Optional[str] = None, @@ -102,6 +103,7 @@ def from_finetuned( model_name_or_path, normalize_embeddings=normalize_embeddings, use_fp16=use_fp16, + use_bf16=use_bf16, query_instruction_for_retrieval=query_instruction_for_retrieval, query_instruction_format=query_instruction_format, devices=devices, diff --git a/FlagEmbedding/inference/embedder/decoder_only/base.py b/FlagEmbedding/inference/embedder/decoder_only/base.py index 3765cf0f..736e58a5 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/base.py +++ b/FlagEmbedding/inference/embedder/decoder_only/base.py @@ -60,6 +60,7 @@ def __init__( model_name_or_path: str, normalize_embeddings: bool = True, use_fp16: bool = True, + use_bf16: bool = False, query_instruction_for_retrieval: Optional[str] = None, query_instruction_format: str = "Instruct: {}\nQuery: {}", # specify the format of query_instruction_for_retrieval devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"] @@ -77,6 +78,7 @@ def __init__( model_name_or_path, normalize_embeddings=normalize_embeddings, use_fp16=use_fp16, + use_bf16=use_bf16, query_instruction_for_retrieval=query_instruction_for_retrieval, query_instruction_format=query_instruction_format, devices=devices, @@ -95,7 +97,8 @@ def __init__( self.model = AutoModel.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code, - cache_dir=cache_dir + cache_dir=cache_dir, + dtype=self.get_model_torch_dtype(), ) if self.kwargs.get("pooling_method", "last_token") != "last_token": @@ -211,8 +214,8 @@ def encode_single_device( if device is None: device = self.target_devices[0] - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if device == "cpu": + self.model.float() self.model.to(device) self.model.eval() diff --git a/FlagEmbedding/inference/embedder/decoder_only/icl.py b/FlagEmbedding/inference/embedder/decoder_only/icl.py index b8ea6466..c7495b6e 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/icl.py +++ b/FlagEmbedding/inference/embedder/decoder_only/icl.py @@ -68,6 +68,7 @@ def __init__( model_name_or_path: str, normalize_embeddings: bool = True, use_fp16: bool = True, + use_bf16: bool = False, query_instruction_for_retrieval: Optional[str] = None, query_instruction_format: str = "{}\n{}", # specify the format of query_instruction_for_retrieval suffix: str = '\n', @@ -90,6 +91,7 @@ def __init__( model_name_or_path, normalize_embeddings=normalize_embeddings, use_fp16=use_fp16, + use_bf16=use_bf16, query_instruction_for_retrieval=query_instruction_for_retrieval, query_instruction_format=query_instruction_format, devices=devices, @@ -108,7 +110,8 @@ def __init__( self.model = AutoModel.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code, - cache_dir=cache_dir + cache_dir=cache_dir, + torch_dtype=self.get_model_torch_dtype(), ) self.examples_for_task = examples_for_task self.examples_instruction_format = examples_instruction_format @@ -340,8 +343,8 @@ def encode_queries_single_device( if device is None: device = self.target_devices[0] - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if device == "cpu": + self.model.float() self.model.to(device) self.model.eval() diff --git a/FlagEmbedding/inference/embedder/encoder_only/base.py b/FlagEmbedding/inference/embedder/encoder_only/base.py index 71a23057..e3c2abce 100644 --- a/FlagEmbedding/inference/embedder/encoder_only/base.py +++ b/FlagEmbedding/inference/embedder/encoder_only/base.py @@ -42,6 +42,7 @@ def __init__( model_name_or_path: str, normalize_embeddings: bool = True, use_fp16: bool = True, + use_bf16: bool = False, query_instruction_for_retrieval: Optional[str] = None, query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"] @@ -60,6 +61,7 @@ def __init__( model_name_or_path, normalize_embeddings=normalize_embeddings, use_fp16=use_fp16, + use_bf16=use_bf16, query_instruction_for_retrieval=query_instruction_for_retrieval, query_instruction_format=query_instruction_format, devices=devices, @@ -79,7 +81,8 @@ def __init__( self.model = AutoModel.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code, - cache_dir=cache_dir + cache_dir=cache_dir, + dtype=self.get_model_torch_dtype(), ) def encode_queries( @@ -192,8 +195,8 @@ def encode_single_device( if device is None: device = self.target_devices[0] - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if device == "cpu": + self.model.float() self.model.to(device) self.model.eval() diff --git a/FlagEmbedding/inference/embedder/encoder_only/m3.py b/FlagEmbedding/inference/embedder/encoder_only/m3.py index 2082003a..1f2c24c9 100644 --- a/FlagEmbedding/inference/embedder/encoder_only/m3.py +++ b/FlagEmbedding/inference/embedder/encoder_only/m3.py @@ -52,6 +52,7 @@ def __init__( model_name_or_path: str, normalize_embeddings: bool = True, use_fp16: bool = True, + use_bf16: bool = False, query_instruction_for_retrieval: Optional[str] = None, query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"] @@ -73,6 +74,7 @@ def __init__( model_name_or_path, normalize_embeddings=normalize_embeddings, use_fp16=use_fp16, + use_bf16=use_bf16, query_instruction_for_retrieval=query_instruction_for_retrieval, query_instruction_format=query_instruction_format, devices=devices, @@ -96,7 +98,8 @@ def __init__( model_name_or_path, trust_remote_code=trust_remote_code, colbert_dim=colbert_dim, - cache_dir=cache_dir + cache_dir=cache_dir, + torch_dtype=self.get_model_torch_dtype(), ), tokenizer=self.tokenizer, sentence_pooling_method=pooling_method, @@ -334,8 +337,8 @@ def encode_single_device( if device is None: device = self.target_devices[0] - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if device == "cpu": + self.model.float() self.model.to(device) self.model.eval() @@ -630,8 +633,8 @@ def _tokenize(texts: list, max_length: int): if device is None: device = self.target_devices[0] - if device == "cpu": self.use_fp16 = False - if self.use_fp16: self.model.half() + if device == "cpu": + self.model.float() self.model.to(device) self.model.eval()