diff --git a/integrations/spark/spark-3.1/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/SparkTestBase.java b/integrations/spark/spark-3.1/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/SparkTestBase.java index ec5ee9c67..0e5080ee9 100644 --- a/integrations/spark/spark-3.1/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/SparkTestBase.java +++ b/integrations/spark/spark-3.1/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/SparkTestBase.java @@ -39,7 +39,7 @@ public void beforeEach(ExtensionContext context) throws Exception { "spark.sql.extensions", ("org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions," + "com.linkedin.openhouse.spark.extensions.OpenhouseSparkSessionExtensions")) - .config("spark.sql.catalog.openhouse", "org.apache.iceberg.spark.SparkCatalog") + .config("spark.sql.catalog.openhouse", "com.linkedin.openhouse.spark.OHSparkCatalog") .config( "spark.sql.catalog.openhouse.catalog-impl", "com.linkedin.openhouse.spark.OpenHouseCatalog") diff --git a/integrations/spark/spark-3.1/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/CTASNonNullTest.java b/integrations/spark/spark-3.1/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/CTASNonNullTest.java index beb812daf..d969ae706 100644 --- a/integrations/spark/spark-3.1/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/CTASNonNullTest.java +++ b/integrations/spark/spark-3.1/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/CTASNonNullTest.java @@ -24,7 +24,8 @@ public void testCTASPreservesNonNull() throws Exception { // Verify spark catalogs have correct classes configured assertEquals( - "org.apache.iceberg.spark.SparkCatalog", spark.conf().get("spark.sql.catalog.openhouse")); + "com.linkedin.openhouse.spark.OHSparkCatalog", + spark.conf().get("spark.sql.catalog.openhouse")); // Verify id column is preserved in good catalog, not preserved in bad catalog assertFalse(sourceSchema.apply("id").nullable(), "Source table id column should be required"); diff --git a/integrations/spark/spark-3.1/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/CatalogOperationTest.java b/integrations/spark/spark-3.1/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/CatalogOperationTest.java index 4d5641124..d946c1cd1 100644 --- a/integrations/spark/spark-3.1/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/CatalogOperationTest.java +++ b/integrations/spark/spark-3.1/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/CatalogOperationTest.java @@ -11,9 +11,9 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; -import org.apache.iceberg.CatalogUtil; import org.apache.iceberg.DataFile; import org.apache.iceberg.DataFiles; +import org.apache.iceberg.NullOrder; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; import org.apache.iceberg.SchemaParser; @@ -31,23 +31,31 @@ import org.apache.spark.sql.types.StructType; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import scala.collection.JavaConverters; public class CatalogOperationTest extends OpenHouseSparkITest { + + private static final String DATABASE = "d1_catalog"; + @Test public void testCasingWithCTAS() throws Exception { try (SparkSession spark = getSparkSession()) { // creating a casing preserving table using backtick - spark.sql("CREATE TABLE openhouse.d1.`tT1` (name string)"); + spark.sql("CREATE TABLE openhouse." + DATABASE + ".`tT1` (name string)"); // testing writing behavior, note the casing of tt1 is intentionally changed. - spark.sql("INSERT INTO openhouse.d1.Tt1 VALUES ('foo')"); + spark.sql("INSERT INTO openhouse." + DATABASE + ".Tt1 VALUES ('foo')"); // Verifying by querying with all lower-cased name Assertions.assertEquals( - 1, spark.sql("SELECT * from openhouse.d1.tt1").collectAsList().size()); + 1, spark.sql("SELECT * from openhouse." + DATABASE + ".tt1").collectAsList().size()); // ctas but referring with lower-cased name - spark.sql("CREATE TABLE openhouse.d1.t2 AS SELECT * from openhouse.d1.tt1"); - Assertions.assertEquals(1, spark.sql("SELECT * FROM openhouse.d1.t2").collectAsList().size()); + spark.sql( + "CREATE TABLE openhouse." + + DATABASE + + ".t2 AS SELECT * from openhouse." + + DATABASE + + ".tt1"); + Assertions.assertEquals( + 1, spark.sql("SELECT * FROM openhouse." + DATABASE + ".t2").collectAsList().size()); } } @@ -55,7 +63,7 @@ public void testCasingWithCTAS() throws Exception { public void testCreateTablePartitionedByDate() throws Exception { try (SparkSession spark = getSparkSession()) { // creating a casing preserving table using backtick - String quotedFqtn = "openhouse.d1.tpartionedbydate"; + String quotedFqtn = "openhouse." + DATABASE + ".tpartionedbydate"; spark.sql( String.format( "CREATE TABLE %s (data string) PARTITIONED BY (datefield DATE)", quotedFqtn)); @@ -140,13 +148,17 @@ public void testCreateReplicaSkipFieldIdReassignmentUnPartitionedTable() throws @Test public void testAlterTableUnsetReplicationPolicy() throws Exception { try (SparkSession spark = getSparkSession()) { - spark.sql("CREATE TABLE openhouse.d1.`ttt1` (name string)"); - spark.sql("INSERT INTO openhouse.d1.ttt1 VALUES ('foo')"); + spark.sql("CREATE TABLE openhouse." + DATABASE + ".`ttt1` (name string)"); + spark.sql("INSERT INTO openhouse." + DATABASE + ".ttt1 VALUES ('foo')"); spark.sql( - "ALTER TABLE openhouse.d1.ttt1 SET POLICY (REPLICATION=({destination:'WAR', interval:12h}))"); + "ALTER TABLE openhouse." + + DATABASE + + ".ttt1 SET POLICY (REPLICATION=({destination:'WAR', interval:12h}))"); spark.sql( - "ALTER TABLE openhouse.d1.ttt1 SET POLICY (RETENTION= 30d on column name where pattern='yyyy-MM-dd')"); - Policies policies = getPoliciesObj("openhouse.d1.ttt1", spark); + "ALTER TABLE openhouse." + + DATABASE + + ".ttt1 SET POLICY (RETENTION= 30d on column name where pattern='yyyy-MM-dd')"); + Policies policies = getPoliciesObj("openhouse." + DATABASE + ".ttt1", spark); Assertions.assertNotNull(policies); Assertions.assertEquals( "'WAR'", policies.getReplication().getConfig().get(0).getDestination()); @@ -155,8 +167,8 @@ public void testAlterTableUnsetReplicationPolicy() throws Exception { "'yyyy-MM-dd'", policies.getRetention().getColumnPattern().getPattern()); // unset replication policy - spark.sql("ALTER TABLE openhouse.d1.ttt1 UNSET POLICY (REPLICATION)"); - Policies updatedPolicy = getPoliciesObj("openhouse.d1.ttt1", spark); + spark.sql("ALTER TABLE openhouse." + DATABASE + ".ttt1 UNSET POLICY (REPLICATION)"); + Policies updatedPolicy = getPoliciesObj("openhouse." + DATABASE + ".ttt1", spark); Assertions.assertEquals(updatedPolicy.getReplication().getConfig().size(), 0); // assert that other policies, retention is not modified after unsetting replication Assertions.assertNotNull(updatedPolicy.getRetention()); @@ -165,8 +177,10 @@ public void testAlterTableUnsetReplicationPolicy() throws Exception { // assert retention can be set after unsetting replication spark.sql( - "ALTER TABLE openhouse.d1.ttt1 SET POLICY (RETENTION = 30D on COLUMN name WHERE pattern = 'yyyy')"); - Policies policyWithRetention = getPoliciesObj("openhouse.d1.ttt1", spark); + "ALTER TABLE openhouse." + + DATABASE + + ".ttt1 SET POLICY (RETENTION = 30D on COLUMN name WHERE pattern = 'yyyy')"); + Policies policyWithRetention = getPoliciesObj("openhouse." + DATABASE + ".ttt1", spark); Assertions.assertNotNull(policyWithRetention); Assertions.assertEquals( "'yyyy'", policyWithRetention.getRetention().getColumnPattern().getPattern()); @@ -174,17 +188,19 @@ public void testAlterTableUnsetReplicationPolicy() throws Exception { // assert replication can be set again after retention policy spark.sql( - "ALTER TABLE openhouse.d1.ttt1 SET POLICY (REPLICATION=({destination:'WAR', interval:12h}))"); - Policies policyWithReplication = getPoliciesObj("openhouse.d1.ttt1", spark); + "ALTER TABLE openhouse." + + DATABASE + + ".ttt1 SET POLICY (REPLICATION=({destination:'WAR', interval:12h}))"); + Policies policyWithReplication = getPoliciesObj("openhouse." + DATABASE + ".ttt1", spark); Assertions.assertNotNull(policyWithReplication); Assertions.assertEquals( "'WAR'", policyWithReplication.getReplication().getConfig().get(0).getDestination()); // UNSET policy for table without replication - spark.sql("CREATE TABLE openhouse.d1.`tttest1` (name string)"); - spark.sql("INSERT INTO openhouse.d1.tttest1 VALUES ('foo')"); - spark.sql("ALTER TABLE openhouse.d1.tttest1 UNSET POLICY (REPLICATION)"); - Policies policytttest1 = getPoliciesObj("openhouse.d1.tttest1", spark); + spark.sql("CREATE TABLE openhouse." + DATABASE + ".`tttest1` (name string)"); + spark.sql("INSERT INTO openhouse." + DATABASE + ".tttest1 VALUES ('foo')"); + spark.sql("ALTER TABLE openhouse." + DATABASE + ".tttest1 UNSET POLICY (REPLICATION)"); + Policies policytttest1 = getPoliciesObj("openhouse." + DATABASE + ".tttest1", spark); Assertions.assertEquals(0, policytttest1.getReplication().getConfig().size()); } } @@ -225,29 +241,6 @@ public void testCreateReplicaSkipFieldIdReassignmentPartitionedTable() throws Ex } } - /** - * This is a copy of com.linkedin.openhouse.jobs.spark.Operations#getCatalog() temporarily. - * Refactoring these pieces require deployment coordination, thus we shall create an artifact - * module that can be pulled by :apps module. - */ - private Catalog getOpenHouseCatalog(SparkSession spark) { - final Map catalogProperties = new HashMap<>(); - final String catalogPropertyPrefix = String.format("spark.sql.catalog.openhouse."); - final Map sparkProperties = JavaConverters.mapAsJavaMap(spark.conf().getAll()); - for (Map.Entry entry : sparkProperties.entrySet()) { - if (entry.getKey().startsWith(catalogPropertyPrefix)) { - catalogProperties.put( - entry.getKey().substring(catalogPropertyPrefix.length()), entry.getValue()); - } - } - // this initializes the catalog based on runtime Catalog class passed in catalog-impl conf. - return CatalogUtil.loadCatalog( - sparkProperties.get("spark.sql.catalog.openhouse.catalog-impl"), - "openhouse", - catalogProperties, - spark.sparkContext().hadoopConfiguration()); - } - private Policies getPoliciesObj(String tableName, SparkSession spark) { List props = spark.sql(String.format("show tblProperties %s", tableName)).collectAsList(); Map collect = @@ -411,4 +404,154 @@ public void testAlterTableSortOrderCTAS() throws Exception { Assertions.assertEquals(SortOrder.unsorted(), newSqlTable.sortOrder()); } } + + @Test + public void testWriteWithCaseMismatch_succeedsWithCaseSensitiveTrue() throws Exception { + try (SparkSession spark = getSparkSession()) { + // Create a table with uppercase column "ID" — the common case for tables originally created + // by Hive or engines that preserve user-specified casing. + Catalog catalog = getOpenHouseCatalog(spark); + Schema schema = new Schema(Types.NestedField.required(1, "ID", Types.StringType.get())); + catalog.createTable(TableIdentifier.of("d1", "write_case_test"), schema); + + // With caseSensitive=true, Spark's ResolveOutputRelation uses a case-sensitive resolver and + // cannot find source column "id" in the target schema column "ID". Vanilla Spark would throw + // "Cannot find data for output column 'ID'" at analysis time. + // + // OHSparkCatalog advertises ACCEPT_ANY_SCHEMA so outputResolved=true and + // ResolveOutputRelation skips OH writes. OHWriteSchemaNormalizationRule (post-hoc) then + // inserts a Project(Alias("id" -> "ID")) so Iceberg sees the correct stored casing. + spark.conf().set("spark.sql.caseSensitive", "true"); + try { + Assertions.assertDoesNotThrow( + () -> spark.sql("SELECT 'row1' AS id").writeTo("openhouse.d1.write_case_test").append(), + "writeTo().append() must succeed when source has lowercase 'id' and OH table has 'ID'"); + + // Verify the row was written with the correct stored casing. + // Use the exact stored column name "ID" (not lowercase "id") for the read since this + // branch does not include the read-side case-insensitive resolution rule. + List rows = spark.sql("SELECT ID FROM openhouse.d1.write_case_test").collectAsList(); + Assertions.assertEquals(1, rows.size()); + Assertions.assertEquals("row1", rows.get(0).getString(0)); + + // The rule must NOT mutate spark.sql.caseSensitive. + Assertions.assertEquals( + "true", + spark.conf().get("spark.sql.caseSensitive"), + "OHWriteSchemaNormalizationRule must not modify spark.sql.caseSensitive"); + } finally { + spark.conf().set("spark.sql.caseSensitive", "false"); + spark.sql("DROP TABLE openhouse.d1.write_case_test"); + } + } + } + + /** + * Verifies that {@code OHWriteSchemaNormalizationRule} fixes writes from a temporary Spark view + * into an OH table when the view's column casing differs from the stored table casing. + * + *

Without the fix: {@code INSERT INTO oh_tbl SELECT colA FROM tempView} fails at analysis time + * because Spark's {@code ResolveOutputRelation} performs a case-sensitive name comparison between + * the view output column (e.g. {@code "colA"}) and the stored table column (e.g. {@code "COLA"}). + * This throw happens regardless of {@code spark.sql.caseSensitive}, because the temporary view + * introduces an intermediate resolved relation whose output attribute names are locked to the + * casing in the view body. + * + *

With the fix: {@code OHSparkCatalog} advertises {@code ACCEPT_ANY_SCHEMA}, causing {@code + * ResolveOutputRelation} to skip OH write commands entirely. {@code + * OHWriteSchemaNormalizationRule} then fires post-hoc and inserts a {@code Project} that renames + * the view output column to match the stored OH casing, regardless of whether the source is a + * temp view, a direct table, or any other resolved query. + */ + @Test + public void testWriteFromTempView_caseMismatch_succeeds() throws Exception { + try (SparkSession spark = getSparkSession()) { + // Create an OH table with uppercase column "ID". + Catalog catalog = getOpenHouseCatalog(spark); + Schema schema = new Schema(Types.NestedField.required(1, "ID", Types.StringType.get())); + catalog.createTable(TableIdentifier.of("d1", "write_view_case_test"), schema); + + spark.conf().set("spark.sql.caseSensitive", "true"); + try { + // Create a temp view that produces a lowercase "id" column. + // This simulates the real-world pattern: a view created over a Hive/external source + // where the engine lowercases or camelCases the identifier. + spark.sql("CREATE OR REPLACE TEMP VIEW v_write_src AS SELECT 'row1' AS id"); + + // INSERT INTO from the temp view must succeed — the rule must rename "id" → "ID" + // in the Project inserted between the view output and the Iceberg writer. + Assertions.assertDoesNotThrow( + () -> + spark.sql( + "INSERT INTO openhouse.d1.write_view_case_test SELECT id FROM v_write_src"), + "INSERT INTO from temp view must succeed when view has 'id' and OH table stores 'ID'"); + + // Confirm the row landed with the correct stored casing. + List rows = + spark.sql("SELECT ID FROM openhouse.d1.write_view_case_test").collectAsList(); + Assertions.assertEquals(1, rows.size()); + Assertions.assertEquals("row1", rows.get(0).getString(0)); + } finally { + spark.conf().set("spark.sql.caseSensitive", "false"); + spark.sql("DROP VIEW IF EXISTS v_write_src"); + spark.sql("DROP TABLE openhouse.d1.write_view_case_test"); + } + } + } + + @Test + public void testWriteOrderedByPersistsMultiColumnSortOrder() throws Exception { + try (SparkSession spark = getSparkSession()) { + Catalog catalog = getOpenHouseCatalog(spark); + spark.sql( + "CREATE TABLE openhouse.db.write_ordered_multi (id INT, category STRING, data STRING)"); + spark.sql("ALTER TABLE openhouse.db.write_ordered_multi WRITE ORDERED BY category, id"); + + Table table = catalog.loadTable(TableIdentifier.of("db", "write_ordered_multi")); + Assertions.assertEquals( + SortOrder.builderFor(table.schema()).asc("category").asc("id").build(), + table.sortOrder()); + } + } + + @Test + public void testWriteOrderedByRespectsDirectionAndNullOrder() throws Exception { + try (SparkSession spark = getSparkSession()) { + Catalog catalog = getOpenHouseCatalog(spark); + spark.sql("CREATE TABLE openhouse.db.write_ordered_desc (id INT, category STRING)"); + // DESC defaults to NULLS LAST in Iceberg; override to NULLS FIRST to verify both + // direction and null-order are propagated end-to-end. + spark.sql( + "ALTER TABLE openhouse.db.write_ordered_desc WRITE ORDERED BY category DESC NULLS FIRST"); + + Table table = catalog.loadTable(TableIdentifier.of("db", "write_ordered_desc")); + Assertions.assertEquals( + SortOrder.builderFor(table.schema()).desc("category", NullOrder.NULLS_FIRST).build(), + table.sortOrder()); + } + } + + @Test + public void testWriteOrderedByRoundTripsThroughInsert() throws Exception { + try (SparkSession spark = getSparkSession()) { + Catalog catalog = getOpenHouseCatalog(spark); + spark.sql("CREATE TABLE openhouse.db.write_ordered_insert (id INT, category STRING)"); + spark.sql("ALTER TABLE openhouse.db.write_ordered_insert WRITE ORDERED BY id"); + + spark.sql( + "INSERT INTO openhouse.db.write_ordered_insert VALUES (3, 'C'), (1, 'A'), (2, 'B')"); + + Table table = catalog.loadTable(TableIdentifier.of("db", "write_ordered_insert")); + // Sort order metadata is preserved across an INSERT (no implicit reset). + Assertions.assertEquals( + SortOrder.builderFor(table.schema()).asc("id").build(), table.sortOrder()); + + List rows = + spark.sql("SELECT id FROM openhouse.db.write_ordered_insert ORDER BY id").collectAsList(); + Assertions.assertEquals(3, rows.size()); + Assertions.assertEquals(1, rows.get(0).getInt(0)); + Assertions.assertEquals(2, rows.get(1).getInt(0)); + Assertions.assertEquals(3, rows.get(2).getInt(0)); + } + } } diff --git a/integrations/spark/spark-3.1/openhouse-spark-runtime/src/main/java/com/linkedin/openhouse/spark/OHSparkCatalog.java b/integrations/spark/spark-3.1/openhouse-spark-runtime/src/main/java/com/linkedin/openhouse/spark/OHSparkCatalog.java new file mode 100644 index 000000000..8dd7aeca9 --- /dev/null +++ b/integrations/spark/spark-3.1/openhouse-spark-runtime/src/main/java/com/linkedin/openhouse/spark/OHSparkCatalog.java @@ -0,0 +1,65 @@ +package com.linkedin.openhouse.spark; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCapability; + +/** + * Spark catalog wrapper for OpenHouse that extends {@link SparkCatalog} and annotates every loaded + * table with {@link TableCapability#ACCEPT_ANY_SCHEMA}. + * + *

Why {@code ACCEPT_ANY_SCHEMA}? Spark's {@code ResolveOutputRelation} analyzer rule uses the + * session resolver (case-sensitive when {@code spark.sql.caseSensitive=true}). If a client + * DataFrame has column {@code "id"} but the OH table stores {@code "ID"}, the write command never + * resolves — Spark throws "Cannot find data for output column 'ID'" before OH's own server-side + * case-insensitive schema validation runs. + * + *

Advertising {@code ACCEPT_ANY_SCHEMA} causes {@code DataSourceV2Relation.skipSchemaResolution} + * to return {@code true}, which makes {@code V2WriteCommand.outputResolved} return {@code true} + * immediately. {@code ResolveOutputRelation} therefore skips OH write commands, allowing {@link + * OHWriteSchemaNormalizationRule} (a post-hoc resolution rule) to insert the necessary + * column-renaming {@code Project} before execution. + * + *

For reads and DDL the capability has no effect; it is only consulted during write analysis. + * + *

Configuration: + * + *

+ *   spark.sql.catalog.openhouse=com.linkedin.openhouse.spark.OHSparkCatalog
+ *   spark.sql.catalog.openhouse.catalog-impl=com.linkedin.openhouse.spark.OpenHouseCatalog
+ * 
+ */ +public class OHSparkCatalog extends SparkCatalog { + + @Override + public SparkTable loadTable(Identifier ident) throws NoSuchTableException { + SparkTable original = super.loadTable(ident); + return withAcceptAnySchema(original); + } + + /** + * Wraps a {@link SparkTable} in an anonymous subclass that adds {@link + * TableCapability#ACCEPT_ANY_SCHEMA} to the table's capabilities. + * + *

The anonymous class delegates all other behaviour to the original table by invoking {@code + * super} (which delegates to the underlying Iceberg table object). {@code snapshotId=null} and + * {@code refreshEagerly=false} are the correct defaults for a standard (non-time-travel) table + * load; the original table's Iceberg {@code Table} object is passed unchanged so all reads and + * writes continue to use the real table state. + */ + private SparkTable withAcceptAnySchema(SparkTable original) { + return new SparkTable(original.table(), null /* snapshotId */, false /* refreshEagerly */) { + @Override + public Set capabilities() { + Set caps = new HashSet<>(original.capabilities()); + caps.add(TableCapability.ACCEPT_ANY_SCHEMA); + return Collections.unmodifiableSet(caps); + } + }; + } +} diff --git a/integrations/spark/spark-3.1/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OHWriteSchemaNormalizationRule.scala b/integrations/spark/spark-3.1/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OHWriteSchemaNormalizationRule.scala new file mode 100644 index 000000000..75f03e072 --- /dev/null +++ b/integrations/spark/spark-3.1/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OHWriteSchemaNormalizationRule.scala @@ -0,0 +1,163 @@ +package com.linkedin.openhouse.spark.extensions + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, V2WriteCommand} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +/** + * Post-hoc resolution rule that replicates the column-name and type normalization that Spark's + * {@code ResolveOutputRelation} would have applied to OpenHouse write commands, compensating for + * the fact that OH tables advertise {@link org.apache.spark.sql.connector.catalog.TableCapability#ACCEPT_ANY_SCHEMA} + * (via {@link OHSparkCatalog}) which causes {@code ResolveOutputRelation} to skip them entirely. + * + *

Why {@code ACCEPT_ANY_SCHEMA}? Spark's {@code ResolveOutputRelation} throws at analysis time + * when {@code caseSensitive=true} and a client DataFrame column (e.g. {@code "id"}) does not match + * the OH table column name exactly ({@code "ID"}). Advertising {@code ACCEPT_ANY_SCHEMA} prevents + * the throw. This rule then runs as a {@code Post-Hoc Resolution} rule and does the work that + * {@code ResolveOutputRelation} would have done: wrapping the source query in a {@code Project} + * that renames (and if necessary casts) each source column to the stored OH casing and type. + * + *

The rule handles both write modes: + *

    + *
  • By-name writes ({@code isByName=true}, e.g. {@code df.writeTo().append()}): each + * source column is matched to the target column whose name it equals case-insensitively. + * Tables with case-duplicate columns are skipped (ambiguous target).
  • + *
  • By-position writes ({@code isByName=false}, e.g. {@code INSERT INTO … VALUES …}): + * source and target columns are zipped positionally and each source column is renamed (and + * if the types differ, cast) to match the target. This replicates the {@code Alias} + + * {@code Cast} that {@code ResolveOutputRelation} would have inserted.
  • + *
+ * + *

In both modes, if source and target already match in name and type the rule returns the plan + * unchanged. If the column count differs the rule is a no-op (the mismatch is left for Iceberg or + * the OH server to report). + */ +class OHWriteSchemaNormalizationRule(spark: SparkSession) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.transformDown { + case write: V2WriteCommand + if write.table.resolved && write.query.resolved && isOHWrite(write) => + normalizeColumnNames(write).getOrElse(write) + } + } + + private def isOHWrite(write: V2WriteCommand): Boolean = { + write.table match { + case rel: DataSourceV2Relation => isOHRelation(rel) + case _ => false + } + } + + private def normalizeColumnNames(write: V2WriteCommand): Option[V2WriteCommand] = { + val ohRelation = write.table match { + case rel: DataSourceV2Relation => rel + case _ => return None + } + + val targetCols = ohRelation.output + val sourceCols = write.query.output + + // If column counts differ, leave it to Iceberg / the OH server to report the mismatch. + if (sourceCols.size != targetCols.size) return None + + val projections = + if (write.isByName) projectByName(sourceCols, targetCols) + else projectByPosition(sourceCols, targetCols) + + projections match { + case None => None + case Some(exprs) => Some(write.withNewQuery(Project(exprs, write.query))) + } + } + + /** + * By-name mode: replicate what {@code ResolveOutputRelation} does for by-name writes — produce a + * projection in target column order that renames (and if necessary casts) each source + * column to the stored OH casing. This also handles the case where the source DataFrame has + * columns in a different order than the stored schema (e.g. when the source is built from a bean + * whose fields are introspected alphabetically). + * + *

Tables with case-duplicate columns are skipped (the target is ambiguous). + */ + private def projectByName( + sourceCols: Seq[org.apache.spark.sql.catalyst.expressions.Attribute], + targetCols: Seq[org.apache.spark.sql.catalyst.expressions.Attribute]) + : Option[Seq[org.apache.spark.sql.catalyst.expressions.NamedExpression]] = { + + // Case-duplicate target: skip normalization to avoid silently misdirecting the write. + val targetGrouped = targetCols.groupBy(_.name.toLowerCase) + if (targetGrouped.values.exists(_.size > 1)) return None + + // Case-duplicate source: skip to avoid ambiguous lookup. + val srcGrouped = sourceCols.groupBy(_.name.toLowerCase) + if (srcGrouped.values.exists(_.size > 1)) return None + val srcByLower: Map[String, org.apache.spark.sql.catalyst.expressions.Attribute] = + srcGrouped.map { case (lower, attrs) => lower -> attrs.head } + + // Produce expressions in TARGET column order (replicating ResolveOutputRelation). + // For each target column find the matching source column by case-insensitive name. + val exprs: Seq[org.apache.spark.sql.catalyst.expressions.NamedExpression] = + targetCols.map { tgt => + srcByLower.get(tgt.name.toLowerCase) match { + case Some(src) if src.name == tgt.name => src // correct casing, keep as-is + case Some(src) => Alias(src, tgt.name)() // rename to stored casing + case None => return None // unmatched column + } + } + + // No-op if the result is identical to the source (same column order, same names). + val unchanged = exprs.zip(sourceCols).forall { + case (expr: org.apache.spark.sql.catalyst.expressions.Attribute, src) => + expr.exprId == src.exprId + case _ => false + } + if (unchanged) None else Some(exprs) + } + + /** + * By-position mode (e.g. {@code INSERT INTO … VALUES …}): zip source and target by position. + * For each pair, replicate what {@code ResolveOutputRelation} would have done: + *

    + *
  • If names and types already match, keep the source attribute as-is.
  • + *
  • Otherwise, wrap the source in {@code Alias(Cast(src, targetType), targetName)} to + * rename the column and coerce the type to the stored schema.
  • + *
+ */ + private def projectByPosition( + sourceCols: Seq[org.apache.spark.sql.catalyst.expressions.Attribute], + targetCols: Seq[org.apache.spark.sql.catalyst.expressions.Attribute]) + : Option[Seq[org.apache.spark.sql.catalyst.expressions.NamedExpression]] = { + + val pairsNeedingChange = sourceCols.zip(targetCols).filter { + case (src, tgt) => + src.name != tgt.name || + src.dataType != tgt.dataType || + src.metadata != tgt.metadata + } + if (pairsNeedingChange.isEmpty) return None + + val exprs = sourceCols.zip(targetCols).map { + case (src, tgt) + if src.name == tgt.name && src.dataType == tgt.dataType && src.metadata == tgt.metadata => + src + case (src, tgt) => + val castExpr = if (src.dataType == tgt.dataType) src + else Cast(src, tgt.dataType, Option(spark.conf.get("spark.sql.session.timeZone"))) + Alias(castExpr, tgt.name)(explicitMetadata = Some(tgt.metadata)) + } + Some(exprs) + } + + private def isOHRelation(rel: DataSourceV2Relation): Boolean = { + rel.catalog match { + case Some(c) => + val key = s"spark.sql.catalog.${c.name()}.catalog-impl" + spark.conf.getOption(key).exists(_.toLowerCase.contains("openhouse")) + case None => + false + } + } +} diff --git a/integrations/spark/spark-3.1/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OpenhouseSparkSessionExtensions.scala b/integrations/spark/spark-3.1/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OpenhouseSparkSessionExtensions.scala index c8d911dc2..e7c29181b 100644 --- a/integrations/spark/spark-3.1/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OpenhouseSparkSessionExtensions.scala +++ b/integrations/spark/spark-3.1/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OpenhouseSparkSessionExtensions.scala @@ -7,6 +7,7 @@ import org.apache.spark.sql.SparkSessionExtensions class OpenhouseSparkSessionExtensions extends (SparkSessionExtensions => Unit) { override def apply(extensions: SparkSessionExtensions): Unit = { extensions.injectParser { case (_, parser) => new OpenhouseSparkSqlExtensionsParser(parser) } - extensions.injectPlannerStrategy( spark => OpenhouseDataSourceV2Strategy(spark)) + extensions.injectPlannerStrategy(spark => OpenhouseDataSourceV2Strategy(spark)) + extensions.injectPostHocResolutionRule(spark => new OHWriteSchemaNormalizationRule(spark)) } } diff --git a/integrations/spark/spark-3.5/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/CTASNonNullTestSpark3_5.java b/integrations/spark/spark-3.5/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/CTASNonNullTestSpark3_5.java index 46b47c9b8..6c7dd0898 100644 --- a/integrations/spark/spark-3.5/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/CTASNonNullTestSpark3_5.java +++ b/integrations/spark/spark-3.5/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/CTASNonNullTestSpark3_5.java @@ -24,7 +24,8 @@ public void testCTASPreservesNonNull() throws Exception { // Verify spark catalogs have correct classes configured assertEquals( - "org.apache.iceberg.spark.SparkCatalog", spark.conf().get("spark.sql.catalog.openhouse")); + "com.linkedin.openhouse.spark.OHSparkCatalog", + spark.conf().get("spark.sql.catalog.openhouse")); // Verify id column is preserved in good catalog, not preserved in bad catalog assertFalse(sourceSchema.apply("id").nullable(), "Source table id column should be required"); diff --git a/integrations/spark/spark-3.5/openhouse-spark-runtime/src/main/java/com/linkedin/openhouse/spark/OHSparkCatalog.java b/integrations/spark/spark-3.5/openhouse-spark-runtime/src/main/java/com/linkedin/openhouse/spark/OHSparkCatalog.java new file mode 100644 index 000000000..34ab2d33f --- /dev/null +++ b/integrations/spark/spark-3.5/openhouse-spark-runtime/src/main/java/com/linkedin/openhouse/spark/OHSparkCatalog.java @@ -0,0 +1,60 @@ +package com.linkedin.openhouse.spark; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.source.SparkTable; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; + +/** + * OpenHouse catalog extension for Spark 3.5 / Iceberg 1.5. Overrides {@link + * SparkCatalog#loadTable(Identifier)} to advertise {@link TableCapability#ACCEPT_ANY_SCHEMA} on + * every OpenHouse table. This prevents Spark's {@code ResolveOutputRelation} from throwing at + * analysis time when {@code caseSensitive=true} and the source column casing differs from the + * stored column name. The companion rule {@link + * com.linkedin.openhouse.spark.extensions.OHWriteSchemaNormalizationRule} runs as a post-hoc + * resolution rule and applies the necessary column renaming / casting that {@code + * ResolveOutputRelation} would otherwise have done. + */ +public class OHSparkCatalog extends SparkCatalog { + + @Override + public Table loadTable(Identifier ident) throws NoSuchTableException { + Table original = super.loadTable(ident); + if (original instanceof SparkTable) { + return withAcceptAnySchema((SparkTable) original); + } + return original; + } + + private SparkTable withAcceptAnySchema(SparkTable original) { + // SparkTable carries a branch field (set when loading branch-qualified identifiers like + // "table.branch_feature_a"). We must use the SparkTable(Table, String, boolean) constructor + // when a branch is present so that newWriteBuilder() targets the correct branch. + // Using SparkTable(Table, Long, boolean) with snapshotId=null would silently drop the branch + // and cause all branch writes to land on the main table instead. + String branch = original.branch(); + if (branch != null) { + return new SparkTable(original.table(), branch, false /* refreshEagerly */) { + @Override + public Set capabilities() { + Set caps = new HashSet<>(original.capabilities()); + caps.add(TableCapability.ACCEPT_ANY_SCHEMA); + return Collections.unmodifiableSet(caps); + } + }; + } + return new SparkTable(original.table(), original.snapshotId(), false /* refreshEagerly */) { + @Override + public Set capabilities() { + Set caps = new HashSet<>(original.capabilities()); + caps.add(TableCapability.ACCEPT_ANY_SCHEMA); + return Collections.unmodifiableSet(caps); + } + }; + } +} diff --git a/integrations/spark/spark-3.5/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OHWriteSchemaNormalizationRule.scala b/integrations/spark/spark-3.5/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OHWriteSchemaNormalizationRule.scala new file mode 100644 index 000000000..75f03e072 --- /dev/null +++ b/integrations/spark/spark-3.5/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OHWriteSchemaNormalizationRule.scala @@ -0,0 +1,163 @@ +package com.linkedin.openhouse.spark.extensions + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, V2WriteCommand} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +/** + * Post-hoc resolution rule that replicates the column-name and type normalization that Spark's + * {@code ResolveOutputRelation} would have applied to OpenHouse write commands, compensating for + * the fact that OH tables advertise {@link org.apache.spark.sql.connector.catalog.TableCapability#ACCEPT_ANY_SCHEMA} + * (via {@link OHSparkCatalog}) which causes {@code ResolveOutputRelation} to skip them entirely. + * + *

Why {@code ACCEPT_ANY_SCHEMA}? Spark's {@code ResolveOutputRelation} throws at analysis time + * when {@code caseSensitive=true} and a client DataFrame column (e.g. {@code "id"}) does not match + * the OH table column name exactly ({@code "ID"}). Advertising {@code ACCEPT_ANY_SCHEMA} prevents + * the throw. This rule then runs as a {@code Post-Hoc Resolution} rule and does the work that + * {@code ResolveOutputRelation} would have done: wrapping the source query in a {@code Project} + * that renames (and if necessary casts) each source column to the stored OH casing and type. + * + *

The rule handles both write modes: + *

    + *
  • By-name writes ({@code isByName=true}, e.g. {@code df.writeTo().append()}): each + * source column is matched to the target column whose name it equals case-insensitively. + * Tables with case-duplicate columns are skipped (ambiguous target).
  • + *
  • By-position writes ({@code isByName=false}, e.g. {@code INSERT INTO … VALUES …}): + * source and target columns are zipped positionally and each source column is renamed (and + * if the types differ, cast) to match the target. This replicates the {@code Alias} + + * {@code Cast} that {@code ResolveOutputRelation} would have inserted.
  • + *
+ * + *

In both modes, if source and target already match in name and type the rule returns the plan + * unchanged. If the column count differs the rule is a no-op (the mismatch is left for Iceberg or + * the OH server to report). + */ +class OHWriteSchemaNormalizationRule(spark: SparkSession) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.transformDown { + case write: V2WriteCommand + if write.table.resolved && write.query.resolved && isOHWrite(write) => + normalizeColumnNames(write).getOrElse(write) + } + } + + private def isOHWrite(write: V2WriteCommand): Boolean = { + write.table match { + case rel: DataSourceV2Relation => isOHRelation(rel) + case _ => false + } + } + + private def normalizeColumnNames(write: V2WriteCommand): Option[V2WriteCommand] = { + val ohRelation = write.table match { + case rel: DataSourceV2Relation => rel + case _ => return None + } + + val targetCols = ohRelation.output + val sourceCols = write.query.output + + // If column counts differ, leave it to Iceberg / the OH server to report the mismatch. + if (sourceCols.size != targetCols.size) return None + + val projections = + if (write.isByName) projectByName(sourceCols, targetCols) + else projectByPosition(sourceCols, targetCols) + + projections match { + case None => None + case Some(exprs) => Some(write.withNewQuery(Project(exprs, write.query))) + } + } + + /** + * By-name mode: replicate what {@code ResolveOutputRelation} does for by-name writes — produce a + * projection in target column order that renames (and if necessary casts) each source + * column to the stored OH casing. This also handles the case where the source DataFrame has + * columns in a different order than the stored schema (e.g. when the source is built from a bean + * whose fields are introspected alphabetically). + * + *

Tables with case-duplicate columns are skipped (the target is ambiguous). + */ + private def projectByName( + sourceCols: Seq[org.apache.spark.sql.catalyst.expressions.Attribute], + targetCols: Seq[org.apache.spark.sql.catalyst.expressions.Attribute]) + : Option[Seq[org.apache.spark.sql.catalyst.expressions.NamedExpression]] = { + + // Case-duplicate target: skip normalization to avoid silently misdirecting the write. + val targetGrouped = targetCols.groupBy(_.name.toLowerCase) + if (targetGrouped.values.exists(_.size > 1)) return None + + // Case-duplicate source: skip to avoid ambiguous lookup. + val srcGrouped = sourceCols.groupBy(_.name.toLowerCase) + if (srcGrouped.values.exists(_.size > 1)) return None + val srcByLower: Map[String, org.apache.spark.sql.catalyst.expressions.Attribute] = + srcGrouped.map { case (lower, attrs) => lower -> attrs.head } + + // Produce expressions in TARGET column order (replicating ResolveOutputRelation). + // For each target column find the matching source column by case-insensitive name. + val exprs: Seq[org.apache.spark.sql.catalyst.expressions.NamedExpression] = + targetCols.map { tgt => + srcByLower.get(tgt.name.toLowerCase) match { + case Some(src) if src.name == tgt.name => src // correct casing, keep as-is + case Some(src) => Alias(src, tgt.name)() // rename to stored casing + case None => return None // unmatched column + } + } + + // No-op if the result is identical to the source (same column order, same names). + val unchanged = exprs.zip(sourceCols).forall { + case (expr: org.apache.spark.sql.catalyst.expressions.Attribute, src) => + expr.exprId == src.exprId + case _ => false + } + if (unchanged) None else Some(exprs) + } + + /** + * By-position mode (e.g. {@code INSERT INTO … VALUES …}): zip source and target by position. + * For each pair, replicate what {@code ResolveOutputRelation} would have done: + *

    + *
  • If names and types already match, keep the source attribute as-is.
  • + *
  • Otherwise, wrap the source in {@code Alias(Cast(src, targetType), targetName)} to + * rename the column and coerce the type to the stored schema.
  • + *
+ */ + private def projectByPosition( + sourceCols: Seq[org.apache.spark.sql.catalyst.expressions.Attribute], + targetCols: Seq[org.apache.spark.sql.catalyst.expressions.Attribute]) + : Option[Seq[org.apache.spark.sql.catalyst.expressions.NamedExpression]] = { + + val pairsNeedingChange = sourceCols.zip(targetCols).filter { + case (src, tgt) => + src.name != tgt.name || + src.dataType != tgt.dataType || + src.metadata != tgt.metadata + } + if (pairsNeedingChange.isEmpty) return None + + val exprs = sourceCols.zip(targetCols).map { + case (src, tgt) + if src.name == tgt.name && src.dataType == tgt.dataType && src.metadata == tgt.metadata => + src + case (src, tgt) => + val castExpr = if (src.dataType == tgt.dataType) src + else Cast(src, tgt.dataType, Option(spark.conf.get("spark.sql.session.timeZone"))) + Alias(castExpr, tgt.name)(explicitMetadata = Some(tgt.metadata)) + } + Some(exprs) + } + + private def isOHRelation(rel: DataSourceV2Relation): Boolean = { + rel.catalog match { + case Some(c) => + val key = s"spark.sql.catalog.${c.name()}.catalog-impl" + spark.conf.getOption(key).exists(_.toLowerCase.contains("openhouse")) + case None => + false + } + } +} diff --git a/integrations/spark/spark-3.5/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OpenhouseSparkSessionExtensions.scala b/integrations/spark/spark-3.5/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OpenhouseSparkSessionExtensions.scala new file mode 100644 index 000000000..e7c29181b --- /dev/null +++ b/integrations/spark/spark-3.5/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/extensions/OpenhouseSparkSessionExtensions.scala @@ -0,0 +1,13 @@ +package com.linkedin.openhouse.spark.extensions + +import com.linkedin.openhouse.spark.sql.catalyst.parser.extensions.OpenhouseSparkSqlExtensionsParser +import com.linkedin.openhouse.spark.sql.execution.datasources.v2.OpenhouseDataSourceV2Strategy +import org.apache.spark.sql.SparkSessionExtensions + +class OpenhouseSparkSessionExtensions extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectParser { case (_, parser) => new OpenhouseSparkSqlExtensionsParser(parser) } + extensions.injectPlannerStrategy(spark => OpenhouseDataSourceV2Strategy(spark)) + extensions.injectPostHocResolutionRule(spark => new OHWriteSchemaNormalizationRule(spark)) + } +} diff --git a/tables-test-fixtures/tables-test-fixtures-iceberg-1.2/src/main/java/com/linkedin/openhouse/tablestest/TestSparkSessionUtil.java b/tables-test-fixtures/tables-test-fixtures-iceberg-1.2/src/main/java/com/linkedin/openhouse/tablestest/TestSparkSessionUtil.java index 5404c3248..d78278b8b 100644 --- a/tables-test-fixtures/tables-test-fixtures-iceberg-1.2/src/main/java/com/linkedin/openhouse/tablestest/TestSparkSessionUtil.java +++ b/tables-test-fixtures/tables-test-fixtures-iceberg-1.2/src/main/java/com/linkedin/openhouse/tablestest/TestSparkSessionUtil.java @@ -43,7 +43,7 @@ public static void configureCatalogs( builder .config( String.format("spark.sql.catalog.%s", catalogName), - "org.apache.iceberg.spark.SparkCatalog") + "com.linkedin.openhouse.spark.OHSparkCatalog") .config( String.format("spark.sql.catalog.%s.catalog-impl", catalogName), "com.linkedin.openhouse.spark.OpenHouseCatalog")