From c17c9531fe3f7b8dc863e6475751d82f61fbea40 Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Wed, 28 Jan 2026 21:36:01 +0800 Subject: [PATCH 1/6] [Feat] Add a loss term to remove sky-related Gaussians. Continue pruning overly large Gaussians periodically after densification ends --- examples/datasets/colmap.py | 50 ++++++++++++++++++++++++++++++++++++ examples/extended_trainer.py | 29 +++++++++++++++++++++ gsplat/strategy/improved.py | 33 ++++++++++++++++++++++++ 3 files changed, 112 insertions(+) diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 7650f4a..32a2573 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -64,6 +64,7 @@ def __init__( test_every: int = 8, depth_dir_name: Optional[str] = None, normal_dir_name: Optional[str] = None, + sky_mask_dir_name: Optional[str] = None, ): self.data_dir = data_dir self.factor = factor @@ -402,6 +403,28 @@ def __init__( path = os.path.join(normal_dir, base_name + ".png") self.normal_paths.append(path) + # Process sky mask paths. + # We primarily match the image filename (same basename + extension), but also + # fall back to {basename}.png which is a common convention. + if sky_mask_dir_name is None: + print("[Parser] No sky mask directory name provided. Skipping sky masks.") + self.sky_mask_paths = None + else: + sky_mask_dir = os.path.join(self.data_dir, sky_mask_dir_name) + self.sky_mask_paths = [] + print(f"[Parser] Building sky mask paths from: {sky_mask_dir}") + for img_name in self.image_names: + base_name, ext = os.path.splitext(img_name) + candidate_same_ext = os.path.join(sky_mask_dir, img_name) + candidate_png = os.path.join(sky_mask_dir, base_name + ".png") + # Prefer same extension if it exists, else try png, else keep same-ext. + if os.path.exists(candidate_same_ext): + self.sky_mask_paths.append(candidate_same_ext) + elif os.path.exists(candidate_png): + self.sky_mask_paths.append(candidate_png) + else: + self.sky_mask_paths.append(candidate_same_ext) + class Dataset: """A simple dataset class with optional data preloading.""" @@ -469,6 +492,7 @@ def _load_base_sample(self, index: int) -> Dict[str, Any]: depth_data = None normal_data = None + sky_mask_data = None if self.parser.depth_paths is not None: depth_path = self.parser.depth_paths[index] @@ -483,6 +507,15 @@ def _load_base_sample(self, index: int) -> Dict[str, Any]: normal_data = imageio.imread(normal_path)[..., :3] # Known to be .png except Exception as e: print(f"Warning: Could not load normal {normal_path}: {e}") + if getattr(self.parser, "sky_mask_paths", None) is not None: + sky_mask_path = self.parser.sky_mask_paths[index] + try: + sky_img = cv2.imread(sky_mask_path, cv2.IMREAD_GRAYSCALE) + if sky_img is None: + raise RuntimeError("cv2.imread returned None") + sky_mask_data = sky_img > 0 + except Exception as e: + print(f"Warning: Could not load sky mask {sky_mask_path}: {e}") if len(params) > 0: # Images are distorted. Undistort them. @@ -502,12 +535,19 @@ def _load_base_sample(self, index: int) -> Dict[str, Any]: if normal_data is not None: normal_data = cv2.remap(normal_data, mapx, mapy, cv2.INTER_LINEAR) normal_data = normal_data[y : y + h, x : x + w] + # Apply to sky mask (Note: INTER_NEAREST) + if sky_mask_data is not None: + sky_mask_u8 = sky_mask_data.astype(np.uint8) * 255 + sky_mask_u8 = cv2.remap(sky_mask_u8, mapx, mapy, cv2.INTER_NEAREST) + sky_mask_u8 = sky_mask_u8[y : y + h, x : x + w] + sky_mask_data = sky_mask_u8 > 0 return { "image": image, "depth": depth_data, "normal": normal_data, "mask": mask, + "sky_mask": sky_mask_data, "K": K, "camtoworld": camtoworlds, } @@ -532,6 +572,11 @@ def _convert_to_tensors( if sample["mask"] is not None else None ) + tensor_sample["sky_mask"] = ( + torch.from_numpy(sample["sky_mask"]).bool().to(device) + if sample.get("sky_mask") is not None + else None + ) tensor_sample["K"] = torch.from_numpy(sample["K"]).float().to(device) tensor_sample["camtoworld"] = ( torch.from_numpy(sample["camtoworld"]).float().to(device) @@ -547,6 +592,7 @@ def _prepare_sample( depth = sample["depth"] normal = sample["normal"] mask = sample["mask"] + sky_mask = sample.get("sky_mask") K = sample["K"] camtoworlds = sample["camtoworld"] @@ -568,6 +614,8 @@ def _prepare_sample( normal = normal[y_slice, x_slice] if mask is not None: mask = mask[y_slice, x_slice] + if sky_mask is not None: + sky_mask = sky_mask[y_slice, x_slice] K = K.clone() K[0, 2] -= x K[1, 2] -= y @@ -592,6 +640,8 @@ def _prepare_sample( } if mask is not None: data["mask"] = mask + if sky_mask is not None: + data["sky_mask"] = sky_mask return data diff --git a/examples/extended_trainer.py b/examples/extended_trainer.py index 7feb90a..545fdc5 100644 --- a/examples/extended_trainer.py +++ b/examples/extended_trainer.py @@ -232,6 +232,11 @@ class Config: """Starting iteration for normal consistency regularization""" consistency_normal_loss_activation_step: int = 7000 + # Sky supervision (optional): if set, loads per-image `sky_mask` and adds a loss that + # encourages low alpha in sky pixels (do not occlude sky). + sky_mask_dir_name: Optional[str] = None + sky_loss_weight: float = 0.05 + # Enable camera optimization. pose_opt: bool = False # Learning rate for camera optimization @@ -307,6 +312,7 @@ def rebuild_strategy(self): refine_stop_iter=self.refine_stop_iter, reset_every=self.reset_every, refine_every=self.refine_every, + max_steps=self.max_steps, absgrad=self.absgrad, verbose=self.strategy_verbose, budget=self.budget, @@ -451,6 +457,7 @@ def __init__( test_every=cfg.test_every, depth_dir_name=cfg.depth_dir_name, normal_dir_name=cfg.normal_dir_name, + sky_mask_dir_name=cfg.sky_mask_dir_name, ) self.trainset = Dataset( self.parser, @@ -806,6 +813,12 @@ def train(self): ) image_ids = data["image_id"].to(device) masks = data["mask"].to(device) if "mask" in data else None # [1, H, W] + sky_mask = data.get("sky_mask") + if sky_mask is not None: + sky_mask = sky_mask.to(device) + # Expected shapes: [B,H,W] or [H,W]; keep [B,H,W] + if sky_mask.dim() == 2: + sky_mask = sky_mask.unsqueeze(0) # Optional priors depth_prior = None @@ -1000,6 +1013,20 @@ def train(self): ppisp_reg_loss = self.ppisp_module.get_regularization_loss() loss += ppisp_reg_loss + # sky supervision loss (optional) + sky_loss = None + if cfg.sky_mask_dir_name is not None and cfg.sky_loss_weight > 0.0: + if sky_mask is not None: + # Use accumulated alpha (opacity) to encourage sky pixels to be transparent. + acc = alphas[..., 0].clamp(min=1e-6, max=1.0 - 1e-6) # [B,H,W] + # sky_mask True means "sky/invalid", so we want acc -> 0 there. + sky_pixels = sky_mask + if masks is not None: + sky_pixels = sky_pixels & masks + if sky_pixels.any(): + sky_loss = (-torch.log1p(-acc))[sky_pixels].mean() + loss += cfg.sky_loss_weight * sky_loss + # depth loss if need_depth_prior: median_depths = info.get("render_median") @@ -1185,6 +1212,8 @@ def train(self): self.writer.add_scalar( "train/ppisp_reg_loss", ppisp_reg_loss.item(), step ) + if sky_loss is not None: + self.writer.add_scalar("train/sky_loss", sky_loss.item(), step) if need_depth_prior: self.writer.add_scalar("train/depthloss", depth_loss.item(), step) if ( diff --git a/gsplat/strategy/improved.py b/gsplat/strategy/improved.py index fbf85fa..11269c8 100644 --- a/gsplat/strategy/improved.py +++ b/gsplat/strategy/improved.py @@ -93,6 +93,11 @@ class ImprovedStrategy(Strategy): refine_stop_iter: int = 15000 reset_every: int = 3000 refine_every: int = 100 + # Continue pruning after densification stops. Set to 0 to disable. + post_refine_prune_every: int = 1000 + # Total number of training steps (used to trigger one extra prune at the final step). + # Set to 0 if unknown / not used. + max_steps: int = 30000 absgrad: bool = True verbose: bool = True key_for_gradient: Literal["means2d", "gradient_2dgs"] = "means2d" @@ -253,6 +258,34 @@ def step_post_backward( # ---------------------------------------------------------- if step >= self.refine_stop_iter: + did_prune = False + if self.post_refine_prune_every > 0 and step % self.post_refine_prune_every == 0: + n_prune = self._prune_gs(params, optimizers, state, step) + did_prune = True + if self.verbose and n_prune > 0: + print( + f"[Post-Refine] Step {step}: Pruned {n_prune} GSs. " + f"Now having {len(params['means'])} GSs." + ) + torch.cuda.empty_cache() + + # Prune one extra time at the penultimate step so the final saved checkpoint/ply + # (typically written before the last optimizer/strategy step) already reflects pruning. + is_penultimate_step = self.max_steps > 1 and step == self.max_steps - 2 + if ( + is_penultimate_step + and self.post_refine_prune_every > 0 + and not did_prune + ): + n_prune = self._prune_gs( + params, optimizers, state, step=self.max_steps - 1 + ) + if self.verbose and n_prune > 0: + print( + f"[Post-Refine] Penultimate step {step}: Pruned {n_prune} GSs. " + f"Now having {len(params['means'])} GSs." + ) + torch.cuda.empty_cache() return self._update_state(params, state, info, packed=packed) From 7a67867e88f8f434d428a2cf899d661cf9dd8bdf Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Thu, 29 Jan 2026 10:56:13 +0800 Subject: [PATCH 2/6] [Feat] Mask out sky pixels when computing the color loss --- examples/extended_trainer.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/examples/extended_trainer.py b/examples/extended_trainer.py index 545fdc5..488dbd2 100644 --- a/examples/extended_trainer.py +++ b/examples/extended_trainer.py @@ -997,9 +997,26 @@ def train(self): ) # color loss - l1loss = F.l1_loss(colors, pixels) + colors_for_ssim = colors + if sky_mask is not None: + non_sky = (~sky_mask.bool()).clone() + if non_sky.dim() == 4 and non_sky.shape[-1] == 1: + non_sky = non_sky[..., 0] + if non_sky.any(): + diff = (colors - pixels).abs() + l1loss = diff[non_sky].mean() + colors_for_ssim = torch.where( + non_sky.unsqueeze(-1), colors, pixels + ) + else: + l1loss = F.l1_loss(colors, pixels) + else: + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - fused_ssim( - colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" + colors_for_ssim.permute(0, 3, 1, 2), + pixels.permute(0, 3, 1, 2), + padding="valid", ) loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda # tv loss From d18368cb713a846a45a3c11442ff9d21af9d5d06 Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Thu, 29 Jan 2026 13:49:22 +0800 Subject: [PATCH 3/6] [Feat] Support dynamic masks --- examples/datasets/colmap.py | 53 +++++++++++++++++++++++++++ examples/extended_trainer.py | 70 +++++++++++++++++++++++++++--------- 2 files changed, 107 insertions(+), 16 deletions(-) diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 32a2573..2d805b9 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -64,6 +64,7 @@ def __init__( test_every: int = 8, depth_dir_name: Optional[str] = None, normal_dir_name: Optional[str] = None, + dynamic_mask_dir_name: Optional[str] = None, sky_mask_dir_name: Optional[str] = None, ): self.data_dir = data_dir @@ -403,6 +404,29 @@ def __init__( path = os.path.join(normal_dir, base_name + ".png") self.normal_paths.append(path) + # Process dynamic mask paths. + # We primarily match the image filename (same basename + extension), but also + # fall back to {basename}.png which is a common convention. + if dynamic_mask_dir_name is None: + print( + "[Parser] No dynamic mask directory name provided. Skipping dynamic masks." + ) + self.dynamic_mask_paths = None + else: + dynamic_mask_dir = os.path.join(self.data_dir, dynamic_mask_dir_name) + self.dynamic_mask_paths = [] + print(f"[Parser] Building dynamic mask paths from: {dynamic_mask_dir}") + for img_name in self.image_names: + base_name, _ext = os.path.splitext(img_name) + candidate_same_ext = os.path.join(dynamic_mask_dir, img_name) + candidate_png = os.path.join(dynamic_mask_dir, base_name + ".png") + if os.path.exists(candidate_same_ext): + self.dynamic_mask_paths.append(candidate_same_ext) + elif os.path.exists(candidate_png): + self.dynamic_mask_paths.append(candidate_png) + else: + self.dynamic_mask_paths.append(candidate_same_ext) + # Process sky mask paths. # We primarily match the image filename (same basename + extension), but also # fall back to {basename}.png which is a common convention. @@ -492,6 +516,7 @@ def _load_base_sample(self, index: int) -> Dict[str, Any]: depth_data = None normal_data = None + dynamic_mask_data = None sky_mask_data = None if self.parser.depth_paths is not None: @@ -507,6 +532,15 @@ def _load_base_sample(self, index: int) -> Dict[str, Any]: normal_data = imageio.imread(normal_path)[..., :3] # Known to be .png except Exception as e: print(f"Warning: Could not load normal {normal_path}: {e}") + if getattr(self.parser, "dynamic_mask_paths", None) is not None: + dynamic_mask_path = self.parser.dynamic_mask_paths[index] + try: + dyn_img = cv2.imread(dynamic_mask_path, cv2.IMREAD_GRAYSCALE) + if dyn_img is None: + raise RuntimeError("cv2.imread returned None") + dynamic_mask_data = dyn_img > 0 + except Exception as e: + print(f"Warning: Could not load dynamic mask {dynamic_mask_path}: {e}") if getattr(self.parser, "sky_mask_paths", None) is not None: sky_mask_path = self.parser.sky_mask_paths[index] try: @@ -535,6 +569,14 @@ def _load_base_sample(self, index: int) -> Dict[str, Any]: if normal_data is not None: normal_data = cv2.remap(normal_data, mapx, mapy, cv2.INTER_LINEAR) normal_data = normal_data[y : y + h, x : x + w] + # Apply to dynamic mask (Note: INTER_NEAREST) + if dynamic_mask_data is not None: + dynamic_mask_u8 = dynamic_mask_data.astype(np.uint8) * 255 + dynamic_mask_u8 = cv2.remap( + dynamic_mask_u8, mapx, mapy, cv2.INTER_NEAREST + ) + dynamic_mask_u8 = dynamic_mask_u8[y : y + h, x : x + w] + dynamic_mask_data = dynamic_mask_u8 > 0 # Apply to sky mask (Note: INTER_NEAREST) if sky_mask_data is not None: sky_mask_u8 = sky_mask_data.astype(np.uint8) * 255 @@ -547,6 +589,7 @@ def _load_base_sample(self, index: int) -> Dict[str, Any]: "depth": depth_data, "normal": normal_data, "mask": mask, + "dynamic_mask": dynamic_mask_data, "sky_mask": sky_mask_data, "K": K, "camtoworld": camtoworlds, @@ -572,6 +615,11 @@ def _convert_to_tensors( if sample["mask"] is not None else None ) + tensor_sample["dynamic_mask"] = ( + torch.from_numpy(sample["dynamic_mask"]).bool().to(device) + if sample.get("dynamic_mask") is not None + else None + ) tensor_sample["sky_mask"] = ( torch.from_numpy(sample["sky_mask"]).bool().to(device) if sample.get("sky_mask") is not None @@ -592,6 +640,7 @@ def _prepare_sample( depth = sample["depth"] normal = sample["normal"] mask = sample["mask"] + dynamic_mask = sample.get("dynamic_mask") sky_mask = sample.get("sky_mask") K = sample["K"] camtoworlds = sample["camtoworld"] @@ -614,6 +663,8 @@ def _prepare_sample( normal = normal[y_slice, x_slice] if mask is not None: mask = mask[y_slice, x_slice] + if dynamic_mask is not None: + dynamic_mask = dynamic_mask[y_slice, x_slice] if sky_mask is not None: sky_mask = sky_mask[y_slice, x_slice] K = K.clone() @@ -640,6 +691,8 @@ def _prepare_sample( } if mask is not None: data["mask"] = mask + if dynamic_mask is not None: + data["dynamic_mask"] = dynamic_mask if sky_mask is not None: data["sky_mask"] = sky_mask diff --git a/examples/extended_trainer.py b/examples/extended_trainer.py index 488dbd2..57a4353 100644 --- a/examples/extended_trainer.py +++ b/examples/extended_trainer.py @@ -232,6 +232,10 @@ class Config: """Starting iteration for normal consistency regularization""" consistency_normal_loss_activation_step: int = 7000 + # Dynamic mask (optional): if set, loads per-image `dynamic_mask` and excludes those + # pixels from color/depth/normal losses. + dynamic_mask_dir_name: Optional[str] = None + # Sky supervision (optional): if set, loads per-image `sky_mask` and adds a loss that # encourages low alpha in sky pixels (do not occlude sky). sky_mask_dir_name: Optional[str] = None @@ -457,6 +461,7 @@ def __init__( test_every=cfg.test_every, depth_dir_name=cfg.depth_dir_name, normal_dir_name=cfg.normal_dir_name, + dynamic_mask_dir_name=cfg.dynamic_mask_dir_name, sky_mask_dir_name=cfg.sky_mask_dir_name, ) self.trainset = Dataset( @@ -819,6 +824,12 @@ def train(self): # Expected shapes: [B,H,W] or [H,W]; keep [B,H,W] if sky_mask.dim() == 2: sky_mask = sky_mask.unsqueeze(0) + dynamic_mask = data.get("dynamic_mask") + if dynamic_mask is not None: + dynamic_mask = dynamic_mask.to(device) + # Expected shapes: [B,H,W] or [H,W]; keep [B,H,W] + if dynamic_mask.dim() == 2: + dynamic_mask = dynamic_mask.unsqueeze(0) # Optional priors depth_prior = None @@ -998,20 +1009,23 @@ def train(self): # color loss colors_for_ssim = colors + valid_color = torch.ones( + (pixels.shape[0], pixels.shape[1], pixels.shape[2]), + dtype=torch.bool, + device=device, + ) if sky_mask is not None: - non_sky = (~sky_mask.bool()).clone() - if non_sky.dim() == 4 and non_sky.shape[-1] == 1: - non_sky = non_sky[..., 0] - if non_sky.any(): - diff = (colors - pixels).abs() - l1loss = diff[non_sky].mean() - colors_for_ssim = torch.where( - non_sky.unsqueeze(-1), colors, pixels - ) - else: - l1loss = F.l1_loss(colors, pixels) + valid_color &= ~sky_mask.bool() + if dynamic_mask is not None: + valid_color &= ~dynamic_mask.bool() + + if valid_color.any(): + diff = (colors - pixels).abs() + l1loss = diff[valid_color].mean() + colors_for_ssim = torch.where(valid_color.unsqueeze(-1), colors, pixels) else: - l1loss = F.l1_loss(colors, pixels) + l1loss = torch.tensor(0.0, device=device) + colors_for_ssim = pixels ssimloss = 1.0 - fused_ssim( colors_for_ssim.permute(0, 3, 1, 2), @@ -1047,7 +1061,12 @@ def train(self): # depth loss if need_depth_prior: median_depths = info.get("render_median") - depth_loss = self.compute_depth_loss(depths, median_depths, depth_prior) + depth_loss = self.compute_depth_loss( + depths, + median_depths, + depth_prior, + valid_mask=(~dynamic_mask) if dynamic_mask is not None else None, + ) loss += cfg.depth_loss_weight * depth_loss consistency_norm_loss = None @@ -1065,6 +1084,10 @@ def train(self): and render_normals is not None ): mask_consistency = torch.ones_like(alphas) + if dynamic_mask is not None: + mask_consistency = mask_consistency * ( + (~dynamic_mask).float().unsqueeze(-1) + ) consistency_norm_loss = self.compute_normal_loss( F.normalize(render_normals, dim=-1), surf_normals_from_depth, @@ -1077,6 +1100,8 @@ def train(self): mask = torch.ones_like(depths).float() # [B,H,W,1] if normal_prior_mask is not None: mask = mask * normal_prior_mask + if dynamic_mask is not None: + mask = mask * ((~dynamic_mask).float().unsqueeze(-1)) # surface normal loss (from depth) if ( @@ -1730,6 +1755,7 @@ def compute_depth_loss( expected_depth: Optional[torch.Tensor], median_depth: Optional[torch.Tensor], gt_depth: torch.Tensor, + valid_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Computes a weighted L1 loss between predicted depths and the ground-truth depth map. @@ -1744,6 +1770,12 @@ def compute_depth_loss( return torch.tensor(0.0, device=device) valid_pix = gt_depth > 0.0 + if valid_mask is not None: + if valid_mask.dim() == 2: + valid_mask = valid_mask.unsqueeze(0) + if valid_mask.dim() == 3: + valid_mask = valid_mask.unsqueeze(-1) + valid_pix = valid_pix & valid_mask.bool() if not valid_pix.any(): return torch.tensor(0.0, device=device) @@ -1799,11 +1831,17 @@ def compute_normal_loss( # Avoid sparse access pattern from masked_select by zeroing invalid pixels masked_dot = dot * mask # Zero out invalid pixels (faster than masked_select) - # Count valid pixels per image; clamp to avoid division by zero - valid_counts = mask.sum(dim=(1, 2)).clamp(min=1) + # Count valid pixels per image + valid_counts = mask.sum(dim=(1, 2)) + has_valid = valid_counts > 0 # Compute per-image cosine loss: 1 - mean(cos(theta)) - per_batch_loss = 1 - (masked_dot.sum(dim=(1, 2)) / valid_counts) + per_batch_loss = torch.zeros( + pred_normals_bhw3.shape[0], device=pred_normals_bhw3.device + ) + per_batch_loss[has_valid] = 1 - ( + masked_dot.sum(dim=(1, 2))[has_valid] / valid_counts[has_valid] + ) # Return mean loss across the batch return per_batch_loss.mean() From 832553506a404aa6f2244e298043571d5389c629 Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Thu, 29 Jan 2026 20:00:10 +0800 Subject: [PATCH 4/6] [Fix] Fix use_single_camera_mode parameter failure --- examples/preprocess/run_hloc_sfm.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/examples/preprocess/run_hloc_sfm.py b/examples/preprocess/run_hloc_sfm.py index f2f0ef7..a4a51f8 100644 --- a/examples/preprocess/run_hloc_sfm.py +++ b/examples/preprocess/run_hloc_sfm.py @@ -34,6 +34,17 @@ from hloc_utils import run_hloc, CameraModel, PANO_CONFIG +def _str2bool(v): + if isinstance(v, bool): + return v + s = str(v).strip().lower() + if s in {"true", "1", "yes", "y", "t", "on"}: + return True + if s in {"false", "0", "no", "n", "f", "off"}: + return False + raise argparse.ArgumentTypeError(f"Expected a boolean value, got: {v!r}") + + def copy_images_fast( image_dir: Path, output_root: Path, image_prefix: str = "frame_" ) -> Path: @@ -207,7 +218,16 @@ def main(): choices=["netvlad", "megaloc", "dir", "openibl"], help="Global descriptor used for image retrieval.", ) - parser.add_argument("--use_single_camera_mode", type=bool, default=True) + # NOTE: avoid `type=bool` (bool("False") == True). This accepts: + # --use_single_camera_mode (True) + # --use_single_camera_mode True/False/0/1/... + parser.add_argument( + "--use_single_camera_mode", + type=_str2bool, + nargs="?", + const=True, + default=True, + ) parser.add_argument( "--is_panorama", action="store_true", From d981a4898edb516109074dcb3beff195e8351350 Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Mon, 2 Feb 2026 17:09:53 +0800 Subject: [PATCH 5/6] [Feat] Add data prefetching to VRAM; use InfiniteRandomSampler instead of shuffle to prevent iterator reset --- examples/extended_trainer.py | 180 ++++++++++++++++++++++++++++++----- 1 file changed, 155 insertions(+), 25 deletions(-) diff --git a/examples/extended_trainer.py b/examples/extended_trainer.py index 57a4353..d1e653e 100644 --- a/examples/extended_trainer.py +++ b/examples/extended_trainer.py @@ -49,6 +49,100 @@ from nerfview import CameraState, RenderTabState, apply_float_colormap +def _record_stream_recursive(obj: object, stream: torch.cuda.Stream) -> None: + if isinstance(obj, torch.Tensor) and obj.is_cuda: + obj.record_stream(stream) + return + if isinstance(obj, dict): + for v in obj.values(): + _record_stream_recursive(v, stream) + return + if isinstance(obj, (list, tuple)): + for v in obj: + _record_stream_recursive(v, stream) + return + + +def _to_device_recursive( + obj: object, + device: Union[str, torch.device], + *, + non_blocking: bool, +) -> object: + if isinstance(obj, torch.Tensor): + if obj.device == torch.device(device): + return obj + return obj.to(device, non_blocking=non_blocking) + if isinstance(obj, dict): + return { + k: _to_device_recursive(v, device, non_blocking=non_blocking) + for k, v in obj.items() + } + if isinstance(obj, list): + return [_to_device_recursive(v, device, non_blocking=non_blocking) for v in obj] + if isinstance(obj, tuple): + return tuple( + _to_device_recursive(v, device, non_blocking=non_blocking) for v in obj + ) + return obj + + +class _InfiniteRandomSampler(torch.utils.data.Sampler[int]): + """An infinite sampler that yields shuffled indices forever. + + This avoids DataLoader epoch boundaries (StopIteration), which can otherwise cause + periodic stalls while the iterator/workers refill the prefetch queue. + """ + + def __init__(self, data_source: torch.utils.data.Dataset, seed: int) -> None: + self.data_source = data_source + self.seed = seed + + def __iter__(self): + g = torch.Generator() + g.manual_seed(self.seed) + n = len(self.data_source) + while True: + # Generate a fresh permutation each cycle to mimic epoch-wise shuffle. + yield from torch.randperm(n, generator=g).tolist() + + def __len__(self) -> int: + # A large value to satisfy APIs that may query length. + return 2**31 + + +class _CudaPrefetcher: + """Prefetches DataLoader batches to GPU on a dedicated CUDA stream.""" + + def __init__(self, loader_iter, device: Union[str, torch.device]) -> None: + self.loader_iter = loader_iter + self.device = torch.device(device) + if self.device.type != "cuda": + raise ValueError(f"_CudaPrefetcher requires a CUDA device, got {self.device}") + self.stream = torch.cuda.Stream(device=self.device) + self._next_batch = None + self._preload() + + def _preload(self) -> None: + try: + batch = next(self.loader_iter) + except StopIteration: + self._next_batch = None + return + with torch.cuda.stream(self.stream): + batch = _to_device_recursive(batch, self.device, non_blocking=True) + self._next_batch = batch + + def next(self): + if self._next_batch is None: + return None + torch.cuda.current_stream(self.device).wait_stream(self.stream) + batch = self._next_batch + _record_stream_recursive(batch, torch.cuda.current_stream(self.device)) + self._preload() + return batch + + @dataclass class Config: # Disable viewer @@ -78,6 +172,10 @@ class Config: camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" # Dataset preload mode: none, cpu, or cuda dataset_preload: Literal["none", "cpu", "cuda"] = "none" + # Asynchronously prefetch the next batch to GPU (hides H2D latency for none/cpu preload). + prefetch_to_gpu: bool = False + # Use an infinite sampler to avoid DataLoader epoch-boundary stalls. + infinite_sampler: bool = False # Port for the viewer server port: int = 8080 @@ -206,7 +304,7 @@ class Config: the 'images' directory, load the dense depth maps from it, and use their depth values for regularization. """ - depth_dir_name: Optional[str] = "moge_depth" # "pi3_depth" + depth_dir_name: Optional[str] = None # "pi3_depth" """Weight of the depth loss""" depth_loss_weight: float = 0.25 """Starting iteration for depth regularization""" @@ -218,7 +316,7 @@ class Config: the 'images' directory, load the dense normal maps from it, and use their normal values for regularization. """ - normal_dir_name: Optional[str] = "moge_normal" # "moge_normal" + normal_dir_name: Optional[str] = None # "moge_normal" """Weight of the render_normal_loss""" render_normal_loss_weight: float = 0.1 """Starting iteration for render_normal regularization""" @@ -780,14 +878,28 @@ def train(self): pin_memory = cfg.dataset_preload != "cuda" train_num_workers = 0 if cfg.dataset_preload == "cuda" else 8 - trainloader = torch.utils.data.DataLoader( - self.trainset, - batch_size=cfg.batch_size, - shuffle=True, - num_workers=train_num_workers, - persistent_workers=train_num_workers > 0, - pin_memory=pin_memory, - ) + if cfg.infinite_sampler: + sampler = _InfiniteRandomSampler( + self.trainset, seed=42 + int(self.local_rank) + ) + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=False, + sampler=sampler, + num_workers=train_num_workers, + persistent_workers=train_num_workers > 0, + pin_memory=pin_memory, + ) + else: + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=train_num_workers, + persistent_workers=train_num_workers > 0, + pin_memory=pin_memory, + ) trainloader_iter = iter(trainloader) # Training loop. @@ -797,6 +909,14 @@ def train(self): cached_h, cached_w = -1, -1 pbar = tqdm.tqdm(range(init_step, max_steps)) + use_prefetch = ( + cfg.prefetch_to_gpu + and cfg.dataset_preload != "cuda" + and torch.cuda.is_available() + ) + prefetcher = _CudaPrefetcher(trainloader_iter, device) if use_prefetch else None + data = prefetcher.next() if prefetcher is not None else None + for step in pbar: if not cfg.disable_viewer: while self.viewer.state == "paused": @@ -804,29 +924,36 @@ def train(self): self.viewer.lock.acquire() tic = time.time() - try: - data = next(trainloader_iter) - except StopIteration: - trainloader_iter = iter(trainloader) - data = next(trainloader_iter) - - camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] - Ks = data["K"].to(device) # [1, 3, 3] - pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + if prefetcher is None: + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + else: + if data is None: + trainloader_iter = iter(trainloader) + prefetcher = _CudaPrefetcher(trainloader_iter, device) + data = prefetcher.next() + + assert data is not None + if prefetcher is None: + data = _to_device_recursive(data, device, non_blocking=True) # type: ignore[assignment] + camtoworlds = camtoworlds_gt = data["camtoworld"] # [B, 4, 4] + Ks = data["K"] # [B, 3, 3] + pixels = data["image"] / 255.0 # [B, H, W, 3] num_train_rays_per_step = ( pixels.shape[0] * pixels.shape[1] * pixels.shape[2] ) - image_ids = data["image_id"].to(device) - masks = data["mask"].to(device) if "mask" in data else None # [1, H, W] + image_ids = data["image_id"] + masks = data["mask"] if "mask" in data else None # [B, H, W] sky_mask = data.get("sky_mask") if sky_mask is not None: - sky_mask = sky_mask.to(device) # Expected shapes: [B,H,W] or [H,W]; keep [B,H,W] if sky_mask.dim() == 2: sky_mask = sky_mask.unsqueeze(0) dynamic_mask = data.get("dynamic_mask") if dynamic_mask is not None: - dynamic_mask = dynamic_mask.to(device) # Expected shapes: [B,H,W] or [H,W]; keep [B,H,W] if dynamic_mask.dim() == 2: dynamic_mask = dynamic_mask.unsqueeze(0) @@ -839,7 +966,7 @@ def train(self): and isinstance(raw_depth_prior, torch.Tensor) and raw_depth_prior.numel() > 0 ): - depth_prior = raw_depth_prior.to(device) + depth_prior = raw_depth_prior if depth_prior.dim() == 3: depth_prior = depth_prior.unsqueeze(0) @@ -851,7 +978,7 @@ def train(self): and isinstance(raw_normal_prior, torch.Tensor) and raw_normal_prior.numel() > 0 ): - normal_prior = raw_normal_prior.to(device) + normal_prior = raw_normal_prior if normal_prior.dim() == 3: normal_prior = normal_prior.unsqueeze(0) ones_like = torch.ones_like(normal_prior) @@ -1464,6 +1591,9 @@ def train(self): # Update the scene. self.viewer.update(step, num_train_rays_per_step) + if prefetcher is not None: + data = prefetcher.next() + @torch.no_grad() def eval(self, step: int, stage: str = "val"): """Entry for evaluation.""" From 3593994563109d17564cabb8ecd4ae7d83f008ec Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Mon, 2 Feb 2026 19:59:06 +0800 Subject: [PATCH 6/6] [Style] format --- examples/extended_trainer.py | 4 +++- gsplat/strategy/improved.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/extended_trainer.py b/examples/extended_trainer.py index d1e653e..f5e4bb8 100644 --- a/examples/extended_trainer.py +++ b/examples/extended_trainer.py @@ -118,7 +118,9 @@ def __init__(self, loader_iter, device: Union[str, torch.device]) -> None: self.loader_iter = loader_iter self.device = torch.device(device) if self.device.type != "cuda": - raise ValueError(f"_CudaPrefetcher requires a CUDA device, got {self.device}") + raise ValueError( + f"_CudaPrefetcher requires a CUDA device, got {self.device}" + ) self.stream = torch.cuda.Stream(device=self.device) self._next_batch = None self._preload() diff --git a/gsplat/strategy/improved.py b/gsplat/strategy/improved.py index 11269c8..93d37ac 100644 --- a/gsplat/strategy/improved.py +++ b/gsplat/strategy/improved.py @@ -259,7 +259,10 @@ def step_post_backward( if step >= self.refine_stop_iter: did_prune = False - if self.post_refine_prune_every > 0 and step % self.post_refine_prune_every == 0: + if ( + self.post_refine_prune_every > 0 + and step % self.post_refine_prune_every == 0 + ): n_prune = self._prune_gs(params, optimizers, state, step) did_prune = True if self.verbose and n_prune > 0: