diff --git a/superset/mcp_service/chart/tool/get_chart_data.py b/superset/mcp_service/chart/tool/get_chart_data.py index 7421ab0cdc06..d5f19c3dbd0e 100644 --- a/superset/mcp_service/chart/tool/get_chart_data.py +++ b/superset/mcp_service/chart/tool/get_chart_data.py @@ -52,10 +52,177 @@ build_oauth2_redirect_message, OAUTH2_CONFIG_ERROR_MESSAGE, ) -from superset.utils.core import merge_extra_filters +from superset.utils.core import GenericDataType, merge_extra_filters logger = logging.getLogger(__name__) +_GENERIC_TYPE_MAP: dict[int, str] = { + GenericDataType.NUMERIC: "numeric", + GenericDataType.STRING: "string", + GenericDataType.TEMPORAL: "temporal", + GenericDataType.BOOLEAN: "boolean", +} + +# Maps Superset viz_type strings to canonical categories so we can +# avoid recommending a chart type the user already has. +_VIZ_CATEGORY: dict[str, str] = { + "echarts_timeseries_line": "line", + "echarts_timeseries_smooth": "line", + "echarts_timeseries_step": "line", + "echarts_timeseries": "line", + "echarts_timeseries_bar": "bar", + "echarts_area": "area", + "echarts_timeseries_scatter": "scatter", + "mixed_timeseries": "line", + "table": "table", + "pie": "pie", + "big_number": "kpi", + "big_number_total": "kpi", + "pop_kpi": "kpi", + "dist_bar": "bar", + "line": "line", + "area": "area", + "scatter": "scatter", + "bubble": "bubble", + "treemap_v2": "treemap", + "sunburst_v2": "treemap", + "heatmap_v2": "heatmap", + "gauge_chart": "gauge", + "funnel": "funnel", + "histogram": "histogram", + "histogram_v2": "histogram", + "box_plot": "box_plot", + "world_map": "map", + "pivot_table_v2": "table", +} + +_MAX_RECOMMENDATIONS = 4 + + +def _recommend_visualizations( + viz_type: str, + columns: list[DataColumn], + row_count: int, +) -> list[str]: + """Suggest visualization types based on column types, + cardinality, and the chart's current viz_type. + """ + if not columns: + return ["table"] + + current_category = _VIZ_CATEGORY.get(viz_type, viz_type) + candidates = _build_candidates(columns, row_count) + + if not candidates: + candidates = ["table", "bar chart"] + + return _filter_candidates(candidates, current_category) + + +def _build_candidates( + columns: list[DataColumn], + row_count: int, +) -> list[str]: + """Build candidate visualization list from column metadata.""" + temporal = [c for c in columns if c.data_type == "temporal"] + numeric = [c for c in columns if c.data_type == "numeric"] + categorical = [c for c in columns if c.data_type in ("string", "boolean")] + + if temporal and numeric: + return _candidates_temporal_numeric(numeric, row_count) + if categorical and numeric: + return _candidates_categorical_numeric(numeric, categorical) + if len(numeric) >= 2: + return _candidates_multi_numeric(numeric, categorical) + if len(numeric) == 1 and not temporal and not categorical: + return _candidates_single_numeric(numeric[0], row_count) + return [] + + +def _candidates_temporal_numeric( + numeric: list[DataColumn], row_count: int +) -> list[str]: + # Few data points are better as a bar chart than a line + if row_count < 5: + candidates = ["bar chart", "table"] + else: + candidates = ["line chart", "area chart", "bar chart"] + if len(numeric) > 1: + candidates.append("multi-line chart") + return candidates + + +def _candidates_categorical_numeric( + numeric: list[DataColumn], + categorical: list[DataColumn], +) -> list[str]: + candidates = ["bar chart"] + if len(numeric) == 1 and categorical[0].unique_count <= 10: + candidates.append("pie chart") + if len(numeric) >= 2: + candidates.append("scatter plot") + candidates.append("heatmap") + if any(c.unique_count > 5 for c in categorical): + candidates.append("treemap") + return candidates + + +def _candidates_single_numeric(col: DataColumn, row_count: int) -> list[str]: + candidates = ["big number / KPI", "gauge chart"] + if row_count > 20 and col.unique_count > 10: + candidates.insert(0, "histogram") + return candidates + + +def _candidates_multi_numeric( + numeric: list[DataColumn], + categorical: list[DataColumn], +) -> list[str]: + candidates = ["scatter plot"] + if len(numeric) >= 3: + candidates.append("bubble chart") + if categorical: + candidates.append("heatmap") + return candidates + + +# Maps each candidate string to a canonical category for dedup +# against the current viz_type. +_CANDIDATE_CATEGORY: dict[str, str] = { + "line chart": "line", + "multi-line chart": "line", + "area chart": "area", + "bar chart": "bar", + "scatter plot": "scatter", + "bubble chart": "bubble", + "pie chart": "pie", + "treemap": "treemap", + "heatmap": "heatmap", + "big number / KPI": "kpi", + "gauge chart": "gauge", + "histogram": "histogram", + "table": "table", +} + + +def _filter_candidates( + candidates: list[str], + current_category: str, +) -> list[str]: + """Deduplicate, exclude the current viz category, and cap.""" + seen: set[str] = set() + result: list[str] = [] + for c in candidates: + if c in seen: + continue + if _CANDIDATE_CATEGORY.get(c) == current_category: + continue + seen.add(c) + result.append(c) + if len(result) >= _MAX_RECOMMENDATIONS: + break + return result + def _sanitize_chart_data_for_llm_context(chart_data: ChartData) -> ChartData: """Wrap chart data read-path descriptive fields before LLM exposure.""" @@ -620,8 +787,9 @@ async def get_chart_data( # noqa: C901 ) # Create rich column metadata + coltypes = query_result.get("coltypes", []) columns = [] - for col_name in raw_columns: + for idx, col_name in enumerate(raw_columns): # Sample some values for metadata sample_values = [ row.get(col_name) @@ -629,13 +797,16 @@ async def get_chart_data( # noqa: C901 if row.get(col_name) is not None ] - # Infer data type + # Use SQL-derived GenericDataType when available, + # fall back to Python isinstance heuristic data_type = "string" - if sample_values: - if all(isinstance(v, (int, float)) for v in sample_values): - data_type = "numeric" - elif all(isinstance(v, bool) for v in sample_values): + if coltypes: + data_type = _GENERIC_TYPE_MAP.get(coltypes[idx], "string") + elif sample_values: + if all(isinstance(v, bool) for v in sample_values): data_type = "boolean" + elif all(isinstance(v, (int, float)) for v in sample_values): + data_type = "numeric" columns.append( DataColumn( @@ -678,13 +849,11 @@ async def get_chart_data( # noqa: C901 else: insights.append("Fresh data retrieved from database") - recommended_visualizations = [] - if any( - "time" in col.lower() or "date" in col.lower() for col in raw_columns - ): - recommended_visualizations.extend(["line chart", "time series"]) - if len(raw_columns) <= 3: - recommended_visualizations.extend(["bar chart", "scatter plot"]) + recommended_visualizations = _recommend_visualizations( + viz_type=chart.viz_type or "unknown", + columns=columns, + row_count=len(data), + ) # Performance metadata with cache awareness execution_time = int((time.time() - start_time) * 1000) diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py index 8d54cacfabdd..c54f42817a17 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py @@ -30,10 +30,14 @@ PerformanceMetadata, ) from superset.mcp_service.chart.tool.get_chart_data import ( + _GENERIC_TYPE_MAP, + _MAX_RECOMMENDATIONS, + _recommend_visualizations, _sanitize_chart_data_for_llm_context, ) from superset.mcp_service.utils import sanitize_for_llm_context from superset.mcp_service.utils.sanitization import LLM_CONTEXT_ESCAPED_CLOSE_DELIMITER +from superset.utils.core import GenericDataType def _collect_groupby_extras( @@ -988,3 +992,133 @@ def test_compile_chart_security_exception_from_validate(self): ) mock_command.run.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests for _recommend_visualizations +# --------------------------------------------------------------------------- + + +def _col( + name: str, + data_type: str = "string", + unique_count: int = 5, + null_count: int = 0, +) -> DataColumn: + """Shortcut to build a DataColumn for tests.""" + return DataColumn( + name=name, + display_name=name, + data_type=data_type, + sample_values=[], + null_count=null_count, + unique_count=unique_count, + ) + + +def test_recommend_temporal_and_numeric_suggests_line_chart(): + cols = [_col("created_at", "temporal"), _col("revenue", "numeric")] + result = _recommend_visualizations("table", cols, row_count=50) + assert "line chart" in result + assert "area chart" in result + + +def test_recommend_categorical_and_numeric_suggests_bar_chart(): + cols = [_col("region", "string", unique_count=5), _col("sales", "numeric")] + result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50) + assert "bar chart" in result + + +def test_recommend_excludes_current_viz_type(): + cols = [_col("created_at", "temporal"), _col("revenue", "numeric")] + result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50) + assert "line chart" not in result + + +def test_recommend_multiple_numeric_suggests_scatter(): + cols = [ + _col("height", "numeric"), + _col("weight", "numeric"), + _col("age", "numeric"), + ] + result = _recommend_visualizations("table", cols, row_count=100) + assert "scatter plot" in result + + +def test_recommend_single_numeric_suggests_kpi(): + cols = [_col("total_revenue", "numeric")] + result = _recommend_visualizations("table", cols, row_count=1) + assert "big number / KPI" in result + + +def test_recommend_all_strings_falls_back(): + cols = [_col("name", "string"), _col("address", "string")] + result = _recommend_visualizations("pie", cols, row_count=100) + assert "table" in result or "bar chart" in result + + +def test_recommend_high_cardinality_no_pie(): + cols = [ + _col("user_id", "string", unique_count=900), + _col("score", "numeric"), + ] + result = _recommend_visualizations("table", cols, row_count=1000) + assert "pie chart" not in result + + +def test_recommend_caps_at_max(): + cols = [_col("ts", "temporal"), _col("a", "numeric"), _col("b", "numeric")] + result = _recommend_visualizations("table", cols, row_count=100) + assert len(result) <= _MAX_RECOMMENDATIONS + + +def test_recommend_empty_columns_returns_table(): + result = _recommend_visualizations("table", [], row_count=0) + assert result == ["table"] + + +def test_recommend_pie_only_for_low_cardinality(): + cols = [ + _col("department", "string", unique_count=25), + _col("headcount", "numeric"), + ] + result = _recommend_visualizations("table", cols, row_count=100) + assert "pie chart" not in result + + +def test_recommend_temporal_few_rows_prefers_bar(): + cols = [_col("date", "temporal"), _col("revenue", "numeric")] + result = _recommend_visualizations("table", cols, row_count=3) + assert "bar chart" in result + assert "line chart" not in result + + +def test_recommend_single_numeric_high_cardinality_suggests_histogram(): + cols = [_col("salary", "numeric", unique_count=500)] + result = _recommend_visualizations("table", cols, row_count=1000) + assert "histogram" in result + + +def test_coltypes_populates_data_type(): + """Verify that GenericDataType values from coltypes are mapped correctly.""" + assert _GENERIC_TYPE_MAP[GenericDataType.NUMERIC] == "numeric" + assert _GENERIC_TYPE_MAP[GenericDataType.STRING] == "string" + assert _GENERIC_TYPE_MAP[GenericDataType.TEMPORAL] == "temporal" + assert _GENERIC_TYPE_MAP[GenericDataType.BOOLEAN] == "boolean" + + +def test_bool_isinstance_check_before_int(): + """bool is a subclass of int; verify bool check takes priority in fallback.""" + + # When coltypes is unavailable, the fallback isinstance heuristic + # must check bool before int/float since isinstance(True, int) is True. + # We verify this indirectly: if _GENERIC_TYPE_MAP handles bool correctly, + # and the fallback code checks bool first, booleans won't be "numeric". + # Direct test: simulate what the fallback does + sample_values = [True, False, True] + data_type = "string" + if all(isinstance(v, bool) for v in sample_values): + data_type = "boolean" + elif all(isinstance(v, (int, float)) for v in sample_values): + data_type = "numeric" + assert data_type == "boolean"