Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions superset/mcp_service/dataset/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,29 @@ class TableColumnInfo(BaseModel):
filterable: bool | None = Field(None, description="Is filterable")
description: str | None = Field(None, description="Column description")

@model_serializer(mode="wrap")
def _filter_column_fields_by_context(
self, serializer: Any, info: Any
) -> Dict[str, Any]:
"""Filter column fields based on serialization context.

If context contains 'column_fields', only include those fields.
Otherwise, include all fields. This trims wide datasets so a
50-column dataset doesn't ship 50 long descriptions when the
caller only needs column_name + type.
"""
data = serializer(self)

if info.context and isinstance(info.context, dict):
column_fields = info.context.get("column_fields")
if column_fields:
requested = set(column_fields)
# Always preserve column_name as the only required field
requested.add("column_name")
return {k: v for k, v in data.items() if k in requested}
Comment thread
aminghadersohi marked this conversation as resolved.

return data
Comment thread
aminghadersohi marked this conversation as resolved.


class SqlMetricInfo(BaseModel):
metric_name: str = Field(
Expand Down Expand Up @@ -311,13 +334,89 @@ def create(cls, error: str, error_type: str) -> "DatasetError":
)


DEFAULT_GET_DATASET_INFO_COLUMNS: List[str] = [
"id",
"table_name",
"schema",
"database_name",
"database_id",
"uuid",
"is_virtual",
"description",
"main_dttm_col",
"sql",
"url",
"columns",
"metrics",
]

DEFAULT_GET_DATASET_INFO_COLUMN_FIELDS: List[str] = [
"column_name",
"type",
"is_dttm",
]


class GetDatasetInfoRequest(MetadataCacheControl):
"""Request schema for get_dataset_info with support for ID or UUID."""

identifier: Annotated[
int | str,
Field(description="Dataset identifier - can be numeric ID or UUID string"),
]
select_columns: Annotated[
List[str],
Field(
default_factory=lambda: list(DEFAULT_GET_DATASET_INFO_COLUMNS),
description=(
"Top-level fields to include in the response. Defaults to a lean "
"set that excludes verbose fields like params, template_params, "
"extra, tags, certification_details. Pass an explicit list to "
"override (e.g. ['id','table_name','columns'] for minimal output)."
),
),
]
column_fields: Annotated[
List[str],
Field(
default_factory=lambda: list(DEFAULT_GET_DATASET_INFO_COLUMN_FIELDS),
description=(
"Per-column fields to include for entries in 'columns'. Defaults "
"to ['column_name','type','is_dttm']. Pass a wider list to "
"include 'verbose_name','groupby','filterable','description' "
"when needed. Trimming per-column fields keeps responses small "
"for wide datasets."
),
),
]

@field_validator("select_columns", mode="before")
@classmethod
def _parse_select_columns(cls, value: Any) -> Any:
from superset.mcp_service.utils.schema_utils import parse_json_or_list

if value is None:
return list(DEFAULT_GET_DATASET_INFO_COLUMNS)
parsed = parse_json_or_list(value, "select_columns")
# Treat empty list as "use defaults" so callers cannot accidentally
# opt out of size reduction by passing []. Without this, an empty
# list disables filtering downstream and reintroduces oversized
# responses.
if not parsed:
return list(DEFAULT_GET_DATASET_INFO_COLUMNS)
return parsed

@field_validator("column_fields", mode="before")
@classmethod
def _parse_column_fields(cls, value: Any) -> Any:
from superset.mcp_service.utils.schema_utils import parse_json_or_list

if value is None:
return list(DEFAULT_GET_DATASET_INFO_COLUMN_FIELDS)
parsed = parse_json_or_list(value, "column_fields")
if not parsed:
return list(DEFAULT_GET_DATASET_INFO_COLUMN_FIELDS)
return parsed


class CreateVirtualDatasetRequest(BaseModel):
Expand Down
29 changes: 24 additions & 5 deletions superset/mcp_service/dataset/tool/get_dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import logging
from datetime import datetime, timezone
from typing import Any

from fastmcp import Context
from sqlalchemy.orm import joinedload, subqueryload
Expand Down Expand Up @@ -58,7 +59,7 @@
@requires_data_model_metadata_access
async def get_dataset_info(
request: GetDatasetInfoRequest, ctx: Context
) -> DatasetInfo | DatasetError:
) -> dict[str, Any] | DatasetError:
"""Get dataset metadata by ID or UUID.

Returns columns, metrics, and schema details.
Expand All @@ -68,6 +69,12 @@ async def get_dataset_info(
- DO NOT use schema.table_name format (e.g., "public.customers")
- To find a dataset ID, use the list_datasets tool first

Response size control (use these to keep responses small):
- select_columns: top-level fields to include (default: lean set)
- column_fields: per-column fields for entries in 'columns' (default:
column_name, type, is_dttm). Pass a wider list to opt in to
verbose_name, groupby, filterable, description.

IMPORTANT - Saved Metrics vs Columns:
The response includes both 'columns' (raw database columns) and 'metrics'
(pre-defined saved metrics). When building chart configs, use saved_metric=true
Expand Down Expand Up @@ -144,12 +151,24 @@ async def get_dataset_info(
len(result.metrics) if result.metrics else 0,
)
)
else:
await ctx.warning(
"Dataset retrieval failed: error_type=%s, error=%s"
% (result.error_type, result.error)
await ctx.debug(
"Filtering response: select_columns=%s, column_fields=%s"
% (request.select_columns, request.column_fields)
)
with event_logger.log_context(action="mcp.get_dataset_info.serialization"):
return result.model_dump(
mode="json",
by_alias=True,
context={
"select_columns": request.select_columns,
"column_fields": request.column_fields,
},
)

await ctx.warning(
"Dataset retrieval failed: error_type=%s, error=%s"
% (result.error_type, result.error)
)
return result

except Exception as e:
Expand Down
Loading
Loading