Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# v0.4.5

- Added support for polymorphic foreign keys
- Removed Python 3.8, 3.9 support and added 3.13, 3.14 support
- Updated dependencies

# v0.4.4

- Improved query performance following foreign key relationships
Expand Down
2 changes: 1 addition & 1 deletion subsetter/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.4"
__version__ = "0.4.5"
2 changes: 1 addition & 1 deletion subsetter/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _push(key: Any, value: Any):
data = stack.pop()
if isinstance(data, BaseModel):
yield data
for field, _ in data.model_fields.items():
for field, _ in data.__class__.model_fields.items():
_push(field, getattr(data, field))

if isinstance(data, list):
Expand Down
27 changes: 27 additions & 0 deletions subsetter/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,32 @@ def check_columns_match(self):
raise ValueError("each column in src_columns must be unique")
return self

class PolymorphicFKConfig(ForbidBaseModel):
class KeyDestination(ForbidBaseModel):
table: str
columns: List[str]

table: str
columns: List[str]
discriminator_column: str
destinations: dict[str, KeyDestination]

@model_validator(mode="after")
def check_columns_match(self):
col_count = len(self.columns)
if not col_count:
raise ValueError("columns cannot be empty")
if len(set(self.columns)) != col_count:
raise ValueError("each column in columns must be unique")
for key_dest in self.destinations.values():
if len(key_dest.columns) != col_count:
raise ValueError(
"src_columns and dst_columns must be the same length"
)
if len(set(key_dest.columns)) != col_count:
raise ValueError("each column in src_columns must be unique")
return self

class ColumnConstraint(ForbidBaseModel):
column: str
operator: SQLKnownOperator
Expand All @@ -55,6 +81,7 @@ class ColumnConstraint(ForbidBaseModel):
passthrough: List[str] = []
ignore_fks: List[IgnoreFKConfig] = []
extra_fks: List[ExtraFKConfig] = []
polymorphic_fks: List[PolymorphicFKConfig] = []
infer_foreign_keys: Literal["none", "schema", "all"] = "none"
include_dependencies: bool = True

Expand Down
4 changes: 4 additions & 0 deletions subsetter/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class ForeignKey:
dst_schema: str
dst_table: str
dst_columns: Tuple[str, ...]
src_discriminator: Optional[Tuple[str, str]] = None
dst_discriminator: Optional[Tuple[str, str]] = None

@classmethod
def from_schema(cls, fk: sa.ForeignKeyConstraint) -> "ForeignKey":
Expand Down Expand Up @@ -188,6 +190,8 @@ def compute_reverse_keys(self) -> None:
dst_schema=table.schema,
dst_table=table.name,
dst_columns=fk.columns,
src_discriminator=fk.dst_discriminator,
dst_discriminator=fk.src_discriminator,
)
)

Expand Down
75 changes: 37 additions & 38 deletions subsetter/plan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ class SQLLeftJoin(BaseModel):
left_columns: List[str]
right_columns: List[str]
half_unique: bool = True
left_discriminator: List[str] = []
right_discriminator: List[str] = []


class SQLStatementSelect(BaseModel):
Expand All @@ -245,8 +247,9 @@ class SQLStatementSelect(BaseModel):
from_: SQLTableIdentifier = Field(..., alias="from")
where: Optional[SQLWhereClause] = None
limit: Optional[int] = None
joins: Optional[List[SQLLeftJoin]] = None
joins_outer: bool = False

# Joins are combined in CNF format - inner lists of joins must have one matching joined row
joins: List[List[SQLLeftJoin]] = []

model_config = ConfigDict(populate_by_name=True)

Expand All @@ -259,50 +262,47 @@ def build(self, context: SQLBuildContext):
else:
stmt = sa.select(table_obj)

if self.joins:
joined_cols: List[sa.ColumnElement] = []
joined: sa.FromClause = table_obj
exists_constraints: List[sa.ColumnExpressionArgument] = []
for join in self.joins: # pylint: disable=not-an-iterable
joined: sa.FromClause = table_obj
join_and_conditions = []
for join_list in self.joins:
join_or_conditions: List[sa.ColumnExpressionArgument] = []
for join in join_list:
right = join.right.build(context).alias()

join_on = [
table_obj.c[lft_col] == right.c[rht_col]
for lft_col, rht_col in zip(join.left_columns, join.right_columns)
]
if join.left_discriminator:
disc_col, disc_val = join.left_discriminator
join_on.append(table_obj.c[disc_col] == disc_val)
if join.right_discriminator:
disc_col, disc_val = join.right_discriminator
join_on.append(right.c[disc_col] == disc_val)

if join.half_unique and table_obj.primary_key:
joined = joined.join(
right,
onclause=sa.and_(
*(
table_obj.c[lft_col] == right.c[rht_col]
for lft_col, rht_col in zip(
join.left_columns, join.right_columns
)
)
),
isouter=self.joins_outer,
)
joined_cols.extend(
right.c[rht_col] for rht_col in join.right_columns
onclause=sa.and_(*join_on),
isouter=len(join_list) > 1,
)
else:
exists_constraints.append(
sa.exists().where(
*(
table_obj.c[lft_col] == right.c[rht_col]
for lft_col, rht_col in zip(
join.left_columns, join.right_columns
)
)
if len(join_list) > 1:
join_or_conditions.extend(
right.c[rht_col].is_not(None)
for rht_col in join.right_columns
)
)
else:
join_or_conditions.append(sa.exists().where(*join_on))

if join_or_conditions:
join_and_conditions.append(sa.or_(*join_or_conditions))

stmt = stmt.select_from(joined)
if joined is not table_obj:
stmt = stmt.group_by(*table_obj.primary_key.columns)
stmt = stmt.select_from(joined)
if joined is not table_obj:
stmt = stmt.group_by(*table_obj.primary_key.columns)

if self.joins_outer:
exists_constraints.extend(col.is_not(None) for col in joined_cols)
stmt = stmt.where(sa.or_(*exists_constraints))
elif exists_constraints:
stmt = stmt.where(sa.and_(*exists_constraints))
if join_and_conditions:
stmt = stmt.where(sa.and_(*join_and_conditions))

if self.where:
stmt = stmt.where(self.where.build(context, table_obj))
Expand All @@ -329,7 +329,6 @@ def simplify(self) -> "SQLStatementSelect":
kwargs["limit"] = self.limit
if self.joins:
kwargs["joins"] = self.joins
kwargs["joins_outer"] = self.joins_outer

return SQLStatementSelect(**kwargs) # type: ignore

Expand Down
139 changes: 123 additions & 16 deletions subsetter/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _plan_internal(self) -> SubsetPlan:
)
self._remove_ignore_fks()
self._add_extra_fks()
self._add_polymorphic_fks()
if self.config.include_dependencies:
self._check_ignore_tables()
self._check_passthrough_tables()
Expand Down Expand Up @@ -202,6 +203,85 @@ def _add_extra_fks(self) -> None:
),
)

def _add_polymorphic_fks(self) -> None:
"""Add in configured polymorphic foreign keys requested."""
for index, poly_fk in enumerate(self.config.polymorphic_fks):
src_schema, src_table_name = parse_table_name(poly_fk.table)
table = self.meta.tables.get((src_schema, src_table_name))
if table is None:
LOGGER.warning(
"Found no source table %s.%s referenced in polymorphic_fks[%d]",
src_schema,
src_table_name,
index,
)
continue

src_missing_cols = {
col for col in poly_fk.columns if col not in table.table_obj.columns
}
if src_missing_cols:
LOGGER.warning(
"Columns %s do not exist in %s.%s referenced in poly_fks[%d]",
src_missing_cols,
src_schema,
src_table_name,
index,
)
continue

if poly_fk.discriminator_column not in table.table_obj.columns:
LOGGER.warning(
"Column %s does not exist in %s.%s referenced in poly_fks[%d].discriminator_column",
poly_fk.discriminator_column,
src_schema,
src_table_name,
index,
)
continue

for discriminator_value, key_dest in poly_fk.destinations.items():
dst_schema, dst_table_name = parse_table_name(key_dest.table)
dst_table = self.meta.tables.get((dst_schema, dst_table_name))
if dst_table is None:
LOGGER.warning(
"Found no destination table %s.%s referenced in poly_fks[%d].destinations[%s]",
dst_schema,
dst_table_name,
index,
discriminator_value,
)
continue

dst_missing_cols = {
col
for col in key_dest.columns
if col not in dst_table.table_obj.columns
}
if dst_missing_cols:
LOGGER.warning(
"Columns %s do not exist in %s.%s referenced in poly_fks[%d].destinations[%s]",
dst_missing_cols,
dst_schema,
dst_table_name,
index,
discriminator_value,
)
continue

table.foreign_keys.append(
ForeignKey(
columns=tuple(poly_fk.columns),
dst_schema=dst_schema,
dst_table=dst_table_name,
dst_columns=tuple(key_dest.columns),
src_discriminator=(
poly_fk.discriminator_column,
discriminator_value,
),
),
)

def _remove_ignore_fks(self) -> None:
"""Remove requested foreign keys"""
for ignore_fk in self.config.ignore_fks:
Expand Down Expand Up @@ -322,24 +402,51 @@ def _is_distinct(table_obj: sa.Table, cols: Iterable[str]) -> bool:
return True
return False

# Create joins in conjunctive normal form.
fks_to_join = []

# reverse foreign keys just get OR'ed together
if rev_foreign_keys:
fks_to_join.append(rev_foreign_keys)

# forward foreign keys get AND'ed except for polymorphic fks which OR when using the same
# discriminator column.
fk_disc_index: dict[str, int] = {}
for fk in foreign_keys:
if fk.src_discriminator:
disc_col = fk.src_discriminator[0]
if disc_col in fk_disc_index:
fks_to_join[fk_disc_index[disc_col]].append(fk)
else:
fk_disc_index[disc_col] = len(fks_to_join)
fks_to_join.append([fk])
else:
fks_to_join.append([fk])

fk_joins = []
for fk in foreign_keys or rev_foreign_keys:
dst_table = self.meta.tables[(fk.dst_schema, fk.dst_table)]
half_unique = _is_distinct(table.table_obj, fk.columns) or _is_distinct(
dst_table.table_obj, fk.dst_columns
)
fk_joins.append(
SQLLeftJoin(
right=SQLTableIdentifier(
table_schema=fk.dst_schema,
table_name=fk.dst_table,
sampled=True,
),
left_columns=list(fk.columns),
right_columns=list(fk.dst_columns),
half_unique=half_unique,
for fk_join_list in fks_to_join:
or_joins = []
for fk in fk_join_list:
dst_table = self.meta.tables[(fk.dst_schema, fk.dst_table)]
half_unique = _is_distinct(table.table_obj, fk.columns) or _is_distinct(
dst_table.table_obj, fk.dst_columns
)
)

or_joins.append(
SQLLeftJoin(
right=SQLTableIdentifier(
table_schema=fk.dst_schema,
table_name=fk.dst_table,
sampled=True,
),
left_columns=list(fk.columns),
right_columns=list(fk.dst_columns),
half_unique=half_unique,
left_discriminator=list(fk.src_discriminator or ()),
right_discriminator=list(fk.dst_discriminator or ()),
)
)
fk_joins.append(or_joins)

conf_constraints = self.config.table_constraints.get(
f"{table.schema}.{table.name}", []
Expand Down
19 changes: 9 additions & 10 deletions tests/data/big_join.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@ expected_plan:
schema: test
table: users
joins:
- half_unique: false
left_columns:
- state
right:
sampled: true
schema: test
table: homes
right_columns:
- state
joins_outer: true
- - half_unique: false
left_columns:
- state
right:
sampled: true
schema: test
table: homes
right_columns:
- state
type: select

expected_sample:
Expand Down
Loading
Loading