Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 122 additions & 46 deletions src/ezmsg/core/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,34 @@
# Must return monotonic nanoseconds so *_ns metrics remain unit-consistent.
PROFILE_TIME_TYPE: TypeAlias = Callable[[], int]
PROFILE_TIME: PROFILE_TIME_TYPE = time.perf_counter_ns
TraceRecord: TypeAlias = tuple[
int,
str,
str,
str,
float,
ProfileChannelType | None,
int | None,
]


def _endpoint_id(topic: str, id: UUID) -> str:
return f"{topic}:{id}"


def _trace_sample_from_record(record: TraceRecord) -> ProfilingTraceSample:
timestamp_ns, endpoint_id, topic, metric, value, channel_kind, sample_seq = record
return ProfilingTraceSample(
timestamp=float(timestamp_ns),
endpoint_id=endpoint_id,
topic=topic,
metric=metric,
value=value,
channel_kind=channel_kind,
sample_seq=sample_seq,
)


@dataclass
class _PublisherMetrics:
topic: str
Expand All @@ -44,7 +66,7 @@ class _PublisherMetrics:
_trace_counter: int = 0
_trace_publish_delta_enabled: bool = False
_trace_backpressure_wait_enabled: bool = False
trace_samples: deque[ProfilingTraceSample] = field(
trace_samples: deque[TraceRecord] = field(
default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES)
)

Expand All @@ -54,23 +76,26 @@ def record_publish(self, inflight: int, msg_seq: int | None = None) -> None:

if not self._trace_publish_delta_enabled:
return
self._trace_counter += 1
if self._trace_counter % max(1, self.trace_sample_mod) != 0:
return
sample_mod = self.trace_sample_mod
if sample_mod != 1:
self._trace_counter += 1
if self._trace_counter % sample_mod != 0:
return

