-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbootstrap_SCV.py
More file actions
43 lines (38 loc) · 1.97 KB
/
bootstrap_SCV.py
File metadata and controls
43 lines (38 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import argparse
import numpy as np
from utils.evaluation import eval_SCV, bootstrap_SCV
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-a', '--AV', type=int, choices=[0, 1, 2],
default=0, help='The AV model under test')
parser.add_argument('-b', '--num_bootstrap', type=int,
default=1000, help='Number of bootstrap')
parser.add_argument('-s', '--num_separation', type=int,
default=100, help='Number of separation')
parser.add_argument('-t', '--RHW_threshold', type=float,
default=0.3, help='The threshold for relative half-width')
args = parser.parse_args()
AV = args.AV
num_bootstrap = args.num_bootstrap
num_separation = args.num_separation
RHW_threshold = args.RHW_threshold
print(f'Bootstrapping SCV, AV = {AV}, num_bootstrap = {num_bootstrap}')
crash_NADE = np.load(f'results/crash_NADE_AV_{AV}.npy').flatten()
control_step = np.load(f'results/control_step_AV_{AV}.npy').flatten()
control_variates = np.load(f'results/control_variates_AV_{AV}.npy', allow_pickle=True)
CV = {}
for i in range(len(control_variates)):
for j in range(len(control_variates[0])):
CV[j + len(control_variates[0]) * i] = control_variates[i][j]
crash_SCV_dict, rhw_SCV = eval_SCV(crash_NADE, control_step, CV, sep=num_separation)
crash_SCV = np.zeros_like(crash_NADE)
for l in list(crash_SCV_dict.keys()):
idx = (control_step == l)
crash_SCV[idx] = crash_SCV_dict[l]
np.save(f"results/crash_SCV_AV_{AV}", crash_SCV)
np.save(f"results/rhw_SCV_total_AV_{AV}", rhw_SCV)
RNoT_SCV = bootstrap_SCV(crash_NADE, control_step, CV,
n_bootstrap=num_bootstrap,
sep=num_separation,
RHW_threshold=RHW_threshold)
np.save(f'results/RNoT_SCV_bootstrap_{num_bootstrap}_AV_{AV}', RNoT_SCV)