diff --git a/DEMO/files/agent_classification.png b/DEMO/files/agent_classification.png new file mode 100644 index 0000000..ed631e2 Binary files /dev/null and b/DEMO/files/agent_classification.png differ diff --git a/DEMO/files/e2e_genai_hook.png b/DEMO/files/e2e_genai_hook.png new file mode 100644 index 0000000..162e94a Binary files /dev/null and b/DEMO/files/e2e_genai_hook.png differ diff --git a/DEMO/files/human_in_the_loop.png b/DEMO/files/human_in_the_loop.png new file mode 100644 index 0000000..629abcd Binary files /dev/null and b/DEMO/files/human_in_the_loop.png differ diff --git a/DEMO/files/scoring_threshold.png b/DEMO/files/scoring_threshold.png new file mode 100644 index 0000000..6f23804 Binary files /dev/null and b/DEMO/files/scoring_threshold.png differ diff --git a/DEMO/files/vectorstore_search.png b/DEMO/files/vectorstore_search.png new file mode 100644 index 0000000..0dc40a7 Binary files /dev/null and b/DEMO/files/vectorstore_search.png differ diff --git a/DEMO/generative_ai_classifier.ipynb b/DEMO/generative_ai_classifier.ipynb new file mode 100644 index 0000000..8487b59 --- /dev/null +++ b/DEMO/generative_ai_classifier.ipynb @@ -0,0 +1,1504 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Making an AI Classifier with ClassifAI Hooks and Google GenAI ✨" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Through the use of semantic/vector search, ClassifAI enables users to develop a system to retrieve candidate classification texts for a sample of text with an unknown label. This is done by pre-computing and storing vector representations of text samples with known labels, then a text sample with an unknown label can be vectorisered and its embedded representaiton compared to the embedded representation of the labelled samples. Using this process, a ranking of potential candidate labels based on how similar the unlabelled text is to each labelled sample can be obtained. This then offers the user flexibility in how to derive a final classification from the retrieved candidate classifications. \n", + "\n", + "\n", + "There are many posible ways to use the candidate list, and there are also many ways derive a final classification from the ranking list to obtain a single prediction label for a sample of text.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Common strategies include:\n", + "\n", + "1. A human-in-the-loop approach where a human reads the candidate ranking list and unlabelled text to classify what the unknown label should be. This reduces the workload of the human by providing then with the top-K candidate ranking, insteaed of the user having to consider every possible label category, which can be numerous.\n", + "\n", + "2. Automatically selecting the top ranked candidate. This is an efficient approach with minimal additional computation required - however it can lead to inaccurate results. It can help to calibrate on a test collection and set confidence thresholds that determine if a sample should be automatically classified. These thresholds can be calibrated with test datasets.\n", + "\n", + "3. Using a designated AI model to make the final decision. In this case, a RAG AI model utilises the candidate ranking as a set of retrieved results from ClassifAI and makes a final decision based on the information provided." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + " \"Strategy\n", + " \"Strategy\n", + " \"Strategy\n", + "
\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook we demonstrate an approach to strategy 3, using Google's Gemini generative AI agent together with ClassifAI's hooks functionality, to automatically classify ClassifAI candidate results in a single ClassifAI pipeline." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Geneerative AI Schematic](./files/e2e_genai_hook.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### What is covered in this notebook:\n", + "\n", + "1. A ClassifAI recap on how to obtain search results,\n", + "\n", + "2. An introduction to using the GCP Generative AI models,\n", + "\n", + "3. Custom Hooks in ClassifAI,\n", + "\n", + "4. Setting up a Generateive AI model to make the final classification in ClassifAI - Systems Prompts, Response formatting, ClassifAI Hook operations,\n", + "\n", + "5. A full implementation of the code (you can skip to this section for a view of the final solution that can be copied or used in your own work)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Requirements / Prerequisites\n", + "\n", + "This notebook is designed so that you can execute the code cells and follow along with the implementation. To do this you will need:\n", + "\n", + "1. Download this notebook from your local machine, and the DEMO/data/fake_soc_dataset.csv file which will be used as a test dataset,\n", + "\n", + "2. Set up a virtual environment with ClassifAI[gcp] installed for this ipynb instance - view ClassifAI installation instructions in the repo,\n", + "\n", + "3. Ensure that you have Google Cloud Platform project set up - with the Generative AI API enabled to use Google's Gemini models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 1: Using ClassifAI to get candidate Rankings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!uv pip install \"https://github.com/datasciencecampus/classifai/releases/download/v0.2.1/classifai-0.2.1-py3-none-any.whl[gcp]\"\n", + "\n", + "## you will also need to authenticate to GCP and set up your project and bucket for the next steps 'gcloud auth login' and 'gcloud config set project [PROJECT_ID]'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will start with a script that will create a ClassifAI `VectorStore` object, which we use to create a Vector Database of labelled samples. Later we can use this VectorStore object to find labelled text sample that are similar to the text of a provided unlabelled sample." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers import VectorStore\n", + "from classifai.vectorisers import GcpVectoriser\n", + "\n", + "# Initialise the vectoriser\n", + "demo_vectoriser = GcpVectoriser(project_id=\"YOUR PROJECT ID HERE\", vertexai=True)\n", + "\n", + "\n", + "# Initialise the vector store, pointing the demo to the demo test data\n", + "demo_vectorstore = VectorStore(\n", + " file_name=\"./data/fake_soc_dataset.csv\",\n", + " data_type=\"csv\",\n", + " vectoriser=demo_vectoriser,\n", + " output_dir=\"./demo_vdb\",\n", + " overwrite=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above cell should indicate a succesfully created a VectorStore which is now in memory as `demo_vectorstore`. It contains vectors for the `fake_soc_dataset` which has many dummy example texts with corresponding dummy SOC labels.\n", + "\n", + "\n", + "We can call the `demo_vectorstore.search()` method, passing it a `VectorStoreSearchInput` dataclass object with an unlabelled sample of text, to get back a candidate ranking of the `fake_soc_dataset` entries that have the most similar text to our labelled sample." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers.dataclasses import VectorStoreSearchInput\n", + "\n", + "# Create a search input\n", + "search_input = VectorStoreSearchInput({\"id\": [\"1\"], \"query\": [\"a photographer hired for wedding events\"]})\n", + "# call the VectorStore search method\n", + "demo_vectorstore.search(search_input, n_results=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you succesfully ran the above code cell, you should now see a pandas dataframe output with 5 candidate results. The unlabelled input sample was `a photographer hired for wedding events`, which is shown in the `query_text` column. The `doc_text` column of each row shows the corresponding top 5 most similar samples from the `fake_soc_dataset.csv` file, with the `doc_id` column inticating the predicted label. `rank` and `score` columns provide additionl relevant information.\n", + "\n", + "This showcases the core functionality of the ClassifAI package, being able to build a semantic search engine for your labelledd datasets and get a candidate result sets for unlabelled samples. \n", + "\n", + "As described in the introduction section, these candidate lists can be used in different ways to arrive at a final decision about the correct label for our sample. \n", + "\n", + "ClassifAI provides a 'shortlist' of options. In this demo we're now going to show how we can use a Generative AI model to automatically make the final classification." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![ClassifAI Search Diagram](./files/vectorstore_search_dataflow.svg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See additional resources on ClassifAI on the OpenAI Github Repo:\n", + "\n", + "- https://github.com/datasciencecampus/classifai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 2: Generative AI" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This section provides a brief overview of how Google's Generative AI (genai) API to enables us to use the AI models.\n", + "\n", + "Broadly, Generative AI models accept some text as input, which can be in the form of an instruction, telling the model to _generate_ some output text. \n", + "\n", + "We will see how we can set up a connection to one of Google's cloud-hosted generative models and 'prompt' it to produce some output. We'll also explore basic generative model principles such as 'system prompting' and 'structured outputs', which we'll use later to construct a generative model that makes a final classification of some input, that would be suitable to do classifications on our `VectorStore` candidate results.\n", + "\n", + "This section of the tutorial is purely about the Generative model capabilities, later we will see how these models can be used in a pipeline with ClassifAI to make classifications. In this section we showcase a simple 'weather classifier' which will indicate if a sample of text contains a description of good or bad weather." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prompting the Google Genai API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!uv pip install google-genai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Firstly, we can establish an API connection to a generative model hosted by google and 'prompt' it to write a simple story about the weather. This showcases how we can pass text to a generative model and get some text response back." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google import genai\n", + "\n", + "# Set up the API key for authentication\n", + "client = genai.Client(api_key=\"YOUR API KEY HERE\")\n", + "\n", + "# Define the model to use\n", + "chosen_model = \"gemini-2.5-flash\" # Replace with the appropriate Gemini Flash 2.5 model name if different" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Send a text prompt to the model\n", + "prompt = \"Write a very short description about todays weather.\"\n", + "\n", + "\n", + "response = client.models.generate_content(\n", + " model=chosen_model,\n", + " contents=prompt,\n", + ")\n", + "\n", + "# Print the response\n", + "print(\"Generated Text:\")\n", + "print(response.text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have a Genai client set up to work with the google cloud models, we can instruct the generative model on how to 'behave' by choosing how we prompt the model. For example, relevant to our ClassifAI task, we can ask the model to classify inputs into categories." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt2 = \"\"\"\n", + "\n", + "Classify the following text into 'good weather' or 'bad weather':\n", + "\n", + "The rain is crashing down in Scotland.\n", + "\n", + "\"\"\"\n", + "\n", + "response = client.models.generate_content(\n", + " model=chosen_model,\n", + " contents=prompt2,\n", + ")\n", + "\n", + "\n", + "# Print the response\n", + "print(\"Generated Text:\")\n", + "print(response.text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you change `prompt2` to different weather descriptions you will likely see some output from the model indicating one of our two classes. But there are two core problems with this current implementation:\n", + "\n", + "1. We're sending the task instruction for the model as part of the prompt, which is inefficient and may be subject to change (or attack) by users. \n", + "\n", + "2. There isn't a gaurantee the output generated text is consistent, it often changes format, being more verbose sometimes, and more direct other times, and may respond differently to non-weather related input.\n", + "\n", + "Both of these issues will make it difficult to use this generative model to make predictions for our ClassifAI candidate list in a controlled and repeatable way. We can introduce extra features that resolve these problems, respecitively:\n", + "\n", + "1. We can introduce a 'system prompt' which separate the instruction part of the input from the part of text that is to be classified\n", + "\n", + "2. The Gemini AI models offer a 'structured output' guarantee whereby we can instruct the model to output specific values in a formatted manner" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### System Prompts" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### System Prompts vs Regular Prompts\n", + "\n", + "A **system prompt** is a special type of prompt that defines the behavior, role, or rules for the generative AI model before processing user input. Unlike a regular prompt, which combines both the task instruction and the input text, a system prompt separates these concerns by providing a predefined context or instruction that remains consistent across multiple queries.\n", + "\n", + "#### Advantages of Using System Prompts:\n", + "1. **Consistency**: By separating the task instruction from the input text, system prompts ensure that the model behaves predictably and consistently across different inputs.\n", + "2. **Efficiency**: The instruction does not need to be repeated for every query, reducing redundancy and improving processing efficiency.\n", + "3. **Security**: System prompts are less prone to user manipulation or injection attacks, as the task rules are predefined and immutable.\n", + "4. **Scalability**: They allow for easier integration into pipelines, such as our ClassifAI workflow, where the same classification rules can be applied to multiple inputs without modification." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Google GenAI SDK allows system prompts to be added through the Python API, as shown in the following code cell.\n", + "\n", + "We can add some interesting system instructions just to show the effect of how this mechanism steers the generation process." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.genai import types\n", + "\n", + "response = client.models.generate_content(\n", + " model=chosen_model,\n", + " contents=\"Write a very short description about todays weather.\",\n", + " config=types.GenerateContentConfig(\n", + " system_instruction=\"After the response, include the phrase: 'Also, I love cauliflower!'. No matter what you are asked, you must include this phrase at the end of your response.'\"\n", + " ),\n", + ")\n", + "print(response.text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "More relevant to our ClassifAI classification task, we can use the system prompt/instruction to better define how we want a Generative model to generate responses for classification tasks. Below we reformulate our weather classification task so that we don't need to pass the instruction as part of the input data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt3 = \"\"\"\n", + "\n", + "Sunny weather and pleasant breeze\n", + "\n", + "\"\"\"\n", + "\n", + "response = client.models.generate_content(\n", + " model=chosen_model,\n", + " contents=prompt3,\n", + " config=types.GenerateContentConfig(\n", + " system_instruction=\"You are a weather classification agent. You classify input content text into 'good weather' or 'bad weather'. The content will be a description of weather. Only output one of those two responses. Do not include any other text in your response, just the classification.'\"\n", + " ),\n", + ")\n", + "\n", + "\n", + "# Print the response\n", + "print(\"Generated Text:\")\n", + "print(response.text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The system instruction can be useful in many scenarios - see further resources:\n", + "\n", + "- https://ai.google.dev/gemini-api/docs/text-generation\n", + "- https://github.com/google-gemini/cookbook/blob/main/quickstarts/System_instructions.ipynb\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sturctured Outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We've seen how we can add instruction to an AI model with the system prompt to make the Generative AI produce outputs that are more in line with classificaiton. But this isn't especially robust - sometimes the (regular) prompt or another factor can cause unexpected output - for a system that may rely on these outputs, such as if we want to use this AI with ClassifAI to make final classifications.\n", + "\n", + "In this section we explore a feature of the Genai API that allows us to enforce specific outputs, not just including that as part of an instruction in the system prompt." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Google Genai documenation gives the following example of how Pydantic can be used to provide a structure that the model should adhere to when generating text. You can read more about structured outputs and this example at: \n", + "* http://ai.google.dev/gemini-api/docs/structured-output?example=recipe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pydantic import BaseModel, Field\n", + "\n", + "\n", + "class Ingredient(BaseModel):\n", + " name: str = Field(description=\"Name of the ingredient.\")\n", + " quantity: str = Field(description=\"Quantity of the ingredient, including units.\")\n", + "\n", + "\n", + "class Recipe(BaseModel):\n", + " recipe_name: str = Field(description=\"The name of the recipe.\")\n", + " prep_time_minutes: int | None = Field(description=\"Optional time in minutes to prepare the recipe.\")\n", + " ingredients: list[Ingredient]\n", + " instructions: list[str]\n", + "\n", + "\n", + "googledemoprompt = \"\"\"\n", + "Please extract the recipe from the following text.\n", + "The user wants to make delicious chocolate chip cookies.\n", + "They need 2 and 1/4 cups of all-purpose flour, 1 teaspoon of baking soda,\n", + "1 teaspoon of salt, 1 cup of unsalted butter (softened), 3/4 cup of granulated sugar,\n", + "3/4 cup of packed brown sugar, 1 teaspoon of vanilla extract, and 2 large eggs.\n", + "For the best part, they'll need 2 cups of semisweet chocolate chips.\n", + "First, preheat the oven to 375°F (190°C). Then, in a small bowl, whisk together the flour,\n", + "baking soda, and salt. In a large bowl, cream together the butter, granulated sugar, and brown sugar\n", + "until light and fluffy. Beat in the vanilla and eggs, one at a time. Gradually beat in the dry\n", + "ingredients until just combined. Finally, stir in the chocolate chips. Drop by rounded tablespoons\n", + "onto ungreased baking sheets and bake for 9 to 11 minutes.\n", + "\"\"\"\n", + "\n", + "response = client.models.generate_content(\n", + " model=chosen_model,\n", + " contents=googledemoprompt,\n", + " config={\n", + " \"response_mime_type\": \"application/json\",\n", + " \"response_json_schema\": Recipe.model_json_schema(),\n", + " },\n", + ")\n", + "\n", + "recipe = Recipe.model_validate_json(response.text)\n", + "print(recipe)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can adapt this to our recipe extraction example into our weather classifier system, adding to our system prompt, so that it outputs a specific label for the prompt weather:\n", + "\n", + "* 1 - for good weather\n", + "* 0 - for bad weather\n", + "\n", + "We use a Pydantic model to create a `WeatherLabel` class which defines a schema for the Google Genai model to adhere to when generating a text response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt4 = \"\"\"\n", + "\n", + "Sunny weather and pleasant breeze\n", + "\n", + "\"\"\"\n", + "\n", + "\n", + "SYSTEM_INSTRUCTION = \"\"\"\n", + "You are a strict weather classifier agent. Your job is to classify the input text as either good weather or bad weather\n", + "Return JSON only that matches the provided schema.\n", + "\n", + "Labeling rules:\n", + "- label = 1 (good weather): positive/pleasant conditions (e.g., sunny, clear skies, warm, mild, calm).\n", + "- label = 0 (bad weather): negative/hazardous/unpleasant conditions (e.g., rain, storm, snow, ice, fog, extreme heat/cold, high winds).\n", + "- If both good and bad are mentioned, label based on the overall/most salient condition. If unclear, choose 0.\n", + "Do not include any extra keys.\n", + "\"\"\"\n", + "\n", + "\n", + "# A Py\n", + "class WeatherLabel(BaseModel):\n", + " label: int = Field(ge=0, le=1, description=\"1 = good weather, 0 = bad weather\")\n", + "\n", + "\n", + "response = client.models.generate_content(\n", + " model=chosen_model,\n", + " contents=prompt4,\n", + " config=types.GenerateContentConfig(\n", + " system_instruction=SYSTEM_INSTRUCTION,\n", + " response_mime_type=\"application/json\",\n", + " response_json_schema=WeatherLabel.model_json_schema(),\n", + " ),\n", + ")\n", + "\n", + "\n", + "# Print the response\n", + "print(\"Generated Text:\")\n", + "print(response.text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The output from the model now, should adhere much more strictly to text that looks suitable for classification in a traditional sense - predicting a class label value." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In practice:\n", + "\n", + "- `WeatherLabel`: a schema that enforces shape/types (label must be int 0/1).\n", + "- `SYSTEM_INSTRUCTION`: defines semantics of the task (“what is good vs bad weather”, potentially how to handle mixed cases).\n", + "\n", + "\n", + "We will use both of these features to build a Generative AI agent for classification with ClassifAI package.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 3: Hooks and Custom ClassifAI Workflows" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this part of the demo we briefly recap how custom logic can be added to the ClassifAI VectorStore Pipeline using hooks, which will be neccessay for adding a generative classifier." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There is a longer, dedicated notebook tutorial on creating Hooks and how to use then to define custom workflows available from the ClassifAI repo:\n", + "\n", + "https://github.com/datasciencecampus/classifai/blob/main/DEMO/custom_preprocessing_and_postprocessing_hooks.ipynb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### The Concept of Hooks in ClassifAI" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `VectorStore` is the central component of the ClassifAI package. It provides the ability to create Vector database with labelled text samples, and then do similarity search and other operations on that `VectorStore` including `search()`, `reverse_search()` and embedding raw text with the `embed()` method.\n", + "\n", + "Each of these methods, has an associated `dataclass` input and output objects - these are Pandas dataframe like objects that specify what data needs to be included when calling the various methods of the `VectorStore`. For example, the `search()` method takes as input a `VectorStoreSearchInput` object and returns a `VectorStoreSearchOutput` object, both of which specify specific columns of data. This ensures data is passsed to our API correctly, and also ensures that the reseponse from the `VectorStore` is generated correctly. This demo used `VectorStoreSearchInput` in part 1 of the demo to create inputs to the `search` method.\n", + "\n", + "Finally *Hooks* allow users of the package to write custom functions that modify the data in the dataclass objects at certain points in the package codebase - before and after a specific method is called. Hooks can apply any user-defined operation on a specific dataclass object, as long as they return a valid dataclass object of the same kind. For example, users could write a spell-checking function that corrects the text of queries input to the search method. See example hooks int he demo notebook referenced in the previous cells for implementation examples." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Below is an illustration of how a `VectorStoreSearchInput` object can be transformed in a 'remove punctuation' hook. Note that the column names and content of the dataframe remains consistent.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![ClassifAI Search Diagram](./files/vectorstore_search_dataflow.svg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the case of using Generative AI as a classifier, in this demo we set up the classifier as a Hook, calling the Genai client inside the hook function to operate on a set of results returned form the `VectorStore.search()` API." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example hook, capitalising the candidate result texts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def capitalise_candidate_texts(input_data):\n", + " # convert the result texts to CAPITAL letters\n", + " input_data[\"doc_text\"] = input_data[\"doc_text\"].str.upper()\n", + "\n", + " print(type(input_data))\n", + "\n", + " return input_data\n", + "\n", + "\n", + "# Initialise the vector store, pointing the demo to the demo test data\n", + "demo_vectorstore = VectorStore(\n", + " file_name=\"./data/fake_soc_dataset.csv\",\n", + " data_type=\"csv\",\n", + " vectoriser=demo_vectoriser,\n", + " output_dir=\"./demo_vdb\",\n", + " overwrite=True,\n", + " hooks={\"search_postprocess\": capitalise_candidate_texts},\n", + ")\n", + "\n", + "\n", + "query_df = VectorStoreSearchInput({\"id\": [1], \"query\": [\"apple merchant\"]})\n", + "\n", + "demo_vectorstore.search(query_df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You should see that in the resulting `VectorStoreSearchOutput` object that ll of the `doc_text` entries are fully capitalised, which was the action performed by our hook." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 4: Building a Generative Classifier for ClassifAI" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this section we will:\n", + "\n", + "1. Ceate a Generative AI agent that can do classifications on VectorStore results using a system prompt and structured output to guide the model,\n", + "\n", + "2. Create a Custom Hook function that passes the VectorStore Search results to the generateive model for classification,\n", + "\n", + "3. Write custom pre-processing and post-processing steps that will execute inside the hook,\n", + "\n", + "4. Use the Genai in a custom search postprocessing hook of a VectorStore to get final classification results." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Overview" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We plan to add our generative AI model in a post processsing hook to interact with search results and determine a final result.\n", + "\n", + "For this we will need to format the results data as a text prompt for the model, provide the model with instructions on how to do the classificatiin task and what format to output its prediction. Then we will also need to add additional logic to reduce the original dataframe down to a final classification based on the Generative model's output classificaiton." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Geneerative AI Schematic](./files/e2e_genai_hook.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining a System Prompt and Structured Output" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The VectorStoreSearchoutput search results are ranked, with a corresponding ID and text entry for each candidate. We design our system prompt to inform the model that it will receive N candidates and must select a corresponding ID for the entry it 'thinks' corresponds best to the input query. We instruct the model to output an ID from 1 to N, the ID corresponding to the candidate entry number.\n", + "\n", + "We also specify several guidelines to the model and then provide an example of how the results object will be presentred to the Generateive model, (we will write the code that formats the VectorStore results object to this XML format shortly.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def make_classification_system_prompt(n: int) -> str:\n", + " CLASSIFICATION_SYSTEM_PROMPT = \"\"\"You are an AI assistant designed to classify a user query based on the provided context. You will be provided with candidate entries retrieved from a knowledge base, each containing an ID and a text description. Your task is to analyze the user query and the text of the context entries to determine which of the entries best matches the user query.\n", + "\n", + " Guidelines:\n", + " 1. Always prioritize the provided context when making your classification.\n", + " 2. The context will be provided as an XML structure containing multiple entries. Each entry includes an ID and a text description.\n", + " 3. The IDs will be integer values from 1 to {n}, corresponding to the {n} candidate entries.\n", + " 4. Use the text of the entries to determine the most relevant classification for the user query.\n", + " 5. Your output must be a JSON object that adheres to the following schema:\n", + " - The JSON object must contain a single key, `classification`.\n", + " - The value of `classification` must be an integer between 1 and {n}, representing the ID of the best matching entry.\n", + " - If no classification can be determined due to ambiguity or insufficient information, the value of `classification` must be `-1`.\n", + "\n", + " Example of the required JSON output:\n", + " {{\n", + " 'classification': 1\n", + " }}\n", + "\n", + " The XML structure for the context and user query will be as follows:\n", + " \n", + " \n", + " 0\n", + " [Text from the first entry]\n", + " \n", + " \n", + " 1\n", + " [Text from the second entry]\n", + " \n", + " ...\n", + " \n", + " {n}\n", + " [Text from the fifth entry]\n", + " \n", + " \n", + "\n", + " \n", + " [The user query will be inserted here]\n", + " \n", + "\n", + " Your task is to analyze the context and the user query, and return the classification in the required structured format.\n", + " \"\"\"\n", + "\n", + " return CLASSIFICATION_SYSTEM_PROMPT.format(n=n)\n", + "\n", + "\n", + "system_prompt = make_classification_system_prompt(n=7)\n", + "print(system_prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we specify a Pydantic class model that will guide the structure of the output. Similar to our ealier example, we now want the generative model to output an ID between 1 and N, adhering the the System Prompt above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Literal\n", + "\n", + "from pydantic import BaseModel, Field, conint, create_model\n", + "\n", + "\n", + "def make_classification_model(n: int) -> type[BaseModel]:\n", + " if n < 1:\n", + " raise ValueError(\"n must be >= 1\")\n", + "\n", + " PositiveId = conint(ge=1, le=n) # type: ignore[valid-type]\n", + "\n", + " return create_model(\n", + " \"ClassificationResponseModel\",\n", + " classification=(\n", + " Literal[-1] | PositiveId,\n", + " Field(description=f\"-1 if unclassifiable, else an integer in [1, {n}] (no 0).\"),\n", + " ),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case we have created a funcition that instructs the model to output either -1, or a value between 1 and N where N can be dynamically scaled depending on how many results are passed to the model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pre-processing and post processing for the GenAI agent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now need methods for:\n", + "\n", + "- formatting the query and results object to string format to present to the agent\n", + "- formatting the response from the agent and selecting the final result based on the agent generation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our query-results formatting is straight forward - we want to convert from the VectorStoreSearchOutput object to the XML format specified earlier in the system prompt. Python code function to do this is as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "\n", + "def format_prompt_with_vectorstore_results(results_df) -> str:\n", + " # Extract the user query (assuming all rows have the same query_id and query_text)\n", + " user_query = results_df[\"query_text\"].iloc[0]\n", + "\n", + " # Build the section\n", + " context_entries = \"\\n\".join(\n", + " f\" \\n {idx + 1}\\n {row['doc_text']}\\n \"\n", + " for idx, row in results_df.iterrows()\n", + " )\n", + "\n", + " # Combine everything into the final prompt\n", + " formatted_prompt = f\"\"\"\n", + "\n", + "{context_entries}\n", + "\n", + "\n", + "\n", + " {user_query}\n", + "\"\"\"\n", + "\n", + " return formatted_prompt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Creating a dummy VectorStoreSearchOutput object we can see how the above function transforms the data that will be passed as the prompt to the agent model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers.dataclasses import VectorStoreSearchOutput\n", + "\n", + "# Create a sample dataframe adhering to the searchOutputSchema\n", + "sample_data = {\n", + " \"query_id\": [\"1\", \"1\", \"1\", \"1\", \"1\"],\n", + " \"query_text\": [\n", + " \"a photographer hired for wedding events\",\n", + " \"a photographer hired for wedding events\",\n", + " \"a photographer hired for wedding events\",\n", + " \"a photographer hired for wedding events\",\n", + " \"a photographer hired for wedding events\",\n", + " ],\n", + " \"doc_id\": [\"701\", \"702\", \"703\", \"704\", \"705\"],\n", + " \"doc_text\": [\n", + " \"Wedding photographer available for hire.\",\n", + " \"Professional event photographer for weddings.\",\n", + " \"Experienced celebration event photographer.\",\n", + " \"Photographer specializing in nature photography.\",\n", + " \"A fashion model photographer.\",\n", + " ],\n", + " \"rank\": [0, 1, 2, 3, 4],\n", + " \"score\": [0.95, 0.90, 0.85, 0.80, 0.75],\n", + "}\n", + "\n", + "# Wrap the validated dataframe in the VectorStoreSearchOutput class\n", + "sample_results_object = VectorStoreSearchOutput(sample_data)\n", + "\n", + "\n", + "# pass the sample results object to the formatting function and print the output\n", + "print(format_prompt_with_vectorstore_results(sample_results_object))\n", + "\n", + "print(\"\\n\\n-----\\n\\n\")\n", + "# showing the original result object for compariosn with the XML string\n", + "sample_results_object" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Focusing on the output of the model, we expect either a value of -1 or a value between 1 AND N. We write a function that validates that output, and then reduces the original VectorStoreSearchOutput object down to a single row result, based on the output label. If the agent returns -1 or an invalid response, we design this function to play it safe and return the original results object with no modifications. Although, the design choice here can be adapted for other functionalities - such as just returning the top ranked item with no agent assessment or otherwise." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "\n", + "def format_agent_classification(\n", + " agent_generated_text: str, results_df: VectorStoreSearchOutput\n", + ") -> VectorStoreSearchOutput:\n", + " # Parse the generated text\n", + " try:\n", + " response = json.loads(agent_generated_text)\n", + " validation_model = make_classification_model(n=results_df.shape[0])\n", + " validated_response = validation_model(**response)\n", + " except (json.JSONDecodeError, ValueError):\n", + " # If parsing or validation fails, return the original DataFrame\n", + " return results_df\n", + "\n", + " # Extract the classification\n", + " classification = validated_response.classification\n", + "\n", + " # Validate the classification value is in the expected range\n", + " MIN_INDEX = 1\n", + " MAX_INDEX = len(results_df)\n", + " if int(classification) < MIN_INDEX or int(classification) > MAX_INDEX:\n", + " return results_df\n", + "\n", + " # Otherwise, filter to only keep the row with the classified doc_id, adjusting for 1-based 0-indexing in the classification\n", + " result = results_df.iloc[[classification - 1]].reset_index(drop=True)\n", + "\n", + " return VectorStoreSearchOutput(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can try out this function as well using our example result_df and for now a manually generated agent json response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "genai_response = {\"classification\": \"3\"}\n", + "\n", + "classified_result_df = format_agent_classification(json.dumps(genai_response), sample_results_object)\n", + "\n", + "classified_result_df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you change the classification label value, you'll see different results being selected, and if the classificaiton label is set to -1, 0 or greater than 5, or some other invalid value it will return the full dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Instantiating the Agent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now want to pull all this together and demonstrate the generative agent performing the above task." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google import genai\n", + "from google.genai import types\n", + "\n", + "# Set up the API key for authentication\n", + "client = genai.Client(api_key=\"YOUR API KEY HERRE\")\n", + "\n", + "# Define the model to use\n", + "chosen_model = \"gemini-2.5-flash\" # Replace with the appropriate Gemini Flash 2.5 model name if different\n", + "\n", + "\n", + "prompt5 = format_prompt_with_vectorstore_results(sample_results_object)\n", + "\n", + "\n", + "response = client.models.generate_content(\n", + " model=chosen_model,\n", + " contents=prompt5,\n", + " config=types.GenerateContentConfig(\n", + " system_instruction=make_classification_system_prompt(n=len(sample_results_object)),\n", + " response_mime_type=\"application/json\",\n", + " response_json_schema=make_classification_model(n=len(sample_results_object)).model_json_schema(),\n", + " ),\n", + ")\n", + "\n", + "\n", + "classification_df = format_agent_classification(response.text, sample_results_object)\n", + "\n", + "classification_df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Done! The model does take a few seconds to consider, but here we have a VectorStore result, being formatted into a prompt, the model is provided a system instruction and structured Pydantic guidance on what to produce with the provided information, and a post-processsing function that outputs the answer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Building a Hook to process results with the agent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "All thats left to do is to code the post-processing search() method hook, which we be called to operate on the VectorStoreSearchOutput object after the method returns that object. We provide the full hook implementation in the next with extra comments, typing together the various components we've discussed up until now" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# reinstiating the genai Client outside of the function to avoid re-instantiating on every call\n", + "client = genai.Client(api_key=\"YOUR API KEY HERE\")\n", + "chosen_model = \"gemini-2.5-flash\" # we could choose another model, for example gemini-3 vraiants are available\n", + "\n", + "# we've alreadt defined the CLASSIFICATION_SYSTEM_PROMPT, format_prompt_with_vectorstore_results and the make_classification_model function, so we recan use those directly without redefining\n", + "\n", + "\n", + "def rag_classifier(input_data: VectorStoreSearchOutput) -> VectorStoreSearchOutput:\n", + " # first our result df might have multiple queries in it, so we want to group by query_id and query_text to ensure we classify each unique query separately\n", + " per_query = list(input_data.groupby([\"query_id\"]))\n", + "\n", + " # creating an empty list to store the classified results for each query\n", + " classified_results = []\n", + "\n", + " # iterate over the object for each query\n", + " for _, group in per_query:\n", + " # first we want to set a manual maximum limit on how many samples the model can classify, as the more samples we include the more context we provide but also the more expensive the model call will be. Here we take the top 5 results for each query, but this could be adjusted based on the expected number of results and the cost/benefit tradeoff of providing more context to the model\n", + " # and code can handle any value of N, this is a practical limit relating to the generateive models context window and effectiveness.\n", + " group = group.head(5) # noqa: PLW2901\n", + "\n", + " # pass the group dataframe to the formatting function to get the prompt for this query\n", + " prompt = format_prompt_with_vectorstore_results(group)\n", + "\n", + " # send the prompt to the model and get the response, passing the same system instruction and structured json schema as before\n", + " response = client.models.generate_content(\n", + " model=chosen_model,\n", + " contents=prompt,\n", + " config=types.GenerateContentConfig(\n", + " # there are many other config options we could experiment with here, for example temperature, nucleas sampling, top_k, top_p, and max_response_tokens\n", + " system_instruction=make_classification_system_prompt(n=group.shape[0]),\n", + " response_mime_type=\"application/json\",\n", + " response_json_schema=make_classification_model(n=group.shape[0]).model_json_schema(),\n", + " ),\n", + " )\n", + "\n", + " # we try to pass the result to our postprocessing function which will either return the chosen classificaiton, or on failure return the original result\n", + " classified_group = format_agent_classification(response.text, group)\n", + "\n", + " # the classification for that query is added to the classified_results list\n", + " classified_results.append(classified_group)\n", + "\n", + " # we convert the results to a single dataframe\n", + " final_result = pd.concat(classified_results).reset_index(drop=True)\n", + "\n", + " # we wrap the final result in the VectorStoreSearchOutput class to ensure it adheres to the expected schema\n", + " hook_output = VectorStoreSearchOutput(final_result)\n", + "\n", + " return hook_output" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets see how that works when we apply it to a VectorStore as a Hook!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers import VectorStore\n", + "from classifai.vectorisers import GcpVectoriser\n", + "\n", + "# Initialise the vectoriser\n", + "demo_vectoriser = GcpVectoriser(project_id=\"YOUR PROJECT ID HERE\", vertexai=True)\n", + "\n", + "# Initialise the vector store, pointing the demo to the demo test data\n", + "demo_vectorstore = VectorStore(\n", + " file_name=\"./data/fake_soc_dataset.csv\",\n", + " data_type=\"csv\",\n", + " vectoriser=demo_vectoriser,\n", + " output_dir=\"./demo_vdb\",\n", + " overwrite=True,\n", + " hooks={\"search_postprocess\": rag_classifier},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers import VectorStoreSearchInput\n", + "\n", + "search_input = VectorStoreSearchInput(\n", + " {\"id\": [\"1\", \"2\"], \"query\": [\"a photographer hired for wedding events\", \"Tax expert advisor\"]}\n", + ")\n", + "\n", + "result = demo_vectorstore.search(search_input, n_results=5)\n", + "\n", + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Thats it! Without the added generative hook we would see a set of candidate results in each case here. Dynamically removing the genai hook we can see what results were considered by the Generative RAG model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "demo_vectorstore.hooks = {}\n", + "result_no_genai = demo_vectorstore.search(search_input, n_results=5)\n", + "result_no_genai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## That's it! We've walked through:\n", + "\n", + "- Establishing connections to Google Generative AI models,\n", + "\n", + "- Configuring AI model input and output for Classification tasks and how models should behave,\n", + "\n", + "- Setting up pre- and post-processing functions to handle the model input and output.\n", + "\n", + "In this last part of this notebook (below) we provide a full code implementation for a VectorStore, and GenAI classifier hook, but here are some additional final points to consider:\n", + "\n", + "- For each query and result, we call the Google API sequentially - which is fine but can be slow - in the below final imp we show how it can be done in an asynchronous manner for each query to speed things up.\n", + "\n", + "- We didn't add any additional checks to make sure the user query prompt would not break the pre-processing formatting function , such as the user prompt containing punctuation that may break the formatter.\n", + "\n", + "- There are many different models available to consider with different trade offs in cost and capability as well as many metadata arguments that can affect performance, while our guide shows good practices using systems prompts and other features it is worth taking time to reviewing the Google documentation on how best to use these models and how to acheive constistent performance." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Typing it all together" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following code cells provide a complete implementation of the Generative AI RAG agent, (reimporting dependecies and reinitialising all variables)\n", + "\n", + "In this version we provide one key change with an _asyncronous_ version of the API calls to the genai client, to speed up the processing of queries. Note that ClassifAI does not currently support async code in hooks, therefore we nest our agent code in an outer syncronous function.\n", + "\n", + "`Warning:` To run this final section of the code you will need to execute it outside of the Notebook environments, as it contains async code that does not operate nicely within the Jupyter Notebook environment. We recommend exceuting in an external script with ClassifAI[gcp] installed in the environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "import json\n", + "\n", + "from google import genai\n", + "from google.genai import types\n", + "from pydantic import BaseModel, Field\n", + "\n", + "from classifai.indexers.dataclasses import VectorStoreSearchInput, VectorStoreSearchOutput\n", + "\n", + "\n", + "# thia function generates the system prompt for the classification task, it takes the number of candidate entries as input and returns a formatted string that includes instructions for the model on how to classify the user query based on the provided context. The prompt specifies the expected format of the input and output, and provides guidelines for how the model should prioritize the context when making its classification\n", + "def make_classification_system_prompt(n: int) -> str:\n", + " CLASSIFICATION_SYSTEM_PROMPT = \"\"\"You are an AI assistant designed to classify a user query based on the provided context. You will be provided with candidate entries retrieved from a knowledge base, each containing an ID and a text description. Your task is to analyze the user query and the text of the context entries to determine which of the entries best matches the user query.\n", + "\n", + " Guidelines:\n", + " 1. Always prioritize the provided context when making your classification.\n", + " 2. The context will be provided as an XML structure containing multiple entries. Each entry includes an ID and a text description.\n", + " 3. The IDs will be integer values from 1 to {n}, corresponding to the {n} candidate entries.\n", + " 4. Use the text of the entries to determine the most relevant classification for the user query.\n", + " 5. Your output must be a JSON object that adheres to the following schema:\n", + " - The JSON object must contain a single key, `classification`.\n", + " - The value of `classification` must be an integer between 1 and {n}, representing the ID of the best matching entry.\n", + " - If no classification can be determined due to ambiguity or insufficient information, the value of `classification` must be `-1`.\n", + "\n", + " Example of the required JSON output:\n", + " {{\n", + " 'classification': 1\n", + " }}\n", + "\n", + " The XML structure for the context and user query will be as follows:\n", + " \n", + " \n", + " 0\n", + " [Text from the first entry]\n", + " \n", + " \n", + " 1\n", + " [Text from the second entry]\n", + " \n", + " ...\n", + " \n", + " {n}\n", + " [Text from the fifth entry]\n", + " \n", + " \n", + "\n", + " \n", + " [The user query will be inserted here]\n", + " \n", + "\n", + " Your task is to analyze the context and the user query, and return the classification in the required structured format.\n", + " \"\"\"\n", + "\n", + " return CLASSIFICATION_SYSTEM_PROMPT.format(n=n)\n", + "\n", + "\n", + "# this function generates a Pydantic model for validating the model's response, it takes the number of candidate entries as input and returns a dynamically created Pydantic model class that includes a single field 'classification' which can be either -1 or an integer between 1 and n. This model is used to ensure that the response from the generative model adheres to the expected format and value constraints\n", + "def make_classification_model(n: int) -> type[BaseModel]:\n", + " if n < 1:\n", + " raise ValueError(\"n must be >= 1\")\n", + "\n", + " PositiveId = conint(ge=1, le=n) # type: ignore[valid-type]\n", + "\n", + " return create_model(\n", + " \"ClassificationResponseModel\",\n", + " classification=(\n", + " Literal[-1] | PositiveId,\n", + " Field(description=f\"-1 if unclassifiable, else an integer in [1, {n}] (no 0).\"),\n", + " ),\n", + " )\n", + "\n", + "\n", + "# this function convertes the VectorStoreSearchOutput to a formatted string that can be passed as a prompt to the generative model, it extracts the user query and the candidate entries and formats them into the required XML structure as specified in the system prompt\n", + "def format_prompt_with_vectorstore_results(results_df) -> str:\n", + " # Extract the user query (assuming all rows have the same query_id and query_text)\n", + " user_query = results_df[\"query_text\"].iloc[0]\n", + "\n", + " # Build the section\n", + " context_entries = \"\\n\".join(\n", + " f\" \\n {idx + 1}\\n {row['doc_text']}\\n \"\n", + " for idx, row in results_df.iterrows()\n", + " )\n", + "\n", + " # Combine everything into the final prompt\n", + " formatted_prompt = f\"\"\"\n", + "\n", + "{context_entries}\n", + "\n", + "\n", + "\n", + " {user_query}\n", + "\"\"\"\n", + "\n", + " return formatted_prompt\n", + "\n", + "\n", + "# this function takes the generative model's response and the original results dataframe, it attempts to parse and validate the model's response, and if successful it filters the original results dataframe to return only the row corresponding to the classification returned by the model. If parsing or validation fails, or if the classification is out of range, it returns the original results dataframe unfiltered\n", + "def format_agent_classification(\n", + " agent_generated_text: str, results_df: VectorStoreSearchOutput\n", + ") -> VectorStoreSearchOutput:\n", + " # Parse the generated text\n", + " try:\n", + " response = json.loads(agent_generated_text)\n", + " validation_model = make_classification_model(n=results_df.shape[0])\n", + " validated_response = validation_model(**response)\n", + " except (json.JSONDecodeError, ValueError):\n", + " # If parsing or validation fails, return the original DataFrame\n", + " return results_df\n", + "\n", + " # Extract the classification\n", + " classification = validated_response.classification\n", + "\n", + " # Validate the classification value is in the expected range\n", + " MIN_INDEX = 1\n", + " MAX_INDEX = len(results_df)\n", + " if int(classification) < MIN_INDEX or int(classification) > MAX_INDEX:\n", + " return results_df\n", + "\n", + " # Otherwise, filter to only keep the row with the classified doc_id, adjusting for 1-based 0-indexing in the classification\n", + " result = results_df.iloc[[classification - 1]].reset_index(drop=True)\n", + "\n", + " return VectorStoreSearchOutput(result)\n", + "\n", + "\n", + "# herre we restart a client instance to ensure we have access to the async interface, as per the SDK documentation you shared, we should create a new client instance for the async code and then close it after use to avoid any potential issues with reusing the same client instance across sync and async code\n", + "client = genai.Client(api_key=\"YOUR API KEY HERE\")\n", + "chosen_model = \"gemini-2.5-flash\" # could use another model here if desired, for example a gemini-3 variant\n", + "MAX_CONCURRENCY = 8\n", + "\n", + "\n", + "# this function is an async version of the classification function for a single query group, it takes the async client, the results dataframe for a single query, and a semaphore to limit concurrency. It formats the prompt, sends the request to the model, and returns the classified result\n", + "async def _classify_result(\n", + " aclient,\n", + " result: VectorStoreSearchOutput,\n", + " sem: asyncio.Semaphore,\n", + ") -> VectorStoreSearchOutput:\n", + " group = result.head(5)\n", + " prompt = format_prompt_with_vectorstore_results(group)\n", + " n = group.shape[0]\n", + "\n", + " async with sem:\n", + " response = await aclient.models.generate_content(\n", + " model=chosen_model,\n", + " contents=prompt,\n", + " config=types.GenerateContentConfig(\n", + " system_instruction=make_classification_system_prompt(n=n),\n", + " response_mime_type=\"application/json\",\n", + " response_json_schema=make_classification_model(n=n).model_json_schema(),\n", + " ),\n", + " )\n", + "\n", + " return format_agent_classification(response.text, group)\n", + "\n", + "\n", + "# this funcion is the main async function that orchestrates the classification of the results, it groups the results by query, creates async tasks for each group, and gathers the results. Finally, it concatenates the classified results and returns them wrapped in the VectorStoreSearchOutput class\n", + "async def rag_classifier_async(\n", + " input_data: VectorStoreSearchOutput,\n", + ") -> VectorStoreSearchOutput:\n", + " # first our result df might have multiple queries in it, so we want to group by query_id and query_text to ensure we classify each unique query separately\n", + " per_query = [group for _, group in input_data.groupby([\"query_id\"])]\n", + "\n", + " # This is the SDK-provided async interface\n", + " aclient = client.aio\n", + " sem = asyncio.Semaphore(MAX_CONCURRENCY)\n", + "\n", + " try:\n", + " tasks = [_classify_result(aclient, group, sem) for group in per_query]\n", + " classified_results = await asyncio.gather(*tasks)\n", + " finally:\n", + " # Important per the SDK docs/snippet you pasted\n", + " await aclient.aclose()\n", + "\n", + " # we convert the results to a single dataframe\n", + " final_result = pd.concat(classified_results).reset_index(drop=True)\n", + " return VectorStoreSearchOutput(final_result)\n", + "\n", + "\n", + "# this function is a simple wrapper around the async classification function to allow it to be used as a hook in the VectorStore, it runs the async function using asyncio.run and returns the result\n", + "def rag_hook(input_data: VectorStoreSearchOutput) -> VectorStoreSearchOutput:\n", + " return asyncio.run(rag_classifier_async(input_data))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers import VectorStore\n", + "from classifai.vectorisers import GcpVectoriser\n", + "\n", + "# We instantiate a vectoriser (we could use any Vectoriser from ClassifAI here, but we'll use the GCP one to keep it consistent with the rest of the demo)\n", + "imp_vectoriser = GcpVectoriser(project_id=\"YOUR GCP PROJECT ID HERE\", vertexai=True)\n", + "\n", + "\n", + "# Build the vectorstore which points to the demo data and includes the RAG classification hook in the search_postprocess step, this means that every time we call search on this vectorstore, after it retrieves the candidate entries it will pass them to the rag_hook function which will classify the results using the Gemini model before returning them\n", + "imp_vectorstore = VectorStore(\n", + " file_name=\"./data/fake_soc_dataset.csv\",\n", + " data_type=\"csv\",\n", + " vectoriser=imp_vectoriser,\n", + " output_dir=\"./demo_vdb\",\n", + " overwrite=True,\n", + " hooks={\"search_postprocess\": rag_hook},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create a new demo input dataframe to showcase the RAG classification being performed asynchronously.\n", + "imp_search_input = VectorStoreSearchInput(\n", + " {\n", + " \"id\": [\"1\", \"2\", \"3\", \"4\", \"5\"],\n", + " \"query\": [\n", + " \"a photographer hired for wedding events\",\n", + " \"Tax expert advisor\",\n", + " \"farmer of potatoes\",\n", + " \"Tax expert advisor\",\n", + " \"farmer of potatoes\",\n", + " ],\n", + " }\n", + ")\n", + "\n", + "# finally run the VectorStore search method which will call the RAG classification hook we have defined on the search results\n", + "imp_vectorstore.search(imp_search_input, n_results=5)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "classifai", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index c384170..8f36fb6 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -12,6 +12,7 @@ - Batch processing of input files to handle large datasets. - Support for CSV file format (additional formats may be added in future updates). - Integration with a custom embedder for generating vector embeddings. +- Support for user-defined hooks for preprocessing and postprocessing. - Logging for tracking progress and handling errors during processing. Dependencies: