Skip to content

Commit b6a6ec5

Browse files
committed
wip
1 parent f058c32 commit b6a6ec5

6 files changed

Lines changed: 374 additions & 130 deletions

File tree

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,15 @@ async def default_agent_rollout_processor(
117117
) -> List[EvaluationRow]:
118118
dataset: Dataset = []
119119
for row in rows:
120-
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger)
121-
await agent.setup()
122-
await agent.call_agent()
123-
dataset.append(agent.evaluation_row)
124-
if agent.mcp_client:
125-
await agent.mcp_client.cleanup()
120+
try:
121+
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger)
122+
await agent.setup()
123+
await agent.call_agent()
124+
dataset.append(agent.evaluation_row)
125+
if agent.mcp_client:
126+
await agent.mcp_client.cleanup()
127+
except Exception as e:
128+
row.rollout_status.status = "error"
129+
row.rollout_status.error_message = str(e)
130+
dataset.append(row)
126131
return dataset

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,12 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
8787

8888
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
8989
async with semaphore:
90-
return await process_row(r)
90+
try:
91+
return await process_row(r)
92+
except Exception as e:
93+
r.rollout_status.status = "error"
94+
r.rollout_status.error_message = str(e)
95+
return r
9196

9297
tasks = [_sem_wrapper(row) for row in rows]
9398
dataset = list(await asyncio.gather(*tasks))

eval_protocol/pytest/evaluation_test.py

Lines changed: 111 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -401,133 +401,132 @@ def _log_eval_error(
401401
logger=active_logger,
402402
)
403403

