Skip to content

Commit 5544d95

Browse files
authored
[Fix] Update NavDP Training Pipeline and Support Latest InternData-N1 Format (#305)
* [fix] fix the navdp training code and support latest interndata-n1 format * [fix] fix the navdp training code and support latest interndata-n1 format * [fix] fix the navdp training code and support latest interndata-n1 format
1 parent 1d8d078 commit 5544d95

4 files changed

Lines changed: 108 additions & 87 deletions

File tree

internnav/dataset/navdp_lerobot_dataset.py

Lines changed: 95 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from datetime import datetime
66

77
import cv2
8+
import jsonlines
89
import numpy as np
910
import open3d as o3d
1011
import pandas as pd
@@ -42,6 +43,7 @@ def __init__(
4243
scene_data_scale=1.0,
4344
trajectory_data_scale=1.0,
4445
pixel_channel=7,
46+
action_dim=3,
4547
debug=False,
4648
preload=False,
4749
random_digit=False,
@@ -54,8 +56,9 @@ def __init__(
5456
self.scene_scale_size = scene_data_scale
5557
self.trajectory_data_scale = trajectory_data_scale
5658
self.predict_size = predict_size
59+
self.action_dim = action_dim
5760
self.debug = debug
58-
self.trajectory_dirs = []
61+
5962
self.trajectory_data_dir = []
6063
self.trajectory_rgb_path = []
6164
self.trajectory_depth_path = []
@@ -74,65 +77,84 @@ def __init__(
7477
select_scene_dirs = all_scene_dirs[
7578
np.arange(0, all_scene_dirs.shape[0], 1 / self.scene_scale_size).astype(np.int32)
7679
]
77-
for scene_dir in select_scene_dirs:
78-
all_traj_dirs = np.array([p for p in os.listdir(os.path.join(root_dirs, group_dir, scene_dir))])
79-
select_traj_dirs = all_traj_dirs[
80-
np.arange(0, all_traj_dirs.shape[0], 1 / self.trajectory_data_scale).astype(np.int32)
81-
]
82-
for traj_dir in tqdm(select_traj_dirs):
83-
entire_task_dir = os.path.join(root_dirs, group_dir, scene_dir, traj_dir)
84-
rgb_dir = os.path.join(entire_task_dir, "videos/chunk-000/observation.images.rgb/")
85-
depth_dir = os.path.join(entire_task_dir, "videos/chunk-000/observation.images.depth/")
86-
data_path = os.path.join(
87-
entire_task_dir, 'data/chunk-000/episode_000000.parquet'
88-
) # intrinsic, extrinsic, cam_traj, path
89-
afford_path = os.path.join(entire_task_dir, 'data/chunk-000/path.ply')
90-
rgbs_length = len([p for p in os.listdir(rgb_dir)])
91-
depths_length = len([p for p in os.listdir(depth_dir)])
92-
93-
rgbs_path = []
94-
depths_path = []
95-
if depths_length != rgbs_length:
96-
continue
97-
for i in range(rgbs_length):
98-
rgbs_path.append(os.path.join(rgb_dir, "%d.jpg" % i))
99-
depths_path.append(os.path.join(depth_dir, "%d.png" % i))
100-
if os.path.exists(data_path) is False:
101-
continue
102-
self.trajectory_dirs.append(entire_task_dir)
103-
self.trajectory_data_dir.append(data_path)
104-
self.trajectory_rgb_path.append(rgbs_path)
105-
self.trajectory_depth_path.append(depths_path)
106-
self.trajectory_afford_path.append(afford_path)
80+
81+
for scene_dir in tqdm(select_scene_dirs):
82+
chunk_name = os.listdir(os.path.join(root_dirs, group_dir, scene_dir, 'data'))[0]
83+
data_dir = os.path.join(root_dirs, group_dir, scene_dir, f'data/{chunk_name}')
84+
afford_dir = os.path.join(root_dirs, group_dir, scene_dir, 'meta/pointcloud.ply')
85+
with jsonlines.open(
86+
os.path.join(root_dirs, group_dir, scene_dir, 'meta/episodes_stats.jsonl'), 'r'
87+
) as reader:
88+
episode_info = list(reader)
89+
rgb_dir = os.path.join(
90+
root_dirs, group_dir, scene_dir, f"videos/{chunk_name}/observation.images.rgb/"
91+
)
92+
rgb_paths = [os.path.join(rgb_dir, p) for p in sorted(os.listdir(rgb_dir))]
93+
94+
depth_dir = os.path.join(
95+
root_dirs, group_dir, scene_dir, f"videos/{chunk_name}/observation.images.depth/"
96+
)
97+
depth_paths = [os.path.join(depth_dir, p) for p in sorted(os.listdir(depth_dir))]
98+
99+
data_paths = [os.path.join(data_dir, p) for p in sorted(os.listdir(data_dir))]
100+
101+
for episode_idx, episode in enumerate(episode_info):
102+
image_start_index = episode['image_index']['min']
103+
image_end_index = episode['image_index']['max']
104+
episode_rgb_path = np.array(rgb_paths)[image_start_index : image_end_index + 1].tolist()
105+
episode_depth_path = np.array(depth_paths)[image_start_index : image_end_index + 1].tolist()
106+
107+
try:
108+
self.trajectory_data_dir.append(data_paths[episode_idx])
109+
self.trajectory_rgb_path.append(episode_rgb_path)
110+
self.trajectory_depth_path.append(episode_depth_path)
111+
self.trajectory_afford_path.append(afford_dir)
112+
except Exception as e:
113+
import pdb
114+
115+
print(f"Error processing episode {episode_idx}: {e}")
116+
pdb.set_trace()
107117

108118
save_dict = {
109-
'trajectory_dirs': self.trajectory_dirs,
110119
'trajectory_data_dir': self.trajectory_data_dir,
111120
'trajectory_rgb_path': self.trajectory_rgb_path,
112121
'trajectory_depth_path': self.trajectory_depth_path,
113122
'trajectory_afford_path': self.trajectory_afford_path,
114123
}
115124
with open(preload_path, 'w') as f:
116125
json.dump(save_dict, f, indent=4)
126+
127+
# replicate the data 50 times
128+
self.trajectory_data_dir = self.trajectory_data_dir * 50
129+
self.trajectory_rgb_path = self.trajectory_rgb_path * 50
130+
self.trajectory_depth_path = self.trajectory_depth_path * 50
131+
self.trajectory_afford_path = self.trajectory_afford_path * 50
117132
else:
118133
load_dict = json.load(open(preload_path, 'r'))
119-
self.trajectory_dirs = load_dict['trajectory_dirs'] * 50
120134
self.trajectory_data_dir = load_dict['trajectory_data_dir'] * 50
121135
self.trajectory_rgb_path = load_dict['trajectory_rgb_path'] * 50
122136
self.trajectory_depth_path = load_dict['trajectory_depth_path'] * 50
123137
self.trajectory_afford_path = load_dict['trajectory_afford_path'] * 50
124138

125139
def __len__(self):
126-
return len(self.trajectory_dirs)
140+
return len(self.trajectory_data_dir)
127141

128142
def load_image(self, image_url):
129-
image = Image.open(image_url)
130-
image = np.array(image, np.uint8)
143+
try:
144+
image = Image.open(image_url)
145+
image = np.array(image, np.uint8)
146+
except Exception as e:
147+
print(f"Error loading image {image_url}: {e}")
148+
image = np.zeros((self.image_size, self.image_size, 3), dtype=np.uint8)
131149
return image
132150

133151
def load_depth(self, depth_url):
134-
depth = Image.open(depth_url)
135-
depth = np.array(depth, np.uint16)
152+
try:
153+
depth = Image.open(depth_url)
154+
depth = np.array(depth, np.uint16)
155+
except Exception as e:
156+
print(f"Error loading depth {depth_url}: {e}")
157+
depth = np.zeros((self.image_size, self.image_size), dtype=np.uint16)
136158
return depth
137159

138160
def load_pointcloud(self, pcd_url):
@@ -176,39 +198,19 @@ def process_data_parquet(self, index):
176198
camera_intrinsic = np.vstack(np.array(df['observation.camera_intrinsic'].tolist()[0])).reshape(3, 3)
177199
camera_extrinsic = np.vstack(np.array(df['observation.camera_extrinsic'].tolist()[0])).reshape(4, 4)
178200
trajectory_length = len(df['action'].tolist())
179-
camera_trajectory = np.array([np.stack(frame) for frame in df['action']], dtype=np.float64)
201+
camera_trajectory = np.array([np.stack(frame) for frame in df['action']], dtype=np.float64).reshape(-1, 4, 4)
180202
return camera_intrinsic, camera_extrinsic, camera_trajectory, trajectory_length
181203

182-
def process_path_points(self, index):
183-
trajectory_pcd = self.load_pointcloud(self.trajectory_afford_path[index])
184-
trajectory_color = np.array(trajectory_pcd.colors)
185-
color_distance = np.abs(trajectory_color - np.array([0, 0, 0])).sum(
186-
axis=-1
187-
) # sometimes, the path are saved as black points
204+
def process_obstacle_points(self, index):
205+
scene_pcd = self.load_pointcloud(self.trajectory_afford_path[index])
206+
scene_color = np.array(scene_pcd.colors)
207+
scene_points = np.array(scene_pcd.points)
208+
color_distance = np.abs(scene_color - np.array([0, 0, 0.5])).sum(axis=-1)
188209
select_index = np.where(color_distance < 0.05)[0]
189-
trajectory_path = o3d.geometry.PointCloud()
190-
trajectory_path.points = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.points)[select_index])
191-
trajectory_path.colors = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.colors)[select_index])
192-
return np.array(trajectory_path.points), trajectory_path
193-
194-
def process_obstacle_points(self, index, path_points):
195-
trajectory_pcd = self.load_pointcloud(self.trajectory_afford_path[index])
196-
trajectory_color = np.array(trajectory_pcd.colors)
197-
trajectory_points = np.array(trajectory_pcd.points)
198-
color_distance = np.abs(trajectory_color - np.array([0, 0, 0.5])).sum(axis=-1) # the obstacles are save in blue
199-
path_lower_bound = path_points.min(axis=0)
200-
path_upper_bound = path_points.max(axis=0)
201-
condition_x = (trajectory_points[:, 0] >= path_lower_bound[0] - 2.0) & (
202-
trajectory_points[:, 0] <= path_upper_bound[0] + 2.0
203-
)
204-
condition_y = (trajectory_points[:, 1] >= path_lower_bound[1] - 2.0) & (
205-
trajectory_points[:, 1] <= path_upper_bound[1] + 2.0
206-
)
207-
select_index = np.where((color_distance < 0.05) & condition_x & condition_y)[0]
208-
trajectory_obstacle = o3d.geometry.PointCloud()
209-
trajectory_obstacle.points = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.points)[select_index])
210-
trajectory_obstacle.colors = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.colors)[select_index])
211-
return np.array(trajectory_obstacle.points), trajectory_obstacle
210+
scene_obstacle = o3d.geometry.PointCloud()
211+
scene_obstacle.points = o3d.utility.Vector3dVector(scene_points[select_index])
212+
scene_obstacle.colors = o3d.utility.Vector3dVector(scene_color[select_index])
213+
return np.array(scene_obstacle.points), scene_obstacle
212214

