Skip to content

Commit 282264f

Browse files
committed
feat: add compacted trace xact id reads
Capture x-bt-write-xact-id from logs3 flush responses and record the latest write xact id per trace in SDK state. Pass trace_min_xact_id on get_thread preprocessor invokes so compacted reads can wait for ingestion visibility, with unit coverage for logging, invoke serialization, and LocalTrace wiring.
1 parent ec87b5a commit 282264f

6 files changed

Lines changed: 183 additions & 5 deletions

File tree

py/src/braintrust/functions/invoke.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def invoke(
5757
api_key: str | None = None,
5858
app_url: str | None = None,
5959
force_login: bool = False,
60+
trace_min_xact_id: str | None = None,
6061
) -> T: ...
6162

6263

@@ -85,6 +86,7 @@ def invoke(
8586
api_key: str | None = None,
8687
app_url: str | None = None,
8788
force_login: bool = False,
89+
trace_min_xact_id: str | None = None,
8890
) -> BraintrustStream: ...
8991

9092

@@ -112,6 +114,7 @@ def invoke(
112114
api_key: str | None = None,
113115
app_url: str | None = None,
114116
force_login: bool = False,
117+
trace_min_xact_id: str | None = None,
115118
) -> BraintrustStream | T:
116119
"""
117120
Invoke a Braintrust function, returning a `BraintrustStream` or the value as a plain
@@ -151,6 +154,7 @@ def invoke(
151154
global_function: The name of the global function to invoke.
152155
function_type: The type of the global function to invoke. If unspecified, defaults to 'scorer'
153156
for backward compatibility.
157+
trace_min_xact_id: Optional minimum ingestion xact ID for compacted trace-ref reads.
154158
155159
Returns:
156160
The output of the function. If `stream` is True, returns a `BraintrustStream`,
@@ -198,6 +202,8 @@ def invoke(
198202
request["mode"] = mode
199203
if strict is not None:
200204
request["strict"] = strict
205+
if trace_min_xact_id is not None:
206+
request["trace_min_xact_id"] = trace_min_xact_id
201207

202208
headers = {
203209
"Accept": "text/event-stream" if stream else "application/json",

py/src/braintrust/functions/test_invoke.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,32 @@ def test_invoke_serializes_google_messages():
118118
assert isinstance(parsed, dict) and parsed
119119

120120

121+
def test_invoke_serializes_trace_min_xact_id():
122+
mock_resp = MagicMock()
123+
mock_resp.status_code = 200
124+
mock_resp.json.return_value = {}
125+
mock_conn = MagicMock()
126+
mock_conn.post.return_value = mock_resp
127+
128+
with (
129+
patch("braintrust.functions.invoke.login"),
130+
patch("braintrust.functions.invoke.get_span_parent_object") as mock_parent,
131+
patch("braintrust.functions.invoke.proxy_conn", return_value=mock_conn),
132+
):
133+
mock_parent.return_value.export.return_value = "span-export"
134+
invoke(
135+
global_function="project_default",
136+
function_type="preprocessor",
137+
input={"trace_ref": {"object_id": "exp-123", "root_span_id": "root-456"}},
138+
trace_min_xact_id="12345",
139+
)
140+
141+
data = mock_conn.post.call_args.kwargs["data"]
142+
parsed = json.loads(data.decode("utf-8"))
143+
assert parsed["trace_min_xact_id"] == "12345"
144+
assert "trace_read" not in parsed
145+
146+
121147
@pytest.mark.vcr
122148
def test_invoke_encodes_body_as_utf8_bytes(monkeypatch):
123149
"""Regression test for BT-4620: non-Latin-1 Unicode must not be corrupted.

py/src/braintrust/logger.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import types
1919
import uuid
2020
from abc import ABC, abstractmethod
21+
from collections import OrderedDict
2122
from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence
2223
from functools import partial, wraps
2324
from multiprocessing import cpu_count
@@ -119,6 +120,8 @@ class Logs3OverflowInputRow:
119120
class LogItemWithMeta:
120121
str_value: str
121122
overflow_meta: Logs3OverflowInputRow
123+
root_span_id: str | None = None
124+
object_ids: dict[str, Any] = dataclasses.field(default_factory=dict)
122125

123126

124127
class DatasetRef(TypedDict, total=False):
@@ -419,7 +422,11 @@ def default_get_api_conn():
419422
# We lazily-initialize the logger so that it does any initialization
420423
# (including reading env variables) upon the first actual usage.
421424
self._global_bg_logger = LazyValue(
422-
lambda: _HTTPBackgroundLogger(LazyValue(default_get_api_conn, use_mutex=True)), use_mutex=True
425+
lambda: _HTTPBackgroundLogger(
426+
LazyValue(default_get_api_conn, use_mutex=True),
427+
record_write_xact_id=self.record_trace_write_xact_id,
428+
),
429+
use_mutex=True,
423430
)
424431

425432
self._id_generator = None
@@ -462,6 +469,9 @@ def default_get_api_conn():
462469
from braintrust.span_cache import SpanCache
463470

464471
self.span_cache = SpanCache()
472+
self._trace_write_xact_ids: OrderedDict[tuple[str, str], str] = OrderedDict()
473+
self._trace_write_xact_ids_max_size = int(os.environ.get("BRAINTRUST_TRACE_WRITE_XACT_IDS_MAX_SIZE", "10000"))
474+
self._trace_write_xact_ids_lock = threading.Lock()
465475
self._otel_flush_callback: Any | None = None
466476

467477
def reset_login_info(self):
@@ -521,6 +531,23 @@ def context_manager(self):
521531

522532
return self._context_manager
523533

534+
def record_trace_write_xact_id(self, object_id: str, root_span_id: str, xact_id: str) -> None:
535+
"""Record the highest ingestion xact ID observed for a trace."""
536+
parsed_xact_id = int(xact_id)
537+
key = (object_id, root_span_id)
538+
with self._trace_write_xact_ids_lock:
539+
current_xact_id = self._trace_write_xact_ids.get(key)
540+
if current_xact_id is None or parsed_xact_id > int(current_xact_id):
541+
self._trace_write_xact_ids[key] = xact_id
542+
self._trace_write_xact_ids.move_to_end(key)
543+
while len(self._trace_write_xact_ids) > self._trace_write_xact_ids_max_size:
544+
self._trace_write_xact_ids.popitem(last=False)
545+
546+
def get_trace_write_xact_id(self, object_id: str, root_span_id: str) -> str | None:
547+
"""Return the highest ingestion xact ID recorded for a trace."""
548+
with self._trace_write_xact_ids_lock:
549+
return self._trace_write_xact_ids.get((object_id, root_span_id))
550+
524551
def register_otel_flush(self, callback: Any) -> None:
525552
"""
526553
Register an OTEL flush callback. This is called by the OTEL integration
@@ -554,6 +581,9 @@ def copy_state(self, other: "BraintrustState"):
554581
"_context_manager",
555582
"_last_otel_setting",
556583
"_context_manager_lock",
584+
"_trace_write_xact_ids",
585+
"_trace_write_xact_ids_max_size",
586+
"_trace_write_xact_ids_lock",
557587
)
558588
}
559589
)
@@ -864,14 +894,17 @@ def pick_logs3_overflow_object_ids(row: Mapping[str, Any]) -> dict[str, Any]:
864894

865895
def stringify_with_overflow_meta(item: dict[str, Any]) -> LogItemWithMeta:
866896
str_value = bt_dumps(item)
897+
object_ids = pick_logs3_overflow_object_ids(item)
867898
return LogItemWithMeta(
868899
str_value=str_value,
869900
overflow_meta=Logs3OverflowInputRow(
870-
object_ids=pick_logs3_overflow_object_ids(item),
901+
object_ids=object_ids,
871902
has_comment="comment" in item,
872903
is_delete=item.get(OBJECT_DELETE_FIELD) is True,
873904
byte_size=utf8_byte_length(str_value),
874905
),
906+
root_span_id=item.get("root_span_id") if isinstance(item.get("root_span_id"), str) else None,
907+
object_ids=object_ids,
875908
)
876909

877910

@@ -1004,8 +1037,13 @@ def pop(self):
10041037
# instances of this class, because concurrent _BackgroundLoggers will not log to
10051038
# the backend in a deterministic order.
10061039
class _HTTPBackgroundLogger:
1007-
def __init__(self, api_conn: LazyValue[HTTPConnection]):
1040+
def __init__(
1041+
self,
1042+
api_conn: LazyValue[HTTPConnection],
1043+
record_write_xact_id: Callable[[str, str, str], None] | None = None,
1044+
):
10081045
self.api_conn = api_conn
1046+
self._record_write_xact_id = record_write_xact_id
10091047
self.masking_function: Callable[[Any], Any] | None = None
10101048
self.outfile = sys.stderr
10111049
self.flush_lock = threading.RLock()
@@ -1383,6 +1421,7 @@ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_siz
13831421
if error is None and resp is not None and resp.ok:
13841422
if overflow_rows:
13851423
self._overflow_upload_count += 1
1424+
self._record_batch_write_xact_id(items, resp.headers.get("x-bt-write-xact-id"))
13861425
return
13871426
if error is None and resp is not None:
13881427
resp_errmsg = f"{resp.status_code}: {resp.text}"
@@ -1410,6 +1449,16 @@ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_siz
14101449

14111450
print(f"log request failed after {self.num_tries} retries. Dropping batch", file=self.outfile)
14121451

1452+
def _record_batch_write_xact_id(self, items: Sequence[LogItemWithMeta], xact_id: str | None) -> None:
1453+
if not xact_id or self._record_write_xact_id is None:
1454+
return
1455+
for item in items:
1456+
if not item.root_span_id:
1457+
continue
1458+
for object_id in item.object_ids.values():
1459+
if isinstance(object_id, str):
1460+
self._record_write_xact_id(object_id, item.root_span_id, xact_id)
1461+
14131462
def _dump_dropped_events(self, wrapped_items):
14141463
publish_payloads_dir = [x for x in [self.all_publish_payloads_dir, self.failed_publish_payloads_dir] if x]
14151464
if not (wrapped_items and publish_payloads_dir):
@@ -1480,7 +1529,9 @@ def _internal_get_global_state() -> BraintrustState:
14801529

14811530
@contextlib.contextmanager
14821531
def _internal_with_custom_background_logger():
1483-
custom_logger = _HTTPBackgroundLogger(LazyValue(lambda: _state.api_conn(), use_mutex=True))
1532+
custom_logger = _HTTPBackgroundLogger(
1533+
LazyValue(lambda: _state.api_conn(), use_mutex=True), record_write_xact_id=_state.record_trace_write_xact_id
1534+
)
14841535
_state._override_bg_logger.logger = custom_logger
14851536
try:
14861537
yield custom_logger

py/src/braintrust/test_logger.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,69 @@ def test_init_enable_atexit_flush(self):
137137
_HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) # type: ignore
138138
mock_register.assert_called()
139139

140+
def test_records_write_xact_id_from_logs3_response(self):
141+
from braintrust.logger import _HTTPBackgroundLogger, stringify_with_overflow_meta
142+
143+
class FakeResponse:
144+
ok = True
145+
headers = {"x-bt-write-xact-id": "12345"}
146+
147+
mock_conn = MagicMock()
148+
mock_conn.post.return_value = FakeResponse()
149+
recorded = []
150+
bg_logger = _HTTPBackgroundLogger(
151+
LazyValue(lambda: mock_conn, use_mutex=False),
152+
record_write_xact_id=lambda object_id, root_span_id, xact_id: recorded.append(
153+
(object_id, root_span_id, xact_id)
154+
),
155+
)
156+
157+
bg_logger._submit_logs_request(
158+
[
159+
stringify_with_overflow_meta(
160+
{
161+
"experiment_id": "exp-123",
162+
"root_span_id": "root-456",
163+
"span_id": "span-789",
164+
}
165+
)
166+
],
167+
{"max_request_size": 1024 * 1024, "can_use_overflow": False},
168+
)
169+
170+
assert recorded == [("exp-123", "root-456", "12345")]
171+
172+
def test_trace_write_xact_id_keeps_high_watermark(self):
173+
from braintrust.logger import BraintrustState
174+
175+
state = BraintrustState()
176+
state.record_trace_write_xact_id("exp-123", "root-456", "200")
177+
state.record_trace_write_xact_id("exp-123", "root-456", "100")
178+
state.record_trace_write_xact_id("exp-123", "root-other", "50")
179+
180+
assert state.get_trace_write_xact_id("exp-123", "root-456") == "200"
181+
assert state.get_trace_write_xact_id("exp-123", "root-other") == "50"
182+
183+
def test_trace_write_xact_id_rejects_non_numeric_values(self):
184+
from braintrust.logger import BraintrustState
185+
186+
state = BraintrustState()
187+
with pytest.raises(ValueError):
188+
state.record_trace_write_xact_id("exp-123", "root-456", "not-numeric")
189+
190+
def test_trace_write_xact_ids_are_bounded(self):
191+
from braintrust.logger import BraintrustState
192+
193+
with patch.dict(os.environ, {"BRAINTRUST_TRACE_WRITE_XACT_IDS_MAX_SIZE": "2"}):
194+
state = BraintrustState()
195+
state.record_trace_write_xact_id("exp-123", "root-1", "1")
196+
state.record_trace_write_xact_id("exp-123", "root-2", "2")
197+
state.record_trace_write_xact_id("exp-123", "root-3", "3")
198+
199+
assert state.get_trace_write_xact_id("exp-123", "root-1") is None
200+
assert state.get_trace_write_xact_id("exp-123", "root-2") == "2"
201+
assert state.get_trace_write_xact_id("exp-123", "root-3") == "3"
202+
140203
def test_init_disable_atexit_flush(self):
141204
from braintrust.logger import _HTTPBackgroundLogger
142205

py/src/braintrust/test_trace.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,16 @@ def get_by_root_span_id(self, root_span_id: str):
306306

307307

308308
class _DummyState:
309-
def __init__(self):
309+
def __init__(self, xact_id: str | None = None):
310310
self.span_cache = _DummySpanCache()
311+
self.xact_id = xact_id
311312

312313
def login(self):
313314
return None
314315

316+
def get_trace_write_xact_id(self, object_id: str, root_span_id: str):
317+
return self.xact_id
318+
315319

316320
class TestLocalTraceGetThread:
317321
@pytest.mark.asyncio
@@ -349,8 +353,33 @@ def fake_invoke(**kwargs):
349353
"root_span_id": "root-456",
350354
}
351355
}
356+
assert calls[0]["trace_min_xact_id"] is None
352357
assert result == mock_thread
353358

359+
@pytest.mark.asyncio
360+
async def test_passes_trace_min_xact_id_with_recorded_xact_id(self, monkeypatch):
361+
calls = []
362+
363+
def fake_invoke(**kwargs):
364+
calls.append(kwargs)
365+
return []
366+
367+
monkeypatch.setattr("braintrust.trace.invoke", fake_invoke)
368+
369+
trace = LocalTrace(
370+
object_type="experiment",
371+
object_id="exp-123",
372+
root_span_id="root-456",
373+
ensure_spans_flushed=None,
374+
state=_DummyState(xact_id="12345"),
375+
)
376+
377+
await trace.get_thread()
378+
379+
assert calls[0]["trace_min_xact_id"] == "12345"
380+
assert "trace_read" not in calls[0]
381+
assert "skip_realtime" not in calls[0]["input"]["trace_ref"]
382+
354383
@pytest.mark.asyncio
355384
async def test_uses_custom_preprocessor(self, monkeypatch):
356385
calls = []

py/src/braintrust/trace.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,8 @@ async def _fetch_thread(self, options: GetThreadOptions | None = None) -> list[A
407407
await asyncio.get_event_loop().run_in_executor(None, lambda: self._state.login())
408408
preprocessor = options.get("preprocessor") if options and options.get("preprocessor") else None
409409

410+
trace_min_xact_id = self._state.get_trace_write_xact_id(self._object_id, self._root_span_id)
411+
410412
result = await asyncio.get_event_loop().run_in_executor(
411413
None,
412414
lambda: invoke(
@@ -420,6 +422,7 @@ async def _fetch_thread(self, options: GetThreadOptions | None = None) -> list[A
420422
"root_span_id": self._root_span_id,
421423
}
422424
},
425+
trace_min_xact_id=trace_min_xact_id,
423426
),
424427
)
425428

0 commit comments

Comments
 (0)