diff --git a/veadk/agent.py b/veadk/agent.py index ffd24204..f44536f2 100644 --- a/veadk/agent.py +++ b/veadk/agent.py @@ -114,13 +114,8 @@ def model_post_init(self, __context: Any) -> None: self.tools.append(load_memory) if self.tracers: - self.before_model_callback = [] - self.after_model_callback = [] - self.after_tool_callback = [] for tracer in self.tracers: - self.before_model_callback.append(tracer.tracer_hook_before_model) - self.after_model_callback.append(tracer.tracer_hook_after_model) - self.after_tool_callback.append(tracer.tracer_hook_after_tool) + tracer.do_hooks(self) logger.info(f"Agent `{self.name}` init done.") logger.debug( diff --git a/veadk/runner.py b/veadk/runner.py index 829d3376..58629b20 100644 --- a/veadk/runner.py +++ b/veadk/runner.py @@ -15,6 +15,7 @@ from google.adk.agents import RunConfig from google.adk.agents.run_config import StreamingMode +from google.adk.plugins.base_plugin import BasePlugin from google.adk.runners import Runner as ADKRunner from google.genai import types from google.genai.types import Blob @@ -23,6 +24,7 @@ from veadk.agent import Agent from veadk.evaluation import EvalSetRecorder from veadk.memory.short_term_memory import ShortTermMemory +from veadk.tracing.base_tracer import UserMessagePlugin from veadk.types import MediaMessage from veadk.utils.logger import get_logger from veadk.utils.misc import read_png_to_bytes @@ -44,6 +46,7 @@ def __init__( self, agent: Agent | RemoteVeAgent, short_term_memory: ShortTermMemory, + plugins: list[BasePlugin] = [], app_name: str = "veadk_default_app", user_id: str = "veadk_default_user", ): @@ -65,11 +68,20 @@ def __init__( else: self.long_term_memory = None + # process plugins + try: + # try to detect tracer + _ = self.agent.tracers[0] + plugins.extend([UserMessagePlugin(name="user_message_plugin")]) + except Exception: + logger.debug("Agent has no tracers, telemetry plugin not added.") + self.runner = ADKRunner( app_name=self.app_name, agent=self.agent, session_service=self.session_service, memory_service=self.long_term_memory, + plugins=plugins, ) def _convert_messages(self, messages) -> list: @@ -163,8 +175,30 @@ async def run( if save_tracing_data: self.save_tracing_file(session_id) + self._print_trace_id() + return final_output + def _print_trace_id(self): + if not isinstance(self.agent, Agent): + logger.warning( + ("The agent is not an instance of VeADK Agent, no trace id provided.") + ) + return + + if not self.agent.tracers: + logger.warning( + "No tracer is configured in the agent, no trace id provided." + ) + return + + try: + trace_id = self.agent.tracers[0].get_trace_id() # type: ignore + logger.info(f"Trace id: {trace_id}") + except Exception as e: + logger.warning(f"Get tracer id failed as {e}") + return + def save_tracing_file(self, session_id: str) -> str: if not isinstance(self.agent, Agent): logger.warning( diff --git a/veadk/tracing/base_tracer.py b/veadk/tracing/base_tracer.py index 1b632285..c112e2ef 100644 --- a/veadk/tracing/base_tracer.py +++ b/veadk/tracing/base_tracer.py @@ -17,9 +17,12 @@ from typing import Any, Optional from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin from google.adk.tools import BaseTool, ToolContext +from google.genai import types from opentelemetry import trace from veadk.utils.logger import get_logger @@ -27,6 +30,56 @@ logger = get_logger(__name__) +class UserMessagePlugin(BasePlugin): + def __init__(self, name: str): + super().__init__(name) + + async def on_user_message_callback( + self, + *, + invocation_context: InvocationContext, + user_message: types.Content, + ) -> Optional[types.Content]: + """Callback executed when a user message is received before an invocation starts. + + This callback helps logging and modifying the user message before the + runner starts the invocation. + + Args: + invocation_context: The context for the entire invocation. + user_message: The message content input by user. + + Returns: + An optional `types.Content` to be returned to the ADK. Returning a + value to replace the user message. Returning `None` to proceed + normally. + """ + trace.get_tracer("gcp.vertex.agent") + span = trace.get_current_span() + + logger.debug(f"User message plugin works, catch {span}") + span_name = getattr(span, "name", None) + if span_name and span_name.startswith("invocation"): + agent_name = invocation_context.agent.name + invoke_branch = ( + invocation_context.branch if invocation_context.branch else agent_name + ) + current_session = invocation_context.session + + span.set_attribute("app_name", current_session.app_name) + span.set_attribute("user_id", current_session.user_id) + span.set_attribute("session_id", current_session.id) + + span.set_attribute("agent_name", agent_name) + span.set_attribute("invoke_branch", invoke_branch) + + logger.debug( + f"Add attributes to {span_name}: app_name={current_session.app_name}, user_id={current_session.user_id}, session_id={current_session.id}, agent_name={agent_name}, invoke_branch={invoke_branch}" + ) + + return None + + def replace_bytes_with_empty(data): """ Recursively traverse the data structure and replace all bytes types with empty strings. diff --git a/veadk/tracing/telemetry/opentelemetry_tracer.py b/veadk/tracing/telemetry/opentelemetry_tracer.py index d2fb0291..4159df75 100644 --- a/veadk/tracing/telemetry/opentelemetry_tracer.py +++ b/veadk/tracing/telemetry/opentelemetry_tracer.py @@ -51,11 +51,11 @@ class OpentelemetryTracer(BaseModel, BaseTracer): description="The exporters to export spans.", ) name: str = Field( - DEFAULT_VEADK_TRACER_NAME, description="The identifier of tracer." + default=DEFAULT_VEADK_TRACER_NAME, description="The identifier of tracer." ) app_name: str = Field( - "veadk_app", + default="veadk_app", description="The identifier of app.", ) @@ -127,6 +127,16 @@ def _init_tracer_provider(self) -> None: self._processors.append(processor) logger.debug(f"Init OpentelemetryTracer with {len(self.exporters)} exporters.") + def get_trace_id(self) -> str: + if not self._inmemory_exporter: + return "" + try: + trace_id = hex(int(self._inmemory_exporter._real_exporter.trace_id))[2:] + except Exception: + return "" + + return trace_id + @override def dump( self,