diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/geojson/GeoJSONUtils.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/geojson/GeoJSONUtils.scala index 5eaec7b39e3..58d3f2a0400 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/geojson/GeoJSONUtils.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/geojson/GeoJSONUtils.scala @@ -30,20 +30,28 @@ import org.apache.spark.unsafe.types.UTF8String object GeoJSONUtils { def updateGeometrySchema(schema: StructType, datatype: DataType): StructType = { - StructType(schema.fields.map { - case StructField("geometry", _, nullable, metadata) => - StructField("geometry", datatype, nullable, metadata) - case StructField(name, dataType: StructType, nullable, metadata) => - StructField(name, updateGeometrySchema(dataType, datatype), nullable, metadata) - case StructField( - name, - ArrayType(elementType: StructType, containsNull), - nullable, - metadata) => - val updatedElementType = updateGeometrySchema(elementType, datatype) - StructField(name, ArrayType(updatedElementType, containsNull), nullable, metadata) - case other => other - }) + // If this struct already has a geometry field, only update that field and stop. + val hasGeometry = hasGeometryField(schema) + if (hasGeometry) { + StructType(schema.fields.map { + case StructField("geometry", _, nullable, metadata) => + StructField("geometry", datatype, nullable, metadata) + case other => + other + }) + } else { + // Otherwise keep searching deeper for the first geometry field + StructType(schema.fields.map { + case StructField(name, st: StructType, nullable, metadata) => + StructField(name, updateGeometrySchema(st, datatype), nullable, metadata) + + case StructField(name, ArrayType(elem: StructType, containsNull), nullable, metadata) => + val updatedElem = updateGeometrySchema(elem, datatype) + StructField(name, ArrayType(updatedElem, containsNull), nullable, metadata) + + case other => other + }) + } } def geoJsonToGeometry(geoJson: String): Array[Byte] = { @@ -103,20 +111,53 @@ object GeoJSONUtils { InternalRow.fromSeq(newValues) } + private def hasGeometryField(st: StructType): Boolean = + st.fields.exists(_.name == "geometry") + def convertGeoJsonToGeometry(row: InternalRow, schema: StructType): InternalRow = { val newValues = new Array[Any](schema.fields.length) + // This struct is the geometry level if it has a geometry field at this level + val geometryLevel = hasGeometryField(schema) + schema.fields.zipWithIndex.foreach { - case (StructField("geometry", StringType, _, _), index) => - val geometryGeoJson = row.getString(index) - newValues(index) = geoJsonToGeometry(geometryGeoJson) + + // Convert geometry ONLY at the first geometry level + case (StructField("geometry", StringType, _, _), index) if geometryLevel => + newValues(index) = + if (row.isNullAt(index)) null + else geoJsonToGeometry(row.getString(index)) + + // If we've reached the geometry level, do NOT recurse further + case (sf @ StructField(_, _: StructType, _, _), index) if geometryLevel => + newValues(index) = + if (row.isNullAt(index)) null + else row.get(index, sf.dataType) + + case (sf @ StructField(_, _: ArrayType, _, _), index) if geometryLevel => + newValues(index) = + if (row.isNullAt(index)) null + else row.get(index, sf.dataType) + + // Otherwise, recurse until the first geometry level is reached case (StructField(_, structType: StructType, _, _), index) => - val nestedRow = row.getStruct(index, structType.fields.length) - newValues(index) = convertGeoJsonToGeometry(nestedRow, structType) + newValues(index) = + if (row.isNullAt(index)) null + else { + val nestedRow = row.getStruct(index, structType.fields.length) + convertGeoJsonToGeometry(nestedRow, structType) + } + case (StructField(_, arrayType: ArrayType, _, _), index) => - newValues(index) = handleArray(row, index, arrayType.elementType, true) + newValues(index) = + if (row.isNullAt(index)) null + else handleArray(row, index, arrayType.elementType, toGeometry = true) + + // Primitives case (_, index) => - newValues(index) = row.get(index, schema.fields(index).dataType) + newValues(index) = + if (row.isNullAt(index)) null + else row.get(index, schema.fields(index).dataType) } InternalRow.fromSeq(newValues) diff --git a/spark/common/src/test/resources/geojson/geojson_feature-collection.json b/spark/common/src/test/resources/geojson/geojson_feature-collection.json index 08803b13cb0..a0705d5248f 100644 --- a/spark/common/src/test/resources/geojson/geojson_feature-collection.json +++ b/spark/common/src/test/resources/geojson/geojson_feature-collection.json @@ -14,7 +14,11 @@ }, "properties": { "prop0": "value1", - "prop1": 0.0 + "prop1": { + "prop1_0": "0", + "prop1_1": "1", + "prop1_2": "2" + } } }, { "type": "Feature", @@ -27,7 +31,13 @@ }, "properties": { "prop0": "value2", - "prop1": {"this": "that"} + "prop1": { + "prop1_0": "0", + "prop1_1": "1", + "prop1_2": { + "this": "that" + } + } } }, { @@ -41,7 +51,11 @@ }, "properties": { "prop0": "value3", - "prop1": {"this": "that"} + "prop1": { + "prop1_0": "0", + "prop1_1": "1", + "prop1_2": "2" + } } }, { @@ -55,7 +69,11 @@ }, "properties": { "prop0": "value4", - "prop1": {"this": "that"} + "prop1": { + "prop1_0": "0", + "prop1_1": "1", + "prop1_2": "2" + } } }, { @@ -69,8 +87,12 @@ }, "properties": { "prop0": "value5", - "prop1": {"this": "that"} + "prop1": { + "prop1_0": "0", + "prop1_1": "1", + "prop1_2": "2" + } } } ] -} \ No newline at end of file +} diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/geojsonIOTests.scala b/spark/common/src/test/scala/org/apache/sedona/sql/geojsonIOTests.scala index 61d79165d91..ba3a5d3811e 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/geojsonIOTests.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/geojsonIOTests.scala @@ -22,6 +22,8 @@ import org.apache.commons.io.FileUtils import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.functions.{col, explode, expr} +import org.apache.spark.sql.sedona_sql.io.geojson.GeoJSONUtils.updateGeometrySchema +import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType} import org.locationtech.jts.geom.{Geometry, MultiLineString, Point, Polygon} import org.scalatest.BeforeAndAfterAll @@ -455,4 +457,62 @@ class geojsonIOTests extends TestBaseScala with BeforeAndAfterAll { assert(rowsR.getAs[Polygon]("geometry") == rowsW.getAs[Polygon]("geometry")) } } + + it("updateGeometrySchema should update first geometry struct only") { + val inputSchema = StructType( + Seq( + StructField("type", StringType), + StructField( + "features", + ArrayType(StructType(Seq( + StructField("type", StringType), + StructField( + "geometry", + StructType(Seq( + StructField("type", StringType), + StructField("coordinates", ArrayType(ArrayType(DoubleType)))))), + StructField( + "properties", + StructType(Seq( + StructField("prop0", StringType), + StructField( + "prop1", + StructType(Seq( + StructField("prop1_0", StringType), + StructField( + "geometry", + StructType(Seq( + StructField("type", StringType), + StructField( + "coordinates", + ArrayType(ArrayType(DoubleType)))))))))))))))))) + + val updatedSchema = updateGeometrySchema(inputSchema, StringType) + + val expectedSchema = StructType( + Seq( + StructField("type", StringType), + StructField( + "features", + ArrayType(StructType(Seq( + StructField("type", StringType), + StructField("geometry", StringType), // GeoJSON string + StructField( + "properties", + StructType(Seq( + StructField("prop0", StringType), + StructField( + "prop1", + StructType(Seq( + StructField("prop1_0", StringType), + StructField( + "geometry", + StructType(Seq( + StructField("type", StringType), + StructField( + "coordinates", + ArrayType(ArrayType(DoubleType)))))))))))))))))) + + assert(updatedSchema == expectedSchema) + } }