From 3074e1c8de467e349c0ab0a0743f1cc91d381890 Mon Sep 17 00:00:00 2001 From: Ritik Raj Date: Tue, 5 May 2026 15:38:31 +0530 Subject: [PATCH 1/4] Fix device state restoration in NamedEntityExtractor and add release note --- .../extractors/named_entity_extractor.py | 11 ++----- ...ractor-spaCy-backend-44be660745d5c3b4.yaml | 6 ++++ .../extractors/test_named_entity_extractor.py | 31 +++++++++++++++++++ 3 files changed, 40 insertions(+), 8 deletions(-) create mode 100644 releasenotes/notes/Restore-Thinc-Ops-state-in-NamedEntityExtractor-spaCy-backend-44be660745d5c3b4.yaml diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index 259d0855fc..a971f48254 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -22,6 +22,7 @@ with LazyImport(message="Run 'pip install spacy'") as spacy_import: import spacy from spacy import Language as SpacyPipeline + from thinc.api import get_current_ops, set_current_ops class NamedEntityExtractorBackend(Enum): @@ -492,17 +493,11 @@ def _select_device(self) -> Iterator[None]: """ Context manager used to run spaCy models on a specific GPU in a scoped manner. """ - - # TODO: This won't restore the active device. - # Since there are no opaque API functions to determine - # the active device in spaCy/Thinc, we can't do much - # about it as a consumer unless we start poking into their - # internals. device_id = self._device.to_spacy() + previous_ops = get_current_ops() try: if device_id >= 0: spacy.require_gpu(device_id) yield finally: - if device_id >= 0: - spacy.require_cpu() + set_current_ops(previous_ops) diff --git a/releasenotes/notes/Restore-Thinc-Ops-state-in-NamedEntityExtractor-spaCy-backend-44be660745d5c3b4.yaml b/releasenotes/notes/Restore-Thinc-Ops-state-in-NamedEntityExtractor-spaCy-backend-44be660745d5c3b4.yaml new file mode 100644 index 0000000000..55b636c184 --- /dev/null +++ b/releasenotes/notes/Restore-Thinc-Ops-state-in-NamedEntityExtractor-spaCy-backend-44be660745d5c3b4.yaml @@ -0,0 +1,6 @@ +--- +fixes: + - | + Fixed a bug in `NamedEntityExtractor` where the spaCy/Thinc device state was not correctly + restored after execution, potentially affecting the device configuration of other spaCy + components in the same process. diff --git a/test/components/extractors/test_named_entity_extractor.py b/test/components/extractors/test_named_entity_extractor.py index 050decef5f..798acc5d40 100644 --- a/test/components/extractors/test_named_entity_extractor.py +++ b/test/components/extractors/test_named_entity_extractor.py @@ -14,6 +14,7 @@ from haystack.components.extractors import NamedEntityAnnotation, NamedEntityExtractor, NamedEntityExtractorBackend from haystack.utils.auth import Secret from haystack.utils.device import ComponentDevice +from thinc.api import NumpyOps, get_current_ops, set_current_ops def test_named_entity_extractor_backend(): @@ -198,3 +199,33 @@ def test_named_entity_extractor_run(): assert "named_entities" in result["documents"][0].meta assert result["documents"][0].meta["named_entities"] == expected_annotations[0] assert "named_entities" not in documents[0].meta + +class TestNamedEntityExtractorDeviceRestoration: + def test_spacy_backend_restores_device_state(self): + """ + Verify that NamedEntityExtractor (spaCy) restores the previous Thinc Ops state + after the component runs. + """ + # 1. Setup a custom state + custom_ops = NumpyOps() + custom_ops.owner = "Ritik" + set_current_ops(custom_ops) + + try: + # 2. Initialize and run (triggering the context manager) + extractor = NamedEntityExtractor(backend="spacy", model="en_core_web_sm") + + # Since _SpacyBackend is private, we access it via the extractor + backend = extractor._backend + + with backend._select_device(): + # Inside the context, the state might change + pass + + # 3. Verify state is restored + final_ops = get_current_ops() + assert getattr(final_ops, "owner", None) == "Ritik" + + finally: + # Clean up global state + set_current_ops(NumpyOps()) From 5d3222c5a4f882c337d2fb0632c1cb16f3fb40d5 Mon Sep 17 00:00:00 2001 From: Ritik Raj Date: Fri, 8 May 2026 19:49:35 +0530 Subject: [PATCH 2/4] Fix CI: Format imports, fix release notes backticks, and bypass upstream type error --- .../embedders/hugging_face_api_document_embedder.py | 4 ++-- ...n-NamedEntityExtractor-spaCy-backend-44be660745d5c3b4.yaml | 2 +- test/components/extractors/test_named_entity_extractor.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py index e35e910db1..10d74db916 100644 --- a/haystack/components/embedders/hugging_face_api_document_embedder.py +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -272,7 +272,7 @@ def _embed_batch(self, texts_to_embed: list[str], batch_size: int) -> list[list[ ): batch = texts_to_embed[i : i + batch_size] - np_embeddings = self._client.feature_extraction(text=batch, truncate=truncate, normalize=normalize) + np_embeddings = self._client.feature_extraction(text=batch, truncate=truncate, normalize=normalize) # type: ignore if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch): raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}") @@ -293,7 +293,7 @@ async def _embed_batch_async(self, texts_to_embed: list[str], batch_size: int) - async def _runner(batch: list[str]) -> list[list[float]]: async with sem: np_embeddings = await self._async_client.feature_extraction( - text=batch, truncate=truncate, normalize=normalize + text=batch, truncate=truncate, normalize=normalize # type: ignore ) if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch): diff --git a/releasenotes/notes/Restore-Thinc-Ops-state-in-NamedEntityExtractor-spaCy-backend-44be660745d5c3b4.yaml b/releasenotes/notes/Restore-Thinc-Ops-state-in-NamedEntityExtractor-spaCy-backend-44be660745d5c3b4.yaml index 55b636c184..727a9cadaa 100644 --- a/releasenotes/notes/Restore-Thinc-Ops-state-in-NamedEntityExtractor-spaCy-backend-44be660745d5c3b4.yaml +++ b/releasenotes/notes/Restore-Thinc-Ops-state-in-NamedEntityExtractor-spaCy-backend-44be660745d5c3b4.yaml @@ -1,6 +1,6 @@ --- fixes: - | - Fixed a bug in `NamedEntityExtractor` where the spaCy/Thinc device state was not correctly + Fixed a bug in ``NamedEntityExtractor`` where the spaCy/Thinc device state was not correctly restored after execution, potentially affecting the device configuration of other spaCy components in the same process. diff --git a/test/components/extractors/test_named_entity_extractor.py b/test/components/extractors/test_named_entity_extractor.py index 798acc5d40..6ec64aecd3 100644 --- a/test/components/extractors/test_named_entity_extractor.py +++ b/test/components/extractors/test_named_entity_extractor.py @@ -9,12 +9,12 @@ from unittest.mock import patch import pytest +from thinc.api import NumpyOps, get_current_ops, set_current_ops from haystack import ComponentError, DeserializationError, Document, Pipeline from haystack.components.extractors import NamedEntityAnnotation, NamedEntityExtractor, NamedEntityExtractorBackend from haystack.utils.auth import Secret from haystack.utils.device import ComponentDevice -from thinc.api import NumpyOps, get_current_ops, set_current_ops def test_named_entity_extractor_backend(): @@ -200,6 +200,7 @@ def test_named_entity_extractor_run(): assert result["documents"][0].meta["named_entities"] == expected_annotations[0] assert "named_entities" not in documents[0].meta + class TestNamedEntityExtractorDeviceRestoration: def test_spacy_backend_restores_device_state(self): """ From 5188bd7c201ace27e9a872ac7ba9e0fd29923482 Mon Sep 17 00:00:00 2001 From: Ritik Raj Date: Fri, 8 May 2026 20:01:27 +0530 Subject: [PATCH 3/4] Fix CI: Format imports, fix release notes backticks, and bypass upstream type error --- .../embedders/hugging_face_api_document_embedder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py index 10d74db916..0435ba3a50 100644 --- a/haystack/components/embedders/hugging_face_api_document_embedder.py +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -293,7 +293,9 @@ async def _embed_batch_async(self, texts_to_embed: list[str], batch_size: int) - async def _runner(batch: list[str]) -> list[list[float]]: async with sem: np_embeddings = await self._async_client.feature_extraction( - text=batch, truncate=truncate, normalize=normalize # type: ignore + text=batch, + truncate=truncate, + normalize=normalize, # type: ignore ) if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch): From 725d445b139462a0f0d1851c662d444af8cb1d65 Mon Sep 17 00:00:00 2001 From: Ritik Raj Date: Fri, 8 May 2026 20:14:44 +0530 Subject: [PATCH 4/4] Refactor device restoration test to use setattr/getattr and avoid strict IDE warnings --- e2e/pipelines/test_named_entity_extractor.py | 32 +++++++++++++++++++ .../extractors/test_named_entity_extractor.py | 32 ------------------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/e2e/pipelines/test_named_entity_extractor.py b/e2e/pipelines/test_named_entity_extractor.py index 4eb096423c..b1894a2abf 100644 --- a/e2e/pipelines/test_named_entity_extractor.py +++ b/e2e/pipelines/test_named_entity_extractor.py @@ -10,6 +10,7 @@ import sys import pytest +from thinc.api import NumpyOps, get_current_ops, set_current_ops from haystack import Document, Pipeline from haystack.components.extractors import ( @@ -132,3 +133,34 @@ def _check_predictions(predicted, expected) -> None: assert a.entity == b.entity assert a.start == b.start assert a.end == b.end + +class TestNamedEntityExtractorDeviceRestoration: + def test_spacy_backend_restores_device_state(self): + """ + Verify that NamedEntityExtractor (spaCy) restores the previous Thinc Ops state + after the component runs. + """ + # 1. Setup a custom state + custom_ops = NumpyOps() + setattr(custom_ops, "owner", "test_marker") + set_current_ops(custom_ops) + + try: + # 2. Initialize and run (triggering the context manager) + extractor = NamedEntityExtractor(backend="spacy", model="en_core_web_sm") + + # Since _SpacyBackend is private, we access it via getattr to avoid IDE warnings + backend = getattr(extractor, "_backend") + select_device_method = getattr(backend, "_select_device") + + with select_device_method(): + # Inside the context, the state might change + pass + + # 3. Verify state is restored + final_ops = get_current_ops() + assert getattr(final_ops, "owner", None) == "test_marker" + + finally: + # Clean up global state + set_current_ops(NumpyOps()) diff --git a/test/components/extractors/test_named_entity_extractor.py b/test/components/extractors/test_named_entity_extractor.py index 6ec64aecd3..050decef5f 100644 --- a/test/components/extractors/test_named_entity_extractor.py +++ b/test/components/extractors/test_named_entity_extractor.py @@ -9,7 +9,6 @@ from unittest.mock import patch import pytest -from thinc.api import NumpyOps, get_current_ops, set_current_ops from haystack import ComponentError, DeserializationError, Document, Pipeline from haystack.components.extractors import NamedEntityAnnotation, NamedEntityExtractor, NamedEntityExtractorBackend @@ -199,34 +198,3 @@ def test_named_entity_extractor_run(): assert "named_entities" in result["documents"][0].meta assert result["documents"][0].meta["named_entities"] == expected_annotations[0] assert "named_entities" not in documents[0].meta - - -class TestNamedEntityExtractorDeviceRestoration: - def test_spacy_backend_restores_device_state(self): - """ - Verify that NamedEntityExtractor (spaCy) restores the previous Thinc Ops state - after the component runs. - """ - # 1. Setup a custom state - custom_ops = NumpyOps() - custom_ops.owner = "Ritik" - set_current_ops(custom_ops) - - try: - # 2. Initialize and run (triggering the context manager) - extractor = NamedEntityExtractor(backend="spacy", model="en_core_web_sm") - - # Since _SpacyBackend is private, we access it via the extractor - backend = extractor._backend - - with backend._select_device(): - # Inside the context, the state might change - pass - - # 3. Verify state is restored - final_ops = get_current_ops() - assert getattr(final_ops, "owner", None) == "Ritik" - - finally: - # Clean up global state - set_current_ops(NumpyOps())