From d1e427e4f4f7cfbf3c5de1ab488468d5dddd4c46 Mon Sep 17 00:00:00 2001 From: kould Date: Sun, 3 May 2026 21:27:26 +0800 Subject: [PATCH] fix: support joined update and delete --- src/binder/delete.rs | 31 +--- src/binder/mod.rs | 12 +- src/binder/update.rs | 51 +++++ src/execution/dml/update.rs | 4 +- src/execution/dql/join/hash/full_join.rs | 5 +- src/execution/dql/join/hash/inner_join.rs | 5 +- src/execution/dql/join/hash/left_join.rs | 5 +- src/execution/dql/join/hash/right_join.rs | 5 +- src/execution/dql/join/hash_join.rs | 2 +- src/execution/dql/join/nested_loop_join.rs | 2 +- .../rule/normalization/pushdown_predicates.rs | 30 ++- tests/slt/joined_dml.slt | 175 ++++++++++++++++++ 12 files changed, 279 insertions(+), 48 deletions(-) create mode 100644 tests/slt/joined_dml.slt diff --git a/src/binder/delete.rs b/src/binder/delete.rs index 42703c20..49026f89 100644 --- a/src/binder/delete.rs +++ b/src/binder/delete.rs @@ -12,16 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{lower_case_name, Binder, Source}; +use crate::binder::{lower_case_name, Binder}; use crate::errors::DatabaseError; use crate::planner::operator::delete::DeleteOperator; -use crate::planner::operator::table_scan::TableScanOperator; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; use itertools::Itertools; -use sqlparser::ast::{Expr, TableAlias, TableFactor, TableWithJoins}; +use sqlparser::ast::{Expr, TableFactor, TableWithJoins}; use std::sync::Arc; impl> Binder<'_, '_, T, A> { @@ -30,33 +29,19 @@ impl> Binder<'_, '_, T, A> from: &TableWithJoins, selection: &Option, ) -> Result { - if let TableFactor::Table { name, alias, .. } = &from.relation { + if let TableFactor::Table { name, .. } = &from.relation { let table_name: Arc = lower_case_name(name)?.into(); - let mut table_alias = None; - let mut alias_idents = None; - - if let Some(TableAlias { name, columns, .. }) = alias { - table_alias = Some(name.value.to_lowercase().into()); - alias_idents = Some(columns); - } - let Source::Table(table) = self + let table = self .context - .source_and_bind(table_name.clone(), table_alias.as_ref(), None, true)? - .ok_or(DatabaseError::TableNotFound)? - else { - unreachable!() - }; + .table(table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; let primary_keys = table .primary_keys() .iter() .map(|(_, column)| column.clone()) .collect_vec(); - let mut plan = TableScanOperator::build(table_name.clone(), table, true)?; - - if let Some(alias_idents) = alias_idents { - plan = - self.bind_alias(plan, alias_idents, table_alias.unwrap(), table_name.clone())?; - } + self.with_pk(table_name.clone()); + let mut plan = self.bind_table_ref(from)?; if let Some(predicate) = selection { plan = self.bind_where(plan, predicate)?; diff --git a/src/binder/mod.rs b/src/binder/mod.rs index df04b1ea..2b1210e6 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -573,11 +573,7 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' } Statement::Update(update) => { let table = &update.table; - if !table.joins.is_empty() { - unimplemented!() - } else { - self.bind_update(table, &update.selection, &update.assignments)? - } + self.bind_update(table, &update.selection, &update.assignments)? } Statement::Delete(delete) => { let from = match &delete.from { @@ -585,11 +581,7 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' }; let table = &from[0]; - if !table.joins.is_empty() { - unimplemented!() - } else { - self.bind_delete(table, &delete.selection)? - } + self.bind_delete(table, &delete.selection)? } Statement::Analyze(analyze) => { let table_name = analyze.table_name.as_ref().ok_or_else(|| { diff --git a/src/binder/update.rs b/src/binder/update.rs index 141349d0..3b99fb5e 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -16,7 +16,9 @@ use crate::binder::{ attach_span_from_sqlparser_span_if_absent, attach_span_if_absent, lower_case_name, Binder, }; use crate::errors::DatabaseError; +use crate::expression::visitor_mut::VisitorMut; use crate::expression::ScalarExpression; +use crate::planner::operator::project::ProjectOperator; use crate::planner::operator::update::UpdateOperator; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; @@ -29,6 +31,30 @@ use std::borrow::Cow; use std::slice; use std::sync::Arc; +struct UpdateExprTargetRemapper<'a> { + target_schema: &'a [crate::catalog::ColumnRef], +} + +impl VisitorMut<'_> for UpdateExprTargetRemapper<'_> { + fn visit_column_ref( + &mut self, + column: &mut crate::catalog::ColumnRef, + position: &mut usize, + ) -> Result<(), DatabaseError> { + let Some(target_position) = self + .target_schema + .iter() + .position(|target_column| target_column.same_column(column)) + else { + return Err(DatabaseError::UnsupportedStmt( + "joined UPDATE SET expressions can only reference target table columns".to_string(), + )); + }; + *position = target_position; + Ok(()) + } +} + impl> Binder<'_, '_, T, A> { fn single_ident_from_object_name(name: &ObjectName) -> Result<&Ident, DatabaseError> { if name.0.len() != 1 { @@ -51,10 +77,16 @@ impl> Binder<'_, '_, T, A> // FIXME: Make it better to detect the current BindStep self.context.allow_default = true; if let TableFactor::Table { name, .. } = &to.relation { + let is_joined_update = !to.joins.is_empty(); let table_name: Arc = lower_case_name(name)?.into(); self.with_pk(table_name.clone()); let mut plan = self.bind_table_ref(to)?; + let (target_schema, target_offset) = Self::resolve_source_columns_in_scope( + &self.context, + &mut self.table_schema_buf, + &table_name, + )?; if let Some(predicate) = selection { plan = self.bind_where(plan, predicate)?; @@ -96,6 +128,12 @@ impl> Binder<'_, '_, T, A> expr, Cow::Borrowed(column.datatype()), )?; + if is_joined_update { + UpdateExprTargetRemapper { + target_schema: &target_schema, + } + .visit(&mut expr)?; + } value_exprs.push((column, expr)); } _ => { @@ -108,6 +146,19 @@ impl> Binder<'_, '_, T, A> } } self.context.allow_default = false; + if is_joined_update { + let exprs = target_schema + .iter() + .enumerate() + .map(|(index, column)| { + ScalarExpression::column_expr(column.clone(), target_offset + index) + }) + .collect(); + plan = LogicalPlan::new( + Operator::Project(ProjectOperator { exprs }), + Childrens::Only(Box::new(plan)), + ); + } Ok(LogicalPlan::new( Operator::Update(UpdateOperator { table_name, diff --git a/src/execution/dml/update.rs b/src/execution/dml/update.rs index 9eba12fd..45824c0d 100644 --- a/src/execution/dml/update.rs +++ b/src/execution/dml/update.rs @@ -110,7 +110,9 @@ impl Update { let mut tuple = arena.result_tuple().clone(); let mut is_overwrite = true; - let old_pk = tuple.pk.clone().ok_or(DatabaseError::PrimaryKeyNotFound)?; + let Some(old_pk) = tuple.pk.clone() else { + continue; + }; for (index_meta, exprs) in index_metas.iter() { let values = Projection::projection(&tuple, exprs)?; let Some(value) = DataValue::values_to_tuple(values) else { diff --git a/src/execution/dql/join/hash/full_join.rs b/src/execution/dql/join/hash/full_join.rs index 561c5883..172d6b9a 100644 --- a/src/execution/dql/join/hash/full_join.rs +++ b/src/execution/dql/join/hash/full_join.rs @@ -85,7 +85,10 @@ impl JoinProbeState for FullJoinState { ); build_state.is_used = true; build_state.has_filted = probe_state.has_filtered; - return Ok(Some(Tuple::new(pk.clone(), full_values))); + return Ok(Some(Tuple::new( + pk.as_ref().or(probe_state.probe_tuple.pk.as_ref()).cloned(), + full_values, + ))); } build_state.is_used = !probe_state.has_filtered; diff --git a/src/execution/dql/join/hash/inner_join.rs b/src/execution/dql/join/hash/inner_join.rs index 810dc083..4eb62931 100644 --- a/src/execution/dql/join/hash/inner_join.rs +++ b/src/execution/dql/join/hash/inner_join.rs @@ -55,7 +55,10 @@ impl JoinProbeState for InnerJoinState { .chain(probe_state.probe_tuple.values.iter()) .cloned(), ); - return Ok(Some(Tuple::new(pk.clone(), full_values))); + return Ok(Some(Tuple::new( + pk.as_ref().or(probe_state.probe_tuple.pk.as_ref()).cloned(), + full_values, + ))); } probe_state.finished = true; diff --git a/src/execution/dql/join/hash/left_join.rs b/src/execution/dql/join/hash/left_join.rs index 68caff68..109516b8 100644 --- a/src/execution/dql/join/hash/left_join.rs +++ b/src/execution/dql/join/hash/left_join.rs @@ -65,7 +65,10 @@ impl JoinProbeState for LeftJoinState { .cloned(), ); build_state.is_used = true; - return Ok(Some(Tuple::new(pk.clone(), full_values))); + return Ok(Some(Tuple::new( + pk.as_ref().or(probe_state.probe_tuple.pk.as_ref()).cloned(), + full_values, + ))); } build_state.is_used = !probe_state.has_filtered; diff --git a/src/execution/dql/join/hash/right_join.rs b/src/execution/dql/join/hash/right_join.rs index 74578603..c226feb3 100644 --- a/src/execution/dql/join/hash/right_join.rs +++ b/src/execution/dql/join/hash/right_join.rs @@ -77,7 +77,10 @@ impl JoinProbeState for RightJoinState { probe_state.produced = true; build_state.is_used = true; build_state.has_filted = probe_state.has_filtered; - return Ok(Some(Tuple::new(pk.clone(), full_values))); + return Ok(Some(Tuple::new( + pk.as_ref().or(probe_state.probe_tuple.pk.as_ref()).cloned(), + full_values, + ))); } build_state.is_used = probe_state.produced; diff --git a/src/execution/dql/join/hash_join.rs b/src/execution/dql/join/hash_join.rs index 00f6608e..065e2d57 100644 --- a/src/execution/dql/join/hash_join.rs +++ b/src/execution/dql/join/hash_join.rs @@ -70,7 +70,7 @@ enum HashJoinState { impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for HashJoin { fn from( - (JoinOperator { on, join_type, .. }, mut left_input, mut right_input): ( + (JoinOperator { on, join_type }, mut left_input, mut right_input): ( JoinOperator, LogicalPlan, LogicalPlan, diff --git a/src/execution/dql/join/nested_loop_join.rs b/src/execution/dql/join/nested_loop_join.rs index 087a94e5..45e862c5 100644 --- a/src/execution/dql/join/nested_loop_join.rs +++ b/src/execution/dql/join/nested_loop_join.rs @@ -120,7 +120,7 @@ struct ActiveLeftState { impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for NestedLoopJoin { fn from( - (JoinOperator { on, join_type, .. }, left_input, right_input): ( + (JoinOperator { on, join_type }, left_input, right_input): ( JoinOperator, LogicalPlan, LogicalPlan, diff --git a/src/optimizer/rule/normalization/pushdown_predicates.rs b/src/optimizer/rule/normalization/pushdown_predicates.rs index 95cc4bb6..5dd477ab 100644 --- a/src/optimizer/rule/normalization/pushdown_predicates.rs +++ b/src/optimizer/rule/normalization/pushdown_predicates.rs @@ -73,6 +73,23 @@ fn plan_output_columns(plan: &LogicalPlan) -> Vec { } } +fn localize_right_filters( + filters: &mut [ScalarExpression], + left_len: usize, +) -> Result<(), DatabaseError> { + if filters.is_empty() { + return Ok(()); + } + + let mut localizer = PositionShift { + delta: -(left_len as isize), + }; + for expr in filters { + localizer.visit(expr)?; + } + Ok(()) +} + /// Comments copied from Spark Catalyst PushPredicateThroughJoin /// /// Pushes down `Filter` operators where the `condition` can be @@ -138,6 +155,8 @@ impl NormalizationRule for PushPredicateThroughJoin { new_ops.0 = Some(Operator::Filter(left_filter_op)); } + let mut right_filters = right_filters; + localize_right_filters(&mut right_filters, left_columns.len())?; if let Some(right_filter_op) = reduce_filters(right_filters, filter_op.having) { new_ops.1 = Some(Operator::Filter(right_filter_op)); } @@ -155,6 +174,8 @@ impl NormalizationRule for PushPredicateThroughJoin { .collect_vec() } JoinType::RightOuter => { + let mut right_filters = right_filters; + localize_right_filters(&mut right_filters, left_columns.len())?; if let Some(right_filter_op) = reduce_filters(right_filters, filter_op.having) { new_ops.1 = Some(Operator::Filter(right_filter_op)); } @@ -415,14 +436,7 @@ impl NormalizationRule for PushJoinPredicateIntoScan { } else { (Vec::new(), right_filters) }; - if !right_push.is_empty() { - let mut localizer = PositionShift { - delta: -(left_columns.len() as isize), - }; - for expr in &mut right_push { - localizer.visit(expr)?; - } - } + localize_right_filters(&mut right_push, left_columns.len())?; if let Some(filter_op) = reduce_filters(right_push, false) { new_ops.1 = Some(Operator::Filter(filter_op)); } else { diff --git a/tests/slt/joined_dml.slt b/tests/slt/joined_dml.slt new file mode 100644 index 00000000..769e41d1 --- /dev/null +++ b/tests/slt/joined_dml.slt @@ -0,0 +1,175 @@ +statement ok +create table joined_dml_target(id int primary key, v int) + +statement ok +create table joined_dml_source(id int primary key, v int) + +statement ok +insert into joined_dml_target values + (1, 10), (2, 20), + (3, 30), (4, 40), + (5, 50), (6, 60), + (7, 70) + +statement ok +insert into joined_dml_source values + (1, 100), (2, 200), + (5, 500), (6, 600), + (8, 800), (9, 900) + +# Regression for issue 321: joined UPDATE/DELETE should affect only target rows. +statement ok +update joined_dml_target +inner join joined_dml_source on joined_dml_target.id = joined_dml_source.id +set v = 101 +where joined_dml_source.id = 1 + +query II rowsort +select * from joined_dml_target +---- +1 101 +2 20 +3 30 +4 40 +5 50 +6 60 +7 70 + +statement ok +delete from joined_dml_target +inner join joined_dml_source on joined_dml_target.id = joined_dml_source.id +where joined_dml_source.id = 2 + +query II rowsort +select * from joined_dml_target +---- +1 101 +3 30 +4 40 +5 50 +6 60 +7 70 + +statement ok +update joined_dml_target +left join joined_dml_source on joined_dml_target.id = joined_dml_source.id +set v = 303 +where joined_dml_source.id is null and joined_dml_target.id = 3 + +query II rowsort +select * from joined_dml_target +---- +1 101 +3 303 +4 40 +5 50 +6 60 +7 70 + +statement ok +delete from joined_dml_target +left join joined_dml_source on joined_dml_target.id = joined_dml_source.id +where joined_dml_source.id is null and joined_dml_target.id = 4 + +query II rowsort +select * from joined_dml_target +---- +1 101 +3 303 +5 50 +6 60 +7 70 + +statement ok +update joined_dml_target +right join joined_dml_source on joined_dml_target.id = joined_dml_source.id +set v = 505 +where joined_dml_source.id in (5, 8) + +query II rowsort +select * from joined_dml_target +---- +1 101 +3 303 +5 505 +6 60 +7 70 + +statement ok +delete from joined_dml_target +right join joined_dml_source on joined_dml_target.id = joined_dml_source.id +where joined_dml_source.id in (6, 9) + +query II rowsort +select * from joined_dml_target +---- +1 101 +3 303 +5 505 +7 70 + +statement ok +update joined_dml_target +full join joined_dml_source on joined_dml_target.id = joined_dml_source.id +set v = 707 +where joined_dml_source.id is null and joined_dml_target.id = 7 + +query II rowsort +select * from joined_dml_target +---- +1 101 +3 303 +5 505 +7 707 + +statement ok +delete from joined_dml_target +full join joined_dml_source on joined_dml_target.id = joined_dml_source.id +where joined_dml_source.id in (1, 8) + +query II rowsort +select * from joined_dml_target +---- +3 303 +5 505 +7 707 + +statement ok +update joined_dml_target +cross join joined_dml_source +set v = 909 +where joined_dml_target.id = 7 and joined_dml_source.id = 9 + +query II rowsort +select * from joined_dml_target +---- +3 303 +5 505 +7 909 + +statement ok +delete from joined_dml_target +cross join joined_dml_source +where joined_dml_target.id = 3 and joined_dml_source.id = 8 + +query II rowsort +select * from joined_dml_target +---- +5 505 +7 909 + +query II rowsort +select * from joined_dml_source +---- +1 100 +2 200 +5 500 +6 600 +8 800 +9 900 + +statement ok +drop table joined_dml_target + +statement ok +drop table joined_dml_source