diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py index 27b9cdf8..1c60a18c 100644 --- a/src/ezmsg/core/profiling.py +++ b/src/ezmsg/core/profiling.py @@ -23,12 +23,34 @@ # Must return monotonic nanoseconds so *_ns metrics remain unit-consistent. PROFILE_TIME_TYPE: TypeAlias = Callable[[], int] PROFILE_TIME: PROFILE_TIME_TYPE = time.perf_counter_ns +TraceRecord: TypeAlias = tuple[ + int, + str, + str, + str, + float, + ProfileChannelType | None, + int | None, +] def _endpoint_id(topic: str, id: UUID) -> str: return f"{topic}:{id}" +def _trace_sample_from_record(record: TraceRecord) -> ProfilingTraceSample: + timestamp_ns, endpoint_id, topic, metric, value, channel_kind, sample_seq = record + return ProfilingTraceSample( + timestamp=float(timestamp_ns), + endpoint_id=endpoint_id, + topic=topic, + metric=metric, + value=value, + channel_kind=channel_kind, + sample_seq=sample_seq, + ) + + @dataclass class _PublisherMetrics: topic: str @@ -44,7 +66,7 @@ class _PublisherMetrics: _trace_counter: int = 0 _trace_publish_delta_enabled: bool = False _trace_backpressure_wait_enabled: bool = False - trace_samples: deque[ProfilingTraceSample] = field( + trace_samples: deque[TraceRecord] = field( default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) ) @@ -54,9 +76,11 @@ def record_publish(self, inflight: int, msg_seq: int | None = None) -> None: if not self._trace_publish_delta_enabled: return - self._trace_counter += 1 - if self._trace_counter % max(1, self.trace_sample_mod) != 0: - return + sample_mod = self.trace_sample_mod + if sample_mod != 1: + self._trace_counter += 1 + if self._trace_counter % sample_mod != 0: + return now_ns = PROFILE_TIME() publish_delta_ns = ( @@ -64,13 +88,14 @@ def record_publish(self, inflight: int, msg_seq: int | None = None) -> None: ) self._last_publish_ts_ns = now_ns self.trace_samples.append( - ProfilingTraceSample( - timestamp=float(now_ns), - endpoint_id=self.endpoint_id, - topic=self.topic, - metric="publish_delta_ns", - value=float(publish_delta_ns), - sample_seq=msg_seq, + ( + now_ns, + self.endpoint_id, + self.topic, + "publish_delta_ns", + float(publish_delta_ns), + None, + msg_seq, ) ) @@ -80,13 +105,14 @@ def record_backpressure_wait(self, wait_ns: int, msg_seq: int | None = None) -> now_ns = PROFILE_TIME() self.trace_samples.append( - ProfilingTraceSample( - timestamp=float(now_ns), - endpoint_id=self.endpoint_id, - topic=self.topic, - metric="backpressure_wait_ns", - value=float(wait_ns), - sample_seq=msg_seq, + ( + now_ns, + self.endpoint_id, + self.topic, + "backpressure_wait_ns", + float(wait_ns), + None, + msg_seq, ) ) @@ -135,19 +161,31 @@ class _SubscriberMetrics: _trace_counter: int = 0 _trace_lease_time_enabled: bool = False _trace_user_span_enabled: bool = False - trace_samples: deque[ProfilingTraceSample] = field( + trace_samples: deque[TraceRecord] = field( default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) ) - def begin_message(self, channel_kind: ProfileChannelType) -> bool: + def trace_receive_state( + self, channel_kind: ProfileChannelType + ) -> tuple[bool, bool, bool]: self.messages_received_total += 1 self.channel_kind_last = channel_kind - if not (self._trace_lease_time_enabled or self._trace_user_span_enabled): - return False + trace_lease = self._trace_lease_time_enabled + trace_user_span = self._trace_user_span_enabled + if not (trace_lease or trace_user_span): + return False, trace_lease, trace_user_span + + sample_mod = self.trace_sample_mod + if sample_mod == 1: + return True, trace_lease, trace_user_span self._trace_counter += 1 - return self._trace_counter % max(1, self.trace_sample_mod) == 0 + return (self._trace_counter % sample_mod == 0), trace_lease, trace_user_span + + def begin_message(self, channel_kind: ProfileChannelType) -> bool: + sampled, _trace_lease, _trace_user_span = self.trace_receive_state(channel_kind) + return sampled def record_receive( self, @@ -176,14 +214,33 @@ def record_lease_time( now_ns = PROFILE_TIME() self.trace_samples.append( - ProfilingTraceSample( - timestamp=float(now_ns), - endpoint_id=self.endpoint_id, - topic=self.topic, - metric="lease_time_ns", - value=float(lease_ns), - channel_kind=channel_kind, - sample_seq=msg_seq, + ( + now_ns, + self.endpoint_id, + self.topic, + "lease_time_ns", + float(lease_ns), + channel_kind, + msg_seq, + ) + ) + + def append_lease_time( + self, + channel_kind: ProfileChannelType, + lease_ns: int, + msg_seq: int | None = None, + ) -> None: + now_ns = PROFILE_TIME() + self.trace_samples.append( + ( + now_ns, + self.endpoint_id, + self.topic, + "lease_time_ns", + float(lease_ns), + channel_kind, + msg_seq, ) ) @@ -200,14 +257,33 @@ def record_user_span( now_ns = PROFILE_TIME() self.trace_samples.append( - ProfilingTraceSample( - timestamp=float(now_ns), - endpoint_id=self.endpoint_id, - topic=self.topic if label is None else f"{self.topic}:{label}", - metric="user_span_ns", - value=float(span_ns), - channel_kind=self.channel_kind_last, - sample_seq=msg_seq, + ( + now_ns, + self.endpoint_id, + self.topic if label is None else f"{self.topic}:{label}", + "user_span_ns", + float(span_ns), + self.channel_kind_last, + msg_seq, + ) + ) + + def append_user_span( + self, + span_ns: int, + label: str | None, + msg_seq: int | None = None, + ) -> None: + now_ns = PROFILE_TIME() + self.trace_samples.append( + ( + now_ns, + self.endpoint_id, + self.topic if label is None else f"{self.topic}:{label}", + "user_span_ns", + float(span_ns), + self.channel_kind_last, + msg_seq, ) ) @@ -334,7 +410,7 @@ def trace_batch(self, max_samples: int = 1000) -> ProcessProfilingTraceBatch: samples: list[ProfilingTraceSample] = [] limit = max(1, int(max_samples)) - queues: list[deque[ProfilingTraceSample]] = [] + queues: list[deque[TraceRecord]] = [] for metric in self._publishers.values(): if metric.trace_samples: queues.append(metric.trace_samples) @@ -345,13 +421,13 @@ def trace_batch(self, max_samples: int = 1000) -> ProcessProfilingTraceBatch: if len(queues) == 1: queue = queues[0] while queue and len(samples) < limit: - samples.append(queue.popleft()) + samples.append(_trace_sample_from_record(queue.popleft())) elif len(queues) > 1: heap: list[tuple[float, int, int]] = [] for idx, queue in enumerate(queues): sample = queue[0] - seq = sample.sample_seq if sample.sample_seq is not None else -1 - heapq.heappush(heap, (sample.timestamp, seq, idx)) + seq = sample[6] if sample[6] is not None else -1 + heapq.heappush(heap, (float(sample[0]), seq, idx)) while heap and len(samples) < limit: _timestamp, _seq, queue_idx = heapq.heappop(heap) @@ -359,11 +435,11 @@ def trace_batch(self, max_samples: int = 1000) -> ProcessProfilingTraceBatch: if not queue: continue sample = queue.popleft() - samples.append(sample) + samples.append(_trace_sample_from_record(sample)) if queue: nxt = queue[0] - nxt_seq = nxt.sample_seq if nxt.sample_seq is not None else -1 - heapq.heappush(heap, (nxt.timestamp, nxt_seq, queue_idx)) + nxt_seq = nxt[6] if nxt[6] is not None else -1 + heapq.heappush(heap, (float(nxt[0]), nxt_seq, queue_idx)) return ProcessProfilingTraceBatch( process_id=self._process_id, diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 706c6c7a..11abbe7c 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -565,7 +565,35 @@ async def broadcast(self, obj: Any) -> None: f"Publisher {self.id}: Channel {channel.id} connection fail" ) - self._profile.record_publish(self._backpressure.pressure, msg_seq=self._msg_id) + profile = self._profile + inflight = self._backpressure.pressure + profile.messages_published_total += 1 + profile.inflight_messages_current = inflight + + if profile._trace_publish_delta_enabled: + sample_mod = profile.trace_sample_mod + sampled = True + if sample_mod != 1: + profile._trace_counter += 1 + sampled = (profile._trace_counter % sample_mod) == 0 + if sampled: + now_ns = PROFILE_TIME() + last_publish_ts_ns = profile._last_publish_ts_ns + publish_delta_ns = ( + 0 if last_publish_ts_ns is None else now_ns - last_publish_ts_ns + ) + profile._last_publish_ts_ns = now_ns + profile.trace_samples.append( + ( + now_ns, + profile.endpoint_id, + profile.topic, + "publish_delta_ns", + float(publish_delta_ns), + None, + self._msg_id, + ) + ) self._msg_id += 1 def _should_use_local_fast_path(self) -> bool: diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index a3ad3246..4ae67f0f 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -304,37 +304,48 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: channel = self._channels[pub_id] channel_kind = channel.channel_kind self._active_msg_seq = msg_id - self._active_trace_sampled = self._profile.begin_message(channel_kind) + sampled, trace_lease, _trace_user_span = self._profile.trace_receive_state( + channel_kind + ) + self._active_trace_sampled = sampled try: - trace_lease = self._profile._trace_lease_time_enabled start_ns = PROFILE_TIME() if trace_lease else None with channel.get(msg_id, self.id) as msg: yield msg - lease_ns = None - if trace_lease and start_ns is not None: - lease_ns = PROFILE_TIME() - start_ns - self._profile.record_lease_time( - channel_kind, - lease_ns, - msg_seq=msg_id, - sampled=self._active_trace_sampled, - ) + if trace_lease and sampled and start_ns is not None: + now_ns = PROFILE_TIME() + self._profile.trace_samples.append( + ( + now_ns, + self._profile.endpoint_id, + self._profile.topic, + "lease_time_ns", + float(now_ns - start_ns), + channel_kind, + msg_id, + ) + ) finally: self._active_msg_seq = None self._active_trace_sampled = False def begin_profile(self) -> int: - if not self._profile._trace_user_span_enabled or not self._active_trace_sampled: + if not self._active_trace_sampled: return 0 return PROFILE_TIME() def end_profile(self, start_ns: int, label: str | None = None) -> None: if start_ns <= 0: return - end_ns = PROFILE_TIME() - self._profile.record_user_span( - end_ns - start_ns, - label, - msg_seq=self._active_msg_seq, - sampled=self._active_trace_sampled, + now_ns = PROFILE_TIME() + self._profile.trace_samples.append( + ( + now_ns, + self._profile.endpoint_id, + self._profile.topic if label is None else f"{self._profile.topic}:{label}", + "user_span_ns", + float(now_ns - start_ns), + self._profile.channel_kind_last, + self._active_msg_seq, + ) )