-
Notifications
You must be signed in to change notification settings - Fork 17.2k
fix(recommandation): Fix chart recommandation #39886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,10 +52,177 @@ | |
| build_oauth2_redirect_message, | ||
| OAUTH2_CONFIG_ERROR_MESSAGE, | ||
| ) | ||
| from superset.utils.core import merge_extra_filters | ||
| from superset.utils.core import GenericDataType, merge_extra_filters | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _GENERIC_TYPE_MAP: dict[int, str] = { | ||
| GenericDataType.NUMERIC: "numeric", | ||
| GenericDataType.STRING: "string", | ||
| GenericDataType.TEMPORAL: "temporal", | ||
| GenericDataType.BOOLEAN: "boolean", | ||
| } | ||
|
|
||
| # Maps Superset viz_type strings to canonical categories so we can | ||
| # avoid recommending a chart type the user already has. | ||
| _VIZ_CATEGORY: dict[str, str] = { | ||
| "echarts_timeseries_line": "line", | ||
| "echarts_timeseries_smooth": "line", | ||
| "echarts_timeseries_step": "line", | ||
| "echarts_timeseries": "line", | ||
| "echarts_timeseries_bar": "bar", | ||
| "echarts_area": "area", | ||
| "echarts_timeseries_scatter": "scatter", | ||
| "mixed_timeseries": "line", | ||
| "table": "table", | ||
| "pie": "pie", | ||
| "big_number": "kpi", | ||
| "big_number_total": "kpi", | ||
| "pop_kpi": "kpi", | ||
| "dist_bar": "bar", | ||
| "line": "line", | ||
| "area": "area", | ||
| "scatter": "scatter", | ||
| "bubble": "bubble", | ||
| "treemap_v2": "treemap", | ||
| "sunburst_v2": "treemap", | ||
| "heatmap_v2": "heatmap", | ||
| "gauge_chart": "gauge", | ||
| "funnel": "funnel", | ||
| "histogram": "histogram", | ||
| "histogram_v2": "histogram", | ||
| "box_plot": "box_plot", | ||
| "world_map": "map", | ||
| "pivot_table_v2": "table", | ||
| } | ||
|
|
||
| _MAX_RECOMMENDATIONS = 4 | ||
|
|
||
|
|
||
| def _recommend_visualizations( | ||
| viz_type: str, | ||
| columns: list[DataColumn], | ||
| row_count: int, | ||
| ) -> list[str]: | ||
| """Suggest visualization types based on column types, | ||
| cardinality, and the chart's current viz_type. | ||
| """ | ||
| if not columns: | ||
| return ["table"] | ||
|
|
||
| current_category = _VIZ_CATEGORY.get(viz_type, viz_type) | ||
| candidates = _build_candidates(columns, row_count) | ||
|
|
||
| if not candidates: | ||
| candidates = ["table", "bar chart"] | ||
|
|
||
| return _filter_candidates(candidates, current_category) | ||
|
|
||
|
|
||
| def _build_candidates( | ||
| columns: list[DataColumn], | ||
| row_count: int, | ||
| ) -> list[str]: | ||
| """Build candidate visualization list from column metadata.""" | ||
| temporal = [c for c in columns if c.data_type == "temporal"] | ||
| numeric = [c for c in columns if c.data_type == "numeric"] | ||
| categorical = [c for c in columns if c.data_type in ("string", "boolean")] | ||
|
Comment on lines
+127
to
+129
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can probably be one loop right? Would that be better for the performance overhead?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These iterate over the chart's columns — typically 2-10 items. Three passes on <10 elements is negligible; a single loop with conditional appends would be harder to read for no meaningful gain. Let me know your thoughts
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, my concern was about charts that have much more than that, but not sure if that is something that happens often |
||
|
|
||
| if temporal and numeric: | ||
| return _candidates_temporal_numeric(numeric, row_count) | ||
| if categorical and numeric: | ||
| return _candidates_categorical_numeric(numeric, categorical) | ||
| if len(numeric) >= 2: | ||
| return _candidates_multi_numeric(numeric, categorical) | ||
| if len(numeric) == 1 and not temporal and not categorical: | ||
| return _candidates_single_numeric(numeric[0], row_count) | ||
| return [] | ||
|
msyavuz marked this conversation as resolved.
|
||
|
|
||
|
|
||
| def _candidates_temporal_numeric( | ||
| numeric: list[DataColumn], row_count: int | ||
| ) -> list[str]: | ||
| # Few data points are better as a bar chart than a line | ||
| if row_count < 5: | ||
| candidates = ["bar chart", "table"] | ||
| else: | ||
| candidates = ["line chart", "area chart", "bar chart"] | ||
| if len(numeric) > 1: | ||
| candidates.append("multi-line chart") | ||
| return candidates | ||
|
|
||
|
|
||
| def _candidates_categorical_numeric( | ||
| numeric: list[DataColumn], | ||
| categorical: list[DataColumn], | ||
| ) -> list[str]: | ||
| candidates = ["bar chart"] | ||
| if len(numeric) == 1 and categorical[0].unique_count <= 10: | ||
| candidates.append("pie chart") | ||
| if len(numeric) >= 2: | ||
| candidates.append("scatter plot") | ||
| candidates.append("heatmap") | ||
| if any(c.unique_count > 5 for c in categorical): | ||
| candidates.append("treemap") | ||
| return candidates | ||
|
|
||
|
|
||
| def _candidates_single_numeric(col: DataColumn, row_count: int) -> list[str]: | ||
| candidates = ["big number / KPI", "gauge chart"] | ||
| if row_count > 20 and col.unique_count > 10: | ||
| candidates.insert(0, "histogram") | ||
| return candidates | ||
|
|
||
|
|
||
| def _candidates_multi_numeric( | ||
| numeric: list[DataColumn], | ||
| categorical: list[DataColumn], | ||
| ) -> list[str]: | ||
| candidates = ["scatter plot"] | ||
| if len(numeric) >= 3: | ||
| candidates.append("bubble chart") | ||
| if categorical: | ||
| candidates.append("heatmap") | ||
| return candidates | ||
|
msyavuz marked this conversation as resolved.
|
||
|
|
||
|
|
||
| # Maps each candidate string to a canonical category for dedup | ||
| # against the current viz_type. | ||
| _CANDIDATE_CATEGORY: dict[str, str] = { | ||
| "line chart": "line", | ||
| "multi-line chart": "line", | ||
| "area chart": "area", | ||
| "bar chart": "bar", | ||
| "scatter plot": "scatter", | ||
| "bubble chart": "bubble", | ||
| "pie chart": "pie", | ||
| "treemap": "treemap", | ||
| "heatmap": "heatmap", | ||
| "big number / KPI": "kpi", | ||
| "gauge chart": "gauge", | ||
| "histogram": "histogram", | ||
| "table": "table", | ||
| } | ||
|
|
||
|
|
||
| def _filter_candidates( | ||
| candidates: list[str], | ||
| current_category: str, | ||
| ) -> list[str]: | ||
| """Deduplicate, exclude the current viz category, and cap.""" | ||
| seen: set[str] = set() | ||
| result: list[str] = [] | ||
| for c in candidates: | ||
| if c in seen: | ||
| continue | ||
| if _CANDIDATE_CATEGORY.get(c) == current_category: | ||
| continue | ||
| seen.add(c) | ||
| result.append(c) | ||
| if len(result) >= _MAX_RECOMMENDATIONS: | ||
| break | ||
| return result | ||
|
|
||
|
|
||
| def _sanitize_chart_data_for_llm_context(chart_data: ChartData) -> ChartData: | ||
| """Wrap chart data read-path descriptive fields before LLM exposure.""" | ||
|
|
@@ -620,22 +787,26 @@ async def get_chart_data( # noqa: C901 | |
| ) | ||
|
|
||
| # Create rich column metadata | ||
| coltypes = query_result.get("coltypes", []) | ||
| columns = [] | ||
| for col_name in raw_columns: | ||
| for idx, col_name in enumerate(raw_columns): | ||
| # Sample some values for metadata | ||
| sample_values = [ | ||
| row.get(col_name) | ||
| for row in data[:3] | ||
| if row.get(col_name) is not None | ||
| ] | ||
|
|
||
| # Infer data type | ||
| # Use SQL-derived GenericDataType when available, | ||
| # fall back to Python isinstance heuristic | ||
| data_type = "string" | ||
| if sample_values: | ||
| if all(isinstance(v, (int, float)) for v in sample_values): | ||
| data_type = "numeric" | ||
| elif all(isinstance(v, bool) for v in sample_values): | ||
| if coltypes: | ||
| data_type = _GENERIC_TYPE_MAP.get(coltypes[idx], "string") | ||
| elif sample_values: | ||
| if all(isinstance(v, bool) for v in sample_values): | ||
| data_type = "boolean" | ||
| elif all(isinstance(v, (int, float)) for v in sample_values): | ||
| data_type = "numeric" | ||
|
|
||
| columns.append( | ||
| DataColumn( | ||
|
|
@@ -678,13 +849,11 @@ async def get_chart_data( # noqa: C901 | |
| else: | ||
| insights.append("Fresh data retrieved from database") | ||
|
|
||
| recommended_visualizations = [] | ||
| if any( | ||
| "time" in col.lower() or "date" in col.lower() for col in raw_columns | ||
| ): | ||
| recommended_visualizations.extend(["line chart", "time series"]) | ||
| if len(raw_columns) <= 3: | ||
| recommended_visualizations.extend(["bar chart", "scatter plot"]) | ||
| recommended_visualizations = _recommend_visualizations( | ||
| viz_type=chart.viz_type or "unknown", | ||
| columns=columns, | ||
| row_count=len(data), | ||
| ) | ||
|
|
||
| # Performance metadata with cache awareness | ||
| execution_time = int((time.time() - start_time) * 1000) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.