From b5c0e15f5591065724d8136d2a01e9e4c32c9a41 Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Tue, 10 Mar 2026 11:01:24 +0100 Subject: [PATCH] #833 Add strict schema check option for the EBCDIC writer. --- README.md | 19 ++- .../reader/parameters/CobolParameters.scala | 2 + .../parameters/CobolParametersParser.scala | 3 + .../reader/parameters/ReaderParameters.scala | 5 +- .../cobol/writer/NestedRecordCombiner.scala | 52 ++++--- .../fixtures/TextComparisonFixture.scala | 12 ++ .../cobol/writer/NestedWriterSuite.scala | 140 ++++++++++++++---- 7 files changed, 176 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index 08a7eb3d2..b868800d9 100644 --- a/README.md +++ b/README.md @@ -1635,6 +1635,12 @@ The output looks like this: | .option("debug_ignore_file_size", "true") | If 'true' no exception will be thrown if record size does not match file size. Useful for debugging copybooks to make them match a data file. | | .option("enable_self_checks", "true") | If 'true' Cobrix will run self-checks to validate internal consistency. Note: Enabling this option may impact performance, especially for large datasets. It is recommended to disable this option in performance-critical environments. The only check implemented so far is custom record extractor indexing compatibility check. | +##### Writer-only options + +| Option (usage example) | Description | +|----------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| .option("strict_schema", "true") | If 'true' (default) Cobrix will throw an exception if a field exists in the copybook but not in the Spark schema. Array count fields (defined in DEPENDING ON clause) are auto-generated and never required to exist in Spark schema. | + ##### Currently supported EBCDIC code pages | Option | Code page | Description | @@ -1798,18 +1804,15 @@ df.write ### Current Limitations The writer is still in its early stages and has several limitations: -- Nested GROUPs are not supported. Only flat copybooks can be used, for example: - ```cobol - 01 RECORD. - 05 FIELD_1 PIC X(1). - 05 FIELD_2 PIC X(5). - ``` +- Nested GROUPs, OCCURS, OCCURS DEPENDING ON are supported. +- Variable-size occurs are supported with `variable_size_occurs = true`. +- Writing multi-segment files is not supported. - Supported types: - `PIC X(n)` alphanumeric. - `PIC S9(n)` numeric (integral and decimal) with `DISPLAY`, `COMP`/`COMP-4`/`COMP-5` (big-endian), `COMP-3`, and `COMP-9` (Cobrix little-endian). -- Only fixed record length output is supported (`record_format = F`). -- `REDEFINES` and `OCCURS` are not supported. +- Only fixed record length and variable record length with RDWs are supported (`record_format` is either `F` or `V`). +- `REDEFINES` are ignored. Cobrix writes only the first field in a REDEFINES group. - Partitioning by DataFrame fields is not supported. ### Implementation details diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParameters.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParameters.scala index 7a6b63432..3e5e2e53c 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParameters.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParameters.scala @@ -49,6 +49,7 @@ import za.co.absa.cobrix.cobol.reader.policies.SchemaRetentionPolicy.SchemaReten * @param generateRecordBytes Generate 'record_bytes' field containing raw bytes of the original record * @param generateCorruptFields Generate '_corrupt_fields' field for fields that haven't converted successfully * @param schemaRetentionPolicy A copybook usually has a root group struct element that acts like a rowtag in XML. This can be retained in Spark schema or can be collapsed + * @param strictSchema If true, when writing files in mainframe format each field in the copybook must exist in the Spark schema. * @param stringTrimmingPolicy Specify if and how strings should be trimmed when parsed * @param isDisplayAlwaysString If true, all fields having DISPLAY format will remain strings and won't be converted to numbers * @param allowPartialRecords If true, partial ASCII records can be parsed (in cases when LF character is missing for example) @@ -90,6 +91,7 @@ case class CobolParameters( generateRecordBytes: Boolean, generateCorruptFields: Boolean, schemaRetentionPolicy: SchemaRetentionPolicy, + strictSchema: Boolean, stringTrimmingPolicy: StringTrimmingPolicy, isDisplayAlwaysString: Boolean, allowPartialRecords: Boolean, diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala index 2f37a1a0f..197f182c0 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala @@ -73,6 +73,7 @@ object CobolParametersParser extends Logging { val PARAM_FILLER_NAMING_POLICY = "filler_naming_policy" val PARAM_STRICT_INTEGRAL_PRECISION = "strict_integral_precision" val PARAM_DISPLAY_PIC_ALWAYS_STRING = "display_pic_always_string" + val PARAM_STRICT_SCHEMA = "strict_schema" val PARAM_GROUP_NOT_TERMINALS = "non_terminals" val PARAM_OCCURS_MAPPINGS = "occurs_mappings" @@ -289,6 +290,7 @@ object CobolParametersParser extends Logging { params.getOrElse(PARAM_GENERATE_RECORD_BYTES, "false").toBoolean, params.getOrElse(PARAM_CORRUPT_FIELDS, "false").toBoolean, schemaRetentionPolicy, + params.getOrElse(PARAM_STRICT_SCHEMA, "true").toBoolean, stringTrimmingPolicy, params.getOrElse(PARAM_DISPLAY_PIC_ALWAYS_STRING, "false").toBoolean, params.getOrElse(PARAM_ALLOW_PARTIAL_RECORDS, "false").toBoolean, @@ -445,6 +447,7 @@ object CobolParametersParser extends Logging { generateRecordBytes = parameters.generateRecordBytes, corruptFieldsPolicy = corruptFieldsPolicy, schemaPolicy = parameters.schemaRetentionPolicy, + strictSchema = parameters.strictSchema, stringTrimmingPolicy = parameters.stringTrimmingPolicy, isDisplayAlwaysString = parameters.isDisplayAlwaysString, allowPartialRecords = parameters.allowPartialRecords, diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/ReaderParameters.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/ReaderParameters.scala index 9904fb45c..f71282a81 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/ReaderParameters.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/ReaderParameters.scala @@ -58,8 +58,9 @@ import za.co.absa.cobrix.cobol.reader.policies.SchemaRetentionPolicy.SchemaReten * @param fileEndOffset A number of bytes to skip at the end of each file * @param generateRecordId If true, a record id field will be prepended to each record. * @param generateRecordBytes Generate 'record_bytes' field containing raw bytes of the original record - * @param corruptFieldsPolicy Specifies if '_corrupt_fields' field for fields that haven't converted successfully, and the type of raw values. + * @param corruptFieldsPolicy Specifies if '_corrupt_fields' field for fields that haven't converted successfully, and the type of raw values. * @param schemaPolicy Specifies a policy to transform the input schema. The default policy is to keep the schema exactly as it is in the copybook. + * @param strictSchema If true, when writing files in mainframe format each field in the copybook must exist in the Spark schema. * @param stringTrimmingPolicy Specifies if and how strings should be trimmed when parsed. * @param isDisplayAlwaysString If true, all fields having DISPLAY format will remain strings and won't be converted to numbers. * @param allowPartialRecords If true, partial ASCII records can be parsed (in cases when LF character is missing for example) @@ -114,6 +115,8 @@ case class ReaderParameters( generateRecordBytes: Boolean = false, corruptFieldsPolicy: CorruptFieldsPolicy = CorruptFieldsPolicy.Disabled, schemaPolicy: SchemaRetentionPolicy = SchemaRetentionPolicy.CollapseRoot, + + strictSchema: Boolean = true, stringTrimmingPolicy: StringTrimmingPolicy = StringTrimmingPolicy.TrimBoth, isDisplayAlwaysString: Boolean = false, allowPartialRecords: Boolean = false, diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/NestedRecordCombiner.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/NestedRecordCombiner.scala index 82c636816..b25fdcf9d 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/NestedRecordCombiner.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/NestedRecordCombiner.scala @@ -72,7 +72,7 @@ class NestedRecordCombiner extends RecordCombiner { s"RDW length $recordLengthLong exceeds ${Int.MaxValue} and cannot be encoded safely." ) } - processRDD(df.rdd, cobolSchema.copybook, df.schema, size, adjustment1 + adjustment2, startOffset, hasRdw, isRdwBigEndian, readerParameters.variableSizeOccurs) + processRDD(df.rdd, cobolSchema.copybook, df.schema, size, adjustment1 + adjustment2, startOffset, hasRdw, isRdwBigEndian, readerParameters.variableSizeOccurs, readerParameters.strictSchema) } } @@ -114,13 +114,14 @@ object NestedRecordCombiner { * The purpose of WriterAst class hierarchy is to provide memory and CPU efficient way of creating binary * records from Spark dataframes. It links Cobol schema and Spark schema in a single tree. * - * @param copybook The copybook definition describing the binary record layout and field specifications - * @param schema The Spark StructType schema that corresponds to the structure of the data to be written + * @param copybook The copybook definition describing the binary record layout and field specifications + * @param schema The Spark StructType schema that corresponds to the structure of the data to be written + * @param strictSchema If true, each field in the copybook must exist in the Spark schema. * @return A GroupField representing the root of the writer AST, containing all non-filler, non-redefines * fields with their associated getter functions and position information for binary serialization */ - def constructWriterAst(copybook: Copybook, schema: StructType): GroupField = { - buildGroupField(getAst(copybook), schema, row => row, "", new mutable.HashMap[String, DependingOnField]()) + def constructWriterAst(copybook: Copybook, schema: StructType, strictSchema: Boolean): GroupField = { + buildGroupField(getAst(copybook), schema, row => row, "", new mutable.HashMap[String, DependingOnField](), strictSchema) } /** @@ -142,6 +143,7 @@ object NestedRecordCombiner { * @param hasRdw A flag indicating whether to prepend a Record Descriptor Word header to each output record * @param isRdwBigEndian A flag indicating the byte order for the RDW header, true for big-endian, false for little-endian * @param variableSizeOccurs A flag indicating whether OCCURS DEPENDING ON fields should use actual element counts rather than maximum sizes + * @param strictSchema If true, each field in the copybook must exist in the Spark schema. * @return An RDD of byte arrays, where each array represents one record in binary format according to the copybook specification */ private[cobrix] def processRDD(rdd: RDD[Row], @@ -152,8 +154,9 @@ object NestedRecordCombiner { startOffset: Int, hasRdw: Boolean, isRdwBigEndian: Boolean, - variableSizeOccurs: Boolean): RDD[Array[Byte]] = { - val writerAst = constructWriterAst(copybook, schema) + variableSizeOccurs: Boolean, + strictSchema: Boolean): RDD[Array[Byte]] = { + val writerAst = constructWriterAst(copybook, schema, strictSchema) rdd.mapPartitions { rows => rows.map { row => @@ -212,21 +215,22 @@ object NestedRecordCombiner { * Recursively walks the copybook group and the Spark StructType in lockstep, producing * [[WriterAst]] nodes whose getters extract the correct value from a [[org.apache.spark.sql.Row]]. * - * @param group A copybook Group node whose children will be processed. - * @param schema The Spark StructType that corresponds to `group`. - * @param getter A function that, given the "outer" Row, returns the Row that belongs to this group. - * @param path The path to the field - * @param dependeeMap A map of field names to their corresponding DependingOnField specs, used to resolve dependencies for OCCURS DEPENDING ON fields. + * @param group A copybook Group node whose children will be processed. + * @param schema The Spark StructType that corresponds to `group`. + * @param getter A function that, given the "outer" Row, returns the Row that belongs to this group. + * @param path The path to the field + * @param dependeeMap A map of field names to their corresponding DependingOnField specs, used to resolve dependencies for OCCURS DEPENDING ON fields. + * @param strictSchema If true, each field in the copybook must exist in the Spark schema. * @return A [[GroupField]] covering all non-filler, non-redefines children found in both * the copybook and the Spark schema. */ - private def buildGroupField(group: Group, schema: StructType, getter: GroupGetter, path: String, dependeeMap: mutable.HashMap[String, DependingOnField]): GroupField = { + private def buildGroupField(group: Group, schema: StructType, getter: GroupGetter, path: String, dependeeMap: mutable.HashMap[String, DependingOnField], strictSchema: Boolean): GroupField = { val children = group.children.withFilter { stmt => stmt.redefines.isEmpty }.map { case s if s.isFiller => Filler(s.binaryProperties.actualSize) - case p: Primitive => buildPrimitiveNode(p, schema, path, dependeeMap) - case g: Group => buildGroupNode(g, schema, path, dependeeMap) + case p: Primitive => buildPrimitiveNode(p, schema, path, dependeeMap, strictSchema) + case g: Group => buildGroupNode(g, schema, path, dependeeMap, strictSchema) } GroupField(children.toSeq, group, getter) } @@ -237,7 +241,7 @@ object NestedRecordCombiner { * * Returns a filler when the field is absent from the schema (e.g. filtered out during reading). */ - private def buildPrimitiveNode(p: Primitive, schema: StructType, path: String, dependeeMap: mutable.HashMap[String, DependingOnField]): WriterAst = { + private def buildPrimitiveNode(p: Primitive, schema: StructType, path: String, dependeeMap: mutable.HashMap[String, DependingOnField], strictSchema: Boolean): WriterAst = { def addDependee(): DependingOnField = { val spec = DependingOnField(p, p.binaryProperties.offset) val uppercaseName = p.name.toUpperCase() @@ -282,7 +286,10 @@ object NestedRecordCombiner { if (p.isDependee) { PrimitiveDependeeField(addDependee()) } else { - log.error(s"Field '$path${p.name}' is not found in Spark schema. Will be replaced by filler.") + if (strictSchema) + throw new IllegalArgumentException(s"Field '$path${p.name}' is not found in Spark schema.") + else + log.warn(s"Field '$path${p.name}' is not found in Spark schema. Will be replaced by filler.") Filler(p.binaryProperties.actualSize) } } @@ -295,7 +302,7 @@ object NestedRecordCombiner { * * Returns a filler when the field is absent from the schema. */ - private def buildGroupNode(g: Group, schema: StructType, path: String, dependeeMap: mutable.HashMap[String, DependingOnField]): WriterAst = { + private def buildGroupNode(g: Group, schema: StructType, path: String, dependeeMap: mutable.HashMap[String, DependingOnField], strictSchema: Boolean): WriterAst = { val fieldName = g.name val fieldIndexOpt = schema.fields.zipWithIndex.find { case (field, _) => field.name.equalsIgnoreCase(fieldName) @@ -311,7 +318,7 @@ object NestedRecordCombiner { s"Array group '${g.name}' depends on '$dependingOn' which is not found among previously processed fields." )) } - val childAst = buildGroupField(g, elementType, row => row, s"$path${g.name}.", dependeeMap) + val childAst = buildGroupField(g, elementType, row => row, s"$path${g.name}.", dependeeMap, strictSchema) GroupArray(childAst, g, row => row.getAs[mutable.WrappedArray[AnyRef]](idx), dependingOnField) case other => throw new IllegalArgumentException( @@ -322,7 +329,7 @@ object NestedRecordCombiner { schema(idx).dataType match { case nestedSchema: StructType => val childGetter: GroupGetter = row => row.getAs[Row](idx) - val childAst = buildGroupField(g, nestedSchema, childGetter, s"$path${g.name}.", dependeeMap) + val childAst = buildGroupField(g, nestedSchema, childGetter, s"$path${g.name}.", dependeeMap, strictSchema) GroupField(childAst.children, g, childGetter) case other => throw new IllegalArgumentException( @@ -330,7 +337,10 @@ object NestedRecordCombiner { } } }.getOrElse { - log.error(s"Field '$path${g.name}' is not found in Spark schema. Will be replaced by filler.") + if (strictSchema) + throw new IllegalArgumentException(s"Field '$path${g.name}' is not found in Spark schema.") + else + log.warn(s"Field '$path${g.name}' is not found in Spark schema. Will be replaced by filler.") Filler(g.binaryProperties.actualSize) } } diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/fixtures/TextComparisonFixture.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/fixtures/TextComparisonFixture.scala index 991e521e7..53476dc45 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/fixtures/TextComparisonFixture.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/fixtures/TextComparisonFixture.scala @@ -21,6 +21,18 @@ import org.scalatest.{Assertion, Suite} trait TextComparisonFixture { this: Suite => + protected def compareBinary(actual: Array[Byte], expected: Array[Byte], clue: String = "Binary data does not match"): Assertion = { + if (!actual.sameElements(expected)) { + println(s"Expected bytes: ${expected.map("%02X" format _).mkString(" ")}") + println(s"Actual bytes: ${actual.map("%02X" format _).mkString(" ")}") + //println(s"Actual bytes: ${bytes.map("0x%02X" format _).mkString(", ")}") + + assert(actual.sameElements(expected), clue) + } else { + succeed + } + } + protected def compareText(actual: String, expected: String): Assertion = { if (actual.replaceAll("[\r\n]", "") != expected.replaceAll("[\r\n]", "")) { fail(renderTextDifference(actual, expected)) diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/writer/NestedWriterSuite.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/writer/NestedWriterSuite.scala index 49b668457..712ad351e 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/writer/NestedWriterSuite.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/writer/NestedWriterSuite.scala @@ -195,13 +195,74 @@ class NestedWriterSuite extends AnyWordSpec with SparkTestBase with BinaryFileFi 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ).map(_.toByte) - if (!bytes.sameElements(expected)) { - println(s"Expected bytes: ${expected.map("%02X" format _).mkString(" ")}") - println(s"Actual bytes: ${bytes.map("%02X" format _).mkString(" ")}") - //println(s"Actual bytes: ${bytes.map("0x%02X" format _).mkString(", ")}") + compareBinary(bytes, expected, "Written data should match expected EBCDIC encoding") + } + } + + "write the dataframe with OCCURS without strict schema check" in { + val exampleJsons = Seq( + """{"ID":1,"NUMBERS":[10,20,30],"PLACE":{"COUNTRY_CODE":"US","CITY":"New York"},"PEOPLE":[{"NAME":"John Doe"},{"NAME": "Jane Smith"}]}""" + ) + + import spark.implicits._ + + val df = spark.read.json(exampleJsons.toDS()) + .select("ID", "NUMBERS", "PLACE", "PEOPLE") - assert(bytes.sameElements(expected), "Written data should match expected EBCDIC encoding") + withTempDirectory("cobol_writer1") { tempDir => + val path = new Path(tempDir, "writer1") + + val ex = intercept[IllegalArgumentException] { + df.write + .format("cobol") + .mode(SaveMode.Overwrite) + .option("copybook_contents", copybookWithOccurs) + .save(path.toString) } + + assert(ex.getMessage == "Field 'PEOPLE.PHONE_NUMBER' is not found in Spark schema.") + + df.coalesce(1) + .orderBy("id") + .write + .format("cobol") + .mode(SaveMode.Overwrite) + .option("copybook_contents", copybookWithOccurs) + .option("record_format", "F") + .option("strict_schema", "false") + .option("variable_size_occurs", "true") + .save(path.toString) + + // val df2 = spark.read.format("cobol") + // .option("copybook_contents", copybookWithOccurs) + // .option("variable_size_occurs", "true") + // .load(path.toString) + // + // println(SparkUtils.convertDataFrameToPrettyJSON(df2)) + + val fs = path.getFileSystem(spark.sparkContext.hadoopConfiguration) + + assert(fs.exists(path), "Output directory should exist") + val files = fs.listStatus(path) + .filter(_.getPath.getName.startsWith("part-")) + assert(files.nonEmpty, "Output directory should contain part files") + + val partFile = files.head.getPath + val data = fs.open(partFile) + val bytes = new Array[Byte](files.head.getLen.toInt) + data.readFully(bytes) + data.close() + + // Expected EBCDIC data for sample test data + val expected = Array( + 0xF0, 0xF1, 0x00, 0xF1, 0xF0, 0xF2, 0xF0, 0xF3, 0xF0, 0xE4, 0xE2, 0xD5, 0x85, 0xA6, 0x40, 0xE8, 0x96, + 0x99, 0x92, 0x40, 0x40, 0xD1, 0x96, 0x88, 0x95, 0x40, 0xC4, 0x96, 0x85, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xD1, 0x81, 0x95, + 0x85, 0x40, 0xE2, 0x94, 0x89, 0xA3, 0x88, 0x40, 0x40, 0x40, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + ).map(_.toByte) + + compareBinary(bytes, expected, "Written data should match expected EBCDIC encoding") } } @@ -283,13 +344,7 @@ class NestedWriterSuite extends AnyWordSpec with SparkTestBase with BinaryFileFi 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ).map(_.toByte) - if (!bytes.sameElements(expected)) { - println(s"Expected bytes: ${expected.map("%02X" format _).mkString(" ")}") - println(s"Actual bytes: ${bytes.map("%02X" format _).mkString(" ")}") - //println(s"Actual bytes: ${bytes.map("0x%02X" format _).mkString(", ")}") - - assert(bytes.sameElements(expected), "Written data should match expected EBCDIC encoding") - } + compareBinary(bytes, expected, "Written data should match expected EBCDIC encoding") } } @@ -355,13 +410,7 @@ class NestedWriterSuite extends AnyWordSpec with SparkTestBase with BinaryFileFi 0xF1, 0xF2, 0xF3, 0xF5, 0x40, 0x40, 0x40, 0x40 ).map(_.toByte) - if (!bytes.sameElements(expected)) { - println(s"Expected bytes: ${expected.map("%02X" format _).mkString(" ")}") - println(s"Actual bytes: ${bytes.map("%02X" format _).mkString(" ")}") - //println(s"Actual bytes: ${bytes.map("0x%02X" format _).mkString(", ")}") - - assert(bytes.sameElements(expected), "Written data should match expected EBCDIC encoding") - } + compareBinary(bytes, expected, "Written data should match expected EBCDIC encoding") } } @@ -424,13 +473,7 @@ class NestedWriterSuite extends AnyWordSpec with SparkTestBase with BinaryFileFi 0xF1, 0xF2, 0xF3, 0xF5, 0x40, 0x40, 0x40, 0x40 ).map(_.toByte) - if (!bytes.sameElements(expected)) { - println(s"Expected bytes: ${expected.map("%02X" format _).mkString(" ")}") - println(s"Actual bytes: ${bytes.map("%02X" format _).mkString(" ")}") - println(s"Actual bytes: ${bytes.map("0x%02X" format _).mkString(", ")}") - - assert(bytes.sameElements(expected), "Written data should match expected EBCDIC encoding") - } + compareBinary(bytes, expected, "Written data should match expected EBCDIC encoding") } } } @@ -470,11 +513,54 @@ class NestedWriterSuite extends AnyWordSpec with SparkTestBase with BinaryFileFi children(5) = cnt2 val ex = intercept[IllegalArgumentException] { - NestedRecordCombiner.constructWriterAst(parsedCopybook, df.schema) + NestedRecordCombiner.constructWriterAst(parsedCopybook, df.schema, strictSchema = false) } assert(ex.getMessage == "Duplicate field name 'CNT1' found in copybook. Field names must be unique (case-insensitive) when OCCURS DEPENDING ON is used. Already found a dependee field with the same name at line 4, current field line number: 10.") } + + "fail when a field in the copybook does not exits in Spark schema" in { + val copybook = + """ 01 RECORD. + | 05 ID PIC 9(2). + | 05 FILLER PIC 9(1). + | 05 CNT1 PIC 9(1). + | 05 NUMBERS PIC 9(2) + | OCCURS 0 TO 5 DEPENDING ON CNT1. + | 05 PLACE. + | 10 COUNTRY-CODE PIC X(2). + | 10 CITY PIC X(10). + | 05 CNT2 PIC 9(1). + | 05 PEOPLE + | OCCURS 0 TO 3 DEPENDING ON CNT1. + | 10 NAME PIC X(14). + | 10 FILLER PIC X(1). + | 10 PHONE-NUMBER PIC X(12). + |""".stripMargin + val exampleJsons = Seq( + """{"ID":1,"NUMBERS":[10,20,30],"PLACE":{"COUNTRY_CODE":"US","CITY":"New York"},"PEOPLE":[{"NAME":"John Doe"},{"NAME": "Jane Smith"}]}""" + ) + + import spark.implicits._ + + val df = spark.read.json(exampleJsons.toDS()) + .select("ID", "NUMBERS", "PLACE", "PEOPLE") + + val parsedCopybook: Copybook = CopybookParser.parse(copybook) + val ast = parsedCopybook.ast + val children = ast.children.head.asInstanceOf[Group].children + val cnt2 = children(5).asInstanceOf[Primitive].withUpdatedIsDependee(true) + children(5) = cnt2 + + // This should not throw + NestedRecordCombiner.constructWriterAst(parsedCopybook, df.schema, strictSchema = false) + + val ex = intercept[IllegalArgumentException] { + NestedRecordCombiner.constructWriterAst(parsedCopybook, df.schema, strictSchema = true) + } + + assert(ex.getMessage == "Field 'PEOPLE.PHONE_NUMBER' is not found in Spark schema.") + } } }