Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ dev =
ase
pytest
pre-commit
binvox
38 changes: 38 additions & 0 deletions vne/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

import abc
import enum
from typing import List, Tuple

import numpy as np


class SpatialDims(enum.IntEnum):
TWO = 2
THREE = 3


class Datasource(abc.ABC):
"""Abstract datasource."""

@abc.abstractmethod
def __call__(self, model_id: str) -> np.ndarray:
raise NotImplementedError

def __iter__(self) -> Datasource:
self._iter = 0
return self

def __next__(self) -> Tuple[str, np.ndarray]:
if self._iter < len(self):
model_id = self._keys[self._iter]
self._iter += 1
return model_id, self(model_id)
else:
raise StopIteration

def __len__(self) -> int:
return len(self._keys)

def keys(self) -> List[str]:
return self._keys
59 changes: 59 additions & 0 deletions vne/special/shapenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

import enum
from pathlib import Path
from typing import Union

import numpy as np

from .. import base
from .voxels import load_binvox

SHAPENET_PATH = Path(
"/media/quantumjot/DataIII/Data/Turing/ShapeNet/ShapeNetCore.v2"
)


class ShapeNetTaxonomy(str, enum.Enum):
CHAIRS = "03001627"


class ShapeNetDataset(base.Datasource):
"""ShapeNET dataset.

Parameters
----------
filepath : path
A path to the ShapeNET core library.
synsetId : str
A string identifier for the object class.
"""

def __init__(self, filepath: Path, synsetId: Union[str, ShapeNetTaxonomy]):
self.filepath = filepath
self.synsetId = synsetId
self._keys = [
path.name
for path in (self.filepath / self.synsetId).iterdir()
if path.is_dir()
]
self._cache = {}

def __call__(self, model_id: str) -> np.ndarray:
if model_id not in self._keys:
raise KeyError(f"Model {model_id} not found.")

if model_id in self._cache:
return self._cache[model_id]

model_filename = (
self.filepath
/ self.synsetId
/ model_id
/ "models"
/ "model_normalized.solid.binvox"
)

voxels = load_binvox(model_filename)
self._cache[model_id] = voxels
return voxels
64 changes: 64 additions & 0 deletions vne/special/voxels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import itertools
import os
from typing import Tuple

import binvox
import numpy as np
from scipy.ndimage import zoom


def bounding_box(img: np.ndarray) -> Tuple[int]:
"""Calculate the bounding box for a volume."""
dims = img.ndim
out = []
for ax in itertools.combinations(reversed(range(dims)), dims - 1):
nonzero = np.any(img, axis=ax)
out.extend(np.where(nonzero)[0][[0, -1]])
return tuple(out)


def load_binvox(
filename: os.PathLike, *, size: int = 64, centre: bool = True
) -> np.ndarray:
"""Load a binvox file.

Parameters
----------
filename : str, path
A filename for the binvox file.
size : int
The size of the output.
centre : bool
Centre the object based on the calculated bounding box.

Returns
-------
voxels : array
A numpy array representing the voxels.
"""
bv = binvox.Binvox.read(filename, "dense")
voxels = bv.numpy()

if centre:
bb = bounding_box(voxels)
crop = voxels[
slice(bb[0], bb[1], 1),
slice(bb[2], bb[3], 1),
slice(bb[4], bb[5], 1),
]

dx = (voxels.shape[0] // 2) - ((bb[1] - bb[0]) // 2)
dy = (voxels.shape[1] // 2) - ((bb[3] - bb[2]) // 2)
dz = (voxels.shape[2] // 2) - ((bb[5] - bb[4]) // 2)

centred = np.zeros_like(voxels)
centred[
slice(dx, dx + crop.shape[0], 1),
slice(dy, dy + crop.shape[1], 1),
slice(dz, dz + crop.shape[2], 1),
] = crop
voxels = centred

# TODO(arl): fix this to resample the voxels to the correct size
voxels = zoom(voxels, 0.5, order=0)
return voxels
47 changes: 33 additions & 14 deletions vne/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from torch import nn

from .base import SpatialDims


class ShapeSimilarityLoss:
"""Shape similarity loss based on pre-calculated shape similarity.
Expand Down Expand Up @@ -61,42 +63,59 @@ def __call__(
return loss


def dims_after_pooling(start: int, n_pools: int) -> int:
"""Calculate the size of a layer after n pooling ops."""
return start // (2**n_pools)


class ShapeVAE(nn.Module):
"""Shape regularized variational autoencoder.

Parameters
----------
input_shape : tuple
A tuple representing the input shape of the data, e.g. (1, 64, 64) for
images or (1, 64, 64, 64) for a volume with 1 channel.
latent_dims : int
The size of the latent representation.
pose_dims : int
The size of the pose representation.
spatial_dims : int (2 or 3)
Planar of volumetric data.

"""

def __init__(
self, latent_dims: int = 8, pose_dims: int = 1, spatial_dims: int = 2
self,
input_shape: Tuple[int] = (1, 64, 64),
latent_dims: int = 8,
pose_dims: int = 1,
):
super(ShapeVAE, self).__init__()

if spatial_dims == 2:
channels = input_shape[0]
spatial_dims = input_shape[1:]
ndim = len(spatial_dims)

if ndim not in SpatialDims:
raise ValueError(
f"`input_shape` must be have 2 or 3 dimensions, got: {ndim}."
)

if ndim == SpatialDims.TWO:
conv = nn.Conv2d
conv_T = nn.ConvTranspose2d
unflat_shape = (64, 4, 4)
elif spatial_dims == 3:
elif ndim == SpatialDims.THREE:
conv = nn.Conv3d
conv_T = nn.ConvTranspose3d
unflat_shape = (64, 4, 4, 4)
else:
raise ValueError(
f"`spatial_dims` must be in (2, 3), got: {spatial_dims}."
)

unflat_shape = tuple(
[
64,
]
+ [dims_after_pooling(ax) for ax in spatial_dims]
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @marjanfamili - can you just check that this works? This line doesn't look right to me (not least that it only has one argument)

)
flat_shape = np.prod(unflat_shape)

self.encoder = nn.Sequential(
conv(1, 8, 3, stride=2, padding=1),
conv(channels, 8, 3, stride=2, padding=1),
nn.ReLU(True),
conv(8, 16, 3, stride=2, padding=1),
nn.ReLU(True),
Expand All @@ -116,7 +135,7 @@ def __init__(
nn.ReLU(True),
conv_T(16, 8, 3, stride=2, padding=1),
nn.ReLU(True),
conv_T(8, 1, 2, stride=2, padding=1),
conv_T(8, channels, 2, stride=2, padding=1),
)

self.mu = nn.Linear(flat_shape, latent_dims)
Expand Down