|
| 1 | +"""Default training-aware :class:`RolloutProcessor` for Fireworks RFT. |
| 2 | +
|
| 3 | +Unlike :class:`SingleTurnRolloutProcessor`, which uses LiteLLM chat |
| 4 | +completions and discards token-level information, this processor drives |
| 5 | +``FireworksV1CompletionsClient`` against a Fireworks ``/v1/completions`` |
| 6 | +endpoint and surfaces the per-sample token ids / inference logprobs |
| 7 | +needed by reinforcement-fine-tuning training (GRPO, CISPO, DAPO, GSPO). |
| 8 | +
|
| 9 | +Why this exists |
| 10 | +--------------- |
| 11 | +RFT training loops consume token-level data per completion (prompt ids, |
| 12 | +completion ids, inference logprobs). ``SingleTurnRolloutProcessor`` was |
| 13 | +never meant to produce that; customers who need it have to write a |
| 14 | +bespoke :class:`RolloutProcessor` (see the FrozenLake example in |
| 15 | +``fw-ai/cookbook``, which is ~800 lines). This processor promotes that |
| 16 | +bespoke pattern into an Eval Protocol default so managed Fireworks RFT |
| 17 | +jobs can wire it up for every customer evaluator bundle without |
| 18 | +customer code changes. |
| 19 | +
|
| 20 | +What it puts on the row |
| 21 | +----------------------- |
| 22 | +Besides the usual ``EvaluationRow.messages`` update (append the first |
| 23 | +completion as an ``assistant`` turn so existing evaluators keep working), |
| 24 | +this processor writes the following keys to |
| 25 | +``EvaluationRow.execution_metadata.extra``: |
| 26 | +
|
| 27 | +* ``prompt_ids`` — ``list[int]`` (shared across the N completions) |
| 28 | +* ``completion_ids`` — ``list[list[int]]`` (one per completion) |
| 29 | +* ``inference_logprobs``— ``list[list[float]]`` aligned to completion tokens |
| 30 | +* ``completions_text`` — ``list[str]`` (one per completion) |
| 31 | +* ``truncated`` — ``list[bool]`` (True when ``finish_reason == 'length'``) |
| 32 | +* ``finish_reasons`` — ``list[str]`` |
| 33 | +
|
| 34 | +Shape choice |
| 35 | +------------ |
| 36 | +Keys are ``list[list[...]]`` keyed by completion index rather than the |
| 37 | +flattened concat convention used by |
| 38 | +:class:`OpenEnvRolloutProcessor` (which is multi-turn and has no natural |
| 39 | +per-completion structure). Single-turn RFT samples ``n>1`` completions |
| 40 | +per prompt for advantage estimation, so a per-completion shape is what |
| 41 | +the training adapter actually needs. |
| 42 | +
|
| 43 | +Ergonomics |
| 44 | +---------- |
| 45 | +The processor reads all sampling knobs (``model``, ``temperature``, |
| 46 | +``max_tokens``, ``n``) from ``config.completion_params``, matching |
| 47 | +:class:`SingleTurnRolloutProcessor`. Customer evaluator bundles don't |
| 48 | +need to reference this class — the managed RFT launcher swaps it in. |
| 49 | +""" |
| 50 | + |
| 51 | +from __future__ import annotations |
| 52 | + |
| 53 | +import asyncio |
| 54 | +import logging |
| 55 | +import os |
| 56 | +import time |
| 57 | +from typing import Any |
| 58 | + |
| 59 | +from openai.types import CompletionUsage |
| 60 | + |
| 61 | +from eval_protocol.dataset_logger import default_logger |
| 62 | +from eval_protocol.models import EvaluationRow, Message |
| 63 | +from eval_protocol.pytest.rollout_processor import RolloutProcessor |
| 64 | +from eval_protocol.pytest.types import RolloutProcessorConfig |
| 65 | + |
| 66 | +logger = logging.getLogger(__name__) |
| 67 | + |
| 68 | + |
| 69 | +def _as_list_of_dicts(messages: list[Message]) -> list[dict[str, Any]]: |
| 70 | + """Convert ``EvaluationRow.messages`` into the dict shape the client expects.""" |
| 71 | + return [m.dump_mdoel_for_chat_completion_request() for m in messages] |
| 72 | + |
| 73 | + |
| 74 | +def _append_extra(row: EvaluationRow, updates: dict[str, Any]) -> None: |
| 75 | + """Merge *updates* into ``row.execution_metadata.extra`` without clobbering.""" |
| 76 | + current = row.execution_metadata.extra if row.execution_metadata else None |
| 77 | + merged: dict[str, Any] = dict(current) if current else {} |
| 78 | + merged.update(updates) |
| 79 | + row.execution_metadata.extra = merged |
| 80 | + |
| 81 | + |
| 82 | +class FireworksTrainingRolloutProcessor(RolloutProcessor): |
| 83 | + """Single-turn rollout with token-level outputs attached for RFT training. |
| 84 | +
|
| 85 | + Args: |
| 86 | + drop_trailing_assistant_messages: When True (default), strip trailing |
| 87 | + assistant messages from the input conversation before sampling |
| 88 | + — matches :class:`SingleTurnRolloutProcessor` behaviour. |
| 89 | + tokenizer_name_or_path: Override for HuggingFace tokenizer lookup. |
| 90 | + When not set, the model id from ``completion_params["model"]`` is |
| 91 | + used (via ``FireworksV1CompletionsClient``'s default behaviour). |
| 92 | + api_key: Override for the Fireworks API key. Defaults to the |
| 93 | + ``FIREWORKS_API_KEY`` env var at first use. |
| 94 | + base_url: Override for the Fireworks API base URL. Defaults to the |
| 95 | + ``FIREWORKS_BASE_URL`` env var if set, else the SDK default. |
| 96 | + """ |
| 97 | + |
| 98 | + def __init__( |
| 99 | + self, |
| 100 | + *, |
| 101 | + drop_trailing_assistant_messages: bool = True, |
| 102 | + tokenizer_name_or_path: str | None = None, |
| 103 | + api_key: str | None = None, |
| 104 | + base_url: str | None = None, |
| 105 | + ) -> None: |
| 106 | + self.drop_trailing_assistant_messages = drop_trailing_assistant_messages |
| 107 | + self.tokenizer_name_or_path = tokenizer_name_or_path |
| 108 | + self._api_key = api_key |
| 109 | + self._base_url = base_url |
| 110 | + # One client per model id per processor instance; cached lazily in setup(). |
| 111 | + self._clients: dict[str, Any] = {} |
| 112 | + |
| 113 | + def setup(self) -> None: |
| 114 | + """Validate the Fireworks SDK / tokenizer deps up front.""" |
| 115 | + # Defer the heavy import to setup so processor construction is cheap. |
| 116 | + from eval_protocol.integrations.fireworks_v1_completions_client import ( # noqa: F401 |
| 117 | + FireworksV1CompletionsClient, |
| 118 | + ) |
| 119 | + |
| 120 | + async def acleanup(self) -> None: |
| 121 | + for client in self._clients.values(): |
| 122 | + try: |
| 123 | + await client.close() |
| 124 | + except Exception: |
| 125 | + logger.debug("FireworksV1CompletionsClient.close() failed", exc_info=True) |
| 126 | + self._clients.clear() |
| 127 | + |
| 128 | + def _client_for(self, *, model_id: str, temperature: float, max_tokens: int) -> Any: |
| 129 | + """Get-or-create a ``FireworksV1CompletionsClient`` for *model_id*.""" |
| 130 | + cached = self._clients.get(model_id) |
| 131 | + if cached is not None: |
| 132 | + return cached |
| 133 | + |
| 134 | + from eval_protocol.integrations.fireworks_v1_completions_client import ( |
| 135 | + FireworksV1CompletionsClient, |
| 136 | + ) |
| 137 | + |
| 138 | + client = FireworksV1CompletionsClient( |
| 139 | + model_id=model_id, |
| 140 | + tokenizer_name_or_path=self.tokenizer_name_or_path, |
| 141 | + api_key=self._api_key or os.getenv("FIREWORKS_API_KEY"), |
| 142 | + base_url=self._base_url or os.getenv("FIREWORKS_BASE_URL"), |
| 143 | + temperature=temperature, |
| 144 | + max_tokens=max_tokens, |
| 145 | + logprobs=True, |
| 146 | + ) |
| 147 | + self._clients[model_id] = client |
| 148 | + return client |
| 149 | + |
| 150 | + def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]: |
| 151 | + """Return one asyncio.Task per input row.""" |
| 152 | + |
| 153 | + async def process_row(row: EvaluationRow) -> EvaluationRow: |
| 154 | + start_time = time.perf_counter() |
| 155 | + |
| 156 | + if not row.messages: |
| 157 | + raise ValueError("EvaluationRow.messages is empty") |
| 158 | + |
| 159 | + completion_params = dict(config.completion_params or {}) |
| 160 | + row.input_metadata.completion_params = completion_params |
| 161 | + |
| 162 | + model_id = completion_params.get("model") |
| 163 | + if not model_id: |
| 164 | + raise ValueError("completion_params.model is required") |
| 165 | + temperature = float(completion_params.get("temperature", 1.0)) |
| 166 | + max_tokens = int(completion_params.get("max_tokens", 256)) |
| 167 | + completions_per_prompt = int(completion_params.get("n", 1)) |
| 168 | + if completions_per_prompt < 1: |
| 169 | + raise ValueError(f"n must be >= 1, got {completions_per_prompt}") |
| 170 | + |
| 171 | + messages_for_request: list[Message] = list(row.messages) |
| 172 | + if self.drop_trailing_assistant_messages: |
| 173 | + while messages_for_request and messages_for_request[-1].role == "assistant": |
| 174 | + messages_for_request.pop() |
| 175 | + |
| 176 | + client = self._client_for(model_id=str(model_id), temperature=temperature, max_tokens=max_tokens) |
| 177 | + |
| 178 | + prompt_messages = _as_list_of_dicts(messages_for_request) |
| 179 | + prompt_token_ids = client.build_prompt_token_ids(messages=prompt_messages, tools=row.tools) |
| 180 | + |
| 181 | + # Fire N parallel calls against the *same* prompt_token_ids. Each |
| 182 | + # call produces one completion. We sample in parallel because |
| 183 | + # Fireworks /v1/completions handles n=1 most reliably; requesting |
| 184 | + # n>1 sometimes collapses to a single choice on partial failures, |
| 185 | + # and we'd rather surface per-completion retry behaviour. |
| 186 | + async def _one_completion() -> dict[str, Any]: |
| 187 | + return await client.create_completion_from_prompt_ids( |
| 188 | + prompt_token_ids=prompt_token_ids, tools=row.tools |
| 189 | + ) |
| 190 | + |
| 191 | + results = await asyncio.gather(*[_one_completion() for _ in range(completions_per_prompt)]) |
| 192 | + |
| 193 | + completion_ids: list[list[int]] = [] |
| 194 | + completions_text: list[str] = [] |
| 195 | + inference_logprobs: list[list[float]] = [] |
| 196 | + truncated: list[bool] = [] |
| 197 | + finish_reasons: list[str] = [] |
| 198 | + |
| 199 | + for result in results: |
| 200 | + completion_ids.append(list(result.get("completion_ids") or [])) |
| 201 | + inference_logprobs.append(list(result.get("completion_logprobs") or [])) |
| 202 | + finish_reason = str(result.get("finish_reason") or "unknown") |
| 203 | + finish_reasons.append(finish_reason) |
| 204 | + truncated.append(finish_reason == "length") |
| 205 | + # Prefer the parsed assistant content if the client produced it; |
| 206 | + # fall back to the raw choice text. |
| 207 | + choice = (result.get("choices") or [{}])[0] |
| 208 | + message = choice.get("message") or {} |
| 209 | + text = str(message.get("content") or "") |
| 210 | + completions_text.append(text) |
| 211 | + |
| 212 | + first_result = results[0] |
| 213 | + prompt_ids = list(first_result.get("prompt_ids") or prompt_token_ids) |
| 214 | + first_message = (first_result.get("choices") or [{}])[0].get("message") or {} |
| 215 | + first_tool_calls = first_message.get("tool_calls") |
| 216 | + |
| 217 | + # Append the first completion as the assistant turn so that |
| 218 | + # existing evaluators that inspect ``last_assistant_message`` keep |
| 219 | + # working without modification. |
| 220 | + row.messages = list(messages_for_request) + [ |
| 221 | + Message( |
| 222 | + role="assistant", |
| 223 | + content=completions_text[0] if completions_text else "", |
| 224 | + tool_calls=first_tool_calls, |
| 225 | + logprobs=inference_logprobs[0] if inference_logprobs else None, |
| 226 | + ) |
| 227 | + ] |
| 228 | + |
| 229 | + row.execution_metadata.finish_reason = finish_reasons[0] if finish_reasons else None |
| 230 | + row.execution_metadata.tool_call_count = len(first_tool_calls) if first_tool_calls else 0 |
| 231 | + row.execution_metadata.usage = CompletionUsage( |
| 232 | + prompt_tokens=len(prompt_ids), |
| 233 | + completion_tokens=sum(len(ids) for ids in completion_ids), |
| 234 | + total_tokens=len(prompt_ids) + sum(len(ids) for ids in completion_ids), |
| 235 | + ) |
| 236 | + row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time |
| 237 | + |
| 238 | + _append_extra( |
| 239 | + row, |
| 240 | + { |
| 241 | + "prompt_ids": prompt_ids, |
| 242 | + "completion_ids": completion_ids, |
| 243 | + "inference_logprobs": inference_logprobs, |
| 244 | + "completions_text": completions_text, |
| 245 | + "truncated": truncated, |
| 246 | + "finish_reasons": finish_reasons, |
| 247 | + }, |
| 248 | + ) |
| 249 | + |
| 250 | + default_logger.log(row) |
| 251 | + return row |
| 252 | + |
| 253 | + semaphore = config.semaphore |
| 254 | + |
| 255 | + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: |
| 256 | + async with semaphore: |
| 257 | + return await process_row(r) |
| 258 | + |
| 259 | + return [asyncio.create_task(_sem_wrapper(row)) for row in rows] |
0 commit comments