diff --git a/clean_data.py b/clean_data.py index 9a90676..70d7f4d 100644 --- a/clean_data.py +++ b/clean_data.py @@ -5,7 +5,7 @@ import argparse from datetime import datetime -from typing import List, Optional +from typing import List, Optional, Dict import polars as pl # ---------------- Defaults ---------------- @@ -17,7 +17,7 @@ "%Y-%m-%d %H:%M:%S", # 2024-11-30 19:50:30 "%Y-%m-%dT%H:%M:%S", # ISO no fractional seconds "%Y-%m-%dT%H:%M:%S%.f", # ISO with millis - # Add AM/PM variant if needed: + # Add AM/PM if needed: # "%m/%d/%Y %I:%M %p", ] @@ -56,9 +56,12 @@ def blanks_to_null_and_trim(lf: pl.LazyFrame) -> pl.LazyFrame: str_cols = [c for c, dt in zip(lf.columns, lf.dtypes) if dt == pl.Utf8] if not str_cols: return lf - # Trim - lf = lf.with_columns([pl.col(c).cast(pl.Utf8, strict=False).str.strip_chars().alias(c) for c in str_cols]) - # Empty -> null + # 1. trim whitespace + lf = lf.with_columns([ + pl.col(c).cast(pl.Utf8, strict=False).str.strip_chars().alias(c) + for c in str_cols + ]) + # 2. "" -> null lf = lf.with_columns([ pl.when((pl.col(c) == "") | (pl.col(c).str.len_chars() == 0)) .then(pl.lit(None, dtype=pl.Utf8)) @@ -70,13 +73,17 @@ def blanks_to_null_and_trim(lf: pl.LazyFrame) -> pl.LazyFrame: def parse_strptime_multi(col: str, fmts: List[str], dtype): """ - Try parsing a string column with several formats; take the first that succeeds. - Unparseable values become null (strict=False), then coalesce. + Try several datetime/date formats on a string column. + Take first non-null parse result via coalesce(). + strict=False -> unparseable values become null instead of throwing. """ exprs = [ - (pl.col(col).cast(pl.Utf8, strict=False) - .str.strip_chars() - .str.strptime(dtype, format=f, strict=False, exact=True)) + ( + pl.col(col) + .cast(pl.Utf8, strict=False) + .str.strip_chars() + .str.strptime(dtype, format=f, strict=False, exact=True) + ) for f in fmts ] return pl.coalesce(exprs).alias(col) @@ -85,15 +92,26 @@ def parse_and_normalize( lf: pl.LazyFrame, dt_formats: List[str], date_formats: List[str], + col_dtypes: Dict[str, pl.datatypes.DataType], ) -> pl.LazyFrame: cols = set(lf.columns) # --- Date/time --- + # Only parse sale_date_time if it's present AND currently Utf8 if "sale_date_time" in cols: - lf = lf.with_columns(parse_strptime_multi("sale_date_time", dt_formats, pl.Datetime)) + if col_dtypes.get("sale_date_time") == pl.Utf8: + lf = lf.with_columns( + parse_strptime_multi("sale_date_time", dt_formats, pl.Datetime) + ) + # else: assume it's already a proper Datetime-like column, leave it + # Only parse sale_date if it's present AND currently Utf8 if "sale_date" in cols: - lf = lf.with_columns(parse_strptime_multi("sale_date", date_formats, pl.Date)) + if col_dtypes.get("sale_date") == pl.Utf8: + lf = lf.with_columns( + parse_strptime_multi("sale_date", date_formats, pl.Date) + ) + # else: assume already a Date column # --- Numeric coercions --- for c, dtype in [ @@ -109,7 +127,7 @@ def parse_and_normalize( if "item_id" in cols: lf = lf.with_columns(pl.col("item_id").cast(pl.Int64, strict=False).alias("item_id")) - # --- Canonicalize text (UPPERCASE for stable categories) --- + # --- Canonicalize text (UPPERCASE for consistency) --- for c in ["store_format", "command_name", "return_ind", "price_status", "site_name", "item_desc"]: if c in cols: lf = lf.with_columns( @@ -130,13 +148,17 @@ def impute_numerics(lf: pl.LazyFrame, strategy: Optional[str], candidates: List[ return lf.with_columns([pl.col(c).fill_null(0).alias(c) for c in cols]) if strategy in {"mean", "median"}: + # compute stats once (collect just the aggs, tiny in memory) aggs = [ (pl.col(c).mean() if strategy == "mean" else pl.col(c).median()).alias(c) for c in cols ] s = lf.select(aggs).collect() fills = {c: s[c][0] for c in cols} - return lf.with_columns([pl.col(c).fill_null(fills[c]).alias(c) for c in cols]) + return lf.with_columns([ + pl.col(c).fill_null(fills[c]).alias(c) + for c in cols + ]) return lf @@ -148,49 +170,104 @@ def drop_required_nulls(lf: pl.LazyFrame, required: List[str]) -> pl.LazyFrame: def dedupe(lf: pl.LazyFrame, keys: List[str]) -> pl.LazyFrame: keys = [k for k in keys if k in lf.columns] - return lf.unique(subset=keys, keep="first", maintain_order=True) if keys else lf.unique(keep="first", maintain_order=True) + return ( + lf.unique(subset=keys, keep="first", maintain_order=True) + if keys + else lf.unique(keep="first", maintain_order=True) + ) # ---------------- Main ---------------- def main(): - ap = argparse.ArgumentParser(description="Clean a large Parquet dataset (missing values, duplicates, formats).") - ap.add_argument("--input", required=True, help="Parquet file or glob (local path or cloud URI).") - ap.add_argument("--output", required=True, help="Output Parquet file or directory.") - ap.add_argument("--compression", default="zstd", - choices=["zstd", "snappy", "gzip", "lz4", "uncompressed"]) - ap.add_argument("--rechunk", action="store_true", help="Collect and rechunk before writing (uses more RAM).") + ap = argparse.ArgumentParser( + description="Clean a large Parquet dataset (missing values, duplicates, formats)." + ) + ap.add_argument( + "--input", + required=True, + help="Parquet file or glob (local path or cloud URI).", + ) + ap.add_argument( + "--output", + required=True, + help="Output Parquet file or directory.", + ) + ap.add_argument( + "--compression", + default="zstd", + choices=["zstd", "snappy", "gzip", "lz4", "uncompressed"], + ) + ap.add_argument( + "--rechunk", + action="store_true", + help="Collect and rechunk before writing (uses more RAM).", + ) # Handling knobs - ap.add_argument("--impute-numerics", choices=["zero", "mean", "median"], default="median", - help="Impute numeric nulls after type casting.") - ap.add_argument("--required", - nargs="*", default=["sale_date_time","site_id","slip_no","line","extension_amount","qty"], - help="Rows missing any of these are dropped (after imputation).") - ap.add_argument("--dedupe-by", - nargs="*", default=["sale_date_time","site_id","slip_no","line","item_id","qty","extension_amount"], - help="Columns that define duplicate rows.") - - # Date window - ap.add_argument("--min-date", default=None, help="Drop rows with sale_date_time before this (YYYY-MM-DD).") - ap.add_argument("--max-date", default=None, help="Drop rows with sale_date_time after this (YYYY-MM-DD).") + ap.add_argument( + "--impute-numerics", + choices=["zero", "mean", "median"], + default="median", + help="Impute numeric nulls after type casting.", + ) + ap.add_argument( + "--required", + nargs="*", + default=["sale_date_time", "site_id", "slip_no", "line", "extension_amount", "qty"], + help="Rows missing any of these are dropped (after imputation).", + ) + ap.add_argument( + "--dedupe-by", + nargs="*", + default=["sale_date_time", "site_id", "slip_no", "line", "item_id", "qty", "extension_amount"], + help="Columns that define duplicate rows.", + ) + + # Date window filters + ap.add_argument( + "--min-date", + default=None, + help="Drop rows with sale_date_time before this (YYYY-MM-DD).", + ) + ap.add_argument( + "--max-date", + default=None, + help="Drop rows with sale_date_time after this (YYYY-MM-DD).", + ) # Soft sanity constraints - ap.add_argument("--nonnegative-qty", action="store_true", help="Make QTY nonnegative (abs).") - ap.add_argument("--nonnegative-amount", action="store_true", help="Make EXTENSION_AMOUNT nonnegative (abs).") + ap.add_argument( + "--nonnegative-qty", + action="store_true", + help="Make QTY nonnegative (abs).", + ) + ap.add_argument( + "--nonnegative-amount", + action="store_true", + help="Make EXTENSION_AMOUNT nonnegative (abs).", + ) - # Format overrides - ap.add_argument("--dt-format", action="append", default=[], - help="Extra datetime format(s) to try (can repeat).") - ap.add_argument("--date-format", action="append", default=[], - help="Extra date format(s) to try (can repeat).") + # Optional extra datetime formats from CLI + ap.add_argument( + "--dt-format", + action="append", + default=[], + help="Extra datetime format(s) to try for sale_date_time (can repeat).", + ) + ap.add_argument( + "--date-format", + action="append", + default=[], + help="Extra date format(s) to try for sale_date (can repeat).", + ) args = ap.parse_args() - # Build final format lists (user-provided take precedence) + # Merge default + custom formats (keep order, dedupe by dict trick) dt_formats = list(dict.fromkeys(args.dt_format + DEFAULT_DT_FORMATS)) date_formats = list(dict.fromkeys(args.date_format + DEFAULT_DATE_FORMATS)) - # 1) Lazy scan (version-safe args) + # 1) Lazy scan lf = pl.scan_parquet( args.input, low_memory=True, @@ -198,41 +275,71 @@ def main(): hive_partitioning=True, ) - # 2) Column names → snake_case + # 2) Standardize column names -> snake_case lf = standardize_names(lf) - # 3) Clean strings and blank→null + # 3) Clean strings and convert "" -> null lf = blanks_to_null_and_trim(lf) - # 4) Parse & normalize columns - lf = parse_and_normalize(lf, dt_formats=dt_formats, date_formats=date_formats) + # Build a dtype map after standardization/blank cleanup + # This tells us which columns are Utf8, Int32, etc *right now* + col_dtypes = dict(zip(lf.columns, lf.dtypes)) + + # 4) Parse and normalize (only parse sale_date_time / sale_date if they are still strings) + lf = parse_and_normalize( + lf, + dt_formats=dt_formats, + date_formats=date_formats, + col_dtypes=col_dtypes, + ) - # 5) Date window filters (use Python datetime for compatibility) + # 5) Date window filters (use Python datetime for broad Polars compatibility) if args.min_date and "sale_date_time" in lf.columns: dt_min = datetime.fromisoformat(args.min_date) - lf = lf.filter(pl.col("sale_date_time").is_null() | (pl.col("sale_date_time") >= pl.lit(dt_min))) + lf = lf.filter( + pl.col("sale_date_time").is_null() + | (pl.col("sale_date_time") >= pl.lit(dt_min)) + ) if args.max_date and "sale_date_time" in lf.columns: dt_max = datetime.fromisoformat(args.max_date) - lf = lf.filter(pl.col("sale_date_time").is_null() | (pl.col("sale_date_time") <= pl.lit(dt_max))) + lf = lf.filter( + pl.col("sale_date_time").is_null() + | (pl.col("sale_date_time") <= pl.lit(dt_max)) + ) # 6) Impute numeric nulls numeric_candidates = [ - c for c, dt in zip(lf.columns, lf.dtypes) - if dt in (pl.Int8, pl.Int16, pl.Int32, pl.Int64, - pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64, - pl.Float32, pl.Float64) + c + for c, dt in zip(lf.columns, lf.dtypes) + if dt + in ( + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Float32, + pl.Float64, + ) ] lf = impute_numerics(lf, args.impute_numerics, numeric_candidates) - # 7) Drop rows missing required fields + # 7) Drop rows missing required columns lf = drop_required_nulls(lf, args.required) # 8) Soft sanity constraints if args.nonnegative_qty and "qty" in lf.columns: lf = lf.with_columns( - pl.when(pl.col("qty") < 0).then(-pl.col("qty")).otherwise(pl.col("qty")).alias("qty") + pl.when(pl.col("qty") < 0) + .then(-pl.col("qty")) + .otherwise(pl.col("qty")) + .alias("qty") ) + if args.nonnegative_amount and "extension_amount" in lf.columns: lf = lf.with_columns( pl.when(pl.col("extension_amount") < 0) @@ -241,24 +348,29 @@ def main(): .alias("extension_amount") ) - # 9) Dedupe + # 9) Dedupe rows lf = dedupe(lf, args.dedupe_by) - # 10) Execute & write + # 10) Execute and write if args.rechunk: print("⚙️ Rechunking: collecting to DataFrame (uses RAM) ...") df = lf.collect().rechunk() - # Polars versions differ on write_parquet kwargs; try with statistics then fallback + # Polars' write_parquet() signature can vary across versions, + # so try with statistics first, then fallback. try: df.write_parquet(args.output, compression=args.compression, statistics=True) except TypeError: df.write_parquet(args.output, compression=args.compression) else: - # Keep it lazy; write directly if available, else collect->write + # Prefer lazy sink_parquet if available, else fallback to collect+write. try: - lf.sink_parquet(args.output, compression=args.compression, maintain_order=True, statistics=True) + lf.sink_parquet( + args.output, + compression=args.compression, + maintain_order=True, + statistics=True, + ) except AttributeError: - # Older Polars may not have sink_parquet; fallback df = lf.collect() try: df.write_parquet(args.output, compression=args.compression, statistics=True)