diff --git a/src/dlt_iceberg/destination_client.py b/src/dlt_iceberg/destination_client.py index 8f8919d..ba861f2 100644 --- a/src/dlt_iceberg/destination_client.py +++ b/src/dlt_iceberg/destination_client.py @@ -23,8 +23,11 @@ RunnableLoadJob, DestinationClientConfiguration, SupportsOpenTables, + WithStateSync, + StorageSchemaInfo, + StateInfo, ) -from dlt.common.schema.typing import TTableFormat +from dlt.common.schema.typing import TTableFormat, TSchemaTables from dlt.destinations.sql_client import WithSqlClient, SqlClientBase from dlt.common.schema import Schema, TTableSchema from dlt.common.schema.typing import TTableSchema as PreparedTableSchema @@ -159,12 +162,13 @@ def run(self) -> None: raise -class IcebergRestClient(JobClientBase, WithSqlClient, SupportsOpenTables): +class IcebergRestClient(JobClientBase, WithSqlClient, SupportsOpenTables, WithStateSync): """ Class-based Iceberg REST destination with atomic multi-file commits. Accumulates files during load and commits them atomically in complete_load(). Implements WithSqlClient and SupportsOpenTables for pipeline.dataset() support. + Implements WithStateSync for schema restoration from destination. """ def __init__( @@ -250,6 +254,375 @@ def is_open_table(self, table_format: TTableFormat, table_name: str) -> bool: # All tables in this destination are Iceberg tables return table_format == "iceberg" + # ---- WithStateSync interface ---- + + def _get_newest_schema(self, schema_name: str) -> Optional[StorageSchemaInfo]: + """Get newest schema version by schema name using predicate pushdown.""" + try: + catalog = self._get_catalog() + identifier = f"{self.config.namespace}.{self.schema.version_table_name}" + iceberg_table = catalog.load_table(identifier) + + # Use row_filter for predicate pushdown - only scan matching rows + table = iceberg_table.scan( + row_filter=f"schema_name = '{schema_name}'" + ).to_arrow() + + if len(table) == 0: + return None + + # Find row with max version + versions = table.column("version").to_pylist() + row_idx = versions.index(max(versions)) + + return StorageSchemaInfo( + version_hash=table.column("version_hash")[row_idx].as_py(), + schema_name=table.column("schema_name")[row_idx].as_py(), + version=table.column("version")[row_idx].as_py(), + engine_version=table.column("engine_version")[row_idx].as_py(), + inserted_at=table.column("inserted_at")[row_idx].as_py(), + schema=table.column("schema")[row_idx].as_py(), + ) + except NoSuchTableError: + return None + except Exception as e: + logger.warning(f"Failed to get schema for {schema_name}: {e}") + return None + + def _get_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: + """Get schema by exact version hash using predicate pushdown.""" + try: + catalog = self._get_catalog() + identifier = f"{self.config.namespace}.{self.schema.version_table_name}" + iceberg_table = catalog.load_table(identifier) + + # Use row_filter for predicate pushdown + table = iceberg_table.scan( + row_filter=f"version_hash = '{version_hash}'" + ).to_arrow() + + if len(table) == 0: + return None + + return StorageSchemaInfo( + version_hash=table.column("version_hash")[0].as_py(), + schema_name=table.column("schema_name")[0].as_py(), + version=table.column("version")[0].as_py(), + engine_version=table.column("engine_version")[0].as_py(), + inserted_at=table.column("inserted_at")[0].as_py(), + schema=table.column("schema")[0].as_py(), + ) + except NoSuchTableError: + return None + except Exception as e: + logger.warning(f"Failed to get schema by hash {version_hash}: {e}") + return None + + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: + """Retrieves newest schema from the _dlt_version table. + + Falls back to deriving schema from existing Iceberg tables if no stored + schema exists. This handles scenarios where _dlt_version is empty/corrupted + but Iceberg tables already exist with columns that should be preserved. + """ + # First try to get from _dlt_version + stored = self._get_newest_schema(schema_name or self.schema.name) + if stored: + return stored + + # Fallback: derive schema from existing Iceberg tables + return self._derive_schema_from_iceberg_tables(schema_name or self.schema.name) + + def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: + """Retrieves stored schema by its version hash.""" + return self._get_schema_by_hash(version_hash) + + def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: + """Loads pipeline state from the _dlt_pipeline_state table using predicate pushdown.""" + try: + catalog = self._get_catalog() + identifier = f"{self.config.namespace}.{self.schema.state_table_name}" + iceberg_table = catalog.load_table(identifier) + + # Use row_filter for predicate pushdown + table = iceberg_table.scan( + row_filter=f"pipeline_name = '{pipeline_name}'" + ).to_arrow() + + if len(table) == 0: + return None + + # Find row with max created_at + timestamps = table.column("created_at").to_pylist() + row_idx = timestamps.index(max(timestamps)) + + version_hash = None + if "version_hash" in table.column_names: + version_hash = table.column("version_hash")[row_idx].as_py() + + dlt_load_id = None + if "_dlt_load_id" in table.column_names: + dlt_load_id = table.column("_dlt_load_id")[row_idx].as_py() + + return StateInfo( + version=table.column("version")[row_idx].as_py(), + engine_version=table.column("engine_version")[row_idx].as_py(), + pipeline_name=table.column("pipeline_name")[row_idx].as_py(), + state=table.column("state")[row_idx].as_py(), + created_at=table.column("created_at")[row_idx].as_py(), + version_hash=version_hash, + _dlt_load_id=dlt_load_id, + ) + except NoSuchTableError: + return None + except Exception as e: + logger.warning(f"Failed to get state for pipeline {pipeline_name}: {e}") + return None + + def _derive_schema_from_iceberg_tables(self, schema_name: str) -> Optional[StorageSchemaInfo]: + """Derive a dlt schema from existing Iceberg table schemas in the catalog. + + This is a fallback when _dlt_version has no stored schema but Iceberg + tables already exist. It reads table metadata from the catalog and + constructs a dlt schema that includes all existing columns. + + Args: + schema_name: Name of the schema to derive + + Returns: + StorageSchemaInfo with derived schema, or None if no tables exist + """ + try: + catalog = self._get_catalog() + + try: + tables = catalog.list_tables(self.config.namespace) + except NoSuchNamespaceError: + return None + + if not tables: + return None + + # Build a dlt schema from Iceberg table metadata + derived_tables = {} + for table_id in tables: + table_name = table_id[1] + if table_name.startswith('_dlt_'): + continue # Skip dlt metadata tables + + try: + iceberg_table = catalog.load_table(f"{self.config.namespace}.{table_name}") + iceberg_schema = iceberg_table.schema() + + # Convert Iceberg schema to dlt table schema + columns = {} + for field in iceberg_schema.fields: + columns[field.name] = { + "name": field.name, + "data_type": self._iceberg_type_to_dlt_type(field.field_type), + "nullable": not field.required, + } + + derived_tables[table_name] = { + "name": table_name, + "columns": columns, + } + logger.info( + f"Derived schema for table {table_name} with " + f"{len(columns)} columns from Iceberg metadata" + ) + except Exception as e: + logger.warning(f"Failed to derive schema for table {table_name}: {e}") + continue + + if not derived_tables: + return None + + # Create a StorageSchemaInfo from derived schema + import json + from dlt.common.pendulum import pendulum + from dlt.common.schema import Schema as DltSchema + + # Start with a fresh dlt schema to get proper structure and system tables + base_schema = DltSchema(schema_name) + schema_dict = base_schema.to_dict() + + # Merge derived tables into the schema + for table_name, table_def in derived_tables.items(): + schema_dict["tables"][table_name] = table_def + + # Update version hash to indicate it was derived + schema_dict["version_hash"] = "derived_from_iceberg" + + logger.info( + f"Derived schema '{schema_name}' from Iceberg catalog with " + f"{len(derived_tables)} tables" + ) + + return StorageSchemaInfo( + version_hash="derived_from_iceberg", + schema_name=schema_name, + version=1, + engine_version=self.schema.ENGINE_VERSION, + inserted_at=pendulum.now("UTC"), + schema=json.dumps(schema_dict), + ) + except Exception as e: + logger.warning(f"Failed to derive schema from Iceberg tables: {e}") + return None + + def _iceberg_type_to_dlt_type(self, iceberg_type) -> str: + """Convert PyIceberg type to dlt data type string. + + Args: + iceberg_type: PyIceberg type object + + Returns: + dlt data type string + """ + from pyiceberg.types import ( + BooleanType, + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + BinaryType, + DateType, + TimeType, + TimestampType, + TimestamptzType, + DecimalType, + ListType, + MapType, + StructType, + ) + + type_mapping = { + BooleanType: "bool", + IntegerType: "bigint", # dlt uses bigint for integers + LongType: "bigint", + FloatType: "double", + DoubleType: "double", + StringType: "text", + BinaryType: "binary", + DateType: "date", + TimeType: "time", + TimestampType: "timestamp", + TimestamptzType: "timestamp", + } + + for iceberg_cls, dlt_type in type_mapping.items(): + if isinstance(iceberg_type, iceberg_cls): + return dlt_type + + if isinstance(iceberg_type, DecimalType): + return f"decimal({iceberg_type.precision},{iceberg_type.scale})" + + if isinstance(iceberg_type, ListType): + return "complex" # dlt complex type for nested structures + + if isinstance(iceberg_type, (MapType, StructType)): + return "complex" + + return "text" # Default fallback + + def update_stored_schema( + self, + only_tables: Iterable[str] = None, + expected_update: TSchemaTables = None, + ) -> Optional[TSchemaTables]: + """ + Updates storage to the current schema. + + Writes schema to _dlt_version table if it doesn't already exist. + """ + applied_update = super().update_stored_schema(only_tables, expected_update) + + # Check if schema with this hash already exists + current_hash = self.schema.stored_version_hash + if self._get_schema_by_hash(current_hash): + return applied_update + + # Write schema to _dlt_version table + try: + self._write_schema_to_storage() + except Exception as e: + logger.warning(f"Failed to write schema to storage: {e}") + + return applied_update + + def _write_schema_to_storage(self) -> None: + """Write current schema to _dlt_version Iceberg table.""" + from dlt.common.pendulum import pendulum + import json + + catalog = self._get_catalog() + version_table_name = self.schema.version_table_name + identifier = f"{self.config.namespace}.{version_table_name}" + + # Schema data to write + # Use naive datetime (no timezone) to match Iceberg TimestampType + inserted_at = pendulum.now("UTC").naive() + schema_data = { + "version_hash": self.schema.stored_version_hash, + "schema_name": self.schema.name, + "version": self.schema.version, + "engine_version": self.schema.ENGINE_VERSION, + "inserted_at": inserted_at, + "schema": json.dumps(self.schema.to_dict()), + } + + # Create Arrow table with non-nullable schema to match Iceberg required fields + # Use timestamp without timezone to match Iceberg TimestampType() + arrow_schema = pa.schema([ + pa.field("version_hash", pa.string(), nullable=False), + pa.field("schema_name", pa.string(), nullable=False), + pa.field("version", pa.int64(), nullable=False), + pa.field("engine_version", pa.int64(), nullable=False), + pa.field("inserted_at", pa.timestamp("us"), nullable=False), + pa.field("schema", pa.string(), nullable=False), + ]) + arrow_table = pa.table({ + "version_hash": [schema_data["version_hash"]], + "schema_name": [schema_data["schema_name"]], + "version": [schema_data["version"]], + "engine_version": [schema_data["engine_version"]], + "inserted_at": [schema_data["inserted_at"]], + "schema": [schema_data["schema"]], + }, schema=arrow_schema) + + # Create or append to _dlt_version table + try: + iceberg_table = catalog.load_table(identifier) + iceberg_table.append(arrow_table) + except NoSuchTableError: + # Create the table + from pyiceberg.schema import Schema as IcebergSchema + from pyiceberg.types import ( + NestedField, + StringType, + LongType, + TimestampType, + ) + + iceberg_schema = IcebergSchema( + NestedField(1, "version_hash", StringType(), required=True), + NestedField(2, "schema_name", StringType(), required=True), + NestedField(3, "version", LongType(), required=True), + NestedField(4, "engine_version", LongType(), required=True), + NestedField(5, "inserted_at", TimestampType(), required=True), + NestedField(6, "schema", StringType(), required=True), + ) + + iceberg_table = catalog.create_table( + identifier=identifier, + schema=iceberg_schema, + ) + iceberg_table.append(arrow_table) + + logger.info(f"Stored schema {self.schema.name} v{self.schema.version} to {identifier}") + def _get_catalog(self): """Get or create catalog connection.""" if self._catalog is not None: diff --git a/tests/test_state_sync_e2e.py b/tests/test_state_sync_e2e.py new file mode 100644 index 0000000..329a6f6 --- /dev/null +++ b/tests/test_state_sync_e2e.py @@ -0,0 +1,275 @@ +""" +End-to-end tests for state synchronization functionality. + +Tests that verify schema restoration from destination when running +pipelines from different execution contexts. +""" + +import pytest +import tempfile +import shutil +import dlt +from pyiceberg.catalog import load_catalog + + +def test_fresh_pipeline_restores_schema_from_destination(): + """ + Test that fresh pipeline instances can restore schemas from the destination. + + This requires IcebergRestClient to implement WithStateSync. + With WithStateSync, a fresh pipeline can restore the schema from the + destination, allowing it to handle incoming data that is missing columns. + + Scenario (simulating different execution contexts): + - Pipeline 1 (context A): Creates table with columns [a, b, c, d] + - Pipeline 2 (context B): Fresh instance, loads data with only [a, b, c] + + Expected behavior: Pipeline 2 restores schema from destination, + recognizes column 'd' exists, and writes NULL for missing values. + """ + temp_dir = tempfile.mkdtemp() + warehouse_path = f"{temp_dir}/warehouse" + catalog_path = f"{temp_dir}/catalog.db" + + print("\nTest: Fresh Pipeline Restores Schema from Destination") + print(f" Warehouse: {warehouse_path}") + print(f" Catalog: {catalog_path}") + + try: + from dlt_iceberg import iceberg_rest + import uuid + + pipeline_name = f"test_state_sync_{uuid.uuid4().hex[:8]}" + + # PIPELINE 1 (Context A): Create table with columns a, b, c, d + @dlt.resource(name="test_table", write_disposition="append") + def generate_data_v1(): + for i in range(1, 6): + yield { + "a": i, + "b": i * 10, + "c": f"value_{i}", + "d": f"optional_{i}", + } + + pipeline1 = dlt.pipeline( + pipeline_name=pipeline_name, + destination=iceberg_rest( + catalog_uri=f"sqlite:///{catalog_path}", + warehouse=f"file://{warehouse_path}", + namespace="test", + ), + dataset_name="test_dataset", + ) + + print("\nPipeline 1 (Context A): Create table with [a, b, c, d]") + load_info1 = pipeline1.run(generate_data_v1()) + assert not load_info1.has_failed_jobs, "Pipeline 1 load should succeed" + print(" Pipeline 1 completed successfully") + + # Verify table was created with all columns + catalog = load_catalog( + "dlt_catalog", + type="sql", + uri=f"sqlite:///{catalog_path}", + warehouse=f"file://{warehouse_path}", + ) + table = catalog.load_table("test.test_table") + table_columns = [f.name for f in table.schema().fields] + print(f" Table columns: {table_columns}") + assert "d" in table_columns, "Column 'd' should exist in table" + + # Drop pipeline 1's local state to simulate different execution context + pipeline1.drop() + print(" Dropped pipeline 1 local state (simulating context switch)") + + # PIPELINE 2 (Context B): Fresh instance, same pipeline_name, missing column 'd' + @dlt.resource(name="test_table", write_disposition="append") + def generate_data_v2(): + for i in range(6, 11): + yield { + "a": i, + "b": i * 10, + "c": f"value_{i}", + # 'd' is intentionally missing + } + + pipeline2 = dlt.pipeline( + pipeline_name=pipeline_name, + destination=iceberg_rest( + catalog_uri=f"sqlite:///{catalog_path}", + warehouse=f"file://{warehouse_path}", + namespace="test", + ), + dataset_name="test_dataset", + ) + + print("\nPipeline 2 (Context B): Append data with [a, b, c] (missing 'd')") + + # With WithStateSync implemented, Pipeline 2 should restore schema + # from destination and succeed with 'd' as NULL for new rows + load_info2 = pipeline2.run(generate_data_v2()) + assert not load_info2.has_failed_jobs, ( + "Pipeline 2 should succeed after restoring schema from destination" + ) + print(" Pipeline 2 completed successfully") + + # Verify all data is present + table = catalog.load_table("test.test_table") + result = table.scan().to_arrow() + assert len(result) == 10, f"Should have 10 total rows, got {len(result)}" + + df = result.to_pandas() + old_rows = df[df["a"] <= 5] + new_rows = df[df["a"] > 5] + + assert len(old_rows) == 5, "Should have 5 old rows" + assert len(new_rows) == 5, "Should have 5 new rows" + + # Old rows should have values for 'd' + assert not old_rows["d"].isna().any(), "Old rows should have values for 'd'" + # New rows should have NULL for 'd' + assert new_rows["d"].isna().all(), "New rows should have NULL for 'd'" + + print(" Verified: old rows have 'd' values, new rows have 'd' as NULL") + + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_derive_schema_from_iceberg_when_dlt_version_empty(): + """ + Test that schema is derived from Iceberg tables when _dlt_version is empty. + + When _dlt_version has no stored schema but the Iceberg table already exists, + the destination should derive the schema from the existing Iceberg table + metadata to avoid treating existing columns as "dropped". + + Scenario: + 1. Pipeline 1 creates table with columns [a, b, c, d] + 2. _dlt_version table is deleted (simulating corrupted/empty state) + 3. Pipeline 2 runs with data containing only [a, b, c] (missing 'd') + 4. Without schema derivation: fails with "columns dropped: d" + 5. With schema derivation: succeeds, 'd' is NULL for new rows + """ + temp_dir = tempfile.mkdtemp() + warehouse_path = f"{temp_dir}/warehouse" + catalog_path = f"{temp_dir}/catalog.db" + + print("\nTest: Derive Schema from Iceberg When _dlt_version Empty") + print(f" Warehouse: {warehouse_path}") + print(f" Catalog: {catalog_path}") + + try: + from dlt_iceberg import iceberg_rest + import uuid + + pipeline_name = f"test_derive_schema_{uuid.uuid4().hex[:8]}" + + # PIPELINE 1: Create table with columns a, b, c, d + @dlt.resource(name="test_table", write_disposition="append") + def generate_data_v1(): + for i in range(1, 6): + yield { + "a": i, + "b": i * 10, + "c": f"value_{i}", + "d": f"optional_{i}", + } + + pipeline1 = dlt.pipeline( + pipeline_name=pipeline_name, + destination=iceberg_rest( + catalog_uri=f"sqlite:///{catalog_path}", + warehouse=f"file://{warehouse_path}", + namespace="test", + ), + dataset_name="test_dataset", + ) + + print("\nPipeline 1: Create table with [a, b, c, d]") + load_info1 = pipeline1.run(generate_data_v1()) + assert not load_info1.has_failed_jobs, "Pipeline 1 load should succeed" + print(" Pipeline 1 completed successfully") + + # Verify table was created with all columns + catalog = load_catalog( + "dlt_catalog", + type="sql", + uri=f"sqlite:///{catalog_path}", + warehouse=f"file://{warehouse_path}", + ) + table = catalog.load_table("test.test_table") + table_columns = [f.name for f in table.schema().fields] + print(f" Table columns: {table_columns}") + assert "d" in table_columns, "Column 'd' should exist in table" + + # Verify _dlt_version exists + version_table = catalog.load_table("test._dlt_version") + version_data = version_table.scan().to_arrow() + print(f" _dlt_version has {len(version_data)} rows") + assert len(version_data) > 0, "_dlt_version should have data" + + # DELETE _dlt_version to simulate corrupted/empty state + catalog.drop_table("test._dlt_version") + print(" Deleted _dlt_version table (simulating empty/corrupted state)") + + # Drop pipeline 1's local state to simulate different execution context + pipeline1.drop() + print(" Dropped pipeline 1 local state") + + # PIPELINE 2: Fresh instance, data missing column 'd' + @dlt.resource(name="test_table", write_disposition="append") + def generate_data_v2(): + for i in range(6, 11): + yield { + "a": i, + "b": i * 10, + "c": f"value_{i}", + # 'd' is intentionally missing + } + + pipeline2 = dlt.pipeline( + pipeline_name=pipeline_name, + destination=iceberg_rest( + catalog_uri=f"sqlite:///{catalog_path}", + warehouse=f"file://{warehouse_path}", + namespace="test", + ), + dataset_name="test_dataset", + ) + + print("\nPipeline 2: Append data with [a, b, c] (missing 'd', no _dlt_version)") + + # This should succeed by deriving schema from existing Iceberg table + load_info2 = pipeline2.run(generate_data_v2()) + assert not load_info2.has_failed_jobs, ( + "Pipeline 2 should succeed by deriving schema from Iceberg table" + ) + print(" Pipeline 2 completed successfully") + + # Verify all data is present + table = catalog.load_table("test.test_table") + result = table.scan().to_arrow() + assert len(result) == 10, f"Should have 10 total rows, got {len(result)}" + + df = result.to_pandas() + old_rows = df[df["a"] <= 5] + new_rows = df[df["a"] > 5] + + assert len(old_rows) == 5, "Should have 5 old rows" + assert len(new_rows) == 5, "Should have 5 new rows" + + # Old rows should have values for 'd' + assert not old_rows["d"].isna().any(), "Old rows should have values for 'd'" + # New rows should have NULL for 'd' + assert new_rows["d"].isna().all(), "New rows should have NULL for 'd'" + + print(" Verified: old rows have 'd' values, new rows have 'd' as NULL") + + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_with_state_sync.py b/tests/test_with_state_sync.py new file mode 100644 index 0000000..4665acb --- /dev/null +++ b/tests/test_with_state_sync.py @@ -0,0 +1,410 @@ +""" +Unit tests for WithStateSync implementation. + +Tests the schema and state storage/retrieval methods directly. +""" + +import pytest +import tempfile +import shutil +import json +from datetime import datetime + +import pyarrow as pa +from pyiceberg.catalog import load_catalog +from pyiceberg.schema import Schema as IcebergSchema +from pyiceberg.types import NestedField, StringType, LongType, TimestampType + +import dlt +from dlt_iceberg import iceberg_rest + + +class TestWithStateSyncMethods: + """Unit tests for WithStateSync interface methods.""" + + @pytest.fixture + def catalog_setup(self): + """Create a temporary catalog for testing.""" + temp_dir = tempfile.mkdtemp() + warehouse_path = f"{temp_dir}/warehouse" + catalog_path = f"{temp_dir}/catalog.db" + + # Use "dlt_catalog" name to match what IcebergRestClient uses internally + catalog = load_catalog( + "dlt_catalog", + type="sql", + uri=f"sqlite:///{catalog_path}", + warehouse=f"file://{warehouse_path}", + ) + + # Create namespace + catalog.create_namespace("test") + + yield { + "temp_dir": temp_dir, + "warehouse_path": warehouse_path, + "catalog_path": catalog_path, + "catalog": catalog, + } + + shutil.rmtree(temp_dir, ignore_errors=True) + + @pytest.fixture + def client(self, catalog_setup): + """Create an IcebergRestClient for testing.""" + from dlt_iceberg.destination_client import ( + IcebergRestClient, + IcebergRestConfiguration, + iceberg_rest_class_based, + ) + from dlt.common.schema import Schema + + config = IcebergRestConfiguration( + catalog_uri=f"sqlite:///{catalog_setup['catalog_path']}", + warehouse=f"file://{catalog_setup['warehouse_path']}", + namespace="test", + ) + + schema = Schema("test_schema") + # Get capabilities from the destination class + dest = iceberg_rest_class_based() + capabilities = dest._raw_capabilities() + + client = IcebergRestClient(schema, config, capabilities) + + return client + + def test_get_stored_schema_returns_none_when_table_missing(self, client): + """Test that get_stored_schema returns None when _dlt_version doesn't exist.""" + print("\nTest: get_stored_schema returns None when table missing") + + result = client.get_stored_schema() + assert result is None, "Should return None when _dlt_version table doesn't exist" + print(" Correctly returned None") + + def test_get_stored_schema_by_hash_returns_none_when_table_missing(self, client): + """Test that get_stored_schema_by_hash returns None when _dlt_version doesn't exist.""" + print("\nTest: get_stored_schema_by_hash returns None when table missing") + + result = client.get_stored_schema_by_hash("nonexistent_hash") + assert result is None, "Should return None when _dlt_version table doesn't exist" + print(" Correctly returned None") + + def test_get_stored_state_returns_none_when_table_missing(self, client): + """Test that get_stored_state returns None when _dlt_pipeline_state doesn't exist.""" + print("\nTest: get_stored_state returns None when table missing") + + result = client.get_stored_state("nonexistent_pipeline") + assert result is None, "Should return None when _dlt_pipeline_state table doesn't exist" + print(" Correctly returned None") + + def test_get_stored_schema_retrieves_written_schema(self, catalog_setup, client): + """Test that get_stored_schema retrieves a schema that was written.""" + print("\nTest: get_stored_schema retrieves written schema") + + catalog = catalog_setup["catalog"] + + # Create _dlt_version table with test data + iceberg_schema = IcebergSchema( + NestedField(1, "version_hash", StringType(), required=True), + NestedField(2, "schema_name", StringType(), required=True), + NestedField(3, "version", LongType(), required=True), + NestedField(4, "engine_version", LongType(), required=True), + NestedField(5, "inserted_at", TimestampType(), required=True), + NestedField(6, "schema", StringType(), required=True), + ) + + version_table = catalog.create_table( + identifier="test._dlt_version", + schema=iceberg_schema, + ) + + # Insert test schema + test_schema_dict = {"name": "test_schema", "tables": {}} + arrow_schema = pa.schema([ + pa.field("version_hash", pa.string(), nullable=False), + pa.field("schema_name", pa.string(), nullable=False), + pa.field("version", pa.int64(), nullable=False), + pa.field("engine_version", pa.int64(), nullable=False), + pa.field("inserted_at", pa.timestamp("us"), nullable=False), + pa.field("schema", pa.string(), nullable=False), + ]) + arrow_table = pa.table({ + "version_hash": ["hash123"], + "schema_name": ["test_schema"], + "version": [1], + "engine_version": [1], + "inserted_at": [datetime.now()], + "schema": [json.dumps(test_schema_dict)], + }, schema=arrow_schema) + + version_table.append(arrow_table) + print(" Created _dlt_version table with test data") + + # Retrieve schema + result = client.get_stored_schema("test_schema") + assert result is not None, "Should retrieve stored schema" + assert result.version_hash == "hash123", "Should have correct version_hash" + assert result.schema_name == "test_schema", "Should have correct schema_name" + assert result.version == 1, "Should have correct version" + print(f" Retrieved schema: {result.schema_name} v{result.version}") + + def test_get_stored_schema_returns_newest_version(self, catalog_setup, client): + """Test that get_stored_schema returns the newest version when multiple exist.""" + print("\nTest: get_stored_schema returns newest version") + + catalog = catalog_setup["catalog"] + + # Create _dlt_version table + iceberg_schema = IcebergSchema( + NestedField(1, "version_hash", StringType(), required=True), + NestedField(2, "schema_name", StringType(), required=True), + NestedField(3, "version", LongType(), required=True), + NestedField(4, "engine_version", LongType(), required=True), + NestedField(5, "inserted_at", TimestampType(), required=True), + NestedField(6, "schema", StringType(), required=True), + ) + + version_table = catalog.create_table( + identifier="test._dlt_version", + schema=iceberg_schema, + ) + + # Insert multiple versions + arrow_schema = pa.schema([ + pa.field("version_hash", pa.string(), nullable=False), + pa.field("schema_name", pa.string(), nullable=False), + pa.field("version", pa.int64(), nullable=False), + pa.field("engine_version", pa.int64(), nullable=False), + pa.field("inserted_at", pa.timestamp("us"), nullable=False), + pa.field("schema", pa.string(), nullable=False), + ]) + arrow_table = pa.table({ + "version_hash": ["hash_v1", "hash_v2", "hash_v3"], + "schema_name": ["test_schema", "test_schema", "test_schema"], + "version": [1, 2, 3], + "engine_version": [1, 1, 1], + "inserted_at": [datetime.now(), datetime.now(), datetime.now()], + "schema": ['{"v": 1}', '{"v": 2}', '{"v": 3}'], + }, schema=arrow_schema) + + version_table.append(arrow_table) + print(" Created _dlt_version table with 3 versions") + + # Should return newest (version 3) + result = client.get_stored_schema("test_schema") + assert result is not None, "Should retrieve stored schema" + assert result.version == 3, f"Should return newest version (3), got {result.version}" + assert result.version_hash == "hash_v3", "Should return hash for version 3" + print(f" Correctly returned newest version: {result.version}") + + def test_get_stored_schema_by_hash_retrieves_exact_match(self, catalog_setup, client): + """Test that get_stored_schema_by_hash retrieves exact hash match.""" + print("\nTest: get_stored_schema_by_hash retrieves exact match") + + catalog = catalog_setup["catalog"] + + # Create _dlt_version table with multiple versions + iceberg_schema = IcebergSchema( + NestedField(1, "version_hash", StringType(), required=True), + NestedField(2, "schema_name", StringType(), required=True), + NestedField(3, "version", LongType(), required=True), + NestedField(4, "engine_version", LongType(), required=True), + NestedField(5, "inserted_at", TimestampType(), required=True), + NestedField(6, "schema", StringType(), required=True), + ) + + version_table = catalog.create_table( + identifier="test._dlt_version", + schema=iceberg_schema, + ) + + arrow_schema = pa.schema([ + pa.field("version_hash", pa.string(), nullable=False), + pa.field("schema_name", pa.string(), nullable=False), + pa.field("version", pa.int64(), nullable=False), + pa.field("engine_version", pa.int64(), nullable=False), + pa.field("inserted_at", pa.timestamp("us"), nullable=False), + pa.field("schema", pa.string(), nullable=False), + ]) + arrow_table = pa.table({ + "version_hash": ["hash_v1", "hash_v2"], + "schema_name": ["test_schema", "test_schema"], + "version": [1, 2], + "engine_version": [1, 1], + "inserted_at": [datetime.now(), datetime.now()], + "schema": ['{"v": 1}', '{"v": 2}'], + }, schema=arrow_schema) + + version_table.append(arrow_table) + + # Get by specific hash + result = client.get_stored_schema_by_hash("hash_v1") + assert result is not None, "Should retrieve schema by hash" + assert result.version == 1, "Should return version 1 for hash_v1" + print(f" Retrieved version {result.version} for hash_v1") + + # Non-existent hash + result_none = client.get_stored_schema_by_hash("nonexistent") + assert result_none is None, "Should return None for non-existent hash" + print(" Correctly returned None for non-existent hash") + + def test_get_stored_state_retrieves_newest_state(self, catalog_setup, client): + """Test that get_stored_state retrieves the newest state for a pipeline.""" + print("\nTest: get_stored_state retrieves newest state") + + catalog = catalog_setup["catalog"] + + # Create _dlt_pipeline_state table + iceberg_schema = IcebergSchema( + NestedField(1, "version", LongType(), required=True), + NestedField(2, "engine_version", LongType(), required=True), + NestedField(3, "pipeline_name", StringType(), required=True), + NestedField(4, "state", StringType(), required=True), + NestedField(5, "created_at", TimestampType(), required=True), + NestedField(6, "version_hash", StringType(), required=False), + NestedField(7, "_dlt_load_id", StringType(), required=False), + ) + + state_table = catalog.create_table( + identifier="test._dlt_pipeline_state", + schema=iceberg_schema, + ) + + # Insert multiple states + arrow_schema = pa.schema([ + pa.field("version", pa.int64(), nullable=False), + pa.field("engine_version", pa.int64(), nullable=False), + pa.field("pipeline_name", pa.string(), nullable=False), + pa.field("state", pa.string(), nullable=False), + pa.field("created_at", pa.timestamp("us"), nullable=False), + pa.field("version_hash", pa.string(), nullable=True), + pa.field("_dlt_load_id", pa.string(), nullable=True), + ]) + arrow_table = pa.table({ + "version": [1, 2], + "engine_version": [1, 1], + "pipeline_name": ["my_pipeline", "my_pipeline"], + "state": ['{"old": true}', '{"new": true}'], + "created_at": [ + datetime(2024, 1, 1, 10, 0, 0), + datetime(2024, 1, 2, 10, 0, 0), # Newer + ], + "version_hash": ["hash1", "hash2"], + "_dlt_load_id": ["load1", "load2"], + }, schema=arrow_schema) + + state_table.append(arrow_table) + print(" Created _dlt_pipeline_state table with 2 states") + + # Should return newest state + result = client.get_stored_state("my_pipeline") + assert result is not None, "Should retrieve stored state" + assert result.version == 2, f"Should return newest version, got {result.version}" + assert '{"new": true}' in result.state, "Should return newest state data" + print(f" Retrieved newest state version: {result.version}") + + # Non-existent pipeline + result_none = client.get_stored_state("other_pipeline") + assert result_none is None, "Should return None for non-existent pipeline" + print(" Correctly returned None for non-existent pipeline") + + def test_derive_schema_from_iceberg_tables(self, catalog_setup, client): + """Test that schema can be derived from existing Iceberg tables.""" + print("\nTest: derive schema from existing Iceberg tables") + + catalog = catalog_setup["catalog"] + + # Create an Iceberg table directly (simulating existing table) + iceberg_schema = IcebergSchema( + NestedField(1, "id", LongType(), required=True), + NestedField(2, "name", StringType(), required=False), + NestedField(3, "value", LongType(), required=False), + ) + + catalog.create_table( + identifier="test.user_data", + schema=iceberg_schema, + ) + print(" Created test.user_data table with [id, name, value]") + + # _dlt_version does NOT exist - fallback should derive from Iceberg + result = client.get_stored_schema("test_schema") + assert result is not None, "Should derive schema from Iceberg tables" + assert result.version_hash == "derived_from_iceberg", "Should mark as derived" + + # Parse the schema and verify table was derived + schema_dict = json.loads(result.schema) + assert "user_data" in schema_dict["tables"], "Should include user_data table" + + user_data = schema_dict["tables"]["user_data"] + assert "id" in user_data["columns"], "Should have 'id' column" + assert "name" in user_data["columns"], "Should have 'name' column" + assert "value" in user_data["columns"], "Should have 'value' column" + + # Check column properties + assert user_data["columns"]["id"]["nullable"] is False, "id should be required" + assert user_data["columns"]["name"]["nullable"] is True, "name should be nullable" + + print(f" Derived schema with table: user_data ({len(user_data['columns'])} columns)") + + def test_get_stored_schema_prefers_dlt_version_over_derivation(self, catalog_setup, client): + """Test that stored schema in _dlt_version takes precedence over derivation.""" + print("\nTest: _dlt_version takes precedence over Iceberg derivation") + + catalog = catalog_setup["catalog"] + + # Create _dlt_version table with stored schema + iceberg_schema = IcebergSchema( + NestedField(1, "version_hash", StringType(), required=True), + NestedField(2, "schema_name", StringType(), required=True), + NestedField(3, "version", LongType(), required=True), + NestedField(4, "engine_version", LongType(), required=True), + NestedField(5, "inserted_at", TimestampType(), required=True), + NestedField(6, "schema", StringType(), required=True), + ) + + version_table = catalog.create_table( + identifier="test._dlt_version", + schema=iceberg_schema, + ) + + test_schema_dict = {"name": "test_schema", "tables": {"stored_table": {}}} + arrow_schema = pa.schema([ + pa.field("version_hash", pa.string(), nullable=False), + pa.field("schema_name", pa.string(), nullable=False), + pa.field("version", pa.int64(), nullable=False), + pa.field("engine_version", pa.int64(), nullable=False), + pa.field("inserted_at", pa.timestamp("us"), nullable=False), + pa.field("schema", pa.string(), nullable=False), + ]) + arrow_table = pa.table({ + "version_hash": ["stored_hash"], + "schema_name": ["test_schema"], + "version": [5], + "engine_version": [1], + "inserted_at": [datetime.now()], + "schema": [json.dumps(test_schema_dict)], + }, schema=arrow_schema) + + version_table.append(arrow_table) + print(" Created _dlt_version with stored schema") + + # Also create an Iceberg table (which should NOT be used) + user_schema = IcebergSchema( + NestedField(1, "id", LongType(), required=True), + ) + catalog.create_table(identifier="test.user_data", schema=user_schema) + print(" Created test.user_data table") + + # get_stored_schema should return the stored schema, not derived + result = client.get_stored_schema("test_schema") + assert result is not None, "Should return stored schema" + assert result.version_hash == "stored_hash", "Should return stored schema, not derived" + assert result.version == 5, "Should return stored version" + + print(" Correctly returned stored schema (not derived)") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])