diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java index 872845f0..9593d875 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java @@ -126,6 +126,8 @@ public void generate( writer.addImport("static com.hedera.pbj.runtime.ProtoConstants.*"); writer.addImport("static com.hedera.pbj.runtime.Utf8Tools.*"); + var sbFunc = new StringBuilder(); + // spotless:off writer.append(""" /** @@ -164,7 +166,7 @@ public void generate( .replace("$codecClass", codecClassName) .replace("$cacheableSupport", cacheableSupport) .replace("$unsetOneOfConstants", CodecParseMethodGenerator.generateUnsetOneOfConstants(fields)) - .replace("$parseMethod", CodecParseMethodGenerator.generateParseMethod(modelClassName, schemaClassName, fields, !cacheableSupport.isBlank())) + .replace("$parseMethod", CodecParseMethodGenerator.generateParseMethod(sbFunc, modelClassName, schemaClassName, fields, !cacheableSupport.isBlank())) .replace("$writeMethod", writeMethod) .replace("$writeByteArrayMethod", writeByteArrayMethod) .replace("$measureDataMethod", CodecMeasureDataMethodGenerator.generateMeasureMethod(modelClassName, fields)) @@ -172,6 +174,7 @@ public void generate( .replace("$fastEqualsMethod", CodecFastEqualsMethodGenerator.generateFastEqualsMethod(modelClassName, fields)) .replace("$getDefaultInstanceMethod", generateGetDefaultInstanceMethod(modelClassName)) ); + writer.append(sbFunc.toString()); // spotless:on for (final var item : msgDef.messageBody().messageElement()) { diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java index e1ae6e20..5b3b7a68 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java @@ -10,6 +10,7 @@ import com.hedera.pbj.compiler.impl.PbjCompilerException; import com.hedera.pbj.compiler.impl.generators.ModelGenerator; import edu.umd.cs.findbugs.annotations.NonNull; +import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -47,10 +48,14 @@ static String generateUnsetOneOfConstants(final List fields) { } static String generateParseMethod( + StringBuilder sbFunc, final String modelClassName, final String schemaClassName, final List fields, final boolean isCacheable) { + + ParseAndDefaultBody parseAndDefaultBodyPair = + generateParseLoop(generateCaseStatements(sbFunc, fields, schemaClassName), "", schemaClassName); // spotless:off return """ /** @@ -105,6 +110,11 @@ static String generateParseMethod( throw new ParseException(anyException); } } + + private List defaultCase(int tag, int field, FieldDefinition f, boolean strictMode, boolean parseUnknownFields, List $unknownFields, ReadableSequentialData input, int maxSize) throws ParseException, IOException { + $defaultCaseBody + return $unknownFields; + } """ .replace("$cacheableSupport", isCacheable ? generateCacheableSupport(modelClassName, fields) : "return new $modelClassName($fieldsList);") .replace("$modelClassName",modelClassName) @@ -117,7 +127,8 @@ static String generateParseMethod( fields.stream().map(field -> "temp_"+field.name()).collect(Collectors.joining(", ")) + (fields.isEmpty() ? "" : ", ") + "$unknownFields" ) - .replace("$parseLoop", generateParseLoop(generateCaseStatements(fields, schemaClassName), "", schemaClassName)) + .replace("$parseLoop", parseAndDefaultBodyPair.parseBody()) + .replace("$defaultCaseBody", parseAndDefaultBodyPair.defaultBody()) .replace("$listFieldsWriteProtection", fields.stream() .filter(Field::repeated) .map(field -> "if (temp_" + field.name() + " instanceof UnmodifiableArrayList ual) ual.makeReadOnly();") @@ -159,11 +170,15 @@ static String generateCacheableSupport(String modelClassName, final List .indent(DEFAULT_INDENT * 2); } + public record ParseAndDefaultBody(String parseBody, String defaultBody) {} + // prefix is pre-pended to variable names to support a nested parsing loop. - static String generateParseLoop( + // The list returned is [$parseLoop, $defaultCaseBody] + static ParseAndDefaultBody generateParseLoop( final String caseStatements, @NonNull final String prefix, @NonNull final String schemaClassName) { // spotless:off - return """ + List list = new ArrayList<>(); + list.add(""" // -- PARSE LOOP --------------------------------------------- // Continue to parse bytes out of the input stream until we get to the end. while (input.hasRemaining()) { @@ -191,77 +206,74 @@ static String generateParseLoop( // Given the wire type and the field type, parse the field switch ($prefixtag) { $caseStatements - default -> { - // The wire type is the bottom 3 bits of the byte. Read that off - final int wireType = $prefixtag & TAG_WIRE_TYPE_MASK; - // handle error cases here, so we do not do if statements in normal loop - // Validate the field number is valid (must be > 0) - if ($prefixfield == 0) { - throw new IOException("Bad protobuf encoding. We read a field value of " - + $prefixfield); - } - // Validate the wire type is valid (must be >=0 && <= 5). - // Otherwise we cannot parse this. - // Note: it is always >= 0 at this point (see code above where it is defined). - if (wireType > 5) { - throw new IOException("Cannot understand wire_type of " + wireType); - } - // It may be that the parser subclass doesn't know about this field - if ($prefixf == null) { - if (strictMode) { - // Since we are parsing is strict mode, this is an exceptional condition. - throw new UnknownFieldException($prefixfield); - } else if (parseUnknownFields) { - if ($unknownFields == null) { - $unknownFields = new ArrayList<>($initialSizeOfUnknownFieldsArray); - } - $unknownFields.add(new UnknownField( - field, - ProtoConstants.get(wireType), - extractField(input, ProtoConstants.get(wireType), $skipMaxSize) - )); - } else { - // We just need to read off the bytes for this field to skip it - // and move on to the next one. - skipField(input, ProtoConstants.get(wireType), $skipMaxSize); - } - } else { - throw new IOException("Bad tag [" + $prefixtag + "], field [" + $prefixfield - + "] wireType [" + wireType + "]"); - } - } + default -> $unknownFields = defaultCase(tag, field, f, strictMode, parseUnknownFields, $unknownFields, input, maxSize); } } - """ - .replace("$caseStatements",caseStatements) - .replace("$prefix",prefix) - .replace("$schemaClassName",schemaClassName) +"""); + list.add(""" + // The wire type is the bottom 3 bits of the byte. Read that off + final int wireType = $prefixtag & TAG_WIRE_TYPE_MASK; + // handle error cases here, so we do not do if statements in normal loop + // Validate the field number is valid (must be > 0) + if ($prefixfield == 0) { + throw new IOException("Bad protobuf encoding. We read a field value of " + + $prefixfield); + } + // Validate the wire type is valid (must be >=0 && <= 5). + // Otherwise we cannot parse this. + // Note: it is always >= 0 at this point (see code above where it is defined). + if (wireType > 5) { + throw new IOException("Cannot understand wire_type of " + wireType); + } + // It may be that the parser subclass doesn't know about this field + if ($prefixf == null) { + if (strictMode) { + // Since we are parsing is strict mode, this is an exceptional condition. + throw new UnknownFieldException($prefixfield); + } else if (parseUnknownFields) { + if ($unknownFields == null) { + $unknownFields = new ArrayList<>($initialSizeOfUnknownFieldsArray); + } + $unknownFields.add(new UnknownField( + field, + ProtoConstants.get(wireType), + extractField(input, ProtoConstants.get(wireType), $skipMaxSize) + )); + } else { + // We just need to read off the bytes for this field to skip it + // and move on to the next one. + skipField(input, ProtoConstants.get(wireType), $skipMaxSize); + } + } else { + throw new IOException("Bad tag [" + $prefixtag + "], field [" + $prefixfield + + "] wireType [" + wireType + "]"); + }"""); + for (int i = 0; i < list.size(); i++) { + list.set(i, list.get(i) + .replace("$caseStatements", caseStatements) + .replace("$prefix", prefix) + .replace("$schemaClassName", schemaClassName) .replace("$skipMaxSize", "maxSize") - .indent(DEFAULT_INDENT); + .indent(DEFAULT_INDENT)); + } // spotless:on + return new ParseAndDefaultBody(list.get(0), list.get(1)); } - /** - * Generate switch case statements for each tag (field & wire type pair). For repeated numeric value types we - * generate 2 case statements for packed and unpacked encoding. - * - * @param fields list of all fields in record - * @return string of case statement code - */ - private static String generateCaseStatements(final List fields, final String schemaClassName) { + private static String generateCaseStatements(StringBuilder sbFunc, List fields, String schemaClassName) { StringBuilder sb = new StringBuilder(); for (Field field : fields) { if (field instanceof final OneOfField oneOfField) { for (final Field subField : oneOfField.fields()) { - generateFieldCaseStatement(sb, subField, schemaClassName); + generateFieldCaseStatement(sb, sbFunc, subField, schemaClassName); } } else if (field.repeated() && field.type().wireType() != Common.TYPE_LENGTH_DELIMITED) { // for repeated fields that are not length encoded there are 2 forms they can be stored in file. // "packed" and repeated primitive fields - generateFieldCaseStatement(sb, field, schemaClassName); - generateFieldCaseStatementPacked(sb, field); + generateFieldCaseStatement(sb, sbFunc, field, schemaClassName); + generateFieldCaseStatementPacked(sb, sbFunc, field); } else { - generateFieldCaseStatement(sb, field, schemaClassName); + generateFieldCaseStatement(sb, sbFunc, field, schemaClassName); } } return sb.toString().indent(DEFAULT_INDENT * 4); @@ -271,16 +283,24 @@ private static String generateCaseStatements(final List fields, final Str * Generate switch case statement for a repeated numeric value type in packed encoding. * * @param field field to generate case statement for - * @param sb StringBuilder to append code to + * @param sbCase code written in case statement + * @param sbFunc code written in class scope, used to create functions */ @SuppressWarnings("StringConcatenationInsideStringBufferAppend") - private static void generateFieldCaseStatementPacked(final StringBuilder sb, final Field field) { + private static void generateFieldCaseStatementPacked( + StringBuilder sbCase, StringBuilder sbFunc, final Field field) { final int wireType = Common.TYPE_LENGTH_DELIMITED; final int fieldNum = field.fieldNumber(); final int tag = Common.getTag(wireType, fieldNum); + var fieldType = + field.type() == Field.FieldType.ENUM ? field.repeated() ? "List" : "Object" : field.javaFieldType(); + var tempFieldName = "temp_" + field.name(); // spotless:off - sb.append("case %d /* type=%d [%s] packed-repeated field=%d [%s] */ -> {%n" + sbCase.append("case %d /* type=%d [%s] packed-repeated field=%d [%s] */ -> {%n" .formatted(tag, wireType, field.type(), fieldNum, field.name())); + sbCase.append("%s = case%d(input, maxSize, %s);%n".formatted(tempFieldName, tag, tempFieldName)); + sbFunc.append(""" +%s case%d(ReadableSequentialData input, int maxSize, %s %s) throws ParseException, IOException {""".formatted(fieldType, tag, fieldType, tempFieldName)); final String preRead; if (field.type() == Field.FieldType.ENUM) { // spotless:off @@ -299,7 +319,7 @@ private static void generateFieldCaseStatementPacked(final StringBuilder sb, fin } else { preRead = ""; } - sb.append(""" + sbFunc.append(""" // Read the length of packed repeated field data final long length = input.readVarInt(false); if (length > $maxSize) { @@ -317,13 +337,14 @@ private static void generateFieldCaseStatementPacked(final StringBuilder sb, fin input.limit(beforeLimit); if (input.position() != beforePosition + length) { throw new BufferUnderflowException(); - }""".replace("$tempFieldName", "temp_" + field.name()) + }""".replace("$tempFieldName", tempFieldName) .replace("$preRead", preRead) .replace("$readMethod", field.type() == Field.FieldType.ENUM ? "value" : readMethod(field)) .replace("$maxSize", field.maxSize() >= 0 ? String.valueOf(field.maxSize()) : "maxSize") .replace("$fieldName", field.name()) .indent(DEFAULT_INDENT)); - sb.append("\n}\n"); + sbCase.append("\n}\n"); + sbFunc.append(" return %s;\n }\n".formatted(tempFieldName)); // spotless:on } @@ -331,20 +352,21 @@ private static void generateFieldCaseStatementPacked(final StringBuilder sb, fin * Generate switch case statement for a field. * * @param field field to generate case statement for - * @param sb StringBuilder to append code to + * @param sbCase code written in case statement + * @param sbFunc code written in class scope, used to create functions */ private static void generateFieldCaseStatement( - final StringBuilder sb, final Field field, final String schemaClassName) { + StringBuilder sbCase, StringBuilder sbFunc, final Field field, final String schemaClassName) { final int wireType = field.optionalValueType() ? Common.TYPE_LENGTH_DELIMITED : field.type().wireType(); final int fieldNum = field.fieldNumber(); final int tag = Common.getTag(wireType, fieldNum); // spotless:off - sb.append("case %d /* type=%d [%s] field=%d [%s] */ -> {%n" + sbCase.append("case %d /* type=%d [%s] field=%d [%s] */ -> {%n" .formatted(tag, wireType, field.type(), fieldNum, field.name())); if (field.optionalValueType()) { - sb.append(""" + sbCase.append(""" // Read the message size, it is not needed final var valueTypeMessageSize = input.readVarInt(false); final $fieldType value; @@ -386,11 +408,11 @@ private static void generateFieldCaseStatement( })) .indent(DEFAULT_INDENT) ); - sb.append('\n'); + sbCase.append('\n'); // spotless:on } else if (field.type() == Field.FieldType.MESSAGE) { // spotless:off - sb.append(""" + sbCase.append(""" final var messageLength = input.readVarInt(false); final $fieldType value; if (messageLength == 0) { @@ -432,8 +454,10 @@ private static void generateFieldCaseStatement( // However(!), we read the key and value fields explicitly to avoid creating temporary entry objects. final MapField mapField = (MapField) field; final List mapEntryFields = List.of(mapField.keyField(), mapField.valueField()); + ParseAndDefaultBody parseAndDefaultBodyPair = generateParseLoop( + generateCaseStatements(sbFunc, mapEntryFields, schemaClassName), "map_entry_", schemaClassName); // spotless:off - sb.append(""" + sbCase.append(""" final var __map_messageLength = input.readVarInt(false); $fieldDefs @@ -466,14 +490,14 @@ private static void generateFieldCaseStatement( .replace("$fieldDefs",mapEntryFields.stream().map(mapEntryField -> "%s temp_%s = %s;".formatted(mapEntryField.javaFieldType(), mapEntryField.name(), mapEntryField.javaDefault())).collect(Collectors.joining("\n"))) - .replace("$mapParseLoop", generateParseLoop(generateCaseStatements(mapEntryFields, schemaClassName), "map_entry_", schemaClassName) + .replace("$mapParseLoop", parseAndDefaultBodyPair.parseBody() .indent(-DEFAULT_INDENT)) .replace("$maxSize", field.maxSize() >= 0 ? String.valueOf(field.maxSize()) : "maxSize") ); // spotless:on } else if (field.type() == Field.FieldType.ENUM) { // spotless:off - sb.append(""" + sbCase.append(""" final int enumOrdinal = readEnum(input); final var value = $enumName.fromProtobufOrdinal(enumOrdinal); """ @@ -483,20 +507,20 @@ private static void generateFieldCaseStatement( ); // spotless:on } else { - sb.append(("final var value = " + readMethod(field) + ";\n").indent(DEFAULT_INDENT)); + sbCase.append(("final var value = " + readMethod(field) + ";\n").indent(DEFAULT_INDENT)); } // set value to temp var // spotless:off - sb.append(Common.FIELD_INDENT); + sbCase.append(Common.FIELD_INDENT); if (field.parent() != null && field.repeated()) { throw new PbjCompilerException("Fields can not be oneof and repeated ["+field+"]"); } else if (field.parent() != null) { final var oneOfField = field.parent(); - sb.append("temp_%s = new %s<>(%s.%s, value);%n" + sbCase.append("temp_%s = new %s<>(%s.%s, value);%n" .formatted(oneOfField.name(), oneOfField.className(), oneOfField.getEnumClassRef(), Common.camelToUpperSnake(field.name()))); } else if (field.repeated()) { - sb.append( + sbCase.append( """ if (temp_%s.size() >= %s) { throw new ParseException("%1$s size %%d is greater than max %2$s".formatted(temp_%1$s.size())); @@ -505,7 +529,7 @@ private static void generateFieldCaseStatement( """.formatted(field.name(), field.maxSize() >= 0 ? String.valueOf(field.maxSize()) : "maxSize")); } else if (field.type() == Field.FieldType.MAP) { final MapField mapField = (MapField) field; - sb.append( + sbCase.append( """ if (__map_messageLength != 0) { if (temp_%s.size() >= %s) { @@ -517,7 +541,7 @@ private static void generateFieldCaseStatement( mapField.keyField().name(), mapField.valueField().name())); } else if (field.type() == Field.FieldType.ENUM) { // spotless:off - sb.append(""" + sbCase.append(""" temp_$fieldName = value != $enumName.UNRECOGNIZED ? value : Integer.valueOf(enumOrdinal); """ .replace("$enumName", Common.snakeToCamel(field.messageType(), true)) @@ -526,9 +550,9 @@ private static void generateFieldCaseStatement( ); // spotless:on } else { - sb.append("temp_%s = value;\n".formatted(field.name())); + sbCase.append("temp_%s = value;\n".formatted(field.name())); } - sb.append("}\n"); + sbCase.append("}\n"); // spotless:on } diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/GenericParserQuickBench.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/GenericParserQuickBench.java new file mode 100644 index 00000000..d4c72fc2 --- /dev/null +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/GenericParserQuickBench.java @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.integration.jmh; + +import com.hedera.pbj.runtime.Codec; +import com.hedera.pbj.runtime.ParseException; +import com.hedera.pbj.runtime.io.buffer.BufferedData; +import com.hedera.pbj.runtime.io.buffer.Bytes; +import com.hedera.pbj.test.proto.pbj.NotCacheableAccountID; +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +@SuppressWarnings("unused") +@State(Scope.Benchmark) +@Fork(3) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@BenchmarkMode(Mode.Throughput) +public class GenericParserQuickBench { + private static final int INVOCATIONS = 1 * 1024; + + @State(Scope.Thread) + public static class BenchState { + record Model(int maxSize, Function factory, Codec codec) {} + + public enum Type { + NotCacheableAccountIDType(new Model( + 256, + random -> { + final NotCacheableAccountID.Builder builder = NotCacheableAccountID.newBuilder() + .shardNum(random.nextLong()) + .realmNum(random.nextLong()); + if (random.nextBoolean()) { + builder.accountNum(random.nextLong()); + } else { + byte[] arr = new byte[32]; + random.nextBytes(arr); + builder.alias(Bytes.wrap(arr)); + } + return builder.build(); + }, + NotCacheableAccountID.PROTOBUF)); + + private final Model model; + + Type(Model model) { + this.model = model; + } + } + + @Param + Type type; + + Model model; + byte[] array; + BufferedData bd; + + @Setup(Level.Trial) + public void setup() throws IOException { + model = type.model; + + array = new byte[INVOCATIONS * model.maxSize]; + // For determinism: + final Random random = new Random(723049435); + bd = BufferedData.wrap(array); + for (int i = 0, j = 0; i < INVOCATIONS; i++) { + model.codec.write(model.factory.apply(random), bd); + } + bd.flip(); + } + + @TearDown(Level.Trial) + public void tearDown() {} + } + + @Benchmark + @OperationsPerInvocation(INVOCATIONS) + public void bench(final BenchState state, final Blackhole blackhole) throws ParseException { + for (int invocation = 0; invocation < INVOCATIONS; invocation++) { + blackhole.consume(state.model.codec.parse(state.bd)); + } + } + + public static void main(String[] args) throws Exception { + Options opt = new OptionsBuilder() + .include(GenericParserQuickBench.class.getSimpleName()) + .build(); + + new Runner(opt).run(); + } +}