diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/SparkUtils.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/SparkUtils.scala index fb180378..99f3a45d 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/SparkUtils.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/SparkUtils.scala @@ -32,7 +32,7 @@ import za.co.absa.pramen.core.utils.SparkMaster.Databricks import java.io.ByteArrayOutputStream import java.time.format.DateTimeFormatter import java.time.{Instant, LocalDate} -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.reflect.runtime.universe._ import scala.util.{Failure, Success, Try} @@ -155,45 +155,84 @@ object SparkUtils { } /** - * Compares 2 schemas. + * Compares two schemas represented as `StructType` and identifies the differences + * between them, such as newly added fields, deleted fields, or fields with changed types. + * + * @param schemaA the first schema to compare + * @param schemaB the second schema to compare + * @return a list of `FieldChange` that represents the differences between the two schemas */ - def compareSchemas(schema1: StructType, schema2: StructType): List[FieldChange] = { + def compareSchemas(schemaA: StructType, schemaB: StructType): List[FieldChange] = { + val newFields = new ListBuffer[FieldChange] + val deletedFields = new ListBuffer[FieldChange] + val changedFields = new ListBuffer[FieldChange] + def dataTypeToString(dt: DataType, metadata: Metadata): String = { val maxLength = getLengthFromMetadata(metadata).getOrElse(0) dt match { - case _: StructType | _: ArrayType => dt.simpleString - case _: StringType if maxLength > 0 => s"varchar($maxLength)" - case _ => dt.typeName + case a: ArrayType if a.elementType.isInstanceOf[StructType] => "array>" + case a: ArrayType => s"array<${a.elementType.typeName}>" + case _: StructType => "struct<...>" + case _: StringType if maxLength > 0 => s"varchar($maxLength)" + case _ => dt.typeName } } - val fields1 = schema1.fields.map(f => (f.name, f)).toMap - val fields2 = schema2.fields.map(f => (f.name, f)).toMap - - val newColumns: Array[FieldChange] = schema2.fields - .filter(f => !fields1.contains(f.name)) - .map(f => FieldChange.NewField(f.name, dataTypeToString(f.dataType, f.metadata))) - - val deletedColumns: Array[FieldChange] = schema1.fields - .filter(f => !fields2.contains(f.name)) - .map(f => FieldChange.DeletedField(f.name, dataTypeToString(f.dataType, f.metadata))) + def processStruct(schema1: StructType, schema2: StructType, path: String = ""): Unit = { + val fields1 = schema1.fields.map(f => (f.name, f)).toMap + val fields2 = schema2.fields.map(f => (f.name, f)).toMap + + val newColumns: Array[FieldChange] = schema2.fields + .filter(f => !fields1.contains(f.name)) + .map(f => FieldChange.NewField(s"$path${f.name}", dataTypeToString(f.dataType, f.metadata))) + newFields ++= newColumns + + val deletedColumns: Array[FieldChange] = schema1.fields + .filter(f => !fields2.contains(f.name)) + .map(f => FieldChange.DeletedField(s"$path${f.name}", dataTypeToString(f.dataType, f.metadata))) + deletedFields ++= deletedColumns + + schema1.fields + .filter(f => fields2.contains(f.name)) + .foreach(f1 => { + val f2 = fields2(f1.name) + + (f1.dataType, f2.dataType) match { + case (st1: StructType, st2: StructType) => + processStruct(st1, st2, s"$path${f1.name}.") + case (ar1: ArrayType, ar2: ArrayType) => + processArray(ar1, ar2, f1.metadata, f2.metadata, s"$path${f1.name}") + case _ => + val dt1 = dataTypeToString(f1.dataType, f1.metadata) + val dt2 = dataTypeToString(f2.dataType, f2.metadata) + + if (dt1 != dt2) { + changedFields += FieldChange.ChangedType(s"$path${f1.name}", dt1, dt2) + } + } + }) + } - val changedType: Array[FieldChange] = schema1.fields - .filter(f => fields2.contains(f.name)) - .flatMap(f1 => { - val dt1 = dataTypeToString(f1.dataType, f1.metadata) - val f2 = fields2(f1.name) - val dt2 = dataTypeToString(f2.dataType, f2.metadata) + def processArray(array1: ArrayType, array2: ArrayType, metadata1: Metadata, metadata2: Metadata, path: String = ""): Unit = { + (array1.elementType, array2.elementType) match { + case (st1: StructType, st2: StructType) => + processStruct(st1, st2, s"$path[].") + case (ar1: ArrayType, ar2: ArrayType) => + processArray(ar1, ar2, metadata1, metadata2, s"$path[]") + case _ => + val dt1 = dataTypeToString(array1, metadata1) + val dt2 = dataTypeToString(array2, metadata2) - if (dt1 == dt2) { - Seq.empty[FieldChange] - } else { - Seq(FieldChange.ChangedType(f1.name, dt1, dt2)) - } - }) + if (dt1 != dt2) { + changedFields += FieldChange.ChangedType(path, dt1, dt2) + } + } + } - (newColumns ++ deletedColumns ++ changedType).toList + processStruct(schemaA, schemaB) + val allChanges = newFields ++ deletedFields ++ changedFields + allChanges.toList } /** diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/SparkUtilsSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/SparkUtilsSuite.scala index 46d49104..9f28a67a 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/SparkUtilsSuite.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/SparkUtilsSuite.scala @@ -287,6 +287,68 @@ class SparkUtilsSuite extends AnyWordSpec with SparkTestBase with TempDirFixture assert(diff.head.asInstanceOf[ChangedType].oldType == "varchar(10)") assert(diff.head.asInstanceOf[ChangedType].newType == "varchar(15)") } + + "detect nested type changes" in { + val schema1 = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("name", StringType, nullable = true), + StructField("address", StructType(Seq( + StructField("street", StringType, nullable = true), + StructField("city", StringType, nullable = true) + ))), + StructField("tags", ArrayType(StringType, containsNull = true), nullable = true), + StructField("phones", ArrayType(StructType(Seq( + StructField("type", StringType, nullable = true), + StructField("number", IntegerType, nullable = true) + )), containsNull = true), nullable = true), + StructField("error_info", StructType(Seq( + StructField("reason", StringType, nullable = true), + StructField("value", StringType, nullable = true) + ))) + )) + + val schema2 = StructType(Seq( + StructField("Id", IntegerType, nullable = false), + StructField("name", StringType, nullable = true), + StructField("address", StructType(Seq( + StructField("street", StringType, nullable = true), + StructField("city", LongType, nullable = true), + StructField("state", StringType, nullable = true) + ))), + StructField("tags", ArrayType(IntegerType, containsNull = true), nullable = true), + StructField("phones", ArrayType(StructType(Seq( + StructField("type", StringType, nullable = true), + StructField("number", StringType, nullable = true), + StructField("country", StringType, nullable = true) + )), containsNull = true), nullable = true), + StructField("additional_properties", ArrayType(StructType(Seq( + StructField("key", StringType, nullable = true), + StructField("value", StringType, nullable = true) + )), containsNull = true), nullable = true) + )) + + val diff = compareSchemas(schema1, schema2) + + assert(diff.length == 9) + assert(diff.count(_.isInstanceOf[ChangedType]) == 3) + assert(diff.count(_.isInstanceOf[NewField]) == 4) + assert(diff.count(_.isInstanceOf[DeletedField]) == 2) + + val changedTypes = diff.collect { case ct: ChangedType => ct } + assert(changedTypes.exists(c => c.columnName == "address.city" && c.oldType == "string" && c.newType == "long")) + assert(changedTypes.exists(c => c.columnName == "tags" && c.oldType == "array" && c.newType == "array")) + assert(changedTypes.exists(c => c.columnName == "phones[].number" && c.oldType == "integer" && c.newType == "string")) + + val newFields = diff.collect { case nf: NewField => nf } + assert(newFields.exists(n => n.columnName == "Id" && n.dataType == "integer")) + assert(newFields.exists(n => n.columnName == "address.state" && n.dataType == "string")) + assert(newFields.exists(n => n.columnName == "phones[].country" && n.dataType == "string")) + assert(newFields.exists(n => n.columnName == "additional_properties" && n.dataType == "array>")) + + val deletedFields = diff.collect { case df: DeletedField => df } + assert(deletedFields.exists(d => d.columnName == "id" && d.dataType == "integer")) + assert(deletedFields.exists(d => d.columnName == "error_info" && d.dataType == "struct<...>")) + } } "applyTransformations" should {