From 5d259ff687397e9c0e11c812a94380a7ed490cdc Mon Sep 17 00:00:00 2001 From: wszhdshys <1925792291@qq.com> Date: Tue, 15 Jul 2025 13:56:54 +0800 Subject: [PATCH] feat: The octet function is implemented to calculate the number of bytes of the input value. And fixed the charlength function so that it correctly outputs the number of characters instead of bytes --- src/db.rs | 2 ++ src/function/char_length.rs | 2 +- src/function/mod.rs | 1 + src/function/octet_length.rs | 63 ++++++++++++++++++++++++++++++++++ tests/slt/sql_2016/E021_04.slt | 11 +++++- tests/slt/sql_2016/E021_05.slt | 15 ++++---- 6 files changed, 86 insertions(+), 8 deletions(-) create mode 100644 src/function/octet_length.rs diff --git a/src/db.rs b/src/db.rs index ab457b6b..5605fd19 100644 --- a/src/db.rs +++ b/src/db.rs @@ -8,6 +8,7 @@ use crate::function::char_length::CharLength; use crate::function::current_date::CurrentDate; use crate::function::lower::Lower; use crate::function::numbers::Numbers; +use crate::function::octet_length::OctetLength; use crate::function::upper::Upper; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizer; @@ -61,6 +62,7 @@ impl DataBaseBuilder { builder.register_scala_function(CharLength::new("character_length".to_lowercase())); builder = builder.register_scala_function(CurrentDate::new()); builder = builder.register_scala_function(Lower::new()); + builder = builder.register_scala_function(OctetLength::new()); builder = builder.register_scala_function(Upper::new()); builder = builder.register_table_function(Numbers::new()); builder diff --git a/src/function/char_length.rs b/src/function/char_length.rs index 6cf0c100..817f591a 100644 --- a/src/function/char_length.rs +++ b/src/function/char_length.rs @@ -43,7 +43,7 @@ impl ScalarFunctionImpl for CharLength { } let mut length: u64 = 0; if let DataValue::Utf8 { value, ty, unit } = &mut value { - length = value.len() as u64; + length = value.chars().count() as u64; } Ok(DataValue::UInt64(length)) } diff --git a/src/function/mod.rs b/src/function/mod.rs index 6c660c57..7930e807 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -2,4 +2,5 @@ pub(crate) mod char_length; pub(crate) mod current_date; pub(crate) mod lower; pub(crate) mod numbers; +pub(crate) mod octet_length; pub(crate) mod upper; diff --git a/src/function/octet_length.rs b/src/function/octet_length.rs new file mode 100644 index 00000000..b712ee3c --- /dev/null +++ b/src/function/octet_length.rs @@ -0,0 +1,63 @@ +use crate::catalog::ColumnRef; +use crate::errors::DatabaseError; +use crate::expression::function::scala::FuncMonotonicity; +use crate::expression::function::scala::ScalarFunctionImpl; +use crate::expression::function::FunctionSummary; +use crate::expression::ScalarExpression; +use crate::types::tuple::Tuple; +use crate::types::value::DataValue; +use crate::types::LogicalType; +use serde::Deserialize; +use serde::Serialize; +use sqlparser::ast::CharLengthUnits; +use std::sync::Arc; + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct OctetLength { + summary: FunctionSummary, +} + +impl OctetLength { + pub(crate) fn new() -> Arc { + let function_name = "octet_length".to_lowercase(); + let arg_types = vec![LogicalType::Varchar(None, CharLengthUnits::Characters)]; + Arc::new(Self { + summary: FunctionSummary { + name: function_name, + arg_types, + }, + }) + } +} + +#[typetag::serde] +impl ScalarFunctionImpl for OctetLength { + #[allow(unused_variables, clippy::redundant_closure_call)] + fn eval( + &self, + exprs: &[ScalarExpression], + tuples: Option<(&Tuple, &[ColumnRef])>, + ) -> Result { + let mut value = exprs[0].eval(tuples)?; + if !matches!(value.logical_type(), LogicalType::Varchar(_, _)) { + value = value.cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))?; + } + let mut length: u64 = 0; + if let DataValue::Utf8 { value, ty, unit } = &mut value { + length = value.len() as u64; + } + Ok(DataValue::UInt64(length)) + } + + fn monotonicity(&self) -> Option { + todo!() + } + + fn return_type(&self) -> &LogicalType { + &LogicalType::Varchar(None, CharLengthUnits::Characters) + } + + fn summary(&self) -> &FunctionSummary { + &self.summary + } +} diff --git a/tests/slt/sql_2016/E021_04.slt b/tests/slt/sql_2016/E021_04.slt index ac401405..8d4842b5 100644 --- a/tests/slt/sql_2016/E021_04.slt +++ b/tests/slt/sql_2016/E021_04.slt @@ -5,8 +5,17 @@ SELECT CHARACTER_LENGTH ( 'foo' ) ---- 3 - query I SELECT CHAR_LENGTH ( 'foo' ) ---- 3 + +query I +SELECT CHARACTER_LENGTH ( '测试' ) +---- +2 + +query I +SELECT CHAR_LENGTH ( '测试' ) +---- +2 diff --git a/tests/slt/sql_2016/E021_05.slt b/tests/slt/sql_2016/E021_05.slt index e29fffa7..a40733ff 100644 --- a/tests/slt/sql_2016/E021_05.slt +++ b/tests/slt/sql_2016/E021_05.slt @@ -1,8 +1,11 @@ -# E021-05: OCTET_LENGTH function +#E021-05: OCTET_LENGTH function -# TODO: OCTET_LENGTH() +query I +SELECT OCTET_LENGTH ( 'foo' ) +---- +3 -# query I -# SELECT OCTET_LENGTH ( 'foo' ) -# ---- -# 3 +query I +SELECT OCTET_LENGTH ( '测试' ) +---- +6