From a07c77ab877a3cf39a6945dde1764ba5855585a6 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 31 Dec 2025 11:36:00 +0000 Subject: [PATCH 1/2] Update LinkedIn and GitHub username to ajaymauryabbn --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e90e0c8..0c86f98 100644 --- a/README.md +++ b/README.md @@ -248,8 +248,8 @@ MIT License - see [LICENSE](LICENSE) for details. ## Author -**Ajay Maurya** - AI Engineer -[LinkedIn](https://linkedin.com/in/ajaymaurya) | [GitHub](https://github.com/ajaymaurya) +**Ajay Maurya** - AI Engineer +[LinkedIn](https://linkedin.com/in/ajaymauryabbn) | [GitHub](https://github.com/ajaymauryabbn) --- From 32003f2e7cd12b35b734e48ceee42039e020ba7c Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 31 Dec 2025 11:48:37 +0000 Subject: [PATCH 2/2] Fix all linting errors to pass CI checks - Auto-fixed 544 linting errors with ruff (whitespace, imports, etc.) - Updated deprecated ruff configuration in pyproject.toml - Fixed ambiguous variable name in cli.py - Removed unused imports from sql_parser.py - Replaced bare except clauses with proper exception handling - All tests passing (26/26) --- examples/basic_usage.py | 26 +-- pyproject.toml | 2 + sql_eval/__init__.py | 4 +- sql_eval/cli.py | 107 +++++------ sql_eval/connectors/__init__.py | 16 +- sql_eval/connectors/base.py | 30 +-- sql_eval/connectors/mysql.py | 55 +++--- sql_eval/connectors/postgresql.py | 63 +++--- sql_eval/connectors/sqlite.py | 59 +++--- sql_eval/core/__init__.py | 16 +- sql_eval/core/evaluator.py | 129 +++++++------ sql_eval/core/models.py | 46 ++--- sql_eval/core/schema_loader.py | 85 +++++---- sql_eval/core/sql_parser.py | 191 +++++++++---------- sql_eval/datasets/__init__.py | 29 ++- sql_eval/llm_providers/__init__.py | 16 +- sql_eval/llm_providers/anthropic_provider.py | 35 ++-- sql_eval/llm_providers/base.py | 37 ++-- sql_eval/llm_providers/ollama_provider.py | 78 ++++---- sql_eval/llm_providers/openai_provider.py | 37 ++-- tests/test_core.py | 69 +++---- 21 files changed, 568 insertions(+), 562 deletions(-) diff --git a/examples/basic_usage.py b/examples/basic_usage.py index 72b7ad5..861d3fd 100644 --- a/examples/basic_usage.py +++ b/examples/basic_usage.py @@ -16,7 +16,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from sql_eval import Evaluator -from sql_eval.datasets import load_ecommerce, get_dataset_info +from sql_eval.datasets import get_dataset_info, load_ecommerce from sql_eval.llm_providers import get_provider @@ -29,7 +29,7 @@ def main(): provider_name = 'ollama' else: provider_name = 'openai' - + # Show dataset info print("\n" + "=" * 50) print("Dataset Information") @@ -38,21 +38,21 @@ def main(): print(f"Name: {info['name']}") print(f"Questions: {info['num_questions']}") print(f"Tables: {', '.join(info['tables'])}") - print(f"\nDifficulty breakdown:") + print("\nDifficulty breakdown:") for diff, count in info['difficulty_breakdown'].items(): print(f" {diff}: {count}") - + # Load dataset print("\n" + "=" * 50) print("Loading Dataset") print("=" * 50) test_cases, schema, db = load_ecommerce(with_db=True) print(f"Loaded {len(test_cases)} test cases") - + # Limit to first 5 for quick demo test_cases = test_cases[:5] print(f"Running evaluation on {len(test_cases)} questions (limited for demo)") - + # Initialize provider print(f"\nUsing LLM provider: {provider_name}") try: @@ -60,7 +60,7 @@ def main(): except Exception as e: print(f"Error initializing provider: {e}") return - + # Create evaluator evaluator = Evaluator( llm_provider=provider, @@ -68,12 +68,12 @@ def main(): db_connector=db, verbose=True ) - + # Run evaluation print("\n" + "=" * 50) print("Running Evaluation") print("=" * 50) - + try: report = evaluator.evaluate( test_cases, @@ -83,12 +83,12 @@ def main(): print(f"\nError during evaluation: {e}") print("\nTip: Make sure your LLM provider is properly configured") return - + # Show detailed results print("\n" + "=" * 50) print("Detailed Results") print("=" * 50) - + for result in report.results: status = "✓" if result.exact_match else "✗" print(f"\n{status} {result.question_id}: {result.question[:50]}...") @@ -96,11 +96,11 @@ def main(): print(f" Got: {result.generated_sql[:60]}...") if not result.exact_match and result.partial_scores: print(f" Structural score: {result.partial_scores.overall_score:.1%}") - + # Cleanup if db: db.disconnect() - + print("\n" + "=" * 50) print("Done!") print("=" * 50) diff --git a/pyproject.toml b/pyproject.toml index 3becfc5..97a78ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,8 @@ target-version = ['py310', 'py311', 'py312'] [tool.ruff] line-length = 100 + +[tool.ruff.lint] select = ["E", "F", "W", "I"] ignore = ["E501"] diff --git a/sql_eval/__init__.py b/sql_eval/__init__.py index 0a34f3a..4c572ff 100644 --- a/sql_eval/__init__.py +++ b/sql_eval/__init__.py @@ -8,11 +8,11 @@ __author__ = "Ajay Maurya" from sql_eval.core.evaluator import Evaluator -from sql_eval.core.models import EvaluationCase, EvaluationResult, EvaluationReport +from sql_eval.core.models import EvaluationCase, EvaluationReport, EvaluationResult __all__ = [ "Evaluator", - "EvaluationCase", + "EvaluationCase", "EvaluationResult", "EvaluationReport", ] diff --git a/sql_eval/cli.py b/sql_eval/cli.py index 0a795bf..03a6efc 100644 --- a/sql_eval/cli.py +++ b/sql_eval/cli.py @@ -3,11 +3,10 @@ sql-eval Command Line Interface """ -import sys -import json import argparse +import json +import sys from datetime import datetime -from pathlib import Path def main(): @@ -15,9 +14,9 @@ def main(): prog='sql-eval', description='Text-to-SQL Evaluation Framework' ) - + subparsers = parser.add_subparsers(dest='command', help='Commands') - + # Run command run_parser = subparsers.add_parser('run', help='Run evaluation') run_parser.add_argument( @@ -62,14 +61,14 @@ def main(): action='store_true', help='Suppress progress output' ) - + # List datasets command - list_parser = subparsers.add_parser('list', help='List available datasets') - + subparsers.add_parser('list', help='List available datasets') + # Dataset info command info_parser = subparsers.add_parser('info', help='Show dataset information') info_parser.add_argument('dataset', help='Dataset name') - + # Compare command compare_parser = subparsers.add_parser('compare', help='Compare multiple LLMs') compare_parser.add_argument( @@ -88,13 +87,13 @@ def main(): default=None, help='Limit number of questions' ) - + args = parser.parse_args() - + if args.command is None: parser.print_help() return - + if args.command == 'list': cmd_list() elif args.command == 'info': @@ -107,12 +106,12 @@ def main(): def cmd_list(): """List available datasets""" - from .datasets import list_datasets, get_dataset_info - + from .datasets import get_dataset_info, list_datasets + datasets = list_datasets() print("\nAvailable datasets:") print("-" * 40) - + for name in datasets: info = get_dataset_info(name) print(f" {name}") @@ -125,23 +124,23 @@ def cmd_list(): def cmd_info(dataset_name: str): """Show dataset information""" from .datasets import get_dataset_info - + try: info = get_dataset_info(dataset_name) except ValueError as e: print(f"Error: {e}") sys.exit(1) - + print(f"\nDataset: {info['name']}") print("=" * 40) print(f"Questions: {info['num_questions']}") print(f"Tables: {', '.join(info['tables'])}") print(f"Has seed data: {'Yes' if info['has_seed_data'] else 'No'}") - + print("\nDifficulty breakdown:") for diff, count in info['difficulty_breakdown'].items(): print(f" {diff}: {count}") - + print("\nCategory breakdown:") for cat, count in sorted(info['category_breakdown'].items(), key=lambda x: -x[1]): print(f" {cat}: {count}") @@ -149,14 +148,14 @@ def cmd_info(dataset_name: str): def cmd_run(args): """Run evaluation""" + from .core.evaluator import Evaluator from .datasets import load_dataset from .llm_providers import get_provider - from .core.evaluator import Evaluator - + print(f"\n{'='*50}") print("sql-eval v0.1.0") print(f"{'='*50}") - + # Load dataset print(f"\nLoading dataset: {args.dataset}") try: @@ -167,11 +166,11 @@ def cmd_run(args): except ValueError as e: print(f"Error: {e}") sys.exit(1) - + # Apply limit if args.limit: test_cases = test_cases[:args.limit] - + # Get LLM provider print(f"LLM: {args.llm}" + (f" ({args.model})" if args.model else "")) try: @@ -179,7 +178,7 @@ def cmd_run(args): except Exception as e: print(f"Error initializing LLM provider: {e}") sys.exit(1) - + # Run evaluation evaluator = Evaluator( llm_provider=provider, @@ -187,21 +186,21 @@ def cmd_run(args): db_connector=db, verbose=not args.quiet ) - + print(f"\nEvaluating {len(test_cases)} questions...") - + try: report = evaluator.evaluate(test_cases, run_execution_tests=args.with_execution) except Exception as e: print(f"\nError during evaluation: {e}") sys.exit(1) - + # Output results if args.output == 'json': output_json(report, args.output_file) elif args.output == 'html': output_html(report, args.output_file) - + # Cleanup if db: db.disconnect() @@ -209,29 +208,29 @@ def cmd_run(args): def cmd_compare(args): """Compare multiple LLMs""" + from .core.evaluator import Evaluator from .datasets import load_dataset from .llm_providers import get_provider - from .core.evaluator import Evaluator - - llm_names = [l.strip() for l in args.llms.split(',')] - + + llm_names = [llm_name.strip() for llm_name in args.llms.split(',')] + print(f"\n{'='*50}") print("sql-eval - LLM Comparison") print(f"{'='*50}") print(f"Dataset: {args.dataset}") print(f"LLMs: {', '.join(llm_names)}") - + # Load dataset once test_cases, schema, db = load_dataset(args.dataset, with_db=False) - + if args.limit: test_cases = test_cases[:args.limit] - + results = {} - + for llm_name in llm_names: print(f"\n--- Evaluating: {llm_name} ---") - + try: provider = get_provider(llm_name) evaluator = Evaluator( @@ -244,14 +243,14 @@ def cmd_compare(args): except Exception as e: print(f"Error with {llm_name}: {e}") results[llm_name] = None - + # Print comparison print(f"\n{'='*60}") print("COMPARISON RESULTS") print(f"{'='*60}") print(f"{'LLM':<20} {'Exact Match':<15} {'Structural':<15} {'Latency':<10}") print("-" * 60) - + for llm_name, report in results.items(): if report: print( @@ -267,7 +266,7 @@ def cmd_compare(args): def output_json(report, filepath=None): """Output report as JSON""" data = report.to_dict() - + if filepath: with open(filepath, 'w') as f: json.dump(data, f, indent=2) @@ -283,12 +282,12 @@ def output_html(report, filepath=None): """Output report as HTML""" if not filepath: filepath = f"sql_eval_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.html" - + html = generate_html_report(report) - + with open(filepath, 'w') as f: f.write(html) - + print(f"\nHTML report saved to: {filepath}") @@ -323,7 +322,7 @@ def generate_html_report(report) -> str:

sql-eval Evaluation Report

Generated: {report.timestamp}

LLM: {report.llm_provider} ({report.llm_model})

- +

Summary

@@ -343,12 +342,12 @@ def generate_html_report(report) -> str:
Total Questions
- +

Results by Difficulty

""" - + for diff, stats in report.accuracy_by_difficulty.items(): html += f""" @@ -357,14 +356,14 @@ def generate_html_report(report) -> str: """ - + html += """
DifficultyAccuracyCorrectTotal
{diff}{stats['total']}
- +

Results by Category

""" - + for cat, stats in sorted(report.accuracy_by_category.items(), key=lambda x: -x[1]['accuracy']): html += f""" @@ -373,14 +372,14 @@ def generate_html_report(report) -> str: """ - + html += """
CategoryAccuracyCorrectTotal
{cat}{stats['total']}
- +

Individual Results

""" - + for r in report.results: status = '✓ Pass' if r.exact_match else '✗ Fail' question_short = r.question[:60] + '...' if len(r.question) > 60 else r.question @@ -391,7 +390,7 @@ def generate_html_report(report) -> str: """ - + html += """
IDQuestionResultLatency
{r.latency_ms:.0f}ms
diff --git a/sql_eval/connectors/__init__.py b/sql_eval/connectors/__init__.py index 25c0a05..adbacec 100644 --- a/sql_eval/connectors/__init__.py +++ b/sql_eval/connectors/__init__.py @@ -3,9 +3,9 @@ """ from .base import BaseConnector -from .sqlite import SQLiteConnector -from .postgresql import PostgreSQLConnector from .mysql import MySQLConnector +from .postgresql import PostgreSQLConnector +from .sqlite import SQLiteConnector def get_connector( @@ -15,12 +15,12 @@ def get_connector( ) -> BaseConnector: """ Factory function to get database connector by type - + Args: db_type: One of 'sqlite', 'postgresql', 'mysql' connection_string: Database connection string **kwargs: Additional connector-specific arguments - + Returns: Configured database connector instance """ @@ -31,18 +31,18 @@ def get_connector( 'pg': PostgreSQLConnector, 'mysql': MySQLConnector, } - + db_type = db_type.lower() - + if db_type not in connectors: available = ', '.join(set(connectors.keys())) raise ValueError( f"Unknown database type: {db_type}. " f"Available types: {available}" ) - + connector_class = connectors[db_type] - + if connection_string: return connector_class(connection_string=connection_string, **kwargs) return connector_class(**kwargs) diff --git a/sql_eval/connectors/base.py b/sql_eval/connectors/base.py index 9c249e2..d649ba3 100644 --- a/sql_eval/connectors/base.py +++ b/sql_eval/connectors/base.py @@ -3,57 +3,57 @@ """ from abc import ABC, abstractmethod -from typing import Optional, Any +from typing import Optional class BaseConnector(ABC): """Abstract base class for database connectors""" - + def __init__(self, connection_string: str = None, **kwargs): self.connection_string = connection_string self.config = kwargs self._connection = None - + @property @abstractmethod def db_type(self) -> str: """Return database type (e.g., 'postgresql', 'mysql')""" pass - + @abstractmethod def connect(self) -> None: """Establish database connection""" pass - + @abstractmethod def disconnect(self) -> None: """Close database connection""" pass - + @abstractmethod def execute(self, sql: str, params: Optional[tuple] = None) -> list[dict]: """ Execute SQL query and return results - + Args: sql: SQL query to execute params: Optional query parameters - + Returns: List of dictionaries (each dict is a row) """ pass - + @abstractmethod def get_schema(self) -> dict: """ Extract schema from database - + Returns: Dictionary with tables and columns info """ pass - + def execute_safe( self, sql: str, @@ -61,7 +61,7 @@ def execute_safe( ) -> tuple[bool, Optional[list[dict]], Optional[str]]: """ Execute SQL with timeout and error handling - + Returns: Tuple of (success, results, error_message) """ @@ -70,13 +70,13 @@ def execute_safe( return True, results, None except Exception as e: return False, None, str(e) - + def __enter__(self): self.connect() return self - + def __exit__(self, exc_type, exc_val, exc_tb): self.disconnect() - + def __repr__(self) -> str: return f"{self.__class__.__name__}()" diff --git a/sql_eval/connectors/mysql.py b/sql_eval/connectors/mysql.py index af7ff2a..997f992 100644 --- a/sql_eval/connectors/mysql.py +++ b/sql_eval/connectors/mysql.py @@ -4,19 +4,20 @@ import os from typing import Optional + from .base import BaseConnector class MySQLConnector(BaseConnector): """ MySQL database connector - + Usage: # Using connection string conn = MySQLConnector( connection_string="mysql://user:pass@localhost:3306/mydb" ) - + # Using individual parameters conn = MySQLConnector( host="localhost", @@ -26,7 +27,7 @@ class MySQLConnector(BaseConnector): password="pass" ) """ - + def __init__( self, connection_string: str = None, @@ -38,19 +39,19 @@ def __init__( **kwargs ): super().__init__(connection_string=connection_string, **kwargs) - + self.host = host or os.environ.get("MYSQL_HOST", "localhost") self.port = port or int(os.environ.get("MYSQL_PORT", 3306)) self.database = database or os.environ.get("MYSQL_DATABASE") self.user = user or os.environ.get("MYSQL_USER") self.password = password or os.environ.get("MYSQL_PASSWORD") - + self._connection = None - + @property def db_type(self) -> str: return "mysql" - + def connect(self) -> None: """Connect to MySQL database""" try: @@ -60,7 +61,7 @@ def connect(self) -> None: "mysql-connector-python not installed. " "Install with: pip install mysql-connector-python" ) - + self._connection = mysql.connector.connect( host=self.host, port=self.port, @@ -68,26 +69,26 @@ def connect(self) -> None: user=self.user, password=self.password ) - + def disconnect(self) -> None: """Close MySQL connection""" if self._connection: self._connection.close() self._connection = None - + def execute(self, sql: str, params: Optional[tuple] = None) -> list[dict]: """Execute SQL and return results as list of dicts""" if not self._connection: self.connect() - + cursor = self._connection.cursor(dictionary=True) - + try: if params: cursor.execute(sql, params) else: cursor.execute(sql) - + # Check if query returns results if cursor.description: rows = cursor.fetchall() @@ -100,30 +101,30 @@ def execute(self, sql: str, params: Optional[tuple] = None) -> list[dict]: raise e finally: cursor.close() - + def get_schema(self) -> dict: """Extract schema from MySQL database""" if not self._connection: self.connect() - + schema = {'tables': {}} - + # Get all tables tables = self.execute(""" - SELECT table_name - FROM information_schema.tables + SELECT table_name + FROM information_schema.tables WHERE table_schema = DATABASE() AND table_type = 'BASE TABLE' """) - + for table in tables: table_name = table['TABLE_NAME'] if 'TABLE_NAME' in table else table['table_name'] columns = {} foreign_keys = [] - + # Get column info col_info = self.execute(f""" - SELECT + SELECT COLUMN_NAME as column_name, DATA_TYPE as data_type, IS_NULLABLE as is_nullable, @@ -134,7 +135,7 @@ def get_schema(self) -> dict: AND table_name = '{table_name}' ORDER BY ordinal_position """) - + for col in col_info: col_name = col.get('column_name') or col.get('COLUMN_NAME') columns[col_name] = { @@ -143,10 +144,10 @@ def get_schema(self) -> dict: 'primary_key': (col.get('column_key') or col.get('COLUMN_KEY')) == 'PRI', 'default': col.get('column_default') or col.get('COLUMN_DEFAULT') } - + # Get foreign keys fk_info = self.execute(f""" - SELECT + SELECT COLUMN_NAME as column_name, REFERENCED_TABLE_NAME as ref_table, REFERENCED_COLUMN_NAME as ref_column @@ -155,7 +156,7 @@ def get_schema(self) -> dict: AND table_name = '{table_name}' AND REFERENCED_TABLE_NAME IS NOT NULL """) - + for fk in fk_info: col_name = fk.get('column_name') or fk.get('COLUMN_NAME') ref_table = fk.get('ref_table') or fk.get('REFERENCED_TABLE_NAME') @@ -164,10 +165,10 @@ def get_schema(self) -> dict: 'column': col_name, 'references': f"{ref_table}.{ref_column}" }) - + schema['tables'][table_name] = { 'columns': columns, 'foreign_keys': foreign_keys } - + return schema diff --git a/sql_eval/connectors/postgresql.py b/sql_eval/connectors/postgresql.py index be013e7..9fd184e 100644 --- a/sql_eval/connectors/postgresql.py +++ b/sql_eval/connectors/postgresql.py @@ -4,19 +4,20 @@ import os from typing import Optional + from .base import BaseConnector class PostgreSQLConnector(BaseConnector): """ PostgreSQL database connector - + Usage: # Using connection string conn = PostgreSQLConnector( connection_string="postgresql://user:pass@localhost:5432/mydb" ) - + # Using individual parameters conn = PostgreSQLConnector( host="localhost", @@ -26,7 +27,7 @@ class PostgreSQLConnector(BaseConnector): password="pass" ) """ - + def __init__( self, connection_string: str = None, @@ -38,20 +39,20 @@ def __init__( **kwargs ): super().__init__(connection_string=connection_string, **kwargs) - + # Allow individual params or connection string self.host = host or os.environ.get("PGHOST", "localhost") self.port = port or int(os.environ.get("PGPORT", 5432)) self.database = database or os.environ.get("PGDATABASE") self.user = user or os.environ.get("PGUSER") self.password = password or os.environ.get("PGPASSWORD") - + self._connection = None - + @property def db_type(self) -> str: return "postgresql" - + def connect(self) -> None: """Connect to PostgreSQL database""" try: @@ -62,7 +63,7 @@ def connect(self) -> None: "psycopg2 not installed. " "Install with: pip install psycopg2-binary" ) - + if self.connection_string: self._connection = psycopg2.connect(self.connection_string) else: @@ -73,30 +74,30 @@ def connect(self) -> None: user=self.user, password=self.password ) - + def disconnect(self) -> None: """Close PostgreSQL connection""" if self._connection: self._connection.close() self._connection = None - + def execute(self, sql: str, params: Optional[tuple] = None) -> list[dict]: """Execute SQL and return results as list of dicts""" if not self._connection: self.connect() - + import psycopg2.extras - + cursor = self._connection.cursor( cursor_factory=psycopg2.extras.RealDictCursor ) - + try: if params: cursor.execute(sql, params) else: cursor.execute(sql) - + # Check if query returns results if cursor.description: rows = cursor.fetchall() @@ -109,47 +110,47 @@ def execute(self, sql: str, params: Optional[tuple] = None) -> list[dict]: raise e finally: cursor.close() - + def get_schema(self) -> dict: """Extract schema from PostgreSQL database""" if not self._connection: self.connect() - + schema = {'tables': {}} - + # Get all tables tables = self.execute(""" - SELECT table_name - FROM information_schema.tables - WHERE table_schema = 'public' + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public' AND table_type = 'BASE TABLE' """) - + for table in tables: table_name = table['table_name'] columns = {} foreign_keys = [] - + # Get column info col_info = self.execute(""" - SELECT + SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns - WHERE table_schema = 'public' + WHERE table_schema = 'public' AND table_name = %s ORDER BY ordinal_position """, (table_name,)) - + for col in col_info: columns[col['column_name']] = { 'type': col['data_type'].upper(), 'nullable': col['is_nullable'] == 'YES', 'default': col['column_default'] } - + # Get primary keys pk_info = self.execute(""" SELECT kcu.column_name @@ -160,11 +161,11 @@ def get_schema(self) -> dict: AND tc.table_name = %s AND tc.constraint_type = 'PRIMARY KEY' """, (table_name,)) - + for pk in pk_info: if pk['column_name'] in columns: columns[pk['column_name']]['primary_key'] = True - + # Get foreign keys fk_info = self.execute(""" SELECT @@ -180,16 +181,16 @@ def get_schema(self) -> dict: AND tc.table_name = %s AND tc.constraint_type = 'FOREIGN KEY' """, (table_name,)) - + for fk in fk_info: foreign_keys.append({ 'column': fk['column_name'], 'references': f"{fk['foreign_table']}.{fk['foreign_column']}" }) - + schema['tables'][table_name] = { 'columns': columns, 'foreign_keys': foreign_keys } - + return schema diff --git a/sql_eval/connectors/sqlite.py b/sql_eval/connectors/sqlite.py index 20f2791..15aaf83 100644 --- a/sql_eval/connectors/sqlite.py +++ b/sql_eval/connectors/sqlite.py @@ -5,28 +5,29 @@ """ import sqlite3 -from typing import Optional from pathlib import Path +from typing import Optional + from .base import BaseConnector class SQLiteConnector(BaseConnector): """ SQLite database connector - + Great for: - Local testing without setup - Bundled sample databases - Fast evaluation runs - + Usage: # In-memory database conn = SQLiteConnector(":memory:") - + # File-based database conn = SQLiteConnector("path/to/database.db") """ - + def __init__( self, database: str = ":memory:", @@ -37,11 +38,11 @@ def __init__( self.database = database self.timeout = timeout self._connection: Optional[sqlite3.Connection] = None - + @property def db_type(self) -> str: return "sqlite" - + def connect(self) -> None: """Connect to SQLite database""" self._connection = sqlite3.connect( @@ -52,26 +53,26 @@ def connect(self) -> None: self._connection.execute("PRAGMA foreign_keys = ON") # Return rows as dictionaries self._connection.row_factory = sqlite3.Row - + def disconnect(self) -> None: """Close SQLite connection""" if self._connection: self._connection.close() self._connection = None - + def execute(self, sql: str, params: Optional[tuple] = None) -> list[dict]: """Execute SQL and return results as list of dicts""" if not self._connection: self.connect() - + cursor = self._connection.cursor() - + try: if params: cursor.execute(sql, params) else: cursor.execute(sql) - + # Check if query returns results if cursor.description: rows = cursor.fetchall() @@ -82,32 +83,32 @@ def execute(self, sql: str, params: Optional[tuple] = None) -> list[dict]: return [] finally: cursor.close() - + def execute_script(self, sql_script: str) -> None: """Execute multiple SQL statements""" if not self._connection: self.connect() - + self._connection.executescript(sql_script) self._connection.commit() - + def get_schema(self) -> dict: """Extract schema from SQLite database""" if not self._connection: self.connect() - + schema = {'tables': {}} - + # Get all tables tables = self.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'" ) - + for table in tables: table_name = table['name'] columns = {} foreign_keys = [] - + # Get column info col_info = self.execute(f"PRAGMA table_info('{table_name}')") for col in col_info: @@ -117,7 +118,7 @@ def get_schema(self) -> dict: 'primary_key': bool(col['pk']), 'default': col['dflt_value'] } - + # Get foreign keys fk_info = self.execute(f"PRAGMA foreign_key_list('{table_name}')") for fk in fk_info: @@ -125,26 +126,26 @@ def get_schema(self) -> dict: 'column': fk['from'], 'references': f"{fk['table']}.{fk['to']}" }) - + schema['tables'][table_name] = { 'columns': columns, 'foreign_keys': foreign_keys } - + return schema - + def load_schema_file(self, filepath: str) -> None: """Load schema from SQL file""" with open(filepath, 'r') as f: schema_sql = f.read() self.execute_script(schema_sql) - + def load_seed_data(self, filepath: str) -> None: """Load seed data from SQL file""" with open(filepath, 'r') as f: seed_sql = f.read() self.execute_script(seed_sql) - + @classmethod def from_files( cls, @@ -154,20 +155,20 @@ def from_files( ) -> "SQLiteConnector": """ Create connector and load schema/seed data - + Args: schema_file: Path to SQL file with CREATE TABLE statements seed_file: Optional path to SQL file with INSERT statements database: Database path (default: in-memory) - + Returns: Configured SQLiteConnector """ connector = cls(database=database) connector.connect() connector.load_schema_file(schema_file) - + if seed_file and Path(seed_file).exists(): connector.load_seed_data(seed_file) - + return connector diff --git a/sql_eval/core/__init__.py b/sql_eval/core/__init__.py index dfb5159..03e9ac3 100644 --- a/sql_eval/core/__init__.py +++ b/sql_eval/core/__init__.py @@ -2,25 +2,25 @@ Core evaluation components """ +from .evaluator import Evaluator from .models import ( - QueryStatus, - Difficulty, Category, - TableSchema, DatabaseSchema, + Difficulty, EvaluationCase, - EvaluationResult, EvaluationReport, + EvaluationResult, + FailurePattern, PartialScores, - FailurePattern + QueryStatus, + TableSchema, ) from .schema_loader import SchemaLoader, SchemaValidator -from .sql_parser import SQLParser, SQLComparator -from .evaluator import Evaluator +from .sql_parser import SQLComparator, SQLParser __all__ = [ "QueryStatus", - "Difficulty", + "Difficulty", "Category", "TableSchema", "DatabaseSchema", diff --git a/sql_eval/core/evaluator.py b/sql_eval/core/evaluator.py index 03544dd..91aba96 100644 --- a/sql_eval/core/evaluator.py +++ b/sql_eval/core/evaluator.py @@ -3,38 +3,37 @@ """ import time +from collections import defaultdict from datetime import datetime from typing import Optional -from collections import defaultdict +from ..connectors.base import BaseConnector +from ..llm_providers.base import BaseLLMProvider from .models import ( DatabaseSchema, EvaluationCase, - EvaluationResult, EvaluationReport, - PartialScores, + EvaluationResult, FailurePattern, - QueryStatus + QueryStatus, ) -from .sql_parser import SQLParser, SQLComparator -from ..llm_providers.base import BaseLLMProvider -from ..connectors.base import BaseConnector +from .sql_parser import SQLComparator, SQLParser class Evaluator: """ Main evaluation engine for Text-to-SQL systems - + Usage: from sql_eval import Evaluator from sql_eval.llm_providers import OpenAIProvider - + provider = OpenAIProvider() evaluator = Evaluator(llm_provider=provider, schema=schema) report = evaluator.evaluate(test_cases) print(report.get_summary()) """ - + def __init__( self, llm_provider: BaseLLMProvider, @@ -50,7 +49,7 @@ def __init__( self.verbose = verbose self.comparator = SQLComparator() self.parser = SQLParser() - + def evaluate( self, test_cases: list[EvaluationCase], @@ -60,13 +59,13 @@ def evaluate( """Run full evaluation on test cases""" if not self.schema: raise ValueError("Schema is required for evaluation") - + examples = few_shot_examples or self.few_shot_examples results = [] - + if self.verbose: self._print_header(len(test_cases)) - + for i, case in enumerate(test_cases): result = self._evaluate_single( case, @@ -74,17 +73,17 @@ def evaluate( examples=examples ) results.append(result) - + if self.verbose: self._print_progress(i + 1, len(test_cases), result) - + report = self._generate_report(results) - + if self.verbose: print("\n" + report.get_summary()) - + return report - + def evaluate_single( self, question: str, @@ -103,7 +102,7 @@ def evaluate_single( run_execution=run_execution and self.db is not None, examples=examples or self.few_shot_examples ) - + def _evaluate_single( self, case: EvaluationCase, @@ -115,7 +114,7 @@ def _evaluate_single( generated_sql = "" status = QueryStatus.SUCCESS error_msg = None - + try: generated_sql = self.llm.generate_sql( question=case.natural_language_question, @@ -125,27 +124,27 @@ def _evaluate_single( except Exception as e: status = QueryStatus.GENERATION_ERROR error_msg = str(e) - + latency_ms = (time.time() - start_time) * 1000 - + exact_match = False if status == QueryStatus.SUCCESS: exact_match = self.comparator.exact_match( generated_sql, case.ground_truth_sql ) - + partial_scores = None if status == QueryStatus.SUCCESS: partial_scores = self.comparator.structural_match( generated_sql, case.ground_truth_sql ) - + execution_match = False gt_result = None gen_result = None - + if run_execution and status == QueryStatus.SUCCESS: execution_match, gt_result, gen_result, exec_error = self._compare_execution( generated_sql, @@ -154,7 +153,7 @@ def _evaluate_single( if exec_error and not error_msg: error_msg = exec_error status = QueryStatus.EXECUTION_ERROR - + return EvaluationResult( question_id=case.question_id, question=case.natural_language_question, @@ -171,7 +170,7 @@ def _evaluate_single( difficulty=case.difficulty, category=case.category ) - + def _compare_execution( self, generated_sql: str, @@ -185,7 +184,7 @@ def _compare_execution( return match, gt_result, gen_result, None except Exception as e: return False, None, None, str(e) - + def _results_equal( self, result1: list[dict], @@ -194,38 +193,38 @@ def _results_equal( """Compare query results (order-insensitive by default)""" if result1 is None or result2 is None: return result1 == result2 - + if len(result1) != len(result2): return False - + if not result1: return True - + def normalize_row(row): return tuple(sorted((str(k).lower(), str(v)) for k, v in row.items())) - + set1 = set(normalize_row(r) for r in result1) set2 = set(normalize_row(r) for r in result2) - + return set1 == set2 - + def _generate_report(self, results: list[EvaluationResult]) -> EvaluationReport: """Generate summary report from results""" total = len(results) - + exact_matches = sum(1 for r in results if r.exact_match) execution_matches = sum(1 for r in results if r.execution_match) structural_matches = sum( - 1 for r in results + 1 for r in results if r.partial_scores and r.partial_scores.overall_score >= 0.8 ) - + avg_latency = sum(r.latency_ms for r in results) / total if total else 0 - + accuracy_by_difficulty = self._group_accuracy(results, 'difficulty') accuracy_by_category = self._group_accuracy(results, 'category') failure_patterns = self._analyze_failures(results) - + return EvaluationReport( total_questions=total, exact_match_accuracy=exact_matches / total if total else 0, @@ -240,7 +239,7 @@ def _generate_report(self, results: list[EvaluationResult]) -> EvaluationReport: llm_model=self.llm.model, timestamp=datetime.now().isoformat() ) - + def _group_accuracy( self, results: list[EvaluationResult], @@ -248,55 +247,55 @@ def _group_accuracy( ) -> dict[str, dict]: """Group accuracy by difficulty or category""" groups = defaultdict(lambda: {'correct': 0, 'total': 0}) - + for r in results: key = getattr(r, group_by, 'unknown') groups[key]['total'] += 1 if r.exact_match: groups[key]['correct'] += 1 - + for key in groups: total = groups[key]['total'] correct = groups[key]['correct'] groups[key]['accuracy'] = correct / total if total else 0 - + return dict(groups) - + def _analyze_failures(self, results: list[EvaluationResult]) -> list[FailurePattern]: """Analyze common failure patterns""" patterns = defaultdict(lambda: {'count': 0, 'examples': []}) - + for r in results: if r.exact_match or r.status != QueryStatus.SUCCESS: continue - + if not r.partial_scores: continue - + if not r.partial_scores.tables_match: patterns['wrong_tables']['count'] += 1 patterns['wrong_tables']['examples'].append(r.question_id) - + if not r.partial_scores.columns_match: patterns['wrong_columns']['count'] += 1 patterns['wrong_columns']['examples'].append(r.question_id) - + if not r.partial_scores.joins_match: patterns['wrong_joins']['count'] += 1 patterns['wrong_joins']['examples'].append(r.question_id) - + if not r.partial_scores.where_match: patterns['wrong_conditions']['count'] += 1 patterns['wrong_conditions']['examples'].append(r.question_id) - + if not r.partial_scores.groupby_match: patterns['wrong_groupby']['count'] += 1 patterns['wrong_groupby']['examples'].append(r.question_id) - + if not r.partial_scores.aggregations_match: patterns['wrong_aggregations']['count'] += 1 patterns['wrong_aggregations']['examples'].append(r.question_id) - + pattern_descriptions = { 'wrong_tables': 'Incorrect table selection', 'wrong_columns': 'Missing or incorrect columns in SELECT', @@ -305,7 +304,7 @@ def _analyze_failures(self, results: list[EvaluationResult]) -> list[FailurePatt 'wrong_groupby': 'Missing or incorrect GROUP BY', 'wrong_aggregations': 'Wrong aggregation functions' } - + failure_patterns = [] for name, data in sorted(patterns.items(), key=lambda x: -x[1]['count']): if data['count'] > 0: @@ -315,19 +314,19 @@ def _analyze_failures(self, results: list[EvaluationResult]) -> list[FailurePatt count=data['count'], example_question_ids=data['examples'][:5] )) - + return failure_patterns - + def _print_header(self, total: int): """Print evaluation header""" print("\n" + "=" * 50) - print(f"sql-eval v0.1.0") + print("sql-eval v0.1.0") print("=" * 50) print(f"LLM: {self.llm.provider_name} ({self.llm.model})") print(f"Questions: {total}") print(f"Execution tests: {'Yes' if self.db else 'No'}") print("-" * 50) - + def _print_progress(self, current: int, total: int, result: EvaluationResult): """Print progress indicator""" status_icon = "✓" if result.exact_match else "✗" @@ -339,7 +338,7 @@ def _print_progress(self, current: int, total: int, result: EvaluationResult): class DatasetLoader: """Load test cases from various formats""" - + @staticmethod def from_csv( filepath: str, @@ -351,7 +350,7 @@ def from_csv( ) -> list[EvaluationCase]: """Load test cases from CSV file""" import csv - + cases = [] with open(filepath, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) @@ -364,15 +363,15 @@ def from_csv( category=row.get(category_col, "complex") )) return cases - + @staticmethod def from_json(filepath: str) -> list[EvaluationCase]: """Load test cases from JSON file""" import json - + with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f) - + cases = [] for item in data: cases.append(EvaluationCase( @@ -384,7 +383,7 @@ def from_json(filepath: str) -> list[EvaluationCase]: tags=item.get('tags', []) )) return cases - + @staticmethod def from_dict_list(data: list[dict]) -> list[EvaluationCase]: """Load test cases from list of dictionaries""" diff --git a/sql_eval/core/models.py b/sql_eval/core/models.py index c0eb13d..07ab000 100644 --- a/sql_eval/core/models.py +++ b/sql_eval/core/models.py @@ -3,8 +3,8 @@ """ from dataclasses import dataclass, field -from typing import Optional, Any from enum import Enum +from typing import Optional class QueryStatus(Enum): @@ -52,7 +52,7 @@ class DatabaseSchema: """Complete database schema""" tables: dict[str, TableSchema] relationships: list[dict] = field(default_factory=list) - + def to_prompt_string(self) -> str: """Convert schema to string format for LLM prompts""" lines = [] @@ -68,15 +68,15 @@ def to_prompt_string(self) -> str: if col_info.get('nullable') is False: col_str += " NOT NULL" lines.append(col_str) - + if table.foreign_keys: lines.append(" Foreign Keys:") for fk in table.foreign_keys: lines.append(f" - {fk['column']} -> {fk['references']}") lines.append("") - + return "\n".join(lines) - + def to_ddl_string(self) -> str: """Convert schema to CREATE TABLE statements""" statements = [] @@ -91,13 +91,13 @@ def to_ddl_string(self) -> str: if col_info.get('default') is not None: col_def += f" DEFAULT {col_info['default']}" cols.append(col_def) - + for fk in table.foreign_keys: cols.append(f" FOREIGN KEY ({fk['column']}) REFERENCES {fk['references']}") - + stmt = f"CREATE TABLE {table_name} (\n" + ",\n".join(cols) + "\n);" statements.append(stmt) - + return "\n\n".join(statements) @@ -124,7 +124,7 @@ class PartialScores: orderby_match: bool = False aggregations_match: bool = False overall_score: float = 0.0 - + def to_dict(self) -> dict: return { "tables_match": self.tables_match, @@ -146,22 +146,22 @@ class EvaluationResult: ground_truth_sql: str generated_sql: str status: QueryStatus - + # Core metrics exact_match: bool = False execution_match: bool = False partial_scores: Optional[PartialScores] = None - + # Execution details ground_truth_result: Optional[list] = None generated_result: Optional[list] = None error_message: Optional[str] = None latency_ms: float = 0.0 - + # Metadata from test case difficulty: str = "medium" category: str = "complex" - + def to_dict(self) -> dict: return { "question_id": self.question_id, @@ -198,23 +198,23 @@ class EvaluationReport: execution_accuracy: float structural_accuracy: float avg_latency_ms: float - + # Breakdown accuracy_by_category: dict[str, dict] accuracy_by_difficulty: dict[str, dict] - + # Error analysis common_failure_patterns: list[FailurePattern] - + # Individual results results: list[EvaluationResult] - + # Metadata llm_provider: str = "" llm_model: str = "" dataset_name: str = "" timestamp: str = "" - + def get_summary(self) -> str: """Get a text summary of the report""" lines = [ @@ -229,24 +229,24 @@ def get_summary(self) -> str: "", "BY DIFFICULTY:", ] - + for diff, stats in self.accuracy_by_difficulty.items(): lines.append(f" {diff}: {stats['accuracy']:.1%} ({stats['correct']}/{stats['total']})") - + lines.append("") lines.append("BY CATEGORY:") for cat, stats in self.accuracy_by_category.items(): lines.append(f" {cat}: {stats['accuracy']:.1%} ({stats['correct']}/{stats['total']})") - + if self.common_failure_patterns: lines.append("") lines.append("COMMON FAILURE PATTERNS:") for i, pattern in enumerate(self.common_failure_patterns[:5], 1): lines.append(f" {i}. {pattern.pattern_name} ({pattern.count} cases)") - + lines.append("=" * 50) return "\n".join(lines) - + def to_dict(self) -> dict: return { "summary": { diff --git a/sql_eval/core/schema_loader.py b/sql_eval/core/schema_loader.py index 375b4d2..d5b0189 100644 --- a/sql_eval/core/schema_loader.py +++ b/sql_eval/core/schema_loader.py @@ -2,62 +2,63 @@ Schema loading from various sources """ -import re import json +import re from pathlib import Path from typing import Optional, Union -from .models import TableSchema, DatabaseSchema + +from .models import DatabaseSchema, TableSchema class SchemaLoader: """Load database schema from various sources""" - + @staticmethod def from_sql_file(filepath: Union[str, Path]) -> DatabaseSchema: """ Parse CREATE TABLE statements from SQL file - + Args: filepath: Path to .sql file with CREATE TABLE statements - + Returns: DatabaseSchema object """ with open(filepath, 'r') as f: sql_content = f.read() return SchemaLoader.from_ddl(sql_content) - + @staticmethod def from_ddl(ddl_string: str) -> DatabaseSchema: """ Parse CREATE TABLE statements from DDL string - + Args: ddl_string: String containing CREATE TABLE statements - + Returns: DatabaseSchema object """ tables = {} relationships = [] - + # Find all CREATE TABLE statements create_pattern = r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?[`"\']?(\w+)[`"\']?\s*\((.*?)\);' matches = re.findall(create_pattern, ddl_string, re.IGNORECASE | re.DOTALL) - + for table_name, columns_str in matches: table_name = table_name.lower() columns = {} foreign_keys = [] - + # Split by comma, but handle nested parentheses column_defs = SchemaLoader._split_column_definitions(columns_str) - + for col_def in column_defs: col_def = col_def.strip() if not col_def: continue - + # Check for constraints if col_def.upper().startswith('PRIMARY KEY'): continue @@ -82,20 +83,20 @@ def from_ddl(ddl_string: str) -> DatabaseSchema: col_info = SchemaLoader._parse_column(col_def) if col_info: columns[col_info['name']] = col_info['details'] - + tables[table_name] = TableSchema( name=table_name, columns=columns, foreign_keys=foreign_keys ) - + return DatabaseSchema(tables=tables, relationships=relationships) - + @staticmethod def from_json(filepath: Union[str, Path]) -> DatabaseSchema: """ Load schema from JSON file - + Expected format: { "tables": { @@ -111,7 +112,7 @@ def from_json(filepath: Union[str, Path]) -> DatabaseSchema: """ with open(filepath, 'r') as f: data = json.load(f) - + tables = {} for table_name, table_data in data.get('tables', {}).items(): tables[table_name] = TableSchema( @@ -120,12 +121,12 @@ def from_json(filepath: Union[str, Path]) -> DatabaseSchema: foreign_keys=table_data.get('foreign_keys', []), description=table_data.get('description') ) - + return DatabaseSchema( tables=tables, relationships=data.get('relationships', []) ) - + @staticmethod def from_dict(schema_dict: dict) -> DatabaseSchema: """Load schema from a dictionary""" @@ -140,19 +141,19 @@ def from_dict(schema_dict: dict) -> DatabaseSchema: foreign_keys=table_data.get('foreign_keys', []), description=table_data.get('description') ) - + return DatabaseSchema( tables=tables, relationships=schema_dict.get('relationships', []) ) - + @staticmethod def _split_column_definitions(columns_str: str) -> list[str]: """Split column definitions handling nested parentheses""" result = [] current = "" depth = 0 - + for char in columns_str: if char == '(': depth += 1 @@ -165,57 +166,57 @@ def _split_column_definitions(columns_str: str) -> list[str]: current = "" else: current += char - + if current.strip(): result.append(current.strip()) - + return result - + @staticmethod def _parse_column(col_def: str) -> Optional[dict]: """Parse a single column definition""" # Match: column_name TYPE [(size)] [constraints...] pattern = r'^[`"\']?(\w+)[`"\']?\s+(\w+)(?:\s*\(([^)]+)\))?\s*(.*)?$' match = re.match(pattern, col_def.strip(), re.IGNORECASE) - + if not match: return None - + col_name = match.group(1).lower() col_type = match.group(2).upper() col_size = match.group(3) constraints = match.group(4) or "" - + # Build full type string if col_size: full_type = f"{col_type}({col_size})" else: full_type = col_type - + details = { 'type': full_type, 'nullable': 'NOT NULL' not in constraints.upper(), 'primary_key': 'PRIMARY KEY' in constraints.upper() } - + # Check for default value default_match = re.search(r'DEFAULT\s+([^\s,]+)', constraints, re.IGNORECASE) if default_match: details['default'] = default_match.group(1) - + # Check for references (inline foreign key) ref_match = re.search(r'REFERENCES\s+(\w+)\s*\(\s*(\w+)\s*\)', constraints, re.IGNORECASE) if ref_match: details['references'] = f"{ref_match.group(1)}.{ref_match.group(2)}" - + return {'name': col_name, 'details': details} - + @staticmethod def _parse_foreign_key(fk_def: str) -> Optional[dict]: """Parse a foreign key constraint""" pattern = r'FOREIGN\s+KEY\s*\(\s*[`"\']?(\w+)[`"\']?\s*\)\s*REFERENCES\s+[`"\']?(\w+)[`"\']?\s*\(\s*[`"\']?(\w+)[`"\']?\s*\)' match = re.search(pattern, fk_def, re.IGNORECASE) - + if match: return { 'column': match.group(1).lower(), @@ -226,33 +227,33 @@ def _parse_foreign_key(fk_def: str) -> Optional[dict]: class SchemaValidator: """Validate database schema for common issues""" - + @staticmethod def validate(schema: DatabaseSchema) -> list[str]: """ Validate schema and return list of warnings/errors """ issues = [] - + # Check for empty schema if not schema.tables: issues.append("ERROR: Schema has no tables") return issues - + # Check each table for table_name, table in schema.tables.items(): # Check for empty tables if not table.columns: issues.append(f"WARNING: Table '{table_name}' has no columns") - + # Check for primary key has_pk = any( - col.get('primary_key', False) + col.get('primary_key', False) for col in table.columns.values() ) if not has_pk: issues.append(f"WARNING: Table '{table_name}' has no primary key") - + # Validate foreign keys for fk in table.foreign_keys: ref_table = fk['references'].split('.')[0] @@ -261,5 +262,5 @@ def validate(schema: DatabaseSchema) -> list[str]: f"ERROR: Foreign key in '{table_name}' references " f"non-existent table '{ref_table}'" ) - + return issues diff --git a/sql_eval/core/sql_parser.py b/sql_eval/core/sql_parser.py index aaf43a1..560dd11 100644 --- a/sql_eval/core/sql_parser.py +++ b/sql_eval/core/sql_parser.py @@ -4,12 +4,11 @@ import re from typing import Optional + from .models import PartialScores try: import sqlparse - from sqlparse.sql import IdentifierList, Identifier, Where, Parenthesis - from sqlparse.tokens import Keyword, DML, Wildcard HAS_SQLPARSE = True except ImportError: HAS_SQLPARSE = False @@ -17,7 +16,7 @@ class SQLParser: """Parse SQL queries and extract components""" - + @staticmethod def normalize(sql: str) -> str: """ @@ -29,7 +28,7 @@ def normalize(sql: str) -> str: """ if not sql: return "" - + if HAS_SQLPARSE: normalized = sqlparse.format( sql, @@ -51,27 +50,27 @@ def normalize(sql: str) -> str: ] for kw in keywords: normalized = re.sub( - rf'\b{kw}\b', - kw, - normalized, + rf'\b{kw}\b', + kw, + normalized, flags=re.IGNORECASE ) - + return normalized.strip().rstrip(';').strip() - + @staticmethod def extract_components(sql: str) -> dict: """ Extract SQL components for structural comparison - + Returns: - dict with keys: tables, columns, joins, where_conditions, + dict with keys: tables, columns, joins, where_conditions, group_by, order_by, aggregations, limit """ sql = sql.strip() if not sql: return SQLParser._empty_components() - + components = { 'tables': [], 'columns': [], @@ -84,38 +83,38 @@ def extract_components(sql: str) -> dict: 'distinct': False, 'subqueries': [] } - + sql_upper = sql.upper() - + # Check for DISTINCT components['distinct'] = 'DISTINCT' in sql_upper - + # Extract tables from FROM clause components['tables'] = SQLParser._extract_tables(sql) - + # Extract columns from SELECT clause components['columns'] = SQLParser._extract_select_columns(sql) - + # Extract JOINs components['joins'] = SQLParser._extract_joins(sql) - + # Extract WHERE conditions components['where_conditions'] = SQLParser._extract_where(sql) - + # Extract GROUP BY components['group_by'] = SQLParser._extract_group_by(sql) - + # Extract ORDER BY components['order_by'] = SQLParser._extract_order_by(sql) - + # Extract aggregations components['aggregations'] = SQLParser._extract_aggregations(sql) - + # Extract LIMIT components['limit'] = SQLParser._extract_limit(sql) - + return components - + @staticmethod def _empty_components() -> dict: return { @@ -130,12 +129,12 @@ def _empty_components() -> dict: 'distinct': False, 'subqueries': [] } - + @staticmethod def _extract_tables(sql: str) -> list[str]: """Extract table names from FROM clause and JOINs""" tables = set() - + # Match FROM table from_match = re.search( r'\bFROM\s+([`"\']?\w+[`"\']?(?:\s+(?:AS\s+)?\w+)?)', @@ -145,38 +144,38 @@ def _extract_tables(sql: str) -> list[str]: if from_match: table = from_match.group(1).split()[0].strip('`"\'').lower() tables.add(table) - + # Match JOIN tables join_pattern = r'\bJOIN\s+([`"\']?\w+[`"\']?)' for match in re.finditer(join_pattern, sql, re.IGNORECASE): tables.add(match.group(1).strip('`"\'').lower()) - + return sorted(list(tables)) - + @staticmethod def _extract_select_columns(sql: str) -> list[str]: """Extract column names from SELECT clause""" columns = [] - + # Find SELECT ... FROM select_match = re.search( r'\bSELECT\s+(.*?)\s+FROM\b', sql, re.IGNORECASE | re.DOTALL ) - + if not select_match: return columns - + select_clause = select_match.group(1) - + # Handle SELECT * if select_clause.strip() == '*': return ['*'] - + # Split by comma (handling nested parentheses) parts = SQLParser._split_by_comma(select_clause) - + for part in parts: part = part.strip() # Get alias or column name @@ -188,42 +187,42 @@ def _extract_select_columns(sql: str) -> list[str]: identifiers = re.findall(r'[`"\']?(\w+)[`"\']?', part) if identifiers: columns.append(identifiers[-1].lower()) - + return columns - + @staticmethod def _extract_joins(sql: str) -> list[dict]: """Extract JOIN information""" joins = [] - + join_pattern = r'(LEFT\s+|RIGHT\s+|INNER\s+|OUTER\s+|FULL\s+|CROSS\s+)?JOIN\s+([`"\']?\w+[`"\']?)(?:\s+(?:AS\s+)?(\w+))?\s+ON\s+([^JOIN]+?)(?=(?:LEFT|RIGHT|INNER|OUTER|FULL|CROSS)?\s*JOIN|\bWHERE\b|\bGROUP\b|\bORDER\b|\bLIMIT\b|$)' - + for match in re.finditer(join_pattern, sql, re.IGNORECASE | re.DOTALL): join_type = (match.group(1) or '').strip().upper() or 'INNER' table = match.group(2).strip('`"\'').lower() alias = match.group(3).lower() if match.group(3) else None condition = match.group(4).strip() - + joins.append({ 'type': join_type + ' JOIN', 'table': table, 'alias': alias, 'condition': SQLParser._normalize_condition(condition) }) - + return joins - + @staticmethod def _extract_where(sql: str) -> list[str]: """Extract WHERE conditions""" conditions = [] - + where_match = re.search( r'\bWHERE\s+(.*?)(?:\bGROUP\s+BY\b|\bORDER\s+BY\b|\bLIMIT\b|\bHAVING\b|$)', sql, re.IGNORECASE | re.DOTALL ) - + if where_match: where_clause = where_match.group(1).strip() # Split by AND/OR (simplified) @@ -232,20 +231,20 @@ def _extract_where(sql: str) -> list[str]: part = part.strip() if part: conditions.append(SQLParser._normalize_condition(part)) - + return conditions - + @staticmethod def _extract_group_by(sql: str) -> list[str]: """Extract GROUP BY columns""" columns = [] - + match = re.search( r'\bGROUP\s+BY\s+(.*?)(?:\bHAVING\b|\bORDER\s+BY\b|\bLIMIT\b|$)', sql, re.IGNORECASE | re.DOTALL ) - + if match: group_clause = match.group(1).strip() parts = SQLParser._split_by_comma(group_clause) @@ -255,20 +254,20 @@ def _extract_group_by(sql: str) -> list[str]: identifiers = re.findall(r'[`"\']?(\w+)[`"\']?', part) if identifiers: columns.append(identifiers[-1]) - + return columns - + @staticmethod def _extract_order_by(sql: str) -> list[dict]: """Extract ORDER BY columns""" order = [] - + match = re.search( r'\bORDER\s+BY\s+(.*?)(?:\bLIMIT\b|$)', sql, re.IGNORECASE | re.DOTALL ) - + if match: order_clause = match.group(1).strip() parts = SQLParser._split_by_comma(order_clause) @@ -282,20 +281,20 @@ def _extract_order_by(sql: str) -> list[dict]: if col_match: col = col_match.group(1).strip('`"\'').lower() order.append({'column': col, 'direction': direction}) - + return order - + @staticmethod def _extract_aggregations(sql: str) -> list[str]: """Extract aggregation functions used""" aggs = set() - + agg_pattern = r'\b(COUNT|SUM|AVG|MIN|MAX|GROUP_CONCAT|STRING_AGG)\s*\(' for match in re.finditer(agg_pattern, sql, re.IGNORECASE): aggs.add(match.group(1).upper()) - + return sorted(list(aggs)) - + @staticmethod def _extract_limit(sql: str) -> Optional[int]: """Extract LIMIT value""" @@ -303,14 +302,14 @@ def _extract_limit(sql: str) -> Optional[int]: if match: return int(match.group(1)) return None - + @staticmethod def _split_by_comma(text: str) -> list[str]: """Split by comma, respecting parentheses""" parts = [] current = "" depth = 0 - + for char in text: if char == '(': depth += 1 @@ -323,12 +322,12 @@ def _split_by_comma(text: str) -> list[str]: current = "" else: current += char - + if current.strip(): parts.append(current.strip()) - + return parts - + @staticmethod def _normalize_condition(condition: str) -> str: """Normalize a condition for comparison""" @@ -349,28 +348,28 @@ def _normalize_condition(condition: str) -> str: class SQLComparator: """Compare SQL queries at multiple levels""" - + def __init__(self): self.parser = SQLParser() - + def exact_match(self, sql1: str, sql2: str) -> bool: """Check if two SQL queries are identical (normalized)""" norm1 = SQLParser.normalize(sql1) norm2 = SQLParser.normalize(sql2) return norm1 == norm2 - + def structural_match(self, sql1: str, sql2: str) -> PartialScores: """ Compare SQL structure and return partial scores """ comp1 = SQLParser.extract_components(sql1) comp2 = SQLParser.extract_components(sql2) - + scores = PartialScores() - + # Compare tables (order-insensitive) scores.tables_match = set(comp1['tables']) == set(comp2['tables']) - + # Compare columns (order-insensitive, handle * expansion) cols1 = set(comp1['columns']) cols2 = set(comp2['columns']) @@ -379,28 +378,28 @@ def structural_match(self, sql1: str, sql2: str) -> PartialScores: scores.columns_match = bool(cols1 & cols2) or (cols1 == cols2) else: scores.columns_match = cols1 == cols2 - + # Compare JOINs scores.joins_match = self._compare_joins(comp1['joins'], comp2['joins']) - + # Compare WHERE conditions (relaxed matching) scores.where_match = self._compare_conditions( comp1['where_conditions'], comp2['where_conditions'] ) - + # Compare GROUP BY scores.groupby_match = set(comp1['group_by']) == set(comp2['group_by']) - + # Compare ORDER BY scores.orderby_match = self._compare_order_by( comp1['order_by'], comp2['order_by'] ) - + # Compare aggregations scores.aggregations_match = set(comp1['aggregations']) == set(comp2['aggregations']) - + # Calculate overall score weights = { 'tables_match': 0.20, @@ -411,66 +410,66 @@ def structural_match(self, sql1: str, sql2: str) -> PartialScores: 'orderby_match': 0.05, 'aggregations_match': 0.10 } - + scores.overall_score = sum( weights[key] * (1.0 if getattr(scores, key) else 0.0) for key in weights ) - + return scores - + def _compare_joins(self, joins1: list, joins2: list) -> bool: """Compare JOIN clauses""" if len(joins1) != len(joins2): return False - + if not joins1: return True - + # Compare tables being joined (order might differ) tables1 = {j['table'] for j in joins1} tables2 = {j['table'] for j in joins2} - + return tables1 == tables2 - + def _compare_conditions(self, conds1: list, conds2: list) -> bool: """Compare WHERE conditions (relaxed matching)""" if not conds1 and not conds2: return True - + if len(conds1) != len(conds2): return False - + # Normalize and compare as sets set1 = set(conds1) set2 = set(conds2) - + return set1 == set2 - + def _compare_order_by(self, order1: list, order2: list) -> bool: """Compare ORDER BY clauses""" if len(order1) != len(order2): return False - + if not order1: return True - + # Order matters for ORDER BY for o1, o2 in zip(order1, order2): if o1['column'] != o2['column']: return False if o1['direction'] != o2['direction']: return False - + return True - + def get_differences(self, sql1: str, sql2: str) -> list[str]: """Get human-readable list of differences""" differences = [] - + comp1 = SQLParser.extract_components(sql1) comp2 = SQLParser.extract_components(sql2) - + # Check tables tables1 = set(comp1['tables']) tables2 = set(comp2['tables']) @@ -481,7 +480,7 @@ def get_differences(self, sql1: str, sql2: str) -> list[str]: differences.append(f"Missing tables: {missing}") if extra: differences.append(f"Extra tables: {extra}") - + # Check columns cols1 = set(comp1['columns']) cols2 = set(comp2['columns']) @@ -492,23 +491,23 @@ def get_differences(self, sql1: str, sql2: str) -> list[str]: differences.append(f"Missing columns: {missing}") if extra: differences.append(f"Extra columns: {extra}") - + # Check JOINs if len(comp1['joins']) != len(comp2['joins']): differences.append( f"JOIN count mismatch: {len(comp1['joins'])} vs {len(comp2['joins'])}" ) - + # Check GROUP BY gb1 = set(comp1['group_by']) gb2 = set(comp2['group_by']) if gb1 != gb2: differences.append(f"GROUP BY mismatch: {gb1} vs {gb2}") - + # Check aggregations agg1 = set(comp1['aggregations']) agg2 = set(comp2['aggregations']) if agg1 != agg2: differences.append(f"Aggregation mismatch: {agg1} vs {agg2}") - + return differences diff --git a/sql_eval/datasets/__init__.py b/sql_eval/datasets/__init__.py index af70512..c109458 100644 --- a/sql_eval/datasets/__init__.py +++ b/sql_eval/datasets/__init__.py @@ -6,11 +6,10 @@ from pathlib import Path from typing import Tuple +from ..connectors.sqlite import SQLiteConnector +from ..core.evaluator import DatasetLoader from ..core.models import DatabaseSchema, EvaluationCase from ..core.schema_loader import SchemaLoader -from ..core.evaluator import DatasetLoader -from ..connectors.sqlite import SQLiteConnector - DATASETS_DIR = Path(__file__).parent @@ -30,27 +29,27 @@ def load_dataset( ) -> Tuple[list[EvaluationCase], DatabaseSchema, SQLiteConnector]: """ Load a bundled dataset - + Args: name: Dataset name (e.g., 'ecommerce') with_db: Whether to create SQLite database with seed data - + Returns: Tuple of (test_cases, schema, db_connector or None) """ dataset_path = DATASETS_DIR / name - + if not dataset_path.exists(): available = list_datasets() raise ValueError( f"Dataset '{name}' not found. " f"Available datasets: {available}" ) - + # Load schema schema_file = dataset_path / 'schema.sql' schema = SchemaLoader.from_sql_file(schema_file) - + # Load test cases questions_file = dataset_path / 'questions.json' if questions_file.exists(): @@ -61,7 +60,7 @@ def load_dataset( test_cases = DatasetLoader.from_csv(questions_csv) else: raise ValueError(f"No questions file found in {dataset_path}") - + # Optionally create database db_connector = None if with_db: @@ -70,24 +69,24 @@ def load_dataset( schema_file=str(schema_file), seed_file=str(seed_file) if seed_file.exists() else None ) - + return test_cases, schema, db_connector def get_dataset_info(name: str) -> dict: """Get information about a dataset""" dataset_path = DATASETS_DIR / name - + if not dataset_path.exists(): raise ValueError(f"Dataset '{name}' not found") - + # Count questions questions_file = dataset_path / 'questions.json' if questions_file.exists(): with open(questions_file, 'r') as f: questions = json.load(f) num_questions = len(questions) - + # Count by difficulty difficulty_counts = {} category_counts = {} @@ -100,12 +99,12 @@ def get_dataset_info(name: str) -> dict: num_questions = 0 difficulty_counts = {} category_counts = {} - + # Count tables in schema schema_file = dataset_path / 'schema.sql' schema = SchemaLoader.from_sql_file(schema_file) num_tables = len(schema.tables) - + return { 'name': name, 'path': str(dataset_path), diff --git a/sql_eval/llm_providers/__init__.py b/sql_eval/llm_providers/__init__.py index c1f22da..5665a9e 100644 --- a/sql_eval/llm_providers/__init__.py +++ b/sql_eval/llm_providers/__init__.py @@ -2,10 +2,10 @@ LLM Providers for sql-eval """ -from .base import BaseLLMProvider -from .openai_provider import OpenAIProvider, AzureOpenAIProvider from .anthropic_provider import AnthropicProvider +from .base import BaseLLMProvider from .ollama_provider import OllamaProvider, SQLCoderProvider +from .openai_provider import AzureOpenAIProvider, OpenAIProvider def get_provider( @@ -15,12 +15,12 @@ def get_provider( ) -> BaseLLMProvider: """ Factory function to get LLM provider by name - + Args: provider_name: One of 'openai', 'anthropic', 'ollama', 'sqlcoder', 'azure' model: Optional model name override **kwargs: Additional provider-specific arguments - + Returns: Configured LLM provider instance """ @@ -33,18 +33,18 @@ def get_provider( 'azure': AzureOpenAIProvider, 'azure_openai': AzureOpenAIProvider, } - + provider_name = provider_name.lower() - + if provider_name not in providers: available = ', '.join(providers.keys()) raise ValueError( f"Unknown provider: {provider_name}. " f"Available providers: {available}" ) - + provider_class = providers[provider_name] - + if model: return provider_class(model=model, **kwargs) return provider_class(**kwargs) diff --git a/sql_eval/llm_providers/anthropic_provider.py b/sql_eval/llm_providers/anthropic_provider.py index 435c0e7..1033a5b 100644 --- a/sql_eval/llm_providers/anthropic_provider.py +++ b/sql_eval/llm_providers/anthropic_provider.py @@ -4,15 +4,16 @@ import os from typing import Optional -from .base import BaseLLMProvider + from ..core.models import DatabaseSchema +from .base import BaseLLMProvider class AnthropicProvider(BaseLLMProvider): """Anthropic API provider (Claude models)""" - + DEFAULT_MODEL = "claude-sonnet-4-20250514" - + def __init__( self, model: str = None, @@ -26,11 +27,11 @@ def __init__( self.temperature = temperature self.max_tokens = max_tokens self._client = None - + @property def provider_name(self) -> str: return "anthropic" - + @property def client(self): """Lazy initialization of Anthropic client""" @@ -44,7 +45,7 @@ def client(self): "Install with: pip install anthropic" ) return self._client - + def generate_sql( self, question: str, @@ -52,9 +53,9 @@ def generate_sql( examples: Optional[list[dict]] = None ) -> str: """Generate SQL using Anthropic API""" - + prompt = self.build_prompt(question, schema, examples) - + try: response = self.client.messages.create( model=self.model, @@ -67,18 +68,18 @@ def generate_sql( ], system="You are an expert SQL query generator. Return only the SQL query without any explanations or markdown formatting." ) - + # Extract text from response sql = "" for block in response.content: if block.type == "text": sql += block.text - + return self.clean_sql(sql) - + except Exception as e: raise RuntimeError(f"Anthropic API error: {str(e)}") - + def build_prompt( self, question: str, @@ -86,9 +87,9 @@ def build_prompt( examples: Optional[list[dict]] = None ) -> str: """Build prompt optimized for Claude""" - + schema_str = schema.to_ddl_string() - + prompt = f"""Given the database schema below, generate a SQL query to answer the question. @@ -103,14 +104,14 @@ def build_prompt( - No trailing semicolon """ - + if examples: prompt += "\n\n" for i, ex in enumerate(examples, 1): prompt += f"Question: {ex['question']}\n" prompt += f"SQL: {ex['sql']}\n\n" prompt += "\n" - + prompt += f"\n\n{question}\n\n\nSQL:" - + return prompt diff --git a/sql_eval/llm_providers/base.py b/sql_eval/llm_providers/base.py index f5dee88..62b0a09 100644 --- a/sql_eval/llm_providers/base.py +++ b/sql_eval/llm_providers/base.py @@ -4,22 +4,23 @@ from abc import ABC, abstractmethod from typing import Optional + from ..core.models import DatabaseSchema class BaseLLMProvider(ABC): """Abstract base class for all LLM providers""" - + def __init__(self, model: str, **kwargs): self.model = model self.config = kwargs - + @property @abstractmethod def provider_name(self) -> str: """Return the provider name (e.g., 'openai', 'anthropic')""" pass - + @abstractmethod def generate_sql( self, @@ -29,17 +30,17 @@ def generate_sql( ) -> str: """ Generate SQL from natural language question - + Args: question: Natural language question schema: Database schema examples: Optional few-shot examples [{'question': ..., 'sql': ...}] - + Returns: Generated SQL query string """ pass - + def build_prompt( self, question: str, @@ -48,11 +49,11 @@ def build_prompt( ) -> str: """ Build the prompt for SQL generation - + Can be overridden by subclasses for custom prompting strategies """ schema_str = schema.to_ddl_string() - + prompt = f"""You are an expert SQL query generator. Given the database schema below, generate a SQL query to answer the user's question. DATABASE SCHEMA: @@ -65,18 +66,18 @@ def build_prompt( - Use table aliases for clarity in JOINs - Do not include a trailing semicolon """ - + if examples: prompt += "\nEXAMPLES:\n" for i, ex in enumerate(examples, 1): prompt += f"\nExample {i}:\n" prompt += f"Question: {ex['question']}\n" prompt += f"SQL: {ex['sql']}\n" - + prompt += f"\nNow generate SQL for this question:\nQuestion: {question}\n\nSQL:" - + return prompt - + def clean_sql(self, sql: str) -> str: """ Clean up generated SQL @@ -86,9 +87,9 @@ def clean_sql(self, sql: str) -> str: """ if not sql: return "" - + sql = sql.strip() - + # Remove markdown code blocks if sql.startswith("```"): lines = sql.split('\n') @@ -98,7 +99,7 @@ def clean_sql(self, sql: str) -> str: if lines and lines[-1].strip() == "```": lines = lines[:-1] sql = '\n'.join(lines) - + # Remove common prefixes prefixes_to_remove = [ "sql:", "SQL:", "Here is the SQL:", "The SQL query is:", @@ -107,11 +108,11 @@ def clean_sql(self, sql: str) -> str: for prefix in prefixes_to_remove: if sql.lower().startswith(prefix.lower()): sql = sql[len(prefix):] - + # Strip whitespace and trailing semicolon sql = sql.strip().rstrip(';').strip() - + return sql - + def __repr__(self) -> str: return f"{self.__class__.__name__}(model='{self.model}')" diff --git a/sql_eval/llm_providers/ollama_provider.py b/sql_eval/llm_providers/ollama_provider.py index 1a8a9c9..8bb4404 100644 --- a/sql_eval/llm_providers/ollama_provider.py +++ b/sql_eval/llm_providers/ollama_provider.py @@ -4,38 +4,38 @@ No data leaves your machine when using this provider. """ -import json from typing import Optional -from .base import BaseLLMProvider + from ..core.models import DatabaseSchema +from .base import BaseLLMProvider class OllamaProvider(BaseLLMProvider): """ Ollama provider for fully local LLM inference - + Your data NEVER leaves your machine. - + Supported models: - codellama (recommended for SQL) - llama3 - mistral - sqlcoder (specialized for SQL) - deepseek-coder - + Usage: # First, ensure Ollama is running: # $ ollama serve - + # Pull a model: # $ ollama pull codellama - + provider = OllamaProvider(model="codellama") """ - + DEFAULT_MODEL = "codellama" DEFAULT_BASE_URL = "http://localhost:11434" - + def __init__( self, model: str = None, @@ -48,11 +48,11 @@ def __init__( self.base_url = base_url or self.DEFAULT_BASE_URL self.temperature = temperature self.num_predict = num_predict - + @property def provider_name(self) -> str: return "ollama" - + def generate_sql( self, question: str, @@ -60,7 +60,7 @@ def generate_sql( examples: Optional[list[dict]] = None ) -> str: """Generate SQL using local Ollama instance""" - + try: import requests except ImportError: @@ -68,9 +68,9 @@ def generate_sql( "requests package not installed. " "Install with: pip install requests" ) - + prompt = self.build_prompt(question, schema, examples) - + try: response = requests.post( f"{self.base_url}/api/generate", @@ -85,17 +85,17 @@ def generate_sql( }, timeout=120 # 2 minute timeout for slower machines ) - + if response.status_code != 200: raise RuntimeError( f"Ollama returned status {response.status_code}: {response.text}" ) - + result = response.json() sql = result.get("response", "") - + return self.clean_sql(sql) - + except requests.exceptions.ConnectionError: raise RuntimeError( "Cannot connect to Ollama. Make sure it's running:\n" @@ -110,7 +110,7 @@ def generate_sql( ) except Exception as e: raise RuntimeError(f"Ollama error: {str(e)}") - + def build_prompt( self, question: str, @@ -118,9 +118,9 @@ def build_prompt( examples: Optional[list[dict]] = None ) -> str: """Build prompt optimized for local models""" - + schema_str = schema.to_ddl_string() - + # Simpler, more direct prompt for smaller models prompt = f"""Generate a SQL query for the question below. @@ -132,25 +132,25 @@ def build_prompt( - No explanations - Standard SQL syntax """ - + if examples: prompt += "\nExamples:\n" for ex in examples[:3]: # Limit examples for smaller context prompt += f"Q: {ex['question']}\nSQL: {ex['sql']}\n\n" - + prompt += f"Q: {question}\nSQL:" - + return prompt - + def check_connection(self) -> bool: """Check if Ollama is running and accessible""" try: import requests response = requests.get(f"{self.base_url}/api/tags", timeout=5) return response.status_code == 200 - except: + except Exception: return False - + def list_models(self) -> list[str]: """List available models in Ollama""" try: @@ -159,7 +159,7 @@ def list_models(self) -> list[str]: if response.status_code == 200: data = response.json() return [m["name"] for m in data.get("models", [])] - except: + except Exception: pass return [] @@ -167,18 +167,18 @@ def list_models(self) -> list[str]: class SQLCoderProvider(OllamaProvider): """ Specialized provider for SQLCoder model - + SQLCoder is fine-tuned specifically for text-to-SQL tasks. - + Usage: # Pull SQLCoder: # $ ollama pull sqlcoder - + provider = SQLCoderProvider() """ - + DEFAULT_MODEL = "sqlcoder" - + def build_prompt( self, question: str, @@ -186,9 +186,9 @@ def build_prompt( examples: Optional[list[dict]] = None ) -> str: """Build prompt in SQLCoder's expected format""" - + schema_str = schema.to_ddl_string() - + # SQLCoder uses a specific prompt format prompt = f"""### Task Generate a SQL query to answer [QUESTION]{question}[/QUESTION] @@ -201,15 +201,15 @@ def build_prompt( Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION] [SQL] """ - + return prompt - + def clean_sql(self, sql: str) -> str: """Clean SQLCoder's output format""" sql = super().clean_sql(sql) - + # SQLCoder sometimes outputs [/SQL] tag if "[/SQL]" in sql: sql = sql.split("[/SQL]")[0] - + return sql.strip() diff --git a/sql_eval/llm_providers/openai_provider.py b/sql_eval/llm_providers/openai_provider.py index 2cccc24..08b5b1b 100644 --- a/sql_eval/llm_providers/openai_provider.py +++ b/sql_eval/llm_providers/openai_provider.py @@ -4,15 +4,16 @@ import os from typing import Optional -from .base import BaseLLMProvider + from ..core.models import DatabaseSchema +from .base import BaseLLMProvider class OpenAIProvider(BaseLLMProvider): """OpenAI API provider (GPT-4, GPT-3.5, etc.)""" - + DEFAULT_MODEL = "gpt-4o-mini" - + def __init__( self, model: str = None, @@ -26,11 +27,11 @@ def __init__( self.temperature = temperature self.max_tokens = max_tokens self._client = None - + @property def provider_name(self) -> str: return "openai" - + @property def client(self): """Lazy initialization of OpenAI client""" @@ -44,7 +45,7 @@ def client(self): "Install with: pip install openai" ) return self._client - + def generate_sql( self, question: str, @@ -52,9 +53,9 @@ def generate_sql( examples: Optional[list[dict]] = None ) -> str: """Generate SQL using OpenAI API""" - + prompt = self.build_prompt(question, schema, examples) - + try: response = self.client.chat.completions.create( model=self.model, @@ -71,17 +72,17 @@ def generate_sql( temperature=self.temperature, max_tokens=self.max_tokens ) - + sql = response.choices[0].message.content return self.clean_sql(sql) - + except Exception as e: raise RuntimeError(f"OpenAI API error: {str(e)}") class AzureOpenAIProvider(BaseLLMProvider): """Azure OpenAI API provider""" - + def __init__( self, model: str, @@ -99,11 +100,11 @@ def __init__( self.temperature = temperature self.max_tokens = max_tokens self._client = None - + @property def provider_name(self) -> str: return "azure_openai" - + @property def client(self): """Lazy initialization of Azure OpenAI client""" @@ -121,7 +122,7 @@ def client(self): "Install with: pip install openai" ) return self._client - + def generate_sql( self, question: str, @@ -129,9 +130,9 @@ def generate_sql( examples: Optional[list[dict]] = None ) -> str: """Generate SQL using Azure OpenAI API""" - + prompt = self.build_prompt(question, schema, examples) - + try: response = self.client.chat.completions.create( model=self.model, # This is the deployment name in Azure @@ -148,9 +149,9 @@ def generate_sql( temperature=self.temperature, max_tokens=self.max_tokens ) - + sql = response.choices[0].message.content return self.clean_sql(sql) - + except Exception as e: raise RuntimeError(f"Azure OpenAI API error: {str(e)}") diff --git a/tests/test_core.py b/tests/test_core.py index 4c624a2..a123539 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3,72 +3,73 @@ """ import pytest + from sql_eval.core.models import ( + DatabaseSchema, EvaluationCase, EvaluationResult, - QueryStatus, PartialScores, - DatabaseSchema, - TableSchema + QueryStatus, + TableSchema, ) -from sql_eval.core.sql_parser import SQLParser, SQLComparator from sql_eval.core.schema_loader import SchemaLoader +from sql_eval.core.sql_parser import SQLComparator, SQLParser class TestSQLParser: """Tests for SQL parsing functionality""" - + def test_normalize_basic(self): sql = "select * from users" normalized = SQLParser.normalize(sql) assert "SELECT" in normalized assert "FROM" in normalized - + def test_normalize_removes_semicolon(self): sql = "SELECT * FROM users;" normalized = SQLParser.normalize(sql) assert not normalized.endswith(';') - + def test_extract_tables_single(self): sql = "SELECT * FROM users" components = SQLParser.extract_components(sql) assert 'users' in components['tables'] - + def test_extract_tables_with_join(self): sql = "SELECT * FROM users JOIN orders ON users.id = orders.user_id" components = SQLParser.extract_components(sql) assert 'users' in components['tables'] assert 'orders' in components['tables'] - + def test_extract_columns_star(self): sql = "SELECT * FROM users" components = SQLParser.extract_components(sql) assert '*' in components['columns'] - + def test_extract_columns_specific(self): sql = "SELECT id, name, email FROM users" components = SQLParser.extract_components(sql) assert 'id' in components['columns'] assert 'name' in components['columns'] assert 'email' in components['columns'] - + def test_extract_aggregations(self): sql = "SELECT COUNT(*), SUM(amount) FROM orders" components = SQLParser.extract_components(sql) assert 'COUNT' in components['aggregations'] assert 'SUM' in components['aggregations'] - + def test_extract_group_by(self): sql = "SELECT category, COUNT(*) FROM products GROUP BY category" components = SQLParser.extract_components(sql) assert 'category' in components['group_by'] - + def test_extract_order_by(self): sql = "SELECT * FROM products ORDER BY price DESC" components = SQLParser.extract_components(sql) assert len(components['order_by']) == 1 assert components['order_by'][0]['direction'] == 'DESC' - + def test_extract_limit(self): sql = "SELECT * FROM users LIMIT 10" components = SQLParser.extract_components(sql) @@ -77,44 +78,44 @@ def test_extract_limit(self): class TestSQLComparator: """Tests for SQL comparison functionality""" - + def setup_method(self): self.comparator = SQLComparator() - + def test_exact_match_identical(self): sql1 = "SELECT * FROM users" sql2 = "SELECT * FROM users" assert self.comparator.exact_match(sql1, sql2) - + def test_exact_match_case_insensitive(self): sql1 = "SELECT * FROM users" sql2 = "select * from users" assert self.comparator.exact_match(sql1, sql2) - + def test_exact_match_whitespace_insensitive(self): sql1 = "SELECT * FROM users" sql2 = "SELECT * FROM users" assert self.comparator.exact_match(sql1, sql2) - + def test_exact_match_different(self): sql1 = "SELECT * FROM users" sql2 = "SELECT * FROM orders" assert not self.comparator.exact_match(sql1, sql2) - + def test_structural_match_tables(self): sql1 = "SELECT id FROM users" sql2 = "SELECT name FROM users" scores = self.comparator.structural_match(sql1, sql2) assert scores.tables_match assert not scores.columns_match - + def test_structural_match_joins(self): sql1 = "SELECT * FROM users JOIN orders ON users.id = orders.user_id" sql2 = "SELECT * FROM users JOIN orders ON users.id = orders.user_id" scores = self.comparator.structural_match(sql1, sql2) assert scores.tables_match assert scores.joins_match - + def test_get_differences(self): sql1 = "SELECT id FROM users" sql2 = "SELECT name FROM orders" @@ -124,7 +125,7 @@ def test_get_differences(self): class TestSchemaLoader: """Tests for schema loading functionality""" - + def test_parse_simple_table(self): ddl = """ CREATE TABLE users ( @@ -138,7 +139,7 @@ def test_parse_simple_table(self): assert 'id' in schema.tables['users'].columns assert 'name' in schema.tables['users'].columns assert 'email' in schema.tables['users'].columns - + def test_parse_primary_key(self): ddl = """ CREATE TABLE users ( @@ -148,7 +149,7 @@ def test_parse_primary_key(self): """ schema = SchemaLoader.from_ddl(ddl) assert schema.tables['users'].columns['id']['primary_key'] - + def test_parse_foreign_key(self): ddl = """ CREATE TABLE orders ( @@ -160,7 +161,7 @@ def test_parse_foreign_key(self): schema = SchemaLoader.from_ddl(ddl) assert len(schema.tables['orders'].foreign_keys) == 1 assert schema.tables['orders'].foreign_keys[0]['column'] == 'user_id' - + def test_parse_multiple_tables(self): ddl = """ CREATE TABLE users (id INTEGER PRIMARY KEY); @@ -169,7 +170,7 @@ def test_parse_multiple_tables(self): """ schema = SchemaLoader.from_ddl(ddl) assert len(schema.tables) == 3 - + def test_to_prompt_string(self): ddl = "CREATE TABLE users (id INTEGER PRIMARY KEY, name VARCHAR(100));" schema = SchemaLoader.from_ddl(ddl) @@ -181,7 +182,7 @@ def test_to_prompt_string(self): class TestModels: """Tests for data models""" - + def test_evaluation_case_creation(self): case = EvaluationCase( question_id="Q001", @@ -192,7 +193,7 @@ def test_evaluation_case_creation(self): ) assert case.question_id == "Q001" assert case.difficulty == "easy" - + def test_evaluation_result_to_dict(self): result = EvaluationResult( question_id="Q001", @@ -205,9 +206,9 @@ def test_evaluation_result_to_dict(self): ) d = result.to_dict() assert d['question_id'] == "Q001" - assert d['exact_match'] == True + assert d['exact_match'] assert d['status'] == 'success' - + def test_partial_scores_to_dict(self): scores = PartialScores( tables_match=True, @@ -216,14 +217,14 @@ def test_partial_scores_to_dict(self): overall_score=0.75 ) d = scores.to_dict() - assert d['tables_match'] == True - assert d['joins_match'] == False + assert d['tables_match'] + assert not d['joins_match'] assert d['overall_score'] == 0.75 class TestDatabaseSchema: """Tests for DatabaseSchema class""" - + def test_to_ddl_string(self): schema = DatabaseSchema( tables={