Skip to content

Commit 159a12f

Browse files
authored
differentiate skippable exception and non-skippable exception (#96)
* custom rollout exception and termination_reason=interrupted * add nonskippableerror * format
1 parent ff034a8 commit 159a12f

6 files changed

Lines changed: 62 additions & 53 deletions

File tree

eval_protocol/mcp/execution/manager.py

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,16 @@
1212
import threading
1313
import time
1414
from dataclasses import asdict
15-
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Union
15+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
1616

1717
import anyio
18-
import httpx
1918
from openai.types import CompletionUsage
2019

2120
from vendor.tau2.data_model.message import AssistantMessage, UserMessage
2221
from vendor.tau2.user.user_simulator import UserSimulator
2322

24-
from ...models import EvaluationRow, InputMetadata, Message
25-
from ...types import MCPSession, MCPToolCall, TerminationReason, Trajectory
23+
from ...models import EvaluationRow, InputMetadata, Message, RolloutStatus
24+
from ...types import TerminationReason, Trajectory, NonSkippableException
2625

2726
if TYPE_CHECKING:
2827
from ..session.manager import GeneralMCPVectorEnv
@@ -107,7 +106,7 @@ async def _execute_with_semaphore(idx):
107106
)
108107

109108
# Convert trajectory to EvaluationRow immediately
110-
evaluation_row = evaluation_rows[idx]
109+
evaluation_row: EvaluationRow = evaluation_rows[idx]
111110

112111
# Handle multimodal content by extracting text from complex content structures
113112
messages = []
@@ -137,16 +136,15 @@ async def _execute_with_semaphore(idx):
137136
}
138137

139138
if trajectory.terminated:
140-
if trajectory.termination_reason == TerminationReason.ERROR:
141-
evaluation_row.rollout_status.status = "error"
142-
evaluation_row.rollout_status.termination_reason = trajectory.control_plane_summary.get(
143-
"error_message", None
144-
)
145-
else:
146-
evaluation_row.rollout_status.status = "finished"
147-
evaluation_row.rollout_status.termination_reason = trajectory.termination_reason
139+
evaluation_row.rollout_status.termination_reason = trajectory.termination_reason
140+
evaluation_row.rollout_status.status = RolloutStatus.Status.FINISHED
141+
# preserve the true error mesage if there are any
142+
if trajectory.control_plane_summary.get("error_message"):
143+
evaluation_row.rollout_status.extra_info = {
144+
"error_message": trajectory.control_plane_summary.get("error_message")
145+
}
148146
else:
149-
evaluation_row.rollout_status.status = "running"
147+
evaluation_row.rollout_status.status = RolloutStatus.Status.RUNNING
150148

151149
return evaluation_row
152150

@@ -437,31 +435,18 @@ async def _execute_rollout(
437435
logger.info(
438436
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
439437
)
440-
441-
except asyncio.CancelledError:
442-
failure_reason = "asyncio context cancelled"
443-
logger.error(
444-
f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True
445-
)
446-
except (anyio.ClosedResourceError, anyio.BrokenResourceError):
447-
failure_reason = "anyioconnection/resource error"
448-
logger.error(
449-
f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True
450-
)
451-
except Exception as e:
452-
error_msg = str(e) if str(e) else f"{type(e).__name__}: Unexpected error"
453-
logger.error(f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {error_msg}", exc_info=True)
454-
failure_reason = error_msg
438+
except NonSkippableException as e:
439+
# terminate the rollout right away, no retry and preserve the current trajectory history.
440+
# for other types of exceptions, keep propagate them to upper layers and handle them with retry handler.
441+
trajectory.terminated = True
442+
trajectory.termination_reason = TerminationReason.NON_SKIPPABLE_ERROR
443+
trajectory.control_plane_summary.update({"error_message": str(e)})
444+
logger.error(f"🚨 Rollout {rollout_idx} terminated due to non-skippable error: {str(e)}", exc_info=True)
455445
finally:
456-
if failure_reason:
457-
trajectory.terminated = True
458-
trajectory.termination_reason = TerminationReason.ERROR
459-
trajectory.control_plane_summary.update({"error_message": f"{failure_reason}"})
460446
try:
461447
await envs.connection_manager.reset_session(session)
462448
except Exception as e:
463449
logger.warning(f"Failed to reset session {session.session_id}: {type(e).__name__}: {e}", exc_info=True)
464-
465450
try:
466451
await envs.connection_manager.close_session(session)
467452
except Exception as e:

eval_protocol/models.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from datetime import datetime
3+
from enum import Enum
34
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
45

56
from openai.types import CompletionUsage
@@ -11,6 +12,7 @@
1112

1213
from eval_protocol.get_pep440_version import get_pep440_version
1314
from eval_protocol.human_id import generate_id
15+
from eval_protocol.types import TerminationReason
1416

1517

1618
class ChatCompletionContentPartTextParam(BaseModel):
@@ -285,14 +287,20 @@ class RolloutStatus(BaseModel):
285287

286288
"""
287289
running: Unfinished rollout which is still in progress.
288-
finished: Rollout finished successfully.
289-
error: Rollout failed.
290-
stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop).
290+
finished: Rollout finished.
291+
error: Rollout failed due to unexpected error. The rollout record should be discard.
291292
"""
292-
status: Literal["running", "finished", "error"] = Field("running", description="Status of the rollout.")
293-
termination_reason: Optional[str] = Field(
294-
"", description="reason of the rollout status, mapped to values in TerminationReason"
293+
294+
class Status(str, Enum):
295+
RUNNING = "running"
296+
FINISHED = "finished"
297+
ERROR = "error"
298+
299+
status: Status = Field(Status.RUNNING, description="Status of the rollout.")
300+
termination_reason: Optional[TerminationReason] = Field(
301+
None, description="reason of the rollout status, mapped to values in TerminationReason"
295302
)
303+
extra_info: Optional[Dict[str, Any]] = Field(None, description="Extra information about the rollout status.")
296304

297305

298306
class EvaluationRow(BaseModel):

eval_protocol/pytest/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Callable, Dict, List, Literal, Optional, Union
77

88
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
9-
from eval_protocol.models import EvalMetadata, EvaluationRow
9+
from eval_protocol.models import EvalMetadata, EvaluationRow, RolloutStatus
1010
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1111
from eval_protocol.pytest.types import (
1212
CompletionParams,
@@ -248,7 +248,7 @@ async def rollout_processor_with_retry(
248248
"""
249249

250250
try:
251-
queue = asyncio.Queue()
251+
queue: asyncio.Queue[EvaluationRow] = asyncio.Queue()
252252
retry_counts = {r.execution_metadata.rollout_id: 0 for r in fresh_dataset}
253253
failed_permanently = []
254254

@@ -257,7 +257,7 @@ async def retry_handler(failed_row: EvaluationRow):
257257
current_attempts = retry_counts.get(rollout_id, 0)
258258

259259
if current_attempts >= max_retry:
260-
assert failed_row.rollout_status and failed_row.rollout_status.status == "error", (
260+
assert failed_row.rollout_status and failed_row.rollout_status.status == RolloutStatus.Status.ERROR, (
261261
f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status"
262262
)
263263
failed_permanently.append(failed_row)
@@ -273,10 +273,10 @@ async def retry_handler(failed_row: EvaluationRow):
273273

274274
try:
275275
retry_result = await retry_tasks[0]
276-
retry_result.rollout_status.status = "finished"
276+
retry_result.rollout_status.status = RolloutStatus.Status.FINISHED
277277
await queue.put(retry_result)
278278
except Exception as e:
279-
failed_row.rollout_status.status = "error"
279+
failed_row.rollout_status.status = RolloutStatus.Status.ERROR
280280
failed_row.rollout_status.termination_reason = str(e)
281281
asyncio.create_task(retry_handler(failed_row)) # retry failed, spawn another retry
282282

@@ -299,11 +299,11 @@ async def initial_processor():
299299

300300
try:
301301
result = await task
302-
result.rollout_status.status = "finished"
302+
result.rollout_status.status = RolloutStatus.Status.FINISHED
303303
await queue.put(result)
304304
except Exception as e:
305305
failed_row = fresh_dataset[task_index]
306-
failed_row.rollout_status.status = "error"
306+
failed_row.rollout_status.status = RolloutStatus.Status.ERROR
307307
failed_row.rollout_status.termination_reason = str(e)
308308
asyncio.create_task(retry_handler(failed_row)) # rollout errored, spawn retry task
309309

@@ -317,7 +317,7 @@ async def initial_processor():
317317
finished_row = await queue.get()
318318

319319
# only permanent failure rows are put on the queue, so we can check for them here
320-
if finished_row.rollout_status and finished_row.rollout_status.status == "error":
320+
if finished_row.rollout_status and finished_row.rollout_status.status == RolloutStatus.Status.ERROR:
321321
if max_retry > 0 and os.getenv("EP_FAIL_ON_MAX_RETRY", "true") != "false":
322322
raise RuntimeError(
323323
f"Rollout {finished_row.execution_metadata.rollout_id} failed after {max_retry} retries. Errors: {finished_row.rollout_status.termination_reason}"

eval_protocol/types/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .types import DatasetRow, MCPSession, MCPToolCall, TerminationReason, Trajectory
2+
from .errors import NonSkippableException
23

3-
__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow"]
4+
__all__ = ["MCPSession", "MCPToolCall", "TerminationReason", "Trajectory", "DatasetRow", "NonSkippableException"]

eval_protocol/types/errors.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
class NonSkippableException(Exception):
2+
"""
3+
A type of custom exception raised during rollout or evaluation. This error means the rollout/evaluation result is not skippable and need to be
4+
processed explicitly.
5+
6+
For example, if the policy (llm) returns 400 User error, we need to end the rollout but keep the trajectory.
7+
It differs from other exceptions such as network error, which are retriable and the trajectory should be discarded if
8+
it fails eventually after retries.
9+
"""
10+
11+
pass

eval_protocol/types/types.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ class TerminationReason(str, Enum):
1212
MAX_STEPS: Trajectory ends because we hit the step limit
1313
CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition)
1414
USER_STOP: Trajectory ends because the simulated user signals to stop
15-
ERROR: Trajectory ends because of an error
15+
SKIPPABLE_ERROR: Trajectory ends because of an error, this trajectory can be discarded/skipped during postprocessing/evaluation.
16+
NON_SKIPPABLE_ERROR: Trajectory is interrupted due to some non-skippable error (e.g. policy returns unexpected response and we need to terminate the rollout).
1617
STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop")
1718
LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length")
1819
TOOL_CALLS: Trajectory ends by the policy with a hanging tool call response (mapped to llm response stop reason "tool_calls")
@@ -21,7 +22,8 @@ class TerminationReason(str, Enum):
2122
MAX_STEPS = "max_steps"
2223
CONTROL_PLANE_SIGNAL = "control_plane_signal"
2324
USER_STOP = "user_stop"
24-
ERROR = "error"
25+
SKIPPABLE_ERROR = "skippable_error"
26+
NON_SKIPPABLE_ERROR = "non_skippable_error"
2527
STOP = "stop"
2628
LENGTH = "length"
2729
TOOL_CALLS = "tool_calls"
@@ -38,8 +40,10 @@ def from_str(cls, value: str) -> "TerminationReason":
3840
return cls.CONTROL_PLANE_SIGNAL
3941
elif value == "user_stop":
4042
return cls.USER_STOP
41-
elif value == "error":
42-
return cls.ERROR
43+
elif value == "skippable_error":
44+
return cls.SKIPPABLE_ERROR
45+
elif value == "non_skippable_error":
46+
return cls.NON_SKIPPABLE_ERROR
4347
elif value == "tool_calls":
4448
return cls.TOOL_CALLS
4549
else:

0 commit comments

Comments
 (0)