213215
def process_memory(self, rgb_paths, depth_paths, start_step, memory_digit=1):
214216
memory_index = np.arange(start_step - (self.memory_size - 1) * memory_digit, start_step + 1, memory_digit)
@@ -220,8 +222,12 @@ def process_memory(self, rgb_paths, depth_paths, start_step, memory_digit=1):
220222
return context_image, context_depth, memory_index
221223

222224
def process_pixel_goal(self, image_url, target_point, camera_intrinsic, camera_extrinsic):
223-
image = Image.open(image_url)
224-
image = np.array(image, np.uint8)
225+
try:
226+
image = Image.open(image_url)
227+
image = np.array(image, np.uint8)
228+
except Exception as e:
229+
print(f"Error loading image {image_url}: {e}")
230+
image = np.zeros((self.image_size, self.image_size, 3), dtype=np.uint8)
225231
resize_image = self.process_image(image_url)
226232

227233
coordinate = np.array([-target_point[1], target_point[0], camera_extrinsic[2, 3] * 0.8])
@@ -422,10 +428,7 @@ def __getitem__(self, index):
422428
trajectory_length,
423429
) = self.process_data_parquet(index)
424430

425-
trajectory_path_points, trajectory_path_pcd = self.process_path_points(index)
426-
trajectory_obstacle_points, trajectory_obstacle_pcd = self.process_obstacle_points(
427-
index, trajectory_path_points
428-
)
431+
trajectory_obstacle_points, trajectory_obstacle_pcd = self.process_obstacle_points(index)
429432

