Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void beforeEach(ExtensionContext context) throws Exception {
"spark.sql.extensions",
("org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions,"
+ "com.linkedin.openhouse.spark.extensions.OpenhouseSparkSessionExtensions"))
.config("spark.sql.catalog.openhouse", "org.apache.iceberg.spark.SparkCatalog")
.config("spark.sql.catalog.openhouse", "com.linkedin.openhouse.spark.OHSparkCatalog")
.config(
"spark.sql.catalog.openhouse.catalog-impl",
"com.linkedin.openhouse.spark.OpenHouseCatalog")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ public void testCTASPreservesNonNull() throws Exception {

// Verify spark catalogs have correct classes configured
assertEquals(
"org.apache.iceberg.spark.SparkCatalog", spark.conf().get("spark.sql.catalog.openhouse"));
"com.linkedin.openhouse.spark.OHSparkCatalog",
spark.conf().get("spark.sql.catalog.openhouse"));

// Verify id column is preserved in good catalog, not preserved in bad catalog
assertFalse(sourceSchema.apply("id").nullable(), "Source table id column should be required");
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package com.linkedin.openhouse.spark;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import org.apache.iceberg.spark.SparkCatalog;
import org.apache.iceberg.spark.source.SparkTable;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.catalog.TableCapability;

/**
* Spark catalog wrapper for OpenHouse that extends {@link SparkCatalog} and annotates every loaded
* table with {@link TableCapability#ACCEPT_ANY_SCHEMA}.
*
* <p>Why {@code ACCEPT_ANY_SCHEMA}? Spark's {@code ResolveOutputRelation} analyzer rule uses the
* session resolver (case-sensitive when {@code spark.sql.caseSensitive=true}). If a client
* DataFrame has column {@code "id"} but the OH table stores {@code "ID"}, the write command never
* resolves — Spark throws "Cannot find data for output column 'ID'" before OH's own server-side
* case-insensitive schema validation runs.
*
* <p>Advertising {@code ACCEPT_ANY_SCHEMA} causes {@code DataSourceV2Relation.skipSchemaResolution}
* to return {@code true}, which makes {@code V2WriteCommand.outputResolved} return {@code true}
* immediately. {@code ResolveOutputRelation} therefore skips OH write commands, allowing {@link
* OHWriteSchemaNormalizationRule} (a post-hoc resolution rule) to insert the necessary
* column-renaming {@code Project} before execution.
*
* <p>For reads and DDL the capability has no effect; it is only consulted during write analysis.
*
* <p>Configuration:
*
* <pre>
* spark.sql.catalog.openhouse=com.linkedin.openhouse.spark.OHSparkCatalog
* spark.sql.catalog.openhouse.catalog-impl=com.linkedin.openhouse.spark.OpenHouseCatalog
* </pre>
*/
public class OHSparkCatalog extends SparkCatalog {

@Override
public SparkTable loadTable(Identifier ident) throws NoSuchTableException {
Comment on lines +36 to +40
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have V2 sparkCatalog interface implemented in the internal fork i think, or we did and needed to deramp it for an unrelated failure.
We would need to co-ordinate the the two set of changes / make them compatible

SparkTable original = super.loadTable(ident);
return withAcceptAnySchema(original);
}

/**
* Wraps a {@link SparkTable} in an anonymous subclass that adds {@link
* TableCapability#ACCEPT_ANY_SCHEMA} to the table's capabilities.
*
* <p>The anonymous class delegates all other behaviour to the original table by invoking {@code
* super} (which delegates to the underlying Iceberg table object). {@code snapshotId=null} and
* {@code refreshEagerly=false} are the correct defaults for a standard (non-time-travel) table
* load; the original table's Iceberg {@code Table} object is passed unchanged so all reads and
* writes continue to use the real table state.
*/
private SparkTable withAcceptAnySchema(SparkTable original) {
return new SparkTable(original.table(), null /* snapshotId */, false /* refreshEagerly */) {
@Override
public Set<TableCapability> capabilities() {
Set<TableCapability> caps = new HashSet<>(original.capabilities());
caps.add(TableCapability.ACCEPT_ANY_SCHEMA);
return Collections.unmodifiableSet(caps);
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package com.linkedin.openhouse.spark.extensions

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, V2WriteCommand}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation

/**
* Post-hoc resolution rule that replicates the column-name and type normalization that Spark's
* {@code ResolveOutputRelation} would have applied to OpenHouse write commands, compensating for
* the fact that OH tables advertise {@link org.apache.spark.sql.connector.catalog.TableCapability#ACCEPT_ANY_SCHEMA}
* (via {@link OHSparkCatalog}) which causes {@code ResolveOutputRelation} to skip them entirely.
*
* <p>Why {@code ACCEPT_ANY_SCHEMA}? Spark's {@code ResolveOutputRelation} throws at analysis time
* when {@code caseSensitive=true} and a client DataFrame column (e.g. {@code "id"}) does not match
* the OH table column name exactly ({@code "ID"}). Advertising {@code ACCEPT_ANY_SCHEMA} prevents
* the throw. This rule then runs as a {@code Post-Hoc Resolution} rule and does the work that
* {@code ResolveOutputRelation} would have done: wrapping the source query in a {@code Project}
* that renames (and if necessary casts) each source column to the stored OH casing and type.
*
* <p>The rule handles both write modes:
* <ul>
* <li><b>By-name writes</b> ({@code isByName=true}, e.g. {@code df.writeTo().append()}): each
* source column is matched to the target column whose name it equals case-insensitively.
* Tables with case-duplicate columns are skipped (ambiguous target).</li>
* <li><b>By-position writes</b> ({@code isByName=false}, e.g. {@code INSERT INTO … VALUES …}):
* source and target columns are zipped positionally and each source column is renamed (and
* if the types differ, cast) to match the target. This replicates the {@code Alias} +
* {@code Cast} that {@code ResolveOutputRelation} would have inserted.</li>
* </ul>
*
* <p>In both modes, if source and target already match in name and type the rule returns the plan
* unchanged. If the column count differs the rule is a no-op (the mismatch is left for Iceberg or
* the OH server to report).
*/
class OHWriteSchemaNormalizationRule(spark: SparkSession) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformDown {
case write: V2WriteCommand
if write.table.resolved && write.query.resolved && isOHWrite(write) =>
normalizeColumnNames(write).getOrElse(write)
}
}

private def isOHWrite(write: V2WriteCommand): Boolean = {
write.table match {
case rel: DataSourceV2Relation => isOHRelation(rel)
case _ => false
}
}

private def normalizeColumnNames(write: V2WriteCommand): Option[V2WriteCommand] = {
val ohRelation = write.table match {
case rel: DataSourceV2Relation => rel
case _ => return None
}

val targetCols = ohRelation.output
val sourceCols = write.query.output

// If column counts differ, leave it to Iceberg / the OH server to report the mismatch.
if (sourceCols.size != targetCols.size) return None

val projections =
if (write.isByName) projectByName(sourceCols, targetCols)
else projectByPosition(sourceCols, targetCols)

projections match {
case None => None
case Some(exprs) => Some(write.withNewQuery(Project(exprs, write.query)))
}
}

/**
* By-name mode: replicate what {@code ResolveOutputRelation} does for by-name writes — produce a
* projection in <em>target column order</em> that renames (and if necessary casts) each source
* column to the stored OH casing. This also handles the case where the source DataFrame has
* columns in a different order than the stored schema (e.g. when the source is built from a bean
* whose fields are introspected alphabetically).
*
* <p>Tables with case-duplicate columns are skipped (the target is ambiguous).
*/
private def projectByName(
sourceCols: Seq[org.apache.spark.sql.catalyst.expressions.Attribute],
targetCols: Seq[org.apache.spark.sql.catalyst.expressions.Attribute])
: Option[Seq[org.apache.spark.sql.catalyst.expressions.NamedExpression]] = {

// Case-duplicate target: skip normalization to avoid silently misdirecting the write.
val targetGrouped = targetCols.groupBy(_.name.toLowerCase)
if (targetGrouped.values.exists(_.size > 1)) return None

// Case-duplicate source: skip to avoid ambiguous lookup.
val srcGrouped = sourceCols.groupBy(_.name.toLowerCase)
if (srcGrouped.values.exists(_.size > 1)) return None
val srcByLower: Map[String, org.apache.spark.sql.catalyst.expressions.Attribute] =
srcGrouped.map { case (lower, attrs) => lower -> attrs.head }

// Produce expressions in TARGET column order (replicating ResolveOutputRelation).
// For each target column find the matching source column by case-insensitive name.
val exprs: Seq[org.apache.spark.sql.catalyst.expressions.NamedExpression] =
targetCols.map { tgt =>
srcByLower.get(tgt.name.toLowerCase) match {
case Some(src) if src.name == tgt.name => src // correct casing, keep as-is
case Some(src) => Alias(src, tgt.name)() // rename to stored casing
case None => return None // unmatched column
}
}

// No-op if the result is identical to the source (same column order, same names).
val unchanged = exprs.zip(sourceCols).forall {
case (expr: org.apache.spark.sql.catalyst.expressions.Attribute, src) =>
expr.exprId == src.exprId
case _ => false
}
if (unchanged) None else Some(exprs)
}

/**
* By-position mode (e.g. {@code INSERT INTO … VALUES …}): zip source and target by position.
* For each pair, replicate what {@code ResolveOutputRelation} would have done:
* <ul>
* <li>If names and types already match, keep the source attribute as-is.</li>
* <li>Otherwise, wrap the source in {@code Alias(Cast(src, targetType), targetName)} to
* rename the column and coerce the type to the stored schema.</li>
* </ul>
*/
private def projectByPosition(
sourceCols: Seq[org.apache.spark.sql.catalyst.expressions.Attribute],
targetCols: Seq[org.apache.spark.sql.catalyst.expressions.Attribute])
: Option[Seq[org.apache.spark.sql.catalyst.expressions.NamedExpression]] = {

val pairsNeedingChange = sourceCols.zip(targetCols).filter {
case (src, tgt) =>
src.name != tgt.name ||
src.dataType != tgt.dataType ||
src.metadata != tgt.metadata
}
if (pairsNeedingChange.isEmpty) return None

val exprs = sourceCols.zip(targetCols).map {
case (src, tgt)
if src.name == tgt.name && src.dataType == tgt.dataType && src.metadata == tgt.metadata =>
src
case (src, tgt) =>
val castExpr = if (src.dataType == tgt.dataType) src
else Cast(src, tgt.dataType, Option(spark.conf.get("spark.sql.session.timeZone")))
Alias(castExpr, tgt.name)(explicitMetadata = Some(tgt.metadata))
}
Some(exprs)
}

private def isOHRelation(rel: DataSourceV2Relation): Boolean = {
rel.catalog match {
case Some(c) =>
val key = s"spark.sql.catalog.${c.name()}.catalog-impl"
spark.conf.getOption(key).exists(_.toLowerCase.contains("openhouse"))
case None =>
false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.apache.spark.sql.SparkSessionExtensions
class OpenhouseSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
override def apply(extensions: SparkSessionExtensions): Unit = {
extensions.injectParser { case (_, parser) => new OpenhouseSparkSqlExtensionsParser(parser) }
extensions.injectPlannerStrategy( spark => OpenhouseDataSourceV2Strategy(spark))
extensions.injectPlannerStrategy(spark => OpenhouseDataSourceV2Strategy(spark))
extensions.injectPostHocResolutionRule(spark => new OHWriteSchemaNormalizationRule(spark))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ public void testCTASPreservesNonNull() throws Exception {

// Verify spark catalogs have correct classes configured
assertEquals(
"org.apache.iceberg.spark.SparkCatalog", spark.conf().get("spark.sql.catalog.openhouse"));
"com.linkedin.openhouse.spark.OHSparkCatalog",
spark.conf().get("spark.sql.catalog.openhouse"));

// Verify id column is preserved in good catalog, not preserved in bad catalog
assertFalse(sourceSchema.apply("id").nullable(), "Source table id column should be required");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.linkedin.openhouse.spark;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import org.apache.iceberg.spark.SparkCatalog;
import org.apache.iceberg.spark.source.SparkTable;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableCapability;

/**
* OpenHouse catalog extension for Spark 3.5 / Iceberg 1.5. Overrides {@link
* SparkCatalog#loadTable(Identifier)} to advertise {@link TableCapability#ACCEPT_ANY_SCHEMA} on
* every OpenHouse table. This prevents Spark's {@code ResolveOutputRelation} from throwing at
* analysis time when {@code caseSensitive=true} and the source column casing differs from the
* stored column name. The companion rule {@link
* com.linkedin.openhouse.spark.extensions.OHWriteSchemaNormalizationRule} runs as a post-hoc
* resolution rule and applies the necessary column renaming / casting that {@code
* ResolveOutputRelation} would otherwise have done.
*/
public class OHSparkCatalog extends SparkCatalog {

@Override
public Table loadTable(Identifier ident) throws NoSuchTableException {
Table original = super.loadTable(ident);
if (original instanceof SparkTable) {
return withAcceptAnySchema((SparkTable) original);
}
return original;
}

private SparkTable withAcceptAnySchema(SparkTable original) {
// SparkTable carries a branch field (set when loading branch-qualified identifiers like
// "table.branch_feature_a"). We must use the SparkTable(Table, String, boolean) constructor
// when a branch is present so that newWriteBuilder() targets the correct branch.
// Using SparkTable(Table, Long, boolean) with snapshotId=null would silently drop the branch
// and cause all branch writes to land on the main table instead.
String branch = original.branch();
if (branch != null) {
return new SparkTable(original.table(), branch, false /* refreshEagerly */) {
@Override
public Set<TableCapability> capabilities() {
Set<TableCapability> caps = new HashSet<>(original.capabilities());
caps.add(TableCapability.ACCEPT_ANY_SCHEMA);
return Collections.unmodifiableSet(caps);
}
};
}
return new SparkTable(original.table(), original.snapshotId(), false /* refreshEagerly */) {
@Override
public Set<TableCapability> capabilities() {
Set<TableCapability> caps = new HashSet<>(original.capabilities());
caps.add(TableCapability.ACCEPT_ANY_SCHEMA);
return Collections.unmodifiableSet(caps);
}
};
}
}
Loading