Skip to content

Commit c11f150

Browse files
committed
Added utils_minkowski to support MinkowskiEngine operation
1 parent 0d0445f commit c11f150

1 file changed

Lines changed: 175 additions & 0 deletions

File tree

src/py_utils/utils_minkowski.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import torch
2+
import MinkowskiEngine as ME
3+
4+
####################
5+
# KERNEL GENERATOR #
6+
####################
7+
8+
9+
def get_cube_kernel_generator(kernel_size, stride=1, dilation=1, dimension=3):
10+
"""for kernel_size = 3, the kernel region is a 3x3x3 cube"""
11+
12+
return ME.KernelGenerator(
13+
kernel_size=kernel_size,
14+
stride=stride,
15+
dilation=dilation,
16+
region_type=ME.RegionType.HYPER_CUBE,
17+
dimension=dimension,
18+
)
19+
20+
21+
def get_cross_kernel_generator(kernel_size, stride=1, dilation=1, dimension=3):
22+
"""for kernel_size = 3, the kernel region is a 3x3x3 cross"""
23+
24+
return ME.KernelGenerator(
25+
kernel_size=kernel_size,
26+
stride=stride,
27+
dilation=dilation,
28+
region_type=ME.RegionType.HYPER_CROSS,
29+
dimension=dimension,
30+
)
31+
32+
33+
########################
34+
# NEIGHBORHOOD MAPPING #
35+
########################
36+
37+
38+
@torch.no_grad()
39+
def _sparse_tensor_key_map(
40+
A: ME.CoordinateMapKey,
41+
B: ME.CoordinateMapKey,
42+
kernel_generator: ME.KernelGenerator,
43+
coordinate_manager: ME.CoordinateManager,
44+
device="cuda",
45+
):
46+
47+
kg = kernel_generator
48+
km = coordinate_manager.kernel_map(
49+
A,
50+
B,
51+
kernel_size=kg.kernel_size,
52+
stride=kg.kernel_stride,
53+
dilation=kg.kernel_dilation,
54+
region_type=kg.region_type,
55+
region_offset=kg.region_offsets,
56+
)
57+
58+
a_keys, b_keys = [], []
59+
for _, pair in km.items():
60+
a, b = pair
61+
a_keys.append(a.long())
62+
b_keys.append(b.long())
63+
64+
if len(a_keys) == 0 or len(b_keys) == 0:
65+
a_keys = torch.empty(0, dtype=torch.long, device=device)
66+
b_keys = torch.empty(0, dtype=torch.long, device=device)
67+
else:
68+
a_keys = torch.cat(a_keys)
69+
b_keys = torch.cat(b_keys)
70+
71+
return a_keys, b_keys
72+
73+
74+
@torch.no_grad()
75+
def sparse_tensor_map(
76+
A: ME.SparseTensor,
77+
B: ME.SparseTensor,
78+
kernel_generator=get_cube_kernel_generator(1),
79+
):
80+
81+
if A.coordinate_manager is not B.coordinate_manager:
82+
raise ValueError("A and B must share the same coordinate_manager.")
83+
84+
# shorthanded
85+
cm = A.coordinate_manager
86+
ak = A.coordinate_map_key
87+
bk = B.coordinate_map_key
88+
kg = kernel_generator
89+
90+
exp_stride = [b // a for a, b in zip(A.tensor_stride, B.tensor_stride)]
91+
ker_stride = list(kg.kernel_stride)
92+
93+
if ker_stride != exp_stride:
94+
msg = f"kernel_generator stride {ker_stride} does not match: "
95+
msg += f"A.tensor_stride {A.tensor_stride} "
96+
msg += f"B.tensor_stride {B.tensor_stride} "
97+
msg += f"expected stride {exp_stride})."
98+
raise ValueError(msg)
99+
100+
return _sparse_tensor_key_map(ak, bk, kg, cm, device=A.device)
101+
102+
103+
@torch.no_grad()
104+
def A_occupied_by_B(
105+
A: ME.SparseTensor,
106+
B: ME.SparseTensor | ME.CoordinateMapKey,
107+
):
108+
cm = A.coordinate_manager
109+
110+
if isinstance(B, ME.SparseTensor):
111+
strided_B_key = cm.stride(B.coordinate_map_key, A.tensor_stride)
112+
elif isinstance(B, ME.CoordinateMapKey):
113+
strided_B_key = cm.stride(B, A.tensor_stride)
114+
else:
115+
msg = "B must be either a SparseTensor or CoordinateMapKey."
116+
raise ValueError(msg)
117+
118+
mask = torch.zeros(len(A), dtype=torch.bool, device=A.device)
119+
120+
if cm.size(strided_B_key) == 0:
121+
return mask
122+
123+
# only the exact match (kernel_size=1) is needed to determine occupancy
124+
kg = get_cube_kernel_generator(kernel_size=1)
125+
a_idx, _ = _sparse_tensor_key_map(
126+
A.coordinate_map_key,
127+
strided_B_key,
128+
kg,
129+
cm,
130+
device=A.device,
131+
)
132+
mask[a_idx] = True
133+
134+
return mask
135+
136+
137+
##################
138+
# SET OPERATIONS #
139+
##################
140+
141+
142+
@torch.no_grad()
143+
def set_difference(A: ME.SparseTensor, B: ME.SparseTensor):
144+
"""A - B"""
145+
146+
assert A.tensor_stride == B.tensor_stride, "tensor_stride mismatch"
147+
148+
occupied = A_occupied_by_B(A, B)
149+
if not torch.any(occupied):
150+
return A
151+
152+
keep = ~occupied
153+
154+
out = ME.SparseTensor(
155+
features=A.F[keep],
156+
coordinates=A.C[keep],
157+
tensor_stride=A.tensor_stride,
158+
coordinate_manager=A.coordinate_manager,
159+
)
160+
return out
161+
162+
163+
@torch.no_grad()
164+
def set_disjoint_union(A: ME.SparseTensor, B: ME.SparseTensor):
165+
"""A U B, assume A and B don't have intersection"""
166+
167+
assert A.tensor_stride == B.tensor_stride, "tensor_stride mismatch"
168+
169+
out = ME.SparseTensor(
170+
features=torch.cat([A.F, B.F], dim=0),
171+
coordinates=torch.cat([A.C, B.C], dim=0),
172+
tensor_stride=A.tensor_stride,
173+
coordinate_manager=A.coordinate_manager,
174+
)
175+
return out

0 commit comments

Comments
 (0)