404+
max_retry = int(os.getenv("EP_MAX_RETRY", "0"))
405+
404406
for i in range(num_runs):
405-
# Regenerate outputs each run by deep-copying the pristine dataset
406-
# so model responses are not reused across runs.
407407
run_id = generate_id()
408-
fresh_dataset = [r.model_copy(deep=True) for r in data]
409-
410-
# apply new run_id to fresh_dataset
411-
for row in fresh_dataset:
412-
row.run_id = run_id
413-
414-
# generate new rollout_id for each row
415-
for row in fresh_dataset:
416-
row.rollout_id = generate_id()
417-
418-
# log the fresh_dataset
419-
for row in fresh_dataset:
420-
active_logger.log(row)
421-
422-
# filter out rows that already have completed rollouts via checkpointing
423-
rows_to_process = []
424-
completed_rollout_ids = set()
425-
426-
finished_logs = active_logger.read()
427-
428-
for finished_row in finished_logs:
429-
# need to add finished rows to all_results so that we can aggregate them later.
430-
all_results.append(finished_row)
431-
# TODO: need to also add the num_run to track which run the row belongs to.
432-
# TODO: ask why we made row_id optional in the first place. checkpointing won't work without some ID.
433-
if finished_row.input_metadata and finished_row.input_metadata.row_id:
434-
completed_rollout_ids.add(finished_row.input_metadata.row_id)
435-
436-
for row in fresh_dataset:
437-
row_id = row.input_metadata.row_id if row.input_metadata else None
438-
if row_id not in completed_rollout_ids:
439-
rows_to_process.append(row)
440-
441-
if len(rows_to_process) < len(fresh_dataset):
442-
print(
443-
f"Checkpointing: Found {len(fresh_dataset) - len(rows_to_process)} completed rows, processing {len(rows_to_process)} remaining rows"
444-
)
445-
446-
if rows_to_process:
447-
processed_dataset = execute_function(
448-
rollout_processor, rows=rows_to_process, config=config
449-
)
450-
451-
if mode == "pointwise":
452-
# Pointwise mode: apply the evaluator function to each row
453-
for row in processed_dataset:
454-
result = execute_with_params(
408+
retry_attempt = 0
409+
current_data = data
410+
411+
while retry_attempt <= max_retry:
412+
if retry_attempt > 0:
413+
logged_rows = active_logger.read()
414+
failed_rows = [
415+
row
416+
for row in logged_rows
417+
if row.rollout_status
418+
and row.rollout_status.status == "error"
419+
and row.run_id == run_id
420+
]
421+
if not failed_rows:
422+
break
423+
current_data = failed_rows
424+
425+
# Regenerate outputs each run by deep-copying the pristine dataset
426+
# so model responses are not reused across runs.
427+
fresh_dataset = [r.model_copy(deep=True) for r in current_data]
428+
429+
# apply new run_id to fresh_dataset
430+
for row in fresh_dataset:
431+
row.run_id = run_id
432+
433+
# generate new rollout_id for each row
434+
for row in fresh_dataset:
435+
row.rollout_id = generate_id()
436+
437+
# log the fresh_dataset
438+
for row in fresh_dataset:
439+
active_logger.log(row)
440+
441+
processed_dataset = execute_function(rollout_processor, rows=fresh_dataset, config=config)
442+
443+
if mode == "pointwise":
444+
# Pointwise mode: apply the evaluator function to each row
445+
for row in processed_dataset:
446+
result = execute_with_params(
447+
test_func,
448+
processed_row=row,
449+
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
450+
)
451+
if result is None or not isinstance(result, EvaluationRow):
452+
raise ValueError(
453+
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
454+
)
455+
# TODO: not this simple, only append ones that are not error
456+
all_results[i].append(result)
457+
else:
458+
# Batch mode: call the test function with the full dataset
459+
results = execute_with_params(
455460
test_func,
456-
processed_row=row,
461+
processed_dataset=processed_dataset,
457462
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
458463
)
459-
if result is None or not isinstance(result, EvaluationRow):
464+
if results is None:
460465
raise ValueError(
461466
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
462467
)
463-
all_results[i].append(result)
464-
else:
465-
# Batch mode: call the test function with the full dataset
466-
results = execute_with_params(
467-
test_func,
468-
processed_dataset=processed_dataset,
469-
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
470-
)
471-
if results is None:
472-
raise ValueError(
473-
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
474-
)
475-
if not isinstance(results, list):
476-
raise ValueError(
477-
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
478-
)
479-
if not results:
480-
raise ValueError(
481-
f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test."
482-
)
483-
if not all(isinstance(r, EvaluationRow) for r in results):
484-
raise ValueError(
485-
f"Test function {test_func.__name__} returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
468+
if not isinstance(results, list):
469+
raise ValueError(
470+
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
471+
)
472+
if not results:
473+
raise ValueError(
474+
f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test."
475+
)
476+
if not all(isinstance(r, EvaluationRow) for r in results):
477+
raise ValueError(
478+
f"Test function {test_func.__name__} returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
479+
)
480+
# TODO: not this simple, only append ones that are not error
481+
all_results[i] = results
482+
483+
retry_attempt += 1
484+
485+
scores = [
486+
sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result)
487+
for result in all_results
488+
]
489+
agg_score = aggregate(scores, aggregation_method)
490+
score_std = statistics.stdev(scores) if len(scores) > 1 else 0.0
491+
492+
# Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats)
493+
ci_low: float | None = None
494+
ci_high: float | None = None
495+
if aggregation_method == "mean":
496+
try:
497+
result_ci = compute_fixed_set_mu_ci(
498+
[item for sublist in all_results for item in sublist]
486499
)
487-
all_results[i] = results
488-
489-
scores = [
490-
sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result)
491-
for result in all_results
492-
]
493-
agg_score = aggregate(scores, aggregation_method)
494-
score_std = statistics.stdev(scores) if len(scores) > 1 else 0.0
495-
496-
# Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats)
497-
ci_low: float | None = None
498-
ci_high: float | None = None
499-
if aggregation_method == "mean":
500-
try:
501-
result_ci = compute_fixed_set_mu_ci([item for sublist in all_results for item in sublist])
502-
mu_ci_low, mu_ci_high = result_ci[1], result_ci[2]
503-
if mu_ci_low is not None and mu_ci_high is not None:
504-
ci_low = float(mu_ci_low)
505-
ci_high = float(mu_ci_high)
506-
# Keep agg_score as-is (mean over scores). For equal repeats per question these match.
507-
except Exception:
508-
ci_low = None
509-
ci_high = None
510-
511-
# Determine if the evaluation passed based on threshold
512-
passed = None
513-
514-
if threshold is not None:
515-
success_passed, std_passed = True, True
516-
517-
success_passed = agg_score >= threshold.success
500+
mu_ci_low, mu_ci_high = result_ci[1], result_ci[2]
501+
if mu_ci_low is not None and mu_ci_high is not None:
502+
ci_low = float(mu_ci_low)
503+
ci_high = float(mu_ci_high)
504+
# Keep agg_score as-is (mean over scores). For equal repeats per question these match.
505+
except Exception:
506+
ci_low = None
507+
ci_high = None
518508

