Skip to content

Commit b0de110

Browse files
committed
feat: add FireworksTrainingRolloutProcessor for RFT (FIR2-1351)
A new default RolloutProcessor that drives Fireworks /v1/completions via `FireworksV1CompletionsClient` and surfaces the per-sample token-level data required by reinforcement fine-tuning training (GRPO, CISPO, DAPO, GSPO). Problem ------- The existing `SingleTurnRolloutProcessor` uses LiteLLM chat completions and discards token ids + inference logprobs, so scored `EvaluationRow`s are fine for evaluation but cannot feed a training loop. Today, teams that need training-ready rollouts write a bespoke `RolloutProcessor` (the FrozenLake example in fw-ai/cookbook is ~800 lines). This puts token ids / logprobs out of reach of every customer evaluator bundle unless they rewrite their own processor. What it does ------------ For each `EvaluationRow`, `FireworksTrainingRolloutProcessor`: * Reads model / temperature / max_tokens / n from `completion_params`. * Builds prompt token ids locally via `FireworksV1CompletionsClient". * Fires `n` parallel `/v1/completions` calls from the same `prompt_token_ids", so each completion gets independent retry behaviour rather than collapsing on partial server failures. * Appends the first completion as the assistant message so existing evaluators that inspect `last_assistant_message()" keep scoring. * Populates `EvaluationRow.execution_metadata.extra" with: - `prompt_ids: list[int]" (shared across completions) - `completion_ids: list[list[int]]" (per-completion) - `inference_logprobs: list[list[float]]" (aligned to completion tokens) - `completions_text: list[str]" - `truncated: list[bool]" (`finish_reason == 'length'") - `finish_reasons: list[str]" * Merges into pre-existing `extra" rather than clobbering it. * Caches one client per model id; closes them all via `acleanup()". Shape rationale --------------- OpenEnvRolloutProcessor already writes flat `prompt_ids" / `completion_ids" concatenated across turns (multi-turn, per-episode agent rollouts). Single-turn RFT samples n>1 completions per prompt for advantage estimation and needs per-completion indexing, hence the `list[list[...]]" shape here. The training adapter on the consumer side can key into either convention without loss of generality. Tests ----- 8 new unit tests stub `FireworksV1CompletionsClient" so no network calls or tokenizers are needed; existing `SingleTurnRolloutProcessor" suite still passes. Fixes FIR2-1351
1 parent 0655f89 commit b0de110

3 files changed

Lines changed: 567 additions & 0 deletions

File tree

eval_protocol/pytest/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
"MCPGymRolloutProcessor": (".default_mcp_gym_rollout_processor", "MCPGymRolloutProcessor"),
1919
"NoOpRolloutProcessor": (".default_no_op_rollout_processor", "NoOpRolloutProcessor"),
2020
"SingleTurnRolloutProcessor": (".default_single_turn_rollout_process", "SingleTurnRolloutProcessor"),
21+
"FireworksTrainingRolloutProcessor": (
22+
".default_fireworks_training_rollout_processor",
23+
"FireworksTrainingRolloutProcessor",
24+
),
2125
"RemoteRolloutProcessor": (".remote_rollout_processor", "RemoteRolloutProcessor"),
2226
"GithubActionRolloutProcessor": (".github_action_rollout_processor", "GithubActionRolloutProcessor"),
2327
"RolloutProcessor": (".rollout_processor", "RolloutProcessor"),
@@ -102,6 +106,7 @@ def __dir__():
102106
"MCPGymRolloutProcessor",
103107
"RolloutProcessor",
104108
"SingleTurnRolloutProcessor",
109+
"FireworksTrainingRolloutProcessor",
105110
"RemoteRolloutProcessor",
106111
"GithubActionRolloutProcessor",
107112
"NoOpRolloutProcessor",
@@ -132,6 +137,9 @@ def __dir__():
132137
from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor as MCPGymRolloutProcessor
133138
from .default_no_op_rollout_processor import NoOpRolloutProcessor as NoOpRolloutProcessor
134139
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor as SingleTurnRolloutProcessor
140+
from .default_fireworks_training_rollout_processor import (
141+
FireworksTrainingRolloutProcessor as FireworksTrainingRolloutProcessor,
142+
)
135143
from .remote_rollout_processor import RemoteRolloutProcessor as RemoteRolloutProcessor
136144
from .github_action_rollout_processor import GithubActionRolloutProcessor as GithubActionRolloutProcessor
137145
from .evaluation_test import evaluation_test as evaluation_test
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
"""Default training-aware :class:`RolloutProcessor` for Fireworks RFT.
2+
3+
Unlike :class:`SingleTurnRolloutProcessor`, which uses LiteLLM chat
4+
completions and discards token-level information, this processor drives
5+
``FireworksV1CompletionsClient`` against a Fireworks ``/v1/completions``
6+
endpoint and surfaces the per-sample token ids / inference logprobs
7+
needed by reinforcement-fine-tuning training (GRPO, CISPO, DAPO, GSPO).
8+
9+
Why this exists
10+
---------------
11+
RFT training loops consume token-level data per completion (prompt ids,
12+
completion ids, inference logprobs). ``SingleTurnRolloutProcessor`` was
13+
never meant to produce that; customers who need it have to write a
14+
bespoke :class:`RolloutProcessor` (see the FrozenLake example in
15+
``fw-ai/cookbook``, which is ~800 lines). This processor promotes that
16+
bespoke pattern into an Eval Protocol default so managed Fireworks RFT
17+
jobs can wire it up for every customer evaluator bundle without
18+
customer code changes.
19+
20+
What it puts on the row
21+
-----------------------
22+
Besides the usual ``EvaluationRow.messages`` update (append the first
23+
completion as an ``assistant`` turn so existing evaluators keep working),
24+
this processor writes the following keys to
25+
``EvaluationRow.execution_metadata.extra``:
26+
27+
* ``prompt_ids`` — ``list[int]`` (shared across the N completions)
28+
* ``completion_ids`` — ``list[list[int]]`` (one per completion)
29+
* ``inference_logprobs``— ``list[list[float]]`` aligned to completion tokens
30+
* ``completions_text`` — ``list[str]`` (one per completion)
31+
* ``truncated`` — ``list[bool]`` (True when ``finish_reason == 'length'``)
32+
* ``finish_reasons`` — ``list[str]``
33+
34+
Shape choice
35+
------------
36+
Keys are ``list[list[...]]`` keyed by completion index rather than the
37+
flattened concat convention used by
38+
:class:`OpenEnvRolloutProcessor` (which is multi-turn and has no natural
39+
per-completion structure). Single-turn RFT samples ``n>1`` completions
40+
per prompt for advantage estimation, so a per-completion shape is what
41+
the training adapter actually needs.
42+
43+
Ergonomics
44+
----------
45+
The processor reads all sampling knobs (``model``, ``temperature``,
46+
``max_tokens``, ``n``) from ``config.completion_params``, matching
47+
:class:`SingleTurnRolloutProcessor`. Customer evaluator bundles don't
48+
need to reference this class — the managed RFT launcher swaps it in.
49+
"""
50+
51+
from __future__ import annotations
52+
53+
import asyncio
54+
import logging
55+
import os
56+
import time
57+
from typing import Any
58+
59+
from openai.types import CompletionUsage
60+
61+
from eval_protocol.dataset_logger import default_logger
62+
from eval_protocol.models import EvaluationRow, Message
63+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
64+
from eval_protocol.pytest.types import RolloutProcessorConfig
65+
66+
logger = logging.getLogger(__name__)
67+
68+
69+
def _as_list_of_dicts(messages: list[Message]) -> list[dict[str, Any]]:
70+
"""Convert ``EvaluationRow.messages`` into the dict shape the client expects."""
71+
return [m.dump_mdoel_for_chat_completion_request() for m in messages]
72+
73+
74+
def _append_extra(row: EvaluationRow, updates: dict[str, Any]) -> None:
75+
"""Merge *updates* into ``row.execution_metadata.extra`` without clobbering."""
76+
current = row.execution_metadata.extra if row.execution_metadata else None
77+
merged: dict[str, Any] = dict(current) if current else {}
78+
merged.update(updates)
79+
row.execution_metadata.extra = merged
80+
81+
82+
class FireworksTrainingRolloutProcessor(RolloutProcessor):
83+
"""Single-turn rollout with token-level outputs attached for RFT training.
84+
85+
Args:
86+
drop_trailing_assistant_messages: When True (default), strip trailing
87+
assistant messages from the input conversation before sampling
88+
— matches :class:`SingleTurnRolloutProcessor` behaviour.
89+
tokenizer_name_or_path: Override for HuggingFace tokenizer lookup.
90+
When not set, the model id from ``completion_params["model"]`` is
91+
used (via ``FireworksV1CompletionsClient``'s default behaviour).
92+
api_key: Override for the Fireworks API key. Defaults to the
93+
``FIREWORKS_API_KEY`` env var at first use.
94+
base_url: Override for the Fireworks API base URL. Defaults to the
95+
``FIREWORKS_BASE_URL`` env var if set, else the SDK default.
96+
"""
97+
98+
def __init__(
99+
self,
100+
*,
101+
drop_trailing_assistant_messages: bool = True,
102+
tokenizer_name_or_path: str | None = None,
103+
api_key: str | None = None,
104+
base_url: str | None = None,
105+
) -> None:
106+
self.drop_trailing_assistant_messages = drop_trailing_assistant_messages
107+
self.tokenizer_name_or_path = tokenizer_name_or_path
108+
self._api_key = api_key
109+
self._base_url = base_url
110+
# One client per model id per processor instance; cached lazily in setup().
111+
self._clients: dict[str, Any] = {}
112+
113+
def setup(self) -> None:
114+
"""Validate the Fireworks SDK / tokenizer deps up front."""
115+
# Defer the heavy import to setup so processor construction is cheap.
116+
from eval_protocol.integrations.fireworks_v1_completions_client import ( # noqa: F401
117+
FireworksV1CompletionsClient,
118+
)
119+
120+
async def acleanup(self) -> None:
121+
for client in self._clients.values():
122+
try:
123+
await client.close()
124+
except Exception:
125+
logger.debug("FireworksV1CompletionsClient.close() failed", exc_info=True)
126+
self._clients.clear()
127+
128+
def _client_for(self, *, model_id: str, temperature: float, max_tokens: int) -> Any:
129+
"""Get-or-create a ``FireworksV1CompletionsClient`` for *model_id*."""
130+
cached = self._clients.get(model_id)
131+
if cached is not None:
132+
return cached
133+
134+
from eval_protocol.integrations.fireworks_v1_completions_client import (
135+
FireworksV1CompletionsClient,
136+
)
137+
138+
client = FireworksV1CompletionsClient(
139+
model_id=model_id,
140+
tokenizer_name_or_path=self.tokenizer_name_or_path,
141+
api_key=self._api_key or os.getenv("FIREWORKS_API_KEY"),
142+
base_url=self._base_url or os.getenv("FIREWORKS_BASE_URL"),
143+
temperature=temperature,
144+
max_tokens=max_tokens,
145+
logprobs=True,
146+
)
147+
self._clients[model_id] = client
148+
return client
149+
150+
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
151+
"""Return one asyncio.Task per input row."""
152+
153+
async def process_row(row: EvaluationRow) -> EvaluationRow:
154+
start_time = time.perf_counter()
155+
156+
if not row.messages:
157+
raise ValueError("EvaluationRow.messages is empty")
158+
159+
completion_params = dict(config.completion_params or {})
160+
row.input_metadata.completion_params = completion_params
161+
162+
model_id = completion_params.get("model")
163+
if not model_id:
164+
raise ValueError("completion_params.model is required")
165+
temperature = float(completion_params.get("temperature", 1.0))
166+
max_tokens = int(completion_params.get("max_tokens", 256))
167+
completions_per_prompt = int(completion_params.get("n", 1))
168+
if completions_per_prompt < 1:
169+
raise ValueError(f"n must be >= 1, got {completions_per_prompt}")
170+
171+
messages_for_request: list[Message] = list(row.messages)
172+
if self.drop_trailing_assistant_messages:
173+
while messages_for_request and messages_for_request[-1].role == "assistant":
174+
messages_for_request.pop()
175+
176+
client = self._client_for(model_id=str(model_id), temperature=temperature, max_tokens=max_tokens)
177+
178+
prompt_messages = _as_list_of_dicts(messages_for_request)
179+
prompt_token_ids = client.build_prompt_token_ids(messages=prompt_messages, tools=row.tools)
180+
181+
# Fire N parallel calls against the *same* prompt_token_ids. Each
182+
# call produces one completion. We sample in parallel because
183+
# Fireworks /v1/completions handles n=1 most reliably; requesting
184+
# n>1 sometimes collapses to a single choice on partial failures,
185+
# and we'd rather surface per-completion retry behaviour.
186+
async def _one_completion() -> dict[str, Any]:
187+
return await client.create_completion_from_prompt_ids(
188+
prompt_token_ids=prompt_token_ids, tools=row.tools
189+
)
190+
191+
results = await asyncio.gather(*[_one_completion() for _ in range(completions_per_prompt)])
192+
193+
completion_ids: list[list[int]] = []
194+
completions_text: list[str] = []
195+
inference_logprobs: list[list[float]] = []
196+
truncated: list[bool] = []
197+
finish_reasons: list[str] = []
198+
199+
for result in results:
200+
completion_ids.append(list(result.get("completion_ids") or []))
201+
inference_logprobs.append(list(result.get("completion_logprobs") or []))
202+
finish_reason = str(result.get("finish_reason") or "unknown")
203+
finish_reasons.append(finish_reason)
204+
truncated.append(finish_reason == "length")
205+
# Prefer the parsed assistant content if the client produced it;
206+
# fall back to the raw choice text.
207+
choice = (result.get("choices") or [{}])[0]
208+
message = choice.get("message") or {}
209+
text = str(message.get("content") or "")
210+
completions_text.append(text)
211+
212+
first_result = results[0]
213+
prompt_ids = list(first_result.get("prompt_ids") or prompt_token_ids)
214+
first_message = (first_result.get("choices") or [{}])[0].get("message") or {}
215+
first_tool_calls = first_message.get("tool_calls")
216+
217+
# Append the first completion as the assistant turn so that
218+
# existing evaluators that inspect ``last_assistant_message`` keep
219+
# working without modification.
220+
row.messages = list(messages_for_request) + [
221+
Message(
222+
role="assistant",
223+
content=completions_text[0] if completions_text else "",
224+
tool_calls=first_tool_calls,
225+
logprobs=inference_logprobs[0] if inference_logprobs else None,
226+
)
227+
]
228+
229+
row.execution_metadata.finish_reason = finish_reasons[0] if finish_reasons else None
230+
row.execution_metadata.tool_call_count = len(first_tool_calls) if first_tool_calls else 0
231+
row.execution_metadata.usage = CompletionUsage(
232+
prompt_tokens=len(prompt_ids),
233+
completion_tokens=sum(len(ids) for ids in completion_ids),
234+
total_tokens=len(prompt_ids) + sum(len(ids) for ids in completion_ids),
235+
)
236+
row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time
237+
238+
_append_extra(
239+
row,
240+
{
241+
"prompt_ids": prompt_ids,
242+
"completion_ids": completion_ids,
243+
"inference_logprobs": inference_logprobs,
244+
"completions_text": completions_text,
245+
"truncated": truncated,
246+
"finish_reasons": finish_reasons,
247+
},
248+
)
249+
250+
default_logger.log(row)
251+
return row
252+
253+
semaphore = config.semaphore
254+
255+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
256+
async with semaphore:
257+
return await process_row(r)
258+
259+
return [asyncio.create_task(_sem_wrapper(row)) for row in rows]

0 commit comments

Comments
 (0)