From 18c341da6602de6add7482f1ed1733420d8fd89c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filipe=20Caba=C3=A7o?= Date: Thu, 30 Apr 2026 18:10:33 +0100 Subject: [PATCH] feat: AI agent extension MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds first-class AI agent support to Realtime. Clients can chat with LLM-backed agents through the same WebSocket they already use for Broadcast, Presence, and Postgres Changes. The provider API key never leaves the server — clients only see streamed broadcasts. Supports OpenAI, Anthropic, Groq, OpenRouter, Mistral, Ollama, and any OpenAI-compatible endpoint. Conversations are persisted in realtime.messages and can be resumed across reconnects via session_id. Authorization uses the existing RLS pattern with a new ai_agent extension type. --- .github/workflows/tests.yml | 61 +++ compose.ollama.yml | 28 ++ compose.tests.yml | 6 +- compose.yml | 2 + config/config.exs | 7 + lib/extensions/ai_agent/adapter.ex | 31 ++ .../ai_agent/adapter/anthropic_messages.ex | 118 +++++ .../ai_agent/adapter/chat_completions.ex | 105 ++++ lib/extensions/ai_agent/adapter/sse_stream.ex | 77 +++ lib/extensions/ai_agent/db_settings.ex | 21 + lib/extensions/ai_agent/session.ex | 452 ++++++++++++++++++ lib/extensions/ai_agent/session_supervisor.ex | 24 + lib/extensions/ai_agent/supervisor.ex | 31 ++ lib/extensions/ai_agent/types/event.ex | 28 ++ .../ai_agent/types/tool_call_buffer.ex | 71 +++ lib/realtime/api.ex | 19 +- lib/realtime/api/extensions.ex | 136 ++++-- lib/realtime/api/message.ex | 2 +- lib/realtime/api/tenant.ex | 8 +- lib/realtime/dns.ex | 4 + lib/realtime/messages.ex | 19 +- lib/realtime/tenants.ex | 36 ++ lib/realtime/tenants/authorization.ex | 72 ++- .../tenants/authorization/policies.ex | 13 +- .../authorization/policies/ai_policies.ex | 12 + .../tenants/replication_connection.ex | 9 + lib/realtime/tenants/repo.ex | 30 +- lib/realtime_web/channels/realtime_channel.ex | 94 ++-- .../realtime_channel/ai_agent_handler.ex | 163 +++++++ .../realtime_channel/broadcast_handler.ex | 172 +++---- lib/realtime_web/tenant_broadcaster.ex | 2 +- mise.toml | 5 + priv/repo/dev_seeds.exs | 41 ++ .../20260430000001_add_name_to_extensions.exs | 21 + .../20260430000002_add_ai_to_tenants.exs | 11 + .../adapter/anthropic_messages_test.exs | 239 +++++++++ .../adapter/chat_completions_test.exs | 197 ++++++++ test/extensions/ai_agent/adapter_test.exs | 13 + test/extensions/ai_agent/ai_policies_test.exs | 25 + .../ai_agent/broadcast_handler_ai_test.exs | 147 ++++++ test/extensions/ai_agent/db_settings_test.exs | 38 ++ test/extensions/ai_agent/event_test.exs | 15 + .../ai_agent/session_persistence_test.exs | 253 ++++++++++ .../ai_agent/session_rate_limit_test.exs | 154 ++++++ .../ai_agent/session_supervisor_test.exs | 33 ++ test/extensions/ai_agent/session_test.exs | 183 +++++++ test/extensions/extensions_test.exs | 11 + test/integration/ai_agent/live_smoke_test.exs | 211 ++++++++ test/realtime/api/extensions_test.exs | 142 ++++++ test/realtime/messages_test.exs | 98 +++- .../tenants/authorization_remote_test.exs | 13 +- test/realtime/tenants/authorization_test.exs | 10 +- .../tenants/replication_connection_test.exs | 30 +- .../broadcast_handler_test.exs | 48 ++ .../channels/realtime_channel_test.exs | 113 ++++- test/support/generators.ex | 46 +- test/support/ollama.ex | 141 ++++++ test/test_helper.exs | 10 +- 58 files changed, 3834 insertions(+), 267 deletions(-) create mode 100644 compose.ollama.yml create mode 100644 lib/extensions/ai_agent/adapter.ex create mode 100644 lib/extensions/ai_agent/adapter/anthropic_messages.ex create mode 100644 lib/extensions/ai_agent/adapter/chat_completions.ex create mode 100644 lib/extensions/ai_agent/adapter/sse_stream.ex create mode 100644 lib/extensions/ai_agent/db_settings.ex create mode 100644 lib/extensions/ai_agent/session.ex create mode 100644 lib/extensions/ai_agent/session_supervisor.ex create mode 100644 lib/extensions/ai_agent/supervisor.ex create mode 100644 lib/extensions/ai_agent/types/event.ex create mode 100644 lib/extensions/ai_agent/types/tool_call_buffer.ex create mode 100644 lib/realtime/dns.ex create mode 100644 lib/realtime/tenants/authorization/policies/ai_policies.ex create mode 100644 lib/realtime_web/channels/realtime_channel/ai_agent_handler.ex create mode 100644 priv/repo/migrations/20260430000001_add_name_to_extensions.exs create mode 100644 priv/repo/migrations/20260430000002_add_ai_to_tenants.exs create mode 100644 test/extensions/ai_agent/adapter/anthropic_messages_test.exs create mode 100644 test/extensions/ai_agent/adapter/chat_completions_test.exs create mode 100644 test/extensions/ai_agent/adapter_test.exs create mode 100644 test/extensions/ai_agent/ai_policies_test.exs create mode 100644 test/extensions/ai_agent/broadcast_handler_ai_test.exs create mode 100644 test/extensions/ai_agent/db_settings_test.exs create mode 100644 test/extensions/ai_agent/event_test.exs create mode 100644 test/extensions/ai_agent/session_persistence_test.exs create mode 100644 test/extensions/ai_agent/session_rate_limit_test.exs create mode 100644 test/extensions/ai_agent/session_supervisor_test.exs create mode 100644 test/extensions/ai_agent/session_test.exs create mode 100644 test/integration/ai_agent/live_smoke_test.exs create mode 100644 test/support/ollama.ex diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 83d908560..9745b8940 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -26,6 +26,8 @@ concurrency: env: MIX_ENV: test POSTGRES_IMAGE: supabase/postgres:17.6.1.074 + OLLAMA_IMAGE: ollama/ollama + OLLAMA_MODEL: qwen2:0.5b jobs: tests: @@ -85,6 +87,65 @@ jobs: name: coverage-partition-${{ matrix.partition }} path: cover/lcov.info + live-llm-tests: + name: Tests (AI Agent / Live LLM) + runs-on: blacksmith-8vcpu-ubuntu-2404 + + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Setup elixir + id: beam + uses: erlef/setup-beam@ee09b1e59bb240681c382eb1f0abc6a04af72764 # v1.23.0 + with: + otp-version: 27.x + elixir-version: 1.18.x + - name: Cache Mix + uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: | + deps + _build + priv/native + key: ${{ github.workflow }}-${{ runner.os }}-mix-${{ env.elixir }}-${{ env.otp }}-${{ hashFiles('**/mix.lock') }} + restore-keys: | + ${{ github.workflow }}-${{ runner.os }}-mix-${{ env.elixir }}-${{ env.otp }}- + - name: Cache Docker images + uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + id: docker-cache + with: + path: /tmp/docker-images + key: docker-images-live-llm-zstd-${{ env.POSTGRES_IMAGE }}-${{ env.OLLAMA_IMAGE }}-${{ env.OLLAMA_MODEL }} + - name: Load Docker images from cache + if: steps.docker-cache.outputs.cache-hit == 'true' + run: | + zstd -d --stdout /tmp/docker-images/postgres.tar.zst | docker image load + zstd -d --stdout /tmp/docker-images/ollama.tar.zst | docker image load + - name: Pull and save Docker images + if: steps.docker-cache.outputs.cache-hit != 'true' + run: | + docker pull ${{ env.POSTGRES_IMAGE }} & + PID1=$! + docker pull ${{ env.OLLAMA_IMAGE }} & + PID2=$! + wait $PID1 || exit $? + wait $PID2 || exit $? + mkdir -p /tmp/docker-images + docker image save ${{ env.POSTGRES_IMAGE }} | zstd -T0 -o /tmp/docker-images/postgres.tar.zst + docker image save ${{ env.OLLAMA_IMAGE }} | zstd -T0 -o /tmp/docker-images/ollama.tar.zst + - name: Install dependencies + run: mix deps.get + - name: Set up Postgres + run: docker compose -f compose.dbs.yml up -d --wait + - name: Set up Ollama + run: docker compose -f compose.ollama.yml up -d --wait + - name: Start epmd + run: epmd -daemon + - name: Run AI agent live tests + env: + OLLAMA_HOST: http://localhost:11434 + OLLAMA_MODEL: ${{ env.OLLAMA_MODEL }} + run: mix test --include live_llm test/integration/ai_agent/ + coverage: name: Merge Coverage needs: tests diff --git a/compose.ollama.yml b/compose.ollama.yml new file mode 100644 index 000000000..8b8a0c28a --- /dev/null +++ b/compose.ollama.yml @@ -0,0 +1,28 @@ +services: + ollama: + image: ollama/ollama + ports: + - "11434:11434" + volumes: + - ollama_data:/root/.ollama + healthcheck: + test: ["CMD-SHELL", "ollama list || exit 1"] + interval: 10s + timeout: 5s + retries: 10 + start_period: 10s + + ollama-model-init: + image: ollama/ollama + depends_on: + ollama: + condition: service_healthy + volumes: + - ollama_data:/root/.ollama + environment: + OLLAMA_HOST: http://ollama:11434 + entrypoint: ["/bin/sh", "-c", "ollama pull ${OLLAMA_MODEL:-qwen2:0.5b}"] + restart: "no" + +volumes: + ollama_data: diff --git a/compose.tests.yml b/compose.tests.yml index 16a01607f..cc49855b4 100644 --- a/compose.tests.yml +++ b/compose.tests.yml @@ -1,5 +1,4 @@ services: - # Supabase Realtime service test_db: image: supabase/postgres:17.6.1.074 container_name: test-realtime-db @@ -16,9 +15,11 @@ services: interval: 10s timeout: 5s retries: 5 + test_realtime: depends_on: - - test_db + test_db: + condition: service_healthy build: . container_name: test-realtime-server ports: @@ -58,7 +59,6 @@ services: retries: 5 start_period: 5s - # Deno test runner test-runner: image: denoland/deno:alpine-2.5.6 container_name: deno-test-runner diff --git a/compose.yml b/compose.yml index 3b2e65919..245de81ab 100644 --- a/compose.yml +++ b/compose.yml @@ -1,11 +1,13 @@ include: - compose.dbs.yml + - compose.ollama.yml services: realtime: depends_on: - db - tenant_db + - ollama-model-init build: . container_name: realtime-server environment: diff --git a/config/config.exs b/config/config.exs index 8550e0d72..29d5d36d6 100644 --- a/config/config.exs +++ b/config/config.exs @@ -29,6 +29,13 @@ config :realtime, :extensions, driver: Extensions.PostgresCdcRls, supervisor: Extensions.PostgresCdcRls.Supervisor, db_settings: Extensions.PostgresCdcRls.DbSettings + }, + ai_agent: %{ + type: :ai_agent, + key: "ai_agent", + driver: Extensions.AiAgent, + supervisor: Extensions.AiAgent.Supervisor, + db_settings: Extensions.AiAgent.DbSettings } config :esbuild, diff --git a/lib/extensions/ai_agent/adapter.ex b/lib/extensions/ai_agent/adapter.ex new file mode 100644 index 000000000..9c52427b3 --- /dev/null +++ b/lib/extensions/ai_agent/adapter.ex @@ -0,0 +1,31 @@ +defmodule Extensions.AiAgent.Adapter do + @moduledoc """ + Behaviour for AI provider adapters. + + An adapter receives resolved settings and a message history, makes a + streaming HTTP request to the provider, and sends `Extensions.AiAgent.Event` + structs to the caller process as `{:ai_event, event}` messages. + + The caller is expected to be a `Extensions.AiAgent.Session` GenServer that + runs the adapter in a `Task` so it can handle cancellation via `Task.shutdown`. + """ + + alias Extensions.AiAgent.Types.Event + + @callback stream(settings :: map(), messages :: list(map()), caller :: pid()) :: + :ok | {:error, term()} + + @spec emit(pid(), Event.t()) :: :ok + def emit(caller, %Event{} = event) do + send(caller, {:ai_event, event}) + :ok + end + + @spec maybe_put(map(), String.t(), term()) :: map() + def maybe_put(map, _key, nil), do: map + def maybe_put(map, key, value), do: Map.put(map, key, value) + + @spec maybe_prepend(list(), String.t(), term()) :: list() + def maybe_prepend(list, _item, nil), do: list + def maybe_prepend(list, item, value), do: [{item, value} | list] +end diff --git a/lib/extensions/ai_agent/adapter/anthropic_messages.ex b/lib/extensions/ai_agent/adapter/anthropic_messages.ex new file mode 100644 index 000000000..fdab076e5 --- /dev/null +++ b/lib/extensions/ai_agent/adapter/anthropic_messages.ex @@ -0,0 +1,118 @@ +defmodule Extensions.AiAgent.Adapter.AnthropicMessages do + @moduledoc """ + Adapter for the Anthropic `/v1/messages` SSE protocol. + """ + + @behaviour Extensions.AiAgent.Adapter + + alias Extensions.AiAgent.Adapter + alias Extensions.AiAgent.Adapter.SSEStream + alias Extensions.AiAgent.Types.Event + alias Extensions.AiAgent.Types.ToolCallBuffer + + @default_max_tokens 4096 + @default_anthropic_version "2023-06-01" + @default_anthropic_beta "interleaved-thinking-2025-05-14" + + @impl true + def stream(settings, messages, caller) do + url = settings["base_url"] <> "/v1/messages" + request = Finch.build(:post, url, headers(settings), Jason.encode!(build_body(settings, messages))) + SSEStream.run(request, &process_event/3, caller) + end + + defp process_event(%{"type" => "content_block_delta"} = message, buffer, caller) do + %{"index" => idx, "delta" => delta} = message + + case delta do + %{"type" => "text_delta", "text" => text} -> + Adapter.emit(caller, %Event{type: :text_delta, payload: %{delta: text}}) + buffer + + %{"type" => "thinking_delta", "thinking" => text} -> + Adapter.emit(caller, %Event{type: :thinking_delta, payload: %{delta: text}}) + buffer + + %{"type" => "input_json_delta", "partial_json" => partial} -> + ToolCallBuffer.append_args(buffer, idx, partial, caller) + + _ -> + buffer + end + end + + defp process_event( + %{"type" => "content_block_start", "content_block" => %{"type" => "tool_use"}} = message, + buffer, + _caller + ) do + %{"index" => idx, "content_block" => block} = message + ToolCallBuffer.start(buffer, idx, block["id"], block["name"]) + end + + defp process_event(%{"type" => "content_block_stop"} = message, buffer, caller) do + %{"index" => idx} = message + ToolCallBuffer.finish(buffer, idx, caller) + end + + defp process_event(%{"type" => "message_delta"} = message, buffer, caller) do + %{"delta" => delta, "usage" => usage} = message + Adapter.emit(caller, %Event{type: :usage, payload: %{output_tokens: usage["output_tokens"]}}) + + if stop_reason = delta["stop_reason"] do + Adapter.emit(caller, %Event{type: :done, payload: %{stop_reason: stop_reason}}) + end + + buffer + end + + defp process_event(%{"type" => "message_start"} = message, buffer, caller) do + %{"message" => %{"usage" => usage}} = message + Adapter.emit(caller, %Event{type: :usage, payload: %{input_tokens: usage["input_tokens"]}}) + buffer + end + + defp process_event(%{"type" => "error"} = message, buffer, caller) do + %{"error" => error} = message + Adapter.emit(caller, %Event{type: :error, payload: %{reason: error["message"]}}) + buffer + end + + defp process_event(event, buffer, _caller) do + require Logger + Logger.debug("AnthropicUnknownEvent type=#{inspect(event["type"])}") + buffer + end + + defp build_body(settings, messages) do + %{ + "model" => settings["model"], + "messages" => Enum.reject(messages, &(&1["role"] == "system")), + "max_tokens" => settings["max_tokens"] || @default_max_tokens, + "stream" => true + } + |> Adapter.maybe_put("system", settings["system_prompt"]) + |> Adapter.maybe_put("tools", anthropic_tools(settings["tools"])) + end + + defp anthropic_tools(nil), do: nil + + defp anthropic_tools(tools) when is_list(tools) do + Enum.map(tools, fn tool -> + %{ + "name" => tool["function"]["name"], + "description" => tool["function"]["description"], + "input_schema" => tool["function"]["parameters"] + } + end) + end + + defp headers(settings) do + [ + {"content-type", "application/json"}, + {"x-api-key", settings["api_key"]}, + {"anthropic-version", settings["anthropic_version"] || @default_anthropic_version}, + {"anthropic-beta", settings["anthropic_beta"] || @default_anthropic_beta} + ] + end +end diff --git a/lib/extensions/ai_agent/adapter/chat_completions.ex b/lib/extensions/ai_agent/adapter/chat_completions.ex new file mode 100644 index 000000000..e8109ace2 --- /dev/null +++ b/lib/extensions/ai_agent/adapter/chat_completions.ex @@ -0,0 +1,105 @@ +defmodule Extensions.AiAgent.Adapter.ChatCompletions do + @moduledoc """ + Adapter for the OpenAI `/v1/chat/completions` SSE protocol. + + Compatible with: OpenAI, OpenRouter, Groq, Together, Fireworks, DeepSeek, + Mistral, Cerebras, Perplexity, Ollama, vLLM, LM Studio, and custom endpoints. + """ + + @behaviour Extensions.AiAgent.Adapter + + alias Extensions.AiAgent.Adapter + alias Extensions.AiAgent.Adapter.SSEStream + alias Extensions.AiAgent.Types.Event + alias Extensions.AiAgent.Types.ToolCallBuffer + + @impl true + def stream(settings, messages, caller) do + url = settings["base_url"] <> "/chat/completions" + request = Finch.build(:post, url, headers(settings), Jason.encode!(build_body(settings, messages))) + SSEStream.run(request, &process_delta/3, caller) + end + + defp process_delta(%{"choices" => [_ | _]} = message, buffer, caller) do + %{"choices" => [choice | _]} = message + %{"delta" => delta, "finish_reason" => finish_reason} = choice + buffer = emit_delta(delta, buffer, caller) + buffer = accumulate_tool_calls(delta, buffer, caller) + emit_done(finish_reason, buffer, caller) + end + + defp process_delta(%{"usage" => usage}, buffer, caller) when not is_nil(usage) do + Adapter.emit(caller, %Event{ + type: :usage, + payload: %{input_tokens: usage["prompt_tokens"], output_tokens: usage["completion_tokens"]} + }) + + buffer + end + + defp process_delta(data, buffer, _caller) do + require Logger + Logger.debug("ChatCompletionsUnknownDelta keys=#{inspect(Map.keys(data))}") + buffer + end + + defp emit_delta(%{"content" => content}, buffer, caller) when is_binary(content) and content != "" do + Adapter.emit(caller, %Event{type: :text_delta, payload: %{delta: content}}) + buffer + end + + defp emit_delta(%{"reasoning_content" => content}, buffer, caller) when is_binary(content) and content != "" do + Adapter.emit(caller, %Event{type: :thinking_delta, payload: %{delta: content}}) + buffer + end + + defp emit_delta(%{"reasoning" => content}, buffer, caller) when is_binary(content) and content != "" do + Adapter.emit(caller, %Event{type: :thinking_delta, payload: %{delta: content}}) + buffer + end + + defp emit_delta(_delta, buffer, _caller), do: buffer + + defp accumulate_tool_calls(%{"tool_calls" => chunks}, buffer, caller) when is_list(chunks) do + Enum.reduce(chunks, buffer, fn chunk, acc -> + idx = chunk["index"] + acc = ToolCallBuffer.start(acc, idx, chunk["id"], get_in(chunk, ["function", "name"])) + ToolCallBuffer.append_args(acc, idx, get_in(chunk, ["function", "arguments"]) || "", caller) + end) + end + + defp accumulate_tool_calls(_delta, buffer, _caller), do: buffer + + defp emit_done("tool_calls", buffer, caller) do + buffer = ToolCallBuffer.finish_all(buffer, caller) + Adapter.emit(caller, %Event{type: :done, payload: %{stop_reason: "tool_calls"}}) + buffer + end + + defp emit_done(reason, buffer, caller) when is_binary(reason) do + Adapter.emit(caller, %Event{type: :done, payload: %{stop_reason: reason}}) + buffer + end + + defp emit_done(_reason, buffer, _caller), do: buffer + + defp build_body(settings, messages) do + %{ + "model" => settings["model"], + "messages" => messages, + "stream" => true, + "stream_options" => %{"include_usage" => true} + } + |> Adapter.maybe_put("tools", settings["tools"]) + |> Adapter.maybe_put("temperature", settings["temperature"]) + end + + defp headers(settings) do + [ + {"content-type", "application/json"}, + {"authorization", "Bearer #{settings["api_key"]}"} + ] + |> Adapter.maybe_prepend("HTTP-Referer", settings["http_referer"]) + |> Adapter.maybe_prepend("X-Title", settings["x_title"]) + end +end diff --git a/lib/extensions/ai_agent/adapter/sse_stream.ex b/lib/extensions/ai_agent/adapter/sse_stream.ex new file mode 100644 index 000000000..5fe9f4b1e --- /dev/null +++ b/lib/extensions/ai_agent/adapter/sse_stream.ex @@ -0,0 +1,77 @@ +defmodule Extensions.AiAgent.Adapter.SSEStream do + @moduledoc """ + Shared SSE streaming scaffold for AI provider adapters. + + Handles the Finch request, HTTP status/header chunks, SSE line parsing, and + JSON decoding. Each adapter supplies a `process_event/3` function that + translates provider-specific event shapes into `Extensions.AiAgent.Event` + structs and updates the `ToolCallBuffer` accumulator. + """ + + alias Extensions.AiAgent.Adapter + alias Extensions.AiAgent.Types.Event + alias Extensions.AiAgent.Types.ToolCallBuffer + + @receive_timeout :timer.minutes(5) + + @type process_event_fn :: (map(), ToolCallBuffer.t(), pid() -> ToolCallBuffer.t()) + + @spec run(Finch.Request.t(), process_event_fn(), pid()) :: :ok | {:error, term()} + def run(request, process_event_fn, caller) do + acc = {"", ToolCallBuffer.new()} + handler = &handle_chunk(&1, &2, caller, process_event_fn) + + case Finch.stream(request, AiAgent.Finch, acc, handler, receive_timeout: @receive_timeout) do + {:ok, _} -> :ok + {:error, reason, _acc} -> {:error, reason} + end + end + + defp handle_chunk({:status, status}, acc, caller, _process_event_fn) when status >= 400 do + Adapter.emit(caller, %Event{type: :error, payload: %{reason: "HTTP #{status}"}}) + acc + end + + defp handle_chunk({:status, _}, acc, _, _), do: acc + defp handle_chunk({:headers, _}, acc, _, _), do: acc + + defp handle_chunk({:data, chunk}, {buffer, tool_calls}, caller, process_event_fn) do + {lines, remaining} = parse_sse_lines(buffer <> chunk) + + tool_calls = + Enum.reduce(lines, tool_calls, fn line, acc -> + case Jason.decode(line) do + {:ok, data} -> + process_event_fn.(data, acc, caller) + + {:error, reason} -> + require Logger + Logger.debug("SseJsonDecodeError reason=#{inspect(reason)} line=#{inspect(line)}") + acc + end + end) + + {remaining, tool_calls} + end + + defp parse_sse_lines(buffer) do + case String.split(buffer, "\n\n") do + [incomplete] -> + {[], incomplete} + + parts -> + {complete, [incomplete]} = Enum.split(parts, -1) + + lines = + complete + |> Enum.flat_map(&String.split(&1, "\n")) + |> Enum.flat_map(fn + "data: [DONE]" -> [] + "data: " <> data -> [data] + _ -> [] + end) + + {lines, incomplete} + end + end +end diff --git a/lib/extensions/ai_agent/db_settings.ex b/lib/extensions/ai_agent/db_settings.ex new file mode 100644 index 000000000..107451cb0 --- /dev/null +++ b/lib/extensions/ai_agent/db_settings.ex @@ -0,0 +1,21 @@ +defmodule Extensions.AiAgent.DbSettings do + @moduledoc """ + Schema callbacks for the AI Agent extension. + """ + + def default do + %{ + "max_concurrent_sessions" => 10 + } + end + + # Each tuple: {field_name, validator_fn, encrypted?} + def required do + [ + {"protocol", &is_binary/1, false}, + {"base_url", &is_binary/1, false}, + {"model", &is_binary/1, false}, + {"api_key", &is_binary/1, true} + ] + end +end diff --git a/lib/extensions/ai_agent/session.ex b/lib/extensions/ai_agent/session.ex new file mode 100644 index 000000000..98d714f28 --- /dev/null +++ b/lib/extensions/ai_agent/session.ex @@ -0,0 +1,452 @@ +defmodule Extensions.AiAgent.Session do + @moduledoc """ + GenServer that manages one AI conversation session per channel join. + + Lifecycle: + - Started by `Extensions.AiAgent.SessionSupervisor` when a channel joins + with `config.ai.enabled = true`. + - Receives `agent_input` and `agent_cancel` from the channel process. + - Runs the adapter in a `Task.Supervisor.async_nolink` task so an adapter + crash broadcasts an error event rather than crashing the session. + - Broadcasts `Extensions.AiAgent.Event` structs to the channel's PubSub + topic as `:ai_events`. + - Persists conversation turns to `realtime.messages` (extension: ai_agent) + in the tenant's own database. Pass `session_id` in `config.ai` to + continue a prior session; omit to start fresh. + - Terminates when the channel process terminates. + """ + + use GenServer + use Realtime.Logs + + import Ecto.Query, only: [from: 2] + + alias Extensions.AiAgent.Adapter.AnthropicMessages + alias Extensions.AiAgent.Adapter.ChatCompletions + alias Extensions.AiAgent.Types.Event + alias Realtime.Api.Message + alias Realtime.Crypto + alias Realtime.GenCounter + alias Realtime.RateCounter + alias Realtime.Tenants + alias Realtime.Tenants.Connect + alias Realtime.Tenants.Repo + alias RealtimeWeb.RealtimeChannel.MessageDispatcher + alias RealtimeWeb.TenantBroadcaster + + @enforce_keys [ + :tenant_id, + :tenant_topic, + :session_id, + :channel_pid, + :settings, + :adapter, + :messages, + :channel_ref, + :events_rate_counter + ] + + defstruct [ + :tenant_id, + :tenant_topic, + :session_id, + :channel_pid, + :settings, + :adapter, + :messages, + :stream_task, + :channel_ref, + :events_rate_counter, + assistant_buffer: [], + token_usage: 0, + max_ai_tokens_per_minute: 0 + ] + + @type t :: %__MODULE__{ + tenant_id: String.t(), + tenant_topic: String.t(), + session_id: String.t(), + channel_pid: pid(), + settings: map(), + adapter: module(), + messages: list(map()), + stream_task: Task.t() | nil, + channel_ref: reference(), + events_rate_counter: RateCounter.Args.t(), + assistant_buffer: iodata(), + token_usage: non_neg_integer(), + max_ai_tokens_per_minute: non_neg_integer() + } + + @task_supervisor Extensions.AiAgent.TaskSupervisor + @shutdown_grace_ms 500 + + @spec start_link(keyword()) :: GenServer.on_start() + def start_link(opts) do + GenServer.start_link(__MODULE__, opts) + end + + @spec handle_input(pid(), map()) :: :ok + def handle_input(pid, input), do: GenServer.cast(pid, {:input, input}) + + @spec cancel(pid()) :: :ok + def cancel(pid), do: GenServer.cast(pid, :cancel) + + @max_heap_words 200_000 + + @impl true + def init(opts) do + Process.flag(:max_heap_size, @max_heap_words) + + tenant_id = Keyword.fetch!(opts, :tenant_id) + tenant_topic = Keyword.fetch!(opts, :tenant_topic) + raw_settings = Keyword.fetch!(opts, :settings) + channel_pid = Keyword.fetch!(opts, :channel_pid) + client_session_id = Keyword.get(opts, :session_id) + max_ai_events_per_second = Keyword.get(opts, :max_ai_events_per_second, 100) + max_ai_tokens_per_minute = Keyword.get(opts, :max_ai_tokens_per_minute, 60_000) + system_prompt = raw_settings["system_prompt"] + + settings = + Map.update(raw_settings, "api_key", nil, fn + nil -> nil + key -> Crypto.decrypt!(key) + end) + + with {:ok, adapter} <- resolve_adapter(settings) do + session_id = client_session_id || UUID.uuid4() + ref = Process.monitor(channel_pid) + + events_rate_counter = Tenants.ai_events_per_second_rate(tenant_id, max_ai_events_per_second) + RateCounter.new(events_rate_counter) + + Process.send_after(self(), :reset_token_window, :timer.minutes(1)) + + state = %__MODULE__{ + tenant_id: tenant_id, + tenant_topic: tenant_topic, + session_id: session_id, + channel_pid: channel_pid, + settings: settings, + adapter: adapter, + messages: system_messages(system_prompt), + channel_ref: ref, + events_rate_counter: events_rate_counter, + max_ai_tokens_per_minute: max_ai_tokens_per_minute + } + + init_result(state, client_session_id) + else + {:error, reason} -> {:stop, reason} + end + end + + defp system_messages(prompt) when is_binary(prompt) and prompt != "", do: [%{"role" => "system", "content" => prompt}] + defp system_messages(_), do: [] + + defp init_result(state, nil), do: {:ok, state} + defp init_result(state, _session_id), do: {:ok, state, {:continue, :load_history}} + + @impl true + def handle_continue(:load_history, state) do + prior = load_history(state.tenant_id, state.tenant_topic, state.session_id) + # Messages are stored newest-first and reversed once in start_stream. + # Prepend reversed history so Enum.reverse produces [system_msg | hist_asc...]. + state = %{state | messages: Enum.reverse(prior) ++ state.messages} + {:noreply, state} + end + + @max_input_bytes 64_000 + + @impl true + def handle_cast({:input, _} = msg, state) do + case check_rate_limits(state) do + :ok -> + dispatch_input(msg, state) + + {:error, reason} -> + broadcast_event(state, %Event{type: :error, payload: %{reason: reason}}) + {:noreply, state} + end + end + + def handle_cast(:cancel, state) do + {:noreply, cancel_stream(state)} + end + + def handle_cast(:emit_session_started, state) do + notify_session_started(state) + {:noreply, state} + end + + def handle_cast(msg, state) do + log_warning("UnhandledCast", inspect(msg, limit: 3, printable_limit: 80)) + {:noreply, state} + end + + @impl true + def handle_info({:ai_event, %Event{type: :text_delta} = event}, state) do + %Event{payload: %{delta: delta}} = event + broadcast_event(state, event) + {:noreply, %{state | assistant_buffer: [delta | state.assistant_buffer]}} + end + + def handle_info({:ai_event, %Event{type: :done} = event}, state) do + maybe_persist_assistant_turn(state) + broadcast_event(state, event) + {:noreply, %{state | stream_task: nil, assistant_buffer: []}} + end + + def handle_info({:ai_event, %Event{type: :error} = event}, state) do + broadcast_event(state, event) + {:noreply, %{state | stream_task: nil, assistant_buffer: []}} + end + + def handle_info({:ai_event, %Event{type: :tool_call_done} = event}, state) do + tool_call = event.payload + + assistant_tool_call = %{ + "role" => "assistant", + "tool_calls" => [ + %{ + "id" => tool_call.tool_call_id, + "type" => "function", + "function" => %{"name" => tool_call.name, "arguments" => tool_call.arguments} + } + ] + } + + persist_messages(state, [assistant_tool_call]) + broadcast_event(state, event) + {:noreply, %{state | messages: [assistant_tool_call | state.messages]}} + end + + def handle_info({:ai_event, %Event{type: :usage} = event}, state) do + %Event{payload: payload} = event + tokens = Map.get(payload, :input_tokens, 0) + Map.get(payload, :output_tokens, 0) + broadcast_event(state, event) + {:noreply, %{state | token_usage: state.token_usage + tokens}} + end + + def handle_info(:reset_token_window, state) do + Process.send_after(self(), :reset_token_window, :timer.minutes(1)) + {:noreply, %{state | token_usage: 0}} + end + + def handle_info({:ai_event, %Event{} = event}, state) do + broadcast_event(state, event) + {:noreply, state} + end + + def handle_info({ref, result}, %{stream_task: %Task{ref: ref}} = state) do + Process.demonitor(ref, [:flush]) + + case result do + {:error, reason} -> + log_error("AiStreamError", reason) + broadcast_event(state, %Event{type: :error, payload: %{reason: "stream_failed"}}) + + _ -> + :ok + end + + {:noreply, %{state | stream_task: nil, assistant_buffer: []}} + end + + def handle_info({ref, _result}, %{stream_task: nil} = state) when is_reference(ref) do + Process.demonitor(ref, [:flush]) + {:noreply, state} + end + + def handle_info({:DOWN, ref, :process, _pid, reason}, %{stream_task: %Task{ref: ref}} = state) do + log_error("AiStreamCrash", reason) + broadcast_event(state, %Event{type: :error, payload: %{reason: "stream_failed"}}) + {:noreply, %{state | stream_task: nil, assistant_buffer: []}} + end + + def handle_info({:DOWN, ref, :process, _pid, _reason}, %{channel_ref: ref} = state) do + {:stop, :normal, state} + end + + def handle_info(msg, state) do + log_warning("UnhandledInfo", inspect(msg, limit: 3, printable_limit: 80)) + {:noreply, state} + end + + @impl true + def terminate(_reason, state) do + cancel_stream(state) + :ok + end + + defp start_stream(state) do + state = cancel_stream(state) + caller = self() + adapter = state.adapter + settings = state.settings + messages = Enum.reverse(state.messages) + + task = + Task.Supervisor.async_nolink(@task_supervisor, fn -> + adapter.stream(settings, messages, caller) + end) + + %{state | stream_task: task} + end + + defp cancel_stream(%{stream_task: nil} = state), do: state + + defp cancel_stream(%{stream_task: task} = state) do + Task.shutdown(task, @shutdown_grace_ms) + %{state | stream_task: nil, assistant_buffer: []} + end + + defp resolve_adapter(%{"protocol" => "openai_compatible", "base_url" => url}) when is_binary(url) and url != "", + do: {:ok, ChatCompletions} + + defp resolve_adapter(%{"protocol" => "anthropic", "base_url" => url}) when is_binary(url) and url != "", + do: {:ok, AnthropicMessages} + + defp resolve_adapter(%{"protocol" => protocol, "base_url" => url}) when is_binary(url) and url != "", + do: {:error, "unknown protocol: #{protocol}"} + + defp resolve_adapter(%{"protocol" => _}), do: {:error, "missing base_url in settings"} + defp resolve_adapter(_), do: {:error, "missing protocol in settings"} + + defp check_rate_limits(state) do + GenCounter.add(state.events_rate_counter.id) + + case RateCounter.get(state.events_rate_counter) do + {:ok, %{limit: %{triggered: true}}} -> {:error, "rate_limit_exceeded"} + _ -> check_token_rate(state) + end + end + + defp check_token_rate(%{max_ai_tokens_per_minute: max, token_usage: usage}) when max > 0 and usage >= max do + {:error, "token_limit_exceeded"} + end + + defp check_token_rate(_state), do: :ok + + defp dispatch_input({:input, %{"text" => text}}, state) when is_binary(text) do + stream_with_message(%{"role" => "user", "content" => text}, text, state) + end + + defp dispatch_input( + {:input, %{"tool_result" => %{"tool_call_id" => id, "content" => content}}}, + state + ) + when is_binary(id) and is_binary(content) do + stream_with_message(%{"role" => "tool", "tool_call_id" => id, "content" => content}, content, state) + end + + defp dispatch_input(_msg, state), do: {:noreply, state} + + defp stream_with_message(_msg, body, state) when byte_size(body) > @max_input_bytes do + broadcast_event(state, %Event{type: :error, payload: %{reason: "input_too_large"}}) + {:noreply, state} + end + + defp stream_with_message(%{"role" => "user"} = msg, _body, state) do + %{"content" => text} = msg + persist_turn(state, msg, "agent_input", %{"text" => text}) + {:noreply, start_stream(%{state | messages: [msg | state.messages]})} + end + + defp stream_with_message(msg, _body, state) do + persist_messages(state, [msg]) + {:noreply, start_stream(%{state | messages: [msg | state.messages]})} + end + + defp maybe_persist_assistant_turn(%__MODULE__{assistant_buffer: []}), do: :ok + + defp maybe_persist_assistant_turn(%__MODULE__{} = state) do + %__MODULE__{assistant_buffer: buffer} = state + text = buffer |> Enum.reverse() |> IO.iodata_to_binary() + persist_turn(state, %{"role" => "assistant", "content" => text}, "agent_done", %{"text" => text}) + end + + defp notify_session_started(%__MODULE__{} = state) do + %__MODULE__{session_id: session_id} = state + broadcast_event(state, %Event{type: :session_started, payload: %{session_id: session_id}}) + end + + defp broadcast_event(%__MODULE__{tenant_id: tenant_id, tenant_topic: topic}, %Event{} = event) do + payload = %{ + "event" => Event.broadcast_event(event), + "payload" => event.payload, + "type" => "ai_agent" + } + + message = %Phoenix.Socket.Broadcast{ + topic: topic, + event: "ai_event", + payload: payload + } + + TenantBroadcaster.pubsub_broadcast(tenant_id, topic, message, MessageDispatcher, :ai_events) + end + + @history_limit 100 + + defp load_history(tenant_id, topic, session_id) do + query = + from(m in Message, + where: m.topic == ^topic, + where: m.extension == :ai_agent, + where: fragment("(?)->>'session_id' = ?", m.payload, ^session_id), + order_by: [asc: m.inserted_at], + limit: @history_limit + ) + + with {:ok, db_conn} <- Connect.lookup_or_start_connection(tenant_id), + {:ok, rows} <- Repo.all(db_conn, query, Message) do + Enum.map(rows, &Map.drop(&1.payload, ["session_id"])) + else + error -> + log_error("LoadHistoryError", error) + [] + end + rescue + exception -> + log_error("LoadHistoryError", exception) + [] + end + + defp persist_messages(%__MODULE__{tenant_id: tid, tenant_topic: topic, session_id: sid}, messages) do + changesets = Enum.map(messages, &llm_message_changeset(topic, sid, &1)) + persist(tid, changesets) + end + + defp persist_turn(%__MODULE__{tenant_id: tid, tenant_topic: topic, session_id: sid}, llm_msg, event, event_payload) do + persist(tid, [ + llm_message_changeset(topic, sid, llm_msg), + Message.changeset(%Message{}, %{ + topic: topic, + extension: :ai_agent_event, + event: event, + payload: event_payload, + private: true + }) + ]) + end + + defp llm_message_changeset(topic, sid, msg) do + Message.changeset(%Message{}, %{ + topic: topic, + extension: :ai_agent, + payload: Map.put(msg, "session_id", sid), + private: true + }) + end + + defp persist(tenant_id, changesets) do + with {:ok, db_conn} <- Connect.lookup_or_start_connection(tenant_id) do + Repo.insert_all_entries(db_conn, changesets, Message) + else + error -> log_error("PersistFailed", error) + end + rescue + exception -> log_error("PersistFailed", exception) + end +end diff --git a/lib/extensions/ai_agent/session_supervisor.ex b/lib/extensions/ai_agent/session_supervisor.ex new file mode 100644 index 000000000..f5034d6b9 --- /dev/null +++ b/lib/extensions/ai_agent/session_supervisor.ex @@ -0,0 +1,24 @@ +defmodule Extensions.AiAgent.SessionSupervisor do + @moduledoc """ + DynamicSupervisor that manages AI agent sessions. + One session per channel join with AI enabled. + """ + + use DynamicSupervisor + + alias Extensions.AiAgent.Session + + def start_link(_opts) do + DynamicSupervisor.start_link(__MODULE__, [], name: __MODULE__) + end + + @impl true + def init(_opts) do + DynamicSupervisor.init(strategy: :one_for_one) + end + + @spec start_session(keyword()) :: {:ok, pid()} | {:error, term()} + def start_session(opts) do + DynamicSupervisor.start_child(__MODULE__, {Session, opts}) + end +end diff --git a/lib/extensions/ai_agent/supervisor.ex b/lib/extensions/ai_agent/supervisor.ex new file mode 100644 index 000000000..f698bfc9b --- /dev/null +++ b/lib/extensions/ai_agent/supervisor.ex @@ -0,0 +1,31 @@ +defmodule Extensions.AiAgent.Supervisor do + @moduledoc """ + Top-level supervisor for the AI Agent extension. + Starts the shared Finch HTTP pool, task supervisor, and session DynamicSupervisor. + + max_restarts is set high so transient Finch pool crashes (e.g. abrupt connection + closures during tests or sudden provider outages) do not cascade and kill the + SessionSupervisor. In practice Finch restarts within milliseconds; 100/60s + gives ~1.6 restarts/sec headroom before the supervisor itself shuts down. + """ + + use Supervisor + + def start_link do + Supervisor.start_link(__MODULE__, [], name: __MODULE__) + end + + @impl true + def init(_opts) do + finch_config = Application.get_env(:realtime, AiAgent.Finch, []) + finch_pools = Keyword.get(finch_config, :pools, %{}) + + children = [ + {Finch, name: AiAgent.Finch, pools: finch_pools}, + {Task.Supervisor, name: Extensions.AiAgent.TaskSupervisor}, + Extensions.AiAgent.SessionSupervisor + ] + + Supervisor.init(children, strategy: :one_for_one, max_restarts: 100, max_seconds: 60) + end +end diff --git a/lib/extensions/ai_agent/types/event.ex b/lib/extensions/ai_agent/types/event.ex new file mode 100644 index 000000000..9e74d7fec --- /dev/null +++ b/lib/extensions/ai_agent/types/event.ex @@ -0,0 +1,28 @@ +defmodule Extensions.AiAgent.Types.Event do + @moduledoc """ + Internal envelope for events streamed from an AI provider to a Session. + Each event maps 1:1 to a broadcast sent to the client channel topic. + """ + + @type event_type :: + :session_started + | :text_delta + | :thinking_delta + | :tool_call_delta + | :tool_call_done + | :usage + | :done + | :error + | :rate_limit + + @type t :: %__MODULE__{type: event_type(), payload: map()} + + @enforce_keys [:type, :payload] + defstruct [:type, :payload] + + @doc "Returns the broadcast event name for a given event type." + @spec broadcast_event(t()) :: String.t() + def broadcast_event(%__MODULE__{type: type}) do + "agent_" <> Atom.to_string(type) + end +end diff --git a/lib/extensions/ai_agent/types/tool_call_buffer.ex b/lib/extensions/ai_agent/types/tool_call_buffer.ex new file mode 100644 index 000000000..f9871227a --- /dev/null +++ b/lib/extensions/ai_agent/types/tool_call_buffer.ex @@ -0,0 +1,71 @@ +defmodule Extensions.AiAgent.Types.ToolCallBuffer do + @moduledoc """ + Accumulates streaming tool call chunks from AI provider SSE streams. + + Arguments are stored as an iolist and joined only at `finish`/`finish_all` + to avoid O(n²) binary concatenation across streaming chunks. + + `start/4` upserts an entry without overwriting an already-set id/name, which + handles both Anthropic (id+name on `content_block_start`) and OpenAI + (id+name on the first delta chunk). + """ + + alias Extensions.AiAgent.Adapter + alias Extensions.AiAgent.Types.Event + + @type entry :: %{id: String.t() | nil, name: String.t() | nil, arguments: iodata()} + @type t :: %{non_neg_integer() => entry()} + + @spec new() :: t() + def new, do: %{} + + @spec start(t(), non_neg_integer(), String.t() | nil, String.t() | nil) :: t() + def start(buffer, idx, id, name) do + Map.update(buffer, idx, %{id: id, name: name, arguments: []}, fn entry -> + %{entry | id: entry.id || id, name: entry.name || name} + end) + end + + @spec append_args(t(), non_neg_integer(), String.t(), pid()) :: t() + def append_args(buffer, _idx, "", _caller), do: buffer + + def append_args(buffer, idx, args_delta, caller) do + entry = Map.get(buffer, idx, %{id: nil, name: nil, arguments: []}) + updated = %{entry | arguments: [entry.arguments | [args_delta]]} + + Adapter.emit(caller, %Event{ + type: :tool_call_delta, + payload: %{tool_call_id: updated.id, name: updated.name, arguments_delta: args_delta} + }) + + Map.put(buffer, idx, updated) + end + + @spec finish(t(), non_neg_integer(), pid()) :: t() + def finish(buffer, idx, caller) do + case Map.pop(buffer, idx) do + {nil, buffer} -> + buffer + + {tc, buffer} -> + Adapter.emit(caller, %Event{ + type: :tool_call_done, + payload: %{tool_call_id: tc.id, name: tc.name, arguments: IO.iodata_to_binary(tc.arguments)} + }) + + buffer + end + end + + @spec finish_all(t(), pid()) :: t() + def finish_all(buffer, caller) do + Enum.each(buffer, fn {_idx, tc} -> + Adapter.emit(caller, %Event{ + type: :tool_call_done, + payload: %{tool_call_id: tc.id, name: tc.name, arguments: IO.iodata_to_binary(tc.arguments)} + }) + end) + + new() + end +end diff --git a/lib/realtime/api.ex b/lib/realtime/api.ex index b8d0d2b12..cbf60b451 100644 --- a/lib/realtime/api.ex +++ b/lib/realtime/api.ex @@ -8,7 +8,6 @@ defmodule Realtime.Api do alias Ecto.Changeset alias Extensions.PostgresCdcRls - alias Realtime.Api.Extensions alias Realtime.Api.FeatureFlag alias Realtime.Api.Tenant alias Realtime.FeatureFlags @@ -233,21 +232,15 @@ defmodule Realtime.Api do end end - defp list_extensions(type) do - query = from(e in Extensions, where: e.type == ^type, select: e) - replica = Replica.replica() - replica.all(query) - end - def rename_settings_field(from, to) do if master_region?() do - for extension <- list_extensions("postgres_cdc_rls") do - {value, settings} = Map.pop(extension.settings, from) - new_settings = Map.put(settings, to, value) + {:ok, %{rows: rows}} = + Repo.query("SELECT id, settings FROM extensions WHERE type = $1", ["postgres_cdc_rls"]) - extension - |> Changeset.cast(%{settings: new_settings}, [:settings]) - |> Repo.update() + for [id, settings] <- rows do + {value, new_settings} = Map.pop(settings, from) + updated = Map.put(new_settings, to, value) + Repo.query!("UPDATE extensions SET settings = $1 WHERE id = $2", [updated, id]) end else call(:rename_settings_field, [from, to]) diff --git a/lib/realtime/api/extensions.ex b/lib/realtime/api/extensions.ex index 4ecb1a0f0..096d16ec4 100644 --- a/lib/realtime/api/extensions.ex +++ b/lib/realtime/api/extensions.ex @@ -5,71 +5,147 @@ defmodule Realtime.Api.Extensions do use Ecto.Schema import Ecto.Changeset + import Bitwise alias Realtime.Crypto @primary_key {:id, :binary_id, autogenerate: true} @foreign_key_type :binary_id - @derive {Jason.Encoder, only: [:type, :inserted_at, :updated_at, :settings]} + @derive {Jason.Encoder, only: [:type, :name, :inserted_at, :updated_at, :settings]} schema "extensions" do field(:type, :string) + field(:name, :string) field(:settings, :map) belongs_to(:tenant, Realtime.Api.Tenant, foreign_key: :tenant_external_id, type: :string) timestamps() end def changeset(extension, attrs) do - {attrs1, required_settings} = + {attrs, required_settings} = case attrs["type"] do nil -> {attrs, []} type -> %{default: default, required: required} = Realtime.Extensions.db_settings(type) - - { - %{attrs | "settings" => Map.merge(default, attrs["settings"])}, - required - } + {%{attrs | "settings" => Map.merge(default, attrs["settings"])}, required} end extension - |> cast(attrs1, [:type, :tenant_external_id, :settings]) + |> cast(attrs, [:type, :name, :tenant_external_id, :settings]) |> validate_required([:type, :settings]) - |> unique_constraint([:tenant_external_id, :type]) + |> unique_constraint([:tenant_external_id, :type, :name]) + |> validate_ai_agent() |> validate_required_settings(required_settings) |> encrypt_settings(required_settings) end + @blocked_cidrs [ + {{10, 0, 0, 0}, 8}, + {{172, 16, 0, 0}, 12}, + {{192, 168, 0, 0}, 16}, + {{127, 0, 0, 0}, 8}, + {{169, 254, 0, 0}, 16}, + {{0, 0, 0, 0}, 8}, + {{100, 64, 0, 0}, 10} + ] + + defp validate_ai_agent(changeset) do + if get_field(changeset, :type) == "ai_agent" do + changeset + |> validate_required([:name]) + |> validate_change(:settings, &validate_base_url_settings/2) + |> validate_change(:settings, &validate_header_settings/2) + |> validate_change(:settings, &validate_system_prompt_settings/2) + else + changeset + end + end + + @loopback_hosts ~w(localhost 127.0.0.1 ::1) + @dns_timeout_ms 1_000 + + defp validate_base_url_settings(_, %{"base_url" => url}) when is_binary(url) do + case URI.parse(url) do + %URI{scheme: "https", host: host} when is_binary(host) -> + check_https_host(host) + + %URI{scheme: "http", host: host} when host in @loopback_hosts -> + [] + + %URI{scheme: "http"} -> + [{:settings, "base_url with http scheme is only permitted for loopback hosts (localhost, 127.0.0.1)"}] + + _ -> + [{:settings, "base_url must use https scheme"}] + end + end + + defp validate_base_url_settings(_, %{"base_url" => nil}), do: [] + defp validate_base_url_settings(_, %{"base_url" => _}), do: [{:settings, "base_url must be a string"}] + defp validate_base_url_settings(_, _), do: [] + + defp validate_system_prompt_settings(_, %{"system_prompt" => v}) when is_binary(v), do: [] + defp validate_system_prompt_settings(_, %{"system_prompt" => nil}), do: [] + defp validate_system_prompt_settings(_, %{"system_prompt" => _}), do: [{:settings, "system_prompt must be a string"}] + defp validate_system_prompt_settings(_, _), do: [] + + defp validate_header_settings(_, settings) do + Enum.flat_map(["http_referer", "x_title"], fn key -> + case settings[key] do + nil -> + [] + + v when is_binary(v) -> + if String.contains?(v, ["\r", "\n"]), do: [{:settings, "#{key} contains invalid characters"}], else: [] + + _ -> + [{:settings, "#{key} must be a string"}] + end + end) + end + + defp check_https_host(host) do + case Realtime.DNS.getaddrs(String.to_charlist(host), :inet, @dns_timeout_ms) do + {:ok, ips} -> + if Enum.any?(ips, &private_ip?/1), + do: [{:settings, "base_url resolves to a private or reserved address"}], + else: [] + + {:error, _} -> + [{:settings, "base_url host cannot be resolved"}] + end + end + + defp private_ip?(ip) do + ip_int = ip_to_int(ip) + Enum.any?(@blocked_cidrs, fn {network, prefix_len} -> in_cidr?(ip_int, ip_to_int(network), prefix_len) end) + end + + defp ip_to_int({a, b, c, d}), do: a <<< 24 ||| b <<< 16 ||| c <<< 8 ||| d + + defp in_cidr?(ip_int, network_int, prefix_len) do + mask = 0xFFFFFFFF <<< (32 - prefix_len) &&& 0xFFFFFFFF + (ip_int &&& mask) == (network_int &&& mask) + end + def encrypt_settings(changeset, required) do update_change(changeset, :settings, fn settings -> Enum.reduce(required, settings, fn - {field, _, true}, acc -> - encrypted = Crypto.encrypt!(settings[field]) - %{acc | field => encrypted} - - _, acc -> - acc + {field, _, true}, acc -> %{acc | field => Crypto.encrypt!(settings[field])} + _, acc -> acc end) end) end def validate_required_settings(changeset, required) do - validate_change(changeset, :settings, fn - _, value -> - Enum.reduce(required, [], fn {field, checker, _}, acc -> - case value[field] do - nil -> - [{:settings, "#{field} can't be blank"} | acc] - - data -> - if checker.(data) do - acc - else - [{:settings, "#{field} is invalid"} | acc] - end - end - end) + validate_change(changeset, :settings, fn _, value -> + Enum.reduce(required, [], fn {field, checker, _}, acc -> + case value[field] do + nil -> [{:settings, "#{field} can't be blank"} | acc] + data -> if checker.(data), do: acc, else: [{:settings, "#{field} is invalid"} | acc] + end + end) end) end end diff --git a/lib/realtime/api/message.ex b/lib/realtime/api/message.ex index 1c7bb5b63..2d0e61e49 100644 --- a/lib/realtime/api/message.ex +++ b/lib/realtime/api/message.ex @@ -12,7 +12,7 @@ defmodule Realtime.Api.Message do @timestamps_opts [type: :naive_datetime_usec] schema "messages" do field(:topic, :string) - field(:extension, Ecto.Enum, values: [:broadcast, :presence]) + field(:extension, Ecto.Enum, values: [:broadcast, :presence, :ai_agent, :ai_agent_event]) field(:payload, :map) field(:event, :string) field(:private, :boolean) diff --git a/lib/realtime/api/tenant.ex b/lib/realtime/api/tenant.ex index 03755a1c5..29c2d60ea 100644 --- a/lib/realtime/api/tenant.ex +++ b/lib/realtime/api/tenant.ex @@ -34,6 +34,9 @@ defmodule Realtime.Api.Tenant do field(:client_presence_window_ms, :integer) field(:presence_enabled, :boolean, default: false) field(:feature_flags, :map, default: %{}) + field(:ai_enabled, :boolean, default: false) + field(:max_ai_events_per_second, :integer, default: 100) + field(:max_ai_tokens_per_minute, :integer, default: 60_000) has_many(:extensions, Realtime.Api.Extensions, foreign_key: :tenant_external_id, @@ -84,7 +87,10 @@ defmodule Realtime.Api.Tenant do :max_client_presence_events_per_window, :client_presence_window_ms, :presence_enabled, - :feature_flags + :feature_flags, + :ai_enabled, + :max_ai_events_per_second, + :max_ai_tokens_per_minute ]) |> validate_required([:external_id]) |> check_constraint(:jwt_secret, diff --git a/lib/realtime/dns.ex b/lib/realtime/dns.ex new file mode 100644 index 000000000..468f93b26 --- /dev/null +++ b/lib/realtime/dns.ex @@ -0,0 +1,4 @@ +defmodule Realtime.DNS do + @moduledoc false + def getaddrs(host, family, timeout), do: :inet.getaddrs(host, family, timeout) +end diff --git a/lib/realtime/messages.ex b/lib/realtime/messages.ex index e209461a2..d6946fcaa 100644 --- a/lib/realtime/messages.ex +++ b/lib/realtime/messages.ex @@ -17,14 +17,16 @@ defmodule Realtime.Messages do Only allowed for private channels """ - @spec replay(pid, String.t(), String.t(), non_neg_integer, non_neg_integer) :: + @spec replay(pid, String.t(), String.t(), non_neg_integer, non_neg_integer, list(atom)) :: {:ok, Message.t(), [String.t()]} | {:error, term} | {:error, :rpc_error, term} - def replay(conn, tenant_id, topic, since, limit) + def replay(conn, tenant_id, topic, since, limit, extensions \\ [:broadcast]) + + def replay(conn, tenant_id, topic, since, limit, extensions) when node(conn) == node() and is_integer(since) and is_integer(limit) do limit = max(min(limit, @hard_limit), 1) with {:ok, since} <- DateTime.from_unix(since, :millisecond), - {:ok, messages} <- messages(conn, tenant_id, topic, since, limit) do + {:ok, messages} <- messages(conn, tenant_id, topic, since, limit, extensions) do {:ok, Enum.reverse(messages), MapSet.new(messages, & &1.id)} else {:error, :postgrex_exception} -> {:error, :failed_to_replay_messages} @@ -33,16 +35,17 @@ defmodule Realtime.Messages do end end - def replay(conn, tenant_id, topic, since, limit) when is_integer(since) and is_integer(limit) do - Realtime.GenRpc.call(node(conn), __MODULE__, :replay, [conn, tenant_id, topic, since, limit], + def replay(conn, tenant_id, topic, since, limit, extensions) + when is_integer(since) and is_integer(limit) do + Realtime.GenRpc.call(node(conn), __MODULE__, :replay, [conn, tenant_id, topic, since, limit, extensions], key: topic, tenant_id: tenant_id ) end - def replay(_, _, _, _, _), do: {:error, :invalid_replay_params} + def replay(_, _, _, _, _, _), do: {:error, :invalid_replay_params} - defp messages(conn, tenant_id, topic, since, limit) do + defp messages(conn, tenant_id, topic, since, limit, extensions) do since = DateTime.to_naive(since) # We want to avoid searching partitions in the future as they should be empty # so we limit to 1 minute in the future to account for any potential drift @@ -53,7 +56,7 @@ defmodule Realtime.Messages do where: m.topic == ^topic and m.private == true and - m.extension == :broadcast and + m.extension in ^extensions and m.inserted_at >= ^since and m.inserted_at < ^now, limit: ^limit, diff --git a/lib/realtime/tenants.ex b/lib/realtime/tenants.ex index 90f236297..299ed3192 100644 --- a/lib/realtime/tenants.ex +++ b/lib/realtime/tenants.ex @@ -334,6 +334,42 @@ defmodule Realtime.Tenants do {:channel, :presence_events, tenant.external_id} end + @doc "RateCounter arguments for counting AI agent inputs per second." + @spec ai_events_per_second_rate(Tenant.t()) :: RateCounter.Args.t() + def ai_events_per_second_rate(%Tenant{} = tenant) do + ai_events_per_second_rate(tenant.external_id, tenant.max_ai_events_per_second) + end + + @spec ai_events_per_second_rate(String.t(), non_neg_integer) :: RateCounter.Args.t() + def ai_events_per_second_rate(tenant_id, max_ai_events_per_second) do + opts = [ + tick: :timer.seconds(5), + max_bucket_len: 12, + telemetry: %{ + event_name: [:channel, :ai_events], + measurements: %{limit: max_ai_events_per_second}, + metadata: %{tenant: tenant_id} + }, + limit: [ + value: max_ai_events_per_second, + measurement: :avg, + log: true, + log_fn: fn -> + Logger.error("AiEventsPerSecondRateLimitReached", + external_id: tenant_id, + project: tenant_id + ) + end + ] + ] + + %RateCounter.Args{id: ai_events_per_second_key(tenant_id), opts: opts} + end + + @spec ai_events_per_second_key(Tenant.t() | String.t()) :: {:channel, :ai_events, String.t()} + def ai_events_per_second_key(tenant) when is_binary(tenant), do: {:channel, :ai_events, tenant} + def ai_events_per_second_key(%Tenant{} = tenant), do: {:channel, :ai_events, tenant.external_id} + @spec authorization_errors_per_second_rate(Tenant.t()) :: RateCounter.Args.t() def authorization_errors_per_second_rate(%Tenant{external_id: external_id} = tenant) do opts = [ diff --git a/lib/realtime/tenants/authorization.ex b/lib/realtime/tenants/authorization.ex index 06e0a5cb9..73b4cde8d 100644 --- a/lib/realtime/tenants/authorization.ex +++ b/lib/realtime/tenants/authorization.ex @@ -70,22 +70,16 @@ defmodule Realtime.Tenants.Authorization do def get_read_authorizations(policies, db_conn, authorization_context, opts \\ []) def get_read_authorizations(policies, db_conn, authorization_context, opts) when node() == node(db_conn) do - rate_counter = rate_counter(authorization_context.tenant_id) - - if rate_counter.limit.triggered == false do + with_rate_check(authorization_context.tenant_id, fn rate_counter -> db_conn |> get_read_policies_for_connection(authorization_context, policies, opts) |> handle_policies_result(rate_counter) - else - {:error, :increase_connection_pool} - end + end) end # Remote call def get_read_authorizations(policies, db_conn, authorization_context, opts) do - rate_counter = rate_counter(authorization_context.tenant_id) - - if rate_counter.limit.triggered == false do + with_rate_check(authorization_context.tenant_id, fn rate_counter -> case GenRpc.call( node(db_conn), __MODULE__, @@ -104,9 +98,7 @@ defmodule Realtime.Tenants.Authorization do response -> response end - else - {:error, :increase_connection_pool} - end + end) end @doc """ @@ -125,22 +117,16 @@ defmodule Realtime.Tenants.Authorization do def get_write_authorizations(policies, db_conn, authorization_context, opts \\ []) def get_write_authorizations(policies, db_conn, authorization_context, opts) when node() == node(db_conn) do - rate_counter = rate_counter(authorization_context.tenant_id) - - if rate_counter.limit.triggered == false do + with_rate_check(authorization_context.tenant_id, fn rate_counter -> db_conn |> get_write_policies_for_connection(authorization_context, policies, opts) |> handle_policies_result(rate_counter) - else - {:error, :increase_connection_pool} - end + end) end # Remote call def get_write_authorizations(policies, db_conn, authorization_context, opts) do - rate_counter = rate_counter(authorization_context.tenant_id) - - if rate_counter.limit.triggered == false do + with_rate_check(authorization_context.tenant_id, fn rate_counter -> case GenRpc.call( node(db_conn), __MODULE__, @@ -159,9 +145,7 @@ defmodule Realtime.Tenants.Authorization do response -> response end - else - {:error, :increase_connection_pool} - end + end) end def get_write_authorizations(db_conn, authorization_context), @@ -286,53 +270,63 @@ defmodule Realtime.Tenants.Authorization do ) end - @all_extensions [:broadcast, :presence] + @all_extensions [:broadcast, :presence, :ai_agent] defp extensions_to_check(opts) do - if Keyword.get(opts, :presence_enabled?, true), - do: @all_extensions, - else: [:broadcast] + for {ext, enabled} <- [ + broadcast: true, + presence: Keyword.get(opts, :presence_enabled?, true), + ai_agent: Keyword.get(opts, :ai_enabled?, false) + ], + enabled, + do: ext end defp check_read_policies(conn, authorization_context, messages_by_extension, policies) do ids = Map.values(messages_by_extension) - query = from(m in Message, where: m.topic == ^authorization_context.topic and m.id in ^ids) with {:ok, res} <- Repo.all(conn, query, Message) do returned_ids = MapSet.new(res, & &1.id) - Enum.reduce(@all_extensions, policies, fn extension, acc -> - can? = - Map.has_key?(messages_by_extension, extension) and - MapSet.member?(returned_ids, messages_by_extension[extension]) + Enum.reduce(@all_extensions, policies, fn ext, acc -> + readable = + case Map.get(messages_by_extension, ext) do + nil -> false + msg_id -> MapSet.member?(returned_ids, msg_id) + end - Policies.update_policies(acc, extension, :read, can?) + Policies.update_policies(acc, ext, :read, readable) end) end end defp check_write_policies(conn, authorization_context, extensions, policies) do - Enum.reduce(@all_extensions, policies, fn extension, acc -> + Enum.reduce_while(@all_extensions, policies, fn extension, acc -> if extension in extensions do changeset = Message.changeset(%Message{}, %{topic: authorization_context.topic, extension: extension}) case Repo.insert(conn, changeset, Message, mode: :savepoint) do {:ok, _} -> - Policies.update_policies(acc, extension, :write, true) + {:cont, Policies.update_policies(acc, extension, :write, true)} {:error, %Postgrex.Error{postgres: %{code: :insufficient_privilege}}} -> - Policies.update_policies(acc, extension, :write, false) + {:cont, Policies.update_policies(acc, extension, :write, false)} e -> - e + {:halt, e} end else - Policies.update_policies(acc, extension, :write, false) + {:cont, Policies.update_policies(acc, extension, :write, false)} end end) end + defp with_rate_check(tenant_id, fun) do + rate_counter = rate_counter(tenant_id) + if rate_counter.limit.triggered, do: {:error, :increase_connection_pool}, else: fun.(rate_counter) + end + defp rate_counter(tenant_id) do %Tenant{} = tenant = Realtime.Tenants.Cache.get_tenant_by_external_id(tenant_id) rate_counter = Realtime.Tenants.authorization_errors_per_second_rate(tenant) diff --git a/lib/realtime/tenants/authorization/policies.ex b/lib/realtime/tenants/authorization/policies.ex index 57a4df401..52e364236 100644 --- a/lib/realtime/tenants/authorization/policies.ex +++ b/lib/realtime/tenants/authorization/policies.ex @@ -2,20 +2,23 @@ defmodule Realtime.Tenants.Authorization.Policies do @moduledoc """ Policies structure that holds the required authorization information for a given connection. - Currently there are two types of policies: - * Realtime.Tenants.Authorization.Policies.BroadcastPolicies - Used to store the access to Broadcast feature on a given Topic - * Realtime.Tenants.Authorization.Policies.PresencePolicies - Used to store the access to Presence feature on a given Topic + * Realtime.Tenants.Authorization.Policies.BroadcastPolicies - Broadcast feature access + * Realtime.Tenants.Authorization.Policies.PresencePolicies - Presence feature access + * Realtime.Tenants.Authorization.Policies.AiPolicies - AI agent feature access """ + alias Realtime.Tenants.Authorization.Policies.AiPolicies alias Realtime.Tenants.Authorization.Policies.BroadcastPolicies alias Realtime.Tenants.Authorization.Policies.PresencePolicies defstruct broadcast: %BroadcastPolicies{}, - presence: %PresencePolicies{} + presence: %PresencePolicies{}, + ai_agent: %AiPolicies{} @type t :: %__MODULE__{ broadcast: BroadcastPolicies.t(), - presence: PresencePolicies.t() + presence: PresencePolicies.t(), + ai_agent: AiPolicies.t() } @doc """ diff --git a/lib/realtime/tenants/authorization/policies/ai_policies.ex b/lib/realtime/tenants/authorization/policies/ai_policies.ex new file mode 100644 index 000000000..fb3a6973b --- /dev/null +++ b/lib/realtime/tenants/authorization/policies/ai_policies.ex @@ -0,0 +1,12 @@ +defmodule Realtime.Tenants.Authorization.Policies.AiPolicies do + @moduledoc """ + AiPolicies structure that holds the required authorization information for a given connection + within the scope of sending inputs to and receiving broadcasts from an AI agent. + """ + defstruct read: nil, write: nil + + @type t :: %__MODULE__{ + read: boolean() | nil, + write: boolean() | nil + } +end diff --git a/lib/realtime/tenants/replication_connection.ex b/lib/realtime/tenants/replication_connection.ex index 623753bb3..a5649e31e 100644 --- a/lib/realtime/tenants/replication_connection.ex +++ b/lib/realtime/tenants/replication_connection.ex @@ -370,6 +370,7 @@ defmodule Realtime.Tenants.ReplicationConnection do with %{columns: columns} <- Map.get(relations, relation_id), to_broadcast = tuple_to_map(tuple_data, columns), + {:ok, "broadcast"} <- get_or_error(to_broadcast, "extension", :not_broadcast), {:ok, payload} <- get_or_error(to_broadcast, "payload", :payload_missing), {:ok, inserted_at} <- get_or_error(to_broadcast, "inserted_at", :inserted_at_missing), {:ok, event} <- get_or_error(to_broadcast, "event", :event_missing), @@ -413,6 +414,14 @@ defmodule Realtime.Tenants.ReplicationConnection do {:noreply, state} else + {:error, :not_broadcast} -> + {:noreply, state} + + {:error, %Ecto.Changeset{valid?: false} = changeset} -> + error = Ecto.Changeset.traverse_errors(changeset, &elem(&1, 0)) + log_error("UnableToBroadcastChanges", error) + {:noreply, state} + {:error, error} -> log_error("UnableToBroadcastChanges", error) {:noreply, state} diff --git a/lib/realtime/tenants/repo.ex b/lib/realtime/tenants/repo.ex index 18c9c893f..4a9f06e2d 100644 --- a/lib/realtime/tenants/repo.ex +++ b/lib/realtime/tenants/repo.ex @@ -169,20 +169,22 @@ defmodule Realtime.Tenants.Repo do prefix = schema.__schema__(:prefix) changes = Enum.map(changesets, & &1.changes) - %{header: header, rows: rows} = - Enum.reduce(changes, %{header: [], rows: []}, fn v, changes_acc -> - Enum.reduce(v, changes_acc, fn {field, row}, %{header: header, rows: rows} -> - row = - case row do - row when is_boolean(row) -> row - row when is_atom(row) -> Atom.to_string(row) - _ -> row - end - - %{ - header: Enum.uniq([Atom.to_string(field) | header]), - rows: [row | rows] - } + header = + changes + |> Enum.flat_map(&Map.keys/1) + |> Enum.uniq() + |> Enum.map(&Atom.to_string/1) + + rows = + Enum.flat_map(changes, fn row -> + Enum.map(header, fn field -> + value = Map.get(row, String.to_atom(field)) + + case value do + v when is_boolean(v) -> v + v when is_atom(v) and not is_nil(v) -> Atom.to_string(v) + v -> v + end end) end) diff --git a/lib/realtime_web/channels/realtime_channel.ex b/lib/realtime_web/channels/realtime_channel.ex index b1797ed4f..42bc0f5ff 100644 --- a/lib/realtime_web/channels/realtime_channel.ex +++ b/lib/realtime_web/channels/realtime_channel.ex @@ -24,6 +24,7 @@ defmodule RealtimeWeb.RealtimeChannel do alias RealtimeWeb.Channels.Payloads.Join alias RealtimeWeb.ChannelsAuthorization + alias RealtimeWeb.RealtimeChannel.AiAgentHandler alias RealtimeWeb.RealtimeChannel.BroadcastHandler alias RealtimeWeb.RealtimeChannel.MessageDispatcher alias RealtimeWeb.RealtimeChannel.PresenceHandler @@ -59,12 +60,15 @@ defmodule RealtimeWeb.RealtimeChannel do _ -> false end + ai_config = get_in(params, ["config", "ai"]) || %{} + socket = socket |> assign_access_token(params) |> assign(:private?, !!params["config"]["private"]) |> assign(:policies, nil) |> assign(:presence_enabled?, presence_enabled?) + |> assign(:ai_config, ai_config) case Join.validate(params) do {:ok, _join} -> @@ -87,11 +91,22 @@ defmodule RealtimeWeb.RealtimeChannel do socket = assign_authorization_context(socket, sub_topic, claims), {:ok, db_conn} <- Connect.lookup_or_start_connection(tenant_id), {:ok, socket} <- maybe_assign_policies(sub_topic, db_conn, socket), + tenant_topic = Tenants.tenant_topic(tenant_id, sub_topic, !socket.assigns.private?), {:ok, replayed_message_ids} <- - maybe_replay_messages(params["config"], sub_topic, db_conn, tenant_id, socket.assigns.private?) do - tenant_topic = Tenants.tenant_topic(tenant_id, sub_topic, !socket.assigns.private?) + maybe_replay_messages(params["config"], sub_topic, db_conn, tenant_id, socket.assigns.private?), + {:ok, ai_replayed_message_ids} <- + AiAgentHandler.replay(params["config"], tenant_topic, db_conn, tenant_id, socket.assigns.private?), + {:ok, ai_session} <- + AiAgentHandler.start_session( + socket.assigns.ai_config, + tenant, + tenant_topic, + tenant_id, + self(), + socket.assigns.private? + ) do + all_replayed_ids = MapSet.union(replayed_message_ids, ai_replayed_message_ids) - # fastlane subscription metadata = MessageDispatcher.fastlane_metadata( transport_pid, @@ -99,12 +114,14 @@ defmodule RealtimeWeb.RealtimeChannel do topic, log_level, tenant_id, - replayed_message_ids + all_replayed_ids ) RealtimeWeb.Endpoint.subscribe(tenant_topic, metadata: metadata) RealtimeWeb.Endpoint.subscribe("realtime:operations:" <> tenant_id, metadata: metadata) + AiAgentHandler.notify_session_started(ai_session) + is_new_api = new_api?(params) presence_enabled? = socket.assigns.presence_enabled? @@ -133,7 +150,8 @@ defmodule RealtimeWeb.RealtimeChannel do self_broadcast: !!params["config"]["broadcast"]["self"], tenant_topic: tenant_topic, channel_name: sub_topic, - presence_enabled?: presence_enabled? + presence_enabled?: presence_enabled?, + ai_session: ai_session } socket = @@ -142,7 +160,6 @@ defmodule RealtimeWeb.RealtimeChannel do |> assign_presence_counter(tenant) |> assign_client_presence_rate_limit(tenant) - # Start presence and add user if presence is enabled if presence_enabled?, do: send(self(), :sync_presence) UsersCounter.add(transport_pid, tenant_id) @@ -239,6 +256,18 @@ defmodule RealtimeWeb.RealtimeChannel do {:error, :invalid_replay_channel} -> log_error(socket, "UnableToReplayMessages", "Replay is not allowed for public channels") + {:error, :ai_requires_private_channel} -> + log_error(socket, "AiAgentRequiresPrivateChannel", "AI agent is only supported on private channels") + + {:error, :ai_agent_feature_disabled} -> + log_error(socket, "AiAgentFeatureDisabled", "AI agent feature is not enabled for this tenant") + + {:error, :no_ai_agent_configured} -> + log_error(socket, "AiAgentNotConfigured", "No AI agent configured for this tenant") + + {:error, :ai_session_start_failed} -> + log_error(socket, "AiSessionStartFailed", "Failed to start AI agent session") + {:error, :error_generating_signer} -> log_error( socket, @@ -261,9 +290,10 @@ defmodule RealtimeWeb.RealtimeChannel do def handle_info({:replay, messages}, socket) do for message <- messages do meta = %{"replayed" => true, "id" => message.id} - payload = %{"payload" => message.payload, "event" => message.event, "type" => "broadcast", "meta" => meta} + {channel_event, type} = replay_channel_event(message.extension, message.event) + payload = %{"payload" => message.payload, "event" => message.event, "type" => type, "meta" => meta} - push(socket, "broadcast", payload) + push(socket, channel_event, payload) end {:noreply, socket} @@ -394,10 +424,10 @@ defmodule RealtimeWeb.RealtimeChannel do @impl true def handle_in("broadcast", payload, %{assigns: %{private?: true}} = socket) do - %{tenant: tenant_id} = socket.assigns - - with {:ok, db_conn} <- Connect.lookup_or_start_connection(tenant_id) do - BroadcastHandler.handle(payload, db_conn, socket) + with {:ok, db_conn} <- Connect.lookup_or_start_connection(socket.assigns.tenant) do + if AiAgentHandler.ai_event?(payload), + do: AiAgentHandler.handle(payload, db_conn, socket), + else: BroadcastHandler.handle(payload, db_conn, socket) else {:error, error} -> log_error(socket, "UnableToHandleBroadcast", error) @@ -406,7 +436,9 @@ defmodule RealtimeWeb.RealtimeChannel do end def handle_in("broadcast", payload, %{assigns: %{private?: false}} = socket) do - BroadcastHandler.handle(payload, socket) + if AiAgentHandler.ai_event?(payload), + do: {:noreply, socket}, + else: BroadcastHandler.handle(payload, socket) end def handle_in("presence", payload, %{assigns: %{private?: true}} = socket) do @@ -484,7 +516,6 @@ defmodule RealtimeWeb.RealtimeChannel do } } = socket - # Update token and reset policies socket = assign(socket, %{access_token: refresh_token, policies: nil}) with {:ok, claims, confirm_token_ref} <- confirm_token(socket), @@ -847,10 +878,13 @@ defmodule RealtimeWeb.RealtimeChannel do authorization_context = socket.assigns.authorization_context policies = socket.assigns.policies || %Policies{} presence_enabled? = socket.assigns.presence_enabled? + ai_config = socket.assigns[:ai_config] || %{} + ai_enabled? = ai_config["enabled"] == true and is_binary(ai_config["agent"]) with {:ok, policies} <- Authorization.get_read_authorizations(policies, db_conn, authorization_context, - presence_enabled?: presence_enabled? + presence_enabled?: presence_enabled?, + ai_enabled?: ai_enabled? ) do socket = assign(socket, :policies, policies) @@ -884,33 +918,23 @@ defmodule RealtimeWeb.RealtimeChannel do end end - defp maybe_replay_messages(%{"broadcast" => %{"replay" => _}}, _sub_topic, _db_conn, _tenant_id, false = _private?) do - {:error, :invalid_replay_channel} - end + defp maybe_replay_messages(%{"broadcast" => %{"replay" => _}}, _, _, _, false), do: {:error, :invalid_replay_channel} + + defp maybe_replay_messages(%{"broadcast" => %{"replay" => params}}, topic, conn, tid, true) when is_map(params), + do: do_replay(params, [:broadcast], topic, conn, tid) + + defp maybe_replay_messages(_, _, _, _, _), do: {:ok, MapSet.new()} - defp maybe_replay_messages( - %{"broadcast" => %{"replay" => replay_params}}, - sub_topic, - db_conn, - tenant_id, - true = _private? - ) - when is_map(replay_params) do + defp do_replay(params, extensions, sub_topic, db_conn, tenant_id) do with {:ok, messages, message_ids} <- - Realtime.Messages.replay( - db_conn, - tenant_id, - sub_topic, - replay_params["since"], - replay_params["limit"] || 25 - ) do - # Send to self because we can't write to the socket before finishing the join process + Realtime.Messages.replay(db_conn, tenant_id, sub_topic, params["since"], params["limit"] || 25, extensions) do send(self(), {:replay, messages}) {:ok, message_ids} end end - defp maybe_replay_messages(_, _, _, _, _), do: {:ok, MapSet.new()} + defp replay_channel_event(:ai_agent_event, _event), do: {"ai_event", "ai_agent"} + defp replay_channel_event(_, _), do: {"broadcast", "broadcast"} defp presence_enabled?(client_enabled?, %Tenant{presence_enabled: tenant_enabled}) do client_enabled? || tenant_enabled diff --git a/lib/realtime_web/channels/realtime_channel/ai_agent_handler.ex b/lib/realtime_web/channels/realtime_channel/ai_agent_handler.ex new file mode 100644 index 000000000..63b5931ff --- /dev/null +++ b/lib/realtime_web/channels/realtime_channel/ai_agent_handler.ex @@ -0,0 +1,163 @@ +defmodule RealtimeWeb.RealtimeChannel.AiAgentHandler do + @moduledoc """ + Handles the AI Agent feature from Realtime. + """ + use Realtime.Logs + + import Phoenix.Socket, only: [assign: 3] + import Phoenix.Channel, only: [push: 3] + + alias Extensions.AiAgent.Session + alias Extensions.AiAgent.SessionSupervisor + alias Phoenix.Socket + alias Realtime.Api.Tenant + alias Realtime.FeatureFlags + alias Realtime.Messages + alias Realtime.Tenants.Authorization + alias Realtime.Tenants.Authorization.Policies + alias Realtime.Tenants.Authorization.Policies.AiPolicies + + @spec notify_session_started(pid() | nil) :: :ok + def notify_session_started(pid) when is_pid(pid), do: GenServer.cast(pid, :emit_session_started) + def notify_session_started(_), do: :ok + + @ai_events ["agent_input", "agent_cancel"] + + @spec ai_event?(map() | tuple()) :: boolean() + def ai_event?(%{"event" => event}) when event in @ai_events, do: true + def ai_event?({event, _, _, _}) when event in @ai_events, do: true + def ai_event?(_), do: false + + @spec handle(map() | tuple(), pid() | nil, Socket.t()) :: {:noreply, Socket.t()} + def handle(%{"event" => event} = payload, db_conn, %{assigns: %{ai_session: pid}} = socket) + when event in @ai_events and is_pid(pid) do + do_handle_ai_event(event, payload, db_conn, socket) + end + + def handle({event, :json, payload_binary, _metadata}, db_conn, %{assigns: %{ai_session: pid}} = socket) + when event in @ai_events and is_pid(pid) do + payload = %{"event" => event, "payload" => Phoenix.json_library().decode!(payload_binary)} + do_handle_ai_event(event, payload, db_conn, socket) + end + + def handle(_payload, _db_conn, socket), do: {:noreply, socket} + + @dialyzer {:nowarn_function, start_session: 6} + @spec start_session(map(), Tenant.t(), String.t(), String.t(), pid(), boolean()) :: + {:ok, pid() | nil} | {:error, term()} + def start_session(%{"enabled" => true}, _tenant, _tenant_topic, _tenant_id, _channel_pid, false) do + {:error, :ai_requires_private_channel} + end + + def start_session( + %{"enabled" => true, "agent" => agent_name} = ai_config, + %{ai_enabled: true} = tenant, + tenant_topic, + tenant_id, + channel_pid, + true + ) + when is_binary(agent_name) do + with true <- FeatureFlags.enabled?("ai_agent", tenant_id), + extension when not is_nil(extension) <- + Enum.find(tenant.extensions, &(&1.type == "ai_agent" and &1.name == agent_name)) do + session_id = if is_binary(ai_config["session_id"]), do: ai_config["session_id"] + + opts = [ + tenant_id: tenant_id, + tenant_topic: tenant_topic, + settings: extension.settings, + channel_pid: channel_pid, + session_id: session_id, + max_ai_events_per_second: tenant.max_ai_events_per_second, + max_ai_tokens_per_minute: tenant.max_ai_tokens_per_minute + ] + + do_start_session(opts, tenant_id, agent_name) + else + false -> + Logger.error("AiAgentFeatureFlagDisabled agent=#{agent_name} tenant=#{tenant_id}") + {:error, :ai_agent_feature_disabled} + + nil -> + Logger.error("AiAgentNotFound agent=#{agent_name} tenant=#{tenant_id}") + {:error, :no_ai_agent_configured} + end + end + + def start_session( + %{"enabled" => true, "agent" => agent_name}, + _tenant, + _tenant_topic, + tenant_id, + _channel_pid, + _private? + ) + when is_binary(agent_name) do + Logger.error("AiNotEnabledForTenant agent=#{agent_name} tenant=#{tenant_id}") + {:error, :no_ai_agent_configured} + end + + def start_session(_ai_config, _tenant, _tenant_topic, _tenant_id, _channel_pid, _private?), do: {:ok, nil} + + @spec replay(map(), String.t(), pid(), String.t(), boolean()) :: {:ok, MapSet.t()} | {:error, term()} + def replay(%{"ai" => %{"replay" => _}}, _topic, _conn, _tid, false), do: {:error, :invalid_replay_channel} + + def replay(%{"ai" => %{"replay" => params}}, topic, conn, tid, true) when is_map(params) do + with {:ok, messages, message_ids} <- + Messages.replay(conn, tid, topic, params["since"], params["limit"] || 25, [:ai_agent_event]) do + send(self(), {:replay, messages}) + {:ok, message_ids} + end + end + + def replay(_config, _topic, _conn, _tid, _private?), do: {:ok, MapSet.new()} + + defp do_handle_ai_event(event, payload, db_conn, socket) do + %{authorization_context: authorization_context, policies: policies, ai_session: pid} = socket.assigns + + case check_ai_authorization(policies || %Policies{}, db_conn, authorization_context) do + {:ok, %Policies{ai_agent: %AiPolicies{write: true}} = policies} -> + socket = assign(socket, :policies, policies) + route_ai_event(event, payload["payload"] || %{}, pid) + {:noreply, socket} + + {:ok, _policies} -> + push(socket, "ai_event", %{"event" => "agent_error", "payload" => %{"reason" => "unauthorized"}}) + {:noreply, socket} + + {:error, :rls_policy_error, error} -> + log_error("RlsPolicyError", error) + push(socket, "ai_event", %{"event" => "agent_error", "payload" => %{"reason" => "unauthorized"}}) + {:noreply, socket} + + {:error, error} -> + log_error("UnableToSetPolicies", error) + {:noreply, socket} + end + end + + defp check_ai_authorization(%Policies{ai_agent: %AiPolicies{write: nil}} = policies, db_conn, ctx) do + Authorization.get_write_authorizations(policies, db_conn, ctx, ai_enabled?: true) + end + + defp check_ai_authorization(policies, _db_conn, _ctx), do: {:ok, policies} + + defp route_ai_event("agent_input", payload, pid), do: Session.handle_input(pid, payload) + defp route_ai_event("agent_cancel", _payload, pid), do: Session.cancel(pid) + + defp do_start_session(opts, tenant_id, agent_name) do + case SessionSupervisor.start_session(opts) do + {:ok, pid} -> + {:ok, pid} + + {:error, reason} -> + Logger.error("AiSessionStartFailed reason=#{inspect(reason)} tenant=#{tenant_id} agent=#{agent_name}") + {:error, :ai_session_start_failed} + end + catch + :exit, reason -> + Logger.error("AiSessionStartFailed reason=#{inspect(reason)} tenant=#{tenant_id} agent=#{agent_name}") + {:error, :ai_session_start_failed} + end +end diff --git a/lib/realtime_web/channels/realtime_channel/broadcast_handler.ex b/lib/realtime_web/channels/realtime_channel/broadcast_handler.ex index 8aad778af..e185ada25 100644 --- a/lib/realtime_web/channels/realtime_channel/broadcast_handler.ex +++ b/lib/realtime_web/channels/realtime_channel/broadcast_handler.ex @@ -6,58 +6,62 @@ defmodule RealtimeWeb.RealtimeChannel.BroadcastHandler do import Phoenix.Socket, only: [assign: 3] - alias Realtime.Tenants - alias RealtimeWeb.RealtimeChannel - alias RealtimeWeb.TenantBroadcaster alias Phoenix.Socket alias Realtime.GenCounter + alias Realtime.Tenants alias Realtime.Tenants.Authorization alias Realtime.Tenants.Authorization.Policies alias Realtime.Tenants.Authorization.Policies.BroadcastPolicies + alias RealtimeWeb.RealtimeChannel + alias RealtimeWeb.TenantBroadcaster - @type payload :: map | {String.t(), :json | :binary, binary} + @type payload :: map | {String.t(), :json | :binary, binary, map()} @event_type "broadcast" + @spec handle(payload, Socket.t()) :: {:reply, :ok, Socket.t()} | {:noreply, Socket.t()} def handle(payload, %{assigns: %{private?: false}} = socket), do: handle(payload, nil, socket) @spec handle(payload, pid() | nil, Socket.t()) :: {:reply, :ok, Socket.t()} | {:noreply, Socket.t()} + def handle({_, _, _, _} = payload, db_conn, %{assigns: %{private?: true}} = socket) do + broadcast = build_user_broadcast(socket.assigns.tenant_topic, payload) + broadcast_authorized(broadcast, :ok, db_conn, socket) + end + def handle(payload, db_conn, %{assigns: %{private?: true}} = socket) do - %{ - assigns: %{ - self_broadcast: self_broadcast, - tenant_topic: tenant_topic, - authorization_context: authorization_context, - policies: policies, - tenant: tenant_id - } - } = socket - - case run_authorization_check(policies || %Policies{}, db_conn, authorization_context) do - {:ok, %Policies{broadcast: %BroadcastPolicies{write: true}} = policies} -> - socket = - socket - |> assign(:policies, policies) - |> increment_rate_counter() + %{assigns: %{tenant_topic: topic}} = socket + broadcast = %Phoenix.Socket.Broadcast{topic: topic, event: @event_type, payload: payload} + broadcast_authorized(broadcast, payload, db_conn, socket) + end - %{ack_broadcast: ack_broadcast} = socket.assigns + def handle({_, _, _, _} = payload, _db_conn, %{assigns: %{private?: false}} = socket) do + broadcast = build_user_broadcast(socket.assigns.tenant_topic, payload) + broadcast_public(broadcast, :ok, socket) + end - res = - case Tenants.validate_payload_size(tenant_id, payload) do - :ok -> send_message(tenant_id, self_broadcast, tenant_topic, payload) - {:error, error} -> {:error, error} - end + def handle(payload, _db_conn, %{assigns: %{private?: false}} = socket) do + %{assigns: %{tenant_topic: topic}} = socket + broadcast = %Phoenix.Socket.Broadcast{topic: topic, event: @event_type, payload: payload} + broadcast_public(broadcast, payload, socket) + end - cond do - ack_broadcast && match?({:error, :payload_size_exceeded}, res) -> - {:reply, {:error, :payload_size_exceeded}, socket} + defp build_user_broadcast(topic, {user_event, user_payload_encoding, user_payload, _metadata}) do + %RealtimeWeb.Socket.UserBroadcast{ + topic: topic, + user_event: user_event, + user_payload_encoding: user_payload_encoding, + user_payload: user_payload + } + end - ack_broadcast -> - {:reply, :ok, socket} + defp broadcast_authorized(broadcast, payload_or_size_check, db_conn, socket) do + %{assigns: %{authorization_context: authorization_context, policies: policies}} = socket - true -> - {:noreply, socket} - end + case check_broadcast_authorization(policies || %Policies{}, db_conn, authorization_context) do + {:ok, %Policies{broadcast: %BroadcastPolicies{write: true}} = policies} -> + socket = socket |> assign(:policies, policies) |> increment_rate_counter() + res = do_send(broadcast, payload_or_size_check, socket) + reply_for_result(res, socket.assigns.ack_broadcast, socket) {:ok, policies} -> {:noreply, assign(socket, :policies, policies)} @@ -87,24 +91,30 @@ defmodule RealtimeWeb.RealtimeChannel.BroadcastHandler do end end - def handle(payload, _db_conn, %{assigns: %{private?: false}} = socket) do - %{ - assigns: %{ - tenant_topic: tenant_topic, - self_broadcast: self_broadcast, - ack_broadcast: ack_broadcast, - tenant: tenant_id - } - } = socket - + defp broadcast_public(broadcast, payload_or_size_check, socket) do socket = increment_rate_counter(socket) + res = do_send(broadcast, payload_or_size_check, socket) + reply_for_result(res, socket.assigns.ack_broadcast, socket) + end - res = - case Tenants.validate_payload_size(tenant_id, payload) do - :ok -> send_message(tenant_id, self_broadcast, tenant_topic, payload) - error -> error + defp do_send( + broadcast, + payload_or_size_check, + %{assigns: %{tenant: tenant_id, self_broadcast: self_broadcast, tenant_topic: tenant_topic}} = _socket + ) do + size_check = + case payload_or_size_check do + :ok -> :ok + payload -> Tenants.validate_payload_size(tenant_id, payload) end + case size_check do + :ok -> pubsub_send(tenant_id, self_broadcast, tenant_topic, broadcast) + error -> error + end + end + + defp reply_for_result(res, ack_broadcast, socket) do cond do ack_broadcast && match?({:error, :payload_size_exceeded}, res) -> {:reply, {:error, :payload_size_exceeded}, socket} @@ -117,45 +127,25 @@ defmodule RealtimeWeb.RealtimeChannel.BroadcastHandler do end end - defp send_message(tenant_id, self_broadcast, tenant_topic, payload) do - broadcast = build_broadcast(tenant_topic, payload) - - if self_broadcast do - TenantBroadcaster.pubsub_broadcast( - tenant_id, - tenant_topic, - broadcast, - RealtimeChannel.MessageDispatcher, - :broadcast - ) - else - TenantBroadcaster.pubsub_broadcast_from( - tenant_id, - self(), - tenant_topic, - broadcast, - RealtimeChannel.MessageDispatcher, - :broadcast - ) - end - end - - # No idea why Dialyzer is complaining here - @dialyzer {:nowarn_function, build_broadcast: 2} - - # Message payload was built by V2 Serializer which was originally UserBroadcastPush - # We are not using the metadata for anything just yet. - defp build_broadcast(topic, {user_event, user_payload_encoding, user_payload, _metadata}) do - %RealtimeWeb.Socket.UserBroadcast{ - topic: topic, - user_event: user_event, - user_payload_encoding: user_payload_encoding, - user_payload: user_payload - } + defp pubsub_send(tenant_id, true, tenant_topic, broadcast) do + TenantBroadcaster.pubsub_broadcast( + tenant_id, + tenant_topic, + broadcast, + RealtimeChannel.MessageDispatcher, + :broadcast + ) end - defp build_broadcast(topic, payload) do - %Phoenix.Socket.Broadcast{topic: topic, event: @event_type, payload: payload} + defp pubsub_send(tenant_id, false, tenant_topic, broadcast) do + TenantBroadcaster.pubsub_broadcast_from( + tenant_id, + self(), + tenant_topic, + broadcast, + RealtimeChannel.MessageDispatcher, + :broadcast + ) end defp increment_rate_counter(%{assigns: %{policies: %Policies{broadcast: %BroadcastPolicies{write: false}}}} = socket) do @@ -167,15 +157,9 @@ defmodule RealtimeWeb.RealtimeChannel.BroadcastHandler do socket end - defp run_authorization_check( - %Policies{broadcast: %BroadcastPolicies{write: nil}} = policies, - db_conn, - authorization_context - ) do - Authorization.get_write_authorizations(policies, db_conn, authorization_context) + defp check_broadcast_authorization(%Policies{broadcast: %BroadcastPolicies{write: nil}} = policies, db_conn, ctx) do + Authorization.get_write_authorizations(policies, db_conn, ctx) end - defp run_authorization_check(socket, _db_conn, _authorization_context) do - {:ok, socket} - end + defp check_broadcast_authorization(policies, _db_conn, _ctx), do: {:ok, policies} end diff --git a/lib/realtime_web/tenant_broadcaster.ex b/lib/realtime_web/tenant_broadcaster.ex index b0a95d679..ff7dab27c 100644 --- a/lib/realtime_web/tenant_broadcaster.ex +++ b/lib/realtime_web/tenant_broadcaster.ex @@ -5,7 +5,7 @@ defmodule RealtimeWeb.TenantBroadcaster do alias Phoenix.PubSub - @type message_type :: :broadcast | :presence | :postgres_changes + @type message_type :: :broadcast | :presence | :postgres_changes | :ai_events @spec pubsub_direct_broadcast( node :: node(), diff --git a/mise.toml b/mise.toml index 3a92bd5e1..490f226e2 100644 --- a/mise.toml +++ b/mise.toml @@ -20,6 +20,11 @@ description = "Start another dev server (orange)" run = "iex --name ${NAME}@127.0.0.1 --cookie cookie -S mix phx.server" env = { NAME = "orange", PORT = "4001", REGION = "eu-west-1", GEN_RPC_TCP_SERVER_PORT = "5469", GEN_RPC_TCP_CLIENT_PORT = "5369" } +[tasks.seed] +description = "Seed the dev database (safe to run while dev server is up)" +run = "mix ecto.create && mix ecto.migrate && mix run priv/repo/dev_seeds.exs" +env = { GEN_RPC_TCP_SERVER_PORT = "5569", GEN_RPC_TCP_CLIENT_PORT = "5669" } + [tasks.db-start] description = "Start all dev databases" run = "docker compose -f compose.dbs.yml up -d --wait" diff --git a/priv/repo/dev_seeds.exs b/priv/repo/dev_seeds.exs index a1ca7c05a..91c76a909 100644 --- a/priv/repo/dev_seeds.exs +++ b/priv/repo/dev_seeds.exs @@ -1,4 +1,7 @@ +alias Realtime.Api +alias Realtime.Api.Extensions alias Realtime.Api.Tenant +alias Realtime.Crypto alias Realtime.Database alias Realtime.Repo alias Realtime.Tenants @@ -71,3 +74,41 @@ case Tenants.Migrations.run_migrations(tenant) do end Tenants.Migrations.run_migrations(tenant) + +Postgrex.transaction(tenant_conn, fn db_conn -> + [ + """ + CREATE POLICY "authenticated_all_topic_read" + ON realtime.messages FOR SELECT + TO authenticated + USING ( true ); + """, + """ + CREATE POLICY "authenticated_all_topic_write" + ON realtime.messages FOR INSERT + TO authenticated + WITH CHECK ( true ); + """ + ] + |> Enum.each(&Postgrex.query!(db_conn, &1)) +end) + +ollama_host = System.get_env("OLLAMA_HOST", "http://localhost:11434") +ollama_model = System.get_env("OLLAMA_MODEL", "qwen2:0.5b") + +%Extensions{} +|> Extensions.changeset(%{ + "type" => "ai_agent", + "name" => "local-agent", + "tenant_external_id" => tenant.external_id, + "settings" => %{ + "protocol" => "openai_compatible", + "model" => ollama_model, + "api_key" => Crypto.encrypt!("ollama"), + "base_url" => ollama_host <> "/v1", + "topic_pattern" => "agent:*" + } +}) +|> Repo.insert!() + +Api.update_tenant_by_external_id(tenant.external_id, %{"ai_enabled" => true}) diff --git a/priv/repo/migrations/20260430000001_add_name_to_extensions.exs b/priv/repo/migrations/20260430000001_add_name_to_extensions.exs new file mode 100644 index 000000000..00c1828c3 --- /dev/null +++ b/priv/repo/migrations/20260430000001_add_name_to_extensions.exs @@ -0,0 +1,21 @@ +defmodule Realtime.Repo.Migrations.AddNameToExtensions do + use Ecto.Migration + + def up do + alter table(:extensions) do + add_if_not_exists :name, :string + end + + drop_if_exists unique_index(:extensions, [:tenant_external_id, :type]) + create_if_not_exists unique_index(:extensions, [:tenant_external_id, :type, :name]) + end + + def down do + drop_if_exists unique_index(:extensions, [:tenant_external_id, :type, :name]) + create_if_not_exists unique_index(:extensions, [:tenant_external_id, :type]) + + alter table(:extensions) do + remove_if_exists :name, :string + end + end +end diff --git a/priv/repo/migrations/20260430000002_add_ai_to_tenants.exs b/priv/repo/migrations/20260430000002_add_ai_to_tenants.exs new file mode 100644 index 000000000..50c3740d7 --- /dev/null +++ b/priv/repo/migrations/20260430000002_add_ai_to_tenants.exs @@ -0,0 +1,11 @@ +defmodule Realtime.Repo.Migrations.AddAiToTenants do + use Ecto.Migration + + def change do + alter table(:tenants) do + add :ai_enabled, :boolean, default: false, null: false + add :max_ai_events_per_second, :integer, default: 100, null: false + add :max_ai_tokens_per_minute, :integer, default: 60_000, null: false + end + end +end diff --git a/test/extensions/ai_agent/adapter/anthropic_messages_test.exs b/test/extensions/ai_agent/adapter/anthropic_messages_test.exs new file mode 100644 index 000000000..4b1586e0b --- /dev/null +++ b/test/extensions/ai_agent/adapter/anthropic_messages_test.exs @@ -0,0 +1,239 @@ +defmodule Extensions.AiAgent.Adapter.AnthropicMessagesTest do + use ExUnit.Case, async: true + use Mimic + + alias Extensions.AiAgent.Adapter.AnthropicMessages + alias Extensions.AiAgent.Types.Event + + @settings %{ + "model" => "claude-opus-4-7", + "api_key" => "sk-ant-test", + "base_url" => "https://api.anthropic.com" + } + + @messages [%{"role" => "user", "content" => "hello"}] + + defp sse(type, data), do: "event: #{type}\ndata: #{Jason.encode!(data)}\n\n" + + describe "stream/3" do + test "emits text_delta events from content_block_delta" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + start_evt = + sse("message_start", %{"type" => "message_start", "message" => %{"usage" => %{"input_tokens" => 5}}}) + + block_start = + sse("content_block_start", %{ + "type" => "content_block_start", + "index" => 0, + "content_block" => %{"type" => "text"} + }) + + delta1 = + sse("content_block_delta", %{ + "type" => "content_block_delta", + "index" => 0, + "delta" => %{"type" => "text_delta", "text" => "Hello"} + }) + + delta2 = + sse("content_block_delta", %{ + "type" => "content_block_delta", + "index" => 0, + "delta" => %{"type" => "text_delta", "text" => " world"} + }) + + msg_delta = + sse("message_delta", %{ + "type" => "message_delta", + "delta" => %{"stop_reason" => "end_turn"}, + "usage" => %{"output_tokens" => 10} + }) + + acc = callback.({:status, 200}, acc) + acc = callback.({:data, start_evt <> block_start <> delta1 <> delta2 <> msg_delta}, acc) + {:ok, acc} + end) + + assert :ok = AnthropicMessages.stream(@settings, @messages, caller) + + assert_receive {:ai_event, %Event{type: :usage, payload: %{input_tokens: 5}}} + assert_receive {:ai_event, %Event{type: :text_delta, payload: %{delta: "Hello"}}} + assert_receive {:ai_event, %Event{type: :text_delta, payload: %{delta: " world"}}} + assert_receive {:ai_event, %Event{type: :usage, payload: %{output_tokens: 10}}} + assert_receive {:ai_event, %Event{type: :done, payload: %{stop_reason: "end_turn"}}} + end + + test "emits tool_call events for tool_use content blocks" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + start_evt = + sse("message_start", %{"type" => "message_start", "message" => %{"usage" => %{"input_tokens" => 5}}}) + + tool_start = + sse("content_block_start", %{ + "type" => "content_block_start", + "index" => 0, + "content_block" => %{"type" => "tool_use", "id" => "toolu_01", "name" => "get_weather", "input" => %{}} + }) + + arg_delta = + sse("content_block_delta", %{ + "type" => "content_block_delta", + "index" => 0, + "delta" => %{"type" => "input_json_delta", "partial_json" => "{\"city\":\"NYC\"}"} + }) + + tool_stop = sse("content_block_stop", %{"type" => "content_block_stop", "index" => 0}) + + msg_delta = + sse("message_delta", %{ + "type" => "message_delta", + "delta" => %{"stop_reason" => "tool_use"}, + "usage" => %{"output_tokens" => 15} + }) + + acc = callback.({:status, 200}, acc) + acc = callback.({:data, start_evt <> tool_start <> arg_delta <> tool_stop <> msg_delta}, acc) + {:ok, acc} + end) + + assert :ok = AnthropicMessages.stream(@settings, @messages, caller) + + assert_receive {:ai_event, + %Event{type: :tool_call_delta, payload: %{tool_call_id: "toolu_01", name: "get_weather"}}} + + assert_receive {:ai_event, + %Event{ + type: :tool_call_done, + payload: %{tool_call_id: "toolu_01", name: "get_weather", arguments: "{\"city\":\"NYC\"}"} + }} + end + + test "emits error event on HTTP error status" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + acc = callback.({:status, 401}, acc) + {:ok, acc} + end) + + AnthropicMessages.stream(@settings, @messages, caller) + + assert_receive {:ai_event, %Event{type: :error, payload: %{reason: "HTTP 401"}}} + end + + test "returns error tuple when Finch fails" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, _callback, _opts -> + {:error, %Mint.TransportError{reason: :closed}, acc} + end) + + assert {:error, _} = AnthropicMessages.stream(@settings, @messages, caller) + end + + test "emits error event on HTTP 5xx" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + acc = callback.({:status, 529}, acc) + {:ok, acc} + end) + + assert :ok = AnthropicMessages.stream(@settings, @messages, caller) + assert_receive {:ai_event, %Event{type: :error, payload: %{reason: "HTTP 529"}}} + end + + test "skips malformed SSE lines without crashing" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + bad_data = "data: not-valid-json\n\n" + + valid = + sse("message_delta", %{ + "type" => "message_delta", + "delta" => %{"stop_reason" => "end_turn"}, + "usage" => %{"output_tokens" => 1} + }) + + acc = callback.({:status, 200}, acc) + acc = callback.({:data, bad_data <> valid}, acc) + {:ok, acc} + end) + + assert :ok = AnthropicMessages.stream(@settings, @messages, caller) + assert_receive {:ai_event, %Event{type: :done}} + refute_receive {:ai_event, %Event{type: :error}}, 50 + end + end + + describe "system_prompt" do + defp minimal_sse do + sse("message_start", %{"type" => "message_start", "message" => %{"usage" => %{"input_tokens" => 1}}}) <> + sse("message_delta", %{ + "type" => "message_delta", + "delta" => %{"stop_reason" => "end_turn"}, + "usage" => %{"output_tokens" => 1} + }) + end + + test "sends system_prompt as top-level system field" do + caller = self() + settings = Map.put(@settings, "system_prompt", "You are a helpful assistant.") + + stub(Finch, :stream, fn req, _name, acc, callback, _opts -> + send(caller, {:captured_body, Jason.decode!(req.body)}) + acc = callback.({:status, 200}, acc) + acc = callback.({:data, minimal_sse()}, acc) + {:ok, acc} + end) + + AnthropicMessages.stream(settings, @messages, caller) + + assert_receive {:captured_body, body} + assert body["system"] == "You are a helpful assistant." + end + + test "omits system field when system_prompt is absent" do + caller = self() + + stub(Finch, :stream, fn req, _name, acc, callback, _opts -> + send(caller, {:captured_body, Jason.decode!(req.body)}) + acc = callback.({:status, 200}, acc) + acc = callback.({:data, minimal_sse()}, acc) + {:ok, acc} + end) + + AnthropicMessages.stream(@settings, @messages, caller) + + assert_receive {:captured_body, body} + refute Map.has_key?(body, "system") + end + + test "filters system-role messages from the messages array" do + caller = self() + settings = Map.put(@settings, "system_prompt", "Be helpful.") + + messages = [ + %{"role" => "system", "content" => "Be helpful."}, + %{"role" => "user", "content" => "hello"} + ] + + stub(Finch, :stream, fn req, _name, acc, callback, _opts -> + send(caller, {:captured_body, Jason.decode!(req.body)}) + acc = callback.({:status, 200}, acc) + acc = callback.({:data, minimal_sse()}, acc) + {:ok, acc} + end) + + AnthropicMessages.stream(settings, messages, caller) + + assert_receive {:captured_body, body} + assert [%{"role" => "user"}] = body["messages"] + end + end +end diff --git a/test/extensions/ai_agent/adapter/chat_completions_test.exs b/test/extensions/ai_agent/adapter/chat_completions_test.exs new file mode 100644 index 000000000..02d29b4c4 --- /dev/null +++ b/test/extensions/ai_agent/adapter/chat_completions_test.exs @@ -0,0 +1,197 @@ +defmodule Extensions.AiAgent.Adapter.ChatCompletionsTest do + use ExUnit.Case, async: true + use Mimic + + alias Extensions.AiAgent.Adapter.ChatCompletions + alias Extensions.AiAgent.Types.Event + + @settings %{ + "model" => "gpt-4o", + "api_key" => "sk-test", + "base_url" => "https://api.openai.com/v1" + } + + @messages [%{"role" => "user", "content" => "hello"}] + + defp sse(json), do: "data: #{Jason.encode!(json)}\n\n" + defp done_chunk, do: "data: [DONE]\n\n" + + describe "stream/3" do + test "emits text_delta events from streamed chunks" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + chunk1 = sse(%{"choices" => [%{"delta" => %{"content" => "Hello"}, "finish_reason" => nil}]}) + chunk2 = sse(%{"choices" => [%{"delta" => %{"content" => " world"}, "finish_reason" => nil}]}) + finish = sse(%{"choices" => [%{"delta" => %{}, "finish_reason" => "stop"}]}) + + acc = callback.({:status, 200}, acc) + acc = callback.({:data, chunk1}, acc) + acc = callback.({:data, chunk2}, acc) + acc = callback.({:data, finish <> done_chunk()}, acc) + {:ok, acc} + end) + + assert :ok = ChatCompletions.stream(@settings, @messages, caller) + + assert_receive {:ai_event, %Event{type: :text_delta, payload: %{delta: "Hello"}}} + assert_receive {:ai_event, %Event{type: :text_delta, payload: %{delta: " world"}}} + assert_receive {:ai_event, %Event{type: :done, payload: %{stop_reason: "stop"}}} + end + + test "emits usage event when included in stream" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + finish = sse(%{"choices" => [%{"delta" => %{}, "finish_reason" => "stop"}]}) + usage = sse(%{"usage" => %{"prompt_tokens" => 10, "completion_tokens" => 20}}) + + acc = callback.({:status, 200}, acc) + acc = callback.({:data, finish <> usage <> done_chunk()}, acc) + {:ok, acc} + end) + + assert :ok = ChatCompletions.stream(@settings, @messages, caller) + + assert_receive {:ai_event, %Event{type: :usage, payload: %{input_tokens: 10, output_tokens: 20}}} + end + + test "emits error event on HTTP error status" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + acc = callback.({:status, 429}, acc) + {:ok, acc} + end) + + ChatCompletions.stream(@settings, @messages, caller) + + assert_receive {:ai_event, %Event{type: :error, payload: %{reason: "HTTP 429"}}} + end + + test "emits tool_call_delta and tool_call_done for tool calls" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + tc1 = + sse(%{ + "choices" => [ + %{ + "delta" => %{ + "tool_calls" => [ + %{ + "index" => 0, + "id" => "call_1", + "type" => "function", + "function" => %{"name" => "get_weather", "arguments" => ""} + } + ] + }, + "finish_reason" => nil + } + ] + }) + + tc2 = + sse(%{ + "choices" => [ + %{ + "delta" => %{"tool_calls" => [%{"index" => 0, "function" => %{"arguments" => "{\"city\":"}}]}, + "finish_reason" => nil + } + ] + }) + + tc3 = + sse(%{ + "choices" => [ + %{ + "delta" => %{"tool_calls" => [%{"index" => 0, "function" => %{"arguments" => "\"NYC\"}"}}]}, + "finish_reason" => nil + } + ] + }) + + finish = sse(%{"choices" => [%{"delta" => %{}, "finish_reason" => "tool_calls"}]}) + + acc = callback.({:status, 200}, acc) + acc = callback.({:data, tc1 <> tc2 <> tc3 <> finish <> done_chunk()}, acc) + {:ok, acc} + end) + + assert :ok = ChatCompletions.stream(@settings, @messages, caller) + + assert_receive {:ai_event, + %Event{type: :tool_call_delta, payload: %{tool_call_id: "call_1", name: "get_weather"}}} + + assert_receive {:ai_event, + %Event{ + type: :tool_call_done, + payload: %{tool_call_id: "call_1", name: "get_weather", arguments: "{\"city\":\"NYC\"}"} + }} + + assert_receive {:ai_event, %Event{type: :done, payload: %{stop_reason: "tool_calls"}}} + end + + test "returns error tuple when Finch fails" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, _callback, _opts -> + {:error, %Mint.TransportError{reason: :timeout}, acc} + end) + + assert {:error, _} = ChatCompletions.stream(@settings, @messages, caller) + end + + test "emits error event on HTTP 5xx" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + acc = callback.({:status, 500}, acc) + {:ok, acc} + end) + + assert :ok = ChatCompletions.stream(@settings, @messages, caller) + assert_receive {:ai_event, %Event{type: :error, payload: %{reason: "HTTP 500"}}} + end + + test "skips malformed SSE lines without crashing" do + caller = self() + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + bad_data = "data: {broken json\n\n" + finish = sse(%{"choices" => [%{"delta" => %{}, "finish_reason" => "stop"}]}) + + acc = callback.({:status, 200}, acc) + acc = callback.({:data, bad_data <> finish <> done_chunk()}, acc) + {:ok, acc} + end) + + assert :ok = ChatCompletions.stream(@settings, @messages, caller) + assert_receive {:ai_event, %Event{type: :done}} + refute_receive {:ai_event, %Event{type: :error}}, 50 + end + + test "handles chunks split across multiple data deliveries" do + caller = self() + + full = sse(%{"choices" => [%{"delta" => %{"content" => "split"}, "finish_reason" => nil}]}) + half1 = binary_part(full, 0, div(byte_size(full), 2)) + half2 = binary_part(full, div(byte_size(full), 2), byte_size(full) - div(byte_size(full), 2)) + finish = sse(%{"choices" => [%{"delta" => %{}, "finish_reason" => "stop"}]}) + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + acc = callback.({:status, 200}, acc) + acc = callback.({:data, half1}, acc) + acc = callback.({:data, half2}, acc) + acc = callback.({:data, finish <> done_chunk()}, acc) + {:ok, acc} + end) + + assert :ok = ChatCompletions.stream(@settings, @messages, caller) + + assert_receive {:ai_event, %Event{type: :text_delta, payload: %{delta: "split"}}} + assert_receive {:ai_event, %Event{type: :done, payload: %{stop_reason: "stop"}}} + end + end +end diff --git a/test/extensions/ai_agent/adapter_test.exs b/test/extensions/ai_agent/adapter_test.exs new file mode 100644 index 000000000..c8a14d42f --- /dev/null +++ b/test/extensions/ai_agent/adapter_test.exs @@ -0,0 +1,13 @@ +defmodule Extensions.AiAgent.AdapterTest do + use ExUnit.Case, async: true + + alias Extensions.AiAgent.Adapter + + describe "emit/2" do + test "sends ai_event message to caller" do + event = %Extensions.AiAgent.Types.Event{type: :text_delta, payload: %{delta: "hello"}} + Adapter.emit(self(), event) + assert_receive {:ai_event, ^event} + end + end +end diff --git a/test/extensions/ai_agent/ai_policies_test.exs b/test/extensions/ai_agent/ai_policies_test.exs new file mode 100644 index 000000000..645b2d411 --- /dev/null +++ b/test/extensions/ai_agent/ai_policies_test.exs @@ -0,0 +1,25 @@ +defmodule Extensions.AiAgent.AiPoliciesTest do + use ExUnit.Case, async: true + + alias Realtime.Tenants.Authorization.Policies + alias Realtime.Tenants.Authorization.Policies.AiPolicies + + describe "AiPolicies struct" do + test "new Policies has no AI write access — events are rejected until RLS grants it" do + assert %Policies{ai_agent: %AiPolicies{write: nil}} = %Policies{} + refute match?(%Policies{ai_agent: %AiPolicies{write: true}}, %Policies{}) + end + end + + describe "Policies struct includes ai_agent" do + test "update_policies sets ai_agent write" do + policies = Policies.update_policies(%Policies{}, :ai_agent, :write, true) + assert %Policies{ai_agent: %AiPolicies{write: true}} = policies + end + + test "update_policies sets ai_agent read" do + policies = Policies.update_policies(%Policies{}, :ai_agent, :read, false) + assert %Policies{ai_agent: %AiPolicies{read: false}} = policies + end + end +end diff --git a/test/extensions/ai_agent/broadcast_handler_ai_test.exs b/test/extensions/ai_agent/broadcast_handler_ai_test.exs new file mode 100644 index 000000000..ed8b18e88 --- /dev/null +++ b/test/extensions/ai_agent/broadcast_handler_ai_test.exs @@ -0,0 +1,147 @@ +defmodule Extensions.AiAgent.BroadcastHandlerAiTest do + use ExUnit.Case, async: true + use Mimic + + alias Extensions.AiAgent.Session + alias Realtime.Tenants.Authorization + alias Realtime.Tenants.Authorization.Policies + alias Realtime.Tenants.Authorization.Policies.AiPolicies + alias Realtime.Tenants.Connect + alias RealtimeWeb.RealtimeChannel.AiAgentHandler + + defp private_socket_with_ai_session(pid) do + %Phoenix.Socket{ + joined: true, + topic: "realtime:test", + transport_pid: self(), + serializer: Phoenix.Socket.V1.JSONSerializer, + join_ref: "1", + assigns: %{ + ai_session: pid, + private?: true, + tenant: "tenant-id", + tenant_topic: "tenant:topic", + self_broadcast: true, + ack_broadcast: false, + authorization_context: %Authorization{tenant_id: "tenant-id", topic: "topic"}, + policies: nil + } + } + end + + describe "handle/2 with AI events on private channel" do + setup do + stub(Connect, :lookup_or_start_connection, fn _id -> {:ok, self()} end) + :ok + end + + test "routes agent_input when ai_agent write policy is granted" do + test_pid = self() + session_pid = spawn(fn -> Process.sleep(:infinity) end) + + stub(Authorization, :get_write_authorizations, fn _policies, _conn, _ctx, _opts -> + {:ok, %Policies{ai_agent: %AiPolicies{write: true}}} + end) + + stub(Session, :handle_input, fn _pid, input -> + send(test_pid, {:routed_input, input}) + :ok + end) + + payload = %{"event" => "agent_input", "payload" => %{"text" => "hello"}} + {:noreply, _} = AiAgentHandler.handle(payload, :fake_conn, private_socket_with_ai_session(session_pid)) + + assert_receive {:routed_input, %{"text" => "hello"}} + end + + test "does not route agent_input when ai_agent write policy is denied" do + session_pid = spawn(fn -> Process.sleep(:infinity) end) + + stub(Authorization, :get_write_authorizations, fn _policies, _conn, _ctx, _opts -> + {:ok, %Policies{ai_agent: %AiPolicies{write: false}}} + end) + + reject(&Session.handle_input/2) + + payload = %{"event" => "agent_input", "payload" => %{"text" => "hello"}} + {:noreply, _} = AiAgentHandler.handle(payload, :fake_conn, private_socket_with_ai_session(session_pid)) + end + + test "routes agent_cancel when ai_agent write policy is granted" do + test_pid = self() + session_pid = spawn(fn -> Process.sleep(:infinity) end) + + stub(Authorization, :get_write_authorizations, fn _policies, _conn, _ctx, _opts -> + {:ok, %Policies{ai_agent: %AiPolicies{write: true}}} + end) + + stub(Session, :cancel, fn _pid -> + send(test_pid, :cancelled) + :ok + end) + + payload = %{"event" => "agent_cancel", "payload" => %{}} + {:noreply, _} = AiAgentHandler.handle(payload, :fake_conn, private_socket_with_ai_session(session_pid)) + + assert_receive :cancelled + end + end + + describe "handle/3 with binary-encoded AI events (V2 kind=3) on private channel" do + setup do + stub(Connect, :lookup_or_start_connection, fn _id -> {:ok, self()} end) + :ok + end + + test "routes agent_input from binary tuple when ai_agent write policy is granted" do + test_pid = self() + session_pid = spawn(fn -> Process.sleep(:infinity) end) + + stub(Authorization, :get_write_authorizations, fn _policies, _conn, _ctx, _opts -> + {:ok, %Policies{ai_agent: %AiPolicies{write: true}}} + end) + + stub(Session, :handle_input, fn _pid, input -> + send(test_pid, {:routed_input, input}) + :ok + end) + + payload = {"agent_input", :json, Jason.encode!(%{"text" => "hello"}), %{}} + {:noreply, _} = AiAgentHandler.handle(payload, :fake_conn, private_socket_with_ai_session(session_pid)) + + assert_receive {:routed_input, %{"text" => "hello"}} + end + + test "does not route agent_input from binary tuple when ai_agent write policy is denied" do + session_pid = spawn(fn -> Process.sleep(:infinity) end) + + stub(Authorization, :get_write_authorizations, fn _policies, _conn, _ctx, _opts -> + {:ok, %Policies{ai_agent: %AiPolicies{write: false}}} + end) + + reject(&Session.handle_input/2) + + payload = {"agent_input", :json, Jason.encode!(%{"text" => "hello"}), %{}} + {:noreply, _} = AiAgentHandler.handle(payload, :fake_conn, private_socket_with_ai_session(session_pid)) + end + + test "routes agent_cancel from binary tuple when ai_agent write policy is granted" do + test_pid = self() + session_pid = spawn(fn -> Process.sleep(:infinity) end) + + stub(Authorization, :get_write_authorizations, fn _policies, _conn, _ctx, _opts -> + {:ok, %Policies{ai_agent: %AiPolicies{write: true}}} + end) + + stub(Session, :cancel, fn _pid -> + send(test_pid, :cancelled) + :ok + end) + + payload = {"agent_cancel", :json, Jason.encode!(%{}), %{}} + {:noreply, _} = AiAgentHandler.handle(payload, :fake_conn, private_socket_with_ai_session(session_pid)) + + assert_receive :cancelled + end + end +end diff --git a/test/extensions/ai_agent/db_settings_test.exs b/test/extensions/ai_agent/db_settings_test.exs new file mode 100644 index 000000000..61e7f72dd --- /dev/null +++ b/test/extensions/ai_agent/db_settings_test.exs @@ -0,0 +1,38 @@ +defmodule Extensions.AiAgent.DbSettingsTest do + use ExUnit.Case, async: true + + alias Extensions.AiAgent.DbSettings + + describe "default/0" do + test "fills max_concurrent_sessions when absent from settings" do + default = DbSettings.default() + assert default["max_concurrent_sessions"] == 10 + end + end + + describe "required/0" do + test "api_key is required and will be encrypted" do + required = DbSettings.required() + assert {"api_key", _, true} = List.keyfind!(required, "api_key", 0) + end + + test "model, protocol, and base_url are required but not encrypted" do + required = DbSettings.required() + assert {"model", _, false} = List.keyfind!(required, "model", 0) + assert {"protocol", _, false} = List.keyfind!(required, "protocol", 0) + assert {"base_url", _, false} = List.keyfind!(required, "base_url", 0) + end + + test "validators accept binary values" do + for {_name, validator, _flag} <- DbSettings.required() do + assert validator.("value") == true + end + end + + test "validators reject non-binary values" do + for {_name, validator, _flag} <- DbSettings.required() do + assert validator.(123) == false + end + end + end +end diff --git a/test/extensions/ai_agent/event_test.exs b/test/extensions/ai_agent/event_test.exs new file mode 100644 index 000000000..1e8859982 --- /dev/null +++ b/test/extensions/ai_agent/event_test.exs @@ -0,0 +1,15 @@ +defmodule Extensions.AiAgent.Types.EventTest do + use ExUnit.Case, async: true + + alias Extensions.AiAgent.Types.Event + + describe "broadcast_event/1" do + test "prefixes type with agent_" do + assert Event.broadcast_event(%Event{type: :text_delta, payload: %{}}) == "agent_text_delta" + assert Event.broadcast_event(%Event{type: :done, payload: %{}}) == "agent_done" + assert Event.broadcast_event(%Event{type: :error, payload: %{}}) == "agent_error" + assert Event.broadcast_event(%Event{type: :tool_call_done, payload: %{}}) == "agent_tool_call_done" + assert Event.broadcast_event(%Event{type: :session_started, payload: %{}}) == "agent_session_started" + end + end +end diff --git a/test/extensions/ai_agent/session_persistence_test.exs b/test/extensions/ai_agent/session_persistence_test.exs new file mode 100644 index 000000000..9815635ac --- /dev/null +++ b/test/extensions/ai_agent/session_persistence_test.exs @@ -0,0 +1,253 @@ +defmodule Extensions.AiAgent.SessionPersistenceTest do + use ExUnit.Case, async: true + use Mimic + + alias Extensions.AiAgent.Session + alias Realtime.Api.Message + alias Realtime.Tenants.Connect + alias Realtime.Tenants.Repo + + @encrypted_key Realtime.Crypto.encrypt!("sk-test") + + @settings %{ + "protocol" => "openai_compatible", + "base_url" => "https://api.openai.com/v1", + "model" => "gpt-4o", + "api_key" => @encrypted_key + } + + defp start_session(overrides \\ []) do + topic = "test-tenant:private:agent:" <> UUID.uuid4() + Phoenix.PubSub.subscribe(Realtime.PubSub, topic) + + opts = + Keyword.merge( + [ + tenant_id: "test-tenant", + tenant_topic: topic, + settings: @settings, + channel_pid: self() + ], + overrides + ) + + pid = start_supervised!({Session, opts}) + Mimic.allow(Finch, self(), pid) + Mimic.allow(Connect, self(), pid) + Mimic.allow(Repo, self(), pid) + GenServer.cast(pid, :emit_session_started) + pid + end + + defp sse_text(text) do + data = Jason.encode!(%{"choices" => [%{"delta" => %{"content" => text}, "finish_reason" => nil}]}) + "data: #{data}\n\n" + end + + defp sse_done do + data = Jason.encode!(%{"choices" => [%{"delta" => %{}, "finish_reason" => "stop"}]}) + "data: #{data}\n\ndata: [DONE]\n\n" + end + + describe "session_id in session_started event" do + test "broadcasts a fresh UUID when no client session_id provided" do + start_session() + + assert_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{ + "event" => "agent_session_started", + "payload" => %{session_id: <<_::binary-size(36)>>} + } + }, + 500 + end + + test "uses and broadcasts the client-provided session_id" do + client_id = UUID.uuid4() + stub(Connect, :lookup_or_start_connection, fn _ -> {:error, :not_found} end) + + start_session(session_id: client_id) + + assert_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{"event" => "agent_session_started", "payload" => %{session_id: ^client_id}} + }, + 500 + end + end + + describe "persist_async: user messages" do + test "persists user message to tenant DB on text input" do + test_pid = self() + + stub(Connect, :lookup_or_start_connection, fn _ -> {:ok, :fake_conn} end) + + stub(Repo, :insert_all_entries, fn _conn, changesets, _struct -> + messages = Enum.map(changesets, & &1.changes) + send(test_pid, {:persisted, messages}) + {:ok, []} + end) + + pid = start_session() + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + Mimic.allow(Connect, self(), pid) + Mimic.allow(Repo, self(), pid) + + Session.handle_input(pid, %{"text" => "Hello agent"}) + + assert_receive {:persisted, + [ + %{ + extension: :ai_agent, + payload: %{"role" => "user", "content" => "Hello agent", "session_id" => session_id} + }, + %{ + extension: :ai_agent_event, + event: "agent_input", + payload: %{"text" => "Hello agent"} + } + ]}, + 500 + + assert is_binary(session_id) + end + + test "persists tool_result message to tenant DB" do + test_pid = self() + + stub(Connect, :lookup_or_start_connection, fn _ -> {:ok, :fake_conn} end) + + stub(Repo, :insert_all_entries, fn _conn, changesets, _struct -> + messages = Enum.map(changesets, & &1.changes) + send(test_pid, {:persisted, messages}) + {:ok, []} + end) + + pid = start_session() + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + Mimic.allow(Connect, self(), pid) + Mimic.allow(Repo, self(), pid) + + Session.handle_input(pid, %{"tool_result" => %{"tool_call_id" => "call_1", "content" => "42 degrees"}}) + + assert_receive {:persisted, [%{payload: %{"role" => "tool", "content" => "42 degrees"}}]}, 500 + end + end + + describe "persist_async: assistant messages" do + test "accumulates text deltas and persists full assistant message on done" do + test_pid = self() + + stub(Connect, :lookup_or_start_connection, fn _ -> {:ok, :fake_conn} end) + + stub(Repo, :insert_all_entries, fn _conn, changesets, _struct -> + messages = Enum.map(changesets, & &1.changes) + send(test_pid, {:persisted, messages}) + {:ok, []} + end) + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + acc = callback.({:status, 200}, acc) + acc = callback.({:data, sse_text("Hello") <> sse_text(" world") <> sse_done()}, acc) + {:ok, acc} + end) + + pid = start_session() + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + Mimic.allow(Connect, self(), pid) + Mimic.allow(Repo, self(), pid) + + Session.handle_input(pid, %{"text" => "Say hi"}) + + assert_receive {:persisted, [%{payload: %{"role" => "user"}}, %{extension: :ai_agent_event}]}, 500 + + assert_receive {:persisted, + [%{payload: %{"role" => "assistant", "content" => "Hello world"}}, %{extension: :ai_agent_event}]}, + 1000 + end + + test "does not persist assistant message when stream is empty (tool-call only turn)" do + test_pid = self() + + stub(Connect, :lookup_or_start_connection, fn _ -> {:ok, :fake_conn} end) + + stub(Repo, :insert_all_entries, fn _conn, changesets, _struct -> + messages = Enum.map(changesets, & &1.changes) + send(test_pid, {:persisted, messages}) + {:ok, []} + end) + + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + finish = Jason.encode!(%{"choices" => [%{"delta" => %{}, "finish_reason" => "tool_calls"}]}) + acc = callback.({:status, 200}, acc) + acc = callback.({:data, "data: #{finish}\n\ndata: [DONE]\n\n"}, acc) + {:ok, acc} + end) + + pid = start_session() + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + Mimic.allow(Connect, self(), pid) + Mimic.allow(Repo, self(), pid) + + Session.handle_input(pid, %{"text" => "Run a tool"}) + + assert_receive {:persisted, [%{payload: %{"role" => "user"}}, %{extension: :ai_agent_event}]}, 500 + refute_receive {:persisted, [%{payload: %{"role" => "assistant", "content" => _}} | _]}, 200 + end + end + + describe "load_history: resuming a session" do + test "loads prior messages when client session_id is provided" do + client_id = UUID.uuid4() + + prior_message = %Message{ + topic: "test-tenant:private:agent:some-topic", + extension: :ai_agent, + payload: %{"role" => "user", "content" => "Prior question", "session_id" => client_id}, + inserted_at: NaiveDateTime.utc_now() + } + + stub(Connect, :lookup_or_start_connection, fn _ -> {:ok, :fake_conn} end) + stub(Repo, :all, fn _conn, _query, Message -> {:ok, [prior_message]} end) + + pid = start_session(session_id: client_id) + + assert_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{"event" => "agent_session_started", "payload" => %{session_id: ^client_id}} + }, + 500 + + test_pid = self() + + stub(Finch, :stream, fn _req, _name, acc, _callback, _opts -> + send(test_pid, :stream_called) + {:ok, acc} + end) + + Mimic.allow(Finch, self(), pid) + Session.handle_input(pid, %{"text" => "Follow-up"}) + assert_receive :stream_called, 500 + end + + test "starts fresh when no client session_id (does not query DB)" do + reject(&Repo.all/3) + start_session() + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + end + + test "starts fresh gracefully when DB connection is unavailable" do + client_id = UUID.uuid4() + stub(Connect, :lookup_or_start_connection, fn _ -> {:error, :unavailable} end) + + start_session(session_id: client_id) + + assert_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{"event" => "agent_session_started", "payload" => %{session_id: ^client_id}} + }, + 500 + end + end +end diff --git a/test/extensions/ai_agent/session_rate_limit_test.exs b/test/extensions/ai_agent/session_rate_limit_test.exs new file mode 100644 index 000000000..46872c412 --- /dev/null +++ b/test/extensions/ai_agent/session_rate_limit_test.exs @@ -0,0 +1,154 @@ +defmodule Extensions.AiAgent.SessionRateLimitTest do + use ExUnit.Case, async: true + use Mimic + + alias Extensions.AiAgent.Session + + @encrypted_key Realtime.Crypto.encrypt!("sk-test") + + @settings %{ + "protocol" => "openai_compatible", + "base_url" => "https://api.openai.com/v1", + "model" => "gpt-4o", + "api_key" => @encrypted_key + } + + defp start_session(overrides \\ []) do + topic = "test-tenant:private:agent:" <> UUID.uuid4() + Phoenix.PubSub.subscribe(Realtime.PubSub, topic) + + opts = + Keyword.merge( + [ + tenant_id: "test-tenant", + tenant_topic: topic, + settings: @settings, + channel_pid: self(), + max_ai_events_per_second: 100, + max_ai_tokens_per_minute: 60_000 + ], + overrides + ) + + pid = start_supervised!({Session, opts}) + Mimic.allow(Finch, self(), pid) + Mimic.allow(Realtime.RateCounter, self(), pid) + GenServer.cast(pid, :emit_session_started) + {pid, topic} + end + + describe "max_ai_events_per_second" do + test "allows inputs when rate counter is not triggered" do + stub(Finch, :stream, fn _req, _name, acc, _callback, _opts -> {:ok, acc} end) + stub(Realtime.RateCounter, :get, fn _ -> {:ok, %Realtime.RateCounter{limit: %{triggered: false}}} end) + + {pid, _topic} = start_session(max_ai_events_per_second: 100) + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + Session.handle_input(pid, %{"text" => "hello"}) + + refute_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{"event" => "agent_error", "payload" => %{reason: "rate_limit_exceeded"}} + }, + 200 + end + + test "rejects inputs when rate counter limit is triggered" do + stub(Finch, :stream, fn _req, _name, acc, _callback, _opts -> {:ok, acc} end) + stub(Realtime.RateCounter, :get, fn _ -> {:ok, %Realtime.RateCounter{limit: %{triggered: true}}} end) + + {pid, _topic} = start_session(max_ai_events_per_second: 1) + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + Session.handle_input(pid, %{"text" => "over limit"}) + + assert_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{"event" => "agent_error", "payload" => %{reason: "rate_limit_exceeded"}} + }, + 500 + end + end + + describe "max_ai_tokens_per_minute" do + test "allows inputs when token budget is not exhausted" do + stub(Finch, :stream, fn _req, _name, acc, _callback, _opts -> {:ok, acc} end) + {pid, _topic} = start_session(max_ai_tokens_per_minute: 10_000) + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + Session.handle_input(pid, %{"text" => "hello"}) + + refute_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{"event" => "agent_error", "payload" => %{reason: "token_limit_exceeded"}} + }, + 200 + end + + test "rejects input when token budget is exhausted" do + stub(Finch, :stream, fn _req, _name, acc, _callback, _opts -> {:ok, acc} end) + {pid, _topic} = start_session(max_ai_tokens_per_minute: 5) + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + usage_event = %Extensions.AiAgent.Types.Event{type: :usage, payload: %{input_tokens: 5, output_tokens: 5}} + send(pid, {:ai_event, usage_event}) + Process.sleep(10) + + Session.handle_input(pid, %{"text" => "over budget"}) + + assert_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{"event" => "agent_error", "payload" => %{reason: "token_limit_exceeded"}} + }, + 500 + end + + test "resets token budget after one minute window" do + stub(Finch, :stream, fn _req, _name, acc, _callback, _opts -> {:ok, acc} end) + {pid, _topic} = start_session(max_ai_tokens_per_minute: 5) + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + usage_event = %Extensions.AiAgent.Types.Event{type: :usage, payload: %{input_tokens: 3, output_tokens: 3}} + send(pid, {:ai_event, usage_event}) + Process.sleep(10) + + Session.handle_input(pid, %{"text" => "over budget"}) + + assert_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{"event" => "agent_error", "payload" => %{reason: "token_limit_exceeded"}} + }, + 500 + + send(pid, :reset_token_window) + Process.sleep(10) + + Session.handle_input(pid, %{"text" => "after reset"}) + + refute_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{"event" => "agent_error", "payload" => %{reason: "token_limit_exceeded"}} + }, + 200 + end + + test "zero max_ai_tokens_per_minute disables token limit" do + stub(Finch, :stream, fn _req, _name, acc, _callback, _opts -> {:ok, acc} end) + {pid, _topic} = start_session(max_ai_tokens_per_minute: 0) + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + usage_event = %Extensions.AiAgent.Types.Event{type: :usage, payload: %{input_tokens: 9_999_999, output_tokens: 0}} + send(pid, {:ai_event, usage_event}) + Process.sleep(10) + + Session.handle_input(pid, %{"text" => "should still work"}) + + refute_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{"event" => "agent_error", "payload" => %{reason: "token_limit_exceeded"}} + }, + 200 + end + end +end diff --git a/test/extensions/ai_agent/session_supervisor_test.exs b/test/extensions/ai_agent/session_supervisor_test.exs new file mode 100644 index 000000000..05ac8a55a --- /dev/null +++ b/test/extensions/ai_agent/session_supervisor_test.exs @@ -0,0 +1,33 @@ +defmodule Extensions.AiAgent.SessionSupervisorTest do + use ExUnit.Case, async: true + + alias Extensions.AiAgent.SessionSupervisor + + @encrypted_key Realtime.Crypto.encrypt!("sk-test") + + @base_opts [ + tenant_id: "test-tenant", + tenant_topic: "test-tenant:private:agent:chat", + settings: %{ + "protocol" => "openai_compatible", + "base_url" => "https://api.openai.com/v1", + "model" => "gpt-4o", + "api_key" => @encrypted_key + } + ] + + defp spawn_channel do + pid = spawn(fn -> Process.sleep(:infinity) end) + on_exit(fn -> Process.exit(pid, :kill) end) + pid + end + + describe "start_session/1" do + test "starts a session and returns pid" do + opts = Keyword.put(@base_opts, :channel_pid, spawn_channel()) + assert {:ok, pid} = SessionSupervisor.start_session(opts) + assert is_pid(pid) + assert Process.alive?(pid) + end + end +end diff --git a/test/extensions/ai_agent/session_test.exs b/test/extensions/ai_agent/session_test.exs new file mode 100644 index 000000000..8719add76 --- /dev/null +++ b/test/extensions/ai_agent/session_test.exs @@ -0,0 +1,183 @@ +defmodule Extensions.AiAgent.SessionTest do + use ExUnit.Case, async: true + use Mimic + + alias Extensions.AiAgent.Session + + @encrypted_key Realtime.Crypto.encrypt!("sk-test") + + @settings %{ + "protocol" => "openai_compatible", + "base_url" => "https://api.openai.com/v1", + "model" => "gpt-4o", + "api_key" => @encrypted_key + } + + defp start_session(overrides \\ []) do + topic = "test-tenant:private:agent:" <> UUID.uuid4() + Phoenix.PubSub.subscribe(Realtime.PubSub, topic) + + opts = + Keyword.merge( + [ + tenant_id: "test-tenant", + tenant_topic: topic, + settings: @settings, + channel_pid: self() + ], + overrides + ) + + pid = start_supervised!({Session, opts}) + Mimic.allow(Finch, self(), pid) + GenServer.cast(pid, :emit_session_started) + pid + end + + defp sse_text(text) do + data = Jason.encode!(%{"choices" => [%{"delta" => %{"content" => text}, "finish_reason" => nil}]}) + "data: #{data}\n\n" + end + + defp sse_done do + data = Jason.encode!(%{"choices" => [%{"delta" => %{}, "finish_reason" => "stop"}]}) + "data: #{data}\n\ndata: [DONE]\n\n" + end + + describe "start_link/1" do + test "starts successfully with valid settings" do + pid = start_session() + assert is_pid(pid) + assert Process.alive?(pid) + end + + test "emits session_started event on init" do + start_session() + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + end + end + + describe "handle_input/2 with text" do + test "sends text_delta directly to channel and done via PubSub" do + stub(Finch, :stream, fn _req, _name, acc, callback, _opts -> + acc = callback.({:status, 200}, acc) + acc = callback.({:data, sse_text("Hello") <> sse_done()}, acc) + {:ok, acc} + end) + + pid = start_session() + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + Session.handle_input(pid, %{"text" => "Say hello"}) + + assert_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{"event" => "agent_text_delta", "payload" => %{delta: "Hello"}} + }, + 1000 + + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_done"}}, 1000 + end + + test "rejects text larger than 64KB" do + pid = start_session() + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + large_text = String.duplicate("a", 65_000) + Session.handle_input(pid, %{"text" => large_text}) + + assert_receive %Phoenix.Socket.Broadcast{ + event: "ai_event", + payload: %{"event" => "agent_error", "payload" => %{reason: "input_too_large"}} + }, + 500 + end + end + + describe "cancel/1" do + test "cancels an in-flight stream" do + stub(Finch, :stream, fn _req, _name, acc, _callback, _opts -> + Process.sleep(:infinity) + {:ok, acc} + end) + + pid = start_session() + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + Session.handle_input(pid, %{"text" => "This will be cancelled"}) + Session.cancel(pid) + + refute_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_done"}}, 200 + end + end + + describe "adapter crash isolation" do + test "broadcasts error event when adapter raises instead of crashing session" do + stub(Finch, :stream, fn _req, _name, _acc, _callback, _opts -> + raise "simulated provider crash" + end) + + pid = start_session() + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + Session.handle_input(pid, %{"text" => "trigger crash"}) + + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_error"}}, 1000 + assert Process.alive?(pid) + end + end + + describe "channel process termination" do + test "session stops when channel process dies" do + channel_pid = spawn(fn -> Process.sleep(:infinity) end) + pid = start_session(channel_pid: channel_pid) + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + ref = Process.monitor(pid) + Process.exit(channel_pid, :kill) + + assert_receive {:DOWN, ^ref, :process, ^pid, :normal}, 1000 + end + end + + describe "system_prompt initialization" do + test "includes system message first in request when system_prompt is set" do + caller = self() + settings = Map.put(@settings, "system_prompt", "You are a pirate.") + + stub(Finch, :stream, fn req, _name, acc, callback, _opts -> + send(caller, {:captured_body, Jason.decode!(req.body)}) + acc = callback.({:status, 200}, acc) + acc = callback.({:data, sse_done()}, acc) + {:ok, acc} + end) + + pid = start_session(settings: settings) + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + Session.handle_input(pid, %{"text" => "ahoy"}) + + assert_receive {:captured_body, body}, 1000 + assert [%{"role" => "system", "content" => "You are a pirate."} | _] = body["messages"] + end + + test "sends no system message when system_prompt is absent" do + caller = self() + + stub(Finch, :stream, fn req, _name, acc, callback, _opts -> + send(caller, {:captured_body, Jason.decode!(req.body)}) + acc = callback.({:status, 200}, acc) + acc = callback.({:data, sse_done()}, acc) + {:ok, acc} + end) + + pid = start_session() + assert_receive %Phoenix.Socket.Broadcast{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 500 + + Session.handle_input(pid, %{"text" => "hello"}) + + assert_receive {:captured_body, body}, 1000 + refute Enum.any?(body["messages"], &(&1["role"] == "system")) + end + end +end diff --git a/test/extensions/extensions_test.exs b/test/extensions/extensions_test.exs index eb98c508a..4d31e620c 100644 --- a/test/extensions/extensions_test.exs +++ b/test/extensions/extensions_test.exs @@ -38,5 +38,16 @@ defmodule Realtime.ExtensionsTest do result = Extensions.db_settings("unknown_extension") assert %{default: %{}, required: []} = result end + + test "returns default and required for ai_agent" do + result = Extensions.db_settings("ai_agent") + + assert %{default: default, required: required} = result + assert is_map(default) + assert Map.has_key?(default, "max_concurrent_sessions") + assert is_list(required) + field_names = Enum.map(required, fn {name, _, _} -> name end) + assert "api_key" in field_names + end end end diff --git a/test/integration/ai_agent/live_smoke_test.exs b/test/integration/ai_agent/live_smoke_test.exs new file mode 100644 index 000000000..b13e4c25c --- /dev/null +++ b/test/integration/ai_agent/live_smoke_test.exs @@ -0,0 +1,211 @@ +defmodule Realtime.Integration.AiAgent.LiveSmokeTest do + @moduledoc """ + End-to-end integration tests for the AI agent extension using a real Ollama + instance. These tests require Ollama to be running and a model to be available. + + Run with: + mix test --include live_llm + + Or set OLLAMA_HOST and OLLAMA_MODEL env vars to use an external instance: + OLLAMA_HOST=http://my-ollama:11434 OLLAMA_MODEL=llama3.2:1b mix test --include live_llm + """ + + use RealtimeWeb.ConnCase, async: false + + import Generators + import Integrations + + alias Phoenix.Socket.Message + alias Realtime.Database + alias Realtime.Integration.WebsocketClient + + @moduletag :live_llm + @moduletag :capture_log + @moduletag timeout: 120_000 + + @agent_name "smoke-agent" + @agent_topic "agent:smoke:test" + @serializer Phoenix.Socket.V1.JSONSerializer + + setup_all do + case Ollama.ensure_ready() do + :ok -> + :ok + + {:error, reason} -> + raise "Ollama not available for :live_llm tests: #{reason}\n" <> + "Set OLLAMA_HOST env var or ensure Docker is running.\n" <> + "Skipping with `mix test` (no --include live_llm) is fine." + end + end + + setup do + %{tenant: base_tenant} = checkout_tenant_and_connect() + tenant = add_ai_agent_extension(base_tenant, @agent_name) + + {:ok, db_conn} = Database.connect(tenant, "realtime_test", :stop) + clean_table(db_conn, "realtime", "messages") + create_rls_policies(db_conn, [:authenticated_all_topic_read, :authenticated_all_topic_insert], %{}) + + ExUnit.Callbacks.on_exit(fn -> + if Process.alive?(db_conn), do: GenServer.stop(db_conn, :normal, 1_000) + end) + + %{tenant: tenant} + end + + describe "AI agent via WebSocket" do + test "receives streaming text response for a simple prompt", %{tenant: tenant} do + {socket, _} = get_connection(tenant, @serializer, role: "authenticated") + topic = "realtime:#{@agent_topic}" + config = %{private: true, broadcast: %{self: true}, ai: %{enabled: true, agent: @agent_name}} + + WebsocketClient.join(socket, topic, %{config: config}) + assert_receive %Message{event: "phx_reply", payload: %{"status" => "ok"}, topic: ^topic}, 5_000 + + assert_receive %Message{ + event: "ai_event", + payload: %{"event" => "agent_session_started", "payload" => %{"session_id" => session_id}} + }, + 5_000 + + assert is_binary(session_id) + + WebsocketClient.send_event(socket, topic, "broadcast", %{ + "event" => "agent_input", + "type" => "broadcast", + "payload" => %{"text" => "Reply with exactly the word: pong"} + }) + + {text, stop_reason} = collect_response(topic, 30_000) + + assert is_binary(text) and byte_size(text) > 0, + "Expected non-empty text response, got: #{inspect(text)}" + + assert stop_reason in ["stop", "end_turn", "length"], + "Unexpected stop reason: #{inspect(stop_reason)}" + end + + test "session_id is preserved on reconnect", %{tenant: tenant} do + topic = "realtime:#{@agent_topic}" + config = %{private: true, broadcast: %{self: true}, ai: %{enabled: true, agent: @agent_name}} + + {socket, _} = get_connection(tenant, @serializer, role: "authenticated") + WebsocketClient.join(socket, topic, %{config: config}) + assert_receive %Message{event: "phx_reply", payload: %{"status" => "ok"}, topic: ^topic}, 5_000 + + assert_receive %Message{ + event: "ai_event", + payload: %{"event" => "agent_session_started", "payload" => %{"session_id" => session_id}} + }, + 5_000 + + WebsocketClient.send_event(socket, topic, "broadcast", %{ + "event" => "agent_input", + "type" => "broadcast", + "payload" => %{"text" => "Remember the number 42. Just say 'OK'."} + }) + + {_text, _} = collect_response(topic, 30_000) + WebsocketClient.close(socket) + + {socket2, _} = get_connection(tenant, @serializer, role: "authenticated") + + config2 = %{ + private: true, + broadcast: %{self: true}, + ai: %{enabled: true, agent: @agent_name, session_id: session_id} + } + + WebsocketClient.join(socket2, topic, %{config: config2}) + assert_receive %Message{event: "phx_reply", payload: %{"status" => "ok"}, topic: ^topic}, 5_000 + + assert_receive %Message{ + event: "ai_event", + payload: %{"event" => "agent_session_started", "payload" => %{"session_id" => ^session_id}} + }, + 5_000 + + WebsocketClient.send_event(socket2, topic, "broadcast", %{ + "event" => "agent_input", + "type" => "broadcast", + "payload" => %{"text" => "What number did I ask you to remember?"} + }) + + {text2, _} = collect_response(topic, 30_000) + assert is_binary(text2) and byte_size(text2) > 0 + end + + test "cancels an in-flight response", %{tenant: tenant} do + {socket, _} = get_connection(tenant, @serializer, role: "authenticated") + topic = "realtime:#{@agent_topic}" + config = %{private: true, broadcast: %{self: true}, ai: %{enabled: true, agent: @agent_name}} + + WebsocketClient.join(socket, topic, %{config: config}) + assert_receive %Message{event: "phx_reply", payload: %{"status" => "ok"}, topic: ^topic}, 5_000 + assert_receive %Message{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 5_000 + + WebsocketClient.send_event(socket, topic, "broadcast", %{ + "event" => "agent_input", + "type" => "broadcast", + "payload" => %{"text" => "Count from 1 to 1000 slowly."} + }) + + assert_receive %Message{event: "ai_event", payload: %{"event" => "agent_text_delta"}}, 15_000 + + WebsocketClient.send_event(socket, topic, "broadcast", %{ + "event" => "agent_cancel", + "type" => "broadcast", + "payload" => %{} + }) + + refute_receive %Message{event: "ai_event", payload: %{"event" => "agent_done"}}, 1_000 + end + + test "returns error event on invalid model", %{tenant: tenant} do + bad_tenant = add_ai_agent_extension(tenant, "broken-agent", %{model: "nonexistent-model-xyz"}) + + {socket, _} = get_connection(bad_tenant, @serializer, role: "authenticated") + topic = "realtime:#{@agent_topic}" + config = %{private: true, broadcast: %{self: true}, ai: %{enabled: true, agent: "broken-agent"}} + + WebsocketClient.join(socket, topic, %{config: config}) + assert_receive %Message{event: "phx_reply", payload: %{"status" => "ok"}, topic: ^topic}, 5_000 + assert_receive %Message{event: "ai_event", payload: %{"event" => "agent_session_started"}}, 5_000 + + WebsocketClient.send_event(socket, topic, "broadcast", %{ + "event" => "agent_input", + "type" => "broadcast", + "payload" => %{"text" => "Hello"} + }) + + assert_receive %Message{event: "ai_event", payload: %{"event" => "agent_error"}}, 15_000 + end + end + + defp collect_response(topic, timeout), do: collect_response(topic, timeout, "", nil) + + defp collect_response(topic, timeout, text_acc, _stop_reason) do + receive do + %Message{ + event: "ai_event", + payload: %{"event" => "agent_text_delta", "payload" => %{"delta" => delta}}, + topic: ^topic + } -> + collect_response(topic, timeout, text_acc <> delta, nil) + + %Message{ + event: "ai_event", + payload: %{"event" => "agent_done", "payload" => %{"stop_reason" => reason}}, + topic: ^topic + } -> + {text_acc, reason} + + %Message{event: "ai_event", payload: %{"event" => "agent_error", "payload" => %{"reason" => reason}}} -> + raise "Agent returned error: #{inspect(reason)}" + after + timeout -> + raise "Timed out after #{timeout}ms. Collected so far: #{inspect(text_acc)}" + end + end +end diff --git a/test/realtime/api/extensions_test.exs b/test/realtime/api/extensions_test.exs index 3cec96703..7a7fe89e0 100644 --- a/test/realtime/api/extensions_test.exs +++ b/test/realtime/api/extensions_test.exs @@ -1,5 +1,6 @@ defmodule Realtime.Api.ExtensionsTest do use ExUnit.Case, async: true + use Mimic alias Realtime.Api.Extensions @@ -103,4 +104,145 @@ defmodule Realtime.Api.ExtensionsTest do assert settings["region"] == "us-east-1" end end + + defp ai_agent_attrs(settings) do + base = %{ + "protocol" => "anthropic", + "base_url" => "https://api.example.com", + "api_key" => "key", + "model" => "claude-3-5-sonnet-latest" + } + + %{"type" => "ai_agent", "name" => "my_agent", "settings" => Map.merge(base, settings)} + end + + describe "changeset/2 for ai_agent type" do + test "requires name" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:ok, [{1, 2, 3, 4}]} end) + + attrs = %{ + "type" => "ai_agent", + "settings" => %{"base_url" => "https://api.example.com", "api_key" => "key", "model" => "m"} + } + + changeset = Extensions.changeset(%Extensions{}, attrs) + refute changeset.valid? + assert {"can't be blank", _} = changeset.errors[:name] + end + + test "valid with https base_url resolving to public IP" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:ok, [{1, 2, 3, 4}]} end) + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{})) + assert changeset.valid? + end + + test "rejects base_url resolving to private IP (10.x.x.x)" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:ok, [{10, 0, 0, 1}]} end) + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{})) + refute changeset.valid? + assert {"base_url resolves to a private or reserved address", _} = changeset.errors[:settings] + end + + test "rejects base_url resolving to private IP (192.168.x.x)" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:ok, [{192, 168, 1, 1}]} end) + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{})) + refute changeset.valid? + assert {"base_url resolves to a private or reserved address", _} = changeset.errors[:settings] + end + + test "rejects base_url when host cannot be resolved" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:error, :nxdomain} end) + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{})) + refute changeset.valid? + assert {"base_url host cannot be resolved", _} = changeset.errors[:settings] + end + + test "allows http base_url for localhost" do + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{"base_url" => "http://localhost:4000"})) + assert changeset.valid? + end + + test "allows http base_url for 127.0.0.1" do + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{"base_url" => "http://127.0.0.1:4000"})) + assert changeset.valid? + end + + test "rejects http base_url for non-loopback host" do + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{"base_url" => "http://api.example.com"})) + refute changeset.valid? + + assert {"base_url with http scheme is only permitted for loopback hosts (localhost, 127.0.0.1)", _} = + changeset.errors[:settings] + end + + test "rejects base_url with no scheme" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:ok, [{1, 2, 3, 4}]} end) + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{"base_url" => "api.example.com"})) + refute changeset.valid? + assert {"base_url must use https scheme", _} = changeset.errors[:settings] + end + + test "rejects base_url that is not a string" do + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{"base_url" => 1234})) + refute changeset.valid? + end + + test "rejects nil base_url as blank" do + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{"base_url" => nil})) + refute changeset.valid? + assert {"base_url can't be blank", _} = changeset.errors[:settings] + end + + test "rejects http_referer containing newline" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:ok, [{1, 2, 3, 4}]} end) + + changeset = + Extensions.changeset( + %Extensions{}, + ai_agent_attrs(%{"http_referer" => "https://evil.com\r\nX-Injected: header"}) + ) + + refute changeset.valid? + assert {"http_referer contains invalid characters", _} = changeset.errors[:settings] + end + + test "rejects x_title containing carriage return" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:ok, [{1, 2, 3, 4}]} end) + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{"x_title" => "title\rinjected"})) + refute changeset.valid? + assert {"x_title contains invalid characters", _} = changeset.errors[:settings] + end + + test "rejects http_referer that is not a string" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:ok, [{1, 2, 3, 4}]} end) + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{"http_referer" => 123})) + refute changeset.valid? + assert {"http_referer must be a string", _} = changeset.errors[:settings] + end + + test "accepts string system_prompt" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:ok, [{1, 2, 3, 4}]} end) + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{"system_prompt" => "You are helpful."})) + assert changeset.valid? + end + + test "accepts nil system_prompt" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:ok, [{1, 2, 3, 4}]} end) + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{"system_prompt" => nil})) + assert changeset.valid? + end + + test "accepts absent system_prompt" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:ok, [{1, 2, 3, 4}]} end) + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{})) + assert changeset.valid? + end + + test "rejects non-string system_prompt" do + stub(Realtime.DNS, :getaddrs, fn _host, :inet, _timeout -> {:ok, [{1, 2, 3, 4}]} end) + changeset = Extensions.changeset(%Extensions{}, ai_agent_attrs(%{"system_prompt" => 123})) + refute changeset.valid? + assert {"system_prompt must be a string", _} = changeset.errors[:settings] + end + end end diff --git a/test/realtime/messages_test.exs b/test/realtime/messages_test.exs index 5590adca9..4cf6904bc 100644 --- a/test/realtime/messages_test.exs +++ b/test/realtime/messages_test.exs @@ -98,27 +98,109 @@ defmodule Realtime.MessagesTest do assert Messages.replay(conn, tenant.external_id, "test", 0, 10) == {:ok, [privatem], MapSet.new([privatem.id])} end - test "replay extension=broadcast", %{conn: conn, tenant: tenant} do - privatem = + test "ai_agent LLM context messages are excluded from broadcast replay", %{conn: conn, tenant: tenant} do + broadcast_msg = message_fixture(tenant, %{ "private" => true, "inserted_at" => NaiveDateTime.utc_now() |> NaiveDateTime.add(-1, :minute), - "event" => "new", + "event" => "INSERT", "extension" => "broadcast", "topic" => "test", - "payload" => %{"value" => "new"} + "payload" => %{"value" => "user message"} }) message_fixture(tenant, %{ "private" => true, "inserted_at" => NaiveDateTime.utc_now() |> NaiveDateTime.add(-2, :minute), - "event" => "old", - "extension" => "presence", + "extension" => "ai_agent", "topic" => "test", - "payload" => %{"value" => "old"} + "payload" => %{"role" => "assistant", "content" => "hello"} }) - assert Messages.replay(conn, tenant.external_id, "test", 0, 10) == {:ok, [privatem], MapSet.new([privatem.id])} + assert Messages.replay(conn, tenant.external_id, "test", 0, 10) == + {:ok, [broadcast_msg], MapSet.new([broadcast_msg.id])} + end + + test "mixed broadcast and ai_agent messages on same topic only replays broadcast", %{conn: conn, tenant: tenant} do + broadcast_msgs = + for i <- 1..3 do + message_fixture(tenant, %{ + "private" => true, + "inserted_at" => NaiveDateTime.utc_now() |> NaiveDateTime.add(-i, :minute), + "event" => "INSERT", + "extension" => "broadcast", + "topic" => "test", + "payload" => %{"seq" => i} + }) + end + + for i <- 1..3 do + message_fixture(tenant, %{ + "private" => true, + "inserted_at" => NaiveDateTime.utc_now() |> NaiveDateTime.add(-i, :minute), + "extension" => "ai_agent", + "topic" => "test", + "payload" => %{"role" => "assistant", "content" => "response #{i}"} + }) + end + + {:ok, replayed, replayed_ids} = Messages.replay(conn, tenant.external_id, "test", 0, 10) + + assert length(replayed) == 3 + assert MapSet.size(replayed_ids) == 3 + assert Enum.all?(replayed, &(&1.extension == :broadcast)) + assert MapSet.equal?(replayed_ids, MapSet.new(broadcast_msgs, & &1.id)) + end + + test "ai_agent_event messages are included when requested", %{conn: conn, tenant: tenant} do + ai_event_msg = + message_fixture(tenant, %{ + "private" => true, + "inserted_at" => NaiveDateTime.utc_now() |> NaiveDateTime.add(-1, :minute), + "event" => "agent_done", + "extension" => "ai_agent_event", + "topic" => "test", + "payload" => %{"text" => "hello world"} + }) + + message_fixture(tenant, %{ + "private" => true, + "inserted_at" => NaiveDateTime.utc_now() |> NaiveDateTime.add(-2, :minute), + "extension" => "ai_agent", + "topic" => "test", + "payload" => %{"role" => "assistant", "content" => "hello world"} + }) + + assert Messages.replay(conn, tenant.external_id, "test", 0, 10, [:ai_agent_event]) == + {:ok, [ai_event_msg], MapSet.new([ai_event_msg.id])} + end + + test "ai_agent_event and broadcast messages replay independently by extension filter", %{conn: conn, tenant: tenant} do + broadcast_msg = + message_fixture(tenant, %{ + "private" => true, + "inserted_at" => NaiveDateTime.utc_now() |> NaiveDateTime.add(-1, :minute), + "event" => "my_event", + "extension" => "broadcast", + "topic" => "test", + "payload" => %{"data" => "broadcast"} + }) + + ai_event_msg = + message_fixture(tenant, %{ + "private" => true, + "inserted_at" => NaiveDateTime.utc_now() |> NaiveDateTime.add(-2, :minute), + "event" => "agent_done", + "extension" => "ai_agent_event", + "topic" => "test", + "payload" => %{"text" => "AI response"} + }) + + {:ok, broadcast_replayed, _} = Messages.replay(conn, tenant.external_id, "test", 0, 10) + {:ok, ai_replayed, _} = Messages.replay(conn, tenant.external_id, "test", 0, 10, [:ai_agent_event]) + + assert broadcast_replayed == [broadcast_msg] + assert ai_replayed == [ai_event_msg] end test "replay respects since", %{conn: conn, tenant: tenant} do diff --git a/test/realtime/tenants/authorization_remote_test.exs b/test/realtime/tenants/authorization_remote_test.exs index e531d50eb..b63890577 100644 --- a/test/realtime/tenants/authorization_remote_test.exs +++ b/test/realtime/tenants/authorization_remote_test.exs @@ -9,6 +9,7 @@ defmodule Realtime.Tenants.AuthorizationRemoteTest do alias Realtime.Tenants alias Realtime.Tenants.Authorization alias Realtime.Tenants.Authorization.Policies + alias Realtime.Tenants.Authorization.Policies.AiPolicies alias Realtime.Tenants.Authorization.Policies.BroadcastPolicies alias Realtime.Tenants.Authorization.Policies.PresencePolicies alias Realtime.Tenants.Connect @@ -28,7 +29,8 @@ defmodule Realtime.Tenants.AuthorizationRemoteTest do assert %Policies{ broadcast: %BroadcastPolicies{read: true, write: nil}, - presence: %PresencePolicies{read: true, write: nil} + presence: %PresencePolicies{read: true, write: nil}, + ai_agent: %AiPolicies{read: false, write: nil} } == policies {:ok, policies} = @@ -40,7 +42,8 @@ defmodule Realtime.Tenants.AuthorizationRemoteTest do assert %Policies{ broadcast: %BroadcastPolicies{read: true, write: true}, - presence: %PresencePolicies{read: true, write: true} + presence: %PresencePolicies{read: true, write: true}, + ai_agent: %AiPolicies{read: false, write: false} } == policies end @@ -56,7 +59,8 @@ defmodule Realtime.Tenants.AuthorizationRemoteTest do assert %Policies{ broadcast: %BroadcastPolicies{read: false, write: nil}, - presence: %PresencePolicies{read: false, write: nil} + presence: %PresencePolicies{read: false, write: nil}, + ai_agent: %AiPolicies{read: false, write: nil} } == policies {:ok, policies} = @@ -68,7 +72,8 @@ defmodule Realtime.Tenants.AuthorizationRemoteTest do assert %Policies{ broadcast: %BroadcastPolicies{read: false, write: false}, - presence: %PresencePolicies{read: false, write: false} + presence: %PresencePolicies{read: false, write: false}, + ai_agent: %AiPolicies{read: false, write: false} } == policies end diff --git a/test/realtime/tenants/authorization_test.exs b/test/realtime/tenants/authorization_test.exs index dbdc61065..105e0b3ce 100644 --- a/test/realtime/tenants/authorization_test.exs +++ b/test/realtime/tenants/authorization_test.exs @@ -9,6 +9,7 @@ defmodule Realtime.Tenants.AuthorizationTest do alias Realtime.Tenants.Repo alias Realtime.Tenants.Authorization alias Realtime.Tenants.Authorization.Policies + alias Realtime.Tenants.Authorization.Policies.AiPolicies alias Realtime.Tenants.Authorization.Policies.BroadcastPolicies alias Realtime.Tenants.Authorization.Policies.PresencePolicies @@ -29,7 +30,8 @@ defmodule Realtime.Tenants.AuthorizationTest do assert %Policies{ broadcast: %BroadcastPolicies{read: true, write: true}, - presence: %PresencePolicies{read: true, write: true} + presence: %PresencePolicies{read: true, write: true}, + ai_agent: %AiPolicies{read: false, write: false} } == policies end @@ -73,7 +75,8 @@ defmodule Realtime.Tenants.AuthorizationTest do assert %Policies{ broadcast: %BroadcastPolicies{read: true, write: true}, - presence: %PresencePolicies{read: false, write: false} + presence: %PresencePolicies{read: false, write: false}, + ai_agent: %AiPolicies{read: false, write: false} } == policies end @@ -91,7 +94,8 @@ defmodule Realtime.Tenants.AuthorizationTest do assert %Policies{ broadcast: %BroadcastPolicies{read: false, write: false}, - presence: %PresencePolicies{read: false, write: false} + presence: %PresencePolicies{read: false, write: false}, + ai_agent: %AiPolicies{read: false, write: false} } == policies end diff --git a/test/realtime/tenants/replication_connection_test.exs b/test/realtime/tenants/replication_connection_test.exs index c644890ca..2480e114b 100644 --- a/test/realtime/tenants/replication_connection_test.exs +++ b/test/realtime/tenants/replication_connection_test.exs @@ -400,7 +400,7 @@ defmodule Realtime.Tenants.ReplicationConnectionTest do assert logs =~ "Disconnecting broadcast changes handler in the step" end - test "message without event logs error", %{tenant: tenant} do + test "broadcast message without event logs error", %{tenant: tenant} do logs = capture_log(fn -> start_supervised!( @@ -414,6 +414,7 @@ defmodule Realtime.Tenants.ReplicationConnectionTest do message_fixture(tenant, %{ "topic" => "some_topic", + "extension" => "broadcast", "private" => true, "payload" => %{"value" => "something"} }) @@ -424,7 +425,32 @@ defmodule Realtime.Tenants.ReplicationConnectionTest do assert logs =~ "UnableToBroadcastChanges" end - test "message that exceeds payload size is not broadcast and logs error", %{tenant: tenant} do + test "ai_agent message is silently skipped without error", %{tenant: tenant} do + logs = + capture_log(fn -> + start_supervised!( + {ReplicationConnection, %ReplicationConnection{tenant_id: tenant.external_id, monitored_pid: self()}}, + restart: :transient + ) + + topic = random_string() + tenant_topic = Tenants.tenant_topic(tenant.external_id, topic, false) + assert :ok = Endpoint.subscribe(tenant_topic) + + message_fixture(tenant, %{ + "topic" => topic, + "extension" => "ai_agent", + "private" => true, + "payload" => %{"role" => "assistant", "content" => "hello"} + }) + + refute_receive %Phoenix.Socket.Broadcast{}, 500 + end) + + refute logs =~ "UnableToBroadcastChanges" + end + + test "message that exceeds payload size logs error", %{tenant: tenant} do logs = capture_log(fn -> start_supervised!( diff --git a/test/realtime_web/channels/realtime_channel/broadcast_handler_test.exs b/test/realtime_web/channels/realtime_channel/broadcast_handler_test.exs index 3b6065d9d..4c0958842 100644 --- a/test/realtime_web/channels/realtime_channel/broadcast_handler_test.exs +++ b/test/realtime_web/channels/realtime_channel/broadcast_handler_test.exs @@ -369,6 +369,54 @@ defmodule RealtimeWeb.RealtimeChannel.BroadcastHandlerTest do refute log =~ "UnableToSetPolicies" end + test "V2 json UserBroadcastPush on private channel with write policy", + %{topic: topic, tenant: tenant, db_conn: db_conn, serializer: serializer} do + socket = socket_fixture(tenant, topic, policies: %Policies{broadcast: %BroadcastPolicies{write: true}}) + + user_broadcast_payload = %{"a" => "b"} + json_encoded_user_broadcast_payload = Jason.encode!(user_broadcast_payload) + + {:reply, :ok, _socket} = + BroadcastHandler.handle({"event123", :json, json_encoded_user_broadcast_payload, %{}}, db_conn, socket) + + topic = "realtime:#{topic}" + assert_receive {:socket_push, code, data} + + if serializer == RealtimeWeb.Socket.V2Serializer do + assert code == :binary + + assert data == + << + 4::size(8), + byte_size(topic), + byte_size("event123"), + 0, + 1::size(8), + topic::binary, + "event123" + >> <> json_encoded_user_broadcast_payload + else + assert code == :text + + assert Jason.decode!(data) == + message(serializer, topic, %{ + "event" => "event123", + "payload" => user_broadcast_payload, + "type" => "broadcast" + }) + end + end + + test "V2 json UserBroadcastPush on private channel with write false policy is dropped", + %{topic: topic, tenant: tenant, db_conn: db_conn} do + socket = socket_fixture(tenant, topic, policies: %Policies{broadcast: %BroadcastPolicies{write: false}}) + + json_encoded = Jason.encode!(%{"a" => "b"}) + {:noreply, _socket} = BroadcastHandler.handle({"event123", :json, json_encoded, %{}}, db_conn, socket) + + refute_receive {:socket_push, _code, _data}, 100 + end + @tag policies: [:broken_write_presence] test "handle failing rls policy", %{topic: topic, tenant: tenant, db_conn: db_conn} do socket = socket_fixture(tenant, topic) diff --git a/test/realtime_web/channels/realtime_channel_test.exs b/test/realtime_web/channels/realtime_channel_test.exs index 78f622aa2..2969720b5 100644 --- a/test/realtime_web/channels/realtime_channel_test.exs +++ b/test/realtime_web/channels/realtime_channel_test.exs @@ -505,6 +505,88 @@ defmodule RealtimeWeb.RealtimeChannelTest do refute_receive %Socket.Message{} end + + @tag policies: [:authenticated_all_topic_read] + test "replay ai_agent_event messages on private topic as ai_event channel event", %{tenant: tenant} do + %{id: message_id} = + message_fixture(tenant, %{ + "private" => true, + "inserted_at" => NaiveDateTime.utc_now() |> NaiveDateTime.add(-1, :minute), + "event" => "agent_done", + "extension" => "ai_agent_event", + "topic" => "#{tenant.external_id}-private:test", + "payload" => %{"text" => "hello from AI"} + }) + + jwt = Generators.generate_jwt_token(tenant) + {:ok, %Socket{} = socket} = connect(UserSocket, %{"log_level" => "warning"}, conn_opts(tenant, jwt)) + + config = %{ + "private" => true, + "ai" => %{"replay" => %{"limit" => 2, "since" => :erlang.system_time(:millisecond) - 5 * 60000}} + } + + assert {:ok, _, %Socket{}} = subscribe_and_join(socket, "realtime:test", %{"config" => config}) + + assert_receive %Socket.Message{ + topic: "realtime:test", + event: "ai_event", + payload: %{ + "event" => "agent_done", + "meta" => %{"id" => ^message_id, "replayed" => true}, + "payload" => %{"text" => "hello from AI"}, + "type" => "ai_agent" + } + } + + refute_receive %Socket.Message{} + end + + @tag policies: [:authenticated_all_topic_read] + test "broadcast and ai replay are independent and both delivered on join", %{tenant: tenant} do + %{id: broadcast_id} = + message_fixture(tenant, %{ + "private" => true, + "inserted_at" => NaiveDateTime.utc_now() |> NaiveDateTime.add(-1, :minute), + "event" => "my_event", + "extension" => "broadcast", + "topic" => "test", + "payload" => %{"data" => "broadcast"} + }) + + %{id: ai_event_id} = + message_fixture(tenant, %{ + "private" => true, + "inserted_at" => NaiveDateTime.utc_now() |> NaiveDateTime.add(-2, :minute), + "event" => "agent_done", + "extension" => "ai_agent_event", + "topic" => "#{tenant.external_id}-private:test", + "payload" => %{"text" => "AI response"} + }) + + jwt = Generators.generate_jwt_token(tenant) + {:ok, %Socket{} = socket} = connect(UserSocket, %{"log_level" => "warning"}, conn_opts(tenant, jwt)) + + config = %{ + "private" => true, + "broadcast" => %{"replay" => %{"limit" => 2, "since" => :erlang.system_time(:millisecond) - 5 * 60000}}, + "ai" => %{"replay" => %{"limit" => 2, "since" => :erlang.system_time(:millisecond) - 5 * 60000}} + } + + assert {:ok, _, %Socket{}} = subscribe_and_join(socket, "realtime:test", %{"config" => config}) + + assert_receive %Socket.Message{ + event: "broadcast", + payload: %{"event" => "my_event", "meta" => %{"id" => ^broadcast_id, "replayed" => true}} + } + + assert_receive %Socket.Message{ + event: "ai_event", + payload: %{"event" => "agent_done", "meta" => %{"id" => ^ai_event_id, "replayed" => true}} + } + + refute_receive %Socket.Message{} + end end describe "presence" do @@ -886,10 +968,10 @@ defmodule RealtimeWeb.RealtimeChannelTest do old_confirm_ref = socket.assigns.confirm_token_ref - assert socket.assigns.policies == %Realtime.Tenants.Authorization.Policies{ + assert %Realtime.Tenants.Authorization.Policies{ broadcast: %Realtime.Tenants.Authorization.Policies.BroadcastPolicies{read: true, write: nil}, presence: %Realtime.Tenants.Authorization.Policies.PresencePolicies{read: true, write: nil} - } + } = socket.assigns.policies new_token = Generators.generate_jwt_token(tenant, %{ @@ -918,10 +1000,10 @@ defmodule RealtimeWeb.RealtimeChannelTest do "config" => %{"private" => true, "presence" => %{"enabled" => true}} }) - assert socket.assigns.policies == %Realtime.Tenants.Authorization.Policies{ + assert %Realtime.Tenants.Authorization.Policies{ broadcast: %Realtime.Tenants.Authorization.Policies.BroadcastPolicies{read: true, write: nil}, presence: %Realtime.Tenants.Authorization.Policies.PresencePolicies{read: true, write: nil} - } + } = socket.assigns.policies new_token = Generators.generate_jwt_token(tenant, %{ @@ -1497,6 +1579,29 @@ defmodule RealtimeWeb.RealtimeChannelTest do end end + describe "AI agent" do + @tag policies: [:authenticated_all_topic_read] + test "returns error when AI is requested but no agent is configured for the tenant", %{tenant: tenant} do + jwt = Generators.generate_jwt_token(tenant) + {:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts(tenant, jwt)) + + config = %{"private" => true, "ai" => %{"enabled" => true, "agent" => "nonexistent-agent"}} + + assert {:error, %{reason: "AiAgentNotConfigured: No AI agent configured for this tenant"}} = + subscribe_and_join(socket, "realtime:test", %{"config" => config}) + end + + test "returns error when AI is enabled on a public channel", %{tenant: tenant} do + jwt = Generators.generate_jwt_token(tenant) + {:ok, %Socket{} = socket} = connect(UserSocket, %{}, conn_opts(tenant, jwt)) + + config = %{"ai" => %{"enabled" => true, "agent" => "some-agent"}} + + assert {:error, %{reason: "AiAgentRequiresPrivateChannel: AI agent is only supported on private channels"}} = + subscribe_and_join(socket, "realtime:test", %{"config" => config}) + end + end + defp conn_opts(tenant, token) do [ connect_info: %{ diff --git a/test/support/generators.ex b/test/support/generators.ex index df73b90a2..c1d5b7e37 100644 --- a/test/support/generators.ex +++ b/test/support/generators.ex @@ -45,6 +45,50 @@ defmodule Generators do tenant end + @spec tenant_fixture_with_ai_agent(map()) :: Realtime.Api.Tenant.t() + def tenant_fixture_with_ai_agent(ai_opts \\ %{}, override \\ %{}) do + agent_name = Map.get(ai_opts, :agent_name, "test-agent") + + base_override = + override + |> Enum.map(fn {k, v} -> {"#{k}", v} end) + |> Map.new() + + tenant = tenant_fixture(base_override) + tenant = add_ai_agent_extension(tenant, agent_name, ai_opts) + Realtime.Tenants.Cache.update_cache(tenant) + tenant + end + + @spec add_ai_agent_extension(Realtime.Api.Tenant.t(), String.t(), map()) :: Realtime.Api.Tenant.t() + def add_ai_agent_extension(tenant, agent_name \\ "test-agent", ai_opts \\ %{}) do + ai_settings = %{ + "protocol" => Map.get(ai_opts, :protocol, "openai_compatible"), + "base_url" => Map.get(ai_opts, :base_url, Ollama.base_url() <> "/v1"), + "model" => Map.get(ai_opts, :model, Ollama.model()), + "api_key" => Map.get(ai_opts, :api_key, "ollama"), + "topic_pattern" => Map.get(ai_opts, :topic_pattern, "agent:*"), + "max_concurrent_sessions" => 5 + } + + %Realtime.Api.Extensions{} + |> Realtime.Api.Extensions.changeset(%{ + "type" => "ai_agent", + "name" => agent_name, + "tenant_external_id" => tenant.external_id, + "settings" => ai_settings + }) + |> Realtime.Repo.insert!() + + Realtime.Api.upsert_feature_flag(%{"name" => "ai_agent", "enabled" => false}) + Realtime.Api.update_tenant_by_external_id(tenant.external_id, %{"ai_enabled" => true}) + + {:ok, updated} = Realtime.FeatureFlags.set_tenant_flag("ai_agent", tenant.external_id, true) + + Realtime.Tenants.Cache.update_cache(updated) + updated + end + @spec message_fixture(Realtime.Api.Tenant.t()) :: any() def message_fixture(tenant, override \\ %{}) do {:ok, db_conn} = Database.connect(tenant, "realtime_test", :stop) @@ -52,7 +96,7 @@ defmodule Generators do create_attrs = %{ "topic" => random_string(), - "extension" => Enum.random([:presence, :broadcast]) + "extension" => :broadcast } override = override |> Enum.map(fn {k, v} -> {"#{k}", v} end) |> Map.new() diff --git a/test/support/ollama.ex b/test/support/ollama.ex new file mode 100644 index 000000000..0aab49e33 --- /dev/null +++ b/test/support/ollama.ex @@ -0,0 +1,141 @@ +defmodule Ollama do + @moduledoc """ + Manages an Ollama Docker container for `:live_llm` integration tests. + + Starts a single shared container, pulls a small model, and exposes the base URL. + Tests gated with `@moduletag :live_llm` call `Ollama.ensure_ready/0` in their + `setup_all` callback. + + Environment variables: + - `OLLAMA_HOST` — override the base URL (e.g. for a pre-existing instance). + Defaults to `http://localhost:11435` (non-standard port to + avoid colliding with a developer's local Ollama). + - `OLLAMA_MODEL` — model to pull and use. Defaults to `qwen2:0.5b` (352 MB, + fast on CPU, valid OpenAI-compatible SSE output). + """ + + require Logger + + @image "ollama/ollama" + @container_name "realtime-test-ollama" + @default_host "http://localhost:11435" + @default_model "qwen2:0.5b" + @host_port 11_435 + + @spec base_url() :: String.t() + def base_url, do: System.get_env("OLLAMA_HOST", @default_host) + + @spec model() :: String.t() + def model, do: System.get_env("OLLAMA_MODEL", @default_model) + + @spec ensure_ready() :: :ok | {:error, String.t()} + def ensure_ready do + with :ok <- ensure_container_running(), + :ok <- wait_for_api(), + :ok <- ensure_model_available() do + :ok + end + end + + @spec stop() :: :ok + def stop do + System.cmd("docker", ["stop", @container_name]) + :ok + end + + defp ensure_container_running do + case System.get_env("OLLAMA_HOST") do + nil -> + case container_running?() do + true -> + :ok + + false -> + pull_image() + start_container() + end + + _url -> + :ok + end + end + + defp container_running? do + {output, 0} = System.cmd("docker", ["ps", "--filter", "name=#{@container_name}", "--format", "{{.Names}}"]) + String.contains?(output, @container_name) + end + + defp pull_image do + case System.cmd("docker", ["image", "inspect", @image]) do + {_, 0} -> + :ok + + _ -> + IO.puts("Pulling #{@image}. This might take a while...") + {_, 0} = System.cmd("docker", ["pull", @image]) + :ok + end + end + + defp start_container do + IO.puts("Starting Ollama container on port #{@host_port}...") + + System.cmd("docker", ["rm", "-f", @container_name]) + + {_, 0} = + System.cmd("docker", [ + "run", + "-d", + "--name", + @container_name, + "-p", + "#{@host_port}:11434", + @image + ]) + + :ok + end + + defp wait_for_api(retries \\ 30) do + url = base_url() <> "/api/tags" + + case Req.get(url, receive_timeout: 2_000, retry: false) do + {:ok, %{status: 200}} -> + :ok + + _ when retries > 0 -> + Process.sleep(1_000) + wait_for_api(retries - 1) + + _ -> + {:error, "Ollama API did not become ready at #{url}"} + end + end + + defp ensure_model_available do + m = model() + url = base_url() <> "/api/tags" + + with {:ok, %{status: 200, body: body}} <- Req.get(url), + models = get_in(body, ["models"]) || [], + names = Enum.map(models, & &1["name"]) do + if Enum.any?(names, &String.starts_with?(&1, m)) do + :ok + else + pull_model(m) + end + else + _ -> {:error, "Could not list Ollama models"} + end + end + + defp pull_model(m) do + IO.puts("Pulling Ollama model #{m}. This will take a while on first run...") + url = base_url() <> "/api/pull" + + case Req.post(url, json: %{"name" => m}, receive_timeout: :timer.minutes(10), retry: false) do + {:ok, %{status: 200}} -> :ok + {:error, reason} -> {:error, "Failed to pull model #{m}: #{inspect(reason)}"} + end + end +end diff --git a/test/test_helper.exs b/test/test_helper.exs index f4f8086c9..09ee3ae6b 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -2,7 +2,7 @@ start_time = :os.system_time(:millisecond) alias Realtime.Api max_cases = String.to_integer(System.get_env("MAX_CASES", "4")) -ExUnit.start(exclude: [:failing], max_cases: max_cases, capture_log: true) +ExUnit.start(exclude: [:failing, :live_llm], max_cases: max_cases, capture_log: true) max_cases = ExUnit.configuration()[:max_cases] @@ -41,6 +41,14 @@ Mimic.copy(RealtimeWeb.Endpoint) Mimic.copy(RealtimeWeb.JwtVerification) Mimic.copy(RealtimeWeb.TenantBroadcaster) Mimic.copy(NimbleZTA.Cloudflare) +Mimic.copy(Finch) +Mimic.copy(Realtime.Tenants.Repo) +Mimic.copy(Realtime.Tenants.Connect) +Mimic.copy(Extensions.AiAgent.Adapter.ChatCompletions) +Mimic.copy(Extensions.AiAgent.Adapter.AnthropicMessages) +Mimic.copy(Extensions.AiAgent.Session) +Mimic.copy(Extensions.AiAgent.SessionSupervisor) +Mimic.copy(Realtime.DNS) partition = System.get_env("MIX_TEST_PARTITION") node_name = if partition, do: :"main#{partition}@127.0.0.1", else: :"main@127.0.0.1"