Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
597 changes: 597 additions & 0 deletions DEMO/Using_Hooks.ipynb

Large diffs are not rendered by default.

692 changes: 0 additions & 692 deletions DEMO/custom_preprocessing_and_postprocessing_hooks.ipynb

This file was deleted.

174 changes: 173 additions & 1 deletion DEMO/files/search_spellcheck_hook.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
388 changes: 387 additions & 1 deletion DEMO/files/vectorstore_search_dataflow.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions src/classifai/indexers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,17 @@
VectorStoreSearchInput,
VectorStoreSearchOutput,
)
from .hooks import (
CapitalisationStandardisingHook,
DeduplicationHook,
HookBase,
)
from .main import VectorStore

__all__ = [
"CapitalisationStandardisingHook",
"DeduplicationHook",
"HookBase",
"VectorStore",
"VectorStoreEmbedInput",
"VectorStoreEmbedOutput",
Expand Down
8 changes: 8 additions & 0 deletions src/classifai/indexers/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .default_hooks import CapitalisationStandardisingHook, DeduplicationHook
from .hook_factory import HookBase

__all__ = [
"CapitalisationStandardisingHook",
"DeduplicationHook",
"HookBase",
]
7 changes: 7 additions & 0 deletions src/classifai/indexers/hooks/default_hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .postprocessing import DeduplicationHook
from .preprocessing import CapitalisationStandardisingHook

__all__ = [
"CapitalisationStandardisingHook",
"DeduplicationHook",
]
70 changes: 70 additions & 0 deletions src/classifai/indexers/hooks/default_hooks/postprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
import pandas as pd

from classifai.exceptions import HookError
from classifai.indexers.dataclasses import VectorStoreSearchOutput
from classifai.indexers.hooks.hook_factory import HookBase


class DeduplicationHook(HookBase):
"""A pre-processing hook to remove duplicate knowledgebase entries, i.e. entries with the same label."""

def _mean_score(self, scores):
return np.mean(scores)

def _max_score(self, scores):
return np.max(scores)

def __init__(self, score_aggregation_method: str = "max"):
"""Inititialises the hook with the specified method for assigning scores to deduplicated entries.

Args:
score_aggregation_method (str): Method for assigning score to the deduplicated entry.
Must be one of "max" or "mean". Defaults to "max".
A future update will introduce a 'softmax' option.
"""
if score_aggregation_method not in ["max", "mean"]:
raise HookError(
"Invalid method for DeduplicationHook. Must be one of 'max', or 'mean'.",
context={self.hook_type: "Deduplication", "method": score_aggregation_method},
)
self.score_aggregation_method = score_aggregation_method
if self.score_aggregation_method == "max":
self.score_aggregator = self._max_score
elif self.score_aggregation_method == "mean":
self.score_aggregator = self._mean_score

super().__init__(hook_type="post_processing")

def __call__(self, input_data: VectorStoreSearchOutput) -> VectorStoreSearchOutput:
"""Aggregates retrieved knowledgebase entries corresponding to the same label."""
# 1) Group on two levels - first on query_id, then on doc_id, to ensure that entries with the same label are
# deduplicated within the results for each query. Note that there is a 1-1 mapping between query_id and query_text,
# so no extra grouping is made, but this excludes query_text from the columns to be processed.
# 2) For each group, aggregate the score using the specified method, and assign a new column 'idxmax' to the unique id
# of the entry with the best score. This will allow us to retain the metadata of the best scoring entry.
df_gpby = (
input_data.groupby(["query_id", "query_text", "doc_id"])
.aggregate(
score=("score", self.score_aggregator),
idxmax=("score", "idxmax"),
rank=("rank", "min"),
)
.reset_index()
)
# For each query, re-assign ranks based on the new aggregated scores, to the remaining entries, to ensure that the best
# scoring entry for each label is ranked highest.
for query in df_gpby["query_id"].unique():
batch = df_gpby[df_gpby["query_id"] == query]
new_rank = pd.factorize(-batch["score"], sort=True)[0] + 1
df_gpby.loc[batch.index, "rank"] = new_rank
# Finally, we re-merge the deduplicated results with the original input dataframe,
# to retrieve the metadata of the best scoring entry for each label, and return the processed output.
for col in set(input_data.columns).difference(set(df_gpby.columns)):
df_gpby[col] = df_gpby["idxmax"].map(input_data[col])
# We sort the output by query_id and doc_id to ensure a consistent order of results for each query,
# and validate the output against the dataclass schema.
processed_output = input_data.__class__.validate(
df_gpby[input_data.columns].sort_values(by=["query_id", "doc_id"], ascending=[True, True])
)
return processed_output
63 changes: 63 additions & 0 deletions src/classifai/indexers/hooks/default_hooks/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from classifai.exceptions import HookError
from classifai.indexers.dataclasses import (
VectorStoreEmbedInput,
VectorStoreReverseSearchInput,
VectorStoreSearchInput,
)
from classifai.indexers.hooks.hook_factory import HookBase


class CapitalisationStandardisingHook(HookBase):
"""A pre-processing hook to handle upper-/lower-/sentence-/title-casing."""

def __init__(self, method: str = "lower", colname: str | list[str] = "query"):
"""Inititialises the hook with the specified method for standardising capitalisation.

Args:
method (str): Method for standardisation. Must be one of "lower" (like this),
"upper" (LIKE THIS), "sentence" (Like this), or "title" (Like This).
Defaults to "lower".
colname (str | list[str]): The name of one of the fields of the Input object which is/are
to be capitalised.
Defaults to "query".
"""
super().__init__(method=method, colname=colname, hook_type="pre_processing")
if method not in {"lower", "upper", "sentence", "title"}:
raise HookError(
"Invalid method for CapitalisationStandardisingHook. "
"Must be one of 'lower', 'upper', 'sentence', or 'title'.",
context={self.hook_type: "Capitalisation", "method": method},
)
if method == "lower":
self.method = str.lower
elif method == "upper":
self.method = str.upper
elif method == "sentence":
self.method = lambda text: text.capitalize() if text else text
elif method == "title":
self.method = str.title
self.colname = colname

def __call__(
self, input_data: VectorStoreSearchInput | VectorStoreReverseSearchInput | VectorStoreEmbedInput
) -> VectorStoreSearchInput | VectorStoreReverseSearchInput | VectorStoreEmbedInput:
"""Standardises capitalisation in the input data as specified by 'method' attribute."""
if isinstance(self.colname, str):
self.colname = [self.colname]
for col in self.colname:
if col not in input_data.columns:
raise HookError(
"Invalid column name passed.", context={"pre_processing": "Capitalisation", "colname": col}
)
if col not in input_data.select_dtypes(include=["object"]).columns:
raise HookError(
f"colname is of type {input_data[col].dtype}, expected 'str'.",
context={"pre_processing": "Capitalisation", "colname": col},
)

processed_input = input_data.copy()
for col in self.colname:
processed_input[col] = processed_input[col].apply(self.method)
# Ensure the processed input still conforms to the schema
processed_input = input_data.__class__.validate(processed_input)
return processed_input
48 changes: 48 additions & 0 deletions src/classifai/indexers/hooks/hook_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from abc import ABC, abstractmethod

from classifai.exceptions import HookError
from classifai.indexers.dataclasses import (
VectorStoreEmbedInput,
VectorStoreEmbedOutput,
VectorStoreReverseSearchInput,
VectorStoreReverseSearchOutput,
VectorStoreSearchInput,
VectorStoreSearchOutput,
)


class HookBase(ABC):
"""Abstract base class for all post-processing hooks requiring customisation."""

def __init__(self, **kwargs):
"""Sets any attributes required by the hook."""
self.hook_type: str = "generic" # Placeholder for hook type, can be overridden by subclasses
# or set via kwargs
self.kwargs = kwargs

@abstractmethod
def __call__(
self,
data: VectorStoreSearchOutput
| VectorStoreReverseSearchOutput
| VectorStoreEmbedOutput
| VectorStoreSearchInput
| VectorStoreReverseSearchInput
| VectorStoreEmbedInput,
) -> (
VectorStoreSearchOutput
| VectorStoreReverseSearchOutput
| VectorStoreEmbedOutput
| VectorStoreSearchInput
| VectorStoreReverseSearchInput
| VectorStoreEmbedInput
):
"""Defines the behavior of the hook when called."""
processed_data = data # Placeholder for processing logic
if not isinstance(processed_data, type(data)):
raise HookError(
f"Processed data must be of the same type as input. "
f"Expected {type(data).__name__}, got {type(processed_data).__name__}.",
context={"hook_type": self.hook_type},
)
return processed_data
46 changes: 32 additions & 14 deletions src/classifai/indexers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def _create_vector_store_index(self): # noqa: C901
},
) from e

