Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions superset/mcp_service/chart/chart_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ChartCapabilities,
ChartSemantics,
ColumnRef,
CurrencyFormat,
FilterConfig,
HandlebarsChartConfig,
MixedTimeseriesChartConfig,
Expand Down Expand Up @@ -469,6 +470,7 @@ def map_table_config(config: TableChartConfig) -> Dict[str, Any]:
form_data["order_by_cols"] = config.sort_by

form_data["row_limit"] = config.row_limit
add_color_scheme(form_data, config.color_scheme)

return form_data

Expand Down Expand Up @@ -539,7 +541,35 @@ def add_legend_config(form_data: Dict[str, Any], config: XYChartConfig) -> None:
if not config.legend.show:
form_data["show_legend"] = False
if config.legend.position:
form_data["legend_orientation"] = config.legend.position
# Canonical form_data key is camelCase; the echarts plugins read
# `legendOrientation` directly off form_data.
form_data["legendOrientation"] = config.legend.position


def add_color_scheme(form_data: Dict[str, Any], color_scheme: str | None) -> None:
"""Add color scheme to form_data when set."""
if color_scheme:
form_data["color_scheme"] = color_scheme


def add_currency_format(
form_data: Dict[str, Any],
currency_format: CurrencyFormat | None,
key: str = "currency_format",
) -> None:
"""Add currency format to form_data under the given key when set."""
if currency_format:
form_data[key] = currency_format.to_form_data()


def add_xy_data_label_options(
form_data: Dict[str, Any], config: XYChartConfig, x_is_temporal: bool
) -> None:
"""Apply XY-specific data-label and time-format options when set."""
if config.x_axis_time_format and x_is_temporal:
form_data["x_axis_time_format"] = config.x_axis_time_format
if config.show_value:
form_data["show_value"] = True


def add_orientation_config(form_data: Dict[str, Any], config: XYChartConfig) -> None:
Expand Down Expand Up @@ -714,6 +744,9 @@ def map_xy_config(
add_axis_config(form_data, config)
add_legend_config(form_data, config)
add_orientation_config(form_data, config)
add_color_scheme(form_data, config.color_scheme)
add_currency_format(form_data, config.currency_format)
add_xy_data_label_options(form_data, config, x_is_temporal)

return form_data

Expand All @@ -726,21 +759,23 @@ def map_pie_config(config: PieChartConfig) -> Dict[str, Any]:
"viz_type": "pie",
"groupby": [config.dimension.name],
"metric": metric,
"color_scheme": "supersetColors",
"color_scheme": config.color_scheme or "supersetColors",
"show_labels": config.show_labels,
"show_legend": config.show_legend,
"legendOrientation": config.legend_orientation,
"label_type": config.label_type,
"number_format": config.number_format,
"date_format": config.date_format,
"sort_by_metric": config.sort_by_metric,
"row_limit": config.row_limit,
"donut": config.donut,
"show_total": config.show_total,
"labels_outside": config.labels_outside,
"outerRadius": config.outer_radius,
"innerRadius": config.inner_radius,
"date_format": "smart_date",
}

add_currency_format(form_data, config.currency_format)
_add_adhoc_filters(form_data, config.filters)

return form_data
Expand All @@ -766,6 +801,9 @@ def map_big_number_config(config: BigNumberChartConfig) -> Dict[str, Any]:
if config.y_axis_format:
form_data["y_axis_format"] = config.y_axis_format

add_color_scheme(form_data, config.color_scheme)
add_currency_format(form_data, config.currency_format)

# Trendline-specific fields
if viz_type == "big_number":
# Big Number with trendline uses granularity_sqla for the temporal column
Expand All @@ -781,6 +819,9 @@ def map_big_number_config(config: BigNumberChartConfig) -> Dict[str, Any]:
if config.compare_lag is not None:
form_data["compare_lag"] = config.compare_lag

if config.time_format:
form_data["time_format"] = config.time_format

_add_adhoc_filters(form_data, config.filters)

return form_data
Expand Down Expand Up @@ -852,6 +893,10 @@ def map_pivot_table_config(config: PivotTableChartConfig) -> Dict[str, Any]:
"row_limit": config.row_limit,
}

if config.date_format:
form_data["date_format"] = config.date_format

add_currency_format(form_data, config.currency_format)
_add_adhoc_filters(form_data, config.filters)

return form_data
Expand Down Expand Up @@ -931,10 +976,20 @@ def map_mixed_timeseries_config(
"yAxisIndexB": 1,
# Display
"show_legend": config.show_legend,
"legendOrientation": config.legend_orientation,
"zoomable": True,
"rich_tooltip": True,
}

if config.show_value:
form_data["show_value"] = True

add_color_scheme(form_data, config.color_scheme)
add_currency_format(form_data, config.currency_format)
add_currency_format(
form_data, config.currency_format_secondary, key="currency_format_secondary"
)

# Configure temporal handling
configure_temporal_handling(form_data, x_is_temporal, config.time_grain)

Expand Down
124 changes: 124 additions & 0 deletions superset/mcp_service/chart/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,29 @@ class LegendConfig(BaseModel):
position: Literal["top", "bottom", "left", "right"] | None = "right"


class CurrencyFormat(BaseModel):
"""Currency symbol and placement applied to numeric values."""

model_config = ConfigDict(populate_by_name=True)

symbol: str = Field(
...,
description="Currency code or symbol (e.g. 'USD', 'EUR', '$', '€')",
max_length=20,
)
symbol_position: Literal["prefix", "suffix"] = Field(
"prefix",
description="Whether to render the symbol before or after the value",
validation_alias=AliasChoices("symbol_position", "symbolPosition"),
)

