diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 1fdb4f43ab6d..b09b6dc0556e 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -669,7 +669,10 @@ class ColumnRef(BaseModel): ..., min_length=1, max_length=255, - pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", + # No regex pattern: sanitize_name() already blocks XSS/SQL injection; + # many valid column names (digit-prefixed, locale chars, etc.) would + # be rejected by a strict pattern while posing no security risk. + # Use get_dataset_info to find exact column names. validation_alias=AliasChoices("name", "column_name"), ) label: str | None = Field(None, max_length=500) @@ -743,7 +746,10 @@ class FilterConfig(BaseModel): ..., min_length=1, max_length=255, - pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", + # No regex pattern: sanitize_name() already blocks XSS/SQL injection; + # many valid column names (digit-prefixed, locale chars, etc.) would + # be rejected by a strict pattern while posing no security risk. + # Use get_dataset_info to find exact column names. validation_alias=AliasChoices("column", "col"), ) op: Literal[ @@ -775,7 +781,9 @@ def sanitize_column(cls, v: str) -> str: """Sanitize filter column name to prevent injection attacks.""" # sanitize_user_input raises ValueError when allow_empty=False (default) # so the return value is guaranteed to be a non-None str - return sanitize_user_input(v, "Filter column", max_length=255) # type: ignore[return-value] + return sanitize_user_input( # type: ignore[return-value] + v, "Filter column", max_length=255, check_sql_keywords=True + ) @field_validator("value") @classmethod @@ -1082,8 +1090,19 @@ class BigNumberChartConfig(UnknownFieldCheckMixin): ), min_length=1, max_length=255, - pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", + # No regex pattern — see field description above. ) + + @field_validator("temporal_column") + @classmethod + def sanitize_temporal_column(cls, v: str | None) -> str | None: + """Sanitize temporal column name to prevent XSS and SQL injection.""" + if v is None: + return None + return sanitize_user_input( + v, "Temporal column", max_length=255, check_sql_keywords=True + ) + time_grain: TimeGrain | None = Field( None, description=( diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index 646ac4d4c2a7..697698bbda03 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -175,7 +175,8 @@ async def generate_chart( # noqa: C901 - Set save_chart=True to permanently save the chart - LLM clients MUST display returned chart URL to users - Use numeric dataset ID or UUID (NOT schema.table_name format) - - MUST include chart_type in config (either 'xy' or 'table') + - MUST include chart_type in config (one of: 'xy', 'table', 'pie', + 'big_number', 'pivot_table', 'mixed_timeseries', 'handlebars') IMPORTANT: The 'chart_type' field in the config is a DISCRIMINATOR that determines which chart configuration schema to use. It MUST be included and MUST match the @@ -200,6 +201,86 @@ async def generate_chart( # noqa: C901 } ``` + + Example usage for Pie chart: + ```json + { + "dataset_id": 123, + "config": { + "chart_type": "pie", + "dimension": {"name": "product_category"}, + "metric": {"name": "revenue", "aggregate": "SUM"}, + "donut": false + } + } + ``` + + Example usage for Big Number (no trendline): + ```json + { + "dataset_id": 123, + "config": { + "chart_type": "big_number", + "metric": {"name": "total_sales", "aggregate": "SUM"} + } + } + ``` + + Example usage for Big Number with trendline: + ```json + { + "dataset_id": 123, + "config": { + "chart_type": "big_number", + "metric": {"name": "revenue", "aggregate": "SUM"}, + "temporal_column": "order_date", + "time_grain": "P1M", + "show_trendline": true + } + } + ``` + + Example usage for Pivot Table: + ```json + { + "dataset_id": 123, + "config": { + "chart_type": "pivot_table", + "rows": [{"name": "region"}], + "columns": [{"name": "product_category"}], + "metrics": [{"name": "revenue", "aggregate": "SUM"}] + } + } + ``` + + Example usage for Mixed Timeseries: + ```json + { + "dataset_id": 123, + "config": { + "chart_type": "mixed_timeseries", + "x": {"name": "order_date"}, + "y": [{"name": "revenue", "aggregate": "SUM"}], + "primary_kind": "line", + "y_secondary": [{"name": "order_count", "aggregate": "COUNT"}], + "secondary_kind": "bar" + } + } + ``` + + Example usage for Handlebars: + ```json + { + "dataset_id": 123, + "config": { + "chart_type": "handlebars", + "handlebars_template": "{{#each data}}{{this.name}}{{/each}}", + "groupby": [{"name": "product"}], + "metrics": [{"name": "revenue", "aggregate": "SUM"}] + } + } + ``` + Example usage for Table chart: ```json { diff --git a/superset/mcp_service/chart/validation/schema_validator.py b/superset/mcp_service/chart/validation/schema_validator.py index 7cae450ff599..5fc5af709385 100644 --- a/superset/mcp_service/chart/validation/schema_validator.py +++ b/superset/mcp_service/chart/validation/schema_validator.py @@ -525,6 +525,38 @@ def _pre_validate_mixed_timeseries_config( return True, None + @staticmethod + def _format_single_error(err: Dict[str, Any]) -> tuple[str, str]: + """Return (detail_message, optional_suggestion) for one pydantic error.""" + loc_parts = [str(p) for p in err.get("loc", [])] + loc = " -> ".join(loc_parts) + msg = err.get("msg", "Validation failed") + err_type = err.get("type", "") + field = loc_parts[-1] if loc_parts else "field" + + if err_type == "string_pattern_mismatch": + return ( + f"'{field}' value contains disallowed characters. " + "Column names must not contain HTML, script tags, or SQL " + "injection patterns. Use the exact column name from your dataset.", + "Use get_dataset_info to find exact column names", + ) + if err_type == "literal_error": + # Preserve the pydantic message ("Input should be ...") which is + # already human-readable; just prefix with the field name for context. + return f"'{field}': {msg}", "" + if err_type == "missing": + return ( + f"Required field '{field}' is missing", + "Check the chart_type examples in the tool description", + ) + if err_type == "value_error": + return ( + f"{loc}: {msg}", + "Use get_dataset_info to verify column names and types", + ) + return f"{loc}: {msg}", "" + @staticmethod def _enhance_validation_error( error: PydanticValidationError, request_data: Dict[str, Any] @@ -609,22 +641,29 @@ def _enhance_validation_error( error_code="BIG_NUMBER_VALIDATION_ERROR", ) - # Default enhanced error + # Default enhanced error: build actionable per-field messages error_details = [] - for err in errors[:3]: # Show first 3 errors - loc = " -> ".join(str(location) for location in err.get("loc", [])) - msg = err.get("msg", "Validation failed") - error_details.append(f"{loc}: {msg}") + extra_suggestions: list[str] = [] + for err in errors[:5]: # Surface up to 5 errors + detail, suggestion = SchemaValidator._format_single_error(err) + error_details.append(detail) + if suggestion: + extra_suggestions.append(suggestion) return ChartGenerationError( error_type="validation_error", message="Chart configuration validation failed", details="; ".join(error_details), - suggestions=[ - "Check that all required fields are present", - "Ensure field types match the schema", - "Use get_dataset_info to verify column names", - "Refer to the API documentation for field requirements", - ], + suggestions=list( + dict.fromkeys( + [ + "Check that all required fields are present", + "Ensure field types match the schema", + "Use get_dataset_info to verify column names", + "Refer to the API documentation for field requirements", + ] + + extra_suggestions + ) + ), error_code="VALIDATION_ERROR", ) diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py index 1da9d1bc3a74..19d06601912f 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py @@ -778,3 +778,99 @@ def test_client_warnings_discarded_even_when_server_also_warns(self) -> None: assert len(req.sanitization_warnings) == 1 assert "chart_name" in req.sanitization_warnings[0] assert "injected" not in req.sanitization_warnings[0] + + +class TestColumnRefNameRelaxedPattern: + """ColumnRef.name no longer enforces a strict regex pattern. + + Many valid database column names were previously rejected: + - Names starting with a digit (e.g. "1Q_revenue") + - Names with locale-specific characters + The field_validator sanitize_name() still blocks XSS and SQL injection. + """ + + def test_digit_prefixed_name_accepted(self) -> None: + """Column names starting with a digit must now be accepted.""" + col = ColumnRef(name="1Q_revenue") + assert col.name == "1Q_revenue" + + def test_name_with_hyphen_accepted(self) -> None: + col = ColumnRef(name="order-date") + assert col.name == "order-date" + + def test_name_with_dot_accepted(self) -> None: + col = ColumnRef(name="schema.column") + assert col.name == "schema.column" + + def test_name_with_spaces_accepted(self) -> None: + col = ColumnRef(name="Total Revenue") + assert col.name == "Total Revenue" + + def test_script_tag_blocked(self) -> None: + """sanitize_name() blocks script-tag XSS: nh3 strips the entire script + element (tag + content) leaving an empty string, which the empty-value + guard then rejects.""" + with pytest.raises(ValidationError): + ColumnRef(name="") + + def test_event_handler_injection_blocked(self) -> None: + """sanitize_name() rejects event-handler injection patterns (on...=).""" + with pytest.raises(ValidationError): + ColumnRef(name="col onclick=alert(1)") + + def test_sql_keyword_blocked(self) -> None: + """check_sql_keywords=True still blocks pure SQL statements.""" + with pytest.raises(ValidationError): + ColumnRef(name="1; DROP TABLE users; --") + + def test_empty_name_blocked(self) -> None: + with pytest.raises(ValidationError): + ColumnRef(name="") + + def test_table_chart_with_digit_prefixed_column(self) -> None: + """End-to-end: digit-prefixed column passes through GenerateChartRequest.""" + req = GenerateChartRequest( + dataset_id=1, + config={ + "chart_type": "table", + "columns": [ + {"name": "1Q_revenue"}, + {"name": "product_name"}, + ], + }, + ) + assert req.config.chart_type == "table" + + def test_xy_chart_with_hyphenated_column(self) -> None: + req = GenerateChartRequest( + dataset_id=1, + config={ + "chart_type": "xy", + "x": {"name": "order-date"}, + "y": [{"name": "1Q-revenue", "aggregate": "SUM"}], + }, + ) + assert req.config.chart_type == "xy" + + +class TestFilterConfigColumnRelaxedPattern: + """FilterConfig.column no longer enforces a strict regex pattern.""" + + def test_digit_prefixed_filter_column_accepted(self) -> None: + from superset.mcp_service.chart.schemas import FilterConfig + + f = FilterConfig(column="1Q_flag", op="=", value="active") + assert f.column == "1Q_flag" + + def test_hyphenated_filter_column_accepted(self) -> None: + from superset.mcp_service.chart.schemas import FilterConfig + + f = FilterConfig(column="order-status", op="=", value="shipped") + assert f.column == "order-status" + + def test_sql_injection_in_filter_column_blocked(self) -> None: + """FilterConfig.sanitize_column uses check_sql_keywords=True.""" + from superset.mcp_service.chart.schemas import FilterConfig + + with pytest.raises(ValidationError): + FilterConfig(column="col; DROP TABLE users; --", op="=", value="x")