Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
logging,
replace_return_docstrings,
)
from transformers.utils.import_utils import is_torch_fx_available
from FlagEmbedding.utils.transformers_compat import is_torch_fx_available
from .configuration_minicpm_reranker import LayerWiseMiniCPMConfig
import re

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
logging,
replace_return_docstrings,
)
from transformers.utils.import_utils import is_torch_fx_available
from FlagEmbedding.utils.transformers_compat import is_torch_fx_available
from .configuration_minicpm_reranker import LayerWiseMiniCPMConfig
import re

Expand Down
Empty file added FlagEmbedding/utils/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions FlagEmbedding/utils/transformers_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from packaging import version
import transformers

TF_VER = version.parse(getattr(transformers, "__version__", "0.0.0"))
IS_TF_V5_OR_HIGHER = TF_VER >= version.parse("5.0.0")


# ------------- torch.fx availability -------------
# v5 removed is_torch_fx_available. We emulate it via feature detection.
def is_torch_fx_available():
try:
import torch.fx # noqa: F401

return True
except Exception:
return False


# ------------- other utilities that moved -------------
# Pattern:
# try the new location first (v5), then fall back to v4 path, else provide a safe default.
def import_from_candidates(candidates, default=None):
for mod, name in candidates:
try:
module = __import__(mod, fromlist=[name])
return getattr(module, name)
except Exception:
pass
return default
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
include_package_data=True,
install_requires=[
'torch>=1.6.0',
'transformers>=4.44.2',
'transformers>=4.44.2,<6.0.0',
'datasets>=2.19.0',
'accelerate>=0.20.1',
'sentence_transformers',
Expand Down
42 changes: 42 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# FlagEmbedding Tests

This directory contains tests for the FlagEmbedding library, including compatibility tests for Transformers 5.0.

## Test Files

- `test_imports_v5.py`: Tests that imports work with Transformers v5, particularly the compatibility layer for `is_torch_fx_available`.
- `test_infer_embedder_basic.py`: Tests basic functionality of BGE embedder models with a small public checkpoint.
- `test_infer_reranker_basic.py`: Tests basic functionality of reranker models.

## Running Tests

1. create a python venv `python -m venv pytest_venv`
2. activate venv `source pytest_venv/bin/activate`
3. install pytest `pip install pytest`
4. install flagembedding package in development mode: `pip install -e .`

Then run the tests using pytest:

```bash
# Run all tests
pytest tests/

# Run a specific test file
pytest tests/test_imports_v5.py

# Run with verbose output
pytest -v tests/
```

## Transformers 5.0 Compatibility

The tests verify that FlagEmbedding works with Transformers 5.0, which removed the `is_torch_fx_available` function.
The compatibility layer in `FlagEmbedding/utils/transformers_compat.py` provides this function for backward compatibility.

**Note:** Transformers 5.0 requires Python 3.10 or higher. If you're using Python 3.9 or lower, you'll need to upgrade your Python version to test with Transformers 5.0.

To test with a specific version of transformers (with Python 3.10+):

```bash
pip install transformers==5.0.0
pytest tests/
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
Common pytest fixtures and configuration for FlagEmbedding tests.
"""

import os
import pytest
import torch
from packaging import version
import transformers

# Check if we're using transformers v5+
TF_VER = version.parse(getattr(transformers, "__version__", "0.0.0"))
IS_TF_V5_OR_HIGHER = TF_VER >= version.parse("5.0.0")


@pytest.fixture(scope="session")
def device():
"""Return the device to use for tests."""
return "cuda" if torch.cuda.is_available() else "cpu"


@pytest.fixture(scope="session")
def transformers_version():
"""Return the transformers version."""
return TF_VER
51 changes: 51 additions & 0 deletions tests/test_imports_v5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Test that imports work with Transformers v5.

This test verifies that the compatibility layer in FlagEmbedding/utils/transformers_compat.py
properly handles the the removal of is_torch_fx_available in Transformers v5
"""

import pytest
import transformers
from packaging import version

# Import the compatibility layer
from FlagEmbedding.utils.transformers_compat import is_torch_fx_available

# Check if we're using transformers v5+
TF_VER = version.parse(getattr(transformers, "__version__", "0.0.0"))
IS_TF_V5_OR_HIGHER = TF_VER >= version.parse("5.0.0")


# Import the files mentioned in issue #1561 that use is_torch_fx_available
def test_import_modeling_minicpm_reranker_inference():
"""Test importing the modeling_minicpm_reranker module from inference."""
from FlagEmbedding.inference.reranker.decoder_only.models.modeling_minicpm_reranker import (
LayerWiseMiniCPMForCausalLM,
)

assert LayerWiseMiniCPMForCausalLM is not None


def test_import_modeling_minicpm_reranker_finetune():
"""Test importing the modeling_minicpm_reranker module from finetune."""
from FlagEmbedding.finetune.reranker.decoder_only.layerwise.modeling_minicpm_reranker import (
LayerWiseMiniCPMForCausalLM,
)

assert LayerWiseMiniCPMForCausalLM is not None


@pytest.mark.skipif(not IS_TF_V5_OR_HIGHER, reason="Only relevant for Transformers v5+")
def test_is_torch_fx_available_v5():
"""Test that is_torch_fx_available works with Transformers v5."""
# This should not raise an exception
result = is_torch_fx_available()
# The result depends on whether torch.fx is available, but the function should work
assert isinstance(result, bool)


def test_transformers_version(transformers_version):
"""Test that we can detect the transformers version."""
assert transformers_version is not None
print(f"Transformers version: {transformers_version}")
66 changes: 66 additions & 0 deletions tests/test_infer_embedder_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Test basic functionality of BGE embedder models with Transformers v5.

This test loads a small/public BGE checkpoint and runs a single encode on toy strings,
verifying that the shape/dtype are correct and that cosine similarity is sane.
"""
import pytest
import torch
import numpy as np
from FlagEmbedding import FlagModel

def cosine_similarity(a, b):
"""Compute cosine similarity between two vectors."""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def test_bge_embedder_basic(device):
"""Test basic functionality of BGE embedder."""
# Load a small BGE model
model_name = "BAAI/bge-base-en-v1.5"
model = FlagModel(model_name, device=device)

# Test encoding single strings
query = "What is the capital of France?"
passage = "Paris is the capital and most populous city of France."

# Get embeddings
query_embedding = model.encode(query)
passage_embedding = model.encode(passage)

# Check shapes and types
assert isinstance(query_embedding, np.ndarray)
assert isinstance(passage_embedding, np.ndarray)
assert query_embedding.ndim == 1 # Should be a 1D vector
assert passage_embedding.ndim == 1 # Should be a 1D vector

# Check that embeddings have reasonable values
assert not np.isnan(query_embedding).any()
assert not np.isnan(passage_embedding).any()

# Check cosine similarity is reasonable (should be high for related texts)
similarity = cosine_similarity(query_embedding, passage_embedding)
assert 0 <= similarity <= 1 # Cosine similarity range
assert similarity > 0.5 # These texts should be somewhat similar

def test_bge_embedder_batch(device):
"""Test batch encoding with BGE embedder."""
# Load a small BGE model
model_name = "BAAI/bge-base-en-v1.5"
model = FlagModel(model_name, device=device)

# Test batch encoding
queries = [
"What is the capital of France?",
"Who wrote Romeo and Juliet?"
]

# Get embeddings
embeddings = model.encode(queries)

# Check shapes and types
assert isinstance(embeddings, np.ndarray)
assert embeddings.ndim == 2 # Should be a 2D array (batch_size x embedding_dim)
assert embeddings.shape[0] == len(queries)

# Check that embeddings have reasonable values
assert not np.isnan(embeddings).any()
62 changes: 62 additions & 0 deletions tests/test_infer_reranker_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Test basic functionality of reranker models with Transformers v5.

This test instantiates a lightweight reranker and calls compute_score on query/doc pairs
to validate the forward pass.
"""

import pytest
import torch
import numpy as np
from FlagEmbedding import FlagReranker


def test_reranker_basic(device):
"""Test basic functionality of reranker."""
# Load a lightweight reranker model
model_name = "BAAI/bge-reranker-base"
model = FlagReranker(model_name, device=device)

# Test scoring a single query-document pair
query = "What is the capital of France?"
passage = "Paris is the capital and most populous city of France."

# Get score
pair = [(query, passage)]
scores = model.compute_score(pair)
score = scores[0]

# Check score type and range
assert isinstance(score, float)
# Scores are typically in a reasonable range (model-dependent)
assert -100 < score < 100


def test_reranker_batch(device):
"""Test batch scoring with reranker."""
# Load a lightweight reranker model
model_name = "BAAI/bge-reranker-base"
model = FlagReranker(model_name, device=device)

# Test batch scoring
query = "What is the capital of France?"
passages = [
"Paris is the capital and most populous city of France.",
"Berlin is the capital and largest city of Germany.",
"London is the capital and largest city of England and the United Kingdom.",
]

# Create pairs for scoring
pairs = [(query, passage) for passage in passages]

# Get scores
scores = model.compute_score(pairs)

# Check scores shape and type
assert isinstance(scores, list)
assert len(scores) == len(passages)
assert all(isinstance(score, float) for score in scores)

# Check that Paris (correct answer) gets highest score
paris_score = scores[0]
assert paris_score == max(scores), "Paris should have the highest score"
Loading