Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions eval_protocol/adapters/fireworks_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
from .base import BaseAdapter
from .lp_deserializer import decompress_and_parse_lp
from .pti_deserializer import decompress_and_parse_pti
from .r3_deserializer import decompress_and_parse_r3
from .utils import extract_messages_from_data
from ..common_utils import get_user_agent
Expand Down Expand Up @@ -142,6 +143,21 @@ def convert_trace_dict_to_evaluation_row(
e,
)

prompt_ids_payload = payloads.get("prompt_token_ids")
if isinstance(prompt_ids_payload, dict) and prompt_ids_payload.get("data"):
try:
prompt_token_ids, pti_meta = decompress_and_parse_pti(prompt_ids_payload["data"])
if execution_metadata.extra is None:
execution_metadata.extra = {}
execution_metadata.extra["prompt_token_ids"] = prompt_token_ids
execution_metadata.extra["prompt_token_ids_metadata"] = pti_meta
except Exception as e:
logger.warning(
"Failed to decompress prompt token IDs payload for trace %s: %s",
trace.get("id"),
e,
)

return EvaluationRow(
messages=messages,
tools=tools,
Expand Down
98 changes: 98 additions & 0 deletions eval_protocol/adapters/pti_deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""PTI/v1 binary deserializer for prompt token ID payloads.

Implements the inverse of the tracing gateway's
``prompt_token_ids_serializer.serialize_prompt_token_ids``.
"""

from __future__ import annotations

import base64
import struct
from typing import Any, Dict, List, Tuple

import zstandard as zstd

MAGIC = b"PTI1"
HEADER_VERSION = 1
ENTRY_FORMAT = "<i"
ENTRY_SIZE = struct.calcsize(ENTRY_FORMAT) # 4 bytes
HEADER_FORMAT = "<4sBBHIIQ"
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 24 bytes


def _parse_header(raw: bytes) -> Dict[str, Any]:
if len(raw) < HEADER_SIZE:
raise ValueError(f"Payload too short for PTI/v1 header: {len(raw)} < {HEADER_SIZE}")

(
magic,
version,
flags,
reserved_u16,
token_count,
body_byte_length,
reserved_u64,
) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE])

if magic != MAGIC:
raise ValueError(f"Bad PTI/v1 magic: {magic!r}")
if version != HEADER_VERSION:
raise ValueError(f"Unsupported PTI/v1 header version: {version}")

return {
"flags": flags,
"reserved_u16": reserved_u16,
"token_count": token_count,
"body_byte_length": body_byte_length,
"reserved_u64": reserved_u64,
}


def parse_prompt_token_ids(raw: bytes) -> Tuple[List[int], Dict[str, Any]]:
"""Parse uncompressed PTI/v1 bytes into prompt token IDs and metadata."""
header = _parse_header(raw)
token_count = header["token_count"]
body_byte_length = header["body_byte_length"]

if token_count == 0:
raise ValueError("PTI/v1 token_count must be > 0")
if body_byte_length != token_count * ENTRY_SIZE:
raise ValueError(
f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} "
f"({token_count * ENTRY_SIZE})"
)

expected_len = HEADER_SIZE + body_byte_length
if len(raw) != expected_len:
raise ValueError(f"PTI/v1 payload length mismatch: {len(raw)} != {expected_len}")

token_ids: List[int] = []
offset = HEADER_SIZE
for _ in range(token_count):
(token_id,) = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE])
offset += ENTRY_SIZE
token_ids.append(token_id)

metadata: Dict[str, Any] = {
"scope": "prompt_only",
"token_count": token_count,
}
header.update(metadata)
return token_ids, header


def decompress_and_parse_pti(data_b64: str) -> Tuple[List[int], Dict[str, Any]]:
"""Decompress and unpack a PTI/v1 prompt token ID payload.

Args:
data_b64: Base64-encoded zstd-compressed PTI binary blob from
``payloads.prompt_token_ids.data``.

Returns:
``(token_ids, metadata)`` where ``token_ids`` is the prompt token ID
sequence and ``metadata`` includes ``token_count``.
"""
compressed = base64.b64decode(data_b64)
decompressor = zstd.ZstdDecompressor()
raw = decompressor.decompress(compressed)
return parse_prompt_token_ids(raw)
2 changes: 2 additions & 0 deletions eval_protocol/pytest/tracing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[Eval
for key in (
"completion_logprobs",
"completion_token_ids",
"prompt_token_ids",
"prompt_token_ids_metadata",
"logprobs_metadata",
"routing_matrices",
"routing_metadata",
Expand Down
209 changes: 209 additions & 0 deletions scripts/test_remote_rollout_prompt_token_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
#!/usr/bin/env python3
"""E2E check: RemoteRolloutProcessor reads prompt_token_ids trace payloads.

This starts a tiny local `/init` server, sends one chat completion through the
Fireworks tracing gateway with `return_token_ids`, and verifies that
RemoteRolloutProcessor hydrates `assistant_turn_payloads[*].prompt_token_ids`.
"""

from __future__ import annotations

import argparse
import asyncio
import logging
import os
import sys
import socket
import threading
import time
from pathlib import Path
from typing import Any

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

import uvicorn
from fastapi import FastAPI
from openai import OpenAI

from eval_protocol import FireworksTracingHttpHandler, InitRequest, RolloutIdFilter, Status
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig

logger = logging.getLogger("remote_rollout_prompt_token_ids")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")


def _free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return int(sock.getsockname()[1])


def _message_to_dict(message: Message | dict[str, Any]) -> dict[str, Any]:
if isinstance(message, Message):
return message.dump_mdoel_for_chat_completion_request()
return {k: v for k, v in dict(message).items() if v is not None}


def _make_app(gateway_url: str) -> FastAPI:
app = FastAPI()
app_logger = logging.getLogger(f"{__name__}.server")
app_logger.setLevel(logging.INFO)

@app.get("/")
def health() -> dict[str, str]:
return {"status": "ok"}

@app.post("/init")
def init(req: InitRequest) -> dict[str, str]:
rollout_logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}")
rollout_logger.addFilter(RolloutIdFilter(req.metadata.rollout_id))
if not any(isinstance(handler, FireworksTracingHttpHandler) for handler in rollout_logger.handlers):
rollout_logger.addHandler(FireworksTracingHttpHandler(gateway_base_url=gateway_url))
rollout_logger.setLevel(logging.INFO)

def _worker() -> None:
try:
conversation = [_message_to_dict(message) for message in (req.messages or [])]
params = dict(req.completion_params or {})
params.pop("base_url", None)
params["extra_body"] = {
**dict(params.get("extra_body") or {}),
"return_token_ids": True,
}
params.setdefault("temperature", 0)
params.setdefault("max_tokens", 8)

if not req.model_base_url:
raise ValueError("model_base_url is required")
if not params.get("model"):
raise ValueError("completion_params.model is required")

client = OpenAI(base_url=req.model_base_url, api_key=req.api_key)
response = client.chat.completions.create(messages=conversation, **params)
content = response.choices[0].message.content or ""
logger.info("remote server generated content=%r", content)

rollout_logger.info(
"rollout %s finished",
req.metadata.rollout_id,
extra={"status": Status.rollout_finished()},
)
except Exception as exc:
rollout_logger.exception(
"rollout %s failed",
req.metadata.rollout_id,
extra={"status": Status.rollout_unknown_error(str(exc))},
)

threading.Thread(target=_worker, daemon=True).start()
return {"status": "started"}

return app


def _wait_ready(url: str, timeout_seconds: float = 30.0) -> None:
import requests

deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url, timeout=2)
if resp.status_code == 200:
return
except Exception:
pass
time.sleep(0.2)
raise TimeoutError(f"server not ready: {url}")


async def _run(args: argparse.Namespace) -> None:
api_key = args.api_key or os.getenv("FIREWORKS_DEV_API_KEY") or os.getenv("FIREWORKS_API_KEY")
if not api_key:
raise ValueError("Set FIREWORKS_DEV_API_KEY or FIREWORKS_API_KEY")

# FireworksTracingHttpHandler reads FIREWORKS_API_KEY.
os.environ["FIREWORKS_API_KEY"] = api_key
os.environ["EP_REMOTE_API_KEY"] = api_key

port = args.port or _free_port()
remote_base_url = f"http://127.0.0.1:{port}"
app = _make_app(args.gateway_url)
config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning")
server = uvicorn.Server(config)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
_wait_ready(f"{remote_base_url}/")

rollout_id = f"rrp-prompt-ids-{int(time.time())}"
row = EvaluationRow(
messages=[Message(role="user", content="Reply with exactly: ok")],
)
row.input_metadata.row_id = "row-0"
row.input_metadata.completion_params = {
"model": args.model,
"base_url": args.api_base_url,
"temperature": 0,
"max_tokens": 8,
}
row.execution_metadata.rollout_id = rollout_id
row.execution_metadata.invocation_id = "inv-0"
row.execution_metadata.experiment_id = "fir2-1747-rrp-e2e"
row.execution_metadata.run_id = "run-0"

processor = RemoteRolloutProcessor(
remote_base_url=remote_base_url,
model_base_url=args.gateway_url,
include_payloads=True,
timeout_seconds=args.timeout_seconds,
poll_interval=args.poll_interval,
)
try:
task = processor(
[row],
RolloutProcessorConfig(
completion_params=row.input_metadata.completion_params,
mcp_config_path="",
semaphore=asyncio.Semaphore(1),
steps=1,
),
)[0]
completed = await task
finally:
await processor.acleanup()
server.should_exit = True
thread.join(timeout=5)

extra = completed.execution_metadata.extra or {}
turn_payloads = extra.get("assistant_turn_payloads") or []
prompt_ids = None
if turn_payloads:
prompt_ids = turn_payloads[0].get("prompt_token_ids")
if prompt_ids is None:
prompt_ids = extra.get("prompt_token_ids")

print(f"rollout_id={rollout_id}")
print(f"messages={len(completed.messages)}")
print(f"assistant_turn_payloads={turn_payloads}")
print(f"prompt_token_ids_len={len(prompt_ids) if isinstance(prompt_ids, list) else None}")
print(f"prompt_token_ids_head={prompt_ids[:8] if isinstance(prompt_ids, list) else None}")

if not isinstance(prompt_ids, list) or not prompt_ids:
raise AssertionError("RemoteRolloutProcessor did not hydrate prompt_token_ids")


def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--gateway-url", default=os.getenv("EP_MODEL_BASE_URL", "https://litellm-gateway-dev-j4kzagdteq-uc.a.run.app"))
parser.add_argument("--api-base-url", default=os.getenv("FIREWORKS_API_BASE_URL", "https://dev.api.fireworks.ai/inference/v1"))
parser.add_argument("--model", default=os.getenv("TRACING_E2E_MODEL", "accounts/pyroworks-dev/deployments/malaysia2-intended-butterfly"))
parser.add_argument("--api-key", default=None)
parser.add_argument("--port", type=int, default=0)
parser.add_argument("--timeout-seconds", type=float, default=180.0)
parser.add_argument("--poll-interval", type=float, default=2.0)
asyncio.run(_run(parser.parse_args()))


if __name__ == "__main__":
main()
Loading
Loading