|
3 | 3 | import os |
4 | 4 | import re |
5 | 5 | from dataclasses import replace |
6 | | -from typing import Any, Callable, Dict, List, Literal, Optional, Union |
| 6 | +from typing import Any, Callable, List, Literal, Optional, Union |
| 7 | + |
| 8 | +from litellm import cost_per_token |
7 | 9 |
|
8 | 10 | from eval_protocol.dataset_logger.dataset_logger import DatasetLogger |
9 | | -from eval_protocol.models import EvalMetadata, EvaluationRow, Status |
| 11 | +from eval_protocol.models import CostMetrics, EvalMetadata, EvaluationRow, Status |
10 | 12 | from eval_protocol.pytest.rollout_processor import RolloutProcessor |
11 | 13 | from eval_protocol.pytest.types import ( |
12 | 14 | CompletionParams, |
@@ -435,3 +437,31 @@ def extract_effort_tag(params: dict) -> Optional[str]: |
435 | 437 | except Exception: |
436 | 438 | return None |
437 | 439 | return None |
| 440 | + |
| 441 | + |
| 442 | +def calculate_cost_metrics_for_row(row: EvaluationRow) -> None: |
| 443 | + """Calculate and set cost metrics for an EvaluationRow based on its usage data.""" |
| 444 | + if not row.execution_metadata.usage: |
| 445 | + return |
| 446 | + |
| 447 | + model_id = ( |
| 448 | + row.input_metadata.completion_params.get("model", "unknown") |
| 449 | + if row.input_metadata.completion_params |
| 450 | + else "unknown" |
| 451 | + ) |
| 452 | + usage = row.execution_metadata.usage |
| 453 | + |
| 454 | + input_tokens = usage.prompt_tokens or 0 |
| 455 | + output_tokens = usage.completion_tokens or 0 |
| 456 | + |
| 457 | + input_cost, output_cost = cost_per_token( |
| 458 | + model=model_id, prompt_tokens=input_tokens, completion_tokens=output_tokens |
| 459 | + ) |
| 460 | + total_cost = input_cost + output_cost |
| 461 | + |
| 462 | + # Set all cost metrics on the row |
| 463 | + row.execution_metadata.cost_metrics = CostMetrics( |
| 464 | + input_cost_usd=input_cost, |
| 465 | + output_cost_usd=output_cost, |
| 466 | + total_cost_usd=total_cost, |
| 467 | + ) |
0 commit comments