Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions include/gpu_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2816,26 +2816,13 @@ namespace gpu_array
template <Stride StrideType>
struct stride_adapter
{
template <std::ranges::random_access_range Range>
requires std::ranges::sized_range<Range>
[[nodiscard]] constexpr auto operator()(const Range& r) const noexcept
{
return stride_view<StrideType, const Range&>(r);
}
template <std::ranges::random_access_range Range>
requires std::ranges::sized_range<Range>
[[nodiscard]] constexpr auto operator()(Range& r) const noexcept
{
return stride_view<StrideType, Range&>(r);
}

template <std::ranges::random_access_range Range>
requires std::ranges::sized_range<Range>
[[nodiscard]] friend constexpr std::ranges::view auto operator|(const Range& range,
const stride_adapter& self) noexcept
{
return self(range);
}
template <std::ranges::random_access_range Range>
requires std::ranges::sized_range<Range>
[[nodiscard]] friend constexpr std::ranges::view auto operator|(Range& range,
Expand All @@ -2846,13 +2833,16 @@ namespace gpu_array
};
} // namespace detail

#if !defined(ENABLE_HIP)
// The following three alias templates are also disabled in HIP because HIP does not support alias template argument
// deduction.
template <std::ranges::random_access_range Range>
using block_thread_stride_view = detail::stride_view<detail::Stride::BlockThread, Range>;
template <std::ranges::random_access_range Range>
using grid_thread_stride_view = detail::stride_view<detail::Stride::GridThread, Range>;
template <std::ranges::random_access_range Range>
using grid_block_stride_view = detail::stride_view<detail::Stride::GridBlock, Range>;
#if !defined(ENABLE_HIP)

template <std::ranges::random_access_range Range>
using cluster_thread_stride_view = detail::stride_view<detail::Stride::ClusterThread, Range>;
template <std::ranges::random_access_range Range>
Expand Down
24 changes: 16 additions & 8 deletions test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2070,6 +2070,7 @@ TEST(JaggedArray, MemoryManagement)
}
}

#if !defined(ENABLE_HIP)
template <std::ranges::input_range T>
requires std::ranges::input_range<std::ranges::range_value_t<T>>
__global__ void kernel_stride(T array)
Expand All @@ -2086,14 +2087,6 @@ __global__ void kernel_stride2(T array)
for (auto& v : views::block_thread_stride(a)) v = 2;
}

template <std::ranges::input_range T>
requires std::ranges::input_range<std::ranges::range_value_t<T>>
__global__ void kernel_stride3(T array)
{
for (auto& a : grid_block_stride_view(array))
for (auto& v : block_thread_stride_view(a)) v = 3;
}

TEST(StrideView, HowToUse)
{
auto vec_vec = std::vector(32, std::vector<int>(64, 0));
Expand All @@ -2108,10 +2101,25 @@ TEST(StrideView, HowToUse)
api::gpuDeviceSynchronize();
for (const auto& inner_array : nested_array)
for (const auto& v : inner_array) EXPECT_EQ(v, 2);
}

template <std::ranges::input_range T>
requires std::ranges::input_range<std::ranges::range_value_t<T>>
__global__ void kernel_stride3(T array)
{
for (auto& a : grid_block_stride_view(array))
for (auto& v : block_thread_stride_view(a)) v = 3;
}

TEST(StrideView, AliasTemplate)
{
auto vec_vec = std::vector(32, std::vector<int>(64, 0));
auto nested_array = managed_array(vec_vec);

kernel_stride3<<<32, 64>>>(nested_array);
api::gpuDeviceSynchronize();
for (const auto& inner_array : nested_array)
for (const auto& v : inner_array) EXPECT_EQ(v, 3);
}
#endif
// NOLINTEND