Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions superset/mcp_service/chart/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,10 @@ class ColumnRef(BaseModel):
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
# No regex pattern: sanitize_name() already blocks XSS/SQL injection;
# many valid column names (digit-prefixed, locale chars, etc.) would
# be rejected by a strict pattern while posing no security risk.
# Use get_dataset_info to find exact column names.
validation_alias=AliasChoices("name", "column_name"),
)
label: str | None = Field(None, max_length=500)
Expand Down Expand Up @@ -743,7 +746,10 @@ class FilterConfig(BaseModel):
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
# No regex pattern: sanitize_name() already blocks XSS/SQL injection;
# many valid column names (digit-prefixed, locale chars, etc.) would
# be rejected by a strict pattern while posing no security risk.
# Use get_dataset_info to find exact column names.
validation_alias=AliasChoices("column", "col"),
)
op: Literal[
Expand Down Expand Up @@ -775,7 +781,9 @@ def sanitize_column(cls, v: str) -> str:
"""Sanitize filter column name to prevent injection attacks."""
# sanitize_user_input raises ValueError when allow_empty=False (default)
# so the return value is guaranteed to be a non-None str
return sanitize_user_input(v, "Filter column", max_length=255) # type: ignore[return-value]
return sanitize_user_input( # type: ignore[return-value]
v, "Filter column", max_length=255, check_sql_keywords=True
)

@field_validator("value")
@classmethod
Expand Down Expand Up @@ -1082,8 +1090,19 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
),
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
# No regex pattern — see field description above.
)
Comment thread
aminghadersohi marked this conversation as resolved.

@field_validator("temporal_column")
@classmethod
def sanitize_temporal_column(cls, v: str | None) -> str | None:
"""Sanitize temporal column name to prevent XSS and SQL injection."""
if v is None:
return None
return sanitize_user_input(
v, "Temporal column", max_length=255, check_sql_keywords=True
)

time_grain: TimeGrain | None = Field(
None,
description=(
Expand Down
83 changes: 82 additions & 1 deletion superset/mcp_service/chart/tool/generate_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ async def generate_chart( # noqa: C901
- Set save_chart=True to permanently save the chart
- LLM clients MUST display returned chart URL to users
- Use numeric dataset ID or UUID (NOT schema.table_name format)
- MUST include chart_type in config (either 'xy' or 'table')
- MUST include chart_type in config (one of: 'xy', 'table', 'pie',
'big_number', 'pivot_table', 'mixed_timeseries', 'handlebars')

IMPORTANT: The 'chart_type' field in the config is a DISCRIMINATOR that determines
which chart configuration schema to use. It MUST be included and MUST match the
Expand All @@ -200,6 +201,86 @@ async def generate_chart( # noqa: C901
}
```


Example usage for Pie chart:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "pie",
"dimension": {"name": "product_category"},
Comment thread
aminghadersohi marked this conversation as resolved.
"metric": {"name": "revenue", "aggregate": "SUM"},
"donut": false
}
}
```

Example usage for Big Number (no trendline):
```json
{
"dataset_id": 123,
"config": {
"chart_type": "big_number",
"metric": {"name": "total_sales", "aggregate": "SUM"}
}
}
```

Example usage for Big Number with trendline:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "big_number",
"metric": {"name": "revenue", "aggregate": "SUM"},
"temporal_column": "order_date",
"time_grain": "P1M",
"show_trendline": true
}
}
```

Example usage for Pivot Table:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "pivot_table",
"rows": [{"name": "region"}],
"columns": [{"name": "product_category"}],
"metrics": [{"name": "revenue", "aggregate": "SUM"}]
}
}
```

Example usage for Mixed Timeseries:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "mixed_timeseries",
"x": {"name": "order_date"},
"y": [{"name": "revenue", "aggregate": "SUM"}],
"primary_kind": "line",
"y_secondary": [{"name": "order_count", "aggregate": "COUNT"}],
"secondary_kind": "bar"
}
}
```

Example usage for Handlebars:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "handlebars",
"handlebars_template": "{{#each data}}{{this.name}}{{/each}}",
"groupby": [{"name": "product"}],
"metrics": [{"name": "revenue", "aggregate": "SUM"}]
}
}
```

