diff --git a/.gitignore b/.gitignore index 36f855d..cf5ef59 100644 --- a/.gitignore +++ b/.gitignore @@ -210,3 +210,19 @@ examples/RAG/compose/volumes # End of https://www.gitignore.io/api/vim,c++,cmake,python,synology /.idea + +# Ruff +.ruff_cache/ + +# ALM example - exclude large data/model files +examples/asset_lifecycle_management/data/ +examples/asset_lifecycle_management/moment/ +examples/asset_lifecycle_management/database/ +examples/asset_lifecycle_management/database_vanna/ + +# ALM example - exclude specific markdown files (keeping README.md files) +examples/asset_lifecycle_management/COMPARISON.md +examples/asset_lifecycle_management/INSTALLATION.md +examples/asset_lifecycle_management/MIGRATION_SUMMARY.md +examples/asset_lifecycle_management/README_MIGRATION.md +examples/asset_lifecycle_management/test_comparison_output/COMPARISON_SUMMARY.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..dfe1535 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,333 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Repository Overview + +This repository contains community examples for the NVIDIA NeMo Agent Toolkit (NAT). The repository is organized into three main directories: + +- **`examples/`** - Self-contained demonstrations of NAT features and integration patterns +- **`packages/`** - Reusable, production-ready NAT components that can be shared across examples and industries +- **`industries/`** - Complex, domain-specific workflows that showcase real-world applications + +All components use Python 3.11+ and are managed with `uv` for fast dependency resolution. + +## Building and Testing + +### Initial Setup + +```bash +# Clone and fetch LFS files +git clone https://github.com/NVIDIA/NeMo-Agent-Toolkit-Examples.git +cd NeMo-Agent-Toolkit-Examples +git lfs install +git lfs fetch +git lfs pull + +# Create virtual environment +uv venv --python 3.13 --seed .venv +source .venv/bin/activate + +# Install development dependencies +uv sync --dev +``` + +### Running Tests + +```bash +# Run tests for all components +pytest + +# Run tests for a specific component +pytest examples/mcp_rag_demo/tests/ +pytest packages/nat_vanna_tool/tests/ + +# Run with coverage +pytest --cov +``` + +### Linting and Code Style + +```bash +# Run all checks (pre-commit, pylint, copyright, documentation) +./ci/scripts/checks.sh + +# Run only Python checks (pylint) +./ci/scripts/python_checks.sh + +# Run pre-commit hooks manually +pre-commit run --all-files --show-diff-on-failure +``` + +The repository uses: +- **ruff** for linting and import sorting (configured in root `pyproject.toml`) +- **yapf** for code formatting (max line length: 120) +- **pylint** for Python code quality checks +- **vale** for documentation linting + +### Installing Individual Components + +```bash +# Install an example +uv pip install -e examples/mcp_rag_demo + +# Install a package +uv pip install -e packages/nat_vanna_tool + +# Install an industry workflow +uv pip install -e industries/asset_lifecycle_management + +# Install with optional dependencies +uv pip install -e "packages/nat_vanna_tool[elasticsearch,postgres]" +``` + +## Architecture + +### NAT Component Registration Pattern + +All NAT components follow a plugin-based registration pattern using Python entry points: + +1. **Define a config class** that inherits from `FunctionBaseConfig`: + +```python +from nat.data_models.function import FunctionBaseConfig +from pydantic import Field + +class MyToolConfig(FunctionBaseConfig, name="my_tool"): + """Configuration for my tool.""" + param1: str = Field(description="Parameter 1") + param2: int = Field(default=5, description="Parameter 2") +``` + +2. **Create a function with `@register_function` decorator**: + +```python +from nat.cli.register_workflow import register_function +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from pydantic import BaseModel, Field + +@register_function(config_type=MyToolConfig) +async def my_tool_function(config: MyToolConfig, builder: Builder): + """Tool implementation.""" + + class MyToolInput(BaseModel): + query: str = Field(description="User query") + + async def _execute(query: str) -> str: + # Access config: config.param1, config.param2 + # Access builder for LLMs/embedders: await builder.get_llm(name) + return f"Result: {query}" + + yield FunctionInfo.from_fn( + _execute, + input_schema=MyToolInput, + description="Description of what this tool does" + ) +``` + +3. **Create a `register.py` entry point** that imports the function: + +```python +# src/nat_my_component/register.py +from .my_tool import my_tool_function +``` + +4. **Declare the entry point in `pyproject.toml`**: + +```toml +[project.entry-points.'nat.components'] +nat_my_component = "nat_my_component.register" +``` + +### YAML Configuration Structure + +NAT workflows are configured via hierarchical YAML files: + +```yaml +general: + use_uvloop: true + telemetry: + logging: + console: + _type: console + level: DEBUG + +llms: # Define available LLMs + reasoning_llm: + _type: nvidia_chat_model + model_name: "meta/llama-3.1-70b-instruct" + api_key: "${NVIDIA_API_KEY}" + +embedders: # Define available embedders + my_embedder: + _type: nvidia_embeddings + model_name: "nvidia/nv-embedqa-e5-v5" + api_key: "${NVIDIA_API_KEY}" + +functions: # Define available tools + my_tool: + _type: my_tool # References the name from FunctionBaseConfig + param1: "value" + llm_name: "reasoning_llm" # Reference to llms section + +workflow: # Orchestrate the agent + _type: react_agent + tool_names: [my_tool] + llm_name: reasoning_llm + max_iterations: 10 +``` + +Components reference each other by name strings (e.g., `llm_name: "reasoning_llm"`), which are resolved at runtime by NAT's builder. + +### Directory Structure for New Components + +Use `nat workflow create` to generate the recommended structure: + +``` +examples/$EXAMPLE_NAME/ +├── configs/ # Symlink to src/nat_$EXAMPLE_NAME/configs/ +├── data/ # [Optional] Symlink to src/nat_$EXAMPLE_NAME/data/ +├── scripts/ # [Optional] Setup scripts +├── src/ +│ └── nat_$EXAMPLE_NAME/ # Module name must start with nat_ +│ ├── configs/ +│ │ └── config.yml +│ ├── data/ # [Optional] Data files +│ ├── __init__.py +│ └── register.py # Entry point +├── tests/ # pytest tests +├── README.md +└── pyproject.toml # Must register entry point +``` + +### Dependency Management + +- **Root `pyproject.toml`**: Defines shared linting rules and dev dependencies (pytest, ruff, yapf, vale) +- **Component `pyproject.toml`**: Each example/package/industry has its own dependencies +- **Version constraints**: Use at least 2-digit precision (e.g., `nvidia-nat~=1.2`, not `nvidia-nat==1`) +- **Local package dependencies**: Industries can reference packages via `[tool.uv.sources]`: + +```toml +[tool.uv.sources] +nat_vanna_tool = { path = "../../packages/nat_vanna_tool", editable = true } + +dependencies = [ + "nvidia-nat>=1.3.0", + "nat_vanna_tool", # References local package +] +``` + +### Multi-Framework Support + +Components can support multiple AI frameworks via `framework_wrappers`: + +```python +from nat.builder.builder import LLMFrameworkEnum + +@register_function( + config_type=MyToolConfig, + framework_wrappers=[LLMFrameworkEnum.LANGCHAIN] +) +async def my_tool(config: MyToolConfig, builder: Builder): + # Get LangChain-wrapped embedder + embedder = await builder.get_embedder( + config.embedder_name, + wrapper_type=LLMFrameworkEnum.LANGCHAIN + ) +``` + +## Common Development Tasks + +### Running NAT Workflows + +```bash +# Run a workflow with a config file +nat run --config_file examples/mcp_rag_demo/configs/support-ui.yml --input "your query" + +# Start an MCP server +nat mcp serve --config_file examples/mcp_rag_demo/configs/support-ui.yml --port 9904 + +# Start a UI server +nat serve --config_file examples/mcp_rag_demo/configs/mcp-client-for-ui.yml --port 8000 +``` + +### Adding a New Example + +1. Generate the structure: `nat workflow create` +2. Add dependencies to `pyproject.toml` with version constraints +3. Define config schema in `src/nat_$NAME/config_schemas.py` +4. Implement the function with `@register_function` decorator +5. Import the function in `register.py` +6. Add entry point to `pyproject.toml` +7. Create config files in `configs/` +8. Write tests in `tests/` using pytest +9. Document in README.md using `examples/.template/README.md` as a starting point + +### Creating a Reusable Package + +1. Create in `packages/nat_$NAME/` +2. Follow the same pattern as examples, but focus on reusability +3. Support optional dependencies via `[project.optional-dependencies]` +4. Document all configuration options in README.md +5. Examples and industries can then reference it via `[tool.uv.sources]` + +### Working with Existing Components + +When modifying existing examples/packages: +- Read the component's README.md first +- Check `pyproject.toml` for dependencies and entry points +- Look at `register.py` to see what functions are registered +- Review config files in `configs/` to understand usage +- Run tests after changes: `pytest path/to/component/tests/` + +## Important Notes + +### Git Commit Signing + +**All commits to this repository must be signed with GPG.** + +```bash +# Create a commit with GPG signature +git commit -S -m "Your commit message" + +# Amend the last commit to add a signature +git commit --amend -S --no-edit + +# Configure Git to always sign commits +git config --global commit.gpgsign true +``` + +When creating commits, always include: +- Clear, descriptive commit message +- Co-authored-by line: `Co-Authored-By: Claude Sonnet 4.5 ` when working with Claude + +### Code Style + +- All Python code must be compatible with ruff linting rules defined in root `pyproject.toml` +- Maximum line length: 120 characters +- Use isort for import organization (ruff handles this) +- Known first-party packages: `aiq`, `nat`, `nat_*`, `_utils` +- Known third-party packages: `agno`, `crewai`, `langchain`, `llama_index`, `mem0ai`, `redis`, `semantic_kernel`, `zep_cloud` + +### Licensing + +- All contributions must be Apache 2.0 licensed +- All files must include the SPDX copyright header (checked by CI) +- All dependencies must have Apache 2.0 compatible licenses + +### Security + +- Input validation is critical for tools that interact with databases or external systems +- Use whitelisting for categories, priorities, and other enum-like inputs (see `query_by_category_tool` in mcp_rag_demo for an example) +- Never include API keys or credentials in code or config files - use environment variables + +### Documentation Requirements + +Each component must include: +- Description of what the component does +- Key features +- Setup instructions +- How to run the component +- Expected results +- Troubleshooting tips (optional but recommended) diff --git a/examples/asset_lifecycle_management/.gitignore b/examples/asset_lifecycle_management/.gitignore new file mode 100644 index 0000000..b96852e --- /dev/null +++ b/examples/asset_lifecycle_management/.gitignore @@ -0,0 +1,21 @@ +# SQL tool exploration and comparison files (deferred for future work) +new_sql_tool_exploration/ + +# Output data files +output_data/*.json + +# Session-specific documentation +CLAUDE.md + +# Config documentation (consolidated into README.md) +configs/*.md +!configs/README.md + +# Test scripts +test_alm_workflow.py +test_e2b_sandbox.py + +# Evaluation and test output folders +eval_output_old/ +example_eval_output/ +test_comparison_output/ diff --git a/examples/asset_lifecycle_management/README.md b/examples/asset_lifecycle_management/README.md new file mode 100644 index 0000000..d38e4f5 --- /dev/null +++ b/examples/asset_lifecycle_management/README.md @@ -0,0 +1,911 @@ +# Asset Lifecycle Management Agent + +An AI-powered system for managing industrial assets throughout their lifecycle, built with NeMo Agent Toolkit. Currently focused on predictive maintenance for turbofan engines with plans to expand to full lifecycle management. + +Work done by: Vineeth Kalluru, Janaki Vamaraju, Sugandha Sharma, Ze Yang, and Viraj Modak + +## Overview + +Asset Lifecycle Management (ALM) spans acquisition, operation, upgrades, and retirement of industrial assets. This project delivers an agentic workflow that applies ALM ideas to real data. Today it focuses on the operation and maintenance slice: using time‑series sensor data to predict remaining useful life (RUL), detect anomalies, and recommend next steps. We use the NASA C‑MAPSS turbofan dataset as a practical, well‑studied benchmark with realistic signals and run‑to‑failure trajectories. The system is modular and backed by SQL (SQLite by default, PostgreSQL/MySQL supported), so extending into planning, commissioning, optimization, and decommissioning is straightforward as additional tools and integrations are added. + +## Dataset + +Uses the **NASA Turbofan Engine Degradation Simulation Dataset (C-MAPSS)** with: +- **21 Sensor Measurements**: Temperature, pressure, vibration, and flow +- **3 Operational Settings**: Different flight conditions +- **Multiple Engine Units**: Each with unique degradation patterns +- **Run-to-Failure Data**: Complete lifecycle from healthy operation to failure + +## Architecture + +Multi-agent architecture designed for Asset Lifecycle Management with specialized tools for the Operation & Maintenance phase: +- **ReAct Agent Workflow**: Main orchestration using ReAct pattern for intelligent decision-making +- **SQL Retriever Tool**: Generates SQL queries using NIM LLM for asset data retrieval +- **RUL Prediction Tool**: XGBoost model for remaining useful life prediction to optimize maintenance scheduling +- **Anomaly Detection Tool**: Detects anomalies in sensor data using time series foundation models (MOMENT-1-Large by default, NV Tesseract NIM available as alternative) for early failure detection +- **Plotting Agents**: Multi-tool agent for data visualization and asset performance reporting +- **Vector Database**: ChromaDB for storing table schema, Vanna training queries, and asset documentation + +This architecture provides the foundation for comprehensive asset health monitoring, enabling data-driven maintenance decisions and extending asset operational life. + +#### Agentic workflow architecture diagram w/ reasoning +![Agentic workflow w/ reasoning](imgs/pdm_agentic_worklow_light.png) + +## Setup and Installation + +> 📖 **For detailed installation instructions including database setup (PostgreSQL, MySQL, SQLite) and vector store configuration (ChromaDB, Elasticsearch), see [INSTALLATION.md](INSTALLATION.md)** + +### Prerequisites +- Python 3.11+ (< 3.13) +- Conda or Miniconda +- NVIDIA NIM API access +- Node.js v18+ (for web interface) + +### Hardware Requirements + +**CPU:** +- Minimum: 8 cores, 16GB RAM + +**GPU:** + +| Model Name | Minimum GPU Requirement | +|------------------------------------------|---------------------------------| +| qwen/qwen2.5-coder-32b-instruct | 2×A100 or 1×H100 | +| nvidia/llama-3.3-nemotron-super-49b-v1 | 2×A100 or 1×H100 | +| nvidia/llama-3.1-nemotron-nano-vl-8b-v1 | 1×A100 or 1×H100 | +| nvidia/nv-embed-v1 | 1×L40s or 1×A100 | + +- GPU is **not required** if you are using NVIDIA NIM cloud APIs. +- For local deployment, the table above lists the minimum recommended GPUs for each model. +- For most applications, a system with 8×A100 or 4×H100 GPUs will be more than sufficient (application sizing is still being finalized). + +**Operating System:** +- Linux (Ubuntu 20.04+ recommended) +- macOS (Intel or Apple Silicon) - Tested on MAC M2 Pro 32GB RAM +- Windows 10/11 with WSL2 + +**Storage:** +- **Base Installation**: 250MB (includes database, MOMENT library, and source code) +- **NASA Dataset**: ~50-100MB (when downloaded from Kaggle) +- **Working Space**: 500MB-1GB (for logs, additional outputs, temporary files) +- **Recommended Total**: 2-3GB free space for comfortable operation +**Approximate Memory Requirements for Local Model Storage (if not using NVIDIA endpoints):** + +| Model Name | Approx. Memory Requirement (BF16) | +|-----------------------------------------|-----------------------------------| +| qwen/qwen2.5-coder-32b-instruct | 64 GB | +| nvidia/llama-3.3-nemotron-super-49b-v1 | 98 GB | +| nvidia/llama-3.1-nemotron-nano-vl-8b-v1 | 16 GB | +| nvidia/nv-embed-v1 | ~14 GB | + +### 1. Create Conda Environment + +```bash +conda create -n alm python=3.11 +conda activate alm +``` + +### 2. Install NVIDIA NeMo Agent Toolkit + +1. Clone the NeMo Agent Toolkit repository version 1.2.1 to your local machine: + ```bash + git clone --branch v1.2.1 https://github.com/NVIDIA/NeMo-Agent-Toolkit.git nat-toolkit + cd nat-toolkit + ``` + +2. Initialize, fetch, and update submodules in the Git repository: + ```bash + git submodule update --init --recursive + ``` + +3. Fetch the datasets by downloading the LFS files: + ```bash + git lfs install + git lfs fetch + git lfs pull + ``` +4. Install the NeMo Agent Toolkit library: + To install the NeMo Agent Toolkit library along with all optional dependencies, including developer tools (`--all-groups`) and all dependencies needed for profiling and plugins (`--all-extras`) in the source repository, run the following: + ```bash + uv sync --all-groups --all-extras + ``` + +5. Install telemetry plugins: + ```bash + uv pip install -e '.[telemetry]' + ``` + +### 3. Install Asset Lifecycle Management Agent + +First, clone the GenerativeAIExamples repository inside the parent folder of NeMo-Agent-Toolkit and navigate to the Asset Lifecycle Management Agent folder: + +```bash +git clone https://github.com/NVIDIA/GenerativeAIExamples.git +cd GenerativeAIExamples/industries/manufacturing/asset_lifecycle_management_agent +``` + +Clone the MOMENT library from GitHub inside this Asset Lifecycle Management Agent folder. +This library is required to perform inference with MOMENT-1 time series foundational models for anomaly detection tasks during the Operation & Maintenance phase. More about it [here](https://huggingface.co/AutonLab/MOMENT-1-small). + +```bash +git clone https://github.com/moment-timeseries-foundation-model/moment.git +``` + +Change the pyproject.toml file inside the cloned library: + +```bash +cd moment +vi pyproject.toml +``` + +Change the NumPy and Transformers dependencies: + +```bash +... +dependencies = [ + "huggingface-hub==0.24.0", + "numpy==1.25.2", # --> to "numpy==1.26.2" + "torch~=2.0", + "transformers==4.33.3", # --> to "transformers>=4.33.3,<5.0.0" +] +... +``` + +Go back to the Asset Lifecycle Management Agent folder: + +```bash +cd .. +``` + +Change the path to the cloned MOMENT library in `/path/to/asset_lifecycle_management_agent/pyproject.toml` if necessary. + +Change it from: +```bash +[tool.uv.sources] +momentfm = { path = "/Users/vikalluru/Documents/GenerativeAIExamples/industries/manufacturing/asset_lifecycle_management_agent/moment", editable = true } +``` +to: +```bash +[tool.uv.sources] +momentfm = { path = "/your/path/to/asset_lifecycle_management_agent/moment", editable = true } +``` + +This ensures that the MOMENT library will be installed from our cloned version instead of the PyPI release. +Now install the ALM workflow: + +```bash +uv pip install -e . +``` + +#### Installation Options + +**Base Installation** (default - includes ChromaDB + SQLite): +```bash +uv pip install -e . +``` + +**Optional Database Support:** +- PostgreSQL: `uv pip install -e ".[postgres]"` +- MySQL: `uv pip install -e ".[mysql]"` +- All databases: `uv pip install -e ".[all-databases]"` + +**Optional Vector Store:** +- Elasticsearch: `uv pip install -e ".[elasticsearch]"` + +### [Optional] Verify if all prerequisite packages are installed +```bash +uv pip list | grep -E "nvidia-nat|nvidia-nat-ragaai|nvidia-nat-phoenix|vanna|chromadb|xgboost|pytest|torch|matplotlib" +``` + +### 4. Database Setup + +1. Download the [NASA Turbofan Dataset](https://www.kaggle.com/datasets/behrad3d/nasa-cmaps) +2. Extract files to the `data/` directory +3. Run the setup script: +```bash +python setup_database.py +``` + +### 5. Configure Paths + +**Important**: You need to replace the absolute path `/Users/vikalluru/Documents/GenerativeAIExamples/industries/manufacturing/asset_lifecycle_management_agent/` with your preferred workspace path in the following files: + +1. **`configs/config-reasoning.yml`** - Update the `db_path` and `output_folder` paths +2. **`pyproject.toml`** - Update the MOMENT library path (if you changed it in step 3) + +For example, if your workspace is at `/home/user/my_workspace/`, you would replace: +- `/Users/vikalluru/Documents/GenerativeAIExamples/industries/manufacturing/asset_lifecycle_management_agent/` +- with `/home/user/my_workspace/` + +**Note**: All other paths in the config file can be provided as relative paths from your workspace directory. Only the MOMENT library path in `pyproject.toml` needs to be an absolute path. + +The `db_path` should point to the database inside your data directory: +```yaml +db_path: "data/nasa_turbo.db" +``` + +#### Output Folder Configuration + +Create an empty folder for the output data and configure the `output_folder` path. You have two options: + +**Option 1: Relative Path (Recommended)** +```yaml +output_folder: "output_data" +``` +- Path is relative to where you run the workflow +- **Recommended**: Always run the workflow from the `asset_lifecycle_management_agent/` directory +- Creates `output_data/` folder in your project directory + +**Option 2: Absolute Path** +```yaml +output_folder: "/absolute/path/to/your/output_data" +``` +- Works regardless of where you run the workflow from +- Provides consistent output location + +**Best Practice**: We recommend using relative paths and always running the workflow from the `asset_lifecycle_management_agent/` directory: + +```bash +cd /path/to/GenerativeAIExamples/industries/manufacturing/asset_lifecycle_management_agent/ +# Run all workflow commands from here +nat serve --config_file=configs/config-reasoning.yml +``` + +This ensures that: +- All relative paths work correctly +- Output files are organized within your project +- Configuration remains portable across different machines + +#### Setting Up Workspace Utilities + +**IMPORTANT**: The code generation assistant requires a `utils` folder inside your `output_data` directory for RUL transformation tasks. + +**Setup Instructions:** + +1. Create the output_data directory (if it doesn't exist): +```bash +mkdir -p output_data +``` + +2. Copy the pre-built utility functions from the template: +```bash +cp -r utils_template output_data/utils +``` + +3. Verify the setup: +```bash +ls output_data/utils/ +# You should see: __init__.py rul_utils.py +``` + +**What's included:** +- `apply_piecewise_rul_transformation(df, maxlife=100)` - Transforms RUL data to create realistic "knee" patterns +- `show_utilities()` - Display available utility functions + +These utilities are automatically available to the code generation assistant when running in the Docker sandbox (mapped as `/workspace/utils/`). The system will only import these utilities when specifically needed for RUL transformations, preventing unnecessary module loading errors (`ModuleNotFoundError: No module named 'utils'` will not occur). + +**How It Works:** +- When you start the sandbox with `output_data/` as the mount point, the `utils/` folder becomes accessible at `/workspace/utils/` +- The code generation assistant only imports utils when performing RUL transformations +- For regular tasks (data retrieval, plotting, etc.), utils are not imported, avoiding module errors + +**Note**: If you move your `output_data` folder, make sure the `utils` subfolder comes with it, or copy it from `utils_template` again. + +### 6. Vanna SQL Agent Training (Automatic) + +**Important**: The Vanna SQL agent training happens automatically when you start the workflow server. The `vanna_training_data.yaml` file contains pre-configured domain-specific knowledge that will be loaded automatically during server startup. + +This training data includes: +- **Synthetic DDL statements**: Table schemas for all NASA turbofan datasets +- **Domain documentation**: Detailed explanations of database structure and query patterns +- **Example queries**: Common SQL patterns for turbofan data analysis +- **Question-SQL pairs**: Natural language to SQL mappings + +The automatic training helps the SQL agent understand: +- How to distinguish between training, test, and RUL tables +- Proper handling of remaining useful life calculations +- Domain-specific terminology and query patterns +- Table relationships and data structure + +Training configuration is specified in `configs/config-reasoning.yml`: +```yaml +vanna_training_data_path: "vanna_training_data.yaml" +``` + +**Note**: If you modify your database structure or add new query patterns, update the `vanna_training_data.yaml` file accordingly to maintain optimal SQL generation performance. + +### 7. Set Environment Variables + +Set the required environment variables for the workflow: + +1. Create a `.env` file from the template and update it with your actual values: + ```bash + cp env_template.txt .env + ``` + + Then edit the `.env` file and replace the placeholder values with your actual keys: + ```bash + # Replace the placeholder values with your actual keys + NVIDIA_API_KEY="your-actual-nvidia-api-key" + CATALYST_ACCESS_KEY="your-actual-catalyst-access-key" # Optional + CATALYST_SECRET_KEY="your-actual-catalyst-secret-key" # Optional + ``` + +2. Source the file to export the variables: + ```bash + source .env + ``` + +**Note**: The `env_template.txt` file contains placeholder values. Copy it to `.env` and replace them with your actual API keys before sourcing the file. + +Verify that the NVIDIA API key is set: + +```bash +echo $NVIDIA_API_KEY +``` + +## Launch Server and UI + +### Start FastAPI Server + +With other frameworks like LangGraph or CrewAI, users are expected to develop a FastAPI server to interact with their agentic workflow. Fortunately, NeMo Agent Toolkit offers this out of the box with the simple `nat serve --config_file ` command. + +Start the server now: + +```bash +nat serve --config_file=configs/config-reasoning.yml +``` + +You should see something like this, which indicates that the server started successfully: + +```bash +... +... +INFO: Application startup complete. +INFO: Uvicorn running on http://localhost:8000 (Press CTRL+C to quit) +``` + +During startup, you'll see Vanna training logs as the SQL agent automatically loads the domain knowledge from `vanna_training_data.yaml` (as described in Section 6). + +### Start Modern Web UI (Recommended) + +We now provide a **custom modern web interface** inspired by the NVIDIA AIQ Research Assistant design! This UI offers a superior experience for Asset Lifecycle Management workflows compared to the generic NeMo-Agent-Toolkit-UI. + +**In a new terminal**, navigate to the frontend directory and start the UI: + +```bash +cd frontend +npm install # First time only +npm start +``` + +The UI will be available at `http://localhost:3000` + +**Features of the Modern UI:** +- 🎨 Clean, professional NVIDIA-branded design +- 📊 Embedded visualization display for plots and charts +- 🎯 Quick-start example prompts for common queries +- ⚙️ Configurable settings panel +- 🌓 Dark/Light theme support +- 📱 Fully responsive mobile design +- 🔄 Real-time streaming responses + +See `frontend/README.md` for detailed documentation. + +### Start Code Execution Sandbox + +The code generation assistant requires a standalone Python sandbox that can execute the generated code. This step starts that sandbox. + +Note: You will need a system that can run Docker. If you are running this on a macOS laptop without Docker Desktop, try [Colima](https://github.com/abiosoft/colima). + +Navigate to the NeMo Agent Toolkit code execution directory: + +```bash +cd /path-to/NeMo-Agent-Toolkit/src/nat/tool/code_execution/ +``` + +Start the sandbox by running the script with your output folder path: + +```bash +./local_sandbox/start_local_sandbox.sh local-sandbox /path-to-output-folder-as-specified-in-config-yml/ +``` + +For example: + +```bash +./local_sandbox/start_local_sandbox.sh local-sandbox /path-to/GenerativeAIExamples/industries/manufacturing/asset_lifecycle_management_agent/output_data/ +``` + +[Optional] Verify the sandbox is running correctly: + +```bash +# Test code execution (the main endpoint) +curl -X POST http://localhost:6000/execute \ + -H 'Content-Type: application/json' \ + -d '{"generated_code": "print(\"Hello from sandbox!\")", "timeout": 10, "language": "python"}' +``` + +To stop the sandbox when you're done, stop the Docker container: + +```bash +docker stop local-sandbox +``` + +### Alternative: E2B Cloud Sandbox (No Docker Required) + +E2B provides cloud-hosted code execution without requiring Docker installation. This is ideal for development environments, CI/CD pipelines, or machines without Docker. + +**Setup:** + +1. **Get E2B API Key** + - Sign up at https://e2b.dev/auth/sign-up + - Get your API key from https://e2b.dev/dashboard + +2. **Install E2B SDK** + ```bash + uv pip install -e ".[e2b]" + ``` + +3. **Set API Key** + ```bash + export E2B_API_KEY="your-e2b-api-key" + ``` + +4. **Update Configuration** + + In `configs/config-reasoning.yaml`, change the `code_execution_tool` to use E2B: + + ```yaml + code_generation_assistant: + _type: code_generation_assistant + llm_name: "coding_llm" + code_execution_tool: "e2b_code_execution" # Change from "code_execution" to "e2b_code_execution" + verbose: true + + # Uncomment this section + e2b_code_execution: + _type: e2b_code_execution + e2b_api_key: "${E2B_API_KEY}" + workspace_files_dir: "output_data" + timeout: 30.0 + max_output_characters: 2000 + ``` + +**Comparison:** + +| Feature | Local Docker | E2B Cloud | +|---------|-------------|-----------| +| Setup | Requires Docker + container | Just API key | +| Speed (cold start) | ~2-5 seconds | ~150ms | +| Execution | Fast (local) | + network overhead | +| File access | Direct mount | Upload/download | +| Database | Direct access | Must upload (~50-100MB) | +| Cost | Free (local resources) | API usage based | +| Network | Not required | Required | +| Best for | Development, large databases | CI/CD, cloud deployments | + +**E2B Resources:** +- Documentation: https://e2b.dev/docs +- Configuration Guide: See `configs/README.md` for detailed E2B and SSL setup +- Test Script: `test_e2b_sandbox.py` for standalone testing + +## Workspace Utilities + +The Asset Lifecycle Management Agent includes a powerful **workspace utilities system** that provides pre-built, reliable functions for common data processing tasks. This eliminates the need for the code generation assistant to implement complex algorithms from scratch, resulting in more reliable and consistent results. + +### How Workspace Utilities Work + +The utilities are located in `/workspace/utils/` (which maps to your `output_data/utils/` directory). Instead of asking the LLM to generate complex transformation code through multiple agent layers (which can lose context and introduce errors), this system provides pre-tested utility functions that can be invoked with simple instructions. + +**Architecture Benefits**: +- **Reliability**: Pre-tested, robust implementations instead of generated code +- **Consistency**: Same results every time, no variation in algorithm implementation +- **Simplicity**: Reasoning agent just needs to specify "use RUL utility" instead of detailed pseudo-code +- **Error Handling**: Comprehensive validation and user-friendly error messages +- **In-Place Operations**: Files are modified directly, avoiding unnecessary copies + +### Available Utilities + +#### RUL Transformation Utilities + +**`apply_piecewise_rul_transformation(file_path, maxlife=100)`** +- Transforms RUL data to create realistic "knee" patterns +- **Input**: JSON file with engine time series data +- **Output**: pandas DataFrame with original data plus new 'transformed_RUL' column (also saves file in-place) +- **Pattern**: RUL stays constant at `MAXLIFE` until remaining cycles drop below threshold, then decreases linearly to 0 +- **Use case**: Creating realistic RUL patterns for comparison with predicted values + +### Usage in Workflows + +**For Users**: When interacting with the system, you can request complex data transformations knowing that reliable utilities will handle the implementation. For example: + +``` +"Transform the actual RUL data for engine 24 to piecewise representation with MAXLIFE=100" +``` + +**For Developers**: The code generation assistant automatically uses these utilities when available. The system prompts include instructions to: +1. Check if a task can be accomplished using workspace utilities +2. Import utilities with proper path setup +3. Use utilities instead of custom implementations + +### Example Workflow + +1. **User Request**: "Compare actual vs predicted RUL for engine unit 24" +2. **System Process**: + - Retrieves ground truth data from database + - Predicts RUL using the model + - **Uses utility**: `utils.apply_piecewise_rul_transformation(data_file, maxlife=100)` (returns DataFrame) + - Generates comparison visualization +3. **Result**: Clean, reliable transformation with consistent knee pattern + +### Adding Custom Utilities + +You can extend the utilities by adding new functions to `/output_data/utils/`: + +1. **Create your utility function** in `utils/` directory +2. **Import it** in `utils/__init__.py` +3. **Document it** in the help system +4. **Update system prompts** if needed (optional) + +**Example utility structure**: +```python +def your_custom_utility(file_path: str, param: int = 100) -> str: + """ + Your custom utility function. + + Args: + file_path: Path to input file + param: Your parameter + + Returns: + Success message with details + """ + # Implementation with error handling + # ... + return "✅ Custom utility executed successfully!" +``` + +### Best Practices + +1. **Prefer Utilities**: Always check if existing utilities can handle your task +2. **Error Handling**: Utilities include comprehensive validation - no need to duplicate +3. **In-Place Operations**: Utilities modify files directly, avoiding data duplication +4. **Consistent Interface**: All utilities return descriptive success messages +5. **Documentation**: Use `utils.show_utilities()` to discover available functions + +### Alternative: Generic NeMo-Agent-Toolkit UI + +If you prefer the generic NeMo Agent Toolkit UI instead of our custom interface: + +```bash +git clone https://github.com/NVIDIA/NeMo-Agent-Toolkit-UI.git +cd NeMo-Agent-Toolkit-UI +npm ci +npm run dev +``` +The UI is available at `http://localhost:3000` + +**Configure UI Settings:** +- Click the Settings icon (bottom left) +- Set HTTP URL to `/chat/stream` (recommended) +- Configure theme and WebSocket URL as needed +- Check "Enable intermediate results" and "Enable intermediate results by default" if you prefer to see all agent calls while the workflow runs + +**Note:** The custom modern UI (described above) provides better visualization embedding, domain-specific examples, and a more polished experience tailored for Asset Lifecycle Management workflows. + +## Anomaly Detection Options + +The Asset Lifecycle Management Agent provides two anomaly detection approaches for sensor data analysis: + +### MOMENT-1-Large (Default) + +**Overview**: Local time-series foundation model that detects anomalies without requiring external API calls. + +**Configuration**: Pre-configured in `configs/config-reasoning.yaml`: +```yaml +anomaly_detection: + _type: moment_anomaly_detection_tool + output_folder: "output_data" +``` + +**Usage Pattern**: +1. Retrieve sensor data using SQL retriever (saves as JSON file) +2. Pass JSON file path to anomaly_detection tool +3. Tool adds 'is_anomaly' boolean column to the data +4. Visualize with plot_anomaly tool + +**Example**: +``` +Retrieve and detect anomalies in sensor 4 measurements for engine number 78 in train FD001 dataset. +``` + +**Advantages**: +- No API key required +- Fast local execution +- Works offline +- No usage costs + +### NV Tesseract (Alternative - NVIDIA NIM) + +**Overview**: NVIDIA's production-grade anomaly detection foundation model accessible via NIM endpoints. Provides advanced anomaly analysis with forecasting capabilities. + +**Setup**: + +1. **Update Configuration**: In `configs/config-reasoning.yaml`, uncomment the NV Tesseract section: + + ```yaml + nv_tesseract_anomaly_detection: + _type: nv_tesseract_anomaly_detection + llm_name: "reasoning_llm" # NIM endpoint with NV Tesseract model + model_name: "nvidia/nv-anomaly-tesseract-1.0" + lookback_period: 30 # Number of time steps to analyze + forecast_horizon: 10 # Number of time steps to forecast + ``` + +2. **Update Tool List**: In the `data_analysis_assistant` tool_names, change `anomaly_detection` to `nv_tesseract_anomaly_detection` + +3. **Update System Prompt**: Change references from `anomaly_detection` to `nv_tesseract_anomaly_detection` in the system prompt + +**Usage Pattern**: +- Provide unit_number and dataset_name directly +- No need to pre-fetch data as JSON +- Tool queries database and performs analysis + +**Example**: +``` +Detect anomalies for unit 78 in train_FD001 dataset using NV Tesseract. +``` + +**Advantages**: +- Production-grade accuracy +- Includes forecasting +- Detailed anomaly explanations +- Identifies specific problematic sensors +- Anomaly score (0-1 scale) + +**Comparison**: + +| Feature | MOMENT-1-Large | NV Tesseract NIM | +|---------|----------------|------------------| +| Setup | Pre-configured | Requires NIM access | +| API Key | Not required | NVIDIA_API_KEY required | +| Cost | Free (local) | API usage based | +| Execution | Local | Cloud NIM endpoint | +| Input | JSON file path | unit_number + dataset_name | +| Output | Binary anomaly flags | Detailed analysis + scores | +| Forecasting | No | Yes (configurable horizon) | +| Best for | Quick checks, offline use | Production, detailed analysis | + +**Note**: Both options work with the same downstream visualization tools (plot_anomaly) and integrate seamlessly with the agentic workflow. + +## Example Prompts + +Test the system with these prompts: + +**Data Retrieval:** +``` +Retrieve the time in cycles and operational setting 1 from the FD001 test table for unit number 1 and plot its value vs time. +``` + +![Data Retrieval Example](imgs/test_prompt_1.png) + +**Visualization:** +``` +Retrieve real RUL of each unit in the FD001 test dataset. Then plot a distribution of it. +``` + +![Visualization Example](imgs/test_prompt_2.png) + + +**Anomaly Detection** +``` +Retrieve and detect anomalies in sensor 4 measurements for engine number 78 in train FD001 dataset. +``` + +![Anomaly Detection Example](imgs/test_prompt_4.png) + +**Workspace Utilities Demo** +``` +Retrieve RUL values and time in cycles for engine unit 24 from FD001 train dataset. Use the piece wise RUL transformation code utility to perform piecewise RUL transformation on the ground truth RUL values with MAXLIFE=100.Finally, Plot a comparison line chart with RUL values and its transformed values across time. +``` + +*This example demonstrates how to discover and use workspace utilities directly. The system will show available utilities and then apply the RUL transformation using the pre-built, reliable utility functions.* + +**Prediction and Comparison (Uses Workspace Utilities)** +``` +Perform the following steps: + +1.Retrieve the time in cycles, all sensor measurements, and ground truth RUL values, partition by unit number for engine unit 24 from FD001 train dataset. +2.Use the retrieved data to predict the Remaining Useful Life (RUL). +3.Use the piece wise RUL transformation code utility to apply piecewise RUL transformation only to the observed RUL column with MAXLIFE of 100. +4.Generate a plot that compares the transformed RUL values and the predicted RUL values across time. +``` +![Prediction Example](imgs/test_prompt_3.png) + +*Note: This example automatically uses the workspace `apply_piecewise_rul_transformation` utility to create realistic knee-pattern RUL data for comparison, resulting in much cleaner and more meaningful visualizations.* + +## Observability (Optional) + +### Monitor Your System with Phoenix + +Ensure that Phoenix tracing-related information is present in the config file. + +Uncomment this portion of `asset_lifecycle_management_agent/configs/config-reasoning.yml` file: + +```yaml +... + # Uncomment this to enable tracing + # tracing: + # phoenix: + # _type: phoenix + # endpoint: http://localhost:6006/v1/traces + # project: alm-test # You can replace this with your preferred project name +... +``` + +```bash +# Docker (recommended) +docker run -p 6006:6006 -p 4317:4317 arizephoenix/phoenix:latest + +# Or install as package +uv pip install arize-phoenix +phoenix serve +``` +Access the dashboard at `http://localhost:6006` to monitor traces, performance, and costs. + +### Monitor Your System with Catalyst + +Follow the instructions [here](https://github.com/NVIDIA/NeMo-Agent-Toolkit/blob/develop/docs/source/workflows/observe/observe-workflow-with-catalyst.md) to set up your RAGA AI profile. + +Ensure you update the CATALYST-related environment variables in the `.env` file (uncomment and set the values) and source that file again: + +```bash +CATALYST_ACCESS_KEY="xxxxxxxxxxxxxxxx" # Change this to your RAGA AI Access key +CATALYST_SECRET_KEY="xxxxxxxxxxxxxxxxxxxxxxxx" # Change this to your RAGA AI Secret key +CATALYST_ENDPOINT=https://catalyst.raga.ai/api # Don't change this +``` + +Uncomment this portion of `asset_lifecycle_management_agent/configs/config-reasoning.yml` file to enable Catalyst tracing: + +```yaml +... + # Uncomment this to enable tracing + # tracing: + # catalyst: + # _type: catalyst + # project: "alm-test" # You can replace this with your preferred project name + # dataset: "alm-dataset" # You can replace this with your preferred dataset name +... +``` + +You should see Catalyst initialization-related information in the terminal when you launch the workflow server. + +## [Optional] Testing the Workflow + +NeMo Agent Toolkit provides the flexibility to run workflows not just through terminal commands (`nat serve`) but also programmatically in Python which helps in seamless CI/CD pipeline integration. + +You can test the workflow by running the `test_alm_workflow.py` file using pytest instead of starting the server, which provides a Pythonic way of building and running the workflow programmatically. This approach is particularly valuable for continuous integration and deployment systems, allowing automated validation of workflow components and streamlined deployment processes. + +Ensure that you have set the `$NVIDIA_API_KEY` environment variable before running: + +```bash +pytest test_alm_workflow.py -m e2e -v +``` + +To run individual tests in the file: + +```bash +pytest test_alm_workflow.py -k "" -v +``` + +## Evaluation + +This example comes with 25 curated queries and reference answers that form our evaluation dataset. You can access this in the `eval_data/eval_set_master.json` file. + +### Multimodal Evaluation with Vision-Language Models + +We have implemented an innovative **Multimodal LLM Judge Evaluator** for agentic workflow evaluation, specifically designed for Asset Lifecycle Management tasks that generate both text and visual outputs. + +**Why Custom Multimodal Evaluation?** + +The built-in evaluators in NeMo Agent Toolkit have significant limitations: +- **Text-Only Evaluation**: Cannot assess visual outputs like plots and charts +- **Rigid String Matching**: Uses LangChain's `TrajectoryEvalChain` which only looks for exact patterns +- **No Visual Understanding**: Cannot evaluate whether generated plots match expected visualizations +- **Generic Prompts**: Not tailored for asset management and maintenance domain + +**Our Innovative Multimodal Approach:** + +**Vision-Language Model Evaluation** - Uses `nvidia/llama-3.1-nemotron-nano-vl-8b-v1` (VLM) to evaluate both text semantic correctness and visual plot accuracy in a unified evaluation framework. + +**Key Innovation: Dual-Mode Intelligent Evaluation** +- **Text Evaluation Mode**: When no plots are present, evaluates semantic correctness of text responses +- **Visual Evaluation Mode**: When plot images are detected, automatically switches to visual analysis mode to assess plot accuracy against reference descriptions +- **Automatic Plot Detection**: System automatically discovers plot file paths in responses and includes actual plot images in the evaluation + +**Advantages:** +- ✅ **Unified Multimodal Assessment**: Single VLM evaluates both text quality and visual accuracy +- ✅ **Intelligent Mode Switching**: Automatically detects whether to evaluate text or plots +- ✅ **Visual Understanding**: Can assess if generated plots show correct data patterns, axis labels, trends +- ✅ **Simple Scoring**: Supports only three scores from 0.0 (fully incorrect), 0.5 (partially correct) to 1.0 (fully correct) +- ✅ **Domain-Specific**: Tailored prompts for Asset Lifecycle Management and maintenance visualization patterns + +We have created a smaller version of this dataset in `eval_data/eval_set_test.json` to help with quick checks before running the larger evaluation workflow. + +### Evaluate with NAT + +Update the config file with the path to the evaluation set. + +In `asset_lifecycle_management_agent/configs/config-reasoning.yml`: +```yaml +eval: + general: + output: + dir: "eval_output" + cleanup: true + dataset: + _type: json + file_path: "eval_data/eval_set_master.json" # Path to eval dataset + query_delay: 10 # Change this to increase delay between running queries, useful if your underlying API (like build.nvidia.com) has requests/second or rate limits + max_concurrent: 1 # Change this to the number of eval set entries that should be processed concurrently. Keep it at 1 to ensure smooth execution +``` + +Now, run this command: + +```bash +nat eval --config_file configs/config-reasoning.yml +``` + +You should see an `eval_output` folder generated in your working directory with `multimodal_eval_output.json`. We have provided you with an example output in `eval_output/example_multimodal_eval_output.json`. + +### Model Performance Notes + +Our evaluation results show that **GPT 4.1 mini model achieves higher accuracies** when used as the ReAct agent compared to other models. If you're looking to maximize evaluation performance, consider configuring GPT 4.1 mini as your reasoning agent in the workflow configuration. + +Add this to the config file: +``` +analyst_llm: + _type: openai + model_name: "gpt-4.1-mini" +``` + +## Known Issues + +- **Rate Limiting**: When using hosted build.nvidia.com endpoints, you may receive `[429] Too Many Requests` errors. This happens because the agentic workflow can generate a high volume of requests in a short period, exceeding the service's rate limits. To avoid these errors, consider running models locally instead of relying on the hosted endpoints. + +- **Code Generation Failures**: Sometimes the code generation assistant cannot generate correct code and reaches the maximum retry limit. In this case, you may see a workflow response like "I seem to have a problem." Try running the query again - we are actively working to improve the code generation assistant experience. + +- **Poor Response Quality**: If you're not getting good responses with the provided LLMs: + - First, switch the reasoning model if the generated plan appears incorrect + - Then, swap the analyst LLM to a model that excels at both tool calling and instruction following + - You typically won't need to replace the embedding model, SQL model, code generation model, or evaluation model. + +- **Evaluation Benchmark**: The workflow currently achieves an `average_score` of 0.75 or above on the master evaluation dataset. We are actively working to improve this score toward 1.0. + +## Next Steps + +The Asset Lifecycle Management Agent provides a foundation for comprehensive industrial asset management. Planned enhancements include: + +**Operation & Maintenance Phase:** +- Memory layer for context retention across maintenance sessions +- Parallel tool execution for faster responses +- Action recommendation agent for maintenance prioritization +- Real-time fault detection agent +- Integration with NVIDIA's NV-Tesseract foundation models for improved time-series accuracy +- Integration with NeMo Retriever for enhanced data source context + +**Expanded ALM Capabilities:** +- **Planning & Acquisition**: Tools for asset specification analysis, vendor comparison, and TCO (Total Cost of Ownership) calculation +- **Deployment & Commissioning**: Integration with commissioning checklists, validation protocols, and asset registration systems +- **Upgrades & Optimization**: Performance benchmarking tools, upgrade recommendation engines, and ROI analysis +- **Decommissioning & Disposal**: End-of-life planning tools, environmental compliance tracking, and asset value recovery optimization + +**Evaluation & Quality:** +- Expansion of evaluation dataset with complex queries involving advanced SQL queries like CTEs +- Additional evaluation metrics for ALM-specific tasks +--- + +**Resources:** +- [NeMo Agent Toolkit Documentation](https://docs.nvidia.com/nemo-agent-toolkit/) +- [Phoenix Observability](https://phoenix.arize.com/) +- [NV-Tesseract Models](https://developer.nvidia.com/blog/new-nvidia-nv-tesseract-time-series-models-advance-dataset-processing-and-anomaly-detection/) diff --git a/examples/asset_lifecycle_management/configs/README.md b/examples/asset_lifecycle_management/configs/README.md new file mode 100644 index 0000000..2222396 --- /dev/null +++ b/examples/asset_lifecycle_management/configs/README.md @@ -0,0 +1,776 @@ +# SQL Query and Retrieve Tool Configuration Guide + +This comprehensive guide explains how to configure the SQL Query and Retrieve Tool, covering both vector store backends and SQL database connections. + +## Table of Contents +1. [Vector Store Configuration](#vector-store-configuration) +2. [SQL Database Configuration](#sql-database-configuration) +3. [Complete Configuration Examples](#complete-configuration-examples) +4. [Troubleshooting](#troubleshooting) + +--- + +## Vector Store Configuration + +### Overview + +The tool supports **two vector store backends** for storing Vanna AI SQL training data: +- **ChromaDB** (local, file-based) - Default +- **Elasticsearch** (distributed, server-based) + +Both vector stores provide identical functionality and store the same data (DDL, documentation, question-SQL pairs). + +### Quick Start - Vector Stores + +#### Option 1: ChromaDB (Recommended for Development) + +```yaml +functions: + - name: my_sql_tool + type: generate_sql_query_and_retrieve_tool + llm_name: nim_llm + embedding_name: nim_embeddings + + # ChromaDB Configuration (DEFAULT) + vector_store_type: chromadb + vector_store_path: ./vanna_vector_store + + # Database and other settings... + db_connection_string_or_path: ./database.db + db_type: sqlite + output_folder: ./output + vanna_training_data_path: ./training_data.yaml +``` + +**Requirements:** +- No additional services required +- No extra Python packages needed + +#### Option 2: Elasticsearch (Recommended for Production) + +```yaml +functions: + - name: my_sql_tool + type: generate_sql_query_and_retrieve_tool + llm_name: nim_llm + embedding_name: nim_embeddings + + # Elasticsearch Configuration + vector_store_type: elasticsearch + elasticsearch_url: http://localhost:9200 + elasticsearch_index_name: vanna_sql_vectors # Optional + elasticsearch_username: elastic # Optional + elasticsearch_password: changeme # Optional + + # Database and other settings... + db_connection_string_or_path: ./database.db + db_type: sqlite + output_folder: ./output + vanna_training_data_path: ./training_data.yaml +``` + +**Requirements:** +- Elasticsearch service must be running +- Install: `pip install elasticsearch` + +### Detailed Comparison - Vector Stores + +| Feature | ChromaDB | Elasticsearch | +|---------|----------|---------------| +| **Setup Complexity** | Simple | Moderate | +| **External Services** | None required | Requires ES cluster | +| **Storage Type** | Local file-based | Distributed | +| **High Availability** | No | Yes (with clustering) | +| **Horizontal Scaling** | No | Yes | +| **Best For** | Dev, testing, single-server | Production, multi-user | +| **Authentication** | File system | API key or basic auth | +| **Performance** | Fast for single-user | Fast for multi-user | +| **Backup** | Copy directory | ES snapshots | + +### When to Use Each Vector Store + +#### Use ChromaDB When: +✅ Getting started or prototyping +✅ Single-server deployment +✅ Local development environment +✅ Simple setup required +✅ No existing Elasticsearch infrastructure +✅ Small to medium data volume + +#### Use Elasticsearch When: +✅ Production environment +✅ Multiple instances/users need access +✅ Need high availability and clustering +✅ Already have Elasticsearch infrastructure +✅ Need advanced search capabilities +✅ Distributed deployment required +✅ Large scale deployments + +### Vector Store Configuration Parameters + +#### Common Parameters (Both Vector Stores) +```yaml +llm_name: string # LLM to use +embedding_name: string # Embedding model to use +db_connection_string_or_path: string # Database connection +db_type: string # 'sqlite', 'postgres', or 'sql' +output_folder: string # Output directory +vanna_training_data_path: string # Training data YAML file +``` + +#### ChromaDB-Specific Parameters +```yaml +vector_store_type: chromadb # Set to 'chromadb' +vector_store_path: string # Directory for ChromaDB storage +``` + +#### Elasticsearch-Specific Parameters +```yaml +vector_store_type: elasticsearch # Set to 'elasticsearch' +elasticsearch_url: string # ES URL (e.g., http://localhost:9200) +elasticsearch_index_name: string # Index name (default: vanna_vectors) +elasticsearch_username: string # Optional: for basic auth +elasticsearch_password: string # Optional: for basic auth +elasticsearch_api_key: string # Optional: alternative to username/password +``` + +### Elasticsearch Authentication + +Choose one of these authentication methods: + +#### Option 1: API Key (Recommended) +```yaml +elasticsearch_api_key: your-api-key-here +``` + +#### Option 2: Basic Auth +```yaml +elasticsearch_username: elastic +elasticsearch_password: changeme +``` + +#### Option 3: No Auth (Development Only) +```yaml +# Omit all auth parameters +``` + +### Data Migration Between Vector Stores + +#### From ChromaDB to Elasticsearch +1. Export training data from ChromaDB +2. Update configuration to use Elasticsearch +3. Run tool - it will auto-initialize Elasticsearch with training data + +#### From Elasticsearch to ChromaDB +1. Training data is reloaded from YAML file automatically +2. Update configuration to use ChromaDB +3. Run tool - it will auto-initialize ChromaDB + +### Vector Store Troubleshooting + +#### ChromaDB Issues +**Problem:** `FileNotFoundError` or permission errors +**Solution:** Ensure directory exists and has write permissions + +**Problem:** Slow performance +**Solution:** ChromaDB is single-threaded, consider Elasticsearch for better performance + +#### Elasticsearch Issues +**Problem:** `ConnectionError` or `ConnectionTimeout` +**Solution:** Verify Elasticsearch is running: `curl http://localhost:9200` + +**Problem:** `AuthenticationException` +**Solution:** Check username/password or API key + +**Problem:** Index already exists with different mapping +**Solution:** Delete index and let tool recreate: `curl -X DELETE http://localhost:9200/vanna_vectors` + +--- + +## SQL Database Configuration + +### Overview + +The tool supports **multiple SQL database types** through a unified `db_connection_string_or_path` parameter: +- **SQLite** (local, file-based) - Default +- **PostgreSQL** (open-source RDBMS) +- **MySQL** (open-source RDBMS) +- **SQL Server** (Microsoft database) +- **Oracle** (enterprise database) +- **Any SQLAlchemy-compatible database** + +### Quick Start - SQL Databases + +#### Option 1: SQLite (File-Based, No Server Required) + +```yaml +db_connection_string_or_path: ./database.db # Just a file path +db_type: sqlite +``` + +**Requirements:** +- No additional services required +- No extra Python packages needed (sqlite3 is built-in) + +#### Option 2: PostgreSQL + +```yaml +db_connection_string_or_path: postgresql://user:password@localhost:5432/database +db_type: postgres +``` + +**Requirements:** +- PostgreSQL server must be running +- Install: `pip install psycopg2-binary` + +#### Option 3: MySQL + +```yaml +db_connection_string_or_path: mysql+pymysql://user:password@localhost:3306/database +db_type: sql # Generic SQL via SQLAlchemy +``` + +**Requirements:** +- MySQL server must be running +- Install: `pip install pymysql sqlalchemy` + +#### Option 4: SQL Server + +```yaml +db_connection_string_or_path: mssql+pyodbc://user:pass@host:1433/db?driver=ODBC+Driver+17+for+SQL+Server +db_type: sql # Generic SQL via SQLAlchemy +``` + +**Requirements:** +- SQL Server must be running +- Install: `pip install pyodbc sqlalchemy` +- Install ODBC Driver for SQL Server + +#### Option 5: Oracle + +```yaml +db_connection_string_or_path: oracle+cx_oracle://user:password@host:1521/?service_name=service +db_type: sql # Generic SQL via SQLAlchemy +``` + +**Requirements:** +- Oracle database must be running +- Install: `pip install cx_Oracle sqlalchemy` + +### Detailed Comparison - SQL Databases + +| Feature | SQLite | PostgreSQL | MySQL | SQL Server | Oracle | +|---------|--------|------------|-------|------------|--------| +| **Setup** | None | Server required | Server required | Server required | Server required | +| **Cost** | Free | Free | Free | Licensed | Licensed | +| **Use Case** | Dev/testing | Production | Production | Enterprise | Enterprise | +| **Concurrent Users** | Limited | Excellent | Excellent | Excellent | Excellent | +| **File-Based** | Yes | No | No | No | No | +| **Advanced Features** | Basic | Advanced | Good | Advanced | Advanced | +| **Python Driver** | Built-in | psycopg2 | pymysql | pyodbc | cx_Oracle | + +### When to Use Each Database + +#### Use SQLite When: +✅ Development and testing +✅ Prototyping and demos +✅ Single-user applications +✅ No server infrastructure required +✅ Small to medium data volume +✅ Embedded applications +✅ Quick setup needed + +#### Use PostgreSQL When: +✅ Production deployments +✅ Multi-user applications +✅ Need advanced SQL features +✅ Open-source preference +✅ Need strong data integrity +✅ Complex queries and analytics +✅ GIS data support needed + +#### Use MySQL When: +✅ Web applications +✅ Read-heavy workloads +✅ Need wide compatibility +✅ Open-source preference +✅ Large-scale deployments +✅ Replication required + +#### Use SQL Server When: +✅ Microsoft ecosystem +✅ Enterprise applications +✅ .NET integration needed +✅ Advanced analytics (T-SQL) +✅ Business intelligence +✅ Existing SQL Server infrastructure + +#### Use Oracle When: +✅ Large enterprise deployments +✅ Mission-critical applications +✅ Need advanced features (RAC, Data Guard) +✅ Existing Oracle infrastructure +✅ High-availability requirements +✅ Maximum performance needed + +### Connection String Formats + +#### SQLite +``` +Format: /path/to/database.db +Example: ./data/sales.db +Example: /var/app/database.db +``` + +#### PostgreSQL +``` +Format: postgresql://username:password@host:port/database +Example: postgresql://admin:secret@db.example.com:5432/sales_db +Example: postgresql://user:pass@localhost:5432/mydb +``` + +#### MySQL +``` +Format: mysql+pymysql://username:password@host:port/database +Example: mysql+pymysql://root:password@localhost:3306/inventory +Example: mysql+pymysql://dbuser:pass@192.168.1.10:3306/analytics +``` + +#### SQL Server +``` +Format: mssql+pyodbc://user:pass@host:port/db?driver=ODBC+Driver+XX+for+SQL+Server +Example: mssql+pyodbc://sa:MyPass@localhost:1433/sales?driver=ODBC+Driver+17+for+SQL+Server +Example: mssql+pyodbc://user:pwd@server:1433/db?driver=ODBC+Driver+18+for+SQL+Server +``` + +#### Oracle +``` +Format: oracle+cx_oracle://username:password@host:port/?service_name=service +Example: oracle+cx_oracle://admin:secret@localhost:1521/?service_name=ORCLPDB +Example: oracle+cx_oracle://user:pass@oracledb:1521/?service_name=PROD +``` + +### Database Configuration Parameters + +```yaml +db_connection_string_or_path: string # Path (SQLite) or connection string (others) +db_type: string # 'sqlite', 'postgres', or 'sql' +``` + +**db_type values:** +- `sqlite` - For SQLite databases (uses connect_to_sqlite internally) +- `postgres` or `postgresql` - For PostgreSQL databases (uses connect_to_postgres) +- `sql` - For generic SQL databases via SQLAlchemy (MySQL, SQL Server, Oracle, etc.) + +### SQL Database Troubleshooting + +#### SQLite Issues +**Problem:** `database is locked` error +**Solution:** Close all connections or use WAL mode + +**Problem:** `unable to open database file` +**Solution:** Check file path and permissions + +#### PostgreSQL Issues +**Problem:** `connection refused` +**Solution:** Check PostgreSQL is running: `systemctl status postgresql` + +**Problem:** `authentication failed` +**Solution:** Verify credentials and check pg_hba.conf + +**Problem:** `database does not exist` +**Solution:** Create database: `createdb database_name` + +#### MySQL Issues +**Problem:** `Access denied for user` +**Solution:** Check credentials and user permissions: `GRANT ALL ON db.* TO 'user'@'host'` + +**Problem:** `Can't connect to MySQL server` +**Solution:** Check MySQL is running: `systemctl status mysql` + +#### SQL Server Issues +**Problem:** `Login failed for user` +**Solution:** Check SQL Server authentication mode and user permissions + +**Problem:** `ODBC Driver not found` +**Solution:** Install ODBC Driver: Download from Microsoft + +**Problem:** `SSL Provider: No credentials are available` +**Solution:** Add `TrustServerCertificate=yes` to connection string + +#### Oracle Issues +**Problem:** `ORA-12541: TNS:no listener` +**Solution:** Start Oracle listener: `lsnrctl start` + +**Problem:** `ORA-01017: invalid username/password` +**Solution:** Verify credentials and user exists + +**Problem:** `cx_Oracle.DatabaseError` +**Solution:** Check Oracle client libraries are installed + +### Required Python Packages by Database + +```bash +# SQLite (built-in, no packages needed) +# Already included with Python + +# PostgreSQL +pip install psycopg2-binary + +# MySQL +pip install pymysql sqlalchemy + +# SQL Server +pip install pyodbc sqlalchemy +# Also install: Microsoft ODBC Driver for SQL Server + +# Oracle +pip install cx_Oracle sqlalchemy +# Also install: Oracle Instant Client + +# Generic SQL (covers MySQL, SQL Server, Oracle via SQLAlchemy) +pip install sqlalchemy +``` + +--- + +## Complete Configuration Examples + +### Example 1: SQLite with ChromaDB (Simplest Setup) +```yaml +functions: + - name: simple_sql_tool + type: generate_sql_query_and_retrieve_tool + llm_name: nim_llm + embedding_name: nim_embeddings + # Vector store + vector_store_type: chromadb + vector_store_path: ./vanna_vector_store + # Database + db_connection_string_or_path: ./database.db + db_type: sqlite + # Output + output_folder: ./output + vanna_training_data_path: ./training_data.yaml +``` + +### Example 2: PostgreSQL with Elasticsearch (Production Setup) +```yaml +functions: + - name: production_sql_tool + type: generate_sql_query_and_retrieve_tool + llm_name: nim_llm + embedding_name: nim_embeddings + # Vector store + vector_store_type: elasticsearch + elasticsearch_url: http://elasticsearch:9200 + elasticsearch_username: elastic + elasticsearch_password: changeme + # Database + db_connection_string_or_path: postgresql://dbuser:dbpass@postgres:5432/analytics + db_type: postgres + # Output + output_folder: ./output + vanna_training_data_path: ./training_data.yaml +``` + +### Example 3: MySQL with ChromaDB +```yaml +functions: + - name: mysql_sql_tool + type: generate_sql_query_and_retrieve_tool + llm_name: nim_llm + embedding_name: nim_embeddings + # Vector store + vector_store_type: chromadb + vector_store_path: ./vanna_vector_store + # Database + db_connection_string_or_path: mysql+pymysql://root:password@localhost:3306/sales + db_type: sql + # Output + output_folder: ./output + vanna_training_data_path: ./training_data.yaml +``` + +--- + +## Architecture Notes + +Both vector stores: +- Use the same NVIDIA embedding models +- Store identical training data +- Provide the same vector similarity search +- Are managed automatically by VannaManager +- Support the same training data YAML format + +The tool automatically: +- Detects if vector store needs initialization +- Loads training data from YAML file +- Creates embeddings using NVIDIA models +- Manages vector store lifecycle + +### Performance Tips + +#### ChromaDB +- Keep on SSD for faster I/O +- Regular directory backups +- Monitor disk space + +#### Elasticsearch +- Use SSD-backed storage +- Configure appropriate heap size +- Enable index caching +- Use snapshots for backups +- Monitor cluster health + +--- + +## Quick Reference + +### Configuration Matrix + +| Database | Vector Store | db_type | Connection Format | +|----------|--------------|---------|-------------------| +| SQLite | ChromaDB | sqlite | `./database.db` | +| SQLite | Elasticsearch | sqlite | `./database.db` | +| PostgreSQL | ChromaDB | postgres | `postgresql://user:pass@host:port/db` | +| PostgreSQL | Elasticsearch | postgres | `postgresql://user:pass@host:port/db` | +| MySQL | ChromaDB | sql | `mysql+pymysql://user:pass@host:port/db` | +| MySQL | Elasticsearch | sql | `mysql+pymysql://user:pass@host:port/db` | +| SQL Server | ChromaDB | sql | `mssql+pyodbc://user:pass@host:port/db?driver=...` | +| SQL Server | Elasticsearch | sql | `mssql+pyodbc://user:pass@host:port/db?driver=...` | +| Oracle | ChromaDB | sql | `oracle+cx_oracle://user:pass@host:port/?service_name=...` | +| Oracle | Elasticsearch | sql | `oracle+cx_oracle://user:pass@host:port/?service_name=...` | + +### Recommended Combinations + +| Use Case | Vector Store | Database | Why | +|----------|--------------|----------|-----| +| **Development** | ChromaDB | SQLite | Simplest setup, no servers | +| **Production (Small)** | ChromaDB | PostgreSQL | Reliable, open-source | +| **Production (Large)** | Elasticsearch | PostgreSQL | Scalable, distributed | +| **Enterprise** | Elasticsearch | SQL Server/Oracle | Advanced features, HA | +| **Web App** | ChromaDB | MySQL | Standard web stack | +| **Analytics** | Elasticsearch | PostgreSQL | Complex queries, multi-user | + +### Default Values + +```yaml +vector_store_type: chromadb # Default +elasticsearch_index_name: vanna_vectors # Default ES index +db_type: sqlite # Default +``` + +--- + +## Additional Resources + +For more detailed examples, see: +- **`config_examples.yaml`** - Complete working examples with all combinations of vector stores and databases +- **`vanna_manager.py`** - Implementation details for connection management +- **`vanna_util.py`** - Vector store implementations (ChromaDB and Elasticsearch) + +--- + +## E2B Code Execution Test Setup + +This guide helps you test the E2B cloud sandbox implementation without running the full ALM workflow. + +### Prerequisites + +1. **E2B API Key** + - Sign up at https://e2b.dev/auth/sign-up + - Get your API key from https://e2b.dev/dashboard + - Free tier available for testing + +2. **Install E2B Dependencies** + ```bash + cd examples/asset_lifecycle_management + uv pip install -e ".[e2b]" + ``` + +3. **Set Environment Variables** + ```bash + export E2B_API_KEY="your-e2b-api-key-here" + export NVIDIA_API_KEY="your-nvidia-api-key" # For LLM + ``` + +4. **Prepare Workspace** + ```bash + # Ensure output_data directory exists + mkdir -p output_data + + # Copy utilities if testing utils import + cp -r utils_template output_data/utils + ``` + +### Configuration + +To use E2B cloud sandbox instead of local Docker: + +```yaml +functions: + # E2B Cloud Sandbox + e2b_code_execution: + _type: e2b_code_execution + e2b_api_key: "${E2B_API_KEY}" + workspace_files_dir: "output_data" + timeout: 30.0 + max_output_characters: 2000 + + code_generation_assistant: + _type: code_generation_assistant + llm_name: "coding_llm" + code_execution_tool: "e2b_code_execution" # Use E2B instead of local Docker + output_folder: "output_data" +``` + +### Comparison: Local Docker vs E2B Cloud + +| Feature | Local Docker | E2B Cloud | +|---------|-------------|-----------| +| Setup | Requires Docker + container | Just API key | +| Speed (cold start) | ~2-5 seconds | ~150ms | +| Speed (execution) | Fast | + file transfer overhead | +| File access | Mounted volume | Upload/download | +| Database | Direct access | Must upload | +| Cost | Free (local resources) | API usage based | +| Network | Not required | Required | + +### Troubleshooting E2B + +**Error: "E2B SDK not installed"** +```bash +uv pip install e2b-code-interpreter +``` + +**Error: "E2B API key not set"** +```bash +export E2B_API_KEY="your-key-here" +echo $E2B_API_KEY # Verify it's set +``` + +**Error: "workspace_files_dir not found"** +```bash +mkdir -p output_data +``` + +**Files not downloading** +- Check E2B dashboard for quota limits +- Verify file extensions in `e2b_sandbox.py` (currently: .json, .html, .png, .jpg, .csv, .pdf) +- Check E2B sandbox logs for file creation + +--- + +## SSL Certificate Setup for E2B + +This section explains how to ensure SSL certificates are properly configured for E2B cloud sandbox integration. + +### Quick SSL Certificate Check + +Run this on any machine to verify SSL certificates: + +```bash +# Test if SSL certificates are working +python -c "import ssl; import urllib.request; urllib.request.urlopen('https://api.e2b.dev')" + +# If successful, prints nothing (exit code 0) +# If failed, shows: "SSL: CERTIFICATE_VERIFY_FAILED" +``` + +### SSL Setup by Environment + +#### Ubuntu/Debian + +```bash +# Update system certificates +sudo apt-get update +sudo apt-get install -y ca-certificates + +# Update Python SSL certificates +pip install --upgrade certifi + +# Verify installation +python -c "import certifi; print('Certificates location:', certifi.where())" +``` + +#### macOS (Anaconda/Homebrew Python) + +```bash +# Option 1: Update certifi package +pip install --upgrade certifi + +# Option 2: For Anaconda Python +conda install -c conda-forge ca-certificates certifi openssl + +# Verify +python -c "import certifi; print('Certificates location:', certifi.where())" +``` + +#### Docker Containers + +Add this to your Dockerfile: + +```dockerfile +# Install system certificates +RUN apt-get update && apt-get install -y ca-certificates + +# Install Python certificates +RUN pip install --upgrade certifi + +# Verify SSL works +RUN python -c "import ssl; print('SSL configured')" +``` + +### SSL Troubleshooting + +**Error: "SSL: CERTIFICATE_VERIFY_FAILED"** + +```bash +# Update certifi +pip install --upgrade certifi + +# Set explicit certificate path +export SSL_CERT_FILE=$(python -c "import certifi; print(certifi.where())") + +# Or update system certificates +sudo apt-get install --reinstall ca-certificates # Ubuntu/Debian +``` + +**Error: "certificate verify failed: unable to get local issuer certificate"** + +```bash +# Ubuntu/Debian +sudo apt-get update +sudo apt-get install --reinstall ca-certificates + +# Python +pip install --upgrade certifi + +# Verify +python -c "import requests; requests.get('https://api.e2b.dev')" +``` + +### Best Practices Checklist + +Before deploying E2B to a remote machine: + +- [ ] System CA certificates installed (`ca-certificates` package) +- [ ] Python certifi updated (`pip install --upgrade certifi`) +- [ ] SSL test passes (test with `urllib.request.urlopen('https://api.e2b.dev')`) +- [ ] E2B SDK installed (`pip install e2b-code-interpreter`) +- [ ] E2B_API_KEY environment variable set +- [ ] Test E2B connection works (`Sandbox.create()`) +- [ ] NAT workflow configured correctly +- [ ] Docker containers have certificates (if using Docker) + +**Note**: The Asset Lifecycle Management example uses **local Docker sandbox by default**, so E2B and SSL are optional. E2B is provided as an alternative for cloud-based code execution. + +--- + +## Support Resources + +- **NAT Documentation**: https://docs.nvidia.com/nemo-agent-toolkit/ +- **E2B Documentation**: https://e2b.dev/docs +- **E2B Discord**: https://discord.gg/U7KEcGErtQ +- **GitHub Issues**: https://github.com/NVIDIA/NeMo-Agent-Toolkit-Examples/issues diff --git a/examples/asset_lifecycle_management/configs/config-reasoning.yaml b/examples/asset_lifecycle_management/configs/config-reasoning.yaml new file mode 100644 index 0000000..8e2e864 --- /dev/null +++ b/examples/asset_lifecycle_management/configs/config-reasoning.yaml @@ -0,0 +1,383 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +general: + use_uvloop: true + telemetry: + logging: + console: + _type: console + level: DEBUG + # level: INFO + # file: + # _type: file + # path: "alm.log" + # level: DEBUG + # tracing: + # phoenix: + # _type: phoenix + # endpoint: http://localhost:6006/v1/traces + # project: alm-agent + # catalyst: + # _type: catalyst + # project: "alm-agent" + # dataset: "alm-agent" + +llms: + # SQL query generation model + sql_llm: + _type: nim + model_name: "qwen/qwen2.5-coder-32b-instruct" + + # Data analysis and tool calling model + analyst_llm: + _type: nim + model_name: "qwen/qwen2.5-coder-32b-instruct" + + # Python code generation model + coding_llm: + _type: nim + model_name: "qwen/qwen2.5-coder-32b-instruct" + + # Main reasoning and planning model + reasoning_llm: + _type: nim + model_name: "nvidia/llama-3.3-nemotron-super-49b-v1" + + # Multimodal evaluation model (Vision-Language Model) + multimodal_judging_llm: + _type: nim + model_name: nvidia/llama-3.1-nemotron-nano-vl-8b-v1 + +embedders: + # Text embedding model for vector database operations + vanna_embedder: + _type: nim + model_name: "nvidia/llama-3_2-nv-embedqa-1b-v2" + +functions: + # Original SQL tool (existing implementation for comparison) + sql_retriever_old: + _type: generate_sql_query_and_retrieve_tool + llm_name: "sql_llm" + embedding_name: "vanna_embedder" + # Vector store configuration + vector_store_type: "chromadb" # Optional, chromadb is default + vector_store_path: "database" + # Database configuration + db_type: "sqlite" # Optional, sqlite is default + db_connection_string_or_path: "database/nasa_turbo.db" + # Output configuration + output_folder: "output_data" + vanna_training_data_path: "vanna_training_data.yaml" + + # New SQL tool (package-based implementation for comparison) + sql_retriever_vanna: + _type: vanna_sql_tool + llm_name: "sql_llm" + embedding_name: "vanna_embedder" + # Vector store configuration + vector_store_type: "chromadb" + vector_store_path: "database_vanna" # Different path to avoid conflicts + # Database configuration + db_type: "sqlite" + db_connection_string_or_path: "database/nasa_turbo.db" + # Output configuration + output_folder: "output_data" + training_data_path: "vanna_training_data.yaml" + auto_train_on_init: true + + # Active SQL tool (default to old for now, change after comparison) + sql_retriever: + _type: generate_sql_query_and_retrieve_tool + llm_name: "sql_llm" + embedding_name: "vanna_embedder" + vector_store_type: "chromadb" + vector_store_path: "database" + db_type: "sqlite" + db_connection_string_or_path: "database/nasa_turbo.db" + output_folder: "output_data" + vanna_training_data_path: "vanna_training_data.yaml" + + predict_rul: + _type: predict_rul_tool + output_folder: "output_data" + scaler_path: "models/scaler_model.pkl" + model_path: "models/xgb_model_fd001.pkl" + + # MOMENT-based Anomaly Detection (Default) + # Uses MOMENT-1-Large foundation model for local anomaly detection + anomaly_detection: + _type: moment_anomaly_detection_tool + output_folder: "output_data" + + # NV Tesseract Anomaly Detection (Alternative - NVIDIA NIM) + # Uncomment this section and change "anomaly_detection" to use "nv_tesseract_anomaly_detection" + # in the tool_names list and in system prompts if you want to switch to NV Tesseract + # Requires: NVIDIA NIM endpoint with nv-anomaly-tesseract-1.0 model + # nv_tesseract_anomaly_detection: + # _type: nv_tesseract_anomaly_detection + # llm_name: "reasoning_llm" # NIM endpoint for NV Tesseract + # model_name: "nvidia/nv-anomaly-tesseract-1.0" + # lookback_period: 30 # Number of time steps to analyze + # forecast_horizon: 10 # Number of time steps to forecast + + plot_distribution: + _type: plot_distribution_tool + output_folder: "output_data" + + plot_line_chart: + _type: plot_line_chart_tool + output_folder: "output_data" + + plot_comparison: + _type: plot_comparison_tool + output_folder: "output_data" + + plot_anomaly: + _type: plot_anomaly_tool + output_folder: "output_data" + + code_generation_assistant: + _type: code_generation_assistant + llm_name: "coding_llm" + code_execution_tool: "code_execution" # Change to "e2b_code_execution" to use E2B cloud sandbox + verbose: true + + # Local Docker Sandbox (Default) + # Requires: Docker running + local_sandbox container started + # Start: cd /path/to/NeMo-Agent-Toolkit/src/nat/tool/code_execution/ && ./local_sandbox/start_local_sandbox.sh local-sandbox /path/to/output_data/ + code_execution: + _type: code_execution + uri: http://127.0.0.1:6000/execute + sandbox_type: "local" + max_output_characters: 2000 + + # E2B Cloud Sandbox (Alternative - No Docker Required) + # Uncomment this section and change code_execution_tool to "e2b_code_execution" in code_generation_assistant + # Requires: E2B_API_KEY environment variable set + # Install: uv pip install -e ".[e2b]" + # e2b_code_execution: + # _type: e2b_code_execution + # e2b_api_key: "${E2B_API_KEY}" + # workspace_files_dir: "output_data" + # timeout: 30.0 + # max_output_characters: 2000 + + data_analysis_assistant: + _type: react_agent + llm_name: "analyst_llm" + max_iterations: 20 + max_retries: 2 + tool_names: [ + "sql_retriever", + "predict_rul", + "plot_distribution", + "plot_line_chart", + "plot_comparison", + "anomaly_detection", + "plot_anomaly", + "code_generation_assistant" + ] + parse_agent_response_max_retries: 2 + system_prompt: | + ### TASK DESCRIPTION #### + You are a helpful data analysis assistant specializing in Asset Lifecycle Management tasks, currently focused on predictive maintenance for turbofan engines. + **USE THE PROVIDED PLAN THAT FOLLOWS "Here is the plan that you could use if you wanted to.."** + + ### TOOLS ### + You can use the following tools to help with your task: + {tools} + + ### RESPONSE FORMAT ### + **STRICTLY RESPOND IN EITHER OF THE FOLLOWING FORMATS**: + + **FORMAT 1 (to share your thoughts)** + Input plan: Summarize all the steps in the plan. + Executing step: the step you are currently executing from the plan along with any instructions provided + Thought: describe how you are going to execute the step + + **FORMAT 2 (to return the final answer)** + Input plan: Summarize all the steps in the plan. + Executing step: the step you are currently executing from the plan along with any instructions provided + Thought: describe how you are going to execute the step + Final Answer: the final answer to the original input question including the relative file paths of the generated files in the + `output_data/` directory (e.g., output_data/filename.json). + + **FORMAT 3 (when using a tool)** + Input plan: Summarize all the steps in the plan. + Executing step: the step you are currently executing from the plan along with any instructions provided + Thought: describe how you are going to execute the step + Action: the action to take, should be one of [{tool_names}] + Action Input: the input to the tool (if there is no required input, include "Action Input: None") + Observation: wait for the tool to finish execution and return the result + + ### HOW TO CHOOSE THE RIGHT TOOL ### + Follow these guidelines while deciding the right tool to use: + **Ensure that tool calls do not use single quotes or double quotes within the key-value pairs.** + + 1. **SQL Retrieval Tool** + - Use this tool to retrieve data from the database. + - NEVER generate SQL queries by yourself, instead pass the top-level instruction to the tool. + + 2. **Prediction Tools** + - Use predict_rul for RUL prediction requests. + - Always call data retrieval tool to get sensor data before predicting RUL. + + 3. **Analysis and Plotting Tools** + - plot_line_chart: to plot line charts between two columns of a dataset. + - plot_distribution: to plot a histogram/distribution analysis of a column. + - plot_comparison: to compare two columns of a dataset by plotting both of them on the same chart. + + 4. **Anomaly Detection Tools** + - Use anomaly_detection tool for state-of-the-art foundation model-based anomaly detection. + - **Two options available**: + * MOMENT-1-Large (default): Local foundation model for anomaly detection + * NV Tesseract (alternative): NVIDIA NIM-based anomaly detection (if configured) + - **MOMENT requires JSON data**: First use sql_retriever to get sensor data, then pass the JSON file path to anomaly_detection. + - **NV Tesseract input**: Provide unit_number and dataset_name directly (e.g., unit_number=1, dataset_name='train_FD001') + - **OUTPUT**: Creates enhanced sensor data with added 'is_anomaly' boolean column or anomaly analysis JSON. + - Use plot_anomaly to create interactive visualizations of anomaly detection results. + + 5. **Code Generation Guidelines** + When using code_generation_assistant, provide comprehensive instructions in a single parameter: + • Include complete task description with user context and requirements + • Specify available data files and their structure (columns, format, location) + • Combine multiple related tasks into bullet points within one instruction + • Mention specific output requirements (HTML files, JSON data, visualizations) + • The tool automatically generates and executes Python code, returning results and file paths. + + 6. **File Path Handling** + - When giving instructions to the code_generation_assistant, use only the filename itself (for example, filename.json). Do not include any folder paths, since the code_generation_assistant already operates within the outputs directory. + + ### TYPICAL WORKFLOW FOR EXECUTING A PLAN ### + + First, Data Extraction and analysis + - Use SQL retrieval tool to fetch required data + - **Use code_generation_assistant to perform any data processing using Python code ONLY IF INSTRUCTED TO DO SO.** + Finally, Data Visualization + - Use existing plotting tools to generate plots + - Use predict_rul, anomaly_detection or any other relevant tools to perform analysis + Finally, return the result to the user + - Return processed information to calling agent. + +workflow: + _type: reasoning_agent + augmented_fn: "data_analysis_assistant" + llm_name: "reasoning_llm" + verbose: true + reasoning_prompt_template: | + ### DESCRIPTION ### + You are a Data Analysis Reasoning and Planning Expert specialized in Asset Lifecycle Management, with expertise in analyzing turbofan engine sensor data and predictive maintenance tasks. + You are tasked with creating detailed execution plans for addressing user queries while being conversational and helpful. + + Your Role and Capabilities:** + - Expert in Asset Lifecycle Management, turbofan engine data analysis, predictive maintenance, and anomaly detection + - Provide conversational responses while maintaining technical accuracy + - Create step-by-step execution plans using available tools which will be invoked by a data analysis assistant + + **You are given a data analysis assistant to execute your plan, all you have to do is generate the plan** + DO NOT USE MARKDOWN FORMATTING IN YOUR RESPONSE. + + ### ASSISTANT DESCRIPTION ### + {augmented_function_desc} + + ### TOOLS AVAILABLE TO THE ASSISTANT ### + {tools} + + ### CONTEXT ### + You work with turbofan engine sensor data from multiple engines in a fleet. The data contains: + - **Time series data** from different engines, each with unique wear patterns and operational history separated into + four datasets (FD001, FD002, FD003, FD004), each dataset is further divided into training and test subsets. + - **26 data columns**: unit number, time in cycles, 3 operational settings, and 21 sensor measurements + - **Engine lifecycle**: Engines start operating normally, then develop faults that grow until system failure + - **Asset Lifecycle Management - Operation & Maintenance Phase**: Predict Remaining Useful Life (RUL) - how many operational cycles before failure + - **Data characteristics**: Contains normal operational variation, sensor noise, and progressive fault development + This context helps you understand user queries about engine health, sensor patterns, failure prediction, and maintenance planning. + REMEMBER TO RELY ON DATA ANALYSIS ASSITANT TO RETRIEVE DATA FROM THE DATABASE. + + ### SPECIAL CONSTRAINTS ### + Create execution plans for Asset Lifecycle Management tasks (currently focused on predictive maintenance and sensor data analysis). For other queries, use standard reasoning. + Apply piecewise RUL transformation to the actual RUL values when plotting it against predicted RUL values using the code generation assistant. + + ### GUIDELINES ### + **DO NOT use predict_rul tool to fetch RUL data unless the user explicitly uses the word "predict" or something similar, this is because there is also ground truth RUL data in the database which the user might request sometimes.** + **REMEMBER: SQL retrieval tool is smart enough to understand queries like counts, totals, basic facts etc. It can use UNIQUE(), COUNT(), SUM(), AVG(), MIN(), MAX() to answer simple queries. NO NEED TO USE CODE GENERATION ASSISTANT FOR SIMPLE QUERIES.** + **CODE GENERATION ASSISTANT IS COSTLY AND UNRELIABLE MOST OF THE TIMES. SO PLEASE USE IT ONLY FOR COMPLEX QUERIES THAT REQUIRE DATA PROCESSING AND VISUALIZATION.** + + **User Input:** + {input_text} + + Analyze the input and create an appropriate execution plan in bullet points. + +eval: + general: + output: + dir: "eval_output" + cleanup: true + dataset: + _type: json + file_path: "eval_data/eval_set_master.json" + query_delay: 10 # seconds between queries + max_concurrent: 1 # process queries sequentially + + evaluators: + multimodal_eval: + _type: multimodal_llm_judge_evaluator + llm_name: "multimodal_judging_llm" + judge_prompt: | + You are an expert evaluator for Asset Lifecycle Management agentic workflows, with expertise in predictive maintenance tasks. + Your task is to evaluate how well a generated response (which may include both text and visualizations) + matches the reference answer for a given question. + + Question: {question} + Reference Answer: {reference_answer} + Generated Response: {generated_answer} + + IMPORTANT: You MUST provide your response ONLY as a valid JSON object. + Do not include any text before or after the JSON. + + # EVALUATION LOGIC + Your evaluation mode is determined by whether actual plot images are attached to this message: + - If PLOT IMAGES are attached → Perform ONLY PLOT EVALUATION by examining the actual plot images + - If NO IMAGES are attached → Perform ONLY TEXT EVALUATION of the text response + + DO NOT confuse text mentions of plots/files with actual attached images. + Only evaluate plots if you can actually see plot images in this message. + + ## TEXT EVALUATION (only when no images are attached) + Check if the generated text answer semantically matches the reference answer: + - 1.0: Generated answer fully matches the reference answer semantically + - 0.5: Generated answer partially matches with some missing/incorrect elements + - 0.0: Generated answer does not match the reference answer semantically + + ## PLOT EVALUATION (only when images are attached) + Use the reference answer as expected plot description and check how well the actual plot matches: + - 1.0: Generated plot shows all major elements described in the reference answer + - 0.5: Generated plot shows some elements but missing significant aspects + - 0.0: Generated plot does not match the reference answer description + + # RESPONSE FORMAT + You MUST respond with ONLY this JSON format: + {{ + "score": 0.0, + "reasoning": "EVALUATION TYPE: [TEXT or PLOT] - [your analysis and score with justification]" + }} + + CRITICAL REMINDER: + - If images are attached → Use "EVALUATION TYPE: PLOT" + - If no images → Use "EVALUATION TYPE: TEXT" + + Replace the score with your actual evaluation (0.0, 0.5, or 1.0). diff --git a/examples/asset_lifecycle_management/eval_data/eval_set_master.json b/examples/asset_lifecycle_management/eval_data/eval_set_master.json new file mode 100644 index 0000000..ed1eda4 --- /dev/null +++ b/examples/asset_lifecycle_management/eval_data/eval_set_master.json @@ -0,0 +1,262 @@ +[ + { + "id": "1", + "question": "What is the ground truth remaining useful life (RUL) of unit_number 59 in dataset FD001", + "answer": "114 ", + "type": "text", + "category": "retrieval", + "subcategory": "easy", + "original_id": "1", + "source": "eval_set" + }, + { + "id": "2", + "question": "What is the ground truth RUL of unit_number 20 in dataset FD001", + "answer": "16 ", + "type": "text", + "category": "retrieval", + "subcategory": "easy", + "original_id": "2", + "source": "eval_set" + }, + { + "id": "3", + "question": "How many units have ground truth RUL of 100 or more in dataset FD003", + "answer": "33 ", + "type": "text", + "category": "retrieval", + "subcategory": "easy", + "original_id": "3", + "source": "eval_set" + }, + { + "id": "4", + "question": "How many units have ground truth RUL of 50 or less in dataset FD002", + "answer": "88 ", + "type": "text", + "category": "retrieval", + "subcategory": "easy", + "original_id": "4", + "source": "eval_set" + }, + { + "id": "5", + "question": "Report the unit_number of the units that have ground truth RUL equal to 155 in FD002", + "answer": "6, 141, 165 ", + "type": "text", + "category": "retrieval", + "subcategory": "easy", + "original_id": "5", + "source": "eval_set" + }, + { + "id": "6", + "question": "In the dataset FD004, how many units have ground truth RUL equal to 10 and what are their unit numbers?", + "answer": "4 units; unit numbers: 40, 82, 174, 184", + "type": "text", + "category": "retrieval", + "subcategory": "medium", + "original_id": "6", + "source": "eval_set" + }, + { + "id": "7", + "question": "In dataset train_FD004, what was the operational_setting_3 at time_in_cycles 20 for unit_number 107", + "answer": "100 ", + "type": "text", + "category": "retrieval", + "subcategory": "easy", + "original_id": "7", + "source": "eval_set" + }, + { + "id": "8", + "question": "In dataset train_FD004, what was the 3rd operational setting at time 20 for unit_number 107", + "answer": "100 ", + "type": "text", + "category": "retrieval", + "subcategory": "medium", + "original_id": "8", + "source": "eval_set" + }, + { + "id": "9", + "question": "In dataset test_FD002, what are the values of the three operational setting for unit_number 56 at time_in_cycles 10", + "answer": "10.0026, 0.25, 100 ", + "type": "text", + "category": "retrieval", + "subcategory": "medium", + "original_id": "9", + "source": "eval_set" + }, + { + "id": "10", + "question": "In dataset test_FD003, what is the value of sensor_measurement_4 for unit_number 25 at time_in_cycles 20", + "answer": "1409.26 ", + "type": "text", + "category": "retrieval", + "subcategory": "easy", + "original_id": "10", + "source": "eval_set" + }, + { + "id": "11", + "question": "How many units have operational_setting_3 equal to 100 in dataset train_FD001 at time_in_cycles 40?", + "answer": "100 ", + "type": "text", + "category": "retrieval", + "subcategory": "medium", + "original_id": "11", + "source": "eval_set" + }, + { + "id": "12", + "question": "How many units have operational_setting_3 equal to 100 in dataset train_FD001?", + "answer": "100 ", + "type": "text", + "category": "retrieval", + "subcategory": "medium", + "original_id": "12", + "source": "eval_set" + }, + { + "id": "13", + "question": "In dataset train_FD003, what was sensor_measurement_20 and sensor_measurement_21 for unit 1 at time_in_cycles 10", + "answer": "38.94, 23.4781 ", + "type": "text", + "category": "retrieval", + "subcategory": "easy", + "original_id": "13", + "source": "eval_set" + }, + { + "id": "14", + "question": "For dataset test_FD004, what is the ground truth remaining useful life of unit 60", + "answer": "139 ", + "type": "text", + "category": "retrieval", + "subcategory": "easy", + "original_id": "14", + "source": "eval_set" + }, + { + "id": "15", + "question": "Using the data in test_FD002, predict the remaining useful life of unit_number 10 at time_in_cycles 84", + "answer": "79 ", + "type": "text", + "category": "prediction", + "subcategory": "medium", + "original_id": "15", + "source": "eval_set" + }, + { + "id": "16", + "question": "Given the data in test_FD003, predict the RUL of unit_number 30", + "answer": "89 ", + "type": "text", + "category": "prediction", + "subcategory": "medium", + "original_id": "16", + "source": "eval_set" + }, + { + "id": "17", + "question": "In dataset train_FD004, plot sensor_measurement1 vs time_in_cycles for unit_number 107", + "answer": "Line chart showing sensor_measurement1 values on y-axis ranging from 445.00 to 518.67 plotted against time_in_cycles in x-axisfor unit 107 in dataset FD004.", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "easy", + "original_id": "1", + "source": "eval_set" + }, + { + "id": "18", + "question": "In dataset train_FD004, plot the variation of sensor_measurement1 over time for unit_number 107", + "answer": "Line chart displaying the variation of sensor_measurement1 values on y-axis over time cycles in x-axis for unit 107 in dataset FD004. The plot should illustrate how sensor_measurement1 changes across different time cycles, demonstrating the temporal variation pattern of this sensor reading.", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "easy", + "original_id": "2", + "source": "eval_set" + }, + { + "id": "19", + "question": "In dataset train_FD002, plot operational_setting_3 vs time_in_cycles for unit_number 200", + "answer": "Line chart showing operational_setting_3 values on y-axis ranging against time_in_cycles in x-axis for unit 200 in dataset FD002. Only two values 100 and 60 should be visible on the plot.", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "easy", + "original_id": "3", + "source": "eval_set" + }, + { + "id": "20", + "question": "Plot a histogram showing distribution of values of operational_setting_3 over time for unit_number 200 in dataset train_FD002", + "answer": "Histogram Two bars for 100 and 60 with higher bar for 100", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "medium", + "original_id": "4", + "source": "eval_set" + }, + { + "id": "21", + "question": "In dataset test_FD001 plot a histogram showing the distribution of operational_setting_3 across all units", + "answer": "Constant value 100, so just one high bar for 100", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "medium", + "original_id": "5", + "source": "eval_set" + }, + { + "id": "22", + "question": "In dataset test_FD001 plot operational_setting_3 as a function of time_in_cycles for units 10, 20, 30, 40", + "answer": "Line chart showing operational_setting_3 values on y-axis ranging against time_in_cycles in x-axis with a constant line parallel to x-axis with value y-axis as 100", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "medium", + "original_id": "6", + "source": "eval_set" + }, + { + "id": "23", + "question": "Retrieve RUL of all units from the FD001 and plot their distribution using a histogram", + "answer": "Histogram showing distribution of RUL values for all units in FD001 dataset. Should contain 100 data points representing different RUL values ranging from 7 to 145 cycles. The distribution should show 71 unique RUL values with varying frequencies. The plot should display the spread and frequency of remaining useful life values across all engine units in the dataset.", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "medium", + "original_id": "7", + "source": "eval_set" + }, + { + "id": "24", + "question": "Retrieve time in cycles, all sensor measurements and RUL value for engine unit 24 from FD001 test and RUL tables. Predict RUL for it. Finally, generate a plot to compare actual RUL value with predicted RUL value across time.", + "answer": "A Plot showing both actual RUL values and predicted RUL values trend (in y-axis) plotted against time in cycles (in x-axis) for engine unit 24", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "hard", + "original_id": "8", + "source": "eval_set" + }, + { + "id": "25", + "question": "Retrieve and detect anomalies in sensor 4 measurements for engine number 78 in train FD001 dataset.", + "answer": "A Plot showing observed values and anomalies in sensor 4 measurements for engine number 78.", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "hard", + "original_id": "9", + "source": "eval_set" + }, + { + "id": "26", + "question ": "Perform the following steps: 1.Retrieve the time in cycles, all sensor measurements, and ground truth RUL values for engine unit 24 from FD001 train dataset. 2.Use the retrieved data to predict the Remaining Useful Life (RUL). 3.Use the piece wise RUL transformation code utility to apply piecewise RUL transformation only to the observed RUL column. 4.Generate a plot that compares the transformed RUL values and the predicted RUL values across time.", + "answer": "Two line charts with transformed RUL values in Knee pattern and predicted RUL values across time", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "hard", + "original_id": "10", + "source": "eval_set" + } +] \ No newline at end of file diff --git a/examples/asset_lifecycle_management/eval_data/eval_set_test.json b/examples/asset_lifecycle_management/eval_data/eval_set_test.json new file mode 100644 index 0000000..c691d46 --- /dev/null +++ b/examples/asset_lifecycle_management/eval_data/eval_set_test.json @@ -0,0 +1,42 @@ +[ + { + "id": "1", + "question": "In dataset test_FD001 plot operational_setting_3 as a function of time_in_cycles for units 10, 20, 30, 40", + "answer": "Line chart showing operational_setting_3 values on y-axis ranging against time_in_cycles in x-axis with a constant line parallel to x-axis with value y-axis as 100", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "medium", + "original_id": "6", + "source": "eval_set" + }, + { + "id": "2", + "question": "What is the ground truth remaining useful life (RUL) of unit_number 59 in dataset test FD001", + "answer": "114 " + }, + { + "id": "2", + "question": "What is the ground truth RUL of unit_number 20 in dataset test FD001", + "answer": "16 " + }, + { + "id": "3", + "question": "In dataset train_FD004, plot sensor_measurement1 vs time_in_cycles for unit_number 107", + "answer": "Line chart showing sensor_measurement1 values on y-axis ranging from 445.00 to 518.67 plotted against time_in_cycles in x-axis for unit 107 in dataset FD004.", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "easy", + "original_id": "1", + "source": "eval_set" + }, + { + "id": "4", + "question": "Retrieve and detect anomalies in sensor 4 measurements for engine number 78.", + "answer": "A Plot showing observed values and anomalies in sensor 4 measurements for engine number 78", + "type": "text_plus_plot", + "category": "visualization", + "subcategory": "hard", + "original_id": "9", + "source": "eval_set" + } +] \ No newline at end of file diff --git a/examples/asset_lifecycle_management/output_data/utils/__init__.py b/examples/asset_lifecycle_management/output_data/utils/__init__.py new file mode 100644 index 0000000..d5528a3 --- /dev/null +++ b/examples/asset_lifecycle_management/output_data/utils/__init__.py @@ -0,0 +1,11 @@ +""" +Workspace utilities for Asset Lifecycle Management tasks. + +These pre-built utility functions provide reliable, tested implementations +for common data processing tasks, particularly for predictive maintenance workflows. +""" + +from .rul_utils import apply_piecewise_rul_transformation, show_utilities + +__all__ = ['apply_piecewise_rul_transformation', 'show_utilities'] + diff --git a/examples/asset_lifecycle_management/output_data/utils/rul_utils.py b/examples/asset_lifecycle_management/output_data/utils/rul_utils.py new file mode 100644 index 0000000..afebd9f --- /dev/null +++ b/examples/asset_lifecycle_management/output_data/utils/rul_utils.py @@ -0,0 +1,146 @@ +""" +RUL (Remaining Useful Life) transformation utilities. + +Provides pre-built functions for transforming RUL data to create realistic patterns +for Asset Lifecycle Management and predictive maintenance tasks. +""" + +import pandas as pd +import logging + +logger = logging.getLogger(__name__) + + +def apply_piecewise_rul_transformation( + df: pd.DataFrame, + maxlife: int = 100, + time_col: str = 'time_in_cycles', + rul_col: str = 'RUL' +) -> pd.DataFrame: + """ + Transform RUL data to create realistic "knee" patterns. + + This function applies a piecewise transformation to RUL (Remaining Useful Life) values + to create a more realistic degradation pattern commonly seen in predictive maintenance: + - RUL stays constant at MAXLIFE until the remaining cycles drop below the threshold + - Then RUL decreases linearly to 0 as the equipment approaches failure + + This creates the characteristic "knee" pattern seen in actual equipment degradation. + + Args: + df: pandas DataFrame with time series data containing RUL values + maxlife: Maximum life threshold for the piecewise function (default: 100) + RUL values above this will be capped at maxlife + time_col: Name of the time/cycle column (default: 'time_in_cycles') + rul_col: Name of the RUL column to transform (default: 'RUL') + + Returns: + pandas DataFrame with original data plus new 'transformed_RUL' column + + Raises: + ValueError: If required columns are missing from the DataFrame + + Example: + >>> df = pd.DataFrame({'time_in_cycles': [1, 2, 3], 'RUL': [150, 100, 50]}) + >>> df_transformed = apply_piecewise_rul_transformation(df, maxlife=100) + >>> print(df_transformed['transformed_RUL']) + 0 100 + 1 100 + 2 50 + Name: transformed_RUL, dtype: int64 + """ + # Validate inputs + if not isinstance(df, pd.DataFrame): + raise ValueError(f"Expected pandas DataFrame, got {type(df)}") + + if rul_col not in df.columns: + raise ValueError( + f"RUL column '{rul_col}' not found in DataFrame. " + f"Available columns: {list(df.columns)}" + ) + + if time_col not in df.columns: + logger.warning( + f"Time column '{time_col}' not found in DataFrame, but continuing anyway. " + f"Available columns: {list(df.columns)}" + ) + + # Create a copy to avoid modifying the original + df_copy = df.copy() + + logger.info(f"Applying piecewise RUL transformation with maxlife={maxlife}") + logger.debug(f"Input RUL range: [{df_copy[rul_col].min()}, {df_copy[rul_col].max()}]") + + # Apply piecewise transformation + def transform_rul(rul_value): + """Apply the piecewise transformation to a single RUL value.""" + if pd.isna(rul_value): + return rul_value # Keep NaN values as NaN + if rul_value > maxlife: + return maxlife + return rul_value + + # Apply transformation to create new column + df_copy['transformed_RUL'] = df_copy[rul_col].apply(transform_rul) + + logger.info( + f"✅ Transformation complete! Added 'transformed_RUL' column. " + f"Output range: [{df_copy['transformed_RUL'].min()}, {df_copy['transformed_RUL'].max()}]" + ) + logger.debug(f"Total rows processed: {len(df_copy)}") + + return df_copy + + +def show_utilities(): + """ + Display available utility functions and their usage. + + Prints a formatted list of all available utilities in this workspace, + including descriptions and example usage. + """ + utilities_info = """ + ================================================================================ + WORKSPACE UTILITIES - Asset Lifecycle Management + ================================================================================ + + Available utility functions: + + 1. apply_piecewise_rul_transformation(df, maxlife=100, time_col='time_in_cycles', rul_col='RUL') + + Description: + Transforms RUL (Remaining Useful Life) data to create realistic "knee" patterns + commonly seen in predictive maintenance scenarios. + + Parameters: + - df: pandas DataFrame with time series data + - maxlife: Maximum life threshold (default: 100) + - time_col: Name of time/cycle column (default: 'time_in_cycles') + - rul_col: Name of RUL column to transform (default: 'RUL') + + Returns: + DataFrame with original data plus new 'transformed_RUL' column + + Example: + df_transformed = utils.apply_piecewise_rul_transformation(df, maxlife=100) + print(df_transformed[['time_in_cycles', 'RUL', 'transformed_RUL']]) + + 2. show_utilities() + + Description: + Displays this help message with all available utilities. + + Example: + utils.show_utilities() + + ================================================================================ + """ + print(utilities_info) + + +if __name__ == "__main__": + # Simple test + print("RUL Utilities Module") + print("=" * 50) + show_utilities() + diff --git a/examples/asset_lifecycle_management/pyproject.toml b/examples/asset_lifecycle_management/pyproject.toml new file mode 100644 index 0000000..8124477 --- /dev/null +++ b/examples/asset_lifecycle_management/pyproject.toml @@ -0,0 +1,83 @@ +[build-system] +build-backend = "setuptools.build_meta" +requires = ["setuptools >= 64"] + +[project] +name = "nat_alm_agent" +dynamic = ["version"] +dependencies = [ + "nvidia-nat[profiling, langchain, telemetry]>=1.3.0", + "momentfm", + "vanna==0.7.9", + "chromadb", + "sqlalchemy>=2.0.0", + "xgboost", + "matplotlib", + "torch", + "pytest", + "pytest-asyncio" +] +requires-python = ">=3.11,<3.13" +description = "Asset Lifecycle Management workflow using NeMo Agent Toolkit for comprehensive industrial asset management from acquisition through retirement" +classifiers = ["Programming Language :: Python"] +authors = [{ name = "Vineeth Kalluru" }] +maintainers = [{ name = "NVIDIA Corporation" }] + +[project.optional-dependencies] +elasticsearch = [ + "elasticsearch>=8.0.0" +] +postgres = [ + "psycopg2-binary>=2.9.0" +] +mysql = [ + "pymysql>=1.0.0" +] +sqlserver = [ + "pyodbc>=4.0.0" +] +oracle = [ + "cx_Oracle>=8.0.0" +] +e2b = [ + "e2b-code-interpreter>=0.2.0" +] +all-databases = [ + "psycopg2-binary>=2.9.0", + "pymysql>=1.0.0", + "pyodbc>=4.0.0", + "cx_Oracle>=8.0.0" +] +all = [ + "elasticsearch>=8.0.0", + "psycopg2-binary>=2.9.0", + "pymysql>=1.0.0", + "pyodbc>=4.0.0", + "cx_Oracle>=8.0.0", + "e2b-code-interpreter>=0.2.0" +] + +[project.entry-points.'nat.components'] +nat_alm_agent = "nat_alm_agent.register" + +[tool.uv.sources] +momentfm = { path = "./moment", editable = true } + +[tool.setuptools] +packages = ["nat_alm_agent"] +package-dir = {"" = "src"} + +[tool.setuptools.dynamic] +version = {attr = "nat_alm_agent.__version__"} + +[tool.pytest.ini_options] +asyncio_mode = "auto" +markers = [ + "e2e: end-to-end tests that run full workflows", +] +testpaths = [ + ".", +] +python_files = ["test_*.py", "*_test.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] diff --git a/examples/asset_lifecycle_management/setup_database.py b/examples/asset_lifecycle_management/setup_database.py new file mode 100644 index 0000000..972ebf1 --- /dev/null +++ b/examples/asset_lifecycle_management/setup_database.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +NASA Turbofan Engine Dataset to SQLite Database Converter + +This script converts the NASA Turbofan Engine Degradation Simulation Dataset (C-MAPSS) +from text files into a structured SQLite database for use with the Asset Lifecycle Management agent. + +The NASA dataset contains: +- Training data: Engine run-to-failure trajectories +- Test data: Engine trajectories of unknown remaining cycles +- RUL data: Ground truth remaining useful life values + +Dataset structure: +- unit_number: Engine unit identifier +- time_in_cycles: Operational time cycles +- operational_setting_1, 2, 3: Operating conditions +- sensor_measurement_1 to 21: Sensor readings +""" + +import sqlite3 +import pandas as pd +import numpy as np +import os +from pathlib import Path +import logging + +# Set up logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class NASADatasetProcessor: + """Processes NASA Turbofan Engine Dataset and creates SQLite database.""" + + def __init__(self, data_dir: str = "data", db_path: str = "database/nasa_turbo.db"): + """ + Initialize the processor. + + Args: + data_dir: Directory containing NASA dataset text files + db_path: Path where SQLite database will be created + """ + self.data_dir = Path(data_dir) + self.db_path = Path(db_path) + + # Ensure database directory exists + self.db_path.parent.mkdir(exist_ok=True) + + # Define column names for the dataset + self.columns = [ + 'unit_number', 'time_in_cycles', + 'operational_setting_1', 'operational_setting_2', 'operational_setting_3', + 'sensor_measurement_1', 'sensor_measurement_2', 'sensor_measurement_3', + 'sensor_measurement_4', 'sensor_measurement_5', 'sensor_measurement_6', + 'sensor_measurement_7', 'sensor_measurement_8', 'sensor_measurement_9', + 'sensor_measurement_10', 'sensor_measurement_11', 'sensor_measurement_12', + 'sensor_measurement_13', 'sensor_measurement_14', 'sensor_measurement_15', + 'sensor_measurement_16', 'sensor_measurement_17', 'sensor_measurement_18', + 'sensor_measurement_19', 'sensor_measurement_20', 'sensor_measurement_21' + ] + + # Sensor descriptions for metadata + self.sensor_descriptions = { + 'sensor_measurement_1': 'Total temperature at fan inlet (°R)', + 'sensor_measurement_2': 'Total temperature at LPC outlet (°R)', + 'sensor_measurement_3': 'Total temperature at HPC outlet (°R)', + 'sensor_measurement_4': 'Total temperature at LPT outlet (°R)', + 'sensor_measurement_5': 'Pressure at fan inlet (psia)', + 'sensor_measurement_6': 'Total pressure in bypass-duct (psia)', + 'sensor_measurement_7': 'Total pressure at HPC outlet (psia)', + 'sensor_measurement_8': 'Physical fan speed (rpm)', + 'sensor_measurement_9': 'Physical core speed (rpm)', + 'sensor_measurement_10': 'Engine pressure ratio (P50/P2)', + 'sensor_measurement_11': 'Static pressure at HPC outlet (psia)', + 'sensor_measurement_12': 'Ratio of fuel flow to Ps30 (pps/psi)', + 'sensor_measurement_13': 'Corrected fan speed (rpm)', + 'sensor_measurement_14': 'Corrected core speed (rpm)', + 'sensor_measurement_15': 'Bypass Ratio', + 'sensor_measurement_16': 'Burner fuel-air ratio', + 'sensor_measurement_17': 'Bleed Enthalpy', + 'sensor_measurement_18': 'Required fan speed', + 'sensor_measurement_19': 'Required fan conversion speed', + 'sensor_measurement_20': 'High-pressure turbines Cool air flow', + 'sensor_measurement_21': 'Low-pressure turbines Cool air flow' + } + + def read_data_file(self, file_path: Path) -> pd.DataFrame: + """ + Read a NASA dataset text file and return as DataFrame. + + Args: + file_path: Path to the text file + + Returns: + DataFrame with proper column names + """ + try: + # Read space-separated text file + df = pd.read_csv(file_path, sep='\s+', header=None, names=self.columns) + logger.info(f"Loaded {len(df)} records from {file_path.name}") + return df + except Exception as e: + logger.error(f"Error reading {file_path}: {e}") + return pd.DataFrame() + + def process_training_data(self, conn: sqlite3.Connection): + """Process training data files and create database tables.""" + logger.info("Processing training data...") + + training_files = [ + 'train_FD001.txt', 'train_FD002.txt', 'train_FD003.txt', 'train_FD004.txt' + ] + + for file_name in training_files: + file_path = self.data_dir / file_name + if file_path.exists(): + df = self.read_data_file(file_path) + if not df.empty: + # Calculate RUL for training data (max cycle - current cycle) + df['RUL'] = df.groupby('unit_number')['time_in_cycles'].transform('max') - df['time_in_cycles'] + + # Create separate table for each dataset (e.g., train_FD001) + table_name = file_name.replace('.txt', '') + df.to_sql(table_name, conn, if_exists='replace', index=False) + logger.info(f"Created {table_name} table with {len(df)} records") + else: + logger.warning(f"Training file not found: {file_path}") + + def process_test_data(self, conn: sqlite3.Connection): + """Process test data files and create database tables.""" + logger.info("Processing test data...") + + test_files = [ + 'test_FD001.txt', 'test_FD002.txt', 'test_FD003.txt', 'test_FD004.txt' + ] + + for file_name in test_files: + file_path = self.data_dir / file_name + if file_path.exists(): + df = self.read_data_file(file_path) + if not df.empty: + # Create separate table for each dataset (e.g., test_FD001) + table_name = file_name.replace('.txt', '') + df.to_sql(table_name, conn, if_exists='replace', index=False) + logger.info(f"Created {table_name} table with {len(df)} records") + else: + logger.warning(f"Test file not found: {file_path}") + + def process_rul_data(self, conn: sqlite3.Connection): + """Process RUL (Remaining Useful Life) data files.""" + logger.info("Processing RUL data...") + + rul_files = [ + 'RUL_FD001.txt', 'RUL_FD002.txt', 'RUL_FD003.txt', 'RUL_FD004.txt' + ] + + for file_name in rul_files: + file_path = self.data_dir / file_name + if file_path.exists(): + try: + # RUL files contain one RUL value per line for each test engine + rul_values = pd.read_csv(file_path, header=None, names=['RUL']) + rul_values['unit_number'] = range(1, len(rul_values) + 1) + + # Create separate table for each dataset (e.g., RUL_FD001) + table_name = file_name.replace('.txt', '') + rul_values[['unit_number', 'RUL']].to_sql(table_name, conn, if_exists='replace', index=False) + logger.info(f"Created {table_name} table with {len(rul_values)} records") + except Exception as e: + logger.error(f"Error reading RUL file {file_path}: {e}") + else: + logger.warning(f"RUL file not found: {file_path}") + + def create_metadata_tables(self, conn: sqlite3.Connection): + """Create metadata tables with sensor descriptions and dataset information.""" + logger.info("Creating metadata tables...") + + # Sensor metadata + sensor_metadata = pd.DataFrame([ + {'sensor_name': sensor, 'description': desc} + for sensor, desc in self.sensor_descriptions.items() + ]) + sensor_metadata.to_sql('sensor_metadata', conn, if_exists='replace', index=False) + + # Dataset metadata + dataset_metadata = pd.DataFrame([ + {'dataset': 'FD001', 'description': 'Sea level conditions', 'fault_modes': 1}, + {'dataset': 'FD002', 'description': 'Sea level conditions', 'fault_modes': 6}, + {'dataset': 'FD003', 'description': 'High altitude conditions', 'fault_modes': 1}, + {'dataset': 'FD004', 'description': 'High altitude conditions', 'fault_modes': 6} + ]) + dataset_metadata.to_sql('dataset_metadata', conn, if_exists='replace', index=False) + + logger.info("Created metadata tables") + + def create_indexes(self, conn: sqlite3.Connection): + """Create database indexes for better query performance.""" + logger.info("Creating database indexes...") + + datasets = ['FD001', 'FD002', 'FD003', 'FD004'] + indexes = [] + + # Create indexes for each dataset's tables + for dataset in datasets: + indexes.extend([ + f"CREATE INDEX IF NOT EXISTS idx_train_{dataset}_unit ON train_{dataset}(unit_number)", + f"CREATE INDEX IF NOT EXISTS idx_train_{dataset}_cycle ON train_{dataset}(time_in_cycles)", + f"CREATE INDEX IF NOT EXISTS idx_test_{dataset}_unit ON test_{dataset}(unit_number)", + f"CREATE INDEX IF NOT EXISTS idx_test_{dataset}_cycle ON test_{dataset}(time_in_cycles)", + f"CREATE INDEX IF NOT EXISTS idx_RUL_{dataset}_unit ON RUL_{dataset}(unit_number)" + ]) + + for index_sql in indexes: + try: + conn.execute(index_sql) + except Exception as e: + logger.warning(f"Failed to create index: {e}") + + conn.commit() + logger.info("Created database indexes") + + def create_views(self, conn: sqlite3.Connection): + """Create useful database views for common queries.""" + logger.info("Creating database views...") + + # View for latest sensor readings per engine + latest_readings_view = """ + CREATE VIEW IF NOT EXISTS latest_sensor_readings AS + SELECT t1.* + FROM training_data t1 + INNER JOIN ( + SELECT unit_number, dataset, MAX(time_in_cycles) as max_cycle + FROM training_data + GROUP BY unit_number, dataset + ) t2 ON t1.unit_number = t2.unit_number + AND t1.dataset = t2.dataset + AND t1.time_in_cycles = t2.max_cycle + """ + + # View for engine health summary + engine_health_view = """ + CREATE VIEW IF NOT EXISTS engine_health_summary AS + SELECT + unit_number, + dataset, + MAX(time_in_cycles) as total_cycles, + MIN(RUL) as final_rul, + AVG(sensor_measurement_1) as avg_fan_inlet_temp, + AVG(sensor_measurement_11) as avg_hpc_outlet_pressure, + AVG(sensor_measurement_21) as avg_lpt_cool_air_flow + FROM training_data + GROUP BY unit_number, dataset + """ + + conn.execute(latest_readings_view) + conn.execute(engine_health_view) + conn.commit() + logger.info("Created database views") + + def validate_database(self, conn: sqlite3.Connection): + """Validate the created database by running sample queries.""" + logger.info("Validating database...") + + validation_queries = [ + ("Training data count", "SELECT COUNT(*) FROM training_data"), + ("Test data count", "SELECT COUNT(*) FROM test_data"), + ("RUL data count", "SELECT COUNT(*) FROM rul_data"), + ("Unique engines in training", "SELECT COUNT(DISTINCT unit_number) FROM training_data"), + ("Datasets available", "SELECT DISTINCT dataset FROM training_data"), + ] + + for description, query in validation_queries: + try: + result = conn.execute(query).fetchone() + logger.info(f"{description}: {result[0] if isinstance(result[0], (int, float)) else result}") + except Exception as e: + logger.error(f"Validation query failed - {description}: {e}") + + def process_dataset(self): + """Main method to process the entire NASA dataset.""" + logger.info(f"Starting NASA dataset processing...") + logger.info(f"Data directory: {self.data_dir.absolute()}") + logger.info(f"Database path: {self.db_path.absolute()}") + + # Check if data directory exists + if not self.data_dir.exists(): + logger.error(f"Data directory not found: {self.data_dir}") + logger.info("Please download the NASA Turbofan Engine Degradation Simulation Dataset") + logger.info("and place the text files in the 'data' directory") + return False + + try: + # Connect to SQLite database + with sqlite3.connect(self.db_path) as conn: + logger.info(f"Connected to database: {self.db_path}") + + # Process all data files + self.process_training_data(conn) + self.process_test_data(conn) + self.process_rul_data(conn) + self.create_metadata_tables(conn) + self.create_indexes(conn) + self.create_views(conn) + + # Validate the database + self.validate_database(conn) + + logger.info("Database processing completed successfully!") + return True + + except Exception as e: + logger.error(f"Error processing database: {e}") + return False + +def main(): + """Main function to run the database setup.""" + import argparse + + parser = argparse.ArgumentParser(description="Convert NASA Turbofan Dataset to SQLite") + parser.add_argument("--data-dir", default="data", + help="Directory containing NASA dataset text files") + parser.add_argument("--db-path", default="database/nasa_turbo.db", + help="Path for output SQLite database") + + args = parser.parse_args() + + processor = NASADatasetProcessor(args.data_dir, args.db_path) + success = processor.process_dataset() + + if success: + print(f"\n✅ Database created successfully at: {args.db_path}") + print("\nDatabase contains the following tables:") + print("- training_data: Engine run-to-failure trajectories") + print("- test_data: Engine test trajectories") + print("- rul_data: Ground truth RUL values") + print("- sensor_metadata: Sensor descriptions") + print("- dataset_metadata: Dataset information") + print("\nUseful views created:") + print("- latest_sensor_readings: Latest readings per engine") + print("- engine_health_summary: Engine health statistics") + else: + print("\n❌ Database creation failed. Check the logs above.") + return 1 + + return 0 + +if __name__ == "__main__": + exit(main()) \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/__init__.py b/examples/asset_lifecycle_management/src/nat_alm_agent/__init__.py new file mode 100644 index 0000000..f2fcadb --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "2.0.0" diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/code_execution/__init__.py b/examples/asset_lifecycle_management/src/nat_alm_agent/code_execution/__init__.py new file mode 100644 index 0000000..3c98273 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/code_execution/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Code execution package for E2B cloud sandbox integration. +""" + +from .e2b_code_execution_tool import e2b_code_execution_tool + +__all__ = ["e2b_code_execution_tool"] diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/code_execution/e2b_code_execution_tool.py b/examples/asset_lifecycle_management/src/nat_alm_agent/code_execution/e2b_code_execution_tool.py new file mode 100644 index 0000000..5da5040 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/code_execution/e2b_code_execution_tool.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pydantic import BaseModel, Field + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig + +logger = logging.getLogger(__name__) + + +class E2BCodeExecutionToolConfig(FunctionBaseConfig, name="e2b_code_execution"): + """ + Tool for executing Python code in E2B cloud-hosted sandbox environment. + + E2B provides: + - Cloud-hosted execution (no Docker setup required) + - Automatic workspace management + - File upload/download capabilities + - Fast sandbox startup (~150ms) + + Use this instead of local Docker sandbox when: + - Docker is not available or desired + - Running in CI/CD environments + - Multi-user or production deployments + - Network access is available + """ + + e2b_api_key: str = Field( + description="E2B API key (get from https://e2b.dev/dashboard). Can use ${E2B_API_KEY} env var" + ) + workspace_files_dir: str = Field( + description="Path to local workspace directory for file uploads/downloads (e.g., 'output_data')" + ) + timeout: float = Field( + default=30.0, + description="Timeout in seconds for code execution (E2B needs more time for file transfers)" + ) + max_output_characters: int = Field( + default=2000, + description="Maximum number of characters in stdout/stderr" + ) + + +@register_function(config_type=E2BCodeExecutionToolConfig) +async def e2b_code_execution_tool(config: E2BCodeExecutionToolConfig, builder: Builder): + """ + Execute Python code in E2B cloud sandbox. + + This tool: + 1. Creates ephemeral E2B sandbox + 2. Uploads workspace files (utils/, database/, etc.) + 3. Executes the provided code + 4. Downloads generated files + 5. Returns execution results + + The sandbox is automatically cleaned up after execution. + """ + from .e2b_sandbox import E2BSandbox + + class CodeExecutionInputSchema(BaseModel): + generated_code: str = Field(description="Python code to execute in E2B sandbox") + + # Create E2B sandbox instance + sandbox = E2BSandbox( + api_key=config.e2b_api_key, + workspace_files_dir=config.workspace_files_dir, + timeout=config.timeout + ) + + logger.info("E2B code execution tool initialized") + + async def _execute_code(generated_code: str) -> dict: + """ + Execute code in E2B cloud sandbox. + + Args: + generated_code: Python code to execute + + Returns: + Dictionary containing: + - process_status: "completed", "error", or "timeout" + - stdout: Standard output + - stderr: Standard error + - downloaded_files: List of downloaded file paths + """ + logger.info("Executing code in E2B cloud sandbox...") + + try: + result = await sandbox.execute_code( + generated_code=generated_code, + timeout_seconds=config.timeout, + language="python", + max_output_characters=config.max_output_characters, + ) + + # Log downloaded files + if result.get("downloaded_files"): + logger.info(f"Downloaded {len(result['downloaded_files'])} files from E2B sandbox") + for file_path in result["downloaded_files"]: + logger.debug(f" - {file_path}") + + return result + + except Exception as e: + logger.exception(f"Error executing code in E2B sandbox: {e}") + return { + "process_status": "error", + "stdout": "", + "stderr": f"Execution error: {str(e)}", + "downloaded_files": [] + } + + yield FunctionInfo.from_fn( + fn=_execute_code, + input_schema=CodeExecutionInputSchema, + description="""Executes the provided Python code in an E2B cloud-hosted sandbox. + + E2B provides isolated cloud execution without requiring Docker: + - Automatic workspace setup with utils and database files + - File upload/download for inputs and outputs + - Fast sandbox startup (~150ms) + - Secure execution environment + + Returns a dictionary with execution status, stdout, stderr, and list of downloaded files. + Generated files are automatically downloaded to the local workspace directory.""" + ) diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/code_execution/e2b_sandbox.py b/examples/asset_lifecycle_management/src/nat_alm_agent/code_execution/e2b_sandbox.py new file mode 100644 index 0000000..ee59d81 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/code_execution/e2b_sandbox.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + + +class E2BSandbox: + """ + E2B cloud-hosted sandbox with file transfer capabilities. + + This sandbox provides: + - Cloud-hosted Python execution (no Docker required) + - Automatic workspace setup with utils and database files + - File upload/download for inputs and outputs + - Isolated execution environment + """ + + def __init__( + self, + api_key: str, + workspace_files_dir: str, + timeout: float = 30.0 + ): + """ + Initialize E2B sandbox. + + Args: + api_key: E2B API key (will be set as E2B_API_KEY environment variable) + workspace_files_dir: Local directory path for file transfers (e.g., "output_data") + timeout: Default timeout in seconds for sandbox operations + """ + # E2B SDK v2.x requires API key as environment variable + if api_key: + os.environ['E2B_API_KEY'] = api_key + self.workspace_dir = Path(workspace_files_dir) + self.timeout = timeout + + logger.info(f"E2B Sandbox initialized with workspace: {self.workspace_dir}") + + def _setup_workspace(self, sandbox) -> None: + """ + Upload workspace files to E2B sandbox. + + Uploads to /home/user in the sandbox: + 1. utils/ directory - Pre-built utility functions (uploaded to /home/user/utils/) + 2. database/ - SQLite database file (uploaded to /home/user/database/nasa_turbo.db) + + Args: + sandbox: Active E2B sandbox instance + """ + logger.info("Setting up E2B sandbox workspace...") + + # 1. Upload utils directory + utils_path = self.workspace_dir / "utils" + if utils_path.exists() and utils_path.is_dir(): + logger.info(f"Uploading utils from {utils_path}") + + # Upload each Python file in utils using files.write() API + for file_path in utils_path.glob("*.py"): + with open(file_path, 'r') as f: + content = f.read() + target_path = f"/home/user/utils/{file_path.name}" + sandbox.files.write(target_path, content) + logger.debug(f"Uploaded {file_path.name} to {target_path}") + else: + logger.warning(f"Utils directory not found at {utils_path}") + + # 2. Upload database file + db_path = self.workspace_dir.parent / "database" / "nasa_turbo.db" + if db_path.exists(): + logger.info(f"Uploading database from {db_path}") + + # Read database as bytes and upload using files.write() API + with open(db_path, 'rb') as f: + sandbox.files.write("/home/user/database/nasa_turbo.db", f) + logger.debug(f"Uploaded database ({db_path.stat().st_size} bytes)") + else: + logger.warning(f"Database not found at {db_path}") + + logger.info("Workspace setup complete") + + def _download_outputs(self, sandbox, output_extensions: tuple = ('.json', '.html', '.png', '.jpg', '.csv', '.pdf')) -> list[str]: + """ + Download generated files from sandbox to local filesystem. + + Args: + sandbox: Active E2B sandbox instance + output_extensions: Tuple of file extensions to download + + Returns: + List of local file paths that were downloaded + """ + logger.info("Downloading output files from E2B sandbox...") + downloaded_files = [] + + try: + # List all files in /home/user directory using files.list() API + files = sandbox.files.list("/home/user") + + for file_info in files: + file_name = file_info.name if hasattr(file_info, 'name') else str(file_info) + + # Skip directories and non-output files + if not any(file_name.endswith(ext) for ext in output_extensions): + continue + + # Skip files that are in subdirectories (utils, database) + if '/' in file_name: + continue + + try: + # Read file content from sandbox using files.read() API + sandbox_path = f"/home/user/{file_name}" + content = sandbox.files.read(sandbox_path) + + # Write to local filesystem + local_path = self.workspace_dir / file_name + + # Handle both text and binary content + if isinstance(content, bytes): + local_path.write_bytes(content) + else: + local_path.write_text(content) + + downloaded_files.append(str(local_path)) + logger.debug(f"Downloaded {file_name} to {local_path}") + + except Exception as e: + logger.error(f"Failed to download {file_name}: {e}") + + logger.info(f"Downloaded {len(downloaded_files)} files") + + except Exception as e: + logger.error(f"Error listing/downloading files: {e}") + + return downloaded_files + + async def execute_code( + self, + generated_code: str, + timeout_seconds: float = 10.0, + language: str = "python", + max_output_characters: int = 2000, + ) -> dict[str, str]: + """ + Execute code in E2B cloud sandbox. + + Args: + generated_code: Python code to execute + timeout_seconds: Execution timeout + language: Programming language (currently only "python") + max_output_characters: Maximum characters in output + + Returns: + Dictionary with: + - process_status: "completed", "error", or "timeout" + - stdout: Standard output from execution + - stderr: Standard error from execution + - downloaded_files: List of downloaded file paths (E2B-specific) + """ + if language != "python": + return { + "process_status": "error", + "stdout": "", + "stderr": f"Language '{language}' not supported. E2B sandbox only supports Python.", + "downloaded_files": [] + } + + logger.info("Executing code in E2B cloud sandbox...") + + try: + # Import E2B SDK + try: + from e2b_code_interpreter import Sandbox + except ImportError: + return { + "process_status": "error", + "stdout": "", + "stderr": "E2B SDK not installed. Run: pip install e2b-code-interpreter", + "downloaded_files": [] + } + + # Create E2B sandbox using Sandbox.create() method + # Note: E2B SDK v2.x reads API key from E2B_API_KEY environment variable + # The timeout parameter is for sandbox lifecycle, not code execution timeout + with Sandbox.create() as sandbox: + logger.debug("E2B sandbox created successfully") + + # Setup workspace (upload utils, database, etc.) + self._setup_workspace(sandbox) + + # Execute the code + logger.debug(f"Executing code ({len(generated_code)} chars)...") + execution = sandbox.run_code(generated_code) + + # Parse execution results + stdout = "" + stderr = "" + status = "completed" + + # Extract output from execution object + if hasattr(execution, 'logs'): + stdout = execution.logs.stdout if hasattr(execution.logs, 'stdout') else "" + stderr = execution.logs.stderr if hasattr(execution.logs, 'stderr') else "" + + # Check for text output + if hasattr(execution, 'text') and execution.text: + stdout += str(execution.text) + + # Check for errors + if hasattr(execution, 'error') and execution.error: + status = "error" + stderr += str(execution.error) + + # Download generated files + downloaded_files = self._download_outputs(sandbox) + + # Truncate output if needed + if len(stdout) > max_output_characters: + stdout = stdout[:max_output_characters] + "\n" + + if len(stderr) > max_output_characters: + stderr = stderr[:max_output_characters] + "\n" + + logger.info(f"Execution {status}: {len(downloaded_files)} files downloaded") + + return { + "process_status": status, + "stdout": stdout, + "stderr": stderr, + "downloaded_files": downloaded_files + } + + except Exception as e: + logger.exception(f"E2B sandbox execution failed: {e}") + return { + "process_status": "error", + "stdout": "", + "stderr": f"E2B sandbox error: {str(e)}", + "downloaded_files": [] + } diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/__init__.py b/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/__init__.py new file mode 100644 index 0000000..81d7ca0 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Evaluators package for Asset Lifecycle Management agent. + +This package contains evaluator implementations for assessing the quality +of responses from the Asset Lifecycle Management agent workflow. +""" + +from .llm_judge_evaluator import LLMJudgeEvaluator +from .multimodal_llm_judge_evaluator import MultimodalLLMJudgeEvaluator + +__all__ = [ + "LLMJudgeEvaluator", + "MultimodalLLMJudgeEvaluator", +] \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/llm_judge_evaluator.py b/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/llm_judge_evaluator.py new file mode 100644 index 0000000..56f4954 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/llm_judge_evaluator.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +from typing import Any, Dict, Union + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import HumanMessage +from langchain_core.prompts import ChatPromptTemplate + +from nat.eval.evaluator.base_evaluator import BaseEvaluator +from nat.eval.evaluator.evaluator_model import EvalInputItem, EvalOutputItem + +logger = logging.getLogger(__name__) + + +class LLMJudgeEvaluator(BaseEvaluator): + """ + LLM-as-a-Judge evaluator that uses a large language model to evaluate + how well the generated response matches the reference answer. + """ + + def __init__( + self, + llm: BaseChatModel, + judge_prompt: str, + max_concurrency: int = 4, + ): + super().__init__(max_concurrency=max_concurrency, tqdm_desc="LLM Judge Evaluating") + self.llm = llm + self.judge_prompt = judge_prompt + + # Create the prompt template + self.prompt_template = ChatPromptTemplate.from_template(self.judge_prompt) + logger.debug("LLM Judge evaluator initialized with custom prompt.") + + async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: + """ + Evaluate a single EvalInputItem using LLM-as-a-judge. + + The judge_prompt should contain placeholders for: + - {question}: The original question/input + - {reference_answer}: The expected/reference answer + - {generated_answer}: The model's generated answer + + The LLM should return a JSON object with 'score' and 'reasoning' fields. + """ + question = str(item.input_obj) if item.input_obj else "" + reference_answer = str(item.expected_output_obj) if item.expected_output_obj else "" + generated_answer = str(item.output_obj) if item.output_obj else "" + + try: + # Format the prompt with the actual values + messages = self.prompt_template.format_messages( + question=question, + reference_answer=reference_answer, + generated_answer=generated_answer + ) + + # Get LLM response + response = await self.llm.ainvoke(messages) + response_text = response.content + + # Try to parse the response as JSON + try: + import json + import re + + # First try to parse as direct JSON + eval_result = json.loads(response_text) + + except json.JSONDecodeError: + # If direct JSON parsing fails, try to extract JSON from markdown code blocks + try: + # Look for JSON within markdown code blocks (```json or just ```) + json_pattern = r'```(?:json)?\s*(\{.*?\})\s*```' + json_match = re.search(json_pattern, response_text, re.DOTALL) + + if json_match: + json_str = json_match.group(1) + eval_result = json.loads(json_str) + else: + # If no code blocks found, fall back to text extraction + raise json.JSONDecodeError("No JSON code blocks found", "", 0) + + except json.JSONDecodeError: + # Final fallback to text-based score extraction + score = self._extract_score_from_text(response_text) + reasoning = response_text + eval_result = None + + # Process the parsed JSON result + if eval_result is not None: + if isinstance(eval_result, dict) and 'score' in eval_result: + score = eval_result.get('score', 0.0) + reasoning = eval_result.get('reasoning', response_text) + else: + # If not proper JSON format, try to extract score from text + score = self._extract_score_from_text(response_text) + reasoning = response_text + + # Ensure score is numeric and between 0 and 1 + if isinstance(score, (int, float)): + score = max(0.0, min(1.0, float(score))) + else: + score = 0.0 + reasoning = f"Could not parse score from LLM response: {response_text}" + + return EvalOutputItem( + id=item.id, + score=score, + reasoning={ + "question": question, + "reference_answer": reference_answer, + "generated_answer": generated_answer, + "llm_judgment": reasoning, + "raw_response": response_text + } + ) + + except Exception as e: + logger.exception("Error evaluating item %s: %s", item.id, e) + return EvalOutputItem( + id=item.id, + score=0.0, + reasoning={ + "error": f"LLM evaluation failed: {str(e)}", + "question": question, + "reference_answer": reference_answer, + "generated_answer": generated_answer + } + ) + + def _extract_score_from_text(self, text: str) -> float: + """ + Extract a numeric score from text response if JSON parsing fails. + Looks for patterns like "Score: 0.8" or "8/10" or "80%" + """ + import re + + # Try to find score patterns in the text + patterns = [ + r'"?score"?[:\s]*([0-9]*\.?[0-9]+)', # "score": 0.8, score: 0.8, or score 0.8 + r'([0-9]*\.?[0-9]+)[/\s]*10', # "8/10" or "8 out of 10" + r'([0-9]*\.?[0-9]+)%', # "80%" + r'([0-9]*\.?[0-9]+)[/\s]*100', # "80/100" or "80 out of 100" + ] + + for pattern in patterns: + match = re.search(pattern, text.lower()) + if match: + try: + value = float(match.group(1)) + + # Normalize different scales to 0-1 range + if '/10' in pattern: + return value / 10.0 + elif '%' in pattern or '/100' in pattern: + return value / 100.0 + else: + # Assume it's already in 0-1 range, but clamp it + return max(0.0, min(1.0, value)) + except ValueError: + continue + + # Default to 0.0 if no score found + logger.warning("Could not extract score from text: %s", text) + return 0.0 \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/llm_judge_evaluator_register.py b/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/llm_judge_evaluator_register.py new file mode 100644 index 0000000..3263fdc --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/llm_judge_evaluator_register.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic import Field + +from nat.builder.builder import EvalBuilder +from nat.builder.evaluator import EvaluatorInfo +from nat.cli.register_workflow import register_evaluator +from nat.data_models.evaluator import EvaluatorBaseConfig + + +class LLMJudgeEvaluatorConfig(EvaluatorBaseConfig, name="llm_judge"): + """Configuration for LLM-as-a-Judge evaluator.""" + + llm_name: str = Field(description="Name of the LLM to use as judge") + judge_prompt: str = Field( + description="Prompt template for the judge LLM. Should include {question}, {reference_answer}, and {generated_answer} placeholders", + default="""You are an expert evaluator for Asset Lifecycle Management systems. Your task is to evaluate how well a generated answer matches the reference answer for a given question. + +Question: {question} + +Reference Answer: {reference_answer} + +Generated Answer: {generated_answer} + +Please evaluate the generated answer against the reference answer considering: +1. Factual accuracy and correctness +2. Completeness of the response +3. Technical accuracy for Asset Lifecycle Management context +4. Relevance to the question asked + +Provide your evaluation as a JSON object with the following format: +{{ + "score": , + "reasoning": "" +}} + +The score should be: +- 1.0: Perfect match, completely accurate and complete +- 0.8-0.9: Very good, minor differences but essentially correct +- 0.6-0.7: Good, mostly correct with some inaccuracies or missing details +- 0.4-0.5: Fair, partially correct but with significant issues +- 0.2-0.3: Poor, mostly incorrect but some relevant information +- 0.0-0.1: Very poor, completely incorrect or irrelevant""" + ) + + +@register_evaluator(config_type=LLMJudgeEvaluatorConfig) +async def register_llm_judge_evaluator(config: LLMJudgeEvaluatorConfig, builder: EvalBuilder): + """Register the LLM Judge evaluator with NeMo Agent Toolkit.""" + from nat.builder.framework_enum import LLMFrameworkEnum + + from .llm_judge_evaluator import LLMJudgeEvaluator + + # Get the LLM instance + llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) + + # Create the evaluator instance + evaluator = LLMJudgeEvaluator( + llm=llm, + judge_prompt=config.judge_prompt, + max_concurrency=builder.get_max_concurrency() + ) + + yield EvaluatorInfo( + config=config, + evaluate_fn=evaluator.evaluate, + description="LLM-as-a-Judge Evaluator for Asset Lifecycle Management" + ) \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/multimodal_llm_judge_evaluator.py b/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/multimodal_llm_judge_evaluator.py new file mode 100644 index 0000000..4abc6ff --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/multimodal_llm_judge_evaluator.py @@ -0,0 +1,428 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Multimodal LLM Judge Evaluator + +An enhanced evaluator that uses llama-3.2-90b-instruct to evaluate both text and visual outputs +from agentic workflows. This evaluator is specifically designed for Asset Lifecycle Management +responses that may include plots and visualizations. +""" + +import asyncio +import logging +import os +import re +from typing import Any, Dict, Union, Optional +from pathlib import Path + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import HumanMessage +from langchain_core.prompts import ChatPromptTemplate + +from nat.eval.evaluator.base_evaluator import BaseEvaluator +from nat.eval.evaluator.evaluator_model import EvalInputItem, EvalOutputItem + +try: + from PIL import Image + import base64 + from io import BytesIO + HAS_PIL = True +except ImportError: + HAS_PIL = False + logging.warning("PIL not available. Image evaluation will be disabled.") + +logger = logging.getLogger(__name__) + + +class MultimodalLLMJudgeEvaluator(BaseEvaluator): + """ + Enhanced multimodal LLM Judge evaluator using llama-3.2-90b-instruct that can evaluate + responses containing both text and visual elements (plots). + + This evaluator automatically detects plot paths in responses and includes + visual analysis in the evaluation process using a unified prompt. + """ + + def __init__( + self, + llm: BaseChatModel, + judge_prompt: str, + max_concurrency: int = 4, + ): + super().__init__(max_concurrency=max_concurrency, tqdm_desc="Multimodal LLM Judge Evaluating") + self.llm = llm + self.judge_prompt = judge_prompt + + # Create the prompt template + self.prompt_template = ChatPromptTemplate.from_template(self.judge_prompt) + + logger.debug("Multimodal LLM Judge evaluator initialized.") + logger.debug(f"Model: llama-3.2-90b-instruct") + + @classmethod + def from_config( + cls, + llm: BaseChatModel, + judge_prompt: str, + max_concurrency: int = 4, + **kwargs + ): + """Create MultimodalLLMJudgeEvaluator from configuration parameters.""" + return cls( + llm=llm, + judge_prompt=judge_prompt, + max_concurrency=max_concurrency + ) + + async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: + """ + Evaluate a single EvalInputItem that may contain text and/or visual elements. + + This method uses a unified evaluation approach that handles both text-only + and text+visual responses with a single comprehensive prompt. + """ + question = str(item.input_obj) if item.input_obj else "" + reference_answer = str(item.expected_output_obj) if item.expected_output_obj else "" + generated_answer = str(item.output_obj) if item.output_obj else "" + + try: + # Check if the response contains plots + plot_paths = self._extract_plot_paths(generated_answer) + + # Use unified evaluation for both text-only and text+visual responses + return await self._evaluate_unified( + item, question, reference_answer, generated_answer, plot_paths + ) + + except Exception as e: + logger.exception("Error evaluating item %s: %s", item.id, e) + return EvalOutputItem( + id=item.id, + score=0.0, + reasoning={ + "error": f"Evaluation failed: {str(e)}", + "question": question, + "reference_answer": reference_answer, + "generated_answer": generated_answer, + "plot_paths": [], + "num_images_analyzed": 0, + "evaluation_type": "ERROR" + } + ) + + def _extract_plot_paths(self, response: str) -> list[str]: + """Extract all PNG file paths from the generated response.""" + plot_paths = [] + + # Look for PNG file paths in the response with improved patterns + png_patterns = [ + r'([^\s\[\]]+\.png)', # Original pattern but excluding brackets + r'([/][^\s\[\]]+\.png)', # Paths starting with / + r'([A-Za-z]:[^\s\[\]]+\.png)', # Windows paths starting with drive letter + r'file://([^\s\[\]]+\.png)', # file:// URLs + r'\[([^\[\]]+\.png)\]', # Paths inside square brackets + r'located at ([^\s]+\.png)', # "located at path.png" pattern + r'saved.*?([/][^\s]+\.png)', # "saved at /path.png" pattern + r'`([^`]+\.png)`', # Paths inside backticks + ] + + for pattern in png_patterns: + matches = re.findall(pattern, response) + for match in matches: + # Clean up the match - remove any trailing punctuation + clean_match = match.rstrip('.,;:!?)]') + # Check if the file actually exists + if os.path.exists(clean_match): + plot_paths.append(clean_match) + + # Also look for responses that mention plot/chart generation even if file doesn't exist + # This helps with cases where files are generated after response but before evaluation + plot_indicators = [ + r'plot.*generated', r'chart.*generated', r'histogram.*generated', + r'visualization.*generated', r'\.png.*generated', r'plot.*saved', + r'chart.*saved', r'saved.*\.png' + ] + + has_plot_indicator = any(re.search(indicator, response, re.IGNORECASE) + for indicator in plot_indicators) + + # If we detect plot generation language but no existing files, + # try to find PNG files in the output_data directory that might be related + if has_plot_indicator and not plot_paths: + output_dir = os.path.join(os.getcwd(), "output_data") + if os.path.exists(output_dir): + png_files = [f for f in os.listdir(output_dir) if f.endswith('.png')] + # Add the most recently modified PNG files + for png_file in png_files[-3:]: # Last 3 PNG files as a heuristic + full_path = os.path.join(output_dir, png_file) + plot_paths.append(full_path) + + return list(set(plot_paths)) # Remove duplicates + + async def _evaluate_unified( + self, + item: EvalInputItem, + question: str, + reference_answer: str, + generated_answer: str, + plot_paths: list[str] + ) -> EvalOutputItem: + """ + Unified evaluation method that handles both text-only and text+visual responses. + Uses a single comprehensive prompt that works for both scenarios. + """ + try: + # Load and encode images if plot paths are provided + image_data_list = [] + valid_plot_paths = [] + + if plot_paths and HAS_PIL: + for plot_path in plot_paths: + image_data = self._load_and_encode_image(plot_path) + if image_data: + image_data_list.append(image_data) + valid_plot_paths.append(plot_path) + + # Determine evaluation type based on whether we have valid images + has_visuals = len(image_data_list) > 0 + evaluation_type = "multimodal" if has_visuals else "text_only" + + logger.info(f"Evaluation for item {item.id}: has_visuals={has_visuals}, plot_paths={plot_paths}, valid_plot_paths={valid_plot_paths}, image_data_count={len(image_data_list)}") + + # Use the configured judge_prompt and add explicit evaluation mode instruction + prompt_text = self.judge_prompt.format( + question=question, + reference_answer=reference_answer, + generated_answer=generated_answer + ) + + # Add explicit instruction based on whether we have visuals + if has_visuals: + prompt_text += f"\n\n🚨 CRITICAL OVERRIDE 🚨\nYou can see {len(image_data_list)} plot image(s) attached to this message.\nYou MUST respond with 'EVALUATION TYPE: PLOT' and evaluate the attached images against the reference description.\nIGNORE any text analysis - focus ONLY on the visual plot content." + logger.info(f"Using PLOT evaluation mode for item {item.id} with {len(image_data_list)} images") + else: + prompt_text += "\n\n🚨 CRITICAL OVERRIDE 🚨\nNo images are attached to this message.\nYou MUST respond with 'EVALUATION TYPE: TEXT' and evaluate only the text content.\nDo NOT attempt plot evaluation." + logger.info(f"Using TEXT evaluation mode for item {item.id}") + + # Call LLM using LangChain + if has_visuals: + # Call with images using LangChain multimodal capability + response_text = await self._call_visual_api_langchain( + prompt_text, image_data_list + ) + else: + # Call without images (text-only) + response_text = await self._call_api_langchain( + question, reference_answer, generated_answer + ) + + # Parse the response + logger.info(f"LLM response for item {item.id}: {response_text[:200]}...") + score, reasoning = self._parse_evaluation_response(response_text) + + # Build reasoning object + reasoning_obj = { + "question": question, + "reference_answer": reference_answer, + "generated_answer": generated_answer, + "llm_judgment": reasoning, + "plot_paths": valid_plot_paths, + "num_images_analyzed": len(image_data_list), + "evaluation_type": "PLOT" if has_visuals else "TEXT" + } + + return EvalOutputItem( + id=item.id, + score=score, + reasoning=reasoning_obj + ) + + except Exception as e: + logger.exception("Error in unified evaluation for item %s: %s", item.id, e) + return EvalOutputItem( + id=item.id, + score=0.0, + reasoning={ + "error": f"Unified evaluation failed: {str(e)}", + "question": question, + "reference_answer": reference_answer, + "generated_answer": generated_answer, + "plot_paths": [], + "num_images_analyzed": 0, + "evaluation_type": "ERROR" + } + ) + + async def _call_api_langchain( + self, + question: str, + reference_answer: str, + generated_answer: str + ) -> str: + """Call the API using LangChain for text-only evaluation.""" + messages = self.prompt_template.format_messages( + question=question, + reference_answer=reference_answer, + generated_answer=generated_answer + ) + + response = await self.llm.ainvoke(messages) + return response.content + + async def _call_visual_api_langchain( + self, + prompt_text: str, + image_data_list: list[str] + ) -> str: + """Call the API using LangChain for visual evaluation with multiple images.""" + # Create content with text and all images + content = [ + { + "type": "text", + "text": prompt_text + } + ] + + # Add all images to the content + for image_data in image_data_list: + content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{image_data}" + } + }) + + messages = [ + HumanMessage(content=content) + ] + + response = await self.llm.ainvoke(messages) + return response.content + + def _load_and_encode_image(self, image_path: str) -> Optional[str]: + """Load an image file and encode it as base64.""" + try: + with Image.open(image_path) as img: + # Convert to RGB if necessary + if img.mode != 'RGB': + img = img.convert('RGB') + + # Save to bytes buffer + buffer = BytesIO() + img.save(buffer, format='PNG') + buffer.seek(0) + + # Encode as base64 + image_data = base64.b64encode(buffer.getvalue()).decode('utf-8') + return image_data + + except Exception as e: + logger.exception("Error loading image from %s: %s", image_path, e) + return None + + def _parse_evaluation_response(self, response_text: str) -> tuple[float, str]: + """Parse the evaluation response and extract score and reasoning.""" + try: + import json + + # First try to parse as direct JSON + eval_result = json.loads(response_text) + + except json.JSONDecodeError: + # If direct JSON parsing fails, try to extract JSON from markdown code blocks + try: + # Look for JSON within markdown code blocks (```json or just ```) + json_pattern = r'```(?:json)?\s*(\{.*?\})\s*```' + json_match = re.search(json_pattern, response_text, re.DOTALL) + + if json_match: + json_str = json_match.group(1) + eval_result = json.loads(json_str) + else: + # If no code blocks found, fall back to text extraction + raise json.JSONDecodeError("No JSON code blocks found", "", 0) + + except json.JSONDecodeError: + # Final fallback to text-based score extraction + score = self._extract_score_from_text(response_text) + reasoning = response_text + return score, reasoning + + # Process the parsed JSON result + if isinstance(eval_result, dict) and 'score' in eval_result: + score = eval_result.get('score', 0.0) + reasoning = eval_result.get('reasoning', response_text) + else: + # If not proper JSON format, try to extract score from text + score = self._extract_score_from_text(response_text) + reasoning = response_text + + # Ensure score is valid (0.0, 0.5, or 1.0) + if isinstance(score, (int, float)): + # Round to nearest valid score + if score <= 0.25: + score = 0.0 + elif score <= 0.75: + score = 0.5 + else: + score = 1.0 + else: + score = 0.0 + reasoning = f"Could not parse score from LLM response: {response_text}" + + return score, reasoning + + def _extract_score_from_text(self, text: str) -> float: + """ + Extract a numeric score from text response if JSON parsing fails. + Looks for patterns like "Score: 0.8" or "8/10" or "80%" and maps to 0.0, 0.5, 1.0 + """ + import re + + # Try to find score patterns in the text + patterns = [ + r'"?score"?[:\s]*([0-9]*\.?[0-9]+)', # "score": 0.8, score: 0.8, or score 0.8 + r'([0-9]*\.?[0-9]+)[/\s]*10', # "8/10" or "8 out of 10" + r'([0-9]*\.?[0-9]+)%', # "80%" + r'([0-9]*\.?[0-9]+)[/\s]*100', # "80/100" or "80 out of 100" + ] + + for pattern in patterns: + match = re.search(pattern, text.lower()) + if match: + try: + value = float(match.group(1)) + + # Normalize different scales to 0-1 range first + if '/10' in pattern: + value = value / 10.0 + elif '%' in pattern or '/100' in pattern: + value = value / 100.0 + + # Now map to 0.0, 0.5, 1.0 + if value <= 0.25: + return 0.0 + elif value <= 0.75: + return 0.5 + else: + return 1.0 + + except ValueError: + continue + + # Default to 0.0 if no score found + logger.warning("Could not extract score from text: %s", text) + return 0.0 \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/multimodal_llm_judge_evaluator_register.py b/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/multimodal_llm_judge_evaluator_register.py new file mode 100644 index 0000000..b2c3f60 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/evaluators/multimodal_llm_judge_evaluator_register.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic import Field + +from nat.builder.builder import EvalBuilder +from nat.builder.evaluator import EvaluatorInfo +from nat.cli.register_workflow import register_evaluator +from nat.data_models.evaluator import EvaluatorBaseConfig + + +class MultimodalLLMJudgeEvaluatorConfig(EvaluatorBaseConfig, name="multimodal_llm_judge_evaluator"): + """Configuration for Multimodal LLM Judge evaluator with text and visual evaluation capabilities.""" + + llm_name: str = Field(description="Name of the LLM to use as judge (should support vision for multimodal evaluation)") + judge_prompt: str = Field( + description="Prompt template for the judge LLM. Should include {question}, {reference_answer}, and {generated_answer} placeholders. This prompt works for both text-only and multimodal evaluation.", + default="""You are an expert evaluator for Asset Lifecycle Management agentic workflows. Your task is to evaluate how well a generated response (which may include both text and visualizations) matches the reference answer for a given question. + +Question: {question} + +Reference Answer: {reference_answer} + +Generated Response: {generated_answer} + +Please evaluate the complete response considering: + +TEXT EVALUATION: +1. Factual accuracy and correctness of technical information +2. Completeness of the response (does it answer all parts of the question?) +3. Technical accuracy for Asset Lifecycle Management context (RUL predictions, sensor data analysis, etc.) +4. Appropriate use of Asset Lifecycle Management and predictive maintenance terminology and concepts + +VISUAL EVALUATION (if plots/charts are present): +1. Does the visualization show the correct data/variables as specified in the reference? +2. Are the axes labeled correctly and with appropriate ranges? +3. Does the plot type (line chart, bar chart, distribution, etc.) match what was requested? +4. Are the data values, trends, and patterns approximately correct? +5. Is the visualization clear and appropriate for Asset Lifecycle Management analysis? +6. Does the plot help answer the original question effectively? + +COMBINED EVALUATION: +1. Do the text and visual elements complement each other appropriately? +2. Does the overall response provide a complete answer? +3. Is the combination more helpful than text or visuals alone would be? + +For Asset Lifecycle Management context, pay special attention to: +- RUL (Remaining Useful Life) predictions and trends +- Sensor data patterns and operational settings +- Time-series data representation +- Unit/engine-specific data filtering +- Dataset context (FD001, FD002, etc.) + +Provide your evaluation as a JSON object with the following format: +{{ + "score": , + "reasoning": "" +}} + +The score should be: +- 1.0: Completely correct response - text and any visuals match reference accurately, comprehensive and helpful +- 0.5: Partially correct response - some elements correct but significant issues in text or visuals +- 0.0: Completely wrong response - major errors in text or visuals that make the response unhelpful""" + ) + + +@register_evaluator(config_type=MultimodalLLMJudgeEvaluatorConfig) +async def register_multimodal_llm_judge_evaluator(config: MultimodalLLMJudgeEvaluatorConfig, builder: EvalBuilder): + """Register the Multimodal LLM Judge evaluator with NeMo Agent Toolkit.""" + from nat.builder.framework_enum import LLMFrameworkEnum + + from .multimodal_llm_judge_evaluator import MultimodalLLMJudgeEvaluator + + # Get the LLM instance + llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) + + # Create the evaluator instance + evaluator = MultimodalLLMJudgeEvaluator( + llm=llm, + judge_prompt=config.judge_prompt, + max_concurrency=builder.get_max_concurrency() + ) + + yield EvaluatorInfo( + config=config, + evaluate_fn=evaluator.evaluate, + description="Multimodal LLM Judge Evaluator with Text and Visual Evaluation Capabilities for Asset Lifecycle Management" + ) \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/__init__.py b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/__init__.py new file mode 100644 index 0000000..a15e19b --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/__init__.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Plotting package for Asset Lifecycle Management agent. + +This package contains components for data visualization, plotting tools, +and code generation assistance for Asset Lifecycle Management workflows. +""" + +from . import plot_comparison_tool +from . import plot_distribution_tool +from . import plot_line_chart_tool +from . import plot_anomaly_tool +from . import code_generation_assistant +from .plot_utils import * + +__all__ = [ + "plot_comparison_tool", + "plot_distribution_tool", + "plot_line_chart_tool", + "plot_anomaly_tool", + "code_generation_assistant", +] \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/code_generation_assistant.py b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/code_generation_assistant.py new file mode 100644 index 0000000..3930994 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/code_generation_assistant.py @@ -0,0 +1,366 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Dict + +from pydantic import Field, BaseModel + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig +from nat.data_models.component_ref import LLMRef, FunctionRef +from nat.builder.framework_enum import LLMFrameworkEnum + +logger = logging.getLogger(__name__) + +class CodeGenerationAssistantConfig(FunctionBaseConfig, name="code_generation_assistant"): + """ + NeMo Agent Toolkit function to generate and execute Python code based on input instructions and context. + This tool combines code generation with direct execution, returning results and any generated files. + """ + llm_name: LLMRef = Field(description="The LLM to use for code generation") + code_execution_tool: FunctionRef = Field(description="The code execution tool to run generated code") + output_folder: str = Field(description="The path to the output folder for generated files", default="/output_data") + verbose: bool = Field(description="Enable verbose logging", default=True) + max_retries: int = Field(description="Maximum number of retries if code execution fails", default=0) + + +@register_function(config_type=CodeGenerationAssistantConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) +async def code_generation_assistant( + config: CodeGenerationAssistantConfig, builder: Builder +): + class CodeGenerationInputSchema(BaseModel): + instructions: str = Field(description="Complete instructions including context, data information, and requirements for the code to be generated") + + # Get the LLM and code execution tool from builder + llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) + code_execution_fn = await builder.get_function(config.code_execution_tool) + max_retries = config.max_retries + + async def _generate_and_execute_code( + instructions: str, + ) -> str: + """ + Generate and execute code based on complete instructions. + + Args: + instructions: Complete instructions including context, data information, and requirements for what the code should do + + Returns: + String containing execution results and summary + """ + + system_prompt = """You are an expert Python developer. Generate MINIMAL, EFFICIENT code. + +**CRITICAL OUTPUT REQUIREMENT:** +OUTPUT ONLY THE CODE. NO COMMENTS. NO DOCSTRINGS. NO EXPLANATIONS. +Generate only the code needed. Your response must contain ONLY executable Python code which will be DIRECTLY EXECUTED IN A SANDBOX. + +**DATABASE PATH:** +For SQLite operations, the database file is located at: '/workspace/database/nasa_turbo.db' +ALWAYS use this exact path when connecting to the database. + +**UTILITIES (OPTIONAL - ONLY FOR RUL TRANSFORMATIONS):** +ONLY IF the task involves piecewise RUL transformation, you may use: +- utils.apply_piecewise_rul_transformation(df, maxlife=100, time_col='time_in_cycles', rul_col='RUL') + Takes a pandas DataFrame and returns it with an added 'transformed_RUL' column using piecewise transformation. + +To use utilities (ONLY if needed for RUL transformation): +```python +import sys +sys.path.append("/workspace") +import utils + +# Your code using the utility +df_transformed = utils.apply_piecewise_rul_transformation(df, maxlife=100) +``` + +DO NOT import utils unless specifically needed for RUL transformation tasks. + +**CODE REQUIREMENTS:** +1. Generate COMPLETE, SYNTACTICALLY CORRECT Python code +2. ALWAYS finish the complete code - never stop mid-statement +3. EVERY if/elif statement MUST have a complete return statement or action +4. NO comments, NO docstrings, NO explanations +5. Use minimal variable names (df, fig, data, etc.) +6. **CRITICAL FILE PATH RULE**: Use ONLY the filename directly (e.g., "filename.json"), NOT "output_data/filename.json" +7. **DATABASE PATH RULE**: Use '/workspace/database/nasa_turbo.db' for SQLite connections +8. **IF YOU STILL NEED TO SAVE FILES, THEN PRINT FILE NAMES TO STDOUT. (eg: print("Saved file to: filename.json"))** + +GENERATE CODE ONLY. NO COMMENTS. NO EXPLANATIONS.""" + + user_prompt = """**INSTRUCTIONS:** +{instructions}. Generate a Python code that fulfills these instructions.""" + + if config.verbose: + logger.info(f"Generating code with instructions: {instructions}") + + try: + from langchain_core.prompts.chat import ChatPromptTemplate + + # Create prompt template following the existing pattern + prompt = ChatPromptTemplate.from_messages([("system", system_prompt), ("user", user_prompt)]) + coding_chain = prompt | llm + + # Generate code using the LLM with proper parameter passing + response = await coding_chain.ainvoke({ + "instructions": instructions + }) + + # Clean up the response to extract just the code + raw_code = response.content.strip() if hasattr(response, 'content') else str(response).strip() + code = _clean_generated_code(raw_code) + + if config.verbose: + logger.info(f"Generated code length: {len(code)} characters") + logger.info(f"Generated code:\n{code}") + + # Execute the generated code with retry logic + for attempt in range(max_retries + 1): + if config.verbose: + logger.info(f"Attempt {attempt + 1}/{max_retries + 1}: Executing generated code...") + + # Check if code appears incomplete + def is_code_incomplete(code): + is_truncated = (not code.endswith(')') and not code.endswith('"') and + not code.endswith("'") and not code.endswith(';')) + has_incomplete_fig_write = 'fig.write' in code and not 'fig.write_html(' in code + return is_truncated or has_incomplete_fig_write + + # Skip execution if code is incomplete on retry attempts + execution_failed = False + error_info = "" + + if is_code_incomplete(code): + execution_failed = True + error_info = "Code generation was incomplete - code appears to be truncated or has incomplete statements" + logger.warning(f"Code appears incomplete: {code[-100:]}") + else: + # Execute the code + execution_result = await code_execution_fn.ainvoke({"generated_code": code}) + + if config.verbose: + logger.info(f"Execution result: {execution_result}") + + # Parse result + process_status = execution_result.get('process_status', 'unknown') + raw_stdout = execution_result.get('stdout', '') + stderr = execution_result.get('stderr', '') + + # Convert list to string if needed (E2B returns lists) + if isinstance(raw_stdout, list): + raw_stdout = ''.join(raw_stdout) + if isinstance(stderr, list): + stderr = ''.join(stderr) + + # Handle nested JSON result + actual_stdout, actual_stderr = raw_stdout, stderr + try: + if raw_stdout.startswith('{"') and raw_stdout.endswith('}\n'): + import json + nested = json.loads(raw_stdout.strip()) + if nested.get('process_status') == 'error' or nested.get('stderr'): + process_status = 'error' + actual_stdout = nested.get('stdout', '') + actual_stderr = nested.get('stderr', '') + except: + pass + + # Check if execution succeeded + if process_status in ['completed', 'success'] and not actual_stderr: + # Success! Return result + generated_files = _extract_file_paths(actual_stdout, config.output_folder) + file_count = len(generated_files) + + if file_count > 0: + file_list = ', '.join([f.split('/')[-1] for f in generated_files]) + response = f"Code executed successfully. Generated {file_count} file(s): {file_list}" + else: + response = "Code executed successfully." + + if actual_stdout: + clean_output = actual_stdout.strip().replace('\n', ' ') + response += f"\n\nOutput: {clean_output}" + + logger.info(f"Code generation successful: {response}") + return response + else: + # Execution failed + execution_failed = True + error_info = "" + if actual_stderr: + error_info += f"Error: {actual_stderr.strip()}" + if actual_stdout: + error_info += f"\nOutput: {actual_stdout.strip()}" + + # If we have retries left, ask LLM to fix the code + if execution_failed and attempt < max_retries: + logger.warning(f"Execution failed, asking LLM to fix (attempt {attempt + 1})...") + + fix_prompt_text = f"""The previous code needs to be fixed. Please analyze the issue and generate corrected Python code. + +ORIGINAL INSTRUCTIONS: {instructions} + +PREVIOUS CODE: +{code} + +ISSUE TO FIX: +{error_info} + +Please generate corrected Python code that fixes the problem. Follow all requirements: +- Use '/workspace/database/nasa_turbo.db' for database connections +- Only import utils if doing RUL transformations (use sys.path.append("/workspace")) +- Generate only executable Python code +- No comments or explanations +- Handle file paths correctly (use only filename, not paths) +- Complete all code blocks properly +- Ensure the code is complete and not truncated + +CORRECTED CODE:""" + + try: + fix_prompt = ChatPromptTemplate.from_messages([("system", system_prompt), ("user", fix_prompt_text)]) + fix_chain = fix_prompt | llm + fix_response = await fix_chain.ainvoke({}) + raw_fixed_code = fix_response.content.strip() if hasattr(fix_response, 'content') else str(fix_response).strip() + code = _clean_generated_code(raw_fixed_code) + + if config.verbose: + logger.info(f"Generated corrected code:\n{code}") + + except Exception as e: + logger.error(f"Failed to generate corrected code: {e}") + break + elif execution_failed: + # Max retries reached + break + + # All retries failed + response = f"Code generation failed after {max_retries + 1} attempts." + if error_info: + error_text = error_info.strip().replace('\n', ' ') + response += f" Last error: {error_text}" + response += " Consider using alternative approaches." + + logger.error(response) + return response + + except Exception as e: + logger.error(f"Error in code generation and execution: {e}") + return f"Error in code generation and execution: {str(e)}" + + yield FunctionInfo.from_fn( + fn=_generate_and_execute_code, + input_schema=CodeGenerationInputSchema, + description="""Generate and execute Python code based on complete instructions. + Accepts comprehensive instructions including context, data information, and requirements in a single parameter. + Includes retry logic with max_retries parameter for handling execution failures. + Returns a summary with execution status, generated files, and output details. + Specializes in data analysis, visualization, and file processing tasks. + Include all necessary context, data file information, and requirements in the instructions parameter.""") + + if config.verbose: + logger.info("Code generation assistant initialized successfully") + + +def _clean_generated_code(raw_code: str) -> str: + """ + Clean generated code by removing markdown formatting and explanatory text. + + Args: + raw_code: Raw code string from LLM response + + Returns: + Cleaned code string with only executable code + """ + code = raw_code.strip() + + # Remove markdown code blocks if present + if code.startswith("```python"): + code = code[9:] # Remove ```python + elif code.startswith("```"): + code = code[3:] # Remove ``` + + if code.endswith("```"): + code = code[:-3] # Remove closing ``` + + code = code.strip() + + # Remove any explanatory text that might appear after the code + # Look for common patterns that indicate explanatory text + explanatory_patterns = [ + "\nThis script performs", + "\nThis code performs", + "\nThe script does", + "\nThe code does", + "\nExplanation:", + "\nSummary:", + "\nThe above code", + "\nThis will", + "\nThe generated code" + ] + + for pattern in explanatory_patterns: + if pattern in code: + code = code.split(pattern)[0].strip() + break + + # Also remove any line that starts with explaining the script + lines = code.split('\n') + clean_lines = [] + + for line in lines: + stripped_line = line.strip() + # Skip lines that look like explanations + if (stripped_line.startswith('This script') or + stripped_line.startswith('This code') or + stripped_line.startswith('The script') or + stripped_line.startswith('The code') or + stripped_line.startswith('Explanation:') or + (stripped_line and not any(char in stripped_line for char in ['=', '(', ')', '[', ']', '{', '}', 'import', 'from', 'def', 'class', 'if', 'for', 'while', 'try', 'except', 'with', '#']))): + continue + clean_lines.append(line) + + return '\n'.join(clean_lines).strip() + +def _extract_file_paths(stdout: str, output_folder: str) -> list: + """Extract generated file paths from execution output.""" + import re + import os + + files = [] + # Look for common patterns indicating file generation + patterns = [ + r'saved to[:\s]+([^\s\n]+\.(?:html|png|jpg|jpeg|pdf|csv|json))', + r'([^\s\n]+\.(?:html|png|jpg|jpeg|pdf|csv|json))', + r'Plot saved to[:\s]+([^\s\n]+)', + r'File saved[:\s]+([^\s\n]+)' + ] + + for pattern in patterns: + matches = re.findall(pattern, stdout, re.IGNORECASE) + for match in matches: + file_path = match.strip().strip('"\'') + if file_path and not file_path.startswith('#'): + # Convert relative paths to absolute if needed + if not os.path.isabs(file_path): + full_path = os.path.join(output_folder, file_path.lstrip('./')) + else: + full_path = file_path + files.append(full_path) + + return list(set(files)) # Remove duplicates \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_anomaly_tool.py b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_anomaly_tool.py new file mode 100644 index 0000000..ca244f6 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_anomaly_tool.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import pandas as pd +from typing import Optional +from pydantic import Field, BaseModel + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig + +from .plot_utils import create_anomaly_plot_from_data + +logger = logging.getLogger(__name__) + + +class PlotAnomalyToolConfig(FunctionBaseConfig, name="plot_anomaly_tool"): + """ + NeMo Agent Toolkit function to create anomaly detection visualizations. + """ + output_folder: str = Field(description="The path to the output folder to save plots.", default="./output_data") + + +@register_function(config_type=PlotAnomalyToolConfig) +async def plot_anomaly_tool(config: PlotAnomalyToolConfig, builder: Builder): + + class PlotAnomalyInputSchema(BaseModel): + anomaly_data_json_path: str = Field(description="Path to JSON file containing sensor data with is_anomaly column") + sensor_name: str = Field(description="Name of the sensor to plot", default="sensor_measurement_1") + engine_unit: int = Field(description="Engine unit number", default=5) + plot_title: Optional[str] = Field(description="Custom title for the plot", default=None) + + def load_json_data(json_path: str) -> Optional[pd.DataFrame]: + """Load data from JSON file.""" + from .plot_utils import resolve_relative_path + try: + # Resolve path relative to output folder + resolved_path = resolve_relative_path(json_path, config.output_folder) + with open(resolved_path, 'r') as f: + data = json.load(f) + return pd.DataFrame(data) + except Exception as e: + logger.error(f"Error loading JSON data from {json_path}: {e}") + return None + + # Plotting logic moved to plot_utils.py for thread safety + + async def _response_fn( + anomaly_data_json_path: str, + sensor_name: str = "sensor_measurement_1", + engine_unit: int = 5, + plot_title: Optional[str] = None + ) -> str: + """ + Create anomaly detection visualization from sensor data with is_anomaly column. + """ + try: + # Load the data with anomaly information + data_df = load_json_data(anomaly_data_json_path) + if data_df is None: + return f"Failed to load anomaly data from {anomaly_data_json_path}" + + logger.info(f"Loaded anomaly data: {data_df.shape}") + + # Create the plot using thread-safe utility function + html_filepath, png_filepath = create_anomaly_plot_from_data( + data_df, sensor_name, engine_unit, + config.output_folder, plot_title + ) + + if html_filepath is None: + return "Failed to create anomaly visualization plot" + + # Build response + response_parts = [ + "ANOMALY DETECTION VISUALIZATION COMPLETED SUCCESSFULLY", + "", + f"Plot Details:", + f" • Sensor: {sensor_name}", + f" • Engine Unit: {engine_unit}", + f" • Data Points: {len(data_df)}", + f" • Anomalous Points: {len(data_df[data_df['is_anomaly'] == True])}", + "", + f"Output Files:", + f" • Interactive HTML: {os.path.relpath(html_filepath, config.output_folder)}", + f" • PNG Image: {os.path.relpath(png_filepath, config.output_folder) if png_filepath else 'Not generated'}", + "", + f"Visualization Features:", + f" • Blue line shows observed sensor readings", + f" • Red markers highlight detected anomalies", + f" • Interactive plot with zoom and hover capabilities", + "", + "ANOMALY PLOT GENERATION COMPLETE" + ] + + return "\n".join(response_parts) + + except Exception as e: + logger.error(f"Error in plot_anomaly_tool: {e}") + return f"Error creating anomaly plot: {str(e)}" + + description = """ + Create interactive anomaly detection visualizations from sensor data with is_anomaly column. + + This tool takes a single JSON file containing sensor data with an added 'is_anomaly' boolean column + (typically output from MOMENT anomaly detection tool) and creates a clean visualization. + + Features: + - Interactive HTML plot with zoom and hover capabilities + - Blue line for observed sensor readings + - Red markers for detected anomalies + - Automatic time axis detection (cycle, time_in_cycles, etc.) + - PNG export for reports and documentation + - Customizable plot titles + + Input: + - anomaly_data_json_path: Path to JSON file with sensor data and is_anomaly column [REQUIRED] + - sensor_name: Name of sensor column to plot (default: "sensor_measurement_1") + - engine_unit: Engine unit number for labeling (default: 5) + - plot_title: Custom title for the plot (optional) + + Output: + - Interactive HTML visualization file + - PNG image file (if successfully generated) + - Summary of plot generation with file paths + """ + + yield FunctionInfo.from_fn(_response_fn, + input_schema=PlotAnomalyInputSchema, + description=description) \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_comparison_tool.py b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_comparison_tool.py new file mode 100644 index 0000000..82f6ca4 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_comparison_tool.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import pandas as pd + +from pydantic import Field, BaseModel + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig + +logger = logging.getLogger(__name__) + +def verify_json_path(file_path: str, output_folder: str = None) -> str: + """ + Verify that the input is a valid path to a JSON file. + + Args: + file_path (str): Path to verify + output_folder (str): Output folder to use as base for relative paths + + Returns: + str: Verified file path + + Raises: + ValueError: If input is not a string or not a JSON file + FileNotFoundError: If file does not exist + json.JSONDecodeError: If file contains invalid JSON + """ + if not isinstance(file_path, str): + raise ValueError("Input must be a string path to a JSON file") + + if not file_path.lower().endswith('.json'): + raise ValueError("Input must be a path to a JSON file (ending with .json)") + + # Resolve path relative to output folder if provided + if output_folder: + from .plot_utils import resolve_relative_path + resolved_path = resolve_relative_path(file_path, output_folder) + else: + resolved_path = file_path + + if not os.path.exists(resolved_path): + raise FileNotFoundError(f"JSON file not found at path: {file_path}") + + try: + with open(resolved_path, 'r') as f: + json.load(f) # Verify file contains valid JSON + except json.JSONDecodeError: + raise ValueError(f"File at {file_path} does not contain valid JSON data") + + return resolved_path + +class PlotComparisonToolConfig(FunctionBaseConfig, name="plot_comparison_tool"): + """ + NeMo Agent Toolkit function to plot comparison of two y-axis columns against an x-axis column. + """ + output_folder: str = Field(description="The path to the output folder to save plots.", default="./output_data") + +@register_function(config_type=PlotComparisonToolConfig) +async def plot_comparison_tool( + config: PlotComparisonToolConfig, builder: Builder +): + class PlotComparisonInputSchema(BaseModel): + data_json_path: str = Field(description="The path to the JSON file containing the data") + x_axis_column: str = Field(description="The column name for x-axis data", default="time_in_cycles") + y_axis_column_1: str = Field(description="The first column name for y-axis data", default="actual_RUL") + y_axis_column_2: str = Field(description="The second column name for y-axis data", default="predicted_RUL") + plot_title: str = Field(description="The title for the plot", default="Comparison Plot") + + from .plot_utils import create_comparison_plot, load_data_from_json + + async def _response_fn(data_json_path: str, x_axis_column: str, y_axis_column_1: str, y_axis_column_2: str, plot_title: str) -> str: + """ + Process the input message and generate comparison plot. + """ + try: + # Load data to validate columns exist + df = load_data_from_json(data_json_path, config.output_folder) + if df is None or df.empty: + return "Could not load data or data is empty from the provided JSON file" + + # Check required columns + required_columns = [x_axis_column, y_axis_column_1, y_axis_column_2] + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + return f"Data from {data_json_path} must contain columns: {required_columns}. Missing: {missing_columns}" + + # Use utility function to create plot + html_filepath, png_filepath = create_comparison_plot( + output_dir=config.output_folder, + data_json_path=data_json_path, + x_col=x_axis_column, + y_col_1=y_axis_column_1, + y_col_2=y_axis_column_2, + title=plot_title + ) + + # Build file information for response (relative paths from output folder) + html_relpath = os.path.relpath(html_filepath, config.output_folder) + file_info = f"- HTML File: {html_relpath}" + if png_filepath: + png_relpath = os.path.relpath(png_filepath, config.output_folder) + file_info += f"\n- PNG File: {png_relpath}" + + # Return a clear completion message that the LLM will understand + return f"""TASK COMPLETED SUCCESSFULLY + +Comparison plot has been generated and saved in multiple formats. + +Chart Details: +- Type: Comparison plot with two lines (Plotly) +- X-axis: {x_axis_column} +- Y-axis Line 1: {y_axis_column_1} (dashed teal) +- Y-axis Line 2: {y_axis_column_2} (solid green) +- Title: {plot_title} +{file_info} + +✅ CHART GENERATION COMPLETE - NO FURTHER ACTION NEEDED""" + + except FileNotFoundError as e: + error_msg = f"Required data file ('{data_json_path}') not found for comparison plot: {e}" + logger.error(error_msg) + return error_msg + except KeyError as ke: + error_msg = f"Missing required columns in '{data_json_path}' for comparison plot: {ke}" + logger.error(error_msg) + return error_msg + except ValueError as ve: + error_msg = f"Data validation error for comparison plot: {ve}" + logger.error(error_msg) + return error_msg + except Exception as e: + error_msg = f"Error generating comparison plot: {e}" + logger.error(error_msg) + return error_msg + + prompt = """ + Generate interactive comparison plot between two columns from JSON data using Plotly. + + Input: + - data_json_path: Path to the JSON file containing the data + - x_axis_column: Column name for x-axis data + - y_axis_column_1: Column name for first y-axis data + - y_axis_column_2: Column name for second y-axis data + - plot_title: Title for the plot + + Output: + - HTML file containing the interactive comparison plot + - PNG file containing the static comparison plot + """ + yield FunctionInfo.from_fn(_response_fn, + input_schema=PlotComparisonInputSchema, + description=prompt) + try: + pass + except GeneratorExit: + logger.info("Plot comparison function exited early!") + finally: + logger.info("Cleaning up plot_comparison_tool workflow.") diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_distribution_tool.py b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_distribution_tool.py new file mode 100644 index 0000000..3593f01 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_distribution_tool.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os + +from pydantic import Field, BaseModel + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig + +logger = logging.getLogger(__name__) + +def verify_json_path(file_path: str, output_folder: str = None) -> str: + """ + Verify that the input is a valid path to a JSON file. + + Args: + file_path (str): Path to verify + output_folder (str): Output folder to use as base for relative paths + + Returns: + str: Verified file path + + Raises: + ValueError: If input is not a string or not a JSON file + FileNotFoundError: If file does not exist + json.JSONDecodeError: If file contains invalid JSON + """ + if not isinstance(file_path, str): + raise ValueError("Input must be a string path to a JSON file") + + if not file_path.lower().endswith('.json'): + raise ValueError("Input must be a path to a JSON file (ending with .json)") + + # Resolve path relative to output folder if provided + if output_folder: + from .plot_utils import resolve_relative_path + resolved_path = resolve_relative_path(file_path, output_folder) + else: + resolved_path = file_path + + if not os.path.exists(resolved_path): + raise FileNotFoundError(f"JSON file not found at path: {file_path}") + + try: + with open(resolved_path, 'r') as f: + json.load(f) # Verify file contains valid JSON + except json.JSONDecodeError: + raise ValueError(f"File at {file_path} does not contain valid JSON data") + + return resolved_path + +class PlotDistributionToolConfig(FunctionBaseConfig, name="plot_distribution_tool"): + """ + NeMo Agent Toolkit function to plot distribution histogram of a specified column. + """ + output_folder: str = Field(description="The path to the output folder to save plots.", default="./output_data") + +@register_function(config_type=PlotDistributionToolConfig) +async def plot_distribution_tool( + config: PlotDistributionToolConfig, builder: Builder +): + class PlotDistributionInputSchema(BaseModel): + data_json_path: str = Field(description="The path to the JSON file containing the data") + column_name: str = Field(description="The column name to create distribution plot for", default="RUL") + plot_title: str = Field(description="The title for the plot", default="Distribution Plot") + + from .plot_utils import create_distribution_plot, load_data_from_json + + async def _response_fn(data_json_path: str, column_name: str, plot_title: str) -> str: + """ + Process the input message and generate distribution histogram file. + """ + data_json_path = verify_json_path(data_json_path, config.output_folder) + try: + # Load data to validate column exists + df = load_data_from_json(data_json_path, config.output_folder) + if df is None or df.empty: + return "Could not load data or data is empty from the provided JSON file" + + if column_name not in df.columns: + return f"Column '{column_name}' not found in data. Available columns: {df.columns.tolist()}" + + # Use utility function to create plot + html_filepath, png_filepath = create_distribution_plot( + output_dir=config.output_folder, + data_json_path=data_json_path, + column_name=column_name, + title=plot_title + ) + + # Build file information for response (relative paths from output folder) + html_relpath = os.path.relpath(html_filepath, config.output_folder) + file_info = f"- HTML File: {html_relpath}" + if png_filepath: + png_relpath = os.path.relpath(png_filepath, config.output_folder) + file_info += f"\n- PNG File: {png_relpath}" + + # Return a clear completion message that the LLM will understand + return f"""TASK COMPLETED SUCCESSFULLY + +Distribution histogram has been generated and saved in multiple formats. + +Chart Details: +- Type: Distribution histogram (30 bins, Plotly) +- Column: {column_name} +- Title: {plot_title} +{file_info} + +✅ CHART GENERATION COMPLETE - NO FURTHER ACTION NEEDED""" + + except FileNotFoundError as e: + error_msg = f"Required data file ('{data_json_path}') not found for distribution plot: {e}" + logger.error(error_msg) + return error_msg + except KeyError as ke: + error_msg = f"Missing expected column '{column_name}' in '{data_json_path}' for distribution plot: {ke}" + logger.error(error_msg) + return error_msg + except Exception as e: + error_msg = f"Error generating distribution histogram: {e}" + logger.error(error_msg) + return error_msg + + prompt = """ + Generate interactive distribution histogram from JSON data using Plotly. + Input: + - data_json_path: Path to the JSON file containing the data + - column_name: Column name for the distribution histogram + - plot_title: Title for the plot + + Output: + - HTML file containing the interactive distribution histogram + - PNG file containing the static distribution histogram + """ + yield FunctionInfo.from_fn(_response_fn, + input_schema=PlotDistributionInputSchema, + description=prompt) + try: + pass + except GeneratorExit: + logger.info("Plot distribution function exited early!") + finally: + logger.info("Cleaning up plot_distribution_tool workflow.") diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_line_chart_tool.py b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_line_chart_tool.py new file mode 100644 index 0000000..4981c48 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_line_chart_tool.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os + +from pydantic import Field, BaseModel + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig + +logger = logging.getLogger(__name__) + +def verify_json_path(file_path: str, output_folder: str = None) -> str: + """ + Verify that the input is a valid path to a JSON file. + + Args: + file_path (str): Path to verify + output_folder (str): Output folder to use as base for relative paths + + Returns: + str: Verified file path + + Raises: + ValueError: If input is not a string or not a JSON file + FileNotFoundError: If file does not exist + json.JSONDecodeError: If file contains invalid JSON + """ + if not isinstance(file_path, str): + raise ValueError("Input must be a string path to a JSON file") + + if not file_path.lower().endswith('.json'): + raise ValueError("Input must be a path to a JSON file (ending with .json)") + + # Resolve path relative to output folder if provided + if output_folder: + from .plot_utils import resolve_relative_path + resolved_path = resolve_relative_path(file_path, output_folder) + else: + resolved_path = file_path + + if not os.path.exists(resolved_path): + raise FileNotFoundError(f"JSON file not found at path: {file_path}") + + try: + with open(resolved_path, 'r') as f: + json.load(f) # Verify file contains valid JSON + except json.JSONDecodeError: + raise ValueError(f"File at {file_path} does not contain valid JSON data") + + return resolved_path + +class PlotLineChartToolConfig(FunctionBaseConfig, name="plot_line_chart_tool"): + """ + NeMo Agent Toolkit function to plot a line chart with specified x and y axis columns. + """ + output_folder: str = Field(description="The path to the output folder to save plots.", default="./output_data") + +@register_function(config_type=PlotLineChartToolConfig) +async def plot_line_chart_tool( + config: PlotLineChartToolConfig, builder: Builder +): + class PlotLineChartInputSchema(BaseModel): + data_json_path: str = Field(description="The path to the JSON file containing the data") + x_axis_column: str = Field(description="The column name for x-axis data", default="time_in_cycles") + y_axis_column: str = Field(description="The column name for y-axis data", default="RUL") + plot_title: str = Field(description="The title for the plot", default="Line Chart") + + from .plot_utils import create_line_chart, load_data_from_json + + async def _response_fn(data_json_path: str, x_axis_column: str, y_axis_column: str, plot_title: str) -> str: + """ + Process the input message and generate line chart. + """ + data_json_path = verify_json_path(data_json_path, config.output_folder) + + try: + # Load data to validate columns exist + df = load_data_from_json(data_json_path, config.output_folder) + if df is None or df.empty: + return "Could not load data or data is empty from the provided JSON file" + + # Check required columns + required_columns = [x_axis_column, y_axis_column] + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + return f"Data from {data_json_path} must contain columns: {required_columns}. Missing: {missing_columns}" + + # Use utility function to create plot + html_filepath, png_filepath = create_line_chart( + output_dir=config.output_folder, + data_json_path=data_json_path, + x_col=x_axis_column, + y_col=y_axis_column, + title=plot_title + ) + + # Build file information for response (relative paths from output folder) + html_relpath = os.path.relpath(html_filepath, config.output_folder) + file_info = f"- HTML File: {html_relpath}" + if png_filepath: + png_relpath = os.path.relpath(png_filepath, config.output_folder) + file_info += f"\n- PNG File: {png_relpath}" + + # Return a clear completion message that the LLM will understand + return f"""TASK COMPLETED SUCCESSFULLY + +Line chart has been generated and saved in multiple formats. + +Chart Details: +- Type: Line chart with markers (Plotly) +- X-axis: {x_axis_column} +- Y-axis: {y_axis_column} +- Title: {plot_title} +{file_info} + +✅ CHART GENERATION COMPLETE - NO FURTHER ACTION NEEDED""" + + except FileNotFoundError as e: + error_msg = f"Required data file ('{data_json_path}') not found for line chart: {e}" + logger.error(error_msg) + return error_msg + except KeyError as ke: + error_msg = f"Missing required columns in '{data_json_path}' for line chart: {ke}" + logger.error(error_msg) + return error_msg + except ValueError as ve: + error_msg = f"Data validation error for line chart: {ve}" + logger.error(error_msg) + return error_msg + except Exception as e: + error_msg = f"Error generating line chart: {e}" + logger.error(error_msg) + return error_msg + + prompt = """ + Generate interactive line chart from JSON data using Plotly. + + Input: + - data_json_path: Path to the JSON file containing the data + - x_axis_column: Column name for x-axis data + - y_axis_column: Column name for y-axis data + - plot_title: Title for the plot + + Output: + - HTML file containing the interactive line chart + - PNG file containing the static line chart + """ + yield FunctionInfo.from_fn(_response_fn, + input_schema=PlotLineChartInputSchema, + description=prompt) + try: + pass + except GeneratorExit: + logger.info("Plot line chart function exited early!") + finally: + logger.info("Cleaning up plot_line_chart_tool workflow.") diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_utils.py b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_utils.py new file mode 100644 index 0000000..3077f0d --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/plotting/plot_utils.py @@ -0,0 +1,618 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import pandas as pd +from typing import Optional, Tuple + +logger = logging.getLogger(__name__) + +def resolve_relative_path(file_path: str, output_folder: str) -> str: + """ + Resolve file path relative to output folder. + + Args: + file_path: Input file path (can be relative or absolute) + output_folder: Output folder to use as base for relative paths + + Returns: + Resolved absolute path + """ + if os.path.isabs(file_path): + # If absolute path exists, use it + if os.path.exists(file_path): + return file_path + # If absolute path doesn't exist, try relative to output folder + basename = os.path.basename(file_path) + relative_path = os.path.join(output_folder, basename) + if os.path.exists(relative_path): + return relative_path + # Return original if neither exists (will be handled by calling function) + return file_path + else: + # If relative path, first try relative to output folder + relative_to_output = os.path.join(output_folder, file_path) + if os.path.exists(relative_to_output): + return relative_to_output + # Then try as provided (relative to current working directory) + if os.path.exists(file_path): + return file_path + # Return relative to output folder as default + return relative_to_output + +def load_data_from_json(json_path: str, output_folder: str = None) -> Optional[pd.DataFrame]: + """Load data from JSON file into a pandas DataFrame.""" + try: + # Resolve path relative to output folder if provided + if output_folder: + resolved_path = resolve_relative_path(json_path, output_folder) + else: + resolved_path = json_path + + with open(resolved_path, 'r') as f: + data = json.load(f) + return pd.DataFrame(data) + except FileNotFoundError: + logger.error(f"JSON file not found at {json_path} (resolved to {resolved_path if output_folder else json_path})") + return None + except json.JSONDecodeError: + logger.error(f"Could not decode JSON from {json_path}") + return None + except Exception as e: + logger.error(f"Error loading data from '{json_path}': {e}") + return None + +def save_plotly_as_png(fig, filepath: str, width: int = 650, height: int = 450) -> bool: + """ + Save plotly figure as PNG using matplotlib backend. + + Returns: + bool: True if successful, False otherwise + """ + try: + import matplotlib.pyplot as plt + import matplotlib + matplotlib.use('Agg') # Non-interactive backend + + # Create matplotlib figure + fig_mpl, ax = plt.subplots(figsize=(width/100, height/100)) + + # Plot each trace with simplified approach + for i, trace in enumerate(fig.data): + if trace.type == 'scatter': + # Handle line properties + line_style = '-' + color = '#1f77b4' # default color + + # Extract line properties safely + if hasattr(trace, 'line') and trace.line: + if hasattr(trace.line, 'dash') and trace.line.dash == 'dash': + line_style = '--' + if hasattr(trace.line, 'color') and trace.line.color: + color = trace.line.color + + # Extract marker color (takes precedence for better Plotly color preservation) + if hasattr(trace, 'marker') and trace.marker and hasattr(trace.marker, 'color') and trace.marker.color: + color = trace.marker.color + + # Extract name safely + name = trace.name if hasattr(trace, 'name') and trace.name else f'Trace {i+1}' + + # Plot based on mode + mode = getattr(trace, 'mode', 'lines') + if 'markers' in mode: + if mode == 'markers': + # Only markers, no lines + ax.plot(trace.x, trace.y, 'o', + color=color, label=name, markersize=6) + else: + # Both markers and lines + ax.plot(trace.x, trace.y, 'o-', + linestyle=line_style, color=color, + label=name, linewidth=2, markersize=4) + else: + ax.plot(trace.x, trace.y, linestyle=line_style, + color=color, label=name, linewidth=2) + + elif trace.type == 'histogram': + # Handle histogram properties + color = '#e17160' # default color + if hasattr(trace, 'marker') and trace.marker and hasattr(trace.marker, 'color'): + color = trace.marker.color + + name = trace.name if hasattr(trace, 'name') and trace.name else f'Histogram {i+1}' + ax.hist(trace.x, bins=30, alpha=0.8, color=color, + edgecolor='white', linewidth=0.5, label=name) + + # Apply layout safely + layout = fig.layout + if hasattr(layout, 'title') and layout.title and hasattr(layout.title, 'text') and layout.title.text: + ax.set_title(layout.title.text) + if hasattr(layout, 'xaxis') and layout.xaxis and hasattr(layout.xaxis, 'title') and layout.xaxis.title and hasattr(layout.xaxis.title, 'text'): + ax.set_xlabel(layout.xaxis.title.text) + if hasattr(layout, 'yaxis') and layout.yaxis and hasattr(layout.yaxis, 'title') and layout.yaxis.title and hasattr(layout.yaxis.title, 'text'): + ax.set_ylabel(layout.yaxis.title.text) + + # Show legend if there are multiple traces or if any trace has a name + if len(fig.data) > 1 or (len(fig.data) == 1 and hasattr(fig.data[0], 'name') and fig.data[0].name): + ax.legend() + + ax.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig(filepath, dpi=150, bbox_inches='tight') + plt.close() + + logger.info(f"PNG saved using matplotlib: {filepath}") + return True + + except Exception as e: + logger.error(f"Matplotlib PNG generation failed: {e}") + return False + +def create_comparison_plot(output_dir: str, data_json_path: str, x_col: str, + y_col_1: str, y_col_2: str, title: str) -> Tuple[str, Optional[str]]: + """ + Generate comparison plot in both HTML and PNG formats. + + Returns: + Tuple[str, Optional[str]]: (html_filepath, png_filepath) + """ + import plotly.graph_objects as go + import plotly.offline as pyo + + df = load_data_from_json(data_json_path, output_dir) + if df is None or df.empty: + raise ValueError(f"Could not load data or data is empty from {data_json_path}") + + # Check required columns + required_columns = [x_col, y_col_1, y_col_2] + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + raise KeyError(f"Data from {data_json_path} must contain columns: {required_columns}. Missing: {missing_columns}") + + # Sort by x-axis column for proper line plotting + df_sorted = df.sort_values(x_col) + + # Create the comparison plot + fig = go.Figure() + + # Add first line (dashed) + fig.add_trace(go.Scatter( + x=df_sorted[x_col], + y=df_sorted[y_col_1], + mode='lines', + name=y_col_1, + line=dict(color='#20B2AA', width=3, dash='dash'), + hovertemplate=f'{x_col}: %{{x}}
' + + f'{y_col_1}: %{{y:.1f}}
' + + '' + )) + + # Add second line (solid) + fig.add_trace(go.Scatter( + x=df_sorted[x_col], + y=df_sorted[y_col_2], + mode='lines', + name=y_col_2, + line=dict(color='#2E8B57', width=3), + hovertemplate=f'{x_col}: %{{x}}
' + + f'{y_col_2}: %{{y:.1f}}
' + + '' + )) + + # Update layout + fig.update_layout( + title=dict(text=title, x=0.5, font=dict(size=16)), + xaxis=dict(title=dict(text=x_col, font=dict(size=14)), gridcolor='lightgray', gridwidth=0.5), + yaxis=dict(title=dict(text='Value', font=dict(size=14)), gridcolor='lightgray', gridwidth=0.5), + width=800, height=450, plot_bgcolor='white', + legend=dict(x=1, y=0, xanchor='right', yanchor='bottom', + bgcolor='rgba(255,255,255,0.8)', bordercolor='gray', borderwidth=1), + hovermode='closest' + ) + + # Set y-axis range + y_min = min(df_sorted[y_col_1].min(), df_sorted[y_col_2].min()) + y_max = max(df_sorted[y_col_1].max(), df_sorted[y_col_2].max()) + y_range = y_max - y_min + fig.update_yaxes(range=[max(0, y_min - y_range * 0.05), y_max + y_range * 0.05]) + + # Save files + os.makedirs(output_dir, exist_ok=True) + + # HTML file + html_filepath = os.path.join(output_dir, f"comparison_plot_{y_col_1}_vs_{y_col_2}.html") + html_content = pyo.plot(fig, output_type='div', include_plotlyjs=True) + full_html = f""" + + + + {title} + + + + {html_content} + + + """ + + with open(html_filepath, 'w', encoding='utf-8') as f: + f.write(full_html) + logger.info(f"Comparison plot HTML saved: {html_filepath}") + + # PNG file + png_filepath = os.path.join(output_dir, f"comparison_plot_{y_col_1}_vs_{y_col_2}.png") + png_success = save_plotly_as_png(fig, png_filepath, width=800, height=450) + + return html_filepath, png_filepath if png_success else None + +def create_line_chart(output_dir: str, data_json_path: str, x_col: str, + y_col: str, title: str) -> Tuple[str, Optional[str]]: + """ + Generate line chart in both HTML and PNG formats. + + Returns: + Tuple[str, Optional[str]]: (html_filepath, png_filepath) + """ + import plotly.graph_objects as go + import plotly.offline as pyo + + df = load_data_from_json(data_json_path, output_dir) + if df is None or df.empty: + raise ValueError(f"Could not load data or data is empty from {data_json_path}") + + # Check required columns + required_columns = [x_col, y_col] + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + raise KeyError(f"Data from {data_json_path} must contain columns: {required_columns}. Missing: {missing_columns}") + + # Sort by x-axis column + df_sorted = df.sort_values(x_col) + + # Create line chart + fig = go.Figure() + fig.add_trace(go.Scatter( + x=df_sorted[x_col], + y=df_sorted[y_col], + mode='lines+markers', + name=y_col, + line=dict(color='#1f77b4', width=3), + marker=dict(size=6, color='#1f77b4'), + hovertemplate=f'{x_col}: %{{x}}
' + + f'{y_col}: %{{y:.2f}}
' + + '' + )) + + # Update layout + fig.update_layout( + title=dict(text=title, x=0.5, font=dict(size=16)), + xaxis=dict(title=dict(text=x_col, font=dict(size=14)), gridcolor='lightgray', gridwidth=0.5), + yaxis=dict(title=dict(text=y_col, font=dict(size=14)), gridcolor='lightgray', gridwidth=0.5), + width=650, height=450, plot_bgcolor='white', showlegend=False, hovermode='closest' + ) + + # Set y-axis range + y_min = df_sorted[y_col].min() + y_max = df_sorted[y_col].max() + y_range = y_max - y_min + if y_range > 0: + fig.update_yaxes(range=[y_min - y_range * 0.05, y_max + y_range * 0.05]) + + # Save files + os.makedirs(output_dir, exist_ok=True) + + # HTML file + html_filepath = os.path.join(output_dir, f"line_chart_{x_col}_vs_{y_col}.html") + html_content = pyo.plot(fig, output_type='div', include_plotlyjs=True) + full_html = f""" + + + + {title} + + + + {html_content} + + + """ + + with open(html_filepath, 'w', encoding='utf-8') as f: + f.write(full_html) + logger.info(f"Line chart HTML saved: {html_filepath}") + + # PNG file + png_filepath = os.path.join(output_dir, f"line_chart_{x_col}_vs_{y_col}.png") + png_success = save_plotly_as_png(fig, png_filepath, width=650, height=450) + + return html_filepath, png_filepath if png_success else None + +def create_distribution_plot(output_dir: str, data_json_path: str, column_name: str, + title: str) -> Tuple[str, Optional[str]]: + """ + Generate distribution histogram in both HTML and PNG formats. + + Returns: + Tuple[str, Optional[str]]: (html_filepath, png_filepath) + """ + import plotly.graph_objects as go + import plotly.offline as pyo + + df = load_data_from_json(data_json_path, output_dir) + if df is None or df.empty: + raise ValueError(f"Could not load data or data is empty from {data_json_path}") + + if column_name not in df.columns: + raise KeyError(f"Data from {data_json_path} must contain '{column_name}' column. Found: {df.columns.tolist()}") + + # Create histogram + fig = go.Figure() + fig.add_trace(go.Histogram( + x=df[column_name], + nbinsx=30, + name=column_name, + marker=dict(color='#e17160', line=dict(color='white', width=1)), + opacity=0.8, + hovertemplate='Range: %{x}
' + + 'Count: %{y}
' + + '' + )) + + # Update layout + fig.update_layout( + title=dict(text=title, x=0.5, font=dict(size=14)), + xaxis=dict(title=dict(text=column_name, font=dict(size=12)), gridcolor='lightgray', gridwidth=0.5), + yaxis=dict(title=dict(text='Frequency', font=dict(size=12)), gridcolor='lightgray', gridwidth=0.5), + width=650, height=450, plot_bgcolor='white', showlegend=False, hovermode='closest' + ) + + # Save files + os.makedirs(output_dir, exist_ok=True) + + # HTML file + html_filepath = os.path.join(output_dir, f"distribution_plot_{column_name}.html") + html_content = pyo.plot(fig, output_type='div', include_plotlyjs=True) + full_html = f""" + + + + {title} + + + + {html_content} + + + """ + + with open(html_filepath, 'w', encoding='utf-8') as f: + f.write(full_html) + logger.info(f"Distribution plot HTML saved: {html_filepath}") + + # PNG file + png_filepath = os.path.join(output_dir, f"distribution_plot_{column_name}.png") + png_success = save_plotly_as_png(fig, png_filepath, width=650, height=450) + + return html_filepath, png_filepath if png_success else None + + +def create_moment_anomaly_visualization(df: pd.DataFrame, anomaly_indices, + anomaly_scores, sensor_name: str, + output_dir: str, engine_unit: int, dataset_name: str) -> Tuple[str, str]: + """Create interactive plot for MOMENT-based anomaly detection results for a single sensor.""" + try: + import plotly.graph_objects as go + import numpy as np + + if sensor_name not in df.columns: + raise ValueError(f"Sensor '{sensor_name}' not found in data. Available sensors: {df.columns.tolist()}") + + # Create a simple single plot + fig = go.Figure() + + # Create x-axis (check for various time column names) + time_columns = ['time_in_cycles', 'cycle', 'time', 'timestamp'] + x_axis = None + x_title = "Index" + + for col in time_columns: + if col in df.columns: + x_axis = df[col] + x_title = col.replace('_', ' ').title() + break + + if x_axis is None: + x_axis = df.index + x_title = "Index" + + # Plot all sensor readings as blue line (Observed) + fig.add_trace( + go.Scatter( + x=x_axis, + y=df[sensor_name], + mode='lines', + name='Observed', + line=dict(color='blue', width=2), + opacity=0.8 + ) + ) + + # Plot anomalous points as red markers + if len(anomaly_indices) > 0 and np.any(anomaly_indices): + # Find where anomalies are True + anomaly_positions = np.where(anomaly_indices)[0] + + # Make sure we don't go beyond the dataframe length + valid_positions = anomaly_positions[anomaly_positions < len(df)] + + if len(valid_positions) > 0: + anomaly_x = x_axis.iloc[valid_positions] + anomaly_y = df[sensor_name].iloc[valid_positions] + + fig.add_trace( + go.Scatter( + x=anomaly_x, + y=anomaly_y, + mode='markers', + name='Anomaly', + marker=dict(color='red', size=6, symbol='circle'), + opacity=0.9 + ) + ) + + fig.update_layout( + title=f'MOMENT Anomaly Detection - {sensor_name} (Engine {engine_unit})', + xaxis_title=x_title, + yaxis_title=f"{sensor_name}", + height=400, + showlegend=True, + font=dict(size=12), + template="plotly_white" + ) + + # Save as HTML + os.makedirs(output_dir, exist_ok=True) + html_filename = f"moment_anomaly_detection_{sensor_name}_engine{engine_unit}.html" + html_filepath = os.path.join(output_dir, html_filename) + fig.write_html(html_filepath) + + # Save as PNG using the safe function from plot_utils + png_filename = f"moment_anomaly_detection_{sensor_name}_engine{engine_unit}.png" + png_filepath = os.path.join(output_dir, png_filename) + png_success = save_plotly_as_png(fig, png_filepath, width=1200, height=400) + + logger.info(f"MOMENT anomaly visualization saved: HTML={html_filepath}, PNG={'Success' if png_success else 'Failed'}") + + return html_filepath, png_filepath if png_success else None + + except ImportError: + logger.error("Plotly not available for visualization") + return None, None + except Exception as e: + logger.error(f"Error creating MOMENT anomaly visualization: {e}") + return None, None + + +def create_anomaly_plot_from_data(data_df: pd.DataFrame, sensor_name: str, engine_unit: int, + output_dir: str, plot_title: str = None) -> Tuple[str, str]: + """ + Create anomaly detection visualization plot from sensor data with is_anomaly column. + + Args: + data_df: DataFrame containing sensor data with 'is_anomaly' boolean column + sensor_name: Name of the sensor column to plot + engine_unit: Engine unit number for labeling + output_dir: Directory to save plot files + plot_title: Custom title for the plot (optional) + + Returns: + Tuple of (html_filepath, png_filepath) + """ + try: + import plotly.graph_objects as go + import numpy as np + + if sensor_name not in data_df.columns: + raise ValueError(f"Sensor '{sensor_name}' not found in data. Available sensors: {data_df.columns.tolist()}") + + if 'is_anomaly' not in data_df.columns: + raise ValueError("'is_anomaly' column not found in data. Make sure to use output from MOMENT anomaly detection tool.") + + # Create figure + fig = go.Figure() + + # Determine time axis (check for various time column names) + time_columns = ['time_in_cycles', 'cycle', 'time', 'timestamp'] + x_axis = None + x_title = "Index" + + for col in time_columns: + if col in data_df.columns: + x_axis = data_df[col] + x_title = col.replace('_', ' ').title() + break + + if x_axis is None: + x_axis = data_df.index + x_title = "Index" + + # Plot all sensor readings as blue line (Observed) + fig.add_trace( + go.Scatter( + x=x_axis, + y=data_df[sensor_name], + mode='lines', + name='Observed', + line=dict(color='blue', width=2), + opacity=0.8 + ) + ) + + # Extract anomaly points directly from the is_anomaly column + anomaly_mask = data_df['is_anomaly'] == True + anomaly_indices = data_df[anomaly_mask].index.values + + # Plot anomalous points as red markers + if len(anomaly_indices) > 0: + anomaly_x = x_axis.iloc[anomaly_indices] + anomaly_y = data_df[sensor_name].iloc[anomaly_indices] + + fig.add_trace( + go.Scatter( + x=anomaly_x, + y=anomaly_y, + mode='markers', + name='Anomaly', + marker=dict(color='red', size=8, symbol='circle'), + opacity=0.9 + ) + ) + + fig.update_layout( + title=f'Anomaly Detection - {sensor_name} (Engine {engine_unit})', + xaxis_title=x_title, + yaxis_title=f"{sensor_name}", + height=500, + showlegend=True, + font=dict(size=12), + template="plotly_white" + ) + + # Save files + os.makedirs(output_dir, exist_ok=True) + + # HTML file + html_filename = f"anomaly_plot_{sensor_name}_engine{engine_unit}.html" + html_filepath = os.path.join(output_dir, html_filename) + fig.write_html(html_filepath) + + # PNG file using thread-safe function + png_filename = f"anomaly_plot_{sensor_name}_engine{engine_unit}.png" + png_filepath = os.path.join(output_dir, png_filename) + png_success = save_plotly_as_png(fig, png_filepath, width=1200, height=500) + + logger.info(f"Anomaly plot saved: HTML={html_filepath}, PNG={'Success' if png_success else 'Failed'}") + + return html_filepath, png_filepath if png_success else None + + except ImportError: + logger.error("Plotly not available for visualization") + return None, None + except Exception as e: + logger.error(f"Error creating anomaly plot: {e}") + return None, None \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/__init__.py b/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/__init__.py new file mode 100644 index 0000000..3294670 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Predictors package for Asset Lifecycle Management agent. + +This package contains components for prediction and anomaly detection +in Asset Lifecycle Management workflows (Operation & Maintenance phase). +""" + +from . import moment_anomaly_detection_tool +from . import predict_rul_tool + +__all__ = [ + "moment_anomaly_detection_tool", + "predict_rul_tool", +] \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/moment_anomaly_detection_tool.py b/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/moment_anomaly_detection_tool.py new file mode 100644 index 0000000..e7e8ac1 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/moment_anomaly_detection_tool.py @@ -0,0 +1,422 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import pandas as pd +import numpy as np +from typing import List, Tuple, Optional +from pydantic import Field, BaseModel + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig + +# Note: Visualization is now handled by the separate plot_anomaly_tool + +logger = logging.getLogger(__name__) + +# Global model instance - initialized once when module is loaded +_MOMENT_MODEL: Optional[object] = None +_MODEL_DEVICE: Optional[str] = None + +def _initialize_moment_model(): + """Initialize MOMENT model once and cache it globally.""" + global _MOMENT_MODEL, _MODEL_DEVICE + + if _MOMENT_MODEL is not None: + logger.info("MOMENT model already initialized, reusing cached instance") + return _MOMENT_MODEL, _MODEL_DEVICE + + try: + logger.info("Initializing MOMENT-1-small model (one-time setup)...") + import time + start_time = time.time() + + from momentfm import MOMENTPipeline + import torch + + # Initialize MOMENT pipeline for anomaly detection + model_name = "MOMENT-1-small" + _MOMENT_MODEL = MOMENTPipeline.from_pretrained( + f"AutonLab/{model_name}", + model_kwargs={"task_name": "reconstruction"} + ) + _MOMENT_MODEL.init() + + # Move model to device + _MODEL_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + _MOMENT_MODEL = _MOMENT_MODEL.to(_MODEL_DEVICE).float() + + logger.info(f"MOMENT model initialized and cached in {time.time() - start_time:.2f} seconds on {_MODEL_DEVICE}") + return _MOMENT_MODEL, _MODEL_DEVICE + + except Exception as e: + logger.error(f"Failed to initialize MOMENT model: {e}") + raise RuntimeError(f"MOMENT model initialization failed: {e}") + +# Pre-initialize the model when module is imported (optional - can be lazy loaded) +try: + _initialize_moment_model() + logger.info("MOMENT model pre-loaded successfully") +except Exception as e: + logger.warning(f"MOMENT model pre-loading failed, will initialize on first use: {e}") + _MOMENT_MODEL = None + _MODEL_DEVICE = None + + +class TimeSeriesAnomalyDetectionToolConfig(FunctionBaseConfig, name="moment_anomaly_detection_tool"): + """ + NeMo Agent Toolkit function to perform anomaly detection using MOMENT-1-small foundation model. + """ + output_folder: str = Field(description="The path to the output folder to save results.", default="./output_data") + +@register_function(config_type=TimeSeriesAnomalyDetectionToolConfig) +async def moment_anomaly_detection_tool( + config: TimeSeriesAnomalyDetectionToolConfig, builder: Builder +): + class MomentAnomalyDetectionInputSchema(BaseModel): + sensor_data_json_path: str = Field(description="Path to JSON file containing sensor data (from sql_retriever tool)") + engine_unit: int = Field(description="Engine unit number to analyze", default=5) + sensor_name: str = Field(description="Name of the sensor to analyze and plot (e.g., 'sensor_measurement_1', 'sensor_measurement_4')", default="sensor_measurement_1") + + def prepare_time_series_data_for_moment(df: pd.DataFrame, sensor_name: str, max_seq_len: int = 224) -> List[np.ndarray]: + """Prepare time series data for MOMENT model input. + + MOMENT expects input shape: (batch_size, num_channels, seq_len) + For single sensor analysis: (1, 1, seq_len) where seq_len <= 512 + + Args: + df: DataFrame with sensor data + sensor_name: Name of the sensor column to process + max_seq_len: Maximum sequence length (224 for MOMENT-1-small optimal) + + Returns: + List of sequences with shape (1, 1, seq_len) - all patch-aligned + """ + try: + # Select single sensor column + sensor_data = df[sensor_name].values + logger.info(f"Original sensor data shape: {sensor_data.shape}") + + # Normalize the data + from sklearn.preprocessing import StandardScaler + scaler = StandardScaler() + normalized_data = scaler.fit_transform(sensor_data.reshape(-1, 1)).flatten() + logger.info(f"Normalized sensor data shape: {normalized_data.shape}") + + # Split data into chunks of max_seq_len + sequences = [] + total_length = len(normalized_data) + PATCH_LEN = 8 # MOMENT's default patch length + + i = 0 + while i < total_length: + chunk = normalized_data[i:i + max_seq_len] + + # Truncate to largest multiple of PATCH_LEN (discard non-aligned timesteps) + current_len = len(chunk) + aligned_len = (current_len // PATCH_LEN) * PATCH_LEN + + if aligned_len > 0: # Only keep if we have at least one complete patch + chunk = chunk[:aligned_len] + sequence = chunk.reshape(1, 1, -1) + sequences.append(sequence) + logger.info(f"Truncated sequence from {current_len} to {aligned_len} (discarded {current_len - aligned_len} timesteps)") + else: + logger.info(f"Skipped sequence of length {current_len} (less than one patch)") + + i += max_seq_len + + logger.info(f"Created {len(sequences)} sequences, shapes: {[seq.shape for seq in sequences]}") + + return sequences + + except Exception as e: + logger.error(f"Error preparing time series data for MOMENT: {e}") + return None + + def create_moment_dataset(sequences: List[np.ndarray]): + """Create a dataset compatible with MOMENT from sequences (all same length after truncation).""" + import torch + from torch.utils.data import TensorDataset + + data_tensors = [] + labels = [] + + for seq in sequences: + # seq shape: (1, 1, seq_len) -> squeeze to (1, seq_len) + seq_squeezed = seq.squeeze(0) # Remove first dimension: (1, seq_len) + data_tensors.append(torch.FloatTensor(seq_squeezed)) + labels.append(torch.tensor(0)) # Dummy label + + # All sequences now have patch-aligned lengths - stacking will work + data = torch.stack(data_tensors) # (num_sequences, 1, seq_len) + labels = torch.stack(labels) # (num_sequences,) + + logger.info(f"Dataset created - data shape: {data.shape}") + + return TensorDataset(data, labels) + + def detect_anomalies_with_moment(sequences: List[np.ndarray], threshold_percentile: float) -> Tuple[np.ndarray, np.ndarray]: + """Detect anomalies using MOMENT-1-small foundation model following the official tutorial. + + Args: + sequences: List of sequences with shape (1, 1, seq_len) + threshold_percentile: Percentile for anomaly threshold + + Returns: + anomalies: Boolean array indicating anomalies + anomaly_scores: Array of reconstruction error scores (per timestep) + """ + logger.info("Starting MOMENT-based anomaly detection...") + + from torch.utils.data import DataLoader + from tqdm import tqdm + import torch + + # Use pre-initialized global model or initialize if needed + model, device = _initialize_moment_model() + + logger.info(f"Using cached MOMENT-1-small model for anomaly detection") + logger.info(f"Number of sequences to process: {len(sequences)}") + if sequences: + logger.info(f"Each sequence shape: {sequences[0].shape}") + + # Create dataset without masks + dataset = create_moment_dataset(sequences) # Simplified call + dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False) + logger.info(f"Using device: {device}") + + # Process batches following the tutorial pattern + model.eval() + trues, preds = [], [] + with torch.no_grad(): + for batch_data in tqdm(dataloader, total=len(dataloader), desc="Processing batches"): + # Unpack - now only batch_x and batch_labels (no masks) + if len(batch_data) == 2: + batch_x, batch_labels = batch_data + else: + batch_x, batch_masks, batch_labels = batch_data # Fallback for old format + + batch_x = batch_x.to(device).float() + + logger.info(f"Input batch_x shape: {batch_x.shape}") + + # MOMENT forward pass WITHOUT input_mask + output = model(x_enc=batch_x) # Simplified - no mask parameter + + logger.info(f"Output reconstruction shape: {output.reconstruction.shape}") + + # Continue with existing processing... + batch_x_np = batch_x.detach().cpu().numpy() + reconstruction_np = output.reconstruction.detach().cpu().numpy() + + # Handle potential shape differences (if any) + if batch_x_np.shape != reconstruction_np.shape: + logger.warning(f"Shape mismatch: input {batch_x_np.shape} vs reconstruction {reconstruction_np.shape}") + min_seq_len = min(batch_x_np.shape[-1], reconstruction_np.shape[-1]) + batch_x_np = batch_x_np[..., :min_seq_len] + reconstruction_np = reconstruction_np[..., :min_seq_len] + logger.info(f"Aligned shapes: input {batch_x_np.shape}, reconstruction {reconstruction_np.shape}") + + # Flatten to 1D for each sample in the batch + for i in range(batch_x_np.shape[0]): + true_seq = batch_x_np[i].flatten() + pred_seq = reconstruction_np[i].flatten() + + trues.append(true_seq) + preds.append(pred_seq) + + # Concatenate all results + trues = np.concatenate(trues, axis=0) + preds = np.concatenate(preds, axis=0) + + logger.info(f"Final concatenated shapes - trues: {trues.shape}, preds: {preds.shape}") + + # Ensure shapes match for calculation (they should already match due to our alignment above) + if len(trues) != len(preds): + min_length = min(len(trues), len(preds)) + logger.warning(f"Final shape mismatch: trues={len(trues)}, preds={len(preds)}. Trimming to {min_length}") + trues = trues[:min_length] + preds = preds[:min_length] + else: + logger.info(f"Shapes are aligned: trues={len(trues)}, preds={len(preds)}") + + # Calculate anomaly scores using MSE (following tutorial) + anomaly_scores = (trues - preds) ** 2 + + # Determine anomaly threshold + threshold = np.percentile(anomaly_scores, threshold_percentile) + anomalies = anomaly_scores > threshold + + logger.info(f"MOMENT Anomaly Detection: {np.sum(anomalies)} anomalies detected out of {len(anomalies)} timesteps") + logger.info(f"Anomaly threshold ({threshold_percentile}th percentile): {threshold:.6f}") + logger.info(f"Anomaly scores range: {np.min(anomaly_scores):.6f} - {np.max(anomaly_scores):.6f}") + + return anomalies + + + + async def _response_fn( + sensor_data_json_path: str, + engine_unit: int = 5, + sensor_name: str = "sensor_measurement_1" + ) -> str: + """ + Perform anomaly detection using MOMENT-1-Small foundation model on JSON data from sql_retriever. + """ + # Set default parameters (not exposed to LLM)ensor + threshold_percentile = 95.0 + + try: + if not sensor_data_json_path.lower().endswith('.json'): + return "sensor_data_json_path must be a path to a JSON file (ending with .json)" + + if not os.path.exists(sensor_data_json_path): + return f"JSON file not found at path: {sensor_data_json_path}" + + # Load data from JSON file (output from sql_retriever) + from ..plotting.plot_utils import load_data_from_json + combined_df = load_data_from_json(sensor_data_json_path, config.output_folder) + + if combined_df is None or combined_df.empty: + return f"Could not load data or data is empty from JSON file: {sensor_data_json_path}" + + # Filter for specific engine unit if specified + if 'unit_number' in combined_df.columns: + engine_data = combined_df[combined_df['unit_number'] == engine_unit] + if engine_data.empty: + return f"No data found for engine unit {engine_unit} in the provided JSON file. Available units: {sorted(combined_df['unit_number'].unique())}" + + # Sort by cycle for proper time series analysis + if 'time_in_cycles' in engine_data.columns: + engine_data = engine_data.sort_values('time_in_cycles').reset_index(drop=True) + + logger.info(f"Engine data shape: {engine_data.shape}") + logger.info(f"Analyzing sensor: {sensor_name}") + logger.info(f"MOMENT sequence length: 512") + + # Prepare time series data for MOMENT (single sensor) + sequences = prepare_time_series_data_for_moment(engine_data, sensor_name, max_seq_len=224) + + if sequences is None: + return "Failed to prepare time series data for MOMENT analysis" + + logger.info("Starting MOMENT-based anomaly detection...") + anomaly_indices = detect_anomalies_with_moment(sequences, threshold_percentile) + + # Add is_anomaly column to the original dataframe + # Handle case where MOMENT output length differs from input length + if len(anomaly_indices) == len(engine_data): + engine_data['is_anomaly'] = anomaly_indices + elif len(anomaly_indices) < len(engine_data): + # MOMENT output is shorter - pad with False for remaining timesteps + padded_anomalies = np.zeros(len(engine_data), dtype=bool) + padded_anomalies[:len(anomaly_indices)] = anomaly_indices + engine_data['is_anomaly'] = padded_anomalies + logger.warning(f"MOMENT output length ({len(anomaly_indices)}) < input length ({len(engine_data)}). Padded with False.") + else: + # MOMENT output is longer - trim to match input length + engine_data['is_anomaly'] = anomaly_indices[:len(engine_data)] + logger.warning(f"MOMENT output length ({len(anomaly_indices)}) > input length ({len(engine_data)}). Trimmed to match.") + + # Calculate summary statistics using the final anomaly column + final_anomalies = engine_data['is_anomaly'] + total_anomalies = np.sum(final_anomalies) + anomaly_rate = total_anomalies / len(final_anomalies) * 100 + + # Save results + os.makedirs(config.output_folder, exist_ok=True) + + # Save the original data with is_anomaly column added + # For saving, we want to save relative to output_folder if the original path was relative + if not os.path.isabs(sensor_data_json_path): + save_path = os.path.join(config.output_folder, os.path.basename(sensor_data_json_path)) + else: + # If it was an absolute path, create a results file in output folder + results_filename = f"moment_anomaly_results_engine{engine_unit}.json" + save_path = os.path.join(config.output_folder, results_filename) + + engine_data.to_json(save_path, orient='records', indent=2) + results_filepath = save_path + + # Build comprehensive response + response_parts = [ + "MOMENT-1-Small FOUNDATION MODEL ANOMALY DETECTION COMPLETED SUCCESSFULLY", + "", + f"Analysis Details:", + f" • Engine Unit: {engine_unit}", + f" • Source Data: {os.path.basename(sensor_data_json_path)}", + f" • Sensor Analyzed: {sensor_name}", + f" • Model: MOMENT-1-Small Foundation Model", + f" • Max Sequence Length: 512", + f" • Threshold Percentile: {threshold_percentile}%", + "", + f"Anomaly Detection Results:", + f" • Total Timesteps Analyzed: {len(final_anomalies)}", + f" • Anomalous Timesteps Detected: {total_anomalies}", + f" • Anomaly Rate: {anomaly_rate:.2f}%", + "", + f"Output Files Generated:", + f" • Enhanced Data with is_anomaly Column: {os.path.relpath(results_filepath, config.output_folder)}" + ] + + response_parts.extend([ + "", + f"Key Insights:", + f" • MOMENT-1-Small foundation model provides state-of-the-art time series anomaly detection", + f" • Pre-trained on diverse time series data for superior pattern recognition without additional training", + f" • {total_anomalies} anomalous time periods identified out of {len(final_anomalies)} analyzed sequences", + "", + f"Output Format:", + f" • Original sensor data with added 'is_anomaly' boolean column", + f" • Use the enhanced JSON file with plot_anomaly_tool for visualization", + "", + "MOMENT-1-Small ANOMALY DETECTION COMPLETE" + ]) + + return "\n".join(response_parts) + + except Exception as e: + error_msg = f"Error performing MOMENT-based anomaly detection: {e}" + logger.error(error_msg) + return error_msg + + description = """ + Perform state-of-the-art anomaly detection using MOMENT-1-Small foundation model on sensor data from JSON files. + Outputs detailed anomaly detection results. Use plot_anomaly_tool afterward for visualization. + + Input: + - sensor_data_json_path: File path to a JSON containing sensor data. The file must include timestamp and engine unit number columns along with sensor data columns. + - engine_unit: Engine unit number to analyze (default: 5) + - sensor_name: Name of the specific sensor to analyze and plot (e.g., 'sensor_measurement_1', 'sensor_measurement_4', 'sensor_measurement_7', 'sensor_measurement_11') (default: 'sensor_measurement_1') + + Output: + - JSON file containing original sensor data with added 'is_anomaly' boolean column + - Comprehensive analysis summary with key insights + """ + + yield FunctionInfo.from_fn(_response_fn, + input_schema=MomentAnomalyDetectionInputSchema, + description=description) + try: + pass + except GeneratorExit: + logger.info("moment based anomaly detection function exited early!") + finally: + logger.info("Cleaning up moment based anomaly detection workflow.") \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/moment_predict_rul_tool.py b/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/moment_predict_rul_tool.py new file mode 100644 index 0000000..d90b3d7 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/moment_predict_rul_tool.py @@ -0,0 +1,483 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import pandas as pd +import numpy as np +from typing import List, Tuple, Optional +from pydantic import Field, BaseModel + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig + +logger = logging.getLogger(__name__) + +# Global MOMENT model cache - key is forecast_horizon, value is (model, device) +_MOMENT_MODEL_CACHE: dict = {} + +def _initialize_moment_model(forecast_horizon: int = 96): + """Initialize MOMENT model and cache it for specific forecast horizon.""" + global _MOMENT_MODEL_CACHE + + # Check if we already have a model for this forecast horizon + if forecast_horizon in _MOMENT_MODEL_CACHE: + logger.info(f"MOMENT model already initialized for horizon {forecast_horizon}, reusing cached instance") + return _MOMENT_MODEL_CACHE[forecast_horizon] + + try: + logger.info(f"Initializing MOMENT-1-small model for forecasting horizon {forecast_horizon}...") + import time + start_time = time.time() + + from momentfm import MOMENTPipeline + import torch + + # Initialize MOMENT pipeline for forecasting + model_name = "MOMENT-1-small" + model = MOMENTPipeline.from_pretrained( + f"AutonLab/{model_name}", + model_kwargs={ + 'task_name': 'forecasting', + 'forecast_horizon': forecast_horizon + } + ) + model.init() + + # Move model to device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device).float() + + # Cache the model for this forecast horizon + _MOMENT_MODEL_CACHE[forecast_horizon] = (model, device) + + logger.info(f"MOMENT model initialized and cached in {time.time() - start_time:.2f} seconds on {device}") + return model, device + + except Exception as e: + logger.error(f"Failed to initialize MOMENT model: {e}") + raise RuntimeError(f"MOMENT model initialization failed: {e}") + +# Don't pre-initialize the model since forecast_horizon may vary +# Initialize on first use with the correct forecast_horizon +logger.info("MOMENT model will be initialized on first use with correct forecast_horizon") + + +class MomentPredictRulToolConfig(FunctionBaseConfig, name="moment_predict_rul_tool"): + """ + NeMo Agent Toolkit function to predict RUL using MOMENT-1-small foundation model forecasting. + """ + forecast_horizon: int = Field(description="Number of future timesteps to forecast for trend analysis", default=50) + failure_threshold: float = Field(description="Degradation threshold in normalized space to indicate failure", default=-2.0) + max_rul_cycles: int = Field(description="Maximum RUL prediction to cap unrealistic values", default=500) + output_folder: str = Field(description="Path to output folder to save results", default="./output_data") + +@register_function(config_type=MomentPredictRulToolConfig) +async def moment_predict_rul_tool( + config: MomentPredictRulToolConfig, builder: Builder +): + class MomentPredictRulInputSchema(BaseModel): + sensor_data_json_path: str = Field(description="Path to JSON file containing sensor measurements data for RUL prediction") + engine_unit: int = Field(description="Specific engine unit to analyze (optional, analyzes all if not specified)", default=None) + + def load_data_from_json(json_path: str, output_folder: str = None) -> pd.DataFrame: + """Load data from JSON file into a pandas DataFrame.""" + try: + # Resolve path relative to output folder if provided + if output_folder: + if os.path.isabs(json_path): + # If absolute path exists, use it + if os.path.exists(json_path): + resolved_path = json_path + else: + # If absolute path doesn't exist, try relative to output folder + basename = os.path.basename(json_path) + resolved_path = os.path.join(output_folder, basename) + else: + # If relative path, first try relative to output folder + relative_to_output = os.path.join(output_folder, json_path) + if os.path.exists(relative_to_output): + resolved_path = relative_to_output + else: + # Then try as provided (relative to current working directory) + resolved_path = json_path + else: + resolved_path = json_path + + with open(resolved_path, 'r') as f: + data = json.load(f) + df = pd.DataFrame(data) + logger.info(f"Loaded data with shape: {df.shape}") + return df + except FileNotFoundError: + logger.error(f"JSON file not found at {json_path}") + raise FileNotFoundError(f"JSON file not found at {json_path}") + except json.JSONDecodeError: + logger.error(f"Could not decode JSON from {json_path}") + raise ValueError(f"Invalid JSON format in {json_path}") + except Exception as e: + logger.error(f"Error loading data from '{json_path}': {e}") + raise + + def prepare_sensor_data_for_moment(df: pd.DataFrame, feature_columns: List[str]) -> np.ndarray: + """Prepare sensor data for MOMENT input with proper normalization.""" + from sklearn.preprocessing import StandardScaler + + # Extract sensor data + sensor_data = df[feature_columns].values + logger.info(f"Raw sensor data shape: {sensor_data.shape}") + + # Normalize the data + scaler = StandardScaler() + normalized_data = scaler.fit_transform(sensor_data) + logger.info(f"Normalized sensor data shape: {normalized_data.shape}") + + return normalized_data, scaler + + def forecast_sensor_degradation(model, device, sensor_data: np.ndarray, + sequence_length: int, forecast_horizon: int) -> np.ndarray: + """Use MOMENT model to forecast sensor degradation patterns. + + Process each sensor individually as MOMENT forecasting works better with univariate time series. + """ + import torch + + # MOMENT model expects exactly 512 timesteps for proper operation + expected_seq_len = 512 + + if len(sensor_data) < expected_seq_len: + # Pad with zeros if insufficient data (front padding to preserve recent trends) + padded_data = np.zeros((expected_seq_len, sensor_data.shape[1])) + padded_data[-len(sensor_data):] = sensor_data + input_data = padded_data + logger.warning(f"Data length {len(sensor_data)} < expected_seq_len {expected_seq_len}. Padded with zeros.") + else: + # Take the last 512 timesteps + input_data = sensor_data[-expected_seq_len:] + + logger.info(f"Final input data shape for MOMENT: {input_data.shape}") + + num_sensors = input_data.shape[1] + seq_len = input_data.shape[0] + + logger.info(f"Processing {num_sensors} sensors individually, each with {seq_len} timesteps") + + all_forecasts = [] + + # Process each sensor individually (univariate forecasting) + for sensor_idx in range(num_sensors): + try: + # Extract single sensor time series + sensor_ts = input_data[:, sensor_idx] # Shape: (seq_len,) + + # Convert to tensor format: (batch_size=1, num_channels=1, seq_len) + input_tensor = torch.FloatTensor(sensor_ts).unsqueeze(0).unsqueeze(0).to(device) # (1, 1, seq_len) + + # Create input mask (all True since we don't have missing values) + input_mask = torch.ones(seq_len, dtype=torch.bool).unsqueeze(0).to(device) # (1, seq_len) + + with torch.no_grad(): + # MOMENT forecasting for single sensor + forecast_output = model(x_enc=input_tensor, input_mask=input_mask) + forecast = forecast_output.forecast.cpu().numpy() # Shape: (1, 1, forecast_horizon) + + # Extract forecast for this sensor + sensor_forecast = forecast.squeeze() # Shape: (forecast_horizon,) + all_forecasts.append(sensor_forecast) + + logger.info(f"Sensor {sensor_idx}: input shape {input_tensor.shape}, forecast shape {sensor_forecast.shape}") + + except Exception as e: + logger.warning(f"Error forecasting sensor {sensor_idx}: {e}. Using zero forecast.") + # Use zero forecast as fallback + zero_forecast = np.zeros(forecast_horizon) + all_forecasts.append(zero_forecast) + + # Combine all sensor forecasts + combined_forecast = np.array(all_forecasts) # Shape: (num_sensors, forecast_horizon) + logger.info(f"Combined forecast shape: {combined_forecast.shape}") + + return combined_forecast + + def calculate_rul_from_degradation(current_values: np.ndarray, + forecasted_values: np.ndarray, + forecast_horizon: int, + failure_threshold: float, + max_rul_cycles: int) -> float: + """Calculate RUL based on sensor degradation trends. + + Args: + current_values: Current sensor values, shape (num_sensors,) + forecasted_values: Forecasted sensor values, shape (num_sensors, forecast_horizon) + forecast_horizon: Number of forecast timesteps + failure_threshold: Degradation threshold for failure + max_rul_cycles: Maximum RUL prediction + """ + + # Calculate degradation rate across all sensors + degradation_rates = [] + + for i in range(len(current_values)): + # Calculate degradation rate for each sensor using final forecasted value + final_forecasted_value = forecasted_values[i, -1] # Last forecasted timestep + sensor_degradation_rate = (final_forecasted_value - current_values[i]) / forecast_horizon + degradation_rates.append(sensor_degradation_rate) + + # Use average degradation rate + avg_degradation_rate = np.mean(degradation_rates) + current_degradation = np.mean(current_values) + + logger.info(f"Current degradation level: {current_degradation:.4f}") + logger.info(f"Average degradation rate: {avg_degradation_rate:.6f}") + logger.info(f"Individual sensor degradation rates: {[f'{rate:.6f}' for rate in degradation_rates]}") + + if avg_degradation_rate < 0: # System is degrading + # Calculate cycles until failure threshold is reached + cycles_to_failure = (failure_threshold - current_degradation) / avg_degradation_rate + rul_prediction = max(1, min(cycles_to_failure, max_rul_cycles)) + else: + # System is improving or stable - predict high RUL + rul_prediction = max_rul_cycles * 0.8 # 80% of max as conservative estimate + + return float(rul_prediction) + + def predict_rul_for_engine_unit(df: pd.DataFrame, unit_id: int, feature_columns: List[str], + model, device) -> Tuple[float, dict]: + """Predict RUL for a specific engine unit using MOMENT forecasting.""" + + unit_data = df[df['unit_number'] == unit_id].copy() + + # Sort by time for proper time series analysis + if 'time_in_cycles' in unit_data.columns: + unit_data = unit_data.sort_values('time_in_cycles').reset_index(drop=True) + + logger.info(f"Processing engine unit {unit_id} with {len(unit_data)} timesteps") + + # Prepare sensor data + normalized_data, scaler = prepare_sensor_data_for_moment(unit_data, feature_columns) + + if len(normalized_data) < 10: # Need minimum data for meaningful prediction + logger.warning(f"Insufficient data for unit {unit_id} ({len(normalized_data)} timesteps)") + return config.max_rul_cycles * 0.5, {"status": "insufficient_data"} + + # Forecast sensor degradation using MOMENT's expected sequence length + forecast = forecast_sensor_degradation( + model, device, normalized_data, + 512, # MOMENT expects 512 timesteps + config.forecast_horizon + ) + + # Calculate RUL + current_values = normalized_data[-1] # Last known sensor values (shape: num_sensors) + forecasted_values = forecast # Forecasted sensor values (shape: num_sensors, forecast_horizon) + + rul_prediction = calculate_rul_from_degradation( + current_values, forecasted_values, + config.forecast_horizon, config.failure_threshold, config.max_rul_cycles + ) + + # Additional metrics for analysis + final_forecasted_values = forecasted_values[:, -1] # Last timestep of each sensor forecast + metrics = { + "rul_prediction": rul_prediction, + "current_degradation": float(np.mean(current_values)), + "forecast_degradation": float(np.mean(final_forecasted_values)), + "degradation_rate": float((np.mean(final_forecasted_values) - np.mean(current_values)) / config.forecast_horizon), + "data_points": len(unit_data), + "sensors_analyzed": len(feature_columns), + "forecast_horizon": config.forecast_horizon, + "status": "success" + } + + return rul_prediction, metrics + + async def _response_fn( + sensor_data_json_path: str, + engine_unit: int = None + ) -> str: + """ + Predict RUL using MOMENT-1-small foundation model forecasting. + """ + try: + # Validate file path + if not sensor_data_json_path.lower().endswith('.json'): + return "sensor_data_json_path must be a path to a JSON file (ending with .json)" + + if not os.path.exists(sensor_data_json_path): + return f"JSON file not found at path: {sensor_data_json_path}" + + # Load data + df = load_data_from_json(sensor_data_json_path, config.output_folder) + + if df.empty: + return f"No data found in JSON file: {sensor_data_json_path}" + + # Define required sensor columns (same as traditional RUL models) + required_columns = [ + 'sensor_measurement_2', 'sensor_measurement_3', 'sensor_measurement_4', + 'sensor_measurement_7', 'sensor_measurement_8', 'sensor_measurement_11', + 'sensor_measurement_12', 'sensor_measurement_13', 'sensor_measurement_15', + 'sensor_measurement_17', 'sensor_measurement_20', 'sensor_measurement_21' + ] + + feature_columns = [col for col in df.columns if col in required_columns] + if not feature_columns: + return f"No valid sensor columns found. Available columns: {df.columns.tolist()}" + + logger.info(f"Using {len(feature_columns)} sensor features: {feature_columns}") + + # Initialize MOMENT model + model, device = _initialize_moment_model(config.forecast_horizon) + + # Determine which engines to process + if 'unit_number' in df.columns: + if engine_unit is not None: + engine_units = [engine_unit] if engine_unit in df['unit_number'].unique() else [] + if not engine_units: + available_units = sorted(df['unit_number'].unique()) + return f"Engine unit {engine_unit} not found. Available units: {available_units}" + else: + engine_units = sorted(df['unit_number'].unique()) + else: + # No unit column - treat as single engine + df['unit_number'] = 1 + engine_units = [1] + + logger.info(f"Processing {len(engine_units)} engine units: {engine_units}") + + # Process each engine unit + results = [] + all_predictions = [] + unit_metrics = {} + + for unit_id in engine_units: + try: + rul_pred, metrics = predict_rul_for_engine_unit( + df, unit_id, feature_columns, model, device + ) + + results.append({ + "unit_number": unit_id, + "predicted_RUL": rul_pred, + **metrics + }) + + # Add predictions to all timesteps for this unit + unit_data = df[df['unit_number'] == unit_id] + unit_predictions = [rul_pred] * len(unit_data) + all_predictions.extend(unit_predictions) + unit_metrics[unit_id] = metrics + + logger.info(f"Unit {unit_id}: Predicted RUL = {rul_pred:.1f} cycles") + + except Exception as e: + logger.error(f"Error processing unit {unit_id}: {e}") + results.append({ + "unit_number": unit_id, + "predicted_RUL": None, + "status": "error", + "error": str(e) + }) + + if not all_predictions: + return "No successful RUL predictions could be generated" + + # Add predictions to original DataFrame + df_result = df.copy() + if 'RUL' in df_result.columns: + df_result = df_result.rename(columns={'RUL': 'actual_RUL'}) + df_result['predicted_RUL'] = all_predictions + + # Save results back to the original JSON file (consistent with predict_rul_tool) + # For saving, we want to save relative to output_folder if the original path was relative + if not os.path.isabs(sensor_data_json_path): + save_path = os.path.join(config.output_folder, os.path.basename(sensor_data_json_path)) + else: + save_path = sensor_data_json_path + + results_json = df_result.to_dict('records') + with open(save_path, 'w') as f: + json.dump(results_json, f, indent=2) + + logger.info(f"MOMENT RUL prediction results saved back to file: {save_path}") + results_filepath = save_path + + # Generate summary statistics + valid_predictions = [p for p in all_predictions if p is not None] + avg_rul = np.mean(valid_predictions) + min_rul = np.min(valid_predictions) + max_rul = np.max(valid_predictions) + std_rul = np.std(valid_predictions) + + # Build response similar to predict_rul_tool format (relative path from output folder) + results_relpath = os.path.relpath(results_filepath, config.output_folder) + response = f"""RUL predictions generated successfully! 📊 + +**Model Used:** MOMENT-1-Small Foundation Model (Time Series Forecasting) + +**Prediction Summary:** +- **Total predictions:** {len(valid_predictions)} +- **Average RUL:** {avg_rul:.2f} cycles +- **Minimum RUL:** {min_rul:.2f} cycles +- **Maximum RUL:** {max_rul:.2f} cycles +- **Standard Deviation:** {std_rul:.2f} cycles + +**Results saved to:** {results_relpath} + +The predictions have been added to the original dataset with column name 'predicted_RUL'. The original JSON file has been updated with the RUL predictions. +All columns from the original dataset have been preserved, and the predicted RUL column has been renamed to 'predicted_RUL' and the actual RUL column has been renamed to 'actual_RUL'.""" + + return response + + except Exception as e: + error_msg = f"Error performing MOMENT-based RUL prediction: {e}" + logger.error(error_msg) + return error_msg + + description = """ + Predict RUL (Remaining Useful Life) using MOMENT-1-small foundation model with time series forecasting. + + This tool leverages the power of foundation models to forecast sensor degradation patterns and predict + remaining useful life without requiring domain-specific training data. + + Input: + - sensor_data_json_path: Path to JSON file containing sensor measurements + - engine_unit: Specific engine unit to analyze (optional, analyzes all units if not specified) + + Required Sensor Columns: + • sensor_measurement_2, sensor_measurement_3, sensor_measurement_4 + • sensor_measurement_7, sensor_measurement_8, sensor_measurement_11 + • sensor_measurement_12, sensor_measurement_13, sensor_measurement_15 + • sensor_measurement_17, sensor_measurement_20, sensor_measurement_21 + + Output: + - RUL predictions for each engine unit based on sensor forecasting + - Detailed degradation analysis and trend metrics + - Original JSON file updated with predictions added as 'predicted_RUL' column + - Foundation model insights and confidence indicators + """ + + yield FunctionInfo.from_fn(_response_fn, + input_schema=MomentPredictRulInputSchema, + description=description) + try: + pass + except GeneratorExit: + logger.info("MOMENT RUL prediction function exited early!") + finally: + logger.info("Cleaning up MOMENT RUL prediction workflow.") diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/nv_tesseract_anomaly_detection_tool.py b/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/nv_tesseract_anomaly_detection_tool.py new file mode 100644 index 0000000..49895d0 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/nv_tesseract_anomaly_detection_tool.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NV Tesseract-based anomaly detection tool using NVIDIA NIM.""" + +import json +from typing import Optional + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig +from pydantic import BaseModel, Field + + +class NVTesseractAnomalyDetectionToolConfig(FunctionBaseConfig, name="nv_tesseract_anomaly_detection"): + """Configuration for NV Tesseract anomaly detection tool.""" + + llm_name: str = Field(description="Name of the LLM to use for NV Tesseract NIM") + model_name: str = Field( + default="nvidia/nv-anomaly-tesseract-1.0", + description="NIM model name for anomaly detection" + ) + lookback_period: int = Field( + default=30, + description="Number of time steps to look back for anomaly detection" + ) + forecast_horizon: int = Field( + default=10, + description="Number of time steps to forecast" + ) + + +class AnomalyDetectionInput(BaseModel): + """Input schema for anomaly detection.""" + + unit_number: int = Field(description="Unit number to analyze") + dataset_name: str = Field(description="Dataset name (e.g., 'train_FD001', 'test_FD002')") + + +@register_function(config_type=NVTesseractAnomalyDetectionToolConfig) +async def nv_tesseract_anomaly_detection_tool( + config: NVTesseractAnomalyDetectionToolConfig, builder: Builder +): + """ + NV Tesseract-based anomaly detection using NVIDIA NIM. + + This tool uses NVIDIA's NV Tesseract foundation model for time-series anomaly detection. + It analyzes sensor data from turbofan engines to identify anomalous patterns. + """ + + async def _detect_anomalies(unit_number: int, dataset_name: str) -> str: + """ + Detect anomalies in sensor data using NV Tesseract NIM. + + Args: + unit_number: Unit number to analyze + dataset_name: Dataset name (e.g., 'train_FD001', 'test_FD002') + + Returns: + JSON string containing anomaly detection results + """ + # Get the LLM (NIM endpoint) + llm = await builder.get_llm(config.llm_name) + + # Get SQL retriever to fetch sensor data + sql_retriever = await builder.get_function("sql_retriever") + + # Query to fetch sensor data for the unit + sensor_query = ( + f"Retrieve all sensor readings for unit {unit_number} from {dataset_name} dataset. " + f"Include all sensor columns (sensor_1 through sensor_21) and the time_in_cycles column." + ) + + # Fetch data using SQL retriever + sql_result = await sql_retriever.ainvoke({"query": sensor_query}) + + if not sql_result or "error" in sql_result.lower(): + return json.dumps({ + "error": f"Failed to retrieve sensor data for unit {unit_number}", + "details": sql_result + }) + + # Parse the SQL result to extract sensor data + try: + # The SQL retriever returns a string, parse it to get the data + if isinstance(sql_result, str): + # Try to extract data from the result string + # The format is typically a table or JSON + data_lines = sql_result.strip().split('\n') + + # Skip header and parse data + sensor_values = [] + for line in data_lines[1:]: # Skip header + if line.strip(): + # Extract numeric values + values = [float(v) for v in line.split() if v.replace('.', '').replace('-', '').isdigit()] + if values: + sensor_values.append(values) + + if not sensor_values: + return json.dumps({ + "error": "No valid sensor data found", + "raw_result": sql_result + }) + + # Prepare data for NV Tesseract + # Take the last lookback_period points for analysis + lookback_data = sensor_values[-config.lookback_period:] if len(sensor_values) >= config.lookback_period else sensor_values + + # Format prompt for NV Tesseract NIM + prompt = f"""Analyze the following time-series sensor data for anomalies: + +Dataset: {dataset_name} +Unit: {unit_number} +Lookback Period: {config.lookback_period} time steps +Forecast Horizon: {config.forecast_horizon} time steps + +Sensor Data (most recent {len(lookback_data)} readings): +{json.dumps(lookback_data, indent=2)} + +Task: Detect anomalies in the sensor readings and provide: +1. Anomaly score (0-1, where 1 is highly anomalous) +2. Identified anomalous time steps +3. Most anomalous sensors +4. Brief explanation of detected patterns +5. Forecast for next {config.forecast_horizon} time steps + +Return the analysis as a JSON object. +""" + + # Call NV Tesseract NIM + response = await llm.acomplete(prompt) + + # Extract the response text + response_text = response.text if hasattr(response, 'text') else str(response) + + # Try to parse as JSON, or return as-is + try: + result = json.loads(response_text) + except json.JSONDecodeError: + result = { + "analysis": response_text, + "unit_number": unit_number, + "dataset": dataset_name, + "data_points_analyzed": len(lookback_data) + } + + # Add metadata + result["model"] = config.model_name + result["lookback_period"] = config.lookback_period + result["forecast_horizon"] = config.forecast_horizon + + return json.dumps(result, indent=2) + + except Exception as e: + return json.dumps({ + "error": f"Error processing sensor data: {str(e)}", + "raw_result": sql_result + }) + + yield FunctionInfo.from_fn( + _detect_anomalies, + input_schema=AnomalyDetectionInput, + description=( + "Detect anomalies in turbofan engine sensor data using NV Tesseract foundation model. " + "Analyzes time-series sensor readings to identify unusual patterns and forecast future values. " + "Provides anomaly scores, identifies problematic sensors, and explains detected patterns." + ) + ) diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/predict_rul_tool.py b/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/predict_rul_tool.py new file mode 100644 index 0000000..b244d8f --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/predictors/predict_rul_tool.py @@ -0,0 +1,325 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import warnings +import pickle +import joblib +import numpy as np + +from pydantic import Field, BaseModel + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig + +logger = logging.getLogger(__name__) + +def verify_json_path(file_path: str, output_folder: str = None) -> str: + """ + Verify that the input is a valid path to a JSON file. + + Args: + file_path (str): Path to verify + output_folder (str): Output folder to use as base for relative paths + + Returns: + str: Verified file path + + Raises: + ValueError: If input is not a string or not a JSON file + FileNotFoundError: If file does not exist + json.JSONDecodeError: If file contains invalid JSON + """ + if not isinstance(file_path, str): + return "Input must be a string path to a JSON file" + + if not file_path.lower().endswith('.json'): + return "Input must be a path to a JSON file (ending with .json)" + + # Resolve path relative to output folder if provided + if output_folder: + # Import here to avoid circular imports + import os.path + if os.path.isabs(file_path): + # If absolute path exists, use it + if os.path.exists(file_path): + resolved_path = file_path + else: + # If absolute path doesn't exist, try relative to output folder + basename = os.path.basename(file_path) + resolved_path = os.path.join(output_folder, basename) + else: + # If relative path, first try relative to output folder + relative_to_output = os.path.join(output_folder, file_path) + if os.path.exists(relative_to_output): + resolved_path = relative_to_output + else: + # Then try as provided (relative to current working directory) + resolved_path = file_path + else: + resolved_path = file_path + + if not os.path.exists(resolved_path): + return f"JSON file not found at path: {file_path}" + + try: + with open(resolved_path, 'r') as f: + json.load(f) # Verify file contains valid JSON + except json.JSONDecodeError: + return f"File at {resolved_path} does not contain valid JSON data" + + return resolved_path + +class PredictRulToolConfig(FunctionBaseConfig, name="predict_rul_tool"): + """ + NeMo Agent Toolkit function to predict RUL (Remaining Useful Life) using trained XGBoost models and provided data. + """ + # Runtime configuration parameters + scaler_path: str = Field(description="Path to the trained StandardScaler model.", default="./models/scaler_model.pkl") + model_path: str = Field(description="Path to the trained XGBoost model.", default="./models/xgb_model_fd001.pkl") + output_folder: str = Field(description="The path to the output folder to save prediction results.", default="./output_data") + +@register_function(config_type=PredictRulToolConfig) +async def predict_rul_tool( + config: PredictRulToolConfig, builder: Builder +): + class PredictRulInputSchema(BaseModel): + json_file_path: str = Field(description="Path to a JSON file containing sensor measurements data for RUL prediction") + + def load_data_from_json(json_path: str, output_folder: str = None): + """Load data from JSON file into a pandas DataFrame.""" + import pandas as pd + try: + # Resolve path relative to output folder if provided + if output_folder: + # Import here to avoid circular imports + import os.path + if os.path.isabs(json_path): + # If absolute path exists, use it + if os.path.exists(json_path): + resolved_path = json_path + else: + # If absolute path doesn't exist, try relative to output folder + basename = os.path.basename(json_path) + resolved_path = os.path.join(output_folder, basename) + else: + # If relative path, first try relative to output folder + relative_to_output = os.path.join(output_folder, json_path) + if os.path.exists(relative_to_output): + resolved_path = relative_to_output + else: + # Then try as provided (relative to current working directory) + resolved_path = json_path + else: + resolved_path = json_path + + with open(resolved_path, 'r') as f: + data = json.load(f) + return pd.DataFrame(data) + except FileNotFoundError: + logger.warn(f"JSON file not found at {json_path}") + return None + except json.JSONDecodeError: + logger.warn(f"Could not decode JSON from {json_path}") + return None + except Exception as e: + logger.warn(f"Error loading data from '{json_path}': {e}") + return None + + def predict_rul_from_data(data_json_path: str, scaler_path: str, model_path: str, output_dir: str): + """ + Load data and trained models to make RUL predictions. + + Args: + data_json_path (str): Path to the input JSON data file. + scaler_path (str): Path to the trained StandardScaler model. + model_path (str): Path to the trained XGBoost model. + output_dir (str): Directory to save prediction results (unused - kept for compatibility). + + Returns: + tuple: (predictions array, original file path) + """ + import pandas as pd + + # Suppress warnings + warnings.filterwarnings("ignore", message="X does not have valid feature names") + + # Load the data + df = load_data_from_json(data_json_path, output_dir) + if df is None or df.empty: + raise ValueError(f"Could not load data or data is empty from {data_json_path}") + + # Prepare features for prediction (exclude non-feature columns if present) + required_columns = ['sensor_measurement_2', + 'sensor_measurement_3', + 'sensor_measurement_4', + 'sensor_measurement_7', + 'sensor_measurement_8', + 'sensor_measurement_11', + 'sensor_measurement_12', + 'sensor_measurement_13', + 'sensor_measurement_15', + 'sensor_measurement_17', + 'sensor_measurement_20', + 'sensor_measurement_21'] + feature_columns = [col for col in df.columns if col in required_columns] + if not feature_columns: + raise ValueError(f"No valid feature columns found in the data. Available columns: {df.columns.tolist()}") + + X_test = df[feature_columns].values + logger.info(f"Using {len(feature_columns)} features for prediction: {feature_columns}") + + # Load the StandardScaler + try: + scaler_loaded = joblib.load(scaler_path) + logger.info(f"Successfully loaded scaler from {scaler_path}") + except Exception as e: + raise FileNotFoundError(f"Could not load scaler from {scaler_path}: {e}") + + # Transform the test data using the loaded scaler + X_test_scaled = scaler_loaded.transform(X_test) + + # Load the XGBoost model + try: + with open(model_path, 'rb') as f: + xgb_model = pickle.load(f) + logger.info(f"Successfully loaded XGBoost model from {model_path}") + except Exception as e: + raise FileNotFoundError(f"Could not load XGBoost model from {model_path}: {e}") + + # Make predictions + y_pred = xgb_model.predict(X_test_scaled) + logger.info(f"Generated {len(y_pred)} RUL predictions") + + # Create results DataFrame + results_df = df.copy() + results_df['predicted_RUL'] = y_pred + + # Save results back to the original JSON file (resolve path for saving) + if output_dir: + # For saving, we want to save relative to output_dir if the original path was relative + import os.path + if not os.path.isabs(data_json_path): + save_path = os.path.join(output_dir, os.path.basename(data_json_path)) + else: + save_path = data_json_path + else: + save_path = data_json_path + + results_json = results_df.to_dict('records') + with open(save_path, 'w') as f: + json.dump(results_json, f, indent=2) + + logger.info(f"Prediction results saved back to file: {save_path}") + + return y_pred, save_path + + async def _response_fn(json_file_path: str) -> str: + """ + Process the input message and generate RUL predictions using trained XGBoost models. + """ + logger.info(f"Input message: {json_file_path}") + data_json_path = verify_json_path(json_file_path, config.output_folder) + try: + predictions, output_filepath = predict_rul_from_data( + data_json_path=data_json_path, + scaler_path=config.scaler_path, + model_path=config.model_path, + output_dir=config.output_folder + ) + + # Generate summary statistics + avg_rul = np.mean(predictions) + min_rul = np.min(predictions) + max_rul = np.max(predictions) + std_rul = np.std(predictions) + + # Create response with prediction summary (relative path from output folder) + output_relpath = os.path.relpath(output_filepath, config.output_folder) + response = f"""RUL predictions generated successfully! 📊 + +**Model Used:** XGBoost (Traditional Machine Learning) + +**Prediction Summary:** +- **Total predictions:** {len(predictions)} +- **Average RUL:** {avg_rul:.2f} cycles +- **Minimum RUL:** {min_rul:.2f} cycles +- **Maximum RUL:** {max_rul:.2f} cycles +- **Standard Deviation:** {std_rul:.2f} cycles + +**Results saved to:** {output_relpath} + +The predictions have been added to the original dataset with column name 'predicted_RUL'. The original JSON file has been updated with the RUL predictions. +All columns from the original dataset have been preserved, and a new 'predicted_RUL' column has been added.""" + + return response + + except FileNotFoundError as e: + error_msg = f"Required file not found for RUL prediction: {e}. Please ensure all model files and data are available." + logger.warn(error_msg) + return error_msg + except ValueError as ve: + error_msg = f"Data validation error for RUL prediction: {ve}. Check the input data format." + logger.warn(error_msg) + return error_msg + except Exception as e: + error_msg = f"Error during RUL prediction: {e}" + logger.warn(error_msg) + return error_msg + + prompt = """ + Predict RUL (Remaining Useful Life) for turbofan engines using trained XGBoost machine learning models. + + Input: + - Path to a JSON file containing sensor measurements + + Required columns: + * sensor_measurement_2 + * sensor_measurement_3 + * sensor_measurement_4 + * sensor_measurement_7 + * sensor_measurement_8 + * sensor_measurement_11 + * sensor_measurement_12 + * sensor_measurement_13 + * sensor_measurement_15 + * sensor_measurement_17 + * sensor_measurement_20 + * sensor_measurement_21 + + Process: + 1. Load and preprocess data using StandardScaler + 2. Generate predictions using trained XGBoost model + 3. Calculate summary statistics (mean, min, max, std dev) + 4. Save predictions to JSON file + + Output: + - RUL predictions for each unit + - Summary statistics of predictions + - Updated JSON file with predictions added as 'predicted_RUL' column + """ + yield FunctionInfo.from_fn(_response_fn, + input_schema=PredictRulInputSchema, + description=prompt) + try: + pass + except GeneratorExit: + logger.info("Predict RUL function exited early!") + finally: + logger.info("Cleaning up predict_rul_tool workflow.") diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/register.py b/examples/asset_lifecycle_management/src/nat_alm_agent/register.py new file mode 100644 index 0000000..324c92f --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/register.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=unused-import +# flake8: noqa + +# Import any tools which need to be automatically registered here +# NOTE: Both SQL implementations are available for comparison testing: +# 1. sql_retriever_old: Original implementation (generate_sql_query_and_retrieve_tool) +# 2. sql_retriever_vanna: New package-based implementation (vanna_sql_tool from nat_vanna_tool) +# The vanna_sql_tool is auto-registered via entry points in nat_vanna_tool package +from .retrievers import generate_sql_query_and_retrieve_tool +from .predictors import predict_rul_tool +from .predictors import moment_predict_rul_tool +from .plotting import plot_distribution_tool +from .plotting import plot_comparison_tool +from .plotting import plot_line_chart_tool +from .plotting import plot_anomaly_tool +from .plotting import code_generation_assistant +from .predictors import moment_anomaly_detection_tool +from .predictors import nv_tesseract_anomaly_detection_tool +from .evaluators import llm_judge_evaluator_register +from .evaluators import multimodal_llm_judge_evaluator_register +# NOTE: E2B code execution tool for cloud-based sandbox (alternative to local Docker) +from .code_execution import e2b_code_execution_tool diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/retrievers/__init__.py b/examples/asset_lifecycle_management/src/nat_alm_agent/retrievers/__init__.py new file mode 100644 index 0000000..dbf4b8d --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/retrievers/__init__.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Retrievers package for Asset Lifecycle Management agent. + +This package contains components for data retrieval and SQL query generation +for Asset Lifecycle Management workflows (currently focused on predictive maintenance). +""" + +from .vanna_manager import VannaManager +from .vanna_util import * +from . import generate_sql_query_and_retrieve_tool + +__all__ = [ + "VannaManager", + "generate_sql_query_and_retrieve_tool", +] \ No newline at end of file diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/retrievers/generate_sql_query_and_retrieve_tool.py b/examples/asset_lifecycle_management/src/nat_alm_agent/retrievers/generate_sql_query_and_retrieve_tool.py new file mode 100644 index 0000000..dd34e0b --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/retrievers/generate_sql_query_and_retrieve_tool.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import Optional + +from pydantic import Field, BaseModel + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig +from nat.builder.framework_enum import LLMFrameworkEnum + +logger = logging.getLogger(__name__) + +class GenerateSqlQueryAndRetrieveToolConfig(FunctionBaseConfig, name="generate_sql_query_and_retrieve_tool"): + """ + NeMo Agent Toolkit function to generate SQL queries and retrieve data. + + Supports multiple database types through flexible connection configuration. + """ + # Runtime configuration parameters + llm_name: str = Field(description="The name of the LLM to use for the function.") + embedding_name: str = Field(description="The name of the embedding to use for the function.") + + # Vector store configuration + vector_store_type: str = Field( + default="chromadb", + description="Type of vector store: 'chromadb' or 'elasticsearch'" + ) + vector_store_path: Optional[str] = Field( + default=None, + description="Path to ChromaDB vector store (required if vector_store_type='chromadb')" + ) + elasticsearch_url: Optional[str] = Field( + default=None, + description="Elasticsearch URL (required if vector_store_type='elasticsearch', e.g., 'http://localhost:9200')" + ) + elasticsearch_index_name: str = Field( + default="vanna_vectors", + description="Elasticsearch index name (used if vector_store_type='elasticsearch')" + ) + elasticsearch_username: Optional[str] = Field( + default=None, + description="Elasticsearch username for basic auth (optional)" + ) + elasticsearch_password: Optional[str] = Field( + default=None, + description="Elasticsearch password for basic auth (optional)" + ) + elasticsearch_api_key: Optional[str] = Field( + default=None, + description="Elasticsearch API key for authentication (optional)" + ) + + # Database configuration + db_connection_string_or_path: str = Field( + description=( + "Database connection (path for SQLite, connection string for others). Format depends on db_type:\n" + "- sqlite: Path to .db file (e.g., './database.db')\n" + "- postgres: Connection string (e.g., 'postgresql://user:pass@host:port/db')\n" + "- sql: SQLAlchemy connection string (e.g., 'mysql+pymysql://user:pass@host/db')" + ) + ) + db_type: str = Field( + default="sqlite", + description="Type of database: 'sqlite', 'postgres', or 'sql' (generic SQL via SQLAlchemy)" + ) + + # Output configuration + output_folder: str = Field(description="The path to the output folder to use for the function.") + vanna_training_data_path: str = Field(description="The path to the YAML file containing Vanna training data.") + +@register_function(config_type=GenerateSqlQueryAndRetrieveToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) +async def generate_sql_query_and_retrieve_tool( + config: GenerateSqlQueryAndRetrieveToolConfig, builder: Builder +): + """ + Generate a SQL query for a given question and retrieve the data from the database. + """ + class GenerateSqlQueryInputSchema(BaseModel): + input_question_in_english: str = Field(description="User's question in plain English to generate SQL query for") + + # Create Vanna instance + vanna_llm_config = builder.get_llm_config(config.llm_name) + vanna_embedder_config = builder.get_embedder_config(config.embedding_name) + + from langchain_core.prompts.chat import ChatPromptTemplate + + llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) + + system_prompt = """ + You are an intelligent SQL query assistant that analyzes database query results and provides appropriate responses. + + Your responsibilities: + 1. Analyze the SQL query results and determine the best response format. + 2. For data extraction queries (multiple rows/complex data): recommend saving to JSON file and provide summary. + 3. For simple queries (single values, counts, yes/no, simple lookups): provide DIRECT answers without file storage. + 4. Always be helpful and provide context about the results. + + Guidelines: + + - If results contain multiple rows or complex data AND the query is for data analysis/processing: recommend saving to file + - If results are simple (single value, count, or small lookup): provide only the direct answer even if a file was created for the results. + - Always mention the SQL query that was executed. + - Important: Do not use template variables or placeholders in your response. Provide actual values and descriptions. + + Be conversational and helpful. Explain what was found. + """ + # CRITICAL INSTRUCTION: If the question asks for unit numbers or IDs (e.g., "what are their unit numbers"): + # - Provide the COMPLETE list of ALL unit numbers from the data + # - Never say "not shown in sample" or "additional values" + # - Extract all unit_number values from the complete dataset, not just the sample + # - If you see unit numbers 40, 82, 174, 184 in the data, list ALL of them explicitly + # """ + + user_prompt = """ + Original Question: {original_question} + + SQL Query Executed: {sql_query} + + Query Results: + - Number of rows: {num_rows} + - Number of columns: {num_columns} + - Columns: {columns} + - Sample data (first few rows): {sample_data} + + Output directory: {output_dir} + """ + + prompt = ChatPromptTemplate.from_messages([("system", system_prompt), ("user", user_prompt)]) + output_message = prompt | llm + + from .vanna_manager import VannaManager + + # Create a VannaManager instance with full configuration + # This will trigger immediate Vanna instance creation and training during initialization + vanna_manager = VannaManager.create_with_config( + vanna_llm_config=vanna_llm_config, + vanna_embedder_config=vanna_embedder_config, + vector_store_type=config.vector_store_type, + vector_store_path=config.vector_store_path, + elasticsearch_url=config.elasticsearch_url, + elasticsearch_index_name=config.elasticsearch_index_name, + elasticsearch_username=config.elasticsearch_username, + elasticsearch_password=config.elasticsearch_password, + elasticsearch_api_key=config.elasticsearch_api_key, + db_connection_string_or_path=config.db_connection_string_or_path, + db_type=config.db_type, + training_data_path=config.vanna_training_data_path + ) + + def get_vanna_instance(): + """ + Get the pre-initialized Vanna instance from VannaManager. + Training has already been completed during VannaManager initialization. + """ + return vanna_manager.get_instance() + + async def _response_fn(input_question_in_english: str) -> str: + # Process the input_question_in_english and generate output using VannaManager + logger.info(f"RESPONSE: Starting question processing for: {input_question_in_english}") + + sql = None + try: + # CRITICAL: Ensure VannaManager instance is created before using it + # This creates the instance if it doesn't exist (lazy initialization) + vn_instance = get_vanna_instance() + + # Use VannaManager for safe SQL generation + sql = vanna_manager.generate_sql_safe(question=input_question_in_english) + logger.info(f"Generated SQL: {sql}") + + except Exception as e: + logger.error(f"RESPONSE: Exception during generate_sql_safe: {e}") + return f"Error generating SQL: {e}" + + # vn_instance is already available from above + + if not vn_instance.run_sql_is_set: + return f"Database is not connected via Vanna: {sql}" + + try: + df = vn_instance.run_sql(sql) + if df is None: + return f"Vanna run_sql returned None: {sql}" + if df.empty: + return f"No data found for the generated SQL: {sql}" + + num_rows = df.shape[0] + num_columns = df.shape[1] + columns = df.columns.tolist() + + # Get sample data (first 3 rows for preview) + sample_data = df.head(3).to_dict('records') + + # Use LLM to generate intelligent response + response = await output_message.ainvoke({ + "original_question": input_question_in_english, + "sql_query": sql, + "num_rows": num_rows, + "num_columns": num_columns, + "columns": ", ".join(columns), + "sample_data": json.dumps(sample_data, indent=2), + "output_dir": config.output_folder + }) + + # Check if LLM response suggests saving data (look for keywords or patterns) + llm_response = response.content if hasattr(response, 'content') else str(response) + # Clean up the LLM response and add file save confirmation + # Remove any object references that might have slipped through + import re + llm_response = re.sub(r',\[object Object\],?', '', llm_response) + + # if "save" in llm_response.lower(): + # Clean the question for filename + clean_question = re.sub(r'[^\w\s-]', '', input_question_in_english.lower()) + clean_question = re.sub(r'\s+', '_', clean_question.strip())[:30] + suggested_filename = f"{clean_question}_results.json" + + sql_output_path = os.path.join(config.output_folder, suggested_filename) + + # Save the data to JSON file + os.makedirs(config.output_folder, exist_ok=True) + json_result = df.to_json(orient="records") + with open(sql_output_path, 'w') as f: + json.dump(json.loads(json_result), f, indent=4) + + logger.info(f"Data saved to {sql_output_path}") + + llm_response += f"\n\nData has been saved to file: {suggested_filename}" + + return llm_response + + # return llm_response + + except Exception as e: + return f"Error running SQL query '{sql}': {e}" + + description = """ + Use this tool to automatically generate SQL queries for the user's question, retrieve the data from the SQL database and provide a summary of the data or save the data in a JSON file. + Do not provide SQL query as input, only a question in plain english. + + Input: + - input_question_in_english: User's question or a question that you think is relevant to the user's question in plain english + + Output: Status of the generated SQL query's execution along with the output path. + """ + yield FunctionInfo.from_fn(_response_fn, + input_schema=GenerateSqlQueryInputSchema, + description=description) + try: + pass + except GeneratorExit: + logger.info("Generate SQL query and retrieve function exited early!") + finally: + logger.info("Cleaning up generate_sql_query_and_retrieve_tool workflow.") diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/retrievers/vanna_manager.py b/examples/asset_lifecycle_management/src/nat_alm_agent/retrievers/vanna_manager.py new file mode 100644 index 0000000..bce296b --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/retrievers/vanna_manager.py @@ -0,0 +1,552 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VannaManager - A simplified manager for Vanna instances +""" +import os +import logging +import threading +import hashlib +from typing import Dict, Optional +from .vanna_util import NIMVanna, ElasticNIMVanna, initVanna, NVIDIAEmbeddingFunction + +logger = logging.getLogger(__name__) + +class VannaManager: + """ + A simplified singleton manager for Vanna instances. + + Key features: + - Singleton pattern to ensure only one instance per configuration + - Thread-safe operations + - Simple instance management + - Support for multiple database types: SQLite, generic SQL, and PostgreSQL + """ + + _instances: Dict[str, 'VannaManager'] = {} + _lock = threading.Lock() + + def __new__(cls, config_key: str): + """Ensure singleton pattern per configuration""" + with cls._lock: + if config_key not in cls._instances: + logger.debug(f"VannaManager: Creating new singleton instance for config: {config_key}") + cls._instances[config_key] = super().__new__(cls) + cls._instances[config_key]._initialized = False + else: + logger.debug(f"VannaManager: Returning existing singleton instance for config: {config_key}") + return cls._instances[config_key] + + def __init__(self, config_key: str, vanna_llm_config=None, vanna_embedder_config=None, + vector_store_type: str = "chromadb", vector_store_path: str = None, + elasticsearch_url: str = None, elasticsearch_index_name: str = "vanna_vectors", + elasticsearch_username: str = None, elasticsearch_password: str = None, + elasticsearch_api_key: str = None, + db_connection_string_or_path: str = None, db_type: str = "sqlite", + training_data_path: str = None, nvidia_api_key: str = None): + """Initialize the VannaManager and create Vanna instance immediately if all config is provided + + Args: + config_key: Unique key for this configuration + vanna_llm_config: LLM configuration object + vanna_embedder_config: Embedder configuration object + vector_store_type: Type of vector store - 'chromadb' or 'elasticsearch' + vector_store_path: Path to ChromaDB vector store (required if vector_store_type='chromadb') + elasticsearch_url: Elasticsearch URL (required if vector_store_type='elasticsearch') + elasticsearch_index_name: Elasticsearch index name + elasticsearch_username: Elasticsearch username for basic auth + elasticsearch_password: Elasticsearch password for basic auth + elasticsearch_api_key: Elasticsearch API key + db_connection_string_or_path: Database connection (path for SQLite, connection string for others) + db_type: Type of database - 'sqlite', 'postgres', or 'sql' (generic SQL with SQLAlchemy) + training_data_path: Path to YAML training data file + nvidia_api_key: NVIDIA API key (optional, can use NVIDIA_API_KEY env var) + """ + if hasattr(self, '_initialized') and self._initialized: + return + + self.config_key = config_key + self.lock = threading.Lock() + + # Store configuration + self.vanna_llm_config = vanna_llm_config + self.vanna_embedder_config = vanna_embedder_config + self.vector_store_type = vector_store_type + self.vector_store_path = vector_store_path + self.elasticsearch_url = elasticsearch_url + self.elasticsearch_index_name = elasticsearch_index_name + self.elasticsearch_username = elasticsearch_username + self.elasticsearch_password = elasticsearch_password + self.elasticsearch_api_key = elasticsearch_api_key + self.db_connection_string_or_path = db_connection_string_or_path + self.db_type = db_type + self.training_data_path = training_data_path + self.nvidia_api_key = nvidia_api_key or os.getenv("NVIDIA_API_KEY") + + # Create and initialize Vanna instance immediately if all required config is provided + self.vanna_instance = None + has_vector_config = ( + (vector_store_type == "chromadb" and vector_store_path) or + (vector_store_type == "elasticsearch" and elasticsearch_url) + ) + if all([vanna_llm_config, vanna_embedder_config, has_vector_config, self.db_connection_string_or_path]): + logger.debug(f"VannaManager: Initializing with immediate Vanna instance creation") + self.vanna_instance = self._create_instance() + else: + if any([vanna_llm_config, vanna_embedder_config, vector_store_path, elasticsearch_url, self.db_connection_string_or_path]): + logger.debug(f"VannaManager: Partial configuration provided, Vanna instance will be created later") + else: + logger.debug(f"VannaManager: No configuration provided, Vanna instance will be created later") + + self._initialized = True + logger.debug(f"VannaManager initialized for config: {config_key}") + + def get_instance(self, vanna_llm_config=None, vanna_embedder_config=None, + vector_store_type: str = None, vector_store_path: str = None, + elasticsearch_url: str = None, + db_connection_string_or_path: str = None, db_type: str = None, + training_data_path: str = None, nvidia_api_key: str = None): + """ + Get the Vanna instance. If not created during init, create it now with provided parameters. + """ + with self.lock: + if self.vanna_instance is None: + logger.debug(f"VannaManager: No instance created during init, creating now...") + + # Update configuration with provided parameters + self.vanna_llm_config = vanna_llm_config or self.vanna_llm_config + self.vanna_embedder_config = vanna_embedder_config or self.vanna_embedder_config + self.vector_store_type = vector_store_type or self.vector_store_type + self.vector_store_path = vector_store_path or self.vector_store_path + self.elasticsearch_url = elasticsearch_url or self.elasticsearch_url + self.db_connection_string_or_path = db_connection_string_or_path or self.db_connection_string_or_path + self.db_type = db_type or self.db_type + self.training_data_path = training_data_path or self.training_data_path + self.nvidia_api_key = nvidia_api_key or self.nvidia_api_key + + # Check if we have required vector store config + has_vector_config = ( + (self.vector_store_type == "chromadb" and self.vector_store_path) or + (self.vector_store_type == "elasticsearch" and self.elasticsearch_url) + ) + + if all([self.vanna_llm_config, self.vanna_embedder_config, has_vector_config, self.db_connection_string_or_path]): + self.vanna_instance = self._create_instance() + else: + raise RuntimeError("VannaManager: Missing required configuration parameters") + else: + logger.debug(f"VannaManager: Returning pre-initialized Vanna instance (ID: {id(self.vanna_instance)})") + logger.debug(f"VannaManager: Vector store type: {self.vector_store_type}") + + # Show vector store status for pre-initialized instances + try: + if self.vector_store_type == "chromadb" and self.vector_store_path: + if os.path.exists(self.vector_store_path): + list_of_folders = [d for d in os.listdir(self.vector_store_path) + if os.path.isdir(os.path.join(self.vector_store_path, d))] + logger.debug(f"VannaManager: ChromaDB contains {len(list_of_folders)} collections/folders") + if list_of_folders: + logger.debug(f"VannaManager: ChromaDB folders: {list_of_folders}") + else: + logger.debug(f"VannaManager: ChromaDB directory does not exist") + elif self.vector_store_type == "elasticsearch": + logger.debug(f"VannaManager: Using Elasticsearch at {self.elasticsearch_url}") + except Exception as e: + logger.warning(f"VannaManager: Could not check vector store status: {e}") + + return self.vanna_instance + + def _create_instance(self): + """ + Create a new Vanna instance using the stored configuration. + Returns NIMVanna (ChromaDB) or ElasticNIMVanna (Elasticsearch) based on vector_store_type. + """ + logger.info(f"VannaManager: Creating instance for {self.config_key}") + logger.debug(f"VannaManager: Vector store type: {self.vector_store_type}") + logger.debug(f"VannaManager: Database connection: {self.db_connection_string_or_path}") + logger.debug(f"VannaManager: Database type: {self.db_type}") + logger.debug(f"VannaManager: Training data path: {self.training_data_path}") + + # Create embedding function (used by both ChromaDB and Elasticsearch) + embedding_function = NVIDIAEmbeddingFunction( + api_key=self.nvidia_api_key, + model=self.vanna_embedder_config.model_name + ) + + # LLM configuration (common for both) + llm_config = { + "api_key": self.nvidia_api_key, + "model": self.vanna_llm_config.model_name + } + + # Create instance based on vector store type + if self.vector_store_type == "chromadb": + logger.debug(f"VannaManager: Creating NIMVanna with ChromaDB") + logger.debug(f"VannaManager: ChromaDB path: {self.vector_store_path}") + vn_instance = NIMVanna( + VectorConfig={ + "client": "persistent", + "path": self.vector_store_path, + "embedding_function": embedding_function + }, + LLMConfig=llm_config + ) + elif self.vector_store_type == "elasticsearch": + logger.debug(f"VannaManager: Creating ElasticNIMVanna with Elasticsearch") + logger.debug(f"VannaManager: Elasticsearch URL: {self.elasticsearch_url}") + logger.debug(f"VannaManager: Elasticsearch index: {self.elasticsearch_index_name}") + + # Build Elasticsearch vector config + es_config = { + "url": self.elasticsearch_url, + "index_name": self.elasticsearch_index_name, + "embedding_function": embedding_function + } + + # Add authentication if provided + if self.elasticsearch_api_key: + es_config["api_key"] = self.elasticsearch_api_key + logger.debug("VannaManager: Using Elasticsearch API key authentication") + elif self.elasticsearch_username and self.elasticsearch_password: + es_config["username"] = self.elasticsearch_username + es_config["password"] = self.elasticsearch_password + logger.debug("VannaManager: Using Elasticsearch basic authentication") + + vn_instance = ElasticNIMVanna( + VectorConfig=es_config, + LLMConfig=llm_config + ) + else: + raise ValueError( + f"Unsupported vector store type: {self.vector_store_type}. " + "Supported types: 'chromadb', 'elasticsearch'" + ) + + # Connect to database based on type + logger.debug(f"VannaManager: Connecting to {self.db_type} database...") + if self.db_type == "sqlite": + # Vanna's connect_to_sqlite has broken URL detection in 0.7.9 + # It tries to download everything with requests.get() + # For local files, use direct SQLite connection + import os + db_path = self.db_connection_string_or_path + + # Convert relative paths to absolute + if not os.path.isabs(db_path): + db_path = os.path.abspath(db_path) + + # For local files, use sqlite3 directly + if os.path.exists(db_path): + import sqlite3 + import pandas as pd + + def run_sql_sqlite(sql: str): + """Execute SQL on local SQLite database.""" + conn = sqlite3.connect(db_path) + try: + df = pd.read_sql_query(sql, conn) + return df + finally: + conn.close() + + vn_instance.run_sql = run_sql_sqlite + vn_instance.run_sql_is_set = True + logger.debug(f"VannaManager: Connected to local SQLite database: {db_path}") + else: + # If file doesn't exist, let Vanna try (maybe it's a URL) + logger.warning(f"VannaManager: Database file not found: {db_path}") + vn_instance.connect_to_sqlite(self.db_connection_string_or_path) + elif self.db_type == "postgres" or self.db_type == "postgresql": + self._connect_to_postgres(vn_instance, self.db_connection_string_or_path) + elif self.db_type == "sql": + self._connect_to_sql(vn_instance, self.db_connection_string_or_path) + else: + raise ValueError( + f"Unsupported database type: {self.db_type}. " + "Supported types: 'sqlite', 'postgres', 'sql'" + ) + + # Set configuration - allow LLM to see data for database introspection + vn_instance.allow_llm_to_see_data = True + logger.debug(f"VannaManager: Set allow_llm_to_see_data = True") + + # Initialize if needed (check if vector store is empty) + needs_init = self._needs_initialization() + if needs_init: + logger.info("VannaManager: Vector store needs initialization, starting training...") + try: + initVanna(vn_instance, self.training_data_path) + logger.info("VannaManager: Vector store initialization complete") + except Exception as e: + logger.error(f"VannaManager: Error during initialization: {e}") + raise + else: + logger.debug("VannaManager: Vector store already initialized, skipping training") + + logger.info(f"VannaManager: Instance created successfully") + return vn_instance + + def _connect_to_postgres(self, vn_instance: NIMVanna, connection_string: str): + """ + Connect to a PostgreSQL database. + + Args: + vn_instance: The Vanna instance to connect + connection_string: PostgreSQL connection string in format: + postgresql://user:password@host:port/database + """ + try: + import psycopg2 + from psycopg2.pool import SimpleConnectionPool + + logger.info("Connecting to PostgreSQL database...") + + # Parse connection string if needed + if connection_string.startswith("postgresql://"): + # Use SQLAlchemy-style connection for Vanna + vn_instance.connect_to_postgres(url=connection_string) + else: + # Assume it's a psycopg2 connection string + vn_instance.connect_to_postgres(url=f"postgresql://{connection_string}") + + logger.info("Successfully connected to PostgreSQL database") + except ImportError: + logger.error( + "psycopg2 is required for PostgreSQL connections. " + "Install it with: pip install psycopg2-binary" + ) + raise + except Exception as e: + logger.error(f"Error connecting to PostgreSQL: {e}") + raise + + def _connect_to_sql(self, vn_instance: NIMVanna, connection_string: str): + """ + Connect to a generic SQL database using SQLAlchemy. + + Args: + vn_instance: The Vanna instance to connect + connection_string: SQLAlchemy-compatible connection string, e.g.: + - MySQL: mysql+pymysql://user:password@host:port/database + - PostgreSQL: postgresql://user:password@host:port/database + - SQL Server: mssql+pyodbc://user:password@host:port/database?driver=ODBC+Driver+17+for+SQL+Server + - Oracle: oracle+cx_oracle://user:password@host:port/?service_name=service + """ + try: + from sqlalchemy import create_engine + + logger.info("Connecting to SQL database via SQLAlchemy...") + + # Create SQLAlchemy engine + engine = create_engine(connection_string) + + # Connect Vanna to the database using the engine + vn_instance.connect_to_sqlalchemy(engine) + + logger.info("Successfully connected to SQL database") + except ImportError: + logger.error( + "SQLAlchemy is required for generic SQL connections. " + "Install it with: pip install sqlalchemy" + ) + raise + except Exception as e: + logger.error(f"Error connecting to SQL database: {e}") + raise + + def _needs_initialization(self) -> bool: + """ + Check if the vector store needs initialization by checking if it's empty. + For ChromaDB: checks directory existence and contents + For Elasticsearch: checks if index exists and has data + """ + logger.debug(f"VannaManager: Checking if vector store needs initialization...") + logger.debug(f"VannaManager: Vector store type: {self.vector_store_type}") + + try: + if self.vector_store_type == "chromadb": + logger.debug(f"VannaManager: Checking ChromaDB at: {self.vector_store_path}") + + if not os.path.exists(self.vector_store_path): + logger.debug(f"VannaManager: ChromaDB directory does not exist -> needs initialization") + return True + + # Check if there are any subdirectories (ChromaDB creates subdirectories when data is stored) + list_of_folders = [d for d in os.listdir(self.vector_store_path) + if os.path.isdir(os.path.join(self.vector_store_path, d))] + + logger.debug(f"VannaManager: Found {len(list_of_folders)} folders in ChromaDB") + if list_of_folders: + logger.debug(f"VannaManager: ChromaDB folders: {list_of_folders}") + logger.debug(f"VannaManager: ChromaDB is populated -> skipping initialization") + return False + else: + logger.debug(f"VannaManager: ChromaDB is empty -> needs initialization") + return True + + elif self.vector_store_type == "elasticsearch": + logger.debug(f"VannaManager: Checking Elasticsearch at: {self.elasticsearch_url}") + + # For Elasticsearch, check if training data is available in the instance + # This is a simplified check - we assume if we can connect, we should initialize if no training data exists + try: + if hasattr(self.vanna_instance, 'get_training_data'): + training_data = self.vanna_instance.get_training_data() + if training_data and len(training_data) > 0: + logger.debug(f"VannaManager: Elasticsearch has {len(training_data)} training data entries -> skipping initialization") + return False + else: + logger.debug(f"VannaManager: Elasticsearch has no training data -> needs initialization") + return True + else: + logger.debug(f"VannaManager: Cannot check Elasticsearch training data -> needs initialization") + return True + except Exception as e: + logger.debug(f"VannaManager: Error checking Elasticsearch data ({e}) -> needs initialization") + return True + else: + logger.warning(f"VannaManager: Unknown vector store type: {self.vector_store_type}") + return True + + except Exception as e: + logger.warning(f"VannaManager: Could not check vector store status: {e}") + logger.warning(f"VannaManager: Defaulting to needs initialization = True") + return True + + def generate_sql_safe(self, question: str) -> str: + """ + Generate SQL with error handling. + """ + with self.lock: + if self.vanna_instance is None: + raise RuntimeError("VannaManager: No instance available") + + try: + logger.debug(f"VannaManager: Generating SQL for question: {question}") + + # Generate SQL with allow_llm_to_see_data=True for database introspection + sql = self.vanna_instance.generate_sql(question=question, allow_llm_to_see_data=True) + + # Validate SQL response + if not sql or sql.strip() == "": + raise ValueError("Empty SQL response") + + return sql + + except Exception as e: + logger.error(f"VannaManager: Error in SQL generation: {e}") + raise + + def force_reset(self): + """ + Force reset the instance (useful for cleanup). + """ + with self.lock: + if self.vanna_instance: + logger.debug(f"VannaManager: Resetting instance for {self.config_key}") + self.vanna_instance = None + + def get_stats(self) -> Dict: + """ + Get manager statistics. + """ + return { + "config_key": self.config_key, + "instance_id": id(self.vanna_instance) if self.vanna_instance else None, + "has_instance": self.vanna_instance is not None, + "db_type": self.db_type, + } + + @classmethod + def create_with_config(cls, vanna_llm_config, vanna_embedder_config, + vector_store_type: str = "chromadb", vector_store_path: str = None, + elasticsearch_url: str = None, elasticsearch_index_name: str = "vanna_vectors", + elasticsearch_username: str = None, elasticsearch_password: str = None, + elasticsearch_api_key: str = None, + db_connection_string_or_path: str = None, db_type: str = "sqlite", + training_data_path: str = None, nvidia_api_key: str = None): + """ + Class method to create a VannaManager with full configuration. + Uses create_config_key to ensure singleton behavior based on configuration. + + Args: + vanna_llm_config: LLM configuration object + vanna_embedder_config: Embedder configuration object + vector_store_type: Type of vector store - 'chromadb' or 'elasticsearch' + vector_store_path: Path to ChromaDB vector store (required if vector_store_type='chromadb') + elasticsearch_url: Elasticsearch URL (required if vector_store_type='elasticsearch') + elasticsearch_index_name: Elasticsearch index name + elasticsearch_username: Elasticsearch username for basic auth + elasticsearch_password: Elasticsearch password for basic auth + elasticsearch_api_key: Elasticsearch API key + db_connection_string_or_path: Database connection (path for SQLite, connection string for others) + db_type: Type of database - 'sqlite', 'postgres', or 'sql' + training_data_path: Path to YAML training data file + nvidia_api_key: NVIDIA API key (optional) + """ + config_key = create_config_key( + vanna_llm_config, vanna_embedder_config, + vector_store_type, vector_store_path, elasticsearch_url, + db_connection_string_or_path, db_type + ) + + # Create instance with just config_key (singleton pattern) + instance = cls(config_key) + + # If this is a new instance that hasn't been configured yet, set the configuration + if not hasattr(instance, 'vanna_llm_config') or instance.vanna_llm_config is None: + instance.vanna_llm_config = vanna_llm_config + instance.vanna_embedder_config = vanna_embedder_config + instance.vector_store_type = vector_store_type + instance.vector_store_path = vector_store_path + instance.elasticsearch_url = elasticsearch_url + instance.elasticsearch_index_name = elasticsearch_index_name + instance.elasticsearch_username = elasticsearch_username + instance.elasticsearch_password = elasticsearch_password + instance.elasticsearch_api_key = elasticsearch_api_key + instance.db_connection_string_or_path = db_connection_string_or_path + instance.db_type = db_type + instance.training_data_path = training_data_path + instance.nvidia_api_key = nvidia_api_key + + # Create Vanna instance immediately if all config is available + if instance.vanna_instance is None: + logger.debug(f"VannaManager: Creating Vanna instance for existing singleton") + instance.vanna_instance = instance._create_instance() + + return instance + +def create_config_key(vanna_llm_config, vanna_embedder_config, + vector_store_type: str, vector_store_path: str, elasticsearch_url: str, + db_connection_string_or_path: str, db_type: str = "sqlite") -> str: + """ + Create a unique configuration key for the VannaManager singleton. + + Args: + vanna_llm_config: LLM configuration object + vanna_embedder_config: Embedder configuration object + vector_store_type: Type of vector store + vector_store_path: Path to ChromaDB vector store + elasticsearch_url: Elasticsearch URL + db_connection_string_or_path: Database connection (path for SQLite, connection string for others) + db_type: Type of database + + Returns: + str: Unique configuration key + """ + vector_id = vector_store_path if vector_store_type == "chromadb" else elasticsearch_url + config_str = f"{vanna_llm_config.model_name}_{vanna_embedder_config.model_name}_{vector_store_type}_{vector_id}_{db_connection_string_or_path}_{db_type}" + return hashlib.md5(config_str.encode()).hexdigest()[:12] diff --git a/examples/asset_lifecycle_management/src/nat_alm_agent/retrievers/vanna_util.py b/examples/asset_lifecycle_management/src/nat_alm_agent/retrievers/vanna_util.py new file mode 100644 index 0000000..f4764e5 --- /dev/null +++ b/examples/asset_lifecycle_management/src/nat_alm_agent/retrievers/vanna_util.py @@ -0,0 +1,921 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Vanna utilities for SQL generation using NVIDIA NIM services.""" + +import logging + +from langchain_nvidia import ChatNVIDIA, NVIDIAEmbeddings +from tqdm import tqdm +from vanna.base import VannaBase +from vanna.chromadb import ChromaDB_VectorStore + +logger = logging.getLogger(__name__) + +class NIMCustomLLM(VannaBase): + """Custom LLM implementation for Vanna using NVIDIA NIM.""" + + def __init__(self, config=None): + VannaBase.__init__(self, config=config) + + if not config: + raise ValueError("config must be passed") + + # default parameters - can be overrided using config + self.temperature = 0.7 + + if "temperature" in config: + self.temperature = config["temperature"] + + # If only config is passed + if "api_key" not in config: + raise ValueError("config must contain a NIM api_key") + + if "model" not in config: + raise ValueError("config must contain a NIM model") + + api_key = config["api_key"] + model = config["model"] + + # Initialize ChatNVIDIA client + self.client = ChatNVIDIA( + api_key=api_key, + model=model, + temperature=self.temperature, + ) + self.model = model + + def system_message(self, message: str) -> dict: + """Create a system message.""" + return { + "role": "system", + "content": message + "\n DO NOT PRODUCE MARKDOWN, ONLY RESPOND IN PLAIN TEXT", + } + + def user_message(self, message: str) -> dict: + """Create a user message.""" + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> dict: + """Create an assistant message.""" + return {"role": "assistant", "content": message} + + def submit_prompt(self, prompt, **kwargs) -> str: + """Submit a prompt to the LLM.""" + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + + # Count the number of tokens in the message log + # Use 4 as an approximation for the number of characters per token + num_tokens = 0 + for message in prompt: + num_tokens += len(message["content"]) / 4 + logger.debug(f"Using model {self.model} for {num_tokens} tokens (approx)") + + logger.debug(f"Submitting prompt with {len(prompt)} messages") + logger.debug(f"Prompt content preview: {str(prompt)[:500]}...") + + try: + response = self.client.invoke(prompt) + logger.debug(f"Response type: {type(response)}") + logger.debug(f"Response content type: {type(response.content)}") + logger.debug( + f"Response content length: {len(response.content) if response.content else 0}" + ) + logger.debug( + f"Response content preview: {response.content[:200] if response.content else 'None'}..." + ) + return response.content + except Exception as e: + logger.error(f"Error in submit_prompt: {e}") + logger.error(f"Error type: {type(e)}") + import traceback + + logger.error(f"Full traceback: {traceback.format_exc()}") + raise + +class NIMVanna(ChromaDB_VectorStore, NIMCustomLLM): + """Vanna implementation using NVIDIA NIM for LLM and ChromaDB for vector storage.""" + + def __init__(self, VectorConfig=None, LLMConfig=None): + ChromaDB_VectorStore.__init__(self, config=VectorConfig) + NIMCustomLLM.__init__(self, config=LLMConfig) + + +class ElasticVectorStore(VannaBase): + """ + Elasticsearch-based vector store for Vanna. + + This class provides vector storage and retrieval capabilities using Elasticsearch's + dense_vector field type and kNN search functionality. + + Configuration: + config: Dictionary with the following keys: + - url: Elasticsearch connection URL (e.g., "http://localhost:9200") + - index_name: Name of the Elasticsearch index to use (default: "vanna_vectors") + - api_key: Optional API key for authentication + - username: Optional username for basic auth + - password: Optional password for basic auth + - embedding_function: Function to generate embeddings (required) + """ + + def __init__(self, config=None): + VannaBase.__init__(self, config=config) + + if not config: + raise ValueError("config must be passed for ElasticVectorStore") + + # Elasticsearch connection parameters + self.url = config.get("url", "http://localhost:9200") + self.index_name = config.get("index_name", "vanna_vectors") + self.api_key = config.get("api_key") + self.username = config.get("username") + self.password = config.get("password") + + # Embedding function (required) + if "embedding_function" not in config: + raise ValueError("embedding_function must be provided in config") + self.embedding_function = config["embedding_function"] + + # Initialize Elasticsearch client + self._init_elasticsearch_client() + + # Create index if it doesn't exist + self._create_index_if_not_exists() + + logger.info(f"ElasticVectorStore initialized with index: {self.index_name}") + + def _init_elasticsearch_client(self): + """Initialize the Elasticsearch client with authentication.""" + try: + from elasticsearch import Elasticsearch + except ImportError: + raise ImportError( + "elasticsearch package is required for ElasticVectorStore. " + "Install it with: pip install elasticsearch" + ) + + # Build client kwargs + client_kwargs = {} + + if self.api_key: + client_kwargs["api_key"] = self.api_key + elif self.username and self.password: + client_kwargs["basic_auth"] = (self.username, self.password) + + self.es_client = Elasticsearch(self.url, **client_kwargs) + + # Test connection (try but don't fail if ping doesn't work) + try: + if self.es_client.ping(): + logger.info(f"Successfully connected to Elasticsearch at {self.url}") + else: + logger.warning(f"Elasticsearch ping failed, but will try to proceed at {self.url}") + except Exception as e: + logger.warning(f"Elasticsearch ping check failed ({e}), but will try to proceed") + + def _create_index_if_not_exists(self): + """Create the Elasticsearch index with appropriate mappings if it doesn't exist.""" + if self.es_client.indices.exists(index=self.index_name): + logger.debug(f"Index {self.index_name} already exists") + return + + # Get embedding dimension by creating a test embedding + test_embedding = self._generate_embedding("test") + embedding_dim = len(test_embedding) + + # Index mapping with dense_vector field for embeddings + index_mapping = { + "mappings": { + "properties": { + "id": {"type": "keyword"}, + "text": {"type": "text"}, + "embedding": { + "type": "dense_vector", + "dims": embedding_dim, + "index": True, + "similarity": "cosine" + }, + "metadata": {"type": "object", "enabled": True}, + "type": {"type": "keyword"}, # ddl, documentation, sql + "created_at": {"type": "date"} + } + } + } + + self.es_client.indices.create(index=self.index_name, body=index_mapping) + logger.info(f"Created Elasticsearch index: {self.index_name}") + + def _generate_embedding(self, text: str) -> list[float]: + """Generate embedding for a given text using the configured embedding function.""" + if hasattr(self.embedding_function, 'embed_query'): + # NVIDIA embedding function returns [[embedding]] + result = self.embedding_function.embed_query(text) + if isinstance(result, list) and len(result) > 0: + if isinstance(result[0], list): + return result[0] # Extract the inner list + return result # type: ignore[return-value] + return result # type: ignore[return-value] + elif callable(self.embedding_function): + # Generic callable + result = self.embedding_function(text) + if isinstance(result, list) and len(result) > 0: + if isinstance(result[0], list): + return result[0] + return result # type: ignore[return-value] + return result # type: ignore[return-value] + else: + raise ValueError("embedding_function must be callable or have embed_query method") + + def add_ddl(self, ddl: str, **kwargs) -> str: + """ + Add a DDL statement to the vector store. + + Args: + ddl: The DDL statement to store + **kwargs: Additional metadata + + Returns: + Document ID + """ + import hashlib + from datetime import datetime + + # Generate document ID + doc_id = hashlib.md5(ddl.encode()).hexdigest() + + # Generate embedding + embedding = self._generate_embedding(ddl) + + # Create document + doc = { + "id": doc_id, + "text": ddl, + "embedding": embedding, + "type": "ddl", + "metadata": kwargs, + "created_at": datetime.utcnow().isoformat() + } + + # Index document + self.es_client.index(index=self.index_name, id=doc_id, document=doc) + logger.debug(f"Added DDL to Elasticsearch: {doc_id}") + + return doc_id + + def add_documentation(self, documentation: str, **kwargs) -> str: + """ + Add documentation to the vector store. + + Args: + documentation: The documentation text to store + **kwargs: Additional metadata + + Returns: + Document ID + """ + import hashlib + from datetime import datetime + + doc_id = hashlib.md5(documentation.encode()).hexdigest() + embedding = self._generate_embedding(documentation) + + doc = { + "id": doc_id, + "text": documentation, + "embedding": embedding, + "type": "documentation", + "metadata": kwargs, + "created_at": datetime.utcnow().isoformat() + } + + self.es_client.index(index=self.index_name, id=doc_id, document=doc) + logger.debug(f"Added documentation to Elasticsearch: {doc_id}") + + return doc_id + + def add_question_sql(self, question: str, sql: str, **kwargs) -> str: + """ + Add a question-SQL pair to the vector store. + + Args: + question: The natural language question + sql: The corresponding SQL query + **kwargs: Additional metadata + + Returns: + Document ID + """ + import hashlib + from datetime import datetime + + # Combine question and SQL for embedding + combined_text = f"Question: {question}\nSQL: {sql}" + doc_id = hashlib.md5(combined_text.encode()).hexdigest() + embedding = self._generate_embedding(question) + + doc = { + "id": doc_id, + "text": combined_text, + "embedding": embedding, + "type": "sql", + "metadata": { + "question": question, + "sql": sql, + **kwargs + }, + "created_at": datetime.utcnow().isoformat() + } + + self.es_client.index(index=self.index_name, id=doc_id, document=doc) + logger.debug(f"Added question-SQL pair to Elasticsearch: {doc_id}") + + return doc_id + + def get_similar_question_sql(self, question: str, **kwargs) -> list: + """ + Retrieve similar question-SQL pairs using vector similarity search. + + Args: + question: The question to find similar examples for + **kwargs: Additional parameters (e.g., top_k) + + Returns: + List of similar documents + """ + top_k = kwargs.get("top_k", 10) + + # Generate query embedding + query_embedding = self._generate_embedding(question) + + # Build kNN search query + search_query = { + "knn": { + "field": "embedding", + "query_vector": query_embedding, + "k": top_k, + "num_candidates": top_k * 2, + "filter": {"term": {"type": "sql"}} + }, + "_source": ["text", "metadata", "type"] + } + + # Execute search + response = self.es_client.search(index=self.index_name, body=search_query) + + # Extract results + results = [] + for hit in response["hits"]["hits"]: + source = hit["_source"] + results.append({ + "question": source["metadata"].get("question", ""), + "sql": source["metadata"].get("sql", ""), + "score": hit["_score"] + }) + + logger.debug(f"Found {len(results)} similar question-SQL pairs") + return results + + def get_related_ddl(self, question: str, **kwargs) -> list: + """ + Retrieve related DDL statements using vector similarity search. + + Args: + question: The question to find related DDL for + **kwargs: Additional parameters (e.g., top_k) + + Returns: + List of related DDL statements + """ + top_k = kwargs.get("top_k", 10) + query_embedding = self._generate_embedding(question) + + search_query = { + "knn": { + "field": "embedding", + "query_vector": query_embedding, + "k": top_k, + "num_candidates": top_k * 2, + "filter": {"term": {"type": "ddl"}} + }, + "_source": ["text"] + } + + response = self.es_client.search(index=self.index_name, body=search_query) + + results = [hit["_source"]["text"] for hit in response["hits"]["hits"]] + logger.debug(f"Found {len(results)} related DDL statements") + return results + + def get_related_documentation(self, question: str, **kwargs) -> list: + """ + Retrieve related documentation using vector similarity search. + + Args: + question: The question to find related documentation for + **kwargs: Additional parameters (e.g., top_k) + + Returns: + List of related documentation + """ + top_k = kwargs.get("top_k", 10) + query_embedding = self._generate_embedding(question) + + search_query = { + "knn": { + "field": "embedding", + "query_vector": query_embedding, + "k": top_k, + "num_candidates": top_k * 2, + "filter": {"term": {"type": "documentation"}} + }, + "_source": ["text"] + } + + response = self.es_client.search(index=self.index_name, body=search_query) + + results = [hit["_source"]["text"] for hit in response["hits"]["hits"]] + logger.debug(f"Found {len(results)} related documentation entries") + return results + + def remove_training_data(self, id: str, **kwargs) -> bool: + """ + Remove a training data entry by ID. + + Args: + id: The document ID to remove + **kwargs: Additional parameters + + Returns: + True if successful + """ + try: + self.es_client.delete(index=self.index_name, id=id) + logger.debug(f"Removed training data: {id}") + return True + except Exception as e: + logger.error(f"Error removing training data {id}: {e}") + return False + + def generate_embedding(self, data: str, **kwargs) -> list[float]: + """ + Generate embedding for given data (required by Vanna base class). + + Args: + data: Text to generate embedding for + **kwargs: Additional parameters + + Returns: + Embedding vector + """ + return self._generate_embedding(data) + + def get_training_data(self, **kwargs) -> list: + """ + Get all training data from the vector store (required by Vanna base class). + + Args: + **kwargs: Additional parameters + + Returns: + List of training data entries + """ + try: + # Query all documents + query = { + "query": {"match_all": {}}, + "size": 10000 # Adjust based on expected data size + } + + response = self.es_client.search(index=self.index_name, body=query) + + training_data = [] + for hit in response["hits"]["hits"]: + source = hit["_source"] + training_data.append({ + "id": hit["_id"], + "type": source.get("type"), + "text": source.get("text"), + "metadata": source.get("metadata", {}) + }) + + return training_data + except Exception as e: + logger.error(f"Error getting training data: {e}") + return [] + + +class ElasticNIMVanna(ElasticVectorStore, NIMCustomLLM): + """ + Vanna implementation using NVIDIA NIM for LLM and Elasticsearch for vector storage. + + This class combines ElasticVectorStore for vector operations with NIMCustomLLM + for SQL generation, providing an alternative to ChromaDB-based storage. + + Example: + >>> vanna = ElasticNIMVanna( + ... VectorConfig={ + ... "url": "http://localhost:9200", + ... "index_name": "my_sql_vectors", + ... "username": "elastic", + ... "password": "changeme", + ... "embedding_function": NVIDIAEmbeddingFunction( + ... api_key="your-api-key", + ... model="nvidia/llama-3.2-nv-embedqa-1b-v2" + ... ) + ... }, + ... LLMConfig={ + ... "api_key": "your-api-key", + ... "model": "meta/llama-3.1-70b-instruct" + ... } + ... ) + """ + + def __init__(self, VectorConfig=None, LLMConfig=None): + ElasticVectorStore.__init__(self, config=VectorConfig) + NIMCustomLLM.__init__(self, config=LLMConfig) + + +class NVIDIAEmbeddingFunction: + """ + A class that can be used as a replacement for chroma's DefaultEmbeddingFunction. + It takes in input (text or list of texts) and returns embeddings using NVIDIA's API. + + This class fixes two major interface compatibility issues between ChromaDB and NVIDIA embeddings: + + 1. INPUT FORMAT MISMATCH: + - ChromaDB passes ['query text'] (list) to embed_query() + - But langchain_nvidia's embed_query() expects 'query text' (string) + - When list is passed, langchain does [text] internally → [['query text']] → API 500 error + - FIX: Detect list input and extract string before calling langchain + + 2. OUTPUT FORMAT MISMATCH: + - ChromaDB expects embed_query() to return [[embedding_vector]] (list of embeddings) + - But langchain returns [embedding_vector] (single embedding vector) + - This causes: TypeError: 'float' object cannot be converted to 'Sequence' + - FIX: Wrap single embedding in list: return [embeddings] + """ + + def __init__(self, api_key, model="nvidia/llama-3.2-nv-embedqa-1b-v2"): + """ + Initialize the embedding function with the API key and model name. + + Parameters: + - api_key (str): The API key for authentication. + - model (str): The model name to use for embeddings. + Default: nvidia/llama-3.2-nv-embedqa-1b-v2 (tested and working) + """ + self.api_key = api_key + self.model = model + + logger.info(f"Initializing NVIDIA embeddings with model: {model}") + logger.debug(f"API key length: {len(api_key) if api_key else 0}") + + self.embeddings = NVIDIAEmbeddings( + api_key=api_key, model_name=model, input_type="query", truncate="NONE" + ) + logger.info("Successfully initialized NVIDIA embeddings") + + def __call__(self, input): + """ + Call method to make the object callable, as required by chroma's EmbeddingFunction interface. + + NOTE: This method is used by ChromaDB for batch embedding operations. + The embed_query() method above handles the single query case with the critical fixes. + + Parameters: + - input (str or list): The input data for which embeddings need to be generated. + + Returns: + - embedding (list): The embedding vector(s) for the input data. + """ + logger.debug(f"__call__ method called with input type: {type(input)}") + logger.debug(f"__call__ input: {input}") + + # Ensure input is a list, as required by ChromaDB + if isinstance(input, str): + input_data = [input] + else: + input_data = input + + logger.debug(f"Processing {len(input_data)} texts for embedding") + + # Generate embeddings for each text + embeddings = [] + for i, text in enumerate(input_data): + logger.debug(f"Embedding text {i+1}/{len(input_data)}: {text[:50]}...") + embedding = self.embeddings.embed_query(text) + embeddings.append(embedding) + + logger.debug(f"Generated {len(embeddings)} embeddings") + # Always return a list of embeddings for ChromaDB + return embeddings + + def name(self): + """ + Returns a custom name for the embedding function. + + Returns: + str: The name of the embedding function. + """ + return "NVIDIA Embedding Function" + + def embed_query(self, input: str) -> list[list[float]]: + """ + Generate embeddings for a single query. + + ChromaDB calls this method with ['query text'] (list) but langchain_nvidia expects 'query text' (string). + We must extract the string from the list to prevent API 500 errors. + + ChromaDB expects this method to return [[embedding_vector]] (list of embeddings) + but langchain returns [embedding_vector] (single embedding). We wrap it in a list. + """ + logger.debug(f"Embedding query: {input}") + logger.debug(f"Input type: {type(input)}") + logger.debug(f"Using model: {self.model}") + + # Handle ChromaDB's list input format + # ChromaDB sometimes passes a list instead of a string + # Extract the string from the list if needed + if isinstance(input, list): + if len(input) == 1: + query_text = input[0] + logger.debug(f"Extracted string from list: {query_text}") + else: + logger.error(f"Unexpected list length: {len(input)}") + raise ValueError( + f"Expected single string or list with one element, got list with {len(input)} elements" + ) + else: + query_text = input + + try: + # Call langchain_nvidia with the extracted string + embeddings = self.embeddings.embed_query(query_text) + logger.debug( + f"Successfully generated embeddings of length: {len(embeddings) if embeddings else 0}" + ) + + # Wrap single embedding in list for ChromaDB compatibility + # ChromaDB expects a list of embeddings, even for a single query + return [embeddings] + except Exception as e: + logger.error(f"Error generating embeddings for query: {e}") + logger.error(f"Error type: {type(e)}") + logger.error(f"Query text: {query_text}") + import traceback + + logger.error(f"Full traceback: {traceback.format_exc()}") + raise + + def embed_documents(self, input: list[str]) -> list[list[float]]: + """ + Generate embeddings for multiple documents. + + This function expects a list of strings. If it's a list of lists of strings, flatten it to handle cases + where the input is unexpectedly nested. + """ + logger.debug(f"Embedding {len(input)} documents...") + logger.debug(f"Using model: {self.model}") + + try: + embeddings = self.embeddings.embed_documents(input) + logger.debug("Successfully generated document embeddings") + return embeddings + except Exception as e: + logger.error(f"Error generating document embeddings: {e}") + logger.error(f"Error type: {type(e)}") + logger.error(f"Input documents count: {len(input)}") + import traceback + + logger.error(f"Full traceback: {traceback.format_exc()}") + raise + + +def chunk_documentation(text: str, max_chars: int = 1500) -> list: + """ + Split long documentation into smaller chunks to avoid token limits. + + Args: + text: The documentation text to chunk + max_chars: Maximum characters per chunk (approximate) + + Returns: + List of text chunks + """ + if len(text) <= max_chars: + return [text] + + chunks = [] + # Split by paragraphs first + paragraphs = text.split('\n\n') + current_chunk = "" + + for paragraph in paragraphs: + # If adding this paragraph would exceed the limit, save current chunk and start new one + if len(current_chunk) + len(paragraph) + 2 > max_chars and current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = paragraph + else: + if current_chunk: + current_chunk += "\n\n" + paragraph + else: + current_chunk = paragraph + + # Add the last chunk if it exists + if current_chunk.strip(): + chunks.append(current_chunk.strip()) + + # If any chunk is still too long, split it further + final_chunks = [] + for chunk in chunks: + if len(chunk) > max_chars: + # Split long chunk into sentences + sentences = chunk.split('. ') + temp_chunk = "" + for sentence in sentences: + if len(temp_chunk) + len(sentence) + 2 > max_chars and temp_chunk: + final_chunks.append(temp_chunk.strip() + ".") + temp_chunk = sentence + else: + if temp_chunk: + temp_chunk += ". " + sentence + else: + temp_chunk = sentence + if temp_chunk.strip(): + final_chunks.append(temp_chunk.strip()) + else: + final_chunks.append(chunk) + + return final_chunks + +def initVanna(vn, training_data_path: str = None): + """ + Initialize and train a Vanna instance for SQL generation using configurable training data. + + This function configures a Vanna SQL generation agent with training data loaded from a YAML file, + making it scalable for different SQL data sources with different contexts. + + Args: + vn: Vanna instance to be trained and configured + training_data_path: Path to YAML file containing training data. If None, no training is applied. + + Returns: + None: Modifies the Vanna instance in-place + + Example: + >>> from vanna.chromadb import ChromaDB_VectorStore + >>> vn = NIMCustomLLM(config) & ChromaDB_VectorStore() + >>> vn.connect_to_sqlite("path/to/database.db") + >>> initVanna(vn, "path/to/training_data.yaml") + >>> # Vanna is now ready to generate SQL queries + """ + import json + import os + import logging + + logger = logging.getLogger(__name__) + logger.info("=== Starting Vanna initialization ===") + + # Get and train DDL from sqlite_master + logger.info("Loading DDL from sqlite_master...") + try: + df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null") + ddl_count = len(df_ddl) + logger.info(f"Found {ddl_count} DDL statements in sqlite_master") + + for i, ddl in enumerate(df_ddl['sql'].to_list(), 1): + logger.debug(f"Training DDL {i}/{ddl_count}: {ddl[:100]}...") + vn.train(ddl=ddl) + + logger.info(f"Successfully trained {ddl_count} DDL statements from sqlite_master") + except Exception as e: + logger.error(f"Error loading DDL from sqlite_master: {e}") + raise + + # Load and apply training data from YAML file + if training_data_path: + logger.info(f"Training data path provided: {training_data_path}") + + if os.path.exists(training_data_path): + logger.info(f"Training data file exists, loading YAML...") + + try: + import yaml + with open(training_data_path, 'r') as f: + training_data = yaml.safe_load(f) + + logger.info(f"Successfully loaded YAML training data") + logger.info(f"Training data keys: {list(training_data.keys()) if training_data else 'None'}") + + # Train synthetic DDL statements + synthetic_ddl = training_data.get("synthetic_ddl", []) + logger.info(f"Found {len(synthetic_ddl)} synthetic DDL statements") + + ddl_trained = 0 + for i, ddl_statement in enumerate(synthetic_ddl, 1): + if ddl_statement.strip(): # Only train non-empty statements + logger.debug(f"Training synthetic DDL {i}: {ddl_statement[:100]}...") + vn.train(ddl=ddl_statement) + ddl_trained += 1 + else: + logger.warning(f"Skipping empty synthetic DDL statement at index {i}") + + logger.info(f"Successfully trained {ddl_trained}/{len(synthetic_ddl)} synthetic DDL statements") + + # Train documentation with chunking + documentation = training_data.get("documentation", "") + if documentation.strip(): + logger.info(f"Training documentation ({len(documentation)} characters)") + logger.debug(f"Documentation preview: {documentation[:200]}...") + + # Chunk documentation to avoid token limits + doc_chunks = chunk_documentation(documentation) + logger.info(f"Split documentation into {len(doc_chunks)} chunks") + + for i, chunk in enumerate(doc_chunks, 1): + try: + logger.debug(f"Training documentation chunk {i}/{len(doc_chunks)} ({len(chunk)} chars)") + vn.train(documentation=chunk) + except Exception as e: + logger.error(f"Error training documentation chunk {i}: {e}") + # Continue with other chunks + + logger.info(f"Successfully trained {len(doc_chunks)} documentation chunks") + else: + logger.warning("No documentation found or documentation is empty") + + # Train example queries + example_queries = training_data.get("example_queries", []) + logger.info(f"Found {len(example_queries)} example queries") + + queries_trained = 0 + for i, query_data in enumerate(example_queries, 1): + sql = query_data.get("sql", "") + if sql.strip(): # Only train non-empty queries + logger.debug(f"Training example query {i}: {sql[:100]}...") + vn.train(sql=sql) + queries_trained += 1 + else: + logger.warning(f"Skipping empty example query at index {i}") + + logger.info(f"Successfully trained {queries_trained}/{len(example_queries)} example queries") + + # Train question-SQL pairs + question_sql_pairs = training_data.get("question_sql_pairs", []) + logger.info(f"Found {len(question_sql_pairs)} question-SQL pairs") + + pairs_trained = 0 + for i, pair in enumerate(question_sql_pairs, 1): + question = pair.get("question", "") + sql = pair.get("sql", "") + if question.strip() and sql.strip(): # Only train non-empty pairs + logger.debug(f"Training question-SQL pair {i}: Q='{question[:50]}...' SQL='{sql[:50]}...'") + vn.train(question=question, sql=sql) + pairs_trained += 1 + else: + if not question.strip(): + logger.warning(f"Skipping question-SQL pair {i}: empty question") + if not sql.strip(): + logger.warning(f"Skipping question-SQL pair {i}: empty SQL") + + logger.info(f"Successfully trained {pairs_trained}/{len(question_sql_pairs)} question-SQL pairs") + + # Summary + total_trained = ddl_trained + len(doc_chunks) + queries_trained + pairs_trained + logger.info(f"=== Training Summary ===") + logger.info(f" Synthetic DDL: {ddl_trained}") + logger.info(f" Documentation chunks: {len(doc_chunks)}") + logger.info(f" Example queries: {queries_trained}") + logger.info(f" Question-SQL pairs: {pairs_trained}") + logger.info(f" Total items trained: {total_trained}") + + except yaml.YAMLError as e: + logger.error(f"Error parsing YAML file {training_data_path}: {e}") + raise + except Exception as e: + logger.error(f"Error loading training data from {training_data_path}: {e}") + raise + else: + logger.warning(f"Training data file does not exist: {training_data_path}") + logger.warning("Proceeding without YAML training data") + else: + logger.info("No training data path provided, skipping YAML training") + + logger.info("=== Vanna initialization completed ===") + diff --git a/examples/asset_lifecycle_management/utils_template/__init__.py b/examples/asset_lifecycle_management/utils_template/__init__.py new file mode 100644 index 0000000..d5528a3 --- /dev/null +++ b/examples/asset_lifecycle_management/utils_template/__init__.py @@ -0,0 +1,11 @@ +""" +Workspace utilities for Asset Lifecycle Management tasks. + +These pre-built utility functions provide reliable, tested implementations +for common data processing tasks, particularly for predictive maintenance workflows. +""" + +from .rul_utils import apply_piecewise_rul_transformation, show_utilities + +__all__ = ['apply_piecewise_rul_transformation', 'show_utilities'] + diff --git a/examples/asset_lifecycle_management/utils_template/rul_utils.py b/examples/asset_lifecycle_management/utils_template/rul_utils.py new file mode 100644 index 0000000..afebd9f --- /dev/null +++ b/examples/asset_lifecycle_management/utils_template/rul_utils.py @@ -0,0 +1,146 @@ +""" +RUL (Remaining Useful Life) transformation utilities. + +Provides pre-built functions for transforming RUL data to create realistic patterns +for Asset Lifecycle Management and predictive maintenance tasks. +""" + +import pandas as pd +import logging + +logger = logging.getLogger(__name__) + + +def apply_piecewise_rul_transformation( + df: pd.DataFrame, + maxlife: int = 100, + time_col: str = 'time_in_cycles', + rul_col: str = 'RUL' +) -> pd.DataFrame: + """ + Transform RUL data to create realistic "knee" patterns. + + This function applies a piecewise transformation to RUL (Remaining Useful Life) values + to create a more realistic degradation pattern commonly seen in predictive maintenance: + - RUL stays constant at MAXLIFE until the remaining cycles drop below the threshold + - Then RUL decreases linearly to 0 as the equipment approaches failure + + This creates the characteristic "knee" pattern seen in actual equipment degradation. + + Args: + df: pandas DataFrame with time series data containing RUL values + maxlife: Maximum life threshold for the piecewise function (default: 100) + RUL values above this will be capped at maxlife + time_col: Name of the time/cycle column (default: 'time_in_cycles') + rul_col: Name of the RUL column to transform (default: 'RUL') + + Returns: + pandas DataFrame with original data plus new 'transformed_RUL' column + + Raises: + ValueError: If required columns are missing from the DataFrame + + Example: + >>> df = pd.DataFrame({'time_in_cycles': [1, 2, 3], 'RUL': [150, 100, 50]}) + >>> df_transformed = apply_piecewise_rul_transformation(df, maxlife=100) + >>> print(df_transformed['transformed_RUL']) + 0 100 + 1 100 + 2 50 + Name: transformed_RUL, dtype: int64 + """ + # Validate inputs + if not isinstance(df, pd.DataFrame): + raise ValueError(f"Expected pandas DataFrame, got {type(df)}") + + if rul_col not in df.columns: + raise ValueError( + f"RUL column '{rul_col}' not found in DataFrame. " + f"Available columns: {list(df.columns)}" + ) + + if time_col not in df.columns: + logger.warning( + f"Time column '{time_col}' not found in DataFrame, but continuing anyway. " + f"Available columns: {list(df.columns)}" + ) + + # Create a copy to avoid modifying the original + df_copy = df.copy() + + logger.info(f"Applying piecewise RUL transformation with maxlife={maxlife}") + logger.debug(f"Input RUL range: [{df_copy[rul_col].min()}, {df_copy[rul_col].max()}]") + + # Apply piecewise transformation + def transform_rul(rul_value): + """Apply the piecewise transformation to a single RUL value.""" + if pd.isna(rul_value): + return rul_value # Keep NaN values as NaN + if rul_value > maxlife: + return maxlife + return rul_value + + # Apply transformation to create new column + df_copy['transformed_RUL'] = df_copy[rul_col].apply(transform_rul) + + logger.info( + f"✅ Transformation complete! Added 'transformed_RUL' column. " + f"Output range: [{df_copy['transformed_RUL'].min()}, {df_copy['transformed_RUL'].max()}]" + ) + logger.debug(f"Total rows processed: {len(df_copy)}") + + return df_copy + + +def show_utilities(): + """ + Display available utility functions and their usage. + + Prints a formatted list of all available utilities in this workspace, + including descriptions and example usage. + """ + utilities_info = """ + ================================================================================ + WORKSPACE UTILITIES - Asset Lifecycle Management + ================================================================================ + + Available utility functions: + + 1. apply_piecewise_rul_transformation(df, maxlife=100, time_col='time_in_cycles', rul_col='RUL') + + Description: + Transforms RUL (Remaining Useful Life) data to create realistic "knee" patterns + commonly seen in predictive maintenance scenarios. + + Parameters: + - df: pandas DataFrame with time series data + - maxlife: Maximum life threshold (default: 100) + - time_col: Name of time/cycle column (default: 'time_in_cycles') + - rul_col: Name of RUL column to transform (default: 'RUL') + + Returns: + DataFrame with original data plus new 'transformed_RUL' column + + Example: + df_transformed = utils.apply_piecewise_rul_transformation(df, maxlife=100) + print(df_transformed[['time_in_cycles', 'RUL', 'transformed_RUL']]) + + 2. show_utilities() + + Description: + Displays this help message with all available utilities. + + Example: + utils.show_utilities() + + ================================================================================ + """ + print(utilities_info) + + +if __name__ == "__main__": + # Simple test + print("RUL Utilities Module") + print("=" * 50) + show_utilities() + diff --git a/examples/asset_lifecycle_management/vanna_training_data.yaml b/examples/asset_lifecycle_management/vanna_training_data.yaml new file mode 100644 index 0000000..a4b8b21 --- /dev/null +++ b/examples/asset_lifecycle_management/vanna_training_data.yaml @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Vanna SQL Agent Training Configuration +# ===================================== +# This YAML file contains all the training data needed to configure a Vanna SQL generation agent +# for your specific database and domain. Each section serves a different purpose in training +# the agent to understand your database structure and generate appropriate SQL queries. + +training_config: + # Basic metadata about this training configuration + description: "Training data for NASA Turbofan Engine Asset Lifecycle Management (predictive maintenance) SQL generation" + version: "1.0" + # You should update these fields to describe your specific domain and use case + +# SYNTHETIC DDL STATEMENTS +# ======================== +# Purpose: Define table structures that may not be fully captured in the actual database schema +# When to use: +# - When you have tables that aren't in the main database but need to be referenced +# - When you want to ensure the agent knows about specific table structures +# - When you need to supplement incomplete schema information from sqlite_master +# How to populate: +# - Include CREATE TABLE statements for any tables the agent should know about +# - Focus on tables that are central to your domain but might be missing from auto-discovery +# - Use exact DDL syntax as if you were creating the tables manually +# - Include all columns with proper data types to help the agent understand structure +synthetic_ddl: + - "CREATE TABLE IF NOT EXISTS RUL_FD001 (\"unit_number\" INTEGER, \"RUL\" INTEGER)" + - "CREATE TABLE IF NOT EXISTS RUL_FD002 (\"unit_number\" INTEGER, \"RUL\" INTEGER)" + - "CREATE TABLE IF NOT EXISTS RUL_FD003 (\"unit_number\" INTEGER, \"RUL\" INTEGER)" + - "CREATE TABLE IF NOT EXISTS RUL_FD004 (\"unit_number\" INTEGER, \"RUL\" INTEGER)" + - "CREATE TABLE IF NOT EXISTS train_FD001 (\"unit_number\" INTEGER, \"time_in_cycles\" INTEGER, \"operational_setting_1\" REAL, \"operational_setting_2\" REAL, \"operational_setting_3\" REAL, \"sensor_measurement_1\" REAL, \"sensor_measurement_2\" REAL, \"sensor_measurement_3\" REAL, \"sensor_measurement_4\" REAL, \"sensor_measurement_5\" REAL, \"sensor_measurement_6\" REAL, \"sensor_measurement_7\" REAL, \"sensor_measurement_8\" REAL, \"sensor_measurement_9\" REAL, \"sensor_measurement_10\" REAL, \"sensor_measurement_11\" REAL, \"sensor_measurement_12\" REAL, \"sensor_measurement_13\" REAL, \"sensor_measurement_14\" REAL, \"sensor_measurement_15\" REAL, \"sensor_measurement_16\" REAL, \"sensor_measurement_17\" INTEGER, \"sensor_measurement_18\" INTEGER, \"sensor_measurement_19\" REAL, \"sensor_measurement_20\" REAL, \"sensor_measurement_21\" REAL)" + - "CREATE TABLE IF NOT EXISTS test_FD001 (\"unit_number\" INTEGER, \"time_in_cycles\" INTEGER, \"operational_setting_1\" REAL, \"operational_setting_2\" REAL, \"operational_setting_3\" REAL, \"sensor_measurement_1\" REAL, \"sensor_measurement_2\" REAL, \"sensor_measurement_3\" REAL, \"sensor_measurement_4\" REAL, \"sensor_measurement_5\" REAL, \"sensor_measurement_6\" REAL, \"sensor_measurement_7\" REAL, \"sensor_measurement_8\" REAL, \"sensor_measurement_9\" REAL, \"sensor_measurement_10\" REAL, \"sensor_measurement_11\" REAL, \"sensor_measurement_12\" REAL, \"sensor_measurement_13\" REAL, \"sensor_measurement_14\" REAL, \"sensor_measurement_15\" REAL, \"sensor_measurement_16\" REAL, \"sensor_measurement_17\" INTEGER, \"sensor_measurement_18\" INTEGER, \"sensor_measurement_19\" REAL, \"sensor_measurement_20\" REAL, \"sensor_measurement_21\" REAL)" + - "CREATE TABLE IF NOT EXISTS train_FD002 (\"unit_number\" INTEGER, \"time_in_cycles\" INTEGER, \"operational_setting_1\" REAL, \"operational_setting_2\" REAL, \"operational_setting_3\" REAL, \"sensor_measurement_1\" REAL, \"sensor_measurement_2\" REAL, \"sensor_measurement_3\" REAL, \"sensor_measurement_4\" REAL, \"sensor_measurement_5\" REAL, \"sensor_measurement_6\" REAL, \"sensor_measurement_7\" REAL, \"sensor_measurement_8\" REAL, \"sensor_measurement_9\" REAL, \"sensor_measurement_10\" REAL, \"sensor_measurement_11\" REAL, \"sensor_measurement_12\" REAL, \"sensor_measurement_13\" REAL, \"sensor_measurement_14\" REAL, \"sensor_measurement_15\" REAL, \"sensor_measurement_16\" REAL, \"sensor_measurement_17\" INTEGER, \"sensor_measurement_18\" INTEGER, \"sensor_measurement_19\" REAL, \"sensor_measurement_20\" REAL, \"sensor_measurement_21\" REAL)" + - "CREATE TABLE IF NOT EXISTS test_FD002 (\"unit_number\" INTEGER, \"time_in_cycles\" INTEGER, \"operational_setting_1\" REAL, \"operational_setting_2\" REAL, \"operational_setting_3\" REAL, \"sensor_measurement_1\" REAL, \"sensor_measurement_2\" REAL, \"sensor_measurement_3\" REAL, \"sensor_measurement_4\" REAL, \"sensor_measurement_5\" REAL, \"sensor_measurement_6\" REAL, \"sensor_measurement_7\" REAL, \"sensor_measurement_8\" REAL, \"sensor_measurement_9\" REAL, \"sensor_measurement_10\" REAL, \"sensor_measurement_11\" REAL, \"sensor_measurement_12\" REAL, \"sensor_measurement_13\" REAL, \"sensor_measurement_14\" REAL, \"sensor_measurement_15\" REAL, \"sensor_measurement_16\" REAL, \"sensor_measurement_17\" INTEGER, \"sensor_measurement_18\" INTEGER, \"sensor_measurement_19\" REAL, \"sensor_measurement_20\" REAL, \"sensor_measurement_21\" REAL)" + - "CREATE TABLE IF NOT EXISTS train_FD003 (\"unit_number\" INTEGER, \"time_in_cycles\" INTEGER, \"operational_setting_1\" REAL, \"operational_setting_2\" REAL, \"operational_setting_3\" REAL, \"sensor_measurement_1\" REAL, \"sensor_measurement_2\" REAL, \"sensor_measurement_3\" REAL, \"sensor_measurement_4\" REAL, \"sensor_measurement_5\" REAL, \"sensor_measurement_6\" REAL, \"sensor_measurement_7\" REAL, \"sensor_measurement_8\" REAL, \"sensor_measurement_9\" REAL, \"sensor_measurement_10\" REAL, \"sensor_measurement_11\" REAL, \"sensor_measurement_12\" REAL, \"sensor_measurement_13\" REAL, \"sensor_measurement_14\" REAL, \"sensor_measurement_15\" REAL, \"sensor_measurement_16\" REAL, \"sensor_measurement_17\" INTEGER, \"sensor_measurement_18\" INTEGER, \"sensor_measurement_19\" REAL, \"sensor_measurement_20\" REAL, \"sensor_measurement_21\" REAL)" + - "CREATE TABLE IF NOT EXISTS test_FD003 (\"unit_number\" INTEGER, \"time_in_cycles\" INTEGER, \"operational_setting_1\" REAL, \"operational_setting_2\" REAL, \"operational_setting_3\" REAL, \"sensor_measurement_1\" REAL, \"sensor_measurement_2\" REAL, \"sensor_measurement_3\" REAL, \"sensor_measurement_4\" REAL, \"sensor_measurement_5\" REAL, \"sensor_measurement_6\" REAL, \"sensor_measurement_7\" REAL, \"sensor_measurement_8\" REAL, \"sensor_measurement_9\" REAL, \"sensor_measurement_10\" REAL, \"sensor_measurement_11\" REAL, \"sensor_measurement_12\" REAL, \"sensor_measurement_13\" REAL, \"sensor_measurement_14\" REAL, \"sensor_measurement_15\" REAL, \"sensor_measurement_16\" REAL, \"sensor_measurement_17\" INTEGER, \"sensor_measurement_18\" INTEGER, \"sensor_measurement_19\" REAL, \"sensor_measurement_20\" REAL, \"sensor_measurement_21\" REAL)" + - "CREATE TABLE IF NOT EXISTS train_FD004 (\"unit_number\" INTEGER, \"time_in_cycles\" INTEGER, \"operational_setting_1\" REAL, \"operational_setting_2\" REAL, \"operational_setting_3\" REAL, \"sensor_measurement_1\" REAL, \"sensor_measurement_2\" REAL, \"sensor_measurement_3\" REAL, \"sensor_measurement_4\" REAL, \"sensor_measurement_5\" REAL, \"sensor_measurement_6\" REAL, \"sensor_measurement_7\" REAL, \"sensor_measurement_8\" REAL, \"sensor_measurement_9\" REAL, \"sensor_measurement_10\" REAL, \"sensor_measurement_11\" REAL, \"sensor_measurement_12\" REAL, \"sensor_measurement_13\" REAL, \"sensor_measurement_14\" REAL, \"sensor_measurement_15\" REAL, \"sensor_measurement_16\" REAL, \"sensor_measurement_17\" INTEGER, \"sensor_measurement_18\" INTEGER, \"sensor_measurement_19\" REAL, \"sensor_measurement_20\" REAL, \"sensor_measurement_21\" REAL)" + - "CREATE TABLE IF NOT EXISTS test_FD004 (\"unit_number\" INTEGER, \"time_in_cycles\" INTEGER, \"operational_setting_1\" REAL, \"operational_setting_2\" REAL, \"operational_setting_3\" REAL, \"sensor_measurement_1\" REAL, \"sensor_measurement_2\" REAL, \"sensor_measurement_3\" REAL, \"sensor_measurement_4\" REAL, \"sensor_measurement_5\" REAL, \"sensor_measurement_6\" REAL, \"sensor_measurement_7\" REAL, \"sensor_measurement_8\" REAL, \"sensor_measurement_9\" REAL, \"sensor_measurement_10\" REAL, \"sensor_measurement_11\" REAL, \"sensor_measurement_12\" REAL, \"sensor_measurement_13\" REAL, \"sensor_measurement_14\" REAL, \"sensor_measurement_15\" REAL, \"sensor_measurement_16\" REAL, \"sensor_measurement_17\" INTEGER, \"sensor_measurement_18\" INTEGER, \"sensor_measurement_19\" REAL, \"sensor_measurement_20\" REAL, \"sensor_measurement_21\" REAL)" + +# DOMAIN DOCUMENTATION +# ==================== +# Purpose: Provide context about your database structure, business rules, and query patterns +# When to use: Always - this is crucial for helping the agent understand your domain +# How to populate: +# - Use second-person language ("You are working with...", "When you see...") +# - Explain the business context and what the data represents +# - Define important query patterns and conventions specific to your domain +# - Include any business rules or logical distinctions the agent should understand +# - Explain column meanings, especially if they're not self-evident +# - Provide guidance on default behaviors when queries are ambiguous +# - Think of this as training documentation for a new team member who needs to understand your database +documentation: | + You are working with a SQL database containing train and test splits of four different datasets: FD001, FD002, FD003, FD004. + Each dataset consists of multiple multivariate time series from different engines of the same type. + + DATABASE STRUCTURE YOU'LL WORK WITH: + The data is organized into separate tables for each dataset that you'll need to query: + + Training Tables: train_FD001, train_FD002, train_FD003, train_FD004 + Test Tables: test_FD001, test_FD002, test_FD003, test_FD004 + RUL Tables: RUL_FD001, RUL_FD002, RUL_FD003, RUL_FD004 + + When you query training and test tables, you'll find 26 columns with identical structure: + - unit_number: INTEGER - Identifier for each engine unit + - time_in_cycles: INTEGER - Time step in operational cycles + - operational_setting_1: REAL - First operational setting affecting performance + - operational_setting_2: REAL - Second operational setting affecting performance + - operational_setting_3: REAL - Third operational setting affecting performance + - sensor_measurement_1 through sensor_measurement_21: REAL/INTEGER - Twenty-one sensor measurements + + When you query RUL tables, you'll find 2 columns: + - unit_number: INTEGER - Engine unit identifier + - RUL: INTEGER - Remaining Useful Life value for that test unit + + QUERY PATTERNS YOU SHOULD USE: + + Table References: + - When you see "train_FD001" or "dataset train_FD001" → Use table train_FD001 + - When you see "test_FD002" or "dataset test_FD002" → Use table test_FD002 + - When you see "FD003" (without train/test prefix) → Determine from context whether to use train_FD003 or test_FD003 + - For RUL queries: Use the specific RUL table (RUL_FD001, RUL_FD002, RUL_FD003, or RUL_FD004) + + Counting Patterns You Should Follow: + - When asked "How many units" → Use COUNT(DISTINCT unit_number) to count unique engines + - When asked "How many records/data points/measurements/entries/rows" → Use COUNT(*) to count all records + + RUL Handling (CRITICAL - YOU MUST DISTINGUISH): + + 1. GROUND TRUTH RUL (for test data): + - Use when you see requests for "actual RUL", "true RUL", "ground truth", or "what is the RUL" + - You should query the specific RUL table: SELECT RUL FROM RUL_FD001 WHERE unit_number=N + - For time-series with ground truth: ((SELECT MAX(time_in_cycles) FROM test_FDxxx WHERE unit_number=N) + (SELECT RUL FROM RUL_FDxxx WHERE unit_number=N) - time_in_cycles) + + 2. PREDICTED/CALCULATED RUL (for training data or prediction requests): + - Use when you see requests to "predict RUL", "calculate RUL", "estimate RUL", or "find RUL" for training data + - For training data: You should calculate as remaining cycles until failure = (MAX(time_in_cycles) - current_time_in_cycles + 1) + - Your training RUL query should be: SELECT unit_number, time_in_cycles, (MAX(time_in_cycles) OVER (PARTITION BY unit_number) - time_in_cycles + 1) AS predicted_RUL FROM train_FDxxx + + DEFAULT BEHAVIOR YOU SHOULD FOLLOW: If unclear, assume the user wants PREDICTION (since this is more common) + + Column Names You'll Use (consistent across all training and test tables): + - unit_number: Engine identifier + - time_in_cycles: Time step + - operational_setting_1, operational_setting_2, operational_setting_3: Operational settings + - sensor_measurement_1, sensor_measurement_2, ..., sensor_measurement_21: Sensor readings + + IMPORTANT NOTES FOR YOUR QUERIES: + - Each dataset (FD001, FD002, FD003, FD004) has its own separate RUL table + - RUL tables do NOT have a 'dataset' column - they are dataset-specific by table name + - Training tables contain data until engine failure + - Test tables contain data that stops before failure + - RUL tables provide the actual remaining cycles for test units + + ENGINE OPERATION CONTEXT FOR YOUR UNDERSTANDING: + You are working with engine data where each engine starts with different degrees of initial wear and manufacturing variation. + The engine operates normally at the start of each time series and develops a fault at some point during the series. + In the training set, the fault grows in magnitude until system failure. + In the test set, the time series ends some time prior to system failure. + Your objective is to help predict the number of remaining operational cycles before failure in the test set. + +# EXAMPLE QUERIES +# =============== +# Purpose: Teach the agent common SQL patterns and query structures for your domain +# When to use: Include 3-7 diverse examples that cover the main query patterns you expect +# How to populate: +# - Choose queries that represent different SQL concepts (JOINs, aggregations, window functions, etc.) +# - Focus on domain-specific patterns that are unique to your use case +# - Include complex queries that demonstrate proper table relationships +# - Add a description to explain what pattern each query demonstrates +# - Prioritize quality over quantity - better to have 5 great examples than 20 mediocre ones +example_queries: + - description: "JOIN pattern between training and RUL tables" + sql: "SELECT t.unit_number, t.time_in_cycles, t.operational_setting_1, r.RUL FROM train_FD001 AS t JOIN RUL_FD001 AS r ON t.unit_number = r.unit_number WHERE t.unit_number = 1 ORDER BY t.time_in_cycles" + + - description: "Aggregation with multiple statistical functions" + sql: "SELECT unit_number, AVG(sensor_measurement_1) AS avg_sensor1, MAX(sensor_measurement_2) AS max_sensor2, MIN(sensor_measurement_3) AS min_sensor3 FROM train_FD002 GROUP BY unit_number" + + - description: "Test table filtering with time-based conditions" + sql: "SELECT * FROM test_FD003 WHERE time_in_cycles > 50 AND sensor_measurement_1 > 500 ORDER BY unit_number, time_in_cycles" + + - description: "Window function for predicted RUL calculation on training data" + sql: "SELECT unit_number, time_in_cycles, (MAX(time_in_cycles) OVER (PARTITION BY unit_number) - time_in_cycles + 1) AS predicted_RUL FROM train_FD004 WHERE unit_number <= 3 ORDER BY unit_number, time_in_cycles" + + - description: "Direct RUL table query with filtering" + sql: "SELECT unit_number, RUL FROM RUL_FD001 WHERE RUL > 100 ORDER BY RUL DESC" + + - description: "Retrieve all sensor readings for a specific unit (for anomaly detection)" + sql: "SELECT time_in_cycles, sensor_measurement_1, sensor_measurement_2, sensor_measurement_3, sensor_measurement_4, sensor_measurement_5, sensor_measurement_6, sensor_measurement_7, sensor_measurement_8, sensor_measurement_9, sensor_measurement_10, sensor_measurement_11, sensor_measurement_12, sensor_measurement_13, sensor_measurement_14, sensor_measurement_15, sensor_measurement_16, sensor_measurement_17, sensor_measurement_18, sensor_measurement_19, sensor_measurement_20, sensor_measurement_21 FROM train_FD001 WHERE unit_number = 1 ORDER BY time_in_cycles" + +# QUESTION-SQL PAIRS +# ================== +# Purpose: Train the agent to map natural language questions to specific SQL queries +# When to use: Include 5-10 pairs that cover the most common user questions in your domain +# How to populate: +# - Use realistic questions that your users would actually ask +# - Cover edge cases and domain-specific terminology +# - Include both simple and complex question patterns +# - Focus on questions that demonstrate important business logic distinctions +# - Include variations of similar questions to improve robustness +# - Make sure questions cover different table types and query patterns +question_sql_pairs: + - question: "Get time cycles and operational setting 1 for unit 1 from test FD001" + sql: "SELECT time_in_cycles, operational_setting_1 FROM test_FD001 WHERE unit_number = 1" + + - question: "What is the actual remaining useful life for unit 1 in test dataset FD001" + sql: "SELECT RUL FROM RUL_FD001 WHERE unit_number = 1" + + - question: "Predict the remaining useful life for each time cycle of unit 1 in training dataset FD001" + sql: "SELECT unit_number, time_in_cycles, (MAX(time_in_cycles) OVER (PARTITION BY unit_number) - time_in_cycles + 1) AS predicted_RUL FROM train_FD001 WHERE unit_number = 1 ORDER BY time_in_cycles" + + - question: "How many units are in the training data for FD002" + sql: "SELECT COUNT(DISTINCT unit_number) FROM train_FD002" + + - question: "Calculate RUL for training data in FD003" + sql: "SELECT unit_number, time_in_cycles, (MAX(time_in_cycles) OVER (PARTITION BY unit_number) - time_in_cycles + 1) AS predicted_RUL FROM train_FD003 ORDER BY unit_number, time_in_cycles" + + - question: "Get ground truth RUL values for all units in test FD002" + sql: "SELECT unit_number, RUL FROM RUL_FD002 ORDER BY unit_number" + + - question: "Retrieve all sensor readings for unit 1 from train_FD001 dataset" + sql: "SELECT time_in_cycles, sensor_measurement_1, sensor_measurement_2, sensor_measurement_3, sensor_measurement_4, sensor_measurement_5, sensor_measurement_6, sensor_measurement_7, sensor_measurement_8, sensor_measurement_9, sensor_measurement_10, sensor_measurement_11, sensor_measurement_12, sensor_measurement_13, sensor_measurement_14, sensor_measurement_15, sensor_measurement_16, sensor_measurement_17, sensor_measurement_18, sensor_measurement_19, sensor_measurement_20, sensor_measurement_21 FROM train_FD001 WHERE unit_number = 1 ORDER BY time_in_cycles" + + - question: "Get all sensor data for unit 5 from test FD002 dataset for anomaly detection" + sql: "SELECT time_in_cycles, sensor_measurement_1, sensor_measurement_2, sensor_measurement_3, sensor_measurement_4, sensor_measurement_5, sensor_measurement_6, sensor_measurement_7, sensor_measurement_8, sensor_measurement_9, sensor_measurement_10, sensor_measurement_11, sensor_measurement_12, sensor_measurement_13, sensor_measurement_14, sensor_measurement_15, sensor_measurement_16, sensor_measurement_17, sensor_measurement_18, sensor_measurement_19, sensor_measurement_20, sensor_measurement_21 FROM test_FD002 WHERE unit_number = 5 ORDER BY time_in_cycles" \ No newline at end of file