diff --git a/superset/mcp_service/chart/tool/get_chart_sql.py b/superset/mcp_service/chart/tool/get_chart_sql.py index d586817d2612..1792b928831b 100644 --- a/superset/mcp_service/chart/tool/get_chart_sql.py +++ b/superset/mcp_service/chart/tool/get_chart_sql.py @@ -141,6 +141,21 @@ def _resolve_metrics_and_groupby( return _resolve_metrics(form_data, viz_type), _resolve_groupby(form_data) +def _extract_x_axis_col(form_data: dict[str, Any]) -> str | None: + """Return the x_axis column name from form_data, or None if not set. + + ``x_axis`` may be stored as a plain column-name string or as an adhoc + column dict (``{"column_name": "...", ...}``). + """ + x_axis = form_data.get("x_axis") + if isinstance(x_axis, str) and x_axis: + return x_axis + if isinstance(x_axis, dict): + col_name = x_axis.get("column_name") + return col_name if isinstance(col_name, str) and col_name else None + return None + + def _resolve_engine( datasource_id: Any, datasource_type: str, @@ -162,6 +177,58 @@ def _resolve_engine( return "base" +def _build_single_query_dict( + form_data: dict[str, Any], + columns: list[Any], + metrics: list[Any], +) -> dict[str, Any]: + """Build one query entry for QueryContextFactory from form_data fields.""" + qd: dict[str, Any] = {"columns": columns, "metrics": metrics} + if time_range := form_data.get("time_range"): + qd["time_range"] = time_range + if filters := form_data.get("filters"): + qd["filters"] = filters + if (row_limit := form_data.get("row_limit")) is not None: + qd["row_limit"] = row_limit + return qd + + +def _build_mixed_timeseries_secondary( + form_data: dict[str, Any], + x_axis_col: str | None, + engine: str = "base", +) -> dict[str, Any]: + """Build the secondary query dict for the ``mixed_timeseries`` viz type. + + ``mixed_timeseries`` has two independent series layers; the secondary + layer uses ``metrics_b`` / ``groupby_b`` instead of the primary fields. + Secondary-specific overrides (``time_range_b``, ``row_limit_b``, + ``adhoc_filters_b``) replace the corresponding primary values so the + generated SQL accurately reflects each series' independent configuration. + """ + metrics_b: list[Any] = list(form_data.get("metrics_b") or []) + raw_b = form_data.get("groupby_b") or [] + groupby_b: list[Any] = [raw_b] if isinstance(raw_b, str) else list(raw_b) + if x_axis_col and x_axis_col not in groupby_b: + groupby_b = [x_axis_col] + groupby_b + qd = _build_single_query_dict(form_data, groupby_b, metrics_b) + if time_range_b := form_data.get("time_range_b"): + qd["time_range"] = time_range_b + if (row_limit_b := form_data.get("row_limit_b")) is not None: + qd["row_limit"] = row_limit_b + # Process adhoc_filters_b into concrete filter clauses for the secondary + # query, mirroring how split_adhoc_filters_into_base_filters handles the + # primary adhoc_filters in _build_query_context_from_form_data. + if adhoc_filters_b := form_data.get("adhoc_filters_b"): + from superset.utils.core import split_adhoc_filters_into_base_filters + + secondary_fd: dict[str, Any] = {"adhoc_filters": adhoc_filters_b} + split_adhoc_filters_into_base_filters(secondary_fd, engine) + if secondary_filters := secondary_fd.get("filters"): + qd["filters"] = secondary_filters + return qd + + def _build_query_context_from_form_data( form_data: dict[str, Any], chart: "Slice | None" = None, @@ -209,22 +276,33 @@ def _build_query_context_from_form_data( merge_extra_filters(form_data) split_adhoc_filters_into_base_filters(form_data, engine) - # Build query dict with temporal and filter fields. - # QueryObjectFactory.create() accepts time_range as a top-level kwarg - # and converts it to from_dttm/to_dttm for the QueryObject. - query_dict: dict[str, Any] = { - "columns": groupby, - "metrics": metrics, - } - - if time_range := form_data.get("time_range"): - query_dict["time_range"] = time_range - - if filters := form_data.get("filters"): - query_dict["filters"] = filters + viz_type: str = ( + form_data.get("viz_type") + or (getattr(chart, "viz_type", "") if chart else "") + or "" + ) + is_timeseries = ( + viz_type.startswith("echarts_timeseries") or viz_type == "mixed_timeseries" + ) - if (row_limit := form_data.get("row_limit")) is not None: - query_dict["row_limit"] = row_limit + # For echarts_timeseries_* and mixed_timeseries charts the temporal + # column is stored in x_axis rather than groupby. Prepend it so the + # generated SQL includes the time axis. + x_axis_col: str | None = None + if is_timeseries: + x_axis_col = _extract_x_axis_col(form_data) + if x_axis_col and x_axis_col not in groupby: + groupby = [x_axis_col] + groupby + + queries: list[dict[str, Any]] = [ + _build_single_query_dict(form_data, groupby, metrics) + ] + + # mixed_timeseries exposes two independent query layers (primary and + # secondary). Build the second query from metrics_b / groupby_b so + # that get_chart_sql returns SQL for both and neither is silently lost. + if viz_type == "mixed_timeseries": + queries.append(_build_mixed_timeseries_secondary(form_data, x_axis_col, engine)) # Ensure datasource fields satisfy DatasourceDict typing requirements. # datasource_id must be int | str; datasource_type must be str. @@ -238,7 +316,7 @@ def _build_query_context_from_form_data( return factory.create( datasource={"id": resolved_id, "type": resolved_type_str}, - queries=[query_dict], + queries=queries, form_data=form_data, result_type=ChartDataResultType.QUERY, force=False, diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_sql.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_sql.py index f752beba8b9d..3e6d588fa6c6 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_sql.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_sql.py @@ -33,6 +33,7 @@ from superset.mcp_service.chart.tool.get_chart_sql import ( _build_query_context_from_form_data, _extract_sql_from_result, + _extract_x_axis_col, _find_chart_by_identifier, _resolve_datasource_name, _resolve_effective_form_data, @@ -468,6 +469,407 @@ def test_metrics_and_groupby_in_queries(self, mock_factory_cls): assert queries[0]["columns"] == ["product"] +class TestExtractXAxisCol: + """Tests for the _extract_x_axis_col helper.""" + + def test_string_x_axis(self): + """Plain string x_axis returns the string directly.""" + assert _extract_x_axis_col({"x_axis": "order_date"}) == "order_date" + + def test_dict_x_axis(self): + """Adhoc column dict x_axis returns column_name.""" + assert ( + _extract_x_axis_col( + { + "x_axis": { + "column_name": "ds", + "label": "ds", + "expressionType": "SIMPLE", + } + } + ) + == "ds" + ) + + def test_missing_x_axis_returns_none(self): + """Missing x_axis key returns None.""" + assert _extract_x_axis_col({}) is None + + def test_none_x_axis_returns_none(self): + """Explicit None x_axis returns None.""" + assert _extract_x_axis_col({"x_axis": None}) is None + + def test_empty_string_x_axis_returns_none(self): + """Empty string x_axis returns None.""" + assert _extract_x_axis_col({"x_axis": ""}) is None + + def test_dict_missing_column_name_returns_none(self): + """Adhoc column dict without column_name returns None.""" + assert _extract_x_axis_col({"x_axis": {"label": "ds"}}) is None + + def test_sql_expression_x_axis_returns_none(self): + """SQL expression adhoc columns have no column_name; returns None.""" + assert ( + _extract_x_axis_col( + { + "x_axis": { + "expressionType": "SQL", + "sqlExpression": "DATE_TRUNC('day', created_at)", + "label": "day", + } + } + ) + is None + ) + + +class TestBuildQueryContextTimeseriesAndMixed: + """Regression tests for x_axis and mixed_timeseries query-context fixes. + + Guards against two bugs: x_axis column dropped for echarts_timeseries_* + charts, and only one query rendered for mixed_timeseries charts. + """ + + @patch("superset.common.query_context_factory.QueryContextFactory") + @patch("superset.daos.datasource.DatasourceDAO.get_datasource") + def test_echarts_timeseries_x_axis_included_in_columns( + self, mock_get_ds, mock_factory_cls + ): + """x_axis column is prepended to query columns for echarts_timeseries charts.""" + mock_ds = Mock() + mock_ds.database.db_engine_spec.engine = "postgresql" + mock_get_ds.return_value = mock_ds + + mock_factory = Mock() + mock_factory.create.return_value = Mock() + mock_factory_cls.return_value = mock_factory + + form_data = { + "datasource_id": 1, + "datasource_type": "table", + "viz_type": "echarts_timeseries_line", + "x_axis": "ds", + "metrics": ["sum__sales"], + "groupby": ["region"], + } + + with patch("superset.common.chart_data.ChartDataResultType") as mock_rt: + mock_rt.QUERY = "QUERY" + _build_query_context_from_form_data(form_data, chart=None) + + queries = mock_factory.create.call_args[1]["queries"] + assert len(queries) == 1 + assert queries[0]["columns"][0] == "ds" + assert "region" in queries[0]["columns"] + + @patch("superset.common.query_context_factory.QueryContextFactory") + @patch("superset.daos.datasource.DatasourceDAO.get_datasource") + def test_echarts_timeseries_dict_x_axis_included_in_columns( + self, mock_get_ds, mock_factory_cls + ): + """Adhoc-column x_axis dict is resolved and prepended for echarts_timeseries.""" + mock_ds = Mock() + mock_ds.database.db_engine_spec.engine = "postgresql" + mock_get_ds.return_value = mock_ds + + mock_factory = Mock() + mock_factory.create.return_value = Mock() + mock_factory_cls.return_value = mock_factory + + form_data = { + "datasource_id": 1, + "datasource_type": "table", + "viz_type": "echarts_timeseries_bar", + "x_axis": {"column_name": "order_date", "expressionType": "SIMPLE"}, + "metrics": ["count"], + "groupby": [], + } + + with patch("superset.common.chart_data.ChartDataResultType") as mock_rt: + mock_rt.QUERY = "QUERY" + _build_query_context_from_form_data(form_data, chart=None) + + queries = mock_factory.create.call_args[1]["queries"] + assert queries[0]["columns"][0] == "order_date" + + @patch("superset.common.query_context_factory.QueryContextFactory") + @patch("superset.daos.datasource.DatasourceDAO.get_datasource") + def test_echarts_timeseries_x_axis_not_duplicated_if_already_in_groupby( + self, mock_get_ds, mock_factory_cls + ): + """x_axis is not duplicated if it is already in groupby.""" + mock_ds = Mock() + mock_ds.database.db_engine_spec.engine = "postgresql" + mock_get_ds.return_value = mock_ds + + mock_factory = Mock() + mock_factory.create.return_value = Mock() + mock_factory_cls.return_value = mock_factory + + form_data = { + "datasource_id": 1, + "datasource_type": "table", + "viz_type": "echarts_timeseries_line", + "x_axis": "ds", + "metrics": ["count"], + "groupby": ["ds"], # already present + } + + with patch("superset.common.chart_data.ChartDataResultType") as mock_rt: + mock_rt.QUERY = "QUERY" + _build_query_context_from_form_data(form_data, chart=None) + + queries = mock_factory.create.call_args[1]["queries"] + assert queries[0]["columns"].count("ds") == 1 + + @patch("superset.common.query_context_factory.QueryContextFactory") + @patch("superset.daos.datasource.DatasourceDAO.get_datasource") + def test_non_timeseries_x_axis_not_added(self, mock_get_ds, mock_factory_cls): + """x_axis is not added for non-timeseries chart types (e.g. table).""" + mock_ds = Mock() + mock_ds.database.db_engine_spec.engine = "postgresql" + mock_get_ds.return_value = mock_ds + + mock_factory = Mock() + mock_factory.create.return_value = Mock() + mock_factory_cls.return_value = mock_factory + + form_data = { + "datasource_id": 1, + "datasource_type": "table", + "viz_type": "table", + "x_axis": "ds", + "metrics": ["count"], + "groupby": ["region"], + } + + with patch("superset.common.chart_data.ChartDataResultType") as mock_rt: + mock_rt.QUERY = "QUERY" + _build_query_context_from_form_data(form_data, chart=None) + + queries = mock_factory.create.call_args[1]["queries"] + assert "ds" not in queries[0]["columns"] + + @patch("superset.common.query_context_factory.QueryContextFactory") + @patch("superset.daos.datasource.DatasourceDAO.get_datasource") + def test_mixed_timeseries_produces_two_queries(self, mock_get_ds, mock_factory_cls): + """mixed_timeseries builds two query dicts — one per series layer.""" + mock_ds = Mock() + mock_ds.database.db_engine_spec.engine = "postgresql" + mock_get_ds.return_value = mock_ds + + mock_factory = Mock() + mock_factory.create.return_value = Mock() + mock_factory_cls.return_value = mock_factory + + form_data = { + "datasource_id": 1, + "datasource_type": "table", + "viz_type": "mixed_timeseries", + "x_axis": "ds", + "metrics": ["sum__revenue"], + "groupby": ["country"], + "metrics_b": ["count"], + "groupby_b": ["channel"], + "time_range": "Last 30 days", + } + + with patch("superset.common.chart_data.ChartDataResultType") as mock_rt: + mock_rt.QUERY = "QUERY" + _build_query_context_from_form_data(form_data, chart=None) + + queries = mock_factory.create.call_args[1]["queries"] + assert len(queries) == 2 + + # Primary query + assert "ds" in queries[0]["columns"] + assert "country" in queries[0]["columns"] + assert queries[0]["metrics"] == ["sum__revenue"] + assert queries[0]["time_range"] == "Last 30 days" + + # Secondary query + assert "ds" in queries[1]["columns"] + assert "channel" in queries[1]["columns"] + assert queries[1]["metrics"] == ["count"] + assert queries[1]["time_range"] == "Last 30 days" + + @patch("superset.common.query_context_factory.QueryContextFactory") + @patch("superset.daos.datasource.DatasourceDAO.get_datasource") + def test_mixed_timeseries_x_axis_not_duplicated_in_secondary( + self, mock_get_ds, mock_factory_cls + ): + """x_axis is not duplicated in the secondary query if already in groupby_b.""" + mock_ds = Mock() + mock_ds.database.db_engine_spec.engine = "postgresql" + mock_get_ds.return_value = mock_ds + + mock_factory = Mock() + mock_factory.create.return_value = Mock() + mock_factory_cls.return_value = mock_factory + + form_data = { + "datasource_id": 1, + "datasource_type": "table", + "viz_type": "mixed_timeseries", + "x_axis": "ds", + "metrics": ["count"], + "groupby": [], + "metrics_b": ["sum__sales"], + "groupby_b": ["ds"], # x_axis already present + } + + with patch("superset.common.chart_data.ChartDataResultType") as mock_rt: + mock_rt.QUERY = "QUERY" + _build_query_context_from_form_data(form_data, chart=None) + + queries = mock_factory.create.call_args[1]["queries"] + assert queries[1]["columns"].count("ds") == 1 + + @patch("superset.common.query_context_factory.QueryContextFactory") + @patch("superset.daos.datasource.DatasourceDAO.get_datasource") + def test_mixed_timeseries_empty_secondary(self, mock_get_ds, mock_factory_cls): + """mixed_timeseries with no metrics_b/groupby_b still produces two queries.""" + mock_ds = Mock() + mock_ds.database.db_engine_spec.engine = "postgresql" + mock_get_ds.return_value = mock_ds + + mock_factory = Mock() + mock_factory.create.return_value = Mock() + mock_factory_cls.return_value = mock_factory + + form_data = { + "datasource_id": 1, + "datasource_type": "table", + "viz_type": "mixed_timeseries", + "x_axis": "ds", + "metrics": ["count"], + "groupby": [], + } + + with patch("superset.common.chart_data.ChartDataResultType") as mock_rt: + mock_rt.QUERY = "QUERY" + _build_query_context_from_form_data(form_data, chart=None) + + queries = mock_factory.create.call_args[1]["queries"] + assert len(queries) == 2 + assert queries[1]["metrics"] == [] + + @patch("superset.common.query_context_factory.QueryContextFactory") + @patch("superset.daos.datasource.DatasourceDAO.get_datasource") + def test_mixed_timeseries_time_range_b_overrides_secondary( + self, mock_get_ds, mock_factory_cls + ): + """time_range_b overrides the primary time_range for the secondary query.""" + mock_ds = Mock() + mock_ds.database.db_engine_spec.engine = "postgresql" + mock_get_ds.return_value = mock_ds + + mock_factory = Mock() + mock_factory.create.return_value = Mock() + mock_factory_cls.return_value = mock_factory + + form_data = { + "datasource_id": 1, + "datasource_type": "table", + "viz_type": "mixed_timeseries", + "x_axis": "ds", + "metrics": ["sum__revenue"], + "groupby": [], + "metrics_b": ["count"], + "groupby_b": [], + "time_range": "Last 30 days", + "time_range_b": "Last 7 days", + } + + with patch("superset.common.chart_data.ChartDataResultType") as mock_rt: + mock_rt.QUERY = "QUERY" + _build_query_context_from_form_data(form_data, chart=None) + + queries = mock_factory.create.call_args[1]["queries"] + assert len(queries) == 2 + assert queries[0]["time_range"] == "Last 30 days" + assert queries[1]["time_range"] == "Last 7 days" + + @patch("superset.common.query_context_factory.QueryContextFactory") + @patch("superset.daos.datasource.DatasourceDAO.get_datasource") + def test_mixed_timeseries_row_limit_b_overrides_secondary( + self, mock_get_ds, mock_factory_cls + ): + """row_limit_b overrides the primary row_limit for the secondary query.""" + mock_ds = Mock() + mock_ds.database.db_engine_spec.engine = "postgresql" + mock_get_ds.return_value = mock_ds + + mock_factory = Mock() + mock_factory.create.return_value = Mock() + mock_factory_cls.return_value = mock_factory + + form_data = { + "datasource_id": 1, + "datasource_type": "table", + "viz_type": "mixed_timeseries", + "x_axis": "ds", + "metrics": ["sum__revenue"], + "groupby": [], + "metrics_b": ["count"], + "groupby_b": [], + "row_limit": 100, + "row_limit_b": 50, + } + + with patch("superset.common.chart_data.ChartDataResultType") as mock_rt: + mock_rt.QUERY = "QUERY" + _build_query_context_from_form_data(form_data, chart=None) + + queries = mock_factory.create.call_args[1]["queries"] + assert len(queries) == 2 + assert queries[0]["row_limit"] == 100 + assert queries[1]["row_limit"] == 50 + + @patch("superset.common.query_context_factory.QueryContextFactory") + @patch("superset.daos.datasource.DatasourceDAO.get_datasource") + def test_mixed_timeseries_adhoc_filters_b_applied_to_secondary( + self, mock_get_ds, mock_factory_cls + ): + """adhoc_filters_b is processed and applied to the secondary query filters.""" + mock_ds = Mock() + mock_ds.database.db_engine_spec.engine = "postgresql" + mock_get_ds.return_value = mock_ds + + mock_factory = Mock() + mock_factory.create.return_value = Mock() + mock_factory_cls.return_value = mock_factory + + form_data = { + "datasource_id": 1, + "datasource_type": "table", + "viz_type": "mixed_timeseries", + "x_axis": "ds", + "metrics": ["sum__revenue"], + "groupby": [], + "metrics_b": ["count"], + "groupby_b": [], + "adhoc_filters_b": [ + { + "clause": "WHERE", + "expressionType": "SIMPLE", + "subject": "channel", + "operator": "==", + "comparator": "organic", + } + ], + } + + with patch("superset.common.chart_data.ChartDataResultType") as mock_rt: + mock_rt.QUERY = "QUERY" + _build_query_context_from_form_data(form_data, chart=None) + + queries = mock_factory.create.call_args[1]["queries"] + assert len(queries) == 2 + secondary_filters = queries[1].get("filters", []) + assert {"col": "channel", "op": "==", "val": "organic"} in secondary_filters + + class TestResolveDatasourceName: """Tests for _resolve_datasource_name helper."""