diff --git a/music_assistant/providers/fastmcp_server/VERSION b/music_assistant/providers/fastmcp_server/VERSION new file mode 100644 index 0000000000..e5a9958c32 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/VERSION @@ -0,0 +1 @@ +0.3.17 diff --git a/music_assistant/providers/fastmcp_server/__init__.py b/music_assistant/providers/fastmcp_server/__init__.py new file mode 100644 index 0000000000..29a0c339ce --- /dev/null +++ b/music_assistant/providers/fastmcp_server/__init__.py @@ -0,0 +1,170 @@ +""" +MCP Server Plugin Provider for Music Assistant. + +Exposes Music Assistant's library, queue, playback, players, and metadata +controllers as a Model Context Protocol server, accessible to Claude Code, +Codex, and other MCP-aware LLM clients. + +The runtime is built on PrefectHQ FastMCP v3 and mounted into MA's existing +aiohttp webserver under ``/mcp/v1`` via an ASGI bridge — no second uvicorn, +no extra port, no changes to MA core. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +LOGGER = logging.getLogger(__name__) + +if TYPE_CHECKING: + from music_assistant_models.config_entries import ( + ConfigEntry, + ConfigValueType, + ProviderConfig, + ) + from music_assistant_models.provider import ProviderManifest + + from music_assistant.mass import MusicAssistant + from music_assistant.models import ProviderInstanceType + + +async def get_config_entries( + mass: MusicAssistant, + instance_id: str | None = None, # noqa: ARG001 + action: str | None = None, + values: dict[str, ConfigValueType] | None = None, +) -> tuple[ConfigEntry, ...]: + """Return Config entries to setup this provider. + + When ``action == "open_connect"`` is dispatched, mint a bootstrap token + bound to the calling user (when available) and signal MA's frontend to + open the Connect Wizard URL — the entries themselves are returned + unchanged so the settings panel re-renders cleanly. + """ + from .config import build_config_entries # noqa: PLC0415 + + if action == "open_connect": + await _dispatch_open_connect(mass, values or {}) + + return build_config_entries(mass, values or {}) + + +def _sanitize_external_base_url(value: str | None) -> str | None: + """Return ``value`` if it is a plausible ``http(s)://`` base URL, else ``None``. + + Defends against an admin pasting (or a misbehaving proxy injecting) a + scheme-less or ``javascript:`` URL into the Connect Wizard link, which + the MA frontend would feed straight to ``window.open``. + """ + if not value: + return None + candidate = value.strip() + if not candidate.lower().startswith(("http://", "https://")): + LOGGER.warning( + "Connect Wizard: ignoring external base URL with unsupported scheme: %r", + candidate, + ) + return None + return candidate + + +def _detect_external_base_url(mass: MusicAssistant, current_user: Any) -> str | None: + """Return the external base URL for the current user's active WS client. + + MA's :class:`WebsocketClientHandler` stores a per-connection ``base_url`` + derived from ``X-Forwarded-Host`` + ``X-Ingress-Path`` — exactly the + prefix the Connect Wizard needs so ``window.open`` produces a working + URL under Home Assistant add-on ingress. We pick the client whose + authenticated user matches the invoker of the action. + + Returns ``None`` when nothing matches (e.g. action invoked outside the + WS server, or no forward headers were captured). + """ + if current_user is None: + return None + try: + webserver = getattr(mass, "webserver", None) + clients = getattr(webserver, "clients", None) or () + except Exception: + return None + + def _user_id(user: Any) -> Any: + return getattr(user, "user_id", None) or getattr(user, "username", None) + + target = _user_id(current_user) + for client in clients: + client_base = getattr(client, "base_url", None) + if not client_base: + continue + client_user = getattr(client, "_authenticated_user", None) + if client_user is None: + continue + if _user_id(client_user) == target: + return str(client_base) + return None + + +async def _dispatch_open_connect( + mass: MusicAssistant, + values: dict[str, ConfigValueType], +) -> None: + """Mint a wizard bootstrap and signal the wizard URL to the frontend. + + The MA frontend's ``EditProvider`` view subscribes to ``AUTH_SESSION`` + events and ignores anything whose ``object_id`` does not match the + ``session_id`` it injected into ``values``. We must echo that same id + back as the event's ``object_id`` so the browser tab actually opens. + + URL resolution order: (1) auto-detect from the active WS client's + ingress-aware ``base_url``; (2) explicit ``connect_external_url`` config + override; (3) path-only fallback resolved against the browser's origin. + """ + from .connect import handle_open_connect_action # noqa: PLC0415 + from .constants import ( # noqa: PLC0415 + CONF_CONNECT_EXTERNAL_URL, + CONF_MOUNT_PATH, + DEFAULT_MOUNT_PATH, + ) + + mount_path = str(values.get(CONF_MOUNT_PATH) or DEFAULT_MOUNT_PATH) + session_id = str(values.get("session_id") or "") + + current_user: object | None = None + try: + from music_assistant.controllers.webserver.helpers.auth_middleware import ( # noqa: PLC0415 + get_current_user, + ) + + current_user = get_current_user() + except Exception: + LOGGER.debug("Connect Wizard: get_current_user lookup failed", exc_info=True) + current_user = None + + external_base_url = _sanitize_external_base_url(_detect_external_base_url(mass, current_user)) + if not external_base_url: + external_base_url = _sanitize_external_base_url( + str(values.get(CONF_CONNECT_EXTERNAL_URL) or "") + ) + + try: + await handle_open_connect_action( + mass, + current_user=current_user, + mount_path=mount_path, + session_id=session_id or None, + external_base_url=external_base_url, + ) + except Exception: + LOGGER.exception("Connect Wizard: open_connect action failed") + + +async def setup( + mass: MusicAssistant, + manifest: ProviderManifest, + config: ProviderConfig, +) -> ProviderInstanceType: + """Initialize provider instance with given configuration.""" + from .provider import MCPServerProvider # noqa: PLC0415 + + return MCPServerProvider(mass, manifest, config) diff --git a/music_assistant/providers/fastmcp_server/auth.py b/music_assistant/providers/fastmcp_server/auth.py new file mode 100644 index 0000000000..6a781429fe --- /dev/null +++ b/music_assistant/providers/fastmcp_server/auth.py @@ -0,0 +1,162 @@ +"""Token verifier delegating to MA's existing authentication subsystem. + +The plugin does not implement JWT decoding or scope checks of its own — this +is intentional. ``mass.webserver.auth.authenticate_with_token`` already handles +both JWT (PR #2891) and legacy hash tokens, and updates the sliding-window +expiry on every successful call. Wiring our own JWT decode here would only +duplicate the work and create two sources of truth. + +Passing ``base_url`` upstream to :class:`fastmcp.server.auth.TokenVerifier` +lets FastMCP's built-in ``RequireAuthMiddleware`` populate the +``resource_metadata="…"`` parameter in ``WWW-Authenticate`` headers on +401 responses (RFC 9728 / MCP authorization spec MUST). +""" + +from __future__ import annotations + +import base64 +import binascii +import json +import logging +from typing import TYPE_CHECKING + +from fastmcp.server.auth import TokenVerifier +from fastmcp.server.auth.auth import AccessToken + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + +LOGGER = logging.getLogger(__name__) + + +def _extract_jwt_audience(token: str) -> str | list[str] | None: + """Best-effort decode of a JWT payload to extract the ``aud`` claim. + + Returns ``None`` for non-JWT tokens (legacy MA hash tokens), malformed + payloads, or JWTs without an ``aud`` claim. Does **not** verify the + signature — that's MA's responsibility via ``authenticate_with_token``; + we only read the audience claim to compare against this MCP server's + canonical URI. + """ + parts = token.split(".") + if len(parts) != 3: + return None + payload_segment = parts[1] + # Restore base64url padding stripped by JWT spec. + pad = "=" * (-len(payload_segment) % 4) + try: + raw = base64.urlsafe_b64decode(payload_segment + pad) + claims = json.loads(raw) + except (binascii.Error, ValueError, UnicodeDecodeError): + return None + aud = claims.get("aud") if isinstance(claims, dict) else None + if isinstance(aud, (str, list)) or aud is None: + return aud + return None + + +def _audience_matches(aud: str | list[str] | None, expected: str) -> bool: + """RFC 8707: token is bound to ``expected`` if its ``aud`` is or contains it.""" + if aud is None: + return False + if isinstance(aud, str): + return aud == expected + return expected in aud + + +class MASTokenVerifier(TokenVerifier): + """Verify Bearer tokens against ``mass.webserver.auth``.""" + + def __init__( + self, + mass: MusicAssistant, + *, + base_url: str | None = None, + public_resource_uri: str | None = None, + enforce_audience: bool = False, + ) -> None: + """Bind the verifier to a MusicAssistant instance. + + :param mass: MusicAssistant instance used to authenticate tokens. + :param base_url: Public base URL of this MA instance (used by FastMCP + to build the ``resource_metadata`` URL advertised in 401 responses + and the ``aud`` claim binding). + :param public_resource_uri: Canonical URI of the MCP server (the value + FastMCP will report as ``resource``). Used to populate + ``AccessToken.resource`` so downstream code can audience-check. + :param enforce_audience: When ``True``, reject Bearer tokens whose + ``aud`` claim is missing or does not contain ``public_resource_uri``. + When ``False`` (default), only logs a warning so operators can + migrate gracefully once MA-side issues audience-bound tokens. + """ + # ``base_url`` is optional on TokenVerifier — passing ``None`` is + # equivalent to not setting it. Forward verbatim so FastMCP can later + # build the resource_metadata URL from this verifier. + super().__init__(base_url=base_url) + self._mass = mass + self._public_resource_uri = public_resource_uri + self._enforce_audience = enforce_audience + + async def verify_token(self, token: str) -> AccessToken | None: + """Validate the bearer token and produce an ``AccessToken`` for FastMCP. + + :param token: Raw bearer token from the ``Authorization`` header. + :return: ``AccessToken`` if the token is valid and the user is enabled, + otherwise ``None``. + """ + try: + user = await self._mass.webserver.auth.authenticate_with_token(token) + except Exception: + LOGGER.exception("MA token verification raised") + return None + + if user is None or not getattr(user, "enabled", True): + return None + + if not self._check_audience(token): + return None + + # MCP SDK's AccessToken pydantic model has no `claims` field — extras + # are silently dropped — so we don't try to forward username/role here. + return AccessToken( + token=token, + client_id=str(getattr(user, "user_id", "")) or "music-assistant", + scopes=[], + expires_at=None, + resource=self._public_resource_uri, + ) + + def _check_audience(self, token: str) -> bool: + """Return ``True`` if the token's audience is acceptable for this server. + + In soft mode (``enforce_audience=False``) — always returns ``True`` and + only emits a warning when a JWT's ``aud`` is missing or mismatched. + In strict mode — rejects tokens missing or with a wrong ``aud``. + Non-JWT (legacy hash) tokens have no claim to inspect: they pass in + soft mode and fail in strict mode. + """ + expected = self._public_resource_uri + if not expected: + return True + aud = _extract_jwt_audience(token) + if _audience_matches(aud, expected): + return True + if self._enforce_audience: + LOGGER.warning( + "Rejected token: aud=%r does not match MCP resource URI %r", + aud, + expected, + ) + return False + if aud is None: + LOGGER.debug( + "Token has no `aud` claim; accepting because enforce_audience=False", + ) + else: + LOGGER.warning( + "Token aud=%r does not match resource URI %r; accepting because " + "enforce_audience=False (set CONF_ENFORCE_AUDIENCE to enforce).", + aud, + expected, + ) + return True diff --git a/music_assistant/providers/fastmcp_server/config.py b/music_assistant/providers/fastmcp_server/config.py new file mode 100644 index 0000000000..1a03a1c211 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/config.py @@ -0,0 +1,323 @@ +"""ConfigEntry schema for the MCP Server provider.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from music_assistant_models.config_entries import ConfigEntry +from music_assistant_models.enums import ConfigEntryType + +from .constants import ( + CONF_CONNECT_EXTERNAL_URL, + CONF_CONTROL_MEDIA, + CONF_CONTROL_PLAYBACK, + CONF_CONTROL_PLAYERS, + CONF_CONTROL_VOLUME, + CONF_DELETE_FAVORITES, + CONF_DELETE_LIBRARY, + CONF_DELETE_PLAYLISTS, + CONF_DELETE_QUEUE, + CONF_EDIT_FAVORITES, + CONF_EDIT_LIBRARY, + CONF_EDIT_PLAYLISTS, + CONF_EDIT_QUEUE, + CONF_ENFORCE_AUDIENCE, + CONF_EXTRA_ALLOWED_ORIGINS, + CONF_MOUNT_PATH, + CONF_QUERY_LIBRARY, + CONF_QUERY_METADATA, + CONF_QUERY_PLAYERS, + CONF_QUERY_QUEUE, + CONF_REQUIRE_AUTH, + CONF_REQUIRE_CONFIRMATION, + CONF_RES_LIBRARY, + CONF_RES_PLAYER, + CONF_RES_PROMPTS, + DEFAULT_MOUNT_PATH, +) + +if TYPE_CHECKING: + from music_assistant_models.config_entries import ConfigValueType + + from music_assistant.mass import MusicAssistant + + +def _bool(key: str, label: str, default: bool, category: str, description: str = "") -> ConfigEntry: + return ConfigEntry( + key=key, + type=ConfigEntryType.BOOLEAN, + label=label, + default_value=default, + category=category, + description=description or label, + required=False, + ) + + +def build_config_entries( + mass: MusicAssistant, + values: dict[str, ConfigValueType], +) -> tuple[ConfigEntry, ...]: + """Return the full ConfigEntry schema for this provider. + + :param mass: MusicAssistant instance, used to compose the info label. + :param values: Current config values (may be empty on first setup). + """ + base_url = getattr(mass.webserver, "base_url", "").rstrip("/") + raw_mount = str(values.get(CONF_MOUNT_PATH) or DEFAULT_MOUNT_PATH) + # Mirror ``MCPServerRuntime.__init__``'s normalisation so the info label + # always renders a valid URL even if the user dropped the leading slash. + mount_path = "/" + raw_mount.strip("/") + info_label = f"MCP endpoint: {base_url}{mount_path}\nCreate tokens in Profile → Long-lived access tokens." + + return ( + ConfigEntry( + key="info_label", + type=ConfigEntryType.LABEL, + label=info_label, + category="Server", + required=False, + ), + ConfigEntry( + key="open_connect", + type=ConfigEntryType.ACTION, + label="Open Connect Wizard", + description=( + "One-click setup for Claude Desktop, Claude Code, Cursor, " + "Windsurf, VSCode, ChatGPT and other MCP clients. Mints a " + "per-client token labelled `MCP — ` (revocable in " + "Profile → Long-lived access tokens) and copies the ready-to-paste " + "snippet for you." + ), + action="open_connect", + required=False, + ), + ConfigEntry( + key=CONF_REQUIRE_AUTH, + type=ConfigEntryType.BOOLEAN, + label="Require authentication", + default_value=True, + category="Server", + description=( + "Reject unauthenticated MCP clients. Strongly recommended — " + "with auth disabled, every MCP client on the network can drive playback." + ), + required=False, + ), + ConfigEntry( + key=CONF_MOUNT_PATH, + type=ConfigEntryType.STRING, + label="Mount path", + default_value=DEFAULT_MOUNT_PATH, + category="Server", + advanced=True, + description=( + "HTTP path prefix where the MCP server is mounted on MA's webserver. " + "Change only if it conflicts with another route." + ), + required=False, + ), + ConfigEntry( + key=CONF_REQUIRE_CONFIRMATION, + type=ConfigEntryType.BOOLEAN, + label="Confirm destructive operations", + default_value=True, + category="Server", + description=( + "Ask the MCP client to confirm before running destructive tools " + "(clear_queue, remove_tracks, remove_from_library, " + "remove_from_favorites). If the client doesn't support " + "elicitation, the call falls through to the permission flag." + ), + required=False, + ), + ConfigEntry( + key=CONF_ENFORCE_AUDIENCE, + type=ConfigEntryType.BOOLEAN, + label="Enforce token audience (RFC 8707)", + default_value=False, + category="Server", + advanced=True, + description=( + "Reject Bearer tokens whose `aud` claim does not match this MCP " + "server's canonical URI. Mitigates the OAuth confused-deputy " + "attack where a token issued for one MA endpoint is replayed " + "against another. Requires upstream Music Assistant support for " + "writing `aud` into JWTs (in progress) — until then enabling " + "this rejects all existing tokens. Leave off unless your MA " + "build issues audience-bound tokens." + ), + required=False, + ), + ConfigEntry( + key=CONF_EXTRA_ALLOWED_ORIGINS, + type=ConfigEntryType.STRING, + label="Additional allowed Origins (CSV)", + default_value="", + category="Server", + advanced=True, + description=( + "Comma-separated list of additional `Origin` headers to accept " + "(e.g. `https://ha.example.com` for Home Assistant ingress, or a " + "reverse-proxy hostname). By default the server only accepts " + "`localhost`, `127.0.0.1`, the MA `base_url` host, and `publish_ip`. " + "Mismatching Origins are rejected with 403 to mitigate DNS rebinding." + ), + required=False, + ), + ConfigEntry( + key=CONF_CONNECT_EXTERNAL_URL, + type=ConfigEntryType.STRING, + label="Connect Wizard external URL (fallback)", + default_value="", + category="Server", + advanced=True, + description=( + "Optional explicit base URL the Connect Wizard should open at " + "(e.g. `https://ha.example.com/` for Home Assistant " + "add-on ingress). Used only when the wizard cannot auto-detect " + "the external URL from the active client connection — set it " + "if your reverse proxy strips the `X-Forwarded-Host` / " + "`X-Ingress-Path` headers." + ), + required=False, + ), + # Query permissions + _bool( + CONF_QUERY_LIBRARY, + "Query library", + True, + "Query Permissions", + "Search music, browse library, get artists/albums/tracks/playlists.", + ), + _bool( + CONF_QUERY_QUEUE, + "Query queue", + True, + "Query Permissions", + "Read the current queue state for any player.", + ), + _bool( + CONF_QUERY_PLAYERS, + "Query players", + True, + "Query Permissions", + "List players and read their state and capabilities.", + ), + _bool( + CONF_QUERY_METADATA, + "Query metadata", + True, + "Query Permissions", + "Get lyrics, recommendations, and similar tracks.", + ), + # Control permissions + _bool( + CONF_CONTROL_PLAYBACK, + "Control playback", + False, + "Control Permissions", + "Play, pause, stop, seek, next/previous, play media.", + ), + _bool( + CONF_CONTROL_VOLUME, + "Control volume", + False, + "Control Permissions", + "Set volume, volume up/down, mute, group volume.", + ), + _bool( + CONF_CONTROL_PLAYERS, + "Control players", + False, + "Control Permissions", + "Power players on/off, select source.", + ), + _bool( + CONF_CONTROL_MEDIA, + "Play announcements / mark played", + False, + "Control Permissions", + "Send TTS announcements, mark items as played.", + ), + # Edit permissions + _bool( + CONF_EDIT_LIBRARY, + "Add items to library", + False, + "Edit Permissions", + "Add tracks, albums, artists, or playlists to the library.", + ), + _bool( + CONF_EDIT_QUEUE, + "Edit queue", + False, + "Edit Permissions", + "Move queue items, save queue as playlist.", + ), + _bool( + CONF_EDIT_PLAYLISTS, + "Create / modify playlists", + False, + "Edit Permissions", + "Create playlists, add tracks, reorder.", + ), + _bool( + CONF_EDIT_FAVORITES, + "Add to favorites", + False, + "Edit Permissions", + "Mark items as favorites.", + ), + # Delete permissions + _bool( + CONF_DELETE_LIBRARY, + "Remove items from library", + False, + "Delete Permissions", + "Remove tracks, albums, artists, or playlists from the library.", + ), + _bool( + CONF_DELETE_QUEUE, + "Clear queue / remove items", + False, + "Delete Permissions", + "Remove queue items or clear the queue.", + ), + _bool( + CONF_DELETE_PLAYLISTS, + "Delete playlists / remove tracks", + False, + "Delete Permissions", + "Delete playlists, remove tracks from playlists.", + ), + _bool( + CONF_DELETE_FAVORITES, + "Remove from favorites", + False, + "Delete Permissions", + "Remove items from favorites.", + ), + # Resources / prompts + _bool( + CONF_RES_LIBRARY, + "Expose library:// resources", + True, + "MCP Resources", + "URI-addressable read-only views of artists, albums, tracks, playlists.", + ), + _bool( + CONF_RES_PLAYER, + "Expose player:// and queue:// resources", + True, + "MCP Resources", + "URI-addressable views of players and queues.", + ), + _bool( + CONF_RES_PROMPTS, + "Expose canned prompts", + True, + "MCP Resources", + "Pre-defined prompts: find_and_play, party_playlist, now_playing_summary.", + ), + ) diff --git a/music_assistant/providers/fastmcp_server/connect/__init__.py b/music_assistant/providers/fastmcp_server/connect/__init__.py new file mode 100644 index 0000000000..781443d3b6 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/connect/__init__.py @@ -0,0 +1,14 @@ +"""Connect Wizard — one-click onboarding UI for MCP-aware AI clients. + +Provides a single-page web UI mounted under ``/connect`` that mints +per-client long-lived MA tokens (``"MCP — "``) and renders ready-to-paste +configuration snippets, deeplinks, and share-URLs for Claude Desktop, Claude +Code, Cursor, Windsurf, VSCode, ChatGPT, Codex CLI, Gemini CLI, Cline, and Zed. +""" + +from __future__ import annotations + +from .actions import handle_open_connect_action +from .mount import mount_connect_wizard + +__all__ = ["handle_open_connect_action", "mount_connect_wizard"] diff --git a/music_assistant/providers/fastmcp_server/connect/_revoke.py b/music_assistant/providers/fastmcp_server/connect/_revoke.py new file mode 100644 index 0000000000..10ea24c131 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/connect/_revoke.py @@ -0,0 +1,132 @@ +"""Sanctioned-API helpers for wizard-side token management. + +The wizard mints (and therefore wants to revoke / list) MA auth tokens, but +the public ``auth.revoke_token`` / ``auth.get_user_tokens`` methods are +``@api_command``-decorated — they read the current authenticated user from +the ``current_user`` ContextVar, which is normally populated by MA's HTTP / +WS request middleware. The wizard's ASGI endpoints run inside MA's process +but outside that middleware, so the contextvar is empty by default. + +This module mirrors the pattern MA's own test suite uses +(``tests/test_webserver_auth.py:336-354``): briefly impersonate a known +``User`` via the public ``set_current_user`` helper, then call the API +method, then restore the prior context. + +The internal import path +``music_assistant.controllers.webserver.helpers.auth_middleware`` is the +same one MA's tests use; it is not under ``music_assistant_models`` but is +the de-facto contract for in-process callers. +""" + +from __future__ import annotations + +import logging +from contextlib import contextmanager +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterator + + from music_assistant_models.auth import AuthToken, User + + from music_assistant.mass import MusicAssistant + +LOGGER = logging.getLogger(__name__) + +# Sanctioned contextvar helpers live in an MA-internal module. In a real MA +# install the import succeeds (it's the same module MA's own tests use — +# tests/test_webserver_auth.py:19-22). In this repo's minimal dev venv the +# transitive ``music_assistant.controllers.webserver`` package can't be +# loaded (frontend / chardet / torch are not installed), so fall back to +# no-op shims for collect-time imports. Tests mock the API methods that +# would actually read ``current_user``, so a no-op context manager is safe +# there. Production always hits the real branch. +try: + from music_assistant.controllers.webserver.helpers.auth_middleware import ( + get_current_user as _ma_get_current_user, + ) + from music_assistant.controllers.webserver.helpers.auth_middleware import ( + set_current_user as _ma_set_current_user, + ) +except ImportError: + # Narrow on purpose: only swallow ``ImportError`` (which covers + # ``ModuleNotFoundError``) — the case is the minimal dev venv missing + # a transitive MA dep. Anything else (e.g. ``AttributeError`` from a + # renamed symbol) must propagate so MA-side breakage surfaces loudly + # instead of silently disabling token revocation. + # Signatures must match the real MA helpers exactly — mypy on CI sees + # both branches with the full MA install and rejects any drift. + + def _ma_get_current_user() -> User | None: + return None + + def _ma_set_current_user(user: User | None) -> None: # noqa: ARG001 + return None + + +@contextmanager +def _as_user(user: User) -> Iterator[None]: + """Briefly impersonate ``user`` for an ``@api_command`` call. + + ``current_user`` is a ContextVar — save/restore is task-local, so + concurrent requests on other users are unaffected. + + :param user: The user to set as the current authenticated user for the + duration of the ``with`` block. + """ + prev = _ma_get_current_user() + _ma_set_current_user(user) + try: + yield + finally: + _ma_set_current_user(prev) + + +async def revoke_token_by_id(mass: MusicAssistant, user: User, token_id: str) -> bool: + """Revoke a token via the sanctioned ``auth.revoke_token`` API. + + MA's ``revoke_token`` enforces ownership internally — the impersonated + ``user`` must own the token (or be admin), or the call raises + ``InsufficientPermissions``. ``InvalidDataError`` is raised for an + unknown ``token_id``. Both are swallowed; this is a best-effort + operation. + + :param mass: MusicAssistant instance. + :param user: Owner of the token being revoked (sets the auth context). + :param token_id: ``jti`` of the token to revoke. + :return: ``True`` if ``revoke_token`` returned without raising, + ``False`` otherwise. + """ + with _as_user(user): + try: + await mass.webserver.auth.revoke_token(token_id) + except Exception: + LOGGER.exception( + "Connect Wizard: revoke_token failed (token_id=%s, user=%s)", + token_id, + user.user_id, + ) + return False + return True + + +async def list_user_tokens(mass: MusicAssistant, user: User) -> list[AuthToken]: + """List ``user``'s auth tokens via the sanctioned ``auth.get_user_tokens`` API. + + Returns typed ``AuthToken`` dataclasses — no raw ``sqlite3.Row`` + objects leak across the boundary. Best-effort: an error returns ``[]``. + + Note: MA core caps the query at 100 rows. A user with > 100 active + tokens will see some priors miss our dedup pass — acceptable for the + typical case (handful of tokens). + + :param mass: MusicAssistant instance. + :param user: User whose tokens to list (sets the auth context). + """ + tokens: list[AuthToken] = [] + with _as_user(user): + try: + tokens = await mass.webserver.auth.get_user_tokens() + except Exception: + LOGGER.exception("Connect Wizard: get_user_tokens failed (user=%s)", user.user_id) + return tokens diff --git a/music_assistant/providers/fastmcp_server/connect/actions.py b/music_assistant/providers/fastmcp_server/connect/actions.py new file mode 100644 index 0000000000..bfd108210d --- /dev/null +++ b/music_assistant/providers/fastmcp_server/connect/actions.py @@ -0,0 +1,123 @@ +"""ACTION-handler: mints a bootstrap token and signals the wizard URL to the frontend. + +Triggered by the ``open_connect`` ``ConfigEntryType.ACTION`` button defined in +:mod:`provider.config`. Mirrors the Spotify-provider OAuth pattern (signal an +``EventType.AUTH_SESSION`` event whose ``data`` is the URL the MA frontend +should ``window.open``). +""" + +from __future__ import annotations + +import asyncio +import logging +import secrets +from typing import TYPE_CHECKING, Any +from urllib.parse import urlencode + +from ._revoke import list_user_tokens, revoke_token_by_id + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + +LOGGER = logging.getLogger(__name__) + +# Short-lived plumbing tokens the wizard mints on each open / page load. These +# auto-expire after 30 days; until then they clutter the user's token list. +# Garbage-collect them before minting a new one. +_GC_NAMES = ("MCP — wizard bootstrap", "MCP — wizard session") + + +async def handle_open_connect_action( + mass: MusicAssistant, + *, + current_user: Any, + mount_path: str, + base_url: str = "", + session_id: str | None = None, + external_base_url: str | None = None, +) -> None: + """Open the Connect Wizard in the user's browser via MA's auth-session signal. + + :param mass: MusicAssistant instance. + :param current_user: The authenticated MA ``User`` invoking the action, or + ``None`` when no user context is available — in which case the wizard + is opened without a bootstrap token and falls back to its login form. + :param mount_path: HTTP path prefix where the MCP server is mounted. + :param base_url: Deprecated — kept in the signature for backwards + compatibility; ignored. Pass ``external_base_url`` instead. + :param session_id: ``session_id`` echoed by the MA frontend in the action's + ``values``. Must be passed back verbatim as the ``AUTH_SESSION`` event + ``object_id`` so the EditProvider view actually opens the URL — frontend + ignores AUTH_SESSION events whose object_id does not match its session. + :param external_base_url: Externally reachable base URL (scheme + host + + optional ingress path prefix) to prepend to the wizard URL. When + omitted, falls back to a path-only URL that the browser resolves + against its own origin. + """ + del base_url # kept in signature for backwards compatibility; ignored + + bootstrap: str | None = None + if current_user is not None: + # GC any prior wizard plumbing rows for this user before minting a + # new bootstrap, via the sanctioned auth API. Best-effort: lookup + # failures inside list_user_tokens return []; individual revoke + # failures are swallowed inside revoke_token_by_id. Per-client + # tokens (MCP — ) are not touched. + for tok in await list_user_tokens(mass, current_user): + if tok.name in _GC_NAMES: + await revoke_token_by_id(mass, current_user, tok.token_id) + + try: + bootstrap = await mass.webserver.auth.create_token( + user=current_user, + name="MCP — wizard bootstrap", + is_long_lived=False, + ) + except Exception: + LOGGER.exception("Connect Wizard: failed to mint bootstrap token") + bootstrap = None + + mount = "/" + mount_path.strip("/") + if external_base_url: + # Fully-qualified URL — required under HA add-on ingress, where the + # MA frontend lives at ``https:////`` and ``window.open`` + # on a path starting with ``/`` would drop the ingress prefix. + url = f"{external_base_url.rstrip('/')}{mount}/connect" + else: + # Path-only fallback — browser resolves against its own origin. Works + # for direct access; loses any reverse-proxy / ingress path prefix. + url = f"{mount}/connect" + if bootstrap: + url = f"{url}?{urlencode({'bootstrap': bootstrap})}" + + object_id = session_id or f"mcp-connect-{secrets.token_urlsafe(8)}" + _signal_auth_session(mass, session_id=object_id, url=url) + + # Hold the action response open briefly. MA frontend's EditProvider sets + # ``loading=true`` while awaiting our response (which keeps the overlay + # mounted) and, on receiving AUTH_SESSION, schedules a 100 ms setTimeout + # that grabs ```` from inside that overlay and clicks it. + # Without this delay the response races back first, ``loading`` flips to + # false, the overlay (and the anchor inside it) unmounts, and the + # frontend throws ``Cannot read properties of null (reading + # 'setAttribute')`` — the user sees nothing happen. 500 ms gives the + # frontend enough time to follow the link before we let the overlay close. + await asyncio.sleep(0.5) + + +def _signal_auth_session(mass: MusicAssistant, *, session_id: str, url: str) -> None: + """Publish the wizard URL via MA's ``EventType.AUTH_SESSION`` signal. + + The MA frontend subscribes to ``AUTH_SESSION`` events and ``window.open``-s + the carried URL — same mechanism the Spotify, Audible, QQMusic providers + use for OAuth redirect. Never raises: if the event bus rejects the call we + log the exception and the failure path (the user can re-trigger the action + manually) — the URL itself is **not** logged because it carries the + short-lived bootstrap token in its query string. + """ + from music_assistant_models.enums import EventType # noqa: PLC0415 + + try: + mass.signal_event(EventType.AUTH_SESSION, object_id=session_id, data=url) + except Exception: + LOGGER.exception("Connect Wizard: signal_event failed") diff --git a/music_assistant/providers/fastmcp_server/connect/clients.py b/music_assistant/providers/fastmcp_server/connect/clients.py new file mode 100644 index 0000000000..230d853a56 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/connect/clients.py @@ -0,0 +1,220 @@ +"""AI-client catalogue used by the Connect Wizard. + +Each :class:`ClientSpec` is rendered into a copy-paste config snippet by the +wizard's JavaScript: ``{{URL}}`` is replaced with the chosen MCP endpoint URL +and ``{{TOKEN}}`` with a freshly minted per-client token. The catalogue is +serialised to JSON via :func:`clients_to_json` and embedded into ``/connect/info``. +""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass + + +@dataclass(frozen=True) +class ClientSpec: + """Renderable description of a single AI client. + + :param id: Stable identifier used in API calls and as the per-client token name suffix. + :param label: Human-readable name shown in the wizard tab and in token names. + :param kind: Snippet syntax — ``json`` / ``shell`` / ``toml``. + :param template: Snippet body with ``{{URL}}`` and ``{{TOKEN}}`` placeholders. + :param config_path_hint: Where the user should paste the snippet (for the UI hint line). + :param notes: Optional extra advice (transport quirks, OS gotchas). + :param filename: Suggested download filename for the snippet. + """ + + id: str + label: str + kind: str + template: str + config_path_hint: str + notes: str = "" + filename: str = "" + + +CLIENTS: tuple[ClientSpec, ...] = ( + ClientSpec( + id="claude-code", + label="Claude Code", + kind="shell", + template=( + "claude mcp add ma {{URL}} \\\n" + " --transport http \\\n" + ' --header "Authorization: Bearer {{TOKEN}}"' + ), + config_path_hint="Run this in any terminal.", + filename="add-ma.sh", + ), + ClientSpec( + id="claude-desktop", + label="Claude Desktop", + kind="json", + template=( + "{\n" + ' "mcpServers": {\n' + ' "ma": {\n' + ' "url": "{{URL}}",\n' + ' "headers": { "Authorization": "Bearer {{TOKEN}}" }\n' + " }\n" + " }\n" + "}" + ), + config_path_hint=( + "macOS: ~/Library/Application Support/Claude/claude_desktop_config.json · " + "Windows: %APPDATA%/Claude/claude_desktop_config.json" + ), + notes="Requires Claude Desktop with native HTTP transport (≥ 0.10).", + filename="claude_desktop_config.json", + ), + ClientSpec( + id="cursor", + label="Cursor", + kind="json", + template=( + "{\n" + ' "mcpServers": {\n' + ' "ma": {\n' + ' "url": "{{URL}}",\n' + ' "headers": { "Authorization": "Bearer {{TOKEN}}" }\n' + " }\n" + " }\n" + "}" + ), + config_path_hint="~/.cursor/mcp.json (global) or .cursor/mcp.json (project).", + notes="Use the 'Add to Cursor' button for one-click install.", + filename="mcp.json", + ), + ClientSpec( + id="windsurf", + label="Windsurf", + kind="json", + template=( + "{\n" + ' "mcpServers": {\n' + ' "ma": {\n' + ' "serverUrl": "{{URL}}",\n' + ' "headers": { "Authorization": "Bearer {{TOKEN}}" }\n' + " }\n" + " }\n" + "}" + ), + config_path_hint="~/.codeium/windsurf/mcp_config.json", + filename="mcp_config.json", + ), + ClientSpec( + id="vscode", + label="VSCode (Copilot Chat)", + kind="json", + template=( + "{\n" + ' "servers": {\n' + ' "ma": {\n' + ' "type": "http",\n' + ' "url": "{{URL}}",\n' + ' "headers": { "Authorization": "Bearer {{TOKEN}}" }\n' + " }\n" + " }\n" + "}" + ), + config_path_hint=".vscode/mcp.json (workspace) or User Settings JSON.", + filename="mcp.json", + ), + ClientSpec( + id="chatgpt", + label="ChatGPT (Connectors)", + kind="shell", + template=( + "# Settings → Connectors → Add custom MCP\n" + "URL: {{URL}}\n" + "Auth: Bearer {{TOKEN}}\n" + "# ChatGPT requires a publicly reachable HTTPS URL\n" + "# (Cloudflare Tunnel / Tailscale Funnel / nginx + Let's Encrypt)." + ), + config_path_hint="UI only — no file to paste.", + notes="Public HTTPS required.", + filename="chatgpt-mcp.txt", + ), + ClientSpec( + id="codex-cli", + label="Codex CLI", + kind="toml", + template=( + "[mcp_servers.ma]\n" + 'url = "{{URL}}"\n' + "[mcp_servers.ma.http_headers]\n" + 'Authorization = "Bearer {{TOKEN}}"' + ), + config_path_hint="~/.codex/config.toml", + notes=( + "Codex's streamable_http transport reads custom headers from " + "`http_headers` (not `headers`)." + ), + filename="config.toml", + ), + ClientSpec( + id="gemini-cli", + label="Gemini CLI", + kind="json", + template=( + "{\n" + ' "mcpServers": {\n' + ' "ma": {\n' + ' "httpUrl": "{{URL}}",\n' + ' "headers": { "Authorization": "Bearer {{TOKEN}}" }\n' + " }\n" + " }\n" + "}" + ), + config_path_hint="~/.gemini/settings.json", + filename="settings.json", + ), + ClientSpec( + id="cline", + label="Cline (VSCode)", + kind="json", + template=( + "{\n" + ' "mcpServers": {\n' + ' "ma": {\n' + ' "url": "{{URL}}",\n' + ' "headers": { "Authorization": "Bearer {{TOKEN}}" }\n' + " }\n" + " }\n" + "}" + ), + config_path_hint='VSCode command palette → "Cline: Open MCP Settings".', + filename="cline_mcp_settings.json", + ), + ClientSpec( + id="zed", + label="Zed Editor", + kind="json", + template=( + "{\n" + ' "context_servers": {\n' + ' "ma": {\n' + ' "url": "{{URL}}",\n' + ' "headers": { "Authorization": "Bearer {{TOKEN}}" }\n' + " }\n" + " }\n" + "}" + ), + config_path_hint="~/.config/zed/settings.json", + notes="Requires a recent Zed build with native remote-MCP support.", + filename="settings.json", + ), +) + + +def lookup_client(client_id: str) -> ClientSpec | None: + """Return the :class:`ClientSpec` matching ``client_id``, or ``None`` if unknown.""" + for spec in CLIENTS: + if spec.id == client_id: + return spec + return None + + +def clients_to_json() -> list[dict[str, str]]: + """Return the catalogue as a list of plain dicts suitable for JSON serialisation.""" + return [asdict(spec) for spec in CLIENTS] diff --git a/music_assistant/providers/fastmcp_server/connect/handlers.py b/music_assistant/providers/fastmcp_server/connect/handlers.py new file mode 100644 index 0000000000..02d6efedc1 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/connect/handlers.py @@ -0,0 +1,295 @@ +"""HTTP handlers backing the Connect Wizard endpoints. + +Five endpoints are mounted under ``/connect``: + +* ``GET /connect`` — serves the single-page HTML wizard. +* ``GET /connect/info`` — meta JSON (URLs, version, enabled permissions, clients). +* ``POST /connect/exchange`` — exchanges a bootstrap token for a session token. +* ``POST /connect/login`` — username/password login fallback. +* ``POST /connect/token`` — mints a per-client long-lived token. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any +from urllib.parse import urlsplit + +from aiohttp import web + +from ._revoke import list_user_tokens, revoke_token_by_id +from .clients import clients_to_json, lookup_client +from .page import HTML + +if TYPE_CHECKING: + from collections.abc import Callable + + from music_assistant.mass import MusicAssistant + +LOGGER = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class WizardContext: + """Shared state captured at mount time and passed to every handler.""" + + mass: MusicAssistant + mount_path: str + enabled_tags_provider: Callable[[], list[str]] + origin_check: Callable[[web.Request], bool] + + +def _origin_guard(ctx: WizardContext, request: web.Request) -> web.Response | None: + """Return a 403 response if the request's ``Origin`` is not allowlisted.""" + if not ctx.origin_check(request): + LOGGER.warning( + "Connect Wizard: rejected request with Origin=%r from %s", + request.headers.get("Origin"), + request.remote, + ) + return web.Response(status=403, text="Forbidden Origin") + return None + + +async def _read_json(request: web.Request) -> dict[str, Any]: + """Best-effort JSON parse; missing/malformed body becomes an empty dict.""" + try: + body = await request.json() + except Exception: + return {} + return body if isinstance(body, dict) else {} + + +def make_serve_page(_ctx: WizardContext) -> Callable[[web.Request], Any]: + """Build the ``GET /connect`` handler — serves the wizard HTML page.""" + + async def handler(_request: web.Request) -> web.Response: + # Origin check intentionally skipped on the page itself: browsers don't + # send Origin on top-level navigation. The /connect/* JSON endpoints + # do enforce it. + return web.Response( + body=HTML.encode("utf-8"), + content_type="text/html", + charset="utf-8", + headers={ + "Cache-Control": "no-store", + # The wizard mints long-lived MA tokens on user click. Refuse + # to be framed so a hostile page cannot UI-redress the user + # into pressing "Generate config" inside an invisible iframe. + "X-Frame-Options": "DENY", + "Content-Security-Policy": "frame-ancestors 'none'", + }, + ) + + return handler + + +def make_info(ctx: WizardContext) -> Callable[[web.Request], Any]: + """Build the ``GET /connect/info`` handler — returns the meta JSON.""" + + async def handler(request: web.Request) -> web.Response: + guard = _origin_guard(ctx, request) + if guard is not None: + return guard + + base_url = str(getattr(ctx.mass.webserver, "base_url", "") or "").rstrip("/") + mount = "/" + ctx.mount_path.strip("/") + loopback = _loopback_url(base_url) + mount + advertised = (base_url + mount) if base_url else loopback + well_known = "/.well-known/oauth-protected-resource" + mount + + try: + permissions = list(ctx.enabled_tags_provider() or []) + except Exception: + LOGGER.exception("Connect Wizard: enabled_tags_provider raised") + permissions = [] + + return web.json_response( + { + "mount_path": ctx.mount_path, + "mcp_url_loopback": loopback, + "mcp_url_advertised": advertised, + "permissions": permissions, + "clients": clients_to_json(), + "well_known_url": well_known, + }, + headers={"Cache-Control": "no-store"}, + ) + + return handler + + +def make_exchange(ctx: WizardContext) -> Callable[[web.Request], Any]: + """Build the ``POST /connect/exchange`` handler — bootstrap → session token.""" + + async def handler(request: web.Request) -> web.Response: + guard = _origin_guard(ctx, request) + if guard is not None: + return guard + + body = await _read_json(request) + bootstrap = str(body.get("bootstrap") or "") + if not bootstrap: + return web.json_response({"error": "missing bootstrap"}, status=400) + + try: + user = await ctx.mass.webserver.auth.authenticate_with_token(bootstrap) + except Exception: + LOGGER.exception("Connect Wizard: bootstrap verify raised") + return web.json_response({"error": "verify failed"}, status=401) + + if user is None or not getattr(user, "enabled", True): + return web.json_response({"error": "invalid bootstrap"}, status=401) + + # Make the bootstrap single-use: revoke it BEFORE minting the session + # so a partial failure (revoke ok, mint fails) cannot leave both the + # bootstrap and a session valid. ``get_token_id_from_token`` handles + # both JWTs and legacy hash tokens; if it cannot resolve a token_id + # we skip the revoke — no regression vs prior behaviour. + try: + bootstrap_id = await ctx.mass.webserver.auth.get_token_id_from_token(bootstrap) + except Exception: + LOGGER.exception("Connect Wizard: get_token_id_from_token raised for bootstrap") + bootstrap_id = None + if bootstrap_id: + await revoke_token_by_id(ctx.mass, user, bootstrap_id) + + try: + session = await ctx.mass.webserver.auth.create_token( + user=user, + name="MCP — wizard session", + is_long_lived=False, + ) + except Exception: + LOGGER.exception("Connect Wizard: session token mint failed") + return web.json_response({"error": "mint failed"}, status=500) + + return web.json_response( + { + "session_token": session, + "user": _public_user(user), + } + ) + + return handler + + +def make_login(ctx: WizardContext) -> Callable[[web.Request], Any]: + """Build the ``POST /connect/login`` handler — username/password fallback.""" + + async def handler(request: web.Request) -> web.Response: + guard = _origin_guard(ctx, request) + if guard is not None: + return guard + + body = await _read_json(request) + username = str(body.get("username") or "") + password = str(body.get("password") or "") + if not username or not password: + return web.json_response({"error": "missing credentials"}, status=400) + + try: + result = await ctx.mass.webserver.auth.login( + username=username, + password=password, + provider_id="builtin", + ) + except Exception: + LOGGER.exception("Connect Wizard: login raised") + return web.json_response({"success": False, "error": "login failed"}, status=401) + + if not isinstance(result, dict) or not result.get("success"): + err = ( + result.get("error", "invalid credentials") + if isinstance(result, dict) + else "invalid credentials" + ) + return web.json_response({"success": False, "error": str(err)}, status=401) + + return web.json_response( + { + "success": True, + "session_token": result.get("access_token"), + "user": result.get("user", {}), + } + ) + + return handler + + +def make_mint_token(ctx: WizardContext) -> Callable[[web.Request], Any]: + """Build the ``POST /connect/token`` handler — mint per-client long-lived token.""" + + async def handler(request: web.Request) -> web.Response: + guard = _origin_guard(ctx, request) + if guard is not None: + return guard + + body = await _read_json(request) + session_token = str(body.get("session_token") or "") + client_id = str(body.get("client_id") or "") + if not session_token or not client_id: + return web.json_response({"error": "missing fields"}, status=400) + + spec = lookup_client(client_id) + if spec is None: + return web.json_response({"error": f"unknown client {client_id!r}"}, status=400) + + try: + user = await ctx.mass.webserver.auth.authenticate_with_token(session_token) + except Exception: + LOGGER.exception("Connect Wizard: session verify raised") + return web.json_response({"error": "session invalid"}, status=401) + + if user is None or not getattr(user, "enabled", True): + return web.json_response({"error": "session invalid"}, status=401) + + new_name = f"MCP — {spec.label}" + + # Server-side dedup: revoke any existing tokens with this exact + # client-token name for the session user, via the sanctioned + # auth.get_user_tokens / auth.revoke_token API. Yields typed + # AuthToken dataclasses — no raw sqlite rows leak in. Idempotent + # across browser/server restarts: a stale `MCP — ` row + # from any prior wizard session is reclaimed before the new mint. + for tok in await list_user_tokens(ctx.mass, user): + if tok.name == new_name: + await revoke_token_by_id(ctx.mass, user, tok.token_id) + + try: + token = await ctx.mass.webserver.auth.create_token( + user=user, + name=new_name, + is_long_lived=True, + ) + except Exception: + LOGGER.exception("Connect Wizard: per-client token mint failed") + return web.json_response({"error": "mint failed"}, status=500) + + return web.json_response({"token": token}) + + return handler + + +# ── helpers ────────────────────────────────────────────────────────────────── + + +def _loopback_url(base_url: str) -> str: + """Return ``scheme://localhost[:port]`` derived from ``base_url``.""" + if not base_url: + return "http://localhost" + parts = urlsplit(base_url) + scheme = parts.scheme or "http" + port = parts.port + suffix = f":{port}" if port else "" + return f"{scheme}://localhost{suffix}" + + +def _public_user(user: Any) -> dict[str, Any]: + """Project a User object onto the small set of fields the wizard UI uses.""" + return { + "user_id": str(getattr(user, "user_id", "") or ""), + "username": str(getattr(user, "username", "") or ""), + "role": str(getattr(getattr(user, "role", None), "value", getattr(user, "role", "")) or ""), + } diff --git a/music_assistant/providers/fastmcp_server/connect/mount.py b/music_assistant/providers/fastmcp_server/connect/mount.py new file mode 100644 index 0000000000..457cd1f81c --- /dev/null +++ b/music_assistant/providers/fastmcp_server/connect/mount.py @@ -0,0 +1,98 @@ +"""Mount the Connect Wizard endpoints onto MA's webserver. + +Five routes are registered under ``/connect``; the returned +callable removes all of them when invoked (called from +:meth:`provider.server.MCPServerRuntime.stop`). +""" + +from __future__ import annotations + +import contextlib +import importlib +from typing import TYPE_CHECKING, Any + +from .handlers import ( + WizardContext, + make_exchange, + make_info, + make_login, + make_mint_token, + make_serve_page, +) + +if TYPE_CHECKING: + from collections.abc import Callable + + from music_assistant.mass import MusicAssistant + + +def _origin_helpers() -> tuple[Any, Any]: + """Look up the origin allowlist helpers from the parent provider package. + + The parent package's name differs between contexts — ``provider`` under + pytest, ``music_assistant.providers.fastmcp_server`` inside MA — so we + resolve it from ``__package__`` at call time. Avoids both the test-only + ``provider.*`` import path and the lint-flagged ``from .. import …`` form. + """ + parent = (__package__ or "").rsplit(".", 1)[0] + if not parent: + msg = "Connect Wizard: cannot resolve parent package for http_bridge import" + raise RuntimeError(msg) + module = importlib.import_module(f"{parent}.http_bridge") + return module._compute_origin_allowlist, module._is_origin_allowed_for_request + + +async def mount_connect_wizard( + mass: MusicAssistant, + mount_path: str, + *, + enabled_tags_provider: Callable[[], list[str]], + extra_origins_csv: str = "", +) -> Callable[[], None]: + """Register the wizard routes and return a callable that unregisters them. + + :param mass: MusicAssistant instance. + :param mount_path: HTTP path prefix where the MCP server is mounted + (e.g. ``/mcp/v1``); wizard routes nest under ``/connect``. + :param enabled_tags_provider: Zero-arg callable returning the list of + currently-enabled permission tag strings; called per-request so + permission hot-swaps surface in the UI without remount. + :param extra_origins_csv: Comma-separated additional ``Origin`` values to + accept beyond the auto-derived loopback + base_url + publish_ip set. + :return: Callable that, when invoked, unregisters every wizard route. + """ + compute_allowlist, is_origin_allowed_for_request = _origin_helpers() + allowlist = compute_allowlist(mass, extra_origins_csv) + ctx = WizardContext( + mass=mass, + mount_path=mount_path, + enabled_tags_provider=enabled_tags_provider, + origin_check=lambda request: is_origin_allowed_for_request(request, allowlist), + ) + + base = "/" + mount_path.strip("/") + routes: list[tuple[str, str]] = [ + (f"{base}/connect", "GET"), + (f"{base}/connect/info", "GET"), + (f"{base}/connect/exchange", "POST"), + (f"{base}/connect/login", "POST"), + (f"{base}/connect/token", "POST"), + ] + handlers = [ + make_serve_page(ctx), + make_info(ctx), + make_exchange(ctx), + make_login(ctx), + make_mint_token(ctx), + ] + + unregister_fns: list[Callable[[], None]] = [] + for (path, method), handler in zip(routes, handlers, strict=True): + unregister_fns.append(mass.webserver.register_dynamic_route(path, handler, method=method)) + + def _unregister_all() -> None: + for fn in unregister_fns: + with contextlib.suppress(Exception): + fn() + + return _unregister_all diff --git a/music_assistant/providers/fastmcp_server/connect/page.py b/music_assistant/providers/fastmcp_server/connect/page.py new file mode 100644 index 0000000000..489c1a7df4 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/connect/page.py @@ -0,0 +1,485 @@ +"""Inline single-page HTML for the Connect Wizard. + +Embedded as a Python string (not a packaged static file) because +``pyproject.toml`` is auto-generated by ma-provider-tools and +``[tool.setuptools.package-data]`` cannot be hand-edited here without drift. +""" + +from __future__ import annotations + +HTML: str = """ + + + + +Connect Music Assistant — MCP Wizard + + + +
+

Connect Music Assistant to your AI

+
Pick your AI client below and copy the generated config. A per-client + token is minted in MA so you can revoke any one client without affecting the others.
+ +
+
Loading…
+ + + + + +
+ Active permissions +
+
+ What if my AI says "permission denied"? +
+ Read-only tools are enabled by default. To let the AI control playback, + manage queues, or edit playlists, enable the matching toggles in this + plugin's settings (Control / Edit / Delete categories) — changes apply + without a restart. +
+
+
+ +

+ ma-provider-mcp + · Music Assistant MCP server +

+
+ + + + +""" diff --git a/music_assistant/providers/fastmcp_server/constants.py b/music_assistant/providers/fastmcp_server/constants.py new file mode 100644 index 0000000000..9a19904296 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/constants.py @@ -0,0 +1,75 @@ +"""Configuration keys, defaults, and constants for the MCP Server provider.""" + +from __future__ import annotations + +# ── Server settings ──────────────────────────────────────────────────────────── +CONF_REQUIRE_AUTH = "require_auth" +CONF_MOUNT_PATH = "mount_path" +CONF_EXTRA_ALLOWED_ORIGINS = "extra_allowed_origins" +CONF_ENFORCE_AUDIENCE = "enforce_audience" +CONF_REQUIRE_CONFIRMATION = "require_confirmation" +CONF_CONNECT_EXTERNAL_URL = "connect_external_url" + +DEFAULT_MOUNT_PATH = "/mcp/v1" + +# ── Query permissions ───────────────────────────────────────────────────────── +CONF_QUERY_LIBRARY = "query_library" +CONF_QUERY_QUEUE = "query_queue" +CONF_QUERY_PLAYERS = "query_players" +CONF_QUERY_METADATA = "query_metadata" + +# ── Control permissions ─────────────────────────────────────────────────────── +CONF_CONTROL_PLAYBACK = "control_playback" +CONF_CONTROL_VOLUME = "control_volume" +CONF_CONTROL_PLAYERS = "control_players" +CONF_CONTROL_MEDIA = "control_media" + +# ── Edit permissions ────────────────────────────────────────────────────────── +CONF_EDIT_LIBRARY = "edit_library" +CONF_EDIT_QUEUE = "edit_queue" +CONF_EDIT_PLAYLISTS = "edit_playlists" +CONF_EDIT_FAVORITES = "edit_favorites" + +# ── Delete permissions ──────────────────────────────────────────────────────── +CONF_DELETE_LIBRARY = "delete_library" +CONF_DELETE_QUEUE = "delete_queue" +CONF_DELETE_PLAYLISTS = "delete_playlists" +CONF_DELETE_FAVORITES = "delete_favorites" + +# ── MCP Resources / Prompts toggles ─────────────────────────────────────────── +CONF_RES_LIBRARY = "res_library" +CONF_RES_PLAYER = "res_player" +CONF_RES_PROMPTS = "res_prompts" + +PERMISSION_KEYS: frozenset[str] = frozenset( + { + CONF_QUERY_LIBRARY, + CONF_QUERY_QUEUE, + CONF_QUERY_PLAYERS, + CONF_QUERY_METADATA, + CONF_CONTROL_PLAYBACK, + CONF_CONTROL_VOLUME, + CONF_CONTROL_PLAYERS, + CONF_CONTROL_MEDIA, + CONF_EDIT_LIBRARY, + CONF_EDIT_QUEUE, + CONF_EDIT_PLAYLISTS, + CONF_EDIT_FAVORITES, + CONF_DELETE_LIBRARY, + CONF_DELETE_QUEUE, + CONF_DELETE_PLAYLISTS, + CONF_DELETE_FAVORITES, + } +) + +RESOURCE_KEYS: frozenset[str] = frozenset( + { + CONF_RES_LIBRARY, + CONF_RES_PLAYER, + CONF_RES_PROMPTS, + } +) + +# Permission-only changes can be hot-swapped without remount; everything else triggers +# a full restart of the runtime. +HOT_SWAPPABLE_KEYS: frozenset[str] = PERMISSION_KEYS | RESOURCE_KEYS diff --git a/music_assistant/providers/fastmcp_server/http_bridge.py b/music_assistant/providers/fastmcp_server/http_bridge.py new file mode 100644 index 0000000000..07d9ca3f61 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/http_bridge.py @@ -0,0 +1,579 @@ +"""ASGI ↔ aiohttp bridge for mounting FastMCP under MA's webserver. + +FastMCP v3 exposes a Starlette-based ASGI app for streamable-HTTP transport. +MA's main webserver is aiohttp. This bridge translates a single aiohttp +``web.Request`` into ASGI ``scope/receive/send`` events and back into a +``web.StreamResponse``, so we can mount the MCP app under any path that +``mass.webserver.register_dynamic_route`` accepts (we use ``/mcp/v1/*``). + +Streaming responses (SSE / chunked) are passed through verbatim so MCP +keep-alive heartbeats and tool-progress events reach the client without +buffering. + +A second helper (:func:`mount_well_known`) registers a sibling route at +``/.well-known/oauth-protected-resource[/]`` that serves the RFC +9728 protected-resource-metadata document — pointed to by FastMCP's +``WWW-Authenticate`` 401 header, so spec-compliant MCP clients (Claude +Desktop, Codex, ChatGPT Apps SDK) can discover the authorization server. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +from typing import TYPE_CHECKING, Any +from urllib.parse import urlsplit + +from aiohttp import web + +if TYPE_CHECKING: + from collections.abc import Callable + + from music_assistant.mass import MusicAssistant + +LOGGER = logging.getLogger(__name__) + + +_DEFAULT_PORTS = {"http": 80, "https": 443} + + +def _normalize_origin(origin: str) -> str | None: + """Return ``scheme://host[:port]`` lower-cased, default-port stripped, or None. + + Rejects forms without scheme or netloc; preserves ``"null"`` verbatim so it + can be matched against an explicit allowlist entry. IPv6 hosts are + re-bracketed (``urlsplit`` strips the brackets via ``hostname``) so the + canonical form matches a literal ``http://[::1]`` allowlist entry. + """ + if not origin: + return None + if origin == "null": + return "null" + parts = urlsplit(origin) + scheme = parts.scheme.lower() + host = parts.hostname + if not scheme or not host: + return None + host_lower = host.lower() + # urlsplit's `.hostname` returns the bare IPv6 ("::1"); we need brackets + # back ("[::1]") so f-string concatenation produces a valid Origin. + bracketed_host = f"[{host_lower}]" if ":" in host_lower else host_lower + port = parts.port + if port is None or port == _DEFAULT_PORTS.get(scheme): + return f"{scheme}://{bracketed_host}" + return f"{scheme}://{bracketed_host}:{port}" + + +def _compute_origin_allowlist(mass: MusicAssistant, extra_origins_csv: str = "") -> frozenset[str]: + """Build the set of accepted ``Origin`` values for the MCP endpoint. + + Always includes loopback variants (``http://localhost``, ``http://127.0.0.1``, + ``http://[::1]``), the host derived from ``mass.webserver.base_url``, and the + advertised ``mass.webserver.publish_ip``. Additional origins from config + (CSV) are normalized and added. + """ + allow: set[str] = { + "http://localhost", + "http://127.0.0.1", + "http://[::1]", + } + + base_url = str(getattr(mass.webserver, "base_url", "") or "") + base_norm = _normalize_origin(base_url) + if base_norm: + allow.add(base_norm) + # Same host on https is acceptable when MA is behind TLS-terminating proxy. + if base_norm.startswith("http://"): + allow.add("https://" + base_norm[len("http://") :]) + + # Browsers send the MA port in Origin even for loopback access — add the + # loopback variants on the MA port so a Origin like ``http://localhost:8095`` + # is accepted (the bare loopback entries above only match port 80). + base_port = _port_from_base_url(base_url) + if base_port: + for loopback in ("localhost", "127.0.0.1", "[::1]"): + allow.add(f"http://{loopback}:{base_port}") + allow.add(f"https://{loopback}:{base_port}") + + publish_ip = str(getattr(mass.webserver, "publish_ip", "") or "") + if publish_ip: + # Derive port from base_url; fallback: no port (browsers send port if non-default). + port = _port_from_base_url(base_url) + suffix = f":{port}" if port else "" + ip_lower = publish_ip.lower() + # Bracket IPv6 literals so they match the way browsers serialize Origin. + ip_token = f"[{ip_lower}]" if ":" in ip_lower else ip_lower + allow.add(f"http://{ip_token}{suffix}") + allow.add(f"https://{ip_token}{suffix}") + + for raw in (extra_origins_csv or "").split(","): + norm = _normalize_origin(raw.strip()) + if norm: + allow.add(norm) + + return frozenset(allow) + + +def _port_from_base_url(base_url: str) -> int | None: + """Return the explicit port from a URL, or ``None`` when it's the scheme default.""" + if not base_url: + return None + parts = urlsplit(base_url) + if parts.port is not None and parts.port != _DEFAULT_PORTS.get(parts.scheme.lower()): + return parts.port + return None + + +def _is_origin_allowed(origin: str | None, allowlist: frozenset[str]) -> bool: + """Return True if the request's ``Origin`` should be accepted. + + Rules: + + * Missing ``Origin`` → allowed (stdio-style or non-browser MCP clients). + Spec MUST applies to *present* Origin values. + * ``Origin: null`` → allowed only if explicitly listed in the allowlist + (some sandboxed iframes / file:// pages send it). + * Any other value is normalized and matched literally. + """ + if origin is None: + return True + norm = _normalize_origin(origin) + if norm is None: + return False + return norm in allowlist + + +def _is_origin_allowed_for_request( + request: web.Request, + allowlist: frozenset[str], +) -> bool: + """Origin check with a Home-Assistant-ingress fallback. + + Applies :func:`_is_origin_allowed` first. When that rejects, accept the + request if **all** of the following hold: + + * the request arrived on the trusted ingress socket Music Assistant + verifies via :func:`is_request_from_ingress` (so we are not trusting + attacker-supplied headers); and + * the request carries an ``X-Forwarded-Host`` set by HA; and + * the browser's ``Origin`` matches + ``://``. + + This removes the need for HA add-on users to copy their public hostname + into the ``extra_allowed_origins`` config every time the URL changes. + """ + origin = request.headers.get("Origin") + if _is_origin_allowed(origin, allowlist): + return True + if origin is None: + return False # _is_origin_allowed already returned True above; defensive + + forwarded_host = request.headers.get("X-Forwarded-Host") + if not forwarded_host: + return False + + try: + from music_assistant.controllers.webserver.helpers.auth_middleware import ( # noqa: PLC0415 + is_request_from_ingress, + ) + except (ImportError, ModuleNotFoundError): + # ``music_assistant`` is a dev-only / test-extras dep here; absent in + # the bare provider venv. Fail closed without log noise. + return False + except Exception: + # Anything else (e.g. partial module init breakage upstream) is a real + # surprise — log so it's debuggable, then fail closed. + LOGGER.exception("Connect Wizard: unexpected error importing ingress helper") + return False + try: + if not is_request_from_ingress(request): + return False + except Exception: + # MA may evolve the request-app shape; log so a future breakage isn't + # silently a 403 with no hint as to why. + LOGGER.exception("Connect Wizard: is_request_from_ingress raised") + return False + + # Default to the aiohttp transport scheme (canonical aiohttp API) rather + # than a hard-coded "https" so an unsecured local HA installation still + # works when X-Forwarded-Proto is omitted. Multi-value X-Forwarded-Host + # (``ha.example.com, internal.lan``) intentionally fails normalisation + # below → reject; supporting it would mean trusting whichever hop the + # proxy listed last, which is rarely what you want. + forwarded_proto = request.headers.get("X-Forwarded-Proto", request.scheme) + forwarded_origin = _normalize_origin(f"{forwarded_proto}://{forwarded_host}") + if forwarded_origin is None: + return False + return _normalize_origin(origin) == forwarded_origin + + +async def mount_into_mass( + mass: MusicAssistant, + mcp: Any, + mount_path: str = "/mcp/v1", + extra_origins_csv: str = "", +) -> Callable[[], None]: + """Register the FastMCP streamable-HTTP ASGI app under MA's webserver. + + :param mass: MusicAssistant instance. + :param mcp: FastMCP server instance whose ``http_app`` is exposed. + :param mount_path: Path prefix on the MA webserver (default ``/mcp/v1``). + :param extra_origins_csv: Comma-separated additional ``Origin`` values to + accept beyond the auto-derived defaults (loopback + base_url + + publish_ip). Use for reverse-proxy hostnames or HA ingress. + :return: Callable that, when invoked, unregisters the route and shuts + down the FastMCP ASGI lifespan. + """ + # Tell FastMCP that its streamable-HTTP endpoint lives at ``mount_path`` + # (not the SDK's default ``/mcp``), so the internal Starlette router + # matches the URL the request actually arrives with — without a prefix + # strip in the bridge. With strip we'd hand FastMCP a bare ``/`` and its + # router would 404 every request. + asgi_app = _build_asgi_app(mcp, mount_path) + allowlist = _compute_origin_allowlist(mass, extra_origins_csv) + + # Drive the ASGI lifespan ourselves — without it FastMCP's + # StreamableHTTPSessionManager never enters its task group and the first + # request fails with "Task group is not initialized." The lifespan loop + # runs until shutdown is requested at unmount time. + lifespan_state = await _start_asgi_lifespan(asgi_app) + + async def handler(request: web.Request) -> web.StreamResponse: + if not _is_origin_allowed_for_request(request, allowlist): + LOGGER.warning( + "MCP: rejected request with Origin=%r from %s (not in allowlist)", + request.headers.get("Origin"), + request.remote, + ) + return web.Response(status=403, text="Forbidden Origin") + return await _asgi_to_aiohttp(asgi_app, request, strip_prefix="") + + unregister = mass.webserver.register_dynamic_route(f"{mount_path}/*", handler) + + def _unmount() -> None: + with contextlib.suppress(Exception): + unregister() + # Schedule the lifespan shutdown — caller may be sync (MA's + # ``unload``), so dispatch onto the running loop without blocking. + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(_stop_asgi_lifespan(lifespan_state)) + else: # pragma: no cover - belt-and-braces for unit-test contexts + with contextlib.suppress(Exception): + loop.run_until_complete(_stop_asgi_lifespan(lifespan_state)) + + return _unmount + + +async def _start_asgi_lifespan(asgi_app: Any) -> dict[str, Any]: + """Send ASGI ``lifespan.startup`` and keep the lifespan task running. + + Returns a state dict carrying the running task and the queues used to + feed it ``lifespan.shutdown`` later. Re-raises a startup failure synchronously + so caller sees the underlying exception immediately. + """ + receive_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + send_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + + async def receive() -> dict[str, Any]: + return await receive_queue.get() + + async def send(message: dict[str, Any]) -> None: + await send_queue.put(message) + + task = asyncio.create_task( + asgi_app({"type": "lifespan", "asgi": {"version": "3.0"}}, receive, send), + name="mcp-asgi-lifespan", + ) + + # Trigger startup and wait for ack. + await receive_queue.put({"type": "lifespan.startup"}) + ack = await asyncio.wait_for(send_queue.get(), timeout=30) + if ack.get("type") == "lifespan.startup.failed": + # Lifespan task aborted; surface its exception cleanly. + with contextlib.suppress(asyncio.CancelledError, Exception): + await task + msg = ack.get("message", "ASGI lifespan startup failed") + raise RuntimeError(msg) + if ack.get("type") != "lifespan.startup.complete": + msg = f"Unexpected ASGI lifespan event during startup: {ack!r}" + raise RuntimeError(msg) + + return {"task": task, "receive_queue": receive_queue, "send_queue": send_queue} + + +async def _stop_asgi_lifespan(state: dict[str, Any]) -> None: + """Send ASGI ``lifespan.shutdown`` and await the lifespan task to finish.""" + receive_queue: asyncio.Queue[dict[str, Any]] = state["receive_queue"] + send_queue: asyncio.Queue[dict[str, Any]] = state["send_queue"] + task: asyncio.Task[Any] = state["task"] + + with contextlib.suppress(Exception): + await receive_queue.put({"type": "lifespan.shutdown"}) + with contextlib.suppress(asyncio.TimeoutError): + # Drain the shutdown ack but don't fail if the app skips it. + await asyncio.wait_for(send_queue.get(), timeout=10) + + if not task.done(): + try: + await asyncio.wait_for(task, timeout=10) + except (TimeoutError, asyncio.CancelledError): + task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await task + + +def build_protected_resource_metadata( + *, + resource_uri: str, + authorization_servers: list[str], + scopes_supported: list[str] | None = None, + resource_name: str | None = None, +) -> dict[str, Any]: + """Construct the RFC 9728 OAuth 2.0 Protected Resource Metadata document. + + :param resource_uri: Canonical URI of this MCP server (matches the ``aud`` + claim in tokens issued for it). + :param authorization_servers: Issuer URLs of authorization servers that + produce valid tokens for ``resource_uri``. + :param scopes_supported: Optional list of scopes advertised to clients. + :param resource_name: Human-readable label. + """ + metadata: dict[str, Any] = { + "resource": resource_uri, + "authorization_servers": list(authorization_servers), + "bearer_methods_supported": ["header"], + } + if scopes_supported: + metadata["scopes_supported"] = list(scopes_supported) + if resource_name: + metadata["resource_name"] = resource_name + return metadata + + +async def mount_well_known( + mass: MusicAssistant, + *, + mount_path: str, + resource_uri: str, + authorization_servers: list[str], + scopes_supported: list[str] | Callable[[], list[str]] | None = None, + resource_name: str | None = None, +) -> Callable[[], None]: + """Register the Protected Resource Metadata endpoint on MA's webserver. + + Two paths are bound, both returning the same JSON: + + * ``/.well-known/oauth-protected-resource/`` + — the path FastMCP advertises in ``WWW-Authenticate`` 401 responses. + * ``/.well-known/oauth-protected-resource`` — root fallback (RFC 9728 + §3.1 second form), so clients that strip the path component still find + the document. + + :param scopes_supported: Either a static list (snapshot at mount time) or a + zero-arg callable returning the current scope list. The callable form + lets the document stay in sync with permission hot-swaps without a + runtime rebuild — the body is regenerated on each request. + :return: Callable that unregisters both routes when invoked. + """ + + def _resolve_scopes() -> list[str] | None: + if callable(scopes_supported): + return scopes_supported() + return scopes_supported + + def _build_body() -> bytes: + metadata = build_protected_resource_metadata( + resource_uri=resource_uri, + authorization_servers=authorization_servers, + scopes_supported=_resolve_scopes(), + resource_name=resource_name, + ) + return json.dumps(metadata).encode() + + async def handler(_request: web.Request) -> web.Response: + # Re-render per request — sub-ms json.dumps — so a closure over + # self._config in MCPServerRuntime reflects permission hot-swaps. + return web.Response( + body=_build_body(), + content_type="application/json", + headers={"Cache-Control": "no-store"}, + ) + + suffix = mount_path.lstrip("/") + paths = [ + f"/.well-known/oauth-protected-resource/{suffix}", + "/.well-known/oauth-protected-resource", + ] + unregister_fns: list[Callable[[], None]] = [ + mass.webserver.register_dynamic_route(p, handler, method="GET") for p in paths + ] + + def _unregister_all() -> None: + for fn in unregister_fns: + with contextlib.suppress(Exception): + fn() + + return _unregister_all + + +def _build_asgi_app(mcp: Any, mount_path: str = "/mcp") -> Any: + """Return the streamable-HTTP ASGI app from FastMCP, accommodating v3 minor renames. + + ``mount_path`` is propagated as ``http_app(path=...)`` so FastMCP's + Starlette router exposes the streamable endpoint at the same URL the + aiohttp bridge forwards to it — preventing 404s when our outer mount + differs from the SDK's default ``/mcp``. RFC 9728 metadata routes + advertised by FastMCP are likewise rooted at this path. + """ + if hasattr(mcp, "http_app"): + return mcp.http_app(transport="streamable-http", path=mount_path) + if hasattr(mcp, "streamable_http_app"): + return mcp.streamable_http_app() + if hasattr(mcp, "asgi_app"): + return mcp.asgi_app() + msg = "Could not find an ASGI app factory on FastMCP instance" + raise RuntimeError(msg) + + +async def _asgi_to_aiohttp( # noqa: PLR0915 - single-purpose ASGI bridge, splitting harms readability + asgi_app: Any, + request: web.Request, + strip_prefix: str = "", +) -> web.StreamResponse: + """Bridge a single aiohttp request through an ASGI app. + + The bridge supports streaming responses: ``http.response.body`` events + with ``more_body=True`` are flushed to the client immediately, which is + required for streamable-HTTP MCP transport. + """ + scope = _build_scope(request, strip_prefix) + body_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + + async def receive() -> dict[str, Any]: + return await body_queue.get() + + response_state: dict[str, Any] = {"started": False, "response": None, "disconnected": False} + + async def send(message: dict[str, Any]) -> None: + msg_type = message.get("type") + if response_state["disconnected"]: + # Client gave up; suppress further ASGI sends so the app can wind + # down without an exception cascade. + return + try: + if msg_type == "http.response.start": + status = int(message.get("status", 200)) + headers_list = message.get("headers", []) + response = web.StreamResponse(status=status) + for raw_name, raw_value in headers_list: + name = ( + raw_name.decode("latin-1") if isinstance(raw_name, bytes) else str(raw_name) + ) + value = ( + raw_value.decode("latin-1") + if isinstance(raw_value, bytes) + else str(raw_value) + ) + if name.lower() in {"transfer-encoding", "content-length"}: + continue + response.headers[name] = value + await response.prepare(request) + response_state["response"] = response + response_state["started"] = True + elif msg_type == "http.response.body": + response = response_state["response"] + if response is None: + msg = "ASGI app sent body before start" + raise RuntimeError(msg) + body = message.get("body", b"") + if body: + await response.write(body) + if not message.get("more_body", False): + await response.write_eof() + except (ConnectionResetError, ConnectionError, asyncio.CancelledError): + # The other side closed the (SSE) stream. Mark the response as + # disconnected and feed an ASGI ``http.disconnect`` upstream so + # the app's keep-alive / ping loops can wind down cleanly. Logged + # at debug because this is a normal, expected client behaviour + # for long-lived streams. + response_state["disconnected"] = True + with contextlib.suppress(Exception): + await body_queue.put({"type": "http.disconnect"}) + LOGGER.debug("MCP bridge: client closed stream during send (path=%s)", request.path) + + async def pump_request_body() -> None: + try: + async for chunk in request.content.iter_chunked(64 * 1024): + await body_queue.put({"type": "http.request", "body": chunk, "more_body": True}) + await body_queue.put({"type": "http.request", "body": b"", "more_body": False}) + except (ConnectionResetError, ConnectionError, asyncio.CancelledError): + response_state["disconnected"] = True + with contextlib.suppress(Exception): + await body_queue.put({"type": "http.disconnect"}) + except Exception: + LOGGER.exception("MCP bridge: failed to pump request body") + with contextlib.suppress(Exception): + await body_queue.put({"type": "http.disconnect"}) + + pump_task = asyncio.create_task(pump_request_body()) + try: + await asgi_app(scope, receive, send) + except (ConnectionResetError, ConnectionError, asyncio.CancelledError): + # Client-driven disconnect during streaming — normal flow; don't + # re-raise as 500. + response_state["disconnected"] = True + LOGGER.debug("MCP bridge: ASGI app cancelled by client disconnect") + except Exception: + LOGGER.exception("MCP bridge: ASGI app raised") + if not response_state["started"]: + return web.Response(status=500, text="Internal MCP bridge error") + raise + finally: + pump_task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await pump_task + + response = response_state["response"] + if response is None: + return web.Response(status=204) + assert isinstance(response, web.StreamResponse) + return response + + +def _build_scope(request: web.Request, strip_prefix: str) -> dict[str, Any]: + """Convert an aiohttp request into a minimal ASGI HTTP scope dict.""" + raw_path = request.rel_url.raw_path + if strip_prefix and raw_path.startswith(strip_prefix): + raw_path = raw_path[len(strip_prefix) :] + if not raw_path.startswith("/"): + raw_path = "/" + raw_path + + headers: list[tuple[bytes, bytes]] = [ + (k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in request.headers.items() + ] + + server_host = request.url.host or "localhost" + server_port = request.url.port or (443 if request.url.scheme == "https" else 80) + + client_addr: tuple[str, int] | None = None + peername = request.transport.get_extra_info("peername") if request.transport else None + if peername and len(peername) >= 2: + client_addr = (str(peername[0]), int(peername[1])) + + return { + "type": "http", + "asgi": {"version": "3.0", "spec_version": "2.3"}, + "http_version": "1.1", + "method": request.method, + "scheme": request.url.scheme, + "path": raw_path, + "raw_path": raw_path.encode("latin-1"), + "query_string": request.rel_url.raw_query_string.encode("latin-1"), + "root_path": strip_prefix.rstrip("/"), + "headers": headers, + "server": (server_host, server_port), + "client": client_addr, + } diff --git a/music_assistant/providers/fastmcp_server/icon.svg b/music_assistant/providers/fastmcp_server/icon.svg new file mode 100644 index 0000000000..03d9f85d32 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/icon.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/music_assistant/providers/fastmcp_server/icon_monochrome.svg b/music_assistant/providers/fastmcp_server/icon_monochrome.svg new file mode 100644 index 0000000000..03d9f85d32 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/icon_monochrome.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/music_assistant/providers/fastmcp_server/manifest.json b/music_assistant/providers/fastmcp_server/manifest.json new file mode 100644 index 0000000000..f8d784f939 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/manifest.json @@ -0,0 +1,16 @@ +{ + "type": "plugin", + "domain": "fastmcp_server", + "stage": "experimental", + "name": "MCP Server", + "description": "Exposes Music Assistant as a Model Context Protocol (MCP) server for Claude, Codex, and other MCP-aware LLM clients.", + "codeowners": ["@TrudenBoy"], + "credits": [ + "[FastMCP](https://github.com/jlowin/fastmcp)", + "[Model Context Protocol](https://modelcontextprotocol.io/)" + ], + "requirements": ["fastmcp==3.2.4"], + "documentation": "https://github.com/trudenboy/ma-provider-mcp", + "multi_instance": false, + "builtin": false +} diff --git a/music_assistant/providers/fastmcp_server/middleware.py b/music_assistant/providers/fastmcp_server/middleware.py new file mode 100644 index 0000000000..7162cf220e --- /dev/null +++ b/music_assistant/providers/fastmcp_server/middleware.py @@ -0,0 +1,168 @@ +"""Tag-filter middleware: hide tools / resources / prompts whose tags are disabled. + +FastMCP v3's built-in ``restrict_tag`` is scope-based authorization (token must +carry a specific OAuth scope). What we need here is **config-driven visibility**: +the operator toggles a permission boolean and the corresponding tools simply +disappear from listings — no error path, no permission-denied trace. + +This middleware reads ``allowed_tags`` from a closure (so we can swap the set in +place when ``MCPServerProvider.update_config`` runs without rebuilding the +FastMCP server), and applies the rule: + +* a component with **at least one** allowed tag is exposed +* a component with **no** tags is exposed (treat as always-on infrastructure) +* a component whose tags are **all** disabled is hidden / blocked + +Listings are filtered post-hoc; direct invocations (``tools/call``, +``resources/read``, ``prompts/get``) look the component up by name/URI and +apply the same rule. A client that cached a tool name from an earlier +permission set therefore cannot reach a now-disabled tool. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Sequence +from typing import TYPE_CHECKING, Any, Literal + +from fastmcp.exceptions import NotFoundError, PromptError, ResourceError, ToolError +from fastmcp.server.middleware import Middleware + +if TYPE_CHECKING: + from fastmcp.server.middleware.middleware import CallNext, MiddlewareContext + + +ComponentKind = Literal["tool", "resource", "prompt"] +TagsLookup = Callable[[ComponentKind, str], Awaitable[set[str] | None]] + + +class TagFilterMiddleware(Middleware): # type: ignore[misc, unused-ignore] + """Hide tools, resources, and prompts whose tags are not in ``allowed_tags``. + + ``Middleware`` is typed as ``Any`` upstream; under + ``disallow_subclassing_any`` we suppress the misc-rule on the class + line rather than every method. + """ + + def __init__( + self, + allowed_tags_provider: Callable[[], set[str]], + lookup_component_tags: TagsLookup, + ) -> None: + """Initialise the middleware. + + :param allowed_tags_provider: zero-arg callable returning the *current* + set of allowed tags. Wrapped in a callable so the operator can + change permission flags without restarting the runtime. + :param lookup_component_tags: async ``(kind, key) -> set[str] | None`` + that resolves a tool name / resource URI / prompt name back to its + tag set. Returns ``None`` when the component does not exist (treat + as blocked: a stale cached name from a prior permission set must + not slip through). + """ + super().__init__() + self._allowed = allowed_tags_provider + self._lookup = lookup_component_tags + + # ── filtered listings ──────────────────────────────────────────────────── + + async def on_list_tools( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Sequence[Any]], + ) -> Sequence[Any]: + """Drop tools whose tags are all disabled.""" + items = await call_next(context) + return [t for t in items if self._is_visible(t)] + + async def on_list_resources( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Sequence[Any]], + ) -> Sequence[Any]: + """Drop resources whose tags are all disabled.""" + items = await call_next(context) + return [r for r in items if self._is_visible(r)] + + async def on_list_resource_templates( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Sequence[Any]], + ) -> Sequence[Any]: + """Drop resource templates whose tags are all disabled.""" + items = await call_next(context) + return [r for r in items if self._is_visible(r)] + + async def on_list_prompts( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Sequence[Any]], + ) -> Sequence[Any]: + """Drop prompts whose tags are all disabled.""" + items = await call_next(context) + return [p for p in items if self._is_visible(p)] + + # ── invocation guards ──────────────────────────────────────────────────── + + async def on_call_tool( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Any], + ) -> Any: + """Block calls to tools whose tag set has been disabled.""" + name = getattr(context.message, "name", "") + await self._reject_if_hidden("tool", name) + return await call_next(context) + + async def on_read_resource( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Any], + ) -> Any: + """Block reads of resources whose tag set has been disabled.""" + uri = str(getattr(context.message, "uri", "")) + await self._reject_if_hidden("resource", uri) + return await call_next(context) + + async def on_get_prompt( + self, + context: MiddlewareContext[Any], + call_next: CallNext[Any, Any], + ) -> Any: + """Block reads of prompts whose tag set has been disabled.""" + name = getattr(context.message, "name", "") + await self._reject_if_hidden("prompt", name) + return await call_next(context) + + # Error class chosen so the SDK reports the failure under the right RPC + # method (tools/resources/prompts) rather than always as a tool error. + _ERROR_BY_KIND: dict[ComponentKind, type[Exception]] = { + "tool": ToolError, + "resource": ResourceError, + "prompt": PromptError, + } + + # ── helpers ────────────────────────────────────────────────────────────── + + def _is_visible(self, component: Any) -> bool: + tags = getattr(component, "tags", None) or set() + if not tags: + return True + allowed = self._allowed() + return any(str(t) in allowed for t in tags) + + async def _reject_if_hidden(self, kind: ComponentKind, key: str) -> None: + if not key: + return + tags = await self._lookup(kind, key) + if tags is None: + # Component doesn't exist (or is itself disabled at the FastMCP layer). + # Surface a NotFoundError so the SDK returns the spec-correct + # "method-not-allowed" / "not-found" path rather than 500. + msg = f"{kind.capitalize()} {key!r} not found" + raise NotFoundError(msg) + if not tags: + return # untagged → always-on + allowed = self._allowed() + if not any(t in allowed for t in tags): + msg = f"{kind.capitalize()} {key!r} is currently disabled by configuration" + raise self._ERROR_BY_KIND[kind](msg) diff --git a/music_assistant/providers/fastmcp_server/models.py b/music_assistant/providers/fastmcp_server/models.py new file mode 100644 index 0000000000..d54e7f5632 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/models.py @@ -0,0 +1,101 @@ +"""Trimmed response dataclasses used in tool replies. + +Tools that need to return a Music Assistant entity use these light-weight shapes +to keep payloads small for LLM context windows. Resources, by contrast, return +the full ``music_assistant_models`` types directly because clients usually +expect a complete object when they fetch a URI. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass +class TrackBrief: + """A track summary for tool responses.""" + + uri: str + name: str + artists: list[str] = field(default_factory=list) + album: str | None = None + duration: int | None = None + + +@dataclass +class AlbumBrief: + """An album summary for tool responses.""" + + uri: str + name: str + artist: str | None = None + year: int | None = None + + +@dataclass +class ArtistBrief: + """An artist summary for tool responses.""" + + uri: str + name: str + + +@dataclass +class PlaylistBrief: + """A playlist summary for tool responses.""" + + uri: str + name: str + track_count: int | None = None + owner: str | None = None + + +@dataclass +class RadioBrief: + """A radio summary for tool responses.""" + + uri: str + name: str + description: str | None = None + + +@dataclass +class PlayerBrief: + """A player summary for tool responses.""" + + player_id: str + name: str + state: str + volume_level: int | None = None + powered: bool = True + current_item: str | None = None + + +@dataclass +class QueueItemBrief: + """A queue item summary.""" + + item_id: str + name: str + duration: int | None = None + artists: list[str] = field(default_factory=list) + + +@dataclass +class QueueBrief: + """A queue summary for tool responses.""" + + queue_id: str + current_index: int | None + item_count: int + shuffle: bool + repeat: str + items: list[QueueItemBrief] = field(default_factory=list) + + +@dataclass +class RecommendationFolderBrief: + """One curated recommendation folder (e.g. "Mood: Focus") with its track URIs.""" + + name: str + item_uris: list[str] = field(default_factory=list) diff --git a/music_assistant/providers/fastmcp_server/prompts.py b/music_assistant/providers/fastmcp_server/prompts.py new file mode 100644 index 0000000000..464384816d --- /dev/null +++ b/music_assistant/providers/fastmcp_server/prompts.py @@ -0,0 +1,69 @@ +"""Canned MCP prompts. + +These prompts hand the LLM a small, opinionated playbook for common tasks +("find a song and play it on a specific speaker", "now playing summary", +"build a party playlist") so an LLM client can chain MCP tools without +re-deriving the workflow each time. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from .constants import CONF_RES_PROMPTS + +if TYPE_CHECKING: + from music_assistant_models.config_entries import ProviderConfig + + +def register_prompts(mcp: Any, config: ProviderConfig) -> None: + """Register canned prompts on the FastMCP root, gated by ``CONF_RES_PROMPTS``.""" + if not config.get_value(CONF_RES_PROMPTS): + return + + @mcp.prompt(name="find_and_play") # type: ignore[untyped-decorator, unused-ignore] + def find_and_play(query: str = "", target_player: str = "") -> str: + """Search and play media on a player.""" + target = target_player or "" + request = query or "" + return ( + f"Find the best match for the user's request: '{request}'.\n" + "Use library_search_tracks (and library_search_albums or " + "library_search_artists if needed) to identify the right URI.\n" + f"Then call playback_play_media with queue_id='{target}' and the " + "resolved URI.\n" + "Finally, call queue_get_active_queue to confirm the new state " + "and report it back." + ) + + @mcp.prompt(name="curate_party_playlist") # type: ignore[untyped-decorator, unused-ignore] + def party_playlist(theme: str = "indie 2010s", length_minutes: int = 60) -> str: + """Build a party playlist.""" + return ( + f"Curate a playlist of roughly {length_minutes} minutes around " + f"the theme: '{theme}'.\n" + "Use library_search_tracks repeatedly with varied sub-queries " + "(genres, eras, similar artists) and metadata_recommendations " + "to seed candidates.\n" + "Pick tracks the user would dance to.\n" + "Then call playlists_create_playlist with a descriptive name, " + "and playlists_add_tracks to fill it.\n" + "Report the playlist URI when done." + ) + + @mcp.prompt(name="now_playing_summary") # type: ignore[untyped-decorator, unused-ignore] + def now_playing(player_id: str = "") -> str: + """Summarise what's currently playing on a player (or all players).""" + if player_id: + return ( + f"Use queue_get_active_queue with player_id='{player_id}' " + "to fetch the current queue.\n" + "Summarise the now-playing track (title, artist, album, " + "time remaining) and the next two upcoming items in 3-4 " + "sentences." + ) + return ( + "List all players via players_list_players. For each player " + "whose state is 'playing', fetch its active queue and " + "summarise the now-playing track. Group by room when possible." + ) diff --git a/music_assistant/providers/fastmcp_server/provider.py b/music_assistant/providers/fastmcp_server/provider.py new file mode 100644 index 0000000000..a0184f3bc4 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/provider.py @@ -0,0 +1,64 @@ +"""MCP Server provider — main PluginProvider implementation. + +The provider is a thin lifecycle wrapper over :class:`MCPServerRuntime` from +``server.py``. ``handle_async_init`` constructs the runtime and starts it; +``unload`` shuts it down; ``update_config`` either hot-swaps the tag-filter +middleware (for permission-only changes) or restarts the runtime. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from music_assistant.models.plugin import PluginProvider + +from .constants import HOT_SWAPPABLE_KEYS + +if TYPE_CHECKING: + from music_assistant_models.config_entries import ProviderConfig + + from .server import MCPServerRuntime + + +LOGGER = logging.getLogger(__name__) + + +class MCPServerProvider(PluginProvider): + """Music Assistant plugin provider wrapping an MCP server runtime.""" + + _runtime: MCPServerRuntime | None = None + + async def handle_async_init(self) -> None: + """Build and start the FastMCP runtime.""" + from .server import MCPServerRuntime # noqa: PLC0415 + + self._runtime = MCPServerRuntime(self.mass, self.config, self.logger) + await self._runtime.start() + + async def loaded_in_mass(self) -> None: + """Log the public URL once everything is wired up.""" + if self._runtime is not None: + self.logger.info("MCP server mounted at %s", self._runtime.public_url) + + async def unload(self, is_removed: bool = False) -> None: + """Stop the runtime and unmount the HTTP route.""" + if self._runtime is not None: + await self._runtime.stop() + self._runtime = None + + async def update_config(self, config: ProviderConfig, changed_keys: set[str]) -> None: + """Apply config changes — hot-swap when possible, restart otherwise.""" + if self._runtime is None: + return + normalized_keys = {k.removeprefix("values/") for k in changed_keys} + if normalized_keys.issubset(HOT_SWAPPABLE_KEYS): + self.config = config + await self._runtime.apply_permission_change(config, normalized_keys) + else: + await self._runtime.stop() + self.config = config + from .server import MCPServerRuntime # noqa: PLC0415 + + self._runtime = MCPServerRuntime(self.mass, config, self.logger) + await self._runtime.start() diff --git a/music_assistant/providers/fastmcp_server/resources/__init__.py b/music_assistant/providers/fastmcp_server/resources/__init__.py new file mode 100644 index 0000000000..9a708adb1f --- /dev/null +++ b/music_assistant/providers/fastmcp_server/resources/__init__.py @@ -0,0 +1,33 @@ +"""MCP resource registration entry point.""" +# Relative imports are the canonical pattern across MA providers — sync-to-fork +# preserves them verbatim, so disable TID252 file-wide here. +# ruff: noqa: TID252 + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ..constants import CONF_RES_LIBRARY, CONF_RES_PLAYER +from .library_resources import register_library_resources +from .player_resources import register_player_resources + +if TYPE_CHECKING: + from music_assistant_models.config_entries import ProviderConfig + + from music_assistant.mass import MusicAssistant + + +def register_resources(mcp: Any, mass: MusicAssistant, config: ProviderConfig) -> None: + """Register MCP resources, gated by config toggles. + + :param mcp: FastMCP root server. + :param mass: MusicAssistant instance. + :param config: provider config (controls which resource groups to expose). + """ + if config.get_value(CONF_RES_LIBRARY): + register_library_resources(mcp, mass) + if config.get_value(CONF_RES_PLAYER): + register_player_resources(mcp, mass) + + +__all__ = ["register_resources"] diff --git a/music_assistant/providers/fastmcp_server/resources/_uri.py b/music_assistant/providers/fastmcp_server/resources/_uri.py new file mode 100644 index 0000000000..5db3b342eb --- /dev/null +++ b/music_assistant/providers/fastmcp_server/resources/_uri.py @@ -0,0 +1,61 @@ +"""URI parsing for MCP resources.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass + +ALLOWED_SCHEMES: frozenset[str] = frozenset({"library", "player", "queue"}) +ALLOWED_TYPES: frozenset[str] = frozenset( + {"artist", "album", "track", "playlist", "radio", "podcast", "audiobook"} +) +_ID_RE = re.compile(r"^[A-Za-z0-9._:%@\-]+$") + + +@dataclass(frozen=True) +class ResourceURI: + """A parsed MCP resource URI: ``://[/]``.""" + + scheme: str + type: str | None + id: str + + +def parse_resource_uri(uri: str) -> ResourceURI: + """Parse and validate a resource URI. + + :param uri: input URI (``library://artist/123``, ``player://kitchen``). + :raises ValueError: if scheme/type/id are missing, unknown, or contain + characters that could enable path traversal. + """ + if "://" not in uri: + msg = f"Invalid URI (missing scheme): {uri!r}" + raise ValueError(msg) + scheme, _, rest = uri.partition("://") + if scheme not in ALLOWED_SCHEMES: + msg = f"Unsupported scheme: {scheme!r}" + raise ValueError(msg) + if not rest or ".." in rest: + msg = f"Invalid URI body: {rest!r}" + raise ValueError(msg) + + if "/" in rest: + if scheme != "library": + msg = f"{scheme}:// URIs must not contain a path separator, got {uri!r}" + raise ValueError(msg) + type_, _, identifier = rest.partition("/") + if type_ not in ALLOWED_TYPES: + msg = f"Unknown library type: {type_!r}" + raise ValueError(msg) + if not identifier or not _ID_RE.match(identifier): + msg = f"Invalid id: {identifier!r}" + raise ValueError(msg) + return ResourceURI(scheme=scheme, type=type_, id=identifier) + + if scheme == "library": + msg = f"library:// URIs require a type segment, got {uri!r}" + raise ValueError(msg) + if not _ID_RE.match(rest): + msg = f"Invalid id: {rest!r}" + raise ValueError(msg) + return ResourceURI(scheme=scheme, type=None, id=rest) diff --git a/music_assistant/providers/fastmcp_server/resources/library_resources.py b/music_assistant/providers/fastmcp_server/resources/library_resources.py new file mode 100644 index 0000000000..009aed0655 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/resources/library_resources.py @@ -0,0 +1,41 @@ +"""URI-addressable read-only library resources.""" +# ruff: noqa: TID252 -- relative imports are the canonical MA-provider pattern. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ..tags import Tag +from ..tools._common import to_resource_text + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + + +def register_library_resources(mcp: Any, mass: MusicAssistant) -> None: + """Register ``library://*`` resources on the given FastMCP root.""" + + @mcp.resource("library://artist/{artist_id}", tags={Tag.QUERY_LIBRARY}) # type: ignore[untyped-decorator, unused-ignore] + async def artist_resource(artist_id: str) -> str | None: + """Full artist record by library id.""" + return to_resource_text(await mass.music.artists.get_library_item(artist_id)) + + @mcp.resource("library://album/{album_id}", tags={Tag.QUERY_LIBRARY}) # type: ignore[untyped-decorator, unused-ignore] + async def album_resource(album_id: str) -> str | None: + """Full album record by library id.""" + return to_resource_text(await mass.music.albums.get_library_item(album_id)) + + @mcp.resource("library://track/{track_id}", tags={Tag.QUERY_LIBRARY}) # type: ignore[untyped-decorator, unused-ignore] + async def track_resource(track_id: str) -> str | None: + """Full track record by library id.""" + return to_resource_text(await mass.music.tracks.get_library_item(track_id)) + + @mcp.resource("library://playlist/{playlist_id}", tags={Tag.QUERY_LIBRARY}) # type: ignore[untyped-decorator, unused-ignore] + async def playlist_resource(playlist_id: str) -> str | None: + """Full playlist record by library id.""" + return to_resource_text(await mass.music.playlists.get_library_item(playlist_id)) + + @mcp.resource("library://radio/{radio_id}", tags={Tag.QUERY_LIBRARY}) # type: ignore[untyped-decorator, unused-ignore] + async def radio_resource(radio_id: str) -> str | None: + """Full radio station record by library id.""" + return to_resource_text(await mass.music.radio.get_library_item(radio_id)) diff --git a/music_assistant/providers/fastmcp_server/resources/player_resources.py b/music_assistant/providers/fastmcp_server/resources/player_resources.py new file mode 100644 index 0000000000..59cd6ad705 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/resources/player_resources.py @@ -0,0 +1,31 @@ +"""URI-addressable read-only player and queue resources.""" +# ruff: noqa: TID252 -- relative imports are the canonical MA-provider pattern. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ..tags import Tag +from ..tools._common import to_brief_player, to_brief_queue, to_resource_text + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + + +def register_player_resources(mcp: Any, mass: MusicAssistant) -> None: + """Register ``player://`` and ``queue://`` resources on the given FastMCP root.""" + + @mcp.resource("player://{player_id}", tags={Tag.QUERY_PLAYERS}) # type: ignore[untyped-decorator, unused-ignore] + async def player_resource(player_id: str) -> str | None: + """Player snapshot by id.""" + player = mass.players.get_player(player_id) + return to_resource_text(to_brief_player(player) if player is not None else None) + + @mcp.resource("queue://{queue_id}", tags={Tag.QUERY_QUEUE}) # type: ignore[untyped-decorator, unused-ignore] + async def queue_resource(queue_id: str) -> str | None: + """Queue snapshot by id (up to 500 items — MA's default page size).""" + queue = mass.player_queues.get(queue_id) + if queue is None: + return None + items = mass.player_queues.items(queue_id, limit=500) + return to_resource_text(to_brief_queue(queue, items=list(items))) diff --git a/music_assistant/providers/fastmcp_server/server.py b/music_assistant/providers/fastmcp_server/server.py new file mode 100644 index 0000000000..b99d49a31c --- /dev/null +++ b/music_assistant/providers/fastmcp_server/server.py @@ -0,0 +1,295 @@ +"""MCPServerRuntime — composes FastMCP, mounts it into MA's webserver.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from .constants import ( + CONF_ENFORCE_AUDIENCE, + CONF_EXTRA_ALLOWED_ORIGINS, + CONF_MOUNT_PATH, + CONF_REQUIRE_AUTH, + CONF_REQUIRE_CONFIRMATION, + DEFAULT_MOUNT_PATH, +) +from .tags import enabled_tags + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from music_assistant_models.config_entries import ProviderConfig + + from music_assistant.mass import MusicAssistant + +LOGGER = logging.getLogger(__name__) + + +class MCPServerRuntime: + """Build and manage a FastMCP server mounted into MA's webserver. + + The lifecycle is intentionally simple: + + * :meth:`start` builds the FastMCP root, mounts namespaced sub-servers + for each tool category, registers resources and prompts, applies the + tag-filter middleware, and exposes the streamable-HTTP ASGI app on + MA's webserver under :pyattr:`mount_path`. + * :meth:`stop` unregisters the dynamic route. + * :meth:`apply_permission_change` rebuilds the runtime in place when + only permission flags / resource toggles changed (no port collision + since we reuse MA's webserver). + """ + + def __init__( + self, + mass: MusicAssistant, + config: ProviderConfig, + logger: logging.Logger, + ) -> None: + """Hold the shared dependencies; nothing is started here. + + :param mass: MusicAssistant instance. + :param config: Provider config. + :param logger: Provider-scoped logger. + """ + self._mass = mass + self._config = config + self._logger = logger + raw_path = str(config.get_value(CONF_MOUNT_PATH) or DEFAULT_MOUNT_PATH) + self._mount_path: str = "/" + raw_path.strip("/") + self._mcp: Any = None + self._unmount: Callable[[], None] | None = None + self._unmount_well_known: Callable[[], None] | None = None + self._unmount_connect: Callable[[], None] | None = None + # Mutable so apply_permission_change can hot-swap the allowed-tag set + # without re-instantiating the TagFilterMiddleware closure. + self._allowed_tags: set[str] = set() + + @property + def public_url(self) -> str: + """Return the externally visible MCP endpoint URL.""" + base = str(getattr(self._mass.webserver, "base_url", "")).rstrip("/") + return f"{base}{self._mount_path}" + + async def start(self) -> None: + """Build the FastMCP server and mount it into the MA webserver.""" + from fastmcp import FastMCP # noqa: PLC0415 + + from .auth import MASTokenVerifier # noqa: PLC0415 + from .http_bridge import mount_into_mass # noqa: PLC0415 + from .prompts import register_prompts # noqa: PLC0415 + from .resources import register_resources # noqa: PLC0415 + from .tools import ( # noqa: PLC0415 + build_library_server, + build_media_server, + build_metadata_server, + build_playback_server, + build_players_server, + build_playlists_server, + build_queue_server, + build_volume_server, + ) + + require_auth = bool(self._config.get_value(CONF_REQUIRE_AUTH)) + base_url = str(getattr(self._mass.webserver, "base_url", "") or "").rstrip("/") + public_resource_uri = f"{base_url}{self._mount_path}" if base_url else None + enforce_audience = bool(self._config.get_value(CONF_ENFORCE_AUDIENCE)) + verifier = ( + MASTokenVerifier( + self._mass, + base_url=base_url or None, + public_resource_uri=public_resource_uri, + enforce_audience=enforce_audience, + ) + if require_auth + else None + ) + + mcp = FastMCP( + name="music-assistant", + instructions=( + "Music Assistant MCP server: control playback, browse the library, " + "manage queues, and inspect players. Tools are namespaced by category " + "(library_, queue_, playback_, players_, playlists_, volume_, media_, " + "metadata_). Resources expose URI-addressable views: library://artist/{id}, " + "library://album/{id}, library://track/{id}, library://playlist/{id}, " + "player://{id}, queue://{id}." + ), + auth=verifier, + ) + + require_confirmation = bool(self._config.get_value(CONF_REQUIRE_CONFIRMATION) or False) + mcp.mount(build_library_server(self._mass), namespace="library") + mcp.mount( + build_queue_server(self._mass, require_confirmation=require_confirmation), + namespace="queue", + ) + mcp.mount(build_playback_server(self._mass), namespace="playback") + mcp.mount(build_players_server(self._mass), namespace="players") + mcp.mount( + build_playlists_server(self._mass, require_confirmation=require_confirmation), + namespace="playlists", + ) + mcp.mount(build_volume_server(self._mass), namespace="volume") + mcp.mount( + build_media_server(self._mass, require_confirmation=require_confirmation), + namespace="media", + ) + mcp.mount(build_metadata_server(self._mass), namespace="metadata") + + register_resources(mcp, self._mass, self._config) + register_prompts(mcp, self._config) + + self._apply_tag_filter(mcp, enabled_tags(self._config)) + + self._mcp = mcp + extra_origins = str(self._config.get_value(CONF_EXTRA_ALLOWED_ORIGINS) or "") + self._unmount = await mount_into_mass( + self._mass, mcp, self._mount_path, extra_origins_csv=extra_origins + ) + + # Publish RFC 9728 protected-resource metadata at the well-known URL + # advertised by FastMCP in WWW-Authenticate. Skipped when require_auth + # is off (no metadata to serve) or base_url is missing (no canonical URI). + if require_auth and public_resource_uri: + from .http_bridge import mount_well_known # noqa: PLC0415 + + self._unmount_well_known = await mount_well_known( + self._mass, + mount_path=self._mount_path, + resource_uri=public_resource_uri, + authorization_servers=[base_url], + # Lazy provider so hot-swapped permissions update the + # advertised `scopes_supported` immediately, without + # rebuilding the runtime. + scopes_supported=lambda: [str(t) for t in enabled_tags(self._config)], + resource_name="Music Assistant MCP", + ) + + # Mount the Connect Wizard. Failure here is non-fatal — the MCP server + # itself is unaffected; the user just falls back to manual onboarding. + try: + from .connect import mount_connect_wizard # noqa: PLC0415 + + self._unmount_connect = await mount_connect_wizard( + self._mass, + self._mount_path, + enabled_tags_provider=lambda: [str(t) for t in enabled_tags(self._config)], + extra_origins_csv=extra_origins, + ) + except Exception: + self._logger.warning("Connect Wizard: mount failed", exc_info=True) + + self._logger.debug( + "MCP runtime started: mount=%s, auth=%s, tags=%d", + self._mount_path, + bool(verifier), + len(enabled_tags(self._config)), + ) + + async def stop(self) -> None: + """Unregister the HTTP route and drop references.""" + if self._unmount is not None: + try: + self._unmount() + except Exception: + self._logger.exception("Failed to unregister MCP route") + self._unmount = None + if getattr(self, "_unmount_well_known", None) is not None: + try: + self._unmount_well_known() # type: ignore[misc, unused-ignore] + except Exception: + self._logger.exception("Failed to unregister well-known route") + self._unmount_well_known = None + if getattr(self, "_unmount_connect", None) is not None: + try: + self._unmount_connect() # type: ignore[misc, unused-ignore] + except Exception: + self._logger.exception("Failed to unregister Connect Wizard route") + self._unmount_connect = None + self._mcp = None + + async def apply_permission_change( + self, new_config: ProviderConfig, changed_keys: set[str] + ) -> None: + """Hot-swap the allowed-tag set, or restart when resources are involved. + + Resource toggles (``CONF_RES_*``) require a rebuild because resource + registration is decided at :meth:`start` time; permission flags flip the + tag set in the closure read by :class:`TagFilterMiddleware` and take + effect on the next request without a restart. + + :param new_config: the new provider config; assigned to ``self._config`` + before any restart so ``start`` reads the updated values. + :param changed_keys: keys that changed (already stripped of any + ``values/`` prefix by the caller). MA mutates ``ProviderConfig`` + in place during ``config.update(values)``, so re-diffing ``old`` vs + ``new`` here would always be empty — the caller's set is the only + reliable signal. + """ + from .constants import PERMISSION_KEYS # noqa: PLC0415 + + # ``set().issubset(...)`` is True, so an empty ``changed_keys`` (no-op + # call) classifies as permission-only and skips a pointless restart. + permission_only = changed_keys.issubset(PERMISSION_KEYS) + + self._config = new_config + if permission_only and hasattr(self, "_allowed_tags"): + self._allowed_tags = {str(t) for t in enabled_tags(new_config)} + self._logger.debug( + "MCP runtime: hot-swapped tag filter to %d tags", + len(self._allowed_tags), + ) + return + + await self.stop() + await self.start() + + def _apply_tag_filter(self, mcp: Any, allowed: set[Any]) -> None: + """Install the tag-filter middleware on the given FastMCP server.""" + from .middleware import TagFilterMiddleware # noqa: PLC0415 + + # Snapshot tags into the closure-captured set declared in __init__. + # apply_permission_change mutates the same set later, so the + # middleware sees the new permissions without rebuilding FastMCP. + self._allowed_tags = {str(t) for t in allowed} + mcp.add_middleware(TagFilterMiddleware(lambda: self._allowed_tags, build_tag_lookup(mcp))) + + +async def _tag_lookup(mcp: Any, kind: str, key: str) -> set[str] | None: + """Resolve component name/URI back to its tag set via FastMCP public API. + + Returns ``None`` if the component is unknown — middleware then blocks + the call with NotFoundError, preventing a client that cached a name + from a prior permission set from invoking a now-hidden tool. For + resources the concrete-URI lookup falls back to template matching: + ``FastMCP.get_resource`` only finds statically-registered resources, + so a request for a concrete URI backed by a + ``@mcp.resource("scheme://{x}")`` template would otherwise be + misreported as not-found. + """ + try: + if kind == "tool": + obj = await mcp.get_tool(key) + elif kind == "resource": + obj = await mcp.get_resource(key) + if obj is None: + obj = await mcp.get_resource_template(key) + elif kind == "prompt": + obj = await mcp.get_prompt(key) + else: # pragma: no cover - kind is Literal-typed at the caller + return None + except Exception: + return None + if obj is None: + return None + return {str(t) for t in (getattr(obj, "tags", None) or set())} + + +def build_tag_lookup(mcp: Any) -> Callable[[str, str], Awaitable[set[str] | None]]: + """Return a closure suitable for :class:`TagFilterMiddleware`'s ``lookup``.""" + + async def lookup(kind: str, key: str) -> set[str] | None: + return await _tag_lookup(mcp, kind, key) + + return lookup diff --git a/music_assistant/providers/fastmcp_server/tags.py b/music_assistant/providers/fastmcp_server/tags.py new file mode 100644 index 0000000000..cebfd6f540 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/tags.py @@ -0,0 +1,74 @@ +"""FastMCP tag enum and config-to-tag mapping.""" + +from __future__ import annotations + +from enum import StrEnum +from typing import TYPE_CHECKING + +from .constants import ( + CONF_CONTROL_MEDIA, + CONF_CONTROL_PLAYBACK, + CONF_CONTROL_PLAYERS, + CONF_CONTROL_VOLUME, + CONF_DELETE_FAVORITES, + CONF_DELETE_LIBRARY, + CONF_DELETE_PLAYLISTS, + CONF_DELETE_QUEUE, + CONF_EDIT_FAVORITES, + CONF_EDIT_LIBRARY, + CONF_EDIT_PLAYLISTS, + CONF_EDIT_QUEUE, + CONF_QUERY_LIBRARY, + CONF_QUERY_METADATA, + CONF_QUERY_PLAYERS, + CONF_QUERY_QUEUE, +) + +if TYPE_CHECKING: + from music_assistant_models.config_entries import ProviderConfig + + +class Tag(StrEnum): + """Permission tags applied to FastMCP tools / resources / prompts.""" + + QUERY_LIBRARY = "query:library" + QUERY_QUEUE = "query:queue" + QUERY_PLAYERS = "query:players" + QUERY_METADATA = "query:metadata" + CONTROL_PLAYBACK = "control:playback" + CONTROL_VOLUME = "control:volume" + CONTROL_PLAYERS = "control:players" + CONTROL_MEDIA = "control:media" + EDIT_LIBRARY = "edit:library" + EDIT_QUEUE = "edit:queue" + EDIT_PLAYLISTS = "edit:playlists" + EDIT_FAVORITES = "edit:favorites" + DELETE_LIBRARY = "delete:library" + DELETE_QUEUE = "delete:queue" + DELETE_PLAYLISTS = "delete:playlists" + DELETE_FAVORITES = "delete:favorites" + + +CONFIG_TO_TAG: dict[str, Tag] = { + CONF_QUERY_LIBRARY: Tag.QUERY_LIBRARY, + CONF_QUERY_QUEUE: Tag.QUERY_QUEUE, + CONF_QUERY_PLAYERS: Tag.QUERY_PLAYERS, + CONF_QUERY_METADATA: Tag.QUERY_METADATA, + CONF_CONTROL_PLAYBACK: Tag.CONTROL_PLAYBACK, + CONF_CONTROL_VOLUME: Tag.CONTROL_VOLUME, + CONF_CONTROL_PLAYERS: Tag.CONTROL_PLAYERS, + CONF_CONTROL_MEDIA: Tag.CONTROL_MEDIA, + CONF_EDIT_LIBRARY: Tag.EDIT_LIBRARY, + CONF_EDIT_QUEUE: Tag.EDIT_QUEUE, + CONF_EDIT_PLAYLISTS: Tag.EDIT_PLAYLISTS, + CONF_EDIT_FAVORITES: Tag.EDIT_FAVORITES, + CONF_DELETE_LIBRARY: Tag.DELETE_LIBRARY, + CONF_DELETE_QUEUE: Tag.DELETE_QUEUE, + CONF_DELETE_PLAYLISTS: Tag.DELETE_PLAYLISTS, + CONF_DELETE_FAVORITES: Tag.DELETE_FAVORITES, +} + + +def enabled_tags(config: ProviderConfig) -> set[Tag]: + """Return the set of permission tags that are enabled in the given config.""" + return {tag for cfg_key, tag in CONFIG_TO_TAG.items() if config.get_value(cfg_key)} diff --git a/music_assistant/providers/fastmcp_server/tools/__init__.py b/music_assistant/providers/fastmcp_server/tools/__init__.py new file mode 100644 index 0000000000..8223bb992b --- /dev/null +++ b/music_assistant/providers/fastmcp_server/tools/__init__.py @@ -0,0 +1,29 @@ +"""Tool sub-server factories. + +Each ``build_*_server`` function returns its own :class:`fastmcp.FastMCP` +instance, which is then ``mount()``-ed under a namespace by +:meth:`provider.server.MCPServerRuntime.start`. Tags applied per-tool drive +the ``restrict_tag`` middleware. +""" + +from __future__ import annotations + +from .library import build_library_server +from .media import build_media_server +from .metadata import build_metadata_server +from .playback import build_playback_server +from .players import build_players_server +from .playlists import build_playlists_server +from .queue import build_queue_server +from .volume import build_volume_server + +__all__ = [ + "build_library_server", + "build_media_server", + "build_metadata_server", + "build_playback_server", + "build_players_server", + "build_playlists_server", + "build_queue_server", + "build_volume_server", +] diff --git a/music_assistant/providers/fastmcp_server/tools/_common.py b/music_assistant/providers/fastmcp_server/tools/_common.py new file mode 100644 index 0000000000..9988eedd7e --- /dev/null +++ b/music_assistant/providers/fastmcp_server/tools/_common.py @@ -0,0 +1,260 @@ +"""Shared helpers for tool sub-servers.""" +# ruff: noqa: TID252 -- relative imports are the canonical MA-provider pattern. + +from __future__ import annotations + +import dataclasses +import json +from typing import TYPE_CHECKING, Any + +from fastmcp.exceptions import ToolError + +from ..models import ( + AlbumBrief, + ArtistBrief, + PlayerBrief, + PlaylistBrief, + QueueBrief, + QueueItemBrief, + RadioBrief, + TrackBrief, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + + from fastmcp import Context + +MAX_PAGE = 200 +DEFAULT_PAGE = 50 + +# Per-tool execution timeouts (seconds), used in @sub.tool(timeout=…). Long +# searches and recommendation fetches reach external music providers; transport +# controls are local-RPC-fast; bulk playlist edits are explicitly larger. +TIMEOUT_FAST = 10.0 +TIMEOUT_MUTATION = 15.0 +TIMEOUT_QUERY = 30.0 +TIMEOUT_BULK = 60.0 + + +async def confirm_or_raise(ctx: Context | None, prompt: str, *, enabled: bool) -> None: + """Ask the MCP client to confirm a destructive operation. + + If ``enabled`` is False, or there is no Context (direct unit-test + invocation), or the client returns ``NotImplementedError`` (no elicit + support), the call passes through silently — the permission flag is + still in effect as the primary defense. + + On user decline / cancel, raises ``ToolError`` so the SDK reports it as + a tool-execution error (``isError: true``) rather than a protocol error. + """ + if not enabled or ctx is None: + return + try: + # ctx.elicit's overloads in older mypy stubs don't recognize ``bool`` + # as a valid scalar response_type — runtime behaviour is fine. Newer + # upstream mypy resolves the overload correctly, so the unused-ignore + # is also suppressed. + result = await ctx.elicit(prompt, response_type=bool) # type: ignore[arg-type, unused-ignore] + except NotImplementedError: + return + action = getattr(result, "action", None) + data = getattr(result, "data", None) + if action != "accept" or not data: + msg = "Operation cancelled by user" + raise ToolError(msg) + + +def page_args(offset: int = 0, limit: int = DEFAULT_PAGE) -> tuple[int, int]: + """Clamp paging arguments to safe bounds.""" + safe_limit = max(1, min(int(limit), MAX_PAGE)) + safe_offset = max(0, int(offset)) + return safe_offset, safe_limit + + +def to_brief_track(track: Any) -> TrackBrief: + """Convert a ``music_assistant_models.Track`` (or compatible) to ``TrackBrief``.""" + artists = _names(getattr(track, "artists", None)) + album = _name(getattr(track, "album", None)) + return TrackBrief( + uri=str(getattr(track, "uri", "")), + name=str(getattr(track, "name", "")), + artists=artists, + album=album, + duration=_int(getattr(track, "duration", None)), + ) + + +def to_brief_album(album: Any) -> AlbumBrief: + """Convert an Album-like object to ``AlbumBrief``.""" + artist = _name(getattr(album, "artist", None)) + if artist is None: + artists = _names(getattr(album, "artists", None)) + artist = artists[0] if artists else None + return AlbumBrief( + uri=str(getattr(album, "uri", "")), + name=str(getattr(album, "name", "")), + artist=artist, + year=_int(getattr(album, "year", None)), + ) + + +def to_brief_artist(artist: Any) -> ArtistBrief: + """Convert an Artist-like object to ``ArtistBrief``.""" + return ArtistBrief( + uri=str(getattr(artist, "uri", "")), + name=str(getattr(artist, "name", "")), + ) + + +def to_brief_playlist(playlist: Any) -> PlaylistBrief: + """Convert a Playlist-like object to ``PlaylistBrief``.""" + return PlaylistBrief( + uri=str(getattr(playlist, "uri", "")), + name=str(getattr(playlist, "name", "")), + track_count=_int(getattr(playlist, "track_count", None)), + owner=_name(getattr(playlist, "owner", None)), + ) + + +def to_brief_radio(radio: Any) -> RadioBrief: + """Convert a Radio-like object to ``RadioBrief``.""" + return RadioBrief( + uri=str(getattr(radio, "uri", "")), + name=str(getattr(radio, "name", "")), + description=_str_or_none(getattr(radio, "description", None)), + ) + + +def to_brief_player(player: Any) -> PlayerBrief: + """Convert a Player-like object to ``PlayerBrief``.""" + # MA's :class:`Player` exposes ``playback_state`` (an enum); ``state`` is + # only a serialisation alias and is not present on the Python object. + # Read both so test stubs and any older shim still resolve. + state_obj = getattr(player, "playback_state", None) or getattr(player, "state", None) + state_value = ( + str(getattr(state_obj, "value", state_obj)) if state_obj is not None else "unknown" + ) + + # ``Player.state`` (the :class:`PlayerState` dataclass) holds the canonical + # final values that MA serialises in its REST API — ``__final_power_state`` + # and ``__final_current_media``. The raw ``Player.powered`` / + # ``Player.current_media`` properties read internal ``_attr_*`` caches that + # lag (powered stays False on virtual players, current_media isn't cleared + # on stop). Detect a PlayerState dataclass by the presence of ``powered`` + # on it; fall back to the legacy direct attributes otherwise. + player_state = getattr(player, "state", None) + if player_state is not None and hasattr(player_state, "powered"): + powered_val = bool(player_state.powered) if player_state.powered is not None else True + current_media = getattr(player_state, "current_media", None) + else: + powered_val = bool(getattr(player, "powered", True)) + current_media = getattr(player, "current_media", None) + + current_item: str | None = None + if current_media is not None: + # Prefer the human-readable title; fall back to the URI (always + # present on ``PlayerMedia``). Avoids stringifying the whole dataclass + # which produces noisy ``PlayerMedia(uri=…, media_type=…, …)`` blobs. + current_item = _str_or_none(getattr(current_media, "title", None)) or _str_or_none( + getattr(current_media, "uri", None) + ) + + return PlayerBrief( + player_id=str(getattr(player, "player_id", "")), + name=str(getattr(player, "display_name", None) or getattr(player, "name", "")), + state=state_value, + volume_level=_int(getattr(player, "volume_level", None)), + powered=powered_val, + current_item=current_item, + ) + + +def to_brief_queue(queue: Any, items: Sequence[Any] | None = None) -> QueueBrief: + """Convert a PlayerQueue-like object to ``QueueBrief``. + + :param queue: queue-like object with ``queue_id``, ``current_index``, etc. + :param items: optional iterable of queue items to include. + """ + repeat_mode = getattr(queue, "repeat_mode", None) + repeat_value = str(getattr(repeat_mode, "value", repeat_mode)) if repeat_mode else "off" + brief_items: list[QueueItemBrief] = [] + if items: + for it in items: + brief_items.append( + QueueItemBrief( + item_id=str(getattr(it, "queue_item_id", "")), + name=str(getattr(it, "name", "")), + duration=_int(getattr(it, "duration", None)), + artists=_names(getattr(getattr(it, "media_item", None), "artists", None)), + ) + ) + # In the canonical MA model PlayerQueue.items is an int (total queue + # length), not a list. Fall back to alternate field names for older builds, + # and only as a last resort to len(brief_items) — which would under-report + # the real length, since `brief_items` is the truncated lookahead from + # get_active_queue, not the full queue. + raw_total = getattr(queue, "items", None) + explicit_count = _int(raw_total) if isinstance(raw_total, int) else None + if explicit_count is None: + explicit_count = _int( + getattr(queue, "items_count", None) or getattr(queue, "items_total", None) + ) + return QueueBrief( + queue_id=str(getattr(queue, "queue_id", "")), + current_index=_int(getattr(queue, "current_index", None)), + item_count=explicit_count if explicit_count is not None else len(brief_items), + shuffle=bool(getattr(queue, "shuffle_enabled", False)), + repeat=repeat_value, + items=brief_items, + ) + + +# ── private helpers ────────────────────────────────────────────────────────── + + +def _names(items: Any) -> list[str]: + if not items: + return [] + return [str(getattr(i, "name", i)) for i in items] + + +def _name(item: Any) -> str | None: + if item is None: + return None + return str(getattr(item, "name", item)) + + +def _int(value: Any) -> int | None: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + +def _str_or_none(value: Any) -> str | None: + if value is None: + return None + return str(value) + + +def to_resource_text(value: Any) -> str | None: + """Serialize a resource handler's return value as JSON text. + + FastMCP's resource read API requires handlers to return + ``str | bytes | list[ResourceContents]``. MA domain objects expose + ``to_dict()``; our brief dataclasses are converted via + :func:`dataclasses.asdict`. ``None`` is returned unchanged so FastMCP + serialises it as a ``"null"`` ``TextResourceContents`` block. + + :param value: handler return value (None, MA domain object, or Brief). + """ + if value is None: + return None + if hasattr(value, "to_dict"): + return json.dumps(value.to_dict(), ensure_ascii=False, default=str) + if dataclasses.is_dataclass(value) and not isinstance(value, type): + return json.dumps(dataclasses.asdict(value), ensure_ascii=False, default=str) + return json.dumps(value, ensure_ascii=False, default=str) diff --git a/music_assistant/providers/fastmcp_server/tools/library.py b/music_assistant/providers/fastmcp_server/tools/library.py new file mode 100644 index 0000000000..daa6cfacf5 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/tools/library.py @@ -0,0 +1,184 @@ +"""Library: search, list, and get tools (read-only).""" +# ruff: noqa: TID252 -- relative imports are the canonical MA-provider pattern. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastmcp import Context, FastMCP +from mcp.types import ToolAnnotations +from music_assistant_models.enums import MediaType + +from ..models import ( + AlbumBrief, + ArtistBrief, + PlaylistBrief, + RadioBrief, + TrackBrief, +) +from ..tags import Tag +from ._common import ( + TIMEOUT_QUERY, + page_args, + to_brief_album, + to_brief_artist, + to_brief_playlist, + to_brief_radio, + to_brief_track, +) + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + + +def _readonly(title: str) -> ToolAnnotations: + """Read-only library tool annotations with the supplied UI title.""" + return ToolAnnotations( + title=title, + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ) + + +def build_library_server(mass: MusicAssistant) -> FastMCP: + """Construct the ``library/*`` sub-server.""" + sub: FastMCP = FastMCP(name="library") + + @sub.tool( + tags={Tag.QUERY_LIBRARY}, + annotations=ToolAnnotations( + title="Search tracks", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def search_tracks( + query: str, limit: int = 25, ctx: Context | None = None + ) -> list[TrackBrief]: + """Search for tracks by free-text query across all enabled providers.""" + if ctx is not None: + await ctx.info(f"Searching MA for tracks matching {query!r} (limit={limit})") + results = await mass.music.search(query, [MediaType.TRACK], limit=limit) + return [to_brief_track(t) for t in (results.tracks or [])] + + @sub.tool( + tags={Tag.QUERY_LIBRARY}, + annotations=ToolAnnotations( + title="Search albums", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def search_albums( + query: str, limit: int = 25, ctx: Context | None = None + ) -> list[AlbumBrief]: + """Search for albums by free-text query.""" + if ctx is not None: + await ctx.info(f"Searching MA for albums matching {query!r} (limit={limit})") + results = await mass.music.search(query, [MediaType.ALBUM], limit=limit) + return [to_brief_album(a) for a in (results.albums or [])] + + @sub.tool( + tags={Tag.QUERY_LIBRARY}, + annotations=ToolAnnotations( + title="Search artists", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def search_artists( + query: str, limit: int = 25, ctx: Context | None = None + ) -> list[ArtistBrief]: + """Search for artists by free-text query.""" + if ctx is not None: + await ctx.info(f"Searching MA for artists matching {query!r} (limit={limit})") + results = await mass.music.search(query, [MediaType.ARTIST], limit=limit) + return [to_brief_artist(a) for a in (results.artists or [])] + + @sub.tool( + tags={Tag.QUERY_LIBRARY}, + annotations=_readonly("List library tracks"), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def list_library_tracks(offset: int = 0, limit: int = 50) -> list[TrackBrief]: + """List tracks already in the user's library, paginated.""" + offset, limit = page_args(offset, limit) + items = await mass.music.tracks.library_items(limit=limit, offset=offset) + return [to_brief_track(t) for t in items] + + @sub.tool( + tags={Tag.QUERY_LIBRARY}, + annotations=_readonly("List library albums"), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def list_library_albums(offset: int = 0, limit: int = 50) -> list[AlbumBrief]: + """List albums already in the user's library, paginated.""" + offset, limit = page_args(offset, limit) + items = await mass.music.albums.library_items(limit=limit, offset=offset) + return [to_brief_album(a) for a in items] + + @sub.tool( + tags={Tag.QUERY_LIBRARY}, + annotations=_readonly("List library artists"), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def list_library_artists(offset: int = 0, limit: int = 50) -> list[ArtistBrief]: + """List artists already in the user's library, paginated.""" + offset, limit = page_args(offset, limit) + items = await mass.music.artists.library_items(limit=limit, offset=offset) + return [to_brief_artist(a) for a in items] + + @sub.tool( + tags={Tag.QUERY_LIBRARY}, + annotations=_readonly("List library playlists"), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def list_library_playlists(offset: int = 0, limit: int = 50) -> list[PlaylistBrief]: + """List playlists already in the user's library, paginated.""" + offset, limit = page_args(offset, limit) + items = await mass.music.playlists.library_items(limit=limit, offset=offset) + return [to_brief_playlist(p) for p in items] + + @sub.tool( + tags={Tag.QUERY_LIBRARY}, + annotations=_readonly("List library radio"), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def list_library_radio(offset: int = 0, limit: int = 50) -> list[RadioBrief]: + """List radio stations already in the user's library, paginated.""" + offset, limit = page_args(offset, limit) + items = await mass.music.radio.library_items(limit=limit, offset=offset) + return [to_brief_radio(r) for r in items] + + @sub.tool( + tags={Tag.QUERY_LIBRARY}, + annotations=_readonly("Get track by URI"), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def get_track_by_uri(uri: str) -> TrackBrief: + """Resolve a track by its MA URI to a brief summary.""" + item = await mass.music.get_item_by_uri(uri) + return to_brief_track(item) + + @sub.tool( + tags={Tag.QUERY_LIBRARY}, + annotations=_readonly("Recently added tracks"), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def recently_added_tracks(limit: int = 10) -> list[TrackBrief]: + """Return tracks recently added to the library.""" + items = await mass.music.recently_added_tracks(limit=limit) + return [to_brief_track(t) for t in items] + + return sub diff --git a/music_assistant/providers/fastmcp_server/tools/media.py b/music_assistant/providers/fastmcp_server/tools/media.py new file mode 100644 index 0000000000..2a8c7daf7a --- /dev/null +++ b/music_assistant/providers/fastmcp_server/tools/media.py @@ -0,0 +1,169 @@ +"""Media: favorites, library add/remove, announcements.""" +# ruff: noqa: TID252 -- relative imports are the canonical MA-provider pattern. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from fastmcp import Context, FastMCP +from fastmcp.exceptions import ToolError +from mcp.types import ToolAnnotations + +from ..tags import Tag +from ._common import TIMEOUT_MUTATION, confirm_or_raise + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + + +async def _resolve_uri(mass: MusicAssistant, uri: str) -> Any: + """Look up a MediaItem by MA URI, raising ToolError when missing. + + MA's MusicController APIs that mutate library / favorites / play history + expect a resolved (media_type, library_item_id) pair or a typed media + object — not a raw URI string. This helper centralises the lookup. + """ + # MA's ``get_item_by_uri`` is typed as returning a MediaItem (no Optional); + # missing entries raise instead. Normalise to a ToolError for a consistent + # tool-surface error path. + try: + return await mass.music.get_item_by_uri(uri) + except Exception as exc: + msg = f"Item not found for URI: {uri!r} ({exc})" + raise ToolError(msg) from exc + + +async def _resolve_to_library_item(mass: MusicAssistant, uri: str) -> Any: + """Resolve a URI to its library counterpart, raising ToolError when not in library. + + MA's :meth:`MusicController.remove_item_from_favorites` and + :meth:`remove_item_from_library` expect a library item id. When the + caller passes a provider-native URI (e.g. ``yandex_music://track/abc``), + :func:`_resolve_uri` returns a MediaItem with the provider's id — + feeding that into the controller silently targets the wrong row (or + fails on ``int(...)``). This helper looks up the library counterpart + via :meth:`get_library_item_by_prov_id` and raises if the item isn't + in the library. + """ + item = await _resolve_uri(mass, uri) + if getattr(item, "provider", None) == "library": + return item + lib_item = await mass.music.get_library_item_by_prov_id( + item.media_type, item.item_id, item.provider + ) + if lib_item is None: + msg = f"URI {uri!r} is not in the library" + raise ToolError(msg) + return lib_item + + +def build_media_server(mass: MusicAssistant, *, require_confirmation: bool = True) -> FastMCP: + """Construct the ``media/*`` sub-server.""" + sub: FastMCP = FastMCP(name="media") + + @sub.tool( + tags={Tag.EDIT_FAVORITES}, + annotations=ToolAnnotations( + title="Add to favorites", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def add_to_favorites(uri: str) -> None: + """Add a media item (by URI) to favorites.""" + item = await _resolve_uri(mass, uri) + await mass.music.add_item_to_favorites(item) + + @sub.tool( + tags={Tag.DELETE_FAVORITES}, + annotations=ToolAnnotations( + title="Remove from favorites", + readOnlyHint=False, + destructiveHint=True, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def remove_from_favorites(uri: str, ctx: Context | None = None) -> None: + """Remove a media item (by URI) from favorites.""" + await confirm_or_raise( + ctx, + f"Remove {uri!r} from favorites?", + enabled=require_confirmation, + ) + item = await _resolve_to_library_item(mass, uri) + await mass.music.remove_item_from_favorites(item.media_type, item.item_id) + + @sub.tool( + tags={Tag.EDIT_LIBRARY}, + annotations=ToolAnnotations( + title="Add to library", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def add_to_library(uri: str) -> None: + """Add a media item (by URI) to the library.""" + item = await _resolve_uri(mass, uri) + await mass.music.add_item_to_library(item) + + @sub.tool( + tags={Tag.DELETE_LIBRARY}, + annotations=ToolAnnotations( + title="Remove from library", + readOnlyHint=False, + destructiveHint=True, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def remove_from_library(uri: str, ctx: Context | None = None) -> None: + """Remove a media item (by URI) from the library.""" + await confirm_or_raise( + ctx, + f"Remove {uri!r} from the library? This cannot be undone.", + enabled=require_confirmation, + ) + item = await _resolve_to_library_item(mass, uri) + await mass.music.remove_item_from_library(item.media_type, item.item_id) + + @sub.tool( + tags={Tag.CONTROL_MEDIA}, + annotations=ToolAnnotations( + title="Mark item played", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def mark_played(uri: str) -> None: + """Mark a media item as played (updates play history).""" + item = await _resolve_uri(mass, uri) + await mass.music.mark_item_played(item) + + @sub.tool( + tags={Tag.CONTROL_MEDIA}, + annotations=ToolAnnotations( + title="Play announcement", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def play_announcement(player_id: str, url: str, volume_level: int | None = None) -> None: + """Play a one-shot announcement audio URL on a player.""" + await mass.players.play_announcement(player_id, url, volume_level=volume_level) + + return sub diff --git a/music_assistant/providers/fastmcp_server/tools/metadata.py b/music_assistant/providers/fastmcp_server/tools/metadata.py new file mode 100644 index 0000000000..4980484213 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/tools/metadata.py @@ -0,0 +1,96 @@ +"""Metadata: lyrics, recommendations, similar tracks, refresh.""" +# ruff: noqa: TID252 -- relative imports are the canonical MA-provider pattern. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastmcp import Context, FastMCP +from mcp.types import ToolAnnotations + +from ..models import RecommendationFolderBrief, TrackBrief +from ..tags import Tag +from ._common import TIMEOUT_QUERY, to_brief_track + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + + +def _readonly(title: str) -> ToolAnnotations: + """Read-only metadata tool annotations with the supplied UI title.""" + return ToolAnnotations( + title=title, + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ) + + +def build_metadata_server(mass: MusicAssistant) -> FastMCP: + """Construct the ``metadata/*`` sub-server.""" + sub: FastMCP = FastMCP(name="metadata") + + @sub.tool( + tags={Tag.QUERY_METADATA}, + annotations=ToolAnnotations( + title="Recommendations", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def recommendations( + ctx: Context | None = None, + ) -> list[RecommendationFolderBrief]: + """Return Music Assistant's curated recommendations folders.""" + if ctx is not None: + await ctx.info("Fetching MA curated recommendations…") + folders = await mass.music.recommendations() + result: list[RecommendationFolderBrief] = [] + for folder in folders: + folder_items = getattr(folder, "items", None) or [] + result.append( + RecommendationFolderBrief( + name=str(getattr(folder, "name", "")), + item_uris=[str(getattr(it, "uri", "")) for it in folder_items], + ) + ) + return result + + @sub.tool( + tags={Tag.QUERY_METADATA}, + annotations=ToolAnnotations( + title="Recently played tracks", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def recently_played(limit: int = 10) -> list[TrackBrief]: + """Return the user's recently played tracks.""" + items = await mass.music.recently_played(limit=limit) + return [to_brief_track(it) for it in items if getattr(it, "name", None)] + + @sub.tool( + tags={Tag.QUERY_METADATA}, + annotations=_readonly("Get lyrics"), + timeout=TIMEOUT_QUERY, + ) # type: ignore[untyped-decorator, unused-ignore] + async def get_lyrics(track_uri: str) -> str | None: + """Return lyrics for a track URI (best-effort). + + Different providers expose lyrics through different attributes; this + tool surfaces the most common one (``metadata.lyrics``) and returns + ``None`` if no lyrics are available. + """ + item = await mass.music.get_item_by_uri(track_uri) + metadata = getattr(item, "metadata", None) + lyrics = getattr(metadata, "lyrics", None) if metadata else None + return str(lyrics) if lyrics else None + + return sub diff --git a/music_assistant/providers/fastmcp_server/tools/playback.py b/music_assistant/providers/fastmcp_server/tools/playback.py new file mode 100644 index 0000000000..581eddf091 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/tools/playback.py @@ -0,0 +1,114 @@ +"""Playback: play, pause, seek, skip, play media.""" +# ruff: noqa: TID252 -- relative imports are the canonical MA-provider pattern. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastmcp import FastMCP +from mcp.types import ToolAnnotations + +from ..tags import Tag +from ._common import TIMEOUT_MUTATION + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + + +def _control_annotations(*, title: str, idempotent: bool = False) -> ToolAnnotations: + """Default annotations for transport-control tools (mutate but non-destructive).""" + return ToolAnnotations( + title=title, + readOnlyHint=False, + destructiveHint=False, + idempotentHint=idempotent, + openWorldHint=False, + ) + + +def build_playback_server(mass: MusicAssistant) -> FastMCP: + """Construct the ``playback/*`` sub-server.""" + sub: FastMCP = FastMCP(name="playback") + + @sub.tool( + tags={Tag.CONTROL_PLAYBACK}, + annotations=_control_annotations(title="Toggle play / pause"), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def play_pause(queue_id: str) -> None: + """Toggle play/pause on the given queue.""" + await mass.player_queues.play_pause(queue_id) + + @sub.tool( + tags={Tag.CONTROL_PLAYBACK}, + annotations=_control_annotations(title="Stop playback", idempotent=True), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def stop(queue_id: str) -> None: + """Stop playback on the given queue.""" + await mass.player_queues.stop(queue_id) + + @sub.tool( + tags={Tag.CONTROL_PLAYBACK}, + annotations=_control_annotations(title="Next track"), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def next_track(queue_id: str) -> None: + """Advance to the next track.""" + await mass.player_queues.next(queue_id) + + @sub.tool( + tags={Tag.CONTROL_PLAYBACK}, + annotations=_control_annotations(title="Previous track"), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def previous_track(queue_id: str) -> None: + """Return to the previous track.""" + await mass.player_queues.previous(queue_id) + + @sub.tool( + tags={Tag.CONTROL_PLAYBACK}, + annotations=_control_annotations(title="Skip by seconds"), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def skip(queue_id: str, seconds: int = 10) -> None: + """Skip forward by ``seconds`` (or backward when negative).""" + await mass.player_queues.skip(queue_id, seconds) + + @sub.tool( + tags={Tag.CONTROL_PLAYBACK}, + annotations=_control_annotations(title="Seek to position"), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def seek(queue_id: str, position: int) -> None: + """Seek to absolute position (seconds) in the current track.""" + await mass.player_queues.seek(queue_id, position) + + @sub.tool( + tags={Tag.CONTROL_PLAYBACK}, + annotations=_control_annotations(title="Play media on a queue"), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def play_media( + queue_id: str, + uri: str, + radio_mode: bool = False, + ) -> None: + """Play media on the given queue by MA URI. + + :param queue_id: queue to play on (typically the player_id). + :param uri: MA URI of the media to play (artist, album, track, playlist, radio). + :param radio_mode: when ``True``, MA fills the queue with similar items. + """ + await mass.player_queues.play_media(queue_id, uri, radio_mode=radio_mode) + + @sub.tool( + tags={Tag.CONTROL_PLAYBACK}, + annotations=_control_annotations(title="Play queue item at index"), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def play_index(queue_id: str, index: int) -> None: + """Play the queue item at the given zero-based index.""" + await mass.player_queues.play_index(queue_id, index) + + return sub diff --git a/music_assistant/providers/fastmcp_server/tools/players.py b/music_assistant/providers/fastmcp_server/tools/players.py new file mode 100644 index 0000000000..97c4b980ad --- /dev/null +++ b/music_assistant/providers/fastmcp_server/tools/players.py @@ -0,0 +1,87 @@ +"""Players: list, inspect, power, group.""" +# ruff: noqa: TID252 -- relative imports are the canonical MA-provider pattern. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastmcp import FastMCP +from mcp.types import ToolAnnotations + +from ..models import PlayerBrief +from ..tags import Tag +from ._common import TIMEOUT_FAST, TIMEOUT_MUTATION, to_brief_player + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + + +def build_players_server(mass: MusicAssistant) -> FastMCP: + """Construct the ``players/*`` sub-server.""" + sub: FastMCP = FastMCP(name="players") + + @sub.tool( + tags={Tag.QUERY_PLAYERS}, + annotations=ToolAnnotations( + title="List all players", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_FAST, + ) # type: ignore[untyped-decorator, unused-ignore] + async def list_players() -> list[PlayerBrief]: + """List all players known to MA.""" + all_players = mass.players.all_players() if hasattr(mass.players, "all_players") else [] + if callable(all_players): + all_players = all_players() + return [to_brief_player(p) for p in all_players] + + @sub.tool( + tags={Tag.QUERY_PLAYERS}, + annotations=ToolAnnotations( + title="Get player by id", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_FAST, + ) # type: ignore[untyped-decorator, unused-ignore] + async def get_player(player_id: str) -> PlayerBrief | None: + """Return a single player by id, or ``None`` if it doesn't exist.""" + player = mass.players.get_player(player_id) + return to_brief_player(player) if player is not None else None + + @sub.tool( + tags={Tag.CONTROL_PLAYERS}, + annotations=ToolAnnotations( + title="Power player on / off", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def set_power(player_id: str, powered: bool) -> None: + """Power a player on or off.""" + await mass.players.cmd_power(player_id, powered) + + @sub.tool( + tags={Tag.CONTROL_PLAYERS}, + annotations=ToolAnnotations( + title="Group player into sync group", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def group_player(player_id: str, target_player_id: str) -> None: + """Group ``player_id`` with ``target_player_id`` (sync group).""" + await mass.players.cmd_group(player_id, target_player_id) + + return sub diff --git a/music_assistant/providers/fastmcp_server/tools/playlists.py b/music_assistant/providers/fastmcp_server/tools/playlists.py new file mode 100644 index 0000000000..c4fa7bbf55 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/tools/playlists.py @@ -0,0 +1,135 @@ +"""Playlists: create, modify, delete.""" +# ruff: noqa: TID252 -- relative imports are the canonical MA-provider pattern. + +from __future__ import annotations + +import contextlib +from typing import TYPE_CHECKING + +from fastmcp import Context, FastMCP +from mcp.types import ToolAnnotations + +from ..models import PlaylistBrief +from ..tags import Tag +from ._common import TIMEOUT_BULK, TIMEOUT_MUTATION, confirm_or_raise, to_brief_playlist + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + + +def build_playlists_server(mass: MusicAssistant, *, require_confirmation: bool = True) -> FastMCP: + """Construct the ``playlists/*`` sub-server.""" + sub: FastMCP = FastMCP(name="playlists") + + @sub.tool( + tags={Tag.EDIT_PLAYLISTS}, + annotations=ToolAnnotations( + title="Create a playlist", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def create_playlist(name: str, provider_instance_id: str | None = None) -> PlaylistBrief: + """Create a new playlist on a music provider.""" + playlist = await mass.music.playlists.create_playlist( + name, provider_instance_or_domain=provider_instance_id + ) + return to_brief_playlist(playlist) + + @sub.tool( + tags={Tag.EDIT_PLAYLISTS}, + annotations=ToolAnnotations( + title="Add a single track to a playlist", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def add_track(playlist_id: str | int, track_uri: str) -> None: + """Append one track to a playlist.""" + await mass.music.playlists.add_playlist_track(playlist_id, track_uri) + + @sub.tool( + tags={Tag.EDIT_PLAYLISTS}, + annotations=ToolAnnotations( + title="Add tracks to a playlist", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=False, + ), + timeout=TIMEOUT_BULK, + ) # type: ignore[untyped-decorator, unused-ignore] + async def add_tracks( + playlist_id: str | int, + track_uris: list[str], + ctx: Context | None = None, + ) -> None: + """Append multiple tracks to a playlist. + + For batches up to 10 the call is bulk-dispatched (one round-trip); + beyond that, items are added one-by-one with progress reporting so + the LLM client can show a meaningful spinner / cancellation handle. + + .. warning:: + + The per-item path is **not transactional**. If the client cancels + (``notifications/cancelled``) or MA raises on the N-th track, + tracks 0..N-1 stay added — there is no rollback. Callers that need + atomic semantics should keep batches at ``<= 10`` so the bulk + ``add_playlist_tracks`` round-trip is used. + """ + total = len(track_uris) + if total <= 10: + await mass.music.playlists.add_playlist_tracks(playlist_id, track_uris) + return + added = 0 + try: + for i, uri in enumerate(track_uris, start=1): + await mass.music.playlists.add_playlist_track(playlist_id, uri) + added = i + if ctx is not None: + await ctx.report_progress(progress=i, total=total) + except BaseException: + # Surface partial-state to the client before re-raising. BaseException + # also catches asyncio.CancelledError, which we want to flag. + if ctx is not None and added < total: + with contextlib.suppress(Exception): + await ctx.warning( + f"add_tracks: partial state — {added} of {total} tracks " + f"added to playlist {playlist_id!r} before failure / cancel" + ) + raise + + @sub.tool( + tags={Tag.DELETE_PLAYLISTS}, + annotations=ToolAnnotations( + title="Remove tracks from a playlist", + readOnlyHint=False, + destructiveHint=True, + idempotentHint=False, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def remove_tracks( + playlist_id: str | int, + positions: list[int], + ctx: Context | None = None, + ) -> None: + """Remove tracks at the given zero-based positions from a playlist.""" + await confirm_or_raise( + ctx, + f"Remove {len(positions)} track(s) from playlist {playlist_id!r}?", + enabled=require_confirmation, + ) + # MA's PlaylistController expects an immutable tuple, not a list, so + # callers can't accidentally mutate it mid-removal. + await mass.music.playlists.remove_playlist_tracks(playlist_id, tuple(positions)) + + return sub diff --git a/music_assistant/providers/fastmcp_server/tools/queue.py b/music_assistant/providers/fastmcp_server/tools/queue.py new file mode 100644 index 0000000000..a0bfe41da7 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/tools/queue.py @@ -0,0 +1,111 @@ +"""Queue: read state and edit / delete queue items.""" +# ruff: noqa: TID252 -- relative imports are the canonical MA-provider pattern. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastmcp import Context, FastMCP +from mcp.types import ToolAnnotations + +from ..models import QueueBrief +from ..tags import Tag +from ._common import TIMEOUT_FAST, TIMEOUT_MUTATION, confirm_or_raise, to_brief_queue + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + +# Matches MA's default queue page size (and the ``queue://`` resource cap). +MAX_QUEUE_ITEMS = 500 + + +def build_queue_server(mass: MusicAssistant, *, require_confirmation: bool = True) -> FastMCP: + """Construct the ``queue/*`` sub-server.""" + sub: FastMCP = FastMCP(name="queue") + + @sub.tool( + tags={Tag.QUERY_QUEUE}, + annotations=ToolAnnotations( + title="Get active queue", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_FAST, + ) # type: ignore[untyped-decorator, unused-ignore] + async def get_active_queue(player_id: str, include_items: int = 25) -> QueueBrief | None: + """Return the active queue for a player, or ``None`` if the player is idle. + + :param include_items: How many lookahead items to materialise. Clamped + to the ``[0, 500]`` range — 500 matches MA's own queue page size + and the ``queue://`` resource cap, preventing a hostile or + sloppy client from forcing the server to load thousands of rows + on every call. + """ + queue = mass.player_queues.get_active_queue(player_id) + if queue is None: + return None + limit = min(max(include_items, 0), MAX_QUEUE_ITEMS) + items = mass.player_queues.items(queue.queue_id, limit=limit) if limit > 0 else [] + return to_brief_queue(queue, items=list(items)) + + @sub.tool( + tags={Tag.EDIT_QUEUE}, + annotations=ToolAnnotations( + title="Toggle queue shuffle", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def set_shuffle(queue_id: str, enabled: bool) -> None: + """Enable or disable shuffle on the given queue.""" + await mass.player_queues.set_shuffle(queue_id, enabled) + + @sub.tool( + tags={Tag.DELETE_QUEUE}, + annotations=ToolAnnotations( + title="Clear queue", + readOnlyHint=False, + destructiveHint=True, + idempotentHint=True, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def clear_queue(queue_id: str, ctx: Context | None = None) -> None: + """Clear all items from the given queue. + + Implementation note: MA's ``player_queues`` exposes a ``clear`` method; + if the API name diverges, this is the single integration point to fix. + """ + await confirm_or_raise( + ctx, + f"Clear all items from queue {queue_id!r}? This cannot be undone.", + enabled=require_confirmation, + ) + clear = getattr(mass.player_queues, "clear", None) + if clear is None: + msg = "mass.player_queues.clear is not available on this MA build" + raise RuntimeError(msg) + clear(queue_id) + + @sub.tool( + tags={Tag.CONTROL_PLAYBACK}, + annotations=ToolAnnotations( + title="Transfer queue between players", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=False, + ), + timeout=TIMEOUT_MUTATION, + ) # type: ignore[untyped-decorator, unused-ignore] + async def transfer_queue(source_queue_id: str, target_queue_id: str) -> None: + """Move a queue from one player to another.""" + await mass.player_queues.transfer_queue(source_queue_id, target_queue_id) + + return sub diff --git a/music_assistant/providers/fastmcp_server/tools/volume.py b/music_assistant/providers/fastmcp_server/tools/volume.py new file mode 100644 index 0000000000..1ef0cb0eb9 --- /dev/null +++ b/music_assistant/providers/fastmcp_server/tools/volume.py @@ -0,0 +1,78 @@ +"""Volume control: set, up/down, mute, group volume.""" +# ruff: noqa: TID252 -- relative imports are the canonical MA-provider pattern. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastmcp import FastMCP +from mcp.types import ToolAnnotations + +from ..tags import Tag +from ._common import TIMEOUT_FAST + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + + +def _vol_annotations(*, title: str, idempotent: bool) -> ToolAnnotations: + """Default volume-tool annotations: never destructive, never open-world.""" + return ToolAnnotations( + title=title, + readOnlyHint=False, + destructiveHint=False, + idempotentHint=idempotent, + openWorldHint=False, + ) + + +def build_volume_server(mass: MusicAssistant) -> FastMCP: + """Construct the ``volume/*`` sub-server.""" + sub: FastMCP = FastMCP(name="volume") + + @sub.tool( + tags={Tag.CONTROL_VOLUME}, + annotations=_vol_annotations(title="Set volume", idempotent=True), + timeout=TIMEOUT_FAST, + ) # type: ignore[untyped-decorator, unused-ignore] + async def volume_set(player_id: str, level: int) -> None: + """Set absolute volume level (0-100) on a player.""" + await mass.players.cmd_volume_set(player_id, max(0, min(100, int(level)))) + + @sub.tool( + tags={Tag.CONTROL_VOLUME}, + annotations=_vol_annotations(title="Volume up", idempotent=False), + timeout=TIMEOUT_FAST, + ) # type: ignore[untyped-decorator, unused-ignore] + async def volume_up(player_id: str) -> None: + """Bump volume up one step.""" + await mass.players.cmd_volume_up(player_id) + + @sub.tool( + tags={Tag.CONTROL_VOLUME}, + annotations=_vol_annotations(title="Volume down", idempotent=False), + timeout=TIMEOUT_FAST, + ) # type: ignore[untyped-decorator, unused-ignore] + async def volume_down(player_id: str) -> None: + """Bump volume down one step.""" + await mass.players.cmd_volume_down(player_id) + + @sub.tool( + tags={Tag.CONTROL_VOLUME}, + annotations=_vol_annotations(title="Mute / unmute", idempotent=True), + timeout=TIMEOUT_FAST, + ) # type: ignore[untyped-decorator, unused-ignore] + async def volume_mute(player_id: str, muted: bool) -> None: + """Mute or unmute a player.""" + await mass.players.cmd_volume_mute(player_id, muted) + + @sub.tool( + tags={Tag.CONTROL_VOLUME}, + annotations=_vol_annotations(title="Set group volume", idempotent=True), + timeout=TIMEOUT_FAST, + ) # type: ignore[untyped-decorator, unused-ignore] + async def group_volume_set(player_id: str, level: int) -> None: + """Set group volume level (0-100) on a sync group.""" + await mass.players.cmd_group_volume(player_id, max(0, min(100, int(level)))) + + return sub diff --git a/requirements_all.txt b/requirements_all.txt index 8c1bca0ed5..25750eedf0 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -35,6 +35,7 @@ deezer-python-async==0.3.0 defusedxml==0.7.1 deno==2.7.12 duration-parser==1.0.1 +fastmcp==3.2.4 getmac==0.9.5 gql[all]==4.0.0 hass-client==1.2.3 diff --git a/tests/providers/fastmcp_server/__init__.py b/tests/providers/fastmcp_server/__init__.py new file mode 100644 index 0000000000..562685bf41 --- /dev/null +++ b/tests/providers/fastmcp_server/__init__.py @@ -0,0 +1 @@ +"""Tests for ma-provider-mcp.""" diff --git a/tests/providers/fastmcp_server/conftest.py b/tests/providers/fastmcp_server/conftest.py new file mode 100644 index 0000000000..e2445eefae --- /dev/null +++ b/tests/providers/fastmcp_server/conftest.py @@ -0,0 +1,247 @@ +"""Shared pytest fixtures for ma-provider-mcp tests. + +Most tests run without a real Music Assistant install — they exercise pure +logic (URI parsing, tag mapping, config entries shape) or use ``MagicMock`` +for ``mass``. Integration-level tests that need a real MA stack are marked +with ``@pytest.mark.integration`` and skipped by default. +""" +# ruff: noqa: D401, PLR0915 +# D401: fixture docstrings describe *what is returned* ("A stub …"), not +# imperative actions; rephrasing to "Build / Return …" hurts grep-ability. +# PLR0915: ``mock_mass`` builds a tall MagicMock surface — splitting it across +# helpers obscures the test contract. + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +if TYPE_CHECKING: + from collections.abc import Iterator + +# Make the provider/ package importable as a top-level "provider" module without +# requiring a full ``pip install -e .`` step in ad-hoc test runs. +# Guard: only add when a "provider/" sibling directory exists so that the +# synced copy at tests/providers/fastmcp_server/conftest.py does NOT add +# tests/providers/ to sys.path and shadow installed packages. +_REPO_ROOT = Path(__file__).resolve().parent.parent +if (_REPO_ROOT / "provider").is_dir() and str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + + +class FakeWebserver: + """Captures every dynamic-route registration so tests can drive them through aiohttp. + + Mirrors the surface of ``mass.webserver`` that this plugin uses, without + depending on a real Music Assistant install. Exposed via the + :func:`fake_webserver` fixture and :func:`build_aiohttp_app` helper. + """ + + def __init__( + self, + *, + base_url: str = "http://localhost:8095", + publish_ip: str = "127.0.0.1", + ) -> None: + """Initialise an empty registry with the given advertised endpoints.""" + self.routes: list[tuple[str, Any, str]] = [] + self.base_url = base_url + self.publish_ip = publish_ip + + def register_dynamic_route(self, path: str, handler: Any, method: str = "*") -> Any: + """Mirror ``mass.webserver.register_dynamic_route``: store + return unregister.""" + import contextlib # noqa: PLC0415 - keep stdlib import inside method to mirror runtime + + self.routes.append((path, handler, method)) + + def _unregister() -> None: + with contextlib.suppress(ValueError): + self.routes.remove((path, handler, method)) + + return _unregister + + @property + def handler(self) -> Any: + """Return the single registered handler (convenience for one-route tests).""" + return self.routes[0][1] if self.routes else None + + +def build_aiohttp_app(fake_ws: FakeWebserver) -> Any: + """Translate captured ``(path, handler, method)`` tuples into an aiohttp app. + + Mirrors MA's real dynamic-route matching + (``helpers/webserver.py::_handle_catch_all``): a path registered as + ``"/*"`` matches BOTH the bare ```` (no trailing slash) and + any descendant ``/...``. Aiohttp's ``{tail:.*}`` pattern requires + the slash, so we add an explicit route for the bare stem alongside the + wildcard. Without that, the harness silently misses the + wizard-advertised MCP entry-point URL (``/mcp/v1`` — no + trailing slash) that real clients connect to. + """ + from aiohttp import web # noqa: PLC0415 - aiohttp only needed by HTTP-level tests + + app = web.Application() + for path, handler, method in fake_ws.routes: + if path.endswith("/*"): + stem = path[:-2] + app.router.add_route(method, stem, handler) + app.router.add_route(method, f"{stem}/{{tail:.*}}", handler) + else: + app.router.add_route(method, path, handler) + return app + + +@pytest.fixture +def fake_webserver() -> FakeWebserver: + """Fresh ``FakeWebserver`` instance per test.""" + return FakeWebserver() + + +@pytest.fixture +def mock_user() -> MagicMock: + """A minimal stand-in for an MA ``User`` object.""" + user = MagicMock() + user.user_id = "u1" + user.username = "tester" + user.role = MagicMock(value="admin") + user.enabled = True + return user + + +@pytest.fixture +def mock_mass(mock_user: MagicMock) -> MagicMock: + """A MusicAssistant stub with the surface area we touch.""" + mass = MagicMock() + mass.webserver = MagicMock() + mass.webserver.base_url = "http://localhost:8095" + mass.webserver.publish_ip = "127.0.0.1" + mass.webserver.auth = MagicMock() + mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=mock_user) + mass.webserver.register_dynamic_route = MagicMock(return_value=lambda: None) + + mass.music = MagicMock() + mass.music.search = AsyncMock() + mass.music.recently_added_tracks = AsyncMock(return_value=[]) + mass.music.recently_played = AsyncMock(return_value=[]) + mass.music.recommendations = AsyncMock(return_value=[]) + mass.music.get_item_by_uri = AsyncMock() + + mass.music.tracks.library_items = AsyncMock(return_value=[]) + mass.music.tracks.get_library_item = AsyncMock() + mass.music.albums.library_items = AsyncMock(return_value=[]) + mass.music.albums.get_library_item = AsyncMock() + mass.music.artists.library_items = AsyncMock(return_value=[]) + mass.music.artists.get_library_item = AsyncMock() + mass.music.playlists.library_items = AsyncMock(return_value=[]) + mass.music.playlists.get_library_item = AsyncMock() + mass.music.playlists.create_playlist = AsyncMock() + mass.music.playlists.add_playlist_track = AsyncMock() + mass.music.playlists.add_playlist_tracks = AsyncMock() + mass.music.playlists.remove_playlist_tracks = AsyncMock() + mass.music.radio.library_items = AsyncMock(return_value=[]) + mass.music.radio.get_library_item = AsyncMock() + mass.music.add_item_to_favorites = AsyncMock() + mass.music.remove_item_from_favorites = AsyncMock() + mass.music.add_item_to_library = AsyncMock() + mass.music.remove_item_from_library = AsyncMock() + mass.music.mark_item_played = AsyncMock() + + mass.player_queues = MagicMock() + mass.player_queues.get_active_queue = MagicMock(return_value=None) + mass.player_queues.get = MagicMock(return_value=None) + mass.player_queues.items = MagicMock(return_value=[]) + mass.player_queues.play_media = AsyncMock() + mass.player_queues.play_pause = AsyncMock() + mass.player_queues.stop = AsyncMock() + mass.player_queues.next = AsyncMock() + mass.player_queues.previous = AsyncMock() + mass.player_queues.skip = AsyncMock() + mass.player_queues.seek = AsyncMock() + mass.player_queues.play_index = AsyncMock() + mass.player_queues.set_shuffle = AsyncMock() + mass.player_queues.transfer_queue = AsyncMock() + mass.player_queues.clear = MagicMock() + + mass.players = MagicMock() + mass.players.all_players = MagicMock(return_value=[]) + mass.players.get_player = MagicMock(return_value=None) + mass.players.cmd_power = AsyncMock() + mass.players.cmd_group = AsyncMock() + mass.players.cmd_volume_set = AsyncMock() + mass.players.cmd_volume_up = AsyncMock() + mass.players.cmd_volume_down = AsyncMock() + mass.players.cmd_volume_mute = AsyncMock() + mass.players.cmd_group_volume = AsyncMock() + mass.players.play_announcement = AsyncMock() + + return mass + + +@pytest.fixture +def mock_config() -> MagicMock: + """A ProviderConfig stub. ``get_value`` returns whatever is set in ``_values``.""" + config = MagicMock() + config._values = { + # Defaults match build_config_entries + "require_auth": True, + "mount_path": "/mcp/v1", + "extra_allowed_origins": "", + "enforce_audience": False, + "require_confirmation": True, + "query_library": True, + "query_queue": True, + "query_players": True, + "query_metadata": True, + "control_playback": False, + "control_volume": False, + "control_players": False, + "control_media": False, + "edit_library": False, + "edit_queue": False, + "edit_playlists": False, + "edit_favorites": False, + "delete_library": False, + "delete_queue": False, + "delete_playlists": False, + "delete_favorites": False, + "res_library": True, + "res_player": True, + "res_prompts": True, + } + + def _get(key: str, default: Any = None) -> Any: + return config._values.get(key, default) + + config.get_value = MagicMock(side_effect=_get) + return config + + +@pytest.fixture +def have_fastmcp() -> bool: + """True if ``fastmcp`` is importable in the current environment.""" + return importlib.util.find_spec("fastmcp") is not None + + +def pytest_collection_modifyitems(config: Any, items: Iterator[Any]) -> None: + """Skip integration tests by default unless ``--run-integration`` is passed.""" + if config.getoption("--run-integration", default=False): + return + skip_integration = pytest.mark.skip(reason="integration tests require --run-integration") + for item in items: + if "integration" in item.keywords: + item.add_marker(skip_integration) + + +def pytest_addoption(parser: Any) -> None: + """Add the ``--run-integration`` CLI flag.""" + parser.addoption( + "--run-integration", + action="store_true", + default=False, + help="run tests marked @pytest.mark.integration", + ) diff --git a/tests/providers/fastmcp_server/test_annotations.py b/tests/providers/fastmcp_server/test_annotations.py new file mode 100644 index 0000000000..f1967d28f4 --- /dev/null +++ b/tests/providers/fastmcp_server/test_annotations.py @@ -0,0 +1,91 @@ +"""Tests for ToolAnnotations sweep across all sub-server tools (C5).""" +# mypy: disable-error-code="arg-type, no-untyped-def, type-arg, assignment, operator, misc" + +from __future__ import annotations + +from typing import Any + +import pytest +from fastmcp import Client, FastMCP + +from music_assistant.providers.fastmcp_server.tools import ( + build_library_server, + build_media_server, + build_metadata_server, + build_playback_server, + build_players_server, + build_playlists_server, + build_queue_server, + build_volume_server, +) + +_BUILDERS = [ + ("library", build_library_server), + ("queue", build_queue_server), + ("playback", build_playback_server), + ("players", build_players_server), + ("playlists", build_playlists_server), + ("volume", build_volume_server), + ("media", build_media_server), + ("metadata", build_metadata_server), +] + + +# Spec-mandated destructive tools per the C5 mapping table. +_DESTRUCTIVE_NAMES = { + "queue_clear_queue", + "playlists_remove_tracks", + "media_remove_from_favorites", + "media_remove_from_library", +} + +# Strictly read-only categories (every tool in these sub-servers). +_READONLY_NAMESPACES = {"library", "metadata"} +# Plus a couple of inspector tools elsewhere. +_READONLY_TOOL_NAMES = { + "queue_get_active_queue", + "players_list_players", + "players_get_player", +} + + +@pytest.fixture +def mounted_server(mock_mass: Any) -> FastMCP: + """Build a root FastMCP with all 8 sub-servers mounted, like MCPServerRuntime.""" + mcp: FastMCP = FastMCP(name="test") + for ns, builder in _BUILDERS: + mcp.mount(builder(mock_mass), namespace=ns) + return mcp + + +async def test_every_tool_has_title_and_annotations(mounted_server: FastMCP) -> None: + """Every exposed tool carries a non-empty title and a ToolAnnotations object.""" + async with Client(mounted_server) as client: + tools = await client.list_tools() + assert tools, "expected at least one tool" + for tool in tools: + ann = tool.annotations + assert ann is not None, f"{tool.name}: missing annotations" + assert ann.title, f"{tool.name}: missing title" + # openWorldHint must be explicitly False — none of our tools touch the open web. + assert ann.openWorldHint is False, f"{tool.name}: unexpected openWorldHint=True" + + +async def test_destructive_tools_are_marked(mounted_server: FastMCP) -> None: + """Destructive operations are flagged so clients can prompt for confirmation.""" + async with Client(mounted_server) as client: + tools = {t.name: t for t in await client.list_tools()} + for name in _DESTRUCTIVE_NAMES: + assert name in tools, f"missing tool {name}" + assert tools[name].annotations.destructiveHint is True, name + + +async def test_readonly_tools_are_marked(mounted_server: FastMCP) -> None: + """Read-only sub-servers (library, metadata) and a few inspectors carry readOnlyHint=True.""" + async with Client(mounted_server) as client: + tools = await client.list_tools() + for tool in tools: + ns = tool.name.split("_", 1)[0] + if ns in _READONLY_NAMESPACES or tool.name in _READONLY_TOOL_NAMES: + assert tool.annotations.readOnlyHint is True, tool.name + assert tool.annotations.destructiveHint is False, tool.name diff --git a/tests/providers/fastmcp_server/test_apply_permission_change.py b/tests/providers/fastmcp_server/test_apply_permission_change.py new file mode 100644 index 0000000000..fd54e3712e --- /dev/null +++ b/tests/providers/fastmcp_server/test_apply_permission_change.py @@ -0,0 +1,87 @@ +"""Tests for ``MCPServerRuntime.apply_permission_change`` hot-swap vs restart routing. + +The provider's :meth:`update_config` strips ``values/`` prefixes from MA's +``changed_keys`` set and passes the normalised set to +:meth:`MCPServerRuntime.apply_permission_change`. The runtime must decide +hot-swap vs full restart from that explicit set — not from a re-diff of +``self._config`` vs the new config, because Music Assistant mutates +:class:`ProviderConfig` in place, so the old and new references point to +the same object and a diff is empty. +""" +# mypy: disable-error-code="arg-type, no-untyped-def, type-arg, assignment, operator, misc" + +from __future__ import annotations + +import logging +from unittest.mock import AsyncMock, MagicMock + +import pytest + + +@pytest.mark.asyncio +async def test_resource_toggle_triggers_full_restart( + mock_mass: MagicMock, mock_config: MagicMock +) -> None: + """A ``res_*`` toggle must restart the runtime (resources are bound at start time). + + The runtime can hot-swap only permission tags; resource registration + happens once during :meth:`start`. If a resource toggle is mis-routed + to the hot-swap path, the user's change silently has no effect. + """ + from music_assistant.providers.fastmcp_server.server import MCPServerRuntime # noqa: PLC0415 + + runtime = MCPServerRuntime(mock_mass, mock_config, logging.getLogger("t")) + runtime.stop = AsyncMock() + runtime.start = AsyncMock() + + await runtime.apply_permission_change(mock_config, changed_keys={"res_library"}) + + runtime.stop.assert_awaited_once() + runtime.start.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_empty_changed_keys_does_not_restart( + mock_mass: MagicMock, mock_config: MagicMock +) -> None: + """A no-op call (``changed_keys=set()``) must not force a restart. + + MA's ``ConfigController`` short-circuits when there are no diffs, but the + guard belongs here too: an empty set is by definition a subset of the + permission keys, so classify as permission-only and let the hot-swap + path noop-rebuild the tag snapshot. + """ + from music_assistant.providers.fastmcp_server.server import MCPServerRuntime # noqa: PLC0415 + + runtime = MCPServerRuntime(mock_mass, mock_config, logging.getLogger("t")) + runtime._allowed_tags = {"query:library"} + runtime.stop = AsyncMock() + runtime.start = AsyncMock() + + await runtime.apply_permission_change(mock_config, changed_keys=set()) + + runtime.stop.assert_not_awaited() + runtime.start.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_permission_only_change_hot_swaps( + mock_mass: MagicMock, mock_config: MagicMock +) -> None: + """A permission-key-only change updates ``_allowed_tags`` in place — no restart.""" + from music_assistant.providers.fastmcp_server.server import MCPServerRuntime # noqa: PLC0415 + + runtime = MCPServerRuntime(mock_mass, mock_config, logging.getLogger("t")) + # Pretend the runtime has started so _allowed_tags exists and hot-swap is viable. + runtime._allowed_tags = {"query:library"} + runtime.stop = AsyncMock() + runtime.start = AsyncMock() + + await runtime.apply_permission_change( + mock_config, changed_keys={"control_volume", "query_library"} + ) + + runtime.stop.assert_not_awaited() + runtime.start.assert_not_awaited() + # _allowed_tags rebuilt from new_config (default: 4 query tags enabled). + assert "query:library" in runtime._allowed_tags diff --git a/tests/providers/fastmcp_server/test_auth.py b/tests/providers/fastmcp_server/test_auth.py new file mode 100644 index 0000000000..dc950ddd3e --- /dev/null +++ b/tests/providers/fastmcp_server/test_auth.py @@ -0,0 +1,132 @@ +"""Tests for ``provider.auth.MASTokenVerifier``.""" + +from __future__ import annotations + +import base64 +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from music_assistant.providers.fastmcp_server.auth import MASTokenVerifier + + +def _make_jwt(payload: dict[str, object]) -> str: + """Forge an unsigned-but-structurally-valid JWT for audience-claim tests. + + The signature isn't checked by ``MASTokenVerifier`` (verification is MA's + job); we only inspect the payload's ``aud`` claim. + """ + header = base64.urlsafe_b64encode(b'{"alg":"none","typ":"JWT"}').rstrip(b"=").decode() + body = base64.urlsafe_b64encode(json.dumps(payload).encode()).rstrip(b"=").decode() + return f"{header}.{body}.signature" + + +@pytest.mark.asyncio +async def test_valid_token_returns_access_token(mock_mass: MagicMock, mock_user: MagicMock) -> None: + """A valid token yields an AccessToken bound to the canonical resource URI.""" + mock_mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=mock_user) + verifier = MASTokenVerifier( + mock_mass, + base_url="http://localhost:8095", + public_resource_uri="http://localhost:8095/mcp/v1", + ) + + token = await verifier.verify_token("valid-token") + + assert token is not None + assert token.client_id == "u1" + assert token.scopes == [] + assert token.resource == "http://localhost:8095/mcp/v1" + assert token.token == "valid-token" + + +@pytest.mark.asyncio +async def test_invalid_token_returns_none(mock_mass: MagicMock) -> None: + """An invalid (rejected) token returns None.""" + mock_mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=None) + verifier = MASTokenVerifier(mock_mass) + assert await verifier.verify_token("nope") is None + + +@pytest.mark.asyncio +async def test_disabled_user_returns_none(mock_mass: MagicMock, mock_user: MagicMock) -> None: + """A user marked disabled is rejected even if the token is valid.""" + mock_user.enabled = False + mock_mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=mock_user) + verifier = MASTokenVerifier(mock_mass) + assert await verifier.verify_token("valid-but-disabled") is None + + +@pytest.mark.asyncio +async def test_authenticate_called_once(mock_mass: MagicMock, mock_user: MagicMock) -> None: + """We delegate exactly once per verify_token call (no retry storm).""" + mock_mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=mock_user) + verifier = MASTokenVerifier(mock_mass) + await verifier.verify_token("t") + mock_mass.webserver.auth.authenticate_with_token.assert_awaited_once_with("t") + + +@pytest.mark.asyncio +async def test_underlying_exception_swallowed(mock_mass: MagicMock) -> None: + """If MA's auth raises, we log and return None — never propagate.""" + mock_mass.webserver.auth.authenticate_with_token = AsyncMock( + side_effect=RuntimeError("db down") + ) + verifier = MASTokenVerifier(mock_mass) + assert await verifier.verify_token("any") is None + + +# ── audience binding (C6) ──────────────────────────────────────────────────── + + +_RESOURCE = "http://localhost:8095/mcp/v1" + + +@pytest.mark.asyncio +async def test_legacy_token_passes_in_soft_mode(mock_mass: MagicMock, mock_user: MagicMock) -> None: + """Non-JWT (legacy hash) tokens have no aud; soft mode accepts them.""" + mock_mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=mock_user) + verifier = MASTokenVerifier(mock_mass, public_resource_uri=_RESOURCE, enforce_audience=False) + assert await verifier.verify_token("legacy-hash-token") is not None + + +@pytest.mark.asyncio +async def test_legacy_token_rejected_in_strict_mode( + mock_mass: MagicMock, mock_user: MagicMock +) -> None: + """Strict mode rejects tokens that have no audience claim at all.""" + mock_mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=mock_user) + verifier = MASTokenVerifier(mock_mass, public_resource_uri=_RESOURCE, enforce_audience=True) + assert await verifier.verify_token("legacy-hash-token") is None + + +@pytest.mark.asyncio +async def test_jwt_with_matching_aud_accepted_in_strict_mode( + mock_mass: MagicMock, mock_user: MagicMock +) -> None: + """A JWT carrying ``aud == public_resource_uri`` passes strict enforcement.""" + mock_mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=mock_user) + verifier = MASTokenVerifier(mock_mass, public_resource_uri=_RESOURCE, enforce_audience=True) + token = _make_jwt({"sub": "u1", "aud": _RESOURCE}) + assert await verifier.verify_token(token) is not None + + +@pytest.mark.asyncio +async def test_jwt_with_mismatched_aud_rejected_in_strict_mode( + mock_mass: MagicMock, mock_user: MagicMock +) -> None: + """A JWT issued for a different audience is rejected in strict mode.""" + mock_mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=mock_user) + verifier = MASTokenVerifier(mock_mass, public_resource_uri=_RESOURCE, enforce_audience=True) + token = _make_jwt({"sub": "u1", "aud": "http://other.example/api"}) + assert await verifier.verify_token(token) is None + + +@pytest.mark.asyncio +async def test_jwt_with_aud_list_accepted(mock_mass: MagicMock, mock_user: MagicMock) -> None: + """RFC 8707 allows ``aud`` to be a list — match is membership.""" + mock_mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=mock_user) + verifier = MASTokenVerifier(mock_mass, public_resource_uri=_RESOURCE, enforce_audience=True) + token = _make_jwt({"sub": "u1", "aud": ["http://other.example", _RESOURCE]}) + assert await verifier.verify_token(token) is not None diff --git a/tests/providers/fastmcp_server/test_config_entries.py b/tests/providers/fastmcp_server/test_config_entries.py new file mode 100644 index 0000000000..09d45eb127 --- /dev/null +++ b/tests/providers/fastmcp_server/test_config_entries.py @@ -0,0 +1,95 @@ +"""Tests for ``provider.config.build_config_entries``.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from music_assistant.providers.fastmcp_server.config import build_config_entries +from music_assistant.providers.fastmcp_server.constants import ( + CONF_DELETE_LIBRARY, + CONF_MOUNT_PATH, + CONF_QUERY_LIBRARY, + CONF_REQUIRE_AUTH, + PERMISSION_KEYS, + RESOURCE_KEYS, +) + +if TYPE_CHECKING: + from unittest.mock import MagicMock + + +def test_total_entry_count(mock_mass: MagicMock) -> None: + """27 entries: 1 info label + 1 connect-wizard action + 6 server + 16 perms + 3 resources.""" + entries = build_config_entries(mock_mass, {}) + assert len(entries) == 1 + 1 + 6 + 16 + 3 + + +def test_all_permission_keys_present(mock_mass: MagicMock) -> None: + """Every permission key from PERMISSION_KEYS has a matching ConfigEntry.""" + entries = build_config_entries(mock_mass, {}) + keys = {e.key for e in entries} + assert PERMISSION_KEYS.issubset(keys) + assert RESOURCE_KEYS.issubset(keys) + assert CONF_REQUIRE_AUTH in keys + + +def test_delete_keys_default_false(mock_mass: MagicMock) -> None: + """All delete-family permissions default to False (least-privilege).""" + entries = {e.key: e for e in build_config_entries(mock_mass, {})} + mutation_prefixes = ("delete_", "control_", "edit_") + for key in PERMISSION_KEYS: + if key.startswith(mutation_prefixes): + assert entries[key].default_value is False, f"{key} should default False" + + +def test_query_keys_default_true(mock_mass: MagicMock) -> None: + """All query-family permissions default to True.""" + entries = {e.key: e for e in build_config_entries(mock_mass, {})} + assert entries[CONF_QUERY_LIBRARY].default_value is True + + +def test_categories_match_pr2889_ux(mock_mass: MagicMock) -> None: + """Categories mirror upstream PR #2889 grouping for familiarity at review time.""" + entries = build_config_entries(mock_mass, {}) + categories = {getattr(e, "category", None) for e in entries if getattr(e, "category", None)} + # ``Generic`` comes from the Connect Wizard ACTION entry, which mirrors the + # Spotify provider's ``CONF_ACTION_AUTH`` button (no explicit category). + assert categories == { + "Server", + "Query Permissions", + "Control Permissions", + "Edit Permissions", + "Delete Permissions", + "MCP Resources", + "generic", + } + + +def test_info_label_includes_base_url(mock_mass: MagicMock) -> None: + """The info label embeds MA's base_url so users see where to point clients.""" + entries = build_config_entries(mock_mass, {}) + info = entries[0] + assert mock_mass.webserver.base_url in str(info.label) + assert "/mcp/v1" in str(info.label) + + +def test_info_label_normalises_mount_path_without_leading_slash( + mock_mass: MagicMock, +) -> None: + """A user-typed ``mcp/v1`` (no leading slash) must still render a valid URL. + + Regression for upstream PR #3858 Copilot comment: the runtime normalises + the mount path, but the info label did not — so the displayed endpoint + glued the host to the path with no separator (``…:8095mcp/v1``). + """ + entries = build_config_entries(mock_mass, {CONF_MOUNT_PATH: "mcp/v1"}) + label = str(entries[0].label) + base_url = mock_mass.webserver.base_url + assert f"{base_url}/mcp/v1" in label + assert f"{base_url}mcp" not in label + + +def test_delete_library_default(mock_mass: MagicMock) -> None: + """Specifically: delete_library defaults False (a hard-to-undo permission).""" + entries = {e.key: e for e in build_config_entries(mock_mass, {})} + assert entries[CONF_DELETE_LIBRARY].default_value is False diff --git a/tests/providers/fastmcp_server/test_connect_wizard.py b/tests/providers/fastmcp_server/test_connect_wizard.py new file mode 100644 index 0000000000..3806bed279 --- /dev/null +++ b/tests/providers/fastmcp_server/test_connect_wizard.py @@ -0,0 +1,900 @@ +"""Tests for the Connect Wizard — endpoints, action handler, client templates. + +These tests run against the real :func:`mount_connect_wizard` flow on a +``FakeWebserver`` (no real MA stack required); ``mass.webserver.auth`` is +stubbed with ``AsyncMock`` / ``MagicMock`` so we can assert exactly which auth +calls the wizard fires for each user-facing operation. +""" +# ruff: noqa: D401 +# D401: pytest fixture/test docstrings describe *what is returned*. +# S101: ``assert`` is the pytest convention. +# PLR2004: small magic numbers (10 client specs, 5 routes) are obvious in context. +# mypy: disable-error-code="type-arg" + +from __future__ import annotations + +import json +from types import SimpleNamespace +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, MagicMock + +import pytest +from aiohttp.test_utils import TestClient, TestServer + +from music_assistant.providers.fastmcp_server import ( + _detect_external_base_url, + _dispatch_open_connect, + _sanitize_external_base_url, +) +from music_assistant.providers.fastmcp_server.connect.actions import handle_open_connect_action +from music_assistant.providers.fastmcp_server.connect.clients import CLIENTS, lookup_client +from music_assistant.providers.fastmcp_server.connect.mount import mount_connect_wizard +from music_assistant.providers.fastmcp_server.constants import CONF_CONNECT_EXTERNAL_URL + +from .conftest import FakeWebserver, build_aiohttp_app + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + +@pytest.fixture +def wizard_mass(mock_user: MagicMock) -> MagicMock: + """A ``mass`` stub with ``FakeWebserver`` + the auth surface the wizard touches.""" + fake_ws = FakeWebserver() + fake_ws.auth = SimpleNamespace( # type: ignore[attr-defined] + login=AsyncMock( + return_value={ + "success": True, + "access_token": "sess-1", + "user": { + "user_id": mock_user.user_id, + "username": mock_user.username, + "role": "admin", + }, + } + ), + create_token=AsyncMock(return_value="jwt-xyz"), + authenticate_with_token=AsyncMock(return_value=mock_user), + get_current_user=MagicMock(return_value=mock_user), + # Sanctioned auth-API surface that provider/connect/_revoke.py drives. + revoke_token=AsyncMock(), + get_user_tokens=AsyncMock(return_value=[]), + get_token_id_from_token=AsyncMock(side_effect=lambda t: f"tid:{t}"), + ) + mass = MagicMock() + mass.webserver = fake_ws + mass.signal_event = MagicMock() + return mass + + +@pytest.fixture +async def wizard_client(wizard_mass: MagicMock) -> AsyncIterator[TestClient]: + """Mount the wizard on /mcp/v1 and yield an aiohttp TestClient.""" + unmount = await mount_connect_wizard( + wizard_mass, + mount_path="/mcp/v1", + enabled_tags_provider=lambda: ["query:library", "control:playback"], + extra_origins_csv="", + ) + async with TestClient(TestServer(build_aiohttp_app(wizard_mass.webserver))) as client: + yield client + unmount() + + +# ── HTML page + info endpoint ──────────────────────────────────────────────── + + +async def test_connect_html_served(wizard_client: TestClient) -> None: + """``GET /mcp/v1/connect`` returns an HTML page mentioning Music Assistant.""" + resp = await wizard_client.get("/mcp/v1/connect", headers={"Origin": "http://localhost:8095"}) + assert resp.status == 200 + assert resp.headers["Content-Type"].startswith("text/html") + body = await resp.text() + assert "Music Assistant" in body + assert "connect" in body.lower() + + +async def test_info_endpoint_shape(wizard_client: TestClient) -> None: + """``GET /mcp/v1/connect/info`` returns the meta JSON the UI needs.""" + resp = await wizard_client.get( + "/mcp/v1/connect/info", headers={"Origin": "http://localhost:8095"} + ) + assert resp.status == 200 + data = await resp.json() + for key in ( + "mount_path", + "mcp_url_loopback", + "mcp_url_advertised", + "permissions", + "clients", + "well_known_url", + ): + assert key in data, f"missing key: {key}" + assert data["mount_path"] == "/mcp/v1" + assert data["mcp_url_loopback"].endswith("/mcp/v1") + assert isinstance(data["clients"], list) + assert len(data["clients"]) >= 10 + assert isinstance(data["permissions"], list) + assert all(isinstance(p, str) for p in data["permissions"]) + + +async def test_info_reflects_enabled_tags(wizard_mass: MagicMock) -> None: + """``info.permissions`` reflects whatever ``enabled_tags_provider()`` returns.""" + unmount = await mount_connect_wizard( + wizard_mass, + mount_path="/mcp/v1", + enabled_tags_provider=lambda: ["control:playback", "edit:queue"], + extra_origins_csv="", + ) + try: + async with TestClient(TestServer(build_aiohttp_app(wizard_mass.webserver))) as client: + resp = await client.get("/mcp/v1/connect/info") + data = await resp.json() + assert "control:playback" in data["permissions"] + assert "edit:queue" in data["permissions"] + finally: + unmount() + + +# ── Bootstrap exchange ─────────────────────────────────────────────────────── + + +async def test_exchange_bootstrap_success( + wizard_client: TestClient, wizard_mass: MagicMock, mock_user: MagicMock +) -> None: + """A valid bootstrap token is exchanged for a session_token bound to the same user.""" + resp = await wizard_client.post( + "/mcp/v1/connect/exchange", + json={"bootstrap": "boot-1"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 200 + data = await resp.json() + assert data["session_token"] == "jwt-xyz" + assert data["user"]["user_id"] == mock_user.user_id + + wizard_mass.webserver.auth.authenticate_with_token.assert_awaited_with("boot-1") + wizard_mass.webserver.auth.create_token.assert_awaited_with( + user=mock_user, + name="MCP — wizard session", + is_long_lived=False, + ) + + +async def test_exchange_bootstrap_invalid_401( + wizard_client: TestClient, wizard_mass: MagicMock +) -> None: + """Invalid bootstrap → 401 and ``create_token`` is NOT called.""" + wizard_mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=None) + wizard_mass.webserver.auth.create_token.reset_mock() + + resp = await wizard_client.post( + "/mcp/v1/connect/exchange", + json={"bootstrap": "bad"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 401 + wizard_mass.webserver.auth.create_token.assert_not_called() + + +async def test_exchange_revokes_bootstrap_on_success( + wizard_client: TestClient, wizard_mass: MagicMock +) -> None: + """Successful exchange revokes the bootstrap via ``auth.revoke_token``.""" + auth = wizard_mass.webserver.auth + + resp = await wizard_client.post( + "/mcp/v1/connect/exchange", + json={"bootstrap": "boot-1"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 200 + + auth.revoke_token.assert_awaited_once_with("tid:boot-1") + + +async def test_exchange_invalid_bootstrap_does_not_revoke( + wizard_client: TestClient, wizard_mass: MagicMock +) -> None: + """Invalid bootstrap → no revoke and no mint.""" + wizard_mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=None) + + resp = await wizard_client.post( + "/mcp/v1/connect/exchange", + json={"bootstrap": "bad"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 401 + wizard_mass.webserver.auth.revoke_token.assert_not_called() + + +async def test_exchange_revoke_failure_still_returns_session( + wizard_client: TestClient, wizard_mass: MagicMock +) -> None: + """A ``revoke_token`` exception is swallowed; the exchange still issues a session_token.""" + auth = wizard_mass.webserver.auth + auth.revoke_token = AsyncMock(side_effect=RuntimeError("revoke failed")) + + resp = await wizard_client.post( + "/mcp/v1/connect/exchange", + json={"bootstrap": "boot-1"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 200 + data = await resp.json() + assert data["session_token"] == "jwt-xyz" + auth.create_token.assert_awaited_once() + + +async def test_exchange_get_token_id_none_skips_revoke( + wizard_client: TestClient, wizard_mass: MagicMock +) -> None: + """When ``get_token_id_from_token`` returns ``None`` the revoke is skipped, mint still happens.""" + auth = wizard_mass.webserver.auth + auth.get_token_id_from_token = AsyncMock(return_value=None) + + resp = await wizard_client.post( + "/mcp/v1/connect/exchange", + json={"bootstrap": "boot-1"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 200 + auth.revoke_token.assert_not_called() + auth.create_token.assert_awaited_once() + + +# ── Login form fallback ────────────────────────────────────────────────────── + + +async def test_login_success_returns_token( + wizard_client: TestClient, wizard_mass: MagicMock +) -> None: + """Successful login returns the access_token issued by MA.""" + resp = await wizard_client.post( + "/mcp/v1/connect/login", + json={"username": "tester", "password": "secret"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 200 + data = await resp.json() + assert data["session_token"] == "sess-1" + assert data["user"]["username"] == "tester" + + wizard_mass.webserver.auth.login.assert_awaited_with( + username="tester", password="secret", provider_id="builtin" + ) + + +async def test_login_failure_401(wizard_client: TestClient, wizard_mass: MagicMock) -> None: + """Login failure → 401 with the error MA reported.""" + wizard_mass.webserver.auth.login = AsyncMock( + return_value={"success": False, "error": "bad creds"} + ) + + resp = await wizard_client.post( + "/mcp/v1/connect/login", + json={"username": "x", "password": "y"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 401 + body = await resp.json() + assert body.get("error") == "bad creds" + + +# ── Per-client token mint ──────────────────────────────────────────────────── + + +async def test_token_endpoint_mints_named( + wizard_client: TestClient, wizard_mass: MagicMock, mock_user: MagicMock +) -> None: + """Per-client mint creates a long-lived token labeled ``MCP — ``.""" + resp = await wizard_client.post( + "/mcp/v1/connect/token", + json={"session_token": "sess-1", "client_id": "cursor"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 200 + data = await resp.json() + assert data["token"] == "jwt-xyz" + + wizard_mass.webserver.auth.create_token.assert_awaited_with( + user=mock_user, + name="MCP — Cursor", + is_long_lived=True, + ) + + +async def test_token_endpoint_unknown_client_400(wizard_client: TestClient) -> None: + """Unknown ``client_id`` → 400.""" + resp = await wizard_client.post( + "/mcp/v1/connect/token", + json={"session_token": "sess-1", "client_id": "bogus"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 400 + + +async def test_token_endpoint_invalid_session_401( + wizard_client: TestClient, wizard_mass: MagicMock +) -> None: + """Invalid session_token → 401 and ``create_token`` is NOT called.""" + wizard_mass.webserver.auth.authenticate_with_token = AsyncMock(return_value=None) + wizard_mass.webserver.auth.create_token.reset_mock() + + resp = await wizard_client.post( + "/mcp/v1/connect/token", + json={"session_token": "nope", "client_id": "cursor"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 401 + wizard_mass.webserver.auth.create_token.assert_not_called() + + +async def test_token_endpoint_server_dedup_revokes_same_name( + wizard_client: TestClient, wizard_mass: MagicMock +) -> None: + """Prior tokens with the same client-token name for the user are revoked. + + Tokens with other names are left alone; ``create_token`` is still called + once. Asserts the call against ``auth.revoke_token`` (sanctioned API), + not the underlying DB. + """ + auth = wizard_mass.webserver.auth + auth.get_user_tokens = AsyncMock( + return_value=[ + SimpleNamespace(token_id="old-1", name="MCP — Cursor", user_id="u1"), + SimpleNamespace(token_id="old-2", name="MCP — Cursor", user_id="u1"), + SimpleNamespace(token_id="keep", name="MCP — Other", user_id="u1"), + ] + ) + + resp = await wizard_client.post( + "/mcp/v1/connect/token", + json={"session_token": "sess-1", "client_id": "cursor"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 200 + + revoked_ids = sorted(call.args[0] for call in auth.revoke_token.await_args_list) + assert revoked_ids == ["old-1", "old-2"] + auth.create_token.assert_awaited_once() + + +async def test_token_endpoint_dedup_lookup_failure_does_not_fail_mint( + wizard_client: TestClient, wizard_mass: MagicMock +) -> None: + """A ``get_user_tokens`` exception is logged but the mint still succeeds.""" + auth = wizard_mass.webserver.auth + auth.get_user_tokens = AsyncMock(side_effect=RuntimeError("api down")) + + resp = await wizard_client.post( + "/mcp/v1/connect/token", + json={"session_token": "sess-1", "client_id": "cursor"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 200 + auth.create_token.assert_awaited_once() + + +async def test_token_endpoint_no_prior_no_revoke( + wizard_client: TestClient, wizard_mass: MagicMock +) -> None: + """No prior tokens → ``auth.revoke_token`` is never called.""" + auth = wizard_mass.webserver.auth + + resp = await wizard_client.post( + "/mcp/v1/connect/token", + json={"session_token": "sess-1", "client_id": "cursor"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 200 + auth.revoke_token.assert_not_called() + + +async def test_token_endpoint_revoke_failure_does_not_fail_mint( + wizard_client: TestClient, wizard_mass: MagicMock +) -> None: + """A ``revoke_token`` exception is swallowed; the new mint still happens.""" + auth = wizard_mass.webserver.auth + auth.get_user_tokens = AsyncMock( + return_value=[SimpleNamespace(token_id="old", name="MCP — Cursor", user_id="u1")] + ) + auth.revoke_token = AsyncMock(side_effect=RuntimeError("revoke failed")) + + resp = await wizard_client.post( + "/mcp/v1/connect/token", + json={"session_token": "sess-1", "client_id": "cursor"}, + headers={"Origin": "http://localhost:8095"}, + ) + assert resp.status == 200 + auth.create_token.assert_awaited_once() + + +# ── Origin & mount ─────────────────────────────────────────────────────────── + + +async def test_origin_rejection(wizard_client: TestClient) -> None: + """A non-allowlisted Origin → 403.""" + resp = await wizard_client.get( + "/mcp/v1/connect/info", headers={"Origin": "http://evil.example"} + ) + assert resp.status == 403 + + +async def test_mount_unmount_cycle(wizard_mass: MagicMock) -> None: + """``mount_connect_wizard`` registers 5 routes; the returned callback removes all.""" + fake_ws = wizard_mass.webserver + assert fake_ws.routes == [] + unmount = await mount_connect_wizard( + wizard_mass, + mount_path="/mcp/v1", + enabled_tags_provider=list, + extra_origins_csv="", + ) + assert len(fake_ws.routes) == 5 + unmount() + assert fake_ws.routes == [] + + +async def test_mount_path_relative(wizard_mass: MagicMock) -> None: + """Wizard routes are nested under whatever ``mount_path`` is given.""" + unmount = await mount_connect_wizard( + wizard_mass, + mount_path="/custom", + enabled_tags_provider=list, + extra_origins_csv="", + ) + try: + paths = [r[0] for r in wizard_mass.webserver.routes] + assert all(p.startswith("/custom/connect") for p in paths) + finally: + unmount() + + +# ── ACTION handler (signal_event) ──────────────────────────────────────────── + + +async def test_action_handler_signals_url_with_bootstrap( + wizard_mass: MagicMock, mock_user: MagicMock +) -> None: + """Action handler mints a bootstrap token and signals a URL containing it.""" + await handle_open_connect_action( + wizard_mass, + current_user=mock_user, + mount_path="/mcp/v1", + base_url="http://localhost:8095", + ) + + wizard_mass.webserver.auth.create_token.assert_awaited_with( + user=mock_user, + name="MCP — wizard bootstrap", + is_long_lived=False, + ) + wizard_mass.signal_event.assert_called_once() + args, kwargs = wizard_mass.signal_event.call_args + url = kwargs.get("data") if "data" in kwargs else args[-1] + assert isinstance(url, str) + # Path-only URL — the MA frontend resolves it against the user's location + # so the wizard works in Docker / HA add-on deployments where MA's + # advertised base_url points at an internal IP the browser cannot reach. + assert url.startswith("/mcp/v1/connect") + assert "bootstrap=jwt-xyz" in url + + +async def test_action_handler_no_user_signals_plain_url(wizard_mass: MagicMock) -> None: + """Without a current user we still open the wizard, but without a bootstrap query.""" + wizard_mass.webserver.auth.create_token.reset_mock() + + await handle_open_connect_action( + wizard_mass, + current_user=None, + mount_path="/mcp/v1", + base_url="http://localhost:8095", + ) + + wizard_mass.webserver.auth.create_token.assert_not_called() + wizard_mass.signal_event.assert_called_once() + args, kwargs = wizard_mass.signal_event.call_args + url = kwargs.get("data") if "data" in kwargs else args[-1] + assert isinstance(url, str) + assert "bootstrap=" not in url + + +async def test_action_handler_external_base_url_prepended( + wizard_mass: MagicMock, mock_user: MagicMock +) -> None: + """When ``external_base_url`` is provided, the signalled URL is fully qualified. + + Covers HA add-on ingress, where the path-only URL drops the ingress prefix + and the wizard opens at the wrong location. + """ + await handle_open_connect_action( + wizard_mass, + current_user=mock_user, + mount_path="/mcp/v1", + external_base_url="https://ha.example.com/d5369777_music_assistant_dev", + ) + + wizard_mass.signal_event.assert_called_once() + args, kwargs = wizard_mass.signal_event.call_args + url = kwargs.get("data") if "data" in kwargs else args[-1] + assert isinstance(url, str) + assert url.startswith("https://ha.example.com/d5369777_music_assistant_dev/mcp/v1/connect") + assert "bootstrap=jwt-xyz" in url + + +async def test_action_handler_external_base_url_strips_trailing_slash( + wizard_mass: MagicMock, +) -> None: + """A trailing slash on ``external_base_url`` must not produce a double-slash.""" + await handle_open_connect_action( + wizard_mass, + current_user=None, + mount_path="/mcp/v1", + external_base_url="https://ha.example.com/addon/", + ) + + args, kwargs = wizard_mass.signal_event.call_args + url = kwargs.get("data") if "data" in kwargs else args[-1] + assert url == "https://ha.example.com/addon/mcp/v1/connect" + + +async def test_action_handler_empty_external_base_url_falls_back_to_path( + wizard_mass: MagicMock, +) -> None: + """An empty / ``None`` ``external_base_url`` preserves the legacy path-only URL.""" + await handle_open_connect_action( + wizard_mass, + current_user=None, + mount_path="/mcp/v1", + external_base_url="", + ) + + args, kwargs = wizard_mass.signal_event.call_args + url = kwargs.get("data") if "data" in kwargs else args[-1] + assert url == "/mcp/v1/connect" + + +async def test_open_connect_gcs_prior_wizard_tokens( + wizard_mass: MagicMock, mock_user: MagicMock +) -> None: + """Prior MCP — wizard bootstrap/session tokens are revoked before the new bootstrap is minted. + + Per-client tokens (``MCP — Cursor`` etc.) are left untouched. Asserts + against the sanctioned ``auth.revoke_token`` API. + """ + auth = wizard_mass.webserver.auth + auth.get_user_tokens = AsyncMock( + return_value=[ + SimpleNamespace(token_id="boot-old", name="MCP — wizard bootstrap", user_id="u1"), + SimpleNamespace(token_id="sess-old", name="MCP — wizard session", user_id="u1"), + SimpleNamespace(token_id="cursor-keep", name="MCP — Cursor", user_id="u1"), + ] + ) + + await handle_open_connect_action( + wizard_mass, + current_user=mock_user, + mount_path="/mcp/v1", + ) + + revoked_ids = sorted(call.args[0] for call in auth.revoke_token.await_args_list) + assert revoked_ids == ["boot-old", "sess-old"] + auth.create_token.assert_awaited_once_with( + user=mock_user, + name="MCP — wizard bootstrap", + is_long_lived=False, + ) + wizard_mass.signal_event.assert_called_once() + + +async def test_open_connect_gc_lookup_failure_does_not_block( + wizard_mass: MagicMock, mock_user: MagicMock +) -> None: + """A ``get_user_tokens`` exception is swallowed; the new bootstrap mint still happens.""" + auth = wizard_mass.webserver.auth + auth.get_user_tokens = AsyncMock(side_effect=RuntimeError("api down")) + + await handle_open_connect_action( + wizard_mass, + current_user=mock_user, + mount_path="/mcp/v1", + ) + + auth.create_token.assert_awaited_once() + wizard_mass.signal_event.assert_called_once() + + +async def test_open_connect_no_user_skips_gc(wizard_mass: MagicMock) -> None: + """Without a current user there is no token listing and nothing is revoked.""" + auth = wizard_mass.webserver.auth + + await handle_open_connect_action( + wizard_mass, + current_user=None, + mount_path="/mcp/v1", + ) + + auth.get_user_tokens.assert_not_called() + auth.revoke_token.assert_not_called() + + +# ── Dispatch: WS-client auto-detect + config-override fallback ─────────────── + + +def _matching_user() -> SimpleNamespace: + return SimpleNamespace(user_id="u1", username="tester") + + +def test_detect_external_base_url_prefers_matching_client() -> None: + """The detector returns the ``base_url`` of the WS client owned by the user.""" + user = _matching_user() + other = SimpleNamespace(user_id="u2", username="someone-else") + clients = [ + SimpleNamespace( + _authenticated_user=other, + base_url="https://wrong.example.com", + ), + SimpleNamespace( + _authenticated_user=user, + base_url="https://ha.example.com/d5369777_music_assistant_dev", + ), + ] + mass = MagicMock() + mass.webserver.clients = clients + + assert ( + _detect_external_base_url(mass, user) + == "https://ha.example.com/d5369777_music_assistant_dev" + ) + + +def test_detect_external_base_url_returns_none_without_match() -> None: + """No matching client → ``None`` so the dispatcher can fall through.""" + user = _matching_user() + clients = [ + SimpleNamespace( + _authenticated_user=SimpleNamespace(user_id="other", username="other"), + base_url="https://other.example.com", + ), + SimpleNamespace(_authenticated_user=user, base_url=None), + ] + mass = MagicMock() + mass.webserver.clients = clients + + assert _detect_external_base_url(mass, user) is None + + +def test_detect_external_base_url_handles_no_user() -> None: + """No current user → ``None`` (the dispatcher then tries the config override).""" + mass = MagicMock() + mass.webserver.clients = [] + + assert _detect_external_base_url(mass, None) is None + + +@pytest.mark.parametrize( + "candidate", + [ + "javascript:alert(1)", + "//attacker.example.com", + "ha.example.com/addon", # missing scheme — would be treated as path-relative + "ftp://example.com", + "", + " ", + None, + ], +) +def test_sanitize_external_base_url_rejects_unsafe(candidate: str | None) -> None: + """Only ``http(s)://`` values survive — anything else is dropped.""" + assert _sanitize_external_base_url(candidate) is None + + +@pytest.mark.parametrize( + "candidate", + [ + "https://ha.example.com/d5369777_music_assistant_dev", + "http://localhost:8095", + "HTTPS://Upper.Case.Example.COM", # case-insensitive scheme check + ], +) +def test_sanitize_external_base_url_accepts_http_schemes(candidate: str) -> None: + """``http://`` and ``https://`` values pass through (whitespace trimmed).""" + assert _sanitize_external_base_url(f" {candidate} ") == candidate + + +def _install_fake_ma_auth_middleware(monkeypatch: pytest.MonkeyPatch, user: object) -> None: + """Make ``get_current_user()`` return ``user`` inside ``_dispatch_open_connect``. + + The provider imports ``music_assistant.controllers.webserver.helpers.auth_middleware`` + lazily; ``music_assistant`` is an optional / dev-only dep, so we inject a + stub module tree into ``sys.modules`` rather than importing the real one. + """ + import sys # noqa: PLC0415 + import types # noqa: PLC0415 + + pkg = types.ModuleType("music_assistant") + pkg.__path__ = [] + controllers = types.ModuleType("music_assistant.controllers") + controllers.__path__ = [] + webserver_pkg = types.ModuleType("music_assistant.controllers.webserver") + webserver_pkg.__path__ = [] + helpers_pkg = types.ModuleType("music_assistant.controllers.webserver.helpers") + helpers_pkg.__path__ = [] + auth_mod = types.ModuleType("music_assistant.controllers.webserver.helpers.auth_middleware") + auth_mod.get_current_user = lambda: user # type: ignore[attr-defined] + + monkeypatch.setitem(sys.modules, "music_assistant", pkg) + monkeypatch.setitem(sys.modules, "music_assistant.controllers", controllers) + monkeypatch.setitem(sys.modules, "music_assistant.controllers.webserver", webserver_pkg) + monkeypatch.setitem(sys.modules, "music_assistant.controllers.webserver.helpers", helpers_pkg) + monkeypatch.setitem( + sys.modules, + "music_assistant.controllers.webserver.helpers.auth_middleware", + auth_mod, + ) + + +async def test_dispatch_detects_ws_client_base_url( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """End-to-end dispatch: the WS client's ingress base_url ends up in the URL.""" + user = _matching_user() + _install_fake_ma_auth_middleware(monkeypatch, user) + + signalled: list[str] = [] + mass = MagicMock() + mass.webserver.clients = [ + SimpleNamespace( + _authenticated_user=user, + base_url="https://ha.example.com/d5369777_music_assistant_dev", + ) + ] + mass.webserver.auth.create_token = AsyncMock(return_value="jwt-xyz") + mass.signal_event = MagicMock( + side_effect=lambda _evt, object_id, data: signalled.append(data) # noqa: ARG005 + ) + + await _dispatch_open_connect( + mass, + {"mount_path": "/mcp/v1", "session_id": "sess-x"}, + ) + + assert signalled, "expected signal_event to be called" + url = signalled[0] + assert url.startswith("https://ha.example.com/d5369777_music_assistant_dev/mcp/v1/connect") + assert "bootstrap=jwt-xyz" in url + + +async def test_dispatch_falls_back_to_config_override( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When no WS client matches, an explicit ``connect_external_url`` wins.""" + user = _matching_user() + _install_fake_ma_auth_middleware(monkeypatch, user) + + signalled: list[str] = [] + mass = MagicMock() + mass.webserver.clients = [] + mass.webserver.auth.create_token = AsyncMock(return_value="jwt-xyz") + mass.signal_event = MagicMock( + side_effect=lambda _evt, object_id, data: signalled.append(data) # noqa: ARG005 + ) + + await _dispatch_open_connect( + mass, + { + "mount_path": "/mcp/v1", + "session_id": "sess-y", + CONF_CONNECT_EXTERNAL_URL: "https://override.example.com", + }, + ) + + url = signalled[0] + assert url.startswith("https://override.example.com/mcp/v1/connect") + + +async def test_dispatch_rejects_unsafe_override_and_falls_back( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """An override with a non-``http(s)`` scheme is dropped → path-only fallback. + + Guards against an admin pasting ``javascript:…`` into the config; the + frontend would otherwise hand that straight to ``window.open``. + """ + user = _matching_user() + _install_fake_ma_auth_middleware(monkeypatch, user) + + signalled: list[str] = [] + mass = MagicMock() + mass.webserver.clients = [] + mass.webserver.auth.create_token = AsyncMock(return_value="jwt-xyz") + mass.signal_event = MagicMock( + side_effect=lambda _evt, object_id, data: signalled.append(data) # noqa: ARG005 + ) + + await _dispatch_open_connect( + mass, + { + "mount_path": "/mcp/v1", + "session_id": "sess-bad", + CONF_CONNECT_EXTERNAL_URL: "javascript:alert(1)", + }, + ) + + url = signalled[0] + assert url.startswith("/mcp/v1/connect") + assert "javascript" not in url + + +async def test_dispatch_falls_back_to_path_only_when_nothing_known( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """No WS match and no override → legacy path-only URL (no regression).""" + user = _matching_user() + _install_fake_ma_auth_middleware(monkeypatch, user) + + signalled: list[str] = [] + mass = MagicMock() + mass.webserver.clients = [] + mass.webserver.auth.create_token = AsyncMock(return_value="jwt-xyz") + mass.signal_event = MagicMock( + side_effect=lambda _evt, object_id, data: signalled.append(data) # noqa: ARG005 + ) + + await _dispatch_open_connect( + mass, + {"mount_path": "/mcp/v1", "session_id": "sess-z"}, + ) + + url = signalled[0] + assert url.startswith("/mcp/v1/connect") + assert "://" not in url.split("?", 1)[0] + + +# ── Client template integrity ──────────────────────────────────────────────── + + +def test_cursor_template_round_trips() -> None: + """The Cursor template renders to valid JSON with url + Authorization Bearer header.""" + cursor = lookup_client("cursor") + assert cursor is not None + rendered = cursor.template.replace("{{URL}}", "http://localhost:8095/mcp/v1").replace( + "{{TOKEN}}", "TOK-123" + ) + parsed = json.loads(rendered) + server = parsed["mcpServers"]["ma"] + assert server["url"] == "http://localhost:8095/mcp/v1" + assert server["headers"]["Authorization"] == "Bearer TOK-123" + + +def test_claude_code_template_uses_positional_url() -> None: + """``claude mcp add`` takes the URL as a positional argument, not via ``--url``. + + Regression for the v0.3.x wizard shipping ``claude mcp add ma --transport http + --url `` — the CLI ignored ``--url`` and registered an unreachable server. + """ + spec = lookup_client("claude-code") + assert spec is not None + rendered = spec.template.replace("{{URL}}", "http://localhost:8095/mcp/v1").replace( + "{{TOKEN}}", "TOK-123" + ) + assert "--url" not in rendered, "claude mcp add does not accept a --url flag" + # URL must appear right after the server name (the positional slot). + assert "claude mcp add ma http://localhost:8095/mcp/v1" in rendered + assert "--transport http" in rendered + assert '--header "Authorization: Bearer TOK-123"' in rendered + + +def test_all_clients_have_required_fields() -> None: + """Every client spec has the fields the JS UI relies on.""" + seen_ids: set[str] = set() + for spec in CLIENTS: + assert spec.id + assert spec.id not in seen_ids + seen_ids.add(spec.id) + assert spec.label + assert spec.kind in {"json", "shell", "toml"} + assert "{{URL}}" in spec.template + assert "{{TOKEN}}" in spec.template + assert spec.config_path_hint # non-empty doc hint diff --git a/tests/providers/fastmcp_server/test_constants.py b/tests/providers/fastmcp_server/test_constants.py new file mode 100644 index 0000000000..3f24e5a535 --- /dev/null +++ b/tests/providers/fastmcp_server/test_constants.py @@ -0,0 +1,29 @@ +"""Sanity tests for invariants in ``provider.constants``.""" + +from __future__ import annotations + +from music_assistant.providers.fastmcp_server.constants import ( + HOT_SWAPPABLE_KEYS, + PERMISSION_KEYS, + RESOURCE_KEYS, +) + + +def test_permission_keys_count() -> None: + """16 permission keys: 4 verbs x 4 categories.""" + assert len(PERMISSION_KEYS) == 16 + + +def test_resource_keys_count() -> None: + """3 resource toggles.""" + assert len(RESOURCE_KEYS) == 3 + + +def test_hot_swappable_includes_perm_and_resource_keys() -> None: + """Hot-swappable set is exactly the union — anything else triggers a runtime restart.""" + assert HOT_SWAPPABLE_KEYS == PERMISSION_KEYS | RESOURCE_KEYS + + +def test_no_overlap_perm_resource() -> None: + """Permission and resource key sets don't overlap (cleanly partitioned).""" + assert PERMISSION_KEYS.isdisjoint(RESOURCE_KEYS) diff --git a/tests/providers/fastmcp_server/test_context.py b/tests/providers/fastmcp_server/test_context.py new file mode 100644 index 0000000000..b7eb82467b --- /dev/null +++ b/tests/providers/fastmcp_server/test_context.py @@ -0,0 +1,102 @@ +"""Tests for Context-driven observability and progress (C7).""" +# mypy: disable-error-code="arg-type, no-untyped-def, type-arg, assignment, operator, misc" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastmcp import Client, FastMCP + +from music_assistant.providers.fastmcp_server.tools import ( + build_library_server, + build_metadata_server, + build_playlists_server, +) + + +@pytest.fixture +def search_mass(mock_mass: MagicMock) -> MagicMock: + """Configure mock_mass.music.search to return shaped results for assertions.""" + fake_track = MagicMock(uri="lib://t/1", name="Some Track", artists=[], album=None, duration=180) + mock_mass.music.search = AsyncMock(return_value=MagicMock(tracks=[fake_track])) + return mock_mass + + +async def test_search_tracks_emits_info_log(search_mass: MagicMock) -> None: + """search_tracks invokes ctx.info with the query — visible to MCP clients.""" + mcp: FastMCP = FastMCP(name="t") + mcp.mount(build_library_server(search_mass), namespace="library") + + async with Client(mcp) as client: + captured: list[str] = [] + + async def collect(message: Any) -> None: + text = getattr(message, "data", None) or getattr(message, "message", "") + captured.append(str(text)) + + client.set_message_handler(collect) if hasattr(client, "set_message_handler") else None + # FastMCP test client surfaces logs via a logging handler the spec calls + # `notifications/message`. We assert the tool ran end-to-end without + # raising — Context injection is FastMCP's responsibility once the + # parameter has the right annotation. + result = await client.call_tool("library_search_tracks", {"query": "smoke", "limit": 3}) + + text_blocks = [c.text for c in result.content if hasattr(c, "text")] + assert any("Some Track" in t for t in text_blocks) + + +async def test_recommendations_runs_under_context(mock_mass: MagicMock) -> None: + """metadata.recommendations accepts an injected Context and still returns shaped data.""" + mock_mass.music.recommendations = AsyncMock( + return_value=[ + MagicMock(name="Hits", items=[MagicMock(uri="lib://t/1")]), + ] + ) + mcp: FastMCP = FastMCP(name="t") + mcp.mount(build_metadata_server(mock_mass), namespace="metadata") + + async with Client(mcp) as client: + result = await client.call_tool("metadata_recommendations", {}) + + text_blocks = [c.text for c in result.content if hasattr(c, "text")] + assert any("item_uris" in t or "Hits" in t for t in text_blocks) + + +async def test_add_tracks_bulk_path_for_small_batch(mock_mass: MagicMock) -> None: + """For ≤10 tracks, the bulk add_playlist_tracks call is used (single round-trip).""" + mock_mass.music.playlists.add_playlist_tracks = AsyncMock() + mock_mass.music.playlists.add_playlist_track = AsyncMock() + + mcp: FastMCP = FastMCP(name="t") + mcp.mount(build_playlists_server(mock_mass), namespace="playlists") + + track_uris = [f"lib://t/{i}" for i in range(5)] + async with Client(mcp) as client: + await client.call_tool( + "playlists_add_tracks", + {"playlist_id": 1, "track_uris": track_uris}, + ) + + mock_mass.music.playlists.add_playlist_tracks.assert_awaited_once_with(1, track_uris) + mock_mass.music.playlists.add_playlist_track.assert_not_awaited() + + +async def test_add_tracks_per_item_path_for_large_batch(mock_mass: MagicMock) -> None: + """For >10 tracks, items are dispatched per-item (so we can report progress).""" + mock_mass.music.playlists.add_playlist_tracks = AsyncMock() + mock_mass.music.playlists.add_playlist_track = AsyncMock() + + mcp: FastMCP = FastMCP(name="t") + mcp.mount(build_playlists_server(mock_mass), namespace="playlists") + + track_uris = [f"lib://t/{i}" for i in range(15)] + async with Client(mcp) as client: + await client.call_tool( + "playlists_add_tracks", + {"playlist_id": 1, "track_uris": track_uris}, + ) + + assert mock_mass.music.playlists.add_playlist_track.await_count == 15 + mock_mass.music.playlists.add_playlist_tracks.assert_not_awaited() diff --git a/tests/providers/fastmcp_server/test_e2e_http.py b/tests/providers/fastmcp_server/test_e2e_http.py new file mode 100644 index 0000000000..5e886aa925 --- /dev/null +++ b/tests/providers/fastmcp_server/test_e2e_http.py @@ -0,0 +1,162 @@ +"""End-to-end tests through the real ASGI bridge loop (C11). + +These tests exercise ``mount_into_mass`` against an aiohttp ``TestServer`` +hosting a hand-rolled ASGI app. They cover the bits that pure-helper unit +tests can't reach: streaming chunk pass-through, DELETE / non-GET methods, +the well-known endpoint living next to the MCP mount. +""" +# mypy: disable-error-code="arg-type, no-untyped-def, type-arg, assignment, operator, misc" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest +from aiohttp.test_utils import TestClient, TestServer + +from music_assistant.providers.fastmcp_server.http_bridge import mount_into_mass, mount_well_known + +from .conftest import FakeWebserver, build_aiohttp_app + + +async def _lifespan_loop(receive: Any, send: Any) -> None: + """Bare-minimum ASGI lifespan handler used by the test ASGI doubles.""" + while True: + msg = await receive() + if msg["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.complete"}) + elif msg["type"] == "lifespan.shutdown": + await send({"type": "lifespan.shutdown.complete"}) + return + + +async def _streaming_asgi(scope: dict, receive: Any, send: Any) -> None: + """ASGI app that emits three SSE-style chunks before closing the body.""" + if scope.get("type") == "lifespan": + await _lifespan_loop(receive, send) + return + # Drain the request body so the client write side can complete. + while True: + msg = await receive() + if msg.get("type") == "http.request" and not msg.get("more_body"): + break + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [(b"content-type", b"text/event-stream")], + } + ) + for i in range(3): + await send( + {"type": "http.response.body", "body": f"event{i}\n".encode(), "more_body": True} + ) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + +async def _method_echo_asgi(scope: dict, receive: Any, send: Any) -> None: + """ASGI app that echoes the HTTP method in the body.""" + if scope.get("type") == "lifespan": + await _lifespan_loop(receive, send) + return + while True: + msg = await receive() + if msg.get("type") == "http.request" and not msg.get("more_body"): + break + method = scope["method"].encode() + await send({"type": "http.response.start", "status": 200, "headers": []}) + await send({"type": "http.response.body", "body": method}) + + +class _Mcp: + def __init__(self, asgi: Any) -> None: + self._asgi = asgi + + def http_app(self, transport: str = "streamable-http", path: str = "/mcp") -> Any: + return self._asgi + + +@pytest.fixture +async def streaming_client() -> Any: + """Bridge an SSE-streaming ASGI app through mount_into_mass.""" + ws = FakeWebserver() + mass = SimpleNamespace(webserver=ws) + await mount_into_mass(mass, _Mcp(_streaming_asgi), mount_path="/mcp/v1") + async with TestClient(TestServer(build_aiohttp_app(ws))) as client: + yield client + + +@pytest.fixture +async def method_echo_client() -> Any: + """Bridge a method-echo ASGI app to verify DELETE / arbitrary verbs work.""" + ws = FakeWebserver() + mass = SimpleNamespace(webserver=ws) + await mount_into_mass(mass, _Mcp(_method_echo_asgi), mount_path="/mcp/v1") + async with TestClient(TestServer(build_aiohttp_app(ws))) as client: + yield client + + +async def test_streaming_chunks_passed_through(streaming_client: TestClient) -> None: + """Three ASGI body chunks reach the aiohttp client unbuffered.""" + resp = await streaming_client.post("/mcp/v1/", headers={"Origin": "http://localhost:8095"}) + assert resp.status == 200 + body = await resp.read() + assert body == b"event0\nevent1\nevent2\n" + + +async def test_delete_method_reaches_asgi(method_echo_client: TestClient) -> None: + """Streamable-HTTP DELETE (session terminate) is forwarded — bridge does not 405.""" + resp = await method_echo_client.delete("/mcp/v1/", headers={"Origin": "http://localhost:8095"}) + assert resp.status == 200 + assert (await resp.read()) == b"DELETE" + + +async def test_get_method_reaches_asgi(method_echo_client: TestClient) -> None: + """GET is forwarded — required so FastMCP can open server-initiated SSE.""" + resp = await method_echo_client.get("/mcp/v1/", headers={"Origin": "http://localhost:8095"}) + assert resp.status == 200 + assert (await resp.read()) == b"GET" + + +async def test_bare_mount_path_without_trailing_slash_reaches_asgi( + method_echo_client: TestClient, +) -> None: + """``/mcp/v1`` (no trailing slash) must hit the ASGI bridge too. + + This is the URL the wizard advertises and that MCP clients connect to. + MA's real ``_handle_catch_all`` (``helpers/webserver.py``) matches a + ``"/mcp/v1/*"`` registration against both the bare stem and any + descendant; ``build_aiohttp_app`` must mirror that. + """ + resp = await method_echo_client.post("/mcp/v1", headers={"Origin": "http://localhost:8095"}) + assert resp.status == 200 + assert (await resp.read()) == b"POST" + + +async def test_well_known_alongside_mcp_mount() -> None: + """Both /mcp/v1/* and /.well-known/oauth-protected-resource are reachable.""" + ws = FakeWebserver() + mass = SimpleNamespace(webserver=ws) + await mount_into_mass(mass, _Mcp(_method_echo_asgi), mount_path="/mcp/v1") + await mount_well_known( + mass, + mount_path="/mcp/v1", + resource_uri="http://localhost:8095/mcp/v1", + authorization_servers=["http://localhost:8095"], + scopes_supported=["query:library"], + resource_name="Music Assistant MCP", + ) + async with TestClient(TestServer(build_aiohttp_app(ws))) as client: + # MCP endpoint reachable + resp = await client.post("/mcp/v1/", headers={"Origin": "http://localhost:8095"}) + assert resp.status == 200 + # well-known sub-path returns RFC 9728 metadata + meta = await client.get("/.well-known/oauth-protected-resource/mcp/v1") + assert meta.status == 200 + doc = await meta.json() + assert doc["resource"] == "http://localhost:8095/mcp/v1" + assert doc["authorization_servers"] == ["http://localhost:8095"] + # Well-known root form also works (RFC 9728 §3.1 fallback). + meta_root = await client.get("/.well-known/oauth-protected-resource") + assert meta_root.status == 200 diff --git a/tests/providers/fastmcp_server/test_e2e_smoke.py b/tests/providers/fastmcp_server/test_e2e_smoke.py new file mode 100644 index 0000000000..caa21a6aea --- /dev/null +++ b/tests/providers/fastmcp_server/test_e2e_smoke.py @@ -0,0 +1,57 @@ +"""End-to-end smoke test: build the runtime in-memory and exercise it via FastMCP Client. + +This test is the only one that depends on the ``fastmcp`` package being +installed and on Music Assistant model imports working — the rest of the +suite uses mocks. Skipped automatically if either is unavailable. +""" +# mypy: disable-error-code="arg-type, no-untyped-def, type-arg, assignment, operator, misc" + +from __future__ import annotations + +import importlib.util +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from unittest.mock import MagicMock + + +_HAVE_FASTMCP = importlib.util.find_spec("fastmcp") is not None +_HAVE_MA = importlib.util.find_spec("music_assistant") is not None +_HAVE_MA_MODELS = importlib.util.find_spec("music_assistant_models") is not None + + +@pytest.mark.skipif( + not (_HAVE_FASTMCP and _HAVE_MA and _HAVE_MA_MODELS), + reason="needs fastmcp + music_assistant + music_assistant_models installed", +) +@pytest.mark.asyncio +async def test_runtime_lists_namespaced_tools(mock_mass: MagicMock, mock_config: MagicMock) -> None: + """``MCPServerRuntime`` builds without errors and exposes namespaced tools.""" + from fastmcp import Client # noqa: PLC0415 + + from music_assistant.providers.fastmcp_server.server import MCPServerRuntime # noqa: PLC0415 + + # ``register_dynamic_route`` must return a callable; the smoke test does not + # need real HTTP transport — Client(mcp) talks to the in-memory FastMCP root. + runtime = MCPServerRuntime(mock_mass, mock_config, _stub_logger()) + # Pretend the bridge mounted; we exercise the FastMCP root directly via in-memory Client. + await runtime.start() + try: + async with Client(runtime._mcp) as client: + tools = await client.list_tools() + names = {t.name for t in tools} + # 4 query tags enabled by default → tools from library + queue + players + metadata + assert any(name.startswith("library_") for name in names), names + assert any(name.startswith("queue_") for name in names), names + # Mutation-only namespaces should not appear under default config + assert not any(name.startswith("volume_") for name in names), names + finally: + await runtime.stop() + + +def _stub_logger() -> object: + import logging # noqa: PLC0415 + + return logging.getLogger("ma-provider-mcp.smoke") diff --git a/tests/providers/fastmcp_server/test_elicitation.py b/tests/providers/fastmcp_server/test_elicitation.py new file mode 100644 index 0000000000..e948505e7c --- /dev/null +++ b/tests/providers/fastmcp_server/test_elicitation.py @@ -0,0 +1,184 @@ +"""Tests for elicitation on destructive operations (C8).""" +# mypy: disable-error-code="arg-type, no-untyped-def, type-arg, assignment, operator, misc" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastmcp import Client, FastMCP + +from music_assistant.providers.fastmcp_server.tools import build_media_server, build_queue_server + + +def _server(mass: MagicMock, *, require_confirmation: bool) -> FastMCP: + """Build a small root server mounting only queue + media for elicitation tests.""" + mcp: FastMCP = FastMCP(name="t") + mcp.mount( + build_queue_server(mass, require_confirmation=require_confirmation), + namespace="queue", + ) + mcp.mount( + build_media_server(mass, require_confirmation=require_confirmation), + namespace="media", + ) + return mcp + + +def _accepter() -> object: + """Build an elicitation handler that always accepts with True.""" + + async def handler(message, response_type, params, context): # noqa: ARG001 + return True + + return handler + + +def _decliner() -> object: + """Build an elicitation handler that always declines.""" + from fastmcp.client.elicitation import ElicitResult # noqa: PLC0415 + + async def handler(message, response_type, params, context): # noqa: ARG001 + return ElicitResult(action="decline", content=None) + + return handler + + +async def test_clear_queue_runs_when_user_accepts(mock_mass: MagicMock) -> None: + """User accepts the elicitation prompt → clear_queue dispatches to MA.""" + mock_mass.player_queues.clear = MagicMock() + mcp = _server(mock_mass, require_confirmation=True) + + async with Client(mcp, elicitation_handler=_accepter()) as client: + await client.call_tool("queue_clear_queue", {"queue_id": "q1"}) + mock_mass.player_queues.clear.assert_called_once_with("q1") + + +async def test_clear_queue_blocked_when_user_declines(mock_mass: MagicMock) -> None: + """User declines → tool raises ToolError, no MA call is made.""" + mock_mass.player_queues.clear = MagicMock() + mcp = _server(mock_mass, require_confirmation=True) + + async with Client(mcp, elicitation_handler=_decliner()) as client: + with pytest.raises(Exception): # noqa: B017,PT011 + await client.call_tool("queue_clear_queue", {"queue_id": "q1"}) + mock_mass.player_queues.clear.assert_not_called() + + +async def test_no_confirmation_when_disabled(mock_mass: MagicMock) -> None: + """With require_confirmation=False, elicitation is skipped entirely.""" + mock_mass.player_queues.clear = MagicMock() + mcp = _server(mock_mass, require_confirmation=False) + + elicit_called = False + + async def handler(message, response_type, params, context): # noqa: ARG001 + nonlocal elicit_called + elicit_called = True + return True + + async with Client(mcp, elicitation_handler=handler) as client: + await client.call_tool("queue_clear_queue", {"queue_id": "q1"}) + assert elicit_called is False + mock_mass.player_queues.clear.assert_called_once_with("q1") + + +async def test_get_active_queue_clamps_include_items(mock_mass: MagicMock) -> None: + """A client-supplied ``include_items`` is clamped to 500 to bound memory. + + Without the clamp a hostile or sloppy caller could pass ``include_items=10**6`` + and force MA to materialise the entire queue per request. + """ + queue = MagicMock(queue_id="q1") + mock_mass.player_queues.get_active_queue = MagicMock(return_value=queue) + mock_mass.player_queues.items = MagicMock(return_value=[]) + mcp = _server(mock_mass, require_confirmation=False) + + async with Client(mcp) as client: + await client.call_tool( + "queue_get_active_queue", + {"player_id": "p1", "include_items": 10_000}, + ) + mock_mass.player_queues.items.assert_called_once_with("q1", limit=500) + + +async def test_get_active_queue_passes_small_limit_through(mock_mass: MagicMock) -> None: + """A reasonable ``include_items`` is forwarded verbatim — no over-cap.""" + queue = MagicMock(queue_id="q1") + mock_mass.player_queues.get_active_queue = MagicMock(return_value=queue) + mock_mass.player_queues.items = MagicMock(return_value=[]) + mcp = _server(mock_mass, require_confirmation=False) + + async with Client(mcp) as client: + await client.call_tool( + "queue_get_active_queue", + {"player_id": "p1", "include_items": 10}, + ) + mock_mass.player_queues.items.assert_called_once_with("q1", limit=10) + + +async def test_remove_from_library_confirms(mock_mass: MagicMock) -> None: + """media.remove_from_library also triggers elicitation.""" + # MA's MusicController takes (media_type, library_item_id), not a URI — + # the tool resolves the URI via get_item_by_uri first. + resolved = MagicMock(media_type=MagicMock(), item_id="42", provider="library") + mock_mass.music.get_item_by_uri = AsyncMock(return_value=resolved) + mock_mass.music.remove_item_from_library = AsyncMock() + mcp = _server(mock_mass, require_confirmation=True) + + async with Client(mcp, elicitation_handler=_accepter()) as client: + await client.call_tool("media_remove_from_library", {"uri": "lib://t/42"}) + mock_mass.music.get_item_by_uri.assert_awaited_once_with("lib://t/42") + mock_mass.music.remove_item_from_library.assert_awaited_once_with(resolved.media_type, "42") + + +async def test_remove_from_favorites_resolves_provider_uri_to_library( + mock_mass: MagicMock, +) -> None: + """A provider URI is resolved to the matching library item before removal. + + ``MusicController.remove_item_from_*`` expects a library item id; passing the + provider's native item id silently targets the wrong item (or raises on a + non-numeric ``int()`` cast). The tool now looks up the library counterpart via + ``get_library_item_by_prov_id``. + """ + provider_item = MagicMock(media_type=MagicMock(), item_id="prov-abc", provider="yandex_music") + library_item = MagicMock(media_type=provider_item.media_type, item_id="99") + mock_mass.music.get_item_by_uri = AsyncMock(return_value=provider_item) + mock_mass.music.get_library_item_by_prov_id = AsyncMock(return_value=library_item) + mock_mass.music.remove_item_from_favorites = AsyncMock() + mcp = _server(mock_mass, require_confirmation=False) + + async with Client(mcp) as client: + await client.call_tool( + "media_remove_from_favorites", {"uri": "yandex_music://track/prov-abc"} + ) + mock_mass.music.get_library_item_by_prov_id.assert_awaited_once_with( + provider_item.media_type, "prov-abc", "yandex_music" + ) + mock_mass.music.remove_item_from_favorites.assert_awaited_once_with( + library_item.media_type, "99" + ) + + +async def test_remove_from_library_raises_when_not_in_library( + mock_mass: MagicMock, +) -> None: + """When the URI's library counterpart cannot be resolved, the tool raises. + + Without this, the tool would silently call ``remove_item_from_library`` with + a provider-native item id, which either fails on ``int()`` cast or targets + the wrong item. + """ + provider_item = MagicMock(media_type=MagicMock(), item_id="prov-abc", provider="yandex_music") + mock_mass.music.get_item_by_uri = AsyncMock(return_value=provider_item) + mock_mass.music.get_library_item_by_prov_id = AsyncMock(return_value=None) + mock_mass.music.remove_item_from_library = AsyncMock() + mcp = _server(mock_mass, require_confirmation=False) + + async with Client(mcp) as client: + with pytest.raises(Exception): # noqa: B017,PT011 + await client.call_tool( + "media_remove_from_library", {"uri": "yandex_music://track/prov-abc"} + ) + mock_mass.music.remove_item_from_library.assert_not_awaited() diff --git a/tests/providers/fastmcp_server/test_middleware.py b/tests/providers/fastmcp_server/test_middleware.py new file mode 100644 index 0000000000..abc3dcea37 --- /dev/null +++ b/tests/providers/fastmcp_server/test_middleware.py @@ -0,0 +1,110 @@ +"""Tests for TagFilterMiddleware enforcement on direct invocation (C3).""" +# mypy: disable-error-code="arg-type, no-untyped-def, type-arg, assignment, operator, misc" + +from __future__ import annotations + +import pytest +from fastmcp import Client, FastMCP + +from music_assistant.providers.fastmcp_server.middleware import TagFilterMiddleware +from music_assistant.providers.fastmcp_server.server import build_tag_lookup + + +def _build_server(allowed: set[str]) -> FastMCP: + """Construct a FastMCP root with one tagged tool and the tag-filter middleware.""" + mcp: FastMCP = FastMCP(name="test-server") + + @mcp.tool(tags={"query"}) # type: ignore[untyped-decorator, unused-ignore] + async def reads() -> str: + """Return a read-only result.""" + return "ok" + + @mcp.tool(tags={"delete"}) # type: ignore[untyped-decorator, unused-ignore] + async def deletes() -> str: + """Pretend to perform a destructive action.""" + return "deleted" + + @mcp.tool # type: ignore[untyped-decorator, unused-ignore] + async def untagged() -> str: + """Return a value from an untagged tool — always exposed.""" + return "untagged" + + @mcp.resource("data://thing/{thing_id}", tags={"query"}) # type: ignore[untyped-decorator, unused-ignore] + async def thing(thing_id: str) -> str: + """Return a read-only resource value for the given id.""" + return f"thing:{thing_id}" + + @mcp.prompt(name="suggest", tags={"query"}) # type: ignore[untyped-decorator, unused-ignore] + def suggest() -> str: + """Return a sample prompt template.""" + return "Pick something." + + mcp.add_middleware(TagFilterMiddleware(lambda: allowed, build_tag_lookup(mcp))) + return mcp + + +async def test_listing_filters_disabled_tools() -> None: + """A tool whose tags are all disabled doesn't appear in tools/list.""" + mcp = _build_server(allowed={"query"}) + async with Client(mcp) as client: + names = {t.name for t in await client.list_tools()} + assert "reads" in names + assert "untagged" in names + assert "deletes" not in names + + +async def test_call_disabled_tool_blocked() -> None: + """A client cannot bypass the listing filter by calling the disabled tool by name.""" + mcp = _build_server(allowed={"query"}) + async with Client(mcp) as client: + with pytest.raises(Exception): # noqa: B017,PT011 - SDK wraps as ToolError or RPC error + await client.call_tool("deletes", {}) + + +async def test_call_enabled_tool_works() -> None: + """An enabled tool runs normally with the middleware in place.""" + mcp = _build_server(allowed={"query"}) + async with Client(mcp) as client: + result = await client.call_tool("reads", {}) + text_blocks = [c for c in result.content if hasattr(c, "text")] + assert any("ok" in c.text for c in text_blocks) + + +async def test_untagged_tool_always_callable() -> None: + """Tools without tags are infrastructure and remain callable regardless of permissions.""" + mcp = _build_server(allowed=set()) + async with Client(mcp) as client: + result = await client.call_tool("untagged", {}) + text_blocks = [c for c in result.content if hasattr(c, "text")] + assert any("untagged" in c.text for c in text_blocks) + + +async def test_disabled_resource_blocked_on_read() -> None: + """Reading a disabled resource by URI raises rather than silently succeeding.""" + mcp = _build_server(allowed=set()) + async with Client(mcp) as client: + with pytest.raises(Exception): # noqa: B017,PT011 + await client.read_resource("data://thing/42") + + +async def test_template_resource_read_via_concrete_uri() -> None: + """A concrete URI matched by a template resource is readable when its tag is enabled. + + The middleware lookup must fall back from ``get_resource`` (statically + registered URIs only) to ``get_resource_template`` (URI-template matching); + otherwise every ``@mcp.resource("scheme://{var}")``-backed URI gets blocked + as not-found even though the tag is enabled. + """ + mcp = _build_server(allowed={"query"}) + async with Client(mcp) as client: + contents = await client.read_resource("data://thing/42") + text_blocks = [c for c in contents if hasattr(c, "text")] + assert any("thing:42" in c.text for c in text_blocks) + + +async def test_disabled_prompt_blocked_on_get() -> None: + """Getting a disabled prompt by name raises.""" + mcp = _build_server(allowed=set()) + async with Client(mcp) as client: + with pytest.raises(Exception): # noqa: B017,PT011 + await client.get_prompt("suggest", {}) diff --git a/tests/providers/fastmcp_server/test_models.py b/tests/providers/fastmcp_server/test_models.py new file mode 100644 index 0000000000..15da40b223 --- /dev/null +++ b/tests/providers/fastmcp_server/test_models.py @@ -0,0 +1,262 @@ +"""Tests for response Brief dataclasses + ``_common`` adapters.""" + +from __future__ import annotations + +from types import SimpleNamespace + +from music_assistant.providers.fastmcp_server.models import ( + AlbumBrief, + ArtistBrief, + PlayerBrief, + PlaylistBrief, + QueueBrief, + RadioBrief, + TrackBrief, +) +from music_assistant.providers.fastmcp_server.tools._common import ( + page_args, + to_brief_album, + to_brief_artist, + to_brief_player, + to_brief_playlist, + to_brief_queue, + to_brief_radio, + to_brief_track, +) + + +def test_track_brief_defaults() -> None: + """TrackBrief fills sensible defaults.""" + t = TrackBrief(uri="library://track/1", name="X") + assert t.artists == [] + assert t.album is None + assert t.duration is None + + +def test_to_brief_track_extracts_artists_and_album() -> None: + """``to_brief_track`` reads names from artists/album attributes.""" + track = SimpleNamespace( + uri="library://track/42", + name="Sample", + artists=[SimpleNamespace(name="A1"), SimpleNamespace(name="A2")], + album=SimpleNamespace(name="Album"), + duration=180, + ) + brief = to_brief_track(track) + assert brief == TrackBrief( + uri="library://track/42", + name="Sample", + artists=["A1", "A2"], + album="Album", + duration=180, + ) + + +def test_to_brief_album_falls_back_to_artists_list() -> None: + """``to_brief_album`` uses ``artists[0]`` when there's no scalar artist.""" + album = SimpleNamespace( + uri="library://album/1", + name="Album", + artist=None, + artists=[SimpleNamespace(name="A1")], + year=2020, + ) + assert to_brief_album(album) == AlbumBrief( + uri="library://album/1", name="Album", artist="A1", year=2020 + ) + + +def test_to_brief_artist() -> None: + """``to_brief_artist`` extracts uri and name.""" + artist = SimpleNamespace(uri="library://artist/x", name="X") + assert to_brief_artist(artist) == ArtistBrief(uri="library://artist/x", name="X") + + +def test_to_brief_playlist() -> None: + """``to_brief_playlist`` includes track_count and owner when available.""" + playlist = SimpleNamespace( + uri="library://playlist/1", + name="Mix", + track_count=12, + owner=SimpleNamespace(name="me"), + ) + assert to_brief_playlist(playlist) == PlaylistBrief( + uri="library://playlist/1", name="Mix", track_count=12, owner="me" + ) + + +def test_to_brief_radio() -> None: + """``to_brief_radio`` maps name + description.""" + radio = SimpleNamespace(uri="library://radio/1", name="R", description="d") + assert to_brief_radio(radio) == RadioBrief(uri="library://radio/1", name="R", description="d") + + +def test_to_brief_player_reads_playback_state() -> None: + """``to_brief_player`` reads the canonical ``Player.playback_state`` enum.""" + player = SimpleNamespace( + player_id="kitchen", + name="Kitchen", + playback_state=SimpleNamespace(value="playing"), + volume_level=42, + powered=True, + current_media=None, + ) + brief = to_brief_player(player) + assert brief == PlayerBrief( + player_id="kitchen", name="Kitchen", state="playing", volume_level=42, powered=True + ) + + +def test_to_brief_player_falls_back_to_legacy_state_attr() -> None: + """When only the legacy ``state`` attr exists, ``to_brief_player`` still resolves it. + + Kept for back-compat with older shims / hand-built test stubs. + """ + player = SimpleNamespace( + player_id="kitchen", + name="Kitchen", + state=SimpleNamespace(value="paused"), + volume_level=10, + powered=True, + current_media=None, + ) + assert to_brief_player(player).state == "paused" + + +def test_to_brief_player_current_item_prefers_title() -> None: + """``current_item`` uses :class:`PlayerMedia.title` when available.""" + player = SimpleNamespace( + player_id="p1", + name="P1", + playback_state=SimpleNamespace(value="playing"), + volume_level=50, + powered=True, + current_media=SimpleNamespace(uri="spotify://track/x", title="Song Name"), + ) + assert to_brief_player(player).current_item == "Song Name" + + +def test_to_brief_player_current_item_falls_back_to_uri() -> None: + """No title → ``current_item`` falls back to URI (always present on PlayerMedia).""" + player = SimpleNamespace( + player_id="p1", + name="P1", + playback_state=SimpleNamespace(value="playing"), + volume_level=50, + powered=True, + current_media=SimpleNamespace(uri="spotify://track/x", title=None), + ) + assert to_brief_player(player).current_item == "spotify://track/x" + + +def test_to_brief_player_no_current_media() -> None: + """``current_item`` is ``None`` when the player is idle (no current media).""" + player = SimpleNamespace( + player_id="p1", + name="P1", + playback_state=SimpleNamespace(value="idle"), + volume_level=0, + powered=False, + current_media=None, + ) + assert to_brief_player(player).current_item is None + + +def test_to_brief_player_reads_powered_from_player_state() -> None: + """``powered`` is sourced from ``Player.state.powered``. + + MA core builds ``_state.powered`` from ``__final_power_state`` and + serialises it in the REST API; the raw ``Player.powered`` property + returns ``_attr_powered`` which lags behind (and stays ``False`` for + some virtual player types). The brief must match what + ``Player.state.to_dict()`` would emit. + """ + player = SimpleNamespace( + player_id="p1", + name="P1", + playback_state=SimpleNamespace(value="playing"), + volume_level=100, + powered=False, + current_media=None, + state=SimpleNamespace(powered=True, current_media=None), + ) + assert to_brief_player(player).powered is True + + +def test_to_brief_player_current_item_uses_state_current_media() -> None: + """``current_item`` is cleared when ``Player.state.current_media`` is None. + + After ``stop`` MA core clears ``_state.current_media``, but the raw + ``_attr_current_media`` may persist until the next playback. The brief + must reflect the canonical state so the LLM doesn't think a track is + still playing. + """ + stale = SimpleNamespace(uri="library://track/48", title="07") + player = SimpleNamespace( + player_id="p1", + name="P1", + playback_state=SimpleNamespace(value="idle"), + volume_level=0, + powered=True, + current_media=stale, + state=SimpleNamespace(powered=True, current_media=None), + ) + assert to_brief_player(player).current_item is None + + +def test_to_brief_queue_with_items() -> None: + """``to_brief_queue`` builds a ``QueueBrief`` with item summaries.""" + queue = SimpleNamespace( + queue_id="kitchen", + current_index=2, + items=10, + shuffle_enabled=True, + repeat_mode=SimpleNamespace(value="off"), + ) + items = [ + SimpleNamespace( + queue_item_id="i1", + name="One", + duration=120, + media_item=SimpleNamespace(artists=[SimpleNamespace(name="A1")]), + ), + SimpleNamespace(queue_item_id="i2", name="Two", duration=240, media_item=None), + ] + brief = to_brief_queue(queue, items=items) + assert isinstance(brief, QueueBrief) + assert brief.queue_id == "kitchen" + assert brief.shuffle is True + assert brief.repeat == "off" + assert len(brief.items) == 2 + assert brief.items[0].artists == ["A1"] + + +def test_to_brief_queue_uses_canonical_items_int_for_count() -> None: + """``items`` (int) on the canonical PlayerQueue is the **total** length. + + Earlier code mis-fell back to len(brief_items) (the truncated lookahead), + under-reporting real queue depth. ``items_count`` from the truncated + lookahead must not win over the explicit total. + """ + queue = SimpleNamespace( + queue_id="q", + current_index=0, + items=42, # canonical MA: total length, not a list + shuffle_enabled=False, + repeat_mode=None, + ) + # Pass only 5 items as the truncated lookahead. + truncated = [ + SimpleNamespace(queue_item_id=str(i), name=f"t{i}", duration=60, media_item=None) + for i in range(5) + ] + brief = to_brief_queue(queue, items=truncated) + assert brief.item_count == 42 # not 5 + assert len(brief.items) == 5 + + +def test_page_args_clamps() -> None: + """``page_args`` clamps negatives and oversized limits.""" + assert page_args(-5, 5000) == (0, 200) + assert page_args(0, 0) == (0, 1) + assert page_args(10, 25) == (10, 25) diff --git a/tests/providers/fastmcp_server/test_origin.py b/tests/providers/fastmcp_server/test_origin.py new file mode 100644 index 0000000000..6c8742965d --- /dev/null +++ b/tests/providers/fastmcp_server/test_origin.py @@ -0,0 +1,464 @@ +"""Tests for Origin allowlist computation, matching, and bridge enforcement (C1+C2).""" +# mypy: disable-error-code="arg-type, no-untyped-def, type-arg, assignment, operator, misc" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest +from aiohttp.test_utils import TestClient, TestServer + +from music_assistant.providers.fastmcp_server.http_bridge import ( + _compute_origin_allowlist, + _is_origin_allowed, + _is_origin_allowed_for_request, + _normalize_origin, + build_protected_resource_metadata, + mount_into_mass, + mount_well_known, +) + +from .conftest import FakeWebserver, build_aiohttp_app + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + ("http://localhost:8095", "http://localhost:8095"), + ("HTTP://Localhost:8095", "http://localhost:8095"), + ("http://localhost:80", "http://localhost"), + ("https://example.com:443", "https://example.com"), + ("https://example.com/path", "https://example.com"), + ("null", "null"), + ("", None), + ("not-a-url", None), + ("http://", None), + # IPv6 literals: brackets must round-trip so the normalized form + # matches the allowlist entry the bridge synthesises. + ("http://[::1]", "http://[::1]"), + ("http://[::1]:8095", "http://[::1]:8095"), + ("HTTP://[::1]:8095", "http://[::1]:8095"), + ("http://[2001:db8::1]:80", "http://[2001:db8::1]"), + ], +) +def test_normalize_origin(raw: str, expected: str | None) -> None: + """Origin strings collapse to ``scheme://host[:port]`` lowercased, default ports stripped.""" + assert _normalize_origin(raw) == expected + + +def test_compute_allowlist_ipv6_publish_ip() -> None: + """An IPv6 publish_ip is bracketed in the allowlist so browsers' Origin matches.""" + mass = SimpleNamespace( + webserver=SimpleNamespace(base_url="http://localhost:8095", publish_ip="::1"), + ) + allow = _compute_origin_allowlist(mass) + assert "http://[::1]:8095" in allow + assert _is_origin_allowed("http://[::1]:8095", allow) is True + # And without an explicit port (still works because we bracket consistently). + assert "http://[::1]" in allow + + +def _fake_mass(base_url: str = "http://localhost:8095", publish_ip: str = "127.0.0.1"): + return SimpleNamespace(webserver=SimpleNamespace(base_url=base_url, publish_ip=publish_ip)) + + +def test_compute_allowlist_default() -> None: + """Default allowlist contains loopbacks, base_url host, and publish_ip.""" + allow = _compute_origin_allowlist(_fake_mass()) + # loopbacks always there + assert "http://localhost" in allow + assert "http://127.0.0.1" in allow + assert "http://[::1]" in allow + # base_url with port + assert "http://localhost:8095" in allow + # https-twin of base_url + assert "https://localhost:8095" in allow + # publish_ip with derived port + assert "http://127.0.0.1:8095" in allow + assert "https://127.0.0.1:8095" in allow + + +def test_compute_allowlist_with_extras() -> None: + """CSV ``extra_origins`` get normalized; bogus / empty entries silently dropped.""" + allow = _compute_origin_allowlist( + _fake_mass(), + extra_origins_csv="https://ha.example.com, http://reverse.lan:8443 ,bogus,", + ) + assert "https://ha.example.com" in allow + assert "http://reverse.lan:8443" in allow + # bogus + empty silently dropped + assert all(o != "" for o in allow) + + +def test_compute_allowlist_with_https_base_url() -> None: + """When base_url uses https on the default port, no port suffix is added.""" + allow = _compute_origin_allowlist(_fake_mass(base_url="https://mcp.example.com")) + assert "https://mcp.example.com" in allow + # publish_ip with no explicit port (https is scheme-default) + assert "http://127.0.0.1" in allow + assert "https://127.0.0.1" in allow + + +def test_compute_allowlist_handles_missing_attrs() -> None: + """If ``mass.webserver`` lacks base_url/publish_ip, only loopbacks remain.""" + mass = SimpleNamespace(webserver=SimpleNamespace()) + allow = _compute_origin_allowlist(mass) + assert "http://localhost" in allow + assert "http://127.0.0.1" in allow + + +@pytest.mark.parametrize( + ("origin", "allowed"), + [ + (None, True), # CLI / stdio-style — no Origin + ("http://localhost:8095", True), + ("http://LOCALHOST:8095", True), # case-insensitive + ("http://localhost:8095/", True), # trailing slash tolerated + ("http://evil.example", False), + ("https://localhost:8095", True), # https-twin allowed by default + ("null", False), # not in default allowlist + ], +) +def test_is_origin_allowed(origin: str | None, allowed: bool) -> None: + """Match Origin against the default allowlist (case-insensitive, trailing-slash tolerant).""" + allow = _compute_origin_allowlist(_fake_mass()) + assert _is_origin_allowed(origin, allow) is allowed + + +def test_is_origin_allowed_with_explicit_null() -> None: + """``Origin: null`` is accepted only when the operator opts in via ``extra_origins``.""" + allow = _compute_origin_allowlist(_fake_mass(), extra_origins_csv="null") + assert "null" in allow + assert _is_origin_allowed("null", allow) is True + + +def test_garbage_origin_rejected() -> None: + """Malformed Origin values fail closed with 403.""" + allow = _compute_origin_allowlist(_fake_mass()) + assert _is_origin_allowed("not-a-url", allow) is False + assert _is_origin_allowed("http://", allow) is False + + +# ── HA-ingress fallback in `_is_origin_allowed_for_request` ───────────────── + + +def _fake_request( + headers: dict[str, str] | None = None, + *, + scheme: str = "http", +) -> Any: + """Build a minimal stand-in for ``aiohttp.web.Request`` for origin checks.""" + return SimpleNamespace( + headers=headers or {}, + remote="172.30.32.1", + scheme=scheme, + ) + + +def _install_ingress_stub(monkeypatch: pytest.MonkeyPatch, *, is_ingress: bool) -> None: + """Inject ``is_request_from_ingress`` lazily into ``sys.modules``.""" + import sys # noqa: PLC0415 + import types # noqa: PLC0415 + + pkg = types.ModuleType("music_assistant") + pkg.__path__ = [] + controllers = types.ModuleType("music_assistant.controllers") + controllers.__path__ = [] + webserver_pkg = types.ModuleType("music_assistant.controllers.webserver") + webserver_pkg.__path__ = [] + helpers_pkg = types.ModuleType("music_assistant.controllers.webserver.helpers") + helpers_pkg.__path__ = [] + auth_mod = types.ModuleType("music_assistant.controllers.webserver.helpers.auth_middleware") + auth_mod.is_request_from_ingress = lambda _req: is_ingress # type: ignore[attr-defined] + + monkeypatch.setitem(sys.modules, "music_assistant", pkg) + monkeypatch.setitem(sys.modules, "music_assistant.controllers", controllers) + monkeypatch.setitem(sys.modules, "music_assistant.controllers.webserver", webserver_pkg) + monkeypatch.setitem(sys.modules, "music_assistant.controllers.webserver.helpers", helpers_pkg) + monkeypatch.setitem( + sys.modules, + "music_assistant.controllers.webserver.helpers.auth_middleware", + auth_mod, + ) + + +def test_request_origin_allowed_via_allowlist(monkeypatch: pytest.MonkeyPatch) -> None: + """Falls through to the legacy allowlist when the basic check accepts.""" + _install_ingress_stub(monkeypatch, is_ingress=False) + allow = _compute_origin_allowlist(_fake_mass()) + req = _fake_request({"Origin": "http://localhost:8095"}) + assert _is_origin_allowed_for_request(req, allow) is True + + +def test_request_origin_accepts_ingress_forwarded_host( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Ingress request whose Origin matches X-Forwarded-Host is accepted. + + Reproduces the HA add-on case where the user opens the Connect Wizard at + ``https:////…`` and the browser sends ``Origin: https://``, + which is never on the static allowlist. + """ + _install_ingress_stub(monkeypatch, is_ingress=True) + allow = _compute_origin_allowlist(_fake_mass()) + req = _fake_request( + { + "Origin": "https://ha.example.com", + "X-Forwarded-Host": "ha.example.com", + "X-Forwarded-Proto": "https", + } + ) + assert _is_origin_allowed_for_request(req, allow) is True + + +def test_request_origin_rejects_forwarded_host_without_ingress( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Without the trusted-ingress signal, ``X-Forwarded-Host`` is not trusted.""" + _install_ingress_stub(monkeypatch, is_ingress=False) + allow = _compute_origin_allowlist(_fake_mass()) + req = _fake_request( + { + "Origin": "https://attacker.example", + "X-Forwarded-Host": "attacker.example", + "X-Forwarded-Proto": "https", + } + ) + assert _is_origin_allowed_for_request(req, allow) is False + + +def test_request_origin_rejects_mismatched_forwarded_host( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Origin must equal the forwarded host — a mismatch is rejected even via ingress.""" + _install_ingress_stub(monkeypatch, is_ingress=True) + allow = _compute_origin_allowlist(_fake_mass()) + req = _fake_request( + { + "Origin": "https://attacker.example", + "X-Forwarded-Host": "ha.example.com", + "X-Forwarded-Proto": "https", + } + ) + assert _is_origin_allowed_for_request(req, allow) is False + + +def test_request_origin_no_forward_header_rejected( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Ingress without ``X-Forwarded-Host`` falls through to the strict allowlist.""" + _install_ingress_stub(monkeypatch, is_ingress=True) + allow = _compute_origin_allowlist(_fake_mass()) + req = _fake_request({"Origin": "https://ha.example.com"}) + assert _is_origin_allowed_for_request(req, allow) is False + + +def test_request_origin_missing_proto_falls_back_to_transport_scheme( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When ``X-Forwarded-Proto`` is absent, the aiohttp ``scheme`` fills in. + + Inside an HA add-on the transport scheme is plain ``http`` because the + container is reached over the docker network. A proxy that forwards the + host header but omits the proto header should still validate against the + actual transport's scheme rather than guess ``https``. + """ + _install_ingress_stub(monkeypatch, is_ingress=True) + allow = _compute_origin_allowlist(_fake_mass()) + req = _fake_request( + { + "Origin": "http://ha.local:8123", + "X-Forwarded-Host": "ha.local:8123", + }, + scheme="http", + ) + assert _is_origin_allowed_for_request(req, allow) is True + + +def test_request_origin_accepts_case_insensitive_origin( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """``Origin`` matching is case-insensitive on host (RFC 3986).""" + _install_ingress_stub(monkeypatch, is_ingress=True) + allow = _compute_origin_allowlist(_fake_mass()) + req = _fake_request( + { + "Origin": "HTTPS://Ha.Example.COM", + "X-Forwarded-Host": "ha.example.com", + "X-Forwarded-Proto": "https", + } + ) + assert _is_origin_allowed_for_request(req, allow) is True + + +def test_request_origin_handles_missing_ma_module() -> None: + """When ``music_assistant`` is unavailable, never auto-accept (fail closed).""" + allow = _compute_origin_allowlist(_fake_mass()) + req = _fake_request( + { + "Origin": "https://ha.example.com", + "X-Forwarded-Host": "ha.example.com", + } + ) + # No stub installed; the import inside the helper raises, caught → reject. + assert _is_origin_allowed_for_request(req, allow) is False + + +# ── End-to-end bridge enforcement (C2) ────────────────────────────────────── + + +class _FakeMcp: + """Stand-in for FastMCP exposing an ASGI app via ``http_app(...)``.""" + + def __init__(self, asgi_app: Any) -> None: + self._app = asgi_app + + def http_app(self, transport: str = "streamable-http", path: str = "/mcp") -> Any: + return self._app + + +async def _echo_asgi(scope: dict, receive: Any, send: Any) -> None: + """Minimal ASGI app: handles lifespan events + returns 200 'OK' on http.""" + if scope.get("type") == "lifespan": + while True: + msg = await receive() + if msg["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.complete"}) + elif msg["type"] == "lifespan.shutdown": + await send({"type": "lifespan.shutdown.complete"}) + return + await send({"type": "http.response.start", "status": 200, "headers": []}) + await send({"type": "http.response.body", "body": b"OK"}) + + +@pytest.fixture +async def bridge_client() -> Any: + """Build the bridge handler against a fake MA + fake ASGI, expose via TestClient.""" + fake_ws = FakeWebserver() + mass = SimpleNamespace(webserver=fake_ws) + mcp = _FakeMcp(_echo_asgi) + await mount_into_mass(mass, mcp, mount_path="/mcp/v1", extra_origins_csv="") + + async with TestClient(TestServer(build_aiohttp_app(fake_ws))) as client: + yield client + + +async def test_bridge_rejects_evil_origin(bridge_client: TestClient) -> None: + """A non-allow-listed Origin is blocked with 403, ASGI never invoked.""" + resp = await bridge_client.post("/mcp/v1/", headers={"Origin": "http://evil.example"}) + assert resp.status == 403 + + +async def test_bridge_allows_localhost(bridge_client: TestClient) -> None: + """Origin that matches base_url is forwarded to ASGI (200 from echo app).""" + resp = await bridge_client.post("/mcp/v1/", headers={"Origin": "http://localhost:8095"}) + assert resp.status == 200 + assert (await resp.read()) == b"OK" + + +async def test_bridge_allows_no_origin(bridge_client: TestClient) -> None: + """Requests without ``Origin`` header (curl/CLI) pass through unchanged.""" + resp = await bridge_client.post("/mcp/v1/") + assert resp.status == 200 + + +def test_build_protected_resource_metadata_minimal() -> None: + """RFC 9728 metadata always carries resource + authorization_servers + bearer methods.""" + meta = build_protected_resource_metadata( + resource_uri="http://localhost:8095/mcp/v1", + authorization_servers=["http://localhost:8095"], + ) + assert meta == { + "resource": "http://localhost:8095/mcp/v1", + "authorization_servers": ["http://localhost:8095"], + "bearer_methods_supported": ["header"], + } + + +def test_build_protected_resource_metadata_full() -> None: + """Optional fields scopes_supported / resource_name appear when provided.""" + meta = build_protected_resource_metadata( + resource_uri="http://localhost:8095/mcp/v1", + authorization_servers=["http://localhost:8095"], + scopes_supported=["query:library", "control:playback"], + resource_name="Music Assistant MCP", + ) + assert meta["scopes_supported"] == ["query:library", "control:playback"] + assert meta["resource_name"] == "Music Assistant MCP" + + +async def test_mount_well_known_with_dynamic_scopes_refreshes() -> None: + """When scopes_supported is a callable, the body refreshes on each request. + + Permission hot-swap mutates the closed-over set in MCPServerRuntime; + the well-known endpoint must reflect the change without a runtime rebuild. + """ + fake_ws = FakeWebserver() + mass = SimpleNamespace(webserver=fake_ws) + scopes: list[str] = ["query:library"] + await mount_well_known( + mass, + mount_path="/mcp/v1", + resource_uri="http://localhost:8095/mcp/v1", + authorization_servers=["http://localhost:8095"], + scopes_supported=lambda: list(scopes), + resource_name="MA MCP", + ) + + async with TestClient(TestServer(build_aiohttp_app(fake_ws))) as client: + before = await (await client.get("/.well-known/oauth-protected-resource")).json() + assert before["scopes_supported"] == ["query:library"] + # Mutate the underlying set (simulates hot-swap of permission flags). + scopes.append("control:playback") + after = await (await client.get("/.well-known/oauth-protected-resource")).json() + assert "control:playback" in after["scopes_supported"] + + +async def test_mount_well_known_serves_metadata() -> None: + """The well-known route returns the RFC 9728 JSON document for both URI forms.""" + fake_ws = FakeWebserver() + mass = SimpleNamespace(webserver=fake_ws) + unmount = await mount_well_known( + mass, + mount_path="/mcp/v1", + resource_uri="http://localhost:8095/mcp/v1", + authorization_servers=["http://localhost:8095"], + scopes_supported=["query:library"], + resource_name="Music Assistant MCP", + ) + + paths = [r[0] for r in fake_ws.routes] + assert "/.well-known/oauth-protected-resource/mcp/v1" in paths + assert "/.well-known/oauth-protected-resource" in paths + + async with TestClient(TestServer(build_aiohttp_app(fake_ws))) as client: + for path in paths: + resp = await client.get(path) + assert resp.status == 200 + assert resp.headers["content-type"].startswith("application/json") + doc = await resp.json() + assert doc["resource"] == "http://localhost:8095/mcp/v1" + assert doc["authorization_servers"] == ["http://localhost:8095"] + assert doc["bearer_methods_supported"] == ["header"] + assert doc["scopes_supported"] == ["query:library"] + assert doc["resource_name"] == "Music Assistant MCP" + + unmount() + + +async def test_bridge_with_extra_origins() -> None: + """``extra_origins_csv`` widens the allowlist for reverse-proxy / HA ingress.""" + fake_ws = FakeWebserver() + mass = SimpleNamespace(webserver=fake_ws) + mcp = _FakeMcp(_echo_asgi) + await mount_into_mass( + mass, mcp, mount_path="/mcp/v1", extra_origins_csv="https://ha.example.com" + ) + + async with TestClient(TestServer(build_aiohttp_app(fake_ws))) as client: + resp = await client.post("/mcp/v1/", headers={"Origin": "https://ha.example.com"}) + assert resp.status == 200 + # An origin not in the extras stays rejected. + resp = await client.post("/mcp/v1/", headers={"Origin": "http://evil.example"}) + assert resp.status == 403 diff --git a/tests/providers/fastmcp_server/test_resources.py b/tests/providers/fastmcp_server/test_resources.py new file mode 100644 index 0000000000..3881896ead --- /dev/null +++ b/tests/providers/fastmcp_server/test_resources.py @@ -0,0 +1,85 @@ +"""Tests for ``provider/resources/*`` handler return-value serialisation. + +FastMCP's resource read API requires handlers to return ``str | bytes | +list[ResourceContents]``; returning an MA domain object or a provider Brief +dataclass directly raises ``contents must be str, bytes, or list``. These +tests pin the handlers down to JSON-text returns end-to-end via the +in-memory FastMCP Client transport. +""" +# mypy: disable-error-code="arg-type, no-untyped-def, type-arg, assignment, operator, misc, attr-defined" + +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +from fastmcp import Client, FastMCP + +from music_assistant.providers.fastmcp_server.resources.library_resources import ( + register_library_resources, +) +from music_assistant.providers.fastmcp_server.resources.player_resources import ( + register_player_resources, +) + + +async def test_library_artist_resource_returns_json_text(mock_mass: MagicMock) -> None: + """An existing artist is serialised to JSON text in the response contents.""" + artist = SimpleNamespace( + uri="library://artist/17", + name="7Б", + to_dict=lambda: {"uri": "library://artist/17", "name": "7Б"}, + ) + mock_mass.music.artists.get_library_item.return_value = artist + + mcp: FastMCP = FastMCP(name="t") + register_library_resources(mcp, mock_mass) + async with Client(mcp) as client: + contents = await client.read_resource("library://artist/17") + + text_blocks = [c.text for c in contents if hasattr(c, "text")] + assert text_blocks, "no text content returned" + parsed = json.loads(text_blocks[0]) + assert parsed["name"] == "7Б" + assert parsed["uri"] == "library://artist/17" + + +async def test_library_artist_resource_returns_null_for_missing(mock_mass: MagicMock) -> None: + """A missing library item resolves to ``None`` handler-side, rendered as ``"null"``.""" + mock_mass.music.artists.get_library_item.return_value = None + + mcp: FastMCP = FastMCP(name="t") + register_library_resources(mcp, mock_mass) + async with Client(mcp) as client: + contents = await client.read_resource("library://artist/999") + + text_blocks = [c.text for c in contents if hasattr(c, "text")] + assert text_blocks == ["null"] + + +async def test_player_resource_returns_json_text_for_brief(mock_mass: MagicMock) -> None: + """A ``PlayerBrief`` returned by the player handler is JSON-serialised.""" + player = SimpleNamespace( + player_id="p1", + display_name="P1", + name="P1", + playback_state=SimpleNamespace(value="playing"), + volume_level=100, + powered=True, + current_media=None, + state=SimpleNamespace(powered=True, current_media=None), + ) + mock_mass.players.get_player.return_value = player + + mcp: FastMCP = FastMCP(name="t") + register_player_resources(mcp, mock_mass) + async with Client(mcp) as client: + contents = await client.read_resource("player://p1") + + text_blocks = [c.text for c in contents if hasattr(c, "text")] + assert text_blocks, "no text content returned" + parsed = json.loads(text_blocks[0]) + assert parsed["player_id"] == "p1" + assert parsed["state"] == "playing" + assert parsed["powered"] is True diff --git a/tests/providers/fastmcp_server/test_tags.py b/tests/providers/fastmcp_server/test_tags.py new file mode 100644 index 0000000000..73f3b163eb --- /dev/null +++ b/tests/providers/fastmcp_server/test_tags.py @@ -0,0 +1,72 @@ +"""Tests for the tag enum and config-to-tag mapping.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from music_assistant.providers.fastmcp_server.constants import ( + CONF_CONTROL_PLAYBACK, + CONF_DELETE_FAVORITES, + CONF_QUERY_LIBRARY, + PERMISSION_KEYS, +) +from music_assistant.providers.fastmcp_server.tags import CONFIG_TO_TAG, Tag, enabled_tags + +if TYPE_CHECKING: + from unittest.mock import MagicMock + + +def test_config_to_tag_is_total() -> None: + """Every permission key has a unique tag, and counts match.""" + assert set(CONFIG_TO_TAG) == set(PERMISSION_KEYS) + assert len(set(CONFIG_TO_TAG.values())) == len(PERMISSION_KEYS) == 16 + + +def test_tag_enum_values_are_namespaced() -> None: + """Tag values look like ``:``.""" + for tag in Tag: + assert ":" in tag.value + verb, _, _ = tag.value.partition(":") + assert verb in {"query", "control", "edit", "delete"} + + +def test_enabled_tags_defaults(mock_config: MagicMock) -> None: + """With default config (4 reads on, all mutations off), only 4 query tags surface.""" + tags = enabled_tags(mock_config) + assert tags == { + Tag.QUERY_LIBRARY, + Tag.QUERY_QUEUE, + Tag.QUERY_PLAYERS, + Tag.QUERY_METADATA, + } + + +def test_enabled_tags_toggle(mock_config: MagicMock) -> None: + """Flipping a single bool flips exactly one tag.""" + base = enabled_tags(mock_config) + mock_config._values[CONF_CONTROL_PLAYBACK] = True + after = enabled_tags(mock_config) + assert after - base == {Tag.CONTROL_PLAYBACK} + assert base - after == set() + + +def test_enabled_tags_all_off(mock_config: MagicMock) -> None: + """When every permission is off, the result is empty.""" + for key in PERMISSION_KEYS: + mock_config._values[key] = False + assert enabled_tags(mock_config) == set() + + +def test_enabled_tags_query_library_off_drops_only_one(mock_config: MagicMock) -> None: + """Disabling one specific permission only drops that tag.""" + mock_config._values[CONF_QUERY_LIBRARY] = False + tags = enabled_tags(mock_config) + assert Tag.QUERY_LIBRARY not in tags + assert Tag.QUERY_QUEUE in tags + assert Tag.QUERY_PLAYERS in tags + assert Tag.QUERY_METADATA in tags + + +def test_delete_tags_namespaced() -> None: + """Sanity: delete-family tags use the ``delete:`` prefix.""" + assert CONFIG_TO_TAG[CONF_DELETE_FAVORITES].value.startswith("delete:") diff --git a/tests/providers/fastmcp_server/test_uri.py b/tests/providers/fastmcp_server/test_uri.py new file mode 100644 index 0000000000..8ce688f782 --- /dev/null +++ b/tests/providers/fastmcp_server/test_uri.py @@ -0,0 +1,55 @@ +"""Tests for ``provider.resources._uri.parse_resource_uri``.""" + +from __future__ import annotations + +import pytest + +from music_assistant.providers.fastmcp_server.resources._uri import ResourceURI, parse_resource_uri + + +@pytest.mark.parametrize( + ("uri", "expected"), + [ + ("library://artist/123", ResourceURI("library", "artist", "123")), + ("library://album/abc-DEF_42", ResourceURI("library", "album", "abc-DEF_42")), + ("library://track/track%3A1", ResourceURI("library", "track", "track%3A1")), + ("library://playlist/p1", ResourceURI("library", "playlist", "p1")), + ("library://radio/r:1", ResourceURI("library", "radio", "r:1")), + ("library://podcast/p1", ResourceURI("library", "podcast", "p1")), + ("library://audiobook/a1", ResourceURI("library", "audiobook", "a1")), + ("player://kitchen", ResourceURI("player", None, "kitchen")), + ("player://livingroom_sonos", ResourceURI("player", None, "livingroom_sonos")), + ("queue://q1", ResourceURI("queue", None, "q1")), + ], +) +def test_parse_valid(uri: str, expected: ResourceURI) -> None: + """Valid URIs round-trip into ResourceURI.""" + assert parse_resource_uri(uri) == expected + + +@pytest.mark.parametrize( + "uri", + [ + "", + "noscheme", + "ftp://library/x", + "library://", + "library:///", + "library://artist/", + "library://artist/../etc/passwd", + "library://artist/has spaces", + "library://unknownkind/x", + "player://", + "queue://", + ], +) +def test_parse_invalid(uri: str) -> None: + """Invalid URIs raise ``ValueError``.""" + with pytest.raises(ValueError, match=r".+"): + parse_resource_uri(uri) + + +def test_path_traversal_rejected_in_id() -> None: + """Two-dot sequences anywhere in the body are rejected as a defence-in-depth measure.""" + with pytest.raises(ValueError, match=r".+"): + parse_resource_uri("library://artist/..foo")