Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 183 additions & 14 deletions superset/mcp_service/chart/tool/get_chart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,177 @@
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
)
from superset.utils.core import merge_extra_filters
from superset.utils.core import GenericDataType, merge_extra_filters

logger = logging.getLogger(__name__)

_GENERIC_TYPE_MAP: dict[int, str] = {
GenericDataType.NUMERIC: "numeric",
GenericDataType.STRING: "string",
GenericDataType.TEMPORAL: "temporal",
GenericDataType.BOOLEAN: "boolean",
}

# Maps Superset viz_type strings to canonical categories so we can
# avoid recommending a chart type the user already has.
_VIZ_CATEGORY: dict[str, str] = {
"echarts_timeseries_line": "line",
"echarts_timeseries_smooth": "line",
"echarts_timeseries_step": "line",
"echarts_timeseries": "line",
"echarts_timeseries_bar": "bar",
"echarts_area": "area",
"echarts_timeseries_scatter": "scatter",
"mixed_timeseries": "line",
"table": "table",
"pie": "pie",
"big_number": "kpi",
"big_number_total": "kpi",
"pop_kpi": "kpi",
"dist_bar": "bar",
"line": "line",
"area": "area",
"scatter": "scatter",
"bubble": "bubble",
"treemap_v2": "treemap",
"sunburst_v2": "treemap",
"heatmap_v2": "heatmap",
"gauge_chart": "gauge",
"funnel": "funnel",
"histogram": "histogram",
"histogram_v2": "histogram",
"box_plot": "box_plot",
"world_map": "map",
"pivot_table_v2": "table",
}
Comment thread
alexandrusoare marked this conversation as resolved.

_MAX_RECOMMENDATIONS = 4


def _recommend_visualizations(
viz_type: str,
columns: list[DataColumn],
row_count: int,
) -> list[str]:
"""Suggest visualization types based on column types,
cardinality, and the chart's current viz_type.
"""
if not columns:
return ["table"]

current_category = _VIZ_CATEGORY.get(viz_type, viz_type)
candidates = _build_candidates(columns, row_count)

if not candidates:
candidates = ["table", "bar chart"]

return _filter_candidates(candidates, current_category)


def _build_candidates(
columns: list[DataColumn],
row_count: int,
) -> list[str]:
"""Build candidate visualization list from column metadata."""
temporal = [c for c in columns if c.data_type == "temporal"]
numeric = [c for c in columns if c.data_type == "numeric"]
categorical = [c for c in columns if c.data_type in ("string", "boolean")]
Comment on lines +127 to +129
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can probably be one loop right? Would that be better for the performance overhead?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These iterate over the chart's columns — typically 2-10 items. Three passes on <10 elements is negligible; a single loop with conditional appends would be harder to read for no meaningful gain. Let me know your thoughts

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, my concern was about charts that have much more than that, but not sure if that is something that happens often


if temporal and numeric:
return _candidates_temporal_numeric(numeric, row_count)
if categorical and numeric:
return _candidates_categorical_numeric(numeric, categorical)
if len(numeric) >= 2:
return _candidates_multi_numeric(numeric, categorical)
if len(numeric) == 1 and not temporal and not categorical:
return _candidates_single_numeric(numeric[0], row_count)
return []
Comment thread
msyavuz marked this conversation as resolved.


def _candidates_temporal_numeric(
numeric: list[DataColumn], row_count: int
) -> list[str]:
# Few data points are better as a bar chart than a line
if row_count < 5:
candidates = ["bar chart", "table"]
else:
candidates = ["line chart", "area chart", "bar chart"]
if len(numeric) > 1:
candidates.append("multi-line chart")
return candidates


def _candidates_categorical_numeric(
numeric: list[DataColumn],
categorical: list[DataColumn],
) -> list[str]:
candidates = ["bar chart"]
if len(numeric) == 1 and categorical[0].unique_count <= 10:
candidates.append("pie chart")
if len(numeric) >= 2:
candidates.append("scatter plot")
candidates.append("heatmap")
if any(c.unique_count > 5 for c in categorical):
candidates.append("treemap")
return candidates


def _candidates_single_numeric(col: DataColumn, row_count: int) -> list[str]:
candidates = ["big number / KPI", "gauge chart"]
if row_count > 20 and col.unique_count > 10:
candidates.insert(0, "histogram")
return candidates


def _candidates_multi_numeric(
numeric: list[DataColumn],
categorical: list[DataColumn],
) -> list[str]:
candidates = ["scatter plot"]
if len(numeric) >= 3:
candidates.append("bubble chart")
if categorical:
candidates.append("heatmap")
return candidates
Comment thread
msyavuz marked this conversation as resolved.


# Maps each candidate string to a canonical category for dedup
# against the current viz_type.
_CANDIDATE_CATEGORY: dict[str, str] = {
"line chart": "line",
"multi-line chart": "line",
"area chart": "area",
"bar chart": "bar",
"scatter plot": "scatter",
"bubble chart": "bubble",
"pie chart": "pie",
"treemap": "treemap",
"heatmap": "heatmap",
"big number / KPI": "kpi",
"gauge chart": "gauge",
"histogram": "histogram",
"table": "table",
}


def _filter_candidates(
candidates: list[str],
current_category: str,
) -> list[str]:
"""Deduplicate, exclude the current viz category, and cap."""
seen: set[str] = set()
result: list[str] = []
for c in candidates:
if c in seen:
continue
if _CANDIDATE_CATEGORY.get(c) == current_category:
continue
seen.add(c)
result.append(c)
if len(result) >= _MAX_RECOMMENDATIONS:
break
return result