Example usage for Table chart:
```json
{
Expand Down
61 changes: 50 additions & 11 deletions superset/mcp_service/chart/validation/schema_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,38 @@ def _pre_validate_mixed_timeseries_config(

return True, None

@staticmethod
def _format_single_error(err: Dict[str, Any]) -> tuple[str, str]:
"""Return (detail_message, optional_suggestion) for one pydantic error."""
loc_parts = [str(p) for p in err.get("loc", [])]
loc = " -> ".join(loc_parts)
msg = err.get("msg", "Validation failed")
err_type = err.get("type", "")
field = loc_parts[-1] if loc_parts else "field"

if err_type == "string_pattern_mismatch":
return (
f"'{field}' value contains disallowed characters. "
"Column names must not contain HTML, script tags, or SQL "
"injection patterns. Use the exact column name from your dataset.",
"Use get_dataset_info to find exact column names",
)
if err_type == "literal_error":
# Preserve the pydantic message ("Input should be ...") which is
# already human-readable; just prefix with the field name for context.
return f"'{field}': {msg}", ""
if err_type == "missing":
return (
f"Required field '{field}' is missing",
"Check the chart_type examples in the tool description",
)
if err_type == "value_error":
return (
f"{loc}: {msg}",
"Use get_dataset_info to verify column names and types",
)
return f"{loc}: {msg}", ""

@staticmethod
def _enhance_validation_error(
error: PydanticValidationError, request_data: Dict[str, Any]
Expand Down Expand Up @@ -609,22 +641,29 @@ def _enhance_validation_error(
error_code="BIG_NUMBER_VALIDATION_ERROR",
)

# Default enhanced error
# Default enhanced error: build actionable per-field messages
error_details = []
for err in errors[:3]: # Show first 3 errors
loc = " -> ".join(str(location) for location in err.get("loc", []))
msg = err.get("msg", "Validation failed")
error_details.append(f"{loc}: {msg}")
extra_suggestions: list[str] = []
for err in errors[:5]: # Surface up to 5 errors
detail, suggestion = SchemaValidator._format_single_error(err)
error_details.append(detail)
if suggestion:
extra_suggestions.append(suggestion)

return ChartGenerationError(
error_type="validation_error",
message="Chart configuration validation failed",
details="; ".join(error_details),
suggestions=[
"Check that all required fields are present",
"Ensure field types match the schema",
"Use get_dataset_info to verify column names",
"Refer to the API documentation for field requirements",
],
suggestions=list(
dict.fromkeys(
[
"Check that all required fields are present",
"Ensure field types match the schema",
"Use get_dataset_info to verify column names",
"Refer to the API documentation for field requirements",
]
+ extra_suggestions
)
),
error_code="VALIDATION_ERROR",
)
96 changes: 96 additions & 0 deletions tests/unit_tests/mcp_service/chart/test_chart_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,99 @@ def test_client_warnings_discarded_even_when_server_also_warns(self) -> None:
assert len(req.sanitization_warnings) == 1
assert "chart_name" in req.sanitization_warnings[0]
assert "injected" not in req.sanitization_warnings[0]


class TestColumnRefNameRelaxedPattern:
"""ColumnRef.name no longer enforces a strict regex pattern.

Many valid database column names were previously rejected:
- Names starting with a digit (e.g. "1Q_revenue")
- Names with locale-specific characters
The field_validator sanitize_name() still blocks XSS and SQL injection.
"""

def test_digit_prefixed_name_accepted(self) -> None:
"""Column names starting with a digit must now be accepted."""
col = ColumnRef(name="1Q_revenue")
assert col.name == "1Q_revenue"

def test_name_with_hyphen_accepted(self) -> None:
col = ColumnRef(name="order-date")
assert col.name == "order-date"

def test_name_with_dot_accepted(self) -> None:
col = ColumnRef(name="schema.column")
assert col.name == "schema.column"

def test_name_with_spaces_accepted(self) -> None:
col = ColumnRef(name="Total Revenue")
assert col.name == "Total Revenue"

def test_script_tag_blocked(self) -> None:
"""sanitize_name() blocks script-tag XSS: nh3 strips the entire script
element (tag + content) leaving an empty string, which the empty-value
guard then rejects."""
with pytest.raises(ValidationError):
ColumnRef(name="<script>alert(1)</script>")

def test_event_handler_injection_blocked(self) -> None:
"""sanitize_name() rejects event-handler injection patterns (on...=)."""
with pytest.raises(ValidationError):
ColumnRef(name="col onclick=alert(1)")

def test_sql_keyword_blocked(self) -> None:
"""check_sql_keywords=True still blocks pure SQL statements."""
with pytest.raises(ValidationError):
ColumnRef(name="1; DROP TABLE users; --")

def test_empty_name_blocked(self) -> None:
with pytest.raises(ValidationError):
ColumnRef(name="")

def test_table_chart_with_digit_prefixed_column(self) -> None:
"""End-to-end: digit-prefixed column passes through GenerateChartRequest."""
req = GenerateChartRequest(
dataset_id=1,
config={
"chart_type": "table",
"columns": [
{"name": "1Q_revenue"},
{"name": "product_name"},
],
},
)
assert req.config.chart_type == "table"

def test_xy_chart_with_hyphenated_column(self) -> None:
req = GenerateChartRequest(
dataset_id=1,
config={
"chart_type": "xy",
"x": {"name": "order-date"},
"y": [{"name": "1Q-revenue", "aggregate": "SUM"}],
},
)
assert req.config.chart_type == "xy"


class TestFilterConfigColumnRelaxedPattern:
"""FilterConfig.column no longer enforces a strict regex pattern."""

def test_digit_prefixed_filter_column_accepted(self) -> None:
from superset.mcp_service.chart.schemas import FilterConfig

f = FilterConfig(column="1Q_flag", op="=", value="active")
assert f.column == "1Q_flag"

def test_hyphenated_filter_column_accepted(self) -> None:
from superset.mcp_service.chart.schemas import FilterConfig

f = FilterConfig(column="order-status", op="=", value="shipped")
assert f.column == "order-status"

def test_sql_injection_in_filter_column_blocked(self) -> None:
"""FilterConfig.sanitize_column uses check_sql_keywords=True."""
from superset.mcp_service.chart.schemas import FilterConfig

with pytest.raises(ValidationError):
FilterConfig(column="col; DROP TABLE users; --", op="=", value="x")
Loading