430433
if self.prior_sample:
431434
pixel_start_choice, target_choice = self.rank_steps()
@@ -435,7 +438,6 @@ def __getitem__(self, index):
435438
target_choice = np.random.randint(pixel_start_choice + 1, trajectory_length - 1)
436439
memory_start_choice = np.random.randint(pixel_start_choice, target_choice)
437440

438-
# target_extrinsic = trajectory_extrinsics[target_choice]
439441
if self.random_digit:
440442
memory_digit = np.random.randint(2, 8)
441443
pred_digit = memory_digit
@@ -458,6 +460,7 @@ def __getitem__(self, index):
458460
) = self.process_actions(
459461
trajectory_extrinsics, trajectory_base_extrinsic, memory_start_choice, target_choice, pred_digit=pred_digit
460462
)
463+
461464
# convert the xyz points into xy-theta points
462465
init_vector = target_local_points[1] - target_local_points[0]
463466
target_xyt_actions = self.xyz_to_xyt(target_local_points, init_vector)
@@ -521,6 +524,19 @@ def __getitem__(self, index):
521524
pred_actions = (pred_actions[1:] - pred_actions[:-1]) * 4.0
522525
augment_actions = (augment_actions[1:] - augment_actions[:-1]) * 4.0
523526

527+
pred_actions = np.pad(
528+
pred_actions,
529+
((0, 0), (0, self.action_dim - pred_actions.shape[-1])),
530+
mode='constant',
531+
constant_values=(0, 0),
532+
)
533+
augment_actions = np.pad(
534+
augment_actions,
535+
((0, 0), (0, self.action_dim - augment_actions.shape[-1])),
536+
mode='constant',
537+
constant_values=(0, 0),
538+
)
539+
524540
# Summarize avg time of batch
525541
end_time = time.time()
526542
self.item_cnt += 1
@@ -573,13 +589,13 @@ def navdp_collate_fn(batch):
573589
if __name__ == "__main__":
574590
os.makedirs("./navdp_dataset_test/", exist_ok=True)
575591
dataset = NavDP_Base_Datset(
576-
"/shared/smartbot_new/liuyu/vln-n1-minival/",
577-
"./navdp_dataset_test/dataset_lerobot.json",
592+
"/mnt/data/liuyu/InternDate-N1-v05/vln-n1",
593+
"./navdp_dataset_test/dataset_lerobot_v05_with_interiorgs.json",
578594
8,
579595
24,
580596
224,
581-
trajectory_data_scale=0.1,
582-
scene_data_scale=0.1,
597+
trajectory_data_scale=1.0,
598+
scene_data_scale=1.0,
583599
preload=False,
584600
)
585601

