diff --git a/projects/PTv3/models/point_transformer_v3/point_transformer_v3m1_base.py b/projects/PTv3/models/point_transformer_v3/point_transformer_v3m1_base.py index e91fb76c6..83e43cf34 100644 --- a/projects/PTv3/models/point_transformer_v3/point_transformer_v3m1_base.py +++ b/projects/PTv3/models/point_transformer_v3/point_transformer_v3m1_base.py @@ -21,7 +21,7 @@ from models.builder import MODELS from models.modules import PointModule, PointSequential from models.point_prompt_training import PDNorm -from models.scatter.functional import argsort, segment_csr, unique +from models.scatter.functional import argsort, segment_csr from models.utils.misc import offset2bincount from models.utils.structure import Point @@ -436,6 +436,29 @@ def __init__( if act_layer is not None: self.act = PointSequential(act_layer()) + @staticmethod + def _build_export_cluster( + code: torch.Tensor, + serialized_order: torch.Tensor, + ): + sorted_indices = serialized_order[0] + sorted_code = code[0].index_select(0, sorted_indices) + cluster_starts_mask = torch.cat( + [ + torch.ones_like(sorted_code[:1], dtype=torch.bool), + sorted_code[1:] != sorted_code[:-1], + ] + ) + cluster_starts = torch.nonzero(cluster_starts_mask, as_tuple=False).flatten() + num_points = torch._shape_as_tensor(sorted_indices).to(sorted_indices.device)[:1] + idx_ptr = torch.cat([cluster_starts, num_points], dim=0) + + cluster_sorted = torch.cumsum(cluster_starts_mask.to(dtype=sorted_indices.dtype), dim=0) - 1 + cluster = torch.zeros_like(cluster_sorted) + cluster.scatter_(0, sorted_indices, cluster_sorted) + + return cluster, sorted_indices, idx_ptr + def forward(self, point: Point): pooling_depth = (math.ceil(self.stride) - 1).bit_length() if pooling_depth > point.serialized_depth: @@ -453,22 +476,22 @@ def forward(self, point: Point): code = point.serialized_code >> pooling_depth * 3 if not self.export_mode: - code_, cluster, counts = torch.unique( + _, cluster, counts = torch.unique( code[0], + dim=0, sorted=True, return_inverse=True, return_counts=True, ) _, indices = torch.sort(cluster) + idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) else: - code_, cluster, counts, num_unique = unique(code[0]) - indices = argsort(cluster) + cluster, indices, idx_ptr = self._build_export_cluster(code, point.serialized_order) # indices of point sorted by cluster, for torch_scatter.segment_csr # index pointer for sorted point, for torch_scatter.segment_csr - idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) # head_indices of each cluster, for reduce attr e.g. code, batch head_indices = indices[idx_ptr[:-1]] # generate down code, order, inverse diff --git a/projects/PTv3/tests/test_serialized_pooling.py b/projects/PTv3/tests/test_serialized_pooling.py new file mode 100644 index 000000000..ac5a0aa9b --- /dev/null +++ b/projects/PTv3/tests/test_serialized_pooling.py @@ -0,0 +1,132 @@ +import sys +import types +import unittest +from pathlib import Path + +import torch + +if not hasattr(torch, "inference_mode"): + torch.inference_mode = torch.no_grad + +try: + import torch_scatter # noqa: F401 +except ModuleNotFoundError: + torch_scatter = types.ModuleType("torch_scatter") + + def _segment_csr(src, indptr, reduce="sum"): + outputs = [] + for start, end in zip(indptr[:-1].tolist(), indptr[1:].tolist()): + segment = src[start:end] + if reduce == "sum": + outputs.append(segment.sum(dim=0)) + elif reduce == "mean": + outputs.append(segment.mean(dim=0)) + elif reduce == "max": + outputs.append(segment.max(dim=0).values) + elif reduce == "min": + outputs.append(segment.min(dim=0).values) + else: + raise NotImplementedError(f"Unsupported reduce mode: {reduce}") + return torch.stack(outputs, dim=0) + + torch_scatter.segment_csr = _segment_csr + sys.modules["torch_scatter"] = torch_scatter + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from models.point_transformer_v3.point_transformer_v3m1_base import SerializedPooling +from models.utils.structure import Point + + +class TestSerializedPooling(unittest.TestCase): + def setUp(self): + self.grid_coord = torch.tensor( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [1, 1, 0], + [2, 2, 1], + [3, 2, 1], + [2, 3, 1], + [3, 3, 1], + [0, 0, 2], + [1, 0, 2], + ], + dtype=torch.int32, + ) + self.batch = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=torch.int64) + self.feat = torch.randn(self.grid_coord.shape[0], 6) + self.coord = self.grid_coord.to(torch.float32) + self.sparse_shape = torch.tensor([16, 16, 16], dtype=torch.int64) + self.depth = 6 + + def _make_point(self): + point = Point( + coord=self.coord.clone(), + grid_coord=self.grid_coord.clone(), + feat=self.feat.clone(), + batch=self.batch.clone(), + sparse_shape=self.sparse_shape.clone(), + ) + point.serialization(order=["z", "z-trans"], depth=self.depth, shuffle_orders=False) + return point + + def test_export_mode_matches_train_time(self): + torch.manual_seed(0) + train_module = SerializedPooling( + 6, + 8, + stride=2, + reduce="max", + shuffle_orders=False, + traceable=True, + export_mode=False, + ) + export_module = SerializedPooling( + 6, + 8, + stride=2, + reduce="max", + shuffle_orders=False, + traceable=True, + export_mode=True, + ) + train_module.norm = None + train_module.act = None + export_module.norm = None + export_module.act = None + export_module.load_state_dict(train_module.state_dict()) + + train_out = train_module(self._make_point()) + export_out = export_module(self._make_point()) + + tensor_keys = [ + "feat", + "coord", + "grid_coord", + "serialized_code", + "serialized_order", + "serialized_inverse", + "batch", + "sparse_shape", + "pooling_inverse", + ] + + for key in tensor_keys: + left = train_out[key] + right = export_out[key] + if left.dtype.is_floating_point: + if hasattr(torch.testing, "assert_close"): + torch.testing.assert_close(left, right, msg=f"Mismatch for {key}") + else: + torch.testing.assert_allclose(left, right, msg=f"Mismatch for {key}") + else: + self.assertTrue(torch.equal(left, right), f"Mismatch for {key}") + + +if __name__ == "__main__": + unittest.main()