diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/AbstractStrptimeFunctionMapper.java b/isthmus/src/main/java/io/substrait/isthmus/expression/AbstractStrptimeFunctionMapper.java new file mode 100644 index 000000000..eefd8aab2 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/AbstractStrptimeFunctionMapper.java @@ -0,0 +1,71 @@ +package io.substrait.isthmus.expression; + +import io.substrait.expression.Expression.ScalarFunctionInvocation; +import io.substrait.expression.FunctionArg; +import io.substrait.extension.SimpleExtension.ScalarFunctionVariant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlOperator; + +/** + * Abstract base class for custom mappings between Calcite PARSE_* functions and Substrait + * strptime_* functions. + * + *

Calcite PARSE_* functions have format followed by string parameters, while + * Substrait strptime_* functions have string followed by format. When mapping + * between Calcite and Substrait, the parameters need to be reversed. + */ +abstract class AbstractStrptimeFunctionMapper implements ScalarFunctionMapper { + private final String substraitFunctionName; + private final SqlOperator calciteOperator; + private final List strptimeFunctions; + + /** + * Constructs an abstract strptime function mapper. + * + * @param substraitFunctionName the name of the Substrait function (e.g., "strptime_date") + * @param calciteOperator the Calcite operator to map from (e.g., SqlLibraryOperators.PARSE_DATE) + * @param functions the list of all available scalar function variants + */ + protected AbstractStrptimeFunctionMapper( + String substraitFunctionName, + SqlOperator calciteOperator, + List functions) { + this.substraitFunctionName = substraitFunctionName; + this.calciteOperator = calciteOperator; + this.strptimeFunctions = + functions.stream() + .filter(f -> substraitFunctionName.equals(f.name())) + .collect(Collectors.toUnmodifiableList()); + } + + @Override + public Optional toSubstrait(RexCall call) { + if (!calciteOperator.equals(call.op)) { + return Optional.empty(); + } + + List operands = new ArrayList<>(call.getOperands()); + Collections.swap(operands, 0, 1); + + return Optional.of( + new SubstraitFunctionMapping(substraitFunctionName, operands, strptimeFunctions)); + } + + @Override + public Optional> getExpressionArguments(ScalarFunctionInvocation expression) { + if (!substraitFunctionName.equals(expression.declaration().name())) { + return Optional.empty(); + } + + List arguments = new ArrayList<>(expression.arguments()); + Collections.swap(arguments, 0, 1); + + return Optional.of(arguments); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/StrptimeDateFunctionMapper.java b/isthmus/src/main/java/io/substrait/isthmus/expression/StrptimeDateFunctionMapper.java index 3aff3205c..66cb5778a 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/StrptimeDateFunctionMapper.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/StrptimeDateFunctionMapper.java @@ -1,60 +1,18 @@ package io.substrait.isthmus.expression; -import io.substrait.expression.Expression.ScalarFunctionInvocation; -import io.substrait.expression.FunctionArg; import io.substrait.extension.SimpleExtension.ScalarFunctionVariant; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; -import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.fun.SqlLibraryOperators; /** * Custom mapping for the Calcite {@code PARSE_DATE} function to the Substrait {@code strptime_date} - * function. - * - *

Calcite {@code PARSE_DATE} has format followed by date_string parameters, - * while Substrait {@code strptime_date} has date_string followed by format. When - * mapping between Calcite and Substrait, the parameters need to be reversed. - * - *

{@code PARSE_DATE(format, date_string)} maps to {@code strptime_date(date_string, format)}. + * function. {@code PARSE_DATE(format, date_string)} maps to {@code strptime_date(date_string, + * format)}. */ -public final class StrptimeDateFunctionMapper implements ScalarFunctionMapper { +public final class StrptimeDateFunctionMapper extends AbstractStrptimeFunctionMapper { private static final String STRPTIME_DATE_FUNCTION_NAME = "strptime_date"; - private final List strptimeDateFunctions; public StrptimeDateFunctionMapper(List functions) { - strptimeDateFunctions = - functions.stream() - .filter(f -> STRPTIME_DATE_FUNCTION_NAME.equals(f.name())) - .collect(Collectors.toUnmodifiableList()); - } - - @Override - public Optional toSubstrait(RexCall call) { - if (!SqlLibraryOperators.PARSE_DATE.equals(call.op)) { - return Optional.empty(); - } - - List operands = new ArrayList<>(call.getOperands()); - Collections.swap(operands, 0, 1); - - return Optional.of( - new SubstraitFunctionMapping(STRPTIME_DATE_FUNCTION_NAME, operands, strptimeDateFunctions)); - } - - @Override - public Optional> getExpressionArguments(ScalarFunctionInvocation expression) { - if (!STRPTIME_DATE_FUNCTION_NAME.equals(expression.declaration().name())) { - return Optional.empty(); - } - - List arguments = new ArrayList<>(expression.arguments()); - Collections.swap(arguments, 0, 1); - - return Optional.of(arguments); + super(STRPTIME_DATE_FUNCTION_NAME, SqlLibraryOperators.PARSE_DATE, functions); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/StrptimeTimeFunctionMapper.java b/isthmus/src/main/java/io/substrait/isthmus/expression/StrptimeTimeFunctionMapper.java index a1f360d1f..28625cb7b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/StrptimeTimeFunctionMapper.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/StrptimeTimeFunctionMapper.java @@ -1,60 +1,18 @@ package io.substrait.isthmus.expression; -import io.substrait.expression.Expression.ScalarFunctionInvocation; -import io.substrait.expression.FunctionArg; import io.substrait.extension.SimpleExtension.ScalarFunctionVariant; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; -import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.fun.SqlLibraryOperators; /** * Custom mapping for the Calcite {@code PARSE_TIME} function to the Substrait {@code strptime_time} - * function. - * - *

Calcite {@code PARSE_TIME} has format followed by time_string parameters, - * while Substrait {@code strptime_time} has time_string followed by format. When - * mapping between Calcite and Substrait, the parameters need to be reversed. - * - *

{@code PARSE_TIME(format, time_string)} maps to {@code strptime_time(time_string, format)}. + * function. {@code PARSE_TIME(format, time_string)} maps to {@code strptime_time(time_string, + * format)}. */ -public final class StrptimeTimeFunctionMapper implements ScalarFunctionMapper { +public final class StrptimeTimeFunctionMapper extends AbstractStrptimeFunctionMapper { private static final String STRPTIME_TIME_FUNCTION_NAME = "strptime_time"; - private final List strptimeTimeFunctions; public StrptimeTimeFunctionMapper(List functions) { - strptimeTimeFunctions = - functions.stream() - .filter(f -> STRPTIME_TIME_FUNCTION_NAME.equals(f.name())) - .collect(Collectors.toUnmodifiableList()); - } - - @Override - public Optional toSubstrait(RexCall call) { - if (!SqlLibraryOperators.PARSE_TIME.equals(call.op)) { - return Optional.empty(); - } - - List operands = new ArrayList<>(call.getOperands()); - Collections.swap(operands, 0, 1); - - return Optional.of( - new SubstraitFunctionMapping(STRPTIME_TIME_FUNCTION_NAME, operands, strptimeTimeFunctions)); - } - - @Override - public Optional> getExpressionArguments(ScalarFunctionInvocation expression) { - if (!STRPTIME_TIME_FUNCTION_NAME.equals(expression.declaration().name())) { - return Optional.empty(); - } - - List arguments = new ArrayList<>(expression.arguments()); - Collections.swap(arguments, 0, 1); - - return Optional.of(arguments); + super(STRPTIME_TIME_FUNCTION_NAME, SqlLibraryOperators.PARSE_TIME, functions); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/StrptimeTimestampFunctionMapper.java b/isthmus/src/main/java/io/substrait/isthmus/expression/StrptimeTimestampFunctionMapper.java index f0c1ca5f9..bfe733774 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/StrptimeTimestampFunctionMapper.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/StrptimeTimestampFunctionMapper.java @@ -1,62 +1,18 @@ package io.substrait.isthmus.expression; -import io.substrait.expression.Expression.ScalarFunctionInvocation; -import io.substrait.expression.FunctionArg; import io.substrait.extension.SimpleExtension.ScalarFunctionVariant; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; -import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.fun.SqlLibraryOperators; /** * Custom mapping for the Calcite {@code PARSE_TIMESTAMP} function to the Substrait {@code - * strptime_timestamp} function. - * - *

Calcite {@code PARSE_TIMESTAMP} has format followed by timestamp_string - * parameters, while Substrait {@code strptime_timestamp} has timestamp_string followed by - * format. When mapping between Calcite and Substrait, the parameters need to be reversed. - * - *

{@code PARSE_TIMESTAMP(format, timestamp_string)} maps to {@code + * strptime_timestamp} function. {@code PARSE_TIMESTAMP(format, timestamp_string)} maps to {@code * strptime_timestamp(timestamp_string, format)}. */ -public final class StrptimeTimestampFunctionMapper implements ScalarFunctionMapper { +public final class StrptimeTimestampFunctionMapper extends AbstractStrptimeFunctionMapper { private static final String STRPTIME_TIMESTAMP_FUNCTION_NAME = "strptime_timestamp"; - private final List strptimeTimestampFunctions; public StrptimeTimestampFunctionMapper(List functions) { - strptimeTimestampFunctions = - functions.stream() - .filter(f -> STRPTIME_TIMESTAMP_FUNCTION_NAME.equals(f.name())) - .collect(Collectors.toUnmodifiableList()); - } - - @Override - public Optional toSubstrait(RexCall call) { - if (!SqlLibraryOperators.PARSE_TIMESTAMP.equals(call.op)) { - return Optional.empty(); - } - - List operands = new ArrayList<>(call.getOperands()); - Collections.swap(operands, 0, 1); - - return Optional.of( - new SubstraitFunctionMapping( - STRPTIME_TIMESTAMP_FUNCTION_NAME, operands, strptimeTimestampFunctions)); - } - - @Override - public Optional> getExpressionArguments(ScalarFunctionInvocation expression) { - if (!STRPTIME_TIMESTAMP_FUNCTION_NAME.equals(expression.declaration().name())) { - return Optional.empty(); - } - - List arguments = new ArrayList<>(expression.arguments()); - Collections.swap(arguments, 0, 1); - - return Optional.of(arguments); + super(STRPTIME_TIMESTAMP_FUNCTION_NAME, SqlLibraryOperators.PARSE_TIMESTAMP, functions); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java b/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java index fda3aca43..f0b761967 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java @@ -15,12 +15,16 @@ import io.substrait.isthmus.expression.RexExpressionConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.isthmus.expression.WindowFunctionConverter; +import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.stream.Stream; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlKind; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; /** * Verify that "problematic" Substrait functions can be converted to Calcite and back successfully @@ -284,80 +288,64 @@ void concatStringLiteralAndChar() throws Exception { assertProtoPlanRoundrip("select 'brand_'||P_BRAND from PART"); } - @Test - void testStrptimeTime() { - Expression.StrLiteral timeString = Expression.StrLiteral.builder().value("12:34:56").build(); - Expression.StrLiteral formatString = Expression.StrLiteral.builder().value("%H:%M:%S").build(); - ScalarFunctionInvocation strptimeTimeFn = - sb.scalarFn( - DefaultExtensionCatalog.FUNCTIONS_DATETIME, + /** + * Provides test cases for strptime function tests. + * + * @return Stream of test arguments containing: function name, input string value, format string, + * output type, and expected Calcite function name + */ + private static Stream strptimeTestCases() { + return Stream.of( + Arguments.of( "strptime_time:str_str", + "12:34:56", + "%H:%M:%S", TypeCreator.REQUIRED.TIME, - timeString, - formatString); - - // tests Substrait -> Calcite - RexNode calciteExpr = strptimeTimeFn.accept(expressionRexConverter, Context.newContext()); - assertEquals(SqlKind.OTHER_FUNCTION, calciteExpr.getKind()); - assertInstanceOf(RexCall.class, calciteExpr); - - assertEquals("PARSE_TIME('%H:%M:%S':VARCHAR, '12:34:56':VARCHAR)", calciteExpr.toString()); - - // tests the reverse Calcite -> Substrait - Expression reverse = calciteExpr.accept(rexExpressionConverter); - assertEquals(strptimeTimeFn, reverse); - } - - @Test - void testStrptimeTimestamp() { - Expression.StrLiteral timestampString = - Expression.StrLiteral.builder().value("2026-01-29T12:34:56").build(); - Expression.StrLiteral formatString = - Expression.StrLiteral.builder().value("%Y:%m:%dT%H:%M:%S").build(); - ScalarFunctionInvocation strptimeTimestampFn = - sb.scalarFn( - DefaultExtensionCatalog.FUNCTIONS_DATETIME, + "PARSE_TIME"), + Arguments.of( "strptime_timestamp:str_str", - // using precision 6 here to be compatible with Calcite + "2026-01-29T12:34:56", + "%Y:%m:%dT%H:%M:%S", TypeCreator.REQUIRED.precisionTimestamp(6), - timestampString, - formatString); - - // tests Substrait -> Calcite - RexNode calciteExpr = strptimeTimestampFn.accept(expressionRexConverter, Context.newContext()); - assertEquals(SqlKind.OTHER_FUNCTION, calciteExpr.getKind()); - assertInstanceOf(RexCall.class, calciteExpr); - - assertEquals( - "PARSE_TIMESTAMP('%Y:%m:%dT%H:%M:%S':VARCHAR, '2026-01-29T12:34:56':VARCHAR)", - calciteExpr.toString()); - - // tests the reverse Calcite -> Substrait - Expression reverse = calciteExpr.accept(rexExpressionConverter); - assertEquals(strptimeTimestampFn, reverse); + "PARSE_TIMESTAMP"), + Arguments.of( + "strptime_date:str_str", + "2026-01-29", + "%Y:%m:%d", + TypeCreator.REQUIRED.DATE, + "PARSE_DATE")); } - @Test - void testStrptimeDate() { - Expression.StrLiteral dateString = Expression.StrLiteral.builder().value("2026-01-29").build(); - Expression.StrLiteral formatString = Expression.StrLiteral.builder().value("%Y:%m:%d").build(); - ScalarFunctionInvocation strptimeDateFn = + @ParameterizedTest + @MethodSource("strptimeTestCases") + void testStrptimeFunctions( + String functionSignature, + String inputValue, + String formatValue, + Type outputType, + String expectedCalciteFunctionName) { + Expression.StrLiteral inputString = Expression.StrLiteral.builder().value(inputValue).build(); + Expression.StrLiteral formatString = Expression.StrLiteral.builder().value(formatValue).build(); + ScalarFunctionInvocation strptimeFn = sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, - "strptime_date:str_str", - TypeCreator.REQUIRED.DATE, - dateString, + functionSignature, + outputType, + inputString, formatString); // tests Substrait -> Calcite - RexNode calciteExpr = strptimeDateFn.accept(expressionRexConverter, Context.newContext()); + RexNode calciteExpr = strptimeFn.accept(expressionRexConverter, Context.newContext()); assertEquals(SqlKind.OTHER_FUNCTION, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); - assertEquals("PARSE_DATE('%Y:%m:%d':VARCHAR, '2026-01-29':VARCHAR)", calciteExpr.toString()); + String expectedCallString = + String.format( + "%s('%s':VARCHAR, '%s':VARCHAR)", expectedCalciteFunctionName, formatValue, inputValue); + assertEquals(expectedCallString, calciteExpr.toString()); // tests the reverse Calcite -> Substrait Expression reverse = calciteExpr.accept(rexExpressionConverter); - assertEquals(strptimeDateFn, reverse); + assertEquals(strptimeFn, reverse); } }