Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/rocm-wheels-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ jobs:
3rdparty/aotriton \
3rdparty/aiter \
3rdparty/QoLA \
3rdparty/ck_jit \
3rdparty/hipify_torch

- name: Derive Docker image tag
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@
[submodule "3rdparty/QoLA"]
path = 3rdparty/QoLA
url = https://github.com/Micky774/QoLA.git
[submodule "3rdparty/ck_jit"]
path = 3rdparty/ck_jit
url = https://github.com/ipanfilo/ck_jit.git
1 change: 1 addition & 0 deletions 3rdparty/ck_jit
Submodule ck_jit added at 83f602
43 changes: 41 additions & 2 deletions ci/_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,15 @@ start_message() {
python --version
}

configure_omp_threads() {
get_cpu_count() {
n_vcpus=$(lscpu | grep "^CPU(s):" | awk '{print $2}')
cpus_per_core=$(lscpu | grep "Thread(s) per core:" | awk '{print $NF}')

n_physical_cores=$((n_vcpus / cpus_per_core))
echo $((n_vcpus / cpus_per_core))
}

configure_omp_threads() {
n_physical_cores=`get_cpu_count`
n_parallel_jobs=$1

if [ -z ${OMP_NUM_THREADS} ]; then
Expand Down Expand Up @@ -269,3 +273,38 @@ pytest_run() {
pytest -v -rfEs `get_pytest_junitxml $_test_name_tag` $TEST_PYTEST_ARGS "$TEST_DIR/$@" || test_run_error "[$_test_variant_tag] $1"
echo "Done [$_test_variant_tag] $1 in `time_elapsed $_start_ts`"
}

PYTHON_TE_IMPORT="import sys; sys.path[:] = [p for p in sys.path if p not in ['', '.']]; import transformer_engine"
ck_jit_prebuild() {
_prebuild_list="${TE_PATH}ci/ck_jit_prebuild.txt"
if [ ! -f "$_prebuild_list" ]; then
echo "ck_jit_prebuild: blob list not found: $_prebuild_list" >&2
return 1
fi
_gpu_arch=$(rocminfo | grep -E "^ *Name: *gfx" | head -1 | sed "s/.*gfx/gfx/;s/ .*//" 2>/dev/null)
if [ -n "$_gpu_arch" ]; then
_arch_arg="--arch $_gpu_arch"
else
echo "ck_jit_prebuild: GPU architecture not detected, omitting --arch" >&2
_arch_arg=""
fi
_te_install_dir=$(python -c "${PYTHON_TE_IMPORT}; import os; print(os.path.dirname(transformer_engine.__file__))" 2>/dev/null)
if [ -z "$_te_install_dir" ]; then
echo "ck_jit_prebuild: failed to determine transformer_engine installation directory" >&2
return 1
fi
_prebuild_py="$_te_install_dir/lib/ck_jit/ck_jit_prebuild.py"
if [ ! -f "$_prebuild_py" ]; then
echo "ck_jit_prebuild: prebuild script not found: $_prebuild_py" >&2
return 1
fi
_cpu_count=$(get_cpu_count)
if [ -n "$_cpu_count" -a "$_cpu_count" != "0" ]; then
_jobs_arg="--jobs $((_cpu_count/2))"
fi
if [ "$1" = "build" ]; then
echo "Building CK JIT cache for arch=${_gpu_arch:-<not detected>}..."
python "$_prebuild_py" build --blob-list "$_prebuild_list" $_arch_arg $_jobs_arg > /dev/null
fi
python "$_prebuild_py" cache | grep Cache
}
531 changes: 531 additions & 0 deletions ci/ck_jit_prebuild.txt

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions ci/jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ install_prerequisites
pip list | egrep "flax|fidle|jax|ml_dtypes|numpy|transformer_e|typing_ext"
#check_test_jobs_requested
#test $? -eq 0 && init_test_jobs `python -c "import jax; print(len([d for d in jax.devices() if 'rocm' in d.client.platform_version]))"`
ck_jit_prebuild build

for _fus_attn in auto ck aotriton; do
configure_fused_attn_env $_fus_attn || continue
Expand Down Expand Up @@ -139,4 +140,6 @@ if [ -n "$TEST_JOBS_MODE" -a -n "$TEST_MGPU" ]; then
configure_fused_attn_env $_fus_attn && run_test_config_mgpu
done
fi

ck_jit_prebuild list
return_run_results
16 changes: 16 additions & 0 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,22 @@ if [ -n "$SINGLE_CONFIG" ]; then
exit $?
fi

check_flash_attn_installed() {
_result=$(python -c "${PYTHON_TE_IMPORT}; from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils; print(FlashAttentionUtils.is_installed)" 2>/dev/null)
if [ "$_result" = "True" ]; then
return 0
else
echo "Flash attention is not installed" >&2
return 1
fi
}

#Master script mode: prepare testing prerequisites first
start_message
install_prerequisites
pip list | egrep "flash|ml_dtypes|numpy|torch|transformer_e|typing_ext"
#check_test_jobs_requested && init_test_jobs `python -c "import torch; print(torch.cuda.device_count())"`
ck_jit_prebuild build

for _fus_attn in auto flash ck aotriton unfused; do
configure_fused_attn_env $_fus_attn || continue
Expand All @@ -160,6 +171,10 @@ for _fus_attn in auto flash ck aotriton unfused; do
_DEFAULT_FUSED_ATTN="auto"
fi

if [ $_fus_attn = flash ]; then
check_flash_attn_installed || continue
fi

if [ -n "$TEST_JOBS_MODE" ]; then
test -n "$TEST_SGPU" && run_test_job "$_fus_attn"
else
Expand All @@ -182,4 +197,5 @@ if [ $TEST_LEVEL -ge 3 ]; then
fi
fi

ck_jit_prebuild list
return_run_results
12 changes: 4 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,12 @@ def setup_common_extension() -> CMakeExtension:
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

if rocm_build():
cmake_flags.append("-DUSE_ROCM=ON")
if os.getenv("NVTE_AOTRITON_PATH"):
aotriton_path = Path(os.getenv("NVTE_AOTRITON_PATH"))
cmake_flags.append(f"-DAOTRITON_PATH={aotriton_path}")
cmake_flags.append(f"-DCK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT={os.getenv('NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT', 3)}")
if os.getenv("NVTE_CK_FUSED_ATTN_PATH"):
ck_path = Path(os.getenv("NVTE_CK_FUSED_ATTN_PATH"))
cmake_flags.append(f"-DAITER_MHA_PATH={ck_path}")
cmake_flags.append(
f"-DCK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT={os.getenv('NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT', '3')}"
)

if int(os.getenv("NVTE_FUSED_ATTN_AOTRITON", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0:
cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF")
Expand Down
88 changes: 64 additions & 24 deletions transformer_engine/common/ck_fused_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set(CMAKE_CXX_STANDARD 17)
project(ck_fused_attn LANGUAGES HIP CXX)


set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE")
set(AITER_MHA_INSTALL_DIR "${CMAKE_INSTALL_PREFIX}/transformer_engine/lib")

#Corresponding runtime check is in nvte_get_fused_attn_backend()
list(FIND CMAKE_HIP_ARCHITECTURES "gfx1250" _gfx1250_idx)
Expand Down Expand Up @@ -67,22 +67,54 @@ else()
message(WARNING "Python interpreter not found; skipping AITER API validation.")
endif()

if(DEFINED AITER_MHA_PATH)
message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=${AITER_MHA_PATH}")
# use pre-built te_libmha_fwd.so te_libmha_bwd.so
set(__AITER_MHA_PATH ${AITER_MHA_PATH})
set(__AITER_CACHE_DIR "")
set(__AITER_MHA_PATH "")
set(__QOLA_INCLUDE_DIR "")
if(NOT "$ENV{NVTE_CK_JIT}" STREQUAL "0")
set(__USE_CK_JIT TRUE)
else()
set(__AITER_MHA_PATH "")
set(__USE_CK_JIT FALSE)
endif()
if(DEFINED ENV{AITER_MHA_PATH})
message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=$ENV{AITER_MHA_PATH}")
# use pre-built libraries and includes from a location specified by the user
set(__AITER_CACHE_DIR $ENV{AITER_MHA_PATH})
elseif(NOT __USE_CK_JIT) #disable for CK_JIT for now
# use pre-built cache
include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake")
get_prebuilt_aiter(__AITER_MHA_PATH)
get_prebuilt_aiter(__AITER_CACHE_DIR)
endif()

if(__AITER_MHA_PATH STREQUAL "")
# If not available, fallback: Build from source via QoLA
list(JOIN CMAKE_HIP_ARCHITECTURES ";" GPU_ARCHS_STR)
message(STATUS "[AITER-BUILD] Building AITER kernels for ${GPU_ARCHS_STR} via QoLA.")
set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA")
if(__AITER_CACHE_DIR STREQUAL "")
# If not available, fallback: Build from source via QoLA
list(JOIN CMAKE_HIP_ARCHITECTURES ";" GPU_ARCHS_STR)
message(STATUS "[AITER-BUILD] Building AITER kernels for ${GPU_ARCHS_STR} via QoLA.")
set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA")
set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml")
if(__USE_CK_JIT)
message(STATUS "[AITER-BUILD] CK_JIT is enabled; will build AITER kernels via CK_JIT.")
set(__CK_JIT_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/ck_jit")
set(__QOLA_BUILD_DIR "${__CK_JIT_BUILD_DIR}/qola") #Need it under ck_jit to clean on full build
if(DEFINED ENV{NVTE_CK_JIT_DIR})
set(__CK_JIT_SOURCE_DIR $ENV{NVTE_CK_JIT_DIR})
else()
set(__CK_JIT_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/ck_jit")
endif()
execute_process(
COMMAND ${Python_EXECUTABLE} "${__CK_JIT_SOURCE_DIR}/ck_jit_build.py" full
--with-qola
--qola-dir ${__QOLA_DIR}
--qola-manifest ${__QOLA_MANIFEST}
--qola-output "${__QOLA_BUILD_DIR}"
--gpu-archs "${GPU_ARCHS_STR}"
--aiter-dir ${__AITER_SOURCE_DIR}
--tmp-dir "${__CK_JIT_BUILD_DIR}"
--install-dir ${AITER_MHA_INSTALL_DIR}
--jit-name "te_ck_jit"
RESULT_VARIABLE QOLA_BUILD_RESULT
)
else()
set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build")
set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml")
execute_process(
COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}"
${Python_EXECUTABLE} -m qola.cli build
Expand All @@ -92,22 +124,29 @@ else()
--arch "${GPU_ARCHS_STR}"
RESULT_VARIABLE QOLA_BUILD_RESULT
)
if(NOT QOLA_BUILD_RESULT EQUAL 0)
message(FATAL_ERROR "[AITER-BUILD] QoLA build failed.")
endif()
endif()
if(NOT QOLA_BUILD_RESULT EQUAL 0)
message(FATAL_ERROR "[AITER-BUILD] QoLA build failed.")
endif()

if(__USE_CK_JIT)
set(__AITER_MHA_PATH ${AITER_MHA_INSTALL_DIR})
set(__QOLA_INCLUDE_DIR "${__QOLA_BUILD_DIR}/include")
else()
# Copy the final .so libs and exported public headers into the aiter
# prebuilt cache so downstream consumers see a self-contained tree.
get_default_aiter_cache_dir(__QOLA_CACHE_DIR)
set(__QOLA_CACHE_LIB "${__QOLA_CACHE_DIR}/lib")
get_default_aiter_cache_dir(__AITER_CACHE_DIR)
set(__QOLA_CACHE_LIB "${__AITER_CACHE_DIR}/lib")
file(MAKE_DIRECTORY ${__QOLA_CACHE_LIB})
file(GLOB __QOLA_BUILT_LIBS "${__QOLA_BUILD_DIR}/lib/*.so")
file(COPY ${__QOLA_BUILT_LIBS} DESTINATION ${__QOLA_CACHE_LIB})
file(COPY "${__QOLA_BUILD_DIR}/include" DESTINATION "${__QOLA_CACHE_DIR}")
file(COPY "${__QOLA_BUILD_DIR}/include" DESTINATION "${__AITER_CACHE_DIR}")
set(__AITER_MHA_PATH "${__QOLA_CACHE_LIB}")
else()
message(STATUS "[AITER-BUILD] Using pre-built AITER from ${__AITER_MHA_PATH}")
set(__QOLA_INCLUDE_DIR "${__AITER_CACHE_DIR}/include")
endif()
else()
set(__AITER_MHA_PATH "${__AITER_CACHE_DIR}/lib")
set(__QOLA_INCLUDE_DIR "${__AITER_CACHE_DIR}/include")
endif()

set(ck_fused_attn_SOURCES)
Expand All @@ -129,7 +168,6 @@ list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS
# Public QoLA headers ship alongside the .so libs in ${__AITER_MHA_PATH}/../include
# (emitted by qola.cli build, or copied from the QoLA build dir above for the
# source-build path).
set(__QOLA_INCLUDE_DIR "${__AITER_MHA_PATH}/../include")
if(NOT EXISTS "${__QOLA_INCLUDE_DIR}/qola_config.h")
message(FATAL_ERROR "Could not find QoLA public headers at ${__QOLA_INCLUDE_DIR}.")
endif()
Expand All @@ -146,5 +184,7 @@ target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS})
target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS})
set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN")

install(FILES ${__AITER_MHA_PATH}/te_libmha_fwd.so ${__AITER_MHA_PATH}/te_libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
if (NOT "${__AITER_MHA_PATH}" STREQUAL "${AITER_MHA_INSTALL_DIR}")
install(FILES ${__AITER_MHA_PATH}/te_libmha_fwd.so ${__AITER_MHA_PATH}/te_libmha_bwd.so DESTINATION ${AITER_MHA_INSTALL_DIR})
endif()
install(TARGETS ck_fused_attn DESTINATION ${AITER_MHA_INSTALL_DIR})
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function(get_prebuilt_aiter PREBUILT_DIR_VAR)
is_aiter_cache_valid("${ROCM_VER_PARAM}" RESULT)
if(RESULT)
get_aiter_cache_key("${ROCM_VER_PARAM}" _UNUSED CACHE_DIR)
set(${PREBUILT_DIR_VAR} "${CACHE_DIR}/lib" PARENT_SCOPE)
set(${PREBUILT_DIR_VAR} "${CACHE_DIR}" PARENT_SCOPE)
return()
endif()
endforeach()
Expand All @@ -62,7 +62,7 @@ function(get_prebuilt_aiter PREBUILT_DIR_VAR)
download_aiter_prebuilt("${ROCM_VER_PARAM}" RESULT)
if(RESULT)
get_aiter_cache_key("${ROCM_VER_PARAM}" _UNUSED CACHE_DIR)
set(${PREBUILT_DIR_VAR} "${CACHE_DIR}/lib" PARENT_SCOPE)
set(${PREBUILT_DIR_VAR} "${CACHE_DIR}" PARENT_SCOPE)
return()
endif()
endforeach()
Expand Down
Loading