From 0db3eba9a1282032c6db4edf48e3557f61686aa2 Mon Sep 17 00:00:00 2001 From: quantumjot Date: Mon, 25 Jul 2022 13:47:41 +0100 Subject: [PATCH 1/5] add binvox --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index d3ef016..e515461 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,3 +29,4 @@ dev = ase pytest pre-commit + binvox From 0668da8b961be55200faa724ad4ad5b305d02653 Mon Sep 17 00:00:00 2001 From: quantumjot Date: Mon, 25 Jul 2022 13:48:20 +0100 Subject: [PATCH 2/5] improve vae setup --- vne/vae.py | 47 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/vne/vae.py b/vne/vae.py index caf452f..54fab03 100644 --- a/vne/vae.py +++ b/vne/vae.py @@ -4,6 +4,8 @@ import torch from torch import nn +from .base import SpatialDims + class ShapeSimilarityLoss: """Shape similarity loss based on pre-calculated shape similarity. @@ -61,42 +63,59 @@ def __call__( return loss +def dims_after_pooling(start: int, n_pools: int) -> int: + """Calculate the size of a layer after n pooling ops.""" + return start // (2**n_pools) + + class ShapeVAE(nn.Module): """Shape regularized variational autoencoder. Parameters ---------- + input_shape : tuple + A tuple representing the input shape of the data, e.g. (1, 64, 64) for + images or (1, 64, 64, 64) for a volume with 1 channel. latent_dims : int The size of the latent representation. pose_dims : int The size of the pose representation. - spatial_dims : int (2 or 3) - Planar of volumetric data. - """ def __init__( - self, latent_dims: int = 8, pose_dims: int = 1, spatial_dims: int = 2 + self, + input_shape: Tuple[int] = (1, 64, 64), + latent_dims: int = 8, + pose_dims: int = 1, ): super(ShapeVAE, self).__init__() - if spatial_dims == 2: + channels = input_shape[0] + spatial_dims = input_shape[1:] + ndim = len(spatial_dims) + + if ndim not in SpatialDims: + raise ValueError( + f"`input_shape` must be have 2 or 3 dimensions, got: {ndim}." + ) + + if ndim == SpatialDims.TWO: conv = nn.Conv2d conv_T = nn.ConvTranspose2d - unflat_shape = (64, 4, 4) - elif spatial_dims == 3: + elif ndim == SpatialDims.THREE: conv = nn.Conv3d conv_T = nn.ConvTranspose3d - unflat_shape = (64, 4, 4, 4) - else: - raise ValueError( - f"`spatial_dims` must be in (2, 3), got: {spatial_dims}." - ) + unflat_shape = tuple( + [ + 64, + ] + + [dims_after_pooling(ax) for ax in spatial_dims] + ) flat_shape = np.prod(unflat_shape) self.encoder = nn.Sequential( - conv(1, 8, 3, stride=2, padding=1), + conv(channels, 8, 3, stride=2, padding=1), nn.ReLU(True), conv(8, 16, 3, stride=2, padding=1), nn.ReLU(True), @@ -116,7 +135,7 @@ def __init__( nn.ReLU(True), conv_T(16, 8, 3, stride=2, padding=1), nn.ReLU(True), - conv_T(8, 1, 2, stride=2, padding=1), + conv_T(8, channels, 2, stride=2, padding=1), ) self.mu = nn.Linear(flat_shape, latent_dims) From 309e4007c30da6db40d3619b1e56147e3f920640 Mon Sep 17 00:00:00 2001 From: quantumjot Date: Mon, 25 Jul 2022 13:49:33 +0100 Subject: [PATCH 3/5] add base --- vne/base.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 vne/base.py diff --git a/vne/base.py b/vne/base.py new file mode 100644 index 0000000..5fae533 --- /dev/null +++ b/vne/base.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import abc +import enum +from typing import List, Tuple + +import numpy as np + + +class SpatialDims(enum.IntEnum): + TWO = 2 + THREE = 3 + + +class Datasource(abc.ABC): + """Abstract datasource.""" + + @abc.abstractmethod + def __call__(self, model_id: str) -> np.ndarray: + raise NotImplementedError + + def __iter__(self) -> Datasource: + self._iter = 0 + return self + + def __next__(self) -> Tuple[str, np.ndarray]: + if self._iter < len(self): + model_id = self._keys[self._iter] + self._iter += 1 + return model_id, self(model_id) + else: + raise StopIteration + + def __len__(self) -> int: + return len(self._keys) + + def keys(self) -> List[str]: + return self._keys From cdbec6685da64c925cd56bef248f135b4c02b6bb Mon Sep 17 00:00:00 2001 From: quantumjot Date: Mon, 25 Jul 2022 13:50:03 +0100 Subject: [PATCH 4/5] add voxels --- vne/special/voxels.py | 64 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 vne/special/voxels.py diff --git a/vne/special/voxels.py b/vne/special/voxels.py new file mode 100644 index 0000000..cd0bcb4 --- /dev/null +++ b/vne/special/voxels.py @@ -0,0 +1,64 @@ +import itertools +import os +from typing import Tuple + +import binvox +import numpy as np +from scipy.ndimage import zoom + + +def bounding_box(img: np.ndarray) -> Tuple[int]: + """Calculate the bounding box for a volume.""" + dims = img.ndim + out = [] + for ax in itertools.combinations(reversed(range(dims)), dims - 1): + nonzero = np.any(img, axis=ax) + out.extend(np.where(nonzero)[0][[0, -1]]) + return tuple(out) + + +def load_binvox( + filename: os.PathLike, *, size: int = 64, centre: bool = True +) -> np.ndarray: + """Load a binvox file. + + Parameters + ---------- + filename : str, path + A filename for the binvox file. + size : int + The size of the output. + centre : bool + Centre the object based on the calculated bounding box. + + Returns + ------- + voxels : array + A numpy array representing the voxels. + """ + bv = binvox.Binvox.read(filename, "dense") + voxels = bv.numpy() + + if centre: + bb = bounding_box(voxels) + crop = voxels[ + slice(bb[0], bb[1], 1), + slice(bb[2], bb[3], 1), + slice(bb[4], bb[5], 1), + ] + + dx = (voxels.shape[0] // 2) - ((bb[1] - bb[0]) // 2) + dy = (voxels.shape[1] // 2) - ((bb[3] - bb[2]) // 2) + dz = (voxels.shape[2] // 2) - ((bb[5] - bb[4]) // 2) + + centred = np.zeros_like(voxels) + centred[ + slice(dx, dx + crop.shape[0], 1), + slice(dy, dy + crop.shape[1], 1), + slice(dz, dz + crop.shape[2], 1), + ] = crop + voxels = centred + + # TODO(arl): fix this to resample the voxels to the correct size + voxels = zoom(voxels, 0.5, order=0) + return voxels From 77e9f17d0640fa8c3169942cc66c5574cbf6d526 Mon Sep 17 00:00:00 2001 From: quantumjot Date: Mon, 25 Jul 2022 13:51:22 +0100 Subject: [PATCH 5/5] add shapenet loader --- vne/special/shapenet.py | 59 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 vne/special/shapenet.py diff --git a/vne/special/shapenet.py b/vne/special/shapenet.py new file mode 100644 index 0000000..30185fb --- /dev/null +++ b/vne/special/shapenet.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import enum +from pathlib import Path +from typing import Union + +import numpy as np + +from .. import base +from .voxels import load_binvox + +SHAPENET_PATH = Path( + "/media/quantumjot/DataIII/Data/Turing/ShapeNet/ShapeNetCore.v2" +) + + +class ShapeNetTaxonomy(str, enum.Enum): + CHAIRS = "03001627" + + +class ShapeNetDataset(base.Datasource): + """ShapeNET dataset. + + Parameters + ---------- + filepath : path + A path to the ShapeNET core library. + synsetId : str + A string identifier for the object class. + """ + + def __init__(self, filepath: Path, synsetId: Union[str, ShapeNetTaxonomy]): + self.filepath = filepath + self.synsetId = synsetId + self._keys = [ + path.name + for path in (self.filepath / self.synsetId).iterdir() + if path.is_dir() + ] + self._cache = {} + + def __call__(self, model_id: str) -> np.ndarray: + if model_id not in self._keys: + raise KeyError(f"Model {model_id} not found.") + + if model_id in self._cache: + return self._cache[model_id] + + model_filename = ( + self.filepath + / self.synsetId + / model_id + / "models" + / "model_normalized.solid.binvox" + ) + + voxels = load_binvox(model_filename) + self._cache[model_id] = voxels + return voxels