diff --git a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java index 4a375980c..1366f10b2 100644 --- a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java +++ b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java @@ -28,6 +28,7 @@ import com.google.adk.agents.Callbacks; import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; +import com.google.adk.utils.AgentEnums.AgentOrigin; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; @@ -35,7 +36,6 @@ import com.google.genai.types.Part; import io.a2a.client.Client; import io.a2a.client.ClientEvent; -import io.a2a.client.MessageEvent; import io.a2a.client.TaskEvent; import io.a2a.client.TaskUpdateEvent; import io.a2a.spec.A2AClientException; @@ -541,6 +541,11 @@ protected Flowable runLiveImpl(InvocationContext invocationContext) { "runLiveImpl for " + getClass() + " via A2A is not implemented."); } + @Override + public AgentOrigin toolOrigin() { + return AgentOrigin.A2A; + } + /** Exception thrown when the agent card cannot be resolved. */ public static class AgentCardResolutionError extends RuntimeException { public AgentCardResolutionError(String message) { diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 95fe838cc..cbceceed2 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -25,6 +25,7 @@ import com.google.adk.events.Event; import com.google.adk.plugins.Plugin; import com.google.adk.telemetry.Tracing; +import com.google.adk.utils.AgentEnums.AgentOrigin; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; @@ -256,6 +257,15 @@ public ImmutableList afterAgentCallback() { return afterAgentCallback; } + /** + * Returns the origin of the tool when this agent is used as a tool. + * + * @return the tool origin, defaults to "BASE_AGENT". + */ + public AgentOrigin toolOrigin() { + return AgentOrigin.BASE_AGENT; + } + /** * The resolved beforeAgentCallback field as a list. * diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java index ef826fb56..924ad228e 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java @@ -147,7 +147,12 @@ public void flush() { } else { logger.fine("Successfully wrote " + batch.size() + " rows to BigQuery."); } - } catch (AppendSerializationError ase) { + } + } catch (Exception e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + if (e.getCause() instanceof AppendSerializationError ase) { logger.log( Level.SEVERE, "Failed to write batch to BigQuery due to serialization error", ase); Map rowIndexToErrorMessage = ase.getRowIndexToErrorMessage(); @@ -161,12 +166,9 @@ public void flush() { logger.severe( "AppendSerializationError occurred, but no row-specific errors were provided."); } + } else { + logger.log(Level.SEVERE, "Failed to write batch to BigQuery", e); } - } catch (Exception e) { - if (e instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - logger.log(Level.SEVERE, "Failed to write batch to BigQuery", e); } finally { // Clear the vectors to release the memory. root.clear(); @@ -185,7 +187,12 @@ private void populateVector(FieldVector vector, int index, Object value) { return; } if (vector instanceof VarCharVector varCharVector) { - String strValue = (value instanceof JsonNode jsonNode) ? jsonNode.asText() : value.toString(); + String strValue; + if (value instanceof JsonNode jsonNode) { + strValue = jsonNode.isTextual() ? jsonNode.asText() : jsonNode.toString(); + } else { + strValue = value.toString(); + } varCharVector.setSafe(index, strValue.getBytes(UTF_8)); } else if (vector instanceof BigIntVector bigIntVector) { long longValue; diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java index 68b5fb5a1..c1b86f469 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -16,6 +16,9 @@ package com.google.adk.plugins.agentanalytics; +import static com.google.adk.plugins.agentanalytics.JsonFormatter.convertToJsonNode; +import static com.google.adk.plugins.agentanalytics.JsonFormatter.smartTruncate; +import static com.google.adk.plugins.agentanalytics.JsonFormatter.toJavaObject; import static java.util.concurrent.TimeUnit.MILLISECONDS; import com.google.adk.agents.BaseAgent; @@ -25,8 +28,17 @@ import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; +import com.google.adk.plugins.agentanalytics.JsonFormatter.ParsedContent; +import com.google.adk.plugins.agentanalytics.JsonFormatter.TruncationResult; +import com.google.adk.plugins.agentanalytics.TraceManager.RecordData; +import com.google.adk.plugins.agentanalytics.TraceManager.SpanIds; +import com.google.adk.sessions.Session; +import com.google.adk.tools.AgentTool; import com.google.adk.tools.BaseTool; +import com.google.adk.tools.FunctionTool; import com.google.adk.tools.ToolContext; +import com.google.adk.tools.mcp.AbstractMcpTool; +import com.google.adk.utils.AgentEnums.AgentOrigin; import com.google.api.gax.core.FixedCredentialsProvider; import com.google.api.gax.retrying.RetrySettings; import com.google.auth.oauth2.GoogleCredentials; @@ -45,12 +57,13 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; -import io.opentelemetry.api.trace.Span; -import io.opentelemetry.api.trace.SpanContext; +import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import java.io.IOException; +import java.time.Duration; import java.time.Instant; import java.util.HashMap; import java.util.Map; @@ -61,7 +74,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; -import org.threeten.bp.Duration; +import org.jspecify.annotations.Nullable; /** * BigQuery Agent Analytics Plugin for Java. @@ -74,6 +87,14 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin { private static final ImmutableList DEFAULT_AUTH_SCOPES = ImmutableList.of("https://www.googleapis.com/auth/cloud-platform"); private static final AtomicLong threadCounter = new AtomicLong(0); + private static final ImmutableMap HITL_EVENT_TYPES = + ImmutableMap.of( + "adk_request_credential", + "HITL_CREDENTIAL_REQUEST", + "adk_request_confirmation", + "HITL_CONFIRMATION_REQUEST", + "adk_request_input", + "HITL_INPUT_REQUEST"); private final BigQueryLoggerConfig config; private final BigQuery bigQuery; @@ -81,6 +102,7 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin { private final ScheduledExecutorService executor; private final Object tableEnsuredLock = new Object(); @VisibleForTesting final BatchProcessor batchProcessor; + @VisibleForTesting final TraceManager traceManager; private volatile boolean tableEnsured = false; public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException { @@ -96,6 +118,7 @@ public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQue r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); this.executor = Executors.newScheduledThreadPool(1, threadFactory); this.writeClient = createWriteClient(config); + this.traceManager = createTraceManager(); if (config.enabled()) { StreamWriter writer = createWriter(config); @@ -194,9 +217,10 @@ protected StreamWriter createWriter(BigQueryLoggerConfig config) { RetrySettings retrySettings = RetrySettings.newBuilder() .setMaxAttempts(retryConfig.maxRetries()) - .setInitialRetryDelay(Duration.ofMillis(retryConfig.initialDelay().toMillis())) + .setInitialRetryDelay( + org.threeten.bp.Duration.ofMillis(retryConfig.initialDelay().toMillis())) .setRetryDelayMultiplier(retryConfig.multiplier()) - .setMaxRetryDelay(Duration.ofMillis(retryConfig.maxDelay().toMillis())) + .setMaxRetryDelay(org.threeten.bp.Duration.ofMillis(retryConfig.maxDelay().toMillis())) .build(); String streamName = getStreamName(config); @@ -210,58 +234,130 @@ protected StreamWriter createWriter(BigQueryLoggerConfig config) { } } + protected TraceManager createTraceManager() { + return new TraceManager(); + } + + private void logEvent( + String eventType, + InvocationContext invocationContext, + Object content, + Optional eventData) { + logEvent(eventType, invocationContext, content, false, eventData); + } + private void logEvent( String eventType, InvocationContext invocationContext, - Optional callbackContext, Object content, - Map extraAttributes) { - if (batchProcessor == null) { + boolean isContentTruncated, + Optional eventData) { + if (!config.enabled() || batchProcessor == null) { return; } - + if (!config.eventAllowlist().isEmpty() && !config.eventAllowlist().contains(eventType)) { + return; + } + if (config.eventDenylist().contains(eventType)) { + return; + } + // Ensure table exists before logging. ensureTableExistsOnce(); - + // Log common fields Map row = new HashMap<>(); row.put("timestamp", Instant.now()); row.put("event_type", eventType); - row.put( - "agent", - callbackContext.map(CallbackContext::agentName).orElse(invocationContext.agent().name())); + row.put("agent", invocationContext.agent().name()); row.put("session_id", invocationContext.session().id()); row.put("invocation_id", invocationContext.invocationId()); row.put("user_id", invocationContext.userId()); - - if (content instanceof Content contentParts) { - row.put( - "content_parts", - JsonFormatter.formatContentParts(Optional.of(contentParts), config.maxContentLength())); - row.put( - "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); - } else if (content != null) { - row.put( - "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); + // Parse and log content + ParsedContent parsedContent = JsonFormatter.parse(content, config.maxContentLength()); + row.put("content_parts", parsedContent.parts()); + row.put("content", parsedContent.content()); + row.put("is_truncated", isContentTruncated || parsedContent.isTruncated()); + + EventData data = eventData.orElse(EventData.builder().build()); + row.put("status", data.status()); + data.errorMessage().ifPresent(msg -> row.put("error_message", msg)); + + Map latencyMap = extractLatency(data); + if (latencyMap != null) { + row.put("latency_ms", convertToJsonNode(latencyMap)); } + row.put("attributes", convertToJsonNode(getAttributes(data, invocationContext))); - Map attributes = new HashMap<>(config.customTags()); - if (extraAttributes != null) { - attributes.putAll(extraAttributes); - } + addTraceDetails(row, invocationContext, eventData); + batchProcessor.append(row); + } + + private void addTraceDetails( + Map row, InvocationContext invocationContext, Optional eventData) { + String traceId = + eventData + .flatMap(EventData::traceIdOverride) + .orElseGet(() -> traceManager.getTraceId(invocationContext)); + Optional ambientSpanIds = traceManager.getAmbientSpanAndParent(); + SpanIds spanIds = ambientSpanIds.orElse(traceManager.getCurrentSpanAndParent()); + + row.put("trace_id", traceId); + row.put( + "span_id", + eventData.flatMap(EventData::spanIdOverride).orElse(spanIds.spanId().orElse(null))); row.put( - "attributes", - JsonFormatter.smartTruncate(attributes, config.maxContentLength()).toString()); + "parent_span_id", + eventData + .flatMap(EventData::parentSpanIdOverride) + .orElse(spanIds.parentSpanId().orElse(null))); + } - addTraceDetails(row); - batchProcessor.append(row); + private @Nullable Map extractLatency(EventData eventData) { + Map latencyMap = new HashMap<>(); + eventData.latency().ifPresent(v -> latencyMap.put("total_ms", v.toMillis())); + eventData + .timeToFirstToken() + .ifPresent(v -> latencyMap.put("time_to_first_token_ms", v.toMillis())); + return latencyMap.isEmpty() ? null : latencyMap; } - // TODO(b/491849911): Implement own trace management functionality. - private void addTraceDetails(Map row) { - SpanContext spanContext = Span.current().getSpanContext(); - if (spanContext.isValid()) { - row.put("trace_id", spanContext.getTraceId()); - row.put("span_id", spanContext.getSpanId()); + private Map getAttributes( + EventData eventData, InvocationContext invocationContext) { + Map attributes = new HashMap<>(eventData.extraAttributes()); + + attributes.put("root_agent_name", traceManager.getRootAgentName()); + eventData.model().ifPresent(m -> attributes.put("model", m)); + eventData.modelVersion().ifPresent(mv -> attributes.put("model_version", mv)); + eventData + .usageMetadata() + .ifPresent( + um -> { + TruncationResult result = smartTruncate(um, config.maxContentLength()); + attributes.put("usage_metadata", toJavaObject(result.node())); + }); + + if (config.logSessionMetadata()) { + try { + Session session = invocationContext.session(); + Map sessionMeta = new HashMap<>(); + sessionMeta.put("session_id", session.id()); + sessionMeta.put("app_name", session.appName()); + sessionMeta.put("user_id", session.userId()); + + if (!session.state().isEmpty()) { + TruncationResult result = smartTruncate(session.state(), config.maxContentLength()); + sessionMeta.put("state", toJavaObject(result.node())); + } + attributes.put("session_metadata", sessionMeta); + } catch (RuntimeException e) { + // Ignore session enrichment errors as in Python. + } } + + if (!config.customTags().isEmpty()) { + attributes.put("custom_tags", config.customTags()); + } + + return attributes; } @Override @@ -284,77 +380,237 @@ public Completable close() { return Completable.complete(); } - @Override - public Maybe onUserMessageCallback( - InvocationContext invocationContext, Content userMessage) { - return Maybe.fromAction( - () -> logEvent("USER_MESSAGE", invocationContext, Optional.empty(), userMessage, null)); + private Optional getCompletedEventData(InvocationContext invocationContext) { + String traceId = traceManager.getTraceId(invocationContext); + // Pop the invocation span from the trace manager. + Optional popped = traceManager.popSpan(); + if (popped.isEmpty()) { + // No invocation span to pop. + logger.info("No invocation span to pop."); + return Optional.empty(); + } + String parentSpanId = traceManager.getCurrentSpanId(); + + EventData.Builder eventDataBuilder = EventData.builder(); + eventDataBuilder.setTraceIdOverride(traceId); + eventDataBuilder.setLatency(popped.get().duration()); + // Only override span IDs when no ambient OTel span exists. + // Keep STARTING/COMPLETED pairs consistent. + if (!traceManager.hasAmbientSpan()) { + if (parentSpanId != null) { + eventDataBuilder.setParentSpanIdOverride(parentSpanId); + } + if (popped.get().spanId() != null) { + eventDataBuilder.setSpanIdOverride(popped.get().spanId()); + } + } + return Optional.of(eventDataBuilder.build()); } + // --- Plugin callbacks --- @Override - public Maybe beforeRunCallback(InvocationContext invocationContext) { + public Maybe onUserMessageCallback( + InvocationContext invocationContext, Content userMessage) { return Maybe.fromAction( - () -> logEvent("INVOCATION_START", invocationContext, Optional.empty(), null, null)); + () -> { + traceManager.ensureInvocationSpan(invocationContext); + logEvent("USER_MESSAGE_RECEIVED", invocationContext, userMessage, Optional.empty()); + if (userMessage.parts().isPresent()) { + for (Part part : userMessage.parts().get()) { + if (part.functionCall().isPresent() + && HITL_EVENT_TYPES.containsKey(part.functionCall().get().name().orElse(""))) { + String hitlEvent = HITL_EVENT_TYPES.get(part.functionCall().get().name().get()); + TruncationResult truncatedResult = smartTruncate(part, config.maxContentLength()); + logEvent( + hitlEvent + "_COMPLETED", + invocationContext, + ImmutableMap.of( + "tool", + part.functionCall().get().name().get(), + "result", + truncatedResult.node()), + truncatedResult.isTruncated(), + Optional.empty()); + } + } + } + }); } @Override public Maybe onEventCallback(InvocationContext invocationContext, Event event) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); - attrs.put("event_author", event.author()); + EventData.Builder eventDataBuilder = + EventData.builder() + .setExtraAttributes( + ImmutableMap.builder() + .put("state_delta", event.actions().stateDelta()) + .put("author", event.author()) + .buildOrThrow()); logEvent( - "EVENT", invocationContext, Optional.empty(), event.content().orElse(null), attrs); + "STATE_DELTA", + invocationContext, + event.content().orElse(null), + Optional.of(eventDataBuilder.build())); + + if (event.content().isPresent() && event.content().get().parts().isPresent()) { + for (Part part : event.content().get().parts().get()) { + if (part.functionCall().isPresent() + && HITL_EVENT_TYPES.containsKey(part.functionCall().get().name().orElse(""))) { + String hitlEvent = HITL_EVENT_TYPES.get(part.functionCall().get().name().get()); + TruncationResult truncatedResult = + smartTruncate(part.functionCall().get().args(), config.maxContentLength()); + logEvent( + hitlEvent + "_COMPLETED", + invocationContext, + ImmutableMap.of( + "tool", + part.functionCall().get().name().get(), + "args", + truncatedResult.node()), + truncatedResult.isTruncated(), + Optional.empty()); + } + if (part.functionResponse().isPresent() + && HITL_EVENT_TYPES.containsKey( + part.functionResponse().get().name().orElse(""))) { + String hitlEvent = HITL_EVENT_TYPES.get(part.functionResponse().get().name().get()); + TruncationResult truncatedResult = + smartTruncate( + part.functionResponse().get().response(), config.maxContentLength()); + logEvent( + hitlEvent + "_COMPLETED", + invocationContext, + ImmutableMap.of( + "tool", + part.functionResponse().get().name().get(), + "response", + truncatedResult.node()), + truncatedResult.isTruncated(), + Optional.empty()); + } + } + } }); } + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + traceManager.ensureInvocationSpan(invocationContext); + return Maybe.fromAction( + () -> logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty())); + } + @Override public Completable afterRunCallback(InvocationContext invocationContext) { return Completable.fromAction( () -> { - logEvent("INVOCATION_END", invocationContext, Optional.empty(), null, null); + logEvent( + "INVOCATION_COMPLETED", + invocationContext, + null, + getCompletedEventData(invocationContext)); batchProcessor.flush(); + traceManager.clearStack(); }); } @Override public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return Maybe.fromAction( - () -> - logEvent( - "AGENT_START", - callbackContext.invocationContext(), - Optional.of(callbackContext), - null, - null)); + () -> { + traceManager.pushSpan("agent:" + agent.name()); + logEvent("AGENT_STARTING", callbackContext.invocationContext(), null, Optional.empty()); + }); } @Override public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return Maybe.fromAction( - () -> - logEvent( - "AGENT_END", - callbackContext.invocationContext(), - Optional.of(callbackContext), - null, - null)); + () -> { + logEvent( + "AGENT_COMPLETED", + callbackContext.invocationContext(), + null, + getCompletedEventData(callbackContext.invocationContext())); + }); } + /** + * Callback before LLM call. + * + *

Logs the LLM request details including: 1. Prompt content 2. System instruction (if + * available) + * + *

The content is formatted as 'Prompt: {prompt} | System Prompt: {system_prompt}'. + */ @Override public Maybe beforeModelCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); + Map attributes = new HashMap<>(); + Map llmConfig = new HashMap<>(); LlmRequest req = llmRequest.build(); - attrs.put("model", req.model().orElse("unknown")); - logEvent( - "MODEL_REQUEST", - callbackContext.invocationContext(), - Optional.of(callbackContext), - req, - attrs); + if (req.config().isPresent()) { + if (req.config().get().temperature().isPresent()) { + llmConfig.put("temperature", req.config().get().temperature().get()); + } + if (req.config().get().topP().isPresent()) { + llmConfig.put("top_p", req.config().get().topP().get()); + } + if (req.config().get().topK().isPresent()) { + llmConfig.put("top_k", req.config().get().topK().get()); + } + if (req.config().get().candidateCount().isPresent()) { + llmConfig.put("candidate_count", req.config().get().candidateCount().get()); + } + if (req.config().get().maxOutputTokens().isPresent()) { + llmConfig.put("max_output_tokens", req.config().get().maxOutputTokens().get()); + } + if (req.config().get().stopSequences().isPresent()) { + llmConfig.put("stop_sequences", req.config().get().stopSequences().get()); + } + if (req.config().get().presencePenalty().isPresent()) { + llmConfig.put("presence_penalty", req.config().get().presencePenalty().get()); + } + if (req.config().get().frequencyPenalty().isPresent()) { + llmConfig.put("frequency_penalty", req.config().get().frequencyPenalty().get()); + } + if (req.config().get().responseMimeType().isPresent()) { + llmConfig.put("response_mime_type", req.config().get().responseMimeType().get()); + } + if (req.config().get().responseSchema().isPresent()) { + llmConfig.put("response_schema", req.config().get().responseSchema().get()); + } + if (req.config().get().seed().isPresent()) { + llmConfig.put("seed", req.config().get().seed().get()); + } + if (req.config().get().responseLogprobs().isPresent()) { + llmConfig.put("response_logprobs", req.config().get().responseLogprobs().get()); + } + if (req.config().get().logprobs().isPresent()) { + llmConfig.put("logprobs", req.config().get().logprobs().get()); + } + // Put labels in attributes instead of LLM config. + if (req.config().get().labels().isPresent()) { + attributes.put("labels", req.config().get().labels().get()); + } + } + if (!llmConfig.isEmpty()) { + attributes.put("llm_config", llmConfig); + } + if (!req.tools().isEmpty()) { + attributes.put("tools", req.tools().keySet()); + } + EventData eventData = + EventData.builder() + .setModel(req.model().orElse("")) + .setExtraAttributes(attributes) + .build(); + traceManager.pushSpan("llm_request"); + logEvent("LLM_REQUEST", callbackContext.invocationContext(), req, Optional.of(eventData)); }); } @@ -363,14 +619,99 @@ public Maybe afterModelCallback( CallbackContext callbackContext, LlmResponse llmResponse) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); - llmResponse.usageMetadata().ifPresent(u -> attrs.put("usage_metadata", u)); + // TODO(b/495809488): Add formatting of the content + ParsedContent parsedContent = + JsonFormatter.parse(llmResponse.content().orElse(null), config.maxContentLength()); + + Map usageDict = new HashMap<>(); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage.promptTokenCount().ifPresent(c -> usageDict.put("prompt", c)); + usage.candidatesTokenCount().ifPresent(c -> usageDict.put("completion", c)); + usage.totalTokenCount().ifPresent(c -> usageDict.put("total", c)); + }); + + Map contentMap = new HashMap<>(); + if (parsedContent.content() != null && !parsedContent.content().isNull()) { + contentMap.put("response", parsedContent.content()); + } + if (!usageDict.isEmpty()) { + contentMap.put("usage", usageDict); + } + + InvocationContext invocationContext = callbackContext.invocationContext(); + String spanId = traceManager.getCurrentSpanId(); + SpanIds spanIds = traceManager.getCurrentSpanAndParent(); + String parentSpanId = spanIds.parentSpanId().orElse(null); + + boolean isPopped = false; + Duration duration = Duration.ZERO; + Duration tfft = null; + + if (llmResponse.partial().orElse(false)) { + // Streaming chunk - do NOT pop span yet + if (spanId != null) { + traceManager.recordFirstToken(spanId); + Optional startTime = traceManager.getStartTime(spanId); + Optional firstTokenTime = traceManager.getFirstTokenTime(spanId); + if (startTime.isPresent()) { + duration = Duration.between(startTime.get(), Instant.now()); + } + if (startTime.isPresent() && firstTokenTime.isPresent()) { + tfft = Duration.between(startTime.get(), firstTokenTime.get()); + } + } + } else { + // Final response - pop span + if (spanId != null) { + traceManager.recordFirstToken(spanId); + Optional startTime = traceManager.getStartTime(spanId); + Optional firstTokenTime = traceManager.getFirstTokenTime(spanId); + if (startTime.isPresent() && firstTokenTime.isPresent()) { + tfft = Duration.between(startTime.get(), firstTokenTime.get()); + } + } + Optional popped = traceManager.popSpan(); + if (popped.isPresent()) { + spanId = popped.get().spanId(); + duration = popped.get().duration(); + isPopped = true; + } + } + + boolean hasAmbient = traceManager.hasAmbientSpan(); + boolean useOverride = isPopped && !hasAmbient; + + EventData.Builder eventDataBuilder = EventData.builder(); + if (!duration.isZero()) { + eventDataBuilder.setLatency(duration); + } + if (tfft != null) { + eventDataBuilder.setTimeToFirstToken(tfft); + } + llmResponse.modelVersion().ifPresent(eventDataBuilder::setModelVersion); + + if (!usageDict.isEmpty()) { + eventDataBuilder.setUsageMetadata(usageDict); + } + + if (useOverride) { + if (spanId != null) { + eventDataBuilder.setSpanIdOverride(spanId); + } + if (parentSpanId != null) { + eventDataBuilder.setParentSpanIdOverride(parentSpanId); + } + } + logEvent( - "MODEL_RESPONSE", - callbackContext.invocationContext(), - Optional.of(callbackContext), - llmResponse, - attrs); + "LLM_RESPONSE", + invocationContext, + contentMap.isEmpty() ? null : contentMap, + parsedContent.isTruncated(), + Optional.of(eventDataBuilder.build())); }); } @@ -379,14 +720,28 @@ public Maybe onModelErrorCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); - attrs.put("error_message", error.getMessage()); - logEvent( - "MODEL_ERROR", - callbackContext.invocationContext(), - Optional.of(callbackContext), - null, - attrs); + InvocationContext invocationContext = callbackContext.invocationContext(); + Optional popped = traceManager.popSpan(); + String spanId = popped.map(RecordData::spanId).orElse(null); + + SpanIds spanIds = traceManager.getCurrentSpanAndParent(); + String parentSpanId = spanIds.spanId().orElse(null); + + boolean hasAmbient = traceManager.hasAmbientSpan(); + EventData.Builder eventDataBuilder = + EventData.builder() + .setStatus("ERROR") + .setLatency(popped.get().duration()) + .setErrorMessage(error.getMessage()); + if (!hasAmbient) { + if (spanId != null) { + eventDataBuilder.setSpanIdOverride(spanId); + } + if (parentSpanId != null) { + eventDataBuilder.setParentSpanIdOverride(parentSpanId); + } + } + logEvent("LLM_ERROR", invocationContext, null, Optional.of(eventDataBuilder.build())); }); } @@ -395,14 +750,12 @@ public Maybe> beforeToolCallback( BaseTool tool, Map toolArgs, ToolContext toolContext) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); - attrs.put("tool_name", tool.name()); - logEvent( - "TOOL_START", - toolContext.invocationContext(), - Optional.of(toolContext), - toolArgs, - attrs); + TruncationResult res = smartTruncate(toolArgs, config.maxContentLength()); + ImmutableMap contentMap = + ImmutableMap.of( + "tool_origin", getToolOrigin(tool), "tool", tool.name(), "args", res.node()); + traceManager.pushSpan("tool"); + logEvent("TOOL_STARTING", toolContext.invocationContext(), contentMap, Optional.empty()); }); } @@ -414,10 +767,35 @@ public Maybe> afterToolCallback( Map result) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); - attrs.put("tool_name", tool.name()); + Optional popped = traceManager.popSpan(); + TruncationResult truncationResult = smartTruncate(result, config.maxContentLength()); + ImmutableMap contentMap = + ImmutableMap.of( + "tool", + tool.name(), + "result", + truncationResult.node(), + "tool_origin", + getToolOrigin(tool)); + + SpanIds spanIds = traceManager.getCurrentSpanAndParent(); + boolean hasAmbient = traceManager.hasAmbientSpan(); + + EventData.Builder eventDataBuilder = EventData.builder(); + if (popped.isPresent()) { + eventDataBuilder.setLatency(popped.get().duration()); + } + if (!hasAmbient) { + popped.ifPresent(p -> eventDataBuilder.setSpanIdOverride(p.spanId())); + spanIds.spanId().ifPresent(eventDataBuilder::setParentSpanIdOverride); + } + logEvent( - "TOOL_END", toolContext.invocationContext(), Optional.of(toolContext), result, attrs); + "TOOL_COMPLETED", + toolContext.invocationContext(), + contentMap, + truncationResult.isTruncated(), + Optional.of(eventDataBuilder.build())); }); } @@ -426,11 +804,51 @@ public Maybe> onToolErrorCallback( BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); - attrs.put("tool_name", tool.name()); - attrs.put("error_message", error.getMessage()); + Optional popped = traceManager.popSpan(); + TruncationResult truncationResult = smartTruncate(toolArgs, config.maxContentLength()); + String toolOrigin = getToolOrigin(tool); + ImmutableMap contentMap = + ImmutableMap.of( + "tool", tool.name(), "args", truncationResult.node(), "tool_origin", toolOrigin); + + SpanIds spanIds = traceManager.getCurrentSpanAndParent(); + boolean hasAmbient = traceManager.hasAmbientSpan(); + + EventData.Builder eventDataBuilder = + EventData.builder().setStatus("ERROR").setErrorMessage(error.getMessage()); + + if (popped.isPresent()) { + eventDataBuilder.setLatency(popped.get().duration()); + } + if (!hasAmbient) { + popped.ifPresent(p -> eventDataBuilder.setSpanIdOverride(p.spanId())); + spanIds.spanId().ifPresent(eventDataBuilder::setParentSpanIdOverride); + } + logEvent( - "TOOL_ERROR", toolContext.invocationContext(), Optional.of(toolContext), null, attrs); + "TOOL_ERROR", + toolContext.invocationContext(), + contentMap, + truncationResult.isTruncated(), + Optional.of(eventDataBuilder.build())); }); } + + private String getToolOrigin(BaseTool tool) { + if (tool instanceof AbstractMcpTool) { + return "MCP"; + } + if (tool instanceof AgentTool agentTool) { + return agentTool.getAgent().toolOrigin().equals(AgentOrigin.BASE_AGENT) + ? AgentOrigin.SUB_AGENT.toString() + : agentTool.getAgent().toolOrigin().toString(); + } + if (tool.name().equals("transfer_to_agent")) { + return "TRANSFER_AGENT"; + } + if (tool instanceof FunctionTool) { + return "LOCAL"; + } + return "UNKNOWN"; + } } diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java index aa5bf37de..34a29e72b 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java @@ -25,7 +25,7 @@ import java.util.Map; import java.util.Optional; import java.util.function.BiFunction; -import javax.annotation.Nullable; +import org.jspecify.annotations.Nullable; /** Configuration for the BigQueryAgentAnalyticsPlugin. */ @AutoValue @@ -33,13 +33,11 @@ public abstract class BigQueryLoggerConfig { // Whether the plugin is enabled. public abstract boolean enabled(); - // List of event types to log. If None, all are allowed - // TODO(b/491852782): Implement allowlist/denylist for event types. + // List of event types to log. If None, all are allowed. @Nullable public abstract ImmutableList eventAllowlist(); // List of event types to ignore. - // TODO(b/491852782): Implement allowlist/denylist for event types. @Nullable public abstract ImmutableList eventDenylist(); @@ -102,6 +100,8 @@ public abstract class BigQueryLoggerConfig { @Nullable public abstract Credentials credentials(); + public abstract Builder toBuilder(); + public static Builder builder() { return new AutoValue_BigQueryLoggerConfig.Builder() .setEnabled(true) @@ -117,6 +117,8 @@ public static Builder builder() { .setQueueMaxSize(10000) .setLogSessionMetadata(true) .setCustomTags(ImmutableMap.of()) + .setEventAllowlist(ImmutableList.of()) + .setEventDenylist(ImmutableList.of()) // TODO(b/491851868): Enable auto-schema upgrade once implemented. .setAutoSchemaUpgrade(false); } diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/EventData.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/EventData.java new file mode 100644 index 000000000..8fd95a070 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/EventData.java @@ -0,0 +1,64 @@ +package com.google.adk.plugins.agentanalytics; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableMap; +import java.time.Duration; +import java.util.Map; +import java.util.Optional; + +/** Typed container for structured fields passed to _log_event. */ +@AutoValue +abstract class EventData { + abstract Optional spanIdOverride(); + + abstract Optional parentSpanIdOverride(); + + abstract Optional latency(); + + abstract Optional timeToFirstToken(); + + abstract Optional model(); + + abstract Optional modelVersion(); + + abstract Optional usageMetadata(); + + abstract String status(); + + abstract Optional errorMessage(); + + abstract ImmutableMap extraAttributes(); + + abstract Optional traceIdOverride(); + + static Builder builder() { + return new AutoValue_EventData.Builder().setStatus("OK").setExtraAttributes(ImmutableMap.of()); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setSpanIdOverride(String value); + + abstract Builder setParentSpanIdOverride(String value); + + abstract Builder setLatency(Duration value); + + abstract Builder setTimeToFirstToken(Duration value); + + abstract Builder setModel(String value); + + abstract Builder setModelVersion(String value); + + abstract Builder setUsageMetadata(Object value); + + abstract Builder setStatus(String value); + + abstract Builder setErrorMessage(String value); + + abstract Builder setExtraAttributes(Map value); + + abstract Builder setTraceIdOverride(String value); + + abstract EventData build(); + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java index b4b4a1049..9a8184c70 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java @@ -20,22 +20,202 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.adk.models.LlmRequest; +import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FileData; +import com.google.genai.types.FunctionCall; import com.google.genai.types.Part; +import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; +import org.jspecify.annotations.Nullable; -/** Utility for formatting and truncating content for BigQuery logging. */ +/** Utility for parsing, formatting and truncating content for BigQuery logging. */ final class JsonFormatter { private static final ObjectMapper mapper = new ObjectMapper().findAndRegisterModules(); - private JsonFormatter() {} + @AutoValue + abstract static class TruncationResult { + abstract JsonNode node(); + + abstract boolean isTruncated(); + + static TruncationResult create(JsonNode node, boolean isTruncated) { + return new AutoValue_JsonFormatter_TruncationResult(node, isTruncated); + } + } + + @AutoValue + abstract static class ParsedContent { + abstract ImmutableList parts(); + + abstract JsonNode content(); + + abstract boolean isTruncated(); + + static ParsedContent create( + ImmutableList parts, JsonNode content, boolean isTruncated) { + return new AutoValue_JsonFormatter_ParsedContent(parts, content, isTruncated); + } + } + + @AutoValue + abstract static class ParsedContentObject { + abstract ArrayNode parts(); + + abstract String summary(); + + abstract boolean isTruncated(); + + static ParsedContentObject create(ArrayNode parts, String summary, boolean isTruncated) { + return new AutoValue_JsonFormatter_ParsedContentObject(parts, summary, isTruncated); + } + } + + /** + * Parses content into JSON payload and content parts, matching Python implementation. + * + * @param content the content to parse + * @param maxLength the maximum length for text fields + * @return a ParsedContent object + */ + static ParsedContent parse(Object content, int maxLength) { + JsonNode contentNode = mapper.nullNode(); + ArrayNode contentParts = mapper.createArrayNode(); + boolean isTruncated = false; + + if (content instanceof LlmRequest llmRequest) { + ObjectNode jsonPayload = mapper.createObjectNode(); + // Handle prompt + ArrayNode messages = mapper.createArrayNode(); + List contents = llmRequest.contents(); + for (Content c : contents) { + String role = c.role().orElse("unknown"); + ParsedContentObject parsedContentObject = parseContentObject(c, maxLength); + isTruncated = isTruncated || parsedContentObject.isTruncated(); + contentParts.addAll(parsedContentObject.parts()); + + ObjectNode message = mapper.createObjectNode(); + message.put("role", role); + message.put("content", parsedContentObject.summary()); + messages.add(message); + } + if (!messages.isEmpty()) { + jsonPayload.set("prompt", messages); + } + // Handle system instruction + if (llmRequest.config().isPresent() + && llmRequest.config().get().systemInstruction().isPresent()) { + Content systemInstruction = llmRequest.config().get().systemInstruction().get(); + ParsedContentObject parsedSystemInstruction = + parseContentObject(systemInstruction, maxLength); + isTruncated = isTruncated || parsedSystemInstruction.isTruncated(); + contentParts.addAll(parsedSystemInstruction.parts()); + jsonPayload.put("system_prompt", parsedSystemInstruction.summary()); + } + contentNode = jsonPayload; + } else if (content instanceof Content || content instanceof Part) { + ParsedContentObject parsedContentObject = parseContentObject(content, maxLength); + ObjectNode summaryNode = mapper.createObjectNode(); + summaryNode.put("text_summary", parsedContentObject.summary()); + return ParsedContent.create( + ImmutableList.copyOf(parsedContentObject.parts()), + summaryNode, + parsedContentObject.isTruncated()); + } else if (content instanceof String s) { + TruncationResult result = truncateWithStatus(s, maxLength); + contentNode = result.node(); + isTruncated = result.isTruncated(); + } else { + TruncationResult result = smartTruncate(content, maxLength); + contentNode = result.node(); + isTruncated = result.isTruncated(); + } + return ParsedContent.create(ImmutableList.copyOf(contentParts), contentNode, isTruncated); + } + + /** + * Parses a Content or Part object into summary text and content parts. + * + * @param content the Content or Part object to parse + * @param maxLength the maximum length of text fields before truncation + * @return a ParsedContentObject containing parts, summary, and truncation flag + */ + private static ParsedContentObject parseContentObject(Object content, int maxLength) { + ArrayNode contentParts = mapper.createArrayNode(); + boolean isTruncated = false; + List summaryText = new ArrayList<>(); + + List parts; + if (content instanceof Content c) { + parts = c.parts().orElse(ImmutableList.of()); + } else if (content instanceof Part p) { + parts = ImmutableList.of(p); + } else { + return ParsedContentObject.create(contentParts, "", false); + } + + for (int i = 0; i < parts.size(); i++) { + Part part = parts.get(i); + ObjectNode partData = mapper.createObjectNode(); + partData.put("part_index", i); + partData.put("mime_type", "text/plain"); + partData.putNull("uri"); + partData.putNull("text"); + partData.put("part_attributes", "{}"); + partData.put("storage_mode", "INLINE"); + partData.putNull("object_ref"); + + // CASE A: It is already a URI (e.g. from user input) + if (part.fileData().isPresent()) { + FileData fileData = part.fileData().get(); + partData.put("storage_mode", "EXTERNAL_URI"); + partData.put("uri", fileData.fileUri().orElse(null)); + partData.put("mime_type", fileData.mimeType().orElse(null)); + } + // CASE B: It is Binary/Inline Data (Image/Blob) + else if (part.inlineData().isPresent()) { + // TODO: (b/485571635) Implement GCS offloading here. + partData.put("text", "[BINARY DATA]"); + partData.put("mime_type", part.inlineData().get().mimeType().orElse("")); + } + // CASE C: Text + else if (part.text().isPresent()) { + String text = part.text().get(); + // TODO: (b/485571635) Implement GCS offloading if text length exceeds maxLength. + if (text.length() > maxLength) { + text = truncate(text, maxLength); + isTruncated = true; + } + partData.put("text", text); + summaryText.add(text); + } else if (part.functionCall().isPresent()) { + FunctionCall fc = part.functionCall().get(); + partData.put("mime_type", "application/json"); + partData.put("text", "Function: " + fc.name().orElse("unknown")); + ObjectNode partAttributes = mapper.createObjectNode(); + partAttributes.put("function_name", fc.name().orElse("unknown")); + partData.put("part_attributes", partAttributes.toString()); + } + contentParts.add(partData); + } + + String summaryResult = String.join(" | ", summaryText); + if (summaryResult.length() > maxLength) { + summaryResult = truncate(summaryResult, maxLength); + isTruncated = true; + } + + return ParsedContentObject.create(contentParts, summaryResult, isTruncated); + } /** Formats Content parts into an ArrayNode for BigQuery logging. */ - public static ArrayNode formatContentParts(Optional content, int maxLength) { + static ArrayNode formatContentParts(Optional content, int maxLength) { ArrayNode partsArray = mapper.createArrayNode(); if (content.isEmpty() || content.get().parts() == null) { return partsArray; @@ -51,7 +231,7 @@ public static ArrayNode formatContentParts(Optional content, int maxLen if (part.text().isPresent()) { partObj.put("mime_type", "text/plain"); - partObj.put("text", truncateString(part.text().get(), maxLength)); + partObj.put("text", truncate(part.text().get(), maxLength)); } else if (part.inlineData().isPresent()) { Blob blob = part.inlineData().get(); partObj.put("mime_type", blob.mimeType().orElse("")); @@ -67,45 +247,84 @@ public static ArrayNode formatContentParts(Optional content, int maxLen return partsArray; } - /** Recursively truncates long strings inside an object and returns a Jackson JsonNode. */ - public static JsonNode smartTruncate(Object obj, int maxLength) { + /** Recursively truncates long strings inside an object and returns a TruncationResult. */ + static TruncationResult smartTruncate(Object obj, int maxLength) { if (obj == null) { - return mapper.nullNode(); + return TruncationResult.create(mapper.nullNode(), false); } try { return recursiveSmartTruncate(mapper.valueToTree(obj), maxLength); + } catch (IllegalArgumentException e) { + // Fallback for types that mapper can't handle directly as a tree + return truncateWithStatus(String.valueOf(obj), maxLength); + } + } + + static JsonNode convertToJsonNode(Object obj) { + if (obj == null) { + return mapper.nullNode(); + } + try { + return mapper.valueToTree(obj); } catch (IllegalArgumentException e) { // Fallback for types that mapper can't handle directly as a tree return mapper.valueToTree(String.valueOf(obj)); } } - private static JsonNode recursiveSmartTruncate(JsonNode node, int maxLength) { + private static TruncationResult recursiveSmartTruncate(JsonNode node, int maxLength) { + boolean isTruncated = false; if (node.isTextual()) { - return mapper.valueToTree(truncateString(node.asText(), maxLength)); + String text = node.asText(); + if (text.length() > maxLength) { + return TruncationResult.create(mapper.valueToTree(truncate(text, maxLength)), true); + } + return TruncationResult.create(node, false); } else if (node.isObject()) { ObjectNode newNode = mapper.createObjectNode(); - node.properties() - .iterator() - .forEachRemaining( - entry -> { - newNode.set(entry.getKey(), recursiveSmartTruncate(entry.getValue(), maxLength)); - }); - return newNode; + Set> properties = node.properties(); + for (Map.Entry entry : properties) { + TruncationResult res = recursiveSmartTruncate(entry.getValue(), maxLength); + newNode.set(entry.getKey(), res.node()); + isTruncated = isTruncated || res.isTruncated(); + } + return TruncationResult.create(newNode, isTruncated); } else if (node.isArray()) { ArrayNode newNode = mapper.createArrayNode(); for (JsonNode element : node) { - newNode.add(recursiveSmartTruncate(element, maxLength)); + TruncationResult res = recursiveSmartTruncate(element, maxLength); + newNode.add(res.node()); + isTruncated = isTruncated || res.isTruncated(); } - return newNode; + return TruncationResult.create(newNode, isTruncated); } - return node; + return TruncationResult.create(node, false); } - private static String truncateString(String s, int maxLength) { + private static TruncationResult truncateWithStatus(String s, int maxLength) { + if (s == null) { + return TruncationResult.create(mapper.nullNode(), false); + } + if (s.length() <= maxLength) { + return TruncationResult.create(mapper.valueToTree(s), false); + } + return TruncationResult.create(mapper.valueToTree(truncate(s, maxLength)), true); + } + + private static String truncate(String s, int maxLength) { if (s == null || s.length() <= maxLength) { return s; } return s.substring(0, maxLength) + "...[truncated]"; } + + /** Converts a JsonNode to a standard Java object (Map, List, etc.). */ + public static @Nullable Object toJavaObject(JsonNode node) { + if (node == null || node.isNull()) { + return null; + } + return mapper.convertValue(node, Object.class); + } + + private JsonFormatter() {} } diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/TraceManager.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/TraceManager.java new file mode 100644 index 000000000..581c4731e --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/TraceManager.java @@ -0,0 +1,284 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import com.google.adk.agents.InvocationContext; +import com.google.auto.value.AutoValue; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.sdk.trace.ReadableSpan; +import java.time.Duration; +import java.time.Instant; +import java.util.Iterator; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; + +/** + * Manages OpenTelemetry-style trace and span context using InvocationContext callback data. + * + *

Uses a stack of SpanRecord objects to keep span, ID, ownership, and timing in sync. + */ +public final class TraceManager { + private static final Logger logger = Logger.getLogger(TraceManager.class.getName()); + + private final ConcurrentLinkedDeque records = new ConcurrentLinkedDeque<>(); + private String rootAgentName = "_bq_analytics_root_agent_name"; + private String activeInvocationId = "_bq_analytics_active_invocation_id"; + + private final Tracer tracer; + + TraceManager() { + this(GlobalOpenTelemetry.getTracer("google.adk.plugins.bigquery_agent_analytics")); + } + + TraceManager(Tracer tracer) { + this.tracer = tracer; + } + + @AutoValue + abstract static class SpanRecord { + abstract Span span(); + + abstract String spanId(); + + abstract boolean ownsSpan(); + + abstract Instant startTime(); + + abstract AtomicReference firstTokenTime(); + + static SpanRecord create(Span span, String spanId, boolean ownsSpan, Instant startTime) { + return new AutoValue_TraceManager_SpanRecord( + span, spanId, ownsSpan, startTime, new AtomicReference<>()); + } + } + + @AutoValue + abstract static class RecordData { + abstract String spanId(); + + abstract Duration duration(); + + static RecordData create(String spanId, Duration duration) { + return new AutoValue_TraceManager_RecordData(spanId, duration); + } + } + + @AutoValue + abstract static class SpanIds { + abstract Optional spanId(); + + abstract Optional parentSpanId(); + + static SpanIds create(String spanId, String parentSpanId) { + return new AutoValue_TraceManager_SpanIds( + Optional.ofNullable(spanId), Optional.ofNullable(parentSpanId)); + } + } + + public String getRootAgentName() { + return rootAgentName; + } + + public void initTrace(InvocationContext context) { + String rootAgentName = context.agent().rootAgent().name(); + this.rootAgentName = rootAgentName; + } + + public String getTraceId(InvocationContext context) { + if (!records.isEmpty()) { + Span currentSpan = records.peekLast().span(); + if (currentSpan.getSpanContext().isValid()) { + return currentSpan.getSpanContext().getTraceId(); + } + } + // Fallback to the ambient span. + Span ambient = Span.current(); + if (ambient.getSpanContext().isValid()) { + return ambient.getSpanContext().getTraceId(); + } + // Fallback to the invocation ID. + return context.invocationId(); + } + + public boolean hasAmbientSpan() { + return Span.current().getSpanContext().isValid(); + } + + @CanIgnoreReturnValue + public String pushSpan(String spanName) { + Context parentContext = Context.current(); + if (!records.isEmpty()) { + Span parentSpan = records.peekLast().span(); + if (parentSpan.getSpanContext().isValid()) { + parentContext = parentContext.with(parentSpan); + } + } + + Span span = tracer.spanBuilder(spanName).setParent(parentContext).startSpan(); + String spanIdStr; + if (span.getSpanContext().isValid()) { + spanIdStr = span.getSpanContext().getSpanId(); + } else { + // This span id aligns with the OpenTelemetry Span ID format. + spanIdStr = UUID.randomUUID().toString().replace("-", "").substring(0, 16); + } + + SpanRecord record = SpanRecord.create(span, spanIdStr, true, Instant.now()); + records.add(record); + return spanIdStr; + } + + @CanIgnoreReturnValue + public String attachCurrentSpan() { + Span span = Span.current(); + String spanIdStr; + if (span.getSpanContext().isValid()) { + spanIdStr = span.getSpanContext().getSpanId(); + } else { + spanIdStr = UUID.randomUUID().toString().replace("-", "").substring(0, 16); + } + + SpanRecord record = SpanRecord.create(span, spanIdStr, false, Instant.now()); + records.add(record); + return spanIdStr; + } + + public void ensureInvocationSpan(InvocationContext context) { + String currentInv = context.invocationId(); + + if (!records.isEmpty()) { + if (currentInv.equals(activeInvocationId)) { + return; + } + logger.info("Clearing stale span records from previous invocation."); + clearStack(); + } + + activeInvocationId = currentInv; + + Span ambient = Span.current(); + if (ambient.getSpanContext().isValid()) { + attachCurrentSpan(); + } else { + pushSpan("invocation"); + } + } + + @CanIgnoreReturnValue + public Optional popSpan() { + if (records.isEmpty()) { + return Optional.empty(); + } + SpanRecord record = records.pollLast(); + if (record == null) { + return Optional.empty(); + } + Duration duration = Duration.between(record.startTime(), Instant.now()); + if (record.ownsSpan()) { + record.span().end(); + } + return Optional.of(RecordData.create(record.spanId(), duration)); + } + + public void clearStack() { + if (!records.isEmpty()) { + for (SpanRecord record : records) { + if (record.ownsSpan()) { + record.span().end(); + } + } + records.clear(); + } + } + + public SpanIds getCurrentSpanAndParent() { + if (records.isEmpty()) { + return SpanIds.create(null, null); + } + + String spanId = records.peekLast().spanId(); + String parentId = null; + Iterator descIterator = records.descendingIterator(); + while (descIterator.hasNext()) { + SpanRecord record = descIterator.next(); + if (!record.spanId().equals(spanId)) { + parentId = record.spanId(); + break; + } + } + return SpanIds.create(spanId, parentId); + } + + Optional getAmbientSpanAndParent() { + Span ambient = Span.current(); + if (!ambient.getSpanContext().isValid()) { + return Optional.empty(); + } + String spanId = ambient.getSpanContext().getSpanId(); + String parentSpanId = null; + if (ambient instanceof ReadableSpan readableSpan) { + SpanContext parentCtx = readableSpan.getParentSpanContext(); + if (parentCtx != null && parentCtx.isValid()) { + parentSpanId = parentCtx.getSpanId(); + } + } + return Optional.of(SpanIds.create(spanId, parentSpanId)); + } + + public @Nullable String getCurrentSpanId() { + if (records.isEmpty()) { + return null; + } + return records.peekLast().spanId(); + } + + public void recordFirstToken(String spanId) { + for (SpanRecord record : records) { + if (record.spanId().equals(spanId)) { + record.firstTokenTime().compareAndSet(null, Instant.now()); + return; + } + } + } + + public Optional getStartTime(String spanId) { + for (SpanRecord record : records) { + if (record.spanId().equals(spanId)) { + return Optional.of(record.startTime()); + } + } + return Optional.empty(); + } + + public Optional getFirstTokenTime(String spanId) { + for (SpanRecord record : records) { + if (record.spanId().equals(spanId)) { + return Optional.ofNullable(record.firstTokenTime().get()); + } + } + return Optional.empty(); + } +} diff --git a/core/src/main/java/com/google/adk/tools/AgentTool.java b/core/src/main/java/com/google/adk/tools/AgentTool.java index 7eabc48c4..956a8eb51 100644 --- a/core/src/main/java/com/google/adk/tools/AgentTool.java +++ b/core/src/main/java/com/google/adk/tools/AgentTool.java @@ -80,7 +80,7 @@ protected AgentTool(BaseAgent agent, boolean skipSummarization) { } @VisibleForTesting - BaseAgent getAgent() { + public BaseAgent getAgent() { return agent; } diff --git a/core/src/main/java/com/google/adk/utils/AgentEnums.java b/core/src/main/java/com/google/adk/utils/AgentEnums.java new file mode 100644 index 000000000..05460b540 --- /dev/null +++ b/core/src/main/java/com/google/adk/utils/AgentEnums.java @@ -0,0 +1,13 @@ +package com.google.adk.utils; + +/** Enums for agents. */ +public final class AgentEnums { + /** Origin of the agent. */ + public static enum AgentOrigin { + BASE_AGENT, + SUB_AGENT, + A2A, + } + + private AgentEnums() {} +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java new file mode 100644 index 000000000..890af1719 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java @@ -0,0 +1,240 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.runner.Runner; +import com.google.adk.sessions.Session; +import com.google.api.core.ApiFutures; +import com.google.auth.Credentials; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class BigQueryAgentAnalyticsPluginE2ETest { + private BigQuery mockBigQuery; + private StreamWriter mockWriter; + private BigQueryWriteClient mockWriteClient; + private BigQueryLoggerConfig config; + private BigQueryAgentAnalyticsPlugin plugin; + private Runner runner; + private BaseAgent fakeAgent; + private final List> capturedRows = + Collections.synchronizedList(new ArrayList<>()); + + @Before + public void setUp() throws Exception { + mockBigQuery = mock(BigQuery.class); + mockWriter = mock(StreamWriter.class); + mockWriteClient = mock(BigQueryWriteClient.class); + + config = + BigQueryLoggerConfig.builder() + .setEnabled(true) + .setProjectId("project") + .setDatasetId("dataset") + .setTableName("table") + .setBatchSize(10) + .setBatchFlushInterval(Duration.ofSeconds(10)) + .setCredentials(mock(Credentials.class)) + .build(); + + when(mockBigQuery.getOptions()) + .thenReturn(BigQueryOptions.newBuilder().setProjectId("test-project").build()); + when(mockBigQuery.getTable(any(TableId.class))).thenReturn(mock(Table.class)); + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); + + plugin = + new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + return mockWriter; + } + }; + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create( + BigQuerySchema.getArrowSchema(), plugin.batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + for (int i = 0; i < root.getRowCount(); i++) { + Map row = new HashMap<>(); + row.put( + "event_type", Objects.toString(root.getVector("event_type").getObject(i))); + row.put("agent", Objects.toString(root.getVector("agent").getObject(i))); + row.put( + "session_id", Objects.toString(root.getVector("session_id").getObject(i))); + row.put( + "invocation_id", + Objects.toString(root.getVector("invocation_id").getObject(i))); + row.put("user_id", Objects.toString(root.getVector("user_id").getObject(i))); + row.put( + "timestamp", ((TimeStampMicroTZVector) root.getVector("timestamp")).get(i)); + row.put("is_truncated", root.getVector("is_truncated").getObject(i)); + row.put("content", Objects.toString(root.getVector("content").getObject(i))); + capturedRows.add(row); + } + } catch (RuntimeException e) { + throw new RuntimeException("Error in thenAnswer", e); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + fakeAgent = new FakeAgent("test_agent"); + runner = Runner.builder().agent(fakeAgent).appName("test_app").plugins(plugin).build(); + } + + @Test + public void runAgent_logsAgentStartingAndCompleted() throws Exception { + Session session = runner.sessionService().createSession("test_app", "user").blockingGet(); + String sessionId = session.id(); + + runner + .runAsync("user", sessionId, Content.fromParts(Part.fromText("hello"))) + .blockingSubscribe(); + + // Ensure everything is flushed. The BatchProcessor flushes asynchronously sometimes, + // but the direct flush() call should help. We wait up to 2 seconds for all 5 expected events. + for (int i = 0; i < 20 && capturedRows.size() < 5; i++) { + plugin.batchProcessor.flush(); + if (capturedRows.size() < 5) { + Thread.sleep(100); + } + } + + // Verify presence of expected events + List eventTypes = + capturedRows.stream().map(row -> (String) row.get("event_type")).toList(); + + assertFalse("capturedRows should not be empty", capturedRows.isEmpty()); + assertTrue( + "Events should contain AGENT_STARTING. Actual: " + eventTypes, + eventTypes.contains("AGENT_STARTING")); + assertTrue( + "Events should contain AGENT_COMPLETED. Actual: " + eventTypes, + eventTypes.contains("AGENT_COMPLETED")); + assertTrue( + "Events should contain USER_MESSAGE_RECEIVED. Actual: " + eventTypes, + eventTypes.contains("USER_MESSAGE_RECEIVED")); + assertTrue( + "Events should contain INVOCATION_STARTING. Actual: " + eventTypes, + eventTypes.contains("INVOCATION_STARTING")); + assertTrue( + "Events should contain INVOCATION_COMPLETED. Actual: " + eventTypes, + eventTypes.contains("INVOCATION_COMPLETED")); + + // Verify common fields for one of the rows + Map agentStartingRow = + capturedRows.stream() + .filter(row -> Objects.equals(row.get("event_type"), "AGENT_STARTING")) + .findFirst() + .orElseThrow(); + + assertEquals("test_agent", agentStartingRow.get("agent")); + assertEquals(sessionId, agentStartingRow.get("session_id")); + assertEquals("user", agentStartingRow.get("user_id")); + assertNotNull("invocation_id should be populated", agentStartingRow.get("invocation_id")); + assertTrue("timestamp should be positive", (Long) agentStartingRow.get("timestamp") > 0); + assertEquals(false, agentStartingRow.get("is_truncated")); + + // Verify content for USER_MESSAGE_RECEIVED + Map userMessageRow = + capturedRows.stream() + .filter(row -> Objects.equals(row.get("event_type"), "USER_MESSAGE_RECEIVED")) + .findFirst() + .orElseThrow(); + String contentJson = (String) userMessageRow.get("content"); + assertTrue("Content should contain 'hello'", contentJson.contains("hello")); + + // Verify order + int userMessageIdx = eventTypes.indexOf("USER_MESSAGE_RECEIVED"); + int invocationStartIdx = eventTypes.indexOf("INVOCATION_STARTING"); + int agentStartIdx = eventTypes.indexOf("AGENT_STARTING"); + int agentCompletedIdx = eventTypes.indexOf("AGENT_COMPLETED"); + int invocationCompletedIdx = eventTypes.indexOf("INVOCATION_COMPLETED"); + + assertTrue( + "USER_MESSAGE_RECEIVED should be first by Runner implementation", + userMessageIdx < invocationStartIdx); + assertTrue( + "INVOCATION_STARTING should be before AGENT_STARTING", invocationStartIdx < agentStartIdx); + assertTrue( + "AGENT_STARTING should be before AGENT_COMPLETED", agentStartIdx < agentCompletedIdx); + assertTrue( + "AGENT_COMPLETED should be before INVOCATION_COMPLETED", + agentCompletedIdx < invocationCompletedIdx); + } + + private static class FakeAgent extends BaseAgent { + FakeAgent(String name) { + super(name, "description", null, null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java index 8147c5cc6..224aa38f0 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -33,7 +33,12 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; import com.google.adk.sessions.Session; +import com.google.adk.tools.AgentTool; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.adk.utils.AgentEnums.AgentOrigin; import com.google.api.core.ApiFutures; import com.google.auth.Credentials; import com.google.cloud.bigquery.BigQuery; @@ -43,18 +48,27 @@ import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Candidate; import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentResponse; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.reactivex.rxjava3.core.Flowable; import java.time.Duration; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.LogRecord; @@ -67,6 +81,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -80,6 +95,7 @@ @RunWith(JUnit4.class) public class BigQueryAgentAnalyticsPluginTest { @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + @Rule public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); @Mock private BigQuery mockBigQuery; @Mock private StreamWriter mockWriter; @@ -90,9 +106,11 @@ public class BigQueryAgentAnalyticsPluginTest { private BigQueryLoggerConfig config; private BigQueryAgentAnalyticsPlugin plugin; private Handler mockHandler; + private Tracer tracer; @Before public void setUp() throws Exception { + tracer = openTelemetryRule.getOpenTelemetry().getTracer("test-plugin"); fakeAgent = new FakeAgent("agent_name"); config = BigQueryLoggerConfig.builder() @@ -124,12 +142,18 @@ protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { protected StreamWriter createWriter(BigQueryLoggerConfig config) { return mockWriter; } + + @Override + protected TraceManager createTraceManager() { + return new TraceManager(tracer); + } }; - Session session = Session.builder("session_id").build(); + Session session = Session.builder("session_id").appName("test_app").userId("test_user").build(); when(mockInvocationContext.session()).thenReturn(session); when(mockInvocationContext.invocationId()).thenReturn("invocation_id"); when(mockInvocationContext.agent()).thenReturn(fakeAgent); + when(mockInvocationContext.callbackContextData()).thenReturn(new ConcurrentHashMap<>()); when(mockInvocationContext.userId()).thenReturn("user_id"); Logger logger = Logger.getLogger(BatchProcessor.class.getName()); @@ -137,6 +161,14 @@ protected StreamWriter createWriter(BigQueryLoggerConfig config) { logger.addHandler(mockHandler); } + @After + public void tearDown() { + Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + if (mockHandler != null) { + logger.removeHandler(mockHandler); + } + } + @Test public void onUserMessageCallback_appendsToWriter() throws Exception { Content content = Content.builder().build(); @@ -216,12 +248,15 @@ public void onUserMessageCallback_handlesTableCreationFailure() throws Exception ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); verify(mockHandler, atLeastOnce()).publish(captor.capture()); - assertTrue( - captor - .getValue() - .getMessage() - .contains("Failed to check or create/upgrade BigQuery table")); - assertEquals(Level.WARNING, captor.getValue().getLevel()); + boolean found = + captor.getAllValues().stream() + .anyMatch( + record -> + record + .getMessage() + .contains("Failed to check or create/upgrade BigQuery table") + && Objects.equals(record.getLevel(), Level.WARNING)); + assertTrue("Should have logged table creation failure warning", found); } finally { logger.removeHandler(mockHandler); } @@ -313,7 +348,8 @@ public void logEvent_populatesCommonFields() throws Exception { if (root.getRowCount() != 1) { failureMessage[0] = "Expected 1 row, got " + root.getRowCount(); } else if (!Objects.equals( - root.getVector("event_type").getObject(0).toString(), "USER_MESSAGE")) { + root.getVector("event_type").getObject(0).toString(), + "USER_MESSAGE_RECEIVED")) { failureMessage[0] = "Wrong event_type: " + root.getVector("event_type").getObject(0); } else if (!root.getVector("agent").getObject(0).toString().equals("agent_name")) { @@ -334,6 +370,9 @@ public void logEvent_populatesCommonFields() throws Exception { failureMessage[0] = "Wrong user_id: " + root.getVector("user_id").getObject(0); } else if (((TimeStampMicroTZVector) root.getVector("timestamp")).get(0) <= 0) { failureMessage[0] = "Timestamp not populated"; + } else if (!Objects.equals(root.getVector("is_truncated").getObject(0), false)) { + failureMessage[0] = + "Wrong is_truncated: " + root.getVector("is_truncated").getObject(0); } else { // Check content and content_parts String contentJson = root.getVector("content").getObject(0).toString(); @@ -381,6 +420,8 @@ public void logEvent_populatesTraceDetails() throws Exception { Span mockSpan = Span.wrap(mockSpanContext); try (Scope scope = mockSpan.makeCurrent()) { + plugin.traceManager.attachCurrentSpan(); + Content content = Content.builder().build(); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); @@ -414,29 +455,190 @@ public void onEventCallback_populatesCorrectFields() throws Exception { Map row = plugin.batchProcessor.queue.poll(); assertNotNull("Row not found in queue", row); - assertEquals("EVENT", row.get("event_type")); + assertEquals("STATE_DELTA", row.get("event_type")); assertEquals("agent_name", row.get("agent")); - assertTrue(row.get("attributes").toString().contains("agent_author")); + ObjectNode attributes = (ObjectNode) row.get("attributes"); + assertEquals("agent_author", attributes.get("author").asText()); assertTrue(row.get("content").toString().contains("event content")); + assertEquals(false, row.get("is_truncated")); } @Test public void onModelErrorCallback_populatesCorrectFields() throws Exception { CallbackContext mockCallbackContext = mock(CallbackContext.class); when(mockCallbackContext.invocationContext()).thenReturn(mockInvocationContext); - when(mockCallbackContext.agentName()).thenReturn("agent_in_context"); LlmRequest.Builder mockLlmRequestBuilder = mock(LlmRequest.Builder.class); Throwable error = new RuntimeException("model error message"); + plugin.traceManager.pushSpan("llm_request"); plugin .onModelErrorCallback(mockCallbackContext, mockLlmRequestBuilder, error) .blockingSubscribe(); Map row = plugin.batchProcessor.queue.poll(); assertNotNull("Row not found in queue", row); - assertEquals("MODEL_ERROR", row.get("event_type")); - assertEquals("agent_in_context", row.get("agent")); - assertTrue(row.get("attributes").toString().contains("model error message")); + assertEquals("LLM_ERROR", row.get("event_type")); + assertEquals("agent_name", row.get("agent")); + assertEquals("ERROR", row.get("status")); + assertEquals("model error message", row.get("error_message")); + assertNotNull(row.get("latency_ms")); + assertEquals(false, row.get("is_truncated")); + } + + @Test + public void afterModelCallback_populatesCorrectFields() throws Exception { + CallbackContext mockCallbackContext = mock(CallbackContext.class); + when(mockCallbackContext.invocationContext()).thenReturn(mockInvocationContext); + + GenerateContentResponseUsageMetadata usage = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build(); + + GenerateContentResponse response = + GenerateContentResponse.builder() + .modelVersion("v1") + .usageMetadata(usage) + .candidates( + ImmutableList.of( + Candidate.builder() + .content(Content.fromParts(Part.fromText("llm response"))) + .build())) + .build(); + + LlmResponse adkResponse = LlmResponse.create(response); + + Span parentSpan = tracer.spanBuilder("parent_request").startSpan(); + Span ambientSpan = + tracer.spanBuilder("ambient").setParent(Context.current().with(parentSpan)).startSpan(); + // Set valid ambient span context + try (Scope scope = ambientSpan.makeCurrent()) { + plugin.traceManager.pushSpan("parent_request"); + plugin.traceManager.pushSpan("llm_request"); + plugin.afterModelCallback(mockCallbackContext, adkResponse).blockingSubscribe(); + } finally { + ambientSpan.end(); + } + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals("LLM_RESPONSE", row.get("event_type")); + ObjectNode contentMap = (ObjectNode) row.get("content"); + assertNotNull(contentMap.get("response")); + ObjectNode usageMap = (ObjectNode) contentMap.get("usage"); + assertEquals(10, usageMap.get("prompt").asInt()); + + ObjectNode attributes = (ObjectNode) row.get("attributes"); + assertEquals("v1", attributes.get("model_version").asText()); + ObjectNode usageAttr = (ObjectNode) attributes.get("usage_metadata"); + assertEquals(10, usageAttr.get("prompt").asInt()); + assertEquals(false, row.get("is_truncated")); + assertNotNull(row.get("parent_span_id")); + ObjectNode latencyMs = (ObjectNode) row.get("latency_ms"); + assertNotNull("latency_ms should not be null", latencyMs); + assertTrue( + "latency_ms should contain time_to_first_token_ms", + latencyMs.has("time_to_first_token_ms")); + } + + @Test + public void afterToolCallback_populatesCorrectFields() throws Exception { + ToolContext mockToolContext = mock(ToolContext.class); + when(mockToolContext.invocationContext()).thenReturn(mockInvocationContext); + + BaseTool mockTool = mock(BaseTool.class); + when(mockTool.name()).thenReturn("test_tool"); + + ImmutableMap toolArgs = ImmutableMap.of("arg1", "value1"); + ImmutableMap result = ImmutableMap.of("res1", "value2"); + + plugin.traceManager.pushSpan("tool_request"); + plugin.afterToolCallback(mockTool, toolArgs, mockToolContext, result).blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals("TOOL_COMPLETED", row.get("event_type")); + assertEquals("agent_name", row.get("agent")); + ObjectNode contentMap = (ObjectNode) row.get("content"); + assertEquals("test_tool", contentMap.get("tool").asText()); + assertNotNull(contentMap.get("result")); + assertEquals("UNKNOWN", contentMap.get("tool_origin").asText()); + assertEquals(false, row.get("is_truncated")); + assertNotNull(row.get("latency_ms")); + } + + @Test + public void afterToolCallback_identifiesA2AOrigin() throws Exception { + ToolContext mockToolContext = mock(ToolContext.class); + when(mockToolContext.invocationContext()).thenReturn(mockInvocationContext); + + BaseAgent a2aAgent = + new FakeAgent("a2a_agent") { + @Override + public AgentOrigin toolOrigin() { + return AgentOrigin.A2A; + } + }; + + AgentTool a2aTool = AgentTool.create(a2aAgent); + + plugin.traceManager.pushSpan("tool_request"); + plugin + .afterToolCallback(a2aTool, ImmutableMap.of(), mockToolContext, ImmutableMap.of()) + .blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull(row); + ObjectNode contentMap = (ObjectNode) row.get("content"); + assertEquals("A2A", contentMap.get("tool_origin").asText()); + } + + @Test + public void logEvent_includesSessionMetadata_whenEnabled() throws Exception { + // Config default has logSessionMetadata(true) + Content content = Content.fromParts(Part.fromText("test message")); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull(row); + ObjectNode attributes = (ObjectNode) row.get("attributes"); + assertTrue("attributes should contain session_metadata", attributes.has("session_metadata")); + ObjectNode sessionMeta = (ObjectNode) attributes.get("session_metadata"); + assertEquals("session_id", sessionMeta.get("session_id").asText()); + assertEquals("test_user", sessionMeta.get("user_id").asText()); + assertEquals("test_app", sessionMeta.get("app_name").asText()); + } + + @Test + public void logEvent_excludesSessionMetadata_whenDisabled() throws Exception { + BigQueryLoggerConfig disabledConfig = config.toBuilder().setLogSessionMetadata(false).build(); + BigQueryAgentAnalyticsPlugin disabledPlugin = + new BigQueryAgentAnalyticsPlugin(disabledConfig, mockBigQuery) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + return mockWriter; + } + + @Override + protected TraceManager createTraceManager() { + return new TraceManager(GlobalOpenTelemetry.getTracer("test-plugin-disabled")); + } + }; + + Content content = Content.fromParts(Part.fromText("test message")); + disabledPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = disabledPlugin.batchProcessor.queue.poll(); + assertNotNull(row); + ObjectNode attributes = (ObjectNode) row.get("attributes"); + assertFalse( + "attributes should not contain session_metadata", attributes.has("session_metadata")); } private static class FakeAgent extends BaseAgent { diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java new file mode 100644 index 000000000..739f3a7c3 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java @@ -0,0 +1,145 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.google.adk.models.LlmRequest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import java.util.Arrays; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class JsonFormatterTest { + + @Test + public void parse_llmRequest_populatesPrompt() { + LlmRequest request = + LlmRequest.builder() + .contents( + ImmutableList.of( + Content.fromParts(Part.fromText("hello")).toBuilder().role("user").build())) + .build(); + + JsonFormatter.ParsedContent result = JsonFormatter.parse(request, 100); + + assertTrue(result.content().has("prompt")); + ArrayNode prompt = (ArrayNode) result.content().get("prompt"); + assertEquals(1, prompt.size()); + assertEquals("user", prompt.get(0).get("role").asText()); + assertEquals("hello", prompt.get(0).get("content").asText()); + assertFalse(result.isTruncated()); + } + + @Test + public void parse_llmRequest_populatesSystemPrompt() { + LlmRequest request = + LlmRequest.builder() + .config( + GenerateContentConfig.builder() + .systemInstruction(Content.fromParts(Part.fromText("be helpful"))) + .build()) + .build(); + + JsonFormatter.ParsedContent result = JsonFormatter.parse(request, 100); + + assertTrue(result.content().has("system_prompt")); + assertEquals("be helpful", result.content().get("system_prompt").asText()); + assertFalse(result.isTruncated()); + } + + @Test + public void parse_string_truncates() { + String longString = "this is a very long string that should be truncated"; + JsonFormatter.ParsedContent result = JsonFormatter.parse(longString, 10); + + assertTrue(result.isTruncated()); + assertEquals("this is a ...[truncated]", result.content().asText()); + } + + @Test + public void parse_map_truncatesNested() { + ImmutableMap map = ImmutableMap.of("key", "this is a long value"); + JsonFormatter.ParsedContent result = JsonFormatter.parse(map, 10); + + assertTrue(result.isTruncated()); + assertEquals("this is a ...[truncated]", result.content().get("key").asText()); + } + + @Test + public void parse_content_returnsSummary() { + Content content = Content.fromParts(Part.fromText("part 1"), Part.fromText("part 2")); + JsonFormatter.ParsedContent result = JsonFormatter.parse(content, 100); + + assertEquals("part 1 | part 2", result.content().get("text_summary").asText()); + assertEquals(2, result.parts().size()); + } + + @Test + public void parse_content_withFileData() { + FileData fileData = + FileData.builder().fileUri("gs://bucket/file.txt").mimeType("text/plain").build(); + Content content = Content.fromParts(Part.builder().fileData(fileData).build()); + JsonFormatter.ParsedContent result = JsonFormatter.parse(content, 100); + + assertEquals(1, result.parts().size()); + JsonNode partData = result.parts().get(0); + assertEquals("EXTERNAL_URI", partData.get("storage_mode").asText()); + assertEquals("gs://bucket/file.txt", partData.get("uri").asText()); + assertEquals("text/plain", partData.get("mime_type").asText()); + } + + @Test + public void parse_content_withFunctionCall() { + FunctionCall fc = FunctionCall.builder().name("myFunction").build(); + Content content = Content.fromParts(Part.builder().functionCall(fc).build()); + JsonFormatter.ParsedContent result = JsonFormatter.parse(content, 100); + + assertEquals(1, result.parts().size()); + JsonNode partData = result.parts().get(0); + assertEquals("application/json", partData.get("mime_type").asText()); + assertEquals("Function: myFunction", partData.get("text").asText()); + assertTrue(partData.get("part_attributes").asText().contains("myFunction")); + } + + @Test + public void parse_list_truncatesElements() { + List list = + Arrays.asList("short", "this is a very long string that should be truncated"); + JsonFormatter.ParsedContent result = JsonFormatter.parse(list, 10); + + assertTrue(result.isTruncated()); + JsonNode arrayNode = result.content(); + assertTrue(arrayNode.isArray()); + assertEquals(2, arrayNode.size()); + assertEquals("short", arrayNode.get(0).asText()); + assertEquals("this is a ...[truncated]", arrayNode.get(1).asText()); + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/TraceManagerTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/TraceManagerTest.java new file mode 100644 index 000000000..6cefebf2b --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/TraceManagerTest.java @@ -0,0 +1,230 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Scope; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import io.reactivex.rxjava3.core.Flowable; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class TraceManagerTest { + @Rule public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); + private InvocationContext mockContext; + private BaseAgent mockAgent; + private Map callbackData; + private TraceManager traceManager; + private Tracer tracer; + + @Before + public void setUp() { + tracer = openTelemetryRule.getOpenTelemetry().getTracer("test"); + callbackData = new ConcurrentHashMap<>(); + mockContext = mock(InvocationContext.class); + when(mockContext.callbackContextData()).thenReturn(callbackData); + when(mockContext.invocationId()).thenReturn("test-invocation-id"); + mockAgent = + new BaseAgent("test-agent", "desc", null, null, null) { + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + }; + when(mockContext.agent()).thenReturn(mockAgent); + traceManager = new TraceManager(tracer); + } + + @Test + public void pushSpan_createsValidSpanId() { + String spanId = traceManager.pushSpan("test-span"); + assertNotNull(spanId); + assertTrue(spanId.length() >= 16); + } + + @Test + public void pushSpan_maintainsParentChildRelationship() { + String parentId = traceManager.pushSpan("parent"); + String childId = traceManager.pushSpan("child"); + + TraceManager.SpanIds ids = traceManager.getCurrentSpanAndParent(); + assertEquals(childId, ids.spanId().orElse(null)); + assertEquals(parentId, ids.parentSpanId().orElse(null)); + } + + @Test + public void popSpan_removesFromStack() { + String parentId = traceManager.pushSpan("parent"); + traceManager.pushSpan("child"); + + Optional popped = traceManager.popSpan(); + assertTrue(popped.isPresent()); + assertFalse(popped.get().duration().isNegative()); + + String currentId = traceManager.getCurrentSpanId(); + assertEquals(parentId, currentId); + + TraceManager.SpanIds ids = traceManager.getCurrentSpanAndParent(); + assertEquals(parentId, ids.spanId().orElse(null)); + assertFalse(ids.parentSpanId().isPresent()); + } + + @Test + public void ensureInvocationSpan_isIdempotent() { + traceManager.ensureInvocationSpan(mockContext); + String id1 = traceManager.getCurrentSpanId(); + + traceManager.ensureInvocationSpan(mockContext); + String id2 = traceManager.getCurrentSpanId(); + + assertEquals(id1, id2); + } + + @Test + public void ensureInvocationSpan_clearsStaleRecords() { + Span ambientSpan = tracer.spanBuilder("ambient").startSpan(); + try (Scope scope = ambientSpan.makeCurrent()) { + traceManager.ensureInvocationSpan(mockContext); + } finally { + ambientSpan.end(); + } + String id1 = traceManager.getCurrentSpanId(); + // Create a new context with same callback data but different invocation ID + InvocationContext mockContext2 = mock(InvocationContext.class); + when(mockContext2.callbackContextData()).thenReturn(callbackData); + when(mockContext2.invocationId()).thenReturn("new-invocation-id"); + when(mockContext2.agent()).thenReturn(mockAgent); + Span ambientSpan2 = tracer.spanBuilder("ambient2").startSpan(); + try (Scope scope = ambientSpan2.makeCurrent()) { + traceManager.ensureInvocationSpan(mockContext2); + } finally { + ambientSpan2.end(); + } + String id2 = traceManager.getCurrentSpanId(); + + assertNotEquals(id1, id2); + // Should only have 1 record now + TraceManager.SpanIds ids = traceManager.getCurrentSpanAndParent(); + assertFalse(ids.parentSpanId().isPresent()); + } + + @Test + public void attachCurrentSpan_usesAmbientSpan() { + Span ambientSpan = tracer.spanBuilder("ambient").startSpan(); + try (Scope scope = ambientSpan.makeCurrent()) { + String attachedId = traceManager.attachCurrentSpan(); + String expectedId = ambientSpan.getSpanContext().getSpanId(); + assertEquals(expectedId, attachedId); + } finally { + ambientSpan.end(); + } + } + + @Test + public void getTraceId_returnsCurrentTraceId() { + traceManager.pushSpan("test"); + String traceId = traceManager.getTraceId(mockContext); + assertNotNull(traceId); + if (traceId.equals("test-invocation-id")) { + assertEquals("test-invocation-id", traceId); + } else { + assertTrue(traceId.matches("[0-9a-f]{32}")); + } + } + + @Test + public void getTraceId_returnsInvocationId_whenRecordsIsEmpty() { + String traceId = traceManager.getTraceId(mockContext); + if (traceManager.hasAmbientSpan()) { + assertTrue(traceId.matches("[0-9a-f]{32}")); + } else { + assertEquals("test-invocation-id", traceId); + } + } + + @Test + public void getTraceId_returnsAmbientTraceId_whenRecordsIsEmpty_butAmbientIsPresent() { + Span ambientSpan = tracer.spanBuilder("ambient").startSpan(); + try (Scope scope = ambientSpan.makeCurrent()) { + String expectedTraceId = ambientSpan.getSpanContext().getTraceId(); + String traceId = traceManager.getTraceId(mockContext); + assertEquals(expectedTraceId, traceId); + } finally { + ambientSpan.end(); + } + } + + @Test + public void attachCurrentSpan_worksWithoutAmbientSpan() { + // Ensure no ambient span + String attachedId = traceManager.attachCurrentSpan(); + assertNotNull(attachedId); + assertEquals(16, attachedId.length()); + + // Verify it's in records + assertEquals(attachedId, traceManager.getCurrentSpanId()); + } + + @Test + public void getTraceId_fallsBackToInvocationId_whenRecordSpanIsInvalid() { + // attachCurrentSpan when no ambient context exists creates an invalid span record + traceManager.attachCurrentSpan(); + + String traceId = traceManager.getTraceId(mockContext); + if (traceManager.hasAmbientSpan()) { + assertTrue(traceId.matches("[0-9a-f]{32}")); + } else { + assertEquals("test-invocation-id", traceId); + } + } + + @Test + public void popSpan_returnsEmpty_whenRecordsIsEmpty() { + Optional popped = traceManager.popSpan(); + assertFalse(popped.isPresent()); + } + + @Test + public void clearStack_doesNothing_whenRecordsIsEmpty() { + traceManager.clearStack(); + assertTrue(traceManager.getCurrentSpanAndParent().spanId().isEmpty()); + } +}