From aee752b681259e26ccf37308654805b610e849e8 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Fri, 12 Jun 2026 14:01:55 -0700 Subject: [PATCH] Decode prompt token trace payloads Hydrate prompt_token_ids from Fireworks tracing payloads so RemoteRolloutProcessor can pass token-native prompt IDs through assistant turn metadata. Co-authored-by: Cursor --- eval_protocol/adapters/fireworks_tracing.py | 16 ++ eval_protocol/adapters/pti_deserializer.py | 98 ++++++++ eval_protocol/pytest/tracing_utils.py | 2 + .../test_remote_rollout_prompt_token_ids.py | 209 ++++++++++++++++++ ...test_fireworks_tracing_prompt_token_ids.py | 73 ++++++ tests/pytest/test_tracing_utils.py | 4 + 6 files changed, 402 insertions(+) create mode 100644 eval_protocol/adapters/pti_deserializer.py create mode 100644 scripts/test_remote_rollout_prompt_token_ids.py create mode 100644 tests/adapters/test_fireworks_tracing_prompt_token_ids.py diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 62a632e6..380dc876 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -17,6 +17,7 @@ from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message from .base import BaseAdapter from .lp_deserializer import decompress_and_parse_lp +from .pti_deserializer import decompress_and_parse_pti from .r3_deserializer import decompress_and_parse_r3 from .utils import extract_messages_from_data from ..common_utils import get_user_agent @@ -142,6 +143,21 @@ def convert_trace_dict_to_evaluation_row( e, ) + prompt_ids_payload = payloads.get("prompt_token_ids") + if isinstance(prompt_ids_payload, dict) and prompt_ids_payload.get("data"): + try: + prompt_token_ids, pti_meta = decompress_and_parse_pti(prompt_ids_payload["data"]) + if execution_metadata.extra is None: + execution_metadata.extra = {} + execution_metadata.extra["prompt_token_ids"] = prompt_token_ids + execution_metadata.extra["prompt_token_ids_metadata"] = pti_meta + except Exception as e: + logger.warning( + "Failed to decompress prompt token IDs payload for trace %s: %s", + trace.get("id"), + e, + ) + return EvaluationRow( messages=messages, tools=tools, diff --git a/eval_protocol/adapters/pti_deserializer.py b/eval_protocol/adapters/pti_deserializer.py new file mode 100644 index 00000000..cae74cc0 --- /dev/null +++ b/eval_protocol/adapters/pti_deserializer.py @@ -0,0 +1,98 @@ +"""PTI/v1 binary deserializer for prompt token ID payloads. + +Implements the inverse of the tracing gateway's +``prompt_token_ids_serializer.serialize_prompt_token_ids``. +""" + +from __future__ import annotations + +import base64 +import struct +from typing import Any, Dict, List, Tuple + +import zstandard as zstd + +MAGIC = b"PTI1" +HEADER_VERSION = 1 +ENTRY_FORMAT = " Dict[str, Any]: + if len(raw) < HEADER_SIZE: + raise ValueError(f"Payload too short for PTI/v1 header: {len(raw)} < {HEADER_SIZE}") + + ( + magic, + version, + flags, + reserved_u16, + token_count, + body_byte_length, + reserved_u64, + ) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE]) + + if magic != MAGIC: + raise ValueError(f"Bad PTI/v1 magic: {magic!r}") + if version != HEADER_VERSION: + raise ValueError(f"Unsupported PTI/v1 header version: {version}") + + return { + "flags": flags, + "reserved_u16": reserved_u16, + "token_count": token_count, + "body_byte_length": body_byte_length, + "reserved_u64": reserved_u64, + } + + +def parse_prompt_token_ids(raw: bytes) -> Tuple[List[int], Dict[str, Any]]: + """Parse uncompressed PTI/v1 bytes into prompt token IDs and metadata.""" + header = _parse_header(raw) + token_count = header["token_count"] + body_byte_length = header["body_byte_length"] + + if token_count == 0: + raise ValueError("PTI/v1 token_count must be > 0") + if body_byte_length != token_count * ENTRY_SIZE: + raise ValueError( + f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} " + f"({token_count * ENTRY_SIZE})" + ) + + expected_len = HEADER_SIZE + body_byte_length + if len(raw) != expected_len: + raise ValueError(f"PTI/v1 payload length mismatch: {len(raw)} != {expected_len}") + + token_ids: List[int] = [] + offset = HEADER_SIZE + for _ in range(token_count): + (token_id,) = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE]) + offset += ENTRY_SIZE + token_ids.append(token_id) + + metadata: Dict[str, Any] = { + "scope": "prompt_only", + "token_count": token_count, + } + header.update(metadata) + return token_ids, header + + +def decompress_and_parse_pti(data_b64: str) -> Tuple[List[int], Dict[str, Any]]: + """Decompress and unpack a PTI/v1 prompt token ID payload. + + Args: + data_b64: Base64-encoded zstd-compressed PTI binary blob from + ``payloads.prompt_token_ids.data``. + + Returns: + ``(token_ids, metadata)`` where ``token_ids`` is the prompt token ID + sequence and ``metadata`` includes ``token_count``. + """ + compressed = base64.b64decode(data_b64) + decompressor = zstd.ZstdDecompressor() + raw = decompressor.decompress(compressed) + return parse_prompt_token_ids(raw) diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index 279d1055..e0d5db73 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -63,6 +63,8 @@ def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[Eval for key in ( "completion_logprobs", "completion_token_ids", + "prompt_token_ids", + "prompt_token_ids_metadata", "logprobs_metadata", "routing_matrices", "routing_metadata", diff --git a/scripts/test_remote_rollout_prompt_token_ids.py b/scripts/test_remote_rollout_prompt_token_ids.py new file mode 100644 index 00000000..a143772f --- /dev/null +++ b/scripts/test_remote_rollout_prompt_token_ids.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +"""E2E check: RemoteRolloutProcessor reads prompt_token_ids trace payloads. + +This starts a tiny local `/init` server, sends one chat completion through the +Fireworks tracing gateway with `return_token_ids`, and verifies that +RemoteRolloutProcessor hydrates `assistant_turn_payloads[*].prompt_token_ids`. +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import sys +import socket +import threading +import time +from pathlib import Path +from typing import Any + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +import uvicorn +from fastapi import FastAPI +from openai import OpenAI + +from eval_protocol import FireworksTracingHttpHandler, InitRequest, RolloutIdFilter, Status +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig + +logger = logging.getLogger("remote_rollout_prompt_token_ids") +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _message_to_dict(message: Message | dict[str, Any]) -> dict[str, Any]: + if isinstance(message, Message): + return message.dump_mdoel_for_chat_completion_request() + return {k: v for k, v in dict(message).items() if v is not None} + + +def _make_app(gateway_url: str) -> FastAPI: + app = FastAPI() + app_logger = logging.getLogger(f"{__name__}.server") + app_logger.setLevel(logging.INFO) + + @app.get("/") + def health() -> dict[str, str]: + return {"status": "ok"} + + @app.post("/init") + def init(req: InitRequest) -> dict[str, str]: + rollout_logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}") + rollout_logger.addFilter(RolloutIdFilter(req.metadata.rollout_id)) + if not any(isinstance(handler, FireworksTracingHttpHandler) for handler in rollout_logger.handlers): + rollout_logger.addHandler(FireworksTracingHttpHandler(gateway_base_url=gateway_url)) + rollout_logger.setLevel(logging.INFO) + + def _worker() -> None: + try: + conversation = [_message_to_dict(message) for message in (req.messages or [])] + params = dict(req.completion_params or {}) + params.pop("base_url", None) + params["extra_body"] = { + **dict(params.get("extra_body") or {}), + "return_token_ids": True, + } + params.setdefault("temperature", 0) + params.setdefault("max_tokens", 8) + + if not req.model_base_url: + raise ValueError("model_base_url is required") + if not params.get("model"): + raise ValueError("completion_params.model is required") + + client = OpenAI(base_url=req.model_base_url, api_key=req.api_key) + response = client.chat.completions.create(messages=conversation, **params) + content = response.choices[0].message.content or "" + logger.info("remote server generated content=%r", content) + + rollout_logger.info( + "rollout %s finished", + req.metadata.rollout_id, + extra={"status": Status.rollout_finished()}, + ) + except Exception as exc: + rollout_logger.exception( + "rollout %s failed", + req.metadata.rollout_id, + extra={"status": Status.rollout_unknown_error(str(exc))}, + ) + + threading.Thread(target=_worker, daemon=True).start() + return {"status": "started"} + + return app + + +def _wait_ready(url: str, timeout_seconds: float = 30.0) -> None: + import requests + + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url, timeout=2) + if resp.status_code == 200: + return + except Exception: + pass + time.sleep(0.2) + raise TimeoutError(f"server not ready: {url}") + + +async def _run(args: argparse.Namespace) -> None: + api_key = args.api_key or os.getenv("FIREWORKS_DEV_API_KEY") or os.getenv("FIREWORKS_API_KEY") + if not api_key: + raise ValueError("Set FIREWORKS_DEV_API_KEY or FIREWORKS_API_KEY") + + # FireworksTracingHttpHandler reads FIREWORKS_API_KEY. + os.environ["FIREWORKS_API_KEY"] = api_key + os.environ["EP_REMOTE_API_KEY"] = api_key + + port = args.port or _free_port() + remote_base_url = f"http://127.0.0.1:{port}" + app = _make_app(args.gateway_url) + config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + _wait_ready(f"{remote_base_url}/") + + rollout_id = f"rrp-prompt-ids-{int(time.time())}" + row = EvaluationRow( + messages=[Message(role="user", content="Reply with exactly: ok")], + ) + row.input_metadata.row_id = "row-0" + row.input_metadata.completion_params = { + "model": args.model, + "base_url": args.api_base_url, + "temperature": 0, + "max_tokens": 8, + } + row.execution_metadata.rollout_id = rollout_id + row.execution_metadata.invocation_id = "inv-0" + row.execution_metadata.experiment_id = "fir2-1747-rrp-e2e" + row.execution_metadata.run_id = "run-0" + + processor = RemoteRolloutProcessor( + remote_base_url=remote_base_url, + model_base_url=args.gateway_url, + include_payloads=True, + timeout_seconds=args.timeout_seconds, + poll_interval=args.poll_interval, + ) + try: + task = processor( + [row], + RolloutProcessorConfig( + completion_params=row.input_metadata.completion_params, + mcp_config_path="", + semaphore=asyncio.Semaphore(1), + steps=1, + ), + )[0] + completed = await task + finally: + await processor.acleanup() + server.should_exit = True + thread.join(timeout=5) + + extra = completed.execution_metadata.extra or {} + turn_payloads = extra.get("assistant_turn_payloads") or [] + prompt_ids = None + if turn_payloads: + prompt_ids = turn_payloads[0].get("prompt_token_ids") + if prompt_ids is None: + prompt_ids = extra.get("prompt_token_ids") + + print(f"rollout_id={rollout_id}") + print(f"messages={len(completed.messages)}") + print(f"assistant_turn_payloads={turn_payloads}") + print(f"prompt_token_ids_len={len(prompt_ids) if isinstance(prompt_ids, list) else None}") + print(f"prompt_token_ids_head={prompt_ids[:8] if isinstance(prompt_ids, list) else None}") + + if not isinstance(prompt_ids, list) or not prompt_ids: + raise AssertionError("RemoteRolloutProcessor did not hydrate prompt_token_ids") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--gateway-url", default=os.getenv("EP_MODEL_BASE_URL", "https://litellm-gateway-dev-j4kzagdteq-uc.a.run.app")) + parser.add_argument("--api-base-url", default=os.getenv("FIREWORKS_API_BASE_URL", "https://dev.api.fireworks.ai/inference/v1")) + parser.add_argument("--model", default=os.getenv("TRACING_E2E_MODEL", "accounts/pyroworks-dev/deployments/malaysia2-intended-butterfly")) + parser.add_argument("--api-key", default=None) + parser.add_argument("--port", type=int, default=0) + parser.add_argument("--timeout-seconds", type=float, default=180.0) + parser.add_argument("--poll-interval", type=float, default=2.0) + asyncio.run(_run(parser.parse_args())) + + +if __name__ == "__main__": + main() diff --git a/tests/adapters/test_fireworks_tracing_prompt_token_ids.py b/tests/adapters/test_fireworks_tracing_prompt_token_ids.py new file mode 100644 index 00000000..869bac29 --- /dev/null +++ b/tests/adapters/test_fireworks_tracing_prompt_token_ids.py @@ -0,0 +1,73 @@ +"""Tests for prompt token ID payload handling in fireworks_tracing adapter.""" + +from __future__ import annotations + +import base64 +import struct + +import pytest +import zstandard as zstd + +pytest.importorskip("mcp") + +from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row +from eval_protocol.adapters.pti_deserializer import ( + ENTRY_FORMAT, + ENTRY_SIZE, + HEADER_FORMAT, + MAGIC, + decompress_and_parse_pti, +) + + +def _pti_b64(token_ids: list[int]) -> str: + token_count = len(token_ids) + body_byte_length = token_count * ENTRY_SIZE + header = struct.pack( + HEADER_FORMAT, + MAGIC, + 1, + 0, + 0, + token_count, + body_byte_length, + 0, + ) + body = b"".join(struct.pack(ENTRY_FORMAT, token_id) for token_id in token_ids) + compressed = zstd.ZstdCompressor().compress(header + body) + return base64.b64encode(compressed).decode("ascii") + + +def test_decompress_and_parse_pti_round_trip(): + token_ids, metadata = decompress_and_parse_pti(_pti_b64([101, 102, 103])) + + assert token_ids == [101, 102, 103] + assert metadata["scope"] == "prompt_only" + assert metadata["token_count"] == 3 + + +def test_trace_adapter_attaches_prompt_token_ids_metadata(): + trace = { + "id": "trace-pti", + "input": { + "messages": [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ], + }, + "output": {"role": "assistant", "content": "hello"}, + "payloads": { + "prompt_token_ids": { + "data": _pti_b64([201, 202, 203]), + "manifest": {"PayloadVersion": "pti/v1"}, + }, + }, + } + + row = convert_trace_dict_to_evaluation_row(trace) + + assert row is not None + extra = row.execution_metadata.extra + assert extra is not None + assert extra["prompt_token_ids"] == [201, 202, 203] + assert extra["prompt_token_ids_metadata"]["token_count"] == 3 diff --git a/tests/pytest/test_tracing_utils.py b/tests/pytest/test_tracing_utils.py index 58ec55c1..98246ab1 100644 --- a/tests/pytest/test_tracing_utils.py +++ b/tests/pytest/test_tracing_utils.py @@ -13,6 +13,7 @@ def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): execution_metadata=ExecutionMetadata( extra={ "completion_logprobs": [-0.1, -0.2], + "prompt_token_ids": [101, 102], "routing_matrices": ["first-matrix"], "routing_metadata": {"total_token_count": 1}, }, @@ -28,6 +29,7 @@ def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): execution_metadata=ExecutionMetadata( extra={ "completion_logprobs": [-0.3], + "prompt_token_ids": [101, 102, 103, 104], "routing_matrices": ["second-matrix"], "routing_metadata": {"total_token_count": 1}, }, @@ -45,12 +47,14 @@ def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): { "assistant_turn_index": 0, "completion_logprobs": [-0.1, -0.2], + "prompt_token_ids": [101, 102], "routing_matrices": ["first-matrix"], "routing_metadata": {"total_token_count": 1}, }, { "assistant_turn_index": 1, "completion_logprobs": [-0.3], + "prompt_token_ids": [101, 102, 103, 104], "routing_matrices": ["second-matrix"], "routing_metadata": {"total_token_count": 1}, },