Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions examples/datasets/colmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(
test_every: int = 8,
depth_dir_name: Optional[str] = None,
normal_dir_name: Optional[str] = None,
dynamic_mask_dir_name: Optional[str] = None,
sky_mask_dir_name: Optional[str] = None,
):
self.data_dir = data_dir
self.factor = factor
Expand Down Expand Up @@ -402,6 +404,51 @@ def __init__(
path = os.path.join(normal_dir, base_name + ".png")
self.normal_paths.append(path)

# Process dynamic mask paths.
# We primarily match the image filename (same basename + extension), but also
# fall back to {basename}.png which is a common convention.
if dynamic_mask_dir_name is None:
print(
"[Parser] No dynamic mask directory name provided. Skipping dynamic masks."
)
self.dynamic_mask_paths = None
else:
dynamic_mask_dir = os.path.join(self.data_dir, dynamic_mask_dir_name)
self.dynamic_mask_paths = []
print(f"[Parser] Building dynamic mask paths from: {dynamic_mask_dir}")
for img_name in self.image_names:
base_name, _ext = os.path.splitext(img_name)
candidate_same_ext = os.path.join(dynamic_mask_dir, img_name)
candidate_png = os.path.join(dynamic_mask_dir, base_name + ".png")
if os.path.exists(candidate_same_ext):
self.dynamic_mask_paths.append(candidate_same_ext)
elif os.path.exists(candidate_png):
self.dynamic_mask_paths.append(candidate_png)
else:
self.dynamic_mask_paths.append(candidate_same_ext)

# Process sky mask paths.
# We primarily match the image filename (same basename + extension), but also
# fall back to {basename}.png which is a common convention.
if sky_mask_dir_name is None:
print("[Parser] No sky mask directory name provided. Skipping sky masks.")
self.sky_mask_paths = None
else:
sky_mask_dir = os.path.join(self.data_dir, sky_mask_dir_name)
self.sky_mask_paths = []
print(f"[Parser] Building sky mask paths from: {sky_mask_dir}")
for img_name in self.image_names:
base_name, ext = os.path.splitext(img_name)
candidate_same_ext = os.path.join(sky_mask_dir, img_name)
candidate_png = os.path.join(sky_mask_dir, base_name + ".png")
# Prefer same extension if it exists, else try png, else keep same-ext.
if os.path.exists(candidate_same_ext):
self.sky_mask_paths.append(candidate_same_ext)
elif os.path.exists(candidate_png):
self.sky_mask_paths.append(candidate_png)
else:
self.sky_mask_paths.append(candidate_same_ext)


class Dataset:
"""A simple dataset class with optional data preloading."""
Expand Down Expand Up @@ -469,6 +516,8 @@ def _load_base_sample(self, index: int) -> Dict[str, Any]:

depth_data = None
normal_data = None
dynamic_mask_data = None
sky_mask_data = None

if self.parser.depth_paths is not None:
depth_path = self.parser.depth_paths[index]
Expand All @@ -483,6 +532,24 @@ def _load_base_sample(self, index: int) -> Dict[str, Any]:
normal_data = imageio.imread(normal_path)[..., :3] # Known to be .png
except Exception as e:
print(f"Warning: Could not load normal {normal_path}: {e}")
if getattr(self.parser, "dynamic_mask_paths", None) is not None:
dynamic_mask_path = self.parser.dynamic_mask_paths[index]
try:
dyn_img = cv2.imread(dynamic_mask_path, cv2.IMREAD_GRAYSCALE)
if dyn_img is None:
raise RuntimeError("cv2.imread returned None")
dynamic_mask_data = dyn_img > 0
except Exception as e:
print(f"Warning: Could not load dynamic mask {dynamic_mask_path}: {e}")
if getattr(self.parser, "sky_mask_paths", None) is not None:
sky_mask_path = self.parser.sky_mask_paths[index]
try:
sky_img = cv2.imread(sky_mask_path, cv2.IMREAD_GRAYSCALE)
if sky_img is None:
raise RuntimeError("cv2.imread returned None")
sky_mask_data = sky_img > 0
except Exception as e:
print(f"Warning: Could not load sky mask {sky_mask_path}: {e}")

if len(params) > 0:
# Images are distorted. Undistort them.
Expand All @@ -502,12 +569,28 @@ def _load_base_sample(self, index: int) -> Dict[str, Any]:
if normal_data is not None:
normal_data = cv2.remap(normal_data, mapx, mapy, cv2.INTER_LINEAR)
normal_data = normal_data[y : y + h, x : x + w]
# Apply to dynamic mask (Note: INTER_NEAREST)
if dynamic_mask_data is not None:
dynamic_mask_u8 = dynamic_mask_data.astype(np.uint8) * 255
dynamic_mask_u8 = cv2.remap(
dynamic_mask_u8, mapx, mapy, cv2.INTER_NEAREST
)
dynamic_mask_u8 = dynamic_mask_u8[y : y + h, x : x + w]
dynamic_mask_data = dynamic_mask_u8 > 0
# Apply to sky mask (Note: INTER_NEAREST)
if sky_mask_data is not None:
sky_mask_u8 = sky_mask_data.astype(np.uint8) * 255
sky_mask_u8 = cv2.remap(sky_mask_u8, mapx, mapy, cv2.INTER_NEAREST)
sky_mask_u8 = sky_mask_u8[y : y + h, x : x + w]
sky_mask_data = sky_mask_u8 > 0

return {
"image": image,
"depth": depth_data,
"normal": normal_data,
"mask": mask,
"dynamic_mask": dynamic_mask_data,
"sky_mask": sky_mask_data,
"K": K,
"camtoworld": camtoworlds,
}
Expand All @@ -532,6 +615,16 @@ def _convert_to_tensors(
if sample["mask"] is not None
else None
)
tensor_sample["dynamic_mask"] = (
torch.from_numpy(sample["dynamic_mask"]).bool().to(device)
if sample.get("dynamic_mask") is not None
else None
)
tensor_sample["sky_mask"] = (
torch.from_numpy(sample["sky_mask"]).bool().to(device)
if sample.get("sky_mask") is not None
else None
)
tensor_sample["K"] = torch.from_numpy(sample["K"]).float().to(device)
tensor_sample["camtoworld"] = (
torch.from_numpy(sample["camtoworld"]).float().to(device)
Expand All @@ -547,6 +640,8 @@ def _prepare_sample(
depth = sample["depth"]
normal = sample["normal"]
mask = sample["mask"]
dynamic_mask = sample.get("dynamic_mask")
sky_mask = sample.get("sky_mask")
K = sample["K"]
camtoworlds = sample["camtoworld"]

Expand All @@ -568,6 +663,10 @@ def _prepare_sample(
normal = normal[y_slice, x_slice]
if mask is not None:
mask = mask[y_slice, x_slice]
if dynamic_mask is not None:
dynamic_mask = dynamic_mask[y_slice, x_slice]
if sky_mask is not None:
sky_mask = sky_mask[y_slice, x_slice]
K = K.clone()
K[0, 2] -= x
K[1, 2] -= y
Expand All @@ -592,6 +691,10 @@ def _prepare_sample(
}
if mask is not None:
data["mask"] = mask
if dynamic_mask is not None:
data["dynamic_mask"] = dynamic_mask
if sky_mask is not None:
data["sky_mask"] = sky_mask

return data

Expand Down
Loading