def _sanitize_chart_data_for_llm_context(chart_data: ChartData) -> ChartData:
"""Wrap chart data read-path descriptive fields before LLM exposure."""
Expand Down Expand Up @@ -620,22 +787,26 @@ async def get_chart_data( # noqa: C901
)

# Create rich column metadata
coltypes = query_result.get("coltypes", [])
columns = []
for col_name in raw_columns:
for idx, col_name in enumerate(raw_columns):
# Sample some values for metadata
sample_values = [
row.get(col_name)
for row in data[:3]
if row.get(col_name) is not None
]

# Infer data type
# Use SQL-derived GenericDataType when available,
# fall back to Python isinstance heuristic
data_type = "string"
if sample_values:
if all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
elif all(isinstance(v, bool) for v in sample_values):
if coltypes:
data_type = _GENERIC_TYPE_MAP.get(coltypes[idx], "string")
elif sample_values:
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"

columns.append(
DataColumn(
Expand Down Expand Up @@ -678,13 +849,11 @@ async def get_chart_data( # noqa: C901
else:
insights.append("Fresh data retrieved from database")

recommended_visualizations = []
if any(
"time" in col.lower() or "date" in col.lower() for col in raw_columns
):
recommended_visualizations.extend(["line chart", "time series"])
if len(raw_columns) <= 3:
recommended_visualizations.extend(["bar chart", "scatter plot"])
recommended_visualizations = _recommend_visualizations(
viz_type=chart.viz_type or "unknown",
columns=columns,
row_count=len(data),
)

# Performance metadata with cache awareness
execution_time = int((time.time() - start_time) * 1000)
Expand Down
134 changes: 134 additions & 0 deletions tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@
PerformanceMetadata,
)
from superset.mcp_service.chart.tool.get_chart_data import (
_GENERIC_TYPE_MAP,
_MAX_RECOMMENDATIONS,
_recommend_visualizations,
_sanitize_chart_data_for_llm_context,
)
from superset.mcp_service.utils import sanitize_for_llm_context
from superset.mcp_service.utils.sanitization import LLM_CONTEXT_ESCAPED_CLOSE_DELIMITER
from superset.utils.core import GenericDataType


def _collect_groupby_extras(
Expand Down Expand Up @@ -988,3 +992,133 @@ def test_compile_chart_security_exception_from_validate(self):
)

mock_command.run.assert_not_called()


# ---------------------------------------------------------------------------
# Tests for _recommend_visualizations
# ---------------------------------------------------------------------------


def _col(
name: str,
data_type: str = "string",
unique_count: int = 5,
null_count: int = 0,
) -> DataColumn:
"""Shortcut to build a DataColumn for tests."""
return DataColumn(
name=name,
display_name=name,
data_type=data_type,
sample_values=[],
null_count=null_count,
unique_count=unique_count,
)


def test_recommend_temporal_and_numeric_suggests_line_chart():
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=50)
assert "line chart" in result
assert "area chart" in result


def test_recommend_categorical_and_numeric_suggests_bar_chart():
cols = [_col("region", "string", unique_count=5), _col("sales", "numeric")]
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
assert "bar chart" in result


def test_recommend_excludes_current_viz_type():
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
assert "line chart" not in result


def test_recommend_multiple_numeric_suggests_scatter():
cols = [
_col("height", "numeric"),
_col("weight", "numeric"),
_col("age", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=100)
assert "scatter plot" in result


def test_recommend_single_numeric_suggests_kpi():
cols = [_col("total_revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=1)
assert "big number / KPI" in result


def test_recommend_all_strings_falls_back():
cols = [_col("name", "string"), _col("address", "string")]
result = _recommend_visualizations("pie", cols, row_count=100)
assert "table" in result or "bar chart" in result


def test_recommend_high_cardinality_no_pie():
cols = [
_col("user_id", "string", unique_count=900),
_col("score", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=1000)
assert "pie chart" not in result


def test_recommend_caps_at_max():
cols = [_col("ts", "temporal"), _col("a", "numeric"), _col("b", "numeric")]
result = _recommend_visualizations("table", cols, row_count=100)
assert len(result) <= _MAX_RECOMMENDATIONS


def test_recommend_empty_columns_returns_table():
result = _recommend_visualizations("table", [], row_count=0)
assert result == ["table"]


def test_recommend_pie_only_for_low_cardinality():
cols = [
_col("department", "string", unique_count=25),
_col("headcount", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=100)
assert "pie chart" not in result


def test_recommend_temporal_few_rows_prefers_bar():
cols = [_col("date", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=3)
assert "bar chart" in result
assert "line chart" not in result


def test_recommend_single_numeric_high_cardinality_suggests_histogram():
cols = [_col("salary", "numeric", unique_count=500)]
result = _recommend_visualizations("table", cols, row_count=1000)
assert "histogram" in result


def test_coltypes_populates_data_type():
"""Verify that GenericDataType values from coltypes are mapped correctly."""
assert _GENERIC_TYPE_MAP[GenericDataType.NUMERIC] == "numeric"
assert _GENERIC_TYPE_MAP[GenericDataType.STRING] == "string"
assert _GENERIC_TYPE_MAP[GenericDataType.TEMPORAL] == "temporal"
assert _GENERIC_TYPE_MAP[GenericDataType.BOOLEAN] == "boolean"


def test_bool_isinstance_check_before_int():
"""bool is a subclass of int; verify bool check takes priority in fallback."""

# When coltypes is unavailable, the fallback isinstance heuristic
# must check bool before int/float since isinstance(True, int) is True.
# We verify this indirectly: if _GENERIC_TYPE_MAP handles bool correctly,
# and the fallback code checks bool first, booleans won't be "numeric".
# Direct test: simulate what the fallback does
sample_values = [True, False, True]
data_type = "string"
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
assert data_type == "boolean"
Loading