diff --git a/tests/test_adk_compat.py b/tests/test_adk_compat.py new file mode 100644 index 00000000..8df1c4be --- /dev/null +++ b/tests/test_adk_compat.py @@ -0,0 +1,173 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +from veadk.memory.short_term_memory_backends.mysql_backend import MysqlSTMBackend +from veadk.memory.short_term_memory_backends.postgresql_backend import PostgreSqlSTMBackend +from veadk.memory.short_term_memory_backends.sqlite_backend import SQLiteSTMBackend +from veadk.tracing.telemetry.attributes.extractors.tool_attributes_extractors import ( + tool_gen_ai_tool_output, +) +from veadk.tracing.telemetry.attributes.extractors.types import ToolAttributesParams +import veadk.utils.adk_compat as adk_compat + + +def test_get_previous_interaction_id_missing_field(): + llm_request = SimpleNamespace() + assert adk_compat.get_previous_interaction_id(llm_request) is None + + +def test_get_previous_interaction_id_with_field(): + llm_request = SimpleNamespace(previous_interaction_id="interaction_123") + assert adk_compat.get_previous_interaction_id(llm_request) == "interaction_123" + + +def test_get_event_function_calls_from_getter(): + expected_calls = [SimpleNamespace(name="tool_a")] + + class Event: + def get_function_calls(self): + return expected_calls + + calls = adk_compat.get_event_function_calls(Event()) + assert calls == expected_calls + + +def test_get_event_function_calls_fallback_to_parts(): + part1 = SimpleNamespace(function_call=SimpleNamespace(name="tool_1")) + part2 = SimpleNamespace(function_call=None) + event = SimpleNamespace(content=SimpleNamespace(parts=[part1, part2])) + + calls = adk_compat.get_event_function_calls(event) + assert len(calls) == 1 + assert calls[0].name == "tool_1" + + +def test_get_event_function_calls_getter_error_fallback_to_parts(): + class Event: + content = SimpleNamespace(parts=[SimpleNamespace(function_call="fallback_call")]) + + def get_function_calls(self): + raise RuntimeError("broken getter") + + calls = adk_compat.get_event_function_calls(Event()) + assert calls == ["fallback_call"] + + +def test_get_event_function_responses_fallback_to_parts(): + part = SimpleNamespace(function_response=SimpleNamespace(name="tool_resp")) + event = SimpleNamespace(content=SimpleNamespace(parts=[part])) + + responses = adk_compat.get_event_function_responses(event) + assert len(responses) == 1 + assert responses[0].name == "tool_resp" + + +def test_mysql_backend_url_respects_async_driver_flag(monkeypatch): + monkeypatch.setattr( + "veadk.memory.short_term_memory_backends.mysql_backend.should_use_async_db_drivers", + lambda: True, + ) + backend = MysqlSTMBackend() + assert backend._db_url.startswith("mysql+aiomysql://") + + monkeypatch.setattr( + "veadk.memory.short_term_memory_backends.mysql_backend.should_use_async_db_drivers", + lambda: False, + ) + backend = MysqlSTMBackend() + assert backend._db_url.startswith("mysql+pymysql://") + + +def test_postgresql_backend_url_respects_async_driver_flag(monkeypatch): + monkeypatch.setattr( + "veadk.memory.short_term_memory_backends.postgresql_backend.should_use_async_db_drivers", + lambda: True, + ) + backend = PostgreSqlSTMBackend() + assert backend._db_url.startswith("postgresql+asyncpg://") + + monkeypatch.setattr( + "veadk.memory.short_term_memory_backends.postgresql_backend.should_use_async_db_drivers", + lambda: False, + ) + backend = PostgreSqlSTMBackend() + assert backend._db_url.startswith("postgresql://") + + +def test_sqlite_backend_url_respects_async_driver_flag(monkeypatch, tmp_path): + db_file = tmp_path / "compat-test.db" + + monkeypatch.setattr( + "veadk.memory.short_term_memory_backends.sqlite_backend.should_use_async_db_drivers", + lambda: True, + ) + backend = SQLiteSTMBackend(local_path=str(db_file)) + assert backend._db_url.startswith("sqlite+aiosqlite:///") + + monkeypatch.setattr( + "veadk.memory.short_term_memory_backends.sqlite_backend.should_use_async_db_drivers", + lambda: False, + ) + backend = SQLiteSTMBackend(local_path=str(db_file)) + assert backend._db_url.startswith("sqlite:///") + + +def test_tool_output_extractor_accepts_dict_response(): + function_response_event = SimpleNamespace( + content=SimpleNamespace( + parts=[ + SimpleNamespace( + function_response={ + "id": "id_1", + "name": "tool_name", + "response": {"ok": True}, + } + ) + ] + ) + ) + params = ToolAttributesParams( + tool=SimpleNamespace(name="tool_name"), + args={}, + function_response_event=function_response_event, + ) + + response = tool_gen_ai_tool_output(params) + assert '"name": "tool_name"' in response.content + + +def test_tool_output_extractor_accepts_object_response(): + function_response_event = SimpleNamespace( + content=SimpleNamespace( + parts=[ + SimpleNamespace( + function_response=SimpleNamespace( + id="id_2", + name="tool_obj", + response={"status": "done"}, + ) + ) + ] + ) + ) + params = ToolAttributesParams( + tool=SimpleNamespace(name="tool_obj"), + args={}, + function_response_event=function_response_event, + ) + + response = tool_gen_ai_tool_output(params) + assert '"name": "tool_obj"' in response.content diff --git a/tests/test_adk_compat_regression.py b/tests/test_adk_compat_regression.py new file mode 100644 index 00000000..1e87899e --- /dev/null +++ b/tests/test_adk_compat_regression.py @@ -0,0 +1,662 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Regression tests targeting the seams where VeADK consumes Google ADK APIs +that differ across versions (1.19 vs 2.x). + +Scope: +* Direct coverage of every helper in ``veadk.utils.adk_compat`` that the + pre-existing ``test_adk_compat.py`` did not exercise. +* Edge cases (missing content, empty parts, broken getters) for the event + function-call/response extraction helpers. +* Consumer-side behavior: every modified call site in runner.py / agent.py / + ark_llm.py / tool_attributes_extractors.py — including the new defensive + guards introduced as part of the smooth-upgrade work. +* Real-ADK contract checks: verify the helpers' output matches what the live + ADK objects produce, so we notice if ADK changes its public surface. + +Tests are credential-free: ADK / VeADK constructors that need API keys are +either monkeypatched or constructed with explicit dummy values. +""" + +from __future__ import annotations + +import importlib +import importlib.metadata as im +from types import SimpleNamespace + +import pytest +from packaging.version import Version + +import veadk.utils.adk_compat as adk_compat +from veadk.tracing.telemetry.attributes.extractors.tool_attributes_extractors import ( + tool_gen_ai_tool_output, +) +from veadk.tracing.telemetry.attributes.extractors.types import ToolAttributesParams + + +# --------------------------------------------------------------------------- +# Section A: adk_compat module — direct helpers (15 tests) +# +# These exercise the public surface that ``test_adk_compat.py`` does not yet +# cover: version detection, ``is_adk_gte``, ``should_use_async_db_drivers``, +# ``llm_request_has_field``, and the getter path for +# ``get_event_function_responses`` (the existing file only tests its fallback +# path). +# --------------------------------------------------------------------------- + + +def test_get_adk_version_returns_version_object(): + """``get_adk_version`` returns a ``packaging.version.Version`` instance.""" + v = adk_compat.get_adk_version() + assert isinstance(v, Version), f"expected Version, got {type(v).__name__}" + + +def test_get_adk_version_matches_installed_metadata(): + """``get_adk_version`` agrees with ``importlib.metadata.version``.""" + installed = im.version("google-adk") + assert str(adk_compat.get_adk_version()) == installed + + +def test_get_adk_version_is_cached(): + """``get_adk_version`` uses ``lru_cache`` — same identity on repeat calls.""" + assert adk_compat.get_adk_version() is adk_compat.get_adk_version() + + +def test_is_adk_gte_true_for_lower_target(): + """A version far below current installed should return True.""" + assert adk_compat.is_adk_gte("0.0.1") is True + + +def test_is_adk_gte_false_for_unreachably_high_target(): + """A version far above current installed should return False.""" + assert adk_compat.is_adk_gte("99.0.0") is False + + +def test_is_adk_gte_equal_to_current_returns_true(): + """``is_adk_gte`` is inclusive — current installed version compares True.""" + current = str(adk_compat.get_adk_version()) + assert adk_compat.is_adk_gte(current) is True + + +def test_is_adk_gte_major_only_target_parses(): + """Targets like ``"1"`` or ``"2"`` (major-only) parse and compare cleanly.""" + # Installed is >= 1.19, so "1" must be True. + assert adk_compat.is_adk_gte("1") is True + + +def test_should_use_async_db_drivers_true_on_modern_adk(): + """On any ADK >= 1.19 we must select the async DSN scheme.""" + # The whole compat layer assumes >= 1.19, so this must be True on any + # supported install. + assert adk_compat.should_use_async_db_drivers() is True + + +def test_llm_request_has_field_known_field_present(): + """``model`` is a stable LlmRequest field across ADK 1.19 and 2.0.""" + assert adk_compat.llm_request_has_field("model") is True + + +def test_llm_request_has_field_unknown_field_absent(): + """Unknown field name must return False (no spurious matches).""" + assert adk_compat.llm_request_has_field("definitely_not_a_real_field_xyz") is False + + +def test_llm_request_has_field_is_cached(): + """``llm_request_has_field`` uses ``lru_cache`` — repeated calls hit cache.""" + first = adk_compat.llm_request_has_field("model") + second = adk_compat.llm_request_has_field("model") + info = adk_compat.llm_request_has_field.cache_info() + assert first == second is True + assert info.hits >= 1, "expected at least one cache hit after repeated call" + + +def test_get_event_function_responses_uses_getter_when_present(): + """When the event exposes ``get_function_responses``, the helper uses it.""" + expected = [SimpleNamespace(name="tool_resp_a")] + + class Event: + def get_function_responses(self): + return expected + + assert adk_compat.get_event_function_responses(Event()) == expected + + +def test_get_event_function_responses_getter_error_falls_back_to_parts(): + """If the getter raises, the helper falls back to ``content.parts``.""" + + class Event: + content = SimpleNamespace( + parts=[SimpleNamespace(function_response="fallback_resp")] + ) + + def get_function_responses(self): + raise RuntimeError("getter is broken") + + assert adk_compat.get_event_function_responses(Event()) == ["fallback_resp"] + + +def test_get_event_function_responses_no_content_returns_empty_list(): + """An event without ``content`` produces an empty list, not an exception.""" + assert adk_compat.get_event_function_responses(SimpleNamespace()) == [] + + +def test_get_event_function_calls_no_content_returns_empty_list(): + """Same shape contract for the function-calls extractor.""" + assert adk_compat.get_event_function_calls(SimpleNamespace()) == [] + + +# --------------------------------------------------------------------------- +# Section B: edge cases for the event extraction helpers (6 tests) +# +# These cover the part-traversal fallback path with awkward shapes that ADK +# may emit in practice: ``parts`` set to ``None``, empty list, or mixed where +# some parts carry a function call/response and others don't. +# --------------------------------------------------------------------------- + + +def test_get_event_function_calls_empty_parts_returns_empty(): + event = SimpleNamespace(content=SimpleNamespace(parts=[])) + assert adk_compat.get_event_function_calls(event) == [] + + +def test_get_event_function_calls_parts_none_returns_empty(): + event = SimpleNamespace(content=SimpleNamespace(parts=None)) + assert adk_compat.get_event_function_calls(event) == [] + + +def test_get_event_function_calls_mixed_parts_filters_out_none(): + """Parts without ``function_call`` are skipped; calls are preserved in order.""" + part_a = SimpleNamespace(function_call=SimpleNamespace(name="a")) + part_b = SimpleNamespace(function_call=None) + part_c = SimpleNamespace(function_call=SimpleNamespace(name="c")) + event = SimpleNamespace(content=SimpleNamespace(parts=[part_a, part_b, part_c])) + calls = adk_compat.get_event_function_calls(event) + assert [c.name for c in calls] == ["a", "c"] + + +def test_get_event_function_responses_empty_parts_returns_empty(): + event = SimpleNamespace(content=SimpleNamespace(parts=[])) + assert adk_compat.get_event_function_responses(event) == [] + + +def test_get_event_function_responses_parts_none_returns_empty(): + event = SimpleNamespace(content=SimpleNamespace(parts=None)) + assert adk_compat.get_event_function_responses(event) == [] + + +def test_get_event_function_responses_mixed_parts_filters_out_none(): + part_a = SimpleNamespace(function_response=SimpleNamespace(name="a")) + part_b = SimpleNamespace(function_response=None) + part_c = SimpleNamespace(function_response=SimpleNamespace(name="c")) + event = SimpleNamespace(content=SimpleNamespace(parts=[part_a, part_b, part_c])) + responses = adk_compat.get_event_function_responses(event) + assert [r.name for r in responses] == ["a", "c"] + + +# --------------------------------------------------------------------------- +# Section C: integration with real ADK Event objects (5 tests) +# +# Verifies that the helpers behave correctly on ADK's actual ``Event`` +# instances, not just mocks — guarding against ADK silently changing its +# Event/Content/Part shape across versions. +# --------------------------------------------------------------------------- + + +def _make_real_adk_text_event(text: str = "hello"): + from google.adk.events import Event + from google.genai import types + + return Event( + invocation_id="inv-1", + author="user", + content=types.Content(role="user", parts=[types.Part(text=text)]), + ) + + +def _make_real_adk_function_call_event(): + from google.adk.events import Event + from google.genai import types + + fc = types.FunctionCall(name="my_tool", args={"x": 1}) + return Event( + invocation_id="inv-2", + author="model", + content=types.Content(role="model", parts=[types.Part(function_call=fc)]), + ) + + +def _make_real_adk_function_response_event(): + from google.adk.events import Event + from google.genai import types + + fr = types.FunctionResponse(name="my_tool", response={"ok": True}) + return Event( + invocation_id="inv-3", + author="user", + content=types.Content(role="user", parts=[types.Part(function_response=fr)]), + ) + + +def test_real_adk_event_with_text_returns_no_function_calls(): + event = _make_real_adk_text_event() + assert adk_compat.get_event_function_calls(event) == [] + assert adk_compat.get_event_function_responses(event) == [] + + +def test_real_adk_event_with_function_call_returns_call(): + event = _make_real_adk_function_call_event() + calls = adk_compat.get_event_function_calls(event) + assert len(calls) == 1 + assert calls[0].name == "my_tool" + + +def test_real_adk_event_with_function_response_returns_response(): + event = _make_real_adk_function_response_event() + responses = adk_compat.get_event_function_responses(event) + assert len(responses) == 1 + assert responses[0].name == "my_tool" + + +def test_helper_function_calls_matches_native_getter(): + """Helper output must match ``Event.get_function_calls()`` on real ADK events.""" + event = _make_real_adk_function_call_event() + assert adk_compat.get_event_function_calls(event) == event.get_function_calls() + + +def test_helper_function_responses_matches_native_getter(): + event = _make_real_adk_function_response_event() + assert ( + adk_compat.get_event_function_responses(event) + == event.get_function_responses() + ) + + +# --------------------------------------------------------------------------- +# Section D: Agent.run override gated by ADK version (3 tests) +# +# In ADK 2.0, ``BaseAgent.run`` is a ``@final`` async generator that the +# Workflow engine invokes internally; overriding it breaks NodeRunner. The +# compat fix in ``veadk/agent.py`` only declares the v1 ``NotImplementedError`` +# guard when ``not is_adk_gte("2.0.0")``. These tests assert that gating works +# as advertised. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def _agent_env(monkeypatch): + """Minimum env needed for ``veadk.Agent()`` to construct without secrets.""" + monkeypatch.setenv("MODEL_AGENT_API_KEY", "dummy-key-for-tests") + + +def test_agent_remains_subclass_of_adk_llm_agent(_agent_env): + """VeADK Agent must continue inheriting from ADK's ``LlmAgent``.""" + from google.adk.agents.llm_agent import LlmAgent as ADKLlmAgent + + from veadk import Agent + + assert issubclass(Agent, ADKLlmAgent) + + +def test_agent_run_override_presence_matches_adk_version(_agent_env): + """``Agent.run`` should be locally defined on v1.x only. + + On v1.x ``BaseAgent`` has no ``run``, so VeADK declares its own; on v2.x + the parent ``BaseAgent.run`` is ``@final`` so VeADK must defer to it. + """ + from veadk import Agent + + has_local_run = "run" in Agent.__dict__ + expected_on_v1 = not adk_compat.is_adk_gte("2.0.0") + assert has_local_run is expected_on_v1, ( + f"Agent.__dict__ has 'run'={has_local_run}, but expected " + f"{expected_on_v1} for ADK {adk_compat.get_adk_version()}" + ) + + +def test_agent_run_raises_notimplemented_on_legacy_adk(_agent_env): + """On v1.x the override must still surface the deprecation error. + + Skipped automatically on v2.x where the override is intentionally gone. + """ + if adk_compat.is_adk_gte("2.0.0"): + pytest.skip("Override is removed on ADK >= 2.0 by design") + + import asyncio + + from veadk import Agent + + agent = Agent(name="t") + with pytest.raises(NotImplementedError, match="runner.run_async"): + asyncio.run(agent.run()) + + +# --------------------------------------------------------------------------- +# Section E: ArkLlm version-branching (3 tests) +# +# Covers the version-detection branch in ``ArkLlm.__init__`` (which raises +# ``ImportError`` when ADK lacks ``previous_interaction_id``), plus the +# ``get_previous_interaction_id`` helper applied directly to an LlmRequest. +# --------------------------------------------------------------------------- + + +def test_arkllm_init_raises_when_field_missing(monkeypatch): + """If ADK's LlmRequest lacks ``previous_interaction_id``, ArkLlm refuses.""" + monkeypatch.setattr( + "veadk.models.ark_llm.llm_request_has_field", + lambda field: False, + ) + from veadk.models.ark_llm import ArkLlm + + with pytest.raises(ImportError, match="google-adk"): + ArkLlm(model="ark/test-model") + + +def test_get_previous_interaction_id_with_real_llm_request(): + """The helper reads cleanly from a real ADK ``LlmRequest`` instance. + + Skipped automatically on ADK versions that do not yet expose the + ``previous_interaction_id`` field — the whole point of the helper is + that callers never need to know which version they're on. + """ + if not adk_compat.llm_request_has_field("previous_interaction_id"): + pytest.skip("ADK build lacks 'previous_interaction_id' field") + + from google.adk.models.llm_request import LlmRequest + + req = LlmRequest(model="ark/test", previous_interaction_id="iid-42") # type: ignore[call-arg] + assert adk_compat.get_previous_interaction_id(req) == "iid-42" + + +def test_get_previous_interaction_id_returns_none_when_unset(): + """The helper returns ``None`` whether the field is absent or just unset. + + Works on every supported ADK version because it relies only on + ``getattr(..., default=None)``. + """ + from google.adk.models.llm_request import LlmRequest + + req = LlmRequest(model="ark/test") + assert adk_compat.get_previous_interaction_id(req) is None + + +# --------------------------------------------------------------------------- +# Section F: tool_attributes_extractors fallback variants (4 tests) +# +# After the smooth-upgrade work, ``tool_gen_ai_tool_output`` must accept three +# response object shapes (pydantic-like with ``model_dump``, raw dict, plain +# attribute object) plus an empty list. The existing file tested dict + object +# only; add the model_dump path, the empty-response sentinel, and an id/name +# preservation check. +# --------------------------------------------------------------------------- + + +def _make_extractor_params(function_response): + event = SimpleNamespace( + content=SimpleNamespace( + parts=[SimpleNamespace(function_response=function_response)] + ) + ) + return ToolAttributesParams( + tool=SimpleNamespace(name="my_tool"), + args={}, + function_response_event=event, + ) + + +def test_tool_output_extractor_uses_model_dump_when_available(): + """Pydantic-like objects with ``model_dump`` go through the v1 path.""" + response_obj = SimpleNamespace( + model_dump=lambda: {"id": "fid", "name": "tool_a", "response": {"v": 1}} + ) + params = _make_extractor_params(response_obj) + result = tool_gen_ai_tool_output(params) + assert '"name": "tool_a"' in result.content + assert '"id": "fid"' in result.content + + +def test_tool_output_extractor_returns_sentinel_when_no_responses(): + """Empty ``function_responses`` yields the ```` sentinel.""" + event = SimpleNamespace(content=SimpleNamespace(parts=[])) + params = ToolAttributesParams( + tool=SimpleNamespace(name="my_tool"), + args={}, + function_response_event=event, + ) + result = tool_gen_ai_tool_output(params) + assert "" in result.content + + +def test_tool_output_extractor_preserves_id_and_name_for_attribute_object(): + """Falls back to ``getattr`` for objects lacking ``model_dump`` and not dict.""" + response_obj = SimpleNamespace(id="fid_x", name="tool_x", response={"k": "v"}) + result = tool_gen_ai_tool_output(_make_extractor_params(response_obj)) + assert '"id": "fid_x"' in result.content + assert '"name": "tool_x"' in result.content + + +def test_tool_output_extractor_handles_missing_attributes_gracefully(): + """An object missing ``id``/``name`` shouldn't crash — getattr default kicks in.""" + response_obj = SimpleNamespace() # no id, no name, no response + result = tool_gen_ai_tool_output(_make_extractor_params(response_obj)) + # Should still emit valid JSON content; id/name fall back to empty. + assert '"id": ""' in result.content + assert '"name": ""' in result.content + + +# --------------------------------------------------------------------------- +# Section G: Runner intercept_new_message integration (5 tests) +# +# The runner's ``intercept_new_message`` decorator was updated to (1) route +# through the new ``get_event_function_calls/responses`` helpers and (2) +# tolerate ``part.text is None`` — both are real failure modes when running +# against ADK 2.0's event stream. Drive the wrapper with synthetic event +# streams to confirm. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def _runner_env(monkeypatch, tmp_path): + monkeypatch.setenv("MODEL_AGENT_API_KEY", "dummy") + yield + + +def _make_fake_runner(_runner_env): + """Build a Runner with a never-called FakeLlm so we can inspect plumbing.""" + from typing import AsyncGenerator + + from google.adk.models.base_llm import BaseLlm + from google.adk.models.llm_response import LlmResponse + from google.genai import types + + from veadk import Agent, Runner + from veadk.memory.short_term_memory import ShortTermMemory + + class FakeLlm(BaseLlm): + async def generate_content_async( + self, llm_request, stream=False + ) -> AsyncGenerator[LlmResponse, None]: + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part(text="fake reply")], + ) + ) + + agent = Agent( + name="fake_agent", + description="fake", + instruction="be brief", + model=FakeLlm(model="fake"), + ) + runner = Runner(agent=agent, short_term_memory=ShortTermMemory(backend="local")) + return runner + + +def test_runner_session_service_is_inmemory_for_local_stm(_runner_env): + """``ShortTermMemory(backend='local')`` plumbs through InMemorySessionService.""" + from google.adk.sessions import InMemorySessionService + + runner = _make_fake_runner(_runner_env) + assert isinstance(runner.session_service, InMemorySessionService) + + +@pytest.mark.asyncio +async def test_runner_create_session_returns_adk_session(_runner_env): + """ADK session-service contract: create_session returns a ``Session``.""" + from google.adk.sessions import Session + + runner = _make_fake_runner(_runner_env) + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id="u1", session_id="s1" + ) + assert isinstance(session, Session) + assert session.id == "s1" + + +@pytest.mark.asyncio +async def test_runner_run_async_yields_model_text(_runner_env): + """End-to-end: an Agent + FakeLlm + Runner yields the model's text event.""" + from google.genai import types + + runner = _make_fake_runner(_runner_env) + await runner.session_service.create_session( + app_name=runner.app_name, user_id="u1", session_id="s1" + ) + msg = types.Content(role="user", parts=[types.Part(text="ping")]) + texts = [] + async for ev in runner.run_async(user_id="u1", session_id="s1", new_message=msg): + if ev and ev.content and ev.content.parts: + for p in ev.content.parts: + if getattr(p, "text", None): + texts.append(p.text) + assert "fake reply" in texts + + +def _fake_runner_self(): + """Minimum attributes ``pre_run_process`` reads off ``self``.""" + return SimpleNamespace( + app_name="test_app", + upload_inline_data_to_tos=False, + short_term_memory=None, + ) + + +def _empty_user_message(): + """Empty ``Content`` so ``pre_run_process`` short-circuits cleanly.""" + from google.genai import types + + return types.Content(role="user", parts=[]) + + +async def _no_op_processor(part, app_name, user_id, session_id): + """Stand-in for ``_upload_image_to_tos`` in the runner wrapper tests.""" + return None + + +@pytest.mark.asyncio +async def test_runner_intercept_skips_none_events_safely(_runner_env): + """Synthetic stream containing ``None`` events must not raise.""" + from google.adk.events import Event + from google.genai import types + + from veadk.runner import intercept_new_message + + async def upstream(**kwargs): + # Mix valid events with ``None`` to assert the wrapper filters them. + yield None + yield Event( + invocation_id="i", + author="u", + content=types.Content(role="user", parts=[types.Part(text="hi")]), + ) + + wrapped = intercept_new_message(_no_op_processor)(upstream) + collected = [] + async for ev in wrapped( + _fake_runner_self(), + user_id="u", + session_id="s", + new_message=_empty_user_message(), + ): + collected.append(ev) + # None should be filtered out; the valid event should survive. + assert len(collected) == 1 + assert collected[0].author == "u" + + +@pytest.mark.asyncio +async def test_runner_intercept_tolerates_part_text_none(_runner_env): + """A model event whose ``part.text`` is ``None`` must not raise. + + Regression check for the new ``if part.text and len(part.text.strip()) > 0`` + guard in runner.py — earlier code path called ``.strip()`` on ``None``. + """ + from google.adk.events import Event + from google.genai import types + + from veadk.runner import intercept_new_message + + async def upstream(**kwargs): + yield Event( + invocation_id="i", + author="model", + content=types.Content( + role="model", + parts=[types.Part(text=None)], # the dangerous shape + ), + ) + + wrapped = intercept_new_message(_no_op_processor)(upstream) + collected = [] + async for ev in wrapped( + _fake_runner_self(), + user_id="u", + session_id="s", + new_message=_empty_user_message(), + ): + collected.append(ev) + assert len(collected) == 1 + + +# --------------------------------------------------------------------------- +# Section H: ADK public-surface assumptions (3 tests) +# +# Veadk's compat layer relies on a handful of ADK public modules / classes +# being importable and having stable attributes. These tests guard those +# assumptions so an upstream rename produces a precise, on-topic failure. +# --------------------------------------------------------------------------- + + +def test_adk_version_module_exposes_version_string(): + """``from google.adk import version; version.__version__`` must be a string.""" + mod = importlib.import_module("google.adk.version") + assert isinstance(mod.__version__, str) + assert mod.__version__.split(".")[0].isdigit() + + +def test_adk_llm_request_model_fields_accessible(): + """The ``model_fields`` introspection used by ``llm_request_has_field`` works.""" + from google.adk.models.llm_request import LlmRequest + + fields = getattr(LlmRequest, "model_fields", None) + assert isinstance(fields, dict) + assert "model" in fields, "ADK's LlmRequest must still expose 'model'" + + +def test_adk_events_module_exports_event_class(): + """``Event`` import path used in runner/evaluator must remain stable.""" + from google.adk.events import Event + + assert hasattr(Event, "get_function_calls") + assert hasattr(Event, "get_function_responses") diff --git a/veadk/agent.py b/veadk/agent.py index 32320078..2cde8a14 100644 --- a/veadk/agent.py +++ b/veadk/agent.py @@ -49,6 +49,7 @@ ) from veadk.prompts.prompt_manager import BasePromptManager from veadk.tracing.base_tracer import BaseTracer +from veadk.utils.adk_compat import is_adk_gte from veadk.utils.logger import get_logger from veadk.utils.patches import patch_asyncio, patch_tracer from veadk.version import VERSION @@ -623,7 +624,12 @@ def _llm_flow(self) -> BaseLlmFlow: return SupervisorAutoFlow(supervised_agent=self) return AutoFlow() - async def run(self, **kwargs): - raise NotImplementedError( - "Run method in VeADK agent is deprecated since version 0.5.6. Please use runner.run_async instead. Ref: https://agentkit.gitbook.io/docs/runner/overview" - ) + if not is_adk_gte("2.0.0"): + # On google-adk 1.x, BaseAgent has no `run` method, so override here + # to nudge users toward `runner.run_async`. On google-adk 2.x, + # BaseAgent.run is a @final async generator that the workflow engine + # invokes internally; overriding it would break NodeRunner execution. + async def run(self, **kwargs): + raise NotImplementedError( + "Run method in VeADK agent is deprecated since version 0.5.6. Please use runner.run_async instead. Ref: https://agentkit.gitbook.io/docs/runner/overview" + ) diff --git a/veadk/evaluation/base_evaluator.py b/veadk/evaluation/base_evaluator.py index d0c8e28f..5d872ffd 100644 --- a/veadk/evaluation/base_evaluator.py +++ b/veadk/evaluation/base_evaluator.py @@ -25,6 +25,7 @@ from google.genai import types from pydantic import BaseModel +from veadk.utils.adk_compat import get_event_function_calls from veadk.utils.misc import formatted_timestamp @@ -556,8 +557,8 @@ async def generate_actual_outputs(self): and event.content.parts ): final_response = event.content - elif event.get_function_calls(): - for call in event.get_function_calls(): + else: + for call in get_event_function_calls(event): tool_uses.append(call) tok = time.time() _latency = str((tok - tik) * 1000) diff --git a/veadk/evaluation/eval_set_recorder.py b/veadk/evaluation/eval_set_recorder.py index 9efbe3ab..7d96a65d 100644 --- a/veadk/evaluation/eval_set_recorder.py +++ b/veadk/evaluation/eval_set_recorder.py @@ -15,7 +15,12 @@ import time from pathlib import Path import os -from google.adk.cli.utils import evals +# Note: ``from google.adk.cli.utils import evals`` is imported lazily inside +# ``add_session_to_eval_set`` below. On google-adk 2.0, that submodule pulls +# in ``gcs_eval_set_results_manager`` at top level, which requires +# ``google-cloud-storage`` to be installed. We don't want loading +# ``veadk.evaluation`` (or, transitively, ``veadk.Runner``) to fail for users +# who never touch the eval recorder, so defer the import to the call site. from google.adk.evaluation.eval_case import EvalCase, SessionInput from google.adk.evaluation.local_eval_sets_manager import LocalEvalSetsManager from google.adk.sessions import BaseSessionService @@ -89,6 +94,8 @@ async def add_session_to_eval_set( assert session, "Session not found." # Convert the session data to eval invocations + from google.adk.cli.utils import evals # lazy: see top-of-file note + invocations = evals.convert_session_to_eval_invocations(session) # Populate the session with initial session state. diff --git a/veadk/integrations/ve_identity/auth_processor.py b/veadk/integrations/ve_identity/auth_processor.py index 5f981b68..2dea7cec 100644 --- a/veadk/integrations/ve_identity/auth_processor.py +++ b/veadk/integrations/ve_identity/auth_processor.py @@ -294,8 +294,8 @@ async def event_generator(): new_message=message, run_config=RunConfig(streaming_mode=stream_mode), ): - if event.get_function_calls(): - for function_call in event.get_function_calls(): + if get_event_function_calls(event): + for function_call in get_event_function_calls(event): logger.debug(f"Function call: {function_call}") elif event.content is not None: yield event.content.parts[0].text diff --git a/veadk/memory/short_term_memory_backends/mysql_backend.py b/veadk/memory/short_term_memory_backends/mysql_backend.py index 6f2cc8b5..e3e92d91 100644 --- a/veadk/memory/short_term_memory_backends/mysql_backend.py +++ b/veadk/memory/short_term_memory_backends/mysql_backend.py @@ -15,12 +15,10 @@ from functools import cached_property from typing import Any -from google.adk import version as adk_version from google.adk.sessions import ( BaseSessionService, DatabaseSessionService, ) -from packaging.version import parse as parse_version from pydantic import Field from typing_extensions import override from urllib.parse import quote_plus @@ -30,6 +28,7 @@ from veadk.memory.short_term_memory_backends.base_backend import ( BaseShortTermMemoryBackend, ) +from veadk.utils.adk_compat import should_use_async_db_drivers class MysqlSTMBackend(BaseShortTermMemoryBackend): @@ -39,10 +38,10 @@ class MysqlSTMBackend(BaseShortTermMemoryBackend): def model_post_init(self, context: Any) -> None: encoded_username = quote_plus(self.mysql_config.user) encoded_password = quote_plus(self.mysql_config.password) - if parse_version(adk_version.__version__) < parse_version("1.19.0"): - self._db_url = f"mysql+pymysql://{encoded_username}:{encoded_password}@{self.mysql_config.host}/{self.mysql_config.database}" - else: + if should_use_async_db_drivers(): self._db_url = f"mysql+aiomysql://{encoded_username}:{encoded_password}@{self.mysql_config.host}/{self.mysql_config.database}" + else: + self._db_url = f"mysql+pymysql://{encoded_username}:{encoded_password}@{self.mysql_config.host}/{self.mysql_config.database}" @cached_property @override diff --git a/veadk/memory/short_term_memory_backends/postgresql_backend.py b/veadk/memory/short_term_memory_backends/postgresql_backend.py index 2e37daa9..4c5e5e60 100644 --- a/veadk/memory/short_term_memory_backends/postgresql_backend.py +++ b/veadk/memory/short_term_memory_backends/postgresql_backend.py @@ -16,12 +16,10 @@ from typing import Any from urllib.parse import quote_plus -from google.adk import version as adk_version from google.adk.sessions import ( BaseSessionService, DatabaseSessionService, ) -from packaging.version import parse as parse_version from pydantic import Field from typing_extensions import override @@ -30,6 +28,7 @@ from veadk.memory.short_term_memory_backends.base_backend import ( BaseShortTermMemoryBackend, ) +from veadk.utils.adk_compat import should_use_async_db_drivers class PostgreSqlSTMBackend(BaseShortTermMemoryBackend): @@ -39,10 +38,10 @@ class PostgreSqlSTMBackend(BaseShortTermMemoryBackend): def model_post_init(self, context: Any) -> None: encoded_username = quote_plus(self.postgresql_config.user) encoded_password = quote_plus(self.postgresql_config.password) - if parse_version(adk_version.__version__) < parse_version("1.19.0"): - self._db_url = f"postgresql://{encoded_username}:{encoded_password}@{self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}" - else: + if should_use_async_db_drivers(): self._db_url = f"postgresql+asyncpg://{encoded_username}:{encoded_password}@{self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}" + else: + self._db_url = f"postgresql://{encoded_username}:{encoded_password}@{self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}" @cached_property @override diff --git a/veadk/memory/short_term_memory_backends/sqlite_backend.py b/veadk/memory/short_term_memory_backends/sqlite_backend.py index 1a077690..34e13bd7 100644 --- a/veadk/memory/short_term_memory_backends/sqlite_backend.py +++ b/veadk/memory/short_term_memory_backends/sqlite_backend.py @@ -17,17 +17,16 @@ from functools import cached_property from typing import Any -from google.adk import version as adk_version from google.adk.sessions import ( BaseSessionService, DatabaseSessionService, ) -from packaging.version import parse as parse_version from typing_extensions import override from veadk.memory.short_term_memory_backends.base_backend import ( BaseShortTermMemoryBackend, ) +from veadk.utils.adk_compat import should_use_async_db_drivers class SQLiteSTMBackend(BaseShortTermMemoryBackend): @@ -41,10 +40,10 @@ def model_post_init(self, context: Any) -> None: conn = sqlite3.connect(self.local_path) conn.close() - if parse_version(adk_version.__version__) < parse_version("1.19.0"): - self._db_url = f"sqlite:///{self.local_path}" - else: + if should_use_async_db_drivers(): self._db_url = f"sqlite+aiosqlite:///{self.local_path}" + else: + self._db_url = f"sqlite:///{self.local_path}" @cached_property @override diff --git a/veadk/models/ark_llm.py b/veadk/models/ark_llm.py index d9f6fd08..07cba482 100644 --- a/veadk/models/ark_llm.py +++ b/veadk/models/ark_llm.py @@ -60,6 +60,10 @@ from veadk.config import settings from veadk.consts import DEFAULT_VIDEO_MODEL_API_BASE +from veadk.utils.adk_compat import ( + get_previous_interaction_id, + llm_request_has_field, +) from veadk.utils.logger import get_logger logger = get_logger(__name__) @@ -580,8 +584,8 @@ def record_logs(raw_response: ArkTypeResponse): f"Status: `{raw_response.status}`. " f"{error_message}" ) - except Exception as e: - logger.error(f"Failed to record ark logs: {e}") + except Exception: + logger.exception("Failed to record Ark response logs") def event_to_generate_content_response( @@ -703,7 +707,7 @@ class ArkLlm(Gemini): def __init__(self, **kwargs): # adk version check - if "previous_interaction_id" not in LlmRequest.model_fields: + if not llm_request_has_field("previous_interaction_id"): raise ImportError( "If using the ResponsesAPI, " "please upgrade the version of google-adk to `1.21.0` or higher with the command: " @@ -746,8 +750,8 @@ async def generate_content_async( # ------------------------------------------------------ # # get previous_response_id previous_response_id = None - if self.enable_responses_cache and llm_request.previous_interaction_id: - previous_response_id = llm_request.previous_interaction_id + if self.enable_responses_cache: + previous_response_id = get_previous_interaction_id(llm_request) responses_args = { "model": self.model, "instructions": instructions, @@ -786,15 +790,17 @@ async def generate_content_async( responses_args.copy(), stream=stream ): yield llm_response - except Exception as retry_e: - logger.error(f"Retry failed in generate_content_async: {retry_e}") - raise retry_e + except Exception: + logger.exception( + "Retry without previous_response_id failed in Ark Responses API" + ) + raise else: - logger.error(f"Error in generate_content_async: {e}") - raise e - except Exception as e: - logger.error(f"Error in generate_content_async: {e}") - raise e + logger.exception("Ark Responses API request failed") + raise + except Exception: + logger.exception("Unexpected error in Ark Responses API generation") + raise async def generate_content_via_responses( self, responses_args: dict, stream: bool = False diff --git a/veadk/runner.py b/veadk/runner.py index 6689928f..c9518bfc 100644 --- a/veadk/runner.py +++ b/veadk/runner.py @@ -34,6 +34,10 @@ from veadk.memory.short_term_memory import ShortTermMemory from veadk.processors.base_run_processor import BaseRunProcessor from veadk.types import MediaMessage +from veadk.utils.adk_compat import ( + get_event_function_calls, + get_event_function_responses, +) from veadk.utils.logger import get_logger from veadk.utils.misc import formatted_timestamp, read_file_to_bytes @@ -143,17 +147,19 @@ async def wrapper( yield event event_metadata = f"| agent_name: {event.author} , user_id: {user_id} , session_id: {session_id} , invocation_id: {event.invocation_id}" - if event.get_function_calls(): - for function_call in event.get_function_calls(): + function_calls = get_event_function_calls(event) + function_responses = get_event_function_responses(event) + if function_calls: + for function_call in function_calls: logger.debug(f"Function call: {function_call} {event_metadata}") - elif event.get_function_responses(): - for function_response in event.get_function_responses(): + elif function_responses: + for function_response in function_responses: logger.debug( f"Function response: {function_response} {event_metadata}" ) elif event.content is not None and event.content.parts: for part in event.content.parts: - if len(part.text.strip()) > 0: + if part.text and len(part.text.strip()) > 0: final_output = part.text if part.thought: logger.debug( @@ -286,7 +292,10 @@ async def _upload_image_to_tos( ) part.inline_data.display_name = tos_url except Exception as e: - logger.error(f"Upload to TOS failed: {e}") + logger.exception( + "Upload inline data to TOS failed" + f" | app_name={app_name}, user_id={user_id}, session_id={session_id}, error={e}" + ) class Runner(ADKRunner): diff --git a/veadk/tracing/telemetry/attributes/extractors/tool_attributes_extractors.py b/veadk/tracing/telemetry/attributes/extractors/tool_attributes_extractors.py index 5317998f..2f6319ad 100644 --- a/veadk/tracing/telemetry/attributes/extractors/tool_attributes_extractors.py +++ b/veadk/tracing/telemetry/attributes/extractors/tool_attributes_extractors.py @@ -16,6 +16,7 @@ ExtractorResponse, ToolAttributesParams, ) +from veadk.utils.adk_compat import get_event_function_responses from veadk.utils.misc import safe_json_serialize @@ -126,9 +127,20 @@ def tool_gen_ai_tool_output(params: ToolAttributesParams) -> ExtractorResponse: Returns: ExtractorResponse: Response containing JSON serialized tool output data """ - function_response = params.function_response_event.get_function_responses()[ - 0 - ].model_dump() + function_responses = get_event_function_responses(params.function_response_event) + if not function_responses: + return ExtractorResponse(content="") + function_response_obj = function_responses[0] + if hasattr(function_response_obj, "model_dump"): + function_response = function_response_obj.model_dump() + elif isinstance(function_response_obj, dict): + function_response = function_response_obj + else: + function_response = { + "id": getattr(function_response_obj, "id", ""), + "name": getattr(function_response_obj, "name", ""), + "response": getattr(function_response_obj, "response", None), + } tool_output = { "id": function_response["id"], "name": function_response["name"], diff --git a/veadk/utils/adk_compat.py b/veadk/utils/adk_compat.py new file mode 100644 index 00000000..91be73cd --- /dev/null +++ b/veadk/utils/adk_compat.py @@ -0,0 +1,91 @@ +"""Compatibility helpers for Google ADK feature/version checks. + +This module centralizes ADK capability detection to avoid scattering +hard-coded version checks across the codebase. +""" + +from __future__ import annotations + +from functools import lru_cache +from typing import Any, Optional + +from packaging.version import Version, parse as parse_version + + +@lru_cache(maxsize=1) +def get_adk_version() -> Version: + """Return installed Google ADK version (best effort).""" + try: + from google.adk import version as adk_version + + return parse_version(adk_version.__version__) + except Exception: + return parse_version("0.0.0") + + +def is_adk_gte(version: str) -> bool: + """Whether installed ADK version is greater than or equal to target.""" + return get_adk_version() >= parse_version(version) + + +def should_use_async_db_drivers() -> bool: + """Whether ADK expects async SQLAlchemy DSN schemes for DB sessions.""" + return is_adk_gte("1.19.0") + + +@lru_cache(maxsize=32) +def llm_request_has_field(field_name: str) -> bool: + """Check whether ``google.adk.models.LlmRequest`` contains a model field.""" + try: + from google.adk.models import LlmRequest + + return field_name in getattr(LlmRequest, "model_fields", {}) + except Exception: + return False + + +def get_previous_interaction_id(llm_request) -> Optional[str]: + """Safely read ``previous_interaction_id`` from LlmRequest across ADK versions.""" + return getattr(llm_request, "previous_interaction_id", None) + + +def get_event_function_calls(event: Any) -> list[Any]: + """Extract function calls from an ADK event across versions.""" + getter = getattr(event, "get_function_calls", None) + if callable(getter): + try: + calls = getter() + return list(calls or []) + except Exception: + # Fallback to generic part traversal for compatibility. + pass + + calls: list[Any] = [] + content = getattr(event, "content", None) + parts = getattr(content, "parts", None) if content is not None else None + for part in parts or []: + function_call = getattr(part, "function_call", None) + if function_call is not None: + calls.append(function_call) + return calls + + +def get_event_function_responses(event: Any) -> list[Any]: + """Extract function responses from an ADK event across versions.""" + getter = getattr(event, "get_function_responses", None) + if callable(getter): + try: + responses = getter() + return list(responses or []) + except Exception: + # Fallback to generic part traversal for compatibility. + pass + + responses: list[Any] = [] + content = getattr(event, "content", None) + parts = getattr(content, "parts", None) if content is not None else None + for part in parts or []: + function_response = getattr(part, "function_response", None) + if function_response is not None: + responses.append(function_response) + return responses diff --git a/veadk/utils/patches.py b/veadk/utils/patches.py index 86c420c1..0a827946 100644 --- a/veadk/utils/patches.py +++ b/veadk/utils/patches.py @@ -63,6 +63,25 @@ def patched_cancel_scope_exit(self, exc_type, exc_val, exc_tb): CancelScope.__exit__ = patched_cancel_scope_exit +def _iter_loaded_attrs(mod): + """Iterate ``(name, value)`` pairs for already-loaded module attrs. + + Walking ``dir(mod) + getattr(mod, name)`` would trip ``__getattr__`` hooks + that ADK 2.0 uses for lazy loading (e.g. ``google.adk.tools``), which in + turn drags in optional-dep submodules like ``discovery_engine_search_tool`` + that veadk does not need. Reading ``mod.__dict__`` avoids that side effect + — we only see attrs that have actually been imported into the module + namespace, which is exactly the set we want to patch. + """ + namespace = getattr(mod, "__dict__", None) + if not isinstance(namespace, dict): + return + # Snapshot to avoid "dict changed size during iteration" if a setattr + # below mutates the namespace mid-loop. + for name, value in tuple(namespace.items()): + yield name, value + + def patch_google_adk_telemetry() -> None: trace_functions = { "trace_tool_call": trace_tool_call, @@ -70,11 +89,10 @@ def patch_google_adk_telemetry() -> None: "trace_send_data": trace_send_data, } - for mod_name, mod in sys.modules.items(): + for mod_name, mod in tuple(sys.modules.items()): if mod_name.startswith("google.adk"): - for var_name in dir(mod): - var = getattr(mod, var_name, None) - if var_name in trace_functions.keys() and isinstance(var, Callable): + for var_name, var in _iter_loaded_attrs(mod): + if var_name in trace_functions and isinstance(var, Callable): setattr(mod, var_name, trace_functions[var_name]) logger.debug( f"Patch {mod_name} {var_name} with {trace_functions[var_name]}" @@ -86,8 +104,7 @@ def patch_tracer() -> None: for mod_name, mod in tuple(sys.modules.items()): if mod_name.startswith("google.adk"): - for var_name in dir(mod): - var = getattr(mod, var_name, None) + for var_name, var in _iter_loaded_attrs(mod): if var_name == "tracer" and isinstance(var, trace.Tracer): setattr( mod,