internnav/model/basemodel/navdp/navdp_policy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,10 @@ def forward(self, goal_point, goal_image, goal_pixel, input_images, input_depths
266266
noise_pred_mg,
267267
cr_label_pred,
268268
cr_augment_pred,
269-
[ng_noise, mg_noise],
270-
[imagegoal_aux_pred, pixelgoal_aux_pred],
269+
ng_noise,
270+
mg_noise,
271+
imagegoal_aux_pred,
272+
pixelgoal_aux_pred,
271273
)
272274

273275
def _get_device(self):

internnav/trainer/navdp_trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
7777
batch_label_critic = inputs["batch_label_critic"]
7878
batch_augment_critic = inputs["batch_augment_critic"]
7979

80-
pred_ng, pred_mg, critic_pred, augment_pred, noise, aux_pred = model(
80+
pred_ng, pred_mg, critic_pred, augment_pred, ng_noise, mg_noise, imagegoal_aux_pred, pixelgoal_aux_pred = model(
8181
inputs_on_device["batch_pg"],
8282
inputs_on_device["batch_ig"],
8383
inputs_on_device["batch_tg"],
@@ -87,11 +87,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
8787
inputs_on_device["batch_augments"],
8888
)
8989

90-
ng_action_loss = (pred_ng - noise[0]).square().mean()
91-
mg_action_loss = (pred_mg - noise[1]).square().mean()
90+
ng_action_loss = (pred_ng - ng_noise).square().mean()
91+
mg_action_loss = (pred_mg - mg_noise).square().mean()
9292
aux_loss = (
93-
0.5 * (inputs_on_device["batch_pg"] - aux_pred[0]).square().mean()
94-
+ 0.5 * (inputs_on_device["batch_pg"] - aux_pred[1]).square().mean()
93+
0.5 * (inputs_on_device["batch_pg"] - imagegoal_aux_pred).square().mean()
94+
+ 0.5 * (inputs_on_device["batch_pg"] - pixelgoal_aux_pred).square().mean()
9595
)
9696
action_loss = 0.5 * mg_action_loss + 0.5 * ng_action_loss
9797
critic_loss = (critic_pred - batch_label_critic).square().mean() + (
@@ -104,7 +104,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
104104
'pred_mg': pred_mg,
105105
'critic_pred': critic_pred,
106106
'augment_pred': augment_pred,
107-
'noise': noise,
107+
'noise': [ng_noise, mg_noise],
108108
'loss': loss,
109109
'ng_action_loss': ng_action_loss,
110110
'mg_action_loss': mg_action_loss,

scripts/train/base_train/train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ def main(config, model_class, model_config_class):
127127
model = model_class.from_pretrained(pretrained_model_name_or_path=config.il.ckpt_to_load, config=model_cfg)
128128
if config.model_name == "navdp":
129129
model.to(device)
130+
for name, param in model.named_parameters():
131+
if 'mask_token' in name:
132+
param.requires_grad = False
130133
# Check that all parameters and buffers are on the correct device
131134
for name, param in model.named_parameters():
132135
if param.device != device:

0 commit comments

Comments
 (0)