From 2cb406e6bfc16235aeab2a76d5e3f08229e6fd87 Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Mon, 24 Nov 2025 16:51:57 +0800 Subject: [PATCH 1/7] [Refactor] Remove MCMC strategy --- docs/3dgut.md | 10 +- docs/source/apis/strategy.rst | 2 +- docs/source/index.rst | 2 +- docs/source/tests/eval.rst | 12 +- examples/benchmarks/3dgut/mcmc.sh | 59 ------ examples/benchmarks/3dgut/mcmc_zipnerf.sh | 60 ------ examples/benchmarks/bilarf/mcmc_bilarf.sh | 21 -- examples/benchmarks/compression/mcmc.sh | 55 ------ examples/benchmarks/compression/mcmc_tt.sh | 48 ----- examples/benchmarks/fisheye/mcmc_zipnerf.sh | 58 ------ .../fisheye/mcmc_zipnerf_undistorted.sh | 58 ------ examples/benchmarks/mcmc.sh | 57 ------ examples/benchmarks/mcmc_4gpus.sh | 47 ----- examples/extended_trainer.py | 43 +--- examples/simple_trainer.py | 43 +--- gsplat/__init__.py | 3 +- gsplat/strategy/__init__.py | 1 - gsplat/strategy/mcmc.py | 187 ------------------ tests/test_strategy.py | 20 +- 19 files changed, 19 insertions(+), 767 deletions(-) delete mode 100644 examples/benchmarks/3dgut/mcmc.sh delete mode 100644 examples/benchmarks/3dgut/mcmc_zipnerf.sh delete mode 100644 examples/benchmarks/bilarf/mcmc_bilarf.sh delete mode 100644 examples/benchmarks/compression/mcmc.sh delete mode 100644 examples/benchmarks/compression/mcmc_tt.sh delete mode 100644 examples/benchmarks/fisheye/mcmc_zipnerf.sh delete mode 100644 examples/benchmarks/fisheye/mcmc_zipnerf_undistorted.sh delete mode 100644 examples/benchmarks/mcmc.sh delete mode 100644 examples/benchmarks/mcmc_4gpus.sh delete mode 100644 gsplat/strategy/mcmc.py diff --git a/docs/3dgut.md b/docs/3dgut.md index 34a5e9a..462bd0c 100644 --- a/docs/3dgut.md +++ b/docs/3dgut.md @@ -12,14 +12,12 @@ Here are the instructions on how to use this feature. #### Training -Simplly passing in `--with_ut --with_eval3d` to the `simple_trainer.py` arg list will enable training with 3DGUT! And note in gsplat we only support MCMC densification strategy for 3DGUT: +Simply passing in `--with_ut --with_eval3d` to the `simple_trainer.py` arg list will enable training with 3DGUT: ``` -python examples/simple_trainer.py mcmc --with_ut --with_eval3d ... +python examples/simple_trainer.py default --with_ut --with_eval3d ... ``` -For benchmarking on MipNeRF360 Dataset, please checkout `examples/benchmarks/3dgut/mcmc.sh` - Note if you are not familiar with how to get started with `simple_trainer.py`, please checkout [README.md](README.md) first! #### Rendering @@ -27,12 +25,12 @@ Note if you are not familiar with how to get started with `simple_trainer.py`, p Once trained, you could view the 3DGS and play with the distortion effect supported through 3DGUT via our viewer: ```bash -CUDA_VISIBLE_DEVICES=0 python simple_viewer_3dgut.py --ckpt results/benchmark_mcmc_1M_3dgut/garden/ckpt_29999_rank0.pt +CUDA_VISIBLE_DEVICES=0 python simple_viewer_3dgut.py --ckpt /garden/ckpt_29999_rank0.pt ``` Or a more comprehensive nerfstudio-style viewer to export videos. (note changing distortion is not yet supported in this comprehensive viewer!) ```bash -CUDA_VISIBLE_DEVICES=0 python simple_viewer.py --with_ut --with_eval3d --ckpt results/benchmark_mcmc_1M_3dgut/garden/ckpt_29999_rank0.pt +CUDA_VISIBLE_DEVICES=0 python simple_viewer.py --with_ut --with_eval3d --ckpt /garden/ckpt_29999_rank0.pt ``` ### For users using gsplat' API: diff --git a/docs/source/apis/strategy.rst b/docs/source/apis/strategy.rst index 8f6c5db..fabdc82 100644 --- a/docs/source/apis/strategy.rst +++ b/docs/source/apis/strategy.rst @@ -72,5 +72,5 @@ Below are the strategies that are currently implemented in `gsplat`: .. autoclass:: DefaultStrategy :members: -.. autoclass:: MCMCStrategy +.. autoclass:: ImprovedStrategy :members: diff --git a/docs/source/index.rst b/docs/source/index.rst index c123ef1..c82c1e8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -27,7 +27,7 @@ faster, more memory efficient, and with a growing list of new features! * *gsplat* is equipped with the **latest and greatest** 3D Gaussian Splatting techniques, including `absgrad `_, `anti-aliasing `_, - `3DGS-MCMC `_ etc. And more to come! + and improved densification heuristics. And more to come! .. raw:: html diff --git a/docs/source/tests/eval.rst b/docs/source/tests/eval.rst index 49c1352..e4fbc32 100644 --- a/docs/source/tests/eval.rst +++ b/docs/source/tests/eval.rst @@ -28,7 +28,7 @@ Powered by `gsplat`'s efficient CUDA implementation, the training takes up to Feature Ablation ---------------------------------------------- -Evaluation of features provided in `gsplat` on Mip-NeRF (averaged over 7 scenes). We ablate `gsplat` with default settings, with absgrad and mcmc densification strategies, and antialiased mode. +Evaluation of features provided in `gsplat` on Mip-NeRF (averaged over 7 scenes). We ablate `gsplat` with default settings, the absgrad densification heuristic, and antialiased mode. Absgrad method uses `--grow_grad2d 0.0006` config. These results are obtained with an A100. +-----------------------------+-------+-------+-------+----------+---------+------------+ @@ -40,16 +40,8 @@ Absgrad method uses `--grow_grad2d 0.0006` config. These results are obtained wi +-----------------------------+-------+-------+-------+----------+---------+------------+ | antialiased | 29.03 | 0.87 | 0.14 | 3377807 | 5.87 | 19.52 | +-----------------------------+-------+-------+-------+----------+---------+------------+ -| mcmc (1 mill) | 29.18 | 0.87 | 0.14 | 1000000 | 1.98 | 15.42 | -+-----------------------------+-------+-------+-------+----------+---------+------------+ -| mcmc (2 mill) | 29.53 | 0.88 | 0.13 | 2000000 | 3.43 | 21.79 | -+-----------------------------+-------+-------+-------+----------+---------+------------+ -| mcmc (3 mill) | 29.65 | 0.89 | 0.12 | 3000000 | 4.99 | 27.63 | -+-----------------------------+-------+-------+-------+----------+---------+------------+ | absgrad & antialiased | 29.14 | 0.88 | 0.13 | 2563156 | 4.57 | 18.43 | +-----------------------------+-------+-------+-------+----------+---------+------------+ -| mcmc & antialiased | 29.23 | 0.87 | 0.14 | 1000000 | 2.00 | 15.75 | -+-----------------------------+-------+-------+-------+----------+---------+------------+ Trains Faster with Less GPU Memory @@ -239,4 +231,4 @@ The evaluation of `inria-X` can be reproduced with our forked wersion of the official implementation at `here `_; you need to change the :code:`--model_type 2dgs` to :code:`--model_type 2dgs-inria` in -:code:`benchmars/basic_2dgs` and run command :code:`cd examples; bash benchmarks/basic_2dgs.sh` (commit 28c928a). \ No newline at end of file +:code:`benchmars/basic_2dgs` and run command :code:`cd examples; bash benchmarks/basic_2dgs.sh` (commit 28c928a). diff --git a/examples/benchmarks/3dgut/mcmc.sh b/examples/benchmarks/3dgut/mcmc.sh deleted file mode 100644 index bbbe0c0..0000000 --- a/examples/benchmarks/3dgut/mcmc.sh +++ /dev/null @@ -1,59 +0,0 @@ -SCENE_DIR="data/360_v2" -RESULT_DIR="results/benchmark_mcmc_1M_3dgut" -SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers -RENDER_TRAJ_PATH="ellipse" - -CAP_MAX=1000000 - -for SCENE in $SCENE_LIST; -do - if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then - DATA_FACTOR=2 - else - DATA_FACTOR=4 - fi - - echo "Running $SCENE" - - # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ - --with_eval3d --with_ut \ - --strategy.cap-max $CAP_MAX \ - --render_traj_path $RENDER_TRAJ_PATH \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ - - # run eval and render - for CKPT in $RESULT_DIR/$SCENE/ckpts/*; - do - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ - --with_eval3d --with_ut \ - --strategy.cap-max $CAP_MAX \ - --render_traj_path $RENDER_TRAJ_PATH \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ \ - --ckpt $CKPT - done -done - - -for SCENE in $SCENE_LIST; -do - echo "=== Eval Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/val*.json; - do - echo $STATS - cat $STATS; - echo - done - - echo "=== Train Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json; - do - echo $STATS - cat $STATS; - echo - done -done \ No newline at end of file diff --git a/examples/benchmarks/3dgut/mcmc_zipnerf.sh b/examples/benchmarks/3dgut/mcmc_zipnerf.sh deleted file mode 100644 index 934820b..0000000 --- a/examples/benchmarks/3dgut/mcmc_zipnerf.sh +++ /dev/null @@ -1,60 +0,0 @@ -SCENE_DIR="data/zipnerf" -SCENE_LIST="nyc alameda berlin london" -DATA_FACTOR=4 -RENDER_TRAJ_PATH="ellipse" - -RESULT_DIR="results/benchmark_mcmc_2M_zipnerf_3dgut" -CAP_MAX=2000000 - -for SCENE in $SCENE_LIST; -do - echo "Running $SCENE" - - # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ - --with_eval3d --with_ut \ - --strategy.cap-max $CAP_MAX \ - --opacity_reg 0.001 \ - --init_scale 0.5 \ - --use_bilateral_grid \ - --render_traj_path $RENDER_TRAJ_PATH \ - --camera_model fisheye \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ - - for CKPT in $RESULT_DIR/$SCENE/ckpts/*; - do - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ - --with_eval3d --with_ut \ - --strategy.cap-max $CAP_MAX \ - --opacity_reg 0.001 \ - --init_scale 0.5 \ - --use_bilateral_grid \ - --render_traj_path $RENDER_TRAJ_PATH \ - --camera_model fisheye \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ \ - --ckpt $CKPT - done -done - -for SCENE in $SCENE_LIST; -do - echo "=== Eval Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/val*.json; - do - echo $STATS - cat $STATS; - echo - done - - echo "=== Train Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json; - do - echo $STATS - cat $STATS; - echo - done -done \ No newline at end of file diff --git a/examples/benchmarks/bilarf/mcmc_bilarf.sh b/examples/benchmarks/bilarf/mcmc_bilarf.sh deleted file mode 100644 index 0446280..0000000 --- a/examples/benchmarks/bilarf/mcmc_bilarf.sh +++ /dev/null @@ -1,21 +0,0 @@ -SCENE_DIR="data/bilarf/bilarf_data/editscenes" -SCENE_LIST="rawnerf_windowlegovary rawnerf_sharpshadow scibldg" - -# SCENE_DIR="data/bilarf/bilarf_data/testscenes" -# SCENE_LIST="chinesearch lionpavilion pondbike statue strat building" - -RESULT_DIR="results/benchmark_bilarf" -RENDER_TRAJ_PATH="spiral" -DATA_FACTOR=4 - -for SCENE in $SCENE_LIST; -do - echo "Running $SCENE" - - # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ - --use_bilateral_grid \ - --render_traj_path $RENDER_TRAJ_PATH \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ -done diff --git a/examples/benchmarks/compression/mcmc.sh b/examples/benchmarks/compression/mcmc.sh deleted file mode 100644 index 2741858..0000000 --- a/examples/benchmarks/compression/mcmc.sh +++ /dev/null @@ -1,55 +0,0 @@ -SCENE_DIR="data/360_v2" -# eval all 9 scenes for benchmarking -SCENE_LIST="garden bicycle stump bonsai counter kitchen room treehill flowers" - -# # 0.36M GSs -# RESULT_DIR="results/benchmark_mcmc_0_36M_png_compression" -# CAP_MAX=360000 - -# # 0.49M GSs -# RESULT_DIR="results/benchmark_mcmc_0_49M_png_compression" -# CAP_MAX=490000 - -# 1M GSs -RESULT_DIR="results/benchmark_mcmc_1M_png_compression" -CAP_MAX=1000000 - -# # 4M GSs -# RESULT_DIR="results/benchmark_mcmc_4M_png_compression" -# CAP_MAX=4000000 - - -for SCENE in $SCENE_LIST; -do - if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then - DATA_FACTOR=2 - else - DATA_FACTOR=4 - fi - - echo "Running $SCENE" - - # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ - --strategy.cap-max $CAP_MAX \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ - - # eval: use vgg for lpips to align with other benchmarks - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ - --strategy.cap-max $CAP_MAX \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ \ - --lpips_net vgg \ - --compression png \ - --ckpt $RESULT_DIR/$SCENE/ckpts/ckpt_29999_rank0.pt -done - -# Zip the compressed files and summarize the stats -if command -v zip &> /dev/null -then - echo "Zipping results" - python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST -else - echo "zip command not found, skipping zipping" -fi \ No newline at end of file diff --git a/examples/benchmarks/compression/mcmc_tt.sh b/examples/benchmarks/compression/mcmc_tt.sh deleted file mode 100644 index 34d9f5e..0000000 --- a/examples/benchmarks/compression/mcmc_tt.sh +++ /dev/null @@ -1,48 +0,0 @@ -SCENE_DIR="data/tandt" -# eval all 9 scenes for benchmarking -SCENE_LIST="train truck" - -# # 0.36M GSs -# RESULT_DIR="results/benchmark_tt_mcmc_0_36M_png_compression" -# CAP_MAX=360000 - -# # 0.49M GSs -# RESULT_DIR="results/benchmark_tt_mcmc_0_49M_png_compression" -# CAP_MAX=490000 - -# 1M GSs -RESULT_DIR="results/benchmark_tt_mcmc_1M_png_compression" -CAP_MAX=1000000 - -# # 4M GSs -# RESULT_DIR="results/benchmark_tt_mcmc_4M_png_compression" -# CAP_MAX=4000000 - -for SCENE in $SCENE_LIST; -do - echo "Running $SCENE" - - # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor 1 \ - --strategy.cap-max $CAP_MAX \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ - - # eval: use vgg for lpips to align with other benchmarks - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor 1 \ - --strategy.cap-max $CAP_MAX \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ \ - --lpips_net vgg \ - --compression png \ - --ckpt $RESULT_DIR/$SCENE/ckpts/ckpt_29999_rank0.pt -done - -# Zip the compressed files and summarize the stats -if command -v zip &> /dev/null -then - echo "Zipping results" - python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST -else - echo "zip command not found, skipping zipping" -fi \ No newline at end of file diff --git a/examples/benchmarks/fisheye/mcmc_zipnerf.sh b/examples/benchmarks/fisheye/mcmc_zipnerf.sh deleted file mode 100644 index 97202fd..0000000 --- a/examples/benchmarks/fisheye/mcmc_zipnerf.sh +++ /dev/null @@ -1,58 +0,0 @@ -SCENE_DIR="data/zipnerf" -SCENE_LIST="berlin london nyc alameda" -DATA_FACTOR=4 -RENDER_TRAJ_PATH="ellipse" - -RESULT_DIR="results/benchmark_mcmc_2M_zipnerf" -CAP_MAX=2000000 - -for SCENE in $SCENE_LIST; -do - echo "Running $SCENE" - - # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ - --strategy.cap-max $CAP_MAX \ - --opacity_reg 0.001 \ - --init_scale 0.5 \ - --use_bilateral_grid \ - --render_traj_path $RENDER_TRAJ_PATH \ - --camera_model fisheye \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ - - for CKPT in $RESULT_DIR/$SCENE/ckpts/*; - do - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ - --strategy.cap-max $CAP_MAX \ - --opacity_reg 0.001 \ - --init_scale 0.5 \ - --use_bilateral_grid \ - --render_traj_path $RENDER_TRAJ_PATH \ - --camera_model fisheye \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ \ - --ckpt $CKPT - done -done - -for SCENE in $SCENE_LIST; -do - echo "=== Eval Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/val*.json; - do - echo $STATS - cat $STATS; - echo - done - - echo "=== Train Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json; - do - echo $STATS - cat $STATS; - echo - done -done \ No newline at end of file diff --git a/examples/benchmarks/fisheye/mcmc_zipnerf_undistorted.sh b/examples/benchmarks/fisheye/mcmc_zipnerf_undistorted.sh deleted file mode 100644 index 92aded9..0000000 --- a/examples/benchmarks/fisheye/mcmc_zipnerf_undistorted.sh +++ /dev/null @@ -1,58 +0,0 @@ -SCENE_DIR="data/zipnerf_undistorted" -SCENE_LIST="berlin london nyc alameda" -DATA_FACTOR=4 -RENDER_TRAJ_PATH="ellipse" - -RESULT_DIR="results/benchmark_mcmc_2M_zipnerf_undistorted" -CAP_MAX=2000000 - -for SCENE in $SCENE_LIST; -do - echo "Running $SCENE" - - # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ - --strategy.cap-max $CAP_MAX \ - --opacity_reg 0.001 \ - --init_scale 0.5 \ - --use_bilateral_grid \ - --render_traj_path $RENDER_TRAJ_PATH \ - --camera_model pinhole \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ - - for CKPT in $RESULT_DIR/$SCENE/ckpts/*; - do - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ - --strategy.cap-max $CAP_MAX \ - --opacity_reg 0.001 \ - --init_scale 0.5 \ - --use_bilateral_grid \ - --render_traj_path $RENDER_TRAJ_PATH \ - --camera_model pinhole \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ \ - --ckpt $CKPT - done -done - -for SCENE in $SCENE_LIST; -do - echo "=== Eval Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/val*.json; - do - echo $STATS - cat $STATS; - echo - done - - echo "=== Train Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json; - do - echo $STATS - cat $STATS; - echo - done -done \ No newline at end of file diff --git a/examples/benchmarks/mcmc.sh b/examples/benchmarks/mcmc.sh deleted file mode 100644 index 0eaa5c8..0000000 --- a/examples/benchmarks/mcmc.sh +++ /dev/null @@ -1,57 +0,0 @@ -SCENE_DIR="data/360_v2" -RESULT_DIR="results/benchmark_mcmc_1M" -SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers -RENDER_TRAJ_PATH="ellipse" - -CAP_MAX=1000000 - -for SCENE in $SCENE_LIST; -do - if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then - DATA_FACTOR=2 - else - DATA_FACTOR=4 - fi - - echo "Running $SCENE" - - # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ - --strategy.cap-max $CAP_MAX \ - --render_traj_path $RENDER_TRAJ_PATH \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ - - # run eval and render - for CKPT in $RESULT_DIR/$SCENE/ckpts/*; - do - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ - --strategy.cap-max $CAP_MAX \ - --render_traj_path $RENDER_TRAJ_PATH \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ \ - --ckpt $CKPT - done -done - - -for SCENE in $SCENE_LIST; -do - echo "=== Eval Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/val*.json; - do - echo $STATS - cat $STATS; - echo - done - - echo "=== Train Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json; - do - echo $STATS - cat $STATS; - echo - done -done \ No newline at end of file diff --git a/examples/benchmarks/mcmc_4gpus.sh b/examples/benchmarks/mcmc_4gpus.sh deleted file mode 100644 index 877871a..0000000 --- a/examples/benchmarks/mcmc_4gpus.sh +++ /dev/null @@ -1,47 +0,0 @@ -SCENE_DIR="data/360_v2" -RESULT_DIR="results/benchmark_mcmc_1M_4gpus" -SCENE_LIST="bonsai" # treehill flowers -RENDER_TRAJ_PATH="ellipse" - -CAP_MAX=250000 - -for SCENE in $SCENE_LIST; -do - if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then - DATA_FACTOR=2 - else - DATA_FACTOR=4 - fi - - echo "Running $SCENE" - - # train and eval at the last step (30000) - CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py mcmc --eval_steps 30000 --disable_viewer --data_factor $DATA_FACTOR \ - --steps_scaler 0.25 --packed \ - --strategy.cap-max $CAP_MAX \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ - -done - - -for SCENE in $SCENE_LIST; -do - echo "=== Eval Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/val_step7499.json; - do - echo $STATS - cat $STATS; - echo - done - - echo "=== Train Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/train_step7499_rank0.json; - do - echo $STATS - cat $STATS; - echo - done -done \ No newline at end of file diff --git a/examples/extended_trainer.py b/examples/extended_trainer.py index d157048..a404781 100644 --- a/examples/extended_trainer.py +++ b/examples/extended_trainer.py @@ -36,7 +36,7 @@ from gsplat.distributed import cli from gsplat.optimizers import SelectiveAdam from gsplat.rendering import rasterization -from gsplat.strategy import DefaultStrategy, MCMCStrategy, ImprovedStrategy +from gsplat.strategy import DefaultStrategy, ImprovedStrategy from gsplat_viewer import GsplatViewer, GsplatRenderTabState from nerfview import CameraState, RenderTabState, apply_float_colormap from utils_depth import get_implied_normal_from_depth @@ -113,8 +113,8 @@ class Config: # Far plane clipping distance far_plane: float = 1e10 - # Densification strategy selection (default / mcmc / improved) - strategy_type: Literal["default", "mcmc", "improved"] = "default" + # Densification strategy selection (default / improved) + strategy_type: Literal["default", "improved"] = "default" # Verbosity for densification logs strategy_verbose: bool = True # Densification hyper-parameters (see notes below; shared unless marked otherwise) @@ -135,13 +135,9 @@ class Config: revised_opacity: bool = False # ImprovedStrategy-only hyper-parameter budget: int = 2_000_000 - # MCMCStrategy-only hyper-parameters - mcmc_cap_max: int = 1_000_000 - mcmc_noise_lr: float = 5e5 - mcmc_min_opacity: float = 0.005 # Strategy instance (constructed from the type/params above) - strategy: Union[DefaultStrategy, MCMCStrategy, ImprovedStrategy] = field( + strategy: Union[DefaultStrategy, ImprovedStrategy] = field( init=False, repr=False ) # Use packed mode for rasterization, this leads to less memory usage but slightly slower. @@ -169,9 +165,6 @@ class Config: # LR for higher-order SH (detail) shN_lr: float = 2.5e-3 / 20 - # Opacity regularization - opacity_reg: float = 0.0 - ### Scale regularization """Weight of the regularisation loss encouraging gaussians to be flat, i.e. set their minimum scale to be small""" @@ -273,10 +266,6 @@ def adjust_steps(self, factor: float): strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) strategy.reset_every = int(strategy.reset_every * factor) strategy.refine_every = int(strategy.refine_every * factor) - elif isinstance(strategy, MCMCStrategy): - strategy.refine_start_iter = int(strategy.refine_start_iter * factor) - strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) - strategy.refine_every = int(strategy.refine_every * factor) elif isinstance(strategy, ImprovedStrategy): strategy.refine_start_iter = int(strategy.refine_start_iter * factor) strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) @@ -304,16 +293,6 @@ def rebuild_strategy(self): revised_opacity=self.revised_opacity, verbose=self.strategy_verbose, ) - elif self.strategy_type == "mcmc": - self.strategy = MCMCStrategy( - cap_max=self.mcmc_cap_max, - noise_lr=self.mcmc_noise_lr, - refine_start_iter=self.refine_start_iter, - refine_stop_iter=self.refine_stop_iter, - refine_every=self.refine_every, - min_opacity=self.mcmc_min_opacity, - verbose=self.strategy_verbose, - ) elif self.strategy_type == "improved": self.strategy = ImprovedStrategy( prune_opa=self.prune_opa, @@ -507,8 +486,6 @@ def __init__( self.strategy_state = self.cfg.strategy.initialize_state( scene_scale=self.scene_scale ) - elif isinstance(self.cfg.strategy, MCMCStrategy): - self.strategy_state = self.cfg.strategy.initialize_state() elif isinstance(self.cfg.strategy, ImprovedStrategy): self.strategy_state = self.cfg.strategy.initialize_state( scene_scale=self.scene_scale @@ -953,9 +930,6 @@ def train(self): loss += cfg.render_normal_loss_weight * render_normal_loss # regularizations - if cfg.opacity_reg > 0.0: - loss += cfg.opacity_reg * torch.sigmoid(self.splats["opacities"]).mean() - # the smallest scale is always near 0 if cfg.flat_reg > 0.0: loss += cfg.flat_reg * self.compute_flat_loss() @@ -1168,15 +1142,6 @@ def train(self): info=info, packed=cfg.packed, ) - elif isinstance(self.cfg.strategy, MCMCStrategy): - self.cfg.strategy.step_post_backward( - params=self.splats, - optimizers=self.optimizers, - state=self.strategy_state, - step=step, - info=info, - lr=schedulers[0].get_last_lr()[0], - ) else: assert_never(self.cfg.strategy) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 9f2a082..7c662e6 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -5,7 +5,7 @@ from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple import imageio import numpy as np @@ -35,7 +35,7 @@ from gsplat.distributed import cli from gsplat.optimizers import SelectiveAdam from gsplat.rendering import rasterization -from gsplat.strategy import DefaultStrategy, MCMCStrategy +from gsplat.strategy import DefaultStrategy from gsplat_viewer import GsplatViewer, GsplatRenderTabState from nerfview import CameraState, RenderTabState, apply_float_colormap @@ -112,9 +112,7 @@ class Config: far_plane: float = 1e10 # Strategy for GS densification - strategy: Union[DefaultStrategy, MCMCStrategy] = field( - default_factory=DefaultStrategy - ) + strategy: DefaultStrategy = field(default_factory=DefaultStrategy) # Use packed mode for rasterization, this leads to less memory usage but slightly slower. packed: bool = False # Use sparse gradients for optimization. (experimental) @@ -140,8 +138,6 @@ class Config: # LR for higher-order SH (detail) shN_lr: float = 2.5e-3 / 20 - # Opacity regularization - opacity_reg: float = 0.0 # Scale regularization scale_reg: float = 0.0 @@ -197,10 +193,6 @@ def adjust_steps(self, factor: float): strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) strategy.reset_every = int(strategy.reset_every * factor) strategy.refine_every = int(strategy.refine_every * factor) - elif isinstance(strategy, MCMCStrategy): - strategy.refine_start_iter = int(strategy.refine_start_iter * factor) - strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) - strategy.refine_every = int(strategy.refine_every * factor) else: assert_never(strategy) @@ -377,8 +369,6 @@ def __init__( self.strategy_state = self.cfg.strategy.initialize_state( scene_scale=self.scene_scale ) - elif isinstance(self.cfg.strategy, MCMCStrategy): - self.strategy_state = self.cfg.strategy.initialize_state() else: assert_never(self.cfg.strategy) @@ -521,11 +511,7 @@ def rasterize_splats( width=width, height=height, packed=self.cfg.packed, - absgrad=( - self.cfg.strategy.absgrad - if isinstance(self.cfg.strategy, DefaultStrategy) - else False - ), + absgrad=getattr(self.cfg.strategy, "absgrad", False), sparse_grad=self.cfg.sparse_grad, rasterize_mode=rasterize_mode, distributed=self.world_size > 1, @@ -683,8 +669,6 @@ def train(self): loss += tvloss # regularizations - if cfg.opacity_reg > 0.0: - loss += cfg.opacity_reg * torch.sigmoid(self.splats["opacities"]).mean() if cfg.scale_reg > 0.0: loss += cfg.scale_reg * torch.exp(self.splats["scales"]).mean() @@ -838,15 +822,6 @@ def train(self): info=info, packed=cfg.packed, ) - elif isinstance(self.cfg.strategy, MCMCStrategy): - self.cfg.strategy.step_post_backward( - params=self.splats, - optimizers=self.optimizers, - state=self.strategy_state, - step=step, - info=info, - lr=schedulers[0].get_last_lr()[0], - ) else: assert_never(self.cfg.strategy) @@ -1180,16 +1155,6 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): strategy=DefaultStrategy(verbose=True), ), ), - "mcmc": ( - "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", - Config( - init_opa=0.5, - init_scale=0.1, - opacity_reg=0.01, - scale_reg=0.01, - strategy=MCMCStrategy(verbose=True), - ), - ), } cfg = tyro.extras.overridable_config_cli(configs) cfg.adjust_steps(cfg.steps_scaler) diff --git a/gsplat/__init__.py b/gsplat/__init__.py index 27c7b0a..319caae 100644 --- a/gsplat/__init__.py +++ b/gsplat/__init__.py @@ -28,13 +28,12 @@ rasterization_2dgs_inria_wrapper, rasterization_inria_wrapper, ) -from .strategy import DefaultStrategy, MCMCStrategy, Strategy +from .strategy import DefaultStrategy, Strategy from .version import __version__ all = [ "PngCompression", "DefaultStrategy", - "MCMCStrategy", "Strategy", "rasterization", "rasterization_2dgs", diff --git a/gsplat/strategy/__init__.py b/gsplat/strategy/__init__.py index 4c9aa86..3da3d84 100644 --- a/gsplat/strategy/__init__.py +++ b/gsplat/strategy/__init__.py @@ -1,4 +1,3 @@ from .base import Strategy from .default import DefaultStrategy -from .mcmc import MCMCStrategy from .improved import ImprovedStrategy diff --git a/gsplat/strategy/mcmc.py b/gsplat/strategy/mcmc.py deleted file mode 100644 index c07e173..0000000 --- a/gsplat/strategy/mcmc.py +++ /dev/null @@ -1,187 +0,0 @@ -import math -from dataclasses import dataclass -from typing import Any, Dict, Union - -import torch -from torch import Tensor - -from .base import Strategy -from .ops import inject_noise_to_position, relocate, sample_add - - -@dataclass -class MCMCStrategy(Strategy): - """Strategy that follows the paper: - - `3D Gaussian Splatting as Markov Chain Monte Carlo `_ - - This strategy will: - - - Periodically teleport GSs with low opacity to a place that has high opacity. - - Periodically introduce new GSs sampled based on the opacity distribution. - - Periodically perturb the GSs locations. - - Args: - cap_max (int): Maximum number of GSs. Default to 1_000_000. - noise_lr (float): MCMC samping noise learning rate. Default to 5e5. - refine_start_iter (int): Start refining GSs after this iteration. Default to 500. - refine_stop_iter (int): Stop refining GSs after this iteration. Default to 25_000. - refine_every (int): Refine GSs every this steps. Default to 100. - min_opacity (float): GSs with opacity below this value will be pruned. Default to 0.005. - verbose (bool): Whether to print verbose information. Default to False. - - Examples: - - >>> from gsplat import MCMCStrategy, rasterization - >>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ... - >>> optimizers: Dict[str, torch.optim.Optimizer] = ... - >>> strategy = MCMCStrategy() - >>> strategy.check_sanity(params, optimizers) - >>> strategy_state = strategy.initialize_state() - >>> for step in range(1000): - ... render_image, render_alpha, info = rasterization(...) - ... loss = ... - ... loss.backward() - ... strategy.step_post_backward(params, optimizers, strategy_state, step, info, lr=1e-3) - - """ - - cap_max: int = 1_000_000 - noise_lr: float = 5e5 - refine_start_iter: int = 500 - refine_stop_iter: int = 25_000 - refine_every: int = 100 - min_opacity: float = 0.005 - verbose: bool = False - - def initialize_state(self) -> Dict[str, Any]: - """Initialize and return the running state for this strategy.""" - n_max = 51 - binoms = torch.zeros((n_max, n_max)) - for n in range(n_max): - for k in range(n + 1): - binoms[n, k] = math.comb(n, k) - return {"binoms": binoms} - - def check_sanity( - self, - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - optimizers: Dict[str, torch.optim.Optimizer], - ): - """Sanity check for the parameters and optimizers. - - Check if: - * `params` and `optimizers` have the same keys. - * Each optimizer has exactly one param_group, corresponding to each parameter. - * The following keys are present: {"means", "scales", "quats", "opacities"}. - - Raises: - AssertionError: If any of the above conditions is not met. - - .. note:: - It is not required but highly recommended for the user to call this function - after initializing the strategy to ensure the convention of the parameters - and optimizers is as expected. - """ - - super().check_sanity(params, optimizers) - # The following keys are required for this strategy. - for key in ["means", "scales", "quats", "opacities"]: - assert key in params, f"{key} is required in params but missing." - - # def step_pre_backward( - # self, - # params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - # optimizers: Dict[str, torch.optim.Optimizer], - # # state: Dict[str, Any], - # step: int, - # info: Dict[str, Any], - # ): - # """Callback function to be executed before the `loss.backward()` call.""" - # pass - - def step_post_backward( - self, - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - optimizers: Dict[str, torch.optim.Optimizer], - state: Dict[str, Any], - step: int, - info: Dict[str, Any], - lr: float, - ): - """Callback function to be executed after the `loss.backward()` call. - - Args: - lr (float): Learning rate for "means" attribute of the GS. - """ - # move to the correct device - state["binoms"] = state["binoms"].to(params["means"].device) - - binoms = state["binoms"] - - if ( - step < self.refine_stop_iter - and step > self.refine_start_iter - and step % self.refine_every == 0 - ): - # teleport GSs - n_relocated_gs = self._relocate_gs(params, optimizers, binoms) - if self.verbose: - print(f"Step {step}: Relocated {n_relocated_gs} GSs.") - - # add new GSs - n_new_gs = self._add_new_gs(params, optimizers, binoms) - if self.verbose: - print( - f"Step {step}: Added {n_new_gs} GSs. " - f"Now having {len(params['means'])} GSs." - ) - - torch.cuda.empty_cache() - - # add noise to GSs - inject_noise_to_position( - params=params, optimizers=optimizers, state={}, scaler=lr * self.noise_lr - ) - - @torch.no_grad() - def _relocate_gs( - self, - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - optimizers: Dict[str, torch.optim.Optimizer], - binoms: Tensor, - ) -> int: - opacities = torch.sigmoid(params["opacities"].flatten()) - dead_mask = opacities <= self.min_opacity - n_gs = dead_mask.sum().item() - if n_gs > 0: - relocate( - params=params, - optimizers=optimizers, - state={}, - mask=dead_mask, - binoms=binoms, - min_opacity=self.min_opacity, - ) - return n_gs - - @torch.no_grad() - def _add_new_gs( - self, - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - optimizers: Dict[str, torch.optim.Optimizer], - binoms: Tensor, - ) -> int: - current_n_points = len(params["means"]) - n_target = min(self.cap_max, int(1.05 * current_n_points)) - n_gs = max(0, n_target - current_n_points) - if n_gs > 0: - sample_add( - params=params, - optimizers=optimizers, - state={}, - n=n_gs, - binoms=binoms, - min_opacity=self.min_opacity, - ) - return n_gs diff --git a/tests/test_strategy.py b/tests/test_strategy.py index 0311541..66e3531 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -15,7 +15,7 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") def test_strategy(): from gsplat.rendering import rasterization - from gsplat.strategy import DefaultStrategy, MCMCStrategy + from gsplat.strategy import DefaultStrategy torch.manual_seed(42) @@ -54,18 +54,12 @@ def test_strategy(): render_colors.mean().backward(retain_graph=True) strategy.step_post_backward(params, optimizers, state, step=600, info=info) - # Test MCMCStrategy - strategy = MCMCStrategy(verbose=True) - strategy.check_sanity(params, optimizers) - state = strategy.initialize_state() - render_colors.mean().backward(retain_graph=True) - strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3) @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") def test_strategy_requires_grad(): from gsplat.rendering import rasterization - from gsplat.strategy import DefaultStrategy, MCMCStrategy + from gsplat.strategy import DefaultStrategy def assert_consistent_sizes(params): sizes = [v.shape[0] for v in params.values()] @@ -110,22 +104,12 @@ def assert_consistent_sizes(params): strategy.check_sanity(params, optimizers) state = strategy.initialize_state() strategy.step_pre_backward(params, optimizers, state, step=600, info=info) - render_colors.mean().backward(retain_graph=True) strategy.step_post_backward(params, optimizers, state, step=600, info=info) for k, v in params.items(): assert v.requires_grad == requires_grad_map[k] assert params["non_trainable_features"].grad is None assert_consistent_sizes(params) - # Test MCMCStrategy - strategy = MCMCStrategy(verbose=True) - strategy.check_sanity(params, optimizers) - state = strategy.initialize_state() render_colors.mean().backward(retain_graph=True) - strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3) - assert params["non_trainable_features"].grad is None - for k, v in params.items(): - assert v.requires_grad == requires_grad_map[k] - assert_consistent_sizes(params) if __name__ == "__main__": From d959aea898f3cb0e32de4144a8f0bbd3957c31d7 Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Mon, 24 Nov 2025 19:43:32 +0800 Subject: [PATCH 2/7] [Feat] Implement pruning function for Gradient-Driven Natural Selection (GNS) --- gsplat/strategy/improved.py | 95 +++++++++++++++++++++++++++++++++---- 1 file changed, 87 insertions(+), 8 deletions(-) diff --git a/gsplat/strategy/improved.py b/gsplat/strategy/improved.py index 7b3bb22..3f646b3 100644 --- a/gsplat/strategy/improved.py +++ b/gsplat/strategy/improved.py @@ -199,7 +199,7 @@ def step_post_backward( self.reset_count += 1 # After the first two resets, perform quantile pruning 300 steps after reset - # 看起来似乎有没有这个机制区别不是很大? + # (kept disabled because it showed negligible effect in practice) # if self.reset_count <= 2 and step % self.reset_every == 300 and step > 300: # n_quantile_prune = self._quantile_prune_gs( # params=params, @@ -289,18 +289,18 @@ def _grow_gs( startI = self.refine_start_iter endI = self.refine_stop_iter - 500 den = endI - startI - # 计算 rate,防止除以 0,并确保为 float + # compute rate while avoiding division-by-zero and keeping float precision if den == 0: rate = 1.0 else: rate = float((step - startI) / den) - # clamp 到 [0, 1],避免负值或 >1 的奇异情况 + # clamp to [0, 1] to avoid negative or >1 edge cases rate = max(0.0, min(1.0, rate)) if rate >= 1.0: budget = int(self.budget) else: - # 使用 math.sqrt 对 float 做开方 + # use math.sqrt on the float before scaling with the budget budget = int(math.sqrt(rate) * float(self.budget)) total_qualified = int(torch.sum(is_grad_high).item()) @@ -309,12 +309,12 @@ def _grow_gs( final_budget = min(budget, theoretical_max) new_points_needed = final_budget - curr_points - # 初始化is_split掩码,全为False + # initialize split mask with False is_split = torch.zeros_like(is_grad_high, dtype=torch.bool, device=device) - # 创建重要性分数向量,只考虑高梯度区域 + # create importance scores restricted to high-gradient candidates importance_scores = grads.clone() - importance_scores[~is_grad_high] = 0.0 # 屏蔽非高梯度区域 - # 确保所有分数非负且至少有一个有效候选 + importance_scores[~is_grad_high] = 0.0 # zero scores outside candidate set + # ensure non-negative scores and that at least one candidate exists if torch.any(importance_scores > 0): num_available = (importance_scores > 0).sum().item() actual_split_count = min(max(new_points_needed, 0), num_available) @@ -399,3 +399,82 @@ def _quantile_prune_gs( remove(params=params, optimizers=optimizers, state=state, mask=is_prune) return n_prune + + @torch.no_grad() + def _opacity_prune_gs( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + min_opacity: float, + ) -> int: + """Prune Gaussians with opacities below the given absolute threshold. + + Args: + params: The parameters dictionary containing "opacities". + optimizers: The optimizers for the parameters. + state: The running state dictionary. + min_opacity: The absolute opacity threshold (e.g., 0.005). + Gaussians with opacities strictly lower than this value will be pruned. + + Returns: + Number of Gaussians pruned. + """ + opacities = torch.sigmoid(params["opacities"].flatten()) + is_prune = opacities < min_opacity + + n_prune = is_prune.sum().item() + if n_prune > 0: + remove(params=params, optimizers=optimizers, state=state, mask=is_prune) + + return n_prune + + @torch.no_grad() + def _final_prune_gs( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + target_budget: int, + ) -> int: + """Enforce a strict budget by probabilistic pruning based on opacity. + + This implements the 'Natural Selection' final pruning mechanism from the paper: + "Gradient-Driven Natural Selection for Compact 3D Gaussian Splatting". + + Gaussians are retained based on their fitness (opacity) using Multinomial sampling, + simulating the survival of the fittest under a strict resource constraint. + + Args: + params: The parameters dictionary containing "opacities". + optimizers: The optimizers for the parameters. + state: The running state dictionary. + target_budget: The maximum number of Gaussians to keep. + + Returns: + Number of Gaussians pruned. + """ + # 1. Fetch opacities (sigmoid) to serve as sampling weights. + # gsplat stores logits in params["opacities"], so activation is required. + opacities = torch.sigmoid(params["opacities"].flatten()) + n_curr = opacities.shape[0] + + # 2. If already under budget, nothing to prune. + if n_curr <= target_budget: + return 0 + + # 3. Sample indices to keep via multinomial; higher opacity increases survival chance. + # Use replacement=False to prevent duplicate selections. + keep_indices = torch.multinomial(opacities, target_budget, replacement=False) + + # 4. Build the prune mask (True = delete) starting from all True + is_prune = torch.ones(n_curr, dtype=torch.bool, device=opacities.device) + # flip survivors to False so they are kept + is_prune[keep_indices] = False + + # 5. Perform the actual removal + n_prune = is_prune.sum().item() + if n_prune > 0: + remove(params=params, optimizers=optimizers, state=state, mask=is_prune) + + return n_prune From e962964196defdd89f20328dea6788286381a329 Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Tue, 25 Nov 2025 16:52:45 +0800 Subject: [PATCH 3/7] [Feat] Add a pruning mechanism based on Gradient-Driven Natural Selection --- examples/extended_trainer.py | 127 ++++++++++++++++++++++++++++++++--- gsplat/strategy/improved.py | 111 ++++++++++++++++++++++++++++-- 2 files changed, 224 insertions(+), 14 deletions(-) diff --git a/examples/extended_trainer.py b/examples/extended_trainer.py index a404781..9f6d7bb 100644 --- a/examples/extended_trainer.py +++ b/examples/extended_trainer.py @@ -118,7 +118,7 @@ class Config: # Verbosity for densification logs strategy_verbose: bool = True # Densification hyper-parameters (see notes below; shared unless marked otherwise) - prune_opa: float = 0.05 + prune_opa: float = 0.005 grow_grad2d: float = 0.0002 prune_scale3d: float = 0.08 prune_scale2d: float = 0.15 @@ -134,7 +134,7 @@ class Config: pause_refine_after_reset: int = 0 revised_opacity: bool = False # ImprovedStrategy-only hyper-parameter - budget: int = 2_000_000 + budget: Optional[int] = None # Strategy instance (constructed from the type/params above) strategy: Union[DefaultStrategy, ImprovedStrategy] = field( @@ -165,6 +165,20 @@ class Config: # LR for higher-order SH (detail) shN_lr: float = 2.5e-3 / 20 + # --- Natural Selection Pruning Params --- + # Whether to enable the Natural Selection pruning phase + enable_natural_selection: bool = True + # Iteration to start Natural Selection (usually post-densification, e.g., after 15000) + reg_start: int = 15_000 + # Iteration to end Natural Selection + reg_end: int = 23_000 + # Base regularization strength during Natural Selection (will be adjusted dynamically) + opacity_reg_lr: float = 2e-5 + # Interval for Natural Selection reg updates + reg_interval: int = 50 + # Final target Gaussian count (budget) + final_budget: int = 1000000 + ### Scale regularization """Weight of the regularisation loss encouraging gaussians to be flat, i.e. set their minimum scale to be small""" @@ -251,6 +265,8 @@ class Config: with_eval3d: bool = False def __post_init__(self): + if self.budget is None: + self.budget = self.final_budget * 2.5 self.rebuild_strategy() def adjust_steps(self, factor: float): @@ -307,6 +323,11 @@ def rebuild_strategy(self): absgrad=self.absgrad, verbose=self.strategy_verbose, budget=self.budget, + enable_natural_selection=self.enable_natural_selection, + reg_start=self.reg_start, + reg_end=self.reg_end, + reg_interval=self.reg_interval, + final_budget=self.final_budget, ) else: assert_never(self.strategy_type) @@ -662,6 +683,7 @@ def train(self): max_steps = cfg.max_steps init_step = 0 + ns_stop_step: Optional[int] = None schedulers = [ # means has a learning rate schedule, that end at 0.01 of the initial value @@ -856,6 +878,14 @@ def train(self): bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) + # --- [GNS] Opacity Learning Rate Scaling --- + if cfg.enable_natural_selection and step == cfg.reg_start: + print( + f"[GNS] Starting Natural Selection: Scaling Opacity LR by 4x at step {step}" + ) + for param_group in self.optimizers["opacities"].param_groups: + param_group["lr"] *= 4.0 + self.cfg.strategy.step_pre_backward( params=self.splats, optimizers=self.optimizers, @@ -929,14 +959,93 @@ def train(self): ) loss += cfg.render_normal_loss_weight * render_normal_loss - # regularizations - # the smallest scale is always near 0 + # regularizations the smallest scale is always near 0 if cfg.flat_reg > 0.0: loss += cfg.flat_reg * self.compute_flat_loss() # We follow the original SplatFacto implementation here and only apply this loss every 10 steps if cfg.scale_reg > 0.0 and step % 10 == 0: loss += cfg.scale_reg * self.compute_scale_regularisation_loss_median() + # --- [GNS] Regularization Loss & Early Stop Handling --- + strategy_gns_finished = getattr(self.cfg.strategy, "gns_finished", True) + if ( + cfg.enable_natural_selection + and cfg.reg_end >= step >= cfg.reg_start + and not strategy_gns_finished + ): + current_gs_count = len(self.splats["means"]) + if step > cfg.reg_start and current_gs_count < cfg.final_budget * 1.05: + print( + f"[GNS] Count {current_gs_count} < 1.05 * Budget. " + f"Stopping Natural Selection early at step {step}." + ) + if hasattr(self.cfg.strategy, "force_stop_natural_selection"): + self.cfg.strategy.force_stop_natural_selection( + self.splats, + self.optimizers, + self.strategy_state, + cfg.final_budget, + ) + ns_stop_step = step + elif (step - 1) % cfg.reg_interval == 0: + opacities_logits = self.splats["opacities"].flatten() + + if (step - 1) % 100 == 0: + if not hasattr(self, "gns_start_count"): + self.gns_start_count = len(self.splats["means"]) + if self.gns_start_count < cfg.final_budget: + self.gns_start_count = cfg.final_budget + 1000 + + progress = (step - cfg.reg_start) / (cfg.reg_end - cfg.reg_start) + progress = max(0.0, min(1.0, progress)) + expected_count = self.gns_start_count - ( + self.gns_start_count - cfg.final_budget + ) * progress + current_count = len(self.splats["means"]) + + if current_count > expected_count * 1.05: + cfg.opacity_reg_lr = cfg.opacity_reg_lr * 1.2 + elif current_count < expected_count * 0.95: + cfg.opacity_reg_lr = cfg.opacity_reg_lr * 0.8 + + cfg.opacity_reg_lr = max(1e-7, min(cfg.opacity_reg_lr, 1e-2)) + + if self.cfg.strategy_verbose: + print( + f"[GNS] Step {step}: Count={current_count}, " + f"Target={int(expected_count)}, LR={cfg.opacity_reg_lr:.2e}" + ) + + if step < cfg.reg_start + 1000: + current_opacities = torch.sigmoid(opacities_logits) + rate_l = torch.max( + torch.ones_like(current_opacities) * 0.05, + 1 - current_opacities, + ) + term = (opacities_logits + 20) / rate_l + gns_loss = cfg.opacity_reg_lr * (torch.mean(term) ** 2) + else: + mean_val = torch.mean(opacities_logits) + gns_loss = 3 * cfg.opacity_reg_lr * ((mean_val + 20) ** 2) + + loss += gns_loss + if self.cfg.strategy_verbose: + print( + f"[GNS] Step {step}: opacity_reg_lr={cfg.opacity_reg_lr:.6e}, " + f"loss contribution={gns_loss.item():.6e}" + ) + + if cfg.enable_natural_selection and step == cfg.reg_end and ns_stop_step is None: + ns_stop_step = step + + if ns_stop_step is not None and step == ns_stop_step + 1000: + print( + f"[GNS] Restoring Opacity LR (1000 steps after stop) at step {step}" + ) + for param_group in self.optimizers["opacities"].param_groups: + param_group["lr"] /= 4.0 + ns_stop_step = None + loss.backward() desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " @@ -1093,16 +1202,16 @@ def train(self): # Implementation of gradient accumulation trick from: # "Improving Densification in 3D Gaussian Splatting for High-Fidelity Rendering" # https://arxiv.org/abs/2508.12313v1 - if step <= 15000: - # Every step update for first 15000 iterations + if step <= 20000: + # Every step update for first 20000 iterations for optimizer in self.optimizers.values(): if cfg.visible_adam: optimizer.step(visibility_mask) else: optimizer.step() optimizer.zero_grad(set_to_none=True) - elif step <= 22500: - # Accumulate 5 steps, update every 5 steps for 15000-22500 iterations + elif step <= 24000: + # Accumulate 5 steps, update every 5 steps for 20000-24000 iterations if step % 5 == 0: for optimizer in self.optimizers.values(): if cfg.visible_adam: @@ -1111,7 +1220,7 @@ def train(self): optimizer.step() optimizer.zero_grad(set_to_none=True) else: - # Accumulate 20 steps, update every 20 steps after 22500 iterations + # Accumulate 20 steps, update every 20 steps after 24000 iterations if step % 20 == 0: for optimizer in self.optimizers.values(): if cfg.visible_adam: diff --git a/gsplat/strategy/improved.py b/gsplat/strategy/improved.py index 3f646b3..284a791 100644 --- a/gsplat/strategy/improved.py +++ b/gsplat/strategy/improved.py @@ -13,9 +13,11 @@ class ImprovedStrategy(Strategy): """An improved strategy with budget-based Gaussian splitting. - This strategy is based on the paper: + This strategy is based on the papers: "Improving Densification in 3D Gaussian Splatting for High-Fidelity Rendering" https://arxiv.org/abs/2508.12313v1 + and "Gradient-Driven Natural Selection for Compact 3D Gaussian Splatting" + https://arxiv.org/abs/2511.16980 The strategy will: @@ -23,6 +25,11 @@ class ImprovedStrategy(Strategy): - Periodically prune GSs with low opacity. - Periodically reset GSs to a lower opacity. - Perform quantile-based pruning after the first two resets. + - Optionally run the Natural Selection phase inspired by the paper above: + * enter a dedicated window (`reg_start`→`reg_end`) where low-opacity points are trimmed at interval `reg_interval`; + * dynamically adjust opacity regularization strength to encourage the population to shrink toward `final_budget`; + * optionally early-stop and force a probabilistic final prune once the target count is reached; + * finally restore the learning rate after a short delay, mirroring the workflow in "Gradient-Driven Natural Selection for Compact 3D Gaussian Splatting". If `absgrad=True`, it will use the absolute gradients instead of average gradients for GS splitting, following the AbsGS paper: @@ -53,6 +60,12 @@ class ImprovedStrategy(Strategy): 3DGS uses "means2d" gradient and 2DGS uses a similar gradient which stores in variable "gradient_2dgs". budget (int): Maximum number of Gaussians allowed. Default is 1000000. + enable_natural_selection (bool): Enable the Natural Selection pruning phase + inspired by "Gradient-Driven Natural Selection for Compact 3D Gaussian Splatting". + reg_start (int): Iteration to start Natural Selection. + reg_end (int): Iteration to stop Natural Selection (or when finished early). + reg_interval (int): Interval between opacity pruning steps during Natural Selection. + final_budget (int): Target number of Gaussians after Natural Selection finishes. Examples: @@ -72,22 +85,30 @@ class ImprovedStrategy(Strategy): """ prune_opa: float = 0.005 - grow_grad2d: float = 0.0003 - prune_scale3d: float = 0.1 + grow_grad2d: float = 0.0002 + prune_scale3d: float = 0.08 prune_scale2d: float = 0.15 refine_scale2d_stop_iter: int = 4000 refine_start_iter: int = 500 - refine_stop_iter: int = 15_000 + refine_stop_iter: int = 15000 reset_every: int = 3000 refine_every: int = 100 absgrad: bool = True verbose: bool = True key_for_gradient: Literal["means2d", "gradient_2dgs"] = "means2d" - budget: int = 2000000 + budget: int = 2500000 + + # --- [GNS] Additional params (mirror extended_trainer.Config) --- + enable_natural_selection: bool = False + reg_start: int = 15000 + reg_end: int = 23000 + reg_interval: int = 50 + final_budget: int = 1000000 def __post_init__(self): """Initialize instance variables after dataclass initialization.""" self.reset_count = 0 + self.gns_finished = False def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: """Initialize and return the running state for this strategy. @@ -130,6 +151,39 @@ def check_sanity( for key in ["means", "scales", "quats", "opacities"]: assert key in params, f"{key} is required in params but missing." + def force_stop_natural_selection( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + target_budget: int, + ) -> int: + """Force a final prune and mark natural selection as finished.""" + if self.gns_finished or not self.enable_natural_selection: + return 0 + + if self.verbose: + print( + f"[GNS] Early Stopping triggered! Force pruning to {target_budget}..." + ) + + n_pruned = self._final_prune_gs( + params=params, + optimizers=optimizers, + state=state, + target_budget=target_budget, + ) + + if self.verbose: + print( + f"[GNS] Early stop pruned {n_pruned} gaussians. " + f"Now having {len(params['means'])} GSs." + ) + + self.gns_finished = True + torch.cuda.empty_cache() + return n_pruned + def step_pre_backward( self, params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], @@ -154,6 +208,48 @@ def step_post_backward( packed: bool = False, ): """Callback function to be executed after the `loss.backward()` call.""" + # --- [GNS] Natural Selection Pruning Logic (must run first) --- + # Needs to happen before refine_stop_iter since GNS usually runs post-densification + if self.enable_natural_selection and not self.gns_finished: + # 1. Continuous pruning to remove very transparent Gaussians during the window + if self.reg_start <= step < self.reg_end and step % self.reg_interval == 0: + n_curr = len(params["means"]) + if n_curr > self.final_budget: + n_pruned = self._opacity_prune_gs( + params=params, + optimizers=optimizers, + state=state, + min_opacity=0.001, + ) + if self.verbose and n_pruned > 0: + print( + f"[GNS] Step {step}: Removed {n_pruned} GSs " + f"below opacity threshold. Now having {len(params['means'])} GSs." + ) + + # 2. Final budget prune that enforces the probabilistic cap at the end + if step == self.reg_end: + if self.verbose: + print(f"[GNS] Step {step}: Running Final Budget Prune to {self.final_budget}...") + + n_pruned = self._final_prune_gs( + params=params, + optimizers=optimizers, + state=state, + target_budget=self.final_budget + ) + + if self.verbose: + print( + f"[GNS] Final Prune removed {n_pruned} gaussians. " + f"Now having {len(params['means'])} GSs." + ) + + # Clean up memory after large-scale pruning + torch.cuda.empty_cache() + self.gns_finished = True + # ---------------------------------------------------------- + if step >= self.refine_stop_iter: return @@ -196,6 +292,11 @@ def step_post_backward( state=state, value=self.prune_opa * 10.0, ) + if self.verbose: + print( + f"Step {step}: reset opacities to {self.prune_opa * 10.0}. " + f"Now having {len(params['means'])} GSs." + ) self.reset_count += 1 # After the first two resets, perform quantile pruning 300 steps after reset From 3730b9929bfef332a36d93a2a2ade30d48751477 Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Tue, 25 Nov 2025 20:28:54 +0800 Subject: [PATCH 4/7] [Feat] Remove steps_scaler, Refine comments, Set default adam to fused --- examples/extended_trainer.py | 74 ++++++++++++------------------------ gsplat/strategy/improved.py | 2 +- 2 files changed, 26 insertions(+), 50 deletions(-) diff --git a/examples/extended_trainer.py b/examples/extended_trainer.py index 9f6d7bb..4549ac6 100644 --- a/examples/extended_trainer.py +++ b/examples/extended_trainer.py @@ -75,9 +75,6 @@ class Config: # Batch size for training. Learning rates are scaled automatically batch_size: int = 1 - # A global factor to scale the number of training steps - steps_scaler: float = 1.0 - # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model @@ -165,18 +162,18 @@ class Config: # LR for higher-order SH (detail) shN_lr: float = 2.5e-3 / 20 - # --- Natural Selection Pruning Params --- - # Whether to enable the Natural Selection pruning phase + ### [GNS] Natural Selection Pruning Params + """ Whether to enable the Natural Selection pruning phase """ enable_natural_selection: bool = True - # Iteration to start Natural Selection (usually post-densification, e.g., after 15000) + """ Iteration to start Natural Selection (usually post-densification, e.g., after 15000) """ reg_start: int = 15_000 - # Iteration to end Natural Selection + """ Iteration to end Natural Selection """ reg_end: int = 23_000 - # Base regularization strength during Natural Selection (will be adjusted dynamically) + """ Base regularization strength during Natural Selection (will be adjusted dynamically) """ opacity_reg_lr: float = 2e-5 - # Interval for Natural Selection reg updates + """ Interval for Natural Selection reg updates """ reg_interval: int = 50 - # Final target Gaussian count (budget) + """ Final target Gaussian count (budget) """ final_budget: int = 1000000 ### Scale regularization @@ -269,27 +266,6 @@ def __post_init__(self): self.budget = self.final_budget * 2.5 self.rebuild_strategy() - def adjust_steps(self, factor: float): - self.eval_steps = [int(i * factor) for i in self.eval_steps] - self.save_steps = [int(i * factor) for i in self.save_steps] - self.ply_steps = [int(i * factor) for i in self.ply_steps] - self.max_steps = int(self.max_steps * factor) - self.sh_degree_interval = int(self.sh_degree_interval * factor) - - strategy = self.strategy - if isinstance(strategy, DefaultStrategy): - strategy.refine_start_iter = int(strategy.refine_start_iter * factor) - strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) - strategy.reset_every = int(strategy.reset_every * factor) - strategy.refine_every = int(strategy.refine_every * factor) - elif isinstance(strategy, ImprovedStrategy): - strategy.refine_start_iter = int(strategy.refine_start_iter * factor) - strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) - strategy.reset_every = int(strategy.reset_every * factor) - strategy.refine_every = int(strategy.refine_every * factor) - else: - assert_never(strategy) - def rebuild_strategy(self): if self.strategy_type == "default": self.strategy = DefaultStrategy( @@ -419,6 +395,7 @@ def create_splats_with_optimizers( eps=1e-15 / math.sqrt(BS), # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + fused=True, ) for name, _, lr in params } @@ -878,7 +855,7 @@ def train(self): bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) - # --- [GNS] Opacity Learning Rate Scaling --- + # [GNS] Opacity Learning Rate Scaling if cfg.enable_natural_selection and step == cfg.reg_start: print( f"[GNS] Starting Natural Selection: Scaling Opacity LR by 4x at step {step}" @@ -894,7 +871,7 @@ def train(self): info=info, ) - # loss + # color loss l1loss = F.l1_loss(colors, pixels) ssimloss = 1.0 - fused_ssim( colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" @@ -918,7 +895,7 @@ def train(self): if (need_consistency_normal or need_normal_prior) and depths is not None: surf_normals_from_depth = get_implied_normal_from_depth(depths, Ks) - # consistency normal loss + # normal consistency loss if ( need_consistency_normal and surf_normals_from_depth is not None @@ -932,13 +909,13 @@ def train(self): ) loss += cfg.consistency_normal_loss_weight * consistency_norm_loss - # normal loss + # normal prior loss if need_normal_prior and surf_normals_from_depth is not None: mask = torch.ones_like(depths).float() # [B,H,W,1] if normal_prior_mask is not None: mask = mask * normal_prior_mask - # Surface normal loss (from depth) + # surface normal loss (from depth) if ( cfg.surf_normal_loss_weight > 0 and step >= cfg.surf_normal_loss_activation_step @@ -948,7 +925,7 @@ def train(self): ) loss += cfg.surf_normal_loss_weight * surf_normal_loss - # Rendered normal loss + # rendered normal loss if ( cfg.render_normal_loss_weight > 0 and step >= cfg.render_normal_loss_activation_step @@ -966,7 +943,7 @@ def train(self): if cfg.scale_reg > 0.0 and step % 10 == 0: loss += cfg.scale_reg * self.compute_scale_regularisation_loss_median() - # --- [GNS] Regularization Loss & Early Stop Handling --- + # [GNS] Regularization Loss & Early Stop Handling strategy_gns_finished = getattr(self.cfg.strategy, "gns_finished", True) if ( cfg.enable_natural_selection @@ -1010,11 +987,11 @@ def train(self): cfg.opacity_reg_lr = max(1e-7, min(cfg.opacity_reg_lr, 1e-2)) - if self.cfg.strategy_verbose: - print( - f"[GNS] Step {step}: Count={current_count}, " - f"Target={int(expected_count)}, LR={cfg.opacity_reg_lr:.2e}" - ) + # if self.cfg.strategy_verbose: + # print( + # f"[GNS] Step {step}: Count={current_count}, " + # f"Target={int(expected_count)}, LR={cfg.opacity_reg_lr:.2e}" + # ) if step < cfg.reg_start + 1000: current_opacities = torch.sigmoid(opacities_logits) @@ -1029,11 +1006,11 @@ def train(self): gns_loss = 3 * cfg.opacity_reg_lr * ((mean_val + 20) ** 2) loss += gns_loss - if self.cfg.strategy_verbose: - print( - f"[GNS] Step {step}: opacity_reg_lr={cfg.opacity_reg_lr:.6e}, " - f"loss contribution={gns_loss.item():.6e}" - ) + # if self.cfg.strategy_verbose: + # print( + # f"[GNS] Step {step}: opacity_reg_lr={cfg.opacity_reg_lr:.6e}, " + # f"loss contribution={gns_loss.item():.6e}" + # ) if cfg.enable_natural_selection and step == cfg.reg_end and ns_stop_step is None: ns_stop_step = step @@ -1738,7 +1715,6 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): cfg = tyro.cli(Config) cfg.rebuild_strategy() - cfg.adjust_steps(cfg.steps_scaler) # Import BilateralGrid and related functions based on configuration if cfg.use_bilateral_grid or cfg.use_fused_bilagrid: diff --git a/gsplat/strategy/improved.py b/gsplat/strategy/improved.py index 284a791..1f35536 100644 --- a/gsplat/strategy/improved.py +++ b/gsplat/strategy/improved.py @@ -98,7 +98,7 @@ class ImprovedStrategy(Strategy): key_for_gradient: Literal["means2d", "gradient_2dgs"] = "means2d" budget: int = 2500000 - # --- [GNS] Additional params (mirror extended_trainer.Config) --- + # [GNS] Additional params (mirror extended_trainer.Config) enable_natural_selection: bool = False reg_start: int = 15000 reg_end: int = 23000 From df3735f3e5670d11175e68dd619dcbe44e773db6 Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Tue, 25 Nov 2025 20:40:50 +0800 Subject: [PATCH 5/7] [Refactor] Make variable names consistent --- examples/extended_trainer.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/extended_trainer.py b/examples/extended_trainer.py index 4549ac6..fa42148 100644 --- a/examples/extended_trainer.py +++ b/examples/extended_trainer.py @@ -170,7 +170,7 @@ class Config: """ Iteration to end Natural Selection """ reg_end: int = 23_000 """ Base regularization strength during Natural Selection (will be adjusted dynamically) """ - opacity_reg_lr: float = 2e-5 + opacity_reg_weight: float = 2e-5 """ Interval for Natural Selection reg updates """ reg_interval: int = 50 """ Final target Gaussian count (budget) """ @@ -179,14 +179,14 @@ class Config: ### Scale regularization """Weight of the regularisation loss encouraging gaussians to be flat, i.e. set their minimum scale to be small""" - flat_reg: float = 1.0 + flat_reg_weight: float = 1.0 """If scale regularization is enabled, a scale regularization introduced in PhysGauss (https://xpandora.github.io/PhysGaussian/) is used for reducing huge spikey gaussians. This implementation adapts the PhysGauss loss to use the ratio of max to median scale instead of max to min, as implemented in mvsanywhere/regsplatfacto. This modification has been found to work better at encouraging Gaussians to be disks. """ - scale_reg: float = 1.0 + scale_reg_weight: float = 1.0 """Threshold of ratio of Gaussian's max to median scale before applying regularization loss. This is adapted from the PhysGauss paper (there they used ratio of max to min). """ @@ -937,11 +937,11 @@ def train(self): loss += cfg.render_normal_loss_weight * render_normal_loss # regularizations the smallest scale is always near 0 - if cfg.flat_reg > 0.0: - loss += cfg.flat_reg * self.compute_flat_loss() + if cfg.flat_reg_weight > 0.0: + loss += cfg.flat_reg_weight * self.compute_flat_loss() # We follow the original SplatFacto implementation here and only apply this loss every 10 steps - if cfg.scale_reg > 0.0 and step % 10 == 0: - loss += cfg.scale_reg * self.compute_scale_regularisation_loss_median() + if cfg.scale_reg_weight > 0.0 and step % 10 == 0: + loss += cfg.scale_reg_weight * self.compute_scale_regularisation_loss_median() # [GNS] Regularization Loss & Early Stop Handling strategy_gns_finished = getattr(self.cfg.strategy, "gns_finished", True) @@ -981,16 +981,16 @@ def train(self): current_count = len(self.splats["means"]) if current_count > expected_count * 1.05: - cfg.opacity_reg_lr = cfg.opacity_reg_lr * 1.2 + cfg.opacity_reg_weight = cfg.opacity_reg_weight * 1.2 elif current_count < expected_count * 0.95: - cfg.opacity_reg_lr = cfg.opacity_reg_lr * 0.8 + cfg.opacity_reg_weight = cfg.opacity_reg_weight * 0.8 - cfg.opacity_reg_lr = max(1e-7, min(cfg.opacity_reg_lr, 1e-2)) + cfg.opacity_reg_weight = max(1e-7, min(cfg.opacity_reg_weight, 1e-2)) # if self.cfg.strategy_verbose: # print( # f"[GNS] Step {step}: Count={current_count}, " - # f"Target={int(expected_count)}, LR={cfg.opacity_reg_lr:.2e}" + # f"Target={int(expected_count)}, LR={cfg.opacity_reg_weight:.2e}" # ) if step < cfg.reg_start + 1000: @@ -1000,15 +1000,15 @@ def train(self): 1 - current_opacities, ) term = (opacities_logits + 20) / rate_l - gns_loss = cfg.opacity_reg_lr * (torch.mean(term) ** 2) + gns_loss = cfg.opacity_reg_weight * (torch.mean(term) ** 2) else: mean_val = torch.mean(opacities_logits) - gns_loss = 3 * cfg.opacity_reg_lr * ((mean_val + 20) ** 2) + gns_loss = 3 * cfg.opacity_reg_weight * ((mean_val + 20) ** 2) loss += gns_loss # if self.cfg.strategy_verbose: # print( - # f"[GNS] Step {step}: opacity_reg_lr={cfg.opacity_reg_lr:.6e}, " + # f"[GNS] Step {step}: opacity_reg_weight={cfg.opacity_reg_weight:.6e}, " # f"loss contribution={gns_loss.item():.6e}" # ) From d4f85c3656da8bd204dc1dfa157448d55f970638 Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Tue, 25 Nov 2025 22:33:43 +0800 Subject: [PATCH 6/7] [Refactor] Move utils_depth.py to utils.py --- examples/extended_trainer.py | 40 ++++-- examples/extended_trainer_2dgs.py | 2 +- examples/simple_viewer.py | 2 +- examples/simple_viewer_2dgs.py | 2 +- examples/utils.py | 134 +++++++++++++++++++- examples/utils_depth.py | 199 ------------------------------ gsplat/strategy/improved.py | 2 +- 7 files changed, 166 insertions(+), 215 deletions(-) delete mode 100644 examples/utils_depth.py diff --git a/examples/extended_trainer.py b/examples/extended_trainer.py index fa42148..85338d3 100644 --- a/examples/extended_trainer.py +++ b/examples/extended_trainer.py @@ -29,7 +29,15 @@ from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from typing_extensions import Literal, assert_never -from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed + +from utils import ( + AppearanceOptModule, + CameraOptModule, + get_implied_normal_from_depth, + knn, + rgb_to_sh, + set_random_seed, +) from gsplat import export_splats from gsplat.compression import PngCompression @@ -39,7 +47,6 @@ from gsplat.strategy import DefaultStrategy, ImprovedStrategy from gsplat_viewer import GsplatViewer, GsplatRenderTabState from nerfview import CameraState, RenderTabState, apply_float_colormap -from utils_depth import get_implied_normal_from_depth @dataclass @@ -704,6 +711,9 @@ def train(self): # Training loop. global_tic = time.time() + cached_grid_xy = None + cached_h, cached_w = -1, -1 + pbar = tqdm.tqdm(range(init_step, max_steps)) for step in pbar: if not cfg.disable_viewer: @@ -838,15 +848,22 @@ def train(self): colors, depths, render_normals = renders, None, None if cfg.use_bilateral_grid: - grid_y, grid_x = torch.meshgrid( - (torch.arange(height, device=self.device) + 0.5) / height, - (torch.arange(width, device=self.device) + 0.5) / width, - indexing="ij", - ) - grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + if cached_grid_xy is None or cached_h != height or cached_w != width: + cached_h, cached_w = height, width + grid_y, grid_x = torch.meshgrid( + (torch.arange(height, device=self.device) + 0.5) / height, + (torch.arange(width, device=self.device) + 0.5) / width, + indexing="ij", + ) + cached_grid_xy = ( + torch.stack([grid_x, grid_y], dim=-1) + .unsqueeze(0) + .detach() + ) + batch_grid_xy = cached_grid_xy.expand(colors.shape[0], -1, -1, -1) colors = slice( self.bil_grids, - grid_xy.expand(colors.shape[0], -1, -1, -1), + batch_grid_xy, colors, image_ids.unsqueeze(-1), )["rgb"] @@ -877,8 +894,11 @@ def train(self): colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" ) loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda + # tv loss if cfg.use_bilateral_grid: - tvloss = 10 * total_variation_loss(self.bil_grids.grids) + unique_grid_ids = torch.unique(image_ids) + active_grids = self.bil_grids.grids[unique_grid_ids] + tvloss = 10 * total_variation_loss(active_grids) loss += tvloss # depth loss diff --git a/examples/extended_trainer_2dgs.py b/examples/extended_trainer_2dgs.py index 8af2e6e..b749a8b 100644 --- a/examples/extended_trainer_2dgs.py +++ b/examples/extended_trainer_2dgs.py @@ -30,7 +30,7 @@ rgb_to_sh, set_random_seed, ) -from utils_depth import get_implied_normal_from_depth +from utils import get_implied_normal_from_depth from gsplat_viewer_2dgs import GsplatViewer, GsplatRenderTabState from gsplat.rendering import rasterization_2dgs, rasterization_2dgs_inria_wrapper from gsplat.strategy import DefaultStrategy, ImprovedStrategy diff --git a/examples/simple_viewer.py b/examples/simple_viewer.py index 678875b..8a1dad6 100644 --- a/examples/simple_viewer.py +++ b/examples/simple_viewer.py @@ -17,7 +17,7 @@ from nerfview import CameraState, RenderTabState, apply_float_colormap from gsplat_viewer import GsplatViewer, GsplatRenderTabState -from utils_depth import get_implied_normal_from_depth +from utils import get_implied_normal_from_depth def main(local_rank: int, world_rank, world_size: int, args): diff --git a/examples/simple_viewer_2dgs.py b/examples/simple_viewer_2dgs.py index 4b169ce..34f2584 100644 --- a/examples/simple_viewer_2dgs.py +++ b/examples/simple_viewer_2dgs.py @@ -12,7 +12,7 @@ from nerfview import CameraState, RenderTabState, apply_float_colormap from gsplat_viewer_2dgs import GsplatViewer, GsplatRenderTabState -from utils_depth import get_implied_normal_from_depth +from utils import get_implied_normal_from_depth def main(local_rank: int, world_rank, world_size: int, args): diff --git a/examples/utils.py b/examples/utils.py index 80f8e35..c1cc317 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -1,12 +1,14 @@ import random +from functools import lru_cache +import kornia import numpy as np import torch -from sklearn.neighbors import NearestNeighbors -from torch import Tensor import torch.nn.functional as F import matplotlib.pyplot as plt from matplotlib import colormaps +from sklearn.neighbors import NearestNeighbors +from torch import Tensor class CameraOptModule(torch.nn.Module): @@ -222,3 +224,131 @@ def apply_depth_colormap( if acc is not None: img = img * acc + (1.0 - acc) return img + + +# --- Depth/normal utilities (migrated from utils_depth.py) --- +OPENGL_TO_OPENCV = np.array( + [ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ] +).astype(float) + + +def to_homogeneous(input_tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: + """Convert tensor to homogeneous coordinates by adding ones along a dimension.""" + ones = torch.ones_like(input_tensor.select(dim, 0).unsqueeze(dim)) + output_bkN = torch.cat([input_tensor, ones], dim=dim) + return output_bkN + + +class BackprojectDepth(torch.nn.Module): + """Project depth pixels to 3D space in homogeneous coordinates.""" + + def __init__(self, height: int, width: int) -> None: + super().__init__() + self.height = height + self.width = width + + xx, yy = torch.meshgrid( + torch.arange(self.width), + torch.arange(self.height), + indexing="xy", + ) + pix_coords_2hw = torch.stack((xx, yy), dim=0) + 0.5 + pix_coords_13N = to_homogeneous(pix_coords_2hw, dim=0).flatten(1).unsqueeze(0) + self.register_buffer("pix_coords_13N", pix_coords_13N) + + def forward(self, depth_b1hw: torch.Tensor, invK_b44: torch.Tensor) -> torch.Tensor: + cam_points_b3N = torch.matmul(invK_b44[:, :3, :3], self.pix_coords_13N) + cam_points_b3N = depth_b1hw.flatten(start_dim=2) * cam_points_b3N + cam_points_b4N = to_homogeneous(cam_points_b3N, dim=1) + return cam_points_b4N + + +class Project3D(torch.nn.Module): + """Project 3D points into the 2D camera plane.""" + + def __init__(self, eps: float = 1e-8): + super().__init__() + self.register_buffer("eps", torch.tensor(eps).view(1, 1, 1)) + + def forward( + self, + points_b4N: torch.Tensor, + K_b44: torch.Tensor, + cam_T_world_b44: torch.Tensor, + ) -> torch.Tensor: + P_b44 = K_b44 @ cam_T_world_b44 + cam_points_b3N = P_b44[:, :3] @ points_b4N + + mask = torch.abs(cam_points_b3N[:, 2:]) > self.eps + depth_b1N = cam_points_b3N[:, 2:] + self.eps + scale = torch.where( + mask, 1.0 / depth_b1N, torch.tensor(1.0, device=depth_b1N.device) + ) + + pix_coords_b2N = cam_points_b3N[:, :2] * scale + return torch.cat([pix_coords_b2N, depth_b1N], dim=1) + + +class NormalGenerator(torch.nn.Module): + """Estimate normals from depth maps.""" + + def __init__(self, height: int, width: int): + super().__init__() + self.height = height + self.width = width + self.backproject = BackprojectDepth(self.height, self.width) + + def forward(self, depth_b1hw: torch.Tensor, invK_b44: torch.Tensor) -> torch.Tensor: + cam_points_b4N = self.backproject(depth_b1hw, invK_b44) + cam_points_b3hw = cam_points_b4N[:, :3].view(-1, 3, self.height, self.width) + gradients_b32hw = kornia.filters.spatial_gradient(cam_points_b3hw) + normal_b3hw = -torch.cross( + gradients_b32hw[:, :, 0], + gradients_b32hw[:, :, 1], + dim=1, + ) + normal_b3hw = F.normalize(normal_b3hw, dim=1) + return normal_b3hw + + +def get_implied_normal_from_depth( + depths_bhw1: torch.Tensor, Ks_b33: torch.Tensor +) -> torch.Tensor: + """Compute surface normals from depth maps and camera intrinsics.""" + if depths_bhw1.dim() == 3: + depths_bhw1 = depths_bhw1.unsqueeze(0) + elif depths_bhw1.dim() != 4: + raise ValueError( + f"depths_bhw1 must be [H, W, 1] or [B, H, W, 1], got {depths_bhw1.shape}" + ) + + B, H, W, _ = depths_bhw1.shape + device = depths_bhw1.device + depths_b1hw = depths_bhw1.permute(0, 3, 1, 2) + + if Ks_b33.dim() == 2: + Ks_b33 = Ks_b33.unsqueeze(0).repeat(B, 1, 1) + elif Ks_b33.shape[0] != B: + raise ValueError( + f"Batch size mismatch: depth batch={B}, K batch={Ks_b33.shape[0]}" + ) + + K_b44 = torch.eye(4, device=device).unsqueeze(0).repeat(B, 1, 1) + K_b44[:, :3, :3] = Ks_b33 + invK_b44 = torch.inverse(K_b44) + + normal_generator = get_normal_generator(height=H, width=W) + normals_b3hw = normal_generator(depths_b1hw, invK_b44) + normals_bhw3 = normals_b3hw.permute(0, 2, 3, 1).contiguous() + return normals_bhw3 + + +@lru_cache(maxsize=None) +def get_normal_generator(height: int, width: int) -> NormalGenerator: + """Cached factory for NormalGenerator to reuse per resolution.""" + return NormalGenerator(height=height, width=width).cuda() diff --git a/examples/utils_depth.py b/examples/utils_depth.py deleted file mode 100644 index 20fe9bc..0000000 --- a/examples/utils_depth.py +++ /dev/null @@ -1,199 +0,0 @@ -""" Taken from SimpleRecon""" - -import kornia -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn -from functools import lru_cache - -OPENGL_TO_OPENCV = np.array( - [ - [1, 0, 0, 0], - [0, -1, 0, 0], - [0, 0, -1, 0], - [0, 0, 0, 1], - ] -).astype(float) - - -def to_homogeneous(input_tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: - """ - Converts tensor to homogeneous coordinates by adding ones to the specified - dimension - """ - ones = torch.ones_like(input_tensor.select(dim, 0).unsqueeze(dim)) - output_bkN = torch.cat([input_tensor, ones], dim=dim) - return output_bkN - - -class BackprojectDepth(nn.Module): - """ - Layer that projects points from 2D camera to 3D space. The 3D points are - represented in homogeneous coordinates. - """ - - def __init__(self, height: int, width: int) -> None: - super().__init__() - - self.height = height - self.width = width - - xx, yy = torch.meshgrid( - torch.arange(self.width), - torch.arange(self.height), - indexing="xy", - ) - pix_coords_2hw = torch.stack((xx, yy), dim=0) + 0.5 - - pix_coords_13N = to_homogeneous(pix_coords_2hw, dim=0).flatten(1).unsqueeze(0) - - # make these tensors into buffers so they are put on the correct GPU - # automatically - self.register_buffer("pix_coords_13N", pix_coords_13N) - - def forward(self, depth_b1hw: torch.Tensor, invK_b44: torch.Tensor) -> torch.Tensor: - """ - Backprojects spatial points in 2D image space to world space using - invK_b44 at the depths defined in depth_b1hw. - """ - - cam_points_b3N = torch.matmul(invK_b44[:, :3, :3], self.pix_coords_13N) - cam_points_b3N = depth_b1hw.flatten(start_dim=2) * cam_points_b3N - cam_points_b4N = to_homogeneous(cam_points_b3N, dim=1) - return cam_points_b4N - - -class Project3D(nn.Module): - """ - Layer that projects 3D points into the 2D camera - """ - - def __init__(self, eps: float = 1e-8): - super().__init__() - - self.register_buffer("eps", torch.tensor(eps).view(1, 1, 1)) - - def forward( - self, - points_b4N: torch.Tensor, - K_b44: torch.Tensor, - cam_T_world_b44: torch.Tensor, - ) -> torch.Tensor: - """ - Projects spatial points in 3D world space to camera image space using - the extrinsics matrix cam_T_world_b44 and intrinsics K_b44. - """ - P_b44 = K_b44 @ cam_T_world_b44 - - cam_points_b3N = P_b44[:, :3] @ points_b4N - - # from Kornia and OpenCV, https://kornia.readthedocs.io/en/latest/_modules/kornia/geometry/conversions.html#convert_points_from_homogeneous - mask = torch.abs(cam_points_b3N[:, 2:]) > self.eps - depth_b1N = cam_points_b3N[:, 2:] + self.eps - scale = torch.where( - mask, 1.0 / depth_b1N, torch.tensor(1.0, device=depth_b1N.device) - ) - - pix_coords_b2N = cam_points_b3N[:, :2] * scale - - return torch.cat([pix_coords_b2N, depth_b1N], dim=1) - - -class NormalGenerator(nn.Module): - def __init__( - self, - height: int, - width: int, - ): - """ - Estimates normals from depth maps. - """ - super().__init__() - self.height = height - self.width = width - - self.backproject = BackprojectDepth(self.height, self.width) - - def forward(self, depth_b1hw: torch.Tensor, invK_b44: torch.Tensor) -> torch.Tensor: - cam_points_b4N = self.backproject(depth_b1hw, invK_b44) - cam_points_b3hw = cam_points_b4N[:, :3].view(-1, 3, self.height, self.width) - - gradients_b32hw = kornia.filters.spatial_gradient(cam_points_b3hw) - - # 反转方向以匹配OpenCV坐标定义 - normal_b3hw = -torch.cross( - gradients_b32hw[:, :, 0], # ∂P/∂x - gradients_b32hw[:, :, 1], # ∂P/∂y - dim=1, - ) - normal_b3hw = F.normalize(normal_b3hw, dim=1) - return normal_b3hw - - -def get_implied_normal_from_depth( - depths_bhw1: torch.Tensor, Ks_b33: torch.Tensor -) -> torch.Tensor: - """ - Computes surface normal maps from a batch of depth maps using camera intrinsics. - - Args: - depths_bhw1 (torch.Tensor): Depth maps of shape [B, H, W, 1] or [H, W, 1]. - Ks_b33 (torch.Tensor): Camera intrinsics [B, 3, 3] or [3, 3]. - - Returns: - torch.Tensor: Normal maps of shape [B, H, W, 3], with values in [-1, 1]. - """ - # Handle input dimensions and batch size - if depths_bhw1.dim() == 3: - # Single image case -> add batch dim - depths_bhw1 = depths_bhw1.unsqueeze(0) - elif depths_bhw1.dim() != 4: - raise ValueError( - f"depths_bhw1 must be [H, W, 1] or [B, H, W, 1], got {depths_bhw1.shape}" - ) - - B, H, W, _ = depths_bhw1.shape - device = depths_bhw1.device - - # Convert to [B, 1, H, W] for NormalGenerator - depths_b1hw = depths_bhw1.permute(0, 3, 1, 2) - - # Prepare intrinsics K[3x3] → K[4x4] and invert - if Ks_b33.dim() == 2: - Ks_b33 = Ks_b33.unsqueeze(0).repeat(B, 1, 1) - elif Ks_b33.shape[0] != B: - raise ValueError( - f"Batch size mismatch: depth batch={B}, K batch={Ks_b33.shape[0]}" - ) - - K_b44 = torch.eye(4, device=device).unsqueeze(0).repeat(B, 1, 1) - K_b44[:, :3, :3] = Ks_b33 - invK_b44 = torch.inverse(K_b44) - - # Compute normals with NormalGenerator - normal_generator = get_normal_generator(height=H, width=W) - normals_b3hw = normal_generator(depths_b1hw, invK_b44) # [B, 3, H, W] - - # Convert to [B, H, W, 3] - normals_bhw3 = normals_b3hw.permute(0, 2, 3, 1).contiguous() - - return normals_bhw3 - - -@lru_cache(maxsize=None) -def get_normal_generator(height: int, width: int) -> NormalGenerator: - """ - Gets a normal generator object. - - This is wrapped in lru_cache so for a given height and width, we only create one instance - of the normal generator during the whole lifetime of this class instance. - - Args: - height (int): The height of the depth map. - width (int): The width of the depth map. - - Returns: - NormalGenerator: The normal generator object. - """ - return NormalGenerator(height=height, width=width).cuda() diff --git a/gsplat/strategy/improved.py b/gsplat/strategy/improved.py index 1f35536..33887d5 100644 --- a/gsplat/strategy/improved.py +++ b/gsplat/strategy/improved.py @@ -283,7 +283,7 @@ def step_post_backward( state["count"].zero_() if self.refine_scale2d_stop_iter > 0: state["radii"].zero_() - torch.cuda.empty_cache() + torch.cuda.empty_cache() # it is useful if step % self.reset_every == 0 and step > 0: reset_opa( From 7e13e8086da31b11350c4aec741fd85ce250cb11 Mon Sep 17 00:00:00 2001 From: x299 <1678807086@qq.com> Date: Wed, 26 Nov 2025 10:52:54 +0800 Subject: [PATCH 7/7] [Style] format --- examples/extended_trainer.py | 36 ++++++++++++++++++++++-------------- gsplat/strategy/improved.py | 24 +++++++++++++----------- tests/test_strategy.py | 1 - 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/examples/extended_trainer.py b/examples/extended_trainer.py index 85338d3..a9124da 100644 --- a/examples/extended_trainer.py +++ b/examples/extended_trainer.py @@ -141,9 +141,7 @@ class Config: budget: Optional[int] = None # Strategy instance (constructed from the type/params above) - strategy: Union[DefaultStrategy, ImprovedStrategy] = field( - init=False, repr=False - ) + strategy: Union[DefaultStrategy, ImprovedStrategy] = field(init=False, repr=False) # Use packed mode for rasterization, this leads to less memory usage but slightly slower. packed: bool = False # Use sparse gradients for optimization. (experimental) @@ -177,7 +175,7 @@ class Config: """ Iteration to end Natural Selection """ reg_end: int = 23_000 """ Base regularization strength during Natural Selection (will be adjusted dynamically) """ - opacity_reg_weight: float = 2e-5 + opacity_reg_weight: float = 2e-5 """ Interval for Natural Selection reg updates """ reg_interval: int = 50 """ Final target Gaussian count (budget) """ @@ -856,9 +854,7 @@ def train(self): indexing="ij", ) cached_grid_xy = ( - torch.stack([grid_x, grid_y], dim=-1) - .unsqueeze(0) - .detach() + torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0).detach() ) batch_grid_xy = cached_grid_xy.expand(colors.shape[0], -1, -1, -1) colors = slice( @@ -961,7 +957,10 @@ def train(self): loss += cfg.flat_reg_weight * self.compute_flat_loss() # We follow the original SplatFacto implementation here and only apply this loss every 10 steps if cfg.scale_reg_weight > 0.0 and step % 10 == 0: - loss += cfg.scale_reg_weight * self.compute_scale_regularisation_loss_median() + loss += ( + cfg.scale_reg_weight + * self.compute_scale_regularisation_loss_median() + ) # [GNS] Regularization Loss & Early Stop Handling strategy_gns_finished = getattr(self.cfg.strategy, "gns_finished", True) @@ -993,11 +992,14 @@ def train(self): if self.gns_start_count < cfg.final_budget: self.gns_start_count = cfg.final_budget + 1000 - progress = (step - cfg.reg_start) / (cfg.reg_end - cfg.reg_start) + progress = (step - cfg.reg_start) / ( + cfg.reg_end - cfg.reg_start + ) progress = max(0.0, min(1.0, progress)) - expected_count = self.gns_start_count - ( - self.gns_start_count - cfg.final_budget - ) * progress + expected_count = ( + self.gns_start_count + - (self.gns_start_count - cfg.final_budget) * progress + ) current_count = len(self.splats["means"]) if current_count > expected_count * 1.05: @@ -1005,7 +1007,9 @@ def train(self): elif current_count < expected_count * 0.95: cfg.opacity_reg_weight = cfg.opacity_reg_weight * 0.8 - cfg.opacity_reg_weight = max(1e-7, min(cfg.opacity_reg_weight, 1e-2)) + cfg.opacity_reg_weight = max( + 1e-7, min(cfg.opacity_reg_weight, 1e-2) + ) # if self.cfg.strategy_verbose: # print( @@ -1032,7 +1036,11 @@ def train(self): # f"loss contribution={gns_loss.item():.6e}" # ) - if cfg.enable_natural_selection and step == cfg.reg_end and ns_stop_step is None: + if ( + cfg.enable_natural_selection + and step == cfg.reg_end + and ns_stop_step is None + ): ns_stop_step = step if ns_stop_step is not None and step == ns_stop_step + 1000: diff --git a/gsplat/strategy/improved.py b/gsplat/strategy/improved.py index 33887d5..fbf85fa 100644 --- a/gsplat/strategy/improved.py +++ b/gsplat/strategy/improved.py @@ -230,21 +230,23 @@ def step_post_backward( # 2. Final budget prune that enforces the probabilistic cap at the end if step == self.reg_end: if self.verbose: - print(f"[GNS] Step {step}: Running Final Budget Prune to {self.final_budget}...") - + print( + f"[GNS] Step {step}: Running Final Budget Prune to {self.final_budget}..." + ) + n_pruned = self._final_prune_gs( - params=params, - optimizers=optimizers, - state=state, - target_budget=self.final_budget + params=params, + optimizers=optimizers, + state=state, + target_budget=self.final_budget, ) - + if self.verbose: print( f"[GNS] Final Prune removed {n_pruned} gaussians. " f"Now having {len(params['means'])} GSs." ) - + # Clean up memory after large-scale pruning torch.cuda.empty_cache() self.gns_finished = True @@ -283,7 +285,7 @@ def step_post_backward( state["count"].zero_() if self.refine_scale2d_stop_iter > 0: state["radii"].zero_() - torch.cuda.empty_cache() # it is useful + torch.cuda.empty_cache() # it is useful if step % self.reset_every == 0 and step > 0: reset_opa( @@ -500,7 +502,7 @@ def _quantile_prune_gs( remove(params=params, optimizers=optimizers, state=state, mask=is_prune) return n_prune - + @torch.no_grad() def _opacity_prune_gs( self, @@ -529,7 +531,7 @@ def _opacity_prune_gs( remove(params=params, optimizers=optimizers, state=state, mask=is_prune) return n_prune - + @torch.no_grad() def _final_prune_gs( self, diff --git a/tests/test_strategy.py b/tests/test_strategy.py index 66e3531..6cdb9fa 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -55,7 +55,6 @@ def test_strategy(): strategy.step_post_backward(params, optimizers, state, step=600, info=info) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") def test_strategy_requires_grad(): from gsplat.rendering import rasterization