Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions FlagEmbedding/abc/inference/AbsEmbedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
model_name_or_path: str,
normalize_embeddings: bool = True,
use_fp16: bool = True,
use_bf16: bool = False,
query_instruction_for_retrieval: Optional[str] = None,
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval
devices: Optional[Union[str, int, List[str], List[int]]] = None,
Expand All @@ -62,6 +63,7 @@ def __init__(
self.model_name_or_path = model_name_or_path
self.normalize_embeddings = normalize_embeddings
self.use_fp16 = use_fp16
self.use_bf16 = use_bf16
self.query_instruction_for_retrieval = query_instruction_for_retrieval
self.query_instruction_format = query_instruction_format
self.target_devices = self.get_target_devices(devices)
Expand All @@ -81,6 +83,13 @@ def __init__(
self.model = None
self.pool = None

def get_model_torch_dtype(self) -> torch.dtype:
if self.use_bf16:
return torch.bfloat16
if self.use_fp16:
return torch.float16
return torch.float32

def stop_self_pool(self):
if self.pool is not None:
self.stop_multi_process_pool(self.pool)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re
model = AutoModel.from_pretrained(
model_args.model_name_or_path,
# torch_dtype=torch.bfloat16,
use_flash_attention_2=True if model_args.use_flash_attn else False,
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand Down Expand Up @@ -152,7 +152,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
model = AutoModel.from_pretrained(
model_args.model_name_or_path,
# torch_dtype=torch.bfloat16,
use_flash_attention_2=True if model_args.use_flash_attn else False,
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_dir: str,
model = AutoModel.from_pretrained(
model_args.model_name_or_path,
# torch_dtype=torch.bfloat16,
use_flash_attention_2=True if model_args.use_flash_attn else False,
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand Down Expand Up @@ -150,7 +150,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_d
model = AutoModel.from_pretrained(
model_args.model_name_or_path,
# torch_dtype=torch.bfloat16,
use_flash_attention_2=True if model_args.use_flash_attn else False,
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand Down
9 changes: 6 additions & 3 deletions FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import torch
import logging
from typing import Tuple
from typing import Tuple, Optional
from transformers import (
AutoModel, AutoConfig,
AutoTokenizer, PreTrainedTokenizer
Expand Down Expand Up @@ -44,7 +44,8 @@ def get_model(
model_name_or_path: str,
trust_remote_code: bool = False,
colbert_dim: int = -1,
cache_dir: str = None
cache_dir: str = None,
torch_dtype: Optional[torch.dtype] = None,
):
"""Get the model.

Expand All @@ -54,6 +55,7 @@ def get_model(
trust_remote_code (bool, optional): trust_remote_code to use when loading models from HF. Defaults to ``False``.
colbert_dim (int, optional): Colbert dim to set. Defaults to ``-1``.
cache_dir (str, optional): HF cache dir to store the model. Defaults to ``None``.
torch_dtype (Optional[torch.dtype], optional): Torch dtype used when loading model weights. Defaults to ``None``.

Returns:
dict: A dictionary containing the model, colbert linear and sparse linear.
Expand All @@ -69,7 +71,8 @@ def get_model(
model = AutoModel.from_pretrained(
model_name_or_path,
cache_dir=cache_folder,
trust_remote_code=trust_remote_code
trust_remote_code=trust_remote_code,
dtype=torch_dtype,
)
colbert_linear = torch.nn.Linear(
in_features=model.config.hidden_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_model(model_args: RerankerModelArguments):
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
# torch_dtype=torch.bfloat16,
use_flash_attention_2=True if model_args.use_flash_attn else False,
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand Down Expand Up @@ -135,7 +135,7 @@ def save_merged_model(model_args: RerankerModelArguments, output_dir: str):
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
# torch_dtype=torch.bfloat16,
use_flash_attention_2=True if model_args.use_flash_attn else False,
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
# torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
use_flash_attention_2=True if model_args.use_flash_attn else False,
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand Down Expand Up @@ -131,7 +131,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
model = LayerWiseMiniCPMForCausalLM.from_pretrained(
model_args.model_name_or_path,
# torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
use_flash_attention_2=True if model_args.use_flash_attn else False,
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand Down
2 changes: 2 additions & 0 deletions FlagEmbedding/inference/auto_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def from_finetuned(
model_class: Optional[Union[str, EmbedderModelClass]] = None,
normalize_embeddings: bool = True,
use_fp16: bool = True,
use_bf16: bool = False,
query_instruction_for_retrieval: Optional[str] = None,
devices: Optional[Union[str, List[str]]] = None,
pooling_method: Optional[str] = None,
Expand Down Expand Up @@ -102,6 +103,7 @@ def from_finetuned(
model_name_or_path,
normalize_embeddings=normalize_embeddings,
use_fp16=use_fp16,
use_bf16=use_bf16,
query_instruction_for_retrieval=query_instruction_for_retrieval,
query_instruction_format=query_instruction_format,
devices=devices,
Expand Down
9 changes: 6 additions & 3 deletions FlagEmbedding/inference/embedder/decoder_only/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
model_name_or_path: str,
normalize_embeddings: bool = True,
use_fp16: bool = True,
use_bf16: bool = False,
query_instruction_for_retrieval: Optional[str] = None,
query_instruction_format: str = "Instruct: {}\nQuery: {}", # specify the format of query_instruction_for_retrieval
devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"]
Expand All @@ -77,6 +78,7 @@ def __init__(
model_name_or_path,
normalize_embeddings=normalize_embeddings,
use_fp16=use_fp16,
use_bf16=use_bf16,
query_instruction_for_retrieval=query_instruction_for_retrieval,
query_instruction_format=query_instruction_format,
devices=devices,
Expand All @@ -95,7 +97,8 @@ def __init__(
self.model = AutoModel.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir
cache_dir=cache_dir,
dtype=self.get_model_torch_dtype(),
)

if self.kwargs.get("pooling_method", "last_token") != "last_token":
Expand Down Expand Up @@ -211,8 +214,8 @@ def encode_single_device(
if device is None:
device = self.target_devices[0]

if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if device == "cpu":
self.model.float()

self.model.to(device)
self.model.eval()
Expand Down
9 changes: 6 additions & 3 deletions FlagEmbedding/inference/embedder/decoder_only/icl.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
model_name_or_path: str,
normalize_embeddings: bool = True,
use_fp16: bool = True,
use_bf16: bool = False,
query_instruction_for_retrieval: Optional[str] = None,
query_instruction_format: str = "<instruct>{}\n<query>{}", # specify the format of query_instruction_for_retrieval
suffix: str = '\n<response>',
Expand All @@ -90,6 +91,7 @@ def __init__(
model_name_or_path,
normalize_embeddings=normalize_embeddings,
use_fp16=use_fp16,
use_bf16=use_bf16,
query_instruction_for_retrieval=query_instruction_for_retrieval,
query_instruction_format=query_instruction_format,
devices=devices,
Expand All @@ -108,7 +110,8 @@ def __init__(
self.model = AutoModel.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir
cache_dir=cache_dir,
torch_dtype=self.get_model_torch_dtype(),
)
self.examples_for_task = examples_for_task
self.examples_instruction_format = examples_instruction_format
Expand Down Expand Up @@ -340,8 +343,8 @@ def encode_queries_single_device(
if device is None:
device = self.target_devices[0]

if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if device == "cpu":
self.model.float()

self.model.to(device)
self.model.eval()
Expand Down
9 changes: 6 additions & 3 deletions FlagEmbedding/inference/embedder/encoder_only/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
model_name_or_path: str,
normalize_embeddings: bool = True,
use_fp16: bool = True,
use_bf16: bool = False,
query_instruction_for_retrieval: Optional[str] = None,
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval
devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"]
Expand All @@ -60,6 +61,7 @@ def __init__(
model_name_or_path,
normalize_embeddings=normalize_embeddings,
use_fp16=use_fp16,
use_bf16=use_bf16,
query_instruction_for_retrieval=query_instruction_for_retrieval,
query_instruction_format=query_instruction_format,
devices=devices,
Expand All @@ -79,7 +81,8 @@ def __init__(
self.model = AutoModel.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir
cache_dir=cache_dir,
dtype=self.get_model_torch_dtype(),
)

def encode_queries(
Expand Down Expand Up @@ -192,8 +195,8 @@ def encode_single_device(
if device is None:
device = self.target_devices[0]

if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if device == "cpu":
self.model.float()

self.model.to(device)
self.model.eval()
Expand Down
13 changes: 8 additions & 5 deletions FlagEmbedding/inference/embedder/encoder_only/m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
model_name_or_path: str,
normalize_embeddings: bool = True,
use_fp16: bool = True,
use_bf16: bool = False,
query_instruction_for_retrieval: Optional[str] = None,
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval
devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"]
Expand All @@ -73,6 +74,7 @@ def __init__(
model_name_or_path,
normalize_embeddings=normalize_embeddings,
use_fp16=use_fp16,
use_bf16=use_bf16,
query_instruction_for_retrieval=query_instruction_for_retrieval,
query_instruction_format=query_instruction_format,
devices=devices,
Expand All @@ -96,7 +98,8 @@ def __init__(
model_name_or_path,
trust_remote_code=trust_remote_code,
colbert_dim=colbert_dim,
cache_dir=cache_dir
cache_dir=cache_dir,
torch_dtype=self.get_model_torch_dtype(),
),
tokenizer=self.tokenizer,
sentence_pooling_method=pooling_method,
Expand Down Expand Up @@ -334,8 +337,8 @@ def encode_single_device(
if device is None:
device = self.target_devices[0]

if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if device == "cpu":
self.model.float()

self.model.to(device)
self.model.eval()
Expand Down Expand Up @@ -630,8 +633,8 @@ def _tokenize(texts: list, max_length: int):
if device is None:
device = self.target_devices[0]

if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
if device == "cpu":
self.model.float()

self.model.to(device)
self.model.eval()
Expand Down
Loading