Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,28 @@ import org.apache.spark.unsafe.types.UTF8String
object GeoJSONUtils {

def updateGeometrySchema(schema: StructType, datatype: DataType): StructType = {
StructType(schema.fields.map {
case StructField("geometry", _, nullable, metadata) =>
StructField("geometry", datatype, nullable, metadata)
case StructField(name, dataType: StructType, nullable, metadata) =>
StructField(name, updateGeometrySchema(dataType, datatype), nullable, metadata)
case StructField(
name,
ArrayType(elementType: StructType, containsNull),
nullable,
metadata) =>
val updatedElementType = updateGeometrySchema(elementType, datatype)
StructField(name, ArrayType(updatedElementType, containsNull), nullable, metadata)
case other => other
})
// If this struct already has a geometry field, only update that field and stop.
val hasGeometry = hasGeometryField(schema)
if (hasGeometry) {
StructType(schema.fields.map {
case StructField("geometry", _, nullable, metadata) =>
StructField("geometry", datatype, nullable, metadata)
case other =>
other
})
} else {
// Otherwise keep searching deeper for the first geometry field
StructType(schema.fields.map {
case StructField(name, st: StructType, nullable, metadata) =>
StructField(name, updateGeometrySchema(st, datatype), nullable, metadata)

case StructField(name, ArrayType(elem: StructType, containsNull), nullable, metadata) =>
val updatedElem = updateGeometrySchema(elem, datatype)
StructField(name, ArrayType(updatedElem, containsNull), nullable, metadata)

case other => other
})
}
}

def geoJsonToGeometry(geoJson: String): Array[Byte] = {
Expand Down Expand Up @@ -103,20 +111,53 @@ object GeoJSONUtils {
InternalRow.fromSeq(newValues)
}

private def hasGeometryField(st: StructType): Boolean =
st.fields.exists(_.name == "geometry")

def convertGeoJsonToGeometry(row: InternalRow, schema: StructType): InternalRow = {
val newValues = new Array[Any](schema.fields.length)

// This struct is the geometry level if it has a geometry field at this level
val geometryLevel = hasGeometryField(schema)

schema.fields.zipWithIndex.foreach {
case (StructField("geometry", StringType, _, _), index) =>
val geometryGeoJson = row.getString(index)
newValues(index) = geoJsonToGeometry(geometryGeoJson)

// Convert geometry ONLY at the first geometry level
case (StructField("geometry", StringType, _, _), index) if geometryLevel =>
newValues(index) =
if (row.isNullAt(index)) null
else geoJsonToGeometry(row.getString(index))

// If we've reached the geometry level, do NOT recurse further
case (sf @ StructField(_, _: StructType, _, _), index) if geometryLevel =>
newValues(index) =
if (row.isNullAt(index)) null
else row.get(index, sf.dataType)

case (sf @ StructField(_, _: ArrayType, _, _), index) if geometryLevel =>
newValues(index) =
if (row.isNullAt(index)) null
else row.get(index, sf.dataType)

// Otherwise, recurse until the first geometry level is reached
case (StructField(_, structType: StructType, _, _), index) =>
val nestedRow = row.getStruct(index, structType.fields.length)
newValues(index) = convertGeoJsonToGeometry(nestedRow, structType)
newValues(index) =
if (row.isNullAt(index)) null
else {
val nestedRow = row.getStruct(index, structType.fields.length)
convertGeoJsonToGeometry(nestedRow, structType)
}

case (StructField(_, arrayType: ArrayType, _, _), index) =>
newValues(index) = handleArray(row, index, arrayType.elementType, true)
newValues(index) =
if (row.isNullAt(index)) null
else handleArray(row, index, arrayType.elementType, toGeometry = true)

// Primitives
case (_, index) =>
newValues(index) = row.get(index, schema.fields(index).dataType)
newValues(index) =
if (row.isNullAt(index)) null
else row.get(index, schema.fields(index).dataType)
}

InternalRow.fromSeq(newValues)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
},
"properties": {
"prop0": "value1",
"prop1": 0.0
"prop1": {
"prop1_0": "0",
"prop1_1": "1",
"prop1_2": "2"
}
}
},
{ "type": "Feature",
Expand All @@ -27,7 +31,13 @@
},
"properties": {
"prop0": "value2",
"prop1": {"this": "that"}
"prop1": {
"prop1_0": "0",
"prop1_1": "1",
"prop1_2": {
"this": "that"
}
}
}
},
{
Expand All @@ -41,7 +51,11 @@
},
"properties": {
"prop0": "value3",
"prop1": {"this": "that"}
"prop1": {
"prop1_0": "0",
"prop1_1": "1",
"prop1_2": "2"
}
}
},
{
Expand All @@ -55,7 +69,11 @@
},
"properties": {
"prop0": "value4",
"prop1": {"this": "that"}
"prop1": {
"prop1_0": "0",
"prop1_1": "1",
"prop1_2": "2"
}
}
},
{
Expand All @@ -69,8 +87,12 @@
},
"properties": {
"prop0": "value5",
"prop1": {"this": "that"}
"prop1": {
"prop1_0": "0",
"prop1_1": "1",
"prop1_2": "2"
}
}
}
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions.{col, explode, expr}
import org.apache.spark.sql.sedona_sql.io.geojson.GeoJSONUtils.updateGeometrySchema
import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType}
import org.locationtech.jts.geom.{Geometry, MultiLineString, Point, Polygon}
import org.scalatest.BeforeAndAfterAll

Expand Down Expand Up @@ -455,4 +457,62 @@ class geojsonIOTests extends TestBaseScala with BeforeAndAfterAll {
assert(rowsR.getAs[Polygon]("geometry") == rowsW.getAs[Polygon]("geometry"))
}
}

it("updateGeometrySchema should update first geometry struct only") {
val inputSchema = StructType(
Seq(
StructField("type", StringType),
StructField(
"features",
ArrayType(StructType(Seq(
StructField("type", StringType),
StructField(
"geometry",
StructType(Seq(
StructField("type", StringType),
StructField("coordinates", ArrayType(ArrayType(DoubleType)))))),
StructField(
"properties",
StructType(Seq(
StructField("prop0", StringType),
StructField(
"prop1",
StructType(Seq(
StructField("prop1_0", StringType),
StructField(
"geometry",
StructType(Seq(
StructField("type", StringType),
StructField(
"coordinates",
ArrayType(ArrayType(DoubleType))))))))))))))))))

val updatedSchema = updateGeometrySchema(inputSchema, StringType)

val expectedSchema = StructType(
Seq(
StructField("type", StringType),
StructField(
"features",
ArrayType(StructType(Seq(
StructField("type", StringType),
StructField("geometry", StringType), // GeoJSON string
StructField(
"properties",
StructType(Seq(
StructField("prop0", StringType),
StructField(
"prop1",
StructType(Seq(
StructField("prop1_0", StringType),
StructField(
"geometry",
StructType(Seq(
StructField("type", StringType),
StructField(
"coordinates",
ArrayType(ArrayType(DoubleType))))))))))))))))))

assert(updatedSchema == expectedSchema)
}
}
Loading