diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index 5b865c1f0dd6..eed78d1899af 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -32,6 +32,7 @@ ChartCapabilities, ChartSemantics, ColumnRef, + CurrencyFormat, FilterConfig, HandlebarsChartConfig, MixedTimeseriesChartConfig, @@ -469,6 +470,7 @@ def map_table_config(config: TableChartConfig) -> Dict[str, Any]: form_data["order_by_cols"] = config.sort_by form_data["row_limit"] = config.row_limit + add_color_scheme(form_data, config.color_scheme) return form_data @@ -539,7 +541,35 @@ def add_legend_config(form_data: Dict[str, Any], config: XYChartConfig) -> None: if not config.legend.show: form_data["show_legend"] = False if config.legend.position: - form_data["legend_orientation"] = config.legend.position + # Canonical form_data key is camelCase; the echarts plugins read + # `legendOrientation` directly off form_data. + form_data["legendOrientation"] = config.legend.position + + +def add_color_scheme(form_data: Dict[str, Any], color_scheme: str | None) -> None: + """Add color scheme to form_data when set.""" + if color_scheme: + form_data["color_scheme"] = color_scheme + + +def add_currency_format( + form_data: Dict[str, Any], + currency_format: CurrencyFormat | None, + key: str = "currency_format", +) -> None: + """Add currency format to form_data under the given key when set.""" + if currency_format: + form_data[key] = currency_format.to_form_data() + + +def add_xy_data_label_options( + form_data: Dict[str, Any], config: XYChartConfig, x_is_temporal: bool +) -> None: + """Apply XY-specific data-label and time-format options when set.""" + if config.x_axis_time_format and x_is_temporal: + form_data["x_axis_time_format"] = config.x_axis_time_format + if config.show_value: + form_data["show_value"] = True def add_orientation_config(form_data: Dict[str, Any], config: XYChartConfig) -> None: @@ -714,6 +744,9 @@ def map_xy_config( add_axis_config(form_data, config) add_legend_config(form_data, config) add_orientation_config(form_data, config) + add_color_scheme(form_data, config.color_scheme) + add_currency_format(form_data, config.currency_format) + add_xy_data_label_options(form_data, config, x_is_temporal) return form_data @@ -726,11 +759,13 @@ def map_pie_config(config: PieChartConfig) -> Dict[str, Any]: "viz_type": "pie", "groupby": [config.dimension.name], "metric": metric, - "color_scheme": "supersetColors", + "color_scheme": config.color_scheme or "supersetColors", "show_labels": config.show_labels, "show_legend": config.show_legend, + "legendOrientation": config.legend_orientation, "label_type": config.label_type, "number_format": config.number_format, + "date_format": config.date_format, "sort_by_metric": config.sort_by_metric, "row_limit": config.row_limit, "donut": config.donut, @@ -738,9 +773,9 @@ def map_pie_config(config: PieChartConfig) -> Dict[str, Any]: "labels_outside": config.labels_outside, "outerRadius": config.outer_radius, "innerRadius": config.inner_radius, - "date_format": "smart_date", } + add_currency_format(form_data, config.currency_format) _add_adhoc_filters(form_data, config.filters) return form_data @@ -766,6 +801,9 @@ def map_big_number_config(config: BigNumberChartConfig) -> Dict[str, Any]: if config.y_axis_format: form_data["y_axis_format"] = config.y_axis_format + add_color_scheme(form_data, config.color_scheme) + add_currency_format(form_data, config.currency_format) + # Trendline-specific fields if viz_type == "big_number": # Big Number with trendline uses granularity_sqla for the temporal column @@ -781,6 +819,9 @@ def map_big_number_config(config: BigNumberChartConfig) -> Dict[str, Any]: if config.compare_lag is not None: form_data["compare_lag"] = config.compare_lag + if config.time_format: + form_data["time_format"] = config.time_format + _add_adhoc_filters(form_data, config.filters) return form_data @@ -852,6 +893,10 @@ def map_pivot_table_config(config: PivotTableChartConfig) -> Dict[str, Any]: "row_limit": config.row_limit, } + if config.date_format: + form_data["date_format"] = config.date_format + + add_currency_format(form_data, config.currency_format) _add_adhoc_filters(form_data, config.filters) return form_data @@ -931,10 +976,20 @@ def map_mixed_timeseries_config( "yAxisIndexB": 1, # Display "show_legend": config.show_legend, + "legendOrientation": config.legend_orientation, "zoomable": True, "rich_tooltip": True, } + if config.show_value: + form_data["show_value"] = True + + add_color_scheme(form_data, config.color_scheme) + add_currency_format(form_data, config.currency_format) + add_currency_format( + form_data, config.currency_format_secondary, key="currency_format_secondary" + ) + # Configure temporal handling configure_temporal_handling(form_data, x_is_temporal, config.time_grain) diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 1fdb4f43ab6d..604b94af7762 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -736,6 +736,29 @@ class LegendConfig(BaseModel): position: Literal["top", "bottom", "left", "right"] | None = "right" +class CurrencyFormat(BaseModel): + """Currency symbol and placement applied to numeric values.""" + + model_config = ConfigDict(populate_by_name=True) + + symbol: str = Field( + ..., + description="Currency code or symbol (e.g. 'USD', 'EUR', '$', '€')", + max_length=20, + ) + symbol_position: Literal["prefix", "suffix"] = Field( + "prefix", + description="Whether to render the symbol before or after the value", + validation_alias=AliasChoices("symbol_position", "symbolPosition"), + ) + + def to_form_data(self) -> Dict[str, str]: + return {"symbol": self.symbol, "symbolPosition": self.symbol_position} + + +LEGEND_POSITION_LITERAL = Literal["top", "bottom", "left", "right"] + + class FilterConfig(BaseModel): model_config = ConfigDict(populate_by_name=True) @@ -838,6 +861,27 @@ class PieChartConfig(UnknownFieldCheckMixin): ) row_limit: int = Field(100, description="Max slices", ge=1, le=10000) number_format: str = Field("SMART_NUMBER", max_length=50) + date_format: str = Field( + "smart_date", + description="Date format for date dimension labels (e.g. 'smart_date', " + "'%Y-%m-%d')", + max_length=50, + ) + currency_format: CurrencyFormat | None = Field( + None, + description="Currency symbol applied to the metric value", + ) + color_scheme: str | None = Field( + None, + description=( + "Superset color scheme ID (e.g. 'supersetColors', 'lyftColors', " + "'googleCategory10c', 'd3Category10'). Defaults to 'supersetColors'." + ), + max_length=100, + ) + legend_orientation: LEGEND_POSITION_LITERAL = Field( + "top", description="Legend placement around the chart" + ) show_total: bool = Field(False, description="Show total in center") labels_outside: bool = True outer_radius: int = Field(70, description="Outer radius % (1-100)", ge=1, le=100) @@ -889,6 +933,15 @@ class PivotTableChartConfig(UnknownFieldCheckMixin): ) row_limit: int = Field(10000, description="Max cells", ge=1, le=50000) value_format: str = Field("SMART_NUMBER", max_length=50) + date_format: str | None = Field( + None, + description="Date format for date columns (e.g. 'smart_date', '%Y-%m-%d')", + max_length=50, + ) + currency_format: CurrencyFormat | None = Field( + None, + description="Currency symbol applied to numeric metric values", + ) class MixedTimeseriesChartConfig(UnknownFieldCheckMixin): @@ -935,9 +988,29 @@ class MixedTimeseriesChartConfig(UnknownFieldCheckMixin): ) # Display options show_legend: bool = True + legend_orientation: LEGEND_POSITION_LITERAL = Field( + "top", description="Legend placement around the chart" + ) + show_value: bool = Field(False, description="Show data labels on each data point") x_axis: AxisConfig | None = None y_axis: AxisConfig | None = None y_axis_secondary: AxisConfig | None = None + color_scheme: str | None = Field( + None, + description=( + "Superset color scheme ID (e.g. 'supersetColors', 'lyftColors'). " + "When omitted, Superset's default scheme is used." + ), + max_length=100, + ) + currency_format: CurrencyFormat | None = Field( + None, + description="Currency symbol applied to primary metric values", + ) + currency_format_secondary: CurrencyFormat | None = Field( + None, + description="Currency symbol applied to secondary metric values", + ) filters: List[FilterConfig] | None = Field( None, description="Structured filters (column/op/value). " @@ -1113,6 +1186,27 @@ class BigNumberChartConfig(UnknownFieldCheckMixin): ), max_length=50, ) + time_format: str | None = Field( + None, + description=( + "Date format string for trendline x-axis labels " + "(e.g. 'smart_date', '%Y-%m-%d'). Only applies when " + "show_trendline=True." + ), + max_length=50, + ) + currency_format: CurrencyFormat | None = Field( + None, + description="Currency symbol applied to the metric value", + ) + color_scheme: str | None = Field( + None, + description=( + "Superset color scheme ID for the trendline (e.g. 'supersetColors'). " + "When omitted, Superset's default scheme is used." + ), + max_length=100, + ) start_y_axis_at_zero: bool = Field( True, description="Anchor trendline y-axis at zero", @@ -1194,6 +1288,14 @@ class TableChartConfig(UnknownFieldCheckMixin): validation_alias=AliasChoices("sort_by", "order_by_cols", "order_by"), ) row_limit: int = Field(1000, description="Max rows returned", ge=1, le=50000) + color_scheme: str | None = Field( + None, + description=( + "Superset color scheme ID applied to conditional/cell formatting " + "(e.g. 'supersetColors')." + ), + max_length=100, + ) @model_validator(mode="after") def validate_unique_column_labels(self) -> "TableChartConfig": @@ -1275,6 +1377,28 @@ class XYChartConfig(UnknownFieldCheckMixin): x_axis: AxisConfig | None = None y_axis: AxisConfig | None = None legend: LegendConfig | None = None + x_axis_time_format: str | None = Field( + None, + description=( + "Date format for temporal x-axis labels (e.g. 'smart_date', " + "'%Y-%m-%d'). Only applies when the x-axis column is temporal." + ), + max_length=50, + ) + show_value: bool = Field(False, description="Show data labels on each data point") + currency_format: CurrencyFormat | None = Field( + None, + description="Currency symbol applied to metric values", + ) + color_scheme: str | None = Field( + None, + description=( + "Superset color scheme ID (e.g. 'supersetColors', 'lyftColors', " + "'googleCategory10c', 'd3Category10'). When omitted, Superset's " + "default scheme is used." + ), + max_length=100, + ) filters: List[FilterConfig] | None = Field( None, description="Structured filters (column/op/value). " diff --git a/tests/unit_tests/mcp_service/chart/test_chart_utils.py b/tests/unit_tests/mcp_service/chart/test_chart_utils.py index 7c2da4e5de7d..3fbbe0fac1ce 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_utils.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_utils.py @@ -552,7 +552,34 @@ def test_map_xy_config_with_legend(self) -> None: assert result["viz_type"] == "echarts_timeseries_scatter" assert result["show_legend"] is False - assert result["legend_orientation"] == "top" + assert result["legendOrientation"] == "top" + + def test_map_xy_config_with_color_scheme(self) -> None: + """color_scheme propagates to form_data when set.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue")], + kind="line", + color_scheme="lyftColors", + ) + + result = map_xy_config(config) + + assert result["color_scheme"] == "lyftColors" + + def test_map_xy_config_without_color_scheme(self) -> None: + """color_scheme key omitted when not set, leaving Superset default.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue")], + kind="line", + ) + + result = map_xy_config(config) + + assert "color_scheme" not in result def test_map_xy_config_with_time_grain_month(self) -> None: """Test XY config mapping with monthly time grain""" diff --git a/tests/unit_tests/mcp_service/chart/test_new_chart_types.py b/tests/unit_tests/mcp_service/chart/test_new_chart_types.py index 9469f63b39a9..48b1631568a9 100644 --- a/tests/unit_tests/mcp_service/chart/test_new_chart_types.py +++ b/tests/unit_tests/mcp_service/chart/test_new_chart_types.py @@ -29,18 +29,23 @@ from superset.mcp_service.chart.chart_utils import ( generate_chart_name, + map_big_number_config, map_config_to_form_data, map_mixed_timeseries_config, map_pie_config, map_pivot_table_config, + map_table_config, ) from superset.mcp_service.chart.schemas import ( AxisConfig, + BigNumberChartConfig, ColumnRef, + CurrencyFormat, FilterConfig, MixedTimeseriesChartConfig, PieChartConfig, PivotTableChartConfig, + TableChartConfig, ) from superset.mcp_service.chart.validation.schema_validator import SchemaValidator @@ -212,6 +217,18 @@ def test_pie_form_data_with_filters(self) -> None: assert result["adhoc_filters"][0]["operator"] == "==" assert result["adhoc_filters"][0]["comparator"] == "US" + def test_pie_form_data_color_scheme_override(self) -> None: + """Explicit color_scheme overrides the supersetColors default.""" + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + color_scheme="googleCategory10c", + ) + result = map_pie_config(config) + + assert result["color_scheme"] == "googleCategory10c" + def test_pie_form_data_custom_options(self) -> None: config = PieChartConfig( chart_type="pie", @@ -975,3 +992,272 @@ def test_non_string_chart_type_rejected_gracefully( assert is_valid is False assert error is not None assert error.error_code == "INVALID_CHART_TYPE" + + +# ============================================================ +# Chart Formatting Options Tests (sc-102806 follow-up) +# ============================================================ + + +class TestPieFormattingOptions: + """number/date/currency format, color scheme, legend orientation on Pie.""" + + def test_currency_format_in_form_data(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + currency_format=CurrencyFormat(symbol="USD", symbol_position="prefix"), + ) + result = map_pie_config(config) + + assert result["currency_format"] == { + "symbol": "USD", + "symbolPosition": "prefix", + } + + def test_currency_format_omitted_when_unset(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + ) + result = map_pie_config(config) + + assert "currency_format" not in result + + def test_legend_orientation_in_form_data(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + legend_orientation="bottom", + ) + result = map_pie_config(config) + + assert result["legendOrientation"] == "bottom" + + def test_default_legend_orientation_is_top(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + ) + result = map_pie_config(config) + + assert result["legendOrientation"] == "top" + + def test_date_format_overridable(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="ds"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + date_format="%Y-%m-%d", + ) + result = map_pie_config(config) + + assert result["date_format"] == "%Y-%m-%d" + + +class TestPivotTableFormattingOptions: + """date/currency format on PivotTable.""" + + def test_currency_format_in_form_data(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="region")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + currency_format=CurrencyFormat(symbol="EUR", symbol_position="suffix"), + ) + result = map_pivot_table_config(config) + + assert result["currency_format"] == { + "symbol": "EUR", + "symbolPosition": "suffix", + } + + def test_date_format_in_form_data(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="ds")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + date_format="%Y-%m", + ) + result = map_pivot_table_config(config) + + assert result["date_format"] == "%Y-%m" + + def test_formatting_omitted_when_unset(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="region")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + ) + result = map_pivot_table_config(config) + + assert "currency_format" not in result + assert "date_format" not in result + + +class TestMixedTimeseriesFormattingOptions: + """color scheme, currency format, legend orientation, data labels on Mixed.""" + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_color_scheme_in_form_data(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="ds"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + color_scheme="lyftColors", + ) + result = map_mixed_timeseries_config(config) + + assert result["color_scheme"] == "lyftColors" + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_currency_format_primary_and_secondary(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="ds"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + currency_format=CurrencyFormat(symbol="USD"), + currency_format_secondary=CurrencyFormat(symbol="GBP"), + ) + result = map_mixed_timeseries_config(config) + + assert result["currency_format"] == { + "symbol": "USD", + "symbolPosition": "prefix", + } + assert result["currency_format_secondary"] == { + "symbol": "GBP", + "symbolPosition": "prefix", + } + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_legend_orientation_in_form_data(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="ds"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + legend_orientation="left", + ) + result = map_mixed_timeseries_config(config) + + assert result["legendOrientation"] == "left" + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_show_value_data_labels(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="ds"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + show_value=True, + ) + result = map_mixed_timeseries_config(config) + + assert result["show_value"] is True + + +class TestBigNumberFormattingOptions: + """color scheme, currency format, time format on BigNumber.""" + + def test_currency_format_in_form_data(self) -> None: + config = BigNumberChartConfig( + chart_type="big_number", + metric=ColumnRef(name="revenue", aggregate="SUM"), + currency_format=CurrencyFormat(symbol="JPY", symbol_position="prefix"), + ) + result = map_big_number_config(config) + + assert result["currency_format"] == { + "symbol": "JPY", + "symbolPosition": "prefix", + } + + def test_color_scheme_in_form_data(self) -> None: + config = BigNumberChartConfig( + chart_type="big_number", + metric=ColumnRef(name="revenue", aggregate="SUM"), + color_scheme="d3Category10", + ) + result = map_big_number_config(config) + + assert result["color_scheme"] == "d3Category10" + + def test_time_format_only_for_trendline(self) -> None: + # Without trendline, time_format is dropped because the trendline + # x-axis doesn't render. + config = BigNumberChartConfig( + chart_type="big_number", + metric=ColumnRef(name="revenue", aggregate="SUM"), + time_format="%Y-%m-%d", + ) + result = map_big_number_config(config) + + assert "time_format" not in result + + def test_time_format_with_trendline(self) -> None: + config = BigNumberChartConfig( + chart_type="big_number", + metric=ColumnRef(name="revenue", aggregate="SUM"), + temporal_column="ds", + show_trendline=True, + time_format="%Y-%m-%d", + ) + result = map_big_number_config(config) + + assert result["time_format"] == "%Y-%m-%d" + + +class TestTableFormattingOptions: + """color scheme on Table.""" + + def test_color_scheme_in_form_data(self) -> None: + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product"), ColumnRef(name="revenue")], + color_scheme="lyftColors", + ) + result = map_table_config(config) + + assert result["color_scheme"] == "lyftColors" + + def test_color_scheme_omitted_when_unset(self) -> None: + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product"), ColumnRef(name="revenue")], + ) + result = map_table_config(config) + + assert "color_scheme" not in result + + +class TestCurrencyFormatModel: + """CurrencyFormat schema validation.""" + + def test_default_symbol_position_is_prefix(self) -> None: + cf = CurrencyFormat(symbol="USD") + assert cf.symbol_position == "prefix" + + def test_camel_case_alias_accepted(self) -> None: + cf = CurrencyFormat.model_validate( + {"symbol": "USD", "symbolPosition": "suffix"} + ) + assert cf.symbol_position == "suffix" + + def test_invalid_position_rejected(self) -> None: + with pytest.raises(ValidationError): + CurrencyFormat(symbol="USD", symbol_position="middle") + + def test_to_form_data_shape(self) -> None: + cf = CurrencyFormat(symbol="EUR", symbol_position="suffix") + assert cf.to_form_data() == {"symbol": "EUR", "symbolPosition": "suffix"}