From 5213d8219a46a0355b70f84808bf4b424530ff3a Mon Sep 17 00:00:00 2001 From: Luke Roantree <206716137+lukeroantreeONS@users.noreply.github.com> Date: Thu, 5 Mar 2026 15:41:58 +0000 Subject: [PATCH 1/8] feat(indexers): add framework for flexible, common hooks to be provided as part of the package --- src/classifai/indexers/__init__.py | 6 ++ src/classifai/indexers/hooks/__init__.py | 7 +++ .../indexers/hooks/default_hooks/__init__.py | 5 ++ .../hooks/default_hooks/postprocessing.py | 0 .../hooks/default_hooks/preprocessing.py | 60 +++++++++++++++++++ src/classifai/indexers/hooks/hook_factory.py | 53 ++++++++++++++++ 6 files changed, 131 insertions(+) create mode 100644 src/classifai/indexers/hooks/__init__.py create mode 100644 src/classifai/indexers/hooks/default_hooks/__init__.py create mode 100644 src/classifai/indexers/hooks/default_hooks/postprocessing.py create mode 100644 src/classifai/indexers/hooks/default_hooks/preprocessing.py create mode 100644 src/classifai/indexers/hooks/hook_factory.py diff --git a/src/classifai/indexers/__init__.py b/src/classifai/indexers/__init__.py index 4bc8680..9c181f4 100644 --- a/src/classifai/indexers/__init__.py +++ b/src/classifai/indexers/__init__.py @@ -35,9 +35,15 @@ VectorStoreSearchInput, VectorStoreSearchOutput, ) +from .hooks import ( + CapitalisationStandardisingHook, + HookBase, +) from .main import VectorStore __all__ = [ + "CapitalisationStandardisingHook", + "HookBase", "VectorStore", "VectorStoreEmbedInput", "VectorStoreEmbedOutput", diff --git a/src/classifai/indexers/hooks/__init__.py b/src/classifai/indexers/hooks/__init__.py new file mode 100644 index 0000000..e879892 --- /dev/null +++ b/src/classifai/indexers/hooks/__init__.py @@ -0,0 +1,7 @@ +from .default_hooks import CapitalisationStandardisingHook +from .hook_factory import HookBase + +__all__ = [ + "CapitalisationStandardisingHook", + "HookBase", +] diff --git a/src/classifai/indexers/hooks/default_hooks/__init__.py b/src/classifai/indexers/hooks/default_hooks/__init__.py new file mode 100644 index 0000000..2a8ae40 --- /dev/null +++ b/src/classifai/indexers/hooks/default_hooks/__init__.py @@ -0,0 +1,5 @@ +from .preprocessing import CapitalisationStandardisingHook + +__all__ = [ + "CapitalisationStandardisingHook", +] diff --git a/src/classifai/indexers/hooks/default_hooks/postprocessing.py b/src/classifai/indexers/hooks/default_hooks/postprocessing.py new file mode 100644 index 0000000..e69de29 diff --git a/src/classifai/indexers/hooks/default_hooks/preprocessing.py b/src/classifai/indexers/hooks/default_hooks/preprocessing.py new file mode 100644 index 0000000..15afc26 --- /dev/null +++ b/src/classifai/indexers/hooks/default_hooks/preprocessing.py @@ -0,0 +1,60 @@ +from classifai.exceptions import HookError +from classifai.indexers.dataclasses import ( + VectorStoreEmbedInput, + VectorStoreReverseSearchInput, + VectorStoreSearchInput, +) +from classifai.indexers.hooks.hook_factory import HookBase + + +class CapitalisationStandardisingHook(HookBase): + """A pre-processing hook to handle upper-/lower-/sentence-/title-casing.""" + + def __init__(self, method: str = "lower", colname: str = "query"): + """Inititialises the hook with the specified method for standardising capitalisation. + + Args: + method (str): Method for standardisation. Must be one of "lower", "upper", "sentence" + or "title". Defaults to "lower". + colname (str): The name of one of the fields of the Input object which is to be capitalised. + Defaults to "query". + """ + super().__init__(method=method, colname=colname, hook_type="pre_processing") + if method not in {"lower", "upper", "sentence", "title"}: + raise HookError( + "Invalid method for CapitalisationStandardisingHook. " + "Must be one of 'lower', 'upper', 'sentence', or 'title'.", + context={self.hook_type: "Capitalisation", "method": method}, + ) + if method == "lower": + self.method = str.lower + elif method == "upper": + self.method = str.upper + elif method == "sentence": + self.method = lambda text: text.capitalize() if text else text + elif method == "title": + self.method = str.title + self.colname = colname + self._setup() + + def _setup(self): + """No setup required.""" + pass + + def __call__( + self, input_data: VectorStoreSearchInput | VectorStoreReverseSearchInput | VectorStoreEmbedInput + ) -> VectorStoreSearchInput | VectorStoreReverseSearchInput | VectorStoreEmbedInput: + """Standardises capitalisation in the input data as specified by 'method' attribute.""" + if self.colname not in input_data.columns: + raise HookError("Invalid column name passed.", context={"pre_processing": "Capitalisation"}) + if self.colname not in input_data.select_dtypes(include=["object"]).columns: + raise HookError( + f"colname is of type {input_data[self.colname].dtype}, expected 'str'.", + context={"pre_processing": "Capitalisation"}, + ) + + processed_input = input_data.copy() + processed_input[self.colname] = processed_input[self.colname].apply(self.method) + # Ensure the processed input still conforms to the schema + processed_input = input_data.__class__.validate(processed_input) + return processed_input diff --git a/src/classifai/indexers/hooks/hook_factory.py b/src/classifai/indexers/hooks/hook_factory.py new file mode 100644 index 0000000..9bd0052 --- /dev/null +++ b/src/classifai/indexers/hooks/hook_factory.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod + +from classifai.exceptions import HookError +from classifai.indexers.dataclasses import ( + VectorStoreEmbedInput, + VectorStoreEmbedOutput, + VectorStoreReverseSearchInput, + VectorStoreReverseSearchOutput, + VectorStoreSearchInput, + VectorStoreSearchOutput, +) + + +class HookBase(ABC): + """Abstract base class for all post-processing hooks requiring customisation.""" + + def __init__(self, **kwargs): + """Sets any attributes required by the hook.""" + self.hook_type: str = "generic" # Placeholder for hook type, can be overridden by subclasses + # or set via kwargs + self.kwargs = kwargs + self._setup() + + def _setup(self): # noqa: B027 + """Performs any setup / initialisation required by the hook.""" + pass + + @abstractmethod + def __call__( + self, + data: VectorStoreSearchOutput + | VectorStoreReverseSearchOutput + | VectorStoreEmbedOutput + | VectorStoreSearchInput + | VectorStoreReverseSearchInput + | VectorStoreEmbedInput, + ) -> ( + VectorStoreSearchOutput + | VectorStoreReverseSearchOutput + | VectorStoreEmbedOutput + | VectorStoreSearchInput + | VectorStoreReverseSearchInput + | VectorStoreEmbedInput + ): + """Defines the behavior of the hook when called.""" + processed_data = data # Placeholder for processing logic + if not isinstance(processed_data, type(data)): + raise HookError( + f"Processed data must be of the same type as input. " + f"Expected {type(data).__name__}, got {type(processed_data).__name__}.", + context={"hook_type": self.hook_type}, + ) + return processed_data From f3fd636cf3ee2fd1378883893e6adab0dcbc9fa0 Mon Sep 17 00:00:00 2001 From: Luke Roantree <206716137+lukeroantreeONS@users.noreply.github.com> Date: Fri, 13 Mar 2026 16:11:12 +0000 Subject: [PATCH 2/8] feat(hooks): add deduplication postprocessing hook, refactor base class to remove unneeded internal method --- src/classifai/indexers/hooks/__init__.py | 3 +- .../indexers/hooks/default_hooks/__init__.py | 2 + .../hooks/default_hooks/postprocessing.py | 63 +++++++++++++++++++ .../hooks/default_hooks/preprocessing.py | 5 -- src/classifai/indexers/hooks/hook_factory.py | 5 -- 5 files changed, 67 insertions(+), 11 deletions(-) diff --git a/src/classifai/indexers/hooks/__init__.py b/src/classifai/indexers/hooks/__init__.py index e879892..cff921c 100644 --- a/src/classifai/indexers/hooks/__init__.py +++ b/src/classifai/indexers/hooks/__init__.py @@ -1,7 +1,8 @@ -from .default_hooks import CapitalisationStandardisingHook +from .default_hooks import CapitalisationStandardisingHook, DeduplicationHook from .hook_factory import HookBase __all__ = [ "CapitalisationStandardisingHook", + "DeduplicationHook", "HookBase", ] diff --git a/src/classifai/indexers/hooks/default_hooks/__init__.py b/src/classifai/indexers/hooks/default_hooks/__init__.py index 2a8ae40..adaa32c 100644 --- a/src/classifai/indexers/hooks/default_hooks/__init__.py +++ b/src/classifai/indexers/hooks/default_hooks/__init__.py @@ -1,5 +1,7 @@ +from .postprocessing import DeduplicationHook from .preprocessing import CapitalisationStandardisingHook __all__ = [ "CapitalisationStandardisingHook", + "DeduplicationHook", ] diff --git a/src/classifai/indexers/hooks/default_hooks/postprocessing.py b/src/classifai/indexers/hooks/default_hooks/postprocessing.py index e69de29..c984a86 100644 --- a/src/classifai/indexers/hooks/default_hooks/postprocessing.py +++ b/src/classifai/indexers/hooks/default_hooks/postprocessing.py @@ -0,0 +1,63 @@ +import numpy as np +import pandas as pd + +from classifai.exceptions import HookError +from classifai.indexers.dataclasses import VectorStoreSearchOutput +from classifai.indexers.hooks.hook_factory import HookBase + + +class DeduplicationHook(HookBase): + """A pre-processing hook to remove duplicate knowledgebase entries, i.e. entries with the same label.""" + + def _mean_score(self, scores): + return np.mean(scores) + + def _max_score(self, scores): + return np.max(scores) + + def __init__(self, score_aggregation_method: str = "max"): + """Inititialises the hook with the specified method for assigning scores to deduplicated entries. + + Args: + score_aggregation_method (str): Method for assigning score to the deduplicated entry. + Must be one of "max" or "mean". Defaults to "max". + A future update will introduce a 'softmax' option. + """ + if score_aggregation_method not in ["max", "mean"]: + raise HookError( + "Invalid method for DeduplicationHook. Must be one of 'max', or 'mean'.", + context={self.hook_type: "Deduplication", "method": score_aggregation_method}, + ) + self.score_aggregation_method = score_aggregation_method + if self.score_aggregation_method == "max": + self.score_aggregator = self._max_score + # Softmax not supported until normalisation is implemented. + # elif self.score_aggregation_method == "softmax": + # self.score_aggregator = ... + elif self.score_aggregation_method == "mean": + self.score_aggregator = self._mean_score + + super().__init__(hook_type="post_processing") + + def __call__(self, input_data: VectorStoreSearchOutput) -> VectorStoreSearchOutput: + """Aggregates retrieved knowledgebase entries corresponding to the same label.""" + df_gpby = ( + input_data.groupby(["query_id", "query_text", "doc_id"]) + .aggregate( + score=("score", self.score_aggregator), + idxmax=("score", "idxmax"), + rank=("rank", "min"), + ) + .reset_index() + ) + + for query in df_gpby["query_id"].unique(): + batch = df_gpby[df_gpby["query_id"] == query] + new_rank = pd.factorize(batch["rank"], sort=True)[0] + df_gpby.loc[batch.index, "rank"] = new_rank + + for col in set(input_data.columns).difference(set(df_gpby.columns)): + df_gpby[col] = df_gpby["idxmax"].map(input_data[col]) + + processed_output = input_data.__class__.validate(df_gpby[input_data.columns]) + return processed_output diff --git a/src/classifai/indexers/hooks/default_hooks/preprocessing.py b/src/classifai/indexers/hooks/default_hooks/preprocessing.py index 15afc26..f598f6b 100644 --- a/src/classifai/indexers/hooks/default_hooks/preprocessing.py +++ b/src/classifai/indexers/hooks/default_hooks/preprocessing.py @@ -35,11 +35,6 @@ def __init__(self, method: str = "lower", colname: str = "query"): elif method == "title": self.method = str.title self.colname = colname - self._setup() - - def _setup(self): - """No setup required.""" - pass def __call__( self, input_data: VectorStoreSearchInput | VectorStoreReverseSearchInput | VectorStoreEmbedInput diff --git a/src/classifai/indexers/hooks/hook_factory.py b/src/classifai/indexers/hooks/hook_factory.py index 9bd0052..72a128d 100644 --- a/src/classifai/indexers/hooks/hook_factory.py +++ b/src/classifai/indexers/hooks/hook_factory.py @@ -19,11 +19,6 @@ def __init__(self, **kwargs): self.hook_type: str = "generic" # Placeholder for hook type, can be overridden by subclasses # or set via kwargs self.kwargs = kwargs - self._setup() - - def _setup(self): # noqa: B027 - """Performs any setup / initialisation required by the hook.""" - pass @abstractmethod def __call__( From 4365bdac68e304ecbe3b439e11be67b830d93e5e Mon Sep 17 00:00:00 2001 From: Luke Roantree <206716137+lukeroantreeONS@users.noreply.github.com> Date: Fri, 13 Mar 2026 16:16:04 +0000 Subject: [PATCH 3/8] docs(hooks): improved docstring in CapitalisationStandardisingHook --- src/classifai/indexers/hooks/default_hooks/preprocessing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/classifai/indexers/hooks/default_hooks/preprocessing.py b/src/classifai/indexers/hooks/default_hooks/preprocessing.py index f598f6b..90a0542 100644 --- a/src/classifai/indexers/hooks/default_hooks/preprocessing.py +++ b/src/classifai/indexers/hooks/default_hooks/preprocessing.py @@ -14,8 +14,9 @@ def __init__(self, method: str = "lower", colname: str = "query"): """Inititialises the hook with the specified method for standardising capitalisation. Args: - method (str): Method for standardisation. Must be one of "lower", "upper", "sentence" - or "title". Defaults to "lower". + method (str): Method for standardisation. Must be one of "lower" (like this), + "upper" (LIKE THIS), "sentence" (Like this), or "title" (Like This). + Defaults to "lower". colname (str): The name of one of the fields of the Input object which is to be capitalised. Defaults to "query". """ From b438b12ca939012d3d7c4bffde68bbbf0951991d Mon Sep 17 00:00:00 2001 From: Luke Roantree <206716137+lukeroantreeONS@users.noreply.github.com> Date: Fri, 13 Mar 2026 16:19:47 +0000 Subject: [PATCH 4/8] feat(hooks): allow multiple columns to be passed to CapitalisationStandardisingHook --- .../hooks/default_hooks/preprocessing.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/classifai/indexers/hooks/default_hooks/preprocessing.py b/src/classifai/indexers/hooks/default_hooks/preprocessing.py index 90a0542..f21539c 100644 --- a/src/classifai/indexers/hooks/default_hooks/preprocessing.py +++ b/src/classifai/indexers/hooks/default_hooks/preprocessing.py @@ -10,14 +10,15 @@ class CapitalisationStandardisingHook(HookBase): """A pre-processing hook to handle upper-/lower-/sentence-/title-casing.""" - def __init__(self, method: str = "lower", colname: str = "query"): + def __init__(self, method: str = "lower", colname: str | list[str] = "query"): """Inititialises the hook with the specified method for standardising capitalisation. Args: method (str): Method for standardisation. Must be one of "lower" (like this), "upper" (LIKE THIS), "sentence" (Like this), or "title" (Like This). Defaults to "lower". - colname (str): The name of one of the fields of the Input object which is to be capitalised. + colname (str | list[str]): The name of one of the fields of the Input object which is/are + to be capitalised. Defaults to "query". """ super().__init__(method=method, colname=colname, hook_type="pre_processing") @@ -41,16 +42,22 @@ def __call__( self, input_data: VectorStoreSearchInput | VectorStoreReverseSearchInput | VectorStoreEmbedInput ) -> VectorStoreSearchInput | VectorStoreReverseSearchInput | VectorStoreEmbedInput: """Standardises capitalisation in the input data as specified by 'method' attribute.""" - if self.colname not in input_data.columns: - raise HookError("Invalid column name passed.", context={"pre_processing": "Capitalisation"}) - if self.colname not in input_data.select_dtypes(include=["object"]).columns: - raise HookError( - f"colname is of type {input_data[self.colname].dtype}, expected 'str'.", - context={"pre_processing": "Capitalisation"}, - ) + if isinstance(self.colname, str): + self.colname = [self.colname] + for col in self.colname: + if col not in input_data.columns: + raise HookError( + "Invalid column name passed.", context={"pre_processing": "Capitalisation", "colname": col} + ) + if col not in input_data.select_dtypes(include=["object"]).columns: + raise HookError( + f"colname is of type {input_data[col].dtype}, expected 'str'.", + context={"pre_processing": "Capitalisation", "colname": col}, + ) processed_input = input_data.copy() - processed_input[self.colname] = processed_input[self.colname].apply(self.method) + for col in self.colname: + processed_input[col] = processed_input[col].apply(self.method) # Ensure the processed input still conforms to the schema processed_input = input_data.__class__.validate(processed_input) return processed_input From 565a7070d0de776737384b76bd38151c5f4ace6a Mon Sep 17 00:00:00 2001 From: Luke Roantree <206716137+lukeroantreeONS@users.noreply.github.com> Date: Fri, 13 Mar 2026 16:39:52 +0000 Subject: [PATCH 5/8] fix(hooks): make deduplicated ranking start from 1 not 0 --- src/classifai/indexers/hooks/default_hooks/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/classifai/indexers/hooks/default_hooks/postprocessing.py b/src/classifai/indexers/hooks/default_hooks/postprocessing.py index c984a86..8aeeb51 100644 --- a/src/classifai/indexers/hooks/default_hooks/postprocessing.py +++ b/src/classifai/indexers/hooks/default_hooks/postprocessing.py @@ -53,7 +53,7 @@ def __call__(self, input_data: VectorStoreSearchOutput) -> VectorStoreSearchOutp for query in df_gpby["query_id"].unique(): batch = df_gpby[df_gpby["query_id"] == query] - new_rank = pd.factorize(batch["rank"], sort=True)[0] + new_rank = pd.factorize(batch["rank"], sort=True)[0] + 1 df_gpby.loc[batch.index, "rank"] = new_rank for col in set(input_data.columns).difference(set(df_gpby.columns)): From 6a03300de3bbd51678722965c71dec0d21e3dbb4 Mon Sep 17 00:00:00 2001 From: Luke Roantree <206716137+lukeroantreeONS@users.noreply.github.com> Date: Fri, 13 Mar 2026 17:05:16 +0000 Subject: [PATCH 6/8] chore(hooks): rerank in duplcation based on new score not new rank, to keep robust across aggregation methods --- src/classifai/indexers/hooks/default_hooks/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/classifai/indexers/hooks/default_hooks/postprocessing.py b/src/classifai/indexers/hooks/default_hooks/postprocessing.py index 8aeeb51..75d68d0 100644 --- a/src/classifai/indexers/hooks/default_hooks/postprocessing.py +++ b/src/classifai/indexers/hooks/default_hooks/postprocessing.py @@ -53,7 +53,7 @@ def __call__(self, input_data: VectorStoreSearchOutput) -> VectorStoreSearchOutp for query in df_gpby["query_id"].unique(): batch = df_gpby[df_gpby["query_id"] == query] - new_rank = pd.factorize(batch["rank"], sort=True)[0] + 1 + new_rank = pd.factorize(-batch["score"], sort=True)[0] + 1 df_gpby.loc[batch.index, "rank"] = new_rank for col in set(input_data.columns).difference(set(df_gpby.columns)): From 37f4c8de152630cf69d24798b56daf2e7838f289 Mon Sep 17 00:00:00 2001 From: Luke Roantree <206716137+lukeroantreeONS@users.noreply.github.com> Date: Tue, 17 Mar 2026 16:39:43 +0000 Subject: [PATCH 7/8] feat(hooks): let each hook type be a list of hooks or individual hook, update documentation and examples for hooks --- DEMO/Using_Hooks.ipynb | 590 +++++++++++++++ ...eprocessing_and_postprocessing_hooks.ipynb | 692 ------------------ DEMO/files/search_spellcheck_hook_clipped.svg | 175 +++++ .../vectorstore_search_dataflow_clipped.svg | 361 +++++++++ src/classifai/indexers/__init__.py | 2 + .../hooks/default_hooks/postprocessing.py | 18 +- src/classifai/indexers/main.py | 46 +- 7 files changed, 1174 insertions(+), 710 deletions(-) create mode 100644 DEMO/Using_Hooks.ipynb delete mode 100644 DEMO/custom_preprocessing_and_postprocessing_hooks.ipynb create mode 100644 DEMO/files/search_spellcheck_hook_clipped.svg create mode 100644 DEMO/files/vectorstore_search_dataflow_clipped.svg diff --git a/DEMO/Using_Hooks.ipynb b/DEMO/Using_Hooks.ipynb new file mode 100644 index 0000000..5c67de4 --- /dev/null +++ b/DEMO/Using_Hooks.ipynb @@ -0,0 +1,590 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# VectorStore pre- and post- processing with _Hooks_ 🪝\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "This notebook provides a guide on how to use pre- and post-processing 'hooks' to simplify additional tasks you may\n", + "have around input sanitisation / standardisation, result aggregation, auto-coding, linking in external datasets, etc.\n", + "\n", + "Hooks operate on the input and output dataclasses of each of our `VectorStore` methods, and the key requirement for them is that each will output an object of the same dataclass that it received. This allows hooks to be 'chained', allowing more complex pre-/post-processing tasks to be split into discrete steps, each defined by its own hook.\n", + "\n", + "Hooks provide a way to modify the data flow of the ClassifAI package, to facilitate tasks such as:\n", + "\n", + "- Removal of punctuation from input queries before the `VectorStore` search process begins,\n", + "- Standardising capitalisation of all text in an input query to the `Vectorstore` search process,\n", + "- Deduplication of search results based on the doc_id column,\n", + "- Prevention of users from retrieving certain documents in your `VectorStore`,\n", + "- Removing hate speech from any input text.\n", + "\n", + "There are three main ways of using hooks;\n", + "\n", + " - Using one of the `default_hooks` we provide as part of ClassifAI,\n", + " - Defining your own callable function,\n", + " - Extending the abstract `HookBase` class to create a custom, flexible hook with implicit validation.\n", + "\n", + "#### Key Sections:\n", + "- Recap: How the dataclasses and data flow works for interactions with the VectorStore,\n", + "- Using Hooks: Introduction to defining and using your own hooks for different tasks,\n", + "- Default Hooks: Introduction to using the more flexible hooks provided by ClassifAI for common tasks,\n", + "- Extending the `HookBase` class: Make your own custom flexible hook.\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Recap of VectorStore Dataclasses\n", + "\n", + "The majority of the following points are already covered in the recommended first notebook demo, [general_workflow_demo.ipynb](./general_workflow_demo.ipynb). \n", + "If you are unfamiliar with the package you may wish to work through this demo first, in order to better understand the `VectorStore` object and its' methods and dataclasses." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### ClassifAI uses Pandas dataframe-like dataclasses to specify what data need to be passed as input to the VectorStore methods/functions, and what data can be expected to be returned by those methods\n", + "\n", + "The **VectorStore** class, responsible for performing different actions with your data, has **three key methods/functions**:\n", + "\n", + "1. **`search()`** \n", + " - Takes in a body of text and searches the vector store for semantically similar knowledgebase samples.\n", + "\n", + "2. **`reverse_search()`** \n", + " - Takes in document IDs and searches the vector store for entries with those IDs.\n", + "\n", + "3. **`embed()`** \n", + " - Takes in a body of text and uses the vectoriser model to convert the text into embeddings.\n", + "\n", + "---\n", + "\n", + "For each of these three core methods, we have created an **input dataclass** and an **output dataclass**. These dataclasses define pandas-like objects that specify what data needs to be passed to each method and also perform runtime checks to ensure you've passed the correct columns in a dataframe to the appropriate VectorStore method.\n", + "\n", + "For example, the figure below illustrates the input and output dataclasses of the `VectorStore.search()` method:\n", + "\n", + "
\n", + "\n", + "![VectorStore Search Dataflow](./files/vectorstore_search_dataflow_clipped.svg)\n", + "\n", + "
\n", + "\n", + "This shows that the `VectorStore.search()` method expects:\n", + "- An **input dataclass object** with columns `[id, query]`. \n", + "- To output an **output dataclass object** with columns `[query_id, query_text, doc_id, doc_text, rank, score]`.\n", + "\n", + "The reverse_search() and embed() VectorStore functions have their own input and output data classes with their own validity column data checks. The names of each set are intuitively:\n", + "\n", + "| **VectorStore Method** | **Input Dataclass** | **Output Dataclass** |\n", + "|-------------------------------|-----------------------------|-----------------------------|\n", + "| `VectorStore.search()` | `VectorStoreSearchInput` | `VectorStoreSearchOutput` |\n", + "| `VectorStore.reverse_search()` | `VectorStoreReverseSearchInput` | `VectorStoreReverseSearchOutput` |\n", + "| `VectorStore.embed()` | `VectorStoreEmbedInput` | `VectorStoreEmbedOutput` |\n", + "\n", + "---\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hooks and custom dataflows\n", + "\n", + "'Hooks' allow users to manipulate the content of ClassifAI dataclass objects as they enter or leave the `VectorStore`.\n", + "\n", + "The requirements for a hook are:\n", + "\n", + " - To be callable (e.g. a function),\n", + " - To accept one input parameter, which is an instance of a dataclass,\n", + " - To output a valid instance of the same type.\n", + "\n", + "For example: you might want to preprocess the input to the `VectorStore.search()` method to remove punctuation from the texts:\n", + "\n", + "
\n", + "\n", + "![VectorStore Search Dataflow](./files/search_spellcheck_hook_clipped.svg)\n", + "\n", + "
\n", + "\n", + "Hooks can be attached to a `VectorStore` to run every time the `VectorStore` search method is called. \n", + "You can also apply other hooks to other dataclasses and their respective `VectorStore` methods and chain togtether these custom operations that manipulate the input and output dataclasses of the `VectorStore` methods.\n", + "\n", + "For example, implmenting 2 hooks for the input and output dataclasses of the `VectorStore` search method would provide a dataflow:\n", + "\n", + "![End to end Search with 2 hooks](./files/pre_and_post_search_hooks.png)\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using Hooks\n", + "\n", + "In this section, we'll define a custom data sanitisation function which removes punctuation from input user queries, and attach it to a `VectorStore` as a search pre-processing hook.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "### Demo Data \n", + "\n", + "This demo uses a mock dataset that is freely available on the ClassifAI repo, if yo have not downloaded the entire DEMO folder to run this notebook, the minimum data you require is the `DEMO/data/testdata.csv` file, which you should place in your working directory in a `DEMO` folder - (or you can just change the filepath later in this demo notebook)\n", + "\n", + "#### If you can run the following cell in this notebook, you should be good to go!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.vectorisers import HuggingFaceVectoriser\n", + "\n", + "print(\"done!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Alternatively, to test without running a notebook, run the following from your command line; \n", + "\n", + "```shell\n", + "python -c \"import classifai\"\n", + "```\n", + "\n", + "---\n", + "\n", + "We'll name our hook `remove_punctuation()`, and define it like this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import string\n", + "\n", + "from classifai.indexers.dataclasses import VectorStoreSearchInput\n", + "\n", + "\n", + "def remove_punctuation(input_data: VectorStoreSearchInput) -> VectorStoreSearchInput:\n", + " # we want to modify the 'texts' field in the input_data pydantic model, which is a list of texts\n", + " # this line removes punctuation from each string with list comprehension\n", + " sanitised_texts = [x.translate(str.maketrans(\"\", \"\", string.punctuation)) for x in input_data[\"query\"]]\n", + " input_data[\"query\"] = sanitised_texts\n", + " # Return the dictionary of input data with desired modified values at each desired key\n", + " return input_data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we attach it to a `VectorStore`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers import VectorStore\n", + "\n", + "vectoriser = HuggingFaceVectoriser(model_name=\"sentence-transformers/all-MiniLM-L6-v2\")\n", + "\n", + "my_vector_store = VectorStore(\n", + " file_name=\"data/fake_soc_dataset.csv\",\n", + " data_type=\"csv\",\n", + " vectoriser=vectoriser,\n", + " overwrite=True,\n", + " hooks={\"search_preprocess\": remove_punctuation},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now if we search some query against the `VectorStore` we created, punctuation should be stripped from the query before it is converted to a vector and searched against the knowledgebase:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers.dataclasses import VectorStoreSearchInput\n", + "\n", + "input_data = VectorStoreSearchInput({\"id\": [1], \"query\": [\"a fruit and vegetable farmer!!!\"]})\n", + "\n", + "my_vector_store.search(input_data, n_results=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Default Hooks\n", + "\n", + "ClassifAI provides hooks out-of-the-box for some common tasks, which may speed up your development substantially.\n", + "\n", + "We'll briefly cover two of these here; \n", + "\n", + " - A `CapitalisationStandardisingHook` which can be operate on any `Input` `VectorStore` dataclass,\n", + " - A `DeduplicationHook` which operates on the `VectorStoreSearchOutput` dataclass." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers.hooks.default_hooks import CapitalisationStandardisingHook, DeduplicationHook\n", + "\n", + "# This hook enforces standardisation of capitalisation for inputs, which helps improve matching for some vectorisers.\n", + "# You can specify the dataframe column (or list of columns) which this should apply to, and the method by which to\n", + "# standardise the text.\n", + "# The options for capitalisation standardisation are;\n", + "# - 'upper' LIKE THIS\n", + "# - 'lower' like this\n", + "# - 'sentence' Like this\n", + "# - 'title' Like This\n", + "cap_hook = CapitalisationStandardisingHook(method=\"upper\", colname=\"query\")\n", + "\n", + "# This hook replaces several VectorStore matches with the same label with the single 'best' one for a given query.\n", + "# You can control how this deduplication is reflected in the score via the score_aggregation_method parameter;\n", + "# - 'max' keep only the best individual match,\n", + "# - 'mean' keep the text and metadata of the best match, but recalculate the score as the mean of all retrieved matches for that label.\n", + "dedup_hook = DeduplicationHook(score_aggregation_method=\"max\")\n", + "\n", + "\n", + "# This is a quick additional hook, which we will use to show how several hooks can be chained in a series for the same task\n", + "def quick_second_hook(df):\n", + " df[\"test\"] = 1\n", + " return df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now if we define a new `VectorStore` with these hooks, we can see them in action (and compare against the output of the previous `VectorStore`). \n", + "Note also how you can use multiple hooks in a chain for the same task, by passing a list of hooks instead of a single hook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_vector_store_with_hooks = VectorStore(\n", + " file_name=\"data/fake_soc_dataset.csv\",\n", + " data_type=\"csv\",\n", + " vectoriser=vectoriser,\n", + " overwrite=True,\n", + " hooks={\n", + " \"search_preprocess\": cap_hook,\n", + " \"search_postprocess\": [dedup_hook, quick_second_hook],\n", + " },\n", + ")\n", + "\n", + "# We define some input data to use for our comparison:\n", + "input_data = VectorStoreSearchInput({\"id\": [1], \"query\": [\"a fruit and vegetable farmer!!!\"]})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Output for our first `VectorStore` (which had the punctuation removal hook only):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_vector_store.search(input_data, n_results=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that in this case all five matches correspond to the same label.\n", + "\n", + "Output for our new `VectorStore`, with multiple hooks - including the ClassifAI 'default' hooks:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_vector_store_with_hooks.search(input_data, n_results=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Extending the HookBase class\n", + "\n", + "As well as the flexible hooks we provide for common tasks, we recognise that creating custom flexible callable hooks is not straightforward, due to the strict requirement that they can be called with a single input parameter which must be a `VectorStore` dataclass.\n", + "\n", + "To make this more accessible, we offer an abstract `HookBase` class which underpins the other default hooks offered by ClassifAI.\n", + "This base class acts as scaffolding to create a callable object which accepts additional configuration - as seen in the other default hooks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers.hooks import HookBase" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "An extension of `HookBase` will have two methods;\n", + "\n", + " - `__init__()`, for setting attributes to be referred to later, and any other initial setup. It can take multiple arguments.\n", + " - `__call__()`, to perform the 'hook' action. It should accept only one argument, which is a `VectorStore` dataclass instance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class MyCustomHook(HookBase):\n", + " def __init__(self, new_col_name: str, new_col_value):\n", + " self.new_col_name = new_col_name\n", + " self.new_col_value = new_col_value\n", + "\n", + " def __call__(self, data):\n", + " data[self.new_col_name] = self.new_col_value\n", + " return data\n", + "\n", + "\n", + "my_custom_hook = MyCustomHook(new_col_name=\"test_col\", new_col_value=12345)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's test it out:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_vector_store_with_custom_hook = VectorStore(\n", + " file_name=\"data/fake_soc_dataset.csv\",\n", + " data_type=\"csv\",\n", + " vectoriser=vectoriser,\n", + " overwrite=True,\n", + " hooks={\n", + " \"search_postprocess\": my_custom_hook,\n", + " },\n", + ")\n", + "\n", + "my_vector_store_with_custom_hook.search(input_data, n_results=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Adding Hooks to a VectorStore when loading from filespace\n", + "\n", + "ClassifAI allows you to create your VectorStore once, and then persist it on local storage so that it can be reloaded for reuse later - _without_ having to embed your knowledgebase again.\n", + "\n", + "If you've followed through with the above code cells you may have noticed that every time we've instantiated a VectorStore it has saved a new folder to filespace (overwriting each time).\n", + "\n", + "We can use the `VectorStore.from_filespace()` class method to load the `VectorStore` back into memory.\n", + "\n", + "*Important:* any hooks you applied in previous sessions are _not_ saved to the filespace.\n", + "Instead, you must pass them in again via the `hooks` parameter, as in the original construction.\n", + "There is no strict requirement to use the same hooks upon reloading, so if you want to update, add or remove hooks you don't need to re-create your `VectorStore` from scratch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# you can see we've reused the vectoriser and hooks from before\n", + "reloaded_vector_store = VectorStore.from_filespace(\n", + " folder_path=\"./fake_soc_dataset/\", # YOU MAY NEED TO CHANGE THIS LINE TO THE CORRECT PATH\n", + " vectoriser=vectoriser,\n", + " hooks={\n", + " \"search_preprocess\": remove_punctuation,\n", + " \"search_postprocess\": quick_second_hook,\n", + " },\n", + ")\n", + "reloaded_vector_store.search(input_data, n_results=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Injecting Data into our classification results with a hook\n", + "\n", + "You may have some additional contextual information that you wanted to add in your pipeline.\n", + "For example; some taxonomical definitions for the labels.\n", + "\n", + "We may want to inject this extra information only after the search retrieval rather than store it as metadata in the knowledgebase; this can be achieved with a hook.\n", + "\n", + "Here is some example taxonomy information for our test dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "official_id_definitions = {\n", + " \"101\": \"Fruit farmer: Grows and harvests fruits such as apples, oranges, and berries.\",\n", + " \"102\": \"dairy farmer: Manages cows for milk production and processes dairy products.\",\n", + " \"103\": \"construction laborer: Performs physical tasks on construction sites, such as digging and carrying materials.\",\n", + " \"104\": \"carpenter: Constructs, installs, and repairs wooden frameworks and structures.\",\n", + " \"105\": \"electrician: Installs, maintains, and repairs electrical systems in buildings and equipment.\",\n", + " \"106\": \"plumber: Installs and repairs water, gas, and drainage systems in homes and businesses.\",\n", + " \"107\": \"software developer: Designs, writes, and tests computer programs and applications.\",\n", + " \"108\": \"data analyst: Analyzes data to provide insights and support decision-making.\",\n", + " \"109\": \"accountant: Prepares and examines financial records, ensuring accuracy and compliance with regulations.\",\n", + " \"110\": \"teacher: Educates students in schools, colleges, or universities.\",\n", + " \"111\": \"nurse: Provides medical care and support to patients in hospitals, clinics, or homes.\",\n", + " \"112\": \"chef: Prepares and cooks meals in restaurants, hotels, or other food establishments.\",\n", + " \"113\": \"graphic designer: Creates visual concepts for advertisements, websites, and branding.\",\n", + " \"114\": \"mechanic: Repairs and maintains vehicles and machinery.\",\n", + " \"115\": \"photographer: Captures images for events, advertising, or artistic purposes.\",\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers.dataclasses import VectorStoreSearchOutput\n", + "\n", + "\n", + "def add_id_definitions(input_data: VectorStoreSearchOutput) -> VectorStoreSearchOutput:\n", + " # Map the 'doc_id' column to the corresponding definitions from the dictionary\n", + " input_data.loc[:, \"id_definition\"] = input_data[\"doc_id\"].map(official_id_definitions)\n", + "\n", + " return input_data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now pair this with our deduplicating hook in a chain:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### and lets try the search again!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "multi_input_data = VectorStoreSearchInput(\n", + " {\n", + " \"id\": [1, 2],\n", + " \"query\": [\"a fruit and vegetable farmer!!!\", \"Digital marketing@\"],\n", + " }\n", + ")\n", + "\n", + "reloaded_vector_store = VectorStore.from_filespace(\n", + " folder_path=\"./fake_soc_dataset/\", # YOU MAY NEED TO CHANGE THIS LINE TO THE CORRECT PATH\n", + " vectoriser=vectoriser,\n", + " hooks={\n", + " \"search_postprocess\": [dedup_hook, add_id_definitions],\n", + " },\n", + ")\n", + "reloaded_vector_store.search(multi_input_data, n_results=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see a few null values in that last output because our demo list of extra data wasn't exhaustive, but where an ID does match our 'official_id_definitions' data we see the data being added correctly." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/DEMO/custom_preprocessing_and_postprocessing_hooks.ipynb b/DEMO/custom_preprocessing_and_postprocessing_hooks.ipynb deleted file mode 100644 index 5e8ae2e..0000000 --- a/DEMO/custom_preprocessing_and_postprocessing_hooks.ipynb +++ /dev/null @@ -1,692 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# VectorStore pre- and post- processing logic with Hooks🪝" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Overview\n", - "\n", - "This notebook provides a guide on how to implement custom, user-defined pre- and post-processing 'hooks'. Hooks provide a way to modify the traditional data flow of the ClassifAI package so that you might, for example:\n", - "\n", - "- Remove punctuation from input queries before the VectorStore search process begins,\n", - "- Capitalising all text in an input query to the Vectorstore search process,\n", - "- Deduplicate results based on the doc_id column so that duplicate knowledgebase entries are not returned,\n", - "- Prevent users of the package from retrieving certain documents in your vectorstore,\n", - "- Removing hate speech from any input text.\n", - "\n", - "\n", - "Hooks work by defining functions that operate on the input and output dataclasses of each of our VectorStore functions/methods.\n", - "\n", - "Key Sections:\n", - "- a recap of how the dataclasses for the VectorStore work, and how they ensure the proper flow of data in our package,\n", - "- how hooks can be implemented by working with the dataclass objects,\n", - "- examples of several different hook implementations, some of which were already mentioned above." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Recap of VectorStore Dataclasses\n", - "\n", - "The majority of the following points are already covered in the recommended first notebook demo, [general_workflow_demo.ipynb](./general_workflow_demo.ipynb). So if you are unfamiliar with the package, that is a good place to start before this notebook, and for an intro to the VectorStore, its methods, and how it works with dataclasses." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### ClassifAI uses Pandas dataframe-like dataclasses to specify what data need to be passed as input to the VectorStore methods/functions, and what data can be expected to be returned by those methods\n", - "\n", - "The **VectorStore** class, responsible for performing different actions with your data, has **three key methods/functions**:\n", - "\n", - "1. **`search()`** \n", - " - Takes in a body of text and searches the vector store for semantically similar knowledgebase samples.\n", - "\n", - "2. **`reverse_search()`** \n", - " - Takes in document IDs and searches the vector store for entries with those IDs.\n", - "\n", - "3. **`embed()`** \n", - " - Takes in a body of text and uses the vectoriser model to convert the text into embeddings.\n", - "\n", - "---\n", - "\n", - "For each of these three core methods, we have created an **input dataclass** and an **output dataclass**. These dataclasses define pandas-like objects that specify what data needs to be passed to each method and also perform runtime checks to ensure you've passed the correct columns in a dataframe to the appropriate VectorStore method.\n", - "\n", - "For example, the figure below illustrates the input and output dataclasses of the `VectorStore.search()` method:\n", - "\n", - "![VectorStore Search Dataflow](./files/vectorstore_search_dataflow.svg)\n", - "\n", - "This shows that the `VectorStore.search()` method expects:\n", - "- An **input dataclass object** with columns `[id, query]`. \n", - "- To output an **output dataclass object** with columns `[query_id, query_text, doc_id, doc_text, rank, score]`.\n", - "\n", - "The use of these dataclasses both helps the user of the package to understand what data needs to be provided to the Vectorstore and how a user should interact with the objects being returned by these VectorStore functions. Additionally, this ensures robustness of the package by checking that the correct columns are present in the data before operating on it. \n", - "\n", - "The reverse_search() and embed() VectorStore functions have their own input and output data classes with their own validity column data checks. The names of each set are intuitively:\n", - "\n", - "| **VectorStore Method** | **Input Dataclass** | **Output Dataclass** |\n", - "|-------------------------------|-----------------------------|-----------------------------|\n", - "| `VectorStore.search()` | `VectorStoreSearchInput` | `VectorStoreSearchOutput` |\n", - "| `VectorStore.reverse_search()` | `VectorStoreReverseSearchInput` | `VectorStoreReverseSearchOutput` |\n", - "| `VectorStore.embed()` | `VectorStoreEmbedInput` | `VectorStoreEmbedOutput` |\n", - "\n", - "Users of the package can use the schema of each of these input and output dataclasses to understand how to interface with these main methods of the VectorStore class.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Hooks and custom dataflows\n", - "\n", - "We have implemented 'hooks' where users can write a function that will manipulate the content of a dataclasses object before or after it passes through the VectorStore. \n", - "\n", - "As long as your custom hook function takes as input an instance of a dataclass, and outputs a valid instance of the same type, then your custom function should run as a part of the end to end VectorStore process.\n", - "\n", - "For example: you might want to preprocess the input to the VectorStore.search() method to remove punctuation from the texts:\n", - "\n", - "![VectorStore Search Dataflow](./files/search_spellcheck_hook.svg)\n", - "\n", - "\n", - "\n", - "In a later part of the demo, we showcase how to implement this punctuation removing function, and apply it to the vectorstore. The important concept here is that the hook function takes in a `VectorStoreSearchInput` object, and outputs a valid `VectorStoreSearchInput` object. \n", - "\n", - "\n", - "This can then be attached to a VectorStore to run every time the VectorStore search method is called. You can also apply other hooks to other dataclasses and their respective VectorStore methods and chain togtether these custom operations that manipulate the input and output dataclasses of the VectorStore methods. \n", - "\n", - "For example, implmenting 2 hooks for the input and output dataclasses of the VectorStore search method would provide a dataflow:\n", - "\n", - "![End to end Search with 2 hooks](./files/pre_and_post_search_hooks.png)\n", - "\n", - "\n", - "The above diagram shows a case where two hooks would be implemented: One that operates on the dataclass `VectorStoreSearchInput`that is passed to the Vectortore search method; and a second hook operating on the `VectorStoreSearchOutput` dataclass that is returned from the VectorStore search method.\n", - "\n", - "\n", - "Hooks can perform pretty much any operation, as long as they accept and return a valid dataclass object - we hope that this provides a lot of freedom to users to be able to transform and manipulate data as needed using ClassifAI." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example Hook implementations\n", - "\n", - "This section now shows how to define your hook functions, and inject them into the VectorStore so that the hooks run when the corresponding method is called.\n", - "\n", - "Specifically we'll look at:\n", - "- a pre-processing function that removes punctuation from input user queries,\n", - "- a post-processing function removes results rows that have duplicate ids to other rows of the results.\n", - "\n", - "- We will then make a final post-processing function that injects additional SOC definition data to the VectorStore results dataframe and show how this can be chained together with the deduplication code, to make a multi-step post-processing function!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Pre-requisite\n", - "\n", - "If you are new to the package, its recommended to follow through the ```general_workflow.ipynb``` notebook tutorial first. That interactive DEMO will showcase the core features of the ```ClassifAI package```. This current notebook provides examples of how to modify the flow of data which is initially described in the general_workflow.ipynb notebook.\n", - "\n", - "Check out the ClassifAI repository DEMO folder for all our notebook walkthrough tutorials including those mentioned above:\n", - "\n", - "https://github.com/datasciencecampus/classifai/tree/main/DEMO \n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "\n", - "#### If you can run the following cell in this notebook, you should be good to go!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from classifai.vectorisers import HuggingFaceVectoriser\n", - "\n", - "print(\"done!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Alternatively, to test without running a notebook, run the following from your command line; \n", - "\n", - "```shell\n", - "python -c \"import classifai\"\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Demo Data \n", - "\n", - "\n", - "This demo uses a mock dataset that is freely available on the ClassifAI repo, if yo have not downloaded the entire DEMO folder to run this notebook, the minimum data you require is the `DEMO/data/testdata.csv` file, which you should place in your working directory in a `DEMO` folder - (or you can just change the filepath later in this demo notebook)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Normal vectorstore setup\n", - "\n", - "We can start by loading a normal vectorstore up with no additional preprocessing/hooks. We can use one of our fake example known datasets is known to have several rows of data with the same ID value. (You can get this from the github repo at the folder location specified in the code)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from classifai.indexers import VectorStore\n", - "\n", - "vectoriser = HuggingFaceVectoriser(model_name=\"sentence-transformers/all-MiniLM-L6-v2\")\n", - "\n", - "\n", - "my_vector_store = VectorStore(\n", - " file_name=\"data/fake_soc_dataset.csv\",\n", - " data_type=\"csv\",\n", - " vectoriser=vectoriser,\n", - " overwrite=True,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The below code uses our dataclasses to set up some data to pass to the VectorStore search method, notice that:\n", - " * an exclaimation mark in the query (that in some cases we may want to sanitise) is shown in the results. \n", - " * Also the results for the below query should also show several rows with the same ```'doc_id'``` value (because our example data file had multiple entries with the same id label)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from classifai.indexers.dataclasses import VectorStoreSearchInput\n", - "\n", - "input_data = VectorStoreSearchInput({\"id\": [1], \"query\": [\"a fruit and vegetable farmer!!!\"]})\n", - "\n", - "my_vector_store.search(input_data, n_results=10)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Making pre- and post- processing hooks \n", - "\n", - "So lets write some functions that will remove punctuation on the user's input query, before the main logic of the Vectorstore.search() method begins, and remove rows with duplicate IDs from the results dataframe just before the results are retutned from the Vectorstore.search() method" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "input_data = VectorStoreSearchInput({\"id\": [1], \"query\": [\"a fruit and vegetable farmer!!!\"]})\n", - "\n", - "input_data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import string\n", - "\n", - "from classifai.indexers.dataclasses import VectorStoreSearchOutput\n", - "\n", - "\n", - "def remove_punctuation(input_data: VectorStoreSearchInput) -> VectorStoreSearchInput:\n", - " # we want to modify the 'texts' field in the input_data pydantic model, which is a list of texts\n", - " # this line removes punctuation from each string with list comprehension\n", - " sanitized_texts = [x.translate(str.maketrans(\"\", \"\", string.punctuation)) for x in input_data[\"query\"]]\n", - "\n", - " input_data[\"query\"] = sanitized_texts\n", - "\n", - " # Return the dictionary of input data with desired modified values at each desired key\n", - " return input_data\n", - "\n", - "\n", - "def drop_duplicates(input_data: VectorStoreSearchOutput) -> VectorStoreSearchOutput:\n", - " # we want to depuplicate the ranking attribute of the pydantic model which is a pandas dataframe\n", - " # specifically we want to drop all but the first occurrence of each unique 'doc_id' value for each subset of query results\n", - " input_data = input_data.drop_duplicates(subset=[\"query_id\", \"doc_id\"], keep=\"first\")\n", - "\n", - " # BE CAREFUL: drop_duplicates returns an object of type DataFrame, not VectorStoreSearchOutput so we need to convert back to that type after this operation\n", - " input_data = VectorStoreSearchOutput(input_data)\n", - "\n", - " return input_data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Adding our Hooks to the VectorStore\n", - "\n", - "Now when we initialise the Vectorstore we can declare our custom functions in the hooks dictionary.\n", - "\n", - "The Vectorstore codebase looks for specifically named dictionary entries in the Hooks dictionary, to decide what pre and post processing hooks to run. There are hooks for each major methods of VectorStore class.\n", - "\n", - "Each dictionary entry uses the method name of the class and '_preprocessor' or '_postprocessor' appended to the name. Currenlty the implemented method hooks are:\n", - "\n", - "- for the VectorStore class:\n", - " * search_preprocess\n", - " * search_postprocess\n", - " * reverse_search_preprocess\n", - " * reverse_search_postprocess\n", - "\n", - "\n", - "For our case in this excercise, we are implementig the search_preprocessor and search_postprocessor methods in the VectorStore.\n", - "\n", - "\n", - "However if we could also add to add a preprocessing or postprocessing hook to a VectorStore reverse search method in a similar manner" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "my_vector_store_with_hooks = VectorStore(\n", - " file_name=\"data/fake_soc_dataset.csv\",\n", - " data_type=\"csv\",\n", - " vectoriser=vectoriser,\n", - " overwrite=True,\n", - " hooks={\n", - " \"search_preprocess\": remove_punctuation,\n", - " \"search_postprocess\": drop_duplicates,\n", - " },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Our hooks will run with the VectorStore search method\n", - "\n", - "\n", - "Now we've passed our desired additional functions to our VectorStore initialisation and those hook should run accordingly - lets see:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "input_data = VectorStoreSearchInput({\"id\": [1], \"query\": [\"a fruit and vegetable farmer!!!\"]})\n", - "\n", - "my_vector_store_with_hooks.search(input_data, n_results=10)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Oops!\n", - "\n", - "Notice how in the above dataframe, the rank column now leaps over some values in each ranking. \n", - "\n", - "We didn't reset the ranking values, per query, when we removed duplicate rows...\n", - "\n", - "lets redo that now in a new function and hook it up to our preprocessing hook.\n", - "\n", - "\n", - "##### Notice how this time, we changed the name of our paramter in our custom hook functions, thats because it doesn't matter what the name of the parameter is, we just need to understand that it will take in one argument - the pydantic object associated with the method." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def drop_duplicates_and_reset_rank(input_object: VectorStoreSearchOutput) -> VectorStoreSearchOutput:\n", - " # Remove duplicates based on 'query_id' and 'doc_id'\n", - " input_object = input_object.drop_duplicates(subset=[\"query_id\", \"doc_id\"], keep=\"first\")\n", - "\n", - " # Reset the rank column per query_id using .loc to avoid SettingWithCopyWarning\n", - " input_object.loc[:, \"rank\"] = input_object.groupby(\"query_id\").cumcount()\n", - "\n", - " # convert the DataFrame back to the pydantic validated object\n", - " input_object = VectorStoreSearchOutput(input_object)\n", - "\n", - " return input_object" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "From the cell below, you can see another way to set hooks - by directly accessing the hooks attribute of a running vectorstore:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# and lets access the hooks directly from the vector store instance to modify them\n", - "my_vector_store_with_hooks.hooks[\"search_postprocess\"] = drop_duplicates_and_reset_rank" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### done - now lets run that query again" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "my_vector_store_with_hooks.search(input_data, n_results=10)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### This of course still works well when you pass multiple queries as we wrote it to separate on query_id column:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "multi_input_data = VectorStoreSearchInput(\n", - " {\n", - " \"id\": [1, 2],\n", - " \"query\": [\"a fruit and vegetable farmer!!!\", \"Digital marketing@\"],\n", - " }\n", - ")\n", - "\n", - "my_vector_store_with_hooks.search(multi_input_data, n_results=10)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Adding Hooks to a VectorStore when loading from filespace\n", - "\n", - "ClassifAI allows you to create your VectorStore once, and then save it to file space so that it can be loaded back in later and reused - without having to create all the vectors again.\n", - "\n", - "If you've followed through with the above code cells you may have noticed that every time we've instantiated a VectorStore it has saved a new folder to filespace (overwriting each time).\n", - "\n", - "Use the VectorStore.from_filespace() class method to load the VectorStore back into memory.\n", - "\n", - "Important: any hooks you applied in previous sessions are not saved to the filespace (it can be difficult to serialise functions). The from_filespace() class method has a hook parameter, similar to the VectorStore constructor we saw earlier. When loading from filespace in this way, you must reaplly the hook functions using this parameter or by setting the attribute after loading, as seen above.\n", - "\n", - "\n", - "The following code cells show an example of loading the VectorStore, that was saved to filespace in this demo, back into memory and reapply the hooks on instantiation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# you can see we've reused the vectoriser and hooks from before\n", - "\n", - "\n", - "reloaded_vector_store = VectorStore.from_filespace(\n", - " folder_path=\"./fake_soc_dataset/\", # YOU MAY NEED TO CHANGE THIS LINE TO THE CORRECT PATH\n", - " vectoriser=vectoriser,\n", - " hooks={\n", - " \"search_preprocess\": remove_punctuation,\n", - " \"search_postprocess\": drop_duplicates,\n", - " },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can then continue to use the vectorstore as seem earlier" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "reloaded_vector_store.search(input_data, n_results=10)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Injecting Data into our classification results with a hook\n", - "\n", - "What if we had some additional context information that we wanted to add in our pipeline. It could be some official taxonomy definitions about our doc_id labels, such as SIC or SOC code definitions.\n", - "\n", - "We may want to inject this extra information that's not directly stored as metadata in the knowledgebase, so that a downstream component (such as a RAG agent) can use the additional information" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### But we also want keep our existing hook logic that removes punctuation..." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "official_id_definitions = {\n", - " \"101\": \"Fruit farmer: Grows and harvests fruits such as apples, oranges, and berries.\",\n", - " \"102\": \"dairy farmer: Manages cows for milk production and processes dairy products.\",\n", - " \"103\": \"construction laborer: Performs physical tasks on construction sites, such as digging and carrying materials.\",\n", - " \"104\": \"carpenter: Constructs, installs, and repairs wooden frameworks and structures.\",\n", - " \"105\": \"electrician: Installs, maintains, and repairs electrical systems in buildings and equipment.\",\n", - " \"106\": \"plumber: Installs and repairs water, gas, and drainage systems in homes and businesses.\",\n", - " \"107\": \"software developer: Designs, writes, and tests computer programs and applications.\",\n", - " \"108\": \"data analyst: Analyzes data to provide insights and support decision-making.\",\n", - " \"109\": \"accountant: Prepares and examines financial records, ensuring accuracy and compliance with regulations.\",\n", - " \"110\": \"teacher: Educates students in schools, colleges, or universities.\",\n", - " \"111\": \"nurse: Provides medical care and support to patients in hospitals, clinics, or homes.\",\n", - " \"112\": \"chef: Prepares and cooks meals in restaurants, hotels, or other food establishments.\",\n", - " \"113\": \"graphic designer: Creates visual concepts for advertisements, websites, and branding.\",\n", - " \"114\": \"mechanic: Repairs and maintains vehicles and machinery.\",\n", - " \"115\": \"photographer: Captures images for events, advertising, or artistic purposes.\",\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def add_id_definitions(input_data: VectorStoreSearchOutput) -> VectorStoreSearchOutput:\n", - " # Map the 'doc_id' column to the corresponding definitions from the dictionary\n", - " input_data.loc[:, \"id_definition\"] = input_data[\"doc_id\"].map(official_id_definitions)\n", - "\n", - " return input_data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### We can now combine this with our deduplicating hook in a new function that runs both" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def process_results(validated_input_object: VectorStoreSearchOutput) -> VectorStoreSearchOutput:\n", - " # First, remove duplicates and reset rank\n", - " validated_input_object = drop_duplicates_and_reset_rank(validated_input_object)\n", - "\n", - " # Then, add ID definitions\n", - " validated_input_object = add_id_definitions(validated_input_object)\n", - "\n", - " # Return the final processed dataframe\n", - " return validated_input_object" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### lets once again update the postprocessing hook on our vectorstore" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "my_vector_store_with_hooks.hooks[\"search_postprocess\"] = process_results" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### and lets try the search again!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "multi_input_data = VectorStoreSearchInput(\n", - " {\n", - " \"id\": [1, 2],\n", - " \"query\": [\"a fruit and vegetable farmer!!!\", \"Digital marketing@\"],\n", - " }\n", - ")\n", - "\n", - "my_vector_store_with_hooks.search(multi_input_data, n_results=10)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can see a few null values in that last output because our demo list of extra data wasn't exhaustive, but where an ID does match our 'official_id_definitions' data we see the data being added correctly." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Roundup\n", - "\n", - "- We wrote and combined several hooks on the Vectorstore class to:\n", - " - remove punctuation from queries before the ```VectorStore.search()``` method is executed\n", - " - remove duplicates from the results list per query ranking and fixed the ranking\n", - " - injected data into our dataflow outside of constructing a vectorstore\n", - " - chained several Vectorstore.search() postprocessing steps together into one function that calls other functions\n", - "\n", - "- In this scenario we effectively showed how to deduplicate the rows of the results dataframe and add additional context columns of information in the form of the id_definitions. Hopefully, it is clear that you can add many pre- or post-processing steps this way, or by writing all steps in one big function - Hooks give you the flexibility and choice here.\n", - "\n", - "- Hooks let you disrupt the normal flow of data in the VectorStores. In this case we just had a small amount of dictionary data being added in, however the hooks allow for more complex scenarios:\n", - " - using a 3rd party API to do automated corrective spell checking before passing your queries to the search method\n", - " - making an SQL query call to a database to get the extra information you want to inject in each row\n", - " - handle errors when the API or database fails and choose what should be returned in these cases\n", - "\n", - "\n", - "### Key Takeaway:\n", - "- When writing your custom hook, remember that your custom hook function should take a single argument - a specific dataclass, and it should output that same dataclass with the modified rows, columns and values. How you implement the logic to update the values is up to you but it must satify the requirements of that dataclass type.\n", - "\n", - "- Depending on which kind of hook you are writing, you need to adhere to the rules of the corresponding dataclass for that hook. For example, in the above demonstration we focused on writing search() preprocessing hooks that manipulate the VectorStoreSearchInput dataclass. However, if you were to write a reverse_search() preprocessing hook, your hook function would need to manipulate the VectorStoreReverseSearchInput dataclass, which has a different set of rules for the columns that must be present and the datatypes of those columns. This extends to each of the hook categories, each of which corresponds to a specific dataclass with its own ruleset.\n", - "\n", - "\n", - "### Next Steps and Challenges:\n", - "\n", - "#### We focused soley on showcasing pre- and post-processing hooks for the VectorStore search method in this notebook:\n", - "\n", - "- See if you can implement some pre- and post- processing hooks for the VectorStore reverse search method:\n", - " - try adding a new column of data to the reverse search results \n", - " - make it so that if the user tries to reverse search for a specific ID that is 'secret' then that row is removed from the input data." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.10" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/DEMO/files/search_spellcheck_hook_clipped.svg b/DEMO/files/search_spellcheck_hook_clipped.svg new file mode 100644 index 0000000..fd217ec --- /dev/null +++ b/DEMO/files/search_spellcheck_hook_clipped.svg @@ -0,0 +1,175 @@ + + + + + + + + + + + + + + + + + + + + + + removePunctuation + () + + + takes in + gives out + VectorStoreSearchInput + VectorStoreSearchInput + + + + + + + User defined ‘hook’ function + + + + + + + + + diff --git a/DEMO/files/vectorstore_search_dataflow_clipped.svg b/DEMO/files/vectorstore_search_dataflow_clipped.svg new file mode 100644 index 0000000..a1f03a6 --- /dev/null +++ b/DEMO/files/vectorstore_search_dataflow_clipped.svg @@ -0,0 +1,361 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + VectorStore + . search() + + + takes + i + n + g + ives out + + + + + + + + + + + + + VectorStoreSearchInput + VectorStoreSearchOutput + + + diff --git a/src/classifai/indexers/__init__.py b/src/classifai/indexers/__init__.py index cf9865d..0d43b7c 100644 --- a/src/classifai/indexers/__init__.py +++ b/src/classifai/indexers/__init__.py @@ -39,12 +39,14 @@ ) from .hooks import ( CapitalisationStandardisingHook, + DeduplicationHook, HookBase, ) from .main import VectorStore __all__ = [ "CapitalisationStandardisingHook", + "DeduplicationHook", "HookBase", "VectorStore", "VectorStoreEmbedInput", diff --git a/src/classifai/indexers/hooks/default_hooks/postprocessing.py b/src/classifai/indexers/hooks/default_hooks/postprocessing.py index 75d68d0..793f4be 100644 --- a/src/classifai/indexers/hooks/default_hooks/postprocessing.py +++ b/src/classifai/indexers/hooks/default_hooks/postprocessing.py @@ -41,6 +41,11 @@ def __init__(self, score_aggregation_method: str = "max"): def __call__(self, input_data: VectorStoreSearchOutput) -> VectorStoreSearchOutput: """Aggregates retrieved knowledgebase entries corresponding to the same label.""" + # 1) Group on two levels - first on query_id, then on doc_id, to ensure that entries with the same label are + # deduplicated within the results for each query. Note that there is a 1-1 mapping between query_id and query_text, + # so no extra grouping is made, but this excludes query_text from the columns to be processed. + # 2) For each group, aggregate the score using the specified method, and assign a new column 'idxmax' to the unique id + # of the entry with the best score. This will allow us to retain the metadata of the best scoring entry. df_gpby = ( input_data.groupby(["query_id", "query_text", "doc_id"]) .aggregate( @@ -50,14 +55,19 @@ def __call__(self, input_data: VectorStoreSearchOutput) -> VectorStoreSearchOutp ) .reset_index() ) - + # For each query, re-assign ranks based on the new aggregated scores, to the remaining entries, to ensure that the best + # scoring entry for each label is ranked highest. for query in df_gpby["query_id"].unique(): batch = df_gpby[df_gpby["query_id"] == query] new_rank = pd.factorize(-batch["score"], sort=True)[0] + 1 df_gpby.loc[batch.index, "rank"] = new_rank - + # Finally, we re-merge the deduplicated results with the original input dataframe, + # to retrieve the metadata of the best scoring entry for each label, and return the processed output. for col in set(input_data.columns).difference(set(df_gpby.columns)): df_gpby[col] = df_gpby["idxmax"].map(input_data[col]) - - processed_output = input_data.__class__.validate(df_gpby[input_data.columns]) + # We sort the output by query_id and doc_id to ensure a consistent order of results for each query, + # and validate the output against the dataclass schema. + processed_output = input_data.__class__.validate( + df_gpby[input_data.columns].sort_values(by=["query_id", "doc_id"], ascending=[True, True]) + ) return processed_output diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index 14d64ee..497dd0a 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -356,7 +356,7 @@ def _create_vector_store_index(self): # noqa: C901 }, ) from e - def embed(self, query: VectorStoreEmbedInput) -> VectorStoreEmbedOutput: + def embed(self, query: VectorStoreEmbedInput) -> VectorStoreEmbedOutput: # noqa: C901 """Converts text (provided via a `VectorStoreEmbedInput` object) into vector embeddings using the `Vectoriser` and returns a `VectorStoreEmbedOutput` dataframe with columns `id`, `text`, and `embedding`. @@ -382,8 +382,11 @@ def embed(self, query: VectorStoreEmbedInput) -> VectorStoreEmbedOutput: # ---- Preprocess hook -> HookError if "embed_preprocess" in self.hooks: try: - modified_query = self.hooks["embed_preprocess"](query) - query = VectorStoreEmbedInput.validate(modified_query) + if not isinstance(self.hooks["embed_preprocess"], list): + self.hooks["embed_preprocess"] = [self.hooks["embed_preprocess"]] + for hook in self.hooks["embed_preprocess"]: + query = hook(query) + query = VectorStoreEmbedInput.validate(query) except Exception as e: raise HookError( "embed_preprocess hook raised an exception.", @@ -421,8 +424,11 @@ def embed(self, query: VectorStoreEmbedInput) -> VectorStoreEmbedOutput: # ---- Postprocess hook -> HookError if "embed_postprocess" in self.hooks: try: - modified_results_df = self.hooks["embed_postprocess"](results_df) - results_df = VectorStoreEmbedOutput.validate(modified_results_df) + if not isinstance(self.hooks["embed_postprocess"], list): + self.hooks["embed_postprocess"] = [self.hooks["embed_postprocess"]] + for hook in self.hooks["embed_postprocess"]: + results_df = hook(results_df) + results_df = VectorStoreEmbedOutput.validate(results_df) except Exception as e: raise HookError( "embed_postprocess hook raised an exception.", @@ -431,7 +437,7 @@ def embed(self, query: VectorStoreEmbedInput) -> VectorStoreEmbedOutput: return results_df - def reverse_search( # noqa: C901 + def reverse_search( # noqa: C901, PLR0912 self, query: VectorStoreReverseSearchInput, max_n_results: int = 100, partial_match: bool = False ) -> VectorStoreReverseSearchOutput: """Reverse searches the `VectorStore` using a `VectorStoreReverseSearchInput` object @@ -473,8 +479,11 @@ def reverse_search( # noqa: C901 # ---- Preprocess hook -> HookError if "reverse_search_preprocess" in self.hooks: try: - modified_query = self.hooks["reverse_search_preprocess"](query) - query = VectorStoreReverseSearchInput.validate(modified_query) + if not isinstance(self.hooks["reverse_search_preprocess"], list): + self.hooks["reverse_search_preprocess"] = [self.hooks["reverse_search_preprocess"]] + for hook in self.hooks["reverse_search_preprocess"]: + query = hook(query) + query = VectorStoreReverseSearchInput.validate(query) except Exception as e: raise HookError( "reverse_search_preprocess hook raised an exception.", @@ -531,8 +540,11 @@ def reverse_search( # noqa: C901 # ---- Postprocess hook -> HookError if "reverse_search_postprocess" in self.hooks: try: - modified_result_df = self.hooks["reverse_search_postprocess"](result_df) - result_df = VectorStoreReverseSearchOutput.validate(modified_result_df) + if not isinstance(self.hooks["reverse_search_postprocess"], list): + self.hooks["reverse_search_postprocess"] = [self.hooks["reverse_search_postprocess"]] + for hook in self.hooks["reverse_search_postprocess"]: + result_df = hook(result_df) + result_df = VectorStoreReverseSearchOutput.validate(result_df) except Exception as e: raise HookError( "reverse_search_postprocess hook raised an exception.", @@ -587,8 +599,11 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V # ---- Preprocess hook -> DataValidationError if it returns invalid shape/type if "search_preprocess" in self.hooks: try: - modified_query = self.hooks["search_preprocess"](query) - query = VectorStoreSearchInput.validate(modified_query) + if not isinstance(self.hooks["search_preprocess"], list): + self.hooks["search_preprocess"] = [self.hooks["search_preprocess"]] + for hook in self.hooks["search_preprocess"]: + query = hook(query) + query = VectorStoreSearchInput.validate(query) except Exception as e: raise HookError( "search_preprocess hook raised an exception.", @@ -703,8 +718,11 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V # ---- Postprocess hook -> DataValidationError if it returns invalid shape/type if "search_postprocess" in self.hooks: try: - modified_result_df = self.hooks["search_postprocess"](result_df) - result_df = VectorStoreSearchOutput.validate(modified_result_df) + if not isinstance(self.hooks["search_postprocess"], list): + self.hooks["search_postprocess"] = [self.hooks["search_postprocess"]] + for hook in self.hooks["search_postprocess"]: + result_df = hook(result_df) + result_df = VectorStoreSearchOutput.validate(result_df) except Exception as e: raise HookError( "search_postprocessing hook raised an exception.", From 01c72210e9597fd724bc7ff90d7c1e19c5f72651 Mon Sep 17 00:00:00 2001 From: Luke Roantree <206716137+lukeroantreeONS@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:55:24 +0000 Subject: [PATCH 8/8] chore(hooks): update .svg files for demo, add missing Markdown paragraph in demo, remove commented out softmax code --- DEMO/Using_Hooks.ipynb | 13 +- DEMO/files/search_spellcheck_hook.svg | 174 +++++++- DEMO/files/search_spellcheck_hook_clipped.svg | 175 -------- DEMO/files/vectorstore_search_dataflow.svg | 388 +++++++++++++++++- .../vectorstore_search_dataflow_clipped.svg | 361 ---------------- .../hooks/default_hooks/postprocessing.py | 3 - 6 files changed, 570 insertions(+), 544 deletions(-) delete mode 100644 DEMO/files/search_spellcheck_hook_clipped.svg delete mode 100644 DEMO/files/vectorstore_search_dataflow_clipped.svg diff --git a/DEMO/Using_Hooks.ipynb b/DEMO/Using_Hooks.ipynb index 5c67de4..bdba52b 100644 --- a/DEMO/Using_Hooks.ipynb +++ b/DEMO/Using_Hooks.ipynb @@ -78,7 +78,7 @@ "\n", "
\n", "\n", - "![VectorStore Search Dataflow](./files/vectorstore_search_dataflow_clipped.svg)\n", + "![VectorStore Search Dataflow](./files/vectorstore_search_dataflow.svg)\n", "\n", "
\n", "\n", @@ -115,7 +115,7 @@ "\n", "
\n", "\n", - "![VectorStore Search Dataflow](./files/search_spellcheck_hook_clipped.svg)\n", + "![VectorStore Search Dataflow](./files/search_spellcheck_hook.svg)\n", "\n", "
\n", "\n", @@ -173,6 +173,13 @@ "\n", "---\n", "\n", + "#### Defining a custom hook:\n", + "\n", + "Hooks can be quite simple functions, all they need to do is accept a single dataframe-like input (a `VectorStore` dataclass), and output an object of the same type.\n", + "To make sure your output is valid at the end, we recommend using the input object's class' `.validate()` method.\n", + "\n", + "One example of where we may want to define a custom hook is to strip punctuation from some input text in a `search()` query; we'll now show a short worked example to implement this with a custom hook function.\n", + "\n", "We'll name our hook `remove_punctuation()`, and define it like this:" ] }, @@ -193,7 +200,7 @@ " sanitised_texts = [x.translate(str.maketrans(\"\", \"\", string.punctuation)) for x in input_data[\"query\"]]\n", " input_data[\"query\"] = sanitised_texts\n", " # Return the dictionary of input data with desired modified values at each desired key\n", - " return input_data" + " return input_data.__class__.validate(input_data)" ] }, { diff --git a/DEMO/files/search_spellcheck_hook.svg b/DEMO/files/search_spellcheck_hook.svg index f9bc1e2..f0a0c3d 100644 --- a/DEMO/files/search_spellcheck_hook.svg +++ b/DEMO/files/search_spellcheck_hook.svg @@ -1 +1,173 @@ -removePunctuation()takes ingives outVectorStoreSearchInputVectorStoreSearchInputUser defined ‘hook’ function \ No newline at end of file + + + + + + + + + + + + + + + + + + + + + + + removePunctuation() + + + takes in + gives out + VectorStoreSearchInput + VectorStoreSearchInput + + + + + + + User defined ‘hook’ function + + + + + + + + + + diff --git a/DEMO/files/search_spellcheck_hook_clipped.svg b/DEMO/files/search_spellcheck_hook_clipped.svg deleted file mode 100644 index fd217ec..0000000 --- a/DEMO/files/search_spellcheck_hook_clipped.svg +++ /dev/null @@ -1,175 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - removePunctuation - () - - - takes in - gives out - VectorStoreSearchInput - VectorStoreSearchInput - - - - - - - User defined ‘hook’ function - - - - - - - - - diff --git a/DEMO/files/vectorstore_search_dataflow.svg b/DEMO/files/vectorstore_search_dataflow.svg index 743875f..f278801 100644 --- a/DEMO/files/vectorstore_search_dataflow.svg +++ b/DEMO/files/vectorstore_search_dataflow.svg @@ -1 +1,387 @@ -VectorStore. search()takesingives outVectorStoreSearchInputVectorStoreSearchOutput \ No newline at end of file + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + VectorStore + . search() + + + takes + i + n + g + ives out + + + + + + + + + + + + + VectorStoreSearchInput + VectorStoreSearchOutput + + + diff --git a/DEMO/files/vectorstore_search_dataflow_clipped.svg b/DEMO/files/vectorstore_search_dataflow_clipped.svg deleted file mode 100644 index a1f03a6..0000000 --- a/DEMO/files/vectorstore_search_dataflow_clipped.svg +++ /dev/null @@ -1,361 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - VectorStore - . search() - - - takes - i - n - g - ives out - - - - - - - - - - - - - VectorStoreSearchInput - VectorStoreSearchOutput - - - diff --git a/src/classifai/indexers/hooks/default_hooks/postprocessing.py b/src/classifai/indexers/hooks/default_hooks/postprocessing.py index 793f4be..01700a2 100644 --- a/src/classifai/indexers/hooks/default_hooks/postprocessing.py +++ b/src/classifai/indexers/hooks/default_hooks/postprocessing.py @@ -31,9 +31,6 @@ def __init__(self, score_aggregation_method: str = "max"): self.score_aggregation_method = score_aggregation_method if self.score_aggregation_method == "max": self.score_aggregator = self._max_score - # Softmax not supported until normalisation is implemented. - # elif self.score_aggregation_method == "softmax": - # self.score_aggregator = ... elif self.score_aggregation_method == "mean": self.score_aggregator = self._mean_score