diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index b9f440bd545..a909cde6e13 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -42,7 +43,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -804,6 +804,10 @@ def test_kernel_video(self): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.needs_cvcuda, + ), ], ) def test_functional(self, size, make_input): @@ -827,9 +831,16 @@ def test_functional(self, size, make_input): (F.resize_mask, tv_tensors.Mask), (F.resize_video, tv_tensors.Video), (F.resize_keypoints, tv_tensors.KeyPoints), + pytest.param( + F._geometry._resize_image_cvcuda, + None, + marks=pytest.mark.needs_cvcuda, + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._geometry._resize_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type) @pytest.mark.parametrize("size", OUTPUT_SIZES) @@ -845,6 +856,10 @@ def test_functional_signature(self, kernel, input_type): make_detection_masks, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.needs_cvcuda, + ), ], ) def test_transform(self, size, device, make_input): @@ -862,23 +877,77 @@ def _check_output_size(self, input, output, *, size, max_size): input_size=F.get_size(input), size=size, max_size=max_size ) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.needs_cvcuda, + ), + ], + ) @pytest.mark.parametrize("size", OUTPUT_SIZES) # `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2. # The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT` @pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST}) @pytest.mark.parametrize("use_max_size", [True, False]) @pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)]) - def test_image_correctness(self, size, interpolation, use_max_size, fn): + def test_image_correctness(self, make_input, size, interpolation, use_max_size, fn): if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): return - image = make_image(self.INPUT_SIZE, dtype=torch.uint8) + image = make_input(self.INPUT_SIZE, dtype=torch.uint8) actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=True) + + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0].cpu() + expected = F.to_image(F.resize(F.to_pil_image(image), size=size, interpolation=interpolation, **max_size_kwarg)) self._check_output_size(image, actual, size=size, **max_size_kwarg) - torch.testing.assert_close(actual, expected, atol=1, rtol=0) + + atol = 1 + # when using antialias, CV-CUDA is different for BICUBIC and BILINEAR, since antialias requires hq_resize + # hq_resize using interpolation will have differences on the edge boundaries + # no noticable visual difference + if make_input is make_image_cvcuda and ( + interpolation is transforms.InterpolationMode.BILINEAR + or interpolation is transforms.InterpolationMode.BICUBIC + ): + atol = 9 + assert_close(actual, expected, atol=atol, rtol=0) + + @needs_cvcuda + @pytest.mark.parametrize("size", OUTPUT_SIZES) + @pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST}) + @pytest.mark.parametrize("use_max_size", [True, False]) + @pytest.mark.parametrize("antialias", [True, False]) + @pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)]) + def test_image_correctness_cvcuda(self, size, interpolation, use_max_size, antialias, fn): + if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): + return + + image = make_image_cvcuda(self.INPUT_SIZE, dtype=torch.uint8) + actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=antialias) + expected = fn( + F.cvcuda_to_tensor(image), size=size, interpolation=interpolation, **max_size_kwarg, antialias=antialias + ) + + # assert_close will squeeze the batch dimension off the CV-CUDA tensor so we convert ahead of time + actual = F.cvcuda_to_tensor(actual) + + atol = 1 + if antialias: + # cvcuda.hq_resize is accurate within 9 for the tests + atol = 9 + elif interpolation == transforms.InterpolationMode.BICUBIC: + # the CV-CUDA bicubic interpolation differs significantly + # importantly, this is only the edge boundaries + # visually, there is no noticable difference + atol = 91 + assert_close(actual, expected, atol=atol, rtol=0) def _reference_resize_bounding_boxes(self, bounding_boxes, format, *, size, max_size=None): old_height, old_width = bounding_boxes.canvas_size @@ -964,11 +1033,26 @@ def test_keypoints_correctness(self, size, use_max_size, fn): @pytest.mark.parametrize("interpolation", set(transforms.InterpolationMode) - set(INTERPOLATION_MODES)) @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.needs_cvcuda, + ), + ], ) def test_pil_interpolation_compat_smoke(self, interpolation, make_input): input = make_input(self.INPUT_SIZE) + if make_input is make_image_cvcuda and interpolation in { + transforms.InterpolationMode.BOX, + transforms.InterpolationMode.LANCZOS, + }: + pytest.skip("CV-CUDA may support box and lanczos for certain configurations of resize") + with ( contextlib.nullcontext() if isinstance(input, PIL.Image.Image) @@ -997,6 +1081,10 @@ def test_functional_pil_antialias_warning(self): make_detection_masks, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.needs_cvcuda, + ), ], ) def test_max_size_error(self, size, make_input): @@ -1040,6 +1128,10 @@ def test_max_size_error(self, size, make_input): make_detection_masks, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.needs_cvcuda, + ), ], ) def test_resize_size_none(self, input_size, max_size, expected_size, make_input): @@ -1050,7 +1142,16 @@ def test_resize_size_none(self, input_size, max_size, expected_size, make_input) @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES) @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.needs_cvcuda, + ), + ], ) def test_interpolation_int(self, interpolation, make_input): input = make_input(self.INPUT_SIZE) @@ -1114,6 +1215,10 @@ def test_noop(self, size, make_input): make_detection_masks, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.needs_cvcuda, + ), ], ) def test_no_regression_5405(self, make_input): diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 96166e05e9a..9283955b224 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -139,6 +139,8 @@ class Resize(Transform): _v1_transform_cls = _transforms.Resize + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def __init__( self, size: Union[int, Sequence[int], None], @@ -1258,6 +1260,8 @@ class ScaleJitter(Transform): v0.17, for the PIL and Tensor backends to be consistent. """ + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def __init__( self, target_size: tuple[int, int], @@ -1323,6 +1327,8 @@ class RandomShortestSize(Transform): v0.17, for the PIL and Tensor backends to be consistent. """ + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def __init__( self, min_size: Union[list[int], tuple[int], int], @@ -1402,6 +1408,8 @@ class RandomResize(Transform): v0.17, for the PIL and Tensor backends to be consistent. """ + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def __init__( self, min_size: int, diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..e803aa49c60 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -16,7 +16,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + _is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 0e27218bc89..7032881c7a2 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -28,6 +28,7 @@ from ._utils import ( _FillTypeJIT, + _get_cvcuda_interp, _get_kernel, _import_cvcuda, _is_cvcuda_available, @@ -401,6 +402,82 @@ def __resize_image_pil_dispatch( return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size) +_dtype_to_format_cvcuda: dict["cvcuda.Type", "cvcuda.Format"] = {} + + +def _resize_image_cvcuda( + image: "cvcuda.Tensor", + size: Optional[list[int]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, +) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + if not _dtype_to_format_cvcuda: + _dtype_to_format_cvcuda[cvcuda.Type.U8] = cvcuda.Format.U8 + _dtype_to_format_cvcuda[cvcuda.Type.U16] = cvcuda.Format.U16 + _dtype_to_format_cvcuda[cvcuda.Type.U32] = cvcuda.Format.U32 + _dtype_to_format_cvcuda[cvcuda.Type.S8] = cvcuda.Format.S8 + _dtype_to_format_cvcuda[cvcuda.Type.S16] = cvcuda.Format.S16 + _dtype_to_format_cvcuda[cvcuda.Type.S32] = cvcuda.Format.S32 + _dtype_to_format_cvcuda[cvcuda.Type.F32] = cvcuda.Format.F32 + _dtype_to_format_cvcuda[cvcuda.Type.F64] = cvcuda.Format.F64 + + interp = _get_cvcuda_interp(interpolation) + # hamming error for parity to resize_image + if interp == cvcuda.Interp.HAMMING: + raise NotImplementedError("Unsupported interpolation for CV-CUDA resize, got hamming.") + + # match the antialias behavior of resize_image + if not (interp == cvcuda.Interp.LINEAR or interp == cvcuda.Interp.CUBIC): + antialias = False + + old_height, old_width = image.shape[1], image.shape[2] + new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) + + # No resize needed if dimensions match + if new_height == old_height and new_width == old_width: + return image + + # antialias is only supported for cvcuda.hq_resize, if set to true (which is also default) + # we will fast-track to use hq_resize (also matchs the size parameter) + if antialias: + return cvcuda.hq_resize( + image, + out_size=(new_height, new_width), + interpolation=interp, + antialias=antialias, + ) + + # if not using antialias, we will use cvcuda.resize/pillowresize instead + # resize requires that the shape has the same dimensions as the input + # CV-CUDA tensors are already in NHWC format so we can do a simple tuple creation + shape = image.shape + new_shape = (shape[0], new_height, new_width, shape[3]) + + # bicubic mode is not accurate when using cvcuda.resize + # cvcuda.pillowresize resolves some of the errors + if interp == cvcuda.Interp.CUBIC: + return cvcuda.pillowresize( + image, + shape=new_shape, + format=_dtype_to_format_cvcuda[image.dtype], + interp=interp, + ) + + # otherwise we will use cvcuda.resize + return cvcuda.resize( + image, + shape=new_shape, + interp=interp, + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(resize, _import_cvcuda().Tensor)(_resize_image_cvcuda) + + def resize_mask(mask: torch.Tensor, size: Optional[list[int]], max_size: Optional[int] = None) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 11480b30ef9..4ad08006b00 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,9 +1,13 @@ import functools from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torchvision import tv_tensors +from torchvision.transforms.functional import InterpolationMode + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] _FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillTypeJIT = Optional[list[float]] @@ -177,3 +181,37 @@ def _is_cvcuda_tensor(inpt: Any) -> bool: return isinstance(inpt, cvcuda.Tensor) except ImportError: return False + + +_interpolation_mode_to_cvcuda_interp: dict[InterpolationMode | str | int, "cvcuda.Interp"] = {} + + +def _get_cvcuda_interp(interpolation: InterpolationMode | str | int) -> "cvcuda.Interp": + if not _interpolation_mode_to_cvcuda_interp: + cvcuda = _import_cvcuda() + _interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST_EXACT] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp[InterpolationMode.BILINEAR] = cvcuda.Interp.LINEAR + _interpolation_mode_to_cvcuda_interp[InterpolationMode.BICUBIC] = cvcuda.Interp.CUBIC + _interpolation_mode_to_cvcuda_interp[InterpolationMode.BOX] = cvcuda.Interp.BOX + _interpolation_mode_to_cvcuda_interp[InterpolationMode.HAMMING] = cvcuda.Interp.HAMMING + _interpolation_mode_to_cvcuda_interp[InterpolationMode.LANCZOS] = cvcuda.Interp.LANCZOS + _interpolation_mode_to_cvcuda_interp["nearest"] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp["nearest-exact"] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp["bilinear"] = cvcuda.Interp.LINEAR + _interpolation_mode_to_cvcuda_interp["bicubic"] = cvcuda.Interp.CUBIC + _interpolation_mode_to_cvcuda_interp["box"] = cvcuda.Interp.BOX + _interpolation_mode_to_cvcuda_interp["hamming"] = cvcuda.Interp.HAMMING + _interpolation_mode_to_cvcuda_interp["lanczos"] = cvcuda.Interp.LANCZOS + _interpolation_mode_to_cvcuda_interp[0] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp[2] = cvcuda.Interp.LINEAR + _interpolation_mode_to_cvcuda_interp[3] = cvcuda.Interp.CUBIC + _interpolation_mode_to_cvcuda_interp[4] = cvcuda.Interp.BOX + _interpolation_mode_to_cvcuda_interp[5] = cvcuda.Interp.HAMMING + _interpolation_mode_to_cvcuda_interp[1] = cvcuda.Interp.LANCZOS + + interp = _interpolation_mode_to_cvcuda_interp.get(interpolation) + if interp is None: + raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA") + + return interp