-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocessing.py
More file actions
44 lines (36 loc) · 1.64 KB
/
preprocessing.py
File metadata and controls
44 lines (36 loc) · 1.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import chex
import jax
import jax.numpy as jnp
import numpy as np
import ott
from flax import struct
from gymnax.environments import environment
from gymnax.environments import spaces
from jax import lax
from natsort import natsorted
from purejaxrl.utils import find_files
for k in [3, 5, 8, 10]:
dataset_files = natsorted(find_files(f'*k{k}_*.npy', f'./datasets/'))
# dataset_files = natsorted(find_files(f'*.npy', f'./datasets/'))
for idx, dataset_file in enumerate(dataset_files):
with open(dataset_file, "rb") as fl:
dataset = np.load(fl, allow_pickle=True).item()
print(f"{idx}/{len(dataset_files)}: {dataset_file}")
if "km_clusters" in dataset:
continue
""" Load the corresponding datasets and get clusters. """
object_starts = jnp.array(dataset["start_points"])
object_ends = jnp.array(dataset["end_points"])
depot = jnp.array(dataset["depot"])
""" Perform k-means to get the cluster locations + depot location """
start_and_end_points = jnp.concatenate([object_starts, object_ends], axis=0)
kmout = ott.tools.k_means.k_means(start_and_end_points, k, rng=jax.random.PRNGKey(999))
clusters = kmout.centroids
clusters = jnp.concatenate([jnp.array([depot]), clusters], axis=0)
dataset["km_clusters"] = clusters
""" Save the corresponding dataset again"""
with open(dataset_file, "wb") as fl:
np.save(fl, dataset, True)
with open(dataset_file, "rb") as fl:
test = np.load(fl, allow_pickle=True).item()
# print(test)