diff --git a/beacon-functions/src/util/mask_if_not_null.rs b/beacon-functions/src/util/mask_if_not_null.rs new file mode 100644 index 00000000..08cfcd73 --- /dev/null +++ b/beacon-functions/src/util/mask_if_not_null.rs @@ -0,0 +1,99 @@ +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::{ + common::{exec_err, internal_err, ExprSchema}, + logical_expr::{ + conditional_expressions::CaseBuilder, + simplify::{ExprSimplifyResult, SimplifyInfo}, + ColumnarValue, ExprSchemable, ReturnFieldArgs, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, + }, + physical_plan::expressions::CaseExpr, + prelude::{is_null, Expr}, +}; + +pub fn mask_if_not_null() -> ScalarUDF { + ScalarUDF::new_from_impl(MaskIfNotNullFunc::new()) +} + +#[derive(Debug, Clone)] +pub struct MaskIfNotNullFunc { + signature: Signature, +} + +impl MaskIfNotNullFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MaskIfNotNullFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "mask_if_not_null" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { + args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true)) + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion::error::Result { + if arg_types.len() != 2 { + return exec_err!( + "mask_if_not_null requires exactly two arguments, got {}", + arg_types.len() + ); + } + Ok(arg_types[1].clone()) + } + + fn simplify( + &self, + mut args: Vec, + info: &dyn SimplifyInfo, + ) -> datafusion::error::Result { + if args.len() != 2 { + return exec_err!( + "mask_if_not_null requires exactly two arguments, got {}", + args.len() + ); + } + let left = args.remove(0); + let right = args.remove(0); + + // If the first argument is known to be non-null, we can simplify to the second argument + if let Ok(false) = info.nullable(&left) { + return Ok(ExprSimplifyResult::Simplified(right)); + } + + let new_expr = CaseBuilder::new( + None, + vec![left.is_not_null()], + vec![right], + Some(Box::new(Expr::Literal( + datafusion::scalar::ScalarValue::Null, + None, + ))), + ) + .end()?; + + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn invoke_with_args( + &self, + args: datafusion::logical_expr::ScalarFunctionArgs, + ) -> datafusion::error::Result { + internal_err!("invoke_with_args should not be called for mask_if_not_null") + } +} diff --git a/beacon-functions/src/util/mask_if_null.rs b/beacon-functions/src/util/mask_if_null.rs new file mode 100644 index 00000000..e313ec7b --- /dev/null +++ b/beacon-functions/src/util/mask_if_null.rs @@ -0,0 +1,99 @@ +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::{ + common::{exec_err, internal_err, ExprSchema}, + logical_expr::{ + conditional_expressions::CaseBuilder, + simplify::{ExprSimplifyResult, SimplifyInfo}, + ColumnarValue, ExprSchemable, ReturnFieldArgs, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, + }, + physical_plan::expressions::CaseExpr, + prelude::{is_null, Expr}, +}; + +pub fn mask_if_null() -> ScalarUDF { + ScalarUDF::new_from_impl(MaskIfNullFunc::new()) +} + +#[derive(Debug, Clone)] +pub struct MaskIfNullFunc { + signature: Signature, +} + +impl MaskIfNullFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MaskIfNullFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "mask_if_null" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { + args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true)) + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion::error::Result { + if arg_types.len() != 2 { + return exec_err!( + "mask_if_null requires exactly two arguments, got {}", + arg_types.len() + ); + } + Ok(arg_types[1].clone()) + } + + fn simplify( + &self, + mut args: Vec, + info: &dyn SimplifyInfo, + ) -> datafusion::error::Result { + if args.len() != 2 { + return exec_err!( + "mask_if_null requires exactly two arguments, got {}", + args.len() + ); + } + let left = args.remove(0); + let right = args.remove(0); + + // If the first argument is known to be non-null, we can simplify to the second argument + if let Ok(false) = info.nullable(&left) { + return Ok(ExprSimplifyResult::Simplified(right)); + } + + let new_expr = CaseBuilder::new( + None, + vec![left.is_null()], + vec![right], + Some(Box::new(Expr::Literal( + datafusion::scalar::ScalarValue::Null, + None, + ))), + ) + .end()?; + + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn invoke_with_args( + &self, + args: datafusion::logical_expr::ScalarFunctionArgs, + ) -> datafusion::error::Result { + internal_err!("invoke_with_args should not be called for mask_if_null") + } +} diff --git a/beacon-functions/src/util/mod.rs b/beacon-functions/src/util/mod.rs index 8b90e4f2..21ce4abf 100644 --- a/beacon-functions/src/util/mod.rs +++ b/beacon-functions/src/util/mod.rs @@ -2,6 +2,8 @@ use datafusion::logical_expr::ScalarUDF; pub mod cast_int8_as_char; pub mod coalesce_label; +pub mod mask_if_not_null; +pub mod mask_if_null; pub mod try_arrow_cast; pub fn util_udfs() -> Vec { @@ -9,5 +11,7 @@ pub fn util_udfs() -> Vec { cast_int8_as_char::cast_int8_as_char(), try_arrow_cast::try_arrow_cast(), coalesce_label::coalesce_label(), + mask_if_null::mask_if_null(), + mask_if_not_null::mask_if_not_null(), ] }