diff --git a/trellis2/datasets/components.py b/trellis2/datasets/components.py index 6c593cee..fbb108c5 100644 --- a/trellis2/datasets/components.py +++ b/trellis2/datasets/components.py @@ -1,8 +1,8 @@ -from typing import * +from typing import Any, Dict, Tuple import json from abc import abstractmethod import os -import json +import warnings import torch import numpy as np import pandas as pd @@ -25,7 +25,7 @@ def __init__(self, try: self.roots = json.loads(roots) root_type = 'obj' - except: + except json.JSONDecodeError: self.roots = roots.split(',') root_type = 'list' self.instances = [] @@ -67,18 +67,31 @@ def __len__(self): return len(self.instances) def __getitem__(self, index) -> Dict[str, Any]: - try: - root, instance = self.instances[index] - return self.get_instance(root, instance) - except Exception as e: - print(f'Error loading {instance}: {e}') - return self.__getitem__(np.random.randint(0, len(self))) + if len(self) == 0: + raise IndexError('Cannot load from an empty dataset.') + + max_retries = min(10, len(self)) + root, instance = self.instances[index] + last_error = None + + for _ in range(max_retries): + try: + return self.get_instance(root, instance) + except Exception as e: + last_error = e + warnings.warn(f'Error loading {instance}: {e}') + root, instance = self.instances[np.random.randint(0, len(self))] + + raise RuntimeError( + f'Failed to load a valid instance after {max_retries} retries. ' + f'Last attempted instance: {instance}' + ) from last_error def __str__(self): lines = [] lines.append(self.__class__.__name__) lines.append(f' - Total instances: {len(self)}') - lines.append(f' - Sources:') + lines.append(' - Sources:') for key, stats in self._stats.items(): lines.append(f' - {key}:') for k, v in stats.items(): @@ -86,6 +99,38 @@ def __str__(self): return '\n'.join(lines) + + +def _load_condition_image(image_path: str, image_size: int) -> torch.Tensor: + with Image.open(image_path) as image: + if image.mode != 'RGBA': + image = image.convert('RGBA') + + alpha_np = np.asarray(image.getchannel(3)) + nz = alpha_np.nonzero() + if len(nz[0]) == 0: + crop_box = (0, 0, image.width, image.height) + else: + x0, y0 = nz[1].min(), nz[0].min() + x1, y1 = nz[1].max(), nz[0].max() + center_x = (x0 + x1) * 0.5 + center_y = (y0 + y1) * 0.5 + half_size = max(x1 - x0, y1 - y0) * 0.5 + crop_box = ( + int(center_x - half_size), + int(center_y - half_size), + int(center_x + half_size), + int(center_y + half_size), + ) + + image = image.crop(crop_box).resize((image_size, image_size), Image.Resampling.LANCZOS) + + alpha = torch.from_numpy(np.asarray(image.getchannel(3))).float().div_(255.0) + rgb = torch.from_numpy(np.asarray(image.convert('RGB'))).permute(2, 0, 1).float().div_(255.0) + + return rgb.mul_(alpha.unsqueeze(0)) + + class ImageConditionedMixin: def __init__(self, roots, *, image_size=518, **kwargs): self.image_size = image_size @@ -99,36 +144,15 @@ def filter_metadata(self, metadata): def get_instance(self, root, instance): pack = super().get_instance(root, instance) - + image_root = os.path.join(root['render_cond'], instance) with open(os.path.join(image_root, 'transforms.json')) as f: metadata = json.load(f) - n_views = len(metadata['frames']) - view = np.random.randint(n_views) - metadata = metadata['frames'][view] - - image_path = os.path.join(image_root, metadata['file_path']) - image = Image.open(image_path) - - alpha = np.array(image.getchannel(3)) - bbox = np.array(alpha).nonzero() - bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] - center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] - hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 - aug_hsize = hsize - aug_center_offset = [0, 0] - aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] - aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] - image = image.crop(aug_bbox) - - image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) - alpha = image.getchannel(3) - image = image.convert('RGB') - image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 - alpha = torch.tensor(np.array(alpha)).float() / 255.0 - image = image * alpha.unsqueeze(0) - pack['cond'] = image - + + view = np.random.randint(len(metadata['frames'])) + image_path = os.path.join(image_root, metadata['frames'][view]['file_path']) + pack['cond'] = _load_condition_image(image_path, self.image_size) + return pack @@ -162,31 +186,7 @@ def get_instance(self, root, instance): for v in sampled_views: frame_info = metadata['frames'][v] image_path = os.path.join(image_root, frame_info['file_path']) - image = Image.open(image_path) - - alpha = np.array(image.getchannel(3)) - bbox = np.array(alpha).nonzero() - bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] - center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] - hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 - aug_hsize = hsize - aug_center = center - aug_bbox = [ - int(aug_center[0] - aug_hsize), - int(aug_center[1] - aug_hsize), - int(aug_center[0] + aug_hsize), - int(aug_center[1] + aug_hsize), - ] - - img = image.crop(aug_bbox) - img = img.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) - alpha = img.getchannel(3) - img = img.convert('RGB') - img = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0 - alpha = torch.tensor(np.array(alpha)).float() / 255.0 - img = img * alpha.unsqueeze(0) - - cond_images.append(img) + cond_images.append(_load_condition_image(image_path, self.image_size)) pack['cond'] = [torch.stack(cond_images, dim=0)] # (V,3,H,W) return pack