From cdaa5851a0c0e338312f6f1a9d52589741b605e4 Mon Sep 17 00:00:00 2001 From: kould Date: Mon, 4 May 2026 13:48:34 +0800 Subject: [PATCH 1/2] feat: support quantified subqueries --- src/binder/expr.rs | 186 ++++++++----- src/binder/mod.rs | 4 +- src/binder/select.rs | 97 +++++-- src/db.rs | 2 +- src/execution/dql/mark_apply.rs | 106 +++++--- .../rule/normalization/parameterized_index.rs | 11 +- src/planner/operator/mark_apply.rs | 45 +++- tests/slt/any_all_quantifier.slt | 50 ++++ tests/slt/crdb/delete.slt | 255 +++++++++--------- tests/slt/crdb/update.slt | 225 ++++++++-------- tests/slt/sql_2016/E061_07.slt | 128 +++++---- tests/slt/sql_2016/E061_12.slt | 128 +++++---- 12 files changed, 747 insertions(+), 490 deletions(-) create mode 100644 tests/slt/any_all_quantifier.slt diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 0e894165..72d96a83 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -33,6 +33,7 @@ use crate::expression::function::scala::{ArcScalarFunctionImpl, ScalarFunction}; use crate::expression::function::table::{ArcTableFunctionImpl, TableFunction}; use crate::expression::function::FunctionSummary; use crate::expression::{AliasType, ScalarExpression}; +use crate::planner::operator::mark_apply::MarkApplyQuantifier; use crate::planner::operator::scalar_subquery::ScalarSubqueryOperator; use crate::planner::{LogicalPlan, SchemaOutput}; use crate::storage::Transaction; @@ -245,7 +246,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T }) } Expr::Exists { subquery, negated } => { - let (sub_query, _column, correlated) = self.bind_subquery(None, subquery)?; + let (sub_query, correlated) = self.bind_subquery(subquery)?; let (_, marker_ref) = self .bind_temp_table_alias(ScalarExpression::Constant(DataValue::Boolean(true)), 0); self.context.sub_query(SubQueryType::ExistsSubQuery { @@ -265,7 +266,8 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T } } Expr::Subquery(subquery) => { - let (sub_query, column, correlated) = self.bind_subquery(None, subquery)?; + let (sub_query, column, correlated) = + self.bind_subquery_with_output(None, subquery)?; let sub_query = ScalarSubqueryOperator::build(sub_query); let (expr, sub_query) = if !self.context.is_step(&QueryBindStep::Where) { self.bind_temp_table(column, sub_query)? @@ -282,46 +284,13 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T expr, subquery, negated, - } => { - let left_expr = self.bind_expr(expr)?; - let (sub_query, column, correlated) = - self.bind_subquery(Some(left_expr.return_type().as_ref()), subquery)?; - - if !self.context.is_step(&QueryBindStep::Where) { - return Err(DatabaseError::UnsupportedStmt( - "'IN (SUBQUERY)' can only appear in `WHERE`".to_string(), - )); - } - - let (alias_expr, sub_query) = self.bind_temp_table(column, sub_query)?; - let predicate = ScalarExpression::Binary { - op: expression::BinaryOperator::Eq, - left_expr: Box::new(left_expr), - right_expr: Box::new(alias_expr), - evaluator: None, - ty: LogicalType::Boolean, - }; - let (_, marker_ref) = self - .bind_temp_table_alias(ScalarExpression::Constant(DataValue::Boolean(true)), 0); - self.context.sub_query(SubQueryType::InSubQuery { - negated: *negated, - plan: sub_query, - correlated, - output_column: marker_ref.output_column(), - predicate, - }); - - if *negated { - Ok(ScalarExpression::Unary { - op: expression::UnaryOperator::Not, - expr: Box::new(marker_ref), - evaluator: None, - ty: LogicalType::Boolean, - }) - } else { - Ok(marker_ref) - } - } + } => self.bind_quantified_subquery( + MarkApplyQuantifier::Any, + *negated, + expr, + &BinaryOperator::Eq, + subquery, + ), Expr::Tuple(exprs) => { let mut bond_exprs = Vec::with_capacity(exprs.len()); @@ -377,10 +346,86 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T ty, }) } + Expr::AnyOp { + left, + compare_op, + right, + .. + } => self.bind_quantified_op(MarkApplyQuantifier::Any, left, compare_op, right), + Expr::AllOp { + left, + compare_op, + right, + } => self.bind_quantified_op(MarkApplyQuantifier::All, left, compare_op, right), expr => Err(DatabaseError::UnsupportedStmt(expr.to_string())), } } + fn bind_quantified_op( + &mut self, + quantifier: MarkApplyQuantifier, + left: &Expr, + compare_op: &BinaryOperator, + right: &Expr, + ) -> Result { + let Expr::Subquery(subquery) = right else { + return Err(DatabaseError::UnsupportedStmt(format!( + "{quantifier:?} only supports subquery operands" + ))); + }; + + self.bind_quantified_subquery(quantifier, false, left, compare_op, subquery) + } + + fn bind_quantified_subquery( + &mut self, + quantifier: MarkApplyQuantifier, + negated: bool, + expr: &Expr, + compare_op: &BinaryOperator, + subquery: &Query, + ) -> Result { + let left_expr = self.bind_expr(expr)?; + let (sub_query, column, correlated) = + self.bind_subquery_with_output(Some(left_expr.return_type().as_ref()), subquery)?; + + if !self.context.is_step(&QueryBindStep::Where) { + return Err(DatabaseError::UnsupportedStmt( + "quantified subqueries can only appear in `WHERE`".to_string(), + )); + } + + let (alias_expr, sub_query) = self.bind_temp_table(column, sub_query)?; + let predicate = ScalarExpression::Binary { + op: (*compare_op).clone().try_into()?, + left_expr: Box::new(left_expr), + right_expr: Box::new(alias_expr), + evaluator: None, + ty: LogicalType::Boolean, + }; + let (_, marker_ref) = + self.bind_temp_table_alias(ScalarExpression::Constant(DataValue::Boolean(true)), 0); + self.context.sub_query(SubQueryType::QuantifiedSubQuery { + quantifier, + negated, + plan: sub_query, + correlated, + output_column: marker_ref.output_column(), + predicate, + }); + + if negated { + Ok(ScalarExpression::Unary { + op: expression::UnaryOperator::Not, + expr: Box::new(marker_ref), + evaluator: None, + ty: LogicalType::Boolean, + }) + } else { + Ok(marker_ref) + } + } + fn bind_temp_table( &mut self, expr: ScalarExpression, @@ -425,34 +470,12 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T ) } - fn bind_subquery( + fn bind_subquery_with_output( &mut self, - in_ty: Option<&LogicalType>, + value_ty: Option<&LogicalType>, subquery: &Query, ) -> Result<(LogicalPlan, ScalarExpression, bool), DatabaseError> { - let BinderContext { - table_cache, - view_cache, - transaction, - scala_functions, - table_functions, - temp_table_id, - .. - } = &self.context; - let mut binder = Binder::new( - BinderContext::new( - table_cache, - view_cache, - *transaction, - scala_functions, - table_functions, - temp_table_id.clone(), - ), - self.args, - Some(self), - ); - let mut sub_query = binder.bind_query(subquery)?; - let correlated = binder.context.has_outer_refs(); + let (mut sub_query, correlated) = self.bind_subquery(subquery)?; let sub_query_schema = sub_query.output_schema(); let fn_check = |len: usize| { @@ -465,7 +488,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T Ok(()) }; - let expr = if let Some(LogicalType::Tuple(tys)) = in_ty { + let expr = if let Some(LogicalType::Tuple(tys)) = value_ty { fn_check(tys.len())?; let columns = sub_query_schema @@ -482,6 +505,33 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T Ok((sub_query, expr, correlated)) } + fn bind_subquery(&mut self, subquery: &Query) -> Result<(LogicalPlan, bool), DatabaseError> { + let BinderContext { + table_cache, + view_cache, + transaction, + scala_functions, + table_functions, + temp_table_id, + .. + } = &self.context; + let mut binder = Binder::new( + BinderContext::new( + table_cache, + view_cache, + *transaction, + scala_functions, + table_functions, + temp_table_id.clone(), + ), + self.args, + Some(self), + ); + let sub_query = binder.bind_query(subquery)?; + let correlated = binder.context.has_outer_refs(); + Ok((sub_query, correlated)) + } + pub fn bind_like( &mut self, negated: bool, diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 2b1210e6..ec8697ea 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -49,6 +49,7 @@ use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::{DatabaseError, SqlErrorSpan}; use crate::expression::ScalarExpression; use crate::planner::operator::join::JoinType; +use crate::planner::operator::mark_apply::MarkApplyQuantifier; use crate::planner::{LogicalPlan, SchemaOutput}; use crate::storage::{TableCache, Transaction, ViewCache}; use crate::types::tuple::SchemaRef; @@ -156,7 +157,8 @@ pub enum SubQueryType { correlated: bool, output_column: ColumnRef, }, - InSubQuery { + QuantifiedSubQuery { + quantifier: MarkApplyQuantifier, negated: bool, plan: LogicalPlan, correlated: bool, diff --git a/src/binder/select.rs b/src/binder/select.rs index 8f114c0c..bd9d2178 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -1314,11 +1314,12 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' predicates, ); } - SubQueryType::InSubQuery { + SubQueryType::QuantifiedSubQuery { + quantifier, plan, correlated, output_column, - predicate: mut in_predicate, + predicate: mut quantified_predicate, .. } => { if matches!(uses_mark_apply, Some(false)) { @@ -1329,7 +1330,8 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } uses_mark_apply = Some(true); if correlated { - in_predicate = Self::rewrite_correlated_in_predicate(in_predicate); + quantified_predicate = + Self::rewrite_correlated_quantified_predicate(quantified_predicate); } let (plan, predicates) = Self::prepare_mark_apply( &mut predicate, @@ -1338,10 +1340,15 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' plan, correlated, true, - vec![in_predicate], + vec![quantified_predicate], )?; - children = - MarkApplyOperator::build_in(children, plan, output_column, predicates); + children = MarkApplyOperator::build_quantified( + children, + plan, + quantifier, + output_column, + predicates, + ); } SubQueryType::SubQuery { plan, correlated } => { if matches!(uses_mark_apply, Some(true)) { @@ -1480,7 +1487,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' Ok((plan, apply_predicates)) } - fn rewrite_correlated_in_predicate(predicate: ScalarExpression) -> ScalarExpression { + fn rewrite_correlated_quantified_predicate(predicate: ScalarExpression) -> ScalarExpression { let strip_projection_alias = |expr: Box| match *expr { ScalarExpression::Alias { expr, @@ -1496,7 +1503,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' right_expr, ty, .. - } if op == BinaryOperator::Eq => ScalarExpression::Binary { + } => ScalarExpression::Binary { op, left_expr: strip_projection_alias(left_expr), right_expr: strip_projection_alias(right_expr), @@ -2115,7 +2122,9 @@ mod tests { use crate::expression::visitor_mut::VisitorMut; use crate::expression::{AliasType, ScalarExpression}; use crate::planner::operator::join::{JoinCondition, JoinType}; - use crate::planner::operator::mark_apply::{MarkApplyKind, MarkApplyOperator}; + use crate::planner::operator::mark_apply::{ + MarkApplyKind, MarkApplyOperator, MarkApplyQuantifier, + }; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::types::LogicalType; @@ -2262,6 +2271,19 @@ mod tests { } } + fn assert_quantified_mark_apply( + plan: &LogicalPlan, + quantifier: MarkApplyQuantifier, + predicate_len: usize, + ) { + let Some(mark_apply) = find_mark_apply(plan) else { + panic!("expected quantified subquery to introduce a mark apply") + }; + + assert_eq!(mark_apply.kind, MarkApplyKind::Quantified(quantifier)); + assert_eq!(mark_apply.predicates().len(), predicate_len); + } + #[test] fn test_scalar_subquery_in_where_binds_as_inner_join() -> Result<(), DatabaseError> { let table_states = build_t1_table()?; @@ -2280,12 +2302,34 @@ mod tests { fn test_in_subquery_in_where_binds_as_mark_apply() -> Result<(), DatabaseError> { let table_states = build_t1_table()?; let plan = table_states.plan("select * from t1 where c1 in (select c3 from t2)")?; - let Some(mark_apply) = find_mark_apply(&plan) else { - panic!("expected IN subquery to introduce a mark apply") - }; + assert_quantified_mark_apply(&plan, MarkApplyQuantifier::Any, 1); - assert_eq!(mark_apply.kind, MarkApplyKind::In); - assert_eq!(mark_apply.predicates().len(), 1); + Ok(()) + } + + #[test] + fn test_any_subquery_in_where_binds_as_mark_apply() -> Result<(), DatabaseError> { + let table_states = build_t1_table()?; + let plan = table_states.plan("select * from t1 where c1 < any(select c3 from t2)")?; + assert_quantified_mark_apply(&plan, MarkApplyQuantifier::Any, 1); + + Ok(()) + } + + #[test] + fn test_some_subquery_in_where_binds_as_mark_apply() -> Result<(), DatabaseError> { + let table_states = build_t1_table()?; + let plan = table_states.plan("select * from t1 where c1 = some(select c3 from t2)")?; + assert_quantified_mark_apply(&plan, MarkApplyQuantifier::Any, 1); + + Ok(()) + } + + #[test] + fn test_all_subquery_in_where_binds_as_mark_apply() -> Result<(), DatabaseError> { + let table_states = build_t1_table()?; + let plan = table_states.plan("select * from t1 where c1 > all(select c3 from t2)")?; + assert_quantified_mark_apply(&plan, MarkApplyQuantifier::All, 1); Ok(()) } @@ -2295,12 +2339,27 @@ mod tests { let table_states = build_t1_table()?; let plan = table_states.plan("select * from t1 where c1 in (select c3 from t2 where c4 = c2)")?; - let Some(mark_apply) = find_mark_apply(&plan) else { - panic!("expected correlated IN subquery to introduce a mark apply") - }; + assert_quantified_mark_apply(&plan, MarkApplyQuantifier::Any, 2); - assert_eq!(mark_apply.kind, MarkApplyKind::In); - assert_eq!(mark_apply.predicates().len(), 2); + Ok(()) + } + + #[test] + fn test_correlated_any_subquery_in_where_binds_as_mark_apply() -> Result<(), DatabaseError> { + let table_states = build_t1_table()?; + let plan = table_states + .plan("select * from t1 where c1 < any(select c3 from t2 where c4 = c2)")?; + assert_quantified_mark_apply(&plan, MarkApplyQuantifier::Any, 2); + + Ok(()) + } + + #[test] + fn test_correlated_all_subquery_in_where_binds_as_mark_apply() -> Result<(), DatabaseError> { + let table_states = build_t1_table()?; + let plan = table_states + .plan("select * from t1 where c1 > all(select c3 from t2 where c4 = c2)")?; + assert_quantified_mark_apply(&plan, MarkApplyQuantifier::All, 2); Ok(()) } diff --git a/src/db.rs b/src/db.rs index 92be6d27..d4975118 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1658,7 +1658,7 @@ pub(crate) mod test { |sql: &str, index_name: &str| -> Result<(), DatabaseError> { let explain_plan = collect_plan(sql)?; assert!( - explain_plan.contains("MarkInApply"), + explain_plan.contains("MarkAnyApply"), "unexpected explain plan: {explain_plan}" ); assert!( diff --git a/src/execution/dql/mark_apply.rs b/src/execution/dql/mark_apply.rs index 6d54514e..4ffdf494 100644 --- a/src/execution/dql/mark_apply.rs +++ b/src/execution/dql/mark_apply.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; -use crate::planner::operator::mark_apply::{MarkApplyKind, MarkApplyOperator}; +use crate::planner::operator::mark_apply::{MarkApplyKind, MarkApplyOperator, MarkApplyQuantifier}; use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::index::RuntimeIndexProbe; @@ -23,10 +23,11 @@ use crate::types::value::DataValue; use std::mem; #[derive(PartialEq, Eq)] -enum InPredicateOutcome { - Match, +enum QuantifiedPredicateOutcome { + True, + False, Null, - Continue, + Skip, } pub struct MarkApply { @@ -76,10 +77,16 @@ impl MarkApply { match param_value { Some(value) => Some(RuntimeIndexProbe::Eq(value)), - None if matches!(self.op.kind, MarkApplyKind::In) => Some(RuntimeIndexProbe::Scope { - min: std::collections::Bound::Unbounded, - max: std::collections::Bound::Unbounded, - }), + None if matches!( + self.op.kind, + MarkApplyKind::Quantified(MarkApplyQuantifier::Any) + ) => + { + Some(RuntimeIndexProbe::Scope { + min: std::collections::Bound::Unbounded, + max: std::collections::Bound::Unbounded, + }) + } None => None, } } @@ -141,7 +148,7 @@ impl MarkApply { Ok(DataValue::Boolean(false)) }, ), - MarkApplyKind::In => { + MarkApplyKind::Quantified(MarkApplyQuantifier::Any) => { if let Some(probe_value) = self.parameterized_probe_value()? { if !probe_value.is_null() { if self.with_right_input( @@ -150,8 +157,10 @@ impl MarkApply { |arena, right_input| { while arena.next_tuple(right_input)? { let right_tuple = arena.result_tuple(); - if self.in_predicate_outcome(&self.left_tuple, right_tuple)? - == InPredicateOutcome::Match + if self.quantified_predicate_outcome( + &self.left_tuple, + right_tuple, + )? == QuantifiedPredicateOutcome::True { return Ok(true); } @@ -169,8 +178,10 @@ impl MarkApply { |arena, right_input| { while arena.next_tuple(right_input)? { let right_tuple = arena.result_tuple(); - if self.in_predicate_outcome(&self.left_tuple, right_tuple)? - == InPredicateOutcome::Null + if self.quantified_predicate_outcome( + &self.left_tuple, + right_tuple, + )? == QuantifiedPredicateOutcome::Null { return Ok(true); } @@ -187,25 +198,51 @@ impl MarkApply { } self.with_right_input(arena, None, |arena, right_input| { - let mut saw_null = false; + self.scan_quantified_right_input(arena, right_input, MarkApplyQuantifier::Any) + }) + } + MarkApplyKind::Quantified(MarkApplyQuantifier::All) => { + self.with_right_input(arena, None, |arena, right_input| { + self.scan_quantified_right_input(arena, right_input, MarkApplyQuantifier::All) + }) + } + } + } - while arena.next_tuple(right_input)? { - let right_tuple = arena.result_tuple(); - match self.in_predicate_outcome(&self.left_tuple, right_tuple)? { - InPredicateOutcome::Match => return Ok(DataValue::Boolean(true)), - InPredicateOutcome::Null => saw_null = true, - InPredicateOutcome::Continue => {} - } + fn scan_quantified_right_input<'a, T: Transaction + 'a>( + &self, + arena: &mut ExecArena<'a, T>, + right_input: ExecId, + quantifier: MarkApplyQuantifier, + ) -> Result { + let mut saw_null = false; + + while arena.next_tuple(right_input)? { + let right_tuple = arena.result_tuple(); + match self.quantified_predicate_outcome(&self.left_tuple, right_tuple)? { + QuantifiedPredicateOutcome::True => { + if matches!(quantifier, MarkApplyQuantifier::Any) { + return Ok(DataValue::Boolean(true)); } - - if saw_null { - Ok(DataValue::Null) - } else { - Ok(DataValue::Boolean(false)) + } + QuantifiedPredicateOutcome::False => { + if matches!(quantifier, MarkApplyQuantifier::All) { + return Ok(DataValue::Boolean(false)); } - }) + } + QuantifiedPredicateOutcome::Null => saw_null = true, + QuantifiedPredicateOutcome::Skip => {} } } + + if saw_null { + Ok(DataValue::Null) + } else { + Ok(DataValue::Boolean(matches!( + quantifier, + MarkApplyQuantifier::All + ))) + } } fn exists_predicate_matched( @@ -226,20 +263,21 @@ impl MarkApply { Ok(true) } - fn in_predicate_outcome( + fn quantified_predicate_outcome( &self, left_tuple: &Tuple, right_tuple: &Tuple, - ) -> Result { - match self.in_predicate_value(left_tuple, right_tuple)? { - Some(DataValue::Boolean(true)) => Ok(InPredicateOutcome::Match), - Some(DataValue::Null) => Ok(InPredicateOutcome::Null), - Some(DataValue::Boolean(false)) | None => Ok(InPredicateOutcome::Continue), + ) -> Result { + match self.eval_predicates(left_tuple, right_tuple)? { + Some(DataValue::Boolean(true)) => Ok(QuantifiedPredicateOutcome::True), + Some(DataValue::Boolean(false)) => Ok(QuantifiedPredicateOutcome::False), + Some(DataValue::Null) => Ok(QuantifiedPredicateOutcome::Null), + None => Ok(QuantifiedPredicateOutcome::Skip), Some(_) => Err(DatabaseError::InvalidType), } } - fn in_predicate_value( + fn eval_predicates( &self, left_tuple: &Tuple, right_tuple: &Tuple, diff --git a/src/optimizer/rule/normalization/parameterized_index.rs b/src/optimizer/rule/normalization/parameterized_index.rs index 0aa114bf..9bb8b160 100644 --- a/src/optimizer/rule/normalization/parameterized_index.rs +++ b/src/optimizer/rule/normalization/parameterized_index.rs @@ -16,7 +16,7 @@ use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::expression::{BinaryOperator, ScalarExpression}; use crate::optimizer::core::rule::NormalizationRule; -use crate::planner::operator::mark_apply::MarkApplyKind; +use crate::planner::operator::mark_apply::{MarkApplyKind, MarkApplyQuantifier}; use crate::planner::operator::table_scan::TableScanOperator; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl}; use crate::planner::{Childrens, LogicalPlan}; @@ -59,9 +59,12 @@ fn find_parameterized_probe( MarkApplyKind::Exists => predicates.iter().find_map(|predicate| { extract_parameterized_probe(predicate, left_schema, right_schema) }), - MarkApplyKind::In => predicates.first().and_then(|predicate| { - extract_parameterized_probe(predicate, left_schema, right_schema) - }), + MarkApplyKind::Quantified(MarkApplyQuantifier::Any) => { + predicates.first().and_then(|predicate| { + extract_parameterized_probe(predicate, left_schema, right_schema) + }) + } + MarkApplyKind::Quantified(MarkApplyQuantifier::All) => None, } } diff --git a/src/planner/operator/mark_apply.rs b/src/planner/operator/mark_apply.rs index ed379cbf..bbe5c32d 100644 --- a/src/planner/operator/mark_apply.rs +++ b/src/planner/operator/mark_apply.rs @@ -20,10 +20,16 @@ use kite_sql_serde_macros::ReferenceSerialization; use std::fmt; use std::fmt::Formatter; -#[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, ReferenceSerialization)] +pub enum MarkApplyQuantifier { + Any, + All, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, ReferenceSerialization)] pub enum MarkApplyKind { Exists, - In, + Quantified(MarkApplyQuantifier), } #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] @@ -60,8 +66,16 @@ impl MarkApplyOperator { } pub fn new_in(output_column: ColumnRef, predicates: Vec) -> Self { + Self::new_quantified(MarkApplyQuantifier::Any, output_column, predicates) + } + + pub fn new_quantified( + quantifier: MarkApplyQuantifier, + output_column: ColumnRef, + predicates: Vec, + ) -> Self { Self { - kind: MarkApplyKind::In, + kind: MarkApplyKind::Quantified(quantifier), predicates, output_column, parameterized_probe: None, @@ -73,9 +87,29 @@ impl MarkApplyOperator { right: LogicalPlan, output_column: ColumnRef, predicates: Vec, + ) -> LogicalPlan { + Self::build_quantified( + left, + right, + MarkApplyQuantifier::Any, + output_column, + predicates, + ) + } + + pub fn build_quantified( + left: LogicalPlan, + right: LogicalPlan, + quantifier: MarkApplyQuantifier, + output_column: ColumnRef, + predicates: Vec, ) -> LogicalPlan { LogicalPlan::new( - Operator::MarkApply(MarkApplyOperator::new_in(output_column, predicates)), + Operator::MarkApply(MarkApplyOperator::new_quantified( + quantifier, + output_column, + predicates, + )), Childrens::Twins { left: Box::new(left), right: Box::new(right), @@ -108,7 +142,8 @@ impl fmt::Display for MarkApplyOperator { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self.kind { MarkApplyKind::Exists => write!(f, "MarkExistsApply"), - MarkApplyKind::In => write!(f, "MarkInApply"), + MarkApplyKind::Quantified(MarkApplyQuantifier::Any) => write!(f, "MarkAnyApply"), + MarkApplyKind::Quantified(MarkApplyQuantifier::All) => write!(f, "MarkAllApply"), } } } diff --git a/tests/slt/any_all_quantifier.slt b/tests/slt/any_all_quantifier.slt new file mode 100644 index 00000000..caab46ec --- /dev/null +++ b/tests/slt/any_all_quantifier.slt @@ -0,0 +1,50 @@ +statement ok +create table any_all_outer(id int primary key, a int not null); + +statement ok +create table any_all_inner(id int primary key, b int not null); + +statement ok +insert into any_all_outer values (1, 1), (2, 2), (3, 3), (4, 4); + +statement ok +insert into any_all_inner values (1, 2), (2, 3); + +query I rowsort +select id from any_all_outer +where a = any(select b from any_all_inner); +---- +2 +3 + +query I rowsort +select id from any_all_outer +where a = some(select b from any_all_inner); +---- +2 +3 + +query I rowsort +select id from any_all_outer +where a < any(select b from any_all_inner); +---- +1 +2 + +query I rowsort +select id from any_all_outer +where a > all(select b from any_all_inner); +---- +4 + +query I rowsort +select id from any_all_outer +where a = all(select b from any_all_inner where b = 2); +---- +2 + +statement ok +drop table any_all_outer; + +statement ok +drop table any_all_inner; diff --git a/tests/slt/crdb/delete.slt b/tests/slt/crdb/delete.slt index a8acb851..560ed4b0 100644 --- a/tests/slt/crdb/delete.slt +++ b/tests/slt/crdb/delete.slt @@ -205,131 +205,130 @@ select * from t1 order by a; statement ok insert into t1 values (2),(3); -# TODO: support `ALL/ANY/SOME` on `WHERE` -# statement ok -# delete from t1 where exists (select * from t2 where t1.a = t2.b); - -# query I -# select * from t1 order by a; -# ---- -# 1 -# 8 - -# statement ok -# insert into t1 values(2), (3); - -# statement ok -# delete from t1 where a < any(select b from t2); - -# query I -# select * from t1 order by a; -# ---- -# 3 -# 8 - -# statement ok -# insert into t1 values(2), (1); - -# statement ok -# delete from t1 where a = all(select b from t2); - -# query I -# select * from t1 order by a; -# ---- -# 1 -# 2 -# 3 -# 8 - -# statement ok -# delete from t1 where a in (select b from t2 where a > b); - -# query I -# select * from t1 order by a; -# ---- -# 1 -# 2 -# 3 -# 8 - -# statement ok -# delete from t1 where a = any(select b from t2 where t1.a = t2.b) ; - -# query I -# select * from t1 order by a; -# ---- -# 1 -# 8 - -# statement ok -# insert into t1 values(2), (3); - -# statement ok -# delete from t1 where exists(select b from t2 where b > 2); - -# query I -# select * from t1; -# ---- - -# statement ok -# insert into t1 values(1), (2), (3), (8); - -# statement ok -# delete from t1 where not exists(select b from t2 where b > 2); - -# query I -# select * from t1 order by a; -# ---- -# 1 -# 2 -# 3 -# 8 - -# statement ok -# delete from t1 where a = any(select b from t2 where t1.a = t2.b) or a != any(select b from t2 where t1.a = t2.b); - -# query I -# select * from t1 order by a; -# ---- -# 1 -# 8 - -# statement ok -# insert into t1 values(2), (3); - -# statement ok -# delete from t1 where a = any(select b from t2 where t1.a = t2.b) or a > 1; - -# query I -# select * from t1; -# ---- -# 1 - -# statement ok -# insert into t1 values(2), (3), (8); - -# statement ok -# delete from t1 where a = any(select b from t2 where t1.a = t2.b) or a < any(select b from t2); - -# query I -# select * from t1 order by a; -# ---- -# 8 - -# statement ok -# insert into t1 values(1), (2), (3); - -# statement ok -# delete from t1 where exists(select b from t2 where a = b); - -# query I -# select * from t1 order by a; -# ---- -# 1 -# 8 - -# statement ok -# drop table t1; - -# statement ok -# drop table t2; \ No newline at end of file +statement ok +delete from t1 where exists (select * from t2 where t1.a = t2.b); + +query I +select * from t1 order by a; +---- +1 +8 + +statement ok +insert into t1 values(2), (3); + +statement ok +delete from t1 where a < any(select b from t2); + +query I +select * from t1 order by a; +---- +3 +8 + +statement ok +insert into t1 values(2), (1); + +statement ok +delete from t1 where a = all(select b from t2); + +query I +select * from t1 order by a; +---- +1 +2 +3 +8 + +statement ok +delete from t1 where a in (select b from t2 where a > b); + +query I +select * from t1 order by a; +---- +1 +2 +3 +8 + +statement ok +delete from t1 where a = any(select b from t2 where t1.a = t2.b) ; + +query I +select * from t1 order by a; +---- +1 +8 + +statement ok +insert into t1 values(2), (3); + +statement ok +delete from t1 where exists(select b from t2 where b > 2); + +query I +select * from t1; +---- + +statement ok +insert into t1 values(1), (2), (3), (8); + +statement ok +delete from t1 where not exists(select b from t2 where b > 2); + +query I +select * from t1 order by a; +---- +1 +2 +3 +8 + +statement ok +delete from t1 where a = any(select b from t2 where t1.a = t2.b) or a != any(select b from t2 where t1.a = t2.b); + +query I +select * from t1 order by a; +---- +1 +8 + +statement ok +insert into t1 values(2), (3); + +statement ok +delete from t1 where a = any(select b from t2 where t1.a = t2.b) or a > 1; + +query I +select * from t1; +---- +1 + +statement ok +insert into t1 values(2), (3), (8); + +statement ok +delete from t1 where a = any(select b from t2 where t1.a = t2.b) or a < any(select b from t2); + +query I +select * from t1 order by a; +---- +8 + +statement ok +insert into t1 values(1), (2), (3); + +statement ok +delete from t1 where exists(select b from t2 where a = b); + +query I +select * from t1 order by a; +---- +1 +8 + +statement ok +drop table t1; + +statement ok +drop table t2; diff --git a/tests/slt/crdb/update.slt b/tests/slt/crdb/update.slt index 71f95661..cb4eeccb 100644 --- a/tests/slt/crdb/update.slt +++ b/tests/slt/crdb/update.slt @@ -33,17 +33,16 @@ truncate table t1; statement ok insert into t1 values(0, 1), (1, 2), (2, 3), (3, 8); -# TODO: Exists -# statement ok -# update t1 set a = a + 1 where exists (select * from t2 where t1.a = t2.b); +statement ok +update t1 set a = a + 1 where exists (select * from t2 where t1.a = t2.b); -# query I -# select * from t1 order by a; -# ---- -# 1 -# 3 -# 4 -# 8 +query I +select * from t1 order by a; +---- +0 1 +1 3 +2 4 +3 8 statement ok truncate table t1; @@ -51,17 +50,16 @@ truncate table t1; statement ok insert into t1 values(0, 1), (1, 2), (2, 3), (3, 8); -# sqlparser-rs not support -# statement ok -# update t1 set a = a + 1 where a < any(select b from t2); +statement ok +update t1 set a = a + 1 where a < any(select b from t2); -# query I -# select * from t1 order by a; -# ---- -# 2 -# 3 -# 3 -# 8 +query I +select * from t1 order by a; +---- +0 2 +1 3 +2 3 +3 8 statement ok truncate table t1; @@ -69,41 +67,38 @@ truncate table t1; statement ok insert into t1 values(0, 1), (1, 2), (2, 3), (3, 8); -# sqlparser-rs not support -# statement ok -# update t1 set a = a + 1 where a = all(select b from t2); - -# query I -# select * from t1 order by a; -# ---- -# 1 -# 2 -# 3 -# 8 - -# TODO: Correlated Subquery -# statement ok -# update t1 set a = a + 1 where a in (select b from t2 where a > b); - -# query I -# select a from t1 order by a; -# ---- -# 1 -# 2 -# 3 -# 8 - -# sqlparser-rs not support -# statement ok -# update t1 set a = a + 1 where a = any(select b from t2 where t1.a = t2.b); - -# query I -# select * from t1 order by a; -# ---- -# 1 -# 3 -# 4 -# 8 +statement ok +update t1 set a = a + 1 where a = all(select b from t2); + +query I +select * from t1 order by a; +---- +0 1 +1 2 +2 3 +3 8 + +statement ok +update t1 set a = a + 1 where a in (select b from t2 where a > b); + +query I +select a from t1 order by a; +---- +1 +2 +3 +8 + +statement ok +update t1 set a = a + 1 where a = any(select b from t2 where t1.a = t2.b); + +query I +select * from t1 order by a; +---- +0 1 +1 3 +2 4 +3 8 statement ok truncate table t1; @@ -111,17 +106,16 @@ truncate table t1; statement ok insert into t1 values(0, 1), (1, 2), (2, 3), (3, 8); -# TODO: Exists -# statement ok -# update t1 set a = a + 1 where exists(select b from t2 where b > 2); +statement ok +update t1 set a = a + 1 where exists(select b from t2 where b > 2); -# query I -# select * from t1 order by a; -# ---- -# 2 -# 3 -# 4 -# 9 +query I +select * from t1 order by a; +---- +0 2 +1 3 +2 4 +3 9 statement ok truncate table t1; @@ -129,35 +123,33 @@ truncate table t1; statement ok insert into t1 values(0, 1), (1, 2), (2, 3), (3, 8); -# TODO: Exists -# statement ok -# update t1 set a = a + 1 where not exists(select b from t2 where b > 2); +statement ok +update t1 set a = a + 1 where not exists(select b from t2 where b > 2); -# query I -# select * from t1 order by a; -# ---- -# 1 -# 2 -# 3 -# 8 +query I +select * from t1 order by a; +---- +0 1 +1 2 +2 3 +3 8 statement ok truncate table t1; statement ok -insert into t1 values(1), (2), (3), (8); +insert into t1 values(0, 1), (1, 2), (2, 3), (3, 8); -# sqlparser-rs not support -# statement ok -# update t1 set a = a + 1 where a = any(select b from t2 where t1.a = t2.b) or a != any(select b from t2 where t1.a = t2.b); +statement ok +update t1 set a = a + 1 where a = any(select b from t2 where t1.a = t2.b) or a != any(select b from t2 where t1.a = t2.b); -# query I -# select * from t1 order by a; -# ---- -# 1 -# 3 -# 4 -# 8 +query I +select * from t1 order by a; +---- +0 1 +1 3 +2 4 +3 8 statement ok truncate table t1; @@ -165,53 +157,50 @@ truncate table t1; statement ok insert into t1 values(0, 1), (1, 2), (2, 3), (3, 8); -# sqlparser-rs not support -# statement ok -# update t1 set a = a + 1 where a = any(select b from t2 where t1.a = t2.b) or a > 1; +statement ok +update t1 set a = a + 1 where a = any(select b from t2 where t1.a = t2.b) or a > 1; -# query I -# select * from t1 order by a; -# ---- -# 1 -# 3 -# 4 -# 9 +query I +select * from t1 order by a; +---- +0 1 +1 3 +2 4 +3 9 statement ok truncate table t1; statement ok -insert into t1 values(1), (2), (3), (8); +insert into t1 values(0, 1), (1, 2), (2, 3), (3, 8); -# sqlparser-rs not support -# statement ok -# update t1 set a = a + 1 where a = any(select b from t2 where t1.a = t2.b) or a < any(select b from t2); +statement ok +update t1 set a = a + 1 where a = any(select b from t2 where t1.a = t2.b) or a < any(select b from t2); -# query I -# select * from t1 order by a; -# ---- -# 2 -# 3 -# 4 -# 8 +query I +select * from t1 order by a; +---- +0 2 +1 3 +2 4 +3 8 statement ok truncate table t1; statement ok -insert into t1 values(1), (2), (3), (8); +insert into t1 values(0, 1), (1, 2), (2, 3), (3, 8); -# sqlparser-rs not support -# statement ok -# update t1 set a = a + 1 where exists(select b from t2 where a = b); +statement ok +update t1 set a = a + 1 where exists(select b from t2 where a = b); -# query I -# select * from t1 order by a; -# ---- -# 1 -# 3 -# 4 -# 8 +query I +select * from t1 order by a; +---- +0 1 +1 3 +2 4 +3 8 statement ok truncate table t1; diff --git a/tests/slt/sql_2016/E061_07.slt b/tests/slt/sql_2016/E061_07.slt index 61d7faab..f558cfc4 100644 --- a/tests/slt/sql_2016/E061_07.slt +++ b/tests/slt/sql_2016/E061_07.slt @@ -1,93 +1,109 @@ # E061-07: Quantified comparison predicate -# TODO: support `ALL/ANY/SOME` on `WHERE` +statement ok +CREATE TABLE TABLE_E061_07_01_01 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_01 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_01 WHERE A < ALL ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_01 WHERE A < ALL ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_02 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_02 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_02 WHERE A < ANY ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_02 WHERE A < ANY ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_03 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_03 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_03 WHERE A < SOME ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_03 WHERE A < SOME ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_04 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_04 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_04 WHERE A <= ALL ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_04 WHERE A <= ALL ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_05 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_05 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_05 WHERE A <= ANY ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_05 WHERE A <= ANY ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_06 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_06 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_06 WHERE A <= SOME ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_06 WHERE A <= SOME ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_07 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_07 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_07 WHERE A <> ALL ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_07 WHERE A <> ALL ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_08 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_08 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_08 WHERE A <> ANY ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_08 WHERE A <> ANY ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_09 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_09 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_09 WHERE A <> SOME ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_09 WHERE A <> SOME ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_10 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_10 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_10 WHERE A = ALL ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_10 WHERE A = ALL ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_11 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_11 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_11 WHERE A = ANY ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_11 WHERE A = ANY ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_12 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_12 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_12 WHERE A = SOME ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_12 WHERE A = SOME ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_13 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_13 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_13 WHERE A > ALL ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_13 WHERE A > ALL ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_14 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_14 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_14 WHERE A > ANY ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_14 WHERE A > ANY ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_15 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_15 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_15 WHERE A > SOME ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_15 WHERE A > SOME ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_16 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_16 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_16 WHERE A >= ALL ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_16 WHERE A >= ALL ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_17 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_17 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_07_01_17 WHERE A >= ANY ( SELECT 1 ) -# SELECT A FROM TABLE_E061_07_01_17 WHERE A >= ANY ( SELECT 1 ) +statement ok +CREATE TABLE TABLE_E061_07_01_18 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_07_01_18 ( ID INT PRIMARY KEY, A INT ); - -# SELECT A FROM TABLE_E061_07_01_18 WHERE A >= SOME ( SELECT 1 ) +query I +SELECT A FROM TABLE_E061_07_01_18 WHERE A >= SOME ( SELECT 1 ) diff --git a/tests/slt/sql_2016/E061_12.slt b/tests/slt/sql_2016/E061_12.slt index d1f78e87..142e4bcf 100644 --- a/tests/slt/sql_2016/E061_12.slt +++ b/tests/slt/sql_2016/E061_12.slt @@ -1,93 +1,109 @@ # E061-12: Subqueries in quantified comparison predicate -# TODO: Support Subquery on `WHERE` with `ALL/ANY/SOME` +statement ok +CREATE TABLE TABLE_E061_12_01_01 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_01 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_01 WHERE A < ALL ( SELECT A FROM TABLE_E061_12_01_01 ) -# SELECT A FROM TABLE_E061_12_01_01 WHERE A < ALL ( SELECT A FROM TABLE_E061_12_01_01 ) +statement ok +CREATE TABLE TABLE_E061_12_01_02 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_02 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_02 WHERE A < ANY ( SELECT A FROM TABLE_E061_12_01_02 ) -# SELECT A FROM TABLE_E061_12_01_02 WHERE A < ANY ( SELECT A FROM TABLE_E061_12_01_02 ) +statement ok +CREATE TABLE TABLE_E061_12_01_03 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_03 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_03 WHERE A < SOME ( SELECT A FROM TABLE_E061_12_01_03 ) -# SELECT A FROM TABLE_E061_12_01_03 WHERE A < SOME ( SELECT A FROM TABLE_E061_12_01_03 ) +statement ok +CREATE TABLE TABLE_E061_12_01_04 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_04 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_04 WHERE A <= ALL ( SELECT A FROM TABLE_E061_12_01_04 ) -# SELECT A FROM TABLE_E061_12_01_04 WHERE A <= ALL ( SELECT A FROM TABLE_E061_12_01_04 ) +statement ok +CREATE TABLE TABLE_E061_12_01_05 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_05 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_05 WHERE A <= ANY ( SELECT A FROM TABLE_E061_12_01_05 ) -# SELECT A FROM TABLE_E061_12_01_05 WHERE A <= ANY ( SELECT A FROM TABLE_E061_12_01_05 ) +statement ok +CREATE TABLE TABLE_E061_12_01_06 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_06 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_06 WHERE A <= SOME ( SELECT A FROM TABLE_E061_12_01_06 ) -# SELECT A FROM TABLE_E061_12_01_06 WHERE A <= SOME ( SELECT A FROM TABLE_E061_12_01_06 ) +statement ok +CREATE TABLE TABLE_E061_12_01_07 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_07 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_07 WHERE A <> ALL ( SELECT A FROM TABLE_E061_12_01_07 ) -# SELECT A FROM TABLE_E061_12_01_07 WHERE A <> ALL ( SELECT A FROM TABLE_E061_12_01_07 ) +statement ok +CREATE TABLE TABLE_E061_12_01_08 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_08 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_08 WHERE A <> ANY ( SELECT A FROM TABLE_E061_12_01_08 ) -# SELECT A FROM TABLE_E061_12_01_08 WHERE A <> ANY ( SELECT A FROM TABLE_E061_12_01_08 ) +statement ok +CREATE TABLE TABLE_E061_12_01_09 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_09 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_09 WHERE A <> SOME ( SELECT A FROM TABLE_E061_12_01_09 ) -# SELECT A FROM TABLE_E061_12_01_09 WHERE A <> SOME ( SELECT A FROM TABLE_E061_12_01_09 ) +statement ok +CREATE TABLE TABLE_E061_12_01_10 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_10 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_10 WHERE A = ALL ( SELECT A FROM TABLE_E061_12_01_10 ) -# SELECT A FROM TABLE_E061_12_01_10 WHERE A = ALL ( SELECT A FROM TABLE_E061_12_01_10 ) +statement ok +CREATE TABLE TABLE_E061_12_01_11 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_11 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_11 WHERE A = ANY ( SELECT A FROM TABLE_E061_12_01_11 ) -# SELECT A FROM TABLE_E061_12_01_11 WHERE A = ANY ( SELECT A FROM TABLE_E061_12_01_11 ) +statement ok +CREATE TABLE TABLE_E061_12_01_12 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_12 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_12 WHERE A = SOME ( SELECT A FROM TABLE_E061_12_01_12 ) -# SELECT A FROM TABLE_E061_12_01_12 WHERE A = SOME ( SELECT A FROM TABLE_E061_12_01_12 ) +statement ok +CREATE TABLE TABLE_E061_12_01_13 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_13 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_13 WHERE A > ALL ( SELECT A FROM TABLE_E061_12_01_13 ) -# SELECT A FROM TABLE_E061_12_01_13 WHERE A > ALL ( SELECT A FROM TABLE_E061_12_01_13 ) +statement ok +CREATE TABLE TABLE_E061_12_01_14 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_14 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_14 WHERE A > ANY ( SELECT A FROM TABLE_E061_12_01_14 ) -# SELECT A FROM TABLE_E061_12_01_14 WHERE A > ANY ( SELECT A FROM TABLE_E061_12_01_14 ) +statement ok +CREATE TABLE TABLE_E061_12_01_15 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_15 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_15 WHERE A > SOME ( SELECT A FROM TABLE_E061_12_01_15 ) -# SELECT A FROM TABLE_E061_12_01_15 WHERE A > SOME ( SELECT A FROM TABLE_E061_12_01_15 ) +statement ok +CREATE TABLE TABLE_E061_12_01_16 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_16 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_16 WHERE A >= ALL ( SELECT A FROM TABLE_E061_12_01_16 ) -# SELECT A FROM TABLE_E061_12_01_16 WHERE A >= ALL ( SELECT A FROM TABLE_E061_12_01_16 ) +statement ok +CREATE TABLE TABLE_E061_12_01_17 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_17 ( ID INT PRIMARY KEY, A INT ); +query I +SELECT A FROM TABLE_E061_12_01_17 WHERE A >= ANY ( SELECT A FROM TABLE_E061_12_01_17 ) -# SELECT A FROM TABLE_E061_12_01_17 WHERE A >= ANY ( SELECT A FROM TABLE_E061_12_01_17 ) +statement ok +CREATE TABLE TABLE_E061_12_01_18 ( ID INT PRIMARY KEY, A INT ); -# statement ok -# CREATE TABLE TABLE_E061_12_01_18 ( ID INT PRIMARY KEY, A INT ); - -# SELECT A FROM TABLE_E061_12_01_18 WHERE A >= SOME ( SELECT A FROM TABLE_E061_12_01_18 ) +query I +SELECT A FROM TABLE_E061_12_01_18 WHERE A >= SOME ( SELECT A FROM TABLE_E061_12_01_18 ) From ceec291ffb17936571837c39e398e8ad5bee1266 Mon Sep 17 00:00:00 2001 From: kould Date: Mon, 4 May 2026 14:22:17 +0800 Subject: [PATCH 2/2] feat: add ORM quantified subquery helpers --- src/orm/mod.rs | 83 +++++++++++++++++++++++++++++++++++ tests/macros-test/src/main.rs | 72 ++++++++++++++++++++++++++++++ 2 files changed, 155 insertions(+) diff --git a/src/orm/mod.rs b/src/orm/mod.rs index 14ee4095..52df9d11 100644 --- a/src/orm/mod.rs +++ b/src/orm/mod.rs @@ -304,6 +304,17 @@ trait ValueExpressionOps: Sized { }) } + fn quantified_subquery_expr( + self, + compare_op: CompareOp, + quantifier: QuantifiedSubquery, + subquery: S, + ) -> QueryExpr { + let left = self.into_query_value().into_expr(); + let right = Expr::Subquery(Box::new(subquery.into_subquery())); + QueryExpr::from_expr(quantifier.into_ast(left, compare_op.as_ast(), right)) + } + #[allow(clippy::wrong_self_convention)] fn is_null_expr(self) -> QueryExpr { QueryExpr::from_expr(Expr::IsNull(Box::new(self.into_query_value().into_expr()))) @@ -653,6 +664,58 @@ enum CompareOp { Lte, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum QuantifiedSubquery { + Any, + Some, + All, +} + +macro_rules! quantified_value_methods { + ($(($method:ident, $op:ident, $quantifier:ident, $symbol:literal, $keyword:literal)),+ $(,)?) => { + $( + #[doc = concat!("Builds `expr ", $symbol, " ", $keyword, " (subquery)`.")] + pub fn $method(self, subquery: S) -> QueryExpr { + ValueExpressionOps::quantified_subquery_expr( + self, + CompareOp::$op, + QuantifiedSubquery::$quantifier, + subquery, + ) + } + )+ + }; +} + +macro_rules! quantified_methods { + () => { + quantified_value_methods!( + (eq_any, Eq, Any, "=", "ANY"), + (ne_any, Ne, Any, "<>", "ANY"), + (gt_any, Gt, Any, ">", "ANY"), + (gte_any, Gte, Any, ">=", "ANY"), + (lt_any, Lt, Any, "<", "ANY"), + (lte_any, Lte, Any, "<=", "ANY"), + (eq_some, Eq, Some, "=", "SOME"), + (ne_some, Ne, Some, "<>", "SOME"), + (gt_some, Gt, Some, ">", "SOME"), + (gte_some, Gte, Some, ">=", "SOME"), + (lt_some, Lt, Some, "<", "SOME"), + (lte_some, Lte, Some, "<=", "SOME"), + (eq_all, Eq, All, "=", "ALL"), + (ne_all, Ne, All, "<>", "ALL"), + (gt_all, Gt, All, ">", "ALL"), + (gte_all, Gte, All, ">=", "ALL"), + (lt_all, Lt, All, "<", "ALL"), + (lte_all, Lte, All, "<=", "ALL"), + ); + }; +} + +impl Field { + quantified_methods!(); +} + #[derive(Debug, Clone, PartialEq)] /// A lightweight ORM expression wrapper for predicate-oriented SQL AST nodes. /// @@ -1113,6 +1176,8 @@ impl QueryValue { ValueExpressionOps::lte_expr(self, value) } + quantified_methods!(); + /// Builds `expr IS NULL`. pub fn is_null(self) -> QueryExpr { ValueExpressionOps::is_null_expr(self) @@ -1354,6 +1419,24 @@ impl CompareOp { } } +impl QuantifiedSubquery { + fn into_ast(self, left: Expr, compare_op: SqlBinaryOperator, right: Expr) -> Expr { + match self { + QuantifiedSubquery::Any | QuantifiedSubquery::Some => Expr::AnyOp { + left: Box::new(left), + compare_op, + right: Box::new(right), + is_some: matches!(self, QuantifiedSubquery::Some), + }, + QuantifiedSubquery::All => Expr::AllOp { + left: Box::new(left), + compare_op, + right: Box::new(right), + }, + } + } +} + #[doc(hidden)] pub trait StatementSource { type Iter: ResultIter; diff --git a/tests/macros-test/src/main.rs b/tests/macros-test/src/main.rs index ada09c87..d0fb32a1 100644 --- a/tests/macros-test/src/main.rs +++ b/tests/macros-test/src/main.rs @@ -1037,6 +1037,78 @@ mod test { amount: 300, })?; + let eq_any_users = database + .from::() + .filter( + User::id().eq_any( + database + .from::() + .project_value(Order::user_id()) + .eq(Order::amount(), 300), + ), + ) + .fetch()? + .collect::, _>>()?; + assert_eq!( + eq_any_users.iter().map(|user| user.id).collect::>(), + vec![2] + ); + + let eq_some_users = database + .from::() + .filter( + User::id().eq_some( + database + .from::() + .project_value(Order::user_id()) + .eq(Order::amount(), 100), + ), + ) + .fetch()? + .collect::, _>>()?; + assert_eq!( + eq_some_users.iter().map(|user| user.id).collect::>(), + vec![1] + ); + + let gt_all_users = database + .from::() + .filter(User::id().gt_all(database.from::().project_value(Order::user_id()))) + .fetch()? + .collect::, _>>()?; + assert_eq!( + gt_all_users.iter().map(|user| user.id).collect::>(), + vec![3] + ); + + let lt_any_users = database + .from::() + .filter(User::id().lt_any(database.from::().project_value(Order::user_id()))) + .fetch()? + .collect::, _>>()?; + assert_eq!( + lt_any_users.iter().map(|user| user.id).collect::>(), + vec![1] + ); + + let query_value_gt_all_users = database + .from::() + .filter( + User::id() + .add(1) + .gt_all(database.from::().project_value(Order::user_id())), + ) + .asc(User::id()) + .fetch()? + .collect::, _>>()?; + assert_eq!( + query_value_gt_all_users + .iter() + .map(|user| user.id) + .collect::>(), + vec![2, 3] + ); + let exists_count = database .from::() .filter(kite_sql::orm::QueryExpr::exists(