diff --git a/example.py b/example.py index f25410e..10a22c1 100644 --- a/example.py +++ b/example.py @@ -15,7 +15,10 @@ insecure=True) # grpc # Configure span processors -partial_span_processor = PartialSpanProcessor(log_exporter, 5000) +partial_span_processor = PartialSpanProcessor(log_exporter=log_exporter, + heartbeat_interval_millis=1000, + initial_heartbeat_delay_millis=6000, + process_interval_millis=1000) batch_span_processor = BatchSpanProcessor(span_exporter) # Create a TracerProvider @@ -29,6 +32,11 @@ tracer = trace.get_tracer(__name__) # Start a span (logs heartbeat and stop events) -with tracer.start_as_current_span("partial_span_1"): - print("partial_span_1 is running") - sleep(10) +with tracer.start_as_current_span("span 1"): + with tracer.start_as_current_span("span 2"): + print("sleeping inside span 2") + sleep(2) + print("sleeping inside span 1") + sleep(5) +print("sleeping outside spans") +sleep(5) diff --git a/pyproject.toml b/pyproject.toml index ecfcf4d..c93ea71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "partial_span_processor" -version = "0.0.8" +version = "0.0.9" authors = [ { name = "Mladjan Gadzic", email = "gadzic.mladjan@gmail.com" } ] diff --git a/src/partial_span_processor/__init__.py b/src/partial_span_processor/__init__.py index d18b175..c50c415 100644 --- a/src/partial_span_processor/__init__.py +++ b/src/partial_span_processor/__init__.py @@ -14,10 +14,10 @@ from __future__ import annotations +import datetime import json import threading import time -from queue import Queue from typing import TYPE_CHECKING from google.protobuf import json_format @@ -28,12 +28,36 @@ from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor from opentelemetry.trace import TraceFlags +from partial_span_processor.peekable_queue import PeekableQueue + if TYPE_CHECKING: from opentelemetry import context as context_api from opentelemetry.sdk._logs.export import LogExporter from opentelemetry.sdk.resources import Resource WORKER_THREAD_NAME = "OtelPartialSpanProcessor" +DEFAULT_HEARTBEAT_INTERVAL_MILLIS = 5000 +DEFAULT_INITIAL_HEARTBEAT_DELAY_MILLIS = 5000 +DEFAULT_PROCESS_INTERVAL_MILLIS = 5000 + + +def validate_parameters(log_exporter, heartbeat_interval_millis, + initial_heartbeat_delay_millis, process_interval_millis): + if log_exporter is None: + msg = "log_exporter must not be None" + raise ValueError(msg) + + if heartbeat_interval_millis <= 0: + msg = "heartbeat_interval_millis must be greater than 0" + raise ValueError(msg) + + if initial_heartbeat_delay_millis < 0: + msg = "initial_heartbeat_delay_millis must be greater or equal to 0" + raise ValueError(msg) + + if process_interval_millis <= 0: + msg = "process_interval_millis must be greater than 0" + raise ValueError(msg) class PartialSpanProcessor(SpanProcessor): @@ -41,18 +65,26 @@ class PartialSpanProcessor(SpanProcessor): def __init__( self, log_exporter: LogExporter, - heartbeat_interval_millis: int, + heartbeat_interval_millis: int = DEFAULT_HEARTBEAT_INTERVAL_MILLIS, + initial_heartbeat_delay_millis: int = DEFAULT_INITIAL_HEARTBEAT_DELAY_MILLIS, + process_interval_millis: int = DEFAULT_PROCESS_INTERVAL_MILLIS, resource: Resource | None = None, ) -> None: - if heartbeat_interval_millis <= 0: - msg = "heartbeat_interval_ms must be greater than 0" - raise ValueError(msg) + validate_parameters(log_exporter, heartbeat_interval_millis, + initial_heartbeat_delay_millis, process_interval_millis) + self.log_exporter = log_exporter self.heartbeat_interval_millis = heartbeat_interval_millis + self.initial_heartbeat_delay_millis = initial_heartbeat_delay_millis + self.process_interval_millis = process_interval_millis self.resource = resource self.active_spans = {} - self.ended_spans = Queue() + self.delayed_heartbeat_spans: PeekableQueue[tuple[int, datetime.datetime]] = \ + PeekableQueue() + self.delayed_heartbeat_spans_lookup: set[int] = set() + self.ready_heartbeat_spans: PeekableQueue[ + tuple[int, datetime.datetime]] = PeekableQueue() self.lock = threading.Lock() self.done = False @@ -65,44 +97,42 @@ def __init__( def worker(self) -> None: while not self.done: with self.condition: - self.condition.wait(self.heartbeat_interval_millis / 1000) + self.condition.wait(self.process_interval_millis / 1000) if self.done: break - # Remove ended spans from active spans - with self.lock: - while not self.ended_spans.empty(): - span_key, span = self.ended_spans.get() - if span_key in self.active_spans: - del self.active_spans[span_key] - - self.heartbeat() - - def heartbeat(self) -> None: - with self.lock: - for span in list(self.active_spans.values()): - attributes = self.get_heartbeat_attributes() - log_data = self.get_log_data(span, attributes) - self.log_exporter.export([log_data]) + self.process_delayed_heartbeat_spans() + self.process_ready_heartbeat_spans() def on_start(self, span: Span, parent_context: context_api.Context | None = None) -> None: - attributes = self.get_heartbeat_attributes() - log_data = self.get_log_data(span, attributes) - self.log_exporter.export([log_data]) - - span_key = (span.context.trace_id, span.context.span_id) with self.lock: - self.active_spans[span_key] = span + self.active_spans[span.context.span_id] = span + self.delayed_heartbeat_spans_lookup.add(span.context.span_id) + + next_heartbeat_time = datetime.datetime.now() + datetime.timedelta( + milliseconds=self.initial_heartbeat_delay_millis) + self.delayed_heartbeat_spans.put( + (span.context.span_id, next_heartbeat_time)) def on_end(self, span: ReadableSpan) -> None: - attributes = get_stop_attributes() + is_delayed_heartbeat_pending = False + with self.lock: + self.active_spans.pop(span.context.span_id) + + if span.context.span_id in self.delayed_heartbeat_spans_lookup: + is_delayed_heartbeat_pending = True + self.delayed_heartbeat_spans_lookup.remove(span.context.span_id) + + if is_delayed_heartbeat_pending: + return + + self.export_log(span, get_stop_attributes()) + + def export_log(self, span, attributes: dict[str, str]) -> None: log_data = self.get_log_data(span, attributes) self.log_exporter.export([log_data]) - span_key = (span.context.trace_id, span.context.span_id) - self.ended_spans.put((span_key, span)) - def shutdown(self) -> None: # signal the worker thread to finish and then wait for it self.done = True @@ -161,6 +191,57 @@ def get_log_data(self, span: Span, attributes: dict[str, str]) -> LogData: log_record=log_record, instrumentation_scope=instrumentation_scope, ) + def process_delayed_heartbeat_spans(self) -> None: + spans_to_be_logged = [] + with (self.lock): + now = datetime.datetime.now() + while True: + if self.delayed_heartbeat_spans.empty(): + break + + (span_id, next_heartbeat_time) = self.delayed_heartbeat_spans.peek() + if next_heartbeat_time > now: + break + + self.delayed_heartbeat_spans_lookup.discard(span_id) + self.delayed_heartbeat_spans.get() + + span = self.active_spans.get(span_id) + if span: + spans_to_be_logged.append(span) + + next_heartbeat_time = now + datetime.timedelta( + milliseconds=self.heartbeat_interval_millis) + self.ready_heartbeat_spans.put((span_id, next_heartbeat_time)) + + for span in spans_to_be_logged: + self.export_log(span, self.get_heartbeat_attributes()) + + def process_ready_heartbeat_spans(self) -> None: + spans_to_be_logged = [] + now = datetime.datetime.now() + with self.lock: + while True: + if self.ready_heartbeat_spans.empty(): + break + + (span_id, next_heartbeat_time) = self.ready_heartbeat_spans.peek() + if next_heartbeat_time > now: + break + + self.ready_heartbeat_spans.get() + + span = self.active_spans.get(span_id) + if span: + spans_to_be_logged.append(span) + + next_heartbeat_time = now + datetime.timedelta( + milliseconds=self.heartbeat_interval_millis) + self.ready_heartbeat_spans.put((span_id, next_heartbeat_time)) + + for span in spans_to_be_logged: + self.export_log(span, self.get_heartbeat_attributes()) + def get_stop_attributes() -> dict[str, str]: return { diff --git a/src/partial_span_processor/peekable_queue.py b/src/partial_span_processor/peekable_queue.py new file mode 100644 index 0000000..6c6e72e --- /dev/null +++ b/src/partial_span_processor/peekable_queue.py @@ -0,0 +1,8 @@ +import queue + +class PeekableQueue(queue.Queue): + def peek(self): + with self.mutex: + if self._qsize() > 0: + return self.queue[0] + return None \ No newline at end of file diff --git a/tests/partial_span_processor/test_partial_span_processor.py b/tests/partial_span_processor/test_partial_span_processor.py index 64bb9cf..730f341 100644 --- a/tests/partial_span_processor/test_partial_span_processor.py +++ b/tests/partial_span_processor/test_partial_span_processor.py @@ -12,111 +12,193 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import unittest -from time import sleep +from unittest import mock +from unittest.mock import patch from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.trace import Span, SpanContext, TraceFlags from src.partial_span_processor import PartialSpanProcessor -from tests.partial_span_processor.in_memory_log_exporter import InMemoryLogExporter +from tests.partial_span_processor.in_memory_log_exporter import \ + InMemoryLogExporter class TestPartialSpanProcessor(unittest.TestCase): def setUp(self) -> None: - # Set up an in-memory log exporter and processor self.log_exporter = InMemoryLogExporter() self.processor = PartialSpanProcessor( log_exporter=self.log_exporter, - heartbeat_interval_millis=1000, # 1 second + heartbeat_interval_millis=1000, + initial_heartbeat_delay_millis=1000, + process_interval_millis=1000, resource=Resource(attributes={"service.name": "test"}), ) def tearDown(self) -> None: - # Shut down the processor self.processor.shutdown() - def create_mock_span(self, trace_id: int = 1, span_id: int = 1) -> Span: - # Create a mock tracer + @staticmethod + def create_mock_span(trace_id: int = 1, span_id: int = 1) -> Span: tracer_provider = TracerProvider(resource=Resource.create({})) tracer = tracer_provider.get_tracer("test_tracer") - # Start a span using the tracer with tracer.start_as_current_span("test_span") as span: - # Set the span context manually for testing purposes span_context = SpanContext( trace_id=trace_id, span_id=span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED), ) - span._context = span_context # Modify the span's context for testing + span._context = span_context return span - def test_on_start(self) -> None: - # Test the on_start method - span = self.create_mock_span() - self.processor.on_start(span) + def test_shutdown(self) -> None: + self.processor.shutdown() - # Verify the span is added to active_spans - span_key = (span.context.trace_id, span.context.span_id) - assert span_key in self.processor.active_spans + self.assertTrue(self.processor.done) - # Verify a log is emitted - logs = self.log_exporter.get_finished_logs() - assert len(logs) == 1 - assert logs[0].log_record.attributes["partial.event"] == "heartbeat" - assert logs[0].log_record.resource.attributes["service.name"] == "test" + def test_invalid_log_exporter(self): + with self.assertRaises(ValueError) as context: + PartialSpanProcessor( + log_exporter=None, + heartbeat_interval_millis=1000, + initial_heartbeat_delay_millis=1000, + process_interval_millis=1000, + ) + self.assertEqual(str(context.exception), "log_exporter must not be None") + + def test_invalid_heartbeat_interval(self): + with self.assertRaises(ValueError) as context: + PartialSpanProcessor( + log_exporter=InMemoryLogExporter(), + heartbeat_interval_millis=0, + initial_heartbeat_delay_millis=1000, + process_interval_millis=1000, + ) + self.assertEqual(str(context.exception), + "heartbeat_interval_millis must be greater than 0") + + def test_invalid_initial_heartbeat_delay(self): + with self.assertRaises(ValueError) as context: + PartialSpanProcessor( + log_exporter=InMemoryLogExporter(), + heartbeat_interval_millis=1000, + initial_heartbeat_delay_millis=-1, + process_interval_millis=1000, + ) + self.assertEqual(str(context.exception), + "initial_heartbeat_delay_millis must be greater or equal to 0") + + def test_invalid_process_interval(self): + with self.assertRaises(ValueError) as context: + PartialSpanProcessor( + log_exporter=InMemoryLogExporter(), + heartbeat_interval_millis=1000, + initial_heartbeat_delay_millis=1000, + process_interval_millis=0, + ) + self.assertEqual(str(context.exception), + "process_interval_millis must be greater than 0") - def test_on_end(self) -> None: - # Test the on_end method - span = self.create_mock_span() + def test_on_start(self): + span = TestPartialSpanProcessor.create_mock_span() + expected_span_id = span.get_span_context().span_id + now = datetime.datetime.now() self.processor.on_start(span) - self.processor.on_end(span) - # Verify the span is added to ended_spans - assert not self.processor.ended_spans.empty() - - # Verify a log is emitted - logs = self.log_exporter.get_finished_logs() - assert len(logs) == 2 - assert logs[1].log_record.attributes["partial.event"] == "stop" - assert logs[0].log_record.resource.attributes["service.name"] == "test" + self.assertIn(expected_span_id, self.processor.active_spans) + self.assertIn(expected_span_id, + self.processor.delayed_heartbeat_spans_lookup) + self.assertEqual(self.processor.delayed_heartbeat_spans.qsize(), 1) + ( + span_id, + next_heartbeat_time) = self.processor.delayed_heartbeat_spans.get() + self.assertEqual(expected_span_id, span_id) + self.assertGreater(next_heartbeat_time, now) + self.assertEqual(self.log_exporter.get_finished_logs(), ()) + + def test_on_end_when_initial_heartbeat_not_sent(self): + span = TestPartialSpanProcessor.create_mock_span() + span_id = span.get_span_context().span_id + + self.processor.active_spans[span_id] = span + self.processor.delayed_heartbeat_spans_lookup.add(span_id) + self.processor.delayed_heartbeat_spans.put((span_id, unittest.mock.ANY)) - def test_heartbeat(self) -> None: - # Test the heartbeat method - span = self.create_mock_span() - self.processor.on_start(span) - - # Wait for the heartbeat interval - sleep(1.5) - logs = self.log_exporter.get_finished_logs() + self.processor.on_end(span) - # Verify heartbeat logs are emitted - assert len(logs) >= 2 - assert logs[1].log_record.attributes["partial.event"] == "heartbeat" - assert logs[0].log_record.resource.attributes["service.name"] == "test" + self.assertNotIn(span_id, self.processor.active_spans) + self.assertNotIn(span_id, + self.processor.delayed_heartbeat_spans_lookup) + self.assertFalse(self.processor.delayed_heartbeat_spans.empty()) + self.assertEqual(self.log_exporter.get_finished_logs(), ()) - def test_shutdown(self) -> None: - # Test the shutdown method - self.processor.shutdown() + def test_on_end_when_initial_heartbeat_sent(self): + span = TestPartialSpanProcessor.create_mock_span() + span_id = span.get_span_context().span_id - # Verify the worker thread is stopped - assert self.processor.done + self.processor.active_spans[span_id] = span - def test_worker_thread(self) -> None: - # Test the worker thread processes ended spans - span = self.create_mock_span() - self.processor.on_start(span) self.processor.on_end(span) - # Wait for the worker thread to process the ended span - sleep(1.5) + self.assertNotIn(span_id, self.processor.active_spans) - # Verify the span is removed from active_spans - span_key = (span.context.trace_id, span.context.span_id) - assert span_key not in self.processor.active_spans + logs = self.log_exporter.get_finished_logs() + self.assertEqual(len(logs), 1) + self.assertEqual(logs[0].log_record.attributes["partial.event"], "stop") + self.assertEqual(logs[0].log_record.attributes["partial.body.type"], + "json/v1") + self.assertEqual(logs[0].log_record.resource.attributes["service.name"], + "test") + + def test_process_delayed_heartbeat_spans(self): + span = TestPartialSpanProcessor.create_mock_span() + span_id = span.get_span_context().span_id + + self.processor.active_spans[span_id] = span + now = datetime.datetime.now() + self.processor.delayed_heartbeat_spans.put((span_id, now)) + self.processor.delayed_heartbeat_spans_lookup.add(span_id) + + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = now + self.processor.process_delayed_heartbeat_spans() + + self.assertNotIn(span_id, self.processor.delayed_heartbeat_spans_lookup) + self.assertTrue(self.processor.delayed_heartbeat_spans.empty()) + + next_heartbeat_time = now + datetime.timedelta( + milliseconds=self.processor.heartbeat_interval_millis) + self.assertFalse(self.processor.ready_heartbeat_spans.empty()) + self.assertEqual(self.processor.ready_heartbeat_spans.get(), + (span_id, next_heartbeat_time)) + + def test_process_ready_heartbeat_spans(self): + span = TestPartialSpanProcessor.create_mock_span() + span_id = span.get_span_context().span_id + + self.processor.active_spans[span_id] = span + now = datetime.datetime.now() + next_heartbeat_time = now + self.processor.ready_heartbeat_spans.put((span_id, next_heartbeat_time)) + + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = now + self.processor.process_ready_heartbeat_spans() + + updated_next_heartbeat_time = now + datetime.timedelta( + milliseconds=self.processor.heartbeat_interval_millis) + self.assertTrue(self.processor.ready_heartbeat_spans.qsize() == 1) + self.assertEqual(self.processor.ready_heartbeat_spans.get(), + (span_id, updated_next_heartbeat_time)) + + logs = self.log_exporter.get_finished_logs() + self.assertEqual(len(logs), 1) + self.assertEqual(logs[0].log_record.attributes["partial.event"], + "heartbeat") if __name__ == "__main__":