def to_form_data(self) -> Dict[str, str]:
return {"symbol": self.symbol, "symbolPosition": self.symbol_position}


LEGEND_POSITION_LITERAL = Literal["top", "bottom", "left", "right"]


class FilterConfig(BaseModel):
model_config = ConfigDict(populate_by_name=True)

Expand Down Expand Up @@ -838,6 +861,27 @@ class PieChartConfig(UnknownFieldCheckMixin):
)
row_limit: int = Field(100, description="Max slices", ge=1, le=10000)
number_format: str = Field("SMART_NUMBER", max_length=50)
date_format: str = Field(
"smart_date",
description="Date format for date dimension labels (e.g. 'smart_date', "
"'%Y-%m-%d')",
max_length=50,
)
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to the metric value",
)
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID (e.g. 'supersetColors', 'lyftColors', "
"'googleCategory10c', 'd3Category10'). Defaults to 'supersetColors'."
),
max_length=100,
)
legend_orientation: LEGEND_POSITION_LITERAL = Field(
"top", description="Legend placement around the chart"
)
show_total: bool = Field(False, description="Show total in center")
labels_outside: bool = True
outer_radius: int = Field(70, description="Outer radius % (1-100)", ge=1, le=100)
Expand Down Expand Up @@ -889,6 +933,15 @@ class PivotTableChartConfig(UnknownFieldCheckMixin):
)
row_limit: int = Field(10000, description="Max cells", ge=1, le=50000)
value_format: str = Field("SMART_NUMBER", max_length=50)
date_format: str | None = Field(
None,
description="Date format for date columns (e.g. 'smart_date', '%Y-%m-%d')",
max_length=50,
)
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to numeric metric values",
)


class MixedTimeseriesChartConfig(UnknownFieldCheckMixin):
Expand Down Expand Up @@ -935,9 +988,29 @@ class MixedTimeseriesChartConfig(UnknownFieldCheckMixin):
)
# Display options
show_legend: bool = True
legend_orientation: LEGEND_POSITION_LITERAL = Field(
"top", description="Legend placement around the chart"
)
show_value: bool = Field(False, description="Show data labels on each data point")
x_axis: AxisConfig | None = None
y_axis: AxisConfig | None = None
y_axis_secondary: AxisConfig | None = None
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID (e.g. 'supersetColors', 'lyftColors'). "
"When omitted, Superset's default scheme is used."
),
max_length=100,
)
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to primary metric values",
)
currency_format_secondary: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to secondary metric values",
)
filters: List[FilterConfig] | None = Field(
None,
description="Structured filters (column/op/value). "
Expand Down Expand Up @@ -1113,6 +1186,27 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
),
max_length=50,
)
time_format: str | None = Field(
None,
description=(
"Date format string for trendline x-axis labels "
"(e.g. 'smart_date', '%Y-%m-%d'). Only applies when "
"show_trendline=True."
),
max_length=50,
)
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to the metric value",
)
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID for the trendline (e.g. 'supersetColors'). "
"When omitted, Superset's default scheme is used."
),
max_length=100,
)
start_y_axis_at_zero: bool = Field(
True,
description="Anchor trendline y-axis at zero",
Expand Down Expand Up @@ -1194,6 +1288,14 @@ class TableChartConfig(UnknownFieldCheckMixin):
validation_alias=AliasChoices("sort_by", "order_by_cols", "order_by"),
)
row_limit: int = Field(1000, description="Max rows returned", ge=1, le=50000)
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID applied to conditional/cell formatting "
"(e.g. 'supersetColors')."
),
max_length=100,
)

@model_validator(mode="after")
def validate_unique_column_labels(self) -> "TableChartConfig":
Expand Down Expand Up @@ -1275,6 +1377,28 @@ class XYChartConfig(UnknownFieldCheckMixin):
x_axis: AxisConfig | None = None
y_axis: AxisConfig | None = None
legend: LegendConfig | None = None
x_axis_time_format: str | None = Field(
None,
description=(
"Date format for temporal x-axis labels (e.g. 'smart_date', "
"'%Y-%m-%d'). Only applies when the x-axis column is temporal."
),
max_length=50,
)
show_value: bool = Field(False, description="Show data labels on each data point")
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to metric values",
)
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID (e.g. 'supersetColors', 'lyftColors', "
"'googleCategory10c', 'd3Category10'). When omitted, Superset's "
"default scheme is used."
),
max_length=100,
)
filters: List[FilterConfig] | None = Field(
None,
description="Structured filters (column/op/value). "
Expand Down
29 changes: 28 additions & 1 deletion tests/unit_tests/mcp_service/chart/test_chart_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,34 @@ def test_map_xy_config_with_legend(self) -> None:

assert result["viz_type"] == "echarts_timeseries_scatter"
assert result["show_legend"] is False
assert result["legend_orientation"] == "top"
assert result["legendOrientation"] == "top"

def test_map_xy_config_with_color_scheme(self) -> None:
"""color_scheme propagates to form_data when set."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue")],
kind="line",
color_scheme="lyftColors",
)

result = map_xy_config(config)

assert result["color_scheme"] == "lyftColors"

def test_map_xy_config_without_color_scheme(self) -> None:
"""color_scheme key omitted when not set, leaving Superset default."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue")],
kind="line",
)

result = map_xy_config(config)

assert "color_scheme" not in result

def test_map_xy_config_with_time_grain_month(self) -> None:
"""Test XY config mapping with monthly time grain"""
Expand Down
Loading
Loading