diff --git a/e2e/pipelines/test_named_entity_extractor.py b/e2e/pipelines/test_named_entity_extractor.py index 4eb096423c..b1894a2abf 100644 --- a/e2e/pipelines/test_named_entity_extractor.py +++ b/e2e/pipelines/test_named_entity_extractor.py @@ -10,6 +10,7 @@ import sys import pytest +from thinc.api import NumpyOps, get_current_ops, set_current_ops from haystack import Document, Pipeline from haystack.components.extractors import ( @@ -132,3 +133,34 @@ def _check_predictions(predicted, expected) -> None: assert a.entity == b.entity assert a.start == b.start assert a.end == b.end + +class TestNamedEntityExtractorDeviceRestoration: + def test_spacy_backend_restores_device_state(self): + """ + Verify that NamedEntityExtractor (spaCy) restores the previous Thinc Ops state + after the component runs. + """ + # 1. Setup a custom state + custom_ops = NumpyOps() + setattr(custom_ops, "owner", "test_marker") + set_current_ops(custom_ops) + + try: + # 2. Initialize and run (triggering the context manager) + extractor = NamedEntityExtractor(backend="spacy", model="en_core_web_sm") + + # Since _SpacyBackend is private, we access it via getattr to avoid IDE warnings + backend = getattr(extractor, "_backend") + select_device_method = getattr(backend, "_select_device") + + with select_device_method(): + # Inside the context, the state might change + pass + + # 3. Verify state is restored + final_ops = get_current_ops() + assert getattr(final_ops, "owner", None) == "test_marker" + + finally: + # Clean up global state + set_current_ops(NumpyOps()) diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py index e35e910db1..0435ba3a50 100644 --- a/haystack/components/embedders/hugging_face_api_document_embedder.py +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -272,7 +272,7 @@ def _embed_batch(self, texts_to_embed: list[str], batch_size: int) -> list[list[ ): batch = texts_to_embed[i : i + batch_size] - np_embeddings = self._client.feature_extraction(text=batch, truncate=truncate, normalize=normalize) + np_embeddings = self._client.feature_extraction(text=batch, truncate=truncate, normalize=normalize) # type: ignore if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch): raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}") @@ -293,7 +293,9 @@ async def _embed_batch_async(self, texts_to_embed: list[str], batch_size: int) - async def _runner(batch: list[str]) -> list[list[float]]: async with sem: np_embeddings = await self._async_client.feature_extraction( - text=batch, truncate=truncate, normalize=normalize + text=batch, + truncate=truncate, + normalize=normalize, # type: ignore ) if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch): diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index 259d0855fc..a971f48254 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -22,6 +22,7 @@ with LazyImport(message="Run 'pip install spacy'") as spacy_import: import spacy from spacy import Language as SpacyPipeline + from thinc.api import get_current_ops, set_current_ops class NamedEntityExtractorBackend(Enum): @@ -492,17 +493,11 @@ def _select_device(self) -> Iterator[None]: """ Context manager used to run spaCy models on a specific GPU in a scoped manner. """ - - # TODO: This won't restore the active device. - # Since there are no opaque API functions to determine - # the active device in spaCy/Thinc, we can't do much - # about it as a consumer unless we start poking into their - # internals. device_id = self._device.to_spacy() + previous_ops = get_current_ops() try: if device_id >= 0: spacy.require_gpu(device_id) yield finally: - if device_id >= 0: - spacy.require_cpu() + set_current_ops(previous_ops) diff --git a/releasenotes/notes/Restore-Thinc-Ops-state-in-NamedEntityExtractor-spaCy-backend-44be660745d5c3b4.yaml b/releasenotes/notes/Restore-Thinc-Ops-state-in-NamedEntityExtractor-spaCy-backend-44be660745d5c3b4.yaml new file mode 100644 index 0000000000..727a9cadaa --- /dev/null +++ b/releasenotes/notes/Restore-Thinc-Ops-state-in-NamedEntityExtractor-spaCy-backend-44be660745d5c3b4.yaml @@ -0,0 +1,6 @@ +--- +fixes: + - | + Fixed a bug in ``NamedEntityExtractor`` where the spaCy/Thinc device state was not correctly + restored after execution, potentially affecting the device configuration of other spaCy + components in the same process.