def embed(self, query: VectorStoreEmbedInput) -> VectorStoreEmbedOutput:
def embed(self, query: VectorStoreEmbedInput) -> VectorStoreEmbedOutput: # noqa: C901
"""Converts text (provided via a `VectorStoreEmbedInput` object) into vector embeddings using the `Vectoriser` and
returns a `VectorStoreEmbedOutput` dataframe with columns `id`, `text`, and `embedding`.

Expand All @@ -382,8 +382,11 @@ def embed(self, query: VectorStoreEmbedInput) -> VectorStoreEmbedOutput:
# ---- Preprocess hook -> HookError
if "embed_preprocess" in self.hooks:
try:
modified_query = self.hooks["embed_preprocess"](query)
query = VectorStoreEmbedInput.validate(modified_query)
if not isinstance(self.hooks["embed_preprocess"], list):
self.hooks["embed_preprocess"] = [self.hooks["embed_preprocess"]]
for hook in self.hooks["embed_preprocess"]:
query = hook(query)
query = VectorStoreEmbedInput.validate(query)
except Exception as e:
raise HookError(
"embed_preprocess hook raised an exception.",
Expand Down Expand Up @@ -421,8 +424,11 @@ def embed(self, query: VectorStoreEmbedInput) -> VectorStoreEmbedOutput:
# ---- Postprocess hook -> HookError
if "embed_postprocess" in self.hooks:
try:
modified_results_df = self.hooks["embed_postprocess"](results_df)
results_df = VectorStoreEmbedOutput.validate(modified_results_df)
if not isinstance(self.hooks["embed_postprocess"], list):
self.hooks["embed_postprocess"] = [self.hooks["embed_postprocess"]]
for hook in self.hooks["embed_postprocess"]:
results_df = hook(results_df)
results_df = VectorStoreEmbedOutput.validate(results_df)
except Exception as e:
raise HookError(
"embed_postprocess hook raised an exception.",
Expand All @@ -431,7 +437,7 @@ def embed(self, query: VectorStoreEmbedInput) -> VectorStoreEmbedOutput:

return results_df

def reverse_search( # noqa: C901
def reverse_search( # noqa: C901, PLR0912
self, query: VectorStoreReverseSearchInput, max_n_results: int = 100, partial_match: bool = False
) -> VectorStoreReverseSearchOutput:
"""Reverse searches the `VectorStore` using a `VectorStoreReverseSearchInput` object
Expand Down Expand Up @@ -473,8 +479,11 @@ def reverse_search( # noqa: C901
# ---- Preprocess hook -> HookError
if "reverse_search_preprocess" in self.hooks:
try:
modified_query = self.hooks["reverse_search_preprocess"](query)
query = VectorStoreReverseSearchInput.validate(modified_query)
if not isinstance(self.hooks["reverse_search_preprocess"], list):
self.hooks["reverse_search_preprocess"] = [self.hooks["reverse_search_preprocess"]]
for hook in self.hooks["reverse_search_preprocess"]:
query = hook(query)
query = VectorStoreReverseSearchInput.validate(query)
except Exception as e:
raise HookError(
"reverse_search_preprocess hook raised an exception.",
Expand Down Expand Up @@ -531,8 +540,11 @@ def reverse_search( # noqa: C901
# ---- Postprocess hook -> HookError
if "reverse_search_postprocess" in self.hooks:
try:
modified_result_df = self.hooks["reverse_search_postprocess"](result_df)
result_df = VectorStoreReverseSearchOutput.validate(modified_result_df)
if not isinstance(self.hooks["reverse_search_postprocess"], list):
self.hooks["reverse_search_postprocess"] = [self.hooks["reverse_search_postprocess"]]
for hook in self.hooks["reverse_search_postprocess"]:
result_df = hook(result_df)
result_df = VectorStoreReverseSearchOutput.validate(result_df)
except Exception as e:
raise HookError(
"reverse_search_postprocess hook raised an exception.",
Expand Down Expand Up @@ -587,8 +599,11 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V
# ---- Preprocess hook -> DataValidationError if it returns invalid shape/type
if "search_preprocess" in self.hooks:
try:
modified_query = self.hooks["search_preprocess"](query)
query = VectorStoreSearchInput.validate(modified_query)
if not isinstance(self.hooks["search_preprocess"], list):
self.hooks["search_preprocess"] = [self.hooks["search_preprocess"]]
for hook in self.hooks["search_preprocess"]:
query = hook(query)
query = VectorStoreSearchInput.validate(query)
except Exception as e:
raise HookError(
"search_preprocess hook raised an exception.",
Expand Down Expand Up @@ -703,8 +718,11 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V
# ---- Postprocess hook -> DataValidationError if it returns invalid shape/type
if "search_postprocess" in self.hooks:
try:
modified_result_df = self.hooks["search_postprocess"](result_df)
result_df = VectorStoreSearchOutput.validate(modified_result_df)
if not isinstance(self.hooks["search_postprocess"], list):
self.hooks["search_postprocess"] = [self.hooks["search_postprocess"]]
for hook in self.hooks["search_postprocess"]:
result_df = hook(result_df)
result_df = VectorStoreSearchOutput.validate(result_df)
except Exception as e:
raise HookError(
"search_postprocessing hook raised an exception.",
Expand Down
Loading