519-
if threshold.standard_deviation is not None:
520-
std_passed = score_std <= threshold.standard_deviation
509+
# Determine if the evaluation passed based on threshold
510+
passed = None
511+
512+
if threshold is not None:
513+
success_passed, std_passed = True, True
514+
515+
success_passed = agg_score >= threshold.success
516+
517+
if threshold.standard_deviation is not None:
518+
std_passed = score_std <= threshold.standard_deviation
521519

522-
passed = success_passed and std_passed
520+
passed = success_passed and std_passed
523521

524522
# Update eval metadata status and passed field for all results
525523
for result in all_results:
526524
for r in result:
527-
if r.eval_metadata is not None:
528-
r.eval_metadata.status = "finished"
529-
r.eval_metadata.passed = passed
530-
default_logger.log(r)
525+
if r.rollout_status is not None:
526+
if r.rollout_status.status != "error":
527+
r.rollout_status.status = "finished"
528+
r.rollout_status.passed = passed
529+
active_logger.log(r)
531530

532531
# Optional: print and/or persist a summary artifact for CI
533532
try:

eval_protocol/pytest/plugin.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
max_dataset_rows value set in the decorator).
1313
"""
1414

15-
import os
1615
import logging
16+
import os
1717
from typing import Optional
1818

1919

@@ -32,17 +32,13 @@ def pytest_addoption(parser) -> None:
3232
"--ep-print-summary",
3333
action="store_true",
3434
default=False,
35-
help=(
36-
"Print a concise summary line (suite/model/effort/agg score) at the end of each evaluation_test."
37-
),
35+
help=("Print a concise summary line (suite/model/effort/agg score) at the end of each evaluation_test."),
3836
)
3937
group.addoption(
4038
"--ep-summary-json",
4139
action="store",
4240
default=None,
43-
help=(
44-
"Write a JSON summary artifact at the given path (e.g., ./outputs/aime_low.json)."
45-
),
41+
help=("Write a JSON summary artifact at the given path (e.g., ./outputs/aime_low.json)."),
4642
)
4743
group.addoption(
4844
"--ep-input-param",
@@ -63,6 +59,13 @@ def pytest_addoption(parser) -> None:
6359
"Values: low|medium|high"
6460
),
6561
)
62+
group.addoption(
63+
"--ep-max-retry",
64+
action="store",
65+
type=int,
66+
default=None,
67+
help=("Failed rollouts (with rollout_status.status == 'error') will be retried up to this many times."),
68+
)
6669

6770

6871
def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
@@ -104,10 +107,15 @@ def pytest_configure(config) -> None:
104107
if summary_json_path:
105108
os.environ["EP_SUMMARY_JSON"] = summary_json_path
106109

110+
max_retry = config.getoption("--ep-max-retry")
111+
if max_retry is not None:
112+
os.environ["EP_MAX_RETRY"] = str(max_retry)
113+
107114
# Allow ad-hoc overrides of input params via CLI flags
108115
try:
109116
import json as _json
110117
import pathlib as _pathlib
118+
111119
merged: dict = {}
112120
input_params_opts = config.getoption("--ep-input-param")
113121
if input_params_opts:
@@ -140,5 +148,3 @@ def pytest_configure(config) -> None:
140148
except Exception:
141149
# best effort, do not crash pytest session
142150
pass
143-
144-

tests/pytest/test_tau_bench_airline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval
5858

5959
rows.append(eval_row)
6060

61-
return rows
61+
return rows[0:1]
6262

6363

6464
@evaluation_test(
@@ -68,7 +68,7 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval
6868
rollout_input_params=[{"temperature": 0.8, "max_tokens": 4096, "reasoning_effort": "low"}],
6969
rollout_processor=default_mcp_gym_rollout_processor,
7070
passed_threshold={"success": 0.4, "standard_deviation": 0.1},
71-
num_runs=8,
71+
num_runs=1,
7272
mode="pointwise",
7373
max_concurrent_rollouts=50,
7474
server_script_path="examples/tau2_mcp/server.py",

0 commit comments

Comments
 (0)