diff --git a/functioncalling_finetuning/evals/evaluation_script.py b/functioncalling_finetuning/evals/evaluation_script.py index c5c4813..4257d51 100644 --- a/functioncalling_finetuning/evals/evaluation_script.py +++ b/functioncalling_finetuning/evals/evaluation_script.py @@ -171,14 +171,10 @@ def load_hermes_dataset(max_samples=500): return samples -async def get_model_response_async(model, system_prompt, user_query, semaphore): +async def get_model_response_async(model, system_prompt, user_query, semaphore, client): """Get OpenAI model response asynchronously with rate limiting.""" async with semaphore: - client = AsyncOpenAI( - api_key=API_KEY, - base_url=BASE_URL, # Adjust if using a different endpoint - ) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) @@ -346,12 +342,12 @@ def print_model_summary(model, results): print(f"Format Valid: {format_valid:.3f}") -async def process_single_sample(sample: Dict[str, Any], model: str, semaphore: asyncio.Semaphore) -> Dict[str, Any]: +async def process_single_sample(sample: Dict[str, Any], model: str, semaphore: asyncio.Semaphore, client) -> Dict[str, Any]: """Process a single sample asynchronously.""" # Get model response response = await get_model_response_async( - model, sample["system_prompt"], sample["user_query"], semaphore + model, sample["system_prompt"], sample["user_query"], semaphore, client ) # print(f"Sample {sample['id']} response: {response}") @@ -386,6 +382,12 @@ async def run_benchmark_async(eval_samples: List[Dict[str, Any]], models: List[s all_results = [] + # Create shared client for all requests + client = AsyncOpenAI( + api_key=API_KEY, + base_url=BASE_URL + ) + for model in models: print(f"\nEvaluating {model}...") start_time = time.time() @@ -397,7 +399,7 @@ async def run_benchmark_async(eval_samples: List[Dict[str, Any]], models: List[s tasks = [] for sample in eval_samples: task = asyncio.create_task( - process_single_sample(sample, model, semaphore) + process_single_sample(sample, model, semaphore, client) ) tasks.append(task)