forked from xupei0610/SocialVAE
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
362 lines (340 loc) · 17.2 KB
/
data.py
File metadata and controls
362 lines (340 loc) · 17.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
from typing import Optional, Sequence, List
import os, sys
import torch
import numpy as np
import multiprocessing
from concurrent.futures import ProcessPoolExecutor, as_completed, ThreadPoolExecutor
class Dataloader(torch.utils.data.Dataset):
class FixedNumberBatchSampler(torch.utils.data.sampler.BatchSampler):
def __init__(self, n_batches, *args, **kwargs):
super().__init__(*args, **kwargs)
self.n_batches = n_batches
self.sampler_iter = None #iter(self.sampler)
def __iter__(self):
# same with BatchSampler, but StopIteration every n batches
counter = 0
batch = []
while True:
if counter >= self.n_batches:
break
if self.sampler_iter is None:
self.sampler_iter = iter(self.sampler)
try:
idx = next(self.sampler_iter)
except StopIteration:
self.sampler_iter = None
if self.drop_last: batch = []
continue
batch.append(idx)
if len(batch) == self.batch_size:
counter += 1
yield batch
batch = []
def __init__(self,
files: List[str], ob_horizon: int, pred_horizon: int,
batch_size: int, drop_last: bool=False, shuffle: bool=False, batches_per_epoch=None,
frameskip: int=1, max_overlap: int=1, inclusive_groups: Optional[Sequence]=None,
batch_first: bool=False, seed: Optional[int]=None,
device: Optional[torch.device]=None,
flip: bool=False, rotate: bool=False, scale: bool=False
):
super().__init__()
self.ob_horizon = ob_horizon # Observed horizon (length of history)
self.pred_horizon = pred_horizon # Prediction horizon (length of future to predict)
self.horizon = self.ob_horizon+self.pred_horizon
# frameskip: used to skip frames when sampling the data (controls framerate)
self.frameskip = int(frameskip) if frameskip and int(frameskip) > 1 else 1
# max_overlap: the maximum number of overlapping frames between two consecutive samples, which is determined by the horizon and frameskip.
self.max_overlap = max_overlap
self.batch_first = batch_first
self.flip = flip
self.rotate = rotate
self.scale = scale
if device is None:
self.device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
else:
self.device = device
if inclusive_groups is None:
inclusive_groups = [[] for _ in range(len(files))]
assert(len(inclusive_groups) == len(files))
print("--- Scanning files...")
files_ = []
for path, incl_g in zip(files, inclusive_groups):
if os.path.isdir(path):
files_.extend([(os.path.join(root, f), incl_g) \
for root, _, fs in os.walk(path) \
for f in fs if f.endswith(".txt")])
elif os.path.exists(path):
files_.append((path, incl_g))
data_files = sorted(files_, key=lambda _: _[0])
data = []
done = 0
# too large of max_workers will cause the problem of memory usage
max_workers = min(len(data_files), torch.get_num_threads(), 20)
with ProcessPoolExecutor(mp_context=multiprocessing.get_context("spawn"), max_workers=max_workers) as p:
futures = [p.submit(self.__class__.load, self, f, incl_g) for f, incl_g in data_files]
for fut in as_completed(futures):
done += 1
# We print the progress of loading data files in the same line, and clear the line after each print.
sys.stdout.write("\r\033[K--- Loading data files...{}/{}".format(
done, len(data_files)
))
for fut in futures:
item = fut.result()
if item is not None:
data.extend(item)
sys.stdout.write("\r\033[K--- Loading data files...{}/{} ".format(
done, len(data_files)
))
self.data = np.array(data, dtype=object)
del data
print("\n--- {} trajectories loaded.".format(len(self.data)))
self.rng = np.random.RandomState()
if seed: self.rng.seed(seed)
if shuffle:
sampler = torch.utils.data.sampler.RandomSampler(self)
else:
sampler = torch.utils.data.sampler.SequentialSampler(self)
if batches_per_epoch is None:
self.batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, batch_size, drop_last)
self.batches_per_epoch = len(self.batch_sampler)
else:
self.batch_sampler = self.__class__.FixedNumberBatchSampler(batches_per_epoch, sampler, batch_size, drop_last)
self.batches_per_epoch = batches_per_epoch
def collate_fn(self, batch):
X, Y, NEIGHBORS = [], [], []
for item in batch:
# Each item is a tuple of (hist, future, neighbor)
hist, future, neighbor = item[0], item[1], item[2]
hist_shape = hist.shape
neighbor_shape = neighbor.shape
# Temporary reshape to L x 2 for augmentation
hist = np.reshape(hist, (-1, 2))
neighbor = np.reshape(neighbor, (-1, 2))
# Augmentation: flipping
if self.flip:
if self.rng.randint(2):
hist[..., 1] *= -1
future[..., 1] *= -1
neighbor[..., 1]*= -1
if self.rng.randint(2):
hist[..., 0] *= -1
future[..., 0] *= -1
neighbor[..., 0]*= -1
# Augmentation: rotation
if self.rotate:
# Rotation angle is sampled in [0, 2pi)
rot = self.rng.random() * (np.pi+np.pi)
s, c = np.sin(rot), np.cos(rot)
r = np.asarray([
[c, -s],
[s, c]
])
# Apply the same rotation to history, future, and neighbor.
hist = (r @ np.expand_dims(hist, -1)).squeeze(-1)
future = (r @ np.expand_dims(future, -1)).squeeze(-1)
neighbor= (r @ np.expand_dims(neighbor, -1)).squeeze(-1)
# Augmentation: scaling
if self.scale:
# Scale factor is sampled from N(1, 0.05). Very close to 1.
s = self.rng.randn()*0.05 + 1 # N(1, 0.05)
hist = s * hist
future = s * future
neighbor= s * neighbor
hist = np.reshape(hist, hist_shape)
neighbor = np.reshape(neighbor, neighbor_shape)
X.append(hist)
Y.append(future)
NEIGHBORS.append(neighbor)
# Pad the neighbor tensor to make sure all samples have the same number of neighbors, which is required for batching.
n_neighbors = [n.shape[1] for n in NEIGHBORS]
# Max number of neighbors in this batch.
max_neighbors = max(n_neighbors)
if max_neighbors != min(n_neighbors):
# Padding to 1e9
NEIGHBORS = [
np.pad(neighbor, ((0, 0), (0, max_neighbors-n), (0, 0)),
"constant", constant_values=1e9)
for neighbor, n in zip(NEIGHBORS, n_neighbors)
]
stack_dim = 0 if self.batch_first else 1
# Form the batch numpy tensor by stacking the samples in the batch dimension.
x = np.stack(X, stack_dim)
y = np.stack(Y, stack_dim)
neighbor = np.stack(NEIGHBORS, stack_dim)
# Convert to torch tensors and move to the device.
x = torch.tensor(x, dtype=torch.float32, device=self.device)
y = torch.tensor(y, dtype=torch.float32, device=self.device)
neighbor = torch.tensor(neighbor, dtype=torch.float32, device=self.device)
return x, y, neighbor
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
@staticmethod
def load(self, filename, inclusive_groups):
# Skip directories.
if os.path.isdir(filename): return None
horizon = (self.horizon-1)*self.frameskip
with open(filename, "r") as record:
data = self.load_traj(record)
# Filter the data (remove lonely targets) and extend velocity and acceleration features.
data = self.extend(data, self.frameskip)
times= np.sort(list(data.keys()))
if len(times) < horizon+1: return None
valid_horizon = self.ob_horizon + self.pred_horizon
traj = []
e = len(times)
tid0 = 0
while tid0 < e-horizon:
# The target frame id is tid0, and the last frame id in the horizon is tid1.
tid1 = tid0+horizon
# The time of the target frame.
t0 = times[tid0]
# All target ids in t0
idx = [aid for aid, d in data[t0].items() if not inclusive_groups or any(g in inclusive_groups for g in d[-1])]
if idx:
idx_all = list(data[t0].keys())
# We loop over the frames in the horizon, and take the intersection of the target ids in these frames and the target ids in t0, to make sure the target appears in all frames in the horizon.
for tid in range(tid0+self.frameskip, tid1+1, self.frameskip):
t = times[tid]
idx_cur = [aid for aid, d in data[t].items() if not inclusive_groups or any(g in inclusive_groups for g in d[-1])]
if not idx_cur: # ignore empty frames
tid0 = tid
idx = []
break
idx = np.intersect1d(idx, idx_cur)
if len(idx) == 0: break
idx_all.extend(data[t].keys())
# Intersection of target ids is not null.
if len(idx):
data_dim = 6
# The neighbor ids are the target ids in the target frame that are not in the intersection of target ids in the horizon. We take all neighbors in the target frame, instead of taking the intersection of neighbors in the horizon, because we want to include those neighbors that only appear in part of the horizon (e.g., a neighbor that appears in the future but not in the history, or vice versa).
neighbor_idx = np.setdiff1d(idx_all, idx)
# No neighbor, we just add a dummy neighbor with large value (1e9).
if len(idx) == 1 and len(neighbor_idx) == 0:
agents = np.array([
[data[times[tid]][idx[0]][:data_dim]] + [[1e9]*data_dim]
for tid in range(tid0, tid1+1, self.frameskip)
]) # L x 2 x 6 (1 main agent + 1 dummy neighbor)
else:
agents = np.array([
[data[times[tid]][i][:data_dim] for i in idx] +
[data[times[tid]][j][:data_dim] if j in data[times[tid]] else [1e9]*data_dim for j in neighbor_idx]
for tid in range(tid0, tid1+1, self.frameskip)
]) # L X N x 6 (all agents and neighbors)
# For all agents present all the time, generate hist,future,neighbor.
for i in range(len(idx)):
# For each target agent, we generate a partition of the data as:
# history: ob_horizon x 6
# future: pred_horizon x 2
# neighbor: L x (N-1) x 6
hist = agents[:self.ob_horizon,i] # ob_horizon x 6
future = agents[self.ob_horizon:valid_horizon,i,:2] # pred_horizon x 2
neighbor= agents[:valid_horizon, [d for d in range(agents.shape[1]) if d != i]] # L x (N-1) x 6
traj.append((hist, future, neighbor))
# We move the target frame id by max_overlap to make sure the next sample has a different target frame, and there is at most max_overlap frames of overlap between two consecutive samples.
tid0 += self.max_overlap
items = []
# Everything to float32 and put into items
for hist, future, neighbor in traj:
hist = np.float32(hist)
future = np.float32(future)
neighbor= np.float32(neighbor)
items.append((hist, future, neighbor))
return items
# Get the raw data ("data") and extend the data tensor with velocity and acceleration features, and filter out those targets that are only appearing in one frame (lonely targets). The velocity is computed with finite difference, and the acceleration is computed as the difference of velocity.
def extend(self, data, frameskip):
times = np.sort(list(data.keys()))
# All distinct dts
dts = np.unique(times[1:] - times[:-1])
dt = dts.min()
# All dts should be multiple of the minimum dt.
if np.any(dts % dt != 0):
raise ValueError("Inconsistent frame interval:", dts)
i = 0
while i < len(times)-1:
if times[i+1] - times[i] != dt:
# When a frame is missing, we insert a new frame there.
times = np.insert(times, i+1, times[i]+dt)
i += 1
# ignore those only appearing at one frame
for tid, t in enumerate(times):
removed = []
# Add the missing frames.
if t not in data: data[t] = {}
# Loop over the targets in this frame.
for idx in data[t].keys():
# Previous and next frames.
t0 = times[tid-frameskip] if tid >= frameskip else None
t1 = times[tid+frameskip] if tid+frameskip < len(times) else None
# If the target is not appearing in the previous and next frames, we ignore it (it appears only once).
if (t0 is None or t0 not in data or idx not in data[t0]) and \
(t1 is None or t1 not in data or idx not in data[t1]):
removed.append(idx)
for idx in removed:
data[t].pop(idx)
# Extend the data tensor with velocity.
for tid in range(len(times)-frameskip):
t0 = times[tid]
t1 = times[tid+frameskip]
if t1 not in data or t0 not in data: continue
# All targets in t1
for i, __ in data[t1].items():
if i not in data[t0]: continue
# Previous and current positions
x0 = data[t0][i][0]
y0 = data[t0][i][1]
x1 = data[t1][i][0]
y1 = data[t1][i][1]
vx, vy = x1-x0, y1-y0
# Velocity at t1 computed with finite difference
# Does NOT use future frames, so it is not a "leaking" feature.
data[t1][i].insert(2, vx)
data[t1][i].insert(3, vy)
if tid < frameskip or i not in data[times[tid-1]]:
data[t0][i].insert(2, vx)
data[t0][i].insert(3, vy)
# Extend the data tensor with acceleration.
for tid in range(len(times)-frameskip):
# Three consecutive frames.
t_1 = None if tid < frameskip else times[tid-frameskip]
t0 = times[tid]
t1 = times[tid+frameskip]
if t1 not in data or t0 not in data: continue
for i, item in data[t1].items():
if i not in data[t0]: continue
# Velocities at t0 and t1.
vx0 = data[t0][i][2]
vy0 = data[t0][i][3]
vx1 = data[t1][i][2]
vy1 = data[t1][i][3]
ax, ay = vx1-vx0, vy1-vy0
# Insert accelerations.
data[t1][i].insert(4, ax)
data[t1][i].insert(5, ay)
if t_1 is None or i not in data[t_1]:
# first appearing frame, pick value from the next frame
data[t0][i].insert(4, ax)
data[t0][i].insert(5, ay)
return data
# Read the trajectory data from the file and return a dictionary of the form:
# { time: {agent_id: [x, y, group]}}. The group is optional and can be used for filtering.
def load_traj(self, file):
data = {}
for row in file.readlines():
item = row.split()
if not item: continue
# Parse time, agent id, x, y, and group (if exists) from the row.
t = int(float(item[0]))
idx = int(float(item[1]))
x = float(item[2])
y = float(item[3])
# The group is optional and can be used for filtering. If the group exists, we split it by "/" to support multiple groups for one agent.
group = item[4].split("/") if len(item) > 4 else None
# Time entry in the dictionnary
if t not in data:
data[t] = {}
# Agent entry in the dictionnary (array)
data[t][idx] = [x, y, group]
return data