Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions e2e/pipelines/test_named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys

import pytest
from thinc.api import NumpyOps, get_current_ops, set_current_ops

from haystack import Document, Pipeline
from haystack.components.extractors import (
Expand Down Expand Up @@ -132,3 +133,34 @@ def _check_predictions(predicted, expected) -> None:
assert a.entity == b.entity
assert a.start == b.start
assert a.end == b.end

class TestNamedEntityExtractorDeviceRestoration:
def test_spacy_backend_restores_device_state(self):
"""
Verify that NamedEntityExtractor (spaCy) restores the previous Thinc Ops state
after the component runs.
"""
# 1. Setup a custom state
custom_ops = NumpyOps()
setattr(custom_ops, "owner", "test_marker")
set_current_ops(custom_ops)

try:
# 2. Initialize and run (triggering the context manager)
extractor = NamedEntityExtractor(backend="spacy", model="en_core_web_sm")

# Since _SpacyBackend is private, we access it via getattr to avoid IDE warnings
backend = getattr(extractor, "_backend")
select_device_method = getattr(backend, "_select_device")

with select_device_method():
# Inside the context, the state might change
pass

# 3. Verify state is restored
final_ops = get_current_ops()
assert getattr(final_ops, "owner", None) == "test_marker"

finally:
# Clean up global state
set_current_ops(NumpyOps())
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _embed_batch(self, texts_to_embed: list[str], batch_size: int) -> list[list[
):
batch = texts_to_embed[i : i + batch_size]

np_embeddings = self._client.feature_extraction(text=batch, truncate=truncate, normalize=normalize)
np_embeddings = self._client.feature_extraction(text=batch, truncate=truncate, normalize=normalize) # type: ignore

if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch):
raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}")
Expand All @@ -293,7 +293,9 @@ async def _embed_batch_async(self, texts_to_embed: list[str], batch_size: int) -
async def _runner(batch: list[str]) -> list[list[float]]:
async with sem:
np_embeddings = await self._async_client.feature_extraction(
text=batch, truncate=truncate, normalize=normalize
text=batch,
truncate=truncate,
normalize=normalize, # type: ignore
)

if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch):
Expand Down
11 changes: 3 additions & 8 deletions haystack/components/extractors/named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
with LazyImport(message="Run 'pip install spacy'") as spacy_import:
import spacy
from spacy import Language as SpacyPipeline
from thinc.api import get_current_ops, set_current_ops


class NamedEntityExtractorBackend(Enum):
Expand Down Expand Up @@ -492,17 +493,11 @@ def _select_device(self) -> Iterator[None]:
"""
Context manager used to run spaCy models on a specific GPU in a scoped manner.
"""

# TODO: This won't restore the active device.
# Since there are no opaque API functions to determine
# the active device in spaCy/Thinc, we can't do much
# about it as a consumer unless we start poking into their
# internals.
device_id = self._device.to_spacy()
previous_ops = get_current_ops()
try:
if device_id >= 0:
spacy.require_gpu(device_id)
yield
finally:
if device_id >= 0:
spacy.require_cpu()
set_current_ops(previous_ops)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
fixes:
- |
Fixed a bug in ``NamedEntityExtractor`` where the spaCy/Thinc device state was not correctly
restored after execution, potentially affecting the device configuration of other spaCy
components in the same process.
Loading