diff --git a/benchmarks/python/test_transpose.py b/benchmarks/python/test_transpose.py index 363f32ed8be..11b66774708 100644 --- a/benchmarks/python/test_transpose.py +++ b/benchmarks/python/test_transpose.py @@ -12,14 +12,14 @@ def transpose_fusion( fd: FusionDefinition, dtype: DataType, + is_copy_transpose: bool, axes: list, + rank: int, ): - T0 = fd.define_tensor( - shape=[-1, -1, -1], contiguity=[True, True, True], dtype=dtype, is_cpu=False - ) - T1 = fd.define_tensor( - shape=[-1, -1, -1], contiguity=[True, True, True], dtype=dtype, is_cpu=False - ) + shape = [-1] * rank + contiguity = [True] * rank + T0 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False) + T1 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False) if dtype in PROMOTE_DTYPES: T0 = fd.ops.cast(T0, dtype=DataType.Float) @@ -34,25 +34,55 @@ def transpose_fusion( S6 = fd.define_scalar(0.00000, dtype=DataType.Double) T7 = fd.ops.gt(T5, S6) T9 = fd.ops.where(T7, T5, S6) - - fd.add_output(T9) - - -def transpose_fwd_fn(inputs: list): # [input1, input2, dim0, dim1] - return torch.nn.functional.relu( + # add segmenter set to avoid presegment passes setting the output as a view of the input without any data movement. It leads to pointwise instead of transpose scheduler. + # we can also expose OptimizationPassGuard to python frontend and disable presegmentation passes to enforce output to be contiguous and then transpose scheduler will be used. + if is_copy_transpose: + T10 = fd.ops.segment_set(T9) + fd.add_output(T10) + else: + fd.add_output(T9) + + +# Without contiguous, transpose returns a view with swapped strides. +# contiguous() materializes a contiguous copy of the result. +# When compiled with thunder, contiguous version will use nvFuser's transpose scheduler, otherwise it will use the pointwise scheduler. +def transpose_fwd_fn(inputs: list): # [input1, input2, dim0, dim1, is_copy_transpose] + relu_transpose_result = torch.nn.functional.relu( torch.transpose(inputs[0] + inputs[1], inputs[2], inputs[3]) ) + is_copy_transpose = inputs[4] + if is_copy_transpose: + return relu_transpose_result.contiguous() + else: + return relu_transpose_result -@pytest.mark.parametrize("size", generate_input_sizes(dims=3)) +def _generate_transpose_params(): + params = [] + for dims in (2, 3): + sizes = generate_input_sizes(dims=dims) + axes_list = [(0, 1)] if dims == 2 else [(0, 1), (0, 2), (1, 2)] + for size in sizes: + for axes in axes_list: + params.append((size, axes, dims)) + return params + + +@pytest.mark.parametrize("size,axes,dims", _generate_transpose_params()) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) -@pytest.mark.parametrize("axes", [(0, 1), (0, 2), (1, 2)]) +@pytest.mark.parametrize( + "is_copy_transpose", + [True, False], + ids=["copy_transpose", "view_transpose"], +) @pytest.mark.pointwise def test_transpose_nvf_benchmark( benchmark, size: tuple, + is_copy_transpose: bool, dtype: torch.dtype, - axes: list, + axes: tuple, + dims: int, disable_validation: bool, disable_benchmarking: bool, ): @@ -65,10 +95,18 @@ def test_transpose_nvf_benchmark( ) with FusionDefinition() as fd: - transpose_fusion(fd, torch_dtype_to_nvfuser_dtype(dtype), permute_axes) + transpose_fusion( + fd, + torch_dtype_to_nvfuser_dtype(dtype), + is_copy_transpose, + permute_axes, + rank=dims, + ) if not disable_validation: - eager_output = transpose_fwd_fn([input1, input2, axes[0], axes[1]]) + eager_output = transpose_fwd_fn( + [input1, input2, axes[0], axes[1], is_copy_transpose] + ) fd.validate([input1, input2], [eager_output]) if not disable_benchmarking: @@ -76,15 +114,20 @@ def test_transpose_nvf_benchmark( @pytest.mark.parametrize("executor", DEFAULT_EXECUTORS) -@pytest.mark.parametrize("size", generate_input_sizes(dims=3)) +@pytest.mark.parametrize("size,axes,dims", _generate_transpose_params()) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) -@pytest.mark.parametrize("axes", [(0, 1), (0, 2), (1, 2)]) -@pytest.mark.pointwise +@pytest.mark.parametrize( + "is_copy_transpose", + [True, False], + ids=["copy_transpose", "view_transpose"], +) def test_transpose_baseline_benchmark( benchmark, size: tuple, dtype: torch.dtype, - axes: list, + is_copy_transpose: bool, + axes: tuple, + dims: int, executor: str, ): if executor == "torchcompile": @@ -98,5 +141,5 @@ def test_transpose_baseline_benchmark( run_benchmark( benchmark, benchmark_fn, - [input1, input2, axes[0], axes[1]], + [input1, input2, axes[0], axes[1], is_copy_transpose], )