Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/
package org.apache.spark.sql.sedona_sql.io.stac

import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}

/**
* Defines a STAC extension with its schema and property mappings
Expand Down Expand Up @@ -49,9 +49,14 @@ object StacExtension {
StacExtension(
name = "grid",
// Schema for the grid extension, add all required fields here
schema = StructType(Seq(StructField("grid:code", StringType, nullable = true))))
schema = StructType(Seq(StructField("grid:code", StringType, nullable = true)))),

// Add other extensions here...
)
StacExtension(
name = "eo",
schema = StructType(Seq(StructField("eo:cloud_cover", DoubleType, nullable = true)))),
StacExtension(
name = "eo",
schema = StructType(Seq(StructField("eo:snow_cover", DoubleType, nullable = true)))))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,18 @@ class StacDataSourceTest extends TestBaseScala {
it("basic df load from local file with extensions should work") {
val dfStac = sparkSession.read.format("stac").load(STAC_COLLECTION_LOCAL)
// Filter rows where grid:code equals "MSIN-2506"
val filteredDf = dfStac.filter(dfStac.col("grid:code") === "MSIN-2506")
val rowCount = filteredDf.count()
var filteredDf = dfStac.filter(dfStac.col("grid:code") === "MSIN-2506")
var rowCount = filteredDf.count()
assert(rowCount > 0)

// Filter rows where eo:cloud_cover equals 1.2
filteredDf = dfStac.filter(dfStac.col("eo:cloud_cover") === 1.2)
rowCount = filteredDf.count()
assert(rowCount > 0)

// Filter rows where eo:snow_cover equals 0.0
filteredDf = dfStac.filter(dfStac.col("eo:snow_cover") === 0.0)
rowCount = filteredDf.count()
assert(rowCount > 0)
}

Expand Down Expand Up @@ -347,9 +357,10 @@ class StacDataSourceTest extends TestBaseScala {
// Extension fields that may be present
val extensionFields = Seq(
// Grid extension fields
StructField("grid:code", StringType, nullable = true)
StructField("grid:code", StringType, nullable = true),
// Add other extension fields as needed
)
StructField("eo:cloud_cover", DoubleType, nullable = true),
StructField("eo:snow_cover", DoubleType, nullable = true))

// Check that all base fields are present with correct types
baseFields.foreach { expectedField =>
Expand Down
Loading