|
18 | 18 | import types |
19 | 19 | import uuid |
20 | 20 | from abc import ABC, abstractmethod |
| 21 | +from collections import OrderedDict |
21 | 22 | from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence |
22 | 23 | from functools import partial, wraps |
23 | 24 | from multiprocessing import cpu_count |
@@ -119,6 +120,8 @@ class Logs3OverflowInputRow: |
119 | 120 | class LogItemWithMeta: |
120 | 121 | str_value: str |
121 | 122 | overflow_meta: Logs3OverflowInputRow |
| 123 | + root_span_id: str | None = None |
| 124 | + object_ids: dict[str, Any] = dataclasses.field(default_factory=dict) |
122 | 125 |
|
123 | 126 |
|
124 | 127 | class DatasetRef(TypedDict, total=False): |
@@ -419,7 +422,11 @@ def default_get_api_conn(): |
419 | 422 | # We lazily-initialize the logger so that it does any initialization |
420 | 423 | # (including reading env variables) upon the first actual usage. |
421 | 424 | self._global_bg_logger = LazyValue( |
422 | | - lambda: _HTTPBackgroundLogger(LazyValue(default_get_api_conn, use_mutex=True)), use_mutex=True |
| 425 | + lambda: _HTTPBackgroundLogger( |
| 426 | + LazyValue(default_get_api_conn, use_mutex=True), |
| 427 | + record_write_xact_id=self.record_trace_write_xact_id, |
| 428 | + ), |
| 429 | + use_mutex=True, |
423 | 430 | ) |
424 | 431 |
|
425 | 432 | self._id_generator = None |
@@ -462,6 +469,9 @@ def default_get_api_conn(): |
462 | 469 | from braintrust.span_cache import SpanCache |
463 | 470 |
|
464 | 471 | self.span_cache = SpanCache() |
| 472 | + self._trace_write_xact_ids: OrderedDict[tuple[str, str], str] = OrderedDict() |
| 473 | + self._trace_write_xact_ids_max_size = int(os.environ.get("BRAINTRUST_TRACE_WRITE_XACT_IDS_MAX_SIZE", "10000")) |
| 474 | + self._trace_write_xact_ids_lock = threading.Lock() |
465 | 475 | self._otel_flush_callback: Any | None = None |
466 | 476 |
|
467 | 477 | def reset_login_info(self): |
@@ -521,6 +531,23 @@ def context_manager(self): |
521 | 531 |
|
522 | 532 | return self._context_manager |
523 | 533 |
|
| 534 | + def record_trace_write_xact_id(self, object_id: str, root_span_id: str, xact_id: str) -> None: |
| 535 | + """Record the highest ingestion xact ID observed for a trace.""" |
| 536 | + parsed_xact_id = int(xact_id) |
| 537 | + key = (object_id, root_span_id) |
| 538 | + with self._trace_write_xact_ids_lock: |
| 539 | + current_xact_id = self._trace_write_xact_ids.get(key) |
| 540 | + if current_xact_id is None or parsed_xact_id > int(current_xact_id): |
| 541 | + self._trace_write_xact_ids[key] = xact_id |
| 542 | + self._trace_write_xact_ids.move_to_end(key) |
| 543 | + while len(self._trace_write_xact_ids) > self._trace_write_xact_ids_max_size: |
| 544 | + self._trace_write_xact_ids.popitem(last=False) |
| 545 | + |
| 546 | + def get_trace_write_xact_id(self, object_id: str, root_span_id: str) -> str | None: |
| 547 | + """Return the highest ingestion xact ID recorded for a trace.""" |
| 548 | + with self._trace_write_xact_ids_lock: |
| 549 | + return self._trace_write_xact_ids.get((object_id, root_span_id)) |
| 550 | + |
524 | 551 | def register_otel_flush(self, callback: Any) -> None: |
525 | 552 | """ |
526 | 553 | Register an OTEL flush callback. This is called by the OTEL integration |
@@ -554,6 +581,9 @@ def copy_state(self, other: "BraintrustState"): |
554 | 581 | "_context_manager", |
555 | 582 | "_last_otel_setting", |
556 | 583 | "_context_manager_lock", |
| 584 | + "_trace_write_xact_ids", |
| 585 | + "_trace_write_xact_ids_max_size", |
| 586 | + "_trace_write_xact_ids_lock", |
557 | 587 | ) |
558 | 588 | } |
559 | 589 | ) |
@@ -864,14 +894,17 @@ def pick_logs3_overflow_object_ids(row: Mapping[str, Any]) -> dict[str, Any]: |
864 | 894 |
|
865 | 895 | def stringify_with_overflow_meta(item: dict[str, Any]) -> LogItemWithMeta: |
866 | 896 | str_value = bt_dumps(item) |
| 897 | + object_ids = pick_logs3_overflow_object_ids(item) |
867 | 898 | return LogItemWithMeta( |
868 | 899 | str_value=str_value, |
869 | 900 | overflow_meta=Logs3OverflowInputRow( |
870 | | - object_ids=pick_logs3_overflow_object_ids(item), |
| 901 | + object_ids=object_ids, |
871 | 902 | has_comment="comment" in item, |
872 | 903 | is_delete=item.get(OBJECT_DELETE_FIELD) is True, |
873 | 904 | byte_size=utf8_byte_length(str_value), |
874 | 905 | ), |
| 906 | + root_span_id=item.get("root_span_id") if isinstance(item.get("root_span_id"), str) else None, |
| 907 | + object_ids=object_ids, |
875 | 908 | ) |
876 | 909 |
|
877 | 910 |
|
@@ -1004,8 +1037,13 @@ def pop(self): |
1004 | 1037 | # instances of this class, because concurrent _BackgroundLoggers will not log to |
1005 | 1038 | # the backend in a deterministic order. |
1006 | 1039 | class _HTTPBackgroundLogger: |
1007 | | - def __init__(self, api_conn: LazyValue[HTTPConnection]): |
| 1040 | + def __init__( |
| 1041 | + self, |
| 1042 | + api_conn: LazyValue[HTTPConnection], |
| 1043 | + record_write_xact_id: Callable[[str, str, str], None] | None = None, |
| 1044 | + ): |
1008 | 1045 | self.api_conn = api_conn |
| 1046 | + self._record_write_xact_id = record_write_xact_id |
1009 | 1047 | self.masking_function: Callable[[Any], Any] | None = None |
1010 | 1048 | self.outfile = sys.stderr |
1011 | 1049 | self.flush_lock = threading.RLock() |
@@ -1383,6 +1421,7 @@ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_siz |
1383 | 1421 | if error is None and resp is not None and resp.ok: |
1384 | 1422 | if overflow_rows: |
1385 | 1423 | self._overflow_upload_count += 1 |
| 1424 | + self._record_batch_write_xact_id(items, resp.headers.get("x-bt-write-xact-id")) |
1386 | 1425 | return |
1387 | 1426 | if error is None and resp is not None: |
1388 | 1427 | resp_errmsg = f"{resp.status_code}: {resp.text}" |
@@ -1410,6 +1449,16 @@ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_siz |
1410 | 1449 |
|
1411 | 1450 | print(f"log request failed after {self.num_tries} retries. Dropping batch", file=self.outfile) |
1412 | 1451 |
|
| 1452 | + def _record_batch_write_xact_id(self, items: Sequence[LogItemWithMeta], xact_id: str | None) -> None: |
| 1453 | + if not xact_id or self._record_write_xact_id is None: |
| 1454 | + return |
| 1455 | + for item in items: |
| 1456 | + if not item.root_span_id: |
| 1457 | + continue |
| 1458 | + for object_id in item.object_ids.values(): |
| 1459 | + if isinstance(object_id, str): |
| 1460 | + self._record_write_xact_id(object_id, item.root_span_id, xact_id) |
| 1461 | + |
1413 | 1462 | def _dump_dropped_events(self, wrapped_items): |
1414 | 1463 | publish_payloads_dir = [x for x in [self.all_publish_payloads_dir, self.failed_publish_payloads_dir] if x] |
1415 | 1464 | if not (wrapped_items and publish_payloads_dir): |
@@ -1480,7 +1529,9 @@ def _internal_get_global_state() -> BraintrustState: |
1480 | 1529 |
|
1481 | 1530 | @contextlib.contextmanager |
1482 | 1531 | def _internal_with_custom_background_logger(): |
1483 | | - custom_logger = _HTTPBackgroundLogger(LazyValue(lambda: _state.api_conn(), use_mutex=True)) |
| 1532 | + custom_logger = _HTTPBackgroundLogger( |
| 1533 | + LazyValue(lambda: _state.api_conn(), use_mutex=True), record_write_xact_id=_state.record_trace_write_xact_id |
| 1534 | + ) |
1484 | 1535 | _state._override_bg_logger.logger = custom_logger |
1485 | 1536 | try: |
1486 | 1537 | yield custom_logger |
|
0 commit comments