diff --git a/README.md b/README.md index e90e0c8..0c86f98 100644 --- a/README.md +++ b/README.md @@ -248,8 +248,8 @@ MIT License - see [LICENSE](LICENSE) for details. ## Author -**Ajay Maurya** - AI Engineer -[LinkedIn](https://linkedin.com/in/ajaymaurya) | [GitHub](https://github.com/ajaymaurya) +**Ajay Maurya** - AI Engineer +[LinkedIn](https://linkedin.com/in/ajaymauryabbn) | [GitHub](https://github.com/ajaymauryabbn) --- diff --git a/examples/basic_usage.py b/examples/basic_usage.py index 72b7ad5..861d3fd 100644 --- a/examples/basic_usage.py +++ b/examples/basic_usage.py @@ -16,7 +16,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from sql_eval import Evaluator -from sql_eval.datasets import load_ecommerce, get_dataset_info +from sql_eval.datasets import get_dataset_info, load_ecommerce from sql_eval.llm_providers import get_provider @@ -29,7 +29,7 @@ def main(): provider_name = 'ollama' else: provider_name = 'openai' - + # Show dataset info print("\n" + "=" * 50) print("Dataset Information") @@ -38,21 +38,21 @@ def main(): print(f"Name: {info['name']}") print(f"Questions: {info['num_questions']}") print(f"Tables: {', '.join(info['tables'])}") - print(f"\nDifficulty breakdown:") + print("\nDifficulty breakdown:") for diff, count in info['difficulty_breakdown'].items(): print(f" {diff}: {count}") - + # Load dataset print("\n" + "=" * 50) print("Loading Dataset") print("=" * 50) test_cases, schema, db = load_ecommerce(with_db=True) print(f"Loaded {len(test_cases)} test cases") - + # Limit to first 5 for quick demo test_cases = test_cases[:5] print(f"Running evaluation on {len(test_cases)} questions (limited for demo)") - + # Initialize provider print(f"\nUsing LLM provider: {provider_name}") try: @@ -60,7 +60,7 @@ def main(): except Exception as e: print(f"Error initializing provider: {e}") return - + # Create evaluator evaluator = Evaluator( llm_provider=provider, @@ -68,12 +68,12 @@ def main(): db_connector=db, verbose=True ) - + # Run evaluation print("\n" + "=" * 50) print("Running Evaluation") print("=" * 50) - + try: report = evaluator.evaluate( test_cases, @@ -83,12 +83,12 @@ def main(): print(f"\nError during evaluation: {e}") print("\nTip: Make sure your LLM provider is properly configured") return - + # Show detailed results print("\n" + "=" * 50) print("Detailed Results") print("=" * 50) - + for result in report.results: status = "✓" if result.exact_match else "✗" print(f"\n{status} {result.question_id}: {result.question[:50]}...") @@ -96,11 +96,11 @@ def main(): print(f" Got: {result.generated_sql[:60]}...") if not result.exact_match and result.partial_scores: print(f" Structural score: {result.partial_scores.overall_score:.1%}") - + # Cleanup if db: db.disconnect() - + print("\n" + "=" * 50) print("Done!") print("=" * 50) diff --git a/pyproject.toml b/pyproject.toml index 3becfc5..97a78ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,8 @@ target-version = ['py310', 'py311', 'py312'] [tool.ruff] line-length = 100 + +[tool.ruff.lint] select = ["E", "F", "W", "I"] ignore = ["E501"] diff --git a/sql_eval/__init__.py b/sql_eval/__init__.py index 0a34f3a..4c572ff 100644 --- a/sql_eval/__init__.py +++ b/sql_eval/__init__.py @@ -8,11 +8,11 @@ __author__ = "Ajay Maurya" from sql_eval.core.evaluator import Evaluator -from sql_eval.core.models import EvaluationCase, EvaluationResult, EvaluationReport +from sql_eval.core.models import EvaluationCase, EvaluationReport, EvaluationResult __all__ = [ "Evaluator", - "EvaluationCase", + "EvaluationCase", "EvaluationResult", "EvaluationReport", ] diff --git a/sql_eval/cli.py b/sql_eval/cli.py index 0a795bf..03a6efc 100644 --- a/sql_eval/cli.py +++ b/sql_eval/cli.py @@ -3,11 +3,10 @@ sql-eval Command Line Interface """ -import sys -import json import argparse +import json +import sys from datetime import datetime -from pathlib import Path def main(): @@ -15,9 +14,9 @@ def main(): prog='sql-eval', description='Text-to-SQL Evaluation Framework' ) - + subparsers = parser.add_subparsers(dest='command', help='Commands') - + # Run command run_parser = subparsers.add_parser('run', help='Run evaluation') run_parser.add_argument( @@ -62,14 +61,14 @@ def main(): action='store_true', help='Suppress progress output' ) - + # List datasets command - list_parser = subparsers.add_parser('list', help='List available datasets') - + subparsers.add_parser('list', help='List available datasets') + # Dataset info command info_parser = subparsers.add_parser('info', help='Show dataset information') info_parser.add_argument('dataset', help='Dataset name') - + # Compare command compare_parser = subparsers.add_parser('compare', help='Compare multiple LLMs') compare_parser.add_argument( @@ -88,13 +87,13 @@ def main(): default=None, help='Limit number of questions' ) - + args = parser.parse_args() - + if args.command is None: parser.print_help() return - + if args.command == 'list': cmd_list() elif args.command == 'info': @@ -107,12 +106,12 @@ def main(): def cmd_list(): """List available datasets""" - from .datasets import list_datasets, get_dataset_info - + from .datasets import get_dataset_info, list_datasets + datasets = list_datasets() print("\nAvailable datasets:") print("-" * 40) - + for name in datasets: info = get_dataset_info(name) print(f" {name}") @@ -125,23 +124,23 @@ def cmd_list(): def cmd_info(dataset_name: str): """Show dataset information""" from .datasets import get_dataset_info - + try: info = get_dataset_info(dataset_name) except ValueError as e: print(f"Error: {e}") sys.exit(1) - + print(f"\nDataset: {info['name']}") print("=" * 40) print(f"Questions: {info['num_questions']}") print(f"Tables: {', '.join(info['tables'])}") print(f"Has seed data: {'Yes' if info['has_seed_data'] else 'No'}") - + print("\nDifficulty breakdown:") for diff, count in info['difficulty_breakdown'].items(): print(f" {diff}: {count}") - + print("\nCategory breakdown:") for cat, count in sorted(info['category_breakdown'].items(), key=lambda x: -x[1]): print(f" {cat}: {count}") @@ -149,14 +148,14 @@ def cmd_info(dataset_name: str): def cmd_run(args): """Run evaluation""" + from .core.evaluator import Evaluator from .datasets import load_dataset from .llm_providers import get_provider - from .core.evaluator import Evaluator - + print(f"\n{'='*50}") print("sql-eval v0.1.0") print(f"{'='*50}") - + # Load dataset print(f"\nLoading dataset: {args.dataset}") try: @@ -167,11 +166,11 @@ def cmd_run(args): except ValueError as e: print(f"Error: {e}") sys.exit(1) - + # Apply limit if args.limit: test_cases = test_cases[:args.limit] - + # Get LLM provider print(f"LLM: {args.llm}" + (f" ({args.model})" if args.model else "")) try: @@ -179,7 +178,7 @@ def cmd_run(args): except Exception as e: print(f"Error initializing LLM provider: {e}") sys.exit(1) - + # Run evaluation evaluator = Evaluator( llm_provider=provider, @@ -187,21 +186,21 @@ def cmd_run(args): db_connector=db, verbose=not args.quiet ) - + print(f"\nEvaluating {len(test_cases)} questions...") - + try: report = evaluator.evaluate(test_cases, run_execution_tests=args.with_execution) except Exception as e: print(f"\nError during evaluation: {e}") sys.exit(1) - + # Output results if args.output == 'json': output_json(report, args.output_file) elif args.output == 'html': output_html(report, args.output_file) - + # Cleanup if db: db.disconnect() @@ -209,29 +208,29 @@ def cmd_run(args): def cmd_compare(args): """Compare multiple LLMs""" + from .core.evaluator import Evaluator from .datasets import load_dataset from .llm_providers import get_provider - from .core.evaluator import Evaluator - - llm_names = [l.strip() for l in args.llms.split(',')] - + + llm_names = [llm_name.strip() for llm_name in args.llms.split(',')] + print(f"\n{'='*50}") print("sql-eval - LLM Comparison") print(f"{'='*50}") print(f"Dataset: {args.dataset}") print(f"LLMs: {', '.join(llm_names)}") - + # Load dataset once test_cases, schema, db = load_dataset(args.dataset, with_db=False) - + if args.limit: test_cases = test_cases[:args.limit] - + results = {} - + for llm_name in llm_names: print(f"\n--- Evaluating: {llm_name} ---") - + try: provider = get_provider(llm_name) evaluator = Evaluator( @@ -244,14 +243,14 @@ def cmd_compare(args): except Exception as e: print(f"Error with {llm_name}: {e}") results[llm_name] = None - + # Print comparison print(f"\n{'='*60}") print("COMPARISON RESULTS") print(f"{'='*60}") print(f"{'LLM':<20} {'Exact Match':<15} {'Structural':<15} {'Latency':<10}") print("-" * 60) - + for llm_name, report in results.items(): if report: print( @@ -267,7 +266,7 @@ def cmd_compare(args): def output_json(report, filepath=None): """Output report as JSON""" data = report.to_dict() - + if filepath: with open(filepath, 'w') as f: json.dump(data, f, indent=2) @@ -283,12 +282,12 @@ def output_html(report, filepath=None): """Output report as HTML""" if not filepath: filepath = f"sql_eval_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.html" - + html = generate_html_report(report) - + with open(filepath, 'w') as f: f.write(html) - + print(f"\nHTML report saved to: {filepath}") @@ -323,7 +322,7 @@ def generate_html_report(report) -> str:
Generated: {report.timestamp}
LLM: {report.llm_provider} ({report.llm_model})
- +| Difficulty | Accuracy | Correct | Total |
|---|---|---|---|
| {diff} | @@ -357,14 +356,14 @@ def generate_html_report(report) -> str:{stats['total']} |
| Category | Accuracy | Correct | Total |
|---|---|---|---|
| {cat} | @@ -373,14 +372,14 @@ def generate_html_report(report) -> str:{stats['total']} |
| ID | Question | Result | Latency | {r.latency_ms:.0f}ms | """ - + html += """
|---|