|
12 | 12 | import threading |
13 | 13 | import time |
14 | 14 | from dataclasses import asdict |
15 | | -from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Union |
| 15 | +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union |
16 | 16 |
|
17 | 17 | import anyio |
18 | | -import httpx |
19 | 18 | from openai.types import CompletionUsage |
20 | 19 |
|
21 | 20 | from vendor.tau2.data_model.message import AssistantMessage, UserMessage |
22 | 21 | from vendor.tau2.user.user_simulator import UserSimulator |
23 | 22 |
|
24 | | -from ...models import EvaluationRow, InputMetadata, Message |
25 | | -from ...types import MCPSession, MCPToolCall, TerminationReason, Trajectory |
| 23 | +from ...models import EvaluationRow, InputMetadata, Message, RolloutStatus |
| 24 | +from ...types import TerminationReason, Trajectory, NonSkippableException |
26 | 25 |
|
27 | 26 | if TYPE_CHECKING: |
28 | 27 | from ..session.manager import GeneralMCPVectorEnv |
@@ -107,7 +106,7 @@ async def _execute_with_semaphore(idx): |
107 | 106 | ) |
108 | 107 |
|
109 | 108 | # Convert trajectory to EvaluationRow immediately |
110 | | - evaluation_row = evaluation_rows[idx] |
| 109 | + evaluation_row: EvaluationRow = evaluation_rows[idx] |
111 | 110 |
|
112 | 111 | # Handle multimodal content by extracting text from complex content structures |
113 | 112 | messages = [] |
@@ -137,16 +136,15 @@ async def _execute_with_semaphore(idx): |
137 | 136 | } |
138 | 137 |
|
139 | 138 | if trajectory.terminated: |
140 | | - if trajectory.termination_reason == TerminationReason.ERROR: |
141 | | - evaluation_row.rollout_status.status = "error" |
142 | | - evaluation_row.rollout_status.termination_reason = trajectory.control_plane_summary.get( |
143 | | - "error_message", None |
144 | | - ) |
145 | | - else: |
146 | | - evaluation_row.rollout_status.status = "finished" |
147 | | - evaluation_row.rollout_status.termination_reason = trajectory.termination_reason |
| 139 | + evaluation_row.rollout_status.termination_reason = trajectory.termination_reason |
| 140 | + evaluation_row.rollout_status.status = RolloutStatus.Status.FINISHED |
| 141 | + # preserve the true error mesage if there are any |
| 142 | + if trajectory.control_plane_summary.get("error_message"): |
| 143 | + evaluation_row.rollout_status.extra_info = { |
| 144 | + "error_message": trajectory.control_plane_summary.get("error_message") |
| 145 | + } |
148 | 146 | else: |
149 | | - evaluation_row.rollout_status.status = "running" |
| 147 | + evaluation_row.rollout_status.status = RolloutStatus.Status.RUNNING |
150 | 148 |
|
151 | 149 | return evaluation_row |
152 | 150 |
|
@@ -437,31 +435,18 @@ async def _execute_rollout( |
437 | 435 | logger.info( |
438 | 436 | f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}" |
439 | 437 | ) |
440 | | - |
441 | | - except asyncio.CancelledError: |
442 | | - failure_reason = "asyncio context cancelled" |
443 | | - logger.error( |
444 | | - f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True |
445 | | - ) |
446 | | - except (anyio.ClosedResourceError, anyio.BrokenResourceError): |
447 | | - failure_reason = "anyioconnection/resource error" |
448 | | - logger.error( |
449 | | - f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {failure_reason}", exc_info=True |
450 | | - ) |
451 | | - except Exception as e: |
452 | | - error_msg = str(e) if str(e) else f"{type(e).__name__}: Unexpected error" |
453 | | - logger.error(f"🚨 Error in rollout {session.dataset_row.id} {rollout_idx}: {error_msg}", exc_info=True) |
454 | | - failure_reason = error_msg |
| 438 | + except NonSkippableException as e: |
| 439 | + # terminate the rollout right away, no retry and preserve the current trajectory history. |
| 440 | + # for other types of exceptions, keep propagate them to upper layers and handle them with retry handler. |
| 441 | + trajectory.terminated = True |
| 442 | + trajectory.termination_reason = TerminationReason.NON_SKIPPABLE_ERROR |
| 443 | + trajectory.control_plane_summary.update({"error_message": str(e)}) |
| 444 | + logger.error(f"🚨 Rollout {rollout_idx} terminated due to non-skippable error: {str(e)}", exc_info=True) |
455 | 445 | finally: |
456 | | - if failure_reason: |
457 | | - trajectory.terminated = True |
458 | | - trajectory.termination_reason = TerminationReason.ERROR |
459 | | - trajectory.control_plane_summary.update({"error_message": f"{failure_reason}"}) |
460 | 446 | try: |
461 | 447 | await envs.connection_manager.reset_session(session) |
462 | 448 | except Exception as e: |
463 | 449 | logger.warning(f"Failed to reset session {session.session_id}: {type(e).__name__}: {e}", exc_info=True) |
464 | | - |
465 | 450 | try: |
466 | 451 | await envs.connection_manager.close_session(session) |
467 | 452 | except Exception as e: |
|
0 commit comments