now_ns = PROFILE_TIME()
publish_delta_ns = (
0 if self._last_publish_ts_ns is None else now_ns - self._last_publish_ts_ns
)
self._last_publish_ts_ns = now_ns
self.trace_samples.append(
ProfilingTraceSample(
timestamp=float(now_ns),
endpoint_id=self.endpoint_id,
topic=self.topic,
metric="publish_delta_ns",
value=float(publish_delta_ns),
sample_seq=msg_seq,
(
now_ns,
self.endpoint_id,
self.topic,
"publish_delta_ns",
float(publish_delta_ns),
None,
msg_seq,
)
)

Expand All @@ -80,13 +105,14 @@ def record_backpressure_wait(self, wait_ns: int, msg_seq: int | None = None) ->

now_ns = PROFILE_TIME()
self.trace_samples.append(
ProfilingTraceSample(
timestamp=float(now_ns),
endpoint_id=self.endpoint_id,
topic=self.topic,
metric="backpressure_wait_ns",
value=float(wait_ns),
sample_seq=msg_seq,
(
now_ns,
self.endpoint_id,
self.topic,
"backpressure_wait_ns",
float(wait_ns),
None,
msg_seq,
)
)

Expand Down Expand Up @@ -135,19 +161,31 @@ class _SubscriberMetrics:
_trace_counter: int = 0
_trace_lease_time_enabled: bool = False
_trace_user_span_enabled: bool = False
trace_samples: deque[ProfilingTraceSample] = field(
trace_samples: deque[TraceRecord] = field(
default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES)
)

def begin_message(self, channel_kind: ProfileChannelType) -> bool:
def trace_receive_state(
self, channel_kind: ProfileChannelType
) -> tuple[bool, bool, bool]:
self.messages_received_total += 1
self.channel_kind_last = channel_kind

if not (self._trace_lease_time_enabled or self._trace_user_span_enabled):
return False
trace_lease = self._trace_lease_time_enabled
trace_user_span = self._trace_user_span_enabled
if not (trace_lease or trace_user_span):
return False, trace_lease, trace_user_span

sample_mod = self.trace_sample_mod
if sample_mod == 1:
return True, trace_lease, trace_user_span

self._trace_counter += 1
return self._trace_counter % max(1, self.trace_sample_mod) == 0
return (self._trace_counter % sample_mod == 0), trace_lease, trace_user_span

def begin_message(self, channel_kind: ProfileChannelType) -> bool:
sampled, _trace_lease, _trace_user_span = self.trace_receive_state(channel_kind)
return sampled

def record_receive(
self,
Expand Down Expand Up @@ -176,14 +214,33 @@ def record_lease_time(

now_ns = PROFILE_TIME()
self.trace_samples.append(
ProfilingTraceSample(
timestamp=float(now_ns),
endpoint_id=self.endpoint_id,
topic=self.topic,
metric="lease_time_ns",
value=float(lease_ns),
channel_kind=channel_kind,
sample_seq=msg_seq,
(
now_ns,
self.endpoint_id,
self.topic,
"lease_time_ns",
float(lease_ns),
channel_kind,
msg_seq,
)
)

def append_lease_time(
self,
channel_kind: ProfileChannelType,
lease_ns: int,
msg_seq: int | None = None,
) -> None:
now_ns = PROFILE_TIME()
self.trace_samples.append(
(
now_ns,
self.endpoint_id,
self.topic,
"lease_time_ns",
float(lease_ns),
channel_kind,
msg_seq,
)
)

Expand All @@ -200,14 +257,33 @@ def record_user_span(

now_ns = PROFILE_TIME()
self.trace_samples.append(
ProfilingTraceSample(
timestamp=float(now_ns),
endpoint_id=self.endpoint_id,
topic=self.topic if label is None else f"{self.topic}:{label}",
metric="user_span_ns",
value=float(span_ns),
channel_kind=self.channel_kind_last,
sample_seq=msg_seq,
(
now_ns,
self.endpoint_id,
self.topic if label is None else f"{self.topic}:{label}",
"user_span_ns",
float(span_ns),
self.channel_kind_last,
msg_seq,
)
)

def append_user_span(
self,
span_ns: int,
label: str | None,
msg_seq: int | None = None,
) -> None:
now_ns = PROFILE_TIME()
self.trace_samples.append(
(
now_ns,
self.endpoint_id,
self.topic if label is None else f"{self.topic}:{label}",
"user_span_ns",
float(span_ns),
self.channel_kind_last,
msg_seq,
)
)

Expand Down Expand Up @@ -334,7 +410,7 @@ def trace_batch(self, max_samples: int = 1000) -> ProcessProfilingTraceBatch:
samples: list[ProfilingTraceSample] = []
limit = max(1, int(max_samples))

queues: list[deque[ProfilingTraceSample]] = []
queues: list[deque[TraceRecord]] = []
for metric in self._publishers.values():
if metric.trace_samples:
queues.append(metric.trace_samples)
Expand All @@ -345,25 +421,25 @@ def trace_batch(self, max_samples: int = 1000) -> ProcessProfilingTraceBatch:
if len(queues) == 1:
queue = queues[0]
while queue and len(samples) < limit:
samples.append(queue.popleft())
samples.append(_trace_sample_from_record(queue.popleft()))
elif len(queues) > 1:
heap: list[tuple[float, int, int]] = []
for idx, queue in enumerate(queues):
sample = queue[0]
seq = sample.sample_seq if sample.sample_seq is not None else -1
heapq.heappush(heap, (sample.timestamp, seq, idx))
seq = sample[6] if sample[6] is not None else -1
heapq.heappush(heap, (float(sample[0]), seq, idx))

while heap and len(samples) < limit:
_timestamp, _seq, queue_idx = heapq.heappop(heap)
queue = queues[queue_idx]
if not queue:
continue
sample = queue.popleft()
samples.append(sample)
samples.append(_trace_sample_from_record(sample))
if queue:
nxt = queue[0]
nxt_seq = nxt.sample_seq if nxt.sample_seq is not None else -1
heapq.heappush(heap, (nxt.timestamp, nxt_seq, queue_idx))
nxt_seq = nxt[6] if nxt[6] is not None else -1
heapq.heappush(heap, (float(nxt[0]), nxt_seq, queue_idx))

return ProcessProfilingTraceBatch(
process_id=self._process_id,
Expand Down
30 changes: 29 additions & 1 deletion src/ezmsg/core/pubclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,35 @@ async def broadcast(self, obj: Any) -> None:
f"Publisher {self.id}: Channel {channel.id} connection fail"
)

self._profile.record_publish(self._backpressure.pressure, msg_seq=self._msg_id)
profile = self._profile
inflight = self._backpressure.pressure
profile.messages_published_total += 1
profile.inflight_messages_current = inflight

if profile._trace_publish_delta_enabled:
sample_mod = profile.trace_sample_mod
sampled = True
if sample_mod != 1:
profile._trace_counter += 1
sampled = (profile._trace_counter % sample_mod) == 0
if sampled:
now_ns = PROFILE_TIME()
last_publish_ts_ns = profile._last_publish_ts_ns
publish_delta_ns = (
0 if last_publish_ts_ns is None else now_ns - last_publish_ts_ns
)
profile._last_publish_ts_ns = now_ns
profile.trace_samples.append(
(
now_ns,
profile.endpoint_id,
profile.topic,
"publish_delta_ns",
float(publish_delta_ns),
None,
self._msg_id,
)
)
self._msg_id += 1

def _should_use_local_fast_path(self) -> bool:
Expand Down
47 changes: 29 additions & 18 deletions src/ezmsg/core/subclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,37 +304,48 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]:
channel = self._channels[pub_id]
channel_kind = channel.channel_kind
self._active_msg_seq = msg_id
self._active_trace_sampled = self._profile.begin_message(channel_kind)
sampled, trace_lease, _trace_user_span = self._profile.trace_receive_state(
channel_kind
)
self._active_trace_sampled = sampled
try:
trace_lease = self._profile._trace_lease_time_enabled
start_ns = PROFILE_TIME() if trace_lease else None
with channel.get(msg_id, self.id) as msg:
yield msg
lease_ns = None
if trace_lease and start_ns is not None:
lease_ns = PROFILE_TIME() - start_ns
self._profile.record_lease_time(
channel_kind,
lease_ns,
msg_seq=msg_id,
sampled=self._active_trace_sampled,
)
if trace_lease and sampled and start_ns is not None:
now_ns = PROFILE_TIME()
self._profile.trace_samples.append(
(
now_ns,
self._profile.endpoint_id,
self._profile.topic,
"lease_time_ns",
float(now_ns - start_ns),
channel_kind,
msg_id,
)
)
finally:
self._active_msg_seq = None
self._active_trace_sampled = False

def begin_profile(self) -> int:
if not self._profile._trace_user_span_enabled or not self._active_trace_sampled:
if not self._active_trace_sampled:
return 0
return PROFILE_TIME()

def end_profile(self, start_ns: int, label: str | None = None) -> None:
if start_ns <= 0:
return
end_ns = PROFILE_TIME()
self._profile.record_user_span(
end_ns - start_ns,
label,
msg_seq=self._active_msg_seq,
sampled=self._active_trace_sampled,
now_ns = PROFILE_TIME()
self._profile.trace_samples.append(
(
now_ns,
self._profile.endpoint_id,
self._profile.topic if label is None else f"{self._profile.topic}:{label}",
"user_span_ns",
float(now_ns - start_ns),
self._profile.channel_kind_last,
self._active_msg_seq,
)
)
Loading