diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala index c1f78268acf..f77ff02f648 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala @@ -18,7 +18,7 @@ */ package org.apache.spark.sql.sedona_sql.io.stac -import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} /** * Defines a STAC extension with its schema and property mappings @@ -49,9 +49,14 @@ object StacExtension { StacExtension( name = "grid", // Schema for the grid extension, add all required fields here - schema = StructType(Seq(StructField("grid:code", StringType, nullable = true)))) + schema = StructType(Seq(StructField("grid:code", StringType, nullable = true)))), // Add other extensions here... - ) + StacExtension( + name = "eo", + schema = StructType(Seq(StructField("eo:cloud_cover", DoubleType, nullable = true)))), + StacExtension( + name = "eo", + schema = StructType(Seq(StructField("eo:snow_cover", DoubleType, nullable = true))))) } } diff --git a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala index b26cc6b1f64..1ffef0eb4e7 100644 --- a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala +++ b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala @@ -44,8 +44,18 @@ class StacDataSourceTest extends TestBaseScala { it("basic df load from local file with extensions should work") { val dfStac = sparkSession.read.format("stac").load(STAC_COLLECTION_LOCAL) // Filter rows where grid:code equals "MSIN-2506" - val filteredDf = dfStac.filter(dfStac.col("grid:code") === "MSIN-2506") - val rowCount = filteredDf.count() + var filteredDf = dfStac.filter(dfStac.col("grid:code") === "MSIN-2506") + var rowCount = filteredDf.count() + assert(rowCount > 0) + + // Filter rows where eo:cloud_cover equals 1.2 + filteredDf = dfStac.filter(dfStac.col("eo:cloud_cover") === 1.2) + rowCount = filteredDf.count() + assert(rowCount > 0) + + // Filter rows where eo:snow_cover equals 0.0 + filteredDf = dfStac.filter(dfStac.col("eo:snow_cover") === 0.0) + rowCount = filteredDf.count() assert(rowCount > 0) } @@ -347,9 +357,10 @@ class StacDataSourceTest extends TestBaseScala { // Extension fields that may be present val extensionFields = Seq( // Grid extension fields - StructField("grid:code", StringType, nullable = true) + StructField("grid:code", StringType, nullable = true), // Add other extension fields as needed - ) + StructField("eo:cloud_cover", DoubleType, nullable = true), + StructField("eo:snow_cover", DoubleType, nullable = true)) // Check that all base fields are present with correct types baseFields.foreach { expectedField =>