Skip to content

Commit ac3db7b

Browse files
Add LP/v1 logprobs deserialization for tracing gateway payloads (FIR-21499).
Decode logprobs payloads into completion_logprobs and Message.logprobs on EvaluationRow. Pop base_url from completion_params before OpenAI SDK calls so dev inference API can be encoded in gateway tracing URLs. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 99e49fa commit ac3db7b

5 files changed

Lines changed: 314 additions & 2 deletions

File tree

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,34 @@ def convert_trace_dict_to_evaluation_row(
116116
except Exception as e:
117117
logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e)
118118

119+
logprobs_payload = payloads.get("logprobs")
120+
if isinstance(logprobs_payload, dict) and logprobs_payload.get("data"):
121+
try:
122+
from .lp_deserializer import decompress_and_parse_lp
123+
124+
logprobs, token_ids, lp_meta = decompress_and_parse_lp(logprobs_payload["data"])
125+
if execution_metadata.extra is None:
126+
execution_metadata.extra = {}
127+
execution_metadata.extra["completion_logprobs"] = logprobs
128+
if token_ids is not None:
129+
execution_metadata.extra["completion_token_ids"] = token_ids
130+
execution_metadata.extra["logprobs_metadata"] = lp_meta
131+
132+
for i in range(len(messages) - 1, -1, -1):
133+
if messages[i].role == "assistant":
134+
content_entries = [{"logprob": lp} for lp in logprobs]
135+
if token_ids is not None:
136+
for entry, tid in zip(content_entries, token_ids):
137+
entry["token_id"] = tid
138+
messages[i].logprobs = {"content": content_entries}
139+
break
140+
except Exception as e:
141+
logger.warning(
142+
"Failed to decompress logprobs payload for trace %s: %s",
143+
trace.get("id"),
144+
e,
145+
)
146+
119147
return EvaluationRow(
120148
messages=messages,
121149
tools=tools,
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""LP/v1 binary deserializer for per-token logprobs payloads.
2+
3+
Implements the inverse of the tracing gateway's ``logprobs_serializer.serialize_logprobs``.
4+
See that module for the full header specification.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import base64
10+
import struct
11+
from typing import Any, Dict, List, Optional, Tuple
12+
13+
import zstandard as zstd
14+
15+
MAGIC = b"LP01"
16+
HEADER_VERSION = 1
17+
MISSING_TOKEN_ID = -1
18+
ENTRY_FORMAT = "<if"
19+
ENTRY_SIZE = struct.calcsize(ENTRY_FORMAT) # 8 bytes
20+
HEADER_FORMAT = "<4sBBHIIQ"
21+
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 24 bytes
22+
23+
24+
def _parse_header(raw: bytes) -> Dict[str, Any]:
25+
if len(raw) < HEADER_SIZE:
26+
raise ValueError(f"Payload too short for lp/v1 header: {len(raw)} < {HEADER_SIZE}")
27+
28+
(
29+
magic,
30+
version,
31+
flags,
32+
reserved_u16,
33+
token_count,
34+
body_byte_length,
35+
reserved_u64,
36+
) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE])
37+
38+
if magic != MAGIC:
39+
raise ValueError(f"Bad LP/v1 magic: {magic!r}")
40+
if version != HEADER_VERSION:
41+
raise ValueError(f"Unsupported lp/v1 header version: {version}")
42+
43+
return {
44+
"flags": flags,
45+
"reserved_u16": reserved_u16,
46+
"token_count": token_count,
47+
"body_byte_length": body_byte_length,
48+
"reserved_u64": reserved_u64,
49+
}
50+
51+
52+
def _parse_logprobs_raw(raw: bytes) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]:
53+
header = _parse_header(raw)
54+
token_count = header["token_count"]
55+
body_byte_length = header["body_byte_length"]
56+
57+
if token_count == 0:
58+
raise ValueError("LP/v1 token_count must be > 0")
59+
if body_byte_length != token_count * ENTRY_SIZE:
60+
raise ValueError(
61+
f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} "
62+
f"({token_count * ENTRY_SIZE})"
63+
)
64+
65+
expected_len = HEADER_SIZE + body_byte_length
66+
if len(raw) != expected_len:
67+
raise ValueError(f"LP/v1 payload length mismatch: {len(raw)} != {expected_len}")
68+
69+
logprobs: List[float] = []
70+
token_ids: List[int] = []
71+
all_token_ids_valid = True
72+
offset = HEADER_SIZE
73+
for _ in range(token_count):
74+
wire_id, logprob = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE])
75+
offset += ENTRY_SIZE
76+
logprobs.append(logprob)
77+
if wire_id == MISSING_TOKEN_ID:
78+
all_token_ids_valid = False
79+
token_ids.append(wire_id)
80+
else:
81+
token_ids.append(wire_id)
82+
83+
metadata: Dict[str, Any] = {
84+
"scope": "completion_only",
85+
"completion_token_count": token_count,
86+
"all_token_ids_valid": all_token_ids_valid,
87+
}
88+
header.update(metadata)
89+
ids_out: Optional[List[int]] = token_ids if all_token_ids_valid else None
90+
return logprobs, ids_out, header
91+
92+
93+
def parse_logprobs(raw: bytes) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]:
94+
"""Parse uncompressed LP/v1 bytes into logprobs, optional token ids, and metadata."""
95+
return _parse_logprobs_raw(raw)
96+
97+
98+
def decompress_and_parse_lp(data_b64: str) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]:
99+
"""Decompress and unpack an LP/v1 payload into completion logprobs and token ids.
100+
101+
Args:
102+
data_b64: Base64-encoded zstd-compressed LP binary blob from
103+
``payloads.logprobs.data``.
104+
105+
Returns:
106+
``(logprobs, token_ids, metadata)`` where ``logprobs`` is per-completion-token
107+
scalars, ``token_ids`` is ``None`` if any wire id was ``MISSING_TOKEN_ID``,
108+
and ``metadata`` includes ``all_token_ids_valid`` and ``completion_token_count``.
109+
"""
110+
compressed = base64.b64decode(data_b64)
111+
decompressor = zstd.ZstdDecompressor()
112+
raw = decompressor.decompress(compressed)
113+
return _parse_logprobs_raw(raw)

eval_protocol/pytest/tracing_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def build_init_request(
103103
if not completion_params_dict.get("model"):
104104
raise ValueError("Model must be provided in completion_params")
105105

106-
# Extract base_url from completion_params
107-
completion_params_base_url: Optional[str] = completion_params_dict.get("base_url")
106+
# Extract base_url from completion_params (encoded into gateway path, not sent to OpenAI SDK)
107+
completion_params_base_url: Optional[str] = completion_params_dict.pop("base_url", None)
108108

109109
# Strip non-OpenAI fields from messages
110110
# Use dump_mdoel_for_chat_completion_request() to automatically exclude unsupported fields (weight, control_plane_step, reasoning_content)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Tests for logprobs payload handling in fireworks_tracing adapter."""
2+
3+
from __future__ import annotations
4+
5+
import base64
6+
import struct
7+
8+
import pytest
9+
import zstandard as zstd
10+
11+
pytest.importorskip("mcp")
12+
13+
from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row
14+
from eval_protocol.adapters.lp_deserializer import (
15+
ENTRY_FORMAT,
16+
ENTRY_SIZE,
17+
HEADER_FORMAT,
18+
MAGIC,
19+
MISSING_TOKEN_ID,
20+
)
21+
22+
23+
def _lp_b64(tokens: list[tuple[int, float]]) -> str:
24+
token_count = len(tokens)
25+
body_byte_length = token_count * ENTRY_SIZE
26+
header = struct.pack(
27+
HEADER_FORMAT,
28+
MAGIC,
29+
1,
30+
0,
31+
0,
32+
token_count,
33+
body_byte_length,
34+
0,
35+
)
36+
body = b"".join(struct.pack(ENTRY_FORMAT, tid, lp) for tid, lp in tokens)
37+
raw = header + body
38+
compressed = zstd.ZstdCompressor().compress(raw)
39+
return base64.b64encode(compressed).decode("ascii")
40+
41+
42+
def _base_trace(*, with_token_ids: bool = True) -> dict:
43+
tokens = [(10, -0.1), (11, -0.2)] if with_token_ids else [(MISSING_TOKEN_ID, -0.1), (12, -0.2)]
44+
return {
45+
"id": "trace-1",
46+
"input": {
47+
"messages": [
48+
{"role": "user", "content": "hi"},
49+
{"role": "assistant", "content": "hello"},
50+
],
51+
},
52+
"output": {"role": "assistant", "content": "hello"},
53+
"payloads": {
54+
"logprobs": {
55+
"data": _lp_b64(tokens),
56+
"manifest": {"PayloadVersion": "lp/v1"},
57+
},
58+
},
59+
}
60+
61+
62+
class TestConvertTraceLogprobs:
63+
def test_attaches_completion_logprobs_and_message_logprobs(self):
64+
row = convert_trace_dict_to_evaluation_row(_base_trace())
65+
assert row is not None
66+
67+
extra = row.execution_metadata.extra
68+
assert extra is not None
69+
assert extra["completion_logprobs"] == pytest.approx([-0.1, -0.2])
70+
assert extra["completion_token_ids"] == [10, 11]
71+
72+
assistant = row.messages[-1]
73+
assert assistant.role == "assistant"
74+
content = assistant.logprobs["content"]
75+
assert len(content) == len(extra["completion_logprobs"])
76+
assert content[0]["token_id"] == 10
77+
assert content[1]["token_id"] == 11
78+
assert content[0]["logprob"] == pytest.approx(-0.1)
79+
assert content[1]["logprob"] == pytest.approx(-0.2)
80+
81+
def test_omits_token_id_keys_when_any_missing(self):
82+
row = convert_trace_dict_to_evaluation_row(_base_trace(with_token_ids=False))
83+
assert row is not None
84+
85+
extra = row.execution_metadata.extra
86+
assert "completion_logprobs" in extra
87+
assert "completion_token_ids" not in extra
88+
89+
content = row.messages[-1].logprobs["content"]
90+
assert len(content) == 2
91+
assert all("token_id" not in entry for entry in content)
92+
assert content[0]["logprob"] == pytest.approx(-0.1)
93+
assert content[1]["logprob"] == pytest.approx(-0.2)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Tests for LP/v1 binary deserializer (gateway-compatible)."""
2+
3+
from __future__ import annotations
4+
5+
import base64
6+
import struct
7+
8+
import pytest
9+
import zstandard as zstd
10+
11+
from eval_protocol.adapters.lp_deserializer import (
12+
ENTRY_FORMAT,
13+
ENTRY_SIZE,
14+
HEADER_FORMAT,
15+
HEADER_SIZE,
16+
MAGIC,
17+
MISSING_TOKEN_ID,
18+
decompress_and_parse_lp,
19+
parse_logprobs,
20+
)
21+
22+
# Golden raw bytes: two tokens (7, -0.25) and (8, -0.5) — must match gateway serializer.
23+
GOLDEN_RAW_HEX = (
24+
"4c503031010000000200000010000000000000000000000007000000000080be"
25+
"08000000000000bf"
26+
)
27+
28+
29+
def _build_raw(tokens: list[tuple[int, float]]) -> bytes:
30+
token_count = len(tokens)
31+
body_byte_length = token_count * ENTRY_SIZE
32+
header = struct.pack(
33+
HEADER_FORMAT,
34+
MAGIC,
35+
1,
36+
0,
37+
0,
38+
token_count,
39+
body_byte_length,
40+
0,
41+
)
42+
body = b"".join(struct.pack(ENTRY_FORMAT, tid, lp) for tid, lp in tokens)
43+
return header + body
44+
45+
46+
def _compress_b64(raw: bytes) -> str:
47+
return base64.b64encode(zstd.ZstdCompressor().compress(raw)).decode("ascii")
48+
49+
50+
class TestParseLogprobs:
51+
def test_golden_bytes_match_gateway(self):
52+
raw = bytes.fromhex(GOLDEN_RAW_HEX)
53+
logprobs, token_ids, meta = parse_logprobs(raw)
54+
assert logprobs == [-0.25, -0.5]
55+
assert token_ids == [7, 8]
56+
assert meta["all_token_ids_valid"] is True
57+
assert meta["token_count"] == 2
58+
59+
def test_missing_token_id_omits_token_ids_list(self):
60+
raw = _build_raw([(MISSING_TOKEN_ID, -0.3), (42, -0.4)])
61+
logprobs, token_ids, meta = parse_logprobs(raw)
62+
assert logprobs == pytest.approx([-0.3, -0.4])
63+
assert token_ids is None
64+
assert meta["all_token_ids_valid"] is False
65+
66+
def test_decompress_and_parse_round_trip(self):
67+
raw = bytes.fromhex(GOLDEN_RAW_HEX)
68+
b64 = _compress_b64(raw)
69+
logprobs, token_ids, meta = decompress_and_parse_lp(b64)
70+
assert logprobs == [-0.25, -0.5]
71+
assert token_ids == [7, 8]
72+
assert meta["scope"] == "completion_only"
73+
74+
def test_rejects_bad_magic(self):
75+
raw = _build_raw([(1, -0.1)])
76+
bad = b"XXXX" + raw[4:]
77+
with pytest.raises(ValueError, match="Bad LP/v1 magic"):
78+
parse_logprobs(bad)

0 commit comments

Comments
 (0)