Skip to content

Commit 7df87ea

Browse files
Merge pull request #61 from yannrichet-tmp/fix/fzd-deduplicate-batch-inputs
fix(fzd): deduplicate batch inputs before calling fzr
2 parents 906ac18 + a16733b commit 7df87ea

1 file changed

Lines changed: 27 additions & 1 deletion

File tree

fz/core.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,15 @@ def fzr(
13651365
if not isinstance(input_variables, (dict, pd.DataFrame)):
13661366
raise TypeError(f"input_variables must be a dictionary or DataFrame, got {type(input_variables).__name__}")
13671367

1368+
# Reject duplicate rows in a DataFrame design: each row must be a distinct case
1369+
# (duplicate rows would map to the same temp directory and silently overwrite results)
1370+
if isinstance(input_variables, pd.DataFrame) and input_variables.duplicated().any():
1371+
dup_idx = input_variables[input_variables.duplicated(keep=False)].index.tolist()
1372+
raise ValueError(
1373+
f"input_variables DataFrame contains duplicate rows (indices {dup_idx}). "
1374+
"Each case must have a unique combination of input values."
1375+
)
1376+
13681377
if not isinstance(results_dir, (str, Path)):
13691378
raise TypeError(f"results_dir must be a string or Path, got {type(results_dir).__name__}")
13701379

@@ -1854,14 +1863,31 @@ def fzd(
18541863
# Also check renamed directory for cached results from previous runs
18551864
cache_paths.extend([f"cache://{renamed_results_dir / f'iter{j:03d}'}" for j in range(1, 100)]) # Check up to 99 iterations
18561865

1866+
# Deduplicate current_design: run each unique point only once
1867+
seen_keys: dict = {}
1868+
unique_design = []
1869+
index_map = [] # index_map[i] = row in unique_design for current_design[i]
1870+
for point in current_design:
1871+
key = tuple(sorted(point.items()))
1872+
if key not in seen_keys:
1873+
seen_keys[key] = len(unique_design)
1874+
unique_design.append(point)
1875+
index_map.append(seen_keys[key])
1876+
n_dupes = len(current_design) - len(unique_design)
1877+
if n_dupes:
1878+
log_info(f" ({n_dupes} duplicate point(s) removed from batch; results will be reused)")
1879+
18571880
result_df = fzr(
18581881
str(input_dir),
1859-
pd.DataFrame(current_design, columns=all_var_names),# All points in batch
1882+
pd.DataFrame(unique_design, columns=all_var_names),
18601883
model,
18611884
results_dir=str(iteration_result_dir),
18621885
calculators=[*cache_paths, *calculators] # Cache paths first, then actual calculators
18631886
)
18641887

1888+
# Expand result_df back to full current_design length (re-map duplicates)
1889+
result_df = result_df.iloc[index_map].reset_index(drop=True)
1890+
18651891
# Extract output values for each point
18661892
iteration_inputs = []
18671893
iteration_outputs = []

0 commit comments

Comments
 (0)