diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala index 44772f70ea1..e6c0898aa42 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala @@ -323,11 +323,10 @@ case class BroadcastIndexJoinExec( }) case Some(distanceExpression) => streamResultsRaw.map(row => { - val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]] - if (geom == null) { + val geometry = TraitJoinQueryBase.shapeToGeometry(boundStreamShape, row) + if (geometry == null) { (null, row) } else { - val geometry = GeometrySerializer.deserialize(geom) val radius = BindReferences .bindReference(distanceExpression, streamed.output) .eval(row) @@ -351,23 +350,21 @@ case class BroadcastIndexJoinExec( }) case _ => streamResultsRaw.map(row => { - val serializedObject = boundStreamShape.eval(row).asInstanceOf[Array[Byte]] - if (serializedObject == null) { - (null, row) - } else { - val shape = if (isRasterPredicate) { - if (boundStreamShape.dataType.isInstanceOf[RasterUDT]) { - val raster = RasterSerializer.deserialize(serializedObject) - JoinedGeometryRaster.rasterToWGS84Envelope(raster) - } else { - val geom = GeometrySerializer.deserialize(serializedObject) - JoinedGeometryRaster.geometryToWGS84Envelope(geom) - } + val shape = if (isRasterPredicate) { + // Raster path keeps the legacy bytes-only handling — Box2D doesn't apply here. + val serializedObject = boundStreamShape.eval(row).asInstanceOf[Array[Byte]] + if (serializedObject == null) null + else if (boundStreamShape.dataType.isInstanceOf[RasterUDT]) { + val raster = RasterSerializer.deserialize(serializedObject) + JoinedGeometryRaster.rasterToWGS84Envelope(raster) } else { - GeometrySerializer.deserialize(serializedObject) + val geom = GeometrySerializer.deserialize(serializedObject) + JoinedGeometryRaster.geometryToWGS84Envelope(geom) } - (shape, row) + } else { + TraitJoinQueryBase.shapeToGeometry(boundStreamShape, row) } + (shape, row) }) } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala index e9db76cffda..939cff108c0 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala @@ -291,6 +291,31 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { rightShape, SpatialPredicate.EQUALS, extraCondition) + // Box2D predicates. Both shape expressions resolve to Box2DUDT; the executors + // materialise each Box2D as a rectangular Polygon so the existing partitioner / + // R-tree / refine machinery applies unchanged. ST_BoxContains is closed-interval + // containment, so it maps to SpatialPredicate.COVERS (JTS `contains` would reject + // edge-touching cases). + case ST_BoxIntersects(Seq(leftShape, rightShape)) => + Some( + JoinQueryDetection( + left, + right, + leftShape, + rightShape, + SpatialPredicate.INTERSECTS, + isGeography = false, + extraCondition)) + case ST_BoxContains(Seq(leftShape, rightShape)) => + Some( + JoinQueryDetection( + left, + right, + leftShape, + rightShape, + SpatialPredicate.COVERS, + isGeography = false, + extraCondition)) case pred: ST_Predicate => getJoinDetection(left, right, pred, extraCondition) case pred: RS_Predicate => diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/OptimizableJoinCondition.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/OptimizableJoinCondition.scala index c81fafa4c6b..c9e9b0e0a1b 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/OptimizableJoinCondition.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/OptimizableJoinCondition.scala @@ -63,7 +63,7 @@ case class OptimizableJoinCondition(left: LogicalPlan, right: LogicalPlan) { expression match { case _: ST_Intersects | _: ST_Contains | _: ST_Covers | _: ST_Within | _: ST_CoveredBy | _: ST_Overlaps | _: ST_Touches | _: ST_Equals | _: ST_Crosses | _: ST_KNN | - _: RS_Predicate => + _: ST_BoxIntersects | _: ST_BoxContains | _: RS_Predicate => val leftShape = expression.children.head val rightShape = expression.children(1) ExpressionUtils.matchExpressionsToPlans(leftShape, rightShape, left, right).isDefined diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala index 2c449056102..5b86940b1c5 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala @@ -18,15 +18,17 @@ */ package org.apache.spark.sql.sedona_sql.strategy.join +import org.apache.sedona.common.Constructors import org.apache.sedona.common.S2Geography.GeographyWKBSerializer import org.apache.sedona.core.spatialRDD.SpatialRDD import org.apache.sedona.core.utils.SedonaConf import org.apache.sedona.sql.utils.{GeometrySerializer, RasterSerializer} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeRow} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.sedona_sql.UDT.RasterUDT -import org.locationtech.jts.geom.Geometry +import org.apache.spark.sql.sedona_sql.UDT.{Box2DUDT, RasterUDT} +import org.locationtech.jts.geom.{Geometry, GeometryFactory} trait TraitJoinQueryBase { self: SparkPlan => @@ -49,8 +51,10 @@ trait TraitJoinQueryBase { spatialRdd.setRawSpatialRDD( rdd .map { x => - val shape = - GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]]) + // Null shape rows materialise as an empty geometry collection so they carry the row + // payload through the partitioner / index without participating in any spatial match + // — mirrors the pre-existing `GeometrySerializer.deserialize(null)` fallback. + val shape = TraitJoinQueryBase.shapeToGeometryOrEmpty(shapeExpression, x) shape.setUserData(x.copy) shape } @@ -123,8 +127,7 @@ trait TraitJoinQueryBase { spatialRdd.setRawSpatialRDD( rdd .map { x => - val shape = - GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]]) + val shape = TraitJoinQueryBase.shapeToGeometryOrEmpty(shapeExpression, x) val distance = boundRadius.eval(x).asInstanceOf[Double] val expandedEnvelope = JoinedGeometry.geometryToExpandedEnvelope(shape, distance, isGeography) @@ -178,3 +181,61 @@ trait TraitJoinQueryBase { } } } + +object TraitJoinQueryBase { + + /** + * Materialise a shape column value as a JTS [[Geometry]]. Box2D-typed columns are turned into + * the closed rectangular polygon implied by their `(xmin, ymin, xmax, ymax)` bounds; all other + * shape columns are deserialised from the Sedona geometry binary form. + * + * Producing a JTS rectangle here lets the rest of the join machinery — partitioner, R-tree + * `IndexBuilder`, refine evaluator — stay shape-agnostic. JTS already short-circuits + * rectangle-rectangle predicates (`Polygon.isRectangle` triggers `RectangleIntersects` / + * `RectangleContains`), so a `ST_BoxIntersects` join naturally pays only the four-double + * envelope comparison at refine time. + * + * Inverted Box2D bounds (`xmin > xmax` / `ymin > ymax`) are rejected with the same + * `IllegalArgumentException` raised by `Predicates.boxIntersects` / `boxContains`. Inverted + * bounds have no defined planar meaning today (they are reserved for future + * antimeridian-wraparound semantics on Geography bboxes) and would silently mis-prune the + * R-tree if accepted here. + * + * Returns `null` when the shape column evaluates to NULL; the caller is expected to either skip + * the row or substitute an empty geometry. + */ + def shapeToGeometry(shapeExpression: Expression, row: InternalRow): Geometry = { + val evaluated = shapeExpression.eval(row) + if (evaluated == null) { + null + } else + shapeExpression.dataType match { + case _: Box2DUDT => + val box = evaluated.asInstanceOf[InternalRow] + val xmin = box.getDouble(0) + val ymin = box.getDouble(1) + val xmax = box.getDouble(2) + val ymax = box.getDouble(3) + if (xmin > xmax || ymin > ymax) { + throw new IllegalArgumentException( + "Box2D join input has inverted bounds (xmin > xmax or ymin > ymax). " + + "Planar Box2D predicates require ordered intervals; inverted bounds are " + + "reserved for future antimeridian wraparound semantics.") + } + Constructors.polygonFromEnvelope(xmin, ymin, xmax, ymax) + case _ => + GeometrySerializer.deserialize(evaluated.asInstanceOf[Array[Byte]]) + } + } + + /** + * Convenience wrapper that substitutes an empty geometry collection for NULL shapes. Used by + * the partitioned-RDD path where each row must carry a non-null geometry so the original + * `UnsafeRow` survives to outer-join output; spatial predicates against the empty geometry + * produce no matches, matching the legacy `GeometrySerializer.deserialize(null)` behaviour. + */ + def shapeToGeometryOrEmpty(shapeExpression: Expression, row: InternalRow): Geometry = { + val shape = shapeToGeometry(shapeExpression, row) + if (shape == null) new GeometryFactory().createGeometryCollection() else shape + } +} diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/Box2DJoinSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/Box2DJoinSuite.scala new file mode 100644 index 00000000000..69d1b50ec8a --- /dev/null +++ b/spark/common/src/test/scala/org/apache/sedona/sql/Box2DJoinSuite.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{broadcast, expr} +import org.apache.spark.sql.sedona_sql.strategy.join.{BroadcastIndexJoinExec, RangeJoinExec} + +class Box2DJoinSuite extends TestBaseScala { + + import Box2DJoinSuite.TestBox + + /** + * Three left-side boxes and three right-side boxes wired so we can predict exact result sizes: + * + * - L1=(0,0,10,10) R1=(5,5,15,15) — overlapping + * - L1=(0,0,10,10) R2=(2,2,8,8) — R2 fully inside L1 + * - L2=(0,0,10,10) R1=(5,5,15,15) — overlapping + * - L2=(0,0,10,10) R2=(2,2,8,8) — R2 fully inside L2 + * - L3 and R3 are disjoint from everything else; (L3,R3) is itself disjoint. + * + * Intersection-pair count: 4. Containment-pair count: 2 (L1⊇R2, L2⊇R2). + */ + private def leftBoxes: DataFrame = { + import sparkSession.implicits._ + Seq(TestBox(1, 0, 0, 10, 10), TestBox(2, 0, 0, 10, 10), TestBox(3, 20, 20, 30, 30)) + .toDF("id", "xmin", "ymin", "xmax", "ymax") + .selectExpr("id", "ST_MakeBox2D(ST_Point(xmin, ymin), ST_Point(xmax, ymax)) AS box") + } + + private def rightBoxes: DataFrame = { + import sparkSession.implicits._ + Seq(TestBox(11, 5, 5, 15, 15), TestBox(12, 2, 2, 8, 8), TestBox(13, 40, 40, 50, 50)) + .toDF("id", "xmin", "ymin", "xmax", "ymax") + .selectExpr("id", "ST_MakeBox2D(ST_Point(xmin, ymin), ST_Point(xmax, ymax)) AS box") + } + + describe("Box2D spatial join") { + + it("ST_BoxIntersects: broadcast index join produces correct pairs") { + val df = leftBoxes + .alias("L") + .join(broadcast(rightBoxes.alias("R")), expr("ST_BoxIntersects(L.box, R.box)")) + val plan = df.queryExecution.sparkPlan + assert( + plan.collect { case b: BroadcastIndexJoinExec => b }.size == 1, + "Expected BroadcastIndexJoinExec in the plan") + assert(df.count() == 4) + } + + it("ST_BoxIntersects: argument order is symmetric") { + val swapped = leftBoxes + .alias("L") + .join(broadcast(rightBoxes.alias("R")), expr("ST_BoxIntersects(R.box, L.box)")) + assert(swapped.count() == 4) + assert(swapped.queryExecution.sparkPlan.collect { case b: BroadcastIndexJoinExec => + b + }.size == 1) + } + + it("ST_BoxContains: broadcast index join uses COVERS semantics") { + val df = leftBoxes + .alias("L") + .join(broadcast(rightBoxes.alias("R")), expr("ST_BoxContains(L.box, R.box)")) + assert(df.queryExecution.sparkPlan.collect { case b: BroadcastIndexJoinExec => + b + }.size == 1) + assert(df.count() == 2) + } + + it("ST_BoxContains: edge-touching boxes count (closed-interval semantics)") { + // R contained in L sharing an edge: ST_BoxContains is closed-interval, so this matches. + // JTS Polygon.contains would reject (strict-interior), JTS Polygon.covers accepts; the + // detector maps ST_BoxContains → SpatialPredicate.COVERS specifically for this case. + import sparkSession.implicits._ + val outer = Seq(TestBox(1, 0, 0, 10, 10)) + .toDF("id", "xmin", "ymin", "xmax", "ymax") + .selectExpr("id", "ST_MakeBox2D(ST_Point(xmin, ymin), ST_Point(xmax, ymax)) AS box") + // edge-sharing box: same xmax, shares the right edge with outer. + val inner = Seq(TestBox(11, 5, 5, 10, 10)) + .toDF("id", "xmin", "ymin", "xmax", "ymax") + .selectExpr("id", "ST_MakeBox2D(ST_Point(xmin, ymin), ST_Point(xmax, ymax)) AS box") + val df = outer + .alias("O") + .join(broadcast(inner.alias("I")), expr("ST_BoxContains(O.box, I.box)")) + assert(df.count() == 1, "Closed-interval containment must include edge-touching boxes") + } + + it("ST_BoxIntersects: non-broadcast range join produces the same count") { + val df = leftBoxes + .alias("L") + .join(rightBoxes.alias("R"), expr("ST_BoxIntersects(L.box, R.box)")) + assert( + df.queryExecution.sparkPlan.collect { case r: RangeJoinExec => r }.size == 1, + "Expected RangeJoinExec in the plan") + assert(df.count() == 4) + } + + it("Null Box2D rows are safe and produce no matches") { + // A null shape on either side must not crash the executor and must not contribute matches + // (mirrors the existing GeometrySerializer.deserialize(null) → empty-collection fallback). + import sparkSession.implicits._ + val withNullLeft = leftBoxes + .selectExpr("id", "box AS box") + .union(Seq((99, null.asInstanceOf[org.apache.sedona.common.geometryObjects.Box2D])) + .toDF("id", "box")) + val df = withNullLeft + .alias("L") + .join(broadcast(rightBoxes.alias("R")), expr("ST_BoxIntersects(L.box, R.box)")) + assert(df.count() == 4) // unchanged from the non-null fixture + // Range join path (no broadcast) also tolerates nulls. + val rangeDf = withNullLeft + .alias("L") + .join(rightBoxes.alias("R"), expr("ST_BoxIntersects(L.box, R.box)")) + assert(rangeDf.count() == 4) + } + + it("Inverted Box2D bounds in a join throw IllegalArgumentException") { + import sparkSession.implicits._ + // Construct an inverted Box2D directly via the Java constructor (the SQL ST_MakeBox2D + // doesn't validate, so this is how a stored column with inverted bounds would look). + val invertedLeft = + Seq((1, new org.apache.sedona.common.geometryObjects.Box2D(10.0, 0.0, 0.0, 10.0))) + .toDF("id", "box") + val ex = intercept[org.apache.spark.SparkException] { + invertedLeft + .alias("L") + .join(broadcast(rightBoxes.alias("R")), expr("ST_BoxIntersects(L.box, R.box)")) + .collect() + } + val cause = Iterator + .iterate(ex: Throwable)(_.getCause) + .takeWhile(_ != null) + .find(_.isInstanceOf[IllegalArgumentException]) + assert(cause.isDefined, s"Expected IllegalArgumentException in cause chain, got: $ex") + assert(cause.get.getMessage.contains("inverted bounds")) + } + + it("Result is equivalent to ST_Intersects on the Box2D-as-polygon envelopes") { + val viaBox = leftBoxes + .alias("L") + .join(broadcast(rightBoxes.alias("R")), expr("ST_BoxIntersects(L.box, R.box)")) + .selectExpr("L.id AS l", "R.id AS r") + .orderBy("l", "r") + .collect() + .toSeq + + // ST_GeomFromBox2D is the function-form equivalent of `CAST(box AS geometry)`. The cast + // syntax requires the Sedona SQL parser extension; this suite runs under the common test + // base, which doesn't wire that extension, so we go through the function form here. + val asPolygons = leftBoxes + .selectExpr("id", "ST_GeomFromBox2D(box) AS g") + .alias("L") + .join( + broadcast(rightBoxes.selectExpr("id", "ST_GeomFromBox2D(box) AS g").alias("R")), + expr("ST_Intersects(L.g, R.g)")) + .selectExpr("L.id AS l", "R.id AS r") + .orderBy("l", "r") + .collect() + .toSeq + + assert(viaBox == asPolygons) + } + } + +} + +object Box2DJoinSuite { + // Top-level case class so Spark's encoder doesn't need an outer-class reference. + case class TestBox(id: Int, xmin: Double, ymin: Double, xmax: Double, ymax: Double) +}