Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 65 additions & 22 deletions benchmarks/python/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
def transpose_fusion(
fd: FusionDefinition,
dtype: DataType,
is_copy_transpose: bool,
axes: list,
rank: int,
):
T0 = fd.define_tensor(
shape=[-1, -1, -1], contiguity=[True, True, True], dtype=dtype, is_cpu=False
)
T1 = fd.define_tensor(
shape=[-1, -1, -1], contiguity=[True, True, True], dtype=dtype, is_cpu=False
)
shape = [-1] * rank
contiguity = [True] * rank
T0 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False)
T1 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False)

if dtype in PROMOTE_DTYPES:
T0 = fd.ops.cast(T0, dtype=DataType.Float)
Expand All @@ -34,25 +34,55 @@ def transpose_fusion(
S6 = fd.define_scalar(0.00000, dtype=DataType.Double)
T7 = fd.ops.gt(T5, S6)
T9 = fd.ops.where(T7, T5, S6)

fd.add_output(T9)


def transpose_fwd_fn(inputs: list): # [input1, input2, dim0, dim1]
return torch.nn.functional.relu(
# add segmenter set to avoid presegment passes setting the output as a view of the input without any data movement. It leads to pointwise instead of transpose scheduler.
# we can also expose OptimizationPassGuard to python frontend and disable presegmentation passes to enforce output to be contiguous and then transpose scheduler will be used.
if is_copy_transpose:
T10 = fd.ops.segment_set(T9)
fd.add_output(T10)
else:
fd.add_output(T9)


# Without contiguous, transpose returns a view with swapped strides.
# contiguous() materializes a contiguous copy of the result.
# When compiled with thunder, contiguous version will use nvFuser's transpose scheduler, otherwise it will use the pointwise scheduler.
def transpose_fwd_fn(inputs: list): # [input1, input2, dim0, dim1, is_copy_transpose]
relu_transpose_result = torch.nn.functional.relu(
torch.transpose(inputs[0] + inputs[1], inputs[2], inputs[3])
)
is_copy_transpose = inputs[4]
if is_copy_transpose:
return relu_transpose_result.contiguous()
else:
return relu_transpose_result


@pytest.mark.parametrize("size", generate_input_sizes(dims=3))
def _generate_transpose_params():
params = []
for dims in (2, 3):
sizes = generate_input_sizes(dims=dims)
axes_list = [(0, 1)] if dims == 2 else [(0, 1), (0, 2), (1, 2)]
for size in sizes:
for axes in axes_list:
params.append((size, axes, dims))
return params


@pytest.mark.parametrize("size,axes,dims", _generate_transpose_params())
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("axes", [(0, 1), (0, 2), (1, 2)])
@pytest.mark.parametrize(
Copy link
Collaborator

@Priya2698 Priya2698 Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to benchmark view transpose? Should we remove it instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, it's not an expensive benchmark, so I just leave it as is in this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Please work with @xwang233 for dashboard integration.

"is_copy_transpose",
[True, False],
ids=["copy_transpose", "view_transpose"],
)
@pytest.mark.pointwise
def test_transpose_nvf_benchmark(
benchmark,
size: tuple,
is_copy_transpose: bool,
dtype: torch.dtype,
axes: list,
axes: tuple,
dims: int,
disable_validation: bool,
disable_benchmarking: bool,
):
Expand All @@ -65,26 +95,39 @@ def test_transpose_nvf_benchmark(
)

with FusionDefinition() as fd:
transpose_fusion(fd, torch_dtype_to_nvfuser_dtype(dtype), permute_axes)
transpose_fusion(
fd,
torch_dtype_to_nvfuser_dtype(dtype),
is_copy_transpose,
permute_axes,
rank=dims,
)

if not disable_validation:
eager_output = transpose_fwd_fn([input1, input2, axes[0], axes[1]])
eager_output = transpose_fwd_fn(
[input1, input2, axes[0], axes[1], is_copy_transpose]
)
fd.validate([input1, input2], [eager_output])

if not disable_benchmarking:
run_benchmark(benchmark, fd.execute, [input1, input2])


@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
@pytest.mark.parametrize("size", generate_input_sizes(dims=3))
@pytest.mark.parametrize("size,axes,dims", _generate_transpose_params())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, I used 3D inputs to match C++ benchmark. If 2D inputs are sufficient for benchmarking, we should remove the 3D benchmarking. This should also simplify the dashboard for this benchmark

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should keep 3D for different axes

@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("axes", [(0, 1), (0, 2), (1, 2)])
@pytest.mark.pointwise
@pytest.mark.parametrize(
"is_copy_transpose",
[True, False],
ids=["copy_transpose", "view_transpose"],
)
def test_transpose_baseline_benchmark(
benchmark,
size: tuple,
dtype: torch.dtype,
axes: list,
is_copy_transpose: bool,
axes: tuple,
dims: int,
executor: str,
):
if executor == "torchcompile":
Expand All @@ -98,5 +141,5 @@ def test_transpose_baseline_benchmark(
run_benchmark(
benchmark,
benchmark_fn,
[input1, input2, axes[0], axes[1]],
[input1, input2, axes[0], axes[1], is_copy_transpose],
)