diff --git a/sycl/include/sycl/detail/builtins/builtin_helpers.hpp b/sycl/include/sycl/detail/builtins/builtin_helpers.hpp new file mode 100644 index 0000000000000..8a36aeccabc1b --- /dev/null +++ b/sycl/include/sycl/detail/builtins/builtin_helpers.hpp @@ -0,0 +1,330 @@ +//==------------------- builtin_helpers.hpp -------------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// Shared support for SYCL builtin function declarations and implementations. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace sycl { +inline namespace _V1 { +namespace detail { +#ifdef __FAST_MATH__ +template +struct use_fast_math + : std::is_same>, float> {}; +#else +template struct use_fast_math : std::false_type {}; +#endif +template constexpr bool use_fast_math_v = use_fast_math::value; + +// Utility trait for checking if a multi_ptr has a "writable" address space, +// i.e. global, local, private or generic. +template struct has_writeable_addr_space : std::false_type {}; +template +struct has_writeable_addr_space> + : std::bool_constant {}; + +template +constexpr bool has_writeable_addr_space_v = has_writeable_addr_space::value; + +// Classification of pointer-like types used by builtin pointer helpers. +enum class builtin_ptr_kind { raw, multi_ptr }; + +// Maps a pointer-like type to the corresponding builtin_ptr_kind tag. +template +using builtin_ptr_kind_tag_t = std::integral_constant< + builtin_ptr_kind, + is_multi_ptr_v>> + ? builtin_ptr_kind::multi_ptr + : builtin_ptr_kind::raw>; + +// Returns Ptr unchanged for raw pointer-like types. +template +decltype(auto) builtin_raw_ptr( + PtrTy &&Ptr, + std::integral_constant) { + return std::forward(Ptr); +} + +// Extracts the underlying raw pointer from a multi_ptr. +template +auto builtin_raw_ptr( + PtrTy &&Ptr, + std::integral_constant) { + return get_raw_pointer(std::forward(Ptr)); +} + +// Returns a raw pointer representation for raw pointers and multi_ptrs. +template auto builtin_raw_ptr(PtrTy &&Ptr) { + return builtin_raw_ptr(std::forward(Ptr), + builtin_ptr_kind_tag_t{}); +} + +// Returns a pointer to the first element for raw pointer-like types. +template +decltype(auto) builtin_element_pointer( + PtrTy &&Ptr, + std::integral_constant) { + return &(*std::forward(Ptr))[0]; +} + +// Returns a pointer to the first element while preserving multi_ptr semantics. +template +auto builtin_element_pointer( + PtrTy &&Ptr, + std::integral_constant) { + return detail::get_element_pointer(std::forward(Ptr)); +} + +// Returns an element pointer for raw pointers and multi_ptrs. +template auto builtin_element_pointer(PtrTy &&Ptr) { + return builtin_element_pointer(std::forward(Ptr), + builtin_ptr_kind_tag_t{}); +} + +// Utility trait for changing the element type of a type T. If T is a scalar, +// the new type replaces T completely. +template +struct change_elements { + using type = NewElemT; +}; +template +struct change_elements>> { + using type = + marray::type, + T::size()>; +}; +template +struct change_elements>> { + using type = + vec::type, + T::size()>; +}; + +template +using change_elements_t = typename change_elements::type; + +template +inline constexpr bool builtin_same_shape_v = + ((... && is_scalar_arithmetic_v) || (... && is_marray_v) || + (... && is_vec_or_swizzle_v)) && + (... && (num_elements::value == + num_elements::type>::value)); + +template +inline constexpr bool builtin_same_or_swizzle_v = + // Use builtin_same_shape_v to filter out types unrelated to builtins. + builtin_same_shape_v && all_same_v...>; + +// Utility functions for converting to/from vec/marray. +template vec to_vec2(marray X, size_t Start) { + return {X[Start], X[Start + 1]}; +} +template vec to_vec(marray X) { + vec Vec; + for (size_t I = 0; I < N; I++) { + Vec[I] = X[I]; + } + return Vec; +} +template marray to_marray(vec X) { + marray Marray; + for (size_t I = 0; I < N; I++) { + Marray[I] = X[I]; + } + return Marray; +} + +// Relation builtins widen signed-char masks to the required integer element +// type. Keep that conversion local here so builtin headers do not need to pull +// in vector_convert.hpp just for vec::convert. +template +vec relational_mask_widen(vec X) { + static_assert(std::is_integral_v && + !std::is_same_v); + +#if defined(__SYCL_DEVICE_ONLY__) && !defined(__NVPTX__) + // Keep NVPTX on the scalar fallback for consistency with vec::convert. + // TODO: Likely unnecessary as https://github.com/intel/llvm/issues/11840 + // has been closed already. + if constexpr (N > 1) { + using src_vector_t = signed char __attribute__((ext_vector_type(N))); + using dst_vector_t = NewElemT __attribute__((ext_vector_type(N))); + auto OpenCLVec = bit_cast(X); + return bit_cast>( + __builtin_convertvector(OpenCLVec, dst_vector_t)); + } +#endif // defined(__SYCL_DEVICE_ONLY__) && !defined(__NVPTX__) + + vec Result{}; + loop([&](auto idx) { Result[idx] = static_cast(X[idx]); }); + return Result; +} + +namespace builtins { +#ifdef __SYCL_DEVICE_ONLY__ +template auto convert_arg(T &&x) { + using no_cv_ref = std::remove_cv_t>; + if constexpr (is_vec_v) { + using elem_type = get_elem_type_t; + using converted_elem_type = + decltype(convert_arg(std::declval())); + + constexpr auto N = no_cv_ref::size(); + using result_type = std::conditional_t; + return bit_cast(x); + } else if constexpr (is_swizzle_v) { + return convert_arg(simplify_if_swizzle_t{x}); + } else { + static_assert(is_scalar_arithmetic_v || + is_multi_ptr_v || std::is_pointer_v || + std::is_same_v); + return convertToOpenCLType(std::forward(x)); + } +} +#endif +} // namespace builtins + +template +auto builtin_marray_impl(FuncTy F, const Ts &...x) { + using ret_elem_type = decltype(F(x[0]...)); + using T = typename first_type::type; + marray Res; + constexpr auto N = T::size(); + for (size_t I = 0; I < N / 2; ++I) { + auto PartialRes = [&]() { + using elem_ty = get_elem_type_t; + if constexpr (std::is_integral_v) { + return F( + to_vec2(x, I * 2) + .template as, + fixed_width_signed, + fixed_width_unsigned>, + 2>>()...); + } else { + return F(to_vec2(x, I * 2)...); + } + }(); + sycl::detail::memcpy_no_adl(&Res[I * 2], &PartialRes, + sizeof(decltype(PartialRes))); + } + if (N % 2) { + Res[N - 1] = F(x[N - 1]...); + } + return Res; +} + +template +auto builtin_default_host_impl(FuncTy F, const Ts &...x) { + // We implement support for marray/swizzle in the headers and export symbols + // for scalars/vector from the library binary. The reason is that scalar + // implementations mostly depend on which pollutes global namespace, + // so we can't unconditionally include it from the SYCL headers. Vector + // overloads have to be implemented in the library next to scalar overloads in + // order to be vectorizable. + if constexpr ((... || is_marray_v)) { + return builtin_marray_impl(F, x...); + } else { + return F(simplify_if_swizzle_t{x}...); + } +} + +template +auto builtin_delegate_to_scalar(FuncTy F, const Ts &...x) { + using T = typename first_type::type; + static_assert(is_vec_or_swizzle_v || is_marray_v); + + constexpr auto Size = T::size(); + using ret_elem_type = decltype(F(x[0]...)); + std::conditional_t, marray, + vec> + r{}; + + if constexpr (is_marray_v) { + for (size_t i = 0; i < Size; ++i) { + r[i] = F(x[i]...); + } + } else { + loop([&](auto idx) { r[idx] = F(x[idx]...); }); + } + + return r; +} + +template +struct fp_elem_type + : std::bool_constant< + check_type_in_v, float, double, half>> {}; +template +struct float_elem_type + : std::bool_constant, float>> {}; + +template +struct same_basic_shape : std::bool_constant> {}; + +template +struct same_elem_type : std::bool_constant::value && + all_same_v...>> { +}; + +template struct any_shape : std::true_type {}; + +template +struct scalar_only : std::bool_constant> {}; + +template +struct non_scalar_only : std::bool_constant> {}; + +template struct default_ret_type { + using type = T; +}; + +template struct scalar_ret_type { + using type = get_elem_type_t; +}; + +template