diff --git a/LGHackerton/predict.py b/LGHackerton/predict.py index 259bd3d..c1abd21 100644 --- a/LGHackerton/predict.py +++ b/LGHackerton/predict.py @@ -35,6 +35,11 @@ def _read_table(path: str) -> pd.DataFrame: def convert_to_submission(pred_df: pd.DataFrame, sample_path: str) -> pd.DataFrame: sample_df = _read_table(sample_path) + sample_df.columns = sample_df.columns.str.strip().str.lstrip('\ufeff') + + pred_df = pred_df.copy() + pred_df["series_id"] = pred_df["series_id"].str.replace("::", "_", n=1) + wide = pred_df.pivot(index="date", columns="series_id", values="yhat_ens") wide = wide.reindex(sample_df.iloc[:, 0]).reindex(columns=sample_df.columns[1:], fill_value=0.0)