diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d80564274..0f05dbc40 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,6 +7,10 @@ name: 'Build' on: pull_request: workflow_dispatch: +concurrency: + # Group by workflow name + PR number (for PRs) or ref (for branch/tag pushes) + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true jobs: core: name: 'Core' diff --git a/.github/workflows/deploy_nightly_docs.yml b/.github/workflows/deploy_nightly_docs.yml index b4e015d2d..a8e5ee5ba 100644 --- a/.github/workflows/deploy_nightly_docs.yml +++ b/.github/workflows/deploy_nightly_docs.yml @@ -7,6 +7,7 @@ name: Deploy nightly docs on: push: branches: [ "main" ] + workflow_dispatch: jobs: build: uses: ./.github/workflows/docs.yml @@ -21,9 +22,8 @@ jobs: name: "te_docs" path: "html" - name: Prepare for pages - uses: actions/upload-pages-artifact@v1.0.7 + uses: actions/upload-pages-artifact@v3 with: - name: github-pages path: "html" deploy: needs: prepare @@ -36,4 +36,5 @@ jobs: runs-on: ubuntu-latest steps: - name: Deploy - uses: actions/deploy-pages@v2.0.0 + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 6fde0338a..9d38d709e 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -8,6 +8,10 @@ on: pull_request: workflow_dispatch: workflow_call: +concurrency: + # Group by workflow name + PR number (for PRs) or ref (for branch/tag pushes) + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true jobs: build_docs: name: 'Build' diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c5cb748c2..1d2fb272f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,6 +7,10 @@ name: 'Lint' on: pull_request: workflow_dispatch: +concurrency: + # Group by workflow name + PR number (for PRs) or ref (for branch/tag pushes) + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true jobs: pytorch_cpplint: name: 'PyTorch C++' diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index c56601ae9..de26531a9 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -58,6 +58,8 @@ jobs: || github.actor == 'vthumbe1503' || github.actor == 'shengfangd' || github.actor == 'kainzhong' + || github.actor == 'cspades' + || github.actor == 'jomitchellnv' ) steps: - name: Check if comment is issued by authorized person diff --git a/.gitignore b/.gitignore index 922dbb56b..ef55928a2 100644 --- a/.gitignore +++ b/.gitignore @@ -56,4 +56,8 @@ artifacts/ **/times.csv transformer_engine/build_info.txt transformer_engine/common/util/hip_nvml.* -*.DS_Store +.DS_Store +.rsync-filter +.codex/ +.cline_storage/ +.claude/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5043d6ea2..76f476eb3 100755 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,4 +43,4 @@ repos: rev: c75aca72f4e85c6e47252139e8695f1c8b5f9ae3 hooks: - id: vermin - args: ['-t=3.10', '--violations'] + args: ['-t=3.10-', '--violations'] diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 0258951d4..8d19d3182 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 +Subproject commit 8d19d3182bfbc304046a15e9236bec9ff31511fc diff --git a/MANIFEST.in b/MANIFEST.in index c34025772..c2309a037 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ recursive-include transformer_engine/common/include *.* +recursive-include build_tools *.py *.txt diff --git a/README.rst b/README.rst index 2a6d88dd9..9e9683c52 100644 --- a/README.rst +++ b/README.rst @@ -458,7 +458,7 @@ Flax for _ in range(10): loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp) -For a more comprehensive tutorial, check out our `Quickstart Notebook `_. +For a more comprehensive tutorial, check out our `Getting Started Guide `_. .. overview-end-marker-do-not-remove @@ -496,15 +496,22 @@ For example to use the NGC PyTorch container interactively, .. code-block:: bash - docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.08-py3 + docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:26.01-py3 For example to use the NGC JAX container interactively, .. code-block:: bash - docker run --gpus all -it --rm nvcr.io/nvidia/jax:25.08-py3 + docker run --gpus all -it --rm nvcr.io/nvidia/jax:26.01-py3 -Where 25.08 (corresponding to August 2025 release) is the container version. +Where 26.01 (corresponding to January 2026 release) is the container version. + +We recommend updating to the latest NGC container available here: + +* https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch +* https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax + +If you run any examples, please ensure you are using a matching version of TransformerEngine. TransformerEngine is pre-built and packaged inside the containers with examples available at ``/opt/transformerengine`` or ``/opt/transformer-engine``. If you would like to use examples from TE main branch and are running into import errors, please try the latest pip package or building from source, although NGC containers are recommended for ease-of-use for most users. **Benefits of using NGC containers:** @@ -628,6 +635,37 @@ Troubleshooting cd transformer_engine pip install -v -v -v --no-build-isolation . +**Problems using UV or Virtual Environments:** + +1. **Import Error:** + + * **Symptoms:** Cannot import ``transformer_engine`` + * **Solution:** Ensure your UV environment is active and that you have used ``uv pip install --no-build-isolation `` instead of a regular pip install to your system environment. + +2. **cuDNN Sublibrary Loading Failed:** + + * **Symptoms:** Errors at runtime with ``CUDNN_STATUS_SUBLIBRARY_LOADING_FAILED`` + * **Solution:** This can occur when TE is built against the container's system installation of cuDNN, but pip packages inside the virtual environment pull in pip packages for ``nvidia-cudnn-cu12/cu13``. To resolve this, when building TE from source please specify the following environment variables to point to the cuDNN in your virtual environment. + + + .. code-block:: bash + + export CUDNN_PATH=$(pwd)/.venv/lib/python3.12/site-packages/nvidia/cudnn + export CUDNN_HOME=$CUDNN_PATH + export LD_LIBRARY_PATH=$CUDNN_PATH/lib:$LD_LIBRARY_PATH + +3. **Building Wheels:** + + * **Symptoms:** Regular TE installs work correctly but UV wheel builds fail at runtime. + * **Solution:** Ensure that ``uv build --wheel --no-build-isolation -v`` is used during the wheel build as well as the pip installation of the wheel. Use ``-v`` for verbose output to verify that TE is not pulling in a mismatching version of PyTorch or JAX that differs from the UV environment's version. + +**JAX-specific Common Issues and Solutions:** + +1. **FFI Issues:** + + * **Symptoms:** ``No registered implementation for custom call to for platform CUDA`` + * **Solution:** Ensure ``--no-build-isolation`` is used during installation. If pre-building wheels, ensure that the wheel is both built and installed with ``--no-build-isolation``. See "Problems using UV or Virtual Environments" above if using UV. + .. troubleshooting-end-marker-do-not-remove Breaking Changes diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index d5e1cb291..c7d530773 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.12.0.dev0 +2.14.0.dev0 diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 5fc3cded0..a14a24c50 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -12,6 +12,7 @@ "__nv_fp8_e4m3" : "te_hip_fp8_e4m3", "cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA", "at::cuda::CUDAGuard" : "at::hip::HIPGuardMasqueradingAsCUDA", + "c10::cuda::" : "c10::hip::", "__nv_fp4_e2m1" : "__hip_fp4_e2m1", "__nv_fp4x2_e2m1" : "__hip_fp4x2_e2m1", "__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1", diff --git a/build_tools/jax.py b/build_tools/jax.py index 62dc4336e..5de2e4c40 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -59,12 +59,7 @@ def xla_path() -> str: Throws FileNotFoundError if XLA source is not found.""" try: - import jax - from packaging import version - if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi - else: - from jax.extend import ffi + from jax import ffi except ImportError: if os.getenv("XLA_HOME"): xla_home = Path(os.getenv("XLA_HOME")) diff --git a/build_tools/utils.py b/build_tools/utils.py index e250238e6..feab50116 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -306,9 +306,10 @@ def nvcc_path() -> Tuple[str, str]: def get_cuda_include_dirs() -> Tuple[str, str]: """Returns the CUDA header directory.""" + force_wheels = bool(int(os.getenv("NVTE_BUILD_USE_NVIDIA_WHEELS", "0"))) # If cuda is installed via toolkit, all necessary headers # are bundled inside the top level cuda directory. - if cuda_toolkit_include_path() is not None: + if not force_wheels and cuda_toolkit_include_path() is not None: return [cuda_toolkit_include_path()] # Use pip wheels to include all headers. @@ -317,7 +318,10 @@ def get_cuda_include_dirs() -> Tuple[str, str]: except ModuleNotFoundError as e: raise RuntimeError("CUDA not found.") - cuda_root = Path(nvidia.__file__).parent + if nvidia.__file__ is not None: + cuda_root = Path(nvidia.__file__).parent + else: + cuda_root = Path(nvidia.__path__[0]) # namespace return [ subdir / "include" for subdir in cuda_root.iterdir() diff --git a/ci/pytorch.sh b/ci/pytorch.sh index d69009f1a..241f1e518 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -12,7 +12,7 @@ TEST_DIR=${TE_PATH}tests/pytorch #: ${TEST_WORKERS:=4} install_prerequisites() { - pip install 'numpy>=1.22.4' pandas + pip install 'numpy>=1.22.4' pandas safetensors rc=$? if [ $rc -ne 0 ]; then script_error "Failed to install test prerequisites" @@ -100,8 +100,11 @@ run_test_config_mgpu(){ run_default_fa 2 distributed/test_numerics.py run_default_fa 1 distributed/test_torch_fsdp2.py run_default_fa 2 distributed/test_torch_fsdp2_fp8.py - run_default_fa_lbl "flash" 3 attention/test_attention_with_cp.py -k "with_flash" - run_default_fa_lbl "fused" 2 attention/test_attention_with_cp.py -k "with_fused" + if [ $_fus_attn = ck ]; then + run 2 attention/test_attention_with_cp.py -k "with_fused" + elif [ $_fus_attn = flash ]; then + run 3 attention/test_attention_with_cp.py -k "with_flash" + fi } run_benchmark() { diff --git a/docs/_static/css/diagram-colors.css b/docs/_static/css/diagram-colors.css new file mode 100644 index 000000000..96a2a8a6d --- /dev/null +++ b/docs/_static/css/diagram-colors.css @@ -0,0 +1,134 @@ +/* Diagram color definitions for Transformer Engine documentation */ + +/* High precision (BF16/FP16) elements */ +.hp { + fill: #ede7f6; + stroke: #673ab7; + stroke-width: 2; +} + +/* FP8 precision elements */ +.fp8 { + fill: #fff8e1; + stroke: #ffa726; + stroke-width: 2; +} + +/* GEMM/computation operations */ +.gemm { + fill: #ffe0b2; + stroke: #fb8c00; + stroke-width: 2.5; +} + +/* Quantization operations */ +.quantize { + fill: #e8f5e9; + stroke: #66bb6a; + stroke-width: 2; +} + +/* Amax computation operations */ +.amax { + fill: #e1f5fe; + stroke: #039be5; + stroke-width: 2; +} + +/* Text styles */ +.text { + font-family: 'Segoe UI', Arial, sans-serif; + font-size: 14px; + text-anchor: middle; + fill: #212121; +} + +.small-text { + font-family: 'Segoe UI', Arial, sans-serif; + font-size: 14px; + text-anchor: middle; + fill: #757575; +} + +.label { + font-family: 'Segoe UI', Arial, sans-serif; + font-size: 14px; + text-anchor: middle; + fill: #424242; +} + +.title { + font-family: 'Segoe UI', Arial, sans-serif; + font-size: 18px; + font-weight: 600; + text-anchor: middle; + fill: #212121; +} + +.section-title { + font-family: 'Segoe UI', Arial, sans-serif; + font-size: 15px; + font-weight: 600; + text-anchor: middle; +} + +/* Arrows */ +/* Note: marker-end references #arrowhead marker which must be defined in each SVG's section */ +.arrow { + stroke: #616161; + stroke-width: 2; + fill: none; + marker-end: url(#arrowhead); +} + +/* Additional box and element styles */ +.box-blue { + fill: #e3f2fd; + stroke: #1976d2; + stroke-width: 2; +} + +.box-orange { + fill: #fff3e0; + stroke: #f57c00; + stroke-width: 2; +} + +.box-green { + fill: #c8e6c9; + stroke: #388e3c; + stroke-width: 2; +} + +.box-dashed { + stroke-dasharray: 5,5; +} + +/* LayerNorm specific */ +.layernorm { + fill: #b3e5fc; + stroke: #0277bd; + stroke-width: 2.5; +} + +/* Fused layers */ +.fused { + fill: #b2dfdb; + stroke: #00695c; + stroke-width: 3; +} + +/* Generic computation blocks */ +.computation { + fill: #f5f5f5; + stroke: #757575; + stroke-width: 2; +} + +/* FP32 precision (alternative red) */ +.fp32 { + fill: #ffcdd2; + stroke: #d32f2f; + stroke-width: 2.5; +} + diff --git a/docs/_static/css/sphinx_tabs.css b/docs/_static/css/sphinx_tabs.css new file mode 100644 index 000000000..c3e524e0e --- /dev/null +++ b/docs/_static/css/sphinx_tabs.css @@ -0,0 +1,45 @@ +/* Custom styling for sphinx-tabs */ + +.sphinx-tabs { + margin-bottom: 1rem; +} + +.sphinx-tabs-tab { + background-color: #f4f4f4; + border: 1px solid #ccc; + border-bottom: none; + padding: 0.5rem 1rem; + margin-right: 0.5rem; + cursor: pointer; + font-weight: 500; + transition: background-color 0.2s; +} + +.sphinx-tabs-tab:hover { + background-color: #e0e0e0; +} + +.sphinx-tabs-tab[aria-selected="true"] { + background-color: #76b900; /* NVIDIA green */ + color: white; + border-color: #76b900; + margin-right: 0.5rem; +} + +.sphinx-tabs-panel { + border: 1px solid #ccc; + padding: 1rem; + background-color: #f9f9f9; +} + +/* Dark mode support for RTD theme */ +.rst-content .sphinx-tabs-tab { + color: #333; +} + +.rst-content .sphinx-tabs-tab[aria-selected="true"] { + color: white; +} + + + diff --git a/docs/_static/css/svg-responsive.css b/docs/_static/css/svg-responsive.css new file mode 100644 index 000000000..3ffe14eb1 --- /dev/null +++ b/docs/_static/css/svg-responsive.css @@ -0,0 +1,72 @@ +/* Responsive styling for SVG images */ + +/* Make all SVG images responsive */ +.document svg, +.document object[type="image/svg+xml"], +.rst-content svg { + max-width: 100%; + height: auto; + display: block; + margin: 1em auto; +} + +/* For raw HTML embedded SVGs */ +.document .raw-html svg { + max-width: 100%; + height: auto; + width: 100%; +} + +/* Ensure container doesn't overflow */ +.document .raw-html { + max-width: 100%; + overflow-x: auto; +} + +/* Figure containers with captions */ +.svg-figure { + text-align: center; + margin: 20px auto; +} + +.svg-figure img { + display: block; + margin: 0 auto; + height: auto; +} + +/* Different width classes for figures */ +.svg-figure.width-70 img { + width: 70%; + max-width: 100%; +} + +.svg-figure.width-80 img { + width: 80%; + max-width: 100%; +} + +.svg-figure.width-90 img { + width: 90%; + max-width: 100%; +} + +.svg-figure.width-100 img { + width: 100%; +} + +/* Figure captions */ +.svg-caption { + font-style: italic; + margin-top: 10px; + color: #555; + font-size: 0.95em; + line-height: 1.4; +} + + + + + + + diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html index f94e526f5..99ae0702a 100644 --- a/docs/_templates/layout.html +++ b/docs/_templates/layout.html @@ -67,6 +67,10 @@ overflow: visible !important; } + .quant { + background-color: yellow !important; + } +
stats:\n",
+       "  enabled: True\n",
+       "  layers:\n",
+       "    layer_name_regex_pattern: .*\n",
+       "  transformer_engine:\n",
+       "    PercentageGreaterThanThreshold:\n",
+       "      enabled: True\n",
+       "      tensors: [activation]\n",
+       "      threshold: 0.1\n",
+       "      freq: 5\n",
+       "    LogTensorStats:\n",
+       "      enabled: True\n",
+       "      tensors: [activation]\n",
+       "      stats: [min]\n",
+       "      freq: 5\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{n}{stats}\\PY{p}{:}\n", + " \\PY{n}{enabled}\\PY{p}{:} \\PY{k+kc}{True}\n", + " \\PY{n}{layers}\\PY{p}{:}\n", + " \\PY{n}{layer\\PYZus{}name\\PYZus{}regex\\PYZus{}pattern}\\PY{p}{:} \\PY{o}{.}\\PY{o}{*}\n", + " \\PY{n}{transformer\\PYZus{}engine}\\PY{p}{:}\n", + " \\PY{n}{PercentageGreaterThanThreshold}\\PY{p}{:}\n", + " \\PY{n}{enabled}\\PY{p}{:} \\PY{k+kc}{True}\n", + " \\PY{n}{tensors}\\PY{p}{:} \\PY{p}{[}\\PY{n}{activation}\\PY{p}{]}\n", + " \\PY{n}{threshold}\\PY{p}{:} \\PY{l+m+mf}{0.1}\n", + " \\PY{n}{freq}\\PY{p}{:} \\PY{l+m+mi}{5}\n", + " \\PY{n}{LogTensorStats}\\PY{p}{:}\n", + " \\PY{n}{enabled}\\PY{p}{:} \\PY{k+kc}{True}\n", + " \\PY{n}{tensors}\\PY{p}{:} \\PY{p}{[}\\PY{n}{activation}\\PY{p}{]}\n", + " \\PY{n}{stats}\\PY{p}{:} \\PY{p}{[}\\PY{n+nb}{min}\\PY{p}{]}\n", + " \\PY{n}{freq}\\PY{p}{:} \\PY{l+m+mi}{5}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "stats:\n", + " enabled: True\n", + " layers:\n", + " layer_name_regex_pattern: .*\n", + " transformer_engine:\n", + " PercentageGreaterThanThreshold:\n", + " enabled: True\n", + " tensors: [activation]\n", + " threshold: 0.1\n", + " freq: 5\n", + " LogTensorStats:\n", + " enabled: True\n", + " tensors: [activation]\n", + " stats: [min]\n", + " freq: 5" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from IPython.display import Code\n", + "Code(filename='./custom_feature_dir/custom_feature_example_config.yaml', language='yaml')" + ] + }, + { + "cell_type": "markdown", + "id": "3929f293-7ac1-48b0-8a4d-23bb6976aa0b", + "metadata": {}, + "source": [ + "To use this feature one needs to add `.../custom_feature_dir` to `debug_api.initialize(feature_dirs=...`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d82f1c82", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "NVDLFW INSPECT - 2025-10-17 14:16:42,204 - WARNING - Reduction group initialized for tensor reduction before logging statistics. If per-rank statistics are required, pass `skip_reduction=True` when invoking the API. To pass another reduction group, use `reduction_group` kwarg when invoking the API.\n" + ] + } + ], + "source": [ + "import os, time\n", + "import torch\n", + "import transformer_engine.pytorch as te\n", + "import nvdlfw_inspect.api as debug_api\n", + "\n", + "te_dir = os.environ[\"TE_PATH\"] # setup TE dir as environment variable to run this script\n", + "log_dir = os.environ.get(\"LOG_PATH\", \"./log\")\n", + "\n", + "debug_api.initialize(\n", + " config_file=te_dir + \"/docs/debug/custom_feature_dir/custom_feature_example_config.yaml\",\n", + " feature_dirs=[\n", + " te_dir + \"/transformer_engine/debug/features\", \n", + " te_dir + \"/docs/debug/custom_feature_dir\" # One needs to add path to the custom feature dir here\n", + " ],\n", + " log_dir=log_dir,\n", + " default_logging_enabled=True)\n", + "\n", + "debug_api.set_tensor_reduction_group(None) # For distributed training one needs to set the reduction group\n", + "\n", + "module = te.Linear(128, 128, name=\"linear_1\")\n", + "inp = torch.randn(128, 128).cuda()\n", + "\n", + "# Simple training loop with measuring the time\n", + "times = []\n", + "for _ in range(100):\n", + " time_start = time.time()\n", + " inp.normal_()\n", + " out = module(inp)\n", + " out.sum().backward()\n", + " torch.cuda.synchronize()\n", + " time_end = time.time()\n", + " times.append(time_end - time_start)\n", + "\n", + " debug_api.step()" + ] + }, + { + "cell_type": "markdown", + "id": "e4f129a9", + "metadata": {}, + "source": [ + "Now, let's plot the gathered stats." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b68a21ea", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAHqCAYAAADVi/1VAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdYk+f6B/Dvm4QkbJAtIggooOKeuCfuXav2uLXW1tpqbaunQ+342Z7a2lattrZVa4daR6sd7tG6J25QEGQIhL1HSJ7fH8n7SmRDFnB/rsvrnGY+CZC8uXPf34djjDEQQgghhBBCCCGEEGJEIlMvgBBCCCGEEEIIIYQ0PlSUIoQQQgghhBBCCCFGR0UpQgghhBBCCCGEEGJ0VJQihBBCCCGEEEIIIUZHRSlCCCGEEEIIIYQQYnRUlCKEEEIIIYQQQgghRkdFKUIIIYQQQgghhBBidFSUIoQQQgghhBBCCCFGR0UpQgghhBBCCCGEEGJ0VJQihJgcx3FYtWqVqZdRoVmzZsHHx8fUyyCEEEII0SuO47Bo0SJTL6NKMTEx4DgOa9euNfVSABhmPdu2bQPHcYiJianysj4+Ppg1a5be7rs+qMnzQ+oXKkoRQvSCf6PgOA5nzpwpcz5jDF5eXuA4DqNGjTLBCgkhhBDDK/1+yHEc5HI5WrVqhUWLFiE5OdnUy6uzu3fvYtWqVfTBsJZM8fydO3cOq1atQmZmptHus7b++usvs/6isj46cOAAOnXqBLlcjubNm2PlypUoKSmp1nU//PBDjBkzBm5ubmb/JTKpv6goRQjRK7lcjp9//rnM6adPn0Z8fDxkMlmZ8woKCvD2228bY3mEEEKIUbz33nvYsWMHNmzYgJCQEGzatAk9e/ZEfn6+qZdWJ3fv3sXq1aupKFVLpnj+zp07h9WrV9ebotTq1atNvYwG4++//8a4cePg4OCA9evXY9y4cfjggw/w8ssvV+v6b7/9Ni5fvoyOHTsaeKWkMZOYegGEkIZlxIgR+PXXX/Hll19CInnyEvPzzz+jc+fOSE1NLXMduVxuzCUSQgghBjd8+HB06dIFADBv3jw4OTnhs88+w++//46pU6fW6bbz8/NhZWWlj2WSChQWFkIqlUIkqh/f4efl5cHa2trUyzBbjfX5WbZsGdq1a4cjR44Ix+V2dnb4v//7P7zyyisIDAys9PrR0dHw8fFBamoqXFxcjLFk0gjVj1dZQki9MXXqVKSlpeHo0aPCacXFxdizZw+mTZtW7nWebgdetWoVOI5DZGQkZs2aBQcHB9jb22P27NlVfsO8aNEi2NjYlHu5qVOnwt3dHSqVCgDw+++/Y+TIkWjatClkMhn8/Pzw/vvvC+dX5NSpU+A4DqdOndI5nc8X2LZtm87p4eHhmDRpEpo0aQK5XI4uXbrgwIEDOpdRKpVYvXo1WrZsCblcDicnJ/Tu3VvneSSEEFJ/DRw4EIDmQx7vxx9/ROfOnWFpaYkmTZpgypQpiIuL07le//790bZtW1y9ehV9+/aFlZUV/vvf/wLQFE5WrVqFVq1aQS6Xw8PDAxMmTEBUVJRwfbVajc8//xxt2rSBXC6Hm5sbFixYgIyMDJ378fHxwahRo3DmzBl069YNcrkcvr6++OGHH4TLbNu2Dc888wwAYMCAAcKIIv9+WJP31Y0bN8LX1xeWlpbo1q0b/v33X/Tv3x/9+/fXuVxRURFWrlwJf39/yGQyeHl54Y033kBRUVG1nvfq3A//vr5z5068/fbb8PT0hJWVFbKzswEAFy9exLBhw2Bvbw8rKyv069cPZ8+e1bmfR48e4cUXX0RAQAAsLS3h5OSEZ555RqcjqqrnD9B0tvTp0wfW1tawtbXFyJEjcefOHZ37mjVrFmxsbBAVFYURI0bA1tYWzz33XLmPf9WqVXj99dcBAC1atBDu8+lOrd9++w1t27aFTCZDmzZtcOjQoRo/Pv4xchyHs2fPYunSpXBxcYG1tTXGjx+PlJSUctdY+nFt3LgRAHRGYJ/2zTffwM/PDzKZDF27dsXly5er/fxU9+/hypUrCA0NhbOzMywtLdGiRQvMmTOn3HVXtR4AOHHihPBzdXBwwNixY3Hv3r1Knw9AE3/xwQcfoFmzZrCyssKAAQPK/D5U5O7du7h79y6ef/55nS+KX3zxRTDGsGfPnipvo7aZqnv27AHHcTh9+nSZ877++mtwHIfbt28DAG7evIlZs2bB19cXcrkc7u7umDNnDtLS0qq8n4pGCsvL3MrMzMSrr74KLy8vyGQy+Pv74+OPP4Zarda53M6dO9G5c2fY2trCzs4OwcHB+OKLL6r/4EmNUacUIUSvfHx80LNnT/zyyy8YPnw4AM0BVlZWFqZMmYIvv/yy2rc1efJktGjRAmvWrMG1a9fw7bffwtXVFR9//HGF13n22WexceNG/Pnnn8KBH6D5VvngwYOYNWsWxGIxAM2Bk42NDZYuXQobGxucOHEC7777LrKzs/HJJ5/U8hnQdefOHfTq1Quenp5Yvnw5rK2tsXv3bowbNw579+7F+PHjAWgOGtesWYN58+ahW7duyM7OxpUrV3Dt2jUMGTJEL2shhBBiOnyhyMnJCYAmq+Wdd97B5MmTMW/ePKSkpGD9+vXo27cvrl+/DgcHB+G6aWlpGD58OKZMmYL//Oc/cHNzg0qlwqhRo3D8+HFMmTIFr7zyCnJycnD06FHcvn0bfn5+AIAFCxZg27ZtmD17NhYvXozo6Ghs2LAB169fx9mzZ2FhYSHcT2RkJCZNmoS5c+di5syZ+P777zFr1ix07twZbdq0Qd++fbF48WJ8+eWX+O9//4ugoCAAEP63uu+rmzZtwqJFi9CnTx8sWbIEMTExGDduHBwdHdGsWTPhcmq1GmPGjMGZM2fw/PPPIygoCLdu3cK6detw//59/Pbbb5U+59W9H977778PqVSKZcuWoaioCFKpFCdOnMDw4cPRuXNnrFy5EiKRCFu3bsXAgQPx77//olu3bgCAy5cv49y5c5gyZQqaNWuGmJgYbNq0Cf3798fdu3dhZWVV5fO3Y8cOzJw5E6Ghofj444+Rn5+PTZs2oXfv3rh+/bpOgaCkpAShoaHo3bs31q5dW2Hn3IQJE3D//n388ssvWLduHZydnQFAp+vlzJkz2LdvH1588UXY2triyy+/xMSJExEbGyv8vlbn8ZX28ssvw9HREStXrkRMTAw+//xzLFq0CLt27arw57VgwQI8fvwYR48exY4dO8q9zM8//4ycnBwsWLAAHMfhf//7HyZMmICHDx/q/C5X9PxU5+9BoVBg6NChcHFxwfLly+Hg4ICYmBjs27evVus5duwYhg8fDl9fX6xatQoFBQVYv349evXqhWvXrlVa+Hn33XfxwQcfYMSIERgxYgSuXbuGoUOHori4uMLr8K5fvw4AQscmr2nTpmjWrJlwviGMHDkSNjY22L17N/r166dz3q5du9CmTRu0bdsWAHD06FE8fPgQs2fPhru7O+7cuYNvvvkGd+7cwYULF8otTNZUfn4++vXrh4SEBCxYsADNmzfHuXPnsGLFCiQmJuLzzz8X1jJ16lQMGjRI+Lxx7949nD17Fq+88kqd10EqwAghRA+2bt3KALDLly+zDRs2MFtbW5afn88YY+yZZ55hAwYMYIwx5u3tzUaOHKlzXQBs5cqVwn+vXLmSAWBz5szRudz48eOZk5NTpetQq9XM09OTTZw4Uef03bt3MwDsn3/+EU7j11faggULmJWVFSssLBROmzlzJvP29hb+++TJkwwAO3nypM51o6OjGQC2detW4bRBgwax4OBgndtTq9UsJCSEtWzZUjitffv2ZZ4XQggh9Q//fnjs2DGWkpLC4uLi2M6dO5mTkxOztLRk8fHxLCYmhonFYvbhhx/qXPfWrVtMIpHonN6vXz8GgG3evFnnst9//z0DwD777LMya1Cr1Ywxxv79918GgP3000865x86dKjM6d7e3mXeJxUKBZPJZOy1114TTvv111/LfQ9krHrvq0VFRczJyYl17dqVKZVK4XLbtm1jAFi/fv2E03bs2MFEIhH7999/dW5z8+bNDAA7e/Zsmfvj1eR++Pd1X19fncegVqtZy5YtWWhoqPCc8o+zRYsWbMiQIZU+9vPnzzMA7IcffhBOq+j5y8nJYQ4ODmz+/Pk6pyclJTF7e3ud02fOnMkAsOXLl1f4+Ev75JNPGAAWHR1d5jwATCqVssjISOG0GzduMABs/fr1NX58/O//4MGDdZ6zJUuWMLFYzDIzMytd60svvcTK+4jKH2M5OTmx9PR04fTff/+dAWAHDx4UTqvo+anu38P+/fuFY9qK1GQ9HTp0YK6uriwtLU047caNG0wkErEZM2YIp/HPHf9zUigUTCqVspEjR+o8l//9738ZADZz5swK18fYk597bGxsmfO6du3KevToUen1S0tJSSlzvF6VqVOnMldXV1ZSUiKclpiYyEQiEXvvvfeE08r73frll1/KvB49/fwwVvYzBM/b21vn+Xn//feZtbU1u3//vs7lli9fzsRisfAcvfLKK8zOzk5nzcTwaHyPEKJ3kydPRkFBAf744w/k5OTgjz/+qHB0rzIvvPCCzn/36dMHaWlpQjt9eTiOwzPPPIO//voLubm5wum7du2Cp6cnevfuLZxmaWkp/P+cnBykpqaiT58+yM/PR3h4eI3X+7T09HScOHECkydPFm4/NTUVaWlpCA0NxYMHD5CQkAAAcHBwwJ07d/DgwYM63y8hhBDTGzx4MFxcXODl5YUpU6bAxsYG+/fvh6enJ/bt2we1Wo3JkycL7w2pqalwd3dHy5YtcfLkSZ3bkslkmD17ts5pe/fuhbOzc7mBxXxnwa+//gp7e3sMGTJE5346d+4MGxubMvfTunVr9OnTR/hvFxcXBAQE4OHDh9V6zNV5X71y5QrS0tIwf/58nZGi5557Do6Ojjq39+uvvyIoKAiBgYE66+dHIZ9ef2k1uR/ezJkzdR5DWFgYHjx4gGnTpiEtLU24/7y8PAwaNAj//POPMPpT+npKpRJpaWnw9/eHg4MDrl27VuVzd/ToUWRmZmLq1Kk6j1UsFqN79+7lPtaFCxdWebvVMXjwYKGzDgDatWsHOzs7nZ97TR/f888/r9Ph0qdPH6hUKjx69KhOa3322Wd1fn7872t5v6NPPz/V/XvguxT/+OMPKJXKOq0nMTERYWFhmDVrFpo0aSJcrl27dhgyZAj++uuvCm/72LFjKC4uxssvv6zzXL766quVrolXUFAAAOVuMiSXy4XzDeXZZ5+FQqHQGU/ds2cP1Go1nn32WeG00r9bhYWFSE1NRY8ePQCgWn871fHrr7+iT58+cHR01PnZDx48GCqVCv/88w8Azc8+Ly+P4jOMjMb3CCF65+LigsGDB+Pnn39Gfn4+VCoVJk2aVOPbad68uc5/82/6GRkZsLOzq/B6zz77LD7//HMcOHAA06ZNQ25uLv766y+htZp3584dvP322zhx4kSZQldWVlaN1/u0yMhIMMbwzjvv4J133in3MgqFAp6ennjvvfcwduxYtGrVCm3btsWwYcMwffp0tGvXrs7rIIQQYnwbN25Eq1atIJFI4ObmhoCAACE0+8GDB2CMoWXLluVet/QYEgB4enpCKpXqnBYVFYWAgACdgsvTHjx4gKysLLi6upZ7vkKh0Pnvp993Ac1779N5OxWpzvsqX5Tw9/fXOV8ikZQZY3rw4AHu3btXYcDy0+svrSb3w2vRokWZ+wc0xaqKZGVlwdHREQUFBVizZg22bt2KhIQEMMZ0LlMV/r74gtvTnj7ukUgk5Y4g1kZ1fu41fXyVHcPpc60V3W55z091/x769euHiRMnYvXq1Vi3bh369++PcePGYdq0aWUKPFWth/89DAgIKHN/QUFBOHz4cIUh7Px1n36dcHFxqbCwWhpf7Ckvf62wsFCnGGQIfA7brl27MGjQIACaL4k7dOiAVq1aCZdLT0/H6tWrsXPnzjJ/0/o4Hgc0P/ubN29W+Vry4osvYvfu3Rg+fDg8PT0xdOhQTJ48GcOGDdPLOkj5qChFCDGIadOmYf78+UhKSsLw4cN1sjGqi89+elrpA6Hy9OjRAz4+Pti9ezemTZuGgwcPoqCgQOdbmczMTPTr1w92dnZ477334OfnB7lcjmvXruHNN98sE3pYWkWz7U8HufK3sWzZMoSGhpZ7Hf5guW/fvoiKisLvv/+OI0eO4Ntvv8W6deuwefNmzJs3r9LHSwghxPx069atTJYLT61Wg+M4/P333+W+19nY2Oj8d20/PKrVari6uuKnn34q9/ynP6DV9n0XqNv7amXrDw4OxmeffVbu+V5eXjW+zco8/Tzza/7kk0/QoUOHcq/D/6xefvllbN26Fa+++ip69uwJe3t7cByHKVOmVOux85fZsWMH3N3dy5z/dPFRJpPpbWfA6vzca/r46vK7VNe1AuU/P9X9e+A4Dnv27MGFCxdw8OBBHD58GHPmzMGnn36KCxcu6Px9Gupx6oOHhwcATbfW038riYmJQh6aochkMowbNw779+/HV199heTkZJw9exb/93//p3O5yZMn49y5c3j99dfRoUMH2NjYQK1WY9iwYbV63QDKPyYfMmQI3njjjXIvzxfJXF1dERYWhsOHD+Pvv//G33//ja1bt2LGjBnYvn17rdZCqkZFKUKIQYwfPx4LFizAhQsXKg21NJTJkyfjiy++QHZ2Nnbt2gUfHx+hFRjQ7LSTlpaGffv2oW/fvsLppXdFqgj/7VRmZqbO6U+3pPv6+gLQfOM9ePDgKm+3SZMmmD17NmbPno3c3Fz07dsXq1atoqIUIYQ0MH5+fmCMoUWLFjodAzW9jYsXL0KpVJbprCp9mWPHjqFXr15664qo6IuZ6r6vent7A9B0Ew8YMEA4vaSkBDExMTodwn5+frhx4wYGDRpU47DjmtxPRfiRNjs7uyrfx/fs2YOZM2fi008/FU4rLCwsc6xQ0ePg78vV1bVaxww1oY+g6Oo+vrrSx1orUtO/hx49eqBHjx748MMP8fPPP+O5557Dzp07a3Rcxv8eRkRElDkvPDwczs7O5XZJlb7ugwcPhGNKAEhJSalWxxlfSL1y5YpOAerx48eIj4/H888/X+3HUVvPPvsstm/fjuPHj+PevXtgjOl8SZyRkYHjx49j9erVePfdd4XTqxtn4ejoWOZ3sLi4GImJiTqn+fn5ITc3t1p/W1KpFKNHj8bo0aOhVqvx4osv4uuvv8Y777xTpvOS6AdlShFCDMLGxgabNm3CqlWrMHr0aKPf/7PPPouioiJs374dhw4dwuTJk3XO57/ZKv1NVnFxMb766qsqb9vb2xtisViYP+c9fV1XV1f0798fX3/9dZk3RwA6WyM/ve2tjY0N/P39q73lNSGEkPpjwoQJEIvFWL16dZmOCsZYtbZCnzhxIlJTU7Fhw4Yy5/G3OXnyZKhUKrz//vtlLlNSUlKrggL/Afrp61b3fbVLly5wcnLCli1bUFJSIpz+008/lfmgPXnyZCQkJGDLli1l1lFQUIC8vLwK11mT+6lI586d4efnh7Vr1+rkVPJKv4+LxeIyP8v169eX6dio6PkLDQ2FnZ0d/u///q/cHKPS91VTFd1nTVT38dWVPtZaker+PWRkZJR5rHyBp6bHZR4eHujQoQO2b9+u85hu376NI0eOYMSIERVed/DgwbCwsMD69et11sPvFFeVNm3aIDAwEN98843Oz2nTpk3gOE4nWiMrKwvh4eF6G5fjDR48GE2aNMGuXbuwa9cudOvWTWdMtrzXDaD6j9HPz6/M8fjTjxfQ/OzPnz+Pw4cPl7mNzMxM4TXi6ddekUgkFLDpmNxwqFOKEGIwlWUwGFqnTp3g7++Pt956C0VFRTrfygBASEgIHB0dMXPmTCxevBgcx2HHjh3Vare2t7fHM888g/Xr14PjOPj5+eGPP/4oN9ti48aN6N27N4KDgzF//nz4+voiOTkZ58+fR3x8PG7cuAFAEy7bv39/dO7cGU2aNMGVK1ewZ88eLFq0SD9PCCGEELPh5+eHDz74ACtWrEBMTAzGjRsHW1tbREdHY//+/Xj++eexbNmySm9jxowZ+OGHH7B06VJcunQJffr0QV5eHo4dO4YXX3wRY8eORb9+/bBgwQKsWbMGYWFhGDp0KCwsLPDgwQP8+uuv+OKLL2qc+dihQweIxWJ8/PHHyMrKgkwmw8CBA6v9viqVSrFq1Sq8/PLLGDhwICZPnoyYmBhs27YNfn5+Op0y06dPx+7du/HCCy/g5MmT6NWrF1QqFcLDw7F7924cPny4whHJmtxPRUQiEb799lsMHz4cbdq0wezZs+Hp6YmEhAScPHkSdnZ2OHjwIABg1KhR2LFjB+zt7dG6dWucP38ex44dg5OTU7WeP1dXV2zatAnTp09Hp06dMGXKFLi4uCA2NhZ//vknevXqVW4Bsjo6d+4MAHjrrbcwZcoUWFhYYPTo0RV26JSnuo+vrvi1Ll68GKGhoRCLxZgyZYpebru6fw/bt2/HV199hfHjx8PPzw85OTnYsmUL7OzsKi0iVeSTTz7B8OHD0bNnT8ydOxcFBQVYv3497O3tsWrVqgqv5+LigmXLlmHNmjUYNWoURowYgevXr+Pvv/+Gs7Nzte97zJgxGDp0KKZMmYLbt29jw4YNmDdvHoKCgoTL7d+/H7Nnz8bWrVsxa9Ys4fQdO3bg0aNHyM/PBwD8888/+OCDDwBo/j75bq6KWFhYYMKECdi5cyfy8vKwdu1anfPt7OzQt29f/O9//4NSqYSnpyeOHDlSrckFAJg3bx5eeOEFTJw4EUOGDMGNGzdw+PDhMs/P66+/jgMHDmDUqFGYNWsWOnfujLy8PNy6dQt79uxBTEwMnJ2dMW/ePKSnp2PgwIFo1qwZHj16hPXr16NDhw46zxfRM+Nt9EcIacj4bVor2z6XMc0WrSNHjtQ5DU9t57py5UoGgKWkpJR7H+VtaVyet956iwFg/v7+5Z5/9uxZ1qNHD2ZpacmaNm3K3njjDXb48OEyWzXPnDmTeXt761w3JSWFTZw4kVlZWTFHR0e2YMECdvv2bQaAbd26VeeyUVFRbMaMGczd3Z1ZWFgwT09PNmrUKLZnzx7hMh988AHr1q0bc3BwYJaWliwwMJB9+OGHrLi4uFqPlRBCiHmo7vshY4zt3buX9e7dm1lbWzNra2sWGBjIXnrpJRYRESFcpl+/fqxNmzblXj8/P5+99dZbrEWLFszCwoK5u7uzSZMmsaioKJ3LffPNN6xz587M0tKS2drasuDgYPbGG2+wx48fC5cp7/2Zv/9+/frpnLZlyxbm6+vLxGKxzntmdd9XGWPsyy+/ZN7e3kwmk7Fu3bqxs2fPss6dO7Nhw4bpXK64uJh9/PHHrE2bNkwmkzFHR0fWuXNntnr1apaVlVXVU1yt+zl58iQDwH799ddyb+P69etswoQJzMnJiclkMubt7c0mT57Mjh8/LlwmIyODzZ49mzk7OzMbGxsWGhrKwsPDy2xNX9nzx68lNDSU2dvbM7lczvz8/NisWbPYlStXhMvMnDmTWVtbV/nYS3v//feZp6cnE4lEOsdSANhLL71U5vJPr7u6j6+i33/+OX769+BpJSUl7OWXX2YuLi6M4zjGf1yNjo5mANgnn3xS5jpPH0dW9fxU9fdw7do1NnXqVNa8eXMmk8mYq6srGzVqlM7PoCbrYYyxY8eOsV69ejFLS0tmZ2fHRo8eze7evatzmfKOc1UqFVu9ejXz8PBglpaWrH///uz27dvl/l5VZP/+/axDhw5MJpOxZs2asbfffrvM8SV/308fw/br148BKPdfVT9L3tGjRxkAxnEci4uLK3N+fHw8Gz9+PHNwcGD29vbsmWeeYY8fPy7zPFb0/Lz55pvM2dmZWVlZsdDQUBYZGVnu85OTk8NWrFjB/P39mVQqZc7OziwkJIStXbtWeD727NnDhg4dylxdXZlUKmXNmzdnCxYsYImJidV6rKR2OMbMIIWNEEIIIYQQ0mip1Wq4uLhgwoQJ5Y7r1bf7IYQQUj2UKUUIIYQQQggxmsLCwjJjfT/88APS09PRv3//enc/hBBCao86pQghhBBCCCFGc+rUKSxZsgTPPPMMnJyccO3aNXz33XcICgrC1atXIZVK69X9EEIIqT0KOieEEEIIIYQYjY+PD7y8vPDll18iPT0dTZo0wYwZM/DRRx/ptVBkrPshhBBSe9QpRQghhBBCCCGEEEKMjjKlCCGEEEIIIYQQQojRUVGKEEIIIYQQQgghhBhdg8uUUqvVePz4MWxtbcFxnKmXQwghhJB6iDGGnJwcNG3aFCJR4/sOj46nCCGEEFIX1T2WanBFqcePH8PLy8vUyyCEEEJIAxAXF4dmzZqZehlGR8dThBBCCNGHqo6lGlxRytbWFoDmgdvZ2Zl4NYQQQgipj7Kzs+Hl5SUcVzQ2dDxFCCGEkLqo7rFUgytK8S3mdnZ2dBBFCCGEkDpprKNrdDxFCCGEEH2o6liq8YUkEEIIIYQQQgghhBCTo6IUIYQQQgghhBBCCDE6KkoRQgghhBBCCCGEEKNrcJlShBBCSGOnUqmgVCpNvQyzZmFhAbFYbOplEELqGbVajeLiYlMvgxBCTE5fx1JUlCKEEEIaCMYYkpKSkJmZaeql1AsODg5wd3dvtGHmhJCaKS4uRnR0NNRqtamXQgghZkEfx1JUlCKEEEIaCL4g5erqCisrKyq2VIAxhvz8fCgUCgCAh4eHiVdECDF3jDEkJiZCLBbDy8sLIhGloBBCGi99HktRUYoQQghpAFQqlVCQcnJyMvVyzJ6lpSUAQKFQwNXV1WCjfBs3bsQnn3yCpKQktG/fHuvXr0e3bt2qvN7OnTsxdepUjB07Fr/99pvOeffu3cObb76J06dPo6SkBK1bt8bevXvRvHlz4TLnz5/HW2+9hYsXL0IsFqNDhw44fPiw8LgJITVTUlKC/Px8NG3aFFZWVqZeDiGEmJy+jqWoxE8IIYQ0AHyGFH1Yqj7+uTJU/tauXbuwdOlSrFy5EteuXUP79u0RGhoqfKtYkZiYGCxbtgx9+vQpc15UVBR69+6NwMBAnDp1Cjdv3sQ777wDuVwuXOb8+fMYNmwYhg4dikuXLuHy5ctYtGgRdXYQUgcqlQoAIJVKTbwSQggxH/o4lqJOKUIIIaQBoZG96jP0c/XZZ59h/vz5mD17NgBg8+bN+PPPP/H9999j+fLl5V5HpVLhueeew+rVq/Hvv/+WyQd76623MGLECPzvf/8TTvPz89O5zJIlS7B48WKd+wgICNDToyKkcaPXWEIIeUIfr4n0lRkhhBBCiJ4VFxfj6tWrGDx4sHCaSCTC4MGDcf78+Qqv995778HV1RVz584tc55arcaff/6JVq1aITQ0FK6urujevbvOeJ9CocDFixfh6uqKkJAQuLm5oV+/fjhz5kyl6y0qKkJ2drbOP0IIIYQQQ6OiFCGEEELqlVOnToHjOLPeZTA1NRUqlQpubm46p7u5uSEpKanc65w5cwbfffcdtmzZUu75CoUCubm5+OijjzBs2DAcOXIE48ePx4QJE3D69GkAwMOHDwEAq1atwvz583Ho0CF06tQJgwYNwoMHDypc75o1a2Bvby/88/Lyqs3DJoQQ0sD1798fr776qqmXUUZMTAw4jkNYWJhR71dfxyQcx5XJkCzNVI/PGKgoRQghhJB6JSQkBImJibC3tzf1UvQmJycH06dPx5YtW+Ds7FzuZfht6MeOHYslS5agQ4cOWL58OUaNGoXNmzfrXGbBggWYPXs2OnbsiHXr1iEgIADff/99hfe/YsUKZGVlCf/i4uL0/AgJIaRmqvqQ3lgZ63kx5y+AZs2ahXHjxpl6GWYtNjYWI0eOhJWVFVxdXfH666+jpKSk0ut8+OGHCAkJgZWVFRwcHIyzUFCmFCGEEELqGalUCnd3d1Mvo1LOzs4Qi8VITk7WOT05ObnctUdFRSEmJgajR48WTuMLTBKJBBEREfDy8oJEIkHr1q11rhsUFCSM5/FbMpd3mdjY2ArXK5PJIJPJavAICSGNkUqlAsdxtHFCDRUXF5tdSL5SqYSFhYWpl2FSDfX3WaVSYeTIkXB3d8e5c+eQmJiIGTNmwMLCAv/3f/9X4fWKi4vxzDPPoGfPnvjuu++Mtt6G9ewTQgghpN7p378/Xn75Zbz66qtwdHSEm5sbtmzZgry8PMyePRu2trbw9/fH33//DaDst7fbtm2Dg4MDDh8+jKCgINjY2GDYsGFITEw02WOSSqXo3Lkzjh8/LpymVqtx/Phx9OzZs8zlAwMDcevWLYSFhQn/xowZgwEDBiAsLAxeXl6QSqXo2rUrIiIidK57//59eHt7AwB8fHzQtGnTSi9DCGk8+vfvj0WLFmHRokWwt7eHs7Mz3nnnHTDGAGjy5JYtWwZPT09YW1uje/fuOHXqlHB9/vX1wIEDaN26NWQyGWJjY1FUVIQ333wTXl5ekMlk8Pf31/kQe/v2bQwfPhw2NjZwc3PD9OnTkZqaqrOuxYsX44033kCTJk3g7u6OVatWCef7+PgAAMaPHw+O44T/joqKwtixY+Hm5gYbGxt07doVx44d03nMiYmJGDlyJCwtLdGiRQv8/PPP8PHxweeffy5cJjMzE/PmzYOLiwvs7OwwcOBA3Lhxo9rP6wcffABXV1fY2tpi3rx5WL58OTp06CCcz3fyfPjhh2jatKmw2URcXBwmT54MBwcHNGnSBGPHjkVMTIxwvcuXL2PIkCFwdnaGvb09+vXrh2vXrlX5vADA77//jk6dOkEul8PX1xerV6/W6YzhOA6bNm3CmDFjYG1tjQ8//LDCxxcTE4MBAwYAABwdHcFxHGbNmiWcr1arK/zZAZqNPoKDg2FtbQ0vLy+8+OKLyM3NFc6vy/v2qlWrsH37dvz+++/gOA4cx+n8zj58+BADBgyAlZUV2rdvr5PjWNnvc2V/B48ePcLo0aPh6OgIa2trtGnTBn/99ZfOuq5evYouXbrAysoKISEhZd6HN23aBD8/P0ilUgQEBGDHjh2VPs5Lly6hY8eOkMvl6NKlC65fv17lc8M7cuQI7t69ix9//BEdOnTA8OHD8f7772Pjxo0oLi6u8HqrV6/GkiVLEBwcXO370gvWwGRlZTEALCsry9RLIYQQQoymoKCA3b17lxUUFAinqdVqllekNMk/tVpd7bX369eP2drasvfff5/dv3+fvf/++0wsFrPhw4ezb775ht2/f58tXLiQOTk5sby8PHby5EkGgGVkZDDGGNu6dSuzsLBggwcPZpcvX2ZXr15lQUFBbNq0aTV+znj6OJ7YuXMnk8lkbNu2bezu3bvs+eefZw4ODiwpKYkxxtj06dPZ8uXLK7z+zJkz2dixY3VO27dvH7OwsGDffPMNe/DgAVu/fj0Ti8Xs33//FS6zbt06Zmdnx3799Vf24MED9vbbbzO5XM4iIyOrvXY6niJE19OvF/Xp9dXGxoa98sorLDw8nP3444/MysqKffPNN4wxxubNm8dCQkLYP//8wyIjI9knn3zCZDIZu3//PmPsyetrSEgIO3v2LAsPD2d5eXls8uTJzMvLi+3bt49FRUWxY8eOsZ07dzLGGMvIyGAuLi5sxYoV7N69e+zatWtsyJAhbMCAATrrsrOzY6tWrWL3799n27dvZxzHsSNHjjDGGFMoFAwA27p1K0tMTGQKhYIxxlhYWBjbvHkzu3XrFrt//77w+vbo0SPhtgcPHsw6dOjALly4wK5evcr69evHLC0t2bp163QuM3r0aHb58mV2//599tprrzEnJyeWlpZW5XP6448/Mrlczr7//nsWERHBVq9ezezs7Fj79u2Fy8ycOZPZ2Niw6dOns9u3b7Pbt2+z4uJiFhQUxObMmcNu3rzJ7t69y6ZNm8YCAgJYUVERY4yx48ePsx07drB79+6xu3fvsrlz5zI3NzeWnZ1d6fPyzz//MDs7O7Zt2zYWFRXFjhw5wnx8fNiqVauENQFgrq6u7Pvvv2dRUVE6z9nTSkpK2N69exkAFhERwRITE1lmZma1fnaMad6HTpw4waKjo9nx48dZQEAAW7hwoXB+bd+3GWMsJyeHTZ48mQ0bNowlJiayxMREVlRUxKKjoxkAFhgYyP744w8WERHBJk2axLy9vZlSqdS536d/n6v6Oxg5ciQbMmQIu3nzJouKimIHDx5kp0+fZowx4Zike/fu7NSpU+zOnTusT58+LCQkRFgz/969ceNGFhERwT799FMmFovZiRMndH4++/fvFx6ji4sLmzZtGrt9+zY7ePAg8/X1ZQDY9evXq3yO3nnnHZ3fR8YYe/jwIQPArl27VuX1t27dyuzt7au8HGP6OZai8T1CCCHExCIVOcjIV6KrTxO93m6BUoXW7x7W621W1933QmElrf5hRvv27fH2228D0OQbffTRR3B2dsb8+fMBAO+++y42bdqEmzdvlnt9pVKJzZs3w8/PDwCwaNEivPfee3V8FHXz7LPPIiUlBe+++y6SkpLQoUMHHDp0SAg/j42NrfHIwPjx47F582asWbMGixcvRkBAAPbu3YvevXsLl3n11VdRWFiIJUuWID09He3bt8fRo0eF54YQUnf16fXVy8sL69atA8dxCAgIwK1bt7Bu3TqEhoZi69atePToEWwcXWAplWDZsmU4dOgQtm7dKoz5KJVKfPXVV2jfvj0ATefl7t27cfToUWGHUV9fX+H+NmzYgI4dO+qMCX3//ffw8vLC/fv30apVKwBAu3btsHLlSgBAy5YtsWHDBhw/fhxDhgyBi4sLAMDBwUFn5Ll9+/bCOgDg/fffx/79+3HgwAEsWrQI4eHhOHbsGC5fvowuXboAAL799lu0bNlSuM6ZM2dw6dIlKBQKYWx57dq1+O2337Bnzx48//zzlT6f69evx9y5czF79mwAmvenI0eO6HQCAYC1tTW+/fZbYWzvxx9/hFqtxrfffguO4wAAW7duhYODA06dOoWhQ4di4MCBOrfxzTffwMHBAadPn8aoUaMqfF5Wr16N5cuXY+bMmcLP4/3338cbb7whPMcAMG3aNGHdlRGLxWjSRHNM4urqWiZfqLKfHQCdIHQfHx988MEHeOGFF/DVV18Jp9f2fdvGxgaWlpYoKioqdxx+2bJlGDlyJADN89KmTRtERkYiMDBQuN/Sv8+xsbHYunUrYmNj0bRpU+E2Sv8dxMbGYuLEiUIHUenfd96HH36Ifv36AQCWL1+OkSNHorCwEHK5HGvXrsWsWbPw4osvAgCWLl2KCxcuYO3atUJHWmk///wz1Go1vvvuO8jlcrRp0wbx8fFYuHBhlc8PACQlJZW70Qp/nrmhohQhhBBiQowxTP/uEhQ5RTj5Wn80d7Iy9ZJMol27dsL/F4vFcHJy0mkf5w+mFAoF7OzsylzfyspKp+ji4eEBhUJhwBVXDz82U57SowHl2bZtW7mnz5kzB3PmzKn0usuXL8fy5curs0RCSAPXo0cPoQgCAD179sSnn36KW7duQaVSISAgAGoGcAA4TjPS5+TkJFxeKpXqvEaHhYVBLBYLH8CfduPGDZw8eRI2NjZlzouKitIpSpVWndft3NxcrFq1Cn/++ScSExNRUlKCgoICITMvIiICEokEnTp1Eq7j7+8PR0dHnfXl5ubqPEYAKCgoQFRUVKX3z98HX1zgdevWDSdOnNA5LTg4WCdH6saNG4iMjIStra3O5QoLC4X7TU5Oxttvv41Tp05BoVBApVIhPz+/0kxA/rbPnj2rM5KnUqlQWFiI/Px8WFlpji34Ql1dVfWzO3bsGNasWYPw8HBkZ2ejpKSkzFoM9b5dem18zqJCoRCKUk//PvN/B/zvJa/038HixYuxcOFCHDlyBIMHD8bEiRPLPAcV3W/z5s1x7969MsXOXr164Ysvvij3Mdy7dw/t2rWDXC4XTitv9L+hoKIUIYQQYkKP0vKRmFUIALgWm6HXopSlhRh33wvV2+3V9L5r4umwVY7jdE7jP1Dx4d/VuT7TZqYQQoi+1afX14rk5uZCLBbj2D/nkZqnhIWEg5+LpmBSuqBkaWmpU9SytLSs8nZHjx6Njz/+uMx5/Id1oPzX7Ype43nLli3D0aNHsXbtWvj7+8PS0hKTJk2qNCenvPV5eHiU+8WAPnccs7a2LnO/nTt3xk8//VTmsnwH1MyZM5GWloYvvvgC3t7ekMlk6NmzZ5WPLzc3F6tXr8aECRPKnFe6sPH0mmqrsp9dTEwMRo0ahYULF+LDDz9EkyZNcObMGcydOxfFxcVCUcpQ79tVHTs8/fvM/x1cvXoVYrHu3xb/dzBv3jyEhobizz//xJEjR7BmzRp8+umnePnll6t9v8bk7u6OS5cu6ZzGb7xijhvFGLQoNWbMGISFhUGhUMDR0RGDBw/Gxx9/LLTFVYYxhhEjRuDQoUPYv38/bflICCGkQboRnyn8/1sJWRjX0VNvt81xXI1GPAghhFRPfXp9vXjxos5/X7hwAS1btkTHjh2hUqmQlJwM3+Au4DgOfk3tdD6wlyc4OBhqtRqnT58WxvdK69SpE/bu3QsfHx9IJLV/jiwsLKBSqXROO3v2LGbNmoXx48cD0BQUSgeFBwQEoKSkBNevX0fnzp0BAJGRkcjIyNBZX1JSEiQSiU5QeHUFBATg8uXLmDFjhnDa5cuXq7xep06dsGvXLri6upbb8cs/vq+++gojRowAoAlGLx0QD5T/vHTq1AkRERHw9/ev6cOpEN/l9fR9VeXq1atQq9X49NNPhRH13bt3621d/Npquq6K8H8HCoUCffr0qfByXl5eeOGFF/DCCy9gxYoV2LJli05RqjJBQUE4e/asMF4JaH7WT++UW/ryO3bsEMb/AM3fbXX17NkTH374IRQKBVxdXQEAR48ehZ2dXYX3aUoG3X1vwIAB2L17NyIiIrB3715ERUVh0qRJ1bru559/XuULIiGEEFLf3YzPEv7/rYSsSi5JCCGE1FxsbCyWLl2KiIgI/PLLL1i/fj1eeeUVtGrVCs899xwWvzAPx/4+iLhHMTh/4SLWrFmDP//8s8Lb8/HxwcyZMzFnzhz89ttviI6OxqlTp4TCw0svvYT09HRMnToVly9fRlRUFA4fPozZs2fXqJDg4+OD48ePIykpSSgqtWzZEvv27UNYWBhu3LiBadOm6XSjBAYGYvDgwXj++edx6dIlXL9+Hc8//7xOd8zgwYPRs2dPjBs3DkeOHEFMTAzOnTuHt956C1euXKlyXS+//DK+++47bN++HQ8ePMAHH3yAmzdvVvnZ9bnnnoOzszPGjh2Lf//9V3jeFi9ejPj4eOHx7dixA/fu3cPFixfx3HPPlelMK+95effdd/HDDz9g9erVuHPnDu7du4edO3cKWY214e3tDY7j8McffyAlJaVMZlZF/P39oVQqsX79ejx8+BA7duzA5s2ba72O8vj4+ODmzZuIiIhAamoqlEplrW+L/zuYMWMG9u3bh+joaFy6dEnn7+DVV1/F4cOHER0djWvXruHkyZMICgqq9n28/vrr2LZtGzZt2oQHDx7gs88+w759+7Bs2bJyLz9t2jRwHIf58+fj7t27+Ouvv7B27dpq39/QoUPRunVrTJ8+HTdu3MDhw4fx9ttv46WXXhJy1C5duoTAwEAkJCQI14uNjUVYWBhiY2OhUqmE3YCr+7OvLYMWpZYsWYIePXrA29sbISEhWL58OS5cuFDlL01YWBg+/fRTfP/994ZcHiGEEGJyN0t1St1JyIJaTSNnhBBC9GfGjBkoKChAt27d8NJLL+GVV14R8m22bt2Kcc9Mxafvv42x/bti4sQJuHz5Mpo3b17pbW7atAmTJk3Ciy++iMDAQMyfPx95eXkAgKZNm+Ls2bNQqVQYOnQogoOD8eqrr8LBwaFGmzt8+umnOHr0KLy8vNCxY0cAwGeffQZHR0eEhIRg9OjRCA0N1cmPAoAffvgBbm5u6Nu3L8aPH4/58+fD1tZW6DjhOA5//fUX+vbti9mzZ6NVq1aYMmUKHj16VCYcujzPPfccVqxYgWXLlqFTp06Ijo7GrFmzdMbkymNlZYV//vkHzZs3x4QJExAUFIS5c+eisLBQ6Jz67rvvkJGRgU6dOmH69OlYvHix0OlS2fMSGhqKP/74A0eOHEHXrl3Ro0cPrFu3Dt7e3tV7ssvh6ekpBKi7ublVmI/4tPbt2+Ozzz7Dxx9/jLZt2+Knn37CmjVrar2O8syfPx8BAQHo0qULXFxccPbs2Trd3tatWzFjxgy89tprCAgIwLhx43T+DlQqFV566SUEBQVh2LBhaNWqlU5oe1XGjRuHL774AmvXrkWbNm3w9ddfY+vWrejfv3+5l7exscHBgwdx69YtdOzYEW+99Va547AVEYvF+OOPPyAWi9GzZ0/85z//wYwZM3SC5PPz8xEREaFTm3n33XfRsWNHrFy5Erm5uejYsSM6duxYrWJtXXDMSIEL6enpWLhwIRISEnDmzJkKL5efn48uXbpgzZo1GDt2LDiOq3R8r6ioCEVFRcJ/Z2dnw8vLC1lZWRW2RRJCCCHmQKVmaLvyMAqUT745Pra0H/xdy4bDVqWwsBDR0dFo0aJFlQfGRKOy5yw7Oxv29vaN9niisT9+Qp5WX19j+/fvjw4dOuDzzz+v8DL3k3NQqH0f8nGyhp2lRYWXrY/i4+Ph5eWFY8eOYdCgQQa5jyFDhsDd3R07duwwyO0TYq70cSxl0E4pAHjzzTdhbW0NJycnxMbG4vfff6/08kuWLEFISAjGjh1brdtfs2YN7O3thX9eXl76WDYhhBBicJGKXBQoVbCSitHBywEAcJtG+AghhBhRierJ+JtSZZpgZn06ceIEDhw4gOjoaJw7dw5TpkyBj48P+vbtq5fbz8/Px2effYY7d+4gPDwcK1euxLFjx3Tygggh1VfjotTy5cvBcVyl/8LDw4XLv/7667h+/TqOHDkCsViMGTNmVJiqf+DAAZw4caLSSv7TVqxYgaysLOFfXFxcTR8SIYQQYhJ8yHlbT3u0b2YPgHKlCCGEGI+aMZSUGhtXqur/CLlSqcR///tftGnTBuPHj4eLiwtOnTpVZre3irRp0wY2Njbl/vvpp590xv86d+6MgwcPYu/eveWGvpuzF154ocLH+cILL5h0bRWty8bGBv/++69J12YOzPlnVxs13g7htddew6xZsyq9jK+vr/D/nZ2d4ezsjFatWiEoKAheXl64cOECevbsWeZ6J06cQFRUVJmtOCdOnIg+ffqUu22nTCYTwroIIYSQ+oTPk2rfzB6t3DTbcFNRihBCiL6U9/mptJKnOqMaQqdUaGgoQkNDa339v/76q8IMZDc3N1haWuLYsWO1vn1z8d5771UYtG3qse2wsLAKz/P01N8uxfWVOf/saqPGRSkXFxe4uLjU6s74nRFKZ0CVtnz5csybN0/ntODgYKxbtw6jR4+u1X0SQggh5orfea9dMwehKHX3cTbUagaRiHagJYQQYlhPd0Y1hKJUXdUlHLw+cXV1LROibi78/f1NvQSzZs4/u9qocVGqui5evIjLly+jd+/ecHR0RFRUFN555x34+fkJXVIJCQkYNGgQfvjhB3Tr1g3u7u5wd3cvc1vNmzdHixYtDLVUQgghxOiKSlS4l5gNAGjfzAFNHeSQW4iQW1SC6LQ8+LnUPOycEEIIqYkSbdMAB4ChYYzvEULqF4MFnVtZWWHfvn0YNGgQAgICMHfuXLRr1w6nT58Wxu2USiUiIiKQn59vqGUQQgghZik8MQdKFYOjlQW8mlhCIhahtYem5bouYed8VzKpGj1XhJCaMtLG5UbDF6FkFmLtf6sb3GMkhBiOPl4vDNYpFRwcjBMnTlR6GR8fnyofBL0oEkIIaYj4PKngZg7gOM2oXrCnPa7FZuJWfBbGdqhZZoJUKoVIJMLjx4/h4uICqVQq3C7RxRhDcXExUlJSIBKJIJVKTb0kQoiZE4s1RZvi4mJYWlqaeDX6w2dKWVmIUahUQc0YVIxBQu8fhJBq4BuMqruRQHkMVpQihBBCSMVuaPOk+F33AM0ufABwsxadUiKRCC1atEBiYiIeP36sn0U2cFZWVmjevDlEIoM1jhNCGgiJRAIrKyukpKTAwsKiwbxuFBQUgpUowak5cGol1GqGvLwCoXOKEELKwxhDfn4+FAoFHBwchMJ9bVBRihBCCDEBvlOqXTMH4bRgbYGqtmHnUqkUzZs3R0lJCVQqlb6W2iCJxWJIJBLqJiOEVAvHcfDw8EB0dDQePXpk6uXoTWpuEQqVaiitLJBbVAKlikGdLYWcilKEkGpwcHAoNxe8JqgoRQghhBhZXlEJIhW5AHQ7pfxdbOocds5xHCwsLOrURk0IIaQsqVSKli1bori42NRL0ZsPt19GdGoe1kxoh9/vxuNSdDpeG9oKI1o1NfXSCCFmzsLCok4dUjwqShFCCCFGdjshC2oGuNvJ4WonF06XiEUI8rDD9dhM3E7Ioh34CCHEzIhEIsjl8qovWE/cVRQiPU8FV0cbyOVyJOSoEJtV0qAeIyHEvDWMYWhCCCGkHrmpzZNqV6pLiheszZW6FV/7HfgIIYSQqhSXqJGep+n6crWVw91OE+CelFVoymURQhoZKkoRQgghRnZDmyfV3suhzHlCUaoWYeeEEEJIdaXkFgEALMQcHK0s4GGv6Y5KpKIUIcSIqChFCCGEGFmlnVLa0+5ow84JIYQQQ0jO1hSfXG3l4DgO7tqiFHVKEUKMiYpShBBCiBFl5BUjNj0fANDO06HM+aXDzmPS8oy8OkIIIY2FIlvTKeVqJwOAUp1SBSZbEyGk8aGiFCGEEGJE/Fiet5MV7K3K7pDHh52XviwhhBCib4ocTUeUm62mGMV3SmUXliCvqMRk6yKENC5UlCKEEEKM6KY2T6pdM4cKL0Nh54QQQgzt6U4pW7kFbGSazdmTsmmEjxBiHFSUIoQQQozohrbQ1L6cPCleWwo7J4QQYmB8ppSbnVw4zU1boKJcKUKIsVBRihBCCDGimnRKUdg5IYQQQ1HkaDqlXGxlwmke9pYAaAc+QojxUFGKEEIIMZLk7EIkZxdBxAFtPe0qvFxLVxvIJBR2TgghxHDK65Tic6WSaXyPEGIkVJQihBBCjORGXCYAoKWrLaykkgovJxGL0LophZ0TQggxnBRtp5SrTqcU7cBHCDEuKkoRQgghRnJTmyfVrpI8KR4/wnebilKEEEL0rLhEjbS8YgDld0pRphQhxFioKEUIIYQYyQ0+T8rLocrLUtg5IYQQQ0nN1XRJWYg5OFpZCKc/6ZSiohQhxDioKEUIIYQYAWNMKDBVtvMeTwg7T6Cwc0IIIfrFZ0a52srBcZxwurudJuicOqUIIcZCRSlCCCHECGLT85GZr4RULEKge8Uh5zw+7DyHws4JIYToWXJ22Z33gCedUml5xShUqoy+LkJI40NFKUIIIcQIbmjzpII8bCGVVP32KxGLEORBYeeEEEL0LyWH33lPtyjlYGUBmfY9SqEtXBFCiCFRUYoQQggxgpvanffaNXOo9nUo7JwQQogh8J1SrrZyndM5jqMd+AghRkVFKUIIIcQIarLzHi+Yws4JIYQYgKKCTimg1A582fUvV4oxhkhFLlSUxUhIvUFFKUIIIcTAVGqG24+1IefV2HmP15bCzgkhhBhARZ1SAOBuV3934Pv1SjwGf3Yam09HmXophJBqoqIUIYQQYmCRilzkF6tgJRXDz8Wm2tdr6fYk7PxRer4BV0gIIaQxUeRoi1LldkrV3x34rmtH5S/HpJt2IYSQaqOiFCGEEGJgN+IzAWg6n8QirvILl2JBYeeEEEIMQKEdzSuvU4rPlKqPRan4DM0XOFEpuSZeCSGkuqgoRQghhBjYLW2eVPsa5EnxKOycEEKIPilVaqTlFQOoPFMqsR5mSsVpu4rjMwpQqFSZeDWEkOqgohQhhBBiYDe1nVLBNdh5jyeEncdTUYoQQkjdpWhH9yQiDo5W0jLnP+mUql+776nVDAmZmjUzBkSn5pl4RYSQ6qCiFCGEEGJAxSVq3EvMAVC7Tqm2pTqlKOycEEJIXQl5UrYyiMoZKec7pRQ5RVCq1EZdW10k5xRCqXryPhmpoBE+QuoDKkoRQgghBhSelI1ilRoOVhZo3sSqxtdv6WYDKYWdE0II0ZNkPk/KrmyeFAA4W8sgEXFg7ElXVX0Ql67b2UW5UoTUD1SUIoQQQgzohnbsLtjTHhxX/ZBzHoWdE0II0afSnVLlEYk4uGkLVon1KOw87qkvbqJSaHyPkPqAilKEEEKIAd3Ubk/dvhZ5UrxgT01RisLOCSGE1BW/855bBZ1SQP3cgS8+Q9Mp5WyjKbZF0fgeIfUCFaUIIYQQA7qp7ZRqV4s8KV47TwcAFHZOCCGk7hTZlXdKAaV24KtHYedxGZpOqX6tXAAAD1NzKYuRkHqAilKEEEKIgeQXl+CBQhty7uVQ69sRws4fZ4ExOsAmhBBSe8k5DbNTih/fC/FzglQsQqFSLezGRwgxX1SUIoQQQgzkdkI21Axws5NVevBfFSHsvLAEj9Io7JwQQkjt8Z1SLnYVd0oJmVLZ9acoxY/v+ThbwcdZs7EIhZ0TYv6oKEUIIYQYyM34TABAuzrkSQEUdk4IIUR/FHynlG1lnVKWAIDketIppVSphVFDL0cr+LnYAKCwc0LqAypKEUIIIQbC77zXvg55UjwKOyeEEFJXSpUaqbnFAADXSjqlnmRK1Y+iVGJmIdQMkEpEcLaRwd9VU5SKpLBzQsweFaUIIYQQA9FXpxQABGtzpW5S2DkhhJBaSs3VjO5JRByaWEkrvByfKZWcXVgvwsL5kPNmjpYQibhSnVJUlCLE3FFRqhErVKoo/I8QQgwkM79YyH+qy857PAo7J4QQUlfJfJ6UrQwiEVfh5VxsZRBxQImaITWvyFjLq7V4bVHKy1GTJcUXpR5SUYoQs0dFqUbs/T/uovfHJ3A+Ks3USyGEkAaH72jydrKCQyXfRldXKzdbCjsnhBBSJwptcLlrFZtvWIhFcLHVjPfVhx344tI1X7Q3c9RkYfm6WAMAUnOLkZlfbLJ1EUKqRkWpRuyfBylgDPjteoKpl0IIIQ2OPkf3AG3YubstAAo7J4QQUjvJOZquJ1fbivOkeO7asPP6kCvFj+95NdF0SlnLJGiqHUGkET5CzBsVpRqp3KIS4RuFkxGKejErTggh9Yk+Q855wdrborDz+mPjxo3w8fGBXC5H9+7dcenSpWpdb+fOneA4DuPGjStz3r179zBmzBjY29vD2toaXbt2RWxsbJnLMcYwfPhwcByH3377rY6PhBDSEKRoO6XcKgk553lou6nqQ6dUfMaTnfd4ftqw8ygF7cBHiDmjolQNXXyYhnVH7+Piw/o98haRlCP8f0VOEe48zjbhagghpOG5pS1K6atTCngSdk6dUvXDrl27sHTpUqxcuRLXrl1D+/btERoaCoVCUen1YmJisGzZMvTp06fMeVFRUejduzcCAwNx6tQp3Lx5E++88w7k8rKjOJ9//jk4ruLMGEJI48NnSrnaVj6+B9SvHfji0vlOKUvhNAo7J6R+oKJUDR28+RhfHH+AE+GVH1Cau/Ak3SJUfX88hBBiThTZhUjKLoSIA9o0tdPb7Qph5wkUdl4ffPbZZ5g/fz5mz56N1q1bY/PmzbCyssL3339f4XVUKhWee+45rF69Gr6+vmXOf+uttzBixAj873//Q8eOHeHn54cxY8bA1dVV53JhYWH49NNPK70vQkjjo8ipQaeUPd8pZd4bIxUqVVBoxxKble6U0uZKRSqoKEWIOTNYUWrMmDFo3rw55HI5PDw8MH36dDx+/LjK650/fx4DBw6EtbU17Ozs0LdvXxQUmM8LYZCH5sPF3cT63VkUnqjplOLnyU9EUFGKEEL0hR/d83e1gbVMorfb5cPOswtLEJtOYefmrLi4GFevXsXgwYOF00QiEQYPHozz589XeL333nsPrq6umDt3bpnz1Go1/vzzT7Rq1QqhoaFwdXVF9+7dy4zm5efnY9q0adi4cSPc3d319pgIIfVfQ+yU4kf3rKViOFpZCKcL43vUKUWIWTNYUWrAgAHYvXs3IiIisHfvXkRFRWHSpEmVXuf8+fMYNmwYhg4dikuXLuHy5ctYtGgRRCLzaejii1Lhpcbf6iN+fG9u7xYAgBtxmUjJMf/tXgkhpD7Qd8g5j8LO64/U1FSoVCq4ubnpnO7m5oakpKRyr3PmzBl899132LJlS7nnKxQK5Obm4qOPPsKwYcNw5MgRjB8/HhMmTMDp06eFyy1ZsgQhISEYO3ZstddbVFSE7OxsnX+EkIaH7yhyrUanlDufKZVt7kWpJyHnpUeW/bXje7Hp+SgqUZlkbYSQqunv69unLFmyRPj/3t7eWL58OcaNGwelUgkLC4sKr7N48WIsX75cOC0gIMBQS6yVQHdbcByQklOE1NwiONtU/YJubhhjuKcd3+vT0gV/3EzErYQsnIpQ4JkuXiZeHSGE1H+GCDnntfW0x434LNyKz8Kodk31fvvENHJycjB9+nRs2bIFzs7O5V5GrVYDAMaOHSscZ3Xo0AHnzp3D5s2b0a9fPxw4cAAnTpzA9evXa3T/a9aswerVq+v2IAghZq1EpUZaXvU7pTy0u+8lZRWCMWa2GXVx2k6pZo6WOqe72MpgK5Mgp6gEj9Ly0crN1hTLI4RUwSgtSOnp6fjpp58QEhJSYUFKoVDg4sWLcHV1RUhICNzc3NCvXz+cOXPGGEusNiupBD5Omvnke/V0hC8xqxA5hSWQiDj4uVpjYKAmh+IkjfARQkidMcYM1ikFUNh5feHs7AyxWIzk5GSd05OTk8sdqYuKikJMTAxGjx4NiUQCiUSCH374AQcOHIBEIkFUVBScnZ0hkUjQunVrnesGBQUJu++dOHECUVFRcHBwEG4HACZOnIj+/ftXuN4VK1YgKytL+BcXF1fHZ4AQYm5Sc4vBGCAWcXCyllZ5eb6bqqhEjcx8paGXV2vx2nH20nlSAMBxHHy1I3yUK0WI+TJoUerNN9+EtbU1nJycEBsbi99//73Cyz58+BAAsGrVKsyfPx+HDh1Cp06dMGjQIDx48KDC65mi3TxQOzpRX4tSfMi5r4s1ZBKxUJT6534qikvUplwaIYTUe3HpBcjMV8JCzCHQQ//fylLYef0glUrRuXNnHD9+XDhNrVbj+PHj6NmzZ5nLBwYG4tatWwgLCxP+jRkzBgMGDEBYWBi8vLwglUrRtWtXRERE6Fz3/v378Pb2BgAsX74cN2/e1LkdAFi3bh22bt1a4XplMhns7Ox0/hFCGpZk7Rieq60MIlHVXU9yC7FQvDLnXCk+U8qriVWZ8/gRvigqShFitmpUlFq+fDk4jqv0X3h4uHD5119/HdevX8eRI0cgFosxY8aMCg+g+Zb0BQsWYPbs2ejYsSPWrVuHgICASneOWbNmDezt7YV/Xl6GHz/jc6XuJdbPXCk+DyvQXfM4gj3t4WwjQ25RCa7EpJtyaYQQUu/d0HZJBXnYQSYR6/32W7nZQiqmsPP6YOnSpdiyZQu2b9+Oe/fuYeHChcjLy8Ps2bMBADNmzMCKFSsAAHK5HG3bttX55+DgAFtbW7Rt2xZSqeaD4euvv45du3Zhy5YtiIyMxIYNG3Dw4EG8+OKLAAB3d/cytwMAzZs3R4sWLUzwLBBCzIWQJ2Vb/fgRPuw8Kdt8Np56WlwG3yllWeY8P1fNhAuFnRNivmqUKfXaa69h1qxZlV6m9PbFzs7OcHZ2RqtWrRAUFAQvLy9cuHCh3G8IPTw8AKDSlvTyrFixAkuXLhX+Ozs72+CFqSdFqXraKaUtpgVoO75EIg4DAlzw69V4nAhXIMS//CwLQgghVXsyuqf/PCkAkEpECPKw1eRKJWTBWztSTszPs88+i5SUFLz77rtISkpChw4dcOjQISH8PDY2tsabuYwfPx6bN2/GmjVrsHjxYgQEBGDv3r3o3bu3IR4CIaQBETql7KrOk+J52Mtx53G2WXdKxWm/oPFyLNsp5cd3SqXkGXVNhJDqq1FRysXFBS4uLrW6I74Tqqio/B3efHx80LRp03Jb0ocPH17h7cpkMshkxg0bD9KOY0Sl5KK4RA2pxHx2B6wOfnwvqNRYycBAV6Eo9fao1hVdlRBCSBX4kHND5EnxhLDzBAo7N3eLFi3CokWLyj3v1KlTlV5327Zt5Z4+Z84czJkzp9proDFPQghQx04pMy1K5RaVIEObd+XVpJxOKaEolQu1mlVrbJEQYlwGqaZcvHgRGzZsQFhYGB49eoQTJ05g6tSp8PPzE7qkEhISEBgYiEuXLgHQBNG9/vrr+PLLL7Fnzx5ERkbinXfeQXh4OObOnWuIZdaap4Ml7OQSKFWs3oXmFZWo8FD7TQE/vgcAvVs6w0LM4WFqHqJT6ZsEQgipDZWa4XYCv/Oeg8HuJ7hUrhQhhBBSHQptp5RbjTqlNIUec+2UiteO7jlYWcBWXnZDLW8nK0hEHPKLVUjKNs/HQEhjZ5CilJWVFfbt24dBgwYhICAAc+fORbt27XD69Gmhq0mpVCIiIgL5+U/yMF599VWsWLECS5YsQfv27XH8+HEcPXoUfn5+hlhmrXEch8B6OsIXpchDiZrBVi6Bh/2TNyRbuQW6tWgCADgRTrvwEUJIbUSl5CK/WAUrqRj+2h1/DIEPO78VT2HnhBBCqqdWnVJ25t0pFZeuyboqL08KACzEIng7acb6KFeKEPNUo/G96goODsaJEycqvYyPj0+5B9LLly/H8uXLDbEsvQpyt8Wl6PR6V5SKSNaO7rnbgeN021cHBLjibGQaToYrMLc3haESYmwHbzxGZn4x/tPDu8zfJ6kfbsRlAgDaNrWH2IAjAk+HnVOuFCGEkKok16pTSnPZxCzzDDqvLE+K5+dig6iUPEQpctGnZe2iaAghhlO/wpDMiBB2nlS/ilJPh5yXNjDQFQBwMToNuUUlRl0XIY1dcnYhXtl5He/8fgcfH4qo+grELN0U8qQME3LOk0pECNTmAt6iET5CCCHVkJyt6ZRyaUCZUvEZmmKZV5NKilLazuVI6pQixCxRUaqWnuzAl1OvRifuJWmKUoEeZYtSvi42aOFsDaWK4cyDFGMvjZBG7bfrCVBrX0o2n47CxpORpl0QqRVh5z0vB4PflzDCR0UpQgghVShRqZGWpylK1aRTii9K5RWrkFOoNMja6iJOmylV0fgeAPjzYecKys0lxBxRUaqWAtxtIeKA9LxipOSUv6OgOYrQdnaVDjkvbUCApluKcqUIMR7GGPZdSwAAdNdmu31yOAI7zseYcFWkpopL1Lin7UZtb+BOKYDCzgkhhFRfam4xGAPEIg5O1tJqX89KKoGdXJP4Yo7dUtUa33N9sgMfIcT8UFGqluQWYrRw1mR43K0nuVIZecVC225543vAkxG+kxEpUKvrTwcYIfXZ3cRsRCTnQCoR4ZsZXbB4oD8A4J3f72D/9XgTr45UV0RSDopVajhYWaB5JWME+vKkKJVdrzp2CSGEGJ8iR1NQcrGRQVTDzENz3YGPMYYEYXyv4k4pXxfNZzZFThGyzbDbi5DGjopSdVB6hK8+CNeO7nk1sYSNrPyM+24tmsBaKkZKThFuP6Zv3wkxBr5LakiQG+wtLbBkSCvMCvEBACz79SaO3Eky4epIdd3Qju4Fe9obJaieDzvPKlAKuw8RQggh5eG/mHa1q36eFM9cc6WyCpTI0ebgNqukU8pObgE37eOOUlC3FCHmhopSdfCkKFU/OqXCtaN7AW7lj+4BmvBcflcKGuEjxPBKVGr8HvYYADC+oycAgOM4vDuqNSZ2agaVmmHRz9dxNjLVlMsk1SDkSRlhdA+gsHNCCCHVx3dKudpWP0+K92QHPvMqSvFfyDjbyCC3EFd6WT8+VyqFcqUIMTdUlKqDIO2HgXpTlNJ2dAWVE3Je2sAg7QgfFaUIMbh/I1ORmluEJtZS9At4sk2xSMTh44nBCG3jhmKVGvN/uILrsRkmXCmpypOd9xyMdp8Udk4IIaQ69NIplW1eXbl8yHllo3u8J0Up6pQixNxQUaoO+E6ph6l5KFSqTLyaqoUna3feqyDknNdf+8H4RnyW8K0KIcQw9mtH98a0bwoLse5LskQswpdTO6JPS2fkF6swa+tloeORmJf84hLcT+ZDzh2Mdr8Udk4IIaQ6UrTH9G4NqFMqPqPqkHOenzZXisb3CDE/VJSqA3c7ORysLKBSM0Sa+QucWs1wX5spVVHIOc/VVi6Mn5yKSDH42ghprHIKlTiszYua0Mmz3MvIJGJ8Pb0zOjV3QFaBEv/59hJiUqn13NzceZwNNQNcbWXCN8rGEFyqU4rCzgkhhFSkbp1Smk4kc8uU4sf3mjlW3Snl76r5/BNJnVKEmB0qStUBx3EI0nYdmfsOfLHp+ShQqiCTiODjVPW3CcIufDTCR4jB/H07CUUlavi72gjFhfJYSSXYOrsbgjzskJpbhOe+vYjELPNqoW/sbsRlAjDu6B5AYeeEEEKqh59+cKtFUcpcO6WejO9Vo1PKVdMpFZuWD6VKbdB1EUJqhopSdVRfws75kZ+WbjaQiKv+sfNFqX8fpKK4hF64CTGEfdfiAWgCzqvarc3e0gI/zOmGFs7WSMgswH++vYi03CJjLJNUA58n1d5IIec8qUQkdL9SrhQhhJCKCJ1StRjf4zuAswqUyC8u0eu66iI+Q/NlTHXG99zt5LCSilGiZniUlm/opRFCaoCKUnUUWE/Czu8lVi9Pite2qT2cbWTILSrB5Zh0Qy6NkEYpPiMfFx6mg+OAcR3LH917moutDD/O646m9nJEpeRhxveXkF2oNPBKSXUIO+95ORj9vinsnBBCSGVKVGrhi6zajO/ZyiSwlmp2tzOXET7G2JNMqWoEnXMcR2HnhJgpKkrVUWuhUyrHrPM8IpL4olTleVI8kYjDwEBN4PnxezTCR4i+/R72GADQo4UTPB2qPpjieTpY4sd53eFkLcWdx9mYu+0yCorNf6OFhiwrX4kY7beu7SoZwzQUPgOQws4JIYSUJy2vGGoGiDjAybrmRSmO4+Am7MBnHkWplNwiFCrV4DjAw756x1H+rpqilLlnARPS2FBRqo78XW0gFnHIKlCazYt0efjxvep2SgGlcqUiqChFiD4xxrBXO7pXUcB5ZXxdbPDD3G6wlUtwOSYDL/x4lcZsTehmQiYAoHkTKzhaS41+/xR2TgghpDIK7eiei60MYlHlcQEV4XOlzKVTis9R9LCTQyqp3kdaYQc+6pQixKxQUaqO5BZi4QXOXEf48otL8Chd8y0+P25YHb1busBCzCE6NQ8P6cWbEL25GZ+Fhyl5kFuIMDzYo1a30aapPbbO6gpLCzFO30/Bkl1hUKmpIGEKfJ5UOyPnSfFKh53z+RqEEEIILzmbDzmv/e6w7naabiRzCTvnR/eaVSPknPdkfI92MSbEnFBRSg+CSo3wmaP7yblgDHC2kcLZpvotuzYyCbq3cAIAnKBd+AjRGz7gPLSNO2xkklrfThefJvh6emdIxSL8eSsR/913izplTIDfea+9kXfe41HYOSGEkMoocviQ85qP7vHMrVOK/xKmmWP1IxD8tON7DxW5dLxEiBmhopQe8EWpu2baKRWeWPPRPd4AGuEjRK+KS9Q4eDMRADChU7M6317fVi74cmoHiDhg15U4fPDnPTrQMjJTd0oBFHZOCCGkYnynlGtdOqW0RSlz6ZSK006BVGfnPZ63kxXEIg45RSVCoY4QYnpUlNIDPjzcXMf3wmsYcl4anyt1KTodObTLFyF1dvp+CtLziuFiK0MvPye93Oawth74eGI7AMB3Z6Kx/kSkXm6XVE2RXYik7EKIuCeFIVMQcqXiqShFCCFEl147pbLNY0yc75TyqsH4nkwiRnPt5aMo7JwQs0FFKT3gd+CLSc0zy12w+JDzgFoUpVo4W8PX2RpKFcOZB6n6XhohjQ4/ujeuQ1NIxPp7CX6mixdWjm4NAPjs6H18fyZab7dNKnZDWwTyd7WBdR1GMeuKws4JIYRURKGPTCkzG9+L4zOlajC+B1DYOSHmiIpSeuBiK4OTtRRqBtxPNq9cKcYYIrSdUvyYYU3xI3yUK0VI3WTlK3H8nubvSB+je0+b3asFlg5pBQB474+7+PVKnN7vg+i6FZ8JAGhnojwpXit3G1iIOQo7J4QQUkZyjnZ8r06dUpriT2puMYpKTPslvErN8Diz5p1SAIWdE2KOqCilBxzHlQo7N68RPkVOETLylRBxmm/ya2OgkCuVAjXt7kVIrf1x6zGKVWoEutvWukhclZcH+mNe7xYAgDf33sSh24kGuR+iwXdKtTdhnhSgGUngcwMpV4oQQkhpimzN+F5dOqUcrSwglYh0bs9UkrMLoVQxWIg5uNfwMfFh55E0vkeI2aCilJ4EeZhnrhSfJ9XC2RpyC3GtbqOrTxPYyCRIzS2iDzuE1MH+awkAgIkG6JLicRyHt0YG4dkuXlAz4OVfruOf+ykGu7/GjDGGm9pOqWATd0oBFHZOCCGkLJWaITW37plSHMcJuVKmDjvnQ86bOlhCLOJqdN0nnVJUlCLEXFBRSk+edEqZ1/iesPNeHboypBIR+rR0BkAjfITU1qO0PFx5lAERB4zt0NSg98VxHP5vQjBGBntAqWJYsOMqrsSkG/Q+G6P4jAJk5CthIeaELyZMic+Vuk1FKUIIIVppuUVQM0DEAU42tS9KARC6khKzTDsmHqcdU69pnhTwJFMqMasQuUUlel0XIaR2qCilJ/zYxL2kbLMKmRV23nOr2wemgZQrRUid7NN2SfVu6VKnLZmrSyzisO7ZDujXygUFShVmb7uMO4+pWKFPN7RdUoHudpBJateJqk8Udk4IIeRpydpRO2cbWY27ip7Gh50nZ5tHp5SXY83ypADAwUoKZxspACCacqUIMQtUlNITf1dNyGxOYYlZhcwKRak65tf0D9AUpW4lZAk7eBBCqocxhv3XNUWpCR09jXa/UokIm//TGd18miCnsAQzvruEh9Surjc3tXlS7UycJ8Xjw84z8ynsnBBCiIYip+477/HczWR8j3+Pq2nIOY8f4YtMMa8JF0IaKypK6YlUIhJe4PhCkKkpVWpEKrRFKfe6dUq52MqEIN9TEZRPQ0hNXH2Ugdj0fFhLxRjaxs2o920pFePbWV3Q1tMOaXnF+M+3F5GQSQULfbgRlwkAaG8GeVKAJuw8QPtaTyN8hBBCgCedUnXJk+J5aAtbSabOlMrQdErVZnwPeBJ2HqWgTilCzAEVpfSotZntwBedmgelisFGJoGnQ+1etEsbGKj5MH08PLnOt0VIY7JP2yU1PNgDVlKJ0e/fTm6B7bO7wc/FGo+zCvGfby8K35yS2lGpmVD4aedlHp1SwJMRvptUlCKE1IIiu5B2Wm5g+Pd7fUQHuNtrPk+YvFMqnS9K1a1TisLOCTEPVJTSoyAzK0rx62jlZgNRHWfIgSe5UmcepKKoRFXn2yOkMShUqvDHjccAjDu69zQnGxl+nNcdng6WiE7Nw6RN5xGdSt8Q1tbDlFzkFatgaSGGv/bg1hy0pbBzQkgtnY1MRbf/O47/7r9l6qUQPdJrp5S96TulikvUSNJGiXg1qWWnlDbsnIpShJgHKkrpkbkVpfSVJ8Vr09QOLrYy5BWrcDk6Qy+3SUhDdzJcgezCEjS1l6OHr5NJ1+Jhb4lf5veAt5MVYtPzMXHTOYRpR9BIzdzQ5km19bSDRGw+b6UUdk4Iqa2D2i9Qdl6Ow614Kmw3FCl6zJTii1KKnEKUqNR1vr3aSMwqgJoBMokILrXcTdBfO74XnZpnssdBCHnCfI6kG4BA7Zbgj9LzkWcGW4xGaItSQXXMk+KJRBwGagPPaYSPkOrZq911b2xHT710LNZVcycr7F0YgmBPe6TnFWPqNxdwMoJ21aypm9qd99qZSZ4UL8DdlsLOCSG1cjYqVfj/a/6+R4XtBkKfnVJONjJIRBzUDEjJLarz7dVGXLrmva2ZoyU4rnbHVU3tLSG3EEGpYoij90pCTI6KUnrkbCODi60MjAERyaYPOw/XdmwFuOunUwoABmhH+E6G04dYQqqSlluEU9qCjylH957mbCPDzud7oG8rFxQoVZi3/Qp+vRJn6mXVKzfMbOc9HoWdE0JqIzYtH3HpBZCIOEjFIpyLSsPp+7SxTUOgz933xCJOuB1T5UrxIee13XkP0HzR7uvMh53TCB8hpkZFKT0zlxG+rAIlHmvfLAL01CkFAL1bOsNCzCEmLZ+2liekCn/cTESJmqFdM3u0dNPf36E+WMsk+G5mF0zo5AmVmuH1PTex8WQkfTNeDcUlatx7rHmNN5ed90orPcJHCCHVwXdJdWruiBk9vQEAH/0dDhWFntdrKjVDSo62U8qu7p1SAOBu4lypeL4oVcuQc56wAx99niHE5KgopWdB2hE+Uxel+NG9pvZy2Fta6O12bWQSIRfnBHVLEVKpfdfiAQDjzahLqjQLsQifPtMeC/v7AQA+ORyBd3+/Qx9CqhCRlINilRr2lhbwdqrbQbEhtKWiFCGkhs5EaopSvfyd8dIAf9jKJQhPysF+7e6xpH5KyyuCmgEiDnCylurlNvmilMk6pUqN79UFv0lJJHVKEWJyVJTSs9ZCp5Rpx/fCkzRFMX2FnJc2QJsrRUUpQioWqcjFjfgsSEQcRrdvaurlVIjjOLw5LBCrRrcGxwE7LjzCSz9dQ6GSdtisyA0hT8q+1nkWhhRcagc+6nwjhFRFrWY4H5UGAOjl7wRHayleGuAPAPjsSAS9H9RjCm2elLONTG+bcnhox/eSs+vv+B4A+LnSDnyEmAsqSulZoDa/KTwxG2oTdhsIO+/pcXSPN1CbK3UpOh05hUq93z4hDcH+65ouqX6tXOBcy91hjGlWrxbYMLUTpGIRDt1JwozvLiErn/6+y3OzVFHKHPFh5xn5SiRkUoArIaRy95KykZ5XDGupGO29HAAAs0J84GEvx+OsQmw7F2PS9ZHa4/Ok9DW6B5i+U4rfxKPO43su/PheHn2BQ4iJUVFKz3xdrCEVi5BXrBIq+abwJORc/0UpH2dr+LpYo0TN8O+D1KqvQEgjo1Yz/HZds7X2hE7NTLya6hvZzgPb53SDrVyCSzHpeObrc0jMoqLG024KIecOpl1IBWQSMVppM8xoW3dCSFXORWq6pLr7OsFC200jtxDjtaEBAICNJyORkVdssvWR2uN33nOzrXvIOe9JppTxjw8KlSohI6uu43stnK3BcZoc3jT6/SbEpKgopWcWYhFaumkq76Ya4VOrGe4na1pRgwwwvgcAA2mEj5AKXYxOR0JmAWzlEgwKcjX1cmqkp58Tfn2hJ9zsZLifnIsJX53DfTPYTdRcFBSr8ECbP2GOIec8CjsnhFRX6Typ0sZ39ESguy1yCkuw8WSkKZZG6ogfsdNnp5SHCTul+JBzG5kEDlZ1y8yVW4iFbivKlSLEtKgoZQCm3oEvIbMAuUUlsBBzaOFsbZD74Ef4TkUoTDqmSIg54gPOR7XzgNxCbOLV1Fygux32vdgL/q42SMwqxKRN53ApOt3UyzILdx5nQaVmcLGVwU2PB/n6FtyMilKEkKoVl6iF1/de/k4654lFHJYPDwQA/HD+EeLSTTcBQGpHwe+8p9dOKU2HUnJ2odE/A5QOOddHpqOfC+VKEWIOqChlAKYuSvH36+9qK7Rh61sXnyawlUmQmluMm/ShhxBBQbEKf99OAlC/Rvee5ulgiT0v9ERnb0dkF5bgP99dxKHbiaZelsnd0I7DtTfTkHMehZ0TQqrjemwGCpQqONtIEeBWNvKhXysX9PJ3QrFKjU+PRJhghaQuFAbolHK1lYHjAKWKGX3sLV5PIec8IVdKkaeX2yOE1A4VpQwgyEPzpn4vyTRFqQhtyHmQAfKkeFKJCH1aadq8aYSPkCeO3E1CblEJvJpYoou3o6mXUycOVlL8NK87hrZ2Q3GJGgt/uoYd52NMvSyTehJy7mDSdVSFws4JIdVxVrvrXoifc7mFdo7jsHxYEADgt7DHuE1fRNYrfKeUPjOlLMQiuGg3cEky8ghfXMaTTil98HPVFKUiqVOKEJOiopQBBGl34ItLLzDJ7nT8znuGCDkvbWCgGwDgRHiyQe+HkPpk//UEAMD4js3MupOmuuQWYmz6T2dM694cjAHv/H4HnxwOb7TdN09Czs1z5z1e6bBz+hBJCKnIWW2eVO+n8qRKC25mj7EdmgIA1vx9r9G+/tdHhsiUAkrnShn3Sw9+hLSuO+/x/F35TikqShFiSlSUMgBHaync7TQv1nyByJjCtR1agQYKOef1D3ABxwG3E7KF9mBCGjNFTiH+uZ8CQBMQ21CIRRw+HNcWS4e0AgBsPBmFN/bchFKlNvHKjCurQInoVE2Lv7l3SgEUdk4IqVxOoRJhcZkAgJCn8qSetmxoAKRiEc5GpuEf2nm5XlCpGVJzNeN1bnb665QCSu3AZ+Tj/3htp5S+x/cSMgtQUKzSy20SQmqOilIGwo/whRs5V6pQqRI+NAUauFPK2UYmfDA7GUEjfIQcCHsMNQM6NXcw2CYDpsJxHBYPaomPJgRDLOLw69V4PP/DFeQXl5h6aUZzS9sl5dXEEk2spSZeTdXaCkUp04ySE0LM26XodKjUDD5OVmhWReeJVxMrTO/pDQD46O9wqGiTG7OXllcElZqB4wAnPb9neWjDzo29A1+cNlNKX+N7TaylcNTu4vcwlbqlCDEVgxWlxowZg+bNm0Mul8PDwwPTp0/H48ePK71OUlISpk+fDnd3d1hbW6NTp07Yu3evoZZoUHzY+d1E43ZKPUjOhZoBjlYWcLU1/M5Qg7S78B2/R0UpYhq3E7Lw88VYlJhB186+a9rRvXoccF6VKd2a45vpnSG3EOFkRAqmbrmItNwiUy/LKG7UkzwpntApFZ9J4zaEkDLOaEf3QioZ3Stt0QB/2MoluJeYjd+0o+rEfCmyNe/NzjYySPS88ZHQKWXEolROoRKZ+ZpYFH11SgFPuqUiaYSPEJMxWFFqwIAB2L17NyIiIrB3715ERUVh0qRJlV5nxowZiIiIwIEDB3Dr1i1MmDABkydPxvXr1w21TIMx1Q58wuieu51R8mwGaotSZyJTUVRCba/EuFRqhgU7ruK/+2/h1V1hJi1MhSdl425iNizEHEa38zDZOoxhUJAbfp7fA45WFrgRl4lJm88jNq3hbxXOh5y3N/M8KV6Auy0kIgo7J4SU71ykJuS8sjyp0hytpXixvz8A4NMjEShU0nGfOVPkaPOkDPAltYcJilJx6Zr3MUcrC9jIJHq7XSFXKoV24CPEVAxWlFqyZAl69OgBb29vhISEYPny5bhw4QKUyoqDv8+dO4eXX34Z3bp1g6+vL95++204ODjg6tWrhlqmwfBFqYikHKO2OBsr5JzXpqkdXG1lyC9W4VJ0ulHukxDe+ag04cP2HzcTTVqY2q/tkhoU6AYHK/Mf7aqrTs0dsWdhCDwdLBGdmocJm841+EDtJyHnDqZdSDXJLSjsnBBSPkVOISKSc8BxQE/fyvOkSpvdywce9nI8zirE9nMxhlsgqTO+U0rfeVKlb9OYmVLx2tE9fXZJAU86paJoBz5CTMYomVLp6en46aefEBISAgsLiwovFxISgl27diE9PR1qtRo7d+5EYWEh+vfvb4xl6pWPkxVkEhEKlCo8SjNe5T1CW5TiM60MjeM4oVuKRviIse25GgcA6ODlAAsxZ7LClErN8FsYP7rXcALOq+LnYoP9L4YgyMMOqblFePbr8/j3QYqpl2UQipxCJGYVguOeZDXVB/wugRR2Tggp7XyUpkuqTVM7ONYgb0huIS616UUkMvOLDbI+UnfJ2qKUITulErMKjDYeHqcNOddXnhTPz1WTAUo78BFiOgYtSr355puwtraGk5MTYmNj8fvvv1d6+d27d0OpVMLJyQkymQwLFizA/v374e/vX+F1ioqKkJ2drfPPHEjEIqFb6Z4Rc6X48b0Ad8PuvFfaAG1R6mSEgnJLiNHkFCpx6E4SAGDVmDbYOK2TyQpT56JSkZxdBAcrCwwIcDXa/ZoDVzs5di/ogRA/J+QVqzB76+UGmTVyM05T1PF3sdHr2IChUdg5IaQ8Z7Q76PXyq97oXmkTOjVDoLstsgtLsPFkpL6XRvREGN8zYKdUoVKNrIKKp2D0KS5d2ylVRSh/TfGdUg9T8yjAnxATqVFRavny5eA4rtJ/4eHhwuVff/11XL9+HUeOHIFYLMaMGTMqLVq88847yMzMxLFjx3DlyhUsXboUkydPxq1btyq8zpo1a2Bvby/88/LyqslDMqggbWGILxQZWkpOEVJzi8FxQCs3G6PcJ6DJIpCKRXiUlo+HqTSPTYzjr1uJKFSq4edijfbN7DG0jbvJClN8wPnodk0hlTS+TU1t5RbYOrsrRrdvihI1w6u7wvDNP1ENqkh9s56FnPP4sPPbCVkN6udRn2zcuBE+Pj6Qy+Xo3r07Ll26VK3r7dy5ExzHYdy4cWXOu3fvHsaMGQN7e3tYW1uja9euiI2NBaDpTn/55ZcREBAAS0tLNG/eHIsXL0ZWFnXLEQ3GGM5qQ857VTNPqjSxiMObwwMBANvPPRKKBcS8GLJTSm4hFnahNdYOfPF8p5Sex/eaOVpBKhGhuESNhAzKXyTEFGr06em1117DvXv3Kv3n6+srXN7Z2RmtWrXCkCFDsHPnTvz111+4cOFCubcdFRWFDRs24Pvvv8egQYPQvn17rFy5El26dMHGjRsrXNOKFSuQlZUl/IuLi6vJQzIofoTOWGHn/Oiej5M1rKTG+ybfWiZBd98mAIATNMJHjGTP1XgAwKTOXkKo/9OFqSW7bxi8MJVXVIJDtzUdWxMa0eje02QSMb54tgPm9W4BAPi/v8Lx0aHwKq5Vf9zQ5km196o/o3vAk7Dz9LxiPDby1t0E2LVrF5YuXYqVK1fi2rVraN++PUJDQ6FQVP5eGRMTg2XLlqFPnz5lzouKikLv3r0RGBiIU6dO4ebNm3jnnXcgl2s6Fx4/fozHjx9j7dq1uH37NrZt24ZDhw5h7ty5BnmMpP6JScvH46xCSMUidPVpUqvb6N/KBSF+TihWqfHZ0ft6XiHRhxRtp5QhMqUAwN3OuGHnfKaUvsf3xCIOvs7aET7KlSLEJGpUlHJxcUFgYGCl/6TS8ufS1WrNB8OiovK3Ds/P17zQiES6SxKLxcJ1yyOTyWBnZ6fzz1w82YHPOON7wuiem3HypErjc6VOhFNRihheTGoeLsdkQMQB4zvqFoJKF6YO3nhs8MLUodtJKFCq4OtsjQ5eDga7n/pAJOLw9qjWeGtEEADg69MPseN8jGkXpQeMsXrbKVU67PxWPHXKGNtnn32G+fPnY/bs2WjdujU2b94MKysrfP/99xVeR6VS4bnnnsPq1at1vujjvfXWWxgxYgT+97//oWPHjvDz88OYMWPg6qp5H27bti327t2L0aNHw8/PDwMHDsSHH36IgwcPoqSkxGCPldQffJdUJ28HWErFtboNjuOwYrjmtX7/9QTaTMEMGbJTCiidK2X4ohRjzGDjewCFnRNiagaZM7l48SI2bNiAsLAwPHr0CCdOnMDUqVPh5+eHnj17AgASEhIQGBgotLEHBgbC398fCxYswKVLlxAVFYVPP/0UR48eLbd1vT4I1BalEjILkJVv+Hlrfue9QCOFnJfGF6Uux6Qju9A4s+Wk8dp7TdMl1aelC9zty34DaMzC1L7rmrWM7+gpdGw1dvP7+uKNYQEAgFUH7wrZJfVVfEYBMvKVsBBzRttEQp9Kj/AR4ykuLsbVq1cxePBg4TSRSITBgwfj/PnzFV7vvffeg6ura7mdTWq1Gn/++SdatWqF0NBQuLq6onv37vjtt98qXUtWVhbs7OwgkdSfPDRiOHxRqnctRvdKC25mjzHtmwIAPm5AnbENgVrNkJJruN33AAjHX0lZhh95y8xXIq9YBUD/nVIA4Oei6ZSKpLBzQkzCIEUpKysr7Nu3D4MGDUJAQADmzp2Ldu3a4fTp05DJNNV6pVKJiIgIoUPKwsICf/31F1xcXDB69Gi0a9cOP/zwA7Zv344RI0YYYpkGZ29pAU8HzQvnPSPkSvGdUoHuxv/Q5O1kDT8Xa5SoGf69X78/gBLzplYzIcNpUudmFV7OGIWpxKwCnNPuYDSuY+Md3SvPwn5+mNDREyo1w4s/Xa3X3z7e1HYYBbrbQSapXVeBKbXV7sB3k4pSRpWamgqVSgU3Nzed093c3JCUlFTudc6cOYPvvvsOW7ZsKfd8hUKB3NxcfPTRRxg2bBiOHDmC8ePHY8KECTh9+nSF63j//ffx/PPPV7pec904huiXSs1w/qHmfSukjkUpAHg9NAAWYg7/PkjFP/cb5u6r9VFaXjFUagaOA5xtqr+7Yk0Ys1MqTju652org9xC/+/Dfq7UKUWIKRnkK7Pg4GCcOHGi0sv4+PiUCV1t2bIl9u7da4glmUyQhy0SMgtwLzEbPXydDHY/JSo1HiRrXkgDjbjzXmkDA10RlRKNE+EKjGznYZI1kIbvwsM0JGQWwFYuwZDWbpVeli9MvfjTNRy88RgAsG5ye0jE+qnH/3b9MRgDurVoAi89B2/WdxzH4f8mBCMmLQ/XYjMxb/sV7H8xBA5Whjk4NqQno3v1K0+K93TYOXX0maecnBxMnz4dW7ZsgbNz+cUCPs5g7NixWLJkCQCgQ4cOOHfuHDZv3ox+/frpXD47OxsjR45E69atsWrVqkrvf82aNVi9enXdHwgxa3cfZyMzXwlbmQTtPOv+mubVxArTe/jg+7PRWPN3OHr7O0MkotcYU+N33nOyluntmOdp7vaaL96Tso1QlErXhpwboEsKKD2+Rxs2EWIKjW+bKCPjc6XCDZwrFZOWj6ISNSwtxGhuog/HA7QjfKciFFDTlqrEQPiA89Htm1br27Khbdzx1XOdIBFpOqaW6qljijGGfdoxwomNOOC8MnILMb6e3gWeDpaITs3DSz9fg9JIOyLq0416XpQKpLBzk3B2doZYLEZycrLO6cnJyXB3dy9z+aioKMTExGD06NGQSCSQSCT44YcfcODAAUgkEkRFRcHZ2RkSiQStW7fWuW5QUJCw+x4vJycHw4YNg62tLfbv3w8LC4tK12vOG8cQ/Tkbpelm7+7rpLdixcsD/WErl+BeYjZ+C0vQy22SulFk86N7hsmTAp50Shkj6JzvlDLUF4C+2vG99LxipOcVG+Q+CCEVo6KUgQlh5wYe3+NH91q525rsG6quPk1gK5MgLa9Y+BBHiD7lFCrx1+1EAJWP7j2tdGHqgJ4KU3ceZ+OBIhcyiQjDg6kzsCIutjJ8O7MLrKRinI1Mw+qDd0y9pBpRqxluJ2heX+tbyDmPws5NQyqVonPnzjh+/LhwmlqtxvHjx4V8zdICAwNx69YthIWFCf/GjBmDAQMGICwsDF5eXpBKpejatSsiIiJ0rnv//n14e3sL/52dnY2hQ4dCKpXiwIEDws58lTHnjWOI/jzJk9Jf976jtRQL+/sBAD49ch+FSpXebpvUTrK2e8lQIedA6Uwpwxel+J33DBFyDgBWUokQuUIjfIQYHxWlDIwvSkUk5Rh0B7AIbch5kAnypHgWYhH6tnIBAJykXfiIAfx9KwmFSjV8XazRsYY73em7MMWHrQ9p7QY7eeUdCI1dkIcdvpjSERwH/HghFj/Uox35HqbmIreoBHILEVpqMyfqIwo7N42lS5diy5Yt2L59O+7du4eFCxciLy8Ps2fPBgDMmDEDK1asAADI5XK0bdtW55+DgwNsbW3Rtm1bYXfj119/Hbt27cKWLVsQGRmJDRs24ODBg3jxxRcBPClI5eXl4bvvvkN2djaSkpKQlJQElYqKBY1ZoVKFyzHpAIBeesiTKm1OrxbwsJcjIbOgXr3GN1SKHMOGnAOAu/a2c4pKkGPgTY4MPb4HlMqVorBzQoyOilIG5t3ECpYWYhSVqBGTZrg55Xva8cAAExalgCe78B2nohQxAH50b1LnZrXKxdFXYUqpUgsZVRM7Vb9jqzEb0toNb4QGAgBWH7yLfx+YfyBuel4xXvv1JgCgfTMHg+VyGAMfdn6LilJG9eyzz2Lt2rV499130aFDB4SFheHQoUNC+HlsbCwSExNrdJvjx4/H5s2b8b///Q/BwcH49ttvsXfvXvTu3RsAcO3aNVy8eBG3bt2Cv78/PDw8hH80kte4XYvNQKFSDVdbGfz1XGSXW4ixZEgrAMCGE5HIzKcRKFMyRqeUtUwCW7lE5/4MxdDje8CTHfioU4oQ46u/R9j1hEjECYWiuwbMlYpI5nfeM227ff8AF3CcZrTJ0G9QpHF5lJaHSzHpEHHAhI61LwTpozD174MUpOYWw9lGij4t9fttc0P2Qj9fTOjE78h3zay3Xn6cWYBnNp/DjbhMOFhZ4O2Rrau+khl7OuycGM+iRYvw6NEjFBUV4eLFi+jevbtw3qlTp7Bt27YKr7tt2zb89ttvZU6fM2cOHjx4gIKCAoSFhWHs2LHCef379wdjrNx/Pj4+enxkpL45F6nZda+Xv7NBNjyY2KkZAtxskV1Ygq9ORen99kn18Z1SrgbslAKMswMfYwwJGZpOKUON7wEUdk6IKVFRygiEXKlEw+RK5RQqhbbWQBN3SjnZyNBem7tCI3xEn/Ze04Sn9vJ3FnIMamtoG3dsLFWYeu3XmhWm9mnXMqa9Z73unjE2juOwZkIwung7IqewBPO2XzbLb9MjFTmYuOkcolLy4GEvx54XeiK4noac8/iw87S8YqNs300IMT9ntHlS+h7d44lFHJaP0HTEbjsbI+QAEeNTGKFTCniyA58h31dScopQVKKGiAM8HAxXZOOLUub8hRkhDRV9mjKC1h6aQlG4gYpS95M1HVhudjI4Wpt+u/VBNMJH9EytZthbanRPH0JLFaZ+D6t+YSqrQIkjdzW7aU2gXfdqTCYRY/P0zvB0sERMWj5e/Mm8duS7HpuBSZvPIzGrEH4u1tizMAT+rqYt9uuD3EKMlnzYOY3wkUaEOgM1sgqUuKndhKaXHkPOn9a/lQt6+jqhWKXGZ0fuG+x+SOWMkSkFAB52hg8750f3POwtYWHALwL5kda4jHwK6yfEyKgoZQRPOqUMM74Xrg05N/XoHm+Atih1NjIVRSX0ok7q7kJ0GhIyC2ArkyC0Tdmt1GurNoWpv28lorhEjVZuNmjT1Dz+5uobZxvNjnzWUjHORaVh1YE7ZvHB8fT9FEzbchGZ+Uq093LAry+ECLvxNATBnprf18vR6SZeCSGGxxjD2sMR6PLBMVx9lGHq5ZjcxYdpUDPA18UaHvaGe13jOA4rtN1S+8MScOcxFcGNTa1mSBHG9wzdKWX48T1jhJwDgLONFHZyCRiDQXOACSFlUVHKCAK1Ramk7EJk5Ol/VCU8kS9Kmce3+W2a2sHNTob8YhUuPqQPP6Tu+IDzUe2bQm4h1utt17Qwte+6ZnRvQqfaha0TjdI78v10MRY/nH9k0vX8HpaAedsvo0CpQp+Wzvh5Xnc0MYPOU30aHKQJ1/7lUixSc4tMvBpCDIcxhvf+uIsNJyORlleMQ7drFibfEJ3lR/f8DJ+D2K6ZA0a3bwrGgI/+Djf4/RFd6fnFKFEzcJzmSyBD4jOlkrIKDHYf8UYIOQc0BdUnO/BRUYoQY6KilBHYyCRorn0hNUSuVATfKeVhHkUpjuOEXfhO0AgfqaO8ohIcup0EQH+je0+rbmEqLj0fl6LTwXHA2A5NDbKWxmRwazcsH8bvyHcH/9w3zY5828/F4NVdYVCqGEa188B3M7vCWiYxyVoMaUhrN7RvZo+8YhU2now09XIIMQi1muGd329j69kY4TRDdarXJ2ejnoScG8PrQwNgIebw74PUerHbakPCbzTkZC016Lgb0LA6pQDKlSLEVKgoZSSBwg58+i1KMcZwL8k8dt4rbUAAnyuVbBZjOaT++utWIvKLVWjhbI1OzR0Mdj/lFaZUat3f3d+0XVK9/JwNOv7QmDzf1xcTOzWDmgEv/WzcHfkYY/js6H2sPHAHjAEze3rjyykdIZU0zLdGjuPweqimCPjThVgKISYNjlrN8N/9t/DjhVhwHDCjpzcAIDzJMJme9UVSViEiFbkQcUBPX8PlSZXW3MkK/+mhef7X/BUOtZqOBY1F2HnP1rB5UgCEY6EkA+64zWdKGXLnPR6fKxWVQkUpQoypYR55myE+V4rPf9KXx1mFyCksgUTECdV9c9DL3xlSsQhx6QX0wk7qZE+pgHNDj8uFtnHHhmlPClNLd4cJhSnGmDC6N74jBZzrC8dx+L8JbYUd+eZuv2yQMeenqdQMb/92G18efwAAWDqkFVaNaQORqGGPZPZu6YwQP00I8RfHHph6OYTojUrN8Pqem9h5OQ4iDvj0mfZYMTwIHAek5hZDkdN4d508F6UZ3Qv2tIe9lYXR7vflgS1hK5PgbmI2fr+RYLT7beyEnfcMnCcFPOmUysxXGiwcPD5D0yll6PE94EmnFH12IcS4qChlJE/CzvX7bV2E9ts/Pxcbs/p231omQQ8/zbdxNMJHais2LR8XteNyxioEDWtbfmEqLC4T0al5sLQQY1hb/YWtE90d+R6l5WPhT1cNuiNfUYkKL/9yDT9d1HRTfDCuLRYPatloMsJeDw0AAOy9Fo9IBY01kfqvRKXG0t1h2HstHmIRh3XPdsCETs1gKRWjhZM1gCf5m43RGW2eVIiRRvd4TayleKG/HwBg7eH7tKOZkSiytTvvGaFTyk4ugZVUk/VpiB34VGqGx5nGHN/TvF48TMmj7j5CjMh8qhgNXGttUepBcq5eP2zxOQkBZhJyXtrAABcAVJQitbf3mqZLqre/M5oacRe08gpTv2o7toa1dW+QeUOm5mwjw3ezNDvyXXiYjnd/N8yOfLlFJZi99TL+upUEqViEjdM6CSMmjUXH5o4Y2toNagZ8Slu2k3pOqVLjlV1h+D3sMSQiDuundsTYDk++xDDUl4L1BWMM5yI1eVK9jVyUAoA5vVrA3U6OhMwC7DDxhhaNRXKO8TqlOI6Du53hcqUSswpQomawEHNwszN8ka15EytYiDkUKFV4bMDwdkKILipKGUkzR0vYyCQoVqnxMEV/OzqYW8h5aQMDNTs9XY7JQFaB0sSrIfWNWs2EopShAs4r83Rh6ueLsQCACZ1odM9QAt3t8OVUzY58v1yKxbZzMXq9/dTcIkz95gLORaXBWirG1tldMSLYQ6/3UV8sCw0AxwF/307CzfhMUy+HkFopLlFj0c/X8OfNRFiIOWx8rlOZv2k+01Pf8Qn1RVRKHpKyCyGViNDZ29Ho928pFWPp0FYAgA0nI5GVT8eDhsZ3SrkaoYgDPBnhS8rWfxGHDzn3dLCE2Ajj9RKxCD7a7sooPX5eI4RUjopSRiISccKBkT6/rePDO4PMKOSc19zJCv6uNlCpGe28QmrsUkw64jMKYCOTYGhr04zLlS5MAYCbnQwhRthOuzEbFOSGFcM1Ydzv/3EXpyL002kZl56PZzafx62ELDSxluKX53sYbRcqc9TKzVYYif3kcISJV0NIzRWVqPDiT1dx+E4ypGIRvp7eGaFtyr5XNPZOKT5PqquPI+QWYpOsYWKnZghws0VWgRJfnaKdPw0tWQg6N3ynFGDYHfj4DTmMkSfFE3KlaAc+QoyGilJGxHcz6evAqKhEJVTxzXF8DwAGBmp24fvkcATW/H0PR+4kITW3yMSrIvUBH3A+qp0HLKWmOZAGNIWpjc91gqutDC8N8DfKN3WN3fw+vnims2ZHvpd/vl7n3KOIpBxM2nwO0al58HSwxJ4XeqJdMwf9LLYeWzK4lbBl+zlt5gwh9UGhUoXnf7iKY/cUkElE2DKzi9Cd/bSgppqiVKQiF8UlhsuqM1dnHmjzpEz4hYpYxGG59suGredikJBJY1GGlKINOjfGuBsAePCdUgYoSsVlGC9PiufnyndKUVGKEGOhopQRCd/W6amFPEqRB5WawU4uEd4QzM2odh4QccCjtHx8ffohnt9xFV0+OIZ+n5zE0l1h+PHCI9x9nC3scEYIAOQVleCvW4kATDO697TQNu649NZgzOjpY+qlNAocx+GD8W3R1ccROUUlmLv9Sq135LsSk45nNp9DcnYRWrnZYO/CEPia0U6lpuTVxArTujUHAHx8OMIgGV6E6FtBsQrztl/B6fspkFuI8P2srujXyqXCyze1l8NOLkGJmiGykXU+qNQM5x+aLk+qtP4BLujh2wTFJWp8eoS6Mw1FrWZQGL1TSlMwMkinVLqmU6qZo/E7pRrb6wUhpkRFKSPSdws5P7oX6G5ntrtGtWvmgNOvD8Ank9phajcvtHLTvNA/SsvHvusJePu32xjx5b9ov/oI/vPtRXx29D5ORSgog6qR+/t2EvKLVfBxsjJJBgYxPZlEjM3/6Yxmjk925Ktpl8OJ8GT857uLyC4sQWdvR+xe0FMYMyAaiwa2hJVUjBtxmThyN9nUyyGkUnlFJZi97RLORKbCSirGttndqhzD5TgOgdrjL/64qbG4lZCFnMIS2MklaOtpb9K1cByHFcODAAD7ryfg7uPG9bMwloz8YpRov+h1MVJRysPOcJ1S8dpOKWOO7/m7asf3KFOKEKOhLaSMKNDdFhwHpOQUITW3CM42dXuzMOeQ89K8mljBq4kVnuniBQDIKlAiLC4TVx9l4NqjDFyPzUBuUQnORKYK2xZzHNDS1QadvR3RsbkjOns7wtfZ2myLb0S/9lyNA6DpkqKfeePlZCPDdzO7YsJXZ3HhYTpWHriN/xsfXK3fib1X4/HG3ptQqRkGBrpi47ROJh0DNVcutjLM6dUCG05GYu3hCAwOcqMRVWKWNDtnXsLlmAzYyCTYNrsruvg0qdZ1g9xtcSk6vdHlSp3VHlP19HMyi7/r9l4OGNXOA3/cTMSav+9hx9zupl5Sg5OsDTl3tpHCQmyc3gNDZkrFZfCdUsYb3+O7qVNzi5CVr4S9lYXR7puQxoqKUkZkJZXAx8ka0al5uJeYjT4tK243rw5+DDDQDEPOK2NvaYF+rVyEdnuVmuF+co5QpLoam4FHafm4n5yL+8m5+OWSpkDhaGWBTs0d0cnbEZ2aO6K9lz2spPQr3NDEpefjwsN0cBwwvpPpR/eIaQW422L9tI6Yu/0KfrkUh5autpjTu0Wl1/n234f44M97AIAJHT3x8aR2Rjs4r4/m9/XFjguP8ECRi9+uJ2CiGYzMElJadqESM7+/hOuxmbCVS7B9Tjd0al79LtogoVOqce3AxxelTD26V9rroQH461Yi/n2QiuTsQqPlHjUWyTmawpCLrfGeVz5CJDW3CMUlakgl+nm/LSpRIUmbj+VlxPE9G5kE7nZyJGUXIio1t0avNYSQ2qFP9EYW5GGrt6JUuPYbP3MNOa8usYhDkIcdgjzs8J8e3gA0b2x8geraowzcjM9CRr4Sx8MVOB6uKHU9W3TWFqp6+joZbftbYjj7riUAAEL8nODpYLxvxoj5Ghjohv8OD8KHf93DB3/eha+LNfoHuJa5HGMMHx+KwObTUQCAub1b4K0RQRCZQYeAObO3tMDC/n746O9wrDt2H6PbN9XbhwpC6iorX4kZ31/Ejfgs2FtaYMfcbjXeqCCwEe7AV6hU4cqjDABAiBkVpbydrOFuJ8fjrEI8ziygopSepWg7pdzsjDO6BwBNrKWQikUoVqmRnF2ot1G7xMxCMAZYWojhbCPVy21Wl5+rNZKyCxGpoKIUIcZARSkjC3S3w1+3knAvsW7f1qXnFQtBhvW9KFUeZxsZhrZxx1Dt9s7FJWrcTcx+0k31KANJ2YW4nZCN2wnZ2H7+ESzEHF4e2BIL+/tRV0Q9pVYz7Ln2ZHSPEN68Pi3wQJGD3Vfi8fLP17HvxRC0dHvy2leiUuO/+29h9xXNro1vDgvEC/18afyzmmb29MH3Z6IRn1GAXy7FYmaIj6mXRAgy8orxn+8u4s7jbDhaWeDHed3RpmnNs5EC3DTxCam5xUjJKTJa1o4pXYnJQHGJGh72cvg6W5t6OTrc7DVFKX7UjOhPsrazyFgh54AmL8zdXo7Y9Hy9FqVKj+4Z+73c38UGZyPTaAc+QoyEPrkbmb7CzvmwTq8mlrCRNfzaolQiQgcvB8zt3QIbn+uEC/8dhHPLB2L91I6YFeKD1h52UKoYPjt6H2M3nMXthCxTL5nUwuWYdMSlF8BGJkGotiBJCKDdkW9cMLr5NBF25EvX7shXqFRh4U/XsPtKPEQc8PHEYCzs70cFqRqwlIqxeFBLAMD6E5HIKyox8YpIY5eWW4SpWy7gzuNsOFlL8cvzPWpVkAI0v98tnDSFmcbSLXU2SjO6F+LnbHavhW7a0TK+gEL0h//C2tgdaIbIlYpL14ScGzNPiufHh50rKOycEGOgopSRBWlDyaNScmu8k1RpEfU0T0qfmjpYYnT7plg1pg3+XNwbX0zpAAcrC9xNzMa4jWfx6ZEIFJWoTL1MUgN7rmq6XEYGe1BeGClDKhFh0386wauJJWLT87Hwx6tIyy3CjO8v4ejdZO35nfFs1+amXmq99GxXL3g7WSE1twhbz0abejmkEVPkFGLKNxcQnpQDZxsZdj7fo87HO0GNbAc+IU+qpZOJV1IWX8CgopT+maJTCgDcDbADH98pZcyd93h+2rDzh9QpRYhRUFHKyDwdLGEnl0CpYohU1P6FLlw7/hfUAEf3aoPjOIzt4ImjS/phRLA7StQM609EYvT6M7gRl2nq5ZFqyC8uwV+3EgEAk7rQ6B4pH78jn41MgovR6ej7v5O4FJ0OW5kEP8zpRh12dWAhFmHpkFYAgK//eYjM/GITr4g0RsnZmoLUA0Uu3Oxk2LWgh86obm0Fao+X6hqfUB9k5StxS9sxHuJnPnlSPFdt3lESFaX0ju+UMnbGqocBOqXiMzSdUsYMOefxRalH6fl1aiIghFQPFaWMjOM4vQRuhidrDqoCGnGnVHlcbGX46rnO+Oq5TnCyluJ+ci7Gf3UWa/6+h0IldU2Zs0O3k5BXrIK3kxW6eFOoJKlYKzdbrJ/aESIOyCtWaTopFvRAD1/z6wiob0a3a4pAd1vkFJZgkzYwnhBjeZxZgGe/Po+HKXloai/Hrud7Ch8O60pf8Qn1wfmHqWAMaOlqY5ZB4nxXjYIypfROYapOKW1RKim7QG+3GZf+JFPK2NzsZLCRSaBSMzxKoxE+QgyNilImECR8W1e7AyOVmuE+P77nQZ1S5RkR7IGjS/thbIemUDPg69MPMeLLf3H1Ubqpl0YqwI/uTezUzOzyL4j5GRDoinXPdsDIYA/sXdiz1lkzRJdIxOGNYQEAgG1nY2i8hhhNXHo+nv3mPGLS8tHM0RK7FvSEjx4DugP1FJ9QH5zRju71MqNd90rjC2XUKaVfajVDSq5pMqUM0ylluvE9juPg56J5/aGwc0IMj4pSJiB8W1fLXIPY9HwUKFWQSUTwcTKvHVXMSRNrKb6Y0hFbZnSBq60MD1PyMGnzebx38C7yiynE15zEZ+TjXFQaAGBCJ08Tr4bUF2M7eGLjc53gTa+DejUgwBVdvB1RVKLGl8cfmHo5pBGITcvHlG8uIC69AM2bWGHXgp56/yDq6WAJW218QkP/kHkuUvN+au5FKSp661dGfjGUKgZAs4u1Mbnba7qZ9JUpVVCsQmquZoTcFON7wJMRvqgU6pQixNCoKGUCT1rIc8AYq/H1I7TFrFZuthCLqKOkKkNau+Hokn6Y1LkZGAO+PxuN4V/8i/PaIggxvX3XEgAAIX5OaGaigw9CiAbHcXhjWCAAYNflOBpdIAYVnZqHyV+fR0JmAXydrbF7QU94Ouh/XIfjOAS5N/wRvseZBXiYmgcRB3T3bWLq5ZTLTZsplVNYQl8S6hGfJ+VkLYVUYtyPeHynlCKnCCWqunci8l1StjIJ7CxNs/ENvwNfXTKACSHVQ0UpEwhwt4WIA9LzipGSU/N5ej6kM5BCzqvN3soCa59pj22zu8LDXo5HafmYuuUC3vntNnJp63OTYoxh7zXN6N6kzhRwTog56NaiCfoHuKBEzfDZ0fumXg5poCIVuXj26/NIyi6Ev6sNdj7fQ8imMQR+B+SGXJTid91r7+UAO7mFiVdTPlu5BaylYgBAMuVK6Q3feeZi5DwpQNOZJRZxUKmZ0OFUF/zOe82aWJks0uFJpxQVpQgxNCpKmYDcQowW2pyEu7U4MIpI4kPOqShVU/0DXHFkSV9M7abZMn7HhUcIXfcP/n2QYuKVNV5XHmXgUVo+rKViDGtLO6cRYi6WDdVkSx248Rh3HzfcD/HENO4n52DKN+ehyClCgJstfpnfw+A7hvEbzYQnNdwd+PiiVC8z3HWvNCFXSo8ZRI0d3yllinB7sYiDm7YYlphV97DzuHR+5z3jh5zz/F21mVKK3FpNthBCqo+KUiZSeoSvpsK143v8bZCasZVbYM2EYPw0rzuaOVoiIbMA07+7hOV7byK7UGnq5TU6e65ouqRGBHvASmqaFm1CSFltPe0xqp0HGAPWHokw9XJIA6IpSF1Aam4xWnvY4Zfnexilu6Oh78DHGMPZKPPOk+LxhRNFDhWl9MVUO+/x+C5HfWSFmTLknNe8iTXEIg55xSrq6CPEwKgoZSK1PTDKLy7BI+0WqdQpVTe9/J1x+NW+mBXiAwDYeTkOQz/7ByfCk027sEYkv7gEf95KBECje4SYo9eGBkAs4nAiXIErMbR7KdEPV1sZ3OzkCPa0x8/zu6OJtdQo99vKzQYcB6Tm1i4+wdw9UOQiJacIcgsROnk7mHo5leJzpahTSn9M2SkFAB7asHN97MDHd0o1M2GnlFQigreTpihGuVKEGBYVpUyktrkG95NzwZhmdtvYO2s0RNYyCVaNaYPdC3rCx8kKSdmFmLPtCpbuCkNmft1n4knlDt9JQm5RCZo3sUJXH/MMZCWkMWvhbI3JXTQF4/8diqARBqIXDlZS/Di3G36c1x0OVsYpSAGAlVSCFtrdOhtitxQ/utfVpwlkErGJV1M5N6GrpuEVB02F71BytTNtp5Q+Co18ppSpdt7jUa4UIcZBRSkT4TulHqbmoVCpqvb1whP50T3qktKnbi2a4O9X+mJe7xbgOGDf9QQMWfcPDt9JMvXSGrQ9VzWjexM7NYOIdpIkxCwtHtQSUokIl2LSceo+5e8R/XCykcHe0vhB3IHa4yc+CqEhEfKkzHx0DwDcbPU36kU0+E4pV1vTdEq5azu09NEpFZ+hzZQy4fgeQEUpQoyFilIm4m4nh4OVBVRqVqOWUD6cM8CNilL6ZikV4+1RrbF3YQj8XKyRklOEBTuuYtHP15CWS9/k6VtCZgHOabMvJnTyNPFqCCEV8bC3FMacPzkUAbWauqVI/RXkXvtMT3NWolLjwkPNiG3velCU0mf+ENFQaLvO6nunVHahElkFmoxXU47vAYCfizbsnIpShBgUFaVMhOM44cCoJjvw8d/sBVLIucF0au6IPxf3wcL+fhCLOPxxMxFD1v2Dgzce0+iKHu2/Fg/GgB6+TUz+TRghpHIL+/nBVibB3cRsIQeOkPoosIGGnd+Iz0JuUQkcrCzQuh4cIwqZUlSU0gvGmBAab7pMKW2nVHbddt+L02bnNrGWwlpm2g1w/F01nVKUKUWIYVFRyoRqGnbOGBM6pQIp5Nyg5BZivDksEPtfDEGguy3S84rx8i/X8cKPV4Vvb0jtMcaE0b1Jnb1MvBpCSFUcraWY39cXAPDZ0ftQqtQmXhEhtcPHH0Sl5KK4pOH8Hp/Tju6F+DnVi3F4Yfe97CL6wk8PMvKVUKo0z6OLiTJnhe63rKI6ddTyIedeJu6SAgBf7fhecnYRcmiHbkIMhopSJhRYw7BzRU4RMvOVEHFPKvfEsNo1c8CBRb3xyqCWkIg4HL6TjK9ORpp6WfXe1UcZiEnLh5VUjOFt3U29HEJINczp3QJO1lJEp+bh1yvxpl4OIbXi6WCJ/2/vv8Pjqs+88f99pqtLVht1WXKRjRtY2DHd4GCbhLLwbMguX8CEZQMPpGASFj+7CSHJXk42LGEhbMiPhWCy318geUiyCQnVhVAMBhs3sOUqWb1LM2pTz/ePOZ8jyVaZkWbmnDPzfl2XrgtLU8548OjMPff9vjMcFvgCckKN5LxroDwpYDT3yBsIoneIb/ZnS4xBzkmzwWbR5u1dQYYDkhR6TntmsSyoSQk5L9VBF31WihX5GaEi36nOQY2PhihxsSilocVFo7kG4XxKJIpXVfnpcFj1vVUlkdgsJtz/+QX4zhcXAwDq2hMrh0ILokvqmqVFmrdmE1F40u0W3Lt2HgDgP7Yfi2hJB5FejI1PSJQRviGvH5+c6QMAXFxtjKKUzWJCblpo8yJzpWZvNORcu83cNotJ3Qw+m1wpEXKudZ6UIHKlOMJHFDssSmloXkE6zCYJ/cO+sGbq60TIOUf3NFGRG/rEprWPJ0+zMewN4JWDoUya/7WyVOOjIaJI3PK5cpRkp6Dd5cELu+u1PhyiGRndwJcYHzJ9VN8LbyCIkuwU9VzFCAqUET7mSs2eKOwVaJQnJai5UrMoSolMqbIcffy/LKZTEqmzkkhvWJTSkMNqVqvv4XxaJ06eFrEopYni7NAnNq39swtwTHZvfNaGAY8fpTkpWFU5R+vDIaII2C1mfGPdfADAf+46CRczNsiAIs301Lv31dG9XEiS/vOkBKcSdt7BotSsdSqdUoUadkoBoe3iANA2i3Nl0SmllyU41fksShHFWlyKUh6PBytWrIAkSdi/f/+Ulx0ZGcG9996L3NxcpKen46abbkJ7e3s8DlMTi4rCX00sTp4WOvW/VSURiU9/XCN+DHr8Gh+NcYnRvZsuKDVEGCsRjXfj+SWozk9D35AP//XXU1ofDlHExLKYcM69jMBoeVJCoVrA8Gh8JMY32imlbVFKnCvPtPtNlmU0ikwp3YzviaIUM6WIYiUuRakHH3wQxcXFYV32/vvvx5/+9Cf89re/xdtvv42WlhbceOONMT5C7Yii1GfTfFrnCwTVCj0372kjw2FFupJ/xG6pmWnpG1ZPnm+6gKN7REZkMZvwrasXAgD+693T6BrgG0oyloXODEgS0DXgUTtMjKpn0KueQ15kkDwpQRSl2t3slJqtDpfSKaXx+J4zS0wVzOw57Rn0YsgbyissydZJUUoZ36vvGuTmWUooQ14/rvmPd/DPvz8Ej1/bnNCYF6VeffVVvPHGG3j00UenvWx/fz+effZZPPbYY7jyyiuxcuVK/PKXv8T777+PDz74INaHqonRT+umLkqd6hyELyAj3W7RzScHyUh8AtTCXKkZ+f0nzZBlYPXcOSg3UO4FEY23YYkTy0qzMOQN4CluJCWDSbVZUJkbik842mbsEb7dJ7shy6HzyXyNR7cipRalZpE/RCGisKdl0DkwplNqhs9pozK6V5hp181Sp6JMB1JtZviDMs4oeVdEiWB/Yx8+a3Vhx9EO2MzapjrF9N7b29tx11134Ve/+hVSU6d/A7p37174fD6sW7dO/V5NTQ3Ky8uxe/fuWB6qZsQGvvquQQx7J69QipOm0Kd7HHnSSpHyqc1stookK1mW1dE9BpwTGZskSXhwfQ0A4P/94Iy6wpvIKBYVhfehoN6J7mOjdUkBgDMrVEBhp9TsiU4prYPOR0cyZ/acNqmje/r54NJkklClZACf5AY+SiAf1/cCAGor52heX4hZUUqWZWzatAl33303amtrw7pOW1sbbDYbsrOzx32/sLAQbW1tE17H4/HA5XKN+zKS/Aw7ctNsCMrAsfbJsw1EyDlH97RVLDqlOL4XsX1nenG6axCpNjOuWVqk9eEQ0SxdMj8PF1XnwhsI4vG3jmt9OEQRqVHyOY8aPFfq/ZOhotQl83M1PpLIFWQwUyoaZFlWx1D10inV2j8CWZYjvn5jjxJyrrOpEOZKUSL6qL4HAHBhZY7GRzKDotRDDz0ESZKm/Dp69CiefPJJuN1ubNmyJRbHrdq6dSuysrLUr7KyspjeX7RJkhTWFpijys9YlNKWU/yy5fhexP7v3mYAobGfNCWbi4iM7dvrQ9lSv9vXhONTfLBCpDfhZnrqWWPPEBq6h2AxSVg113hFKXFO1T3oYVbPLPQN+eBV/v60HuEUz+mwLwDXcORLgUTIuV427wmiKHWCnVKUIPyBIPY1hDqlVlYYsCj1wAMP4MiRI1N+VVVVYceOHdi9ezfsdjssFgvmzZsHAKitrcXtt98+4W07nU54vV709fWN+357ezucTueE19myZQv6+/vVr8bGxkgfkubCaSGvE51SRdy8p6ViEeDI9cURGfEF8MqBFgAc3SNKJOeX5+DqxYUIysC/v3FM68MhCps49zrZOQCv35gFEdEltaIsW13EYiRzUm2wmiXIMgwfOK8lMf6Yk2qF3aJtDpPDakZOqhUA0OqKfKqgqVd0SumrKDWvQHRKsShFieFomxuD3gDS7Ra1c1hLEf8Gy8/PR35+/rSXe+KJJ/DDH/5Q/XNLSwvWr1+Pl156CatXr57wOitXroTVasX27dtx0003AQDq6upw5swZrFmzZsLr2O122O3GCnY822in1MSfMvcP+dCizGYvZKeUpoqyRacUx/ci8fqnbXB7/CjJTsHnDPhpLhFN7lvrF+LNI+147dM2HGjsw/KybK0PiWhaJdkpyHBY4B7x42TngHouZiTvnugGAFw0z3h5UkAoq6cgw4HmvmG0u0ZQrJNta0ajl817gjMrBb1DPrT2j0T8ZrepR2RK6ev/hdHxvQHIsqx5/g5NzOsP4n/2N+PyBfma56vp3V6lS+qCihyYTdr//xyzTKny8nIsWbJE/VqwYAEAoLq6GqWloU6J5uZm1NTUYM+ePQCArKws3Hnnndi8eTN27tyJvXv34o477sCaNWvwuc99LlaHqjnxgn2kzTXh/HWdMhJRkp2CTIc1rsdG4xXNctVtshIB5zetLIVJBy98RBQ9Cwoz8DfnlwAAfvJ6ncZHQxQeSZKwSORKGXADXzAo430l5PwSgxalAKAgUwk7Zwf6jIm/O61H94SZbuALBuXRTimdje9V5KbCJAHuET86B9jVp1ePvXkM3/6/B/HoGzwXmY6aJ6WD0T0gxtv3puPz+VBXV4ehodGtPT/96U/xxS9+ETfddBMuu+wyOJ1O/O53v9PwKGNvXkE6rGYJ7hE/mifowBEnS8yT0p74RTvg8cM14tP4aIyhtX9Y3Q500wUlGh8NEcXC/esWwGqW8O6JLvWNMpHe1ajxCcbLQ6trd6N70IsUqxkrDNyd6FS6GdpdfKM/Ux1uvXVKjYadR6JzwANvIAizSVLPt/XCYTWrhTLmSulTp9uDbe/XA+BzNB1ZltWiVG3lHI2PJiRuRanKykrIsowVK1ac870rrrhC/Z7D4cBTTz2Fnp4eDA4O4ne/+92keVKJwmYxqW2hE50Yie9xdE97aXYLMh2hqdeZrrtNNr//pBmyDKyqnIOK3DStD4eIYqBsTir+flU5AODHr9fNaOsSUbyFs2hGr95Tir+rq+bAZtH0M+ZZEYWUNnZKzViH8nen9eY9oUg8pxFuqm5URvecmQ5YzPr7f3oeN/Dp2tNvn8SwLwAAEzZ50Kim3mG0uzywmCTdfKihv3/xSWrxFCdGdaJTyoB5B4lIZB608AVvWrIsq6N7DDgnSmz3XTkfKVYzDjT24fVP27U+HKJpiQ50I3ZKiaLUxdXGHd0DRotSHN+bOb12SrVF2P02unlPX3lSQrUIO2cXju60u0bw3x80qH/ucHsMu8AiHj5uCHVJnVeShRSbtssRBBaldGKyT+uCQVndvLeInVK6MNO25GT0SWMfTnUOIsVqxjXLirQ+HCKKofwMO+68ZC4A4PG3uImP9G+hMwOSBHQNeAy1/c0XCOLD06E3FRcbOE8KAAqZKTVr7XrrlFLyVyPvlNLn5j2hOj/U7c8NfPrz1M4T8PiDWFmRA7vFBFnma8pUPq4PhZzrJU8KYFFKNyYrSjX3DWPQG4DNbEJlHkef9EANO2en1LREl9TGJU5DrqsmosiIotTRNjdz9xRPPfUUKisr4XA4sHr1anW5y3RefPFFSJKEG2644ZyfHTlyBNdddx2ysrKQlpaGCy+8EGfOnFF/PjIygnvvvRe5ublIT0/HTTfdhPZ2dq+dLdVmQaUyVm6ksPP9jX0Y8gYwJ81m+LxRZkrNnvi708u2sZl+eNukdkrptSjFTik9au4bxot7GgEAD1y9ACXKRIsIzadziaKUXvKkABaldEOEbTb0DGHQ41e/L4pUoTB0Pl16UMxOqbCM+AL404EWABzdI0oWOWk25KbZAABnuoemuXTie+mll7B582Y8/PDD2LdvH5YvX47169ejo6NjyuvV19fjW9/6Fi699NJzfnby5ElccsklqKmpwa5du3Dw4EF85zvfgcMx+ob0/vvvx5/+9Cf89re/xdtvv42WlhbceOONUX98iWCRGnZunKKUGN27qDrX8BttRSGlnedUMyLLstrlp5dOKVGUco/4MTDmPc10RKdUaY5Ox/eUolRL/8i492qkrZ/tOA5vIIjPVc3BRdV5jFmZRv+QD3XtoSms2kp2StFZ8tLtyM+wQ5ah/o8CQB3dM/onYYmkSHmxY1Fqan891gn3iB/FWQ58ripX68MhojgRn3KL0Npk9thjj+Guu+7CHXfcgcWLF+Ppp59GamoqnnvuuUmvEwgEcMstt+CRRx5BVVXVOT//53/+Z1xzzTX4t3/7N5x//vmorq7Gddddh4KCAgBAf38/nn32WTz22GO48sorsXLlSvzyl7/E+++/jw8++CBmj9WoapyhTvWjBsqVUvOkDD66B4wpYHj8fKM/A31DPngDoeycgkx9FKXS7RZk2CNfCtSo806psR+6nO5i2LkenOkewm8/Dk1lPHD1QgBAcXboNYVFqYntPRMa/Z6bl4a8dH28ZgAsSunKRCN8R0VRqohFKb0Qa2pbIpyVTzY760KdAJ9fXGj4T3KJKHzlyhuKM0lelPJ6vdi7dy/WrVunfs9kMmHdunXYvXv3pNf7/ve/j4KCAtx5553n/CwYDOLPf/4zFixYgPXr16OgoACrV6/GH/7wB/Uye/fuhc/nG3e/NTU1KC8vn/J+k5U49/rMIJ1Sgx4/PjnTBwC4JAGKUul2C9KUoF1mwEROhJznpFpht+gjsBgYE3YeZlHKHwiqH/bqNVMKGDPCx1wpXXhix3H4gzIunZ+HC5VRtJLs0P8/3MA3sY/E6J6O8qQAFqV0ZaIW8iNi856Tm/f0QhSlWvtGuPZ8ErIsY+fRTgDA2poCjY+GiOKpIpdFKQDo6upCIBBAYWHhuO8XFhaira1twuu8++67ePbZZ/HMM89M+POOjg4MDAzgRz/6ETZs2IA33ngDf/M3f4Mbb7wRb7/9NgCgra0NNpsN2dnZYd8vAHg8HrhcrnFfyUB0op/sHDDEtqY9p3vgD8oon5Oq246SSBVmMVdqpkZDzvWRJyWM5kqFVxho7R9BICjDZjbpZgxxItUFoQy6E8yV0typzgH8bt/4LilgtFOKRamJfVwf6pS6UEd5UgCLUrqyWO2UCnVHjfgCqFfaQzm+px8i6HzYF4BrmK3mE/ms1YU21whSrGaO7hElmTJ2Ss2I2+3GrbfeimeeeQZ5eRN3wASDoaLJ9ddfj/vvvx8rVqzAQw89hC9+8Yt4+umnZ3X/W7duRVZWlvpVVlY2q9szitKcFGTYLfAFZEN0P4yO7iXO79bCDFGUYqdUpESnlF5G94SiCDulxOheSU6Krrvr2SmlH/+x/TiCMnBVTQFWlGWr3y9hptSkPP4ADjT1A9BXnhTAopSuiBbyo60uBIMyjrcPICgDc9JsyNfxpwbJJsVmRk6qFQBH+Caz82hodO/ieblwWPXTTk5EscfxvZC8vDyYzeZztt61t7fD6XSec/mTJ0+ivr4e1157LSwWCywWC1544QX88Y9/hMViwcmTJ5GXlweLxYLFixePu+6iRYvU7XtOpxNerxd9fX1h3a+wZcsW9Pf3q1+NjY0zfOTGIkmSGpFghA187yZQnpTgzGJRaqb02yml5K+G+Zw26TzkXKguEBv4mCmlpWPtbvxRWaZ0/+cXjPtZifL/UHPfMCdaznK4uR9efxC5aTbMzUvT+nDGYVFKR+bmpcFmNmHQG0Bj75B6crSwMAOSpN9PDZKR+suWRakJ7azj6B5RshJFqebeYfgD+h+HihWbzYaVK1di+/bt6veCwSC2b9+ONWvWnHP5mpoaHDp0CPv371e/rrvuOqxduxb79+9HWVkZbDYbLrzwQtTV1Y277rFjx1BRUQEAWLlyJaxW67j7raurw5kzZya8X8FutyMzM3PcV7JYdFanul51DXjUrNE1CdSFLLp82liUipjYvFdo8E6pJp2HnAvzlE6p012DCARjV/CQZRnvHu/Czb/YjaXfex2Hm/tjdl9G9PhbxyDLwIbznFhSkjXuZ6LIPeILonfIp8Xh6ZaaJ1WZo7vagkXrA6BRVrMJ8wvT8WmLC0da3Qw517HiLAeOtLq4gW8CvYNefHIm9KJ3xUIWpYiSTWGmAzazCV4luFbvbzJiafPmzbj99ttRW1uLVatW4fHHH8fg4CDuuOMOAMBtt92GkpISbN26FQ6HA0uWLBl3fZELNfb73/72t3HzzTfjsssuw9q1a/Haa6/hT3/6E3bt2gUAyMrKwp133onNmzdjzpw5yMzMxNe+9jWsWbMGn/vc5+LyuI1G5HYe0XnY+fsnuwGE4h5ydbQ1abacmaE3kR3MlIrYaKeUvv5/iDTovLHXGJ1SxdkpsFtM8PiDaOwZQmWUu01kWcauY514YvtxdaEBAPzxQMs5xZdk9VmLC3851AZJAr75+fnn/NxuMSM/w45OtwctfcOYo2xMpNE8qdoKfeVJASxK6c6iokylKOVSO6UWMeRcd4qyR8POaby3j3UiKIdy0MRcNxElD7NJQumcFJzqHMSZnqGkLkrdfPPN6OzsxHe/+120tbVhxYoVeO2119Tw8zNnzsBkiqxp/W/+5m/w9NNPY+vWrfj617+OhQsX4uWXX8Yll1yiXuanP/0pTCYTbrrpJng8Hqxfvx7/+Z//GdXHlkhGF83ou1Pq/QTMkwJChWyAnVIz0aF2SulrfE/tlArzOW1Uxr31vHkPCP1+q8pPx5FWF052DkStKCXLMt78rB0/23kCB5XMH7vFhKUlWfi4oRcHm/qicj+J4KdvHQMAfGFp0aSLwEqyU9Dp9qCpd5jFPEUwKOPjhtFOKb1hUUpnRlvIXahTOqUWMuRcd0TYOTOlzrVDyZPi6B5R8iqfk6oWpS7W+mA0dt999+G+++6b8Geiu2kyzz///ITf/8pXvoKvfOUrk17P4XDgqaeewlNPPRXuYSa1hc4MSFJoPK7T7dFljqcsy3jneOLlSQGjBRVmSkVO7ZTS2/heZug8uWfQixFfYNp80SalU8oIH2JU56epRamrFhVOf4UpBIMyXj3chid3HFcnZFKsZty6pgL/cOlcdA94sfE/3sHh5lDesJ5D4OPhYFMf3vysHSYJ+Oa6BZNeriQ7Bfsb+xh2PsaprgH0DfngsJpwXrH+CnUsSumM+LTuw9M96B/2QZKABYUsSumN+ASInVLj+QNBvH0slCd1JYtSREmLYedkJKk2Cypz03C6axBH21zIz8jX+pDOcaZnCM19w7CaJayaq7/Ri9kQeUgdLg9kWdZd1oleybI8un1PZ0HnmSkWpFjNGPYF0O4aQUXu5B1FHn8A7e7Q+bTex/eAMRv4ZhF2HgjKeOVgC3624wSOd4Q2+aXbLbhtTQXuvGSuOp47J9UGh9WEAY8fp7oGMU8JWk9Wj70Z6pK6YUXJlH8XxcpEC4tSo0Se1IqybNgs+osVZ1FKZ8SoXv9wKJitMjcNKTZuL9Mb0SnFVvPx9jf2oX/Yh6wUK84fs56ViJILi1JkNDXOjFBRqtWNS+frryj13olQntT55TlItSXW6bsoqHgDoWBiZsCEp3/YB68/tExCb919kiTBmeXA6a5BtPZPXZRq7h2GLIc6hHIN8NyLDXwnOgcivq4vEMQfPmnGf+46idNdoaJWpsOCOy6eizsurkR26vjHbzGHulr2NvTiUHNfUhel9jb0YlddJ8wmCV+/6twsqbFEfEgzi1Kqj5Q8qQsr9fmhRmL9VksAOWk2ODMdarGjhqN7ujS2As9P9UaJ0b3LFuTDYtZfFZ6I4kMtSnWzKEXGsKgoE68ebtNt2Pl7Ik+qOrFG9wDAZjEhN82G7kEv2vpHWJQKk+iSyk61TjsepwVnZqgoNV3YeaM6updiiPNpsYHvRMdA2O8BvP4gXt7XhP/cdQKNPaHHm51qxT9cMhe3XVSJTId10usuLQkVpQ429eNvzi+NzoMwoJ8qXVL/64LSabO8ipWiFDulRn2sdEqtrNBfnhTAopQuLSrKUItSzJPSJ5F/4PHzU72xRFHqyhr9fcpMRPFTnstOKTIW8SHgkTb9hZ17/AH8VRmNv2R+4hWlgNB5VfegF+3uESwGF/yEQ6+b9wQ16mKaolRTb+j3RKnOQ86FuXlpkKRQp1rPoHfKTZgjvgB+83Ejnt51Ei3K30Neug13XVqF/+dzFUizT/9WfHlZKP9HBKAnow9PdePdE12wmiXcd+W8aS9frHZKcaIFADpcIzjTMwRJAi5gUYrCtagoEzvrQicfk20VIG05lBbj7kEv140qWvqGcbTNDUkCLl/APCmiZCY2KPUP+9A/5ENW6uSfAhPpgVg0c6LDDa8/qKvMjfdPdMPt8aMw056wo/GFmXZ81gq0T1PAoFEdLn1u3hOcYgPfNEuBROdQmQHypAAgxWZGSXYKmnqHcbJzcMKi1LA3gP//njP4xdsnx+R+2XH35dX4u1XlEUWzLC3JBgB82tIPfyCYdJMIsizj35UuqS/VloUVhi+yyboGPGEF7Sc6sXWvxpk5ZVeelliU0iFxYhT6b3ZK6VVRtkNtNee6UWBnXahL6vyybBbpiJJcmt2CvHQbuga8aOwdQlYqXyNJ30pzUpBht8Dt8eNU14CuPhT8y6FWAMCG85wJu31LFDDalUILTU+Eg+stT0oIt1OqUemUMsLmPaE6Px1NvcM40TEwbvHAoMePX33QgP965xS6BrwAgOIsB+65ohp/W1s2o+JIVV4a0u0WDHj8ONGpr9emeHjvRDf2nO6BzWIKq0sKALJSrEi1mTHkDaC1fwRzpxn3S3SjeVL67JICgOQqtRrEstIsSFJo1rjMIK2syUiEnbdO8wlQsth5lFv3iGgUw87JSCRJQo3yQaCecqV8gSDe+KwdALBxaZHGRxM7IuycC2TCp/9OqfCWAjUpmVJGGd8DoAaOn1TCzl0jPvxsx3Fc/OMd+NGrR9E14EXZnBRsvXEpdn17LW5dUznjbh2TScKSklAh6mBjco3whbqk6gAAf7+qXH3vNR1JkpgrNYbIk6rVacg5wE4pXarITcMv/p+VyE23JewnYolAfALUwlZzjPgCagjrWhaliAihotS+M31oYNg5GUSNMxMf1ffiaKsbOF/rownZfbIb/cM+5KXbdLs1KRpEp1QHi1Jh63AnSKZUj8iUMsb4HhDqlAKAQ039+Ombx/DL907DNeIHEMqcunftPFy/ohjWKI3aLSvNxgenenCwuQ9furAsKrdpBLvqOvHJmT44rCb877XVEV23JDsFJzoG0Nyb3EWpAY8fn7aEipm1Os2TAliU0q2rz3NqfQg0DbVTihV4fHi6B8O+AAoz7VhclFxtxUQ0MXZKkdGI+ITPdNQp9erhNgDA+vOcMCfwB5WFmaHCCjulwteu+06p0HF1DXjgCwQnLNAMevzoHgyNuRlrfC80Dranvgd7lNGo+QXpuO/KefjisuKo/1tdqsSEHEqisHNZlvGYkiV125pKtZsyXKNh58n9Pm3/mT4E5VCRTvyd6BGLUkQzVJwd3idAyWCnsnVv7cICQ6zzJaLYE28wGlmUIoMQ43tHdbKBzx8I4o1PQ0WpjUsSd3QPGC2sMFMqfHrvlJqTaoPNbII3EESH24OSCd4Qi4JBpsOCrBR9BjBPZEFhBiwmCf6gjEVFmfjalfNimvm2vDQbAHCkVX+LGGLlzc/acai5H6k2M756WVXE1y9R3qcl+/jexw2hommtjvOkABaliGZsNFMquYtSsixjhyhKcXSPiBQVuaFPktkpRUaxsDADkgR0uj3oGvAgb4pV7/Gwp74H3YNe5KRasboqcUf3gNGiVPfg5F01NEqWZd13SplMEgqz7GjsGUZb//CERalGdXTPOF1SAJCTZsPzd6xCQJZx2fy8mH8gWzYnBVkpVvQP+1DX5sbS0sReHhIMjnZJbbqocsINh9MpyWGnFGCMPCmAQedEM1akrrodQTAoa3w02jnZOYgzPUOwmU24ZF6e1odDRDohxvea+4bhCwQ1Phqi6aXZLahQ/r892qp9t9Srh0JdUlcvdiZ8kWZOqg1WswRZDhUFaWquYT+8/tDrql637wFAUebUH+CKolTZHP2OFU3mkvl5uHxBflwmBCRJwjKlEHWwuS/m96e1Vw+34WibGxl2C/5xBl1SAFCcxaBzfyCIfWdCRSk9b94DWJQimrHCTAckCfAGguo8fDLaVRfqklpdNQdpdjZfElFIQYYdNosJgaCM1r7k7igl4xC5Ulpv4AsGZbwmRveWJn7OqMkkcQNfBNqV0b2sFOuMt7rFQ+GYD3An0qiEUHPb+PTUolSCb+ALBGX89K1Ql9RXLpmL7FTbjG5H3b6XxM0DR1rdGPIGkOGwYEFBhtaHMyUWpYhmyGYxqa39k/2yTQY7xuRJEREJJpOEMqV9niN8ZBQ1TqUo1aZtUWrvmV50uj3IcFhwUXVydCEXKGHn3MA3vQ51dE+/XVLA9Bv4mnpFpxSLUtNZWpINADjYnNhFqT8daMGJjgFkpVhx56VzZ3w7ziwHTBLg9QfRNZic3ZcfKSH8KytyYpZ3Fi0sShHNQrHyy7alPzlbQ90jPuw5HXrBY54UEZ2NG/jIaBYpYedHNB7f+8uhVgDA5xcXJkWoMQA4M6fuqqFR7S4Rcq7PPClhuue0sSd0/lyaY7zxvXgTnVLH2t0Y8QU0PprY8AeC+I/txwEA/3hZFTIdMw+/t5pNat5aS5J2a4uQ8wt1nicFsChFNCti3W1rks4rv3u8C/6gjLl5aZibl6b14RCRzoiw84aeQY2PhCg8YnzvRIdbsyy0YFDGa4dDo3vXJPjWvbHUDXzMlJpWh/J3VGCYTqmJz5Mb2SkVtqIsB/LS7QgEZXzaom0nZ6z8/pNmnO4axJw0G26/qHLWt6eO8CXh+zRZlkdDziv0nScFsChFNCvqBr4kbTXn6B4RTUW80WhkpxQZRGlOCjLsFvgCMk52DmhyDAea+tDaP4I0mxmXzE+O0T1gTFGKnVLTMkyn1BSZUv3DPrhH/ADYKRWOsWHnh5r6tD2YGPAFgnhiR6hL6quXVSE9Cjm1YuNjc2/yFaUae4bR4fbAapawvCxb68OZFotSRLNQnC06pZLvBCoYlLHrWCcA4EqO7hHRBDi+R0YjSRJq1BE+bboRXlW6pK5aVKjrEOtoE/lIIsTbCBq6B9E1EP/OLrGhUP+ZUqGiQLvbg8BZYdPiw4rcNBtSbVyUE46lJWIDX+LlSv324yY09gwjL92O29ZURuU2RadUcxJ2Sok8qaUlWYb4PcKiFNEsqJ1SSZgp9WmLC51uD9JsZlw4V/9toUQUf2pRqptFKTIOEXZ+VINcKVmW1Typa5Jg695YRsuU6nR7sP7xv+L6n72HYW98M36M0imVn2GH2SQhEJTPKd6JkPNSju6FbXmZUpRqSqyilMcfwM+ULqn/fUU1UmzRKaKUZItMqeR7nybypGoNkCcFsChFNCtiVj4ZA/TE6N7F8/Jgt+i/Ak9E8SeKUq4RP/qGvBofDVF4RK7UZxp0Sn3a4kJT7zBSrGZcviC5upALlKKU2Cynd0daXRjxBdHcN4z/98OGuN53h0E6pcwmCQUZE2+qFiHnZRzdC9sSpVPqZOcABjx+jY8mel76qBEt/SNwZjrw96vLo3a7JTnJ3CllnDwpgEUpolkpUtpC210jCJ7VlpzodtSFilIc3SOiyaTYzMhX3pBwhI+MQozvHW2Lf6eU6JJaW5MftW4BoxD5Q26PH4MGeMPd0D26wOHnu07G7ZhlWTZMpxQwZinQWUWpJoacR6wgw4GiLAdkGfg0QUb4RnwB/GzHCQDAvVfOi+qoWbIGnfcOenGiI5SJuJJFKaLEV5hhh0kC/BO0JSeyrgEPDiohi2tZlCKiKTBXioxmYWEGJCk0nhXP3+1jR/c2JtHWPSHdbkGaUohrN8ACmfoxY8ndg15s210fl/t1jfjh8Yc2Q+p9+x4wOlXQdlbURaMSPs2Q88iIsPNEGeH77w8a0OH2oCQ7BTfXlkX1tkVRqnfIhyGv/gvd0bK3IdQlVZ2fhtx0/b9GACxKEc2KxWxSP6VqMUgGQjS8XdcJWQbOK85Ut+UQEU2ERSkymjS7BRXK/7fxzJU62uZGffcQ7BZT0n7gUygKGEYoSnWFOqXEeMz/76+n4B7xxfx+O5S/m0yHxRABxs7MiTdVi6Dzshx2SkViWWk2gMQIOx/y+vH02ycBAF+7ch5sluiWJjIdVmQ4QiH6ydQt9ZGSJ3WhQfKkABaliGZNbUtOohc7ju4RUbhEUaqRRSkyEJErFc8NfGLr3uUL8qOyDt2ICjOMkytVr4zv3XflPFTnp6FvyIfn3q2P+f2O5kkZ40NBZ9a5mVKyLKNJ6ZTi+F5kxAa+Q8rEgpFte78BXQNelM9JxU0rS2NyHyXqBj79F7qj5WMlT8ooo3sAi1JEs1acPfGsfKLyBYL467FOABzdI6LpiaJUAzfwkYGIDXxH2uJYlBKje0m2dW8sp0E6pQJBWQ3qrs5PxzfXLQAA/Ne7p9A/FNtuKTVPygCjewDgVDdVjz6n3YNeDPsCkKTR82gKjxjfq+8eivn/a7HkHvHhF38NdUl946r5sJpjU5ZItlypEV8Ah5TRTnZKESWRIvWXbXK82O1t6IV7xI85aTYsV1qIiYgmU57L8T0ynkVK2PmROI3vnehw43jHAKxmCVctKozLfeqRKLToPVOqtX8Y3kAQVrOE4uwUfGFpEWqcGXCP+PFf756K6X23K11khQYIOQfGZkqNPqeic7Yww8ENzhHKTrWpH/YcMvAI3/Pv1aNvyIeq/DTccH5JzO5H7ZTqTY73aYea++ENBJGXbkdFrnG6EFmUIpol8cs2WTKldh4Nje5dviAfZpOk8dEQkd6Jk+eWvmH4AkGNj4YoPGJ870SHOy7/3756KDS6d+n8fGQ6rDG/P71yKiNpei9Kic7PsjmpMJskmEyS2i313Lun0TPojdl9d7hDfzf5RumUyhwtSslyaFN1ozq6x5DzmVgqws6b+7Q9kBnqH/bhmXdCxdtvrlsQ0/cTydYp9VG9yJPKgSQZ530ai1JEs6R2SiXJi90OpSjF0T0iCkdBhh12iwlBOXlOCsn4SnNSkG63wBeQcbJzIOb39xclT2rDkuQd3QNGc5LadZ4pJfKkKnPT1O+tP68QS0oyMegNqGNJsdBhsE4p8Zx6A0G1WNfUGyrqlTLkfEaWKblSBxuN2Sn17Dun4BrxY2FhBr64NLabRsV4aHOSnH+IPKlaA43uASxKEc1aUfa5bcmJqrFnCMc7BmA2Sbh8fr7Wh0NEBiBJEnOlyHAkSUKNMzTCF+sNfKe7BnGk1QWLScLVi5N3dA8YLWDo/ZxKvJaNHY+RJAmbPx/qlnrh/QZ0umNTWBOdUkbJlLJZTMhT1tKLXCmRx1WWw06pmRAb+Iw4vtc76MVz79UDAO7//HyYYjx1UZojgs4TvygVDMr4WOmUqjVQyDnAohTRrBUrnVLtbg8CQVnjo4mtXcrWvZXlOchKTd7xAiKKjChKMVeKjCReG/hePRwKOF9TnYvsVFtM70vvCpVCS4d7dNRLj+q7zu2UAoC1Cwuwoiwbw74Afr4rNt1SaqaUQbbvAaNRF2IsU+2U4ua9GVlSEnptau4bRteAvrsKz/aLv57CgMePxUWZuHpx7DtDxfheW/9Iwr9PO94xANeIHylWMxYXZ2p9OBGJS1HK4/FgxYoVkCQJ+/fvn/RyPT09+NrXvoaFCxciJSUF5eXl+PrXv47+fuNVgSl55GfYYTZJCARl9dOrRMXRPSKaCbHyu5FFKTKQGhF23hbbTqnXlNG9a2I8xmIEBcpImi8gxzSXabYm6pQCxndL/feHDVHPxpLl0XNNo4zvAaNbFUWnVJPIlOL43oxkOKyoyg8VRMWmNSPodHuw7f16AMDmzy+IeZcUEHpNsZgk+INyzLoX9eLjhlCX1Pnl2THbZhgrcTnaBx98EMXFxdNerqWlBS0tLXj00Udx+PBhPP/883jttddw5513xuEoiWbGbJJQmDG+LTkRDXsDeP9kNwBgbQ1H94gofOyUIiOKR6dUY88QDjb1wyQh6Uf3gNCoV25aqFtMr7lSwaCMhp6JO6UA4NL5ebiwMgdefxBP7TwR1ft2jfgx4gsF7xtlfA8Yv4EvGJTVTWilHN+bMbEB+6CBilK/ePskhn0BLC/LxlWL4vMBt9kkqUXR5r7EPgcxap4UEIei1Kuvvoo33ngDjz766LSXXbJkCV5++WVce+21qK6uxpVXXol//dd/xZ/+9Cf4/f5YHyrRjBVli7DzxC1K7T7VBY8/iOIsBxYWZmh9OERkIKKbgJlSZCQLCzMgSaFP92M1IiO6pFbPzUVuunGKDLFUqPMNfB1uD0Z8QZhNEkomKKqEuqUWAgBe3NMY1SybTqVLKtNhgcNqjtrtxtrYTql29wi8gdDfnyhWUeSWKmHnhwyyga/dNYJffdAAINQlFc/NcGKErzmB36cB4zfvGU1Mi1Lt7e2466678Ktf/QqpqTNrz+zv70dmZiYsFsuEP/d4PHC5XOO+iOKtSP1lm7ghejuPdgIIje4ZacUoEWmvfMz4np5zYojGSrNbUKH8vxursPO/KHlS1yxN7q17Y4lcKb0WpcTmvdKclElHZNZU5+Ki6lx4A0H8bMfxqN236B4rMFCeFDCmU8o1rIacF2c7YDHYiJGeLCtVNvAZpFPq57tOwuMPorYiB5fNz4vrfZcoRalE3gDc2j+Mpt5hmCTg/HIWpVSyLGPTpk24++67UVtbO6Pb6Orqwg9+8AP84z/+46SX2bp1K7KystSvsrKymR4y0YyJX7YtCVqBl2VZzZO6knlSRBQhsfbb7fGjb8in8dEQha/GGRrhO9oW/Q89W/uH8cmZPkgSsP48FqUEp1rA0Oc5lQg5r5hgdG8skS3124+bcCZKXaJqnpSBRveA0e631v6R0ZDzbOZJzcZ5xVkwSaHOPb1vq/QFgnh5XxMA4Bvr5sf9w21RlBJjo4lIjO4tKspEun3iZh49i7go9dBDD0GSpCm/jh49iieffBJutxtbtmyZ0YG5XC584QtfwOLFi/G9731v0stt2bIF/f396ldjY+OM7o9oNoqUDXxtrsR8sTveMYDmvmHYLSZcVB3fTzeIyPhSbGYUKNl7zJUiIxG5Up/FIFdKjO7VVuQYrvMllkTYuV4zpeqVAlNl7tRFldrKObhsQT78QRn/sT063VJqp5SBQs6BMefJ/SNqp1TZHOZJzUaKzYwFSpzGwaY+bQ9mGnsbeuEe8WNOmk2T9xHFSdAptbchVJS60IB5UgAQcRntgQcewKZNm6a8TFVVFXbs2IHdu3fDbh9fya+trcUtt9yCbdu2TXp9t9uNDRs2ICMjA7///e9htU6+et5ut59zH0TxVpyd2J1SoktqTXUuUmzGyTAgIv2oyE1Fh9uDMz1DWF6WrfXhEIVFbOCLxfjeq4dCRamNS7h1byzRKaXX8b2G7slDzs+2+fML8Ndjnfj9J024d201qvLTZ3XfHer4nrHe+ziVouuQN4BPW0LjZty8N3tLS7JwtM2NQ839uFrH3ZY7lfcRly/IhzkOG/fOJt6nRTPfTW9EnlStAfOkgBkUpfLz85GfP/3mrSeeeAI//OEP1T+3tLRg/fr1eOmll7B69epJr+dyubB+/XrY7Xb88Y9/hMNhrE8CKDmJT4ASNVNKFKXWLuToHhHNTNmcVHxU38tOKTKUxUqn1ImOAfgCwait2e5wj+AjZX33hiX6fTOpBf1nSimdUnnTF1VWlGVj3aICvHWkA/+x/Tj+48vnz+q+25XxPaN1SqXYzMhOtaJvyIePlY6OsjksSs3WsrJs/HZvEw7oPFdqZ53yPkKjCBCx5TFRi1LuEZ+6Jba2wpidUjHLlCovL8eSJUvUrwULQnPV1dXVKC0tBQA0NzejpqYGe/bsARAqSF199dUYHBzEs88+C5fLhba2NrS1tSEQCMTqUIlmTWRKdbg98AWCGh9NdPUP+9SWUOZJEdFMibDzaGWrEMVDSXYK0u0WeANBnOocjNrtvv5pO2Q5VLQQoyUUoufte7Isq51S02VKCfcr2VJ/PNCCY+2z67jrVDqljJYpBYx2S/UMegGMFgpo5paJDXxNfbpdItLUO4Rj7QMwmyRcPn/6xpZYEM0D7hE/XCOJl2v5yZk+BOXQSKzToBstNV154PP5UFdXh6Gh0Anqvn378OGHH+LQoUOYN28eioqK1C9mRZGe5aXbYTVLkOVQYSqRvHO8E4GgjHkF6fxUi4hmTC1KsVOKDMRkklDjDI3wHYlirtSrh7h1bzKiKNU14NXdB32dAx4MeQMwSeEXVc4rzsLGJU7IMvD4W8dmdf9G7ZQCRj/AFXhOOXs1RRmwmiX0DvnQpNMQbzG6t7I8B1mpk0fyxFKa3YJs5b5bEzBq5WMxumfQLikgjkWpyspKyLKMFStWnPO9K664AgBwxRVXQJblCb8qKyvjdahEETOZpNHNIgnWGsqte0QUDRW5LEqRMYlcqSNR2sDXPeDBB6e6ATBPaiJzUm2wmkO5M3r7oK9B6fQszk6B3RJ+xuY31y2AJAF/OdSmZipFSpZlNVPKkJ1SWaNFPJvFhPx04z0GvbFbzOqG0IM6HeET7yOuqNGmS0pQN/D1Jd45yEfK5j2j5kkBGndKESWSYuWXbYvO17JGIhiU8XZdJwDmSRHR7IhPxVv7h+H166v7gWgqYgPfkSiFnb/5WTuCciikmN0i5zKZpDEb+PR1TlXfFX7I+VgLnRn44rJiAMBP35zZJj63x49hXyjOxOidUqXZKTBpEHidiJaWhkb4Djb3aXsgExjxBfD+yVABXusPt4vVopS+XlNmyxcIYn9jHwDjbt4DWJQiihoxw5tInVIHmvrQPehFht1i6Oo7EWkvP90Oh9WEoJy4YaOUmEQnwtEoje/95XBo6x4Dzienhp3r7IM+0SklOj8j8c1182GSgLeOtONgU1/E1xddUhkOiyE3IY/NuillMTZqRnOl9NcptftkNzz+IIqzHFhYmKHpsYhOqZYEO//4rMWFYV8AWSlWzJvldk8tsShFFCVFyrrRVp2dQM2GmAO/dEFe1DYOEVFykiSJuVJkSCJTqsPtQffA7MbJ+oa8eP9EFwBgI4tSk9Jr2Hl998w6pQCgOj8dN5xfAgB47M3Is6U6XCJPyphjb2M7pcoYch41y0qzAYSKUsGgvsLO1e3dNQWQJG0749TxPZ1mb83UR2qeVI6huw/5LpMoSsT4Xmt/4rzY7eToHhFFUfmc0Bs5FqXISNLsFrUzZrYjfG9+1g5/UEaNMwNVBv5UO9ZEUarNpc9MqZl0SgHAN66aD7NJwq66TnWzcbhEyLn4uzGaceN7OeyUipb5hemwW0xwe/xq0VQPZFkeLUrp4H1EcYJ2Sn2s5EmtNPhEC4tSRFGiju8lSKdUh2sEh5pDrcBX6OCXCREZn9oppaMTZ6JwLBIjfLMMO39NGd1jwPnUROGlQ0edUrIsj2ZK5UXeKQUAFblp+F8XlAIAHnuzLqLrivE9o3ZKjS2mlc1hp1S0WM0mLC4OvT6J83Y9ONExgOa+YdgsJlw0L1frw0FJTuIVpWRZxscNoU4pI+dJASxKEUWNGnSeIAF6u5QuqeWlWcg36AkQEelLufJGhJ1SZDRiA99ns8iVco348M7x0OjeNUs5ujcVZ1bovKNNR0WpnkEv3B4/gNEC+0x87ap5sJolvHeiW93CGI52dfOeMTulMhxWZDgsAGb390fnWq6M8B1o1E9RSnRJranKRarNovHRAMXZovtyBL5AYixbaegeQteAFzazCUuVbDGjYlGKKEpEplTXgCchNkuNnQMnIoqG8lyRKZU4n1RSchAb+I7OYnxvx5EOeANBzCtIx3yNQ3/1rlCH2/fqldG94iwHHNaZB42X5qTi5gvLAISypWQ5vBygDmV8z8gfFD60sQa3fq4CS4qN/QZab0RB4pCONvCJ9xFab90T8tLssJlDy1b09LoyGyJPallp1qxek/SARSmiKMlNs8GmhIEb/cXO6w/iXSWIVQ9z4ESUGMSn4409Q2G/ESPSg8VKUepEx8CMP2V/9XArAOAaBpxPqzBLFKX0kynVoIwdV8wg5Pxs962dD5vFhD2ne/DeifC6pToM3ikFALesrsAPblhi6EBmPVpWGipKHW52IaCDsPP+YR8+VjLT9PI+wmSS1G6pRAk7F3lStQYf3QNYlCKKGkmS1Fwpo88rf1zfgwGPH3npdsO3gxKRfohw2wGPHz2DXo2Phih8JdkpSLdb4A0Ecaoz8ky0QY9fHYvfwDypaYnCy4DHjwFlZE5rolOqMm/2o2fOLAduWV0OAPj3N+vCKtKLTimjZkpR7FTlpyPNZsawL4ATHQNaHw7ePd6FQFBGdX6a2iGtB2rYeYIspfqoYXTzntGxKEUURWKziJ4yEGZCtNxesTCfn2YRUdQ4rGY4lTebzJUiIzGZJNQ4QyN3R2aQK7WzrgMefxCVualYVMTRvemk2y1It4dyaPTSfR7NTikAuOeKajisJnxypk8tWE5GlmXDZ0pR7JhNEs5TPkQ+2NSn7cFAf6N7wugGPn28psxG94BH/YBkJYtSRDRWorzY7ajT5y8TIjI+dQMfi1JkMCLs/MgMNvC9ekjZure0CJLED3vCUZAZ6gjSS1FK7ZSKUudHQYYDt62pBDB9ttSAx49hXyB0vUx2StG5lqm5UtqGnQeDMt4+ps9c2hLlfVpTAozv7VXGI+cXpCMnzabx0cwei1JEUSQ6pVoN3Bba0D2IU52DsJgkXDI/T+vDIaIEUzYmVyoZPPXUU6isrITD4cDq1auxZ8+esK734osvQpIk3HDDDeO+v2nTJkiSNO5rw4YN4y5z7NgxXH/99cjLy0NmZiYuueQS7Ny5M1oPKWmJsPMjEYadD3sD2Kl82HMNR/fCJroqO3SSKxXtTikA+OplVUi1mXGouR9vfNY+6eVEl1SG3aKLTWakP8vKsgEAB5q0LUodbO5H14AX6XYLLtRZ1lGJ2jxg3PdpgsjsSoQ8KYBFKaKoKlIzpfTxqd5MiJbbCyvnINNh1fhoiCjRVOQmT6fUSy+9hM2bN+Phhx/Gvn37sHz5cqxfvx4dHR1TXq++vh7f+ta3cOmll0748w0bNqC1tVX9+vWvfz3u51/84hfh9/uxY8cO7N27F8uXL8cXv/hFtLW1Re2xJaMap9jAF1mn1NvHOjHkDaA0JwVLSjJjcWgJSYyp6SESoW/Ii74hH4DR17BoyE23446LKwEAP33zGIKThFSreVLskqJJiE6pI60uTbeAi/cRl87Pg9Wsr1JDcQIVpcTmvQsrjT+6B7AoRRRVRVmhF7s2l3Ff7HYquQZra/I1PhIiSkRifK+hO/GLUo899hjuuusu3HHHHVi8eDGefvpppKam4rnnnpv0OoFAALfccgseeeQRVFVVTXgZu90Op9OpfuXkjJ6UdnV14fjx43jooYewbNkyzJ8/Hz/60Y8wNDSEw4cPR/0xJhORKdXh9qB7IPzuHbF1b+MSJ0f3IiCKUnoY3xOvVwUZ9qh3Kt11aRUy7BYcbXPj1cMTF45Ft1hBBvOkaGIVuanIdFjg9QdxrD2ybs5o2lWnz9E9ACjJCb1Pa+4bNvQG4GFvAIeVMU29daPNFItSRFFUpKwabTVop9SQ148PToVWEzNPiohiIVnG97xeL/bu3Yt169ap3zOZTFi3bh1279496fW+//3vo6CgAHfeeeekl9m1axcKCgqwcOFC3HPPPejuHl0pn5ubi4ULF+KFF17A4OAg/H4/fvGLX6CgoAArV66c9DY9Hg9cLte4LxovzW5Ru2SOtoX3ps/jD2D7kdCbtI1LOboXiUIdZUrVK6N7lVEc3ROyU22489K5AICfvnUMgQm6pUSnVCE7pWgSkiRhWWk2AOCgRiN8He4R9b6vWKi/D7fFRMuQN4D+YZ/GRzNzB5r64AvIKMiwo1QptBkdi1JEUVSsdEp1D3oxogRSGsl7J7rh9QdRNicF1fnpWh8OESUg0SnV6hqBx2+818lwdXV1IRAIoLCwcNz3CwsLJx2je/fdd/Hss8/imWeemfR2N2zYgBdeeAHbt2/Hj3/8Y7z99tvYuHEjAoHQ36UkSXjrrbfwySefICMjAw6HA4899hhee+21cR1VZ9u6dSuysrLUr7Kyshk86sS3yClypcIr2r17vAsDHj+cmQ6sUN4wUnicaqeU9plS9V2hIno0R/fG+solc5GVYsWJjgH86UDLOT8XfwcF3LxHU1haKsLO+zS5f7FFcllpli67+hxWM/LSQ6HgzQYe4RMh5xdWzkmY7lsWpYiiKDvVCrsl9M+qrV/7T/Yipa5wXViQMC9yRKQveek2pNrMkGWgOQE24ESL2+3GrbfeimeeeQZ5eZMvmfjyl7+M6667DkuXLsUNN9yAV155BR999BF27doFILQ6/t5770VBQQHeeecd7NmzBzfccAOuvfZatLa2Tnq7W7ZsQX9/v/rV2NgY7YeYENQNfGGGnf9F2bq3YYkTJhN/r0ZCFGD0cD4lQs4r86LfKQUAmQ4r/vGy0Ljuf2w/Dn9gfCZQh1uM77FTiiYncqW06pTaqbyPuGKhfqctRNi5kc8/RJ5UbYLkSQEsShFFlSRJaoheqw5OoiIhy7Ku58CJKDFIkjSaK5XAI3x5eXkwm81obx+/Uau9vR1Op/Ocy588eRL19fW49tprYbFYYLFY8MILL+CPf/wjLBYLTp48OeH9VFVVIS8vDydOnAAA7NixA6+88gpefPFFXHzxxbjgggvwn//5n0hJScG2bdsmPV673Y7MzMxxX3Su0Q1803dKef1BvPlZqCh1DUf3IuZURm063COa57/Uq5v3YtMpBQCbLqrEnDQbTncN4nefNI/7mRhhZKcUTUVs4Ktrc8d9YsMXCOKd410A9B0BYvSw80BQHtcplShYlCKKMjGv3NpvrBe7I61utPaPwGE14XNVuVofDhElsGTIlbLZbFi5ciW2b9+ufi8YDGL79u1Ys2bNOZevqanBoUOHsH//fvXruuuuw9q1a7F///5Jx+mamprQ3d2NoqJQ0WNoKPR3ajKNP8UzmUwIBrXbyJQoxPjeiY4B+AJT/32+f7ILrhE/8jPsWFmROJ9ox0t+eqgryBeQ0TPo1fRYRNB5LDKlhDS7BXdfHuqWemL78XH/f3UqnVKF7JSiKRRnOZCbZoM/KIc9YhwtH9X3YMDjR166Te3Y0iO1KGWw5gHhWLsb7hE/0mxmdflGImBRiijKxAY+o3VK7VS6pC6uzoPDatb4aIgokYlOqTMJvoFv8+bNeOaZZ7Bt2zYcOXIE99xzDwYHB3HHHXcAAG677TZs2bIFAOBwOLBkyZJxX9nZ2cjIyMCSJUtgs9kwMDCAb3/72/jggw9QX1+P7du34/rrr8e8efOwfv16AMCaNWuQk5OD22+/HQcOHMCxY8fw7W9/G6dPn8YXvvAFzf4uEkVpTgrS7RZ4A0Gc6hyc8rKvKZvU1p9XCDNH9yJms5jU/Bctc6VcIz50K0WxWHZKAcCtn6tEfoYdTb3D+O3HTer32SlF4ZAkaUyuVHxH+MTo3uULCnQ9qmz08b2PldG988tzYDEnTikncR4JkU6ITimjtYWKXyYc3SOiWBNv7M4kcKcUANx888149NFH8d3vfhcrVqzA/v378dprr6nh52fOnJky5+lsZrMZBw8exHXXXYcFCxbgzjvvxMqVK/HOO+/Abg91UOTl5eG1117DwMAArrzyStTW1uLdd9/F//zP/2D58uUxeZzJxGSSsFD5dPpo2+SdCP5AEK9/qozuLeHo3kyJsGQtN/CJ4nleug0ZDmtM7yvFZsb/vqIaAPCzHcfh8Qcw4PFjyBsaxWKmFE1HbOA70BjfotQO9X2E/rbujSU6pYwadP6xMrqXSHlSAGDR+gCIEk1Rtn6COcPVO+jFvjOhFzkWpYgo1sT4XqIXpQDgvvvuw3333Tfhz0Q4+WSef/75cX9OSUnB66+/Pu191tbWhnU5mplFRRnY29CLz1pduH5FyYSX+fB0D3qHfJiTZsOquYmT+xFvziwHPmt1aVqUGs2Tit3o3lh/t6ocv3j7FFr6R/DinkZcMj+0+CDdbkGanW/daGpidC6eG/jOdA/hZOcgzCYJl87Xd1GqxOCZUh/XJ16eFMBOKaKoK84y3qzyX493IigDNc4M9cWaiChWyscUpbQOMCaKVI2SK3V0ig18rx4OdcCtP68woUYs4q0wM9QZ1KZhUUrkScV6dE9wWM2498p5AICndp5QO7UKMtklRdNbpozvnegYwKDHH5f7FBEgtRU5yEqJbTfhbJXkhN7ndLg98PjjGwY/W819w2juG4bZJGGFEmqfKPhbkijKnAYMOt/B0T0iiqOS7BRIEjDkDahZLURGMd0GvkBQxmuHQ1sXN3B0b1YKM8X4nnaZUvVdoU6pWIacn+3m2jKUZKegw+3B49uPA+DoHoWnINMBZ6YDQRn4tCU+YefifYSet+4JOalWOKyhEoiRplqA0Typ84ozE65rkkUpoigTnVJ9Qz4Me/VfgQ8EZbx9rBMAsHah/n+ZEJHxOaxmOJU3m8kwwkeJRWRKdbg96B44t1jycX0PugY8yEqx4qJqbrOdjdGiVPJ0SgGhkPevXxXqljrQ2Adg9O+CaDoi7PxgU1/M72vI68fuU90AjPHhtiRJhs2VEqN7ibjNlUUpoijLTLEg1RbaXmeEbqn9jb3oG/IhK8WKC8qztT4cIkoSybKBjxJPut2iFiiOtp07wveqsnXv84sLYeXo3qw4dVCUEplS8eyUAoAbLygdVwhjpxSFazRXKvZh57tPdsPrD6IkOwXzC9Jjfn/RYNQNfB8pnVKJlicFsChFFHWSJKkb+FoN0BYqWm4vW5DP3AsiipvyJAo7p8RTo3RLnT3CFwzKeE0pSm1c4oz7cSUakaOkVVFqyOtHhzvUDRfvopTVbMI3rpqv/pmdUhSuZUre0MGm2Belxo7uSZIU8/uLhtGwc/2/TxNcIz7UtYc+BKllpxQRhaMoyzibHXYcDY3uXanzFa5ElFhYlCIjG82VGt8p9UljH9pcI0i3W9StaTRzolOqa8ALXyAY9/uv7wq9PmWnWpGVGv8A5+tXlKA6P1QMi9f2PzK+pUqn1OmuQfQP+2J2P7IsY6eB8qSEYgNu4NvX0AtZDo0RFyRggZpFKaIYEJ1Seg/Qa+0fxpFWFyQJuHyBcX6ZEJHxleeyKEXGpW7gaxvfKfXqodDWvXWLCmC3mON+XIkmJ9UGqznUfSE6luKpQRnd06ogZDZJ2PaVVfjRjUtxlYHe9JO25qTZUKpsmfs0hiN8de1utPSPwG4x4XNVxsnPKzFgppTIk6qtSLzRPYBFKaKYKBIVeJ0XpXbVhbqkzi/Lxpw0m8ZHQ0TJRHRKNbIoRQa0WOmUOt4+oHbwyLKs5kltXMqte9FgMkkoyNAuV6peybybG8eQ87OV5qTiy6vKYTIZYzSK9GF5aTYA4EAMR/h2KtMWF1XnIsVmnCK8ETulRvOkEm90D2BRiigmitVMKX2/2Ik5cG7dI6J4E0WpNtcIRnz631RKNFZpTgrSbGZ4A0Gc6gx10xxq7kdz3zBSbWZcvoAj8dFSKHKlNPigT+tOKaKZEhv4DjX3xew+jDi6B4zvlJJlWeOjmZ7XH8R+ZQtnLYtSRBQupyhK6ThAz+MP4L0TXQCMscKViBLLnDQb0mxmyDLQZLANOEQmk4SaovEjfH85FOqSWltTAIfVOF0DeifOqbTplFI27+Vp1ylFNBNiA1+sws77h3zYeyY0UnaFwT7cdmY5IEmAxx9E96BX68OZ1qct/fD4g8hJtaI63xgbDiPFohRRDIi2UD13Sn14qgdD3gAKM+04rzhT68MhoiQjSRLKOMJHBiY28H3W6lJG90J5Utcs4eheNInxvTaXFplSodcmdkqR0SxROqWaeofRPRD9fztvH+9EIChjfkG6+rvcKGwWEwoyQh2YRhjhE3lSKyvmGGbDYaRYlCKKARF07hrxY9Dj1/hoJjZ2dC9RX+CISN8qGHZOBiY28B1tdeNIqxsN3UNwWE24YiFH96JJdEp1xLlTasQXQKsyMljJohQZTKbDiqq80P+3h2IQdr7LoKN7gpFypRI9TwpgUYooJjIcVmTYLQD02S0lyzJ21ilFKYP+MiEi4xO5UqIbgchIFhWFOqWOtLrULqnLF+QjTfn9T9EhMqXa4lyUEsXyDIcFOanWuN43UTSouVJRHuELBGXsOhYKOTfq+wiRK6X3+ABZlvFxg7J5rzIxN+8BLEoRxYz4ZK9Fh7lSp7sG0dA9BKtZwsXz8rQ+HCJKUqIoxU4pMqKFzlCnVIfbg99+3AQAuIZb96KuMFObTKn6LiVPKjeNHeVkSMtitIHvQFMfega9yHBYsLLCmN07JWqnlP7ep411qmsQPYNe2CwmLClJ3LgVFqWIYqRIebFr02BbzHTE6N7qublI5ye6RKQRZkqRkaXbLeO2SNrMJsOOsujZaFEqvplSo3lSxsrLIRKWxWgDn9i6d9n8fFjNxiwnGGV8b6+SJ7WiNBt2S+Iu0DDm/0VEBlAsOqV0OL73/sluAGDuBRFpSoQHn+kZMsRaZqKziRE+ALh0fh4yHBzzijZRlBrw+DEQx5xOdfMe86TIoM4rzoRJChV0o9lpmAgRIKJTqlnnRSmRJ1WbwHlSAItSRDFTlKVs4NNZW6gsyzjQ2AcAhm25JaLEUJKdAkkChn0BdMZgOxBRrNU4R8cpNnJ0LybS7Ra1qzueI3yiKMVOKTKqVJsF8wtChfODURrh63CN4HCzC5Jk7A+3jdIpJfKkLkzgPCmARSmimCnSaadUU+8wuge9sJoldXMQEZEWbBYTipUCPkf4yIjE71GLScLnFxVqfDSJS4Sdx7Uo1RV6TarMY6cUGddo2HlfVG5PdEktK81GXro9KrepBdEp1T3oxYgvoPHRTKzT7cHprkFIEnBBeWI3ErAoRRQjRdmhopTeMqUOKL+UFhVlwmFN3NlkIjKGsjmhE0OGnZMRXTI/D6sq5+Duy6uRxQ1tMRPvsHOPP6B+qMjxPTIykSt1sDk6nVIil3atgbukACAzZbQDU68jfHsbQqN7CwoyEv73C4tSRDGiju/prSiljO4tVzZyEBFpSd3A163Pk0KiqaTbLfjN3WvwrfULtT6UhOaMc9h5Y88wZBlIs5mRl26Ly30SxYLYwHewqX/W2Y1efxDvHu8CAMMvdZAkCcXZYlO6Ps8/PlZCzhM9TwpgUYooZsT43oDHD9eIT+OjGSXWwi4vy9b2QIiIMD7snIhoIgWZ8e0+b1DzpNIgSVJc7pMoFmqcGbCYJPQMemfdEfRRfQ8GvQHkpduxpDgrSkeoHb3nSn2UJHlSQJyKUh6PBytWrIAkSdi/f39Y15FlGRs3boQkSfjDH/4Q0+MjioU0uwWZjlBbqF5G+PyBIA4pRakVZcb/ZUJExlcmOqV6BjU+EiLSK6eSKdXhjs/5VH23yJNiyDkZm8NqxkJnKOz80CzDzseO7plMxi/Wqhv4evVXlBry+vGpMnLJTqkoefDBB1FcXBzRdR5//HF+MkGGp7cK/InOAQz7Aki3W1CVl6714RARjY7vsVOKiCZRqGGnFJHRiRG+A7MsSu1UilJGH90TxPu0Zp1tSgeA/Y198AdlFGU51OJZIot5UerVV1/FG2+8gUcffTTs6+zfvx///u//jueeey6GR0YUe2KETy+5UiJPallpVkJ8wkFExieKUu0uj2434BCRtgqz4psppXZK5bJTioxPhJ0fau6b8W3Udw3iVNcgLCYJF8/Pi9KRaatEZ80DY4k8qZUVOUnRqGOJ5Y23t7fjrrvuwh/+8Aekpob3oj40NIS///u/x1NPPQWn0znt5T0eDzye0V9QLpdrxsdLFG1OEXaukxe7/Y3MkyIifclJtSLDboHb40dT7xDmFWRofUhEpDOiU6rDPYJgUI75B2vslKJEsrRE2cCnhJ3PpMixsy7UJXVh5RxkOhJjE1xJjuiU0sf7tLH2KnlStRWJP7oHxLBTSpZlbNq0CXfffTdqa2vDvt7999+Piy66CNdff31Yl9+6dSuysrLUr7KyspkeMlHUFeu0U4qb94hILyRJUnOlGro5wkdE5yrICGVK+QIyeoe8Mb0vXyCIJiVjppJFKUoAC50ZsFlMcI/41S7ASO1IsNE9YHR8r7V/GMHg7DYTRlMwKGPfGbF5L/FDzoEZFKUeeughSJI05dfRo0fx5JNPwu12Y8uWLWHf9h//+Efs2LEDjz/+eNjX2bJlC/r7+9WvxsbGSB8SUcwUqS922helhr0B1LW7AQAr2ClFRDrCXCkimorVbEJeug0A0OaK7TlVc+8wAkEZDqtJLYYRGZnVbMLiokwAwMGmvoivP+jx48NTPQCAtQlUlCrMsMNskuALyOgaiM9ocDiOdwzAPeJHqs2MGmdydI9HPL73wAMPYNOmTVNepqqqCjt27MDu3btht49/Ma+trcUtt9yCbdu2nXO9HTt24OTJk8jOzh73/ZtuugmXXnopdu3adc517Hb7OfdBpBeiU6qlX/u20E9b+hEIyijMtMOpHBcRkR6U57IoRURTK8x0oGvAiw6XB+dFtj8pIqfF6N6cNOZvUsJYVpqF/Y19ONTUj+tXlER03fdOdMEbCKJsTgqq8xOne9BiNsGZ6UBz3zCa+oZRkKmP90didG9FWTYs5rjspdNcxEWp/Px85OfnT3u5J554Aj/84Q/VP7e0tGD9+vV46aWXsHr16gmv89BDD+Ef/uEfxn1v6dKl+OlPf4prr7020kMl0pwo/rT2jcx4hjta9qsh59maHQMR0UREp1Qji1JENInCTAc+bXHFvFOqoUvkSTHknBJH6Py/AQdnsIFvZ10nAODKhQUJF7pdnB0qSrX0DeOCcn3kN4mi1MokyZMCYhh0Xl5ePu7P6emh9fPV1dUoLS0FADQ3N+Oqq67CCy+8gFWrVsHpdE4Ybl5eXo65c+fG6lCJYqZICTof9gXgGvYjK1W7YECxBpaje0SkN+XMlCKiaYiw8/YYF6VE5s7cvMTpCCESG/gOK5MT5jC7AGVZxi4l5DyRRveEUK5Ur6428Ik8qQuSqCilaT+Yz+dDXV0dhoZ4EkqJKcVmRo5SiNJ6hI8h50SkV2MzpWRZP2GjRKQfhZmhuI5YF6W4eY8SUXV+OlJtZgx5AzjVORD29Y60utHaPwKH1YTPVeXG8Ai1UaLk/zb36qMo1T3gwWmlW/OCsuQpSsWsU+pslZWV55xoTvS9s/HklIyuKCsFvUM+tPYPY5ESMhhvPYNeNatlqfJJCRGRXhRnp8AkAR5/EJ1uj25yHYhIP5xqp1RsA4lFx2Ylx/cogZhNEpYUZ2FPfQ8ONPVjfmF4Ado7lS6pi6vz4LCaY3mImhAb+Jr7tF9KBQD7zvQBAOYXpGs6YRNvyZGcRaShIhF2ruGL3QFl00ZVfhqyUpLnBY6IjMFmMaknhgw7J6KJiPG9thhuNPYHgmjsDb0GVXB8jxKM+GD6UAQb+HYeTdzRPWC0U0ov43sfN4S2HCZTnhTAohRRzBVlx/4kajpidG8FR/eISKeYK0VEUxFFqQ537M6nWvtH4AvIsFlMKGLHJiUYkSt1sDm8sPPeQa+ab5SwRakc0Smlj6LUvobky5MCWJQiijkRdq5lppSaJ8WQcyLSqbG5UkREZxOZUl0DXnj9wZjcR72SJ1U+JxWmMIOgiYxCbOD+rMUFX2D6f0N/Pd6JoAzUODPUjqJEIyZa+od9GPD4NT0Wrz+oLqaqZVGKiKKpWOmUatVofE+WZXX9K4tSRKRXZUpRqpFFKSKawJw0G6zmUKGocyA2uVL1zJOiBFYxJxUZDgs8/iCOtbunvfwOZXTvioWJ2SUFABkOKzIdoZjtVo27pT5t6YfXH0ROqjXptn+yKEUUY87M0CcLbTHeFjOZpt5hdA96YTVLWFQUXqghEVG8sVOKiKYiSRIKMmIbidDQxc17lLhMJglLS0Su1NQjfIGgjLePdQIArkzQ0T2hJCd0/tGkcVFqrzK6t7IiB5KUXJ2aLEoRxZjolGrpG9Zkm6QIOV9clAm7JfG2ZhBRYqjIZVGKiKbmVEZtOmL0QR87pSjRiRG+6XKl9jf2om/Ih6wUKy4oz479gWmoZMx7NS2J/K5ky5MCWJQiijkRzOnxB9E75Iv7/TNPioiMQHRKdbg9GPYGND4aItIjkSsVq+7zhm52SlFiU8POp9nAJ0b3LluQD4s5sUsGxTrYwCfL8minVDmLUkQUZQ6rGblpNgDavNgdaFTypLh5j4h0LCvFigwl10GsZCciGkt80Nfuin6mVDAoo6FHdEqxKEWJSYzv1bW5MeKb/AOgHUdDo3trF+bH5bi0JELcm3u1K0o19Q6j3eWBxSSp3WzJhEUpojgoyo5tBsJk/IEgDjUz5JyI9E+SpNFcqW4WpYjoXKNFqeifT7W6RuD1B2ExSWr0AlGiKc1JwZw0G3wBGXVtE4edt/WP4EirC5IEXL4g8YtSo51S2uT/AqOje+cVZyLFlnxxKyxKEcVBUVboxa61P74V+OMdAxj2BZBht6AqybY4EJHxMFeKiKbijGFRSoScl89JTfhxJUpekjQadj7ZCN/OutDo3oqybOSm2+N1aJoRRalmDcf3RkPO52h2DFriKy5RHBQrwZwtce6UEnlSy8qyYDIl1xYHIjKeMm7gI6IpFMQwU0qEnFcw5JwS3Giu1MRh5yJP6sqFib11TyjNGd2U7g8ENTmGsZv3khGLUkRx4BSdUnGuwIvNe8k4m0xExlPOohQRTUGM73XEIFOKIeeULMT7gkMTbODz+AN470QXAGBtTXIUpfLT7bCaJQSCMjrc0X9tmc6gx48jrS4AwAUV2XG/fz1gUYooDkQ2QWucO6X2M+SciAyERSkimoooSg14/Bjw+KN62/VKUaqSnVKU4ESn1LF2N4a84/8d7TndgyFvAAUZdpxXnKnF4cWdySTBKaZaNFlK1YegHApcF5EvyYZFKaI4GM2Uil9Rasjrx7H2UIDhCoacE5EBVMwJdSg09gwhGJQ1Phoi0pt0uwXp9tCWzmjnSjWI8T1mcFKCK8x0oCDDjqAMfNbiGvczMbq3dmEBJCl5oj9KNMyVEqN7FyTp6B7AohRRXBRljW7fi9cbrU9bXAgEZRRm2tXqPxGRnhVlO2A2SfD4g5q00BOR/hUquVLtUfygT5blMZ1SLEpR4hMjfGfnSu0URakkGd0TtAw7/1jkSZVnx/2+9YJFKaI4KMx0QJIAbyCI7kFvXO5ThJxzdI+IjMJqNqnjzhzhI6KJiBG+dnf0ilIdbg9GfEGYTZLaMUGUyEbDzvvU753qHEB99xCsZgmXzM/T6Mi0If7dx3t8LxiUse9Mcm/eA1iUIooLm8WEPGWlalucRvj2i6IUR/eIyECYK0VEU3Fmiu7z6HVT1neFuqRKslNgs/DtESW+paIoNSbsfGddJwBg1dw56phsslDH93rjW5Q60TkA94gfKVYzFhVlxPW+9YSvukRxUiwC9Prj82InNu8xT4qIjIRFKSKaSoHolIpippSaJ8WQc0oSy0pCRalTnYNwj/gAjBndW5hco3vA6PheS198l1KJPKkVZdmwmJO3NJO8j5woztSw8zi0hXYPeNDYE7of8UkIEZERlCth52eUfBciorGcIlMqikWp08yToiSTm25Xu4MONfdjwOPHh6e7AQBXJlmeFACU5GgzvieKUiuTOOQcYFGKKG5E2Hg8NvCJVtzq/DRkOqwxvz8iomhJtE6pp556CpWVlXA4HFi9ejX27NkT1vVefPFFSJKEG264Ydz3N23aBEmSxn1t2LDhnOv/+c9/xurVq5GSkoKcnJxzbofIqApj0ikVKkqxU4qSiciVOtTUj3ePd8EXkFGZm4qq/HSNjyz+ipXmAbfHj/5hX9zudx+LUgBYlCKKGxHeG4+i1AHmSRGRQY0WpeK/ASfaXnrpJWzevBkPP/ww9u3bh+XLl2P9+vXo6OiY8nr19fX41re+hUsvvXTCn2/YsAGtra3q169//etxP3/55Zdx66234o477sCBAwfw3nvv4e///u+j9riItFSYJYpS0cyUChXB5+axU4qSh7qBr7kfu+pCv5euSMLRPQBIsZkxJ80GIH7dUj2DXpxS8uzOT+LNewCLUkRxo47vxSFTShSlmCdFREYjilJdAx4Mef0aH83sPPbYY7jrrrtwxx13YPHixXj66aeRmpqK5557btLrBAIB3HLLLXjkkUdQVVU14WXsdjucTqf6lZMz+gmr3+/HN77xDfzkJz/B3XffjQULFmDx4sX40pe+FPXHR6QF0SnV4R5BMCjP+vZkWR7TKcWiFCUP0Sl1oLEPO5WiVDKO7gnx3sAnRvfmFaQjO9UWl/vUKxaliOJEdErFOkBPlmUcaAqN7y1XPgEhIjKKrFQrslJCY8eNBu6W8nq92Lt3L9atW6d+z2QyYd26ddi9e/ek1/v+97+PgoIC3HnnnZNeZteuXSgoKMDChQtxzz33oLu7W/3Zvn370NzcDJPJhPPPPx9FRUXYuHEjDh8+POXxejweuFyucV9EelSQEcqU8gVk9Ax5Z317XQNeDHoDkCSgbE7KrG+PyCiWKGHnTb3DaHd5kGozY3XVHI2PSjvivVpznItStUk+ugewKEUUN06lU6rdFZ1P9ibT1DuMnkEvbGYTapJ4tSgRGZfolmowcNh5V1cXAoEACgsLx32/sLAQbW1tE17n3XffxbPPPotnnnlm0tvdsGEDXnjhBWzfvh0//vGP8fbbb2Pjxo0IBAIAgFOnTgEAvve97+Ff/uVf8MorryAnJwdXXHEFenp6Jr3drVu3IisrS/0qKyuL9CETxYXVbEJeeqirIBq5UuJ1pjgrBXaLeda3R2QUWSnWcSOrF8/LS+p/A2IDX7yKUiJP6gIWpViUIoqXwgw7TBLgD8roGoheDsLZ9iuje4uKM5P6FwsRGVeihZ2Hw+1249Zbb8UzzzyDvLy8SS/35S9/Gddddx2WLl2KG264Aa+88go++ugj7Nq1CwAQDAYBAP/8z/+Mm266CStXrsQvf/lLSJKE3/72t5Pe7pYtW9Df369+NTY2RvXxEUVTNMPO67tDrzOVeQw5p+SztGR0S/faJM2TEkbH92Kf/+v1B3GgqQ8AQ84BwKL1ARAlC4vZhIIMB9pcI2jpH0GBckIVbWqeVGnW1BckItKpMqUo1WjgolReXh7MZjPa29vHfb+9vR1Op/Ocy588eRL19fW49tpr1e+JApPFYkFdXR2qq6vPuV5VVRXy8vJw4sQJXHXVVSgqKgIALF68WL2M3W5HVVUVzpw5M+nx2u122O32yB4kkUYKMx34tMUVlbBz5klRMltWmoU/HmgBAKytydf4aLQlilLNvbE/9/is1QWPP4jsVCuquGCBnVJE8eRUNsa0xrAtVFTduXmPiIxKrGU3cqeUzWbDypUrsX37dvV7wWAQ27dvx5o1a865fE1NDQ4dOoT9+/erX9dddx3Wrl2L/fv3TzpO19TUhO7ubrUYtXLlStjtdtTV1amX8fl8qK+vR0VFRZQfJZE2RKdUWxQ2GqudUrnslKLks6Y6F5IUyjUSS5mSVXEcO6VEntTK8hxIkhTz+9M7dkoRxVFxtgP7G4HWKJxETcQfCOJQcyjkfBlDzonIoNRMKQMXpQBg8+bNuP3221FbW4tVq1bh8ccfx+DgIO644w4AwG233YaSkhJs3boVDocDS5YsGXf97OxsAFC/PzAwgEceeQQ33XQTnE4nTp48iQcffBDz5s3D+vXrAQCZmZm4++678fDDD6OsrAwVFRX4yU9+AgD427/92zg9cqLYKswMdfV1uKOXKcVOKUpG5xVn4ZWvXQJnjCY4jKQkR8n/dY/AFwjCao5d/w7zpMZjUYoojsQnEK39semUOtY+gBFfEBl2C1tBiciwRFGqqWcYwaAMk8mYnyLefPPN6OzsxHe/+120tbVhxYoVeO2119Tw8zNnzsBkCv+k12w24+DBg9i2bRv6+vpQXFyMq6++Gj/4wQ/Gjd795Cc/gcViwa233orh4WGsXr0aO3bsQE4OT34pMTij1CklyzJOd4WKUpUsSlGSOq+YkR8AkJtmg81igtcfRFv/iBolEG2yLOPjhtDiEeZJhbAoRRRHRcr4XkuMOqXE6N6ysizDvokjIirKcsBikuANBNHuHjH0SMF9992H++67b8KfiXDyyTz//PPj/pySkoLXX3992vu0Wq149NFH8eijj4Z7mESGMhp0PrtMqd4hH9wjfgCjxXAiSk6SJKEkOwWnuwbR3Dccs6JUc98w2l0eWEwSlnOyBQAzpYjiSu2UilGmlAg55wscERmZxWxS2+jPdBt7hI+Ioi9a2/fqldE9Z6YDKTZuLCZKdqMb+GKX/yvypM4rzuTrjoJFKaI4KsqOXjDnRPaLohRDzonI4BIlV4qIok9kSnUPeuH1B2d8OyJPqjKPXVJEFMr/BYDm3tgVpZgndS4WpYjiqDhLBOh5EAjKUb3tIa8fx9rdAIAVLEoRkcGJtvlGFqWI6Cxz0mywmkMxBbMJO6/vEpv3mCdFRGM28MUo/xcA9p5RNu+xKKViUYoojvIz7DCbJASCclQ2xox1uNmFoBxqQS/kBg0iMjjRKXWGRSkiOoskSSjImH2uFDfvEdFYYnyvuS82Uy2DHj+OtIaaCFiUGsWiFFEcmU0SCjNCLectUX6xO6iEnC8v4wYNIjI+FqWIaCrOrNnnStV3i04pju8R0ZiiVG9szj0ONPUhEJRRnOUw9BKXaGNRiijOipQXu2jnSjFPiogSSTnH94hoCiJXajZFKXZKEdFY6vhe3whkObpRKwCwt555UhNhUYoozoqUT/ZaozyrfEDplFrBzXtElADKlc6FrgEvBjx+jY+GiPRGRBW0zbAo1T/kQ++QDwBQwU4pIsLoUqphXwB9yutDNIk8qVoWpcZhUYoozsZW4KOle8CDxp5hSBKwpJTje0RkfJkOK7JTrQDYLUVE5xJFqY4ZZko19IS6pPIz7EizW6J2XERkXHaLGflK1EpzX3QbCIJBWd28t7JiTlRv2+hYlCKKM6f6yV70XugONvUDAKrz05HpsEbtdomItMRcKSKajHo+NcM4BOZJEdFEitWw8+gWpU52DsA14keK1Yyaooyo3rbRsShFFGfFSltoNDul1Dwpju4RUQJhrhQRTaZAZErNcJtxQxfzpIjoXKXqVEt0i1J7lS6p5WVZsJpZhhmLfxtEcSY2LUQzU0rNk+LmPSJKIKIo1dDNohQRjSc6pdpn2Cl1Wgk5Z6cUEY0lGgiae2NTlFrJPKlzsChFFGciQK/D7YEvEJz17cmyjAPcvEdECYjje0Q0GZEpNegNzGgZgih2s1OKiMZS83+jvJRKhJyzKHUuFqWI4iwvzQ6rWYIshwpTs9XYM4zeIR9sZhNqnJlROEIiIn3g+B4RTSbNbkGGElA+k1ypBrVTikUpIhpVIjKlotgp1TPoxanO0GvOBeUsSp0t5kUpj8eDFStWQJIk7N+/f9rL7969G1deeSXS0tKQmZmJyy67DMPD0a1SEmnJZJLUT/daozCrvF8Z3VtUnAmbhXVmIkoc5cpYTVPvMAJBWeOjISK9EblSHa7IilLuER+6BrwAgIo8ju8R0ajRoPPo5f+KrXvzCtKRnWqL2u0mipi/g33wwQdRXFwc1mV3796NDRs24Oqrr8aePXvw0Ucf4b777oPJxDfalFiKs0Rb6Oxf7MTo3opS5kkRUWIpykqBxSTBGwiiLcI3nUSU+JxZYqNxZK8PYnQvN83GrcVENI7olOoa8GDEF4jKbaqje+ySmpAlljf+6quv4o033sDLL7+MV199ddrL33///fj617+Ohx56SP3ewoULY3mIRJoQJ1HR6JRinhQRJSqzSUJpTgrqu4dwpntIPVEkIgKAwgwl7NwVWRzCaJ4Uu6SIaLzsVCtSbWYMeQNo7R/B3LzZj/gy5HxqMWtBam9vx1133YVf/epXSE2d/gW/o6MDH374IQoKCnDRRRehsLAQl19+Od59990pr+fxeOByucZ9EemdCDtvnWWnlC8QxOGWfgAsShFRYipjrhQRTaIwSxSlIjufqmeeFBFNQpKk0bDzKDQQ+AJBtYngAhalJhSTopQsy9i0aRPuvvtu1NbWhnWdU6dOAQC+973v4a677sJrr72GCy64AFdddRWOHz8+6fW2bt2KrKws9ausrCwqj4EolsT4Xusstzoca3djxBdEhsOCuTyxIqIEJDoZuIGPiM5WmBHKlIq0KCVCzrl5j4gmMporNfui1GctLnj8QWSnWlEVha6rRBRRUeqhhx6CJElTfh09ehRPPvkk3G43tmzZEvZtB4NBAMBXv/pV3HHHHTj//PPx05/+FAsXLsRzzz036fW2bNmC/v5+9auxsTGSh0SkiaKs6HRKHWhUuqRKs2EySbM+LiIivREb+BpYlCKis8w0U6peGd+rZMg5EU0gmhv4xOjeBeU5fL82iYgypR544AFs2rRpystUVVVhx44d2L17N+x2+7if1dbW4pZbbsG2bdvOuV5RUREAYPHixeO+v2jRIpw5c2bS+7Pb7efcD5HeFYmg81ludRjNk2LIORElJlGUYqcUEZ2tQNlm3BFxphQ7pYhociVK1Eo0xveYJzW9iIpS+fn5yM/Pn/ZyTzzxBH74wx+qf25pacH69evx0ksvYfXq1RNep7KyEsXFxairqxv3/WPHjmHjxo2RHCaR7olMqa4BD7z+IGyWmU3SHmjqAxDqlCIiSkTMlCKiyTgzRzOlgkE5rC6EIa9fDUavZNA5EU1AzZSaZdSKLMv4uKEHAItSU4nJ9r3y8vJxf05PTwcAVFdXo7S0FADQ3NyMq666Ci+88AJWrVoFSZLw7W9/Gw8//DCWL1+OFStWYNu2bTh69Cj+7//9v7E4TCLN5KbZYLOY4PUH0e4aUd90RWLI68exdjcAYAVDzokoQYlOqZ5BL9wjPmRwfTsRKfIz7JAkwB+U0TPkRV769NMTousyK8WK7FRbrA+RiAwoWuN7Lf0jaHd5YDZJbCKYQkyKUuHw+Xyoq6vD0NDoJ5/f/OY3MTIygvvvvx89PT1Yvnw53nzzTVRXV2t1mEQxIUkSirIcaOgeQkvf8IyKUoebXQjKoXwq0b5ORJRoMhxWzEmzoWfQi8aeYSwuZlGKiEKsZhNy0+zoGvCgrX8krKJUfZfYvMcuKSKa2GinVPhdmBMRo3vnFWcixWaO2vElmrgUpSorKyHL8rTfA0Jh6g899FA8DotIU87MUFFqpmHnap4Uq+5ElODK5qSiZ9CLMz2DWFycqfXhEJGOFGaGilId7hEA02dsipBz5kkR0WScWQ6YJMDrD6J70Iv8jJllWO8bE3JOk5tZkA0RzZqowM+0KLVf5ElxdI+IEhzDzoloMiJXqq0/vLBzEXJeydXsRDQJq9mEQuW1pXkWYecMOQ8Pi1JEGilS1hi3zjBAj5v3iChZlM8JFfFZlCKisxWMCTsPR31X6HWE43tENBV1hG+GRakhrx+ftboAsCg1HRaliDRSpL7QRd4p1TXgQVPvMCQJWFrCohQRJbaKOaGOhjM9s1/NTESJxRlhUUp0SnF8j4imMtui1P7GPgSCMoqzHOpt0cRYlCLSSJFoN3dF/kJ3UBndq85P5yYqIkp4YhnEGeXNJBGRUJgZynoJpyg14gugRYlNYKcUEU1FbOBrmuEGPjVPil1S02JRikgjRdnK+N4MOqX2N/YDYMg5ESWHcuXNY1PvMALBc5ekEFHyKswSnVLTZ0o1KiPAGXYL5qTZYnpcRGRsJcp7tZl2SjFPKnwsShFppDgrVH3vHvRixBeI6LoiT2oF86SIKAk4Mx2wmiX4g/KMc/iIKDEVZoQ/vqdu3stLhSTNbMU7ESUHdXxvBucdwaCMfWf6ALAoFQ4WpYg0kp1qhd0S+ifYFsEGPlmWcYCb94goiZhNEspyuIGPiM7lVDqluge98PqDU16WeVJEFK6SnFBRqnkG43unugbQP+yDw2rCoqLMaB9awmFRikgjkiSpFfjWCIpSZ3qG0Dfkg81sQo2TL3JElBxGc6VYlCKiUTmpVtjMobc0He6pz6fqlaIU86SIaDrifVrvkA9DXn9E1xWje8tLs2E1s+QyHf4NEWmoSPl0L5JxlP3K6N7i4kzYLPwnTETJoXwOO6WI6FySJKFADTufOleqQYzvsVOKiKaR6bAiw24BEPm2dOZJRYbvaIk0VJQVeafUASXkfAVH94goibAoRUSTKcwML1fqdJfolGJRioimp47wRRh2/rFSlKqtZFEqHCxKEWlIdEpFstVhNE+KIedElDzEBr5GFqWI6CzOMIpSHn9APd/i+B4RhUMNO4/gvVrPoBenOkMF8PPLWJQKB4tSRBoqUlaNhht07gsE8WlLqFNqeWl2rA6LiEh3RKdUA4tSRHQWMb7XNkVRqql3GEEZSLWZkZ9hj9ehEZGBFWdH3kDwyZlQl1R1fhpy0mwxOa5Ew6IUkYaKs8Sq0fCKUsfa3RjxBZHpsLD1nIiSigg67xvyoX/Yp/HREJGeiE6pjikypcZu3pMkKS7HRUTGVpIdOveIZAMf86Qix6IUkYZEp1S4QeciT2p5WTZMJp5QEVHySLdbkKt84sgRPiIaS2RKTdV5Xt8Vet3g6B4RhUt0SkWSKcWiVORYlCLSUFFmqFOqb8iHYW9g2ssfUDbvcXSPiJKR6JZiUYqIxlKDzt2TF6XGdkoREYWjRGRKhdlA4AsE1fxfFqXCx6IUkYYyUyxItZkBhNctNRpynh3DoyIi0qeKXG7gI6JzFSqZUu1TdUp1s1OKiCIjtu+19o0gEJSnvfyRVhdGfEFkpVhRlZce68NLGCxKEWlIkgUGO2gAABtySURBVCR1A1/rNLlSgx4/jrW7AQDLS7l5j4iSD8POiWgiolNq0BvAgMc/4WXYKUVEkSrIcMBskuAPyuh0T55ZJ4wd3WPUSvhYlCLSWLirRg839yMoA8VZDhQoJ19ERMmE43tENJE0uwUZdguAiXOlfIEgmpSg4so8dkoRUXjMJkldpBBOrtTHzJOaERaliDTmDCOcE+DoHhGR6JTi+B4Rna0wS2zgO/d8qqVvGP6gDLvFhMIMfrBHROETI3zhFKX2KUWpC8pZlIoEi1JEGitSA/SmKUopm/eWMeSciJKUyJRq7h2GPxDU+GiISE9ErlTbBEUpkSdVkZvKkRoiikhJmFMtLX3DaO0fgdkkYXkZo1YiwaIUkcaK1UypqV/o9ovNe3yRI6IkVZjhgM1sgj8oT5vDR0TJRd3A5zo396W+i3lSRDQzxdmh15bpilIiT2pxUSZSbZaYH1ciYVGKSGOiU6q1b/I3WJ1uD5r7hiFJwNISFqWIKDmZTBJK54ReMznCR0RjjRalJuqUChWluHmPiCJVkj3apT2VvcyTmjEWpYg0VhRGp9RBJU9qXn46MhzWeBwWEZEuMVeKiCbinKIo1aCM71XmsVOKiCIjOqWmy5Tad0bJk2JRKmIsShFpTBSlXCN+DE6yxviAOrqXHaejIiLSpwoWpYhoAlNnSolOKRaliCgy4WRKDXn9+LTFBQCoZVEqYixKEWksw2FV1xhP1i21vykUcs6iFBEluzJRlOpmUYqIRonxvY6zMqUCQRmNPaNB50REkShWilKuET/cI74JL3OgsR+BoIyiLId6eQofi1JEOuDMEgF65366J8uy2im1gpv3iCjJcXyPiCYyNlMqGJTV77f0DcMXkGEzm1CUxTeLRBSZNLsF2amh+JSJ3qsBHN2bLRaliHRAhJ23TbBNqqF7CP3DPtgsJix0ZsT70IiIdKU8l0UpIjpXfoYdkgT4gzJ6hrzq90WeVNmcFJhNklaHR0QGVpw19QifGnJezqLUTLAoRaQDxaJTaoLxvQNKyPl5xZmwWfhPloiSm+iU6h/2oX9o4jZ6vXnqqadQWVkJh8OB1atXY8+ePWFd78UXX4QkSbjhhhvGfX/Tpk2QJGnc14YNGya8DY/HgxUrVkCSJOzfv3+Wj4RIv6xmE3LTlFypMR/yMU+KiGarJCdUlGqaoCgVDMpqpxQ3780M3+ES6YBoJ2+doCX0QKOSJ8XRPSIipNosyEsPvfE0QrfUSy+9hM2bN+Phhx/Gvn37sHz5cqxfvx4dHR1TXq++vh7f+ta3cOmll0748w0bNqC1tVX9+vWvfz3h5R588EEUFxfP+nEQGYEzK/Ta0OEePZ9qUIpSFSxKEdEMTRV2fqprEH1DPjisJiwuzoz3oSUEFqWIdKAojE6pFQw5JyICAJTPCZ0cGqEo9dhjj+Guu+7CHXfcgcWLF+Ppp59GamoqnnvuuUmvEwgEcMstt+CRRx5BVVXVhJex2+1wOp3qV07OuZ/Ovvrqq3jjjTfw6KOPRu3xEOlZYUbofKqtfzTsvF4Z36vMY8g5Ec1McbbI/z33vdo+ZXRveWk2rGaWV2aCf2tEOlCULU6ixndK+QJBHG7m5j0iorHECF9jr76LUl6vF3v37sW6devU75lMJqxbtw67d++e9Hrf//73UVBQgDvvvHPSy+zatQsFBQVYuHAh7rnnHnR3d4/7eXt7O+666y786le/Qmrq9G/GPR4PXC7XuC8ioynMGg07F9gpRUSzVZId+j3a3HtuUUrNk+Lo3oxZtD4AIhozvndWUaquzQ2PP4hMhwWVXGNMRAQAeHBDDf7PFxYhXxnj06uuri4EAgEUFhaO+35hYSGOHj064XXeffddPPvss1PmP23YsAE33ngj5s6di5MnT+L//J//g40bN2L37t0wm82QZRmbNm3C3XffjdraWtTX1097rFu3bsUjjzwSycMj0h3RKSWKUsGgrAad8zyKiGZqqk6pjxt6ALAoNRssShHpgBjfG/D44RrxIdMRWjsqRveWl2VDkrgxhogIAIqzE3Otu9vtxq233opnnnkGeXl5k17uy1/+svrfS5cuxbJly1BdXY1du3bhqquuwpNPPgm3240tW7aEfd9btmzB5s2b1T+7XC6UlZXN7IEQaURkSomiVJtrBB5/EBaTpGbCEBFFSgSdt7lG4A8EYVHG9HoHvTjZGerGPJ+b92aMRSkiHUizW5DpsMA14kdr3wgynUpRqrEPAPOkiIiMKC8vD2azGe3t7eO+397eDqfTec7lT548ifr6elx77bXq94LBIADAYrGgrq4O1dXV51yvqqoKeXl5OHHiBK666irs2LEDu3fvht0+vpOstrYWt9xyC7Zt23bObdjt9nMuT2Q0BZlKHIIrlCklNu+V5qSobyKJiCKVl2aHzWyCNxBEm2sEpTmhzstPGkOje1X5aZiTZtPyEA2Nr85EOiE++W8dE3YuNu8t4+Y9IiLDsdlsWLlyJbZv365+LxgMYvv27VizZs05l6+pqcGhQ4ewf/9+9eu6667D2rVrsX///kk7l5qamtDd3Y2ioiIAwBNPPIEDBw6ot/GXv/wFQGgT4L/+67/G4JES6YNTKUp1KJ1S6uheHvOkiGjmTCZJzQBuGbMtXc2TYpfUrLBTikgnirIcONrmVnOlBjx+HOtwAwCWl2ZpeWhERDRDmzdvxu23347a2lqsWrUKjz/+OAYHB3HHHXcAAG677TaUlJRg69atcDgcWLJkybjrZ2dnA4D6/YGBATzyyCO46aab4HQ6cfLkSTz44IOYN28e1q9fDwAoLy8fdxvp6ekAgOrqapSWlsby4RJpqlApSnUPeuHxB9ROqUqGnBPRLJVkp6Che2hcrhRDzqODRSkinSgSnVLKC93h5n7IMlCc5VDb0YmIyFhuvvlmdHZ24rvf/S7a2tqwYsUKvPbaa2r4+ZkzZ2Ayhd+4bjabcfDgQWzbtg19fX0oLi7G1VdfjR/84Accv6Okl5NqVUdsOt0eNHSFOqUqGHJORLMkplqalfdqvkBQnWqprWRRajZYlCLSiSKl8CQ6pUSe1HLmSRERGdp9992H++67b8Kf7dq1a8rrPv/88+P+nJKSgtdffz2i+6+srIQsyxFdh8iIJElCQaYdTb3DaHeNsFOKiKLm7KLU0VY3hn0BZKVYUZWXruWhGR4zpYh0Qu2UEkWpMZv3iIiIiGh6Ileqrd+jZkqxU4qIZqtUea8mxvc+bugBAFxQng2TiVvSZ4NFKSKdKM5SwvOUoHPRDrqcIedEREREYRG5Uoea+zHsC8AkQd2URUQ0U2qnVG/ovRrzpKKHRSkinRjNlBpBh3sEzX3DkCRgKUPOiYiIiMIiilJ7TncDAEpyUmCz8C0PEc1Osbp9bxiyLGOfUpS6gEWpWeMrNJFOiHbzYV8A7xzrAgDML0hHup3Rb0REREThKMwMBf4fbAp1nDNPioiiQXRKDXoDONrmRkv/CMwmiVMtURDzopTH48GKFSsgSRL2798/5WXb2tpw6623wul0Ii0tDRdccAFefvnlWB8ikS6k2MzISbUCAF77tA0AR/eIiIiIIuFU4hD8wVC4P/OkiCgaHFYz8tJtAIA/HWgBACwqykAaGwhmLeZFqQcffBDFxcVhXfa2225DXV0d/vjHP+LQoUO48cYb8aUvfQmffPJJjI+SSB+KskIV+L8e6wTAkHMiIiKiSBRkOMb9mZ1SRBQtolvqTwdDRanaijlaHk7CiGlR6tVXX8Ubb7yBRx99NKzLv//++/ja176GVatWoaqqCv/yL/+C7Oxs7N27N5aHSaQbRcqnex5/EACwgkUpIiIiorCJTimhgkUpIoqSEqUo1dgTCjtnnlR0xKwo1d7ejrvuugu/+tWvkJoaXtvsRRddhJdeegk9PT0IBoN48cUXMTIygiuuuCJWh0mkK0XZoydSNosJC50ZGh4NERERkbGITClhbh7H94goOkSnlMDNe9ERkwFIWZaxadMm3H333aitrUV9fX1Y1/vNb36Dm2++Gbm5ubBYLEhNTcXvf/97zJs3b9LreDweeDwe9c8ul2u2h0+kGTG+BwBLijNhNXMXAREREVG4Um0WZDgscI/4IUlAaQ6LUkQUHWOLUs5MB4rP6sykmYnoHe9DDz0ESZKm/Dp69CiefPJJuN1ubNmyJaKD+c53voO+vj689dZb+Pjjj7F582Z86UtfwqFDhya9ztatW5GVlaV+lZWVRXSfRHpSPKZTinlSRERERJErVDYaF2elwGE1a3w0RJQoSsYUpVZW5ECSJA2PJnFE1Cn1wAMPYNOmTVNepqqqCjt27MDu3btht49vn62trcUtt9yCbdu2nXO9kydP4mc/+xkOHz6M8847DwCwfPlyvPPOO3jqqafw9NNPT3h/W7ZswebNm9U/u1wuFqbIsJyZoy90zJMiIiIiipwz04ETHQPcvEdEUTW2KMU8qeiJqCiVn5+P/Pz8aS/3xBNP4Ic//KH655aWFqxfvx4vvfQSVq9ePeF1hoaGAAAm0/jmLbPZjGAwOOl92e32c4pfREY1tlNqWWm2dgdCREREZFAFSq4UQ86JKJrGvldjnlT0xCRTqry8fNyf09PTAQDV1dUoLS0FADQ3N+Oqq67CCy+8gFWrVqGmpgbz5s3DV7/6VTz66KPIzc3FH/7wB7z55pt45ZVXYnGYRLpTkp2CxUWZSLGZUclP94iIiIgitnZhAV4/3Iaragq0PhQiSiBz0my4bEE+BkZ8OK84U+vDSRgxKUqFw+fzoa6uTu2Qslqt+Mtf/oKHHnoI1157LQYGBjBv3jxs27YN11xzjVaHSRRXFrMJr3ztEkgSOKNMRERENAPXLi/GF5YWwWTiuRQRRY8kSXjhK6u0PoyEE5eiVGVlJWRZnvZ78+fPx8svvxyPQyLSLZ5AEREREc0Oz6eIiIyB++aJiIiIiIiIiCjuWJQiIiIiIiIiIqK4Y1GKiIiIiIiIiIjijkUpIiIiIiIiIiKKOxaliIiIiIiIiIgo7liUIiIiIiIiIiKiuGNRioiIiIiIiIiI4o5FKSIiIiIiIiIiijsWpYiIiIiIiIiIKO5YlCIiIiIiIiIiorhjUYqIiIiIiIiIiOKORSkiIiIiIiIiIoo7FqWIiIiIiIiIiCjuWJQiIiIiIiIiIqK4Y1GKiIiIiIiIiIjizqL1AUSbLMsAAJfLpfGREBERkVGJ8whxXpFseD5FREREsxHuuVTCFaXcbjcAoKysTOMjISIiIqNzu93IysrS+jDijudTREREFA3TnUtJcoJ9BBgMBtHS0oKMjAxIkhT123e5XCgrK0NjYyMyMzOjfvt6x8efvI8/mR87kNyPP5kfO5Dcjz+ZH7ssy3C73SguLobJlHxpBzyfiq1kfvzJ/NiB5H78yfzYgeR+/Mn82IHkffzhnkslXKeUyWRCaWlpzO8nMzMzqf6HOhsff/I+/mR+7EByP/5kfuxAcj/+ZH3sydghJfB8Kj6S+fEn82MHkvvxJ/NjB5L78SfzYweS8/GHcy6VfB/9ERERERERERGR5liUIiIiIiIiIiKiuGNRKkJ2ux0PP/ww7Ha71oeiCT7+5H38yfzYgeR+/Mn82IHkfvzJ/NgptpL9/61kfvzJ/NiB5H78yfzYgeR+/Mn82AE+/ukkXNA5ERERERERERHpHzuliIiIiIiIiIgo7liUIiIiIiIiIiKiuGNRioiIiIiIiIiI4o5FKSIiIiIiIiIiijsWpSbw1FNPobKyEg6HA6tXr8aePXumvPxvf/tb1NTUwOFwYOnSpfjLX/4SpyONrq1bt+LCCy9ERkYGCgoKcMMNN6Curm7K6zz//POQJGncl8PhiNMRR9f3vve9cx5LTU3NlNdJlOe+srLynMcuSRLuvffeCS9v9Of9r3/9K6699loUFxdDkiT84Q9/GPdzWZbx3e9+F0VFRUhJScG6detw/PjxaW830tcOLUz12H0+H/7pn/4JS5cuRVpaGoqLi3HbbbehpaVlytucyb8drUz33G/atOmcx7Jhw4Zpb9fozz2ACV8DJEnCT37yk0lv00jPPcUfz6eS73wqmc+lgOQ6n0rmcykguc+nkvlcCuD5VCywKHWWl156CZs3b8bDDz+Mffv2Yfny5Vi/fj06OjomvPz777+Pv/u7v8Odd96JTz75BDfccANuuOEGHD58OM5HPntvv/027r33XnzwwQd488034fP5cPXVV2NwcHDK62VmZqK1tVX9amhoiNMRR99555037rG8++67k142kZ77jz76aNzjfvPNNwEAf/u3fzvpdYz8vA8ODmL58uV46qmnJvz5v/3bv+GJJ57A008/jQ8//BBpaWlYv349RkZGJr3NSF87tDLVYx8aGsK+ffvwne98B/v27cPvfvc71NXV4brrrpv2diP5t6Ol6Z57ANiwYcO4x/LrX/96yttMhOcewLjH3Nraiueeew6SJOGmm26a8naN8txTfPF8KnnPp5L1XApIrvOpZD6XApL7fCqZz6UAnk/FhEzjrFq1Sr733nvVPwcCAbm4uFjeunXrhJf/0pe+JH/hC18Y973Vq1fLX/3qV2N6nPHQ0dEhA5DffvvtSS/zy1/+Us7KyorfQcXQww8/LC9fvjzsyyfyc/+Nb3xDrq6uloPB4IQ/T6TnHYD8+9//Xv1zMBiUnU6n/JOf/ET9Xl9fn2y32+Vf//rXk95OpK8denD2Y5/Inj17ZAByQ0PDpJeJ9N+OXkz0+G+//Xb5+uuvj+h2EvW5v/766+Urr7xyyssY9bmn2OP51KhkOp/iudR4yXI+lcznUrKc3OdTyXwuJcs8n4oWdkqN4fV6sXfvXqxbt079nslkwrp167B79+4Jr7N79+5xlweA9evXT3p5I+nv7wcAzJkzZ8rLDQwMoKKiAmVlZbj++uvx6aefxuPwYuL48eMoLi5GVVUVbrnlFpw5c2bSyybqc+/1evHf//3f+MpXvgJJkia9XCI972OdPn0abW1t457brKwsrF69etLndiavHUbR398PSZKQnZ095eUi+bejd7t27UJBQQEWLlyIe+65B93d3ZNeNlGf+/b2dvz5z3/GnXfeOe1lE+m5p+jg+dR4yXY+xXOpkGQ+n+K51LmS7XyK51IhPJ8KD4tSY3R1dSEQCKCwsHDc9wsLC9HW1jbhddra2iK6vFEEg0F885vfxMUXX4wlS5ZMermFCxfiueeew//8z//gv//7vxEMBnHRRRehqakpjkcbHatXr8bzzz+P1157DT//+c9x+vRpXHrppXC73RNePlGf+z/84Q/o6+vDpk2bJr1MIj3vZxPPXyTP7UxeO4xgZGQE//RP/4S/+7u/Q2Zm5qSXi/Tfjp5t2LABL7zwArZv344f//jHePvtt7Fx40YEAoEJL5+oz/22bduQkZGBG2+8ccrLJdJzT9HD86lRyXY+xXOpUcl8PsVzqfGS7XyK51KjeD4VHovWB0D6dO+99+Lw4cPTzrKuWbMGa9asUf980UUXYdGiRfjFL36BH/zgB7E+zKjauHGj+t/Lli3D6tWrUVFRgd/85jdhVbcTxbPPPouNGzeiuLh40ssk0vNOE/P5fPjSl74EWZbx85//fMrLJtK/nS9/+cvqfy9duhTLli1DdXU1du3ahauuukrDI4uv5557Drfccsu0gbuJ9NwTxUKynU/xNWEUz6cISM7zKZ5LjeL5VHjYKTVGXl4ezGYz2tvbx32/vb0dTqdzwus4nc6ILm8E9913H1555RXs3LkTpaWlEV3XarXi/PPPx4kTJ2J0dPGTnZ2NBQsWTPpYEvG5b2howFtvvYV/+Id/iOh6ifS8i+cvkud2Jq8deiZOoBoaGvDmm29O+aneRKb7t2MkVVVVyMvLm/SxJNpzDwDvvPMO6urqIn4dABLruaeZ4/lUCM+nkvNcCuD5FM+lQng+FZKM51IAz6ciwaLUGDabDStXrsT27dvV7wWDQWzfvn3cpxhjrVmzZtzlAeDNN9+c9PJ6Jssy7rvvPvz+97/Hjh07MHfu3IhvIxAI4NChQygqKorBEcbXwMAATp48OeljSaTnXvjlL3+JgoICfOELX4joeon0vM+dOxdOp3Pcc+tyufDhhx9O+tzO5LVDr8QJ1PHjx/HWW28hNzc34tuY7t+OkTQ1NaG7u3vSx5JIz73w7LPPYuXKlVi+fHnE102k555mjudTPJ8SkvFcCuD5VLKfSwE8nxorGc+lAJ5PRUTbnHX9efHFF2W73S4///zz8meffSb/4z/+o5ydnS23tbXJsizLt956q/zQQw+pl3/vvfdki8UiP/roo/KRI0fkhx9+WLZarfKhQ4e0eggzds8998hZWVnyrl275NbWVvVraGhIvczZj/+RRx6RX3/9dfnkyZPy3r175S9/+cuyw+GQP/30Uy0ewqw88MAD8q5du+TTp0/L7733nrxu3To5Ly9P7ujokGU5sZ97WQ5tuSgvL5f/6Z/+6ZyfJdrz7na75U8++UT+5JNPZADyY489Jn/yySfqRpQf/ehHcnZ2tvw///M/8sGDB+Xrr79enjt3rjw8PKzexpVXXik/+eST6p+ne+3Qi6keu9frla+77jq5tLRU3r9//7jXAY/Ho97G2Y99un87ejLV43e73fK3vvUteffu3fLp06flt956S77gggvk+fPnyyMjI+ptJOJzL/T398upqanyz3/+8wlvw8jPPcUXz6eS83wq2c+lZDl5zqeS+VxKlpP7fCqZz6VkmedTscCi1ASefPJJuby8XLbZbPKqVavkDz74QP3Z5ZdfLt9+++3jLv+b3/xGXrBggWyz2eTzzjtP/vOf/xznI44OABN+/fKXv1Qvc/bj/+Y3v6n+XRUWFsrXXHONvG/fvvgffBTcfPPNclFRkWyz2eSSkhL55ptvlk+cOKH+PJGfe1mW5ddff10GINfV1Z3zs0R73nfu3Dnh/+viMQaDQfk73/mOXFhYKNvtdvmqq6465++loqJCfvjhh8d9b6rXDr2Y6rGfPn160teBnTt3qrdx9mOf7t+Onkz1+IeGhuSrr75azs/Pl61Wq1xRUSHfdddd55wQJeJzL/ziF7+QU1JS5L6+vglvw8jPPcUfz6eS73wq2c+lZDl5zqeS+VxKlpP7fCqZz6VkmedTsSDJsizPtMuKiIiIiIiIiIhoJpgpRUREREREREREcceiFBERERERERERxR2LUkREREREREREFHcsShERERERERERUdyxKEVERERERERERHHHohQREREREREREcUdi1JERERERERERBR3LEoREREREREREVHcsShFRERERERERERxx6IUERERERERERHFHYtSREREREREREQUdyxKERERERERERFR3P1/wEbI1dCmQUcAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from custom_feature_dir.utils import plot_stats\n", + "\n", + "plot_stats(log_dir)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/debug/custom_feature_dir/custom_feature_example_config.yaml b/docs/debug/custom_feature_dir/custom_feature_example_config.yaml new file mode 100644 index 000000000..ab0369866 --- /dev/null +++ b/docs/debug/custom_feature_dir/custom_feature_example_config.yaml @@ -0,0 +1,15 @@ +stats: + enabled: True + layers: + layer_name_regex_pattern: .* + transformer_engine: + PercentageGreaterThanThreshold: + enabled: True + tensors: [activation] + threshold: 0.1 + freq: 5 + LogTensorStats: + enabled: True + tensors: [activation] + stats: [min] + freq: 5 \ No newline at end of file diff --git a/docs/debug/custom_feature_dir/percentage_greater_than_threshold.py b/docs/debug/custom_feature_dir/percentage_greater_than_threshold.py new file mode 100644 index 000000000..80311ec49 --- /dev/null +++ b/docs/debug/custom_feature_dir/percentage_greater_than_threshold.py @@ -0,0 +1,78 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""PercentageGreaterThanThreshold Feature support for nvidia-dlframework-inspect""" + +from typing import Dict, Optional + +import torch + +from nvdlfw_inspect.registry import Registry, api_method +from nvdlfw_inspect.logging import MetricLogger +import nvdlfw_inspect.api as debug_api + +from transformer_engine.debug.features.api import TEConfigAPIMapper +from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer + + +# Class should inherit from TEConfigAPIMapper and be registered to the transformer_engine namespace. +@Registry.register_feature(namespace="transformer_engine") +class PercentageGreaterThanThreshold(TEConfigAPIMapper): + + @api_method + def inspect_tensor( + self, + config: Dict, + layer_name: str, + tensor_name: str, + iteration: int, + tp_group: torch.distributed.ProcessGroup, + tensor: torch.Tensor, + rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + quantizer: Optional[Quantizer] = None, + ): + # API call inspect_tensor is used to gather the data about the tensor. + # All API calls are documented in the `Precision debug tools / API / Calls to Nvidia-DL-Framework-Inspect` + # section of the documentation. + + threshold = config["threshold"] + + # Get the reduction group from the debug tool + # one can set it using debug_api.set_tensor_reduction_group(group) + reduction_group = debug_api.get_tensor_reduction_group() + + # Compute percentage on local tensor + count = (torch.abs(tensor) > threshold).sum().float() + total = torch.tensor(tensor.numel(), dtype=torch.float32, device=tensor.device) + + # Perform reduction across the group if needed. + # Note that we perform all_reduce twice per every tensor, which is suboptimal. + # For guidance on implementing efficient statistics reduction, see the implementation in the `LogTensorStats` feature. + # In this tutorial we only showcase basic implementation of the feature. + if reduction_group is not None: + torch.distributed.all_reduce(count, group=reduction_group) + torch.distributed.all_reduce(total, group=reduction_group) + + percentage = count / total + + # MetricLogger is a class from nvidia-dlframework-inspect. + # By using it we can also use functionalities provided by nvidia-dlframework-inspect, + # like logging to TensorBoard, etc. + MetricLogger.log_scalar( + f"{layer_name}_{tensor_name}_percentage_greater_than_threshold", percentage, iteration + ) + + @api_method + def inspect_tensor_enabled( + self, config: Dict, layer_name: str, tensor_name: str, iteration: int + ): + # This call is used by TE to determine if the unfused debug layer - which is slower - needs to be run. + # It returns a tuple (bool, int), where the int indicates the next iteration when the feature will be enabled + # and bool indicates if the feature should be enabled at the current iteration. + + run_current = iteration % config["freq"] == 0 + # run in next multiple of freq + next_iter = iteration + (config["freq"] - iteration % config["freq"]) + return run_current, next_iter diff --git a/docs/debug/custom_feature_dir/utils.py b/docs/debug/custom_feature_dir/utils.py new file mode 100644 index 000000000..cc954b12b --- /dev/null +++ b/docs/debug/custom_feature_dir/utils.py @@ -0,0 +1,48 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Utils for plotting stats in the tutorial""" + + +import os +import re +import matplotlib.pyplot as plt + + +def plot_stats(log_dir): + + # print and plot the stats + stat_file = os.path.join( + log_dir, "nvdlfw_inspect_statistics_logs", "nvdlfw_inspect_globalrank-0.log" + ) + + min_values = [] + custom_feature_values = [] + + with open(stat_file, "r") as f: + number_pattern = re.compile(r"[-+]?\d*\.\d+|\d+") + + for line in f: + if "min" in line: + matches = number_pattern.findall(line) + if matches: + min_values.append(float(matches[-1])) + if "percentage_greater_than_threshold" in line: + matches = number_pattern.findall(line) + if matches: + custom_feature_values.append(float(matches[-1])) + + # plot 2 figures side by side + fig, axs = plt.subplots(1, 2, figsize=(12, 5)) + + axs[0].plot(min_values, label="min") + axs[0].legend() + axs[0].set_title("Min values") + + axs[1].plot(custom_feature_values, label="percentage_greater_than_threshold_0.1") + axs[1].legend() + axs[1].set_title("Percentage greater than threshold 0.1 values") + + plt.tight_layout() + plt.show() diff --git a/docs/envvars.rst b/docs/envvars.rst index 86b313b13..85445430f 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -72,6 +72,12 @@ Build Configuration :Default: Not set :Description: Internal flag set to ``1`` during the build process to indicate that the project is being built. Not intended for external use. +.. envvar:: NVTE_BUILD_NUM_PHILOX_ROUNDS + + :Type: ``int`` (positive integer) + :Default: ``10`` + :Description: Number of Philox4x32 rounds used by stochastic rounding kernels. Must be a positive integer. + Optional Dependencies ^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 4b2ed8049..e7253415d 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -151,6 +151,7 @@ "- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n", "- flash-attention supports KV-caching and paged attention, and cuDNN attention does not.\n", "- flash-attention uses bottom right diagonal for `causal` mask in cross attention (see [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)), and cuDNN attention supports both top left and bottom right.\n", + "- **Sliding window attention (SWA):** flash-attention has SWA(left, right) support for all mask types except top-left causal masks, with or without dropout, and without bias. cuDNN attention supports SWA(left, 0) starting from 9.2 and SWA(left, right) starting from 9.6, without dropout, and with `bias_type=\"no_bias\"`.\n", "- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n", "\n", "To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0." @@ -389,7 +390,7 @@ "\n", "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n", - "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n", + "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | Yes (cuDNN 9.2+) | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n", "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes | Yes (`bshd`,`thd`) | Yes |\n", "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "\n", diff --git a/docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb b/docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb index 56bc3b13c..338ce7fdd 100644 --- a/docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb +++ b/docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb @@ -28,6 +28,7 @@ "source": [ "### Question 1: Why choose Striped>1 ?\n", "\n", + "\n", "Prior to the addition of this feature, Transformer Engine JAX attention already supported load balancing via a striping pattern, i.e., `stripe_size=1` for `CP + THD + P2P(Ring) + Striped + SWA`. However, this reordering technique does not lend itself well to an all-gathered (post-AG) pattern. The following example illustrates this distinction. For this example, `cp_size=4`, `num_segments=4`, `window_size=(8,0)`, and the pattern is for a single rank after striped reordering has been performed: \n", "\n", "#### I. Striped (`stripe_size=1`)\n", diff --git a/docs/examples/op_fuser/fp8_layernorm_linear.png b/docs/examples/op_fuser/fp8_layernorm_linear.png new file mode 100644 index 000000000..b5916a615 Binary files /dev/null and b/docs/examples/op_fuser/fp8_layernorm_linear.png differ diff --git a/docs/examples/op_fuser/layernorm_mlp.png b/docs/examples/op_fuser/layernorm_mlp.png new file mode 100644 index 000000000..f388c88fa Binary files /dev/null and b/docs/examples/op_fuser/layernorm_mlp.png differ diff --git a/docs/examples/op_fuser/op_fuser.rst b/docs/examples/op_fuser/op_fuser.rst new file mode 100644 index 000000000..9613ba74b --- /dev/null +++ b/docs/examples/op_fuser/op_fuser.rst @@ -0,0 +1,353 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Operation fuser API +=================== + +Motivation +---------- + +Transformer Engine relies heavily on operation fusion to achieve high +performance. A typical training workload involves many memory-bound +operations such as activation functions and normalization, so +replacing them with fused kernels can deliver a significant +performance benefit. This is especially true for low-precision +training (e.g. FP8 and FP4) because it involves extra cast operations. + +Managing these fusions can be challenging because they differ based on +operation types, communication patterns, data types, and GPU +architectures. The most straightforward solution is to provide +monolithic modules like ``Linear``, ``LayerNormLinear``, or +``TransformerLayer``. These conform to the interface of a standard +PyTorch module, but can perform arbitrary fusions internally. These +hand-tuned implementations can achieve maximum performance, but they +tend to be complicated and difficult to modify. + +As an alternative to this "top-down" design, TE exposes a "bottom-up" +operation-based API. The user constructs individual operations and +passes them into a fuser, resulting in the same fused kernels as the +monolithic modules. This approach is more flexible, making it easier +to support new model architectures or to experiment with fusions. + +Basic usage +----------- + +Sequential operations +^^^^^^^^^^^^^^^^^^^^^ + +At the most basic level, the operation fuser API involves two classes +in the ``transformer_engine.pytorch.ops`` submodule: + +- ``FusibleOperation``: An abstract base class for tensor operations. + Examples include ``Linear``, ``LayerNorm``, and ``AllReduce``. It is + a subclass of ``torch.nn.Module``, so it can hold trainable + parameters and can be called to perform the operation's forward + pass. +- ``Sequential``: A container of modules in sequential order. Its + interface is very similar to ``torch.nn.Sequential``. If it contains + any ``FusibleOperation`` s, then it may attempt to fuse them in the + forward and backward passes. + +Thus, using the operation fuser simply involves constructing +``FusibleOperation`` s and passing them into a ``Sequential``. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + # Options + hidden_size = 4096 + ffn_size = 28672 + batch_size = 16384 + + # Construct operations and fuse + mlp = te.ops.Sequential( + te.ops.LayerNorm(hidden_size), + te.ops.Linear(hidden_size, ffn_size), + te.ops.SwiGLU(), + te.ops.Linear(ffn_size // 2, hidden_size), + ) + + # Forward pass + x = torch.randn(batch_size, hidden_size, device="cuda") + y = mlp(x) + +.. figure:: ./layernorm_mlp.png + :align: center + + Operations that match ``LayerNormMLP`` module. Note that different + fusions have been applied in the forward and backward passes. + +Quantization +^^^^^^^^^^^^ + +The operation fuser respects TE's APIs for low-precision ("quantized") +data formats like FP8 and FP4. Constructing operations within a +``quantized_model_init`` context will enable quantized weights and +performing the forward pass within an ``autocast`` context will enable +quantized compute. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + # Construct layer with quantized weights + with te.quantized_model_init(): + fc1 = te.ops.Sequential( + te.ops.LayerNorm(4096), + te.ops.Linear(4096, 28672), + ) + + # Forward pass within autocast context + x = torch.randn(16384, 4096, device="cuda") + with te.autocast(): + y = fc1(x) + + # Backward pass outside of autocast context + y.sum().backward() + +Branching operations +^^^^^^^^^^^^^^^^^^^^ + +The operation fuser supports very limited branching behavior. While +the operations must be in sequential order, some operations can accept +extra inputs or produce extra outputs. For example, ``AddExtraInput`` +will add an extra input tensor to the intermediate tensor and +``MakeExtraOutput`` will return the intermediate tensor as an extra +output. When calling a ``Sequential`` that contains any of these +branching operations, the extra inputs should be passed in as +arguments and the extra outputs will be returned. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + # Construct MLP with residual connection + fc1 = te.ops.Sequential( + te.ops.LayerNorm(4096), + te.ops.MakeExtraOutput(), # Output residual + te.ops.Linear(4096, 28672), + te.ops.SwiGLU(), + ) + fc2 = te.ops.Sequential( + te.ops.Linear(14336, 4096), + te.ops.AddExtraInput(), # Add residual + ) + + # Forward pass + x = torch.randn(16384, 4096, device="cuda") + y, residual = fc1(x) + y = fc2(y, residual) + +.. figure:: ./residual_layernorm_mlp.png + :align: center + + Operations for an MLP block with a residual connection. Note that + the block has been split into two sections, each with one branching + operation. + +Developer guide +--------------- + +Infrastructure +^^^^^^^^^^^^^^ + +In addition to ``FusibleOperation`` and ``Sequential``, the fuser +infrastructure relies on the following classes: + +- ``BasicOperation``: The most basic type of ``FusibleOperation``. + Examples include ``BasicLinear``, ``Bias``, and ``ReLU``. It holds + parameters and state, and it implements both a forward and backward + pass. The ``op_forward`` and ``op_backward`` functions have an + interface reminiscent of ``torch.autograd.Function``, e.g. they + accept a context object that caches state from the forward pass to + the backward pass. +- ``FusedOperation``: A ``FusibleOperation`` that can replace one or + more ``BasicOperation`` s. Examples include + ``ForwardLinearBiasActivation`` and ``BackwardActivationBias``. Its + forward and backward passes (the ``fuser_forward`` and + ``fuser_backward`` functions) must produce equivalent results as its + corresponding ``BasicOperation`` s. This also means that the + ``FusedOperation`` is stateless since it can access parameters and + state from the ``BasicOperation`` s. Note that different fusions may + be applied in the forward and backward pass, so a ``FusedOperation`` + may be missing its forward and/or backward implementation. +- ``OperationFuser``: This is the class that manages the operation + fusions. It launches the forward and backward passes within a + ``torch.autograd.Function``. It can also replace operations with + equivalent ``FusedOperation`` s. + +The first time that a ``Sequential`` is called, it will group adjacent +``FusibleOperation`` s together into ``OperationFuser`` s. The first +time an ``OperationFuser`` is called, it will attempt to fuse +operations for the forward pass and backward pass. Subsequent calls +will reuse the same state unless it has been invalidated, e.g. by +changing the quantization recipe. + +Quantization +^^^^^^^^^^^^ + +Each operation that supports quantized compute holds one or more +``Quantizer`` s, which are builder classes for converting +high-precision tensors (e.g. in FP32 or BF16) to quantized tensors. In +order to enable fused quantization kernels, operations can access the +quantizers of neighboring operations and quantize eagerly. + +.. figure:: ./fp8_layernorm_linear.png + :align: center + + Operations that match ``LayerNormLinear`` module with FP8 + quantization. + +In some situations, like when operations are split across multiple +``Sequential`` s, it may be helpful to encourage the fuser by manually +adding ``Quantize`` operations. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + # Construct layer with quantized weights + with te.quantized_model_init(): + norm = te.ops.Sequential( + te.ops.LayerNorm(4096), + te.ops.Quantize(), + ) + fc1 = te.ops.Sequential( + te.ops.Linear(4096, 28672), + ) + + # Forward pass + x = torch.randn(16384, 4096, device="cuda") + with te.autocast(): + y = norm(x) # y is a QuantizedTensor + z = fc1(y) + +.. warning:: + + This is an expert technique. Quantizer configurations can be quite + complicated, so the ``Quantize`` operation's quantizers may be + suboptimal. + +Implementing new operations +--------------------------- + +Implementing a basic operation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Subclasses of ``BasicOperation`` must implement ``op_forward`` and +``op_backward``, which are reminiscent of the ``forward`` and +``backward`` methods of ``torch.autograd.Function``. They have an +argument for a context object that can be used to cache state from the +forward pass for use in the backward pass. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + + class LearnableScale(te.ops.BasicOperation): + + def __init__(self) -> None: + super().__init__() + scale = torch.ones((), dtype=torch.float32, device="cuda") + self.register_parameter("scale", torch.nn.Parameter(scale)) + + def op_forward(self, ctx, input_: torch.Tensor, **unused) -> torch.Tensor: + out = self.scale * input_ + ctx.save_for_backward(self.scale, input_) + return out + + def op_backward( + self, + ctx, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]: + scale, input_ = ctx.saved_tensors + grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1)).reshape(()) + grad_input = scale * grad_output + return ( + grad_input, # Input gradient + (grad_scale,), # Param gradients + ) + +Implementing a fused operation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Subclasses of ``FusedOperation`` should declare their corresponding +``BasicOperation`` s in the constructor. They should also implement +``fuser_forward`` and ``fuser_backward``, depending on usage. These +functions are similar to ``op_forward`` and ``op_backward`` from +``BasicOperation``, but some arguments and returns are lists. For +example, instead of taking a single context object, they take a list +of context objects for all the corresponding ``BasicOperation`` s. + +.. code-block:: python + + import torch + import transformer_engine.pytorch as te + from typing import Optional + + class ForwardAxpy(te.ops.FusedOperation): + + def __init__(self, scale: te.ops.ConstantScale, add: te.ops.AddExtraInput) -> None: + super().__init__((scale, add)) # Equivalent basic ops + + def fuser_forward( + self, + basic_op_ctxs: list, + input_: torch.Tensor, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + **unused, + ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, ...]]]: + scale_op, add_op = self.basic_ops + extra_input = basic_op_extra_inputs[1][0] # Extra input to add op + out = scale_op.scale * input_ + extra_input + scale_ctx, add_ctx = basic_op_ctxs # No state needed for backward + return ( + out, # Output + [(), ()], # Extra outputs for each basic op + ) + +.. warning:: + + Remember the contract that the fused operation must produce outputs + that are interchangeable with the corresponding basic operation + outputs. + +In order to make these fused operations useful, they should be +registered with the operation fuser. To do this, first implement a +fusion function that can replace operations with the fused operation, +and then register it with the ``register_forward_fusion`` or +``register_backward_fusion`` functions. + +.. code-block:: python + + def fuse_axpy_ops( + ops: list[te.ops.FusibleOperation], + **unused, + ) -> list[te.ops.FusibleOperation]: + """Sliding window scan to perform ForwardAxpy fusion""" + out = [] + window, ops = ops[:2], ops[2:] + while len(window) == 2: + if ( + isinstance(window[0], te.ops.ConstantScale) + and isinstance(window[1], te.ops.AddExtraInput) + ): + window = [ForwardAxpy(window[0], window[1])] + else: + out.append(window[0]) + window = window[1:] + window, ops = window + ops[:1], ops[1:] + out.extend(window + ops) + return out + + # Register fusion with operation fuser + te.ops.register_forward_fusion(fuse_axpy_ops) diff --git a/docs/examples/op_fuser/residual_layernorm_mlp.png b/docs/examples/op_fuser/residual_layernorm_mlp.png new file mode 100644 index 000000000..fa95114a6 Binary files /dev/null and b/docs/examples/op_fuser/residual_layernorm_mlp.png differ diff --git a/docs/examples/te_llama/requirements.txt b/docs/examples/te_llama/requirements.txt new file mode 100644 index 000000000..093849001 --- /dev/null +++ b/docs/examples/te_llama/requirements.txt @@ -0,0 +1,5 @@ +transformers==4.57.0 +accelerate==1.10.0 +peft==0.15.2 +datasets==4.0.0 +sentencepiece==0.2.1 diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index b2d4d183a..6dfa9b67b 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -72,10 +72,15 @@ def forward(self, hidden_states, *args, attention_mask, **kwargs): forward pass of the `TransformerLayer`. Also, make sure the output format matches the output of the HF's `LlamaDecoderLayer`. """ - return ( - super().forward( - hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb - ), + # Handle case where hidden_states might be a tuple (from previous layer output) + # This can happen with older versions of HuggingFace transformers + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + # Return tensor directly for HuggingFace transformers >= 4.57 + # (older versions wrapped output in tuple and extracted with layer_outputs[0]) + return super().forward( + hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb ) @@ -162,7 +167,7 @@ def replace_params(hf_state_dict, te_state_dict, config): # collect all layer prefixes to update all_layer_prefixes = set() for param_key in hf_state_dict.keys(): - layer_prefix_pat = "model.layers.\d+." + layer_prefix_pat = r"model.layers.\d+." m = re.match(layer_prefix_pat, param_key) if m is not None: all_layer_prefixes.add(m.group()) diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index 00499cff5..ac9252ff1 100644 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -1,763 +1,784 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "6a5b2993", - "metadata": {}, - "source": [ - "# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n", - "\n", - "
\n", - "\n", - "Goal\n", - "\n", - "This tutorial showcases how to accelerate finetuning a full [Llama 2](https://huggingface.co/meta-llama/Llama-2-7b-hf) or [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B) models from Hugging Face by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n", - "\n", - "
\n" - ] - }, - { - "cell_type": "markdown", - "id": "331f476a", - "metadata": {}, - "source": [ - "## Dependencies for this tutorial\n", - "\n", - "Following files and media are necessary to effectively run this tutorial:\n", - "\n", - "1. `te_llama.py`\n", - " - This file contains the code to load a Hugging Face Llama 2 or Llama 3 checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n", - "2. `utils.py`\n", - " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", - "3. `media/`\n", - " - This directory contains the images used in the following tutorial.\n", - "\n", - "These packages are necessary to run this tutorial:\n", - "`pytorch`, `transformer_engine`, `accelerate`, `transformers`, `peft`, `datasets`.\n", - "\n", - "\n", - "
\n", - "\n", - "Note on running the tutorial with Llama 3 weights\n", - "\n", - "This tutorial shows the cell outputs when run with Llama 2 7B weights. It can be run with Llama 3 8B weights simply by providing the directory with those weights (in Hugging Face format) instead of Llama 2 7B weights. These two models are almost identical, the biggest difference being the model dimension (the smallest Llama 3 model has 8B parameters, whereas the smallest Llama 2 has 7B), which enables this tutorial to work for both of them.\n", - "\n", - "
\n" - ] - }, - { - "cell_type": "markdown", - "id": "44abae4f", - "metadata": {}, - "source": [ - "## Table of contents\n", - "1. From \"Transformer\" to \"Llama\"\n", - "2. Hugging Face's `LlamaModel`\n", - " - Hugging Face's `LlamaDecoderLayer`\n", - "3. [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n", - "6. [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", - " - Transformer Engine's `TransformerLayer`\n", - " - `TransformerLayer` options explained\n", - " - Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n", - "7. [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", - "8. Conclusion" - ] - }, - { - "cell_type": "markdown", - "id": "e37e2cc1", - "metadata": {}, - "source": [ - "## From \"Transformer\" to \"Llama\" \n", - "\n", - "
\n", - "\n", - "
Fig 1: Llama visualized as a transformer. (generated with [Nvidia's AI-foundation models](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/sdxl))
\n", - "
\n", - "\n", - "A flashback:\n", - "\n", - "- 2017: [\"Attention Is All You Need\"](https://arxiv.org/abs/1706.03762) paper introduced pioneering \"Transformer\" architecture and changed the NLP field forever.\n", - "- 2018-2020: Emergence of GPT model series that showed causal decoder architectures are great fit for pretraining, few-shot and zero-shot learning.\n", - "- Fast forward to 2023-2024: Following GPT-3/GPT-4 success stories, researchers and companies raced to produce the next best pretrained model that could further be finetuned for application-specific use-cases.\n", - "- February 2023: Meta releases [Llama 2](https://llama.meta.com/llama2) models (Large Language Model Meta AI). \n", - " - These models range from 7B to 70B parameters.\n", - " - LLaMA 2 was pretrained on 2 trillion tokens.\n", - "- April 2024: Meta releases [Llama 3](https://llama.meta.com/llama3) models.\n", - " - These models range from 8B to 70B parameters.\n", - " - LLaMA 3 was pretrained on 15 trillion tokens.\n", - "\n", - "For more information on Llama 2 consider reading the [Huggingface tutorial](https://huggingface.co/blog/llama2). As a quick summary, here are some of the important differences b/w the conventional transformer decoder architecture vs Llama 2 architecture:\n", - "\n", - "1. Decoder only model (causal language modeling and next word prediction)\n", - "2. RMSNorm in place of the LayerNorm\n", - "3. SwiGLU activation function\n", - "4. RoPE as positional embeddings \n", - "5. Grouped Query Attention for the 70B model\n", - "6. Trained on 4K context length\n", - "\n", - "Hugging Face also released a [tutorial about Llama 3](https://huggingface.co/blog/llama3). The key points are:\n", - "\n", - "1. Use of bigger tokenizer - 128256 vs 32K.\n", - "2. Grouped Query Attention is used also by smaller 8B model.\n", - "3. The context length increased to 8K for all models.\n", - "3. Llama 3 was trained on 8x more data than Llama 2.\n", - "\n", - "
\n", - "\n", - "
Fig 2: Comparing GPT and Llama architectures.
\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "a110de1a", - "metadata": {}, - "source": [ - "## Hugging Face's `LlamaModel`\n", - "Hugging Face provides an open-source implementation of `Llama` model in [modeling_llama.py](https://github.com/huggingface/transformers/blob/3d2900e829ab16757632f9dde891f1947cfc4be0/src/transformers/models/llama/modeling_llama.py#L4).\n", - "\n", - "Here's a block diagram that shows how Llama model is implemented in the Hugging Face repo. Notice the modular encapsulated form and `LlamaDecoderLayer` at the core of the model implementation.\n", - "\n", - "
\n", - "\n", - "
Fig 3: Causal Llama Model Block Diagram.
\n", - "
\n", - "\n", - "The above diagram translates to the following text output of the model in PyTorch. Notice that the core of the model has 32 `LlamaDecoderLayer`s. \n", - "\n", - "```\n", - "LlamaForCausalLM(\n", - " (model): LlamaModel(\n", - " (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n", - " (layers): ModuleList(\n", - " (0-31): 32 x LlamaDecoderLayer(\n", - " (self_attn): LlamaFlashAttention2(\n", - " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", - " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", - " (rotary_emb): LlamaRotaryEmbedding()\n", - " )\n", - " (mlp): LlamaMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n", - " (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): LlamaRMSNorm()\n", - " (post_attention_layernorm): LlamaRMSNorm()\n", - " )\n", - " )\n", - " (norm): LlamaRMSNorm()\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n", - ")\n", - "```\n", - "\n", - "#### Hugging Face's `LlamaDecoderLayer`\n", - "\n", - "Let's take a closer look at `LlamaDecoderLayer`. It is composed of `input_layernorm`, `self_attn`, `post_attention_layernorm` and `mlp` modules. Each module has associated weights as shown in the diagram.\n", - "\n", - "
\n", - "\n", - "
Fig 4: Causal Llama Model Block Diagram (with simplified illustration of the [LlamaDecoderLayer](https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/llama/modeling_llama.py#L695)).
\n", - "
\n", - "\n", - "##### Self_Attn Layer\n", - "For simplicity in the block diagram illustration of the \"self_attn\" box, we omit the \"Grouped Query Attention\" operation and only showcase the modules which have associated weights.\n", - " \n", - "##### MLP Layer\n", - "\n", - "SwiGLU is an activation defined as follows in the [modeling_llama.py](https://github.com/huggingface/transformers/blob/7c4995f93d8d24aae05e1e43279c96dce736e5c8/src/transformers/models/llama/modeling_llama.py#L236) file in the Hugging Face github repo:\n", - "```\n", - "\"\"\"\n", - "1. `self.up_proj`, `self.gate_proj` and `self.down_proj` are \"Linear\" layers\n", - "2. `self.act_fn` is a \"Swish\" function\n", - "\n", - "\"\"\"\n", - "down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n", - "```\n", - "It requires a set of 3 weights as compared to 2 weights in conventional \"MLP\" layers e.g. in the traditional transformer or GPT architectures. This is also illustrated in the following figure:\n", - "\n", - "
\n", - "\n", - "
Fig 5: A look inside the feedforward layer with swiglu activation function.
\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "c9529229", - "metadata": {}, - "source": [ - "## [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n", - "\n", - "Llama 2 weights are loaded into the Hugging Face native implementation `LlamaForCausalLM` (refer to [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)). \n", - "\n", - "For this and other subsequent runs, the `batch_size` is `8`. The `LlamaDecoderLayer` is left unchanged in the baseline as follows:\n", - "\n", - "
\n", - "\n", - "
Fig 6: Revisiting \"LlamaDecoderLayer\".
\n", - "
\n", - "\n", - "
\n", - "Note\n", - "\n", - "The baseline implementation will be run in `BF16` precision.\n", - "\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "b38eb3ac", - "metadata": {}, - "source": [ - "
\n", - "\n", - "Note\n", - " \n", - "This tutorial loads and trains a Llama 3 8B or a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", - "\n", - "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", - "\n", - "
\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "2e9d7a8c", - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "10 finetuning steps complete!\n", - "Average time taken per step: 248 milliseconds\n" - ] - } - ], - "source": [ - "# Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "restart_jupyter_notebook()\n", - "\n", - "\n", - "# Import necessary packages, methods and variables\n", - "from utils import *\n", - "\n", - "\n", - "# Provide Huggingface Access Token\n", - "hyperparams.hf_access_token = \"\"\n", - "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", - "\n", - "# Provide a directory to cache weights in to avoid downloading them every time.\n", - "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", - "hyperparams.weights_cache_dir = \"\"\n", - "\n", - "# For Llama 2, uncomment this line (also set by default)\n", - "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", - "\n", - "# For Llama 3, uncomment this line\n", - "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", - "\n", - "hyperparams.mixed_precision = \"bf16\"\n", - "\n", - "\n", - "# Init the model and accelerator wrapper\n", - "model = init_baseline_model(hyperparams)\n", - "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", - "\n", - "\n", - "# Finetune the model\n", - "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" - ] - }, - { - "cell_type": "markdown", - "id": "4035ccb7", - "metadata": {}, - "source": [ - "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", - "\n", - "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 248 | 1 |" - ] - }, - { - "cell_type": "markdown", - "id": "3db90dff", - "metadata": {}, - "source": [ - "## [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", - "\n", - "In addition to basic layers like `Linear` and `LayerNorm`, Transformer Engine offers larger modules like `MultiheadAttention` (combines \"LayerNorm\" and \"Self Attention\") and `LayerNormMLP` (combines \"LayerNorm\" and \"MLP\") that could replace their counterparts in the `LlamaDecoderLayer` and potentially provide a speedup. Transformer Engine also offers a full `TransformerLayer` (which further combines `MultiheadAttention` and `LayerNormMLP` layers) which could replace `LlamaDecoderLayer` and provide a speedup (with careful mapping of the weights since the name of the weights are different for those two layers). Let's take a closer look at Transformer Engine's `TransformerLayer`. \n", - "\n", - "#### Transformer Engine's `TransformerLayer`\n", - "\n", - "At a higher level, TE's `TransformerLayer` could be visualized as an apt replacement for the `LlamaDecoderLayer`. But the internals of the `TransformerLayer` are organized a bit differently. \n", - "\n", - "
\n", - "\n", - "
Fig 7: Transformer Engine's `TransformerLayer`
\n", - "
\n", - "\n", - "Just like Hugging Face's `LlamaDecoderLayer`, Transformer Engine's `TransformerLayer` encapsulates `self_attention` (as `MultiheadAttention`) and `mlp` (as `LayerNormMLP`). A major difference is that the two `Norm`s are included in the `MultiheadAttention` and `LayerNormMLP` layers as shown in the following output prompt:\n", - "\n", - "```\n", - "TransformerLayer(\n", - " (self_attention): MultiheadAttention(\n", - " (layernorm_qkv): LayerNormLinear()\n", - " (core_attention): DotProductAttention()\n", - " (proj): Linear()\n", - " )\n", - " (layernorm_mlp): LayerNormMLP()\n", - ")\n", - "```\n", - "\n", - "Another difference is that Transformer Engine implements an efficient version of feedforward layer with SwiGLU in which the weights from the `up_proj` and `gate_proj` modules are merged together and SwiGLU is applied using a custom fused kernel. This is done so that only one big and efficient Matrix Multiplication operation is issued to the GPU instead of two smaller ones.\n", - "\n", - "
\n", - "\n", - "
Fig 8: Abstract illustration of the SwiGLU implementation in Transformer Engine.
\n", - "
\n", - "\n", - "#### `TransformerLayer` options explained\n", - "\n", - "
\n", - "\n", - "Note\n", - " \n", - "Here, we go over some of the options in `TransformerLayer` that are needed for the tutorial. For a complete list of options, refer the [TransformerLayer API documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html?highlight=transformerlayer#transformer_engine.pytorch.TransformerLayer).\n", - "\n", - "
\n", - "\n", - "In the accompanying `te_llama.py` file, `TELlamaDecoderLayer` is defined as a wrapper over TE's `TransformerLayer` with a few needed options that make `TransformerLayer` a plug-in replacement for the HF's `LlamaDecoderLayer`.\n", - "\n", - "```\n", - "class TELlamaDecoderLayer(te.pytorch.TransformerLayer):\n", - " def __init__(self, config):\n", - " super().__init__(\n", - " config.hidden_size,\n", - " config.intermediate_size,\n", - " config.num_attention_heads,\n", - " bias=False,\n", - " layernorm_epsilon=config.rms_norm_eps,\n", - " hidden_dropout=0,\n", - " attention_dropout=0,\n", - " fuse_qkv_params=False,\n", - " normalization=\"RMSNorm\",\n", - " activation=\"swiglu\",\n", - " attn_input_format=\"bshd\",\n", - " num_gqa_groups=config.num_key_value_heads,\n", - " )\n", - " te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)\n", - " self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()\n", - "```\n", - "\n", - "Here's a list summarizing each option briefly:\n", - "\n", - "1. `hidden_size`: size of each input sample.\n", - "2. `ffn_hidden_size`: intermediate size to which samples are projected.\n", - "3. `num_attention_heads`: number of attention heads in the transformer layer.\n", - "4. `bias`: switch to add additive biases to the submodule layers.\n", - "5. `layernorm_epsilon`: a value added to the denominator of layer normalization for numerical stability. Default is `1e-5`.\n", - "6. `hidden_dropout`: dropout probability for the dropout op after FC2 layer (fully connected layer no. 2). Default is `0.1`.\n", - "7. `attention_dropout`: dropout probability for the dropout op during multi-head attention. Default is `0.1`. \n", - "8. `fuse_qkv_params`: if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.\n", - "9. `normalization`: type of normalization applied. Default is `LayerNorm`.\n", - "10. `activation`: type of activation used in the MLP block. Default is `gelu`.\n", - "11. `attn_input_format`: controls whether the dimensions of the intermediate hidden states is 'batch first' ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, `b` batch size, `h` the number of heads, `d` head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules.\n", - "12. `num_gqa_groups`: number of GQA groups in the transformer layer. Grouped Query Attention is described in [this paper](https://arxiv.org/pdf/2305.13245.pdf). This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention ([MQA](https://arxiv.org/pdf/1911.02150.pdf)), while GQA-H is equivalent to MultiHead Attention, i.e. `num_gqa_groups = num_attention_heads`.\n", - "\n", - "\n", - "Further, note that `RotaryPositionEmbedding` is defined as part of the `TELlamaDecoderLayer` (wrapper around TE's `TransformerLayer`) itself since it expects this rope cache if RoPE is used in the model. \n", - "\n", - "Let's revisit how `LlamaDecoderLayer`s form the core of the decoder layer stack in HF's llama implementation:\n", - "```\n", - "ModuleList(\n", - " (0-31): 32 x LlamaDecoderLayer(\n", - " (self_attn): LlamaAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", - " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", - " (rotary_emb): LlamaRotaryEmbedding()\n", - " )\n", - " (mlp): LlamaMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n", - " (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): LlamaRMSNorm()\n", - " (post_attention_layernorm): LlamaRMSNorm()\n", - " )\n", - ")\n", - "```\n", - "\n", - "A major portion of the Hugging Face model implementation (32 `LlamaDecoderLayer` layers) could be potentially replaced with Transformer Engine's `TransformerLayer` layers. Let's see how it is made possible.\n", - "\n", - "\n", - "#### Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n", - "\n", - "Refer the accompanying file `te_llama.py` which provides a reference to create a Llama 2 model with TE's `TransformerLayer` after replacing HF's `LlamaDecoderLayer`.\n", - "\n", - "Briefly, following pieces of code are put together:\n", - "\n", - "1. `TELlamaDecoderLayer` is added as a wrapper for `TransformerLayer`. \n", - "```\n", - "class TELlamaDecoderLayer(te.pytorch.TransformerLayer):\n", - " \"\"\"\n", - " Wrapper class over TE's `TransformerLayer`. This makes the wrapper very\n", - " similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.\n", - "\n", - " Args:\n", - " config: LlamaConfig\n", - " args: positional args (for compatibility with `LlamaDecoderLayer`)\n", - " kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)\n", - " \"\"\"\n", - " def __init__(self, config, *args, **kwargs):\n", - " super().__init__(\n", - " hidden_size=config.hidden_size,\n", - " ffn_hidden_size=config.intermediate_size,\n", - " num_attention_heads=config.num_attention_heads,\n", - " bias=False,\n", - " layernorm_epsilon=config.rms_norm_eps,\n", - " hidden_dropout=0,\n", - " attention_dropout=0,\n", - " fuse_qkv_params=False,\n", - " normalization=\"RMSNorm\",\n", - " activation=\"swiglu\",\n", - " attn_input_format=\"bshd\",\n", - " )\n", - " te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)\n", - " self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()\n", - "\n", - " def forward(self,\n", - " hidden_states,\n", - " *args,\n", - " attention_mask,\n", - " **kwargs):\n", - " \"\"\"\n", - " Custom forward to make sure we only pass relevant arguments to the\n", - " forward pass of the `TransformerLayer`. Also, make sure the output\n", - " format matches the output of the HF's `LlamaDecoderLayer`.\n", - " \"\"\"\n", - " return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),)\n", - "```\n", - "\n", - "2. Before creating a `LlamaForCausalLM`, `replace_decoder` context manager is used to monkey-patch `LlamaDecoderLayer` with `TELlamaDecoderLayer`.\n", - "\n", - "```\n", - "@contextmanager\n", - "def replace_decoder(te_decoder_cls):\n", - " \"\"\"\n", - " Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.\n", - " \"\"\"\n", - " original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer\n", - " transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls\n", - " try:\n", - " yield\n", - " finally:\n", - " transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls\n", - ".\n", - ".\n", - ".\n", - "class TELlamaForCausalLM:\n", - " \"\"\"\n", - " Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`\n", - " class is monkey-patched with `TELlamaDecoderLayer` class before\n", - " initializing the causal LM with `LlamaForCausalLM`.\n", - "\n", - " Args:\n", - " config: LlamaConfig\n", - " \"\"\"\n", - "\n", - " def __new__(cls, config: LlamaConfig):\n", - " with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):\n", - " llama_for_causal_lm = LlamaForCausalLM(config)\n", - " return llama_for_causal_lm\n", - ".\n", - ".\n", - ".\n", - "```\n", - "\n", - "3. A custom `pretrained_from_local` method is added that copies the weights from the checkpoint (which is meant for HF Llama implementation) to the modified `TELlamaForCausalLM` by carefully mapping the weights from the `LlamaDecoderLayer` (HF) to `TransformerLayer` (TE). The method `replace_params` maps and copies apt weights from `LlamaDecoderLayer` to the `TransformerLayer`. Refer to the following diagram for more details.\n", - "\n", - "```\n", - "def replace_params(hf_state_dict, te_state_dict):\n", - " # collect all layer prefixes to update\n", - " all_layer_prefixes = set()\n", - " for param_key in hf_state_dict.keys():\n", - " layer_prefix_pat = 'model.layers.\\d+.'\n", - " m = re.match(layer_prefix_pat, param_key)\n", - " if m is not None:\n", - " all_layer_prefixes.add(m.group())\n", - "\n", - " for layer_prefix in all_layer_prefixes:\n", - " # When loading weights into models with less number of layers, skip the\n", - " # copy if the corresponding layer doesn't exist in TE model\n", - " if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict:\n", - " te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]\n", - "\n", - " if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict:\n", - " te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]\n", - "\n", - " if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict:\n", - " te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]\n", - " .\n", - " .\n", - " .\n", - "\n", - " return all_layer_prefixes\n", - "```\n", - "\n", - "The following figure shows how the weights get mapped from the HF's `LlamaDecoderLayer` to TE's `TransformerLayer`.\n", - "\n", - "
\n", - "\n", - "
Fig 9: Replace `LlamaDecoderLayer` with `TransformerLayer`.
\n", - "
\n", - "\n", - "After initializing the modified Llama model this way, the core decoder layers get changed to `TELlamaDecoderLayer` (wrapper around `TransformerLayer`) as shown in the following output:\n", - "```\n", - "ModuleList(\n", - " (0-31): 32 x TELlamaDecoderLayer(\n", - " (self_attention): MultiheadAttention(\n", - " (layernorm_qkv): LayerNormLinear()\n", - " (core_attention): DotProductAttention(\n", - " (flash_attention): FlashAttention()\n", - " (fused_attention): FusedAttention()\n", - " (unfused_attention): UnfusedDotProductAttention(\n", - " (scale_mask_softmax): FusedScaleMaskSoftmax()\n", - " (attention_dropout): Dropout(p=0, inplace=False)\n", - " )\n", - " )\n", - " (proj): Linear()\n", - " )\n", - " (layernorm_mlp): LayerNormMLP()\n", - " )\n", - ")\n", - "```\n", - "\n", - "In summary, the model gets changed as follows with a large chunk of the implementation (core decoder layers) coming from Transformer Engine.\n", - "\n", - "
\n", - "\n", - "
Fig 10: Language model after the HF's `LlamaDecoderLayer`s are replaced with TE's `TransformerLayer`s.
\n", - "
\n", - "\n", - "\n", - "
\n", - "Note\n", - "\n", - "Let's first run this \"TELlama\" implementation in `BF16` precision.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "bdb34b91", - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n", + "\n", + "
\n", + "\n", + "Goal\n", + "\n", + "This tutorial showcases how to accelerate finetuning a full [Llama 2](https://huggingface.co/meta-llama/Llama-2-7b-hf) or [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B) models from Hugging Face by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n", + "\n", + "
\n" + ], + "id": "6a5b2993" + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "10 finetuning steps complete!\n", - "Average time taken per step: 185 milliseconds\n" - ] - } - ], - "source": [ - "# Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "restart_jupyter_notebook()\n", - "\n", - "\n", - "# Import necessary packages, methods and variables\n", - "from utils import *\n", - "\n", - "\n", - "# Provide Huggingface Access Token\n", - "hyperparams.hf_access_token = \"\"\n", - "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", - "\n", - "# Provide a directory to cache weights in to avoid downloading them every time.\n", - "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", - "hyperparams.weights_cache_dir = \"\"\n", - "\n", - "# For Llama 2, uncomment this line (also set by default)\n", - "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", - "\n", - "# For Llama 3, uncomment this line\n", - "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", - "\n", - "hyperparams.mixed_precision = \"bf16\"\n", - "\n", - "\n", - "# Init the model and accelerator wrapper\n", - "model = init_te_llama_model(hyperparams)\n", - "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", - "\n", - "\n", - "# Finetune the model\n", - "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" - ] - }, - { - "cell_type": "markdown", - "id": "0c9fbd65", - "metadata": {}, - "source": [ - "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `LlamaDecoderLayer` gives a speedup of **34%** even when using only BF16 precision!\n", - "\n", - "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 248 | 1 |\n", - "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 185 | 1.34 |" - ] - }, - { - "cell_type": "markdown", - "id": "98cd8efb", - "metadata": {}, - "source": [ - "## [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", - "\n", - "Now that most of the HF Llama model implementation (`LlamaDecoderLayer`s) has been swapped with Transformer Engine implementation (`TELlamaDecoderLayer` or `TransformerLayer`), let's see how finetuning in `FP8` precision helps improve performance.\n", - "\n", - "#### How to run the model in `FP8` precision\n", - "\n", - "After the substitution, the model can be run in `FP8` precision by the following change over the previous BF16 runs. (For more information, refer the corresponding `wrap_with_accelerator` function in the accompanying `utils.py` file).\n", - "\n", - "```\n", - "# Specify the `FP8RecipeKwargs` (additional argument required to run in `fp8` precision)\n", - "fp8_kwarg_handler = [FP8RecipeKwargs(backend=\"te\")]\n", - "\n", - "# Pass the `FP8RecipeKwargs` to the `Accelerator` init call\n", - "accelerator = Accelerator(\n", - " ...\n", - " kwargs_handlers=fp8_kwarg_handler\n", - ")\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "772c6f22", - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dependencies for this tutorial\n", + "\n", + "Following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_llama.py`\n", + " - This file contains the code to load a Hugging Face Llama 2 or Llama 3 checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n", + "2. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", + "3. `requirements.txt`\n", + " - This file contains the necessary Python packages for this tutorial.\n", + "4. `media/`\n", + " - This directory contains the images used in the following tutorial.\n", + "\n", + "\n", + "
\n", + "\n", + "Note on running the tutorial with Llama 3 weights\n", + "\n", + "This tutorial shows the cell outputs when run with Llama 2 7B weights. It can be run with Llama 3 8B weights simply by providing the directory with those weights (in Hugging Face format) instead of Llama 2 7B weights. These two models are almost identical, the biggest difference being the model dimension (the smallest Llama 3 model has 8B parameters, whereas the smallest Llama 2 has 7B), which enables this tutorial to work for both of them.\n", + "\n", + "
\n", + "" + ], + "id": "331f476a" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup\n", + "\n", + "Install the required Python packages using the following command:" + ], + "id": "b56526b3" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Uncomment and run this cell when running the tutorial for the first time\n", + "# %pip install -r requirements.txt" + ], + "id": "099697e2", + "execution_count": null, + "outputs": [] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "10 finetuning steps complete!\n", - "Average time taken per step: 160 milliseconds\n" - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Table of contents\n", + "1. From \"Transformer\" to \"Llama\"\n", + "2. Hugging Face's `LlamaModel`\n", + " - Hugging Face's `LlamaDecoderLayer`\n", + "3. [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n", + "6. [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", + " - Transformer Engine's `TransformerLayer`\n", + " - `TransformerLayer` options explained\n", + " - Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n", + "7. [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", + "8. Conclusion" + ], + "id": "44abae4f" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## From \"Transformer\" to \"Llama\" \n", + "\n", + "
\n", + "\n", + "
Fig 1: Llama visualized as a transformer. (generated with [Nvidia's AI-foundation models](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/sdxl))
\n", + "
\n", + "\n", + "A flashback:\n", + "\n", + "- 2017: [\"Attention Is All You Need\"](https://arxiv.org/abs/1706.03762) paper introduced pioneering \"Transformer\" architecture and changed the NLP field forever.\n", + "- 2018-2020: Emergence of GPT model series that showed causal decoder architectures are great fit for pretraining, few-shot and zero-shot learning.\n", + "- Fast forward to 2023-2024: Following GPT-3/GPT-4 success stories, researchers and companies raced to produce the next best pretrained model that could further be finetuned for application-specific use-cases.\n", + "- February 2023: Meta releases [Llama 2](https://llama.meta.com/llama2) models (Large Language Model Meta AI). \n", + " - These models range from 7B to 70B parameters.\n", + " - LLaMA 2 was pretrained on 2 trillion tokens.\n", + "- April 2024: Meta releases [Llama 3](https://llama.meta.com/llama3) models.\n", + " - These models range from 8B to 70B parameters.\n", + " - LLaMA 3 was pretrained on 15 trillion tokens.\n", + "\n", + "For more information on Llama 2 consider reading the [Huggingface tutorial](https://huggingface.co/blog/llama2). As a quick summary, here are some of the important differences b/w the conventional transformer decoder architecture vs Llama 2 architecture:\n", + "\n", + "1. Decoder only model (causal language modeling and next word prediction)\n", + "2. RMSNorm in place of the LayerNorm\n", + "3. SwiGLU activation function\n", + "4. RoPE as positional embeddings \n", + "5. Grouped Query Attention for the 70B model\n", + "6. Trained on 4K context length\n", + "\n", + "Hugging Face also released a [tutorial about Llama 3](https://huggingface.co/blog/llama3). The key points are:\n", + "\n", + "1. Use of bigger tokenizer - 128256 vs 32K.\n", + "2. Grouped Query Attention is used also by smaller 8B model.\n", + "3. The context length increased to 8K for all models.\n", + "3. Llama 3 was trained on 8x more data than Llama 2.\n", + "\n", + "
\n", + "\n", + "
Fig 2: Comparing GPT and Llama architectures.
\n", + "
" + ], + "id": "e37e2cc1" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hugging Face's `LlamaModel`\n", + "Hugging Face provides an open-source implementation of `Llama` model in [modeling_llama.py](https://github.com/huggingface/transformers/blob/3d2900e829ab16757632f9dde891f1947cfc4be0/src/transformers/models/llama/modeling_llama.py#L4).\n", + "\n", + "Here's a block diagram that shows how Llama model is implemented in the Hugging Face repo. Notice the modular encapsulated form and `LlamaDecoderLayer` at the core of the model implementation.\n", + "\n", + "
\n", + "\n", + "
Fig 3: Causal Llama Model Block Diagram.
\n", + "
\n", + "\n", + "The above diagram translates to the following text output of the model in PyTorch. Notice that the core of the model has 32 `LlamaDecoderLayer`s. \n", + "\n", + "```\n", + "LlamaForCausalLM(\n", + " (model): LlamaModel(\n", + " (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n", + " (layers): ModuleList(\n", + " (0-31): 32 x LlamaDecoderLayer(\n", + " (self_attn): LlamaFlashAttention2(\n", + " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (mlp): LlamaMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LlamaRMSNorm()\n", + " (post_attention_layernorm): LlamaRMSNorm()\n", + " )\n", + " )\n", + " (norm): LlamaRMSNorm()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n", + ")\n", + "```\n", + "\n", + "### Hugging Face's `LlamaDecoderLayer`\n", + "\n", + "Let's take a closer look at `LlamaDecoderLayer`. It is composed of `input_layernorm`, `self_attn`, `post_attention_layernorm` and `mlp` modules. Each module has associated weights as shown in the diagram.\n", + "\n", + "
\n", + "\n", + "
Fig 4: Causal Llama Model Block Diagram (with simplified illustration of the [LlamaDecoderLayer](https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/llama/modeling_llama.py#L695)).
\n", + "
\n", + "\n", + "#### Self_Attn Layer\n", + "For simplicity in the block diagram illustration of the \"self_attn\" box, we omit the \"Grouped Query Attention\" operation and only showcase the modules which have associated weights.\n", + " \n", + "#### MLP Layer\n", + "\n", + "SwiGLU is an activation defined as follows in the [modeling_llama.py](https://github.com/huggingface/transformers/blob/7c4995f93d8d24aae05e1e43279c96dce736e5c8/src/transformers/models/llama/modeling_llama.py#L236) file in the Hugging Face github repo:\n", + "```\n", + "\"\"\"\n", + "1. `self.up_proj`, `self.gate_proj` and `self.down_proj` are \"Linear\" layers\n", + "2. `self.act_fn` is a \"Swish\" function\n", + "\n", + "\"\"\"\n", + "down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n", + "```\n", + "It requires a set of 3 weights as compared to 2 weights in conventional \"MLP\" layers e.g. in the traditional transformer or GPT architectures. This is also illustrated in the following figure:\n", + "\n", + "
\n", + "\n", + "
Fig 5: A look inside the feedforward layer with swiglu activation function.
\n", + "
" + ], + "id": "a110de1a" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n", + "\n", + "Llama 2 weights are loaded into the Hugging Face native implementation `LlamaForCausalLM` (refer to [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)). \n", + "\n", + "For this and other subsequent runs, the `batch_size` is `8`. The `LlamaDecoderLayer` is left unchanged in the baseline as follows:\n", + "\n", + "
\n", + "\n", + "
Fig 6: Revisiting \"LlamaDecoderLayer\".
\n", + "
\n", + "\n", + "
\n", + "Note\n", + "\n", + "The baseline implementation will be run in `BF16` precision.\n", + "\n", + "
" + ], + "id": "c9529229" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "Note\n", + " \n", + "This tutorial loads and trains a Llama 3 8B or a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", + "\n", + "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", + "\n", + "
\n" + ], + "id": "b38eb3ac" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages, methods and variables\n", + "from utils import *\n", + "\n", + "\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_baseline_model(hyperparams)\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "10 finetuning steps complete!\n", + "Average time taken per step: 248 milliseconds\n" + ] + } + ], + "id": "2e9d7a8c" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 248 | 1 |" + ], + "id": "4035ccb7" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", + "\n", + "In addition to basic layers like `Linear` and `LayerNorm`, Transformer Engine offers larger modules like `MultiheadAttention` (combines \"LayerNorm\" and \"Self Attention\") and `LayerNormMLP` (combines \"LayerNorm\" and \"MLP\") that could replace their counterparts in the `LlamaDecoderLayer` and potentially provide a speedup. Transformer Engine also offers a full `TransformerLayer` (which further combines `MultiheadAttention` and `LayerNormMLP` layers) which could replace `LlamaDecoderLayer` and provide a speedup (with careful mapping of the weights since the name of the weights are different for those two layers). Let's take a closer look at Transformer Engine's `TransformerLayer`. \n", + "\n", + "### Transformer Engine's `TransformerLayer`\n", + "\n", + "At a higher level, TE's `TransformerLayer` could be visualized as an apt replacement for the `LlamaDecoderLayer`. But the internals of the `TransformerLayer` are organized a bit differently. \n", + "\n", + "
\n", + "\n", + "
Fig 7: Transformer Engine's `TransformerLayer`
\n", + "
\n", + "\n", + "Just like Hugging Face's `LlamaDecoderLayer`, Transformer Engine's `TransformerLayer` encapsulates `self_attention` (as `MultiheadAttention`) and `mlp` (as `LayerNormMLP`). A major difference is that the two `Norm`s are included in the `MultiheadAttention` and `LayerNormMLP` layers as shown in the following output prompt:\n", + "\n", + "```\n", + "TransformerLayer(\n", + " (self_attention): MultiheadAttention(\n", + " (layernorm_qkv): LayerNormLinear()\n", + " (core_attention): DotProductAttention()\n", + " (proj): Linear()\n", + " )\n", + " (layernorm_mlp): LayerNormMLP()\n", + ")\n", + "```\n", + "\n", + "Another difference is that Transformer Engine implements an efficient version of feedforward layer with SwiGLU in which the weights from the `up_proj` and `gate_proj` modules are merged together and SwiGLU is applied using a custom fused kernel. This is done so that only one big and efficient Matrix Multiplication operation is issued to the GPU instead of two smaller ones.\n", + "\n", + "
\n", + "\n", + "
Fig 8: Abstract illustration of the SwiGLU implementation in Transformer Engine.
\n", + "
\n", + "\n", + "### `TransformerLayer` options explained\n", + "\n", + "
\n", + "\n", + "Note\n", + " \n", + "Here, we go over some of the options in `TransformerLayer` that are needed for the tutorial. For a complete list of options, refer the [TransformerLayer API documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html?highlight=transformerlayer#transformer_engine.pytorch.TransformerLayer).\n", + "\n", + "
\n", + "\n", + "In the accompanying `te_llama.py` file, `TELlamaDecoderLayer` is defined as a wrapper over TE's `TransformerLayer` with a few needed options that make `TransformerLayer` a plug-in replacement for the HF's `LlamaDecoderLayer`.\n", + "\n", + "```\n", + "class TELlamaDecoderLayer(te.pytorch.TransformerLayer):\n", + " def __init__(self, config):\n", + " super().__init__(\n", + " config.hidden_size,\n", + " config.intermediate_size,\n", + " config.num_attention_heads,\n", + " bias=False,\n", + " layernorm_epsilon=config.rms_norm_eps,\n", + " hidden_dropout=0,\n", + " attention_dropout=0,\n", + " fuse_qkv_params=False,\n", + " normalization=\"RMSNorm\",\n", + " activation=\"swiglu\",\n", + " attn_input_format=\"bshd\",\n", + " num_gqa_groups=config.num_key_value_heads,\n", + " )\n", + " te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)\n", + " self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()\n", + "```\n", + "\n", + "Here's a list summarizing each option briefly:\n", + "\n", + "1. `hidden_size`: size of each input sample.\n", + "2. `ffn_hidden_size`: intermediate size to which samples are projected.\n", + "3. `num_attention_heads`: number of attention heads in the transformer layer.\n", + "4. `bias`: switch to add additive biases to the submodule layers.\n", + "5. `layernorm_epsilon`: a value added to the denominator of layer normalization for numerical stability. Default is `1e-5`.\n", + "6. `hidden_dropout`: dropout probability for the dropout op after FC2 layer (fully connected layer no. 2). Default is `0.1`.\n", + "7. `attention_dropout`: dropout probability for the dropout op during multi-head attention. Default is `0.1`. \n", + "8. `fuse_qkv_params`: if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.\n", + "9. `normalization`: type of normalization applied. Default is `LayerNorm`.\n", + "10. `activation`: type of activation used in the MLP block. Default is `gelu`.\n", + "11. `attn_input_format`: controls whether the dimensions of the intermediate hidden states is 'batch first' ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, `b` batch size, `h` the number of heads, `d` head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules.\n", + "12. `num_gqa_groups`: number of GQA groups in the transformer layer. Grouped Query Attention is described in [this paper](https://arxiv.org/pdf/2305.13245.pdf). This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention ([MQA](https://arxiv.org/pdf/1911.02150.pdf)), while GQA-H is equivalent to MultiHead Attention, i.e. `num_gqa_groups = num_attention_heads`.\n", + "\n", + "\n", + "Further, note that `RotaryPositionEmbedding` is defined as part of the `TELlamaDecoderLayer` (wrapper around TE's `TransformerLayer`) itself since it expects this rope cache if RoPE is used in the model. \n", + "\n", + "Let's revisit how `LlamaDecoderLayer`s form the core of the decoder layer stack in HF's llama implementation:\n", + "```\n", + "ModuleList(\n", + " (0-31): 32 x LlamaDecoderLayer(\n", + " (self_attn): LlamaAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (mlp): LlamaMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LlamaRMSNorm()\n", + " (post_attention_layernorm): LlamaRMSNorm()\n", + " )\n", + ")\n", + "```\n", + "\n", + "A major portion of the Hugging Face model implementation (32 `LlamaDecoderLayer` layers) could be potentially replaced with Transformer Engine's `TransformerLayer` layers. Let's see how it is made possible.\n", + "\n", + "\n", + "### Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n", + "\n", + "Refer the accompanying file `te_llama.py` which provides a reference to create a Llama 2 model with TE's `TransformerLayer` after replacing HF's `LlamaDecoderLayer`.\n", + "\n", + "Briefly, following pieces of code are put together:\n", + "\n", + "1. `TELlamaDecoderLayer` is added as a wrapper for `TransformerLayer`. \n", + "```\n", + "class TELlamaDecoderLayer(te.pytorch.TransformerLayer):\n", + " \"\"\"\n", + " Wrapper class over TE's `TransformerLayer`. This makes the wrapper very\n", + " similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.\n", + "\n", + " Args:\n", + " config: LlamaConfig\n", + " args: positional args (for compatibility with `LlamaDecoderLayer`)\n", + " kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)\n", + " \"\"\"\n", + " def __init__(self, config, *args, **kwargs):\n", + " super().__init__(\n", + " hidden_size=config.hidden_size,\n", + " ffn_hidden_size=config.intermediate_size,\n", + " num_attention_heads=config.num_attention_heads,\n", + " bias=False,\n", + " layernorm_epsilon=config.rms_norm_eps,\n", + " hidden_dropout=0,\n", + " attention_dropout=0,\n", + " fuse_qkv_params=False,\n", + " normalization=\"RMSNorm\",\n", + " activation=\"swiglu\",\n", + " attn_input_format=\"bshd\",\n", + " )\n", + " te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)\n", + " self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()\n", + "\n", + " def forward(self,\n", + " hidden_states,\n", + " *args,\n", + " attention_mask,\n", + " **kwargs):\n", + " \"\"\"\n", + " Custom forward to make sure we only pass relevant arguments to the\n", + " forward pass of the `TransformerLayer`. Also, make sure the output\n", + " format matches the output of the HF's `LlamaDecoderLayer`.\n", + " \"\"\"\n", + " return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),)\n", + "```\n", + "\n", + "2. Before creating a `LlamaForCausalLM`, `replace_decoder` context manager is used to monkey-patch `LlamaDecoderLayer` with `TELlamaDecoderLayer`.\n", + "\n", + "```\n", + "@contextmanager\n", + "def replace_decoder(te_decoder_cls):\n", + " \"\"\"\n", + " Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.\n", + " \"\"\"\n", + " original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer\n", + " transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls\n", + " try:\n", + " yield\n", + " finally:\n", + " transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls\n", + ".\n", + ".\n", + ".\n", + "class TELlamaForCausalLM:\n", + " \"\"\"\n", + " Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`\n", + " class is monkey-patched with `TELlamaDecoderLayer` class before\n", + " initializing the causal LM with `LlamaForCausalLM`.\n", + "\n", + " Args:\n", + " config: LlamaConfig\n", + " \"\"\"\n", + "\n", + " def __new__(cls, config: LlamaConfig):\n", + " with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):\n", + " llama_for_causal_lm = LlamaForCausalLM(config)\n", + " return llama_for_causal_lm\n", + ".\n", + ".\n", + ".\n", + "```\n", + "\n", + "3. A custom `pretrained_from_local` method is added that copies the weights from the checkpoint (which is meant for HF Llama implementation) to the modified `TELlamaForCausalLM` by carefully mapping the weights from the `LlamaDecoderLayer` (HF) to `TransformerLayer` (TE). The method `replace_params` maps and copies apt weights from `LlamaDecoderLayer` to the `TransformerLayer`. Refer to the following diagram for more details.\n", + "\n", + "```\n", + "def replace_params(hf_state_dict, te_state_dict):\n", + " # collect all layer prefixes to update\n", + " all_layer_prefixes = set()\n", + " for param_key in hf_state_dict.keys():\n", + " layer_prefix_pat = 'model.layers.\\d+.'\n", + " m = re.match(layer_prefix_pat, param_key)\n", + " if m is not None:\n", + " all_layer_prefixes.add(m.group())\n", + "\n", + " for layer_prefix in all_layer_prefixes:\n", + " # When loading weights into models with less number of layers, skip the\n", + " # copy if the corresponding layer doesn't exist in TE model\n", + " if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict:\n", + " te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]\n", + "\n", + " if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict:\n", + " te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]\n", + "\n", + " if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict:\n", + " te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]\n", + " .\n", + " .\n", + " .\n", + "\n", + " return all_layer_prefixes\n", + "```\n", + "\n", + "The following figure shows how the weights get mapped from the HF's `LlamaDecoderLayer` to TE's `TransformerLayer`.\n", + "\n", + "
\n", + "\n", + "
Fig 9: Replace `LlamaDecoderLayer` with `TransformerLayer`.
\n", + "
\n", + "\n", + "After initializing the modified Llama model this way, the core decoder layers get changed to `TELlamaDecoderLayer` (wrapper around `TransformerLayer`) as shown in the following output:\n", + "```\n", + "ModuleList(\n", + " (0-31): 32 x TELlamaDecoderLayer(\n", + " (self_attention): MultiheadAttention(\n", + " (layernorm_qkv): LayerNormLinear()\n", + " (core_attention): DotProductAttention(\n", + " (flash_attention): FlashAttention()\n", + " (fused_attention): FusedAttention()\n", + " (unfused_attention): UnfusedDotProductAttention(\n", + " (scale_mask_softmax): FusedScaleMaskSoftmax()\n", + " (attention_dropout): Dropout(p=0, inplace=False)\n", + " )\n", + " )\n", + " (proj): Linear()\n", + " )\n", + " (layernorm_mlp): LayerNormMLP()\n", + " )\n", + ")\n", + "```\n", + "\n", + "In summary, the model gets changed as follows with a large chunk of the implementation (core decoder layers) coming from Transformer Engine.\n", + "\n", + "
\n", + "\n", + "
Fig 10: Language model after the HF's `LlamaDecoderLayer`s are replaced with TE's `TransformerLayer`s.
\n", + "
\n", + "\n", + "\n", + "
\n", + "Note\n", + "\n", + "Let's first run this \"TELlama\" implementation in `BF16` precision.\n", + "
" + ], + "id": "3db90dff" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages, methods and variables\n", + "from utils import *\n", + "\n", + "\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_llama_model(hyperparams)\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "10 finetuning steps complete!\n", + "Average time taken per step: 185 milliseconds\n" + ] + } + ], + "id": "bdb34b91" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `LlamaDecoderLayer` gives a speedup of **34%** even when using only BF16 precision!\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 248 | 1 |\n", + "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 185 | 1.34 |" + ], + "id": "0c9fbd65" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", + "\n", + "Now that most of the HF Llama model implementation (`LlamaDecoderLayer`s) has been swapped with Transformer Engine implementation (`TELlamaDecoderLayer` or `TransformerLayer`), let's see how finetuning in `FP8` precision helps improve performance.\n", + "\n", + "### How to run the model in `FP8` precision\n", + "\n", + "After the substitution, the model can be run in `FP8` precision by the following change over the previous BF16 runs. (For more information, refer the corresponding `wrap_with_accelerator` function in the accompanying `utils.py` file).\n", + "\n", + "```\n", + "# Specify the `FP8RecipeKwargs` (additional argument required to run in `fp8` precision)\n", + "fp8_kwarg_handler = [FP8RecipeKwargs(backend=\"te\")]\n", + "\n", + "# Pass the `FP8RecipeKwargs` to the `Accelerator` init call\n", + "accelerator = Accelerator(\n", + " ...\n", + " kwargs_handlers=fp8_kwarg_handler\n", + ")\n", + "```" + ], + "id": "98cd8efb" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages, methods and variables\n", + "from utils import *\n", + "\n", + "\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", + "hyperparams.mixed_precision = \"fp8\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_llama_model(hyperparams)\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "10 finetuning steps complete!\n", + "Average time taken per step: 160 milliseconds\n" + ] + } + ], + "id": "772c6f22" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 248 | 1 |\n", + "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 185 | 1.34 |\n", + "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 160 | 1.55 |\n", + "\n", + "\n", + "After turning on FP8 precision, we get even more speedup of **55%** (with Llama 2 7B)!\n", + "\n", + "### Llama 3 performance results\n", + "Running the same tutorial with **Llama 3 8B** yields the following performance numbers:\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 270 | 1 |\n", + "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 217 | 1.24 |\n", + "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 185 | 1.46 |\n", + "\n", + "For Llama 3 8B, we get the most speedup of **46%** with FP8 precision!\n", + "\n" + ], + "id": "e7cf9c3a" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Llama 2 and Llama 3 implementations. This needs careful initialization of the model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!" + ], + "id": "95d6c42b" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" } - ], - "source": [ - "# Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "restart_jupyter_notebook()\n", - "\n", - "\n", - "# Import necessary packages, methods and variables\n", - "from utils import *\n", - "\n", - "\n", - "# Provide Huggingface Access Token\n", - "hyperparams.hf_access_token = \"\"\n", - "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", - "\n", - "# Provide a directory to cache weights in to avoid downloading them every time.\n", - "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", - "hyperparams.weights_cache_dir = \"\"\n", - "\n", - "# For Llama 2, uncomment this line (also set by default)\n", - "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", - "\n", - "# For Llama 3, uncomment this line\n", - "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", - "\n", - "hyperparams.mixed_precision = \"fp8\"\n", - "\n", - "\n", - "# Init the model and accelerator wrapper\n", - "model = init_te_llama_model(hyperparams)\n", - "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", - "\n", - "\n", - "# Finetune the model\n", - "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" - ] - }, - { - "cell_type": "markdown", - "id": "e7cf9c3a", - "metadata": {}, - "source": [ - "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 248 | 1 |\n", - "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 185 | 1.34 |\n", - "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 160 | 1.55 |\n", - "\n", - "\n", - "After turning on FP8 precision, we get even more speedup of **55%** (with Llama 2 7B)!\n", - "\n", - "#### Llama 3 performance results\n", - "Running the same tutorial with **Llama 3 8B** yields the following performance numbers:\n", - "\n", - "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 270 | 1 |\n", - "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 217 | 1.24 |\n", - "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 185 | 1.46 |\n", - "\n", - "For Llama 3 8B, we get the most speedup of **46%** with FP8 precision!\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "id": "95d6c42b", - "metadata": {}, - "source": [ - "## Conclusion\n", - "\n", - "Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Llama 2 and Llama 3 implementations. This needs careful initialization of the model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst b/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst new file mode 100644 index 000000000..48d17db8d --- /dev/null +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst @@ -0,0 +1,254 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +FP8 Blockwise Scaling +=================================== + +.. warning:: + + ``Float8BlockScaling`` is **currently not supported** in JAX. + +FP8 Blockwise Scaling recipe is inspired by the quantization scheme used to train the `DeepSeek-v3 model `__ – +the first open-source large-scale LLM trained entirely in FP8 precision. +Unlike the previous recipes, it assigns a dedicated scaling factor to each block of elements. + + +Data Format +-------------------------- + +The representation of an FP8 tensor element ``x`` in blockwise precision is given by: + +.. code-block:: python + + x = x_fp8 * s_block + +where + +* ``x_fp8`` is the FP8 value (E4M3 or E5M2), +* ``s_block`` is a local **FP32** scaling factor shared by a block of elements. + + +.. raw:: html + :file: img/combined_scaling.svg + +*Figure 1. Top: Comparison of standard FP8 scaling (left) using a single scaling factor per tensor versus +FP8 blockwise scaling in 1 dimension (right) using multiple scaling factors, one per block of 128 elements. +Bottom: FP8 blockwise scaling in 2 dimensions where each 128×128 block in the data tensor has a corresponding +scaling factor.* + +**FP8 format** + +Unlike FP8 Current/Delayed Scaling, E4M3 is used by default for both forward and backward passes. +Tensor-scaled recipes used E5M2 for gradients due to its higher dynamic range, +but with multiple scaling factors per tensor the dynamic range requirement is lowered, so E4M3 is usually sufficient. +The ``fp8_format`` parameter also supports ``HYBRID`` mode (E4M3 for forward, E5M2 for backward). +Pure E5M2 training is not supported. + + +**Block size** + +Block size is 128. +Blocks can be: + +* one dimensional – containing 128 consecutive values, +* two dimensional – containing tiles of 128×128 values. + +By default: + +* activations use 1D scaling (``x_block_scaling_dim=1``), +* weights use 2D scaling (``w_block_scaling_dim=2``), +* gradients use 1D scaling (``grad_block_scaling_dim=1``). + +These can be changed in the recipe, but 2D × 2D GEMMs are not supported +– at most one operand can use 2D scaling. + +One-dimensional scaling is more granular, but 2D scaling offers two advantages: + +* *Performance*: On Hopper, block-scaled GEMMs are software-emulated. GEMMs with mixed + 1D/2D scaled tensors have lower overhead than pure 1D scaled GEMMs. +* *Numerical stability*: 2D scaling behaves better when transposed (details in the next section). + +There are some assumptions on the dimensions of the tensor (for both 1D and 2D scaling): + +* the tensor must have at least 2 dimensions, +* the last dimension must be divisible by 128, +* the product of all dimensions except the last must be divisible by 128. + +**Scaling factors** + +Scaling factors are stored as 32-bit floating point numbers. +By default, they are constrained to powers of 2 (utilizing the 8 exponent bits of FP32). +On Hopper, this constraint can be relaxed by setting the environment variable ``NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1``. +On Blackwell, only powers of 2 are supported. + +Each block's scaling factor is computed through the following steps: + +1. Find the maximum absolute value (``amax_block``) across all elements in the block + (128 consecutive values for 1D blocks, or 128×128 values for 2D blocks). +2. Calculate ``s_block = max_fp8 / amax_block``, where ``max_fp8`` is + the maximum representable value in the FP8 format (448 for E4M3, 57344 for E5M2). +3. If the power-of-2 constraint is enabled, round down to the nearest power of 2 + by zeroing out the mantissa bits, retaining only the sign and exponent. +4. Multiply each element in the block by ``s_block`` before converting to FP8. + +This approach ensures that the largest value in each block fits within the FP8 representable range without overflow. + + +Handling transposes +------------------------ + +On Hopper, columnwise tensor access requires data to be transposed in memory. +For 1D scaling, the block direction must align with the access pattern: + +* *Rowwise access*: 1 scaling factor per 128 consecutive elements in a row. +* *Columnwise access*: 1 scaling factor per 128 consecutive elements in a row of the transposed tensor, + corresponding to 128 consecutive elements in a column of the original tensor. + +For 2D scaling, each 128×128 tile has one scaling factor regardless of access direction. + +This is illustrated below: + +.. raw:: html + :file: img/transpose_handling.svg + +*Figure 2. Quantization directions for original and transposed tensors.* + +Note that for 1D scaling, the rowwise and columnwise quantized tensors may be numerically different, +so the gradient computation may be affected. This issue is not present for 2D scaling. + + +Activations and weights use the rowwise version in the forward pass and the columnwise version in the backward pass. +Experiments have shown that 2D scaling for weights is more helpful for numerical stability than for activations, +so by default 1D scaling is used for activations – as it is more granular – and 2D scaling is used for weights. + + +Unlike FP8 Current/Delayed Scaling, transposing a 1D quantized tensor is not supported. +Rowwise and columnwise blocks cover different sets of elements, so their scaling factors differ. +Both versions must be quantized separately from the high-precision source. + +For 2D scaling, columnwise data can be created from rowwise data by transposing +both the quantized data and the scaling factors. Each 128×128 block covers the same +elements regardless of access direction, so the scaling factors remain valid. + + +Distributed training +----------------------- + +**Scale synchronization** + +The blockwise scaled tensor does not need any scale synchronization among the nodes. +This is because each scaling factor is local to its 128 or 128×128 element block, +unlike FP8 Current/Delayed Scaling where a single global scale applies to the entire tensor, even when sharded. + +**Quantized all-gather** + +FP8 Blockwise Scaling all-gather is supported. + + +Examples +-------- + +Here's how to use the FP8 Blockwise Scaling recipe in PyTorch and JAX: + +.. note:: + + Requires SM90 (Hopper) or later. + +.. tabs:: + + .. tab:: PyTorch + + .. literalinclude:: pytorch_blockwise_scaling_example.py + :language: python + :start-after: # START_BLOCKWISE_SCALING_EXAMPLE + :end-before: # END_BLOCKWISE_SCALING_EXAMPLE + + .. tab:: JAX + + ``Float8BlockScaling`` is **not currently supported** in JAX. + +Supported devices +----------------- + +Hopper (SM 9.0) + +Blackwell and later (SM >= 10.0) – the recipe is emulated with MXFP8. Note that MXFP8 is the preferred recipe on Blackwell. + Only scaling factors that are powers of 2 are supported. + + +---- + +Developer Notes +--------------- + +This section contains implementation details that may be useful for developers +but are not required for using FP8 Blockwise Scaling in practice. + +Swizzle of scaling factors +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +FP8 Blockwise Scaling supports all-gather of both rowwise and columnwise tensors. +To support that, it implements different data layouts for communication (all-gather) +and computation (GEMM). We refer to the conversion between these formats as *swizzling*. + +A tensor of shape ``[A, B]`` can exist in two formats: + +**Compact format** (used for all-gather): + +The all-gather primitive only supports gathering non-transposed shards into a non-transposed full tensor, +so all tensor components in this layout are stored without transposition. +Moreover, all component tensors are stored without padding. + +.. list-table:: + :widths: 30 70 + :header-rows: 1 + + * - Component + - Shape + * - rowwise data + - ``[A, B]`` + * - columnwise data + - ``[A, B]`` + * - rowwise scales + - ``[A, B/128]`` + * - columnwise scales + - ``[A/128, B]`` + +**GEMM-ready format** (used for computation): + +Tensors are transposed and padded as required by the GEMM kernel. + +.. list-table:: + :widths: 30 70 + :header-rows: 1 + + * - Component + - Shape + * - rowwise data + - ``[A, B]`` + * - columnwise data + - ``[B, A]`` (transposed) + * - rowwise scales + - ``[B/128, pad4(A)]`` (transposed, padded) + * - columnwise scales + - ``[A/128, pad4(B)]`` (padded) + +Swizzling converts from compact to GEMM-ready format. This can be fused with quantization +when no all-gather is needed, or performed separately after all-gather. + +.. raw:: html + :file: img/blockwise_swizzle_flow.svg + +*Figure 3. FP8 Blockwise Scaling swizzle paths. Top: With all-gather communication – quantization produces +compact format, then swizzle is performed separately after communication. Bottom: Without all-gather – +quantize and swizzle are fused into a single operation, directly producing GEMM-ready format.* + +All-gather of columnwise tensors +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +All-gather of columnwise tensors is supported and necessary because: + +- columnwise quantized tensors cannot be computed from rowwise quantized ones, +- gathering high-precision tensors is avoided in most cases for performance reasons. diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/img/blockwise_swizzle_flow.svg b/docs/features/low_precision_training/fp8_blockwise_scaling/img/blockwise_swizzle_flow.svg new file mode 100644 index 000000000..afad96d76 --- /dev/null +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/img/blockwise_swizzle_flow.svg @@ -0,0 +1,146 @@ + + + + + + + + + + + + + + + + Input Tensor + + FP32/BF16 + + + + + + + + Quantize + + + + + + + FP8 (Compact) + + + + + FP32 Scales + + + + FP8 Data + + + + + + + + All-Gather + + + + + + + Swizzle + + + + + + + FP8 (GEMM Ready) + + + + + Swizzled Scales + + + + FP8 Data + + + + + + + + GEMM + + + + + + + + + + Input Tensor + + FP32/BF16 + + + + + + + + Quantize + + + Swizzle + + + + + + + FP8 (GEMM Ready) + + + + + Swizzled Scales + + + + FP8 Data + + + + + + + + GEMM + + diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/img/combined_scaling.svg b/docs/features/low_precision_training/fp8_blockwise_scaling/img/combined_scaling.svg new file mode 100644 index 000000000..dbf6999ae --- /dev/null +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/img/combined_scaling.svg @@ -0,0 +1,342 @@ + + + + + + + + + + Delayed/Current FP8 Scaling + (Single scaling factor per tensor) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1 scaling factor + + + + + Blockwise FP8 Scaling – 1 dimension + (One scaling factor per 128 elements) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Scaling factors (one per block) + + + + + Blockwise FP8 Scaling – 2 dimensions + (One scaling factor per 128x128 block of elements) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Scaling factors (1 per 2D block) + + diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/img/transpose_handling.svg b/docs/features/low_precision_training/fp8_blockwise_scaling/img/transpose_handling.svg new file mode 100644 index 000000000..e9a3b7b7d --- /dev/null +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/img/transpose_handling.svg @@ -0,0 +1,347 @@ + + + + + + + 1D Blockwise Scaling + + + + Rowwise Quantization + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Columnwise Quantization + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 2D Blockwise Scaling + + + + Rowwise Quantization + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Columnwise Quantization + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py b/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py new file mode 100644 index 000000000..5100fc1a1 --- /dev/null +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Check for Hopper or newer GPU +major, minor = torch.cuda.get_device_capability() +assert major >= 9, f"FP8 Blockwise Scaling requires SM90 (Hopper) or later, got SM{major}{minor}" + +# START_BLOCKWISE_SCALING_EXAMPLE + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Float8BlockScaling + +# Create FP8 Blockwise Scaling recipe +recipe = Float8BlockScaling( + fp8_format=te.common.recipe.Format.E4M3, # E4M3 or HYBRID (default: E4M3) + x_block_scaling_dim=1, # 1D scaling for activations (default: 1) + w_block_scaling_dim=2, # 2D scaling for weights (default: 2) + grad_block_scaling_dim=1, # 1D scaling for gradients (default: 1) +) + +# Create a linear layer with bfloat16 parameters +layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) + +# Forward and backward pass +inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") + +with te.autocast(enabled=True, recipe=recipe): + output = layer(inp) + loss = output.sum() + +loss.backward() + +# END_BLOCKWISE_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst b/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst new file mode 100644 index 000000000..a4830a3fd --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst @@ -0,0 +1,180 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +FP8 Current Scaling +=================================== + +FP8 current scaling recipe is the simplest low precision recipe provided by Transformer Engine. +To understand how this recipe works, we first need to examine what the FP8 data type is and how it differs from other floating point formats. + + +FP8 data type +------------- + +The FP8 datatype, introduced in Hopper architecture, is actually 2 distinct datatypes, useful in different parts of the training of neural networks: + +* E4M3 -- consists of 1 sign bit, 4 exponent bits and 3 bits of mantissa. It can store values up to +/-448 and ``nan``. +* E5M2 -- consists of 1 sign bit, 5 exponent bits and 2 bits of mantissa. It can store values up to +/-57344, +/- ``inf`` and ``nan``. The tradeoff of the increased dynamic range is lower precision of the stored values. + +.. raw:: html + :file: img/fp8_formats.svg + +*Figure 1: Structure of the floating point datatypes. All of the values shown (in FP16, BF16, FP8 E4M3 and FP8 E5M2) are the closest representations of value 0.3952.* + + +**E4M3 and E5M2 usage in training** + +By default, Transformer Engine uses a hybrid approach: + +* *Forward pass* - activations and weights require more precision, so E4M3 datatype is used to store them. +* *Backward pass* - gradients are less susceptible to precision loss but require higher dynamic range, so E5M2 datatype is preferred. + +The user can configure this behavior via the ``fp8_format`` parameter of the recipe. + + +Scaling factors +--------------- + + +Limited dynamic range of FP8 datatype is insufficient for many tensors. +To address this, values in the tensor are scaled. FP8 Current Scaling recipe uses one **FP32** scale factor per tensor. The representation of a tensor element ``x`` in FP8 precision is given by: + +.. code-block:: python + + x = x_fp8 * s + +where + +* ``x_fp8`` is the FP8 value (E4M3 or E5M2), +* ``s`` is a global **FP32** scaling factor applied to the entire tensor. + +**FP8 Current Scaling quantization** + +Let's take a closer look at how quantization to FP8 with scaling factor is implemented in +the FP8 Current Scaling recipe. + +.. raw:: html + :file: img/fp8_scaling_concept.svg + +*Figure 3: Quantization to FP8 consists of amax (absolute maximum) computation, scaling to fit the FP8 range and casting to the respective FP8 format.* + +Quantization to FP8 consists of 3 steps: + +1. Computation of the absolute maximum value of the tensor - we refer to it as ``amax``. +2. Applying the scaling factor of ``fp8_max / amax`` to the tensor, to fit it into the FP8 range +3. Casting into the respective FP8 format using *Round To Nearest Even (RTNE)*. Values round to the nearest representable FP8 value. When exactly halfway between two values, rounds to the one with even mantissa to minimize systematic bias. + +**Performance analysis** + +Quantization is a memory-bound operation that requires reading the tensor twice: + +* First read: compute ``amax`` across all elements. +* Second read: apply the scaling factor and cast to FP8. + +This is a significant overhead compared to other recipes, which typically require only a single memory read. + +.. raw:: html + :file: img/fp8_cast_process.svg + +*Figure 4: FP8 quantization with current scaling recipe - two tensor reads are needed, one to compute amax and one to apply the scaling factor and cast to FP8.* + + +Transpose handling +------------------ + + + +*Ada and Hopper* + +On Ada and Hopper, the backward pass requires a transposed FP8 tensor. +The columnwise layout is physically different from the rowwise layout, so a transpose operation is needed. +All 3 options from :ref:`Performance Considerations Transpose handling section ` are supported. + +*Blackwell and later* + +Blackwell hardware supports multiple GEMM layouts natively, eliminating the need for explicit transposes. +The rowwise and columnwise tensors share the same physical memory layout. + +.. figure:: ../performance_considerations/img/hopper_vs_blackwell_layout.svg + :align: center + :alt: Comparison of rowwise and columnwise tensor layouts on Blackwell vs Hopper + + *Figure 6: On Blackwell, rowwise and columnwise usages share the same memory layout. On Hopper, columnwise usage requires a physical transpose.* + + +Distributed training +-------------------- + +**Quantized all-gather** + +FP8 all-gather is supported on all architectures (Ada and later). + +**Amax reduction** + +Tensors that are gathered across nodes (e.g. input and gradient in sequence parallelism) require amax synchronization before quantization. +Each node computes its local ``amax``, then a reduction produces the global maximum across all nodes. +All nodes use this synchronized amax to compute identical scaling factors, enabling quantized all-gather. + +.. raw:: html + :file: img/fp8_current_scaling_all_gather.svg + +*Figure 7: Quantization and all-gather flow for FP8 current scaling showing amax computation and synchronization.* + + +Supported devices +----------------- + +Ada and later (SM 8.9+) + +Examples +-------- + +Here's how to use FP8 Current Scaling recipe in PyTorch and JAX: + +.. tabs:: + + .. tab:: PyTorch + + .. raw:: html + +
+ Requires SM89 (Ada) or later +
+ + .. literalinclude:: pytorch_current_scaling_example.py + :language: python + :start-after: # START_CURRENT_SCALING_EXAMPLE + :end-before: # END_CURRENT_SCALING_EXAMPLE + + .. tab:: JAX + + .. raw:: html + +
+ Requires SM89 (Ada) or later +
+ + .. literalinclude:: jax_current_scaling_example.py + :language: python + :start-after: # START_CURRENT_SCALING_EXAMPLE + :end-before: # END_CURRENT_SCALING_EXAMPLE + + +---- + +Developer Notes +--------------- + +This section contains implementation details that may be useful for developers +but are not required for using FP8 Current Scaling in practice. + +All-gather of columnwise tensors +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +On Blackwell and later, rowwise and columnwise tensors share the same memory layout, +so all-gather of columnwise tensors is directly supported. + +For Hopper and Ada, all-gather of transposed FP8 tensors is not supported. +The rowwise tensor is gathered first, then transposed to columnwise format. \ No newline at end of file diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_cast_process.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_cast_process.svg new file mode 100644 index 000000000..294fca318 --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_cast_process.svg @@ -0,0 +1,55 @@ + + + + + + + + + + + FP8 quantization + + + + High Precision + Tensor + + + + + + + Quantize + + + + Compute amax + 1 tensor read + + + + + + + Apply Scale + + Cast + 1 tensor read + + + + + + + FP8 + Tensor + + diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_current_scaling_all_gather.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_current_scaling_all_gather.svg new file mode 100644 index 000000000..f984e1dd3 --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_current_scaling_all_gather.svg @@ -0,0 +1,78 @@ + + + + + + + + + + + Quantization + all gather for FP8 current scaling + + + + High Precision + Tensor + + + + + + + Compute + Amax + + + + + + + Synchronize + Amax + + + + + + + Scale + + Cast + + + + + + + FP8 + Tensor + + + + + + + All-Gather + + + + + + + FP8 Gathered + Tensor + + + diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_formats.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_formats.svg new file mode 100644 index 000000000..bf86a29a6 --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_formats.svg @@ -0,0 +1,164 @@ + + + + + + + sign + exponent + mantissa + + + FP16 + + + + 0 + + + + 0 + + 1 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + 0 + + 1 + + 0 + + 1 + + 0 + + 0 + + 1 + + 1 + + = 0.395264 + + + + BF16 + + + + 0 + + + + 0 + + 1 + + 1 + + 1 + + 1 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + 0 + + 1 + + 0 + + 1 + + 0 + + = 0.394531 + + + + FP8 E4M3 + + + + 0 + + + + 0 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + 1 + + = 0.40625 + + + + FP8 E5M2 + + + + 0 + + + + 0 + + 1 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + = 0.375 + + + diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg new file mode 100644 index 000000000..9442b4e4a --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg @@ -0,0 +1,112 @@ + + + + + Original Tensor Values + + + + + + + 0 + + + + + + + + + + + + + + + + + amax + + + + + + Original range + + + + + + Scaled Values (fit FP8 range) + + + + + + + 0 + + + + + + FP8 range + + + + - FP8 range max + + + + + + + + + + + + + + Cast to FP8 (quantized values) + + + + + + + 0 + + + + + + FP8 range + + + + + + + + + + + + + + + + + + + diff --git a/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py b/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py new file mode 100644 index 000000000..107b13c53 --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py @@ -0,0 +1,33 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_CURRENT_SCALING_EXAMPLE + +import jax +import jax.numpy as jnp +import transformer_engine.jax as te +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.common.recipe import Float8CurrentScaling, Format + +# Create FP8 Current Scaling recipe +# Available formats: +# - Format.HYBRID (default) -- E4M3 for forward pass, E5M2 for backward pass +# - Format.E4M3 -- E4M3 for both forward and backward pass +recipe = Float8CurrentScaling(fp8_format=Format.HYBRID) + +with te.autocast(enabled=True, recipe=recipe): + # Create and initialize layer + layer = DenseGeneral(features=1024) + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) + var_collect = layer.init(key, x) + + # Forward and backward pass + def loss_fn(var_collect): + output = layer.apply(var_collect, x) + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(var_collect) + +# END_CURRENT_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py b/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py new file mode 100644 index 000000000..7ac127189 --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_CURRENT_SCALING_EXAMPLE + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Float8CurrentScaling, Format + +# Create FP8 Current Scaling recipe +# Available formats: +# - Format.HYBRID (default) -- E4M3 for forward pass, E5M2 for backward pass +# - Format.E4M3 -- E4M3 for both forward and backward pass +recipe = Float8CurrentScaling(fp8_format=Format.HYBRID) + +# Create a simple linear layer with bfloat16 parameters +layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) + +# Forward and backward pass +inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") + +with te.autocast(enabled=True, recipe=recipe): + output = layer(inp) + loss = output.sum() + +loss.backward() + +# END_CURRENT_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst new file mode 100644 index 000000000..9d05305ed --- /dev/null +++ b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst @@ -0,0 +1,163 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +FP8 Delayed Scaling +=================================== + +FP8 Delayed Scaling recipe estimates scaling factors from historical amax values rather than computing them +for each tensor. Compared to Current Scaling recipe, +this reduces tensor reads per quantization from two to one, +improving memory efficiency. + +Both this and :doc:`FP8 Current Scaling <../fp8_current_scaling/fp8_current_scaling>` recipe use +the same FP8 formats (E4M3/E5M2) with one FP32 scaling factor per tensor. +Reading the FP8 Current Scaling documentation first is recommended. + +Quantization with delayed scaling factors +----------------------------------------- + +FP8 Current Scaling requires two tensor reads per quantization: one to compute amax, +one to cast. FP8 Delayed Scaling eliminates the first read by predicting the scaling factor +from historical amax values - hence *delayed* (using past values) versus *current* (using present values). + +The quantization process works as follows: + +1. **Compute scaling factor from history** (no tensor read needed): + The scaling factor is derived from stored ``amax_history`` using the formula: + + ``scaling_factor = FP8_MAX / amax`` + + where ``amax`` is computed from history using either ``max`` (maximum over window, default) or ``most_recent`` algorithm. + +2. **Quantize the tensor** (one tensor read): + Apply the scaling factor and cast to FP8. Values exceeding FP8 range are clipped. + +3. **Update history**: + Record the actual amax from this quantization for future iterations. + +Each module maintains an ``amax_history`` tensor of configurable length (``amax_history_len``) +for each quantized tensor. + +.. raw:: html + :file: img/scaling_comparison.svg + +*Figure 1. Comparison of FP8 Current Scaling and FP8 Delayed Scaling quantization processes.* + +Amax History Management +----------------------- + +The ``amax_history`` buffer acts as a sliding window of recent amax values. +Position 0 serves as a staging area for the current amax, while positions 1 to N-1 +store the history from oldest to newest. Each quantization writes the observed amax +to position 0, and after the pass completes, the history is rotated: + +.. code-block:: text + + Before rotation: [amax_N, amax_1, amax_2, ..., amax_N-1] (amax_N = current, amax_1 = oldest) + After rotation: [0, amax_2, ..., amax_N-1, amax_N] (amax_1 dropped, amax_N appended) + +The scaling factor is computed **before** the rotation, so it uses all ``amax_history_len`` values. +Position 0 serves as a staging area — it is zeroed after the scale update, ready for the next iteration's amax. + +The implementation differs between PyTorch and JAX: + +.. tabs:: + + .. tab:: PyTorch + + Each module creates two ``amax_history`` tensors, initialized to zero: + + - Forward: shape ``(amax_history_len, num_gemms * 3)`` — three FP8 tensors per GEMM (input, weight, output) + - Backward: shape ``(amax_history_len, num_gemms * 2)`` — two FP8 tensors per GEMM (grad_output, grad_input) + + When the autocast context exits, a single CUDA kernel processes all tensors at once — + performing amax reduction across GPUs and history rotation. This batched approach + minimizes kernel launch overhead compared to updating each tensor separately. + + .. tab:: JAX + + Each quantizer maintains its own ``amax_history`` with shape ``(amax_history_len,)`` + and updates independently. + +Here's how to use FP8 Delayed Scaling in PyTorch and JAX: + +.. tabs:: + + .. tab:: PyTorch + + .. raw:: html + +
+ Requires SM89 (Ada) or later +
+ + .. literalinclude:: pytorch_delayed_scaling_example.py + :language: python + :start-after: # START_DELAYED_SCALING_EXAMPLE + :end-before: # END_DELAYED_SCALING_EXAMPLE + + .. tab:: JAX + + .. raw:: html + +
+ Requires SM89 (Ada) or later +
+ + .. literalinclude:: jax_delayed_scaling_example.py + :language: python + :start-after: # START_DELAYED_SCALING_EXAMPLE + :end-before: # END_DELAYED_SCALING_EXAMPLE + + +Distributed Training +-------------------- + +FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling - quantized all-gather is supported. +However, amax reduction works slightly differently in different frameworks. + +.. tabs:: + + .. tab:: PyTorch + + Amax reduction is controlled by two parameters: + + - ``reduce_amax`` in recipe: enables/disables reduction (required for SP and CP) + - ``amax_reduction_group`` in ``autocast``: specifies the process group for reduction + + We recommend reducing amax across all GPUs where the tensor is sharded, + including data parallel ranks. + + .. literalinclude:: pytorch_delayed_scaling_distributed_example.py + :language: python + :start-after: # START_AMAX_REDUCTION_EXAMPLE + :end-before: # END_AMAX_REDUCTION_EXAMPLE + + In data parallel training, some modules may not execute on certain ranks + (e.g., MoE experts that receive no tokens). This is handled as follows: + + - **First iteration**: All modules must execute on all ranks to register + their ``amax_history`` tensors in the global buffer. Mismatched registration + would cause the ``all_reduce`` to hang due to different tensor sizes across ranks. + - **Subsequent iterations**: The ``autocast`` context must be entered and exited + on all ranks (this triggers the collective reduction). Individual modules can be + skipped - if no rank executes a module, its history is not rotated and scale + remains unchanged. + + + .. tab:: JAX + + Amax reduction is always enabled and managed automatically. + Reduction scope: all parallelism axes except pipeline parallelism (TP, SP, DP/FSDP). + + .. literalinclude:: jax_delayed_scaling_distributed_example.py + :language: python + :start-after: # START_AMAX_REDUCTION_EXAMPLE + :end-before: # END_AMAX_REDUCTION_EXAMPLE + +Supported devices +----------------- + +Ada and later (SM 8.9+) \ No newline at end of file diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/img/scaling_comparison.svg b/docs/features/low_precision_training/fp8_delayed_scaling/img/scaling_comparison.svg new file mode 100644 index 000000000..aff4ba0da --- /dev/null +++ b/docs/features/low_precision_training/fp8_delayed_scaling/img/scaling_comparison.svg @@ -0,0 +1,82 @@ + + + + + + + + + + + Current Scaling + + + + Tensor + + + + + + + Amax Computation + + + + + + + Quantization + (uses tensor + amax) + + + + + + + FP8 Tensor + + + + Delayed Scaling + + + + Tensor + + + + amax history + + + + read amax + + + + Quantization + (uses tensor + amax from history) + (updates amax history) + + + + update amax + + + + + + + FP8 Tensor + + + diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py new file mode 100644 index 000000000..f354ddaf7 --- /dev/null +++ b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_AMAX_REDUCTION_EXAMPLE +import transformer_engine.jax as te +from transformer_engine.common.recipe import DelayedScaling + +# Amax reduction scope is managed internally +recipe = DelayedScaling(reduce_amax=True) # Must be True in JAX + +with te.autocast(enabled=True, recipe=recipe): + output = layer.apply(params, inp) + +# END_AMAX_REDUCTION_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py new file mode 100644 index 000000000..597111768 --- /dev/null +++ b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from transformer_engine.jax.quantize import get_device_compute_capability + +# Requires Ada (SM89) or newer for FP8 support +assert get_device_compute_capability() >= 89, "This example requires SM89 (Ada) or newer" + +# START_DELAYED_SCALING_EXAMPLE + +import jax +import jax.numpy as jnp +import transformer_engine.jax as te +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.common.recipe import DelayedScaling + +# Create FP8 Delayed Scaling recipe +recipe = DelayedScaling( + margin=0, # Margin for scaling factor computation (default: 0) + amax_history_len=1024, # Length of amax history window (default: 1024) + amax_compute_algo="max", # How to compute amax from history (default: "max") +) + +with te.autocast(enabled=True, recipe=recipe): + # Initialize layer and data + layer = DenseGeneral(features=1024) + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) + var_collect = layer.init(key, x) + + # Forward and backward pass + def loss_fn(var_collect): + output = layer.apply(var_collect, x) + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(var_collect) + +# END_DELAYED_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py new file mode 100644 index 000000000..863b71e8c --- /dev/null +++ b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_AMAX_REDUCTION_EXAMPLE +import torch.distributed as dist +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling + +# Create process group for amax reduction (e.g., all 8 GPUs) +amax_reduction_group = dist.new_group(ranks=[0, 1, 2, 3, 4, 5, 6, 7]) + +recipe = DelayedScaling(reduce_amax=True) + +with te.autocast(recipe=recipe, amax_reduction_group=amax_reduction_group): + output = model(inp) + +# END_AMAX_REDUCTION_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py new file mode 100644 index 000000000..45d244f47 --- /dev/null +++ b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or newer for FP8 support +assert torch.cuda.get_device_capability()[0] >= 9 or ( + torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9 +), "This example requires SM89 (Ada) or newer" + +# START_DELAYED_SCALING_EXAMPLE + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling + +# Create FP8 Delayed Scaling recipe +recipe = DelayedScaling( + margin=0, # Margin for scaling factor computation (default: 0) + amax_history_len=1024, # Length of amax history window (default: 1024) + amax_compute_algo="max", # How to compute amax from history (default: "max") +) + +# Create a linear layer with bfloat16 parameters +layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) + +# Forward and backward pass +inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") + +with te.autocast(enabled=True, recipe=recipe): + output = layer(inp) + loss = output.sum() + +loss.backward() + +# END_DELAYED_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/index.rst b/docs/features/low_precision_training/index.rst new file mode 100644 index 000000000..8b392d2bb --- /dev/null +++ b/docs/features/low_precision_training/index.rst @@ -0,0 +1,17 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Low precision training +=================================== + +.. toctree:: + + introduction/introduction.rst + performance_considerations/performance_considerations.rst + fp8_current_scaling/fp8_current_scaling.rst + fp8_delayed_scaling/fp8_delayed_scaling.rst + fp8_blockwise_scaling/fp8_blockwise_scaling.rst + mxfp8/mxfp8.rst + nvfp4/nvfp4.rst \ No newline at end of file diff --git a/docs/features/low_precision_training/introduction/autocast_jax.py b/docs/features/low_precision_training/introduction/autocast_jax.py new file mode 100644 index 000000000..0abb67006 --- /dev/null +++ b/docs/features/low_precision_training/introduction/autocast_jax.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from transformer_engine.jax.quantize import get_device_compute_capability + +# Requires Ada (SM89) or newer for FP8 support +assert get_device_compute_capability() >= 89, "This example requires SM89 (Ada) or newer" + +# START_AUTOCAST_BASIC + +import jax +import jax.numpy as jnp +import transformer_engine.jax as te +from transformer_engine.jax.flax import TransformerLayer +from transformer_engine.common.recipe import DelayedScaling, Format + +# Set up recipe +recipe = DelayedScaling() + +# Model initialization must happen inside autocast +with te.autocast(enabled=True, recipe=recipe): + layer = TransformerLayer( + hidden_size=1024, + mlp_hidden_size=4096, + num_attention_heads=16, + ) + + init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0)) + x = jax.random.normal(init_key, (32, 128, 1024), dtype=jnp.bfloat16) + var_collect = layer.init({"params": init_key, "dropout": dropout_key}, x) + + # Forward and backward pass (both inside autocast for JAX) + def loss_fn(var_collect): + output = layer.apply(var_collect, x, rngs={"dropout": dropout_key}) + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(var_collect) + +# END_AUTOCAST_BASIC + + +# START_AUTOCAST_SEQUENTIAL + +encoder_recipe = DelayedScaling(fp8_format=Format.E4M3) +decoder_recipe = DelayedScaling(fp8_format=Format.HYBRID) + +with te.autocast(enabled=True, recipe=encoder_recipe): + encoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + encoder_var_collect = encoder.init({"params": init_key, "dropout": dropout_key}, x) + hidden = encoder.apply(encoder_var_collect, x, rngs={"dropout": dropout_key}) + +with te.autocast(enabled=True, recipe=decoder_recipe): + decoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + decoder_var_collect = decoder.init({"params": init_key, "dropout": dropout_key}, hidden) + output = decoder.apply(decoder_var_collect, hidden, rngs={"dropout": dropout_key}) + +# END_AUTOCAST_SEQUENTIAL + + +# START_AUTOCAST_NESTED + +outer_recipe = DelayedScaling(fp8_format=Format.E4M3) +inner_recipe = DelayedScaling(fp8_format=Format.HYBRID) + +with te.autocast(enabled=True, recipe=outer_recipe): + # layer1 uses outer_recipe + layer1 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + var_collect1 = layer1.init({"params": init_key, "dropout": dropout_key}, x) + hidden = layer1.apply(var_collect1, x, rngs={"dropout": dropout_key}) + + with te.autocast(enabled=True, recipe=inner_recipe): + # layer2 uses inner_recipe (overrides outer) + layer2 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + var_collect2 = layer2.init({"params": init_key, "dropout": dropout_key}, hidden) + hidden = layer2.apply(var_collect2, hidden, rngs={"dropout": dropout_key}) + + # layer3 uses outer_recipe again + layer3 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + var_collect3 = layer3.init({"params": init_key, "dropout": dropout_key}, hidden) + output = layer3.apply(var_collect3, hidden, rngs={"dropout": dropout_key}) + +# END_AUTOCAST_NESTED diff --git a/docs/features/low_precision_training/introduction/autocast_pytorch.py b/docs/features/low_precision_training/introduction/autocast_pytorch.py new file mode 100644 index 000000000..2c1528ff9 --- /dev/null +++ b/docs/features/low_precision_training/introduction/autocast_pytorch.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or newer for FP8 support +assert torch.cuda.get_device_capability()[0] >= 9 or ( + torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9 +), "This example requires SM89 (Ada) or newer" + +# START_AUTOCAST_BASIC + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling, Format + +recipe = DelayedScaling() +layer = te.Linear(1024, 1024) +inp = torch.randn(32, 1024, dtype=torch.float32, device="cuda") + +with te.autocast(enabled=True, recipe=recipe): + output = layer(inp) + +# .backward() is called outside of autocast +loss = output.sum() +loss.backward() + +# END_AUTOCAST_BASIC + + +# START_AUTOCAST_SEQUENTIAL + +encoder_recipe = DelayedScaling(fp8_format=Format.E4M3) +decoder_recipe = DelayedScaling(fp8_format=Format.HYBRID) + +encoder = te.Linear(1024, 1024) +decoder = te.Linear(1024, 1024) + +with te.autocast(enabled=True, recipe=encoder_recipe): + hidden = encoder(inp) + +with te.autocast(enabled=True, recipe=decoder_recipe): + output = decoder(hidden) + +# END_AUTOCAST_SEQUENTIAL + + +# START_AUTOCAST_NESTED + +outer_recipe = DelayedScaling(fp8_format=Format.E4M3) +inner_recipe = DelayedScaling(fp8_format=Format.HYBRID) + +layer1 = te.Linear(1024, 1024) +layer2 = te.Linear(1024, 1024) +layer3 = te.Linear(1024, 1024) + +with te.autocast(enabled=True, recipe=outer_recipe): + # layer1 uses outer_recipe + x = layer1(inp) + + with te.autocast(enabled=True, recipe=inner_recipe): + # layer2 uses inner_recipe (overrides outer) + x = layer2(x) + + # layer3 uses outer_recipe again + output = layer3(x) + +# END_AUTOCAST_NESTED diff --git a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py new file mode 100644 index 000000000..a3c9c2ae4 --- /dev/null +++ b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_BF16_FP16_TRAINING + +import jax +import jax.numpy as jnp +from transformer_engine.jax.flax import TransformerLayer + + +def run_forward_backward(params_dtype, compute_dtype): + # Create TransformerLayer + layer = TransformerLayer( + hidden_size=1024, + mlp_hidden_size=4096, + num_attention_heads=16, + dtype=params_dtype, + ) + + # Initialize parameters and optimizer + init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0)) + x = jax.random.normal(init_key, (32, 128, 1024), dtype=compute_dtype) + var_collect = layer.init({"params": init_key, "dropout": dropout_key}, x) + + # Forward and backward pass + def loss_fn(var_collect): + output = layer.apply(var_collect, x, rngs={"dropout": dropout_key}) + assert output.dtype == compute_dtype + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(var_collect) + + +run_forward_backward(jnp.float32, jnp.float32) # high precision training +run_forward_backward(jnp.float32, jnp.bfloat16) # bfloat16 training with master weights in FP32 +run_forward_backward(jnp.bfloat16, jnp.bfloat16) # bfloat16 training with weights in BF16 + +# END_BF16_FP16_TRAINING diff --git a/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py b/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py new file mode 100644 index 000000000..4eb6ce1f8 --- /dev/null +++ b/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py @@ -0,0 +1,52 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_BF16_FP16_TRAINING + +import torch +import transformer_engine.pytorch as te +from contextlib import nullcontext + + +def run_forward_backward(params_dtype, autocast_precision, grad_scaler_enabled): + if grad_scaler_enabled: + grad_scaler = torch.amp.GradScaler("cuda") + + layer = te.TransformerLayer( + hidden_size=1024, + ffn_hidden_size=4096, + num_attention_heads=16, + params_dtype=params_dtype, + ) + x = torch.randn(32, 128, 1024, dtype=params_dtype, device="cuda") + + autocast_ctx = ( + torch.autocast(device_type="cuda", dtype=autocast_precision) + if autocast_precision is not None + else nullcontext() + ) + with autocast_ctx: + output = layer(x) + assert ( + output.dtype == autocast_precision if autocast_precision is not None else params_dtype + ) + loss = output.sum() + if grad_scaler_enabled: + grad_scaler.scale(loss).backward() + else: + loss.backward() + + +run_forward_backward(torch.float32, torch.float32, False) # high precision training +run_forward_backward( + torch.float32, torch.bfloat16, False +) # bfloat16 training with master weights in FP32 +run_forward_backward( + torch.float32, torch.float16, True +) # fp16 training with master weights in FP32, needs loss scaling +run_forward_backward( + torch.bfloat16, torch.bfloat16, False +) # bfloat16 training with weights in BF16 + +# END_BF16_FP16_TRAINING diff --git a/docs/features/low_precision_training/introduction/img/fp8_linear_flow.svg b/docs/features/low_precision_training/introduction/img/fp8_linear_flow.svg new file mode 100644 index 000000000..e1861ebc1 --- /dev/null +++ b/docs/features/low_precision_training/introduction/img/fp8_linear_flow.svg @@ -0,0 +1,172 @@ + + + + + + + + + + + FP8 Linear Layer – Forward and Backward Pass + + + Forward Pass + + + + InputT + + + + Input + + + + + + + Quantize + + + + + + + + + Input + + + + N + + + + Weight + + + + + + + Quantize + + + + + + + + + Weight + + + + WeightT + + + + T + + + + FP8 GEMM + (TN) + + + + + + + Output + + + + + + Backward Pass + + + + WeightT + + + + Output grad. + + + + + + + Quantize + + + + + + + + + Output grad. + + + + Output grad.T + + + + FP8 GEMM + (TN) + + + + Input grad. + + + + FP8 GEMM + (TN) + + + + Weight grad. + + + + InputT + + + + + N + + + T + + + + + + N + + + T + + + + + + + + Higher Precision (FP32/BF16/FP16) + + + + Lower Precision (FP8, MXFP8 etc.) + + + diff --git a/docs/features/low_precision_training/introduction/img/fp_formats_comparison.svg b/docs/features/low_precision_training/introduction/img/fp_formats_comparison.svg new file mode 100644 index 000000000..a6c46b364 --- /dev/null +++ b/docs/features/low_precision_training/introduction/img/fp_formats_comparison.svg @@ -0,0 +1,183 @@ + + + + + + + sign + exponent + mantissa + + + FP32 + + + + 0 + + + + 0 + + 1 + + 1 + + 1 + + 1 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + 0 + + 1 + + 0 + + 1 + + 0 + + 0 + + 1 + + 0 + + 1 + + 0 + + 1 + + 1 + + 1 + + 1 + + 0 + + 1 + + 0 + + 1 + + 0 + + 0 + + 0 + + = 0.3952 + + + + BF16 + + + + 0 + + + + 0 + + 1 + + 1 + + 1 + + 1 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + 0 + + 1 + + 0 + + 1 + + 0 + + ≈ 0.3945 + + + + FP16 + + + + 0 + + + + 0 + + 1 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + 0 + + 1 + + 0 + + 1 + + 0 + + 0 + + 1 + + 0 + + ≈ 0.3950 + + diff --git a/docs/features/low_precision_training/introduction/img/master_weights_approaches.svg b/docs/features/low_precision_training/introduction/img/master_weights_approaches.svg new file mode 100644 index 000000000..b231fefd9 --- /dev/null +++ b/docs/features/low_precision_training/introduction/img/master_weights_approaches.svg @@ -0,0 +1,112 @@ + + + + + + + + + + + Master Weights Storage Approaches + + + + + + + Low Precision Weights + (no master weights) + + + + Model + + Weights (BF16/FP16) + + + + + + + Forward/Backward + + + + + + + Optimizer + + State (FP32) + + + Master Weights in Model + + + + Model + + Weights (FP32) + + + + + cast to BF16/FP16 + + + + Forward/Backward + + + + + + + Optimizer + + State (FP32) + + + Master Weights in Optimizer + + + + cast to BF16/FP16 + + + + + + + Model + + Weights (BF16/FP16) + + + + + + + Forward/Backward + + + + + + + Optimizer + + State (FP32) + + Master (FP32) + + + + + diff --git a/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg b/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg new file mode 100644 index 000000000..7a6175918 --- /dev/null +++ b/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg @@ -0,0 +1,105 @@ + + + + + + + + + + + Transformer Layer – default precision of operation in low precision recipe + + + + Input + + + + + Layer Norm + + + + + QKV Linear + + + + + QK^T + + + + + Softmax + + + + + + Scores * V + + + + + Output Linear + + + + + Dropout + Add + + + + + + Layer Norm + + + + + FFN Linear 1 + + + + + GELU + + + + + FFN Linear 2 + + + + + Output + + + + + + + Parameters + + + + Gradients + + + + + + Higher Precision (FP32/BF16/FP16) + + + + Lower Precision (FP8, MXFP8 etc.) + + + diff --git a/docs/features/low_precision_training/introduction/introduction.rst b/docs/features/low_precision_training/introduction/introduction.rst new file mode 100644 index 000000000..760a63b0b --- /dev/null +++ b/docs/features/low_precision_training/introduction/introduction.rst @@ -0,0 +1,285 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Introduction +=================================== + +Transformer Engine accelerates deep learning on NVIDIA GPUs in several ways, +with low precision training being one of the most important. +This chapter introduces mixed precision training and FP8 support. + + +Training in BF16/FP16 +--------------------- + +Deep learning traditionally uses 32-bit floating-point (FP32) numbers. +NVIDIA GPUs support lower precision formats—FP16 since Pascal, BF16 since Ampere—which offer higher throughput and lower memory usage. +Let's compare these formats. + +.. raw:: html + :file: img/fp_formats_comparison.svg + +*Figure 1: Comparison of FP32, BF16, and FP16 floating-point formats showing bit allocation for sign, exponent, and mantissa.* + +The key differences between these formats are: + +* **FP32** (32 bits total): 1 sign bit + 8 exponent bits + 23 mantissa bits – standard single-precision format +* **BF16** (16 bits total): 1 sign bit + 8 exponent bits + 7 mantissa bits – maintains FP32's exponent range but has reduced precision +* **FP16** (16 bits total): 1 sign bit + 5 exponent bits + 10 mantissa bits – reduced range but higher precision than BF16 + +BF16's advantage is that it shares the same exponent range as FP32, +making it easier to convert between the two formats without overflow/underflow issues. +FP16 offers better precision for smaller values but has a limited dynamic range, +which results in the need to perform loss scaling to avoid overflow/underflow—see `this paper on loss scaling `__ for more details. + +**Mixed precision** + +Not all operations should be run in reduced precision to preserve accuracy. +Modern deep learning frameworks use *mixed precision training*, +where different operations use different precisions based on their numerical properties: + +* Matrix multiplications are compute-heavy and remain numerically stable at lower precision, making them ideal candidates for acceleration. +* Operations like layer normalization and softmax can work with low precision inputs and outputs, but may use high precision internally or for their weights. +* Operations like loss computation and exponentiation need high precision throughout. + +**Master weights** + +Another consideration in mixed precision training is how to store the model weights. +Lower precision formats like FP16 and BF16 have limited representational granularity, +which becomes problematic during gradient updates. +When a small gradient is added to a not so small weight stored in low precision, +the result may round back to the original value if the update falls below the format's precision threshold. +Moreover, some elements of the gradient itself can be too small to be represented in low precision, +especially after the accumulation from multiple GPUs in the data parallel training setting. + +The solution is to maintain *master weights* in FP32. +During training, weights are cast to lower precision for forward and backward passes, +but the gradient updates are applied to the full-precision master copy. +This ensures that even small gradients accumulate correctly over time. + +There are two common software approaches to storing master weights: + +* *In the optimizer*: + The model holds low-precision weights, + while the optimizer maintains FP32 copies alongside momentum and other state. + During each step, + the optimizer updates its FP32 copy and casts the result back to the model's low-precision weights. + + This approach makes it easier to shard master weights together with other optimizer state, for example in ZeRO optimizer. + + Since the casting happens only during the optimizer step, this approach is also faster when optimizer runs less frequently than the model, e.g. when performing gradient accumulation or pipeline parallel training. + +* *In the model*: + The model stores weights directly in FP32, + and they are cast to lower precision on-the-fly during forward and backward passes. + This approach works seamlessly with any standard optimizer, requiring no special support. + +.. raw:: html + :file: img/master_weights_approaches.svg + +*Figure 2: Three approaches to weight storage—low precision only (no master weights), master weights stored in the model, and master weights stored in the optimizer.* + +.. tabs:: + + .. tab:: PyTorch + + The PyTorch API of Transformer Engine provides several mechanisms to control precision: + + * **Weight precision**: Use the ``params_dtype`` argument in any TE layer constructor. + * **Computation precision**: Use the ``torch.autocast`` context manager. When enabled, inputs are cast to the autocast dtype before computation. + * **Input dtype**: When ``torch.autocast`` is not used, the input tensor's dtype determines the computation precision. In this case, inputs and parameters must have matching dtypes. + + .. literalinclude:: bf16_fp16_training_pytorch.py + :language: python + :start-after: # START_BF16_FP16_TRAINING + :end-before: # END_BF16_FP16_TRAINING + + + .. tab:: JAX + + The JAX API of Transformer Engine provides two mechanisms to control precision: + + * **Weight precision**: Use the ``dtype`` argument in any TE layer constructor. + * **Computation precision**: Determined by the dtype of the input tensor. + + For training with master weights in FP32 and computation in BF16, + cast the input tensor to BF16 before passing it to the layer. + + .. literalinclude:: bf16_fp16_training_jax.py + :language: python + :start-after: # START_BF16_FP16_TRAINING + :end-before: # END_BF16_FP16_TRAINING + + + +Lower precisions +---------------- + +Transformer Engine's primary feature is supporting even lower precision than BF16/FP16, such as FP8, MXFP8, NVFP4, etc. +The logic of these precisions is more complicated than the logic of BF16/FP16 – they require scaling factors to +properly represent the full range of values in the tensor. Sometimes it is one scaling factor per tensor, +sometimes it is one scaling factor per block of values. A precision format combined with the logic for training +is called **a recipe**. + +In this section we present common logic for all the recipes. Each one of them is described in more detail in a separate section later. +Let's now see how we can train in lower precisions in supported frameworks. + +.. tabs:: + + .. tab:: PyTorch + + The PyTorch API of Transformer Engine provides an ``autocast`` context manager to control precision. + It's similar to the ``torch.autocast`` context manager, but tailored for low precision training. + The most important argument is the ``recipe`` argument, which accepts objects inheriting from + :class:`~transformer_engine.common.recipe.Recipe`. + + Forward computations need to be performed inside the ``autocast`` context manager, + while the ``.backward()`` call should be outside of it (it inherits the setting from the + corresponding forward pass). + + Here is a basic example: + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada or newer) +
+ + .. literalinclude:: autocast_pytorch.py + :language: python + :start-after: # START_AUTOCAST_BASIC + :end-before: # END_AUTOCAST_BASIC + + You can use multiple recipes in the same model in the following ways: + + **Sequential contexts** – apply different recipes to different parts of your model: + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada or newer) +
+ + .. literalinclude:: autocast_pytorch.py + :language: python + :start-after: # START_AUTOCAST_SEQUENTIAL + :end-before: # END_AUTOCAST_SEQUENTIAL + + **Nested contexts** – the inner context overrides the outer one for its scope: + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada or newer) +
+ + .. literalinclude:: autocast_pytorch.py + :language: python + :start-after: # START_AUTOCAST_NESTED + :end-before: # END_AUTOCAST_NESTED + + + .. tab:: JAX + + The JAX API of Transformer Engine provides an ``autocast`` context manager similar to PyTorch. + The key difference is that in JAX, model initialization must happen inside the ``autocast`` context + to properly capture quantization metadata in the parameter tree. + + Here is a basic example: + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada or newer) +
+ + .. literalinclude:: autocast_jax.py + :language: python + :start-after: # START_AUTOCAST_BASIC + :end-before: # END_AUTOCAST_BASIC + + You can use multiple recipes in the same model in the following ways: + + **Sequential contexts** – apply different recipes to different parts of your model: + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada or newer) +
+ + .. literalinclude:: autocast_jax.py + :language: python + :start-after: # START_AUTOCAST_SEQUENTIAL + :end-before: # END_AUTOCAST_SEQUENTIAL + + **Nested contexts** – the inner context overrides the outer one for its scope: + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada or newer) +
+ + .. literalinclude:: autocast_jax.py + :language: python + :start-after: # START_AUTOCAST_NESTED + :end-before: # END_AUTOCAST_NESTED + + .. note:: + Python context managers like ``autocast`` may interact unexpectedly with JAX's JIT compilation. + For finer-grained control, consider passing the recipe directly to TE modules instead. + See the `TE JAX Integration notebook `_ + for details. + +**Mixed precision with 8- or 4-bit precisions** + +From now on, we will refer to FP8/MXFP8/NVFP4 etc. as *low precision* +and to FP32/BF16/FP16 as *high precision*. This terminology will be +used throughout the rest of the documentation. + +Not all operations run in low precision: + +- **Linear operations**: run in low precision. +- **Attention computations**: run in high precision by default (some recipes allow low precision as an option). +- **Other operations** (layer normalization, softmax, etc.): run in high precision. + +Within high-precision operations, there are two categories: + +- **Configurable precision**: most operations run in parameter precision (FP32/BF16/FP16) or the precision specified by ``torch.autocast``. +- **Fixed FP32 precision**: some operations, or parts of operations—such as the division in layernorm—always run in FP32, regardless of other settings. + +.. raw:: html + :file: img/mixed_precision_operations.svg + +*Figure 3: Default precision of operations in a TransformerLayer forward pass. Only linear operations are in lower precision. Dot product attention is shown as three separate operations (QK^T, Softmax, Scores * V), though in practice these may be fused into a single kernel.* + +**Linear layer data flow** + +Let's see how data flow of a linear layer works by default on a single H100 GPU with FP8 precision: + +H100 (Hopper) architecture natively supports FP8 Matrix Multiplication only in **TN** layout (Transpose-NoTranspose), +so GEMM with tensors ``A`` and ``B`` returns ``B * A^T``. + +*Forward pass* + +* Input is quantized to FP8 – both ``input`` and ``input^T`` quantized versions are created. +* Weights are stored in high precision and quantized to low precision before the GEMM – both ``weight`` and ``weight^T`` quantized versions are created. +* FP8 GEMM with layout **TN** is run with ``weight`` and ``input`` tensors, +* Outputs – ``input * weight^T`` tensor – are returned in high precision. + +*Backward pass* + +* Output gradients are quantized to FP8 – both ``output_grad`` and ``output_grad^T`` quantized versions are created. +* FP8 GEMM with layout **TN** is performed with ``weight^T`` and ``output_grad`` tensors to compute input gradients. +* FP8 GEMM with layout **TN** is performed with ``input^T`` and ``output_grad^T`` tensors to compute weight gradients. +* Input gradients – ``output_grad * weight`` tensor – are returned in high precision. +* Weight gradients – ``output_grad^T * input`` tensor – are returned in high precision. + + +.. raw:: html + :file: img/fp8_linear_flow.svg + +*Figure 4: Forward pass of a Linear layer with low precision data flow.* diff --git a/docs/features/low_precision_training/mxfp8/img/fp8_1d_scaling.svg b/docs/features/low_precision_training/mxfp8/img/fp8_1d_scaling.svg new file mode 100644 index 000000000..30f16d9a7 --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/img/fp8_1d_scaling.svg @@ -0,0 +1,177 @@ + + + + + + + + MXFP8 + (One scaling factor per 32 elements) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + E8M0 scaling factors (one per 32 elements) + + + diff --git a/docs/features/low_precision_training/mxfp8/img/mxfp8_row_col.svg b/docs/features/low_precision_training/mxfp8/img/mxfp8_row_col.svg new file mode 100644 index 000000000..42ea0308b --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/img/mxfp8_row_col.svg @@ -0,0 +1,266 @@ + + + + + + + Rowwise (1x32 blocks) + + + + Data + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Scales + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Columnwise (32x1 blocks) + + + + Data + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Scales + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/features/low_precision_training/mxfp8/img/mxfp8_scale_linearize_and_swizzle.svg b/docs/features/low_precision_training/mxfp8/img/mxfp8_scale_linearize_and_swizzle.svg new file mode 100644 index 000000000..6e4ed44d5 --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/img/mxfp8_scale_linearize_and_swizzle.svg @@ -0,0 +1,190 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + 1 + 2 + 3 + + K + + + 1 + K + + + 2 + K + + + 3 + + 2K + + + 1 + 2K + + + 1 + 2K + + + 3 + + + + + + + + + + + + + 128x4 + + + + + + + + + + + + 1 + + + 2 + + + + + + K + 1 + + + K + 2 + + + + + + 1x512 + + + + + + + 128 4-bit elements + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + ... + + + + + + + + + + + + + + + + + + + + + + + 0 + 32 + 64 + 96 + 1 + 33 + 65 + 97 + ... + + + + diff --git a/docs/features/low_precision_training/mxfp8/img/mxfp8_swizzle_both_tensors.svg b/docs/features/low_precision_training/mxfp8/img/mxfp8_swizzle_both_tensors.svg new file mode 100644 index 000000000..d8489ecc4 --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/img/mxfp8_swizzle_both_tensors.svg @@ -0,0 +1,101 @@ + + + + + + + + + + + + + + + + Input Tensor + + FP32/BF16 + + + + + + + + Quantize + + + + + + + MXFP8 Tensor + + + + + Scales + + + + FP8 Data + + + + + + + + Communication + (All-Gather) + (Optional) + + + + + + + Swizzle + + + + + + + MXFP8 Tensor + + + + + Swizzle Scales + + + + FP8 Data + + + + + + + + GEMM + + diff --git a/docs/features/low_precision_training/mxfp8/img/mxfp8_tensor_scaling_layout.svg b/docs/features/low_precision_training/mxfp8/img/mxfp8_tensor_scaling_layout.svg new file mode 100644 index 000000000..3b81ff0a3 --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/img/mxfp8_tensor_scaling_layout.svg @@ -0,0 +1,63 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + FP8 Tensor (128×128 blocks) + + + + + + + + + + + + + + + + + + + + + + + + + + + Scaling Factors (128×4 blocks) + diff --git a/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py new file mode 100644 index 000000000..96ef1a257 --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# Check for Blackwell or newer GPU +from transformer_engine.jax.quantize import get_device_compute_capability + +assert ( + get_device_compute_capability() >= 100 +), f"MXFP8 requires SM100 (Blackwell) or later, got SM{get_device_compute_capability()}" + +# START_MXFP8_EXAMPLE + +import jax +import jax.numpy as jnp +import transformer_engine.jax as te +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.common.recipe import MXFP8BlockScaling, Format + +# Create MXFP8 recipe +recipe = MXFP8BlockScaling( + fp8_format=Format.E4M3, # FP8 format (default: E4M3, E5M2 not supported) +) + +with te.autocast(enabled=True, recipe=recipe): + # Initialize layer and data + layer = DenseGeneral(features=1024) + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) + var_collect = layer.init(key, x) + + # Forward and backward pass + def loss_fn(var_collect): + output = layer.apply(var_collect, x) + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(var_collect) + +# END_MXFP8_EXAMPLE diff --git a/docs/features/low_precision_training/mxfp8/mxfp8.rst b/docs/features/low_precision_training/mxfp8/mxfp8.rst new file mode 100644 index 000000000..f8f8f48b0 --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/mxfp8.rst @@ -0,0 +1,213 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +MXFP8 +===== + + +MXFP8 (Microscaling FP8) is an enhanced FP8 blockwise scaling recipe that leverages native hardware +acceleration on Blackwell GPUs (SM 10.0+). By using one scaling factor per 32 consecutive values +(rather than 128), MXFP8 delivers finer-grained quantization with improved numerical precision. + + + +Data Format +----------- + +The representation of an FP8 tensor element ``x`` in MXFP8 precision is given by: + +.. code-block:: python + + x = x_fp8 * s_block + +where + +* ``x_fp8`` is the FP8 value in E4M3 format, +* ``s_block`` is a local **E8M0** scaling factor shared by a block of 32 elements. + E8M0 is an 8-bit format with 8 exponent bits and 0 mantissa bits, representing only powers of 2. + + +**FP8 format** + +Like FP8 Blockwise Scaling, E4M3 is used by default for both forward and backward passes. +The finer-grained scaling provides sufficient dynamic range without requiring the E5M2 format. +The ``fp8_format`` parameter also supports ``HYBRID`` mode (E4M3 for forward, E5M2 for backward). +Pure E5M2 training is not supported. + + +**Block size** + +Block size is 32. +Blocks are one-dimensional, containing 32 consecutive values. No 2D scaling is performed. + +There are some assumptions on the dimensions of the tensor: + +* the tensor must have at least 2 dimensions, +* the last dimension must be divisible by 32, +* the product of all dimensions except the last must be divisible by 32. + + +**Scaling factors** + +Scaling factors are stored as E8M0 (8 exponent bits, 0 mantissa bits), which inherently represents +powers of 2. This differs from FP8 Blockwise Scaling, which uses 32-bit floating point numbers +optionally constrained to powers of 2. Note that FP32 also has 8 exponent bits, so the representable +ranges are the same when the power-of-2 constraint is enabled. + +Each block's scaling factor is computed through the following steps: + +1. Find the maximum absolute value (``amax_block``) across all 32 elements in the block. +2. Compute the E8M0 biased exponent: ``e = float_to_e8m0(amax_block / max_fp8)``, where ``max_fp8 = 448`` + (the maximum representable value in E4M3 format). + + Since E8M0 and FP32 share the same exponent bias (127), ``float_to_e8m0`` simply extracts + the 8-bit exponent from the FP32 representation, rounding up if the mantissa is non-zero. + +3. The scaling factor is ``s_block = 2^(e - 127)``. + +This ensures that the largest value in each block fits within the FP8 representable range without overflow. + + +.. raw:: html + :file: img/fp8_1d_scaling.svg + +*Figure 1. MXFP8 uses one E8M0 scaling factor per 32 consecutive elements, providing fine-grained +quantization and compact scaling factor representation.* + + +Handling transposes +------------------- + +Blackwell architecture supports multiple FP8 GEMM layouts (TN, NT, NN), so columnwise usage +does not require explicit transposition. However, rowwise and columnwise quantizations are different: + +- *Rowwise* - 1 scaling factor per 32 consecutive elements along a row (1×32 blocks). +- *Columnwise* - 1 scaling factor per 32 consecutive elements along a column (32×1 blocks). + +Since the scaling factor blocks have different orientations, rowwise and columnwise MXFP8 tensors +are numerically different — one cannot derive one from the other. Both must be quantized +independently from the full-precision data. + +.. raw:: html + :file: img/mxfp8_row_col.svg + +*Figure 2. MXFP8 rowwise vs columnwise quantization layout.* + + +Distributed training +-------------------- + +**Scale synchronization** + +The blockwise scaled tensor does not need any scale synchronization among the nodes. +This is because each scaling factor is local to its 32-element block, +unlike :doc:`FP8 Current <../fp8_current_scaling/fp8_current_scaling>`/:doc:`Delayed Scaling <../fp8_delayed_scaling/fp8_delayed_scaling>` where a single global scale applies to the entire tensor, even when sharded. + +**Quantized all-gather** + +MXFP8 all-gather is supported. + + +Examples +-------- + +Here's how to use MXFP8 recipe in PyTorch and JAX: + +.. tabs:: + + .. tab:: PyTorch + + .. raw:: html + +
+ Requires SM100 (Blackwell) or later +
+ + .. literalinclude:: pytorch_mxfp8_example.py + :language: python + :start-after: # START_MXFP8_EXAMPLE + :end-before: # END_MXFP8_EXAMPLE + + .. tab:: JAX + + .. raw:: html + +
+ Requires SM100 (Blackwell) or later +
+ + .. literalinclude:: jax_mxfp8_example.py + :language: python + :start-after: # START_MXFP8_EXAMPLE + :end-before: # END_MXFP8_EXAMPLE + + +Supported devices +----------------- + +SM 10.0, SM 10.3 + + +---- + +Developer Notes +--------------- + +This section contains implementation details that may be useful for developers +but are not required for using MXFP8 in practice. + +Swizzling scaling factors +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Like :doc:`FP8 Blockwise Scaling <../fp8_blockwise_scaling/fp8_blockwise_scaling>`, MXFP8 uses different data layouts for communication and computation. +MXFP8 GEMMs require scaling factors in a specific hardware layout +(see `cuBLAS documentation `__). +The conversion to this GEMM-ready layout is called *swizzling*. When no communication is needed, +swizzling can be fused with quantization. When communication is required, swizzled scaling factors +cannot be communicated across devices, so Transformer Engine performs swizzling after communication, +just before each GEMM operation. + +.. raw:: html + :file: img/mxfp8_swizzle_both_tensors.svg + +*Figure 3. MXFP8 swizzling process: standard scaling factors are rearranged into the hardware-required layout.* + + +Blackwell Tensor Cores compute matrix multiplications using ``128x128`` tiles. +Scaling factors are stored in row-major order, but to process a tile, we need a ``128x4`` vertical +slice of scaling factors. In row-major storage, these vertical slices are scattered in memory +with gaps between each row. The hardware requires them to be stored contiguously. + +.. raw:: html + :file: img/mxfp8_tensor_scaling_layout.svg + +*Figure 4. FP8 tensor (left) is divided into 128x128 tiles. Each tile requires a 128x4 block of scaling factors (right). These vertical blocks are not contiguous in memory.* + +Swizzling transforms the layout to meet hardware requirements by: + +1. **Linearizing** the ``128x4`` blocks so they are stored contiguously one after another. +2. **Permuting** the 4-byte elements within each block. + +Specifically, if we index the 128 4-byte elements in a scaling factor block as :math:`0, 1, \dots, 127`, the hardware expects them in the following interleaved order: + +.. code-block:: text + + 0, 32, 64, 96, 1, 33, 65, 97, ..., k, 32 + k, 64 + k, 96 + k, ..., 31, 63, 95, 127 + + +.. raw:: html + :file: img/mxfp8_scale_linearize_and_swizzle.svg + +*Figure 5. Linearization and swizzling of scaling factors. The 2D grid of scaling factors is first flattened into a contiguous sequence of blocks (top), then the rows within each block are interleaved to match the hardware access pattern (bottom).* + +For columnwise scaling factors, the process is analogous but with ``4x128`` horizontal blocks instead of ``128x4`` vertical blocks. + +All-gather of columnwise tensors +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +All-gather of columnwise tensors is supported and necessary because: + +- columnwise quantized tensors cannot be computed from rowwise quantized ones, +- gathering high-precision tensors is avoided in most cases for performance reasons. \ No newline at end of file diff --git a/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py b/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py new file mode 100644 index 000000000..3cc70137b --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Check for Blackwell or newer GPU +major, minor = torch.cuda.get_device_capability() +assert major >= 10, f"MXFP8 requires SM100 (Blackwell) or later, got SM{major}{minor}" + +# START_MXFP8_EXAMPLE + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import MXFP8BlockScaling, Format + +# Create MXFP8 recipe +recipe = MXFP8BlockScaling( + fp8_format=Format.E4M3, # E4M3 (default) or HYBRID; pure E5M2 not supported +) + +# Create a linear layer with bfloat16 parameters +layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) + +# Forward and backward pass +inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") + +with te.autocast(enabled=True, recipe=recipe): + output = layer(inp) + loss = output.sum() + +loss.backward() + +# END_MXFP8_EXAMPLE diff --git a/docs/features/low_precision_training/nvfp4/img/nvfp4_all_gather.svg b/docs/features/low_precision_training/nvfp4/img/nvfp4_all_gather.svg new file mode 100644 index 000000000..3e215551a --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/img/nvfp4_all_gather.svg @@ -0,0 +1,118 @@ + + + + + + + + + + + Quantization + All-Gather for NVFP4 + + + + High Precision + Tensor + + + + + + + Compute + Amax + + + + + + + Synchronize + Amax + + + + + + + Compute + s_global + + + + + + + Scale + Cast + (s_block, + s_global) + + + + + + + NVFP4 + Tensor + + + + + + + All-Gather + + + + + + + NVFP4 Gathered + Tensor + + + diff --git a/docs/features/low_precision_training/nvfp4/img/nvfp4_hierarchical_scaling.svg b/docs/features/low_precision_training/nvfp4/img/nvfp4_hierarchical_scaling.svg new file mode 100644 index 000000000..05e67b788 --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/img/nvfp4_hierarchical_scaling.svg @@ -0,0 +1,186 @@ + + + + + + + + NVFP4 Hierarchical Scaling + (Block scaling + Global scaling) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + E4M3 scaling factors (one per 16 elements) + + + + + Global Scale (FP32) + (one per tensor) + + + + + + \ No newline at end of file diff --git a/docs/features/low_precision_training/nvfp4/img/nvfp4_row_col.svg b/docs/features/low_precision_training/nvfp4/img/nvfp4_row_col.svg new file mode 100644 index 000000000..30363d6ce --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/img/nvfp4_row_col.svg @@ -0,0 +1,208 @@ + + + + + + + Rowwise (1×16 blocks) + + + + Data [A, B] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + s_block [A, B/16] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + s_global + + + + + Columnwise (16×1 blocks) — transposed storage + + + + Data [B, A] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + s_block [B, A/16] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + s_global + + + diff --git a/docs/features/low_precision_training/nvfp4/img/nvfp4_vs_fp8.svg b/docs/features/low_precision_training/nvfp4/img/nvfp4_vs_fp8.svg new file mode 100644 index 000000000..68f6bf903 --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/img/nvfp4_vs_fp8.svg @@ -0,0 +1,91 @@ + + + + + + + FP8 E4M3 + + + + 0 + + + + 1 + + 0 + + 0 + + 0 + + + + 1 + + 1 + + 1 + + (1 sign, 4 exp, 3 mantissa) + + + + FP8 E5M2 + + + + 0 + + + + 1 + + 0 + + 0 + + 0 + + 0 + + + + 1 + + 1 + + (1 sign, 5 exp, 2 mantissa) + + + + NVFP4 + + + + 0 + + + + 1 + + 0 + + + + 1 + + (1 sign, 2 exp, 1 mantissa) + + + + diff --git a/docs/features/low_precision_training/nvfp4/img/rht.svg b/docs/features/low_precision_training/nvfp4/img/rht.svg new file mode 100644 index 000000000..0250c27ae --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/img/rht.svg @@ -0,0 +1,138 @@ + + + + + + + + + + + + Random Hadamard Transform for WGRAD GEMM + + + + + + + Without RHT + + + + + Activations + + + + + + + Quantize + + + + + + + WGRAD + GEMM + + + + + Output Grad + + + + + + + Quantize + + + + + + + + + + Weight Grad + + + + + With RHT + + + + + Activations + + + + + + + RHT + + + + + + + Quantize + + + + + + + WGRAD + GEMM + + + + + Output Grad + + + + + + + RHT + + + + + + + Quantize + + + + + + + + + + Weight Grad + + + diff --git a/docs/features/low_precision_training/nvfp4/img/stochastic_rounding.svg b/docs/features/low_precision_training/nvfp4/img/stochastic_rounding.svg new file mode 100644 index 000000000..eb745f6e8 --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/img/stochastic_rounding.svg @@ -0,0 +1,95 @@ + + + + + + + + + + + + Round to Nearest + + + + + + + v₁ + + + + v₂ + + + + x + + + + + Round to v₁ + + + 100% + + + Round to v₂ + + + 0% + + + + + + + Stochastic Rounding + + + + + + + v₁ + + + + v₂ + + + + x + + + + + Round to v₁ + + + 60% + + + Round to v₂ + + + 40% + + + + + diff --git a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py new file mode 100644 index 000000000..6c94f3134 --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# Check for Blackwell or newer GPU +from transformer_engine.jax.quantize import get_device_compute_capability + +assert ( + get_device_compute_capability() >= 100 +), f"NVFP4 requires SM100 (Blackwell) or later, got SM{get_device_compute_capability()}" + +# START_NVFP4_EXAMPLE + +import jax +import jax.numpy as jnp +import transformer_engine.jax as te +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.common.recipe import NVFP4BlockScaling + +# Define NVFP4 recipe +# 2D weight quantization and RHT are enabled by default +recipe = NVFP4BlockScaling() +# To disable features, use: +# recipe = NVFP4BlockScaling(disable_rht=True, disable_2d_quantization=True) + +with te.autocast(enabled=True, recipe=recipe): + # Initialize layer and data + layer = DenseGeneral(features=1024) + key, sr_key = jax.random.split(jax.random.PRNGKey(0)) + x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) + + # NVFP4 requires sr_rng for stochastic rounding + rngs = {"sr_rng": sr_key} + var_collect = layer.init({"params": key, "sr_rng": sr_key}, x) + + # Forward and backward pass + def loss_fn(var_collect): + output = layer.apply(var_collect, x, rngs=rngs) + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(var_collect) + +# END_NVFP4_EXAMPLE diff --git a/docs/features/low_precision_training/nvfp4/nvfp4.rst b/docs/features/low_precision_training/nvfp4/nvfp4.rst new file mode 100644 index 000000000..0415963a7 --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/nvfp4.rst @@ -0,0 +1,275 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +NVFP4 +=================================== + +NVFP4 is the first 4-bit recipe introduced in Transformer Engine – +please refer to the `NVFP4 paper `__ for more details. +It is a more complex recipe than the previous ones – apart from the new data format, +it introduces multiple features which help training stability. + +Data Format +---------------------- + +The NVFP4 datatype consists of 1 sign bit, 2 exponent bits, and 1 mantissa bit (E2M1). +It can represent values of magnitude up to +/- 6. +NVFP4 uses a hierarchical block scaling approach where multiple scaling factors are combined to recover the high precision value. + +.. raw:: html + :file: img/nvfp4_vs_fp8.svg + +*Figure 1. Bit layout comparison between standard FP8 formats (E4M3 and E5M2) and NVFP4 (E2M1).* + + +The representation of an NVFP4 tensor element ``x`` is given by: + +.. code-block:: python + + x = x_e2m1 * s_block * s_global + +where + +* ``x_e2m1`` is the 4-bit value, +* ``s_block`` is a local **FP8 E4M3** scaling factor shared by a block of 16 consecutive elements, +* ``s_global`` is a global **FP32** scaling factor applied to the entire tensor. + +**Scaling Factor Computation** + +The scaling factors are computed as follows: + +1. Global scaling factor (``s_global``): + +.. code-block:: python + + s_global = global_amax / (fp8_max * fp4_max) + # where: + # - global_amax: maximum absolute value across the entire tensor + # - fp8_max: maximum representable value in FP8 E4M3 (448.0) + # - fp4_max: maximum representable value in NVFP4 E2M1 (6.0) + +2. Block scaling factor (``s_block``): + +.. code-block:: python + + s_block = (block_amax / fp4_max) / s_global + # where: + # - block_amax: maximum absolute value within the block + # - fp4_max: maximum representable value in NVFP4 E2M1 (6.0) + # - s_block is stored in FP8 E4M3 format + + +.. raw:: html + :file: img/nvfp4_hierarchical_scaling.svg + +*Figure 2. NVFP4 hierarchical scaling structure showing the combination of block-level and global scaling factors.* + +This hierarchical structure uses fine-grained block scaling to handle the tensor's dynamic range, +while the FP4 values represent the block-level dynamic range. The global scaling factor +aligns values to the representable range of the E4M3 × E2M1 combination. + +**2D weight scaling** + +NVFP4 can be: + +* 1 dimensional - each block of 16 consecutive elements shares a scaling factor, +* 2 dimensional - each block of 16x16 elements shares a scaling factor. + +By default, NVFP4 uses 2D scaling for weights and 1D scaling for activations and gradients. +Set ``disable_2d_quantization=True`` in the recipe configuration to force 1D scaling for weights as well (activations and gradients always use 1D). +The motivation for using 2D scaling for weights is to ensure that rowwise and columnwise +quantized tensors are numerically equivalent. +Please refer to the `NVFP4 paper `__ for more details. + + +Stochastic Rounding +------------------- + +Stochastic rounding is applied when casting scaled values to NVFP4 format. Instead of deterministic rounding +(always rounding to nearest even value), each scaled value is probabilistically rounded to one of the two +nearest representable NVFP4 values. The probability of rounding to a given value is inversely proportional to +the distance to that value, which ensures that the expected value of the quantized +tensor equals the original value, eliminating systematic quantization bias during training. +Stochastic rounding is hardware-accelerated using native GPU instructions introduced with the +Blackwell architecture. + +.. raw:: html + :file: img/stochastic_rounding.svg + +*Figure 3. Stochastic rounding illustration. Given a value* ``x`` *to be quantized, and the two nearest +representable NVFP4 values* ``v1`` *(lower) and* ``v2`` *(higher), deterministic rounding always +rounds to the nearest value, while stochastic rounding probabilistically rounds to either value. +If* ``x`` *is 40% of the way from* ``v1`` *to* ``v2``, *there is a 60% chance of rounding to* ``v1`` +*and a 40% chance of rounding to* ``v2``. + +Stochastic rounding is enabled only for gradients. It can be disabled by setting +``disable_stochastic_rounding=True`` in the recipe configuration. + + +Random Hadamard Transform +-------------------------- + +Random Hadamard Transform (RHT) applies an orthogonal rotation to the tensor **before quantization**, +smoothing outliers in the tensor distributions and making them easier to represent accurately in NVFP4. +RHT is applied to columnwise quantization of inputs and gradients, which are operands +for the **wgrad GEMM**. This GEMM is particularly sensitive +to quantization errors, hence the additional outlier smoothing. +RHT is supported only for BF16 inputs/gradients. + +The transform is defined as: + +.. math:: + + x' = x H + +where :math:`H` is the RHT matrix defined below. The quantization scale factor is computed +from the rotated tensor :math:`x'`. + +**Hadamard matrix** + +The :math:`d \times d` Hadamard matrix has elements :math:`\pm 1` and satisfies :math:`H_d H_d^T = d I`. +When normalized by :math:`1/\sqrt{d}`, the matrix becomes orthogonal and can be applied +to both operands of a matrix multiplication: + +.. math:: + + C = (AH)(H^T B) = AB + +where the transforms cancel within the dot-product since :math:`H H^T = I`. + +**Sign matrix** + +In the RHT implementation, a :math:`d`-dimensional diagonal sign matrix :math:`S_d` is applied +together with the Hadamard matrix: + +.. math:: + + H = \frac{1}{\sqrt{d}} S_d H_d + +where diagonal entries of :math:`S_d` are :math:`\{-1, 1\}` and flip the signs of different rows of :math:`H_d`. +As described in the paper, a single random sign vector is shared across all linear layers throughout training. +In the implementation, this vector is fixed and the RHT matrix is computed once at initialization and cached. + +**Tiled implementation** + +The Hadamard transform is performed in a tiled approach along the last dimension of the tensor. +For an :math:`m \times k` tensor, the data is reshaped to :math:`(mk/d) \times d` +and multiplied by the :math:`d \times d` matrix :math:`H`. In this implementation, :math:`d = 16`. + + +.. raw:: html + :file: img/rht.svg + +*Figure 4. WGRAD GEMM pipeline comparison: without RHT (left) and with RHT applied (right).* + +Handling transposes +------------------- + +Like :doc:`MXFP8 <../mxfp8/mxfp8>`, NVFP4 requires both rowwise and columnwise quantized tensors +for different GEMM operands. Unlike MXFP8 which supports multiple layouts (TN, NT, NN), +**NVFP4 GEMM only supports the TN layout**. + +NVFP4 stores columnwise data and scaling factors in a **transposed layout**: + +- **Rowwise**: data ``[A, B]`` with 1×16 horizontal blocks, ``scales`` shape ``[A, B/16]`` +- **Columnwise**: data ``[B, A]`` (transposed) with 1×16 horizontal blocks, ``scales`` shape ``[B, A/16]`` + +Scale tensors are padded for hardware alignment: first dimension to a multiple of 128, +second dimension to a multiple of 4 (e.g. rowwise: ``[roundup(A, 128), roundup(B/16, 4)]``). + +.. raw:: html + :file: img/nvfp4_row_col.svg + +*Figure 5. NVFP4 rowwise vs columnwise quantization layout. Unlike MXFP8, columnwise scales are stored transposed.* + + +Distributed training +-------------------- + +**Amax reduction** + +Block scaling factors (``s_block``) do not require synchronization between nodes, +as each scaling factor is local to its block of 16 elements. +However, the global scaling factor (``s_global``) requires amax synchronization for gathered tensors. +For tensors that are gathered (e.g., input and gradient in sequence parallelism), +amax reduction is performed before quantization. +If before synchronization there was ``amax_1`` on node 1, +``amax_2`` on node 2, etc., after synchronization there will be ``max(amax_1, amax_2, ...)`` on all nodes. + +**Quantized all-gather** + +NVFP4 all-gather is supported. + +.. raw:: html + :file: img/nvfp4_all_gather.svg + +*Figure 6. Quantization and all-gather flow for NVFP4 showing amax synchronization and hierarchical scaling.* + +Examples +-------- + +Here's how to use NVFP4 recipe in PyTorch and JAX. The examples show how to configure features like 2D weight quantization and Random Hadamard Transform (RHT): + +.. tabs:: + + .. tab:: PyTorch + + .. raw:: html + +
+ Requires SM100 (Blackwell) or later +
+ + .. literalinclude:: pytorch_nvfp4_example.py + :language: python + :start-after: # START_NVFP4_EXAMPLE + :end-before: # END_NVFP4_EXAMPLE + + .. tab:: JAX + + .. raw:: html + +
+ Requires SM100 (Blackwell) or later +
+ + .. literalinclude:: jax_nvfp4_example.py + :language: python + :start-after: # START_NVFP4_EXAMPLE + :end-before: # END_NVFP4_EXAMPLE + + +Supported devices +----------------- + +* **Training**: SM 10.0, SM 10.3 +* **Inference**: SM 10.0+ + + +---- + +Developer Notes +--------------- + +This section contains implementation details that may be useful for developers +but are not required for using NVFP4 in practice. + +Swizzling scaling factors +^^^^^^^^^^^^^^^^^^^^^^^^^ + +NVFP4 requires swizzling of block scaling factors (``s_block``) before GEMM operations, +similar to :doc:`MXFP8 <../mxfp8/mxfp8>`. Key differences: + +- Block size is 16 (vs 32 for MXFP8) +- Both rowwise and columnwise scaling factors are swizzled, but thanks to the transposed + columnwise layout, a single rowwise swizzle kernel handles both cases. +- Scaling factors are stored as FP8 E4M3 (vs E8M0 for MXFP8) + +All-gather of columnwise tensors +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +All-gather of columnwise tensors is supported. To enable quantized all-gather, +all nodes must use the same ``s_global``, which is computed from the synchronized global amax. +This is automatically enabled for column-parallel and row-parallel linear layers. diff --git a/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py new file mode 100644 index 000000000..07b680def --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Check for Blackwell or newer GPU +major, minor = torch.cuda.get_device_capability() +assert major >= 10, f"NVFP4 requires SM100 (Blackwell) or later, got SM{major}{minor}" + +# START_NVFP4_EXAMPLE + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import NVFP4BlockScaling + +# Define NVFP4 recipe +# 2D weight quantization and RHT are enabled by default +recipe = NVFP4BlockScaling() +# To disable features, use: +# recipe = NVFP4BlockScaling(disable_rht=True, disable_2d_quantization=True) + +# Create a linear layer with bfloat16 parameters +layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) + +# Forward and backward pass +inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") + +with te.autocast(enabled=True, recipe=recipe): + output = layer(inp) + loss = output.sum() + +loss.backward() + +# END_NVFP4_EXAMPLE diff --git a/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py b/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py new file mode 100644 index 000000000..4f2f39ca3 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py @@ -0,0 +1,41 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ + +# START_FUSED_LAYERS + +import jax +import jax.numpy as jnp +import transformer_engine.jax as te +from transformer_engine.jax.flax import LayerNorm, DenseGeneral, LayerNormDenseGeneral +from transformer_engine.common.recipe import DelayedScaling + +key = jax.random.PRNGKey(0) +x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) + +# Example 1: Separate LayerNorm and DenseGeneral layers +layer_norm = LayerNorm() +dense = DenseGeneral(features=1024) + +# Initialize parameters +ln_params = layer_norm.init(key, x) +dense_params = dense.init(key, x) + +# Two separate operations +normalized = layer_norm.apply(ln_params, x) +output_separate = dense.apply(dense_params, normalized) + +# Example 2: Fused LayerNormDenseGeneral layer +fused_layer = LayerNormDenseGeneral(features=1024) + +# Initialize and apply with FP8 autocast +recipe = DelayedScaling() +with te.autocast(enabled=True, recipe=recipe): + fused_params = fused_layer.init(key, x) + output_fused, _ = fused_layer.apply(fused_params, x) # Returns (output, ln_output) + +# The fused layer is more efficient as it combines LayerNorm and quantization + +# END_FUSED_LAYERS diff --git a/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py b/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py new file mode 100644 index 000000000..2108f45a0 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ +cc = torch.cuda.get_device_capability() +assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" + +# START_FUSED_LAYERS + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling + +# Example 1: Separate LayerNorm and Linear layers +layer_norm = te.LayerNorm(1024) +linear = te.Linear(1024, 1024) + +inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") + +# Two separate operations: LayerNorm produces FP32, then Linear quantizes it +normalized = layer_norm(inp) +output_separate = linear(normalized) + +# Example 2: Fused LayerNormLinear layer +fused_layer = te.LayerNormLinear(1024, 1024, params_dtype=torch.bfloat16) + +# Single operation: LayerNorm output is directly quantized +recipe = DelayedScaling() +with te.autocast(enabled=True, recipe=recipe): + output_fused = fused_layer(inp) + +# The fused layer is more efficient as it avoids redundant quantization + +# END_FUSED_LAYERS diff --git a/docs/features/low_precision_training/performance_considerations/img/fused_layers.svg b/docs/features/low_precision_training/performance_considerations/img/fused_layers.svg new file mode 100644 index 000000000..8b7ffb5b5 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/img/fused_layers.svg @@ -0,0 +1,120 @@ + + + + + + + + + + + LayerNorm + Linear: Separate vs Fused + + + + + + Scenario 1: Separate Layers + + + + Input + + + + + + + LayerNorm + + + + + + + Output + + + + + + + Linear + + + + Quantize + + + + + + + FP8 tensor + + + + + + + ... + + + + + + + Output + + + + Scenario 2: Fused Layer + + + + Input + + + + + + + LayerNormLinear + + + + + LayerNorm + Quantize + + + + + + + FP8 tensor + + + + + + + ... + + + + + + + Output + + diff --git a/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg b/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg new file mode 100644 index 000000000..fa720427e --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg @@ -0,0 +1,214 @@ + + + + + + + + + + NN GEMM + + + + A + + + + + + + + + + + + + + + + + + + + + + + + + + + rowwise + + + + + B + + + + + + + + + + + + + + + + + + + + + + + + + + + columnwise + + + + + A×B + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + TN GEMM + + + + A + + + + + + + + + + + + + + + + + + + + + + + + + + + rowwise + + + + + B + + + + + + + + + + + + + + + + + + + + + + + + + + + rowwise + + + + + A×BT + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/features/low_precision_training/performance_considerations/img/hopper_vs_blackwell_layout.svg b/docs/features/low_precision_training/performance_considerations/img/hopper_vs_blackwell_layout.svg new file mode 100644 index 000000000..6f9bc4d5a --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/img/hopper_vs_blackwell_layout.svg @@ -0,0 +1,122 @@ + + + + + + + + FP8 tensor on Hopper + + + + rowwise + + + 0 + + 1 + + 2 + + 3 + + + 4 + + 5 + + 6 + + 7 + + + 8 + + 9 + + 10 + + 11 + + + + + columnwise + + + 0 + + 4 + + 8 + + + 1 + + 5 + + 9 + + + 2 + + 6 + + 10 + + + 3 + + 7 + + 11 + + + + + + + + FP8 tensor on Blackwell + + + + rowwise and columnwise + + + 0 + + 1 + + 2 + + 3 + + + 4 + + 5 + + 6 + + 7 + + + 8 + + 9 + + 10 + + 11 + + + diff --git a/docs/features/low_precision_training/performance_considerations/img/sequence_parallel_quantization.svg b/docs/features/low_precision_training/performance_considerations/img/sequence_parallel_quantization.svg new file mode 100644 index 000000000..5b61ac247 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/img/sequence_parallel_quantization.svg @@ -0,0 +1,159 @@ + + + + + + + + + + + All-Gather of Quantized Tensors (one scenario) + + + Input Tensor quantized all-gather + + + FWD: + + + + High Precision + Tensor + + + + + + + Quantize + + + + + + + Rowwise + Quantized + + + + + + + All-Gather + + + + + + ... + + + BWD: + + + + + + + Columnwise + Quantized + + + + + + + All-Gather + + + + + + ... + + + + + + Gradient Tensor quantized all-gather + + + BWD: + + + + High Precision + Tensor + + + + + + + Quantize + + + + + + + Col. Quantized + + + + + + + Row. Quantized + + + + + + + + + + All-Gather + + + + + + ... + + + + + High Precision (FP32/BF16/FP16) + + + Lower Precision (FP8, etc.) + + + Quantization + + + All-Gather + + + + diff --git a/docs/features/low_precision_training/performance_considerations/img/transpose_fusion.svg b/docs/features/low_precision_training/performance_considerations/img/transpose_fusion.svg new file mode 100644 index 000000000..194b1237e --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/img/transpose_fusion.svg @@ -0,0 +1,181 @@ + + + + + + + + + + + Option 1: Quantize both usages in forward + + + FORWARD: + + + + High Precision + Tensor + + + + + + + Quantize + + + + + + + Quantized + Rowwise + + + BACKWARD: + + + + + + + Quantized + Columnwise + + + + + + Option 2: Separate Quantizations (quantize when needed) + + + FORWARD: + + + + High Precision + Tensor + + + + + + + Quantize + + + + + + + Quantized + Rowwise + + + + + + BACKWARD: + + + + High Precision + Tensor + + + + + + + Quantize + + + + + + + Quantized + Columnwise + + + + + + Option 3: Convert Rowwise to Columnwise in Backward (reuse saved tensor) + + + FORWARD: + + + + High Precision + Tensor + + + + + + + Quantize + + + + + + + Quantized + Rowwise + + + + + + BACKWARD: + + + + Quantized + Rowwise + + + + + + + Make + Columnwise + + + + + + + Quantized + Columnwise + + + + + High Precision (FP32/BF16/FP16) + + + Lower Precision (FP8, etc.) + + + Quantization / Make Columnwise + + + + diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out new file mode 100644 index 000000000..717769b1e --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out @@ -0,0 +1,9 @@ +# START_MEMORY_USAGE_1 +Tensors in memory: + Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB + Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB + Total from all live arrays: 4.00 MB +# END_MEMORY_USAGE_1 +Processing events... +Generated: + No reports were generated diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py new file mode 100644 index 000000000..8c1250575 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ + +print("# START_MEMORY_USAGE_1") + +import jax +import jax.numpy as jnp +from transformer_engine.jax.flax import DenseGeneral + + +key = jax.random.PRNGKey(0) +jax.clear_caches() + + +# Initialize layer with BF16 parameters +layer = DenseGeneral(features=1024, dtype=jnp.bfloat16) +x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16) +var_collect = layer.init(key, x) + + +@jax.jit +def loss_fn(var_collect, x): + output = layer.apply(var_collect, x) + return output.sum() + + +# Trace the backward pass - this allocates saved tensors +_, backward_fn = jax.vjp(loss_fn, var_collect, x) + + +del x + +print("Tensors in memory:") +total_bytes = 0 +for arr in jax.live_arrays(): + total_bytes += arr.nbytes + if arr.nbytes > 200000: # do not count small tensors + print(f" Shape: {arr.shape}, Dtype: {arr.dtype}, Size: {arr.nbytes / 1024:.1f} KB") +print(f" Total from all live arrays: {total_bytes / (1024**2):.2f} MB") + + +print("# END_MEMORY_USAGE_1") diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out new file mode 100644 index 000000000..b00749241 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out @@ -0,0 +1,4 @@ + +# START_MEMORY_USAGE_1 +Memory usage after forward pass: 6.00 MB +# END_MEMORY_USAGE_1 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py new file mode 100644 index 000000000..dd4ce2447 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py @@ -0,0 +1,38 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ +cc = torch.cuda.get_device_capability() +assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" + +print("# START_MEMORY_USAGE_1") +import torch +import transformer_engine.pytorch as te + + +def measure_memory(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + init_memory = torch.cuda.memory_allocated() + layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) + + inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") + out = layer(inp) + del inp # Input is saved by model for backward, not by user script + + mem_after_forward = torch.cuda.memory_allocated() - init_memory + return mem_after_forward + + +# Warmup run +measure_memory() + +# Actual measurement +mem_after_forward = measure_memory() +print(f"Memory usage after forward pass: {mem_after_forward/1024**2:.2f} MB") +# END_MEMORY_USAGE_1 +print("# END_MEMORY_USAGE_1") diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out new file mode 100644 index 000000000..ab720b57a --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out @@ -0,0 +1,10 @@ +# START_MEMORY_USAGE_2 +Tensors in memory: + Shape: (1024, 1024), Dtype: float8_e4m3fn, Size: 1024.0 KB + Shape: (1024, 1024), Dtype: float8_e4m3fn, Size: 1024.0 KB + Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB + Total from all live arrays: 4.02 MB +# END_MEMORY_USAGE_2 +Processing events... +Generated: + No reports were generated diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py new file mode 100644 index 000000000..3baa55bb8 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py @@ -0,0 +1,48 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ + +print("# START_MEMORY_USAGE_2") + +import jax +import jax.numpy as jnp +import transformer_engine.jax as te +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.common.recipe import DelayedScaling + + +key = jax.random.PRNGKey(0) +recipe = DelayedScaling() +jax.clear_caches() + + +# Initialize layer with BF16 parameters (outside autocast) +layer = DenseGeneral(features=1024, dtype=jnp.bfloat16) +x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16) + + +# Forward and backward pass with FP8 compute +with te.autocast(enabled=True, recipe=recipe): + var_collect = layer.init(key, x) + + @jax.jit + def loss_fn(var_collect, x): + output = layer.apply(var_collect, x) + return output.sum() + + # Trace the backward pass - this allocates saved tensors + _, backward_fn = jax.vjp(loss_fn, var_collect, x) + +del x + +print("Tensors in memory:") +total_bytes = 0 +for arr in jax.live_arrays(): + total_bytes += arr.nbytes + if arr.nbytes > 200000: # do not count small tensors + print(f" Shape: {arr.shape}, Dtype: {arr.dtype}, Size: {arr.nbytes / 1024:.1f} KB") +print(f" Total from all live arrays: {total_bytes / (1024**2):.2f} MB") + +print("# END_MEMORY_USAGE_2") diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out new file mode 100644 index 000000000..cc1e40258 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out @@ -0,0 +1,4 @@ + +# START_MEMORY_USAGE_2 +Memory after forward pass: 6.02 MB +# END_MEMORY_USAGE_2 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py new file mode 100644 index 000000000..5c247177d --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ +cc = torch.cuda.get_device_capability() +assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" + +print("# START_MEMORY_USAGE_2") +import torch +import transformer_engine.pytorch as te + + +def measure_memory(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + init_memory = torch.cuda.memory_allocated() + layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) + + inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") + with te.autocast(enabled=True): + out = layer(inp) + del inp # Input is saved by model for backward, not by user script + + mem_after_forward = torch.cuda.memory_allocated() - init_memory + return mem_after_forward + + +# Warmup run +measure_memory() + +# Actual measurement +mem_after_forward = measure_memory() +print(f"Memory after forward pass: {mem_after_forward/1024**2:.2f} MB") +# END_MEMORY_USAGE_2 +print("# END_MEMORY_USAGE_2") diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out new file mode 100644 index 000000000..ea4d0dc89 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out @@ -0,0 +1,4 @@ + +# START_MEMORY_USAGE_3 +Memory after forward pass: 3.02 MB +# END_MEMORY_USAGE_3 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py new file mode 100644 index 000000000..ce6905ce4 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py @@ -0,0 +1,44 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ +cc = torch.cuda.get_device_capability() +assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" + +print("# START_MEMORY_USAGE_3") +import torch +import transformer_engine.pytorch as te + + +def measure_memory(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + init_memory = torch.cuda.memory_allocated() + + # FP8 inference with FP8 weights + with te.quantized_model_init(enabled=True), torch.no_grad(): + layer_fp8 = te.Linear(1024, 1024, params_dtype=torch.bfloat16) + + with torch.no_grad(): + inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") + with te.autocast(enabled=True): + out = layer_fp8(inp) + del inp # Input is not saved by model for backward in inference + + mem_after_forward = torch.cuda.memory_allocated() - init_memory + + return mem_after_forward + + +# Warmup run +measure_memory() + +# Actual measurement +mem_after_forward = measure_memory() +print(f"Memory after forward pass: {mem_after_forward/1024**2:.2f} MB") +# END_MEMORY_USAGE_3 +print("# END_MEMORY_USAGE_3") diff --git a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst new file mode 100644 index 000000000..a495af56c --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst @@ -0,0 +1,473 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Performance Considerations +=================================== + +.. _handling_transposes: + +Handling transposes +------------------- + +In the last chapter we demonstrated that for FP8 on Hopper architecture, +some tensors need to be physically transposed in memory to perform needed GEMMs. +Dealing with transposes in Transformer low precision training is a bit tricky. +Let's start by introducing the concept of *tensor usages*. + +**Tensor usages** + +Each quantized tensor may have two usages: + +- *rowwise usage* -- which is used for matrix multiplication, when the consecutive elements in row are accessed, +- *columnwise usage* -- which is used for matrix multiplication, when the consecutive elements in column are accessed, + +To understand what access of consecutive elements means, let's consider two matrices ``A`` and ``B`` +and analyze how their elements are accessed during multiplication. + +For NN (non-transposed, non-transposed) multiplication ``C = A * B``, the formula is ``C_ij = sum_k(A_ik * B_kj)``. +To compute element ``C_ij``, we iterate over the i-th row of ``A`` (elements ``A_i0, A_i1, ...``) +and the j-th column of ``B`` (elements ``B_0j, B_1j, ...``). Thus, ``A`` is accessed rowwise +and ``B`` is accessed columnwise. + +For NT (non-transposed, transposed) multiplication ``C = A * B^T``, the formula changes to ``C_ij = sum_k(A_ik * B_jk)``. +Now we iterate over the i-th row of ``A`` and the j-th row of ``B`` (elements ``B_j0, B_j1, ...``). +Both tensors are accessed rowwise. + +The figure below illustrates these access patterns: + +.. figure:: img/gemm_access_pattern.svg + :align: center + :width: 60% + :alt: Matrix multiplication access pattern showing rowwise access for first tensor and columnwise access for second tensor + + Figure 1: Access patterns in matrix multiplication for matrices in ``A * B`` and ``A * B^T`` operations. + +Based on the visualization above, we can derive general rules for when each matrix +is accessed in rowwise or columnwise fashion. The key insight is that: + +- The **first tensor** in a matrix multiplication is accessed along its rows (rowwise) when non-transposed, + or along its columns (columnwise) when transposed. +- The **second tensor** follows the opposite pattern: columnwise when non-transposed, rowwise when transposed. + +.. table:: Table 1: Summary of tensor access patterns based on transpose state. + :align: center + + +------------------+--------------+---------------+ + | | First tensor | Second tensor | + +------------------+--------------+---------------+ + | Non-transposed | rowwise | columnwise | + +------------------+--------------+---------------+ + | Transposed | columnwise | rowwise | + +------------------+--------------+---------------+ + +**Input, weight and output gradient usages** + +Now let's apply these rules to a Linear layer. During training, a Linear layer performs +three GEMM operations: one in the forward pass and two in the backward pass. + + +.. table:: Table 2: Tensor access patterns for GEMM operations in a Linear layer during training. + :align: center + + +-------------------+-------------------------------------+---------------------------+---------------------------+ + | GEMM | Formula | First tensor usage | Second tensor usage | + +===================+=====================================+===========================+===========================+ + | Forward | ``output = input * weight^T`` | input: rowwise | weight: rowwise | + +-------------------+-------------------------------------+---------------------------+---------------------------+ + | Weight gradient | ``wgrad = output_grad^T * input`` | output_grad: columnwise | input: columnwise | + +-------------------+-------------------------------------+---------------------------+---------------------------+ + | Input gradient | ``dgrad = output_grad * weight`` | output_grad: rowwise | weight: columnwise | + +-------------------+-------------------------------------+---------------------------+---------------------------+ + +An important observation is that the **forward pass uses only rowwise tensors** - both input +and weight are accessed rowwise. + +The backward pass introduces columnwise access. For weight gradient, both output gradient and input +are accessed columnwise. For input gradient, output gradient is rowwise while weight is columnwise. + +As a result, each tensor (input, weight, output gradient) needs both rowwise and columnwise +usages during training. This has implications for memory layout and transpose operations. + + +**Architecture differences** + +The physical memory layout requirements for rowwise and columnwise usages differ between architectures +and recipes. For FP8 tensors: + +- *Hopper*: cannot efficiently access elements in columnwise fashion, so columnwise tensors need to be physically transposed in memory. Note that higher precision formats (BF16/FP16) do not have this limitation. +- *Blackwell*: supports columnwise access natively, so no transpose is needed. + +We will see that for most of the recipes and devices, rowwise usage and columnwise usage need different tensors. +Thus by *rowwise tensor* and *columnwise tensor* we mean tensors that are used in rowwise and columnwise usages respectively. + +.. figure:: img/hopper_vs_blackwell_layout.svg + :align: center + :alt: Comparison of rowwise and columnwise tensor layouts on Blackwell vs Hopper + + Figure 2: On Blackwell, rowwise and columnwise usages share the same memory layout. + On Hopper, columnwise usage requires a physical transpose. + +**Quantization fusions** + +This section is relevant only for recipes for which columnwise tensors +are different from rowwise tensors. + +Note that performing rowwise and columnwise quantization at the same time +enables some fusions, which usually lead to better performance. +We showcase 3 example scenarios of producing quantized tensors in rowwise and columnwise usages, +TE will use best possible fusion for given recipe and TE module configuration: + +1. *Computation of quantized tensor in both rowwise and columnwise usages in a single kernel in forward pass*. + + This is the fastest one, + but since the columnwise usage is saved for backward pass, it may lead to increased memory usage, + if the high precision tensor also needs to be saved for backward - for example if it is the attention output which is saved anyway. + +2. *Computation of quantized tensor in rowwise usage in forward pass and fused quantization to produce columnwise usage in backward pass*. + + This is usually slower than the previous one, since high precision tensor needs to be read twice. + It is used for example when high precision tensor is gathered both in forward and in backward + and quantized tensor gather is not implemented for such recipe. + +3. *Computation of quantized tensor in rowwise usage in forward pass and transpose to columnwise usage in backward pass*. + + It is more memory efficient than Option 1, but not all recipes can utilize it (otherwise + the quantization accuracy would drop due to double quantization errors). + +Transformer Engine chooses the best possible fusion internally taking the recipe and the operation into account. + +.. raw:: html + :file: img/transpose_fusion.svg + +*Figure 3: Three scenarios of producing quantized tensors in rowwise and columnwise usages.* + + + +Memory usage +------------ + +This section discusses memory usage in low precision training. +Contrary to intuition, FP8 training does not always reduce memory compared to BF16/FP16. + +*Master weights* + +Transformer Engine by default stores weights in high precision and quantizes them to low precision before each GEMM. +Moreover, one can specify which high precision should be used to store the weights in the +model (FP32/BF16/FP16) -- or choose not to store high precision weights in the model at all. +There are multiple scenarios to consider, three of them are listed below: + +1. model weights are in FP32, quantized to low precision before each GEMM, +2. model weights are in BF16/FP16, quantized to low precision before each GEMM, master weights in optimizer are in FP32. +3. model weights are stored directly in low precision, and master weights in optimizer are in FP32. + +Note that each of these scenarios may have different memory footprint. + +*Activations saved for backward* + +Unlike weights, activations do not require a high precision copy for optimizer updates. +As shown in Table 2, the input needs rowwise usage in forward and columnwise usage +for weight gradient computation in backward — so it must be saved between passes. + +The memory impact depends on which scenario from Figure 3 is used. +Additionally, on architectures where rowwise and columnwise usage tensors share the same memory layout +(e.g., FP8 on Blackwell, as shown in Figure 2), a single quantized tensor serves both usages, +reducing memory overhead compared to architectures requiring separate tensors. + +Output gradients, on the other hand, are computed during backward and do not need to be saved — +both rowwise and columnwise usages are produced on the fly as needed. + +The FP8 examples below are analyzed on Hopper (SM90) or Ada (SM89) architecture, where rowwise +and columnwise tensors require separate memory layouts. + +.. tabs:: + + .. tab:: PyTorch + + **1. Baseline: high precision forward pass** + + Let's start with a forward pass in higher precision to establish a baseline. + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: memory_usage_1_pytorch.py + :language: python + :start-after: # START_MEMORY_USAGE_1 + :end-before: # END_MEMORY_USAGE_1 + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: memory_usage_1_pytorch.out + :language: text + :start-after: # START_MEMORY_USAGE_1 + :end-before: # END_MEMORY_USAGE_1 + + Layer size is ``1024 * 1024 * 2 (2 bytes per parameter) = 2MB``. + Memory after forward pass is ``2 MB (weight) + 2 MB (input saved for backward) + 2 MB (output) = 6 MB``. + + **2. FP8 training with model weights in BF16** + + Now let's see the memory usage in FP8 training with high precision weights. + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: memory_usage_2_pytorch.py + :language: python + :start-after: # START_MEMORY_USAGE_2 + :end-before: # END_MEMORY_USAGE_2 + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: memory_usage_2_pytorch.out + :language: text + :start-after: # START_MEMORY_USAGE_2 + :end-before: # END_MEMORY_USAGE_2 + + Total memory usage is ``2 MB (weight) + 1 MB (weight in FP8) + 1 MB (input in FP8 saved for backward) + 2 MB (output) = 6 MB``. + + **3. FP8 inference with model weights stored directly in low precision** + + For inference scenarios, model weights can be stored directly in low precision. Since we are only + performing forward passes without gradient updates, master weights in high precision are not needed. + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: memory_usage_3_pytorch.py + :language: python + :start-after: # START_MEMORY_USAGE_3 + :end-before: # END_MEMORY_USAGE_3 + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: memory_usage_3_pytorch.out + :language: text + :start-after: # START_MEMORY_USAGE_3 + :end-before: # END_MEMORY_USAGE_3 + + Total memory usage is ``1 MB (weight in FP8) + 2 MB (output) = 3 MB``. + This is lower than the BF16 baseline (6 MB) since no copies are saved for backward in inference mode. + + **4. Saving original input instead of quantized** + + By default, TE saves the columnwise quantized input for the backward pass (needed for weight gradient). + However, when the high precision input is already being saved (e.g., for a residual connection), + keeping an additional quantized copy wastes memory. + + The ``save_original_input=True`` option tells the layer to reference the original high precision input + instead of caching a separate quantized copy. The input is re-quantized during backward when needed. + Below is an example with a residual block where input is kept for the addition: + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: save_original_input_pytorch.py + :language: python + :start-after: # START_SAVE_ORIGINAL_INPUT + :end-before: # END_SAVE_ORIGINAL_INPUT + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: save_original_input_pytorch.out + :language: text + :start-after: # START_SAVE_ORIGINAL_INPUT + :end-before: # END_SAVE_ORIGINAL_INPUT + + .. tab:: JAX + + **1. Baseline: high precision forward pass** + + Let's start with a forward pass in higher precision to establish a baseline. + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: memory_usage_1_jax.py + :language: python + :start-after: # START_MEMORY_USAGE_1 + :end-before: # END_MEMORY_USAGE_1 + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: memory_usage_1_jax.out + :language: text + :start-after: # START_MEMORY_USAGE_1 + :end-before: # END_MEMORY_USAGE_1 + + Layer size is ``1024 * 1024 * 2 (2 bytes per parameter) = 2MB``. + Memory after forward pass is ``2 MB (weight) + 2 MB (input saved for backward) = 4 MB``. + + **2. FP8 training with master weights in BF16** + + Now let's see the memory usage in FP8 training with high precision weights. + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: memory_usage_2_jax.py + :language: python + :start-after: # START_MEMORY_USAGE_2 + :end-before: # END_MEMORY_USAGE_2 + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: memory_usage_2_jax.out + :language: text + :start-after: # START_MEMORY_USAGE_2 + :end-before: # END_MEMORY_USAGE_2 + + Memory after forward pass is ``2 MB (weight in BF16) + 1 MB (input in FP8) + 1 MB (weight in FP8) = 4 MB``. + +Fused layers +------------ + + +Transformer Engine provides fused layers such as ``LayerNormLinear`` (``LayerNormDenseGeneral`` in JAX) and ``LayerNormMLP`` +that enable kernel fusion optimizations. One key optimization is fusing layer normalization +with quantization. + +In a typical Transformer architecture, LayerNorm precedes a Linear layer. Without fusion, +the LayerNorm outputs in high precision, and the Linear layer must then quantize this input before +performing the GEMM — adding overhead. With ``LayerNormLinear``, these operations are fused +into a single kernel: the LayerNorm output is quantized directly, eliminating the separate +quantization step and reducing memory movement. + + +.. raw:: html + :file: img/fused_layers.svg + +*Figure 4: Comparison of separate LayerNorm and Linear layers versus fused LayerNormLinear layer, showing reduced quantization overhead.* + + +Let's see how we can use fused layers in different frameworks. + +.. tabs:: + + .. tab:: PyTorch + + In PyTorch, Transformer Engine provides fused layers like ``LayerNormLinear`` and ``LayerNormMLP``. + These layers combine normalization and linear operations with optimized quantization. + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada, Hopper, Blackwell, or newer) +
+ + .. literalinclude:: fused_layers_pytorch.py + :language: python + :start-after: # START_FUSED_LAYERS + :end-before: # END_FUSED_LAYERS + + The fused ``LayerNormLinear`` layer is particularly efficient in FP8 training because + it avoids an intermediate quantization step. The LayerNorm output is directly quantized + for the GEMM operation, reducing memory movement and improving performance. + + .. tab:: JAX + + In JAX, Transformer Engine provides fused layers like ``LayerNormDenseGeneral`` and ``LayerNormMLP``. + These layers combine normalization and dense operations with optimized quantization. + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada, Hopper, Blackwell, or newer) +
+ + .. literalinclude:: fused_layers_jax.py + :language: python + :start-after: # START_FUSED_LAYERS + :end-before: # END_FUSED_LAYERS + + The fused ``LayerNormDenseGeneral`` layer is particularly efficient in FP8 training because + it avoids an intermediate quantization step. The LayerNorm output is directly quantized + for the GEMM operation, reducing memory movement and improving performance. + + +Distributed training +-------------------- + +Transformer Engine handles collective operations internally, so users typically don't need to manage +the interaction between communication and low precision computation. + +Recall that each Linear layer involves six tensors: weight, input, output, and their gradients. +Of these, output and gradients are returned in high precision, and weights are generally not +communicated (except in FSDP, which is outside the scope of this section). This leaves two +tensors where low precision communication matters: **input** and **output gradient**. + +For sequence parallelism, TE supports all-gather of quantized tensors. This provides several benefits: + +1. *Reduced memory usage* — no need to store high precision tensors for backward pass. +2. *Reduced communication* — smaller tensors mean less data to transfer. +3. *Parallelized quantization* — quantization work is distributed across GPUs. + +Support varies by recipe — for example, columnwise quantized all-gather is not available +for all configurations. + +The figure below illustrates one possible all-gather scenario for input and output gradient tensors. +Actual behavior depends on the recipe and module configuration. + +.. raw:: html + :file: img/sequence_parallel_quantization.svg + +*Figure 5: All-gather of quantized tensors for input and gradient tensors. +This is one possible scenario — actual behavior varies depending on the recipe and module configuration.* + + diff --git a/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out new file mode 100644 index 000000000..21227220f --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out @@ -0,0 +1,4 @@ +# START_SAVE_ORIGINAL_INPUT +save_original_input=False: 25.0 MB +save_original_input=True: 24.0 MB +# END_SAVE_ORIGINAL_INPUT diff --git a/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py new file mode 100644 index 000000000..c9efa7107 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py @@ -0,0 +1,51 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ +cc = torch.cuda.get_device_capability() +assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" + +print("# START_SAVE_ORIGINAL_INPUT") +# START_SAVE_ORIGINAL_INPUT +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Float8CurrentScaling + +recipe = Float8CurrentScaling() + + +def residual_block(layer, inp): + """Residual connection: input is saved for addition after linear.""" + out = layer(inp) + return out + inp # inp must be kept for this addition + + +def measure_memory(use_save_original): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + layer = te.Linear( + 1024, 1024, params_dtype=torch.bfloat16, save_original_input=use_save_original + ) + inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda", requires_grad=True) + + with te.autocast(enabled=True, recipe=recipe): + out = residual_block(layer, inp) + out.sum().backward() + + return torch.cuda.max_memory_allocated() / 1024**2 + + +# Warmup runs +measure_memory(False) +measure_memory(True) + +# Actual measurements +for use_save_original in [False, True]: + peak = measure_memory(use_save_original) + print(f"save_original_input={use_save_original}: {peak:.1f} MB") +# END_SAVE_ORIGINAL_INPUT +print("# END_SAVE_ORIGINAL_INPUT") diff --git a/docs/features/other_optimizations/cpu_offloading/cpu_offloading.rst b/docs/features/other_optimizations/cpu_offloading/cpu_offloading.rst new file mode 100644 index 000000000..47ea35a83 --- /dev/null +++ b/docs/features/other_optimizations/cpu_offloading/cpu_offloading.rst @@ -0,0 +1,290 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +CPU Offloading +=================================== + +.. note:: + + CPU Offloading in Transformer Engine is currently available only for **PyTorch**. + It supports all PyTorch modules, not just TE layers. + +CPU offloading moves activation tensors from GPU to CPU memory during the +forward pass and reloads them during backward. Transfers are **asynchronous**, +enabling significant GPU memory savings with minimal overhead. + +Unlike activation checkpointing, offloading avoids recomputation — activations +are stored on CPU instead of being recalculated, making it faster when +CPU-GPU bandwidth is sufficient. + + +Hardware Support +---------------- + +CPU offloading benefits greatly from fast CPU-GPU interconnects. +The faster the link, the more effectively transfer time can be hidden +behind computation. + +.. raw:: html + :file: img/pcie_vs_nvlink.svg + +*Figure 1. Traditional PCIe system vs GB200 Superchip with NVLink-C2C.* + +Traditional **PCIe Gen5 x16** systems offer **128 GB/s** bidirectional bandwidth +between CPU and GPU, which limits offloading benefits. + +With **NVLink-C2C** (GB200), bandwidth jumps to **900 GB/s** bidirectional per link, +making offloading increasingly attractive on modern NVIDIA superchips. +The GB200 pairs a Grace CPU with 480 GB LPDDR5X memory and two Blackwell GPUs, +each with 192 GB HBM3e (384 GB total), providing ample CPU memory for offloading +activations. + +Offloading/reloading consumes HBM bandwidth, which may compete with +other GPU operations — even when transfers are asynchronous. +This is unlikely to affect compute-bound operations like GEMMs, but the impact on +memory-bound operations like quantization may be noticeable. + + +CPU Offloading in Transformer Engine +------------------------------------ + +Transformer Engine supports CPU offloading of activations for **sequential models**. +A model is considered sequential if it satisfies the following conditions: + +1. The model is a sequence of layers: ``x₁ = Layer₁(x₀)``, ``x₂ = Layer₂(x₁)``, ..., ``xₙ = Layerₙ(xₙ₋₁)``. + **The layers may be any PyTorch modules**, not just TE layers. +2. Each intermediate tensor ``xᵢ`` is used only as input to the next layer (not elsewhere in the model). +3. ``xᵢ`` is only needed as input to ``Layerᵢ₊₁``'s backward pass and can be freed once that pass completes. + +Most LLM architectures (stacked Transformer blocks) satisfy these conditions. + +.. raw:: html + :file: img/layer_sequence.svg + +*Figure 2. Sequential model: xᵢ₊₁ = Layerᵢ₊₁(xᵢ). Each layer consumes only the output of the previous one.* + +The example below shows how to offload activations for a sequence of ``torch.nn.Linear`` layers using the default scheduling algorithm: + +.. tabs:: + + .. tab:: PyTorch + + .. literalinclude:: pytorch_basic_offload_example.py + :language: python + :start-after: # START_BASIC_EXAMPLE + :end-before: # END_BASIC_EXAMPLE + + + +Let's take a look at the API in detail: + +.. tabs:: + + .. tab:: PyTorch + + .. code-block:: python + + def get_cpu_offload_context( + enabled: bool = False, + num_layers: Optional[int] = 1, + model_layers: int = 1, + manual_synchronization: bool = False, + offload_stream: Optional[torch.cuda.Stream] = None, + # ... (legacy parameters omitted, see :func:`get_cpu_offload_context`) + ) -> Union[Tuple[ContextManager, Callable], Tuple[ContextManager, Callable, ManualOffloadSynchronizer]]: + ... + +The ``model_layers`` parameter must always be set to the total number of layers in the model. +There are two modes of operation: + +1. **Default scheduling** — set ``num_layers`` to the number of layers to offload. + The algorithm automatically schedules offload/reload operations to overlap with computation. + +2. **Manual synchronization** — set ``manual_synchronization=True`` (``num_layers`` is ignored in this mode). + This mode provides explicit control over when to start offload/reload using the returned ``ManualOffloadSynchronizer``. + +The :func:`transformer_engine.pytorch.get_cpu_offload_context` function returns: + +- **context manager** — wraps each layer's forward pass to intercept tensors saved for backward. +- **sync function** — registers a backward hook on the output tensor to trigger activation reload. +- **ManualOffloadSynchronizer** *(only in manual mode)* — provides explicit control over offload/reload. + +The usage pattern for default scheduling is: + +.. tabs:: + + .. tab:: PyTorch + + .. code-block:: python + + cpu_offload_context, sync_function = get_cpu_offload_context(...) + + for layer in layers: + with cpu_offload_context: + x = layer(x) + x = sync_function(x) + + +Default Offloading Scheduling +----------------------------- + +Default scheduling is enabled when ``manual_synchronization=False`` (the default). +The ``num_layers`` parameter must be specified to set the number of layers to offload. +The algorithm then automatically determines when to offload and reload activations +to maximize overlap with computation. + +For ``num_layers`` layers offloaded of ``model_layers`` layers: + +- First ``num_layers`` layers are offloaded to CPU. +- Offloading starts as soon as tensors are saved for backward — it does not wait + for the layer's forward pass to complete. +- At most ``(model_layers - num_layers)`` sets of activations are on GPU at any time; + both compute and reload may be stalled to enforce this limit. +- Reloading must complete by the time the tensor is needed for the layer's backward pass. +- ``num_layers`` must be at most ``model_layers - 1`` (setting it to ``model_layers`` + raises an assertion error). However, ``model_layers - 1`` leaves only 1 activation set + on GPU at a time — compute and transfers cannot overlap, and a warning is raised. + For full overlap, use ``model_layers - 2`` or less. + +Specifying a low enough ``num_layers`` enables full overlap of computation +and offload/reload. The following two scenarios illustrate this — one with full overlap, and one with stalls. + +.. raw:: html + :file: img/scheduling.svg + +*Figure 3. With* ``num_layers=2``\ *and* ``model_layers=5``\ *, at most 3 sets of activations are on GPU. Layer 1 offloading starts during its forward pass (when the first tensor is saved for backward). Offloading fully overlaps with forward, reloading fully overlaps with backward.* + +When ``num_layers`` is too high, the GPU memory limit forces stalls: + +.. raw:: html + :file: img/scheduling_stall.svg + +*Figure 4. With* ``num_layers=3``\ *and* ``model_layers=5``\ *, at most 2 sets of activations can be on GPU (5-3=2), which causes stalls. In forward, Layer 4 cannot start until Layer 2 is offloaded, otherwise there would be 3 sets of activations on GPU (Layers 2, 3, 4). In backward, Layer 3 cannot start immediately — its activations are still on CPU and must be reloaded first. Some tensors may finish reloading earlier, allowing parts of the layer (e.g., a sublayer) to run while the rest waits. The same applies to Layers 2 and 1.* + + +Manual Synchronization +---------------------- + +For custom scheduling, set ``manual_synchronization=True``. +Optionally, pass a custom ``offload_stream`` for fine-grained synchronization. +This mode returns a ``ManualOffloadSynchronizer`` with explicit control over transfers. + +This mode is useful when training does not follow the standard "all forwards then all backwards" +pattern — for example, in pipeline parallelism. Providing a custom ``offload_stream`` enables +additional synchronization logic (e.g., waiting, recording events) tailored to the specific workload. + +The ``ManualOffloadSynchronizer`` object provides the following methods: + +- ``start_offload_layer(layer_id)`` — queue async GPU→CPU copies on the offload stream. + Before each copy, the offload stream waits for an event recorded when that tensor + was saved for backward. +- ``release_activation_forward_gpu_memory(layer_id)`` — make the current stream wait for + this layer's offload to complete, then release GPU memory. +- ``start_reload_layer(layer_id)`` — queue async CPU→GPU copies on the offload stream. + When tensors are accessed in backward, compute stream waits for each tensor's reload + to complete. + +To skip offloading for a specific layer, simply do not call any of these methods for that layer. + +.. tabs:: + + .. tab:: PyTorch + + The example demonstrates: + + 1. **Forward pass**: After each layer, call ``start_offload_layer(i)`` to begin + async copy of layer ``i``'s activations to CPU. + 2. **Release GPU memory**: Call ``release_activation_forward_gpu_memory(i)`` to free + the GPU tensors. Each call waits internally for that layer's offload to complete. + 3. **Before backward**: Call ``start_reload_layer(i)`` to begin async reload. + The compute stream will automatically wait for each tensor to be reloaded + before it's accessed in backward. + + .. literalinclude:: pytorch_manual_offload_example.py + :language: python + :start-after: # START_MANUAL_EXAMPLE + :end-before: # END_MANUAL_EXAMPLE + + +CPU Offloading and CUDA Graphs +------------------------------ + +CPU offloading works with CUDA graphs — async copies and stream synchronization +are GPU operations that can be captured and replayed, even when accessing +pinned CPU memory (via PCIe DMA, without CPU involvement). + +.. note:: + + We recommend capturing the entire forward and backward pass in a single graph. + Async copy operations (offload/reload) must complete within the same graph where + they started. If the graph ends before copies finish, PyTorch will block waiting + for them, defeating the purpose of graph capture. + +.. tabs:: + + .. tab:: PyTorch + + .. literalinclude:: pytorch_cuda_graphs_example.py + :language: python + :start-after: # START_CUDA_GRAPHS_EXAMPLE + :end-before: # END_CUDA_GRAPHS_EXAMPLE + +.. note:: + + In PyTorch versions prior to 2.11, CPU offloading with CUDA graphs required passing + ``retain_pinned_cpu_buffers=True`` to :func:`get_cpu_offload_context`. The root cause + was that ``torch.empty`` with pinned CPU memory was not supported inside CUDA graph + capture — buffers had to be pre-allocated and reused across iterations to avoid + invalidating DMA addresses captured in the graph. This was fixed in + `pytorch#167507 `_ (merged December 2025, + shipping in PyTorch 2.11). On PyTorch 2.11+, ``retain_pinned_cpu_buffers`` is no longer needed. + +Caveats +------- + +.. warning:: + + **Heuristic activation detection**: + + CPU Offloading is implemented using + `PyTorch saved tensors hooks `_. + PyTorch saves various tensors for backward — not just activations, but also weights and other data. + + Activation detection is heuristic. A CUDA tensor is offloaded if it: + + - has at least 256×1024 elements (~1 MB for float32), + - is not a ``torch.nn.Parameter``, + - is not marked with ``mark_not_offload()``. + + Additionally, non-contiguous tensors are skipped to avoid memory layout changes (see below). + For TE layers, tensors that should not be offloaded are manually excluded. + For non-TE layers, no such exclusions exist, so some tensors may remain pinned in GPU memory + even after being copied to CPU (e.g., if the layer stores references in ``ctx``), + resulting in wasted bandwidth with no memory savings. + + To exclude specific tensors from offloading, use :func:`mark_not_offload`: + + .. code-block:: python + + from transformer_engine.pytorch import mark_not_offload + mark_not_offload(tensor) + +.. warning:: + + **Memory layout changes**: + + Offloading/reloading can change tensor memory layout and relations: + + 1. Views of the same storage may be restored as separate allocations. + 2. Adjacent tensors may not be adjacent after reload. + + CUDA kernels that rely on specific memory layout may produce unexpected results. + To mitigate (1), non-trivial views are excluded from offloading by default. + TE attention kernels are an exception — they use internal handling that is tested and supported. + Issue (2) is not mitigated — custom kernels that assume adjacent tensors share + contiguous memory may still fail. + + If you encounter layout-related issues, use :func:`mark_not_offload` to exclude + problematic tensors from offloading. diff --git a/docs/features/other_optimizations/cpu_offloading/img/layer_sequence.svg b/docs/features/other_optimizations/cpu_offloading/img/layer_sequence.svg new file mode 100644 index 000000000..cdb8814a9 --- /dev/null +++ b/docs/features/other_optimizations/cpu_offloading/img/layer_sequence.svg @@ -0,0 +1,66 @@ + + + + + + + + + + + x₀ + + + + + + Layer 1 + + + + + + x₁ + + + + + + Layer 2 + + + + + + x₂ + + + + + + Layer 3 + + + + + ··· + + + + + + Layer N + + + + + + xₙ + + diff --git a/docs/features/other_optimizations/cpu_offloading/img/pcie_vs_nvlink.svg b/docs/features/other_optimizations/cpu_offloading/img/pcie_vs_nvlink.svg new file mode 100644 index 000000000..0b8ec3912 --- /dev/null +++ b/docs/features/other_optimizations/cpu_offloading/img/pcie_vs_nvlink.svg @@ -0,0 +1,132 @@ + + + + + + + + + + Traditional PCIe System + + + + + + + CPU + + + + RAM + + + + + + + + + + GPU + + + + HBM + + + + + + + + PCIe + + 128 GB/s + + + + GB200 Superchip + NVIDIA Grace Blackwell + + + + + + + + + + Blackwell + GPU 1 + + + + HBM + + + + + + + NVLink + C2C + + + + + + + Grace CPU + + + + RAM + + + + + + + NVLink + C2C + + + + + + + Blackwell + GPU 2 + + + + HBM + + + + 900 GB/s per NVLink-C2C link + + diff --git a/docs/features/other_optimizations/cpu_offloading/img/scheduling.svg b/docs/features/other_optimizations/cpu_offloading/img/scheduling.svg new file mode 100644 index 000000000..19255c347 --- /dev/null +++ b/docs/features/other_optimizations/cpu_offloading/img/scheduling.svg @@ -0,0 +1,110 @@ + + + + + + + Model (model_layers = 5) + + + + Layer 1 + + + Layer 2 + + + Layer 3 + + + Layer 4 + + + Layer 5 + + + + num_layers = 2 (offloaded) + + + + + + Forward Pass + + + compute stream + offload stream + + + + Layer 1 fwd + + + Layer 2 fwd + + + Layer 3 fwd + + + Layer 4 fwd + + + Layer 5 fwd + + + + Layer 1 offload + + + Layer 2 offload + + + + + + Backward Pass + + + compute stream + reload stream + + + + Layer 5 bwd + + + Layer 4 bwd + + + Layer 3 bwd + + + Layer 2 bwd + + + Layer 1 bwd + + + + Layer 2 reload + + + Layer 1 reload + + diff --git a/docs/features/other_optimizations/cpu_offloading/img/scheduling_stall.svg b/docs/features/other_optimizations/cpu_offloading/img/scheduling_stall.svg new file mode 100644 index 000000000..cd2d1a660 --- /dev/null +++ b/docs/features/other_optimizations/cpu_offloading/img/scheduling_stall.svg @@ -0,0 +1,143 @@ + + + + + + + Model (model_layers = 5) + + + + Layer 1 + + + Layer 2 + + + Layer 3 + + + Layer 4 + + + Layer 5 + + + + num_layers = 3 (offloaded) + + + + + + Forward Pass + + + compute stream + offload stream + + + Layer 1 fwd + + + Layer 2 fwd + + + Layer 3 fwd + + + + wait + + + Layer 4 fwd + + + + wait + + + Layer 5 fwd + + + + Layer 1 offload + + + Layer 2 offload + + + Layer 3 offload + + + + + + Backward Pass + + + compute stream + reload stream + + + Layer 5 bwd + + + Layer 4 bwd + + + + Layer 3 bwd + + + wait + + + + + Layer 2 bwd + + + wait + + + + + Layer 1 bwd + + wait + + + wait + + + + + Layer 3 reload + + + Layer 2 reload + + + Layer 1 reload + + diff --git a/docs/features/other_optimizations/cpu_offloading/pytorch_basic_offload_example.py b/docs/features/other_optimizations/cpu_offloading/pytorch_basic_offload_example.py new file mode 100644 index 000000000..b453e824a --- /dev/null +++ b/docs/features/other_optimizations/cpu_offloading/pytorch_basic_offload_example.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_BASIC_EXAMPLE +import torch +from transformer_engine.pytorch import get_cpu_offload_context + +# Setup +num_layers = 12 +offloaded_layers = 3 +layers = [torch.nn.Linear(1024, 1024).cuda() for _ in range(num_layers)] +x = torch.randn(16, 1024, 1024, device="cuda") + +# Get offloading context and sync function +cpu_offload_context, sync_function = get_cpu_offload_context( + enabled=True, + model_layers=num_layers, + num_layers=offloaded_layers, +) + +# Forward pass +for i in range(num_layers): + # Context manager captures tensors saved for backward. + # These tensors will be offloaded to CPU asynchronously. + with cpu_offload_context: + x = layers[i](x) + + # sync_function must be called after each layer's forward pass. + # This cannot be done inside the context manager because + # it needs the output tensor after the layer has finished. + x = sync_function(x) + +loss = x.sum() +loss.backward() +# END_BASIC_EXAMPLE diff --git a/docs/features/other_optimizations/cpu_offloading/pytorch_cuda_graphs_example.py b/docs/features/other_optimizations/cpu_offloading/pytorch_cuda_graphs_example.py new file mode 100644 index 000000000..a42bd8908 --- /dev/null +++ b/docs/features/other_optimizations/cpu_offloading/pytorch_cuda_graphs_example.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_CUDA_GRAPHS_EXAMPLE +import torch +from transformer_engine.pytorch import get_cpu_offload_context, make_graphed_callables + +# Setup +num_layers = 12 +offloaded_layers = 3 +layers = [torch.nn.Linear(1024, 1024).cuda() for _ in range(num_layers)] + +# Enable offloading for CUDA graphs +cpu_offload_context, sync_function = get_cpu_offload_context( + enabled=True, + model_layers=num_layers, + num_layers=offloaded_layers, +) + + +# Wrap layers in a module that uses offloading +class OffloadedModel(torch.nn.Module): + def __init__(self, layers): + super().__init__() + self.layers = torch.nn.ModuleList(layers) + + def forward(self, x): + for layer in self.layers: + with cpu_offload_context: + x = layer(x) + x = sync_function(x) + return x + + +model = OffloadedModel(layers) +sample_input = (torch.randn(16, 1024, 1024, device="cuda"),) + +# Create graphed callable (warmup is handled internally) +graphed_model = make_graphed_callables(model, sample_input) + +# Use the graphed model +x = torch.randn(16, 1024, 1024, device="cuda") +out = graphed_model(x) +out.sum().backward() +# END_CUDA_GRAPHS_EXAMPLE diff --git a/docs/features/other_optimizations/cpu_offloading/pytorch_manual_offload_example.py b/docs/features/other_optimizations/cpu_offloading/pytorch_manual_offload_example.py new file mode 100644 index 000000000..92e0768c8 --- /dev/null +++ b/docs/features/other_optimizations/cpu_offloading/pytorch_manual_offload_example.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_MANUAL_EXAMPLE +import torch +from transformer_engine.pytorch import get_cpu_offload_context + +# Setup +num_layers = 12 +layers = [torch.nn.Linear(1024, 1024).cuda() for _ in range(num_layers)] +x = torch.randn(16, 1024, 1024, device="cuda") + +offload_stream = torch.cuda.Stream() +cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context( + enabled=True, + model_layers=num_layers, + manual_synchronization=True, + offload_stream=offload_stream, +) + +# Forward pass - manually trigger offload after each layer +for i in range(num_layers): + with cpu_offload_context: + x = layers[i](x) + x = sync_function(x) + manual_controller.start_offload_layer(i) + +# Release GPU memory (each call waits for that layer's offload to complete) +for i in range(num_layers): + manual_controller.release_activation_forward_gpu_memory(i) + +# Start reloading before backward +for i in range(num_layers - 1, -1, -1): + manual_controller.start_reload_layer(i) + +# Backward pass +loss = x.sum() +loss.backward() +# END_MANUAL_EXAMPLE diff --git a/docs/features/other_optimizations/index.rst b/docs/features/other_optimizations/index.rst new file mode 100644 index 000000000..05e89c4b0 --- /dev/null +++ b/docs/features/other_optimizations/index.rst @@ -0,0 +1,12 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Other optimizations +=================================== + +.. toctree:: + + cpu_offloading/cpu_offloading.rst + diff --git a/docs/index.rst b/docs/index.rst index 0edcb863b..738955367 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -39,6 +39,15 @@ Transformer Engine documentation api/common api/framework + +.. toctree:: + :hidden: + :caption: Features + + features/low_precision_training/index.rst + features/other_optimizations/index.rst + + .. toctree:: :hidden: :caption: Examples and Tutorials @@ -49,6 +58,7 @@ Transformer Engine documentation examples/te_gemma/tutorial_generation_gemma_with_te.ipynb examples/onnx/onnx_export.ipynb examples/te_jax_integration.ipynb + examples/op_fuser/op_fuser.rst .. toctree:: :hidden: diff --git a/examples/README.md b/examples/README.md index 004d1631f..782dc42f5 100644 --- a/examples/README.md +++ b/examples/README.md @@ -23,8 +23,6 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr - **FP8 Weight Caching**: Avoiding redundant FP8 casting during multiple gradient accumulation steps to improve efficiency. - [Introduction to FP8](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/fp8_primer.ipynb) - Overview of FP8 datatypes (E4M3, E5M2), mixed precision training, delayed scaling strategies, and code examples for FP8 configuration and usage. -- [TE Quickstart](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb) - - Introduction to TE, building a Transformer Layer using PyTorch, and instructions on integrating TE modules like Linear and LayerNorm. - [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/mnist) # JAX @@ -34,7 +32,9 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr - Model Parallelism: Divide a model across multiple GPUs for parallel training. - Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup. - [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/mnist) - +- [TE JAX Integration Tutorial](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_jax_integration.ipynb) + - Introduction to integrating TE into an existing JAX model framework, building a Transformer Layer, and instructions on integrating TE modules like Linear and LayerNorm. + # Third party - [Hugging Face Accelerate + TE](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8/transformer_engine) - Scripts for training with Accelerate and TE. Supports single GPU, and multi-GPU via DDP, FSDP, and DeepSpeed ZeRO 1-3. diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 0d812da05..681593239 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -1,47 +1,52 @@ # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""Shared functions for the comm_overlap tests""" +"""Shared functions for the collective GEMM tests""" +import argparse + +import jax import jax.numpy as jnp import numpy as np +from jax.experimental import mesh_utils + +from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap -# Add this after your existing imports def dtype_tols(dtype, rtol=None, atol=None): """Expected numerical tolerance for a data type.""" - # Return immediately if tolerances are fully specified if rtol is not None and atol is not None: return {"rtol": rtol, "atol": atol} - # Default tolerances for common dtypes if dtype in [jnp.float32, "float32"]: return {"rtol": 1e-5, "atol": 1e-8} elif dtype in [jnp.float16, "float16"]: return {"rtol": 1e-3, "atol": 1e-6} elif dtype in [jnp.bfloat16, "bfloat16"]: return {"rtol": 1e-2, "atol": 1e-5} + elif dtype in [jnp.float8_e4m3fn, "float8_e4m3fn", jnp.float8_e5m2, "float8_e5m2"]: + # FP8 quantization introduces ~1% error; match C++ getTolerances for fp8 types + return {"rtol": 1e-2, "atol": 1e-2} else: return {"rtol": 1e-5, "atol": 1e-8} -def assert_allclose( - actual, - desired, - rtol=None, - atol=None, - dtype=None, - **kwargs, -): +def get_tolerance_dtype(quantizer_set): + """Return the dtype used to select numerical tolerances based on the active quantizer. + + Reads q_dtype from quantizer_set.x; falls back to bfloat16 when no quantizer is + active (NO_SCALING / noop path, where quantizer_set.x is None). + """ + if quantizer_set.x is not None: + return quantizer_set.x.q_dtype + return jnp.bfloat16 + + +def assert_allclose(actual, desired, rtol=None, atol=None, dtype=None, **kwargs): """Check if two tensors are close.""" - # Infer data type if needed if dtype is None: - if isinstance(actual, float): - dtype = "float32" - else: - dtype = actual.dtype + dtype = "float32" if isinstance(actual, float) else actual.dtype - # Determine tolerances tols = {} if rtol is None or atol is None: tols = dtype_tols(dtype) @@ -50,49 +55,26 @@ def assert_allclose( if atol is not None: tols["atol"] = atol - # Cast tensors to fp32 if not isinstance(actual, float): actual = actual.astype(jnp.float32) if not isinstance(desired, float): desired = desired.astype(jnp.float32) - # Check if tensors are close np.testing.assert_allclose(actual, desired, **tols, **kwargs) -def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8): - if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol): - diff = jnp.abs(ref_output - gathered_output) - mask = diff > (atol + rtol * jnp.abs(gathered_output)) - print(mask.astype(int)) - print(jnp.where(mask, diff, 0)) - - -# Shared constants for all tests +# Shared constants DP_AXIS = "data" TPSP_AXIS = "tensor_sequence" -PARAMS_KEY = "params" - -# Shared functions for distributed testing -import argparse -import jax -from jax.experimental import mesh_utils -from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap # Global flag to track if distributed has been initialized _distributed_initialized = False -def _is_distributed_initialized(): - """Check if JAX distributed has been initialized.""" - return _distributed_initialized - - def _initialize_distributed(args): """Initialize JAX distributed with custom arguments.""" global _distributed_initialized - # Check if already initialized if _distributed_initialized: return @@ -105,14 +87,10 @@ def _initialize_distributed(args): assert ( args.num_devices_per_process is not None ), "Either local_device_ids or num_devices_per_process must be provided" - # Calculate device range for this process - # Single process single device: each process gets one unique device - # Single process multiple devices: each process gets a unique range of devices start_device = args.process_id * args.num_devices_per_process device_range = range(start_device, start_device + args.num_devices_per_process) global_device_ids_for_this_process = ",".join(map(str, device_range)) else: - # Use explicitly provided global device IDs global_device_ids_for_this_process = args.local_device_ids args.num_devices_per_process = len(args.local_device_ids.split(",")) @@ -131,10 +109,6 @@ def _initialize_distributed(args): ) _distributed_initialized = True - jax.clear_caches() - jax.config.update( - "jax_use_shardy_partitioner", False - ) # CollectiveGEMM does not work with Shardy yet assert jax.local_device_count() == 1, ( f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found" @@ -233,7 +207,16 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para help="Type of collective operation", ) parser.add_argument( - "--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use" + "--quantize-recipe", + type=str, + default=None, + choices=[ + "DelayedScaling", + "Float8CurrentScaling", + "MXFP8BlockScaling", + "NVFP4BlockScaling", + ], + help="Quantization recipe to use. Omit for BF16 (no quantization).", ) parser.add_argument( "--enable-data-parallel", action="store_true", help="Enable data parallelism" diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 388c87837..8340d2010 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -23,11 +23,36 @@ else echo "NVLINK support detected" fi -# Define the test files to run -TEST_FILES=( -"test_gemm.py" -"test_dense_grad.py" -"test_layernorm_mlp_grad.py" +# Define individual test cases to run (file::class::method) +# DelayedScalingFP8 and CurrentScalingFP8 use the same GEMM so we don't need to test both cases all +# the time. +TEST_CASES=( +# test_gemm.py cases +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" +# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" +# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" +# +# # test_dense_grad.py cases +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" + +# test_layernorm_mlp_grad.py cases +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" ) echo @@ -57,24 +82,27 @@ cleanup() { # Set up signal handlers to cleanup on exit trap cleanup EXIT INT TERM -# Run each test file across all GPUs -for TEST_FILE in "${TEST_FILES[@]}"; do +# Run each test case across all GPUs +for TEST_CASE in "${TEST_CASES[@]}"; do echo - echo "=== Starting test file: $TEST_FILE ..." + echo "=== Starting test: $TEST_CASE ..." + + # Extract just the test method name for log/xml file naming + TEST_NAME=$(echo "$TEST_CASE" | awk -F'::' '{print $NF}') - # Clear PIDs array for this test file + # Clear PIDs array for this test case PIDS=() for i in $(seq 0 $(($NUM_GPUS - 1))); do # Define output file for logs - LOG_FILE="${TEST_FILE}_gpu_${i}.log" + LOG_FILE="${TEST_NAME}_gpu_${i}.log" if [ $i -eq 0 ]; then # For process 0: show live output AND save to log file using tee echo "=== Live output from process 0 ===" pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \ - "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_NAME}.xml \ + "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ --num-processes=$NUM_GPUS \ --process-id=$i 2>&1 | tee "$LOG_FILE" & PID=$! @@ -82,7 +110,7 @@ for TEST_FILE in "${TEST_FILES[@]}"; do else # For other processes: redirect to log files only pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ --num-processes=$NUM_GPUS \ --process-id=$i > "$LOG_FILE" 2>&1 & PID=$! @@ -93,22 +121,22 @@ for TEST_FILE in "${TEST_FILES[@]}"; do # Wait for all processes to finish wait - # Check and print the log content from process 0 (now has log file thanks to tee) - if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE SKIPPED" - elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE FAILED" + # Check and print the log content from process 0 + if grep -q "SKIPPED" "${TEST_NAME}_gpu_0.log"; then + echo "... $TEST_CASE SKIPPED" + elif grep -q "FAILED" "${TEST_NAME}_gpu_0.log"; then + echo "... $TEST_CASE FAILED" HAS_FAILURE=1 - elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE PASSED" + elif grep -q "PASSED" "${TEST_NAME}_gpu_0.log"; then + echo "... $TEST_CASE PASSED" else - echo "... $TEST_FILE INVALID" + echo "... $TEST_CASE INVALID" HAS_FAILURE=1 fi # Remove the log files after processing them wait - rm ${TEST_FILE}_gpu_*.log + rm ${TEST_NAME}_gpu_*.log done wait diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index 94c7dc5b6..1d300f8e9 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -2,7 +2,6 @@ # # See LICENSE for license information. """Collective Dense Gradient test on multi-GPU with tensor parallelism""" -import argparse import unittest import os @@ -13,18 +12,24 @@ from common import ( assert_allclose, + get_tolerance_dtype, _initialize_distributed, _get_dp_and_tp_sizes, _create_mesh, DP_AXIS, TPSP_AXIS, - PARAMS_KEY, cgemm_parser, ) from transformer_engine.jax.dense import dense -from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.quantize import ( + autocast, + is_quantize_recipe_supported, + get_quantization_recipe, + QuantizerFactory, + noop_quantizer_set, +) from transformer_engine.jax.cpp_extensions.gemm import ( CollectiveOp, CollectiveOpSet, @@ -56,7 +61,9 @@ def _get_operand_sharding(mesh, collective_op): return x_sharding, weight_sharding, bias_sharding -def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): +def _mean_dense( + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set +): output = dense( x, weight, @@ -66,13 +73,16 @@ def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collectiv kernel_axes=weight_axes, output_axes=output_axes, collective_op_set=collective_op_set, + quantizer_set=quantizer_set, ) return jnp.mean(output.astype(jnp.float32)) -def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): +def _value_and_grad_dense( + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set +): return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))( - x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set ) @@ -98,11 +108,16 @@ def run_dense_grad_tests(args, mesh=None): ) collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) + use_quantization = args.quantize_recipe is not None + recipe = get_quantization_recipe(args.quantize_recipe) if use_quantization else None with mesh, autocast( - enabled=False, - recipe=None, + enabled=use_quantization, + recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): + # Build quantizer_set inside autocast so create_set() reads the global recipe + # for correct fwd/bwd dtypes. + quantizer_set = QuantizerFactory.create_set() if use_quantization else noop_quantizer_set # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) @@ -123,6 +138,7 @@ def run_dense_grad_tests(args, mesh=None): weight_axes, output_axes, noop_collective_op_set, + quantizer_set, ) output, sharded_grads = _value_and_grad_dense( x_sharded, @@ -132,6 +148,7 @@ def run_dense_grad_tests(args, mesh=None): weight_axes, output_axes, collective_op_set, + quantizer_set, ) jax.block_until_ready(ref_output) jax.block_until_ready(output) @@ -148,9 +165,10 @@ def run_dense_grad_tests(args, mesh=None): jax.block_until_ready(gathered_ref_grads) if args.enable_result_check and args.process_id == 0: - assert_allclose(ref_output, output, dtype=jnp.bfloat16) + tol_dtype = get_tolerance_dtype(quantizer_set) + assert_allclose(ref_output, output, dtype=tol_dtype) for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): - assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + assert_allclose(ref_grad, gathered_grad, dtype=tol_dtype) class TestCollectiveDenseGradient(unittest.TestCase): @@ -187,6 +205,82 @@ def test_te_bf16_reduce_scatter(self): self.args.collective_type = "reduce_scatter" run_dense_grad_tests(self.args, self.mesh) + def test_te_delayed_scaling_fp8_all_gather(self): + """Test Collective Dense Gradient with FP8 DelayedScaling + AllGather""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_delayed_scaling_fp8_reduce_scatter(self): + """Test Collective Dense Gradient with FP8 DelayedScaling + ReduceScatter""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_all_gather(self): + """Test Collective Dense Gradient with FP8 Float8CurrentScaling + AllGather""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_reduce_scatter(self): + """Test Collective Dense Gradient with FP8 Float8CurrentScaling + ReduceScatter""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_mxfp8_all_gather(self): + """Test Collective Dense Gradient with MXFP8BlockScaling + AllGather""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_mxfp8_reduce_scatter(self): + """Test Collective Dense Gradient with MXFP8BlockScaling + ReduceScatter""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + # def test_te_nvfp4_all_gather(self): + # """Test Collective Dense Gradient with NVFP4BlockScaling + AllGather""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + # if not is_supported: + # self.skipTest(reason) + # self.args.collective_type = "all_gather" + # run_dense_grad_tests(self.args, self.mesh) + + # def test_te_nvfp4_reduce_scatter(self): + # """Test Collective Dense Gradient with NVFP4BlockScaling + ReduceScatter""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + # if not is_supported: + # self.skipTest(reason) + # self.args.collective_type = "reduce_scatter" + # run_dense_grad_tests(self.args, self.mesh) + if __name__ == "__main__": import sys @@ -209,6 +303,6 @@ def test_te_bf16_reduce_scatter(self): args = cgemm_parser( "Collective Dense Gradient test on multi-GPU with tensor parallelism" - ).parse_args([]) + ).parse_args() _initialize_distributed(args) run_dense_grad_tests(args, mesh=None) diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index d2994723b..c2db8fc44 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -22,17 +22,23 @@ from common import ( assert_allclose, + get_tolerance_dtype, _initialize_distributed, _get_dp_and_tp_sizes, _create_mesh, DP_AXIS, TPSP_AXIS, - PARAMS_KEY, cgemm_parser, ) import transformer_engine.jax.cpp_extensions as tex -from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.quantize import ( + autocast, + is_quantize_recipe_supported, + get_quantization_recipe, + QuantizerFactory, + noop_quantizer_set, +) from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp from transformer_engine.jax.sharding import MeshResource @@ -54,31 +60,15 @@ def _get_operand_sharding(mesh, collective_op, is_with_dp): return x_sharding, weight_sharding, bias_sharding, output_sharding -def _get_dp_and_tp_sizes(args): - num_gpu = args.num_processes * args.num_devices_per_process - if args.tensor_parallel_size is None: - num_gpu_dp = 2 if args.enable_data_parallel else 1 - assert ( - num_gpu > 1 and num_gpu % num_gpu_dp == 0 - ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" - num_gpu_tp = num_gpu // num_gpu_dp - else: - num_gpu_tp = args.tensor_parallel_size - assert ( - num_gpu > 1 and num_gpu % num_gpu_tp == 0 - ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" - num_gpu_dp = num_gpu // num_gpu_tp - return num_gpu_dp, num_gpu_tp - - @partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding")) -def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding): +def _jitted_cgemm(x, weight, bias, quantizer_set, contracting_dims, collective_op, output_sharding): output = tex.gemm( x, weight, bias=bias, contracting_dims=contracting_dims, collective_op=collective_op, + quantizer_set=quantizer_set, ) if output_sharding is not None: output = jax.lax.with_sharding_constraint(output, output_sharding) @@ -88,8 +78,6 @@ def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_shard def run_gemm_tests(args, mesh=None): """Execute GEMM tests.""" print(args) - # Collective GEMM requires Shardy partitioner to be disabled - jax.config.update("jax_use_shardy_partitioner", False) # Initialize distributed with provided arguments _initialize_distributed(args) @@ -109,11 +97,20 @@ def run_gemm_tests(args, mesh=None): else CollectiveOp.REDUCE_SCATTER ) + use_quantization = args.quantize_recipe is not None + recipe = get_quantization_recipe(args.quantize_recipe) if use_quantization else None + + # autocast sets the global recipe (fwd/bwd dtypes) AND the global MeshResource + # (via global_shard_guard) required for collective GEMM sharding axis resolution. with mesh, autocast( - enabled=False, - recipe=None, + enabled=use_quantization, + recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): + # Build quantizer_set inside autocast so create_set() can read the global recipe + # for correct fwd/bwd dtypes. autocast does not inject quantizers into raw + # tex.gemm() calls, so we must pass quantizer_set explicitly. + quantizer_set = QuantizerFactory.create_set() if use_quantization else noop_quantizer_set print(f"Device mesh: {mesh}") x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding( @@ -127,6 +124,7 @@ def run_gemm_tests(args, mesh=None): x_sharded, weight_sharded, bias_sharded, + quantizer_set, contracting_dims=((2,), (0,)), collective_op=CollectiveOp.NONE, output_sharding=output_sharding, @@ -135,10 +133,10 @@ def run_gemm_tests(args, mesh=None): x_sharded, weight_sharded, bias_sharded, + quantizer_set, contracting_dims=((2,), (0,)), collective_op=collective_op, - # CollectiveGEMM output should have a correct sharding without applying sharding constraint - output_sharding=None, + output_sharding=output_sharding, ) assert ( ref_output.sharding == output.sharding @@ -153,7 +151,9 @@ def run_gemm_tests(args, mesh=None): jax.block_until_ready(gathered_output) if args.enable_result_check and args.process_id == 0: - assert_allclose(gathered_ref_output, gathered_output) + assert_allclose( + gathered_ref_output, gathered_output, dtype=get_tolerance_dtype(quantizer_set) + ) class TestCollectiveGemmWithDP(unittest.TestCase): @@ -189,6 +189,84 @@ def test_te_bf16_reduce_scatter_with_dp(self): self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) + def test_te_delayed_scaling_fp8_all_gather_with_dp(self): + """Test Collective GEMM with FP8 DelayedScaling + AllGather""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self): + """Test Collective GEMM with FP8 DelayedScaling + ReduceScatter""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_all_gather_with_dp(self): + """Test Collective GEMM with FP8 Float8CurrentScaling + AllGather""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): + """Test Collective GEMM with FP8 Float8CurrentScaling + ReduceScatter""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + def test_te_mxfp8_all_gather_with_dp(self): + """Test Collective GEMM with MXFP8BlockScaling + AllGather""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_mxfp8_reduce_scatter_with_dp(self): + """Test Collective GEMM with MXFP8BlockScaling + ReduceScatter""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + # def test_te_nvfp4_all_gather_with_dp(self): + # """Test Collective GEMM with NVFP4BlockScaling + AllGather""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + # if not is_supported: + # self.skipTest(reason) + # self.args.collective_type = "all_gather" + # run_gemm_tests(self.args, self.mesh) + + # def test_te_nvfp4_reduce_scatter_with_dp(self): + # """Test Collective GEMM with NVFP4BlockScaling + ReduceScatter""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + # if not is_supported: + # self.skipTest(reason) + # self.args.collective_type = "reduce_scatter" + # run_gemm_tests(self.args, self.mesh) + if __name__ == "__main__": import sys diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 61c960a7a..be94c68d3 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -2,7 +2,6 @@ # # See LICENSE for license information. """Collective Dense Gradient test on multi-GPU with tensor parallelism""" -import argparse import unittest import os @@ -13,18 +12,24 @@ from common import ( assert_allclose, + get_tolerance_dtype, _initialize_distributed, _get_dp_and_tp_sizes, _create_mesh, DP_AXIS, TPSP_AXIS, - PARAMS_KEY, cgemm_parser, ) from transformer_engine.jax.layernorm_mlp import layernorm_mlp -from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.quantize import ( + autocast, + is_quantize_recipe_supported, + get_quantization_recipe, + QuantizerFactory, + noop_quantizer_set, +) from transformer_engine.jax.cpp_extensions.gemm import ( CollectiveOpSet, CollectiveOp, @@ -68,6 +73,7 @@ def _mean_layernorm_mlp( weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ): output = layernorm_mlp( x, @@ -82,6 +88,7 @@ def _mean_layernorm_mlp( kernel_2_axes=weight_2_axes, activation_type=("gelu",), collective_op_sets=collective_op_sets, + quantizer_sets=quantizer_sets, ) return jnp.mean(output) @@ -98,6 +105,7 @@ def _value_and_grad_layernorm_mlp( weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ): return jax.jit( jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10) @@ -113,14 +121,13 @@ def _value_and_grad_layernorm_mlp( weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ) def run_layernorm_mlp_grad_tests(args, mesh=None): - """Execute Dense Gradient tests.""" + """Execute LayerNorm MLP Gradient tests.""" print(args) - # Collective GEMM requires Shardy partitioner to be disabled - jax.config.update("jax_use_shardy_partitioner", False) # Initialize distributed with provided arguments _initialize_distributed(args) @@ -151,11 +158,21 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): collective_op_sets = (collective_op_set_1, collective_op_set_2) noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) + use_quantization = args.quantize_recipe is not None + recipe = get_quantization_recipe(args.quantize_recipe) if use_quantization else None with mesh, autocast( - enabled=False, - recipe=None, + enabled=use_quantization, + recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): + # Build quantizer_sets inside autocast so create_set() reads the global recipe + # for correct fwd/bwd dtypes. One set per dense layer (GEMM1=AG, GEMM2=RS). + quantizer_sets = ( + QuantizerFactory.create_set(n_quantizer_sets=2) + if use_quantization + else (noop_quantizer_set, noop_quantizer_set) + ) + # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) @@ -183,6 +200,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): weight_1_axes, weight_2_axes, noop_collective_op_sets, + quantizer_sets, ) output, sharded_grads = _value_and_grad_layernorm_mlp( x_sharded, @@ -196,6 +214,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ) jax.block_until_ready(ref_output) jax.block_until_ready(output) @@ -212,13 +231,14 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): jax.block_until_ready(gathered_ref_grads) if args.enable_result_check and args.process_id == 0: - assert_allclose(ref_output, output, dtype=jnp.bfloat16) + tol_dtype = get_tolerance_dtype(quantizer_sets[0]) + assert_allclose(ref_output, output, dtype=tol_dtype) for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): - assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + assert_allclose(ref_grad, gathered_grad, dtype=tol_dtype) class TestCollectiveLayerNormMLPGradient(unittest.TestCase): - """Collective Dense Gradient unittests""" + """Collective LayerNorm MLP Gradient unittests""" def setUp(self): self.args = cgemm_parser( @@ -242,9 +262,43 @@ def tearDown(self): os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) def test_te_bf16_layernorm_mlp_grad(self): - """Test Collective Dense Gradient with AllGather""" + """Test Collective LayerNorm MLP Gradient with BF16""" + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self): + """Test Collective LayerNorm MLP Gradient with FP8 DelayedScaling""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + run_layernorm_mlp_grad_tests(self.args, self.mesh) + def test_te_current_scaling_fp8_layernorm_mlp_grad(self): + """Test Collective LayerNorm MLP Gradient with FP8 Float8CurrentScaling""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + def test_te_mxfp8_layernorm_mlp_grad(self): + """Test Collective LayerNorm MLP Gradient with MXFP8BlockScaling""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + if not is_supported: + self.skipTest(reason) + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + # def test_te_nvfp4_layernorm_mlp_grad(self): + # """Test Collective LayerNorm MLP Gradient with NVFP4BlockScaling""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_quantize_recipe_supported(self.args.quantize_recipe) + # if not is_supported: + # self.skipTest(reason) + # run_layernorm_mlp_grad_tests(self.args, self.mesh) + if __name__ == "__main__": import sys @@ -267,6 +321,6 @@ def test_te_bf16_layernorm_mlp_grad(self): args = cgemm_parser( "Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism" - ).parse_args([]) + ).parse_args() _initialize_distributed(args) run_layernorm_mlp_grad_tests(args, mesh=None) diff --git a/examples/jax/datasets.txt b/examples/jax/datasets.txt new file mode 100644 index 000000000..fd3f5bc41 --- /dev/null +++ b/examples/jax/datasets.txt @@ -0,0 +1,3 @@ +# Datasets used by TE encoder tests. Pull these to pre-emptively cache datasets +ylecun/mnist +nyu-mll/glue \ No newline at end of file diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index f2ef33da4..3c1f2ba1f 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -11,10 +11,6 @@ TEST_CASES=( "test_te_current_scaling_fp8" "test_te_mxfp8" "test_te_nvfp4" -"test_te_bf16_shardy" -"test_te_delayed_scaling_fp8_shardy" -"test_te_current_scaling_fp8_shardy" -"test_te_nvfp4_shardy" ) : ${TE_PATH:=/opt/transformerengine} diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 61b74b35a..a672b862b 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -241,7 +241,6 @@ def check_fp8(state, var_collect, inputs, masks, labels): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) - jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) @@ -476,9 +475,6 @@ def encoder_parser(args): parser.add_argument( "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism." ) - parser.add_argument( - "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." - ) return parser.parse_args(args) @@ -561,70 +557,6 @@ def test_te_nvfp4_with_sp(self): actual = train_and_evaluate(self.args) assert actual[0] < 0.40 and actual[1] > 0.82 - @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - def test_te_bf16_shardy(self): - """Test Transformer Engine with BF16""" - self.args.enable_shardy = True - actual = train_and_evaluate(self.args) - assert actual[0] < 0.36 and actual[1] > 0.84 - - @unittest.skipIf(not is_fp8_supported, fp8_reason) - def test_te_delayed_scaling_fp8_shardy(self): - """Test Transformer Engine with DelayedScaling FP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "DelayedScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.362 and actual[1] > 0.84 - - @unittest.skipIf(not is_fp8_supported, fp8_reason) - def test_te_delayed_scaling_fp8_with_sp_shardy(self): - """Test Transformer Engine with DelayedScaling FP8 + SP""" - self.args.enable_shardy = True - self.args.enable_sp = True - self.args.use_fp8 = True - self.args.fp8_recipe = "DelayedScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.362 and actual[1] > 0.84 - - @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - def test_te_mxfp8_shardy(self): - """Test Transformer Engine with MXFP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "MXFP8BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.36 and actual[1] > 0.84 - - @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) - def test_te_nvfp4_shardy(self): - """Test Transformer Engine with NVFP4""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "NVFP4BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.40 and actual[1] > 0.82 - - @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - def test_te_mxfp8_with_sp_shardy(self): - """Test Transformer Engine with MXFP8 + SP""" - self.args.enable_shardy = True - self.args.enable_sp = True - self.args.use_fp8 = True - self.args.fp8_recipe = "MXFP8BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.36 and actual[1] > 0.84 - - @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) - def test_te_nvfp4_with_sp_shardy(self): - """Test Transformer Engine with NVFP4""" - self.args.enable_shardy = True - self.args.enable_sp = True - self.args.use_fp8 = True - self.args.fp8_recipe = "NVFP4BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.40 and actual[1] > 0.82 - if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 51fb20ebc..5b9be7b73 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -251,7 +251,6 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) - jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) num_gpu = jax.local_device_count() @@ -440,9 +439,6 @@ def encoder_parser(args): default="DelayedScaling", help="Use FP8 recipe (default: DelayedScaling)", ) - parser.add_argument( - "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." - ) return parser.parse_args(args) @@ -496,49 +492,6 @@ def test_te_nvfp4(self): actual = train_and_evaluate(self.args) assert actual[0] < 0.52 and actual[1] > 0.74 - @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - def test_te_bf16_shardy(self): - """Test Transformer Engine with BF16""" - self.args.enable_shardy = True - actual = train_and_evaluate(self.args) - assert actual[0] < 0.51 and actual[1] > 0.75 - - @unittest.skipIf(not is_fp8_supported, fp8_reason) - def test_te_delayed_scaling_fp8_shardy(self): - """Test Transformer Engine with DelayedScaling FP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "DelayedScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.51 and actual[1] > 0.75 - - @unittest.skipIf(not is_fp8_supported, fp8_reason) - def test_te_current_scaling_fp8_shardy(self): - """Test Transformer Engine with CurrentScaling FP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "Float8CurrentScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.51 and actual[1] > 0.749 - - @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - def test_te_mxfp8_shardy(self): - """Test Transformer Engine with MXFP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "MXFP8BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.51 and actual[1] > 0.75 - - @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) - def test_te_nvfp4_shardy(self): - """Test Transformer Engine with NVFP4""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "NVFP4BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 - if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 66993f290..873febaee 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -361,7 +361,6 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) - jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) if args.process_id == 0: nltk.download("punkt_tab") @@ -607,9 +606,6 @@ def encoder_parser(args): default=0, help="the ID number of the current process (default: 0)", ) - parser.add_argument( - "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." - ) return parser.parse_args(args) @@ -618,7 +614,7 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): + def exec(self, use_fp8, fp8_recipe): """Run 5 epochs for testing""" args = encoder_parser(["--epochs", "5"]) @@ -634,7 +630,6 @@ def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): args.num_process = num_gpu args.process_id = self.process_id args.fp8_recipe = fp8_recipe - args.enable_shardy = enable_shardy return train_and_evaluate(args) @@ -676,44 +671,6 @@ def test_te_nvfp4(self): result = self.exec(True, "NVFP4BlockScaling") assert result[0] < 0.451 and result[1] > 0.787 - @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - def test_te_bf16_shardy(self): - """Test Transformer Engine with BF16""" - result = self.exec(False, None, enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.80 - - @unittest.skipIf( - not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" - ) - def test_te_delayed_scaling_fp8_shardy(self): - """Test Transformer Engine with DelayedScaling FP8""" - result = self.exec(True, "DelayedScaling", enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.80 - - @unittest.skipIf( - not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" - ) - def test_te_current_scaling_fp8_shardy(self): - """Test Transformer Engine with CurrentScaling FP8""" - result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) - assert result[0] < 0.432 and result[1] > 0.80 - - @unittest.skipIf( - not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" - ) - def test_te_mxfp8_shardy(self): - """Test Transformer Engine with MXFP8""" - result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.80 - - @unittest.skipIf( - not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4" - ) - def test_te_nvfp4_shardy(self): - """Test Transformer Engine with NVFP4""" - result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True) - assert result[0] < 0.451 and result[1] > 0.787 - if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index b469ef56b..ac7a2fac7 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -18,7 +18,12 @@ ) import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.common.recipe import ( + Format, + DelayedScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) @@ -64,10 +69,21 @@ def torch_dtype(d): "bfloat16": torch.bfloat16, } if lowercase(d) not in typemap.keys(): - raise TypeError + raise argparse.ArgumentTypeError( + f"invalid dtype '{d}'. Supported values: fp32/float32, fp16/float16, bf16/bfloat16" + ) return typemap[lowercase(d)] +def precision(d): + typemap = ["fp32", "fp16", "fp8", "mxfp8", "nvfp4"] + if lowercase(d) not in typemap: + raise argparse.ArgumentTypeError( + f"invalid precision '{d}'. Supported values: {', '.join(typemap)}" + ) + return lowercase(d) + + te_layer_map = { "linear": te.Linear, "layernorm": te.LayerNorm, @@ -91,7 +107,6 @@ def get_layer_args(opts): hidden_size = opts.num_heads * opts.head_dim layer_args = (hidden_size,) layer_kwargs = { - "params_dtype": opts.dtype, "device": "cuda" if opts.no_defer_init else "meta", "get_rng_state_tracker": get_cuda_rng_tracker, } @@ -112,6 +127,15 @@ def get_layer_args(opts): return layer_args, layer_kwargs +class StoreExplicitAction(argparse.Action): + """Custom action that tracks whether an argument was explicitly set.""" + + def __call__(self, parser, namespace, values, option_string=None): + # values already converted by argparse via action.type + setattr(namespace, self.dest, values) + setattr(namespace, f"{self.dest}_explicitly_set", True) + + def parse_fsdp_args(): parser = argparse.ArgumentParser( description="Run Transformer Engine modules with the " @@ -173,7 +197,10 @@ def parse_fsdp_args(): "--no-fp8", action="store_true", default=False, - help="Disables the te.autocast() context.", + help=( + "Disable te.autocast() FP8 context. Incompatible with --precision fp8/mxfp8/nvfp4." + " Default: False." + ), ) parser.add_argument( "--no-defer-init", @@ -189,7 +216,21 @@ def parse_fsdp_args(): "--dtype", type=torch_dtype, default=torch.bfloat16, - help="Data type for input tensor and Transformer Engine module parameters.", + action=StoreExplicitAction, + help=( + "Parameter dtype: fp32/float32, fp16/float16, bf16/bfloat16. Overrides --precision" + " dtype when explicitly set. Default: bfloat16." + ), + ) + parser.add_argument( + "--precision", + type=precision, + default=None, + help=( + "Precision preset: fp32, fp16, fp8, mxfp8, nvfp4. Configures dtype and FP8 recipe" + " automatically. Overridden by explicit --dtype. Default: None (use --dtype and" + " --no-fp8 directly)." + ), ) return parser.parse_args() @@ -200,15 +241,118 @@ def dist_print(text, all_ranks=False, no_new_line=False): print(f"[GPU-{LOCAL_RANK}] " + text, end=end) +def get_precision_preset(precision_value): + """Get dtype, no_fp8, and recipe based on precision preset. + + Returns: + tuple: (dtype, no_fp8, recipe) + """ + match precision_value: + case "fp32": + return torch.float32, True, None + case "fp16": + return torch.float16, True, None + case "fp8": + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + return torch.bfloat16, False, recipe + case "mxfp8": + recipe = MXFP8BlockScaling() + return torch.bfloat16, False, recipe + case "nvfp4": + recipe = NVFP4BlockScaling() + return torch.bfloat16, False, recipe + case _: + raise ValueError( + f"Invalid precision preset: {precision_value}. " + "Supported values: fp32, fp16, fp8, mxfp8, nvfp4" + ) + + def train(opts): + # Check which flags were explicitly set + dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False) + + # Validate flag combinations before touching distributed state. + # Error if user requests FP8-based precision but also sets --no-fp8 + # Safe to raise here because torchrun guarantees all ranks receive + # identical CLI arguments; all ranks will raise simultaneously. + if opts.precision in ["fp8", "mxfp8", "nvfp4"] and opts.no_fp8: + raise ValueError( + f"Cannot use --no-fp8 with --precision {opts.precision}. " + "These flags are incompatible. " + f"Either remove --no-fp8 to use {opts.precision} training, " + "or use --precision fp32/fp16 for non-FP8 training." + ) + if opts.precision in ["fp32", "fp16"] and opts.no_fp8: + dist_print( + f"Warning: --no-fp8 is redundant when using --precision {opts.precision} " + "(FP8 is already disabled by this preset). The flag will be ignored." + ) + # Initialize torch.distributed global process group dist.init_process_group(backend="nccl") torch.cuda.set_device(LOCAL_RANK) dist_print(f"WORLD_SIZE = {WORLD_SIZE}") torch.manual_seed(opts.seed) + preset_dtype: torch.dtype = opts.dtype # sensible fallback + preset_recipe = None + + if opts.precision is not None: + preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision) + dtype, no_fp8, recipe = preset_dtype, preset_no_fp8, preset_recipe + dist_print(f"Using precision preset: {opts.precision}") + else: + # Original behavior: --dtype and --no-fp8 control training directly + dtype = opts.dtype + no_fp8 = opts.no_fp8 + recipe = ( + DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max") + if not no_fp8 + else None + ) + + dtype_name = str(dtype).replace("torch.", "") + + # Apply explicit dtype override with warning + if dtype_explicitly_set and opts.precision is not None: + new_dtype = opts.dtype + if new_dtype != preset_dtype: + if opts.precision in ["fp8", "mxfp8", "nvfp4"] and new_dtype == torch.float16: + dist_print( + "Warning: --dtype float16 may be incompatible with --precision" + f" {opts.precision}, which expects bfloat16 accumulation." + ) + + dtype = new_dtype + dtype_name = str(dtype).replace("torch.", "") + + dist_print( + f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype" + " setting" + ) + else: + new_dtype_name = str(new_dtype).replace("torch.", "") + dist_print( + f"Info: --dtype {new_dtype_name} matches --precision {opts.precision} preset" + " default, no override needed" + ) + + # recipe is already set correctly from preset_recipe above; + # dtype only affects parameter storage, not the quantization recipe + + # Always log the final configuration being used + dist_print( + f"Training configuration: dtype={dtype_name}, " + f"quantization={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}" + ) + # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM layer_args, layer_kwargs = get_layer_args(opts) + layer_kwargs["params_dtype"] = dtype + if opts.num_layers > 1: te_layer_list = [] for i in range(opts.num_layers): @@ -239,7 +383,7 @@ def train(opts): process_group=all_gpus, use_orig_params=True, mixed_precision=MixedPrecision( - param_dtype=opts.dtype, + param_dtype=dtype, reduce_dtype=torch.float32, ), auto_wrap_policy=fsdp_wrap_policy, @@ -258,10 +402,6 @@ def train(opts): dist_print(f"Post-FSDP memory use = {post_mem_use}MiB") dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}") - # Fp8 setup for TE - fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") - # Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded optim = torch.optim.Adam(te_model.parameters(), lr=0.0001) @@ -275,17 +415,33 @@ def train(opts): torch.cuda.synchronize() start.record() + # MXFP8 and NVFP4 use local block scaling — no distributed amax reduction group needed. + # amax_reduction_group is only required for DelayedScaling (global AMAX allreduce). + # Also skip when FP8 is disabled to avoid unnecessary distributed communication. + # Compute amax_group BEFORE the recipe fallback so isinstance() reflects the actual + # recipe, not the defensive DelayedScaling() substituted for None. + amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None + + # Ensure recipe is always a concrete object before passing to te.autocast. + # When FP8 is disabled, te.autocast ignores the recipe, but some TE versions + # perform attribute access on it regardless of the enabled flag. + if recipe is None: + recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + for i in range(opts.num_iters): # Generate a random input batch x = torch.rand( opts.seq_length, opts.batch_size, opts.num_heads * opts.head_dim, - dtype=opts.dtype, + dtype=dtype, device="cuda", ) + # autocast needs to be given the FSDP process group for amax reductions - with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus): + with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=amax_group): y = te_model(x) loss = y.sum() # calculate gradient and take training step outside the autocast context diff --git a/examples/pytorch/quantized_model_init/fully_shard.py b/examples/pytorch/quantized_model_init/fully_shard.py new file mode 100644 index 000000000..613171200 --- /dev/null +++ b/examples/pytorch/quantized_model_init/fully_shard.py @@ -0,0 +1,266 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""FSDP2 distributed training with quantized model initialization. + +Extends the single-GPU ``main.py`` example to multi-GPU training using +PyTorch-native FSDP2 (``fully_shard``). The script demonstrates: + +1. **Meta-device initialization** -- Model parameters are created on the + ``meta`` device (zero memory), then FSDP2 sharding is applied, and + finally ``reset_parameters()`` materializes and quantizes only the + local shards on each rank's GPU. +2. ``quantized_model_init`` -- Flags the model for FP8 weight initialization + (actual quantization happens in ``reset_parameters`` after sharding). +3. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer. +4. ``FusedAdam`` with FP32 master weights for full-precision training updates. + +.. note:: + ``fuse_wgrad_accumulation`` is **not** used here. That feature writes + weight gradients directly into ``main_grad`` buffers, bypassing the + autograd gradient flow. FSDP2 requires gradients to go through its + reduce-scatter, so ``fuse_wgrad_accumulation`` needs Megatron-Core's + FSDP integration (which provides ``get_main_grad()``). + +Usage:: + + torchrun --nproc-per-node 2 fully_shard.py +""" + +import os + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor + +import transformer_engine.pytorch as te +from transformer_engine.pytorch import QuantizedTensor +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule + +# ── Configuration (matches main.py) ────────────────────────────────── +HIDDEN_SIZE = 256 +FFN_HIDDEN_SIZE = 1024 +NUM_ATTENTION_HEADS = 8 +NUM_LAYERS = 3 +SEQ_LEN = 32 +BATCH_PER_RANK = 2 +NUM_STEPS = 5 +DTYPE = torch.bfloat16 + + +def dist_print(msg): + """Print only on rank 0.""" + if int(os.environ.get("RANK", "0")) == 0: + print(msg) + + +def main(): + # ── 1. Distributed setup ───────────────────────────────────────── + assert "TORCHELASTIC_RUN_ID" in os.environ, ( + "This script must be launched with torchrun, e.g.:\n" + " torchrun --nproc-per-node 2 fully_shard.py" + ) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl") + device = torch.device(f"cuda:{local_rank}") + + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # ── 2. Create model on meta device (zero memory) ──────────────── + # quantized_model_init sets the flag for FP8 weight initialization, + # but with device="meta" no actual memory is allocated yet. + with te.quantized_model_init(enabled=True): + model = torch.nn.Sequential( + *[ + te.TransformerLayer( + HIDDEN_SIZE, + FFN_HIDDEN_SIZE, + NUM_ATTENTION_HEADS, + fuse_qkv_params=True, + params_dtype=DTYPE, + hidden_dropout=0.0, + attention_dropout=0.0, + device="meta", + ) + for _ in range(NUM_LAYERS) + ] + ) + + # Verify all parameters are on meta device (no GPU memory used). + for name, param in model.named_parameters(): + assert param.device == torch.device("meta"), f"{name} is not on meta device" + dist_print("Model created on meta device (zero GPU memory).") + + # ── 3. FSDP2 sharding ──────────────────────────────────────────── + # Apply sharding to the meta-device model. FSDP2 wraps parameters + # as DTensors but no GPU memory is allocated yet. + mesh = DeviceMesh("cuda", list(range(world_size))) + for child in model.children(): + fully_shard(child, mesh=mesh) + fully_shard(model, mesh=mesh) + dist_print("FSDP2 sharding applied to meta-device model.") + + # ── 4. Materialize parameters on GPU ────────────────────────────── + # reset_parameters() on each TE module materializes the local shard + # on CUDA, applies weight initialization, and quantizes to FP8. + for module in model.modules(): + if isinstance(module, TransformerEngineBaseModule): + module.reset_parameters() + + # Post-materialization verification. + for name, param in model.named_parameters(): + assert isinstance(param, DTensor), f"{name} is not a DTensor after sharding" + qt_count = sum( + 1 + for _, p in model.named_parameters() + if isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) + ) + assert qt_count > 0, "No QuantizedTensor local tensors after materialization" + dist_print( + f"Parameters materialized: {qt_count} FP8 (QuantizedTensor) weight params " + "wrapped in DTensors." + ) + + # ── 5. Optimizer ───────────────────────────────────────────────── + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + dist_print("Using FusedAdam with master_weights=True.") + + # ── 6. Training loop ───────────────────────────────────────────── + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device) + target = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device) + + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + + with te.autocast(enabled=True): + output = model(x) + + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + dist_print(f" Step {step}: loss = {loss.item():.6f}") + + # ── 7. Post-training assertions ────────────────────────────────── + dist_print("\nVerifying invariants ...") + + qt_after = 0 + for name, param in model.named_parameters(): + assert isinstance(param, DTensor), f"{name} lost DTensor wrapping" + if isinstance(param._local_tensor, QuantizedTensor): + qt_after += 1 + assert qt_after > 0, "No QuantizedTensor local tensors after training" + dist_print(f" {qt_after} params still have QuantizedTensor local tensors.") + + # Optimizer states: master weights and moments should be float32. + for param in model.parameters(): + state = optimizer.state[param] + if "master_param" in state: + assert ( + state["master_param"].dtype == torch.float32 + ), f"Master weight dtype {state['master_param'].dtype}, expected float32" + assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32" + assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32" + + dist_print("All assertions passed!") + dist_print(" - Linear weight parameters: QuantizedTensor (FP8) wrapped in DTensor") + dist_print(" - Optimizer master weights: float32") + dist_print(" - Optimizer states (exp_avg, exp_avg_sq): float32") + + # ── 8. Distributed checkpoint: save and load ───────────────────── + # torch.distributed.checkpoint (DCP) saves sharded state — each rank + # writes only its local shard. This preserves FP8 compute weights + # and the full optimizer state (master weights, moments, step count). + import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_optimizer_state_dict, + ) + + # Use a fixed path so all ranks agree on the checkpoint location. + checkpoint_dir = "/tmp/te_fsdp2_example_checkpoint" + dist_print(f"\nSaving distributed checkpoint to {checkpoint_dir} ...") + + # Save sharded checkpoint. DCP handles DTensor shards natively — + # each rank writes only its local shard to the filesystem. + dcp.save( + {"model": model.state_dict(), "optimizer": optimizer.state_dict()}, + checkpoint_id=checkpoint_dir, + ) + dist_print(" Checkpoint saved (FP8 weights + optimizer state).") + + # Load checkpoint back. Provide empty state dict containers with the + # same structure; DCP fills them from the saved files. + state_to_load = {"model": model.state_dict(), "optimizer": optimizer.state_dict()} + dcp.load(state_to_load, checkpoint_id=checkpoint_dir) + model.load_state_dict(state_to_load["model"]) + optimizer.load_state_dict(state_to_load["optimizer"]) + dist_print(" Checkpoint loaded — FP8 weights and optimizer state restored.") + + # Verify training continues after checkpoint load. + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True): + output = model(x) + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + dist_print(f" Post-checkpoint training step: loss = {loss.item():.6f}") + + # ── 9. Save full-precision (FP32) model to safetensors ─────────── + # For inference or fine-tuning you typically want FP32 weights, not + # FP8 compute weights. The optimizer's master weight copies are the + # authoritative FP32 values (more precise than dequantizing FP8). + # All ranks must participate in gathering; only rank 0 saves. + from safetensors.torch import save_file + + full_opts = StateDictOptions(full_state_dict=True, cpu_offload=True) + + full_model_state = get_model_state_dict(model, options=full_opts) + full_opt_state = get_optimizer_state_dict(model, optimizer, options=full_opts) + + rank = int(os.environ.get("RANK", "0")) + if rank == 0: + fp32_state = {} + opt_param_states = full_opt_state.get("state", {}) + + for key, value in full_model_state.items(): + if key in opt_param_states and "master_param" in opt_param_states[key]: + # Prefer optimizer's FP32 master weight (maintained throughout training). + fp32_state[key] = opt_param_states[key]["master_param"].float() + elif isinstance(value, QuantizedTensor): + # Fallback: dequantize FP8 → FP32 (e.g. if master_weights was off). + fp32_state[key] = value.dequantize().float() + else: + # Non-FP8 params (e.g. LayerNorm weights): cast to FP32. + fp32_state[key] = value.float() + + save_path = "/tmp/te_fsdp2_example_model_fp32.safetensors" + save_file(fp32_state, save_path) + dist_print(f"\nSaved FP32 model ({len(fp32_state)} params) to {save_path}") + + # Quick verification: all saved tensors are float32. + from safetensors.torch import load_file + + loaded = load_file(save_path) + for k, v in loaded.items(): + assert v.dtype == torch.float32, f"{k}: expected float32, got {v.dtype}" + dist_print(f" Verified: all {len(loaded)} tensors are float32.") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/quantized_model_init/main.py b/examples/pytorch/quantized_model_init/main.py new file mode 100644 index 000000000..a9d3480ca --- /dev/null +++ b/examples/pytorch/quantized_model_init/main.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Quantized model initialization with FusedAdam and gradient accumulation fusion. + +Demonstrates three Transformer Engine features working together: + +1. ``quantized_model_init`` -- Initialize a model with low-precision (FP8) + parameters, avoiding the memory cost of storing both high-precision and + quantized copies of every weight. + +2. ``FusedAdam`` with master weights -- Maintain FP32 master copies of the + weights inside the optimizer so that the training update retains full + precision despite the model parameters being FP8. + +3. Gradient accumulation fusion -- Use ``fuse_wgrad_accumulation=True`` + together with per-parameter ``main_grad`` buffers so that weight + gradients are accumulated directly in FP32 via Tensor Cores, avoiding a + separate FP8-to-FP32 cast kernel. + +Usage:: + + python main.py +""" + +import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor + +# ── Configuration ────────────────────────────────────────────────────── +HIDDEN_SIZE = 256 +FFN_HIDDEN_SIZE = 1024 +NUM_ATTENTION_HEADS = 8 +SEQ_LEN = 32 +BATCH_SIZE = 2 +NUM_STEPS = 5 +DTYPE = torch.bfloat16 + + +def main(): + # ── 1. Create model with quantized parameters ───────────────────── + # + # Inside quantized_model_init, TransformerEngine modules store only the + # FP8 quantized copy of each parameter (a Float8Tensor), eliminating the + # memory overhead of a high-precision shadow copy. + with te.quantized_model_init(enabled=True): + model = te.TransformerLayer( + HIDDEN_SIZE, + FFN_HIDDEN_SIZE, + NUM_ATTENTION_HEADS, + fuse_wgrad_accumulation=True, + fuse_qkv_params=True, # required for fuse_wgrad_accumulation + params_dtype=DTYPE, + hidden_dropout=0.0, # disable dropout for this synthetic example + attention_dropout=0.0, + ) + + # Verify that linear-layer weight parameters are quantized. + # Biases and LayerNorm parameters are *not* quantized. + quantized_count = 0 + for name, param in model.named_parameters(): + if isinstance(param, QuantizedTensor): + quantized_count += 1 + assert quantized_count > 0, "No QuantizedTensor parameters found" + print(f"Found {quantized_count} QuantizedTensor (FP8) weight parameters.") + + # ── 2. Allocate main_grad buffers (FP32) ────────────────────────── + # + # fuse_wgrad_accumulation causes weight-gradient GEMMs to write directly + # into ``param.main_grad`` in FP32 (via Tensor Core accumulation). + # Non-weight parameters (e.g. LayerNorm) still receive gradients through + # the normal ``param.grad`` path. + for param in model.parameters(): + param.main_grad = torch.zeros(param.shape, dtype=torch.float32, device=param.device) + + # ── 3. Optimizer with FP32 master weights ───────────────────────── + # + # use_decoupled_grad=True tells FusedAdam to read gradients from + # ``param.decoupled_grad`` instead of ``param.grad``. This avoids + # the dtype-mismatch error that would occur when assigning FP32 + # gradients to bfloat16 parameters via ``.grad``. + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + use_decoupled_grad=True, + ) + + # ── 4. Training loop ────────────────────────────────────────────── + # + # Use a fixed synthetic dataset so that loss decreases over steps. + x = torch.randn(SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE, device="cuda") + target = torch.randn(SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE, device="cuda") + + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + for param in model.parameters(): + param.main_grad.zero_() + + # Forward pass inside autocast to enable FP8 compute. + with te.autocast(enabled=True): + output = model(x) + + loss = torch.nn.functional.mse_loss(output, target) + loss.backward() + + # Consolidate gradients into main_grad. + # * Weight params with fuse_wgrad_accumulation: backward already + # accumulated the gradient directly into main_grad (FP32). + # * Other params (e.g. LayerNorm): autograd set param.grad. + for param in model.parameters(): + if param.grad is not None: + param.main_grad.copy_(param.grad) + param.grad = None + + # Expose main_grad as decoupled_grad so FusedAdam can read it. + for param in model.parameters(): + param.decoupled_grad = param.main_grad + + optimizer.step() + print(f" Step {step}: loss = {loss.item():.6f}") + + # ── 5. Post-training assertions ─────────────────────────────────── + print("\nVerifying invariants ...") + + # Optimizer states. + for param in model.parameters(): + state = optimizer.state[param] + if "master_param" in state: + master = state["master_param"] + assert ( + master.dtype == torch.float32 + ), f"Master weight dtype {master.dtype}, expected float32" + assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32" + assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32" + + # main_grad buffers. + for param in model.parameters(): + assert param.main_grad.dtype == torch.float32, "main_grad should be float32" + + print("All assertions passed!") + print(" - Linear weight parameters: QuantizedTensor (FP8)") + print(" - Optimizer master weights: float32") + print(" - Optimizer states (exp_avg, exp_avg_sq): float32") + print(" - Gradient accumulation buffers (main_grad): float32") + + +if __name__ == "__main__": + main() diff --git a/qa/L0_jax_lint/test.sh b/qa/L0_jax_lint/test.sh old mode 100644 new mode 100755 diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index ee9ce130a..3453e35d2 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" mkdir -p "$XML_LOG_DIR" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_with_determinism.xml $TE_PATH/tests/jax/test_fused_attn.py -k "TestFusedAttnWithDeterminism" || test_fail "tests/jax/test_fused_attn.py" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 21eed2836..f2b0b07fe 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -32,10 +32,12 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" @@ -45,9 +47,14 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" -NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" +export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint +if [ ! -d "$NVTE_TEST_CHECKPOINT_ARTIFACT_PATH" ]; then + python3 $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all || error_exit "Failed to generate checkpoint files" +fi +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh index fcf1a52b9..fe4aab456 100644 --- a/qa/L0_pytorch_wheel/test.sh +++ b/qa/L0_pytorch_wheel/test.sh @@ -28,11 +28,11 @@ WHL_BASE="transformer_engine-${VERSION}" # Core wheel. NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation -vvv --wheel-dir ./dist . || error_exit "Failed to setup bdist_wheel" -wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl" +python3 -m wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" -wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" +python3 -m wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" rm dist/*.whl || error_exit "Failed to remove dist/*.whl" mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup metapackage" diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index b3a520e12..6f9ff54e4 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -6,4 +6,5 @@ : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py +# NVTE_UnfusedDPA_Emulate_FP8=1 enables FP8 attention emulation when no native backend is available +NVTE_UnfusedDPA_Emulate_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/setup.py b/setup.py index c66af39df..a6e247305 100644 --- a/setup.py +++ b/setup.py @@ -122,11 +122,6 @@ def setup_common_extension() -> CMakeExtension: f"nvidia-cublasmp-cu{cuda_version()[0]}" ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}") cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") - nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution( - f"nvidia-nvshmem-cu{cuda_version()[0]}" - ).locate_file("nvidia/nvshmem") - cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") - print("CMAKE_FLAGS:", cmake_flags[-2:]) # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 8a19e84f5..e85f70630 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -4,51 +4,49 @@ # # See LICENSE for license information. -list(APPEND test_cuda_sources - test_cast.cu - test_cast_current_scaling.cu - test_cast_dbias.cu - test_cast_dbias_dgelu.cu - test_cast_gated_swiglu.cu - test_cast_mxfp8_gated_swiglu.cu - test_qdq.cu - test_cast_mxfp8.cu - test_dequantize_mxfp8.cu - test_dequantize_nvfp4.cu - test_cast_nvfp4_transpose.cu - test_transpose.cu - test_cast_transpose.cu - test_cast_transpose_current_scaling.cu - test_cast_transpose_dbias.cu - test_cast_transpose_dbias_dgelu.cu - test_cast_transpose_dgeglu.cu - test_act.cu - test_normalization.cu - test_normalization_mxfp8.cu - test_memset.cu - test_multi_cast_transpose.cu - test_multi_padding.cu - test_multi_unpadding.cu - test_causal_softmax.cu - test_swap_first_dims.cu - ../test_common.cu) -if(USE_CUDA) - list(APPEND test_cuda_sources - test_cast_float8blockwise.cu - test_swizzle.cu) -else() +set(GX_CUDA $) +add_executable(test_operator + test_cast.cu + test_cast_current_scaling.cu + test_cast_dbias.cu + test_cast_dbias_dgelu.cu + test_cast_gated_swiglu.cu + test_cast_mxfp8_gated_swiglu.cu + test_qdq.cu + test_cast_mxfp8.cu + test_cast_mxfp8_grouped.cu + test_cast_nvfp4_transpose.cu + $<${GX_CUDA}:test_cast_float8blockwise.cu> + test_dequantize_mxfp8.cu + test_transpose.cu + test_cast_transpose.cu + test_cast_transpose_current_scaling.cu + test_cast_transpose_dbias.cu + test_cast_transpose_dbias_dgelu.cu + test_cast_transpose_dgeglu.cu + test_act.cu + test_normalization.cu + test_normalization_mxfp8.cu + test_memset.cu + test_splits_to_offsets.cu + test_multi_cast_transpose.cu + test_multi_padding.cu + test_multi_unpadding.cu + test_causal_softmax.cu + $<${GX_CUDA}:test_swizzle.cu> + test_swap_first_dims.cu + $<${GX_CUDA}:test_grouped_gemm.cu> + ../test_common.cu) + +if(USE_ROCM) + get_target_property(test_cuda_sources test_operator SOURCES) list(APPEND test_cuda_sources + test_dequantize_nvfp4.cu test_cublaslt_gemm.cu) -endif() - -if(USE_CUDA) - add_executable(test_operator ${test_cuda_sources}) -else() TE_GetHipifiedSources("${test_cuda_sources}" ${CMAKE_CURRENT_SOURCE_DIR} test_hip_sources) message("${message_line}") message(STATUS "test_operator hipified sources: ${test_hip_sources}") - - add_executable(test_operator ${test_hip_sources}) + set_target_properties(test_operator PROPERTIES SOURCES "${test_hip_sources}") endif() # Find required packages diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu new file mode 100644 index 000000000..9f2523cb6 --- /dev/null +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -0,0 +1,886 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ProcessingMethod { + CAST_ONLY, + CAST_DBIAS, + CAST_DBIAS_DACT, + CAST_DACT, + CAST_ACT +}; + +enum ActivationKind { + Identity, + GeLU, + SiLU, + ReLU, + QGeLU, + SReLU +}; + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +template +void compute_ref(const ProcessingMethod processing_method, + float (*OP)(const float), + const bool rowwise, + const bool colwise, + const InputType* input, + const InputType* grad, + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* output_scales_rowwise, + fp8e8m0* output_scales_colwise, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) +{ +#ifdef __HIP_PLATFORM_AMD__ + using std::isnan, std::isinf; +#endif + const size_t tile_size_Y = 32; + const size_t tile_size_X = 32; + const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; + const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; + + std::vector output_dbias_fp32(cols, 0); + #pragma omp parallel proc_bind(spread) + { + // Buffers to cache intermediate computations + std::vector cache_buffer(tile_size_Y * tile_size_X); + + std::vector thread_dbias(cols, 0); + #pragma omp for schedule(static) + for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { + const size_t tile_Y = t / tiles_num_X; + const size_t tile_X = t % tiles_num_X; + const size_t tile_offset_Y = tile_Y * tile_size_Y; + const size_t tile_offset_X = tile_X * tile_size_X; + + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(i_min + tile_size_Y, rows); + + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(j_min + tile_size_X, cols); + + // Cache computations + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + + float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } + thread_dbias[j] += elt; + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + elt = static_cast(static_cast(elt)); + + cache_buffer[cache_idx] = elt; + if (isinf(elt) || isnan(elt)) { + continue; + } + } + } + + if (rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax = 0.0f; + + for (size_t j = j_min; j < j_max; ++j) { + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const size_t scale_idx = i * scales_stride_rowwise + tile_X; + output_scales_rowwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_rowwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + if (colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax = 0.0f; + + for (size_t i = i_min; i < i_max; ++i) { + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const size_t scale_idx = tile_Y * scales_stride_colwise + j; + output_scales_colwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t i = i_min; i < i_max; ++i) { + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_colwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + } + #pragma omp critical + { + for (size_t j = 0; j < cols; ++j) { + output_dbias_fp32[j] += thread_dbias[j]; + } + } + } + + for (size_t j = 0; j < cols; ++j) { + output_dbias[j] = static_cast(output_dbias_fp32[j]); + } +} + +template +void compare_scaled_elts(const std::string &name, + const T* ref_data, + const T* test_data, + const size_t rows, + const size_t cols, + const bool rowwise, + const size_t tolerable_mismatches_limit = 0, + const double atol = 1e-5, + const double rtol = 1e-8) { + size_t mismatches_num = 0; + int first_mismatch_idx = -1; + + for (size_t i = 0; i < rows * cols; ++i) { + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = false; + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + std::string direction = rowwise ? "rowwise" : "columnwise"; + if (assertion) { + mismatches_num++; + if (first_mismatch_idx == -1) { + first_mismatch_idx = i; + } + } + if (mismatches_num > tolerable_mismatches_limit) { + const double first_mismatch_t = static_cast(test_data[first_mismatch_idx]); + const double first_mismatch_r = static_cast(ref_data[first_mismatch_idx]); + + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "First mismatch at place " << first_mismatch_idx + << " (" << std::to_string(first_mismatch_idx) << "): " + << first_mismatch_t << " vs " << first_mismatch_r; + } + } +} + +/** + * Scaling along single dimension (either rows or columns) + * Produces one set of output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * OR + * 2) Scaled columns + column-wise scaling factors + */ +template +void performTest(const ProcessingMethod processing_method, + float (*OP)(const float), + const ShapeRepresentation shape_rep, + const size_t num_tensors, + const std::vector& logical_shape_vec, + const std::vector& first_dims_h, + const std::vector& last_dims_h, + const std::vector& offsets_h, + const bool rowwise, + const bool colwise) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const bool compute_dbias = (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT); + + const size_t rows = logical_shape_vec[0]; + const size_t cols = logical_shape_vec[1]; + + size_t elts_num = 0; + size_t rowwise_sfs_num = 0; + size_t colwise_sfs_num = 0; + size_t sum_of_last_dims = 0; + + std::vector rowwise_scales_first_dim(num_tensors, 0); + std::vector rowwise_scales_last_dim(num_tensors, 0); + std::vector rowwise_scales_offset(num_tensors + 1, 0); + std::vector colwise_scales_first_dim(num_tensors, 0); + std::vector colwise_scales_last_dim(num_tensors, 0); + std::vector colwise_scales_offset(num_tensors + 1, 0); + std::vector dbias_offsets(num_tensors + 1, 0); + + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; + const size_t elts = M * K; + elts_num += elts; + + const size_t unpadded_rowwise_blocks_Y = M; + const size_t unpadded_rowwise_blocks_X = divide_round_up(K, 32); + const size_t unpadded_colwise_blocks_Y = divide_round_up(M, 32); + const size_t unpadded_colwise_blocks_X = K; + + rowwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_Y, 128); + rowwise_scales_last_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4); + colwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_colwise_blocks_Y, 4); + colwise_scales_last_dim[t] = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128); + + const size_t rowwise_sfs = rowwise_scales_first_dim[t] * rowwise_scales_last_dim[t]; + const size_t colwise_sfs = colwise_scales_first_dim[t] * colwise_scales_last_dim[t]; + + rowwise_sfs_num += rowwise_sfs; + colwise_sfs_num += colwise_sfs; + sum_of_last_dims += K; + + rowwise_scales_offset[t+1] = rowwise_sfs_num; + colwise_scales_offset[t+1] = colwise_sfs_num; + dbias_offsets[t+1] = sum_of_last_dims; + } + + std::vector scales_rowwise_shape = {rowwise_sfs_num}; + std::vector scales_colwise_shape = {colwise_sfs_num}; + + std::mt19937 gen; + std::uniform_real_distribution<> dis(-2.0, 1.0); + + std::vector in_data(elts_num); + std::vector grad_data(elts_num); + + std::vector out_data_rowwise_h(rowwise ? elts_num : 0); + std::vector out_data_colwise_h(colwise ? elts_num : 0); + std::vector out_scales_rowwise_h(rowwise ? rowwise_sfs_num : 0); + std::vector out_scales_colwise_h(colwise ? colwise_sfs_num : 0); + + std::vector out_data_rowwise_ref(rowwise ? elts_num : 0); + std::vector out_data_colwise_ref(colwise ? elts_num : 0); + std::vector out_scales_rowwise_ref(rowwise ? rowwise_sfs_num : 0); + std::vector out_scales_colwise_ref(colwise ? colwise_sfs_num : 0); + + std::vector ref_output_dbias(sum_of_last_dims, static_cast(0.0f)); + + for (size_t i = 0; i < elts_num; ++i) { + const float val = dis(gen); + grad_data[i] = static_cast(val); + in_data[i] = static_cast(val); + } + + const OutputType zero_elt = static_cast(0.0f); + const fp8e8m0 zero_SF = static_cast(0.0f); + if (rowwise) { + std::fill(out_data_rowwise_h.begin(), out_data_rowwise_h.end(), zero_elt); + std::fill(out_data_rowwise_ref.begin(), out_data_rowwise_ref.end(), zero_elt); + std::fill(out_scales_rowwise_h.begin(), out_scales_rowwise_h.end(), zero_SF); + std::fill(out_scales_rowwise_ref.begin(), out_scales_rowwise_ref.end(), zero_SF); + } + if (colwise) { + std::fill(out_data_colwise_h.begin(), out_data_colwise_h.end(), zero_elt); + std::fill(out_data_colwise_ref.begin(), out_data_colwise_ref.end(), zero_elt); + std::fill(out_scales_colwise_h.begin(), out_scales_colwise_h.end(), zero_SF); + std::fill(out_scales_colwise_ref.begin(), out_scales_colwise_ref.end(), zero_SF); + } + + const size_t in_data_size = elts_num * sizeof(InputType); + const size_t out_data_size = elts_num * sizeof(OutputType); + const size_t dbias_data_size = sum_of_last_dims * sizeof(InputType); + const size_t rowwise_scales_size = rowwise_sfs_num * sizeof(fp8e8m0); + const size_t colwise_scales_size = colwise_sfs_num * sizeof(fp8e8m0); + + const size_t first_dims_size = num_tensors * sizeof(size_t); + const size_t last_dims_size = num_tensors * sizeof(size_t); + const size_t offsets_size = (num_tensors + 1) * sizeof(size_t); + + InputType* grad_data_d = nullptr; + InputType* in_data_d = nullptr; + InputType* dbias_out_data_d = nullptr; + OutputType* out_data_rowwise_d = nullptr; + OutputType* out_data_colwise_d = nullptr; + fp8e8m0* out_scales_rowwise_d = nullptr; + fp8e8m0* out_scales_colwise_d = nullptr; + size_t* first_dims_d = nullptr; + size_t* last_dims_d = nullptr; + size_t* offsets_d = nullptr; + + cudaMalloc((void**)&grad_data_d, in_data_size); + cudaMalloc((void**)&in_data_d, in_data_size); + cudaMalloc((void**)&first_dims_d, first_dims_size); + cudaMalloc((void**)&last_dims_d, last_dims_size); + cudaMalloc((void**)&offsets_d, offsets_size); + + cudaMemcpy(grad_data_d, grad_data.data(), in_data_size, cudaMemcpyHostToDevice); + cudaMemcpy(in_data_d, in_data.data(), in_data_size, cudaMemcpyHostToDevice); + cudaMemcpy(first_dims_d, first_dims_h.data(), first_dims_size, cudaMemcpyHostToDevice); + cudaMemcpy(last_dims_d, last_dims_h.data(), last_dims_size, cudaMemcpyHostToDevice); + cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice); + + NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); + + std::vector dbias_logical_shape_vec= {num_tensors, cols}; + NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(), + dbias_logical_shape_vec.size()); + + NVTEShape first_dims_shape_; + NVTEShape last_dims_shape_; + NVTEShape offsets_shape_; + + first_dims_shape_.ndim = 1; + last_dims_shape_.ndim = 1; + offsets_shape_.ndim = 1; + + first_dims_shape_.data[0] = num_tensors; + last_dims_shape_.data[0] = num_tensors; + offsets_shape_.data[0] = num_tensors + 1; + + NVTEGroupedTensor grad_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); + NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); + NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_); + NVTEGroupedTensor output_dbias_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, dbias_logical_shape_); + + NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast(itype), logical_shape_}; + NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; + nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, + &in_data_tensor, sizeof(in_data_tensor)); + nvte_set_grouped_tensor_param(grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, + &grad_data_tensor, sizeof(grad_data_tensor)); + + if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_}; + nvte_set_grouped_tensor_param(grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, + &first_dims_tensor, sizeof(first_dims_tensor)); + nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, + &first_dims_tensor, sizeof(first_dims_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, + &first_dims_tensor, sizeof(first_dims_tensor)); + } + + if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape_}; + nvte_set_grouped_tensor_param(grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, + &last_dims_tensor, sizeof(last_dims_tensor)); + nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, + &last_dims_tensor, sizeof(last_dims_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, + &last_dims_tensor, sizeof(last_dims_tensor)); + } + + if (shape_rep != SAME_BOTH_DIMS) { + NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_}; + nvte_set_grouped_tensor_param(grad_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, + &offsets_tensor, sizeof(offsets_tensor)); + nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, + &offsets_tensor, sizeof(offsets_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, + &offsets_tensor, sizeof(offsets_tensor)); + } + + if (rowwise) { + cudaMalloc((void**)&out_data_rowwise_d, out_data_size); + cudaMalloc((void**)&out_scales_rowwise_d, rowwise_scales_size); + cudaMemset(out_data_rowwise_d, 0, out_data_size); + cudaMemset(out_scales_rowwise_d, 0, rowwise_scales_size); + NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, static_cast(otype), logical_shape_}; + NVTEShape scales_rowwise_shape_ = nvte_make_shape(scales_rowwise_shape.data(), scales_rowwise_shape.size()); + NVTEBasicTensor out_scales_rowwise_tensor = {out_scales_rowwise_d, NVTEDType::kNVTEFloat8E8M0, scales_rowwise_shape_}; + nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, + &out_data_rowwise_tensor, sizeof(out_data_rowwise_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, + &out_scales_rowwise_tensor, sizeof(out_scales_rowwise_tensor)); + } + + if (colwise) { + cudaMalloc((void**)&out_data_colwise_d, out_data_size); + cudaMalloc((void**)&out_scales_colwise_d, colwise_scales_size); + cudaMemset(out_data_colwise_d, 0, out_data_size); + cudaMemset(out_scales_colwise_d, 0, colwise_scales_size); + NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, static_cast(otype), logical_shape_}; + NVTEShape scales_colwise_shape_ = nvte_make_shape(scales_colwise_shape.data(), scales_colwise_shape.size()); + NVTEBasicTensor out_scales_colwise_tensor = {out_scales_colwise_d, NVTEDType::kNVTEFloat8E8M0, scales_colwise_shape_}; + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, + &out_data_colwise_tensor, sizeof(out_data_colwise_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, + &out_scales_colwise_tensor, sizeof(out_scales_colwise_tensor)); + } + + if (compute_dbias) { + cudaMalloc((void**)&dbias_out_data_d, dbias_data_size); + cudaMemset(dbias_out_data_d, 0, dbias_data_size); + NVTEBasicTensor output_dbias_data_tensor = {dbias_out_data_d, static_cast(itype), dbias_logical_shape_}; + nvte_set_grouped_tensor_param(output_dbias_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, + &output_dbias_data_tensor, sizeof(output_dbias_data_tensor)); + } + + // Reference (CPU) + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; + + const size_t scales_stride_rowwise = rowwise_scales_last_dim[t]; + const size_t scales_stride_colwise = colwise_scales_last_dim[t]; + const size_t data_offset = offsets_h[t]; + const size_t rowwise_sfs_offset = rowwise_scales_offset[t]; + const size_t colwise_sfs_offset = colwise_scales_offset[t]; + const size_t dbias_offset = dbias_offsets[t]; + + const InputType* const grad_ptr = grad_data.data() + data_offset; + const InputType* const in_ptr = in_data.data() + data_offset; + OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; + OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; + fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset; + fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset; + InputType* const ref_output_dbias_ptr = ref_output_dbias.data() + dbias_offset; + + compute_ref( + processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, + out_data_rowwise_ptr, out_data_colwise_ptr, + out_scales_rowwise_ptr, out_scales_colwise_ptr, + ref_output_dbias_ptr, M, K, + scales_stride_rowwise, + scales_stride_colwise); + } + + // GPU + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_group_quantize(in_group_tensor, out_group_tensor, 0); + break; + } + case ProcessingMethod::CAST_DBIAS: { + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias_tensor, workspace.data(), 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias_tensor, workspace.data(), 0); + break; + } + case ProcessingMethod::CAST_DBIAS_DACT: { + auto nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dgelu; + if (OP == &dsilu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsilu; } + else if (OP == &drelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_drelu; } + else if (OP == &dqgelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dqgelu; } + else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; } + + nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, + output_dbias_tensor, workspace.data(), 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, + output_dbias_tensor, workspace.data(), 0); + break; + } + case ProcessingMethod::CAST_ACT: { + auto nvte_group_act = &nvte_group_gelu; + if (OP == &silu) { nvte_group_act = &nvte_group_silu; } + else if (OP == &relu) { nvte_group_act = &nvte_group_relu; } + else if (OP == &qgelu) { nvte_group_act = &nvte_group_qgelu; } + else if (OP == &srelu) { nvte_group_act = &nvte_group_srelu; } + nvte_group_act(in_group_tensor, out_group_tensor, 0); + break; + } + case ProcessingMethod::CAST_DACT: { + auto nvte_group_dact = &nvte_group_dgelu; + if (OP == &dsilu) { nvte_group_dact = &nvte_group_dsilu; } + else if (OP == &drelu) { nvte_group_dact = &nvte_group_drelu; } + else if (OP == &dqgelu) { nvte_group_dact = &nvte_group_dqgelu; } + else if (OP == &dsrelu) { nvte_group_dact = &nvte_group_dsrelu; } + nvte_group_dact(grad_group_tensor, in_group_tensor, out_group_tensor, 0); + break; + } + } + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol, rtol] = getTolerances(otype); + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; + + if (rowwise) { + cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost); + cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, rowwise_scales_size, cudaMemcpyDeviceToHost); + + size_t mismatches_scales = 0; +#ifdef USE_ROCM + std::vector mismatches_scales_indices; +#endif + compare_scaling_factors("rowwise_scales", out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(), + 1, rowwise_sfs_num, rowwise_sfs_num, +#ifdef USE_ROCM + mismatches_scales_indices, +#endif + mismatches_scales, scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); + +#ifdef USE_ROCM + if (::testing::Test::HasFatalFailure()) return; + adjust_ref_for_e8m0_scale_error("rowwise_scales", mismatches_scales_indices, + out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(), + rowwise_sfs_num, rows, cols, true, + out_data_rowwise_ref.data(), otype); + mismatches_scales = 0; +#endif + const size_t mismatches_elts = 32 * mismatches_scales; + + compare_scaled_elts("rowwise_output", out_data_rowwise_ref.data(), + out_data_rowwise_h.data(), rows, cols, true, mismatches_elts); + } + + if (colwise) { + cudaMemcpy(out_data_colwise_h.data(), out_data_colwise_d, out_data_size, cudaMemcpyDeviceToHost); + cudaMemcpy(out_scales_colwise_h.data(), out_scales_colwise_d, colwise_scales_size, cudaMemcpyDeviceToHost); + + size_t mismatches_scales = 0; +#ifdef USE_ROCM + std::vector mismatches_scales_indices; +#endif + compare_scaling_factors("colwise_scales", out_scales_colwise_h.data(), out_scales_colwise_ref.data(), + 1, colwise_sfs_num, colwise_sfs_num, +#ifdef USE_ROCM + mismatches_scales_indices, +#endif + mismatches_scales, scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); + +#ifdef USE_ROCM + if (::testing::Test::HasFatalFailure()) return; + adjust_ref_for_e8m0_scale_error("colwise_scales", mismatches_scales_indices, + out_scales_colwise_h.data(), out_scales_colwise_ref.data(), + colwise_sfs_num, rows, cols, false, + out_data_colwise_ref.data(), otype); + mismatches_scales = 0; +#endif + const size_t mismatches_elts = 32 * mismatches_scales; + + compare_scaled_elts("colwise_output", out_data_colwise_ref.data(), + out_data_colwise_h.data(), rows, cols, false, mismatches_elts); + } + + if (compute_dbias) { + Tensor output_dbias("output_dbias", std::vector{ sum_of_last_dims }, itype); + cudaMemcpy(output_dbias.rowwise_dptr(), dbias_out_data_d, dbias_data_size, cudaMemcpyDeviceToDevice); + + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + if (itype == DType::kFloat32) { + atol_dbias = 1e-4; + rtol_dbias *= sqrt(static_cast(rows)) ; + } else { + rtol_dbias *= 4; + } + compareResults("output_dbias", output_dbias, ref_output_dbias.data(), true, atol_dbias, rtol_dbias); + } + + cudaFree(grad_data_d); + cudaFree(in_data_d); + cudaFree(dbias_out_data_d); + cudaFree(first_dims_d); + cudaFree(last_dims_d); + cudaFree(offsets_d); + if (rowwise) { + cudaFree(out_data_rowwise_d); + cudaFree(out_scales_rowwise_d); + } + if (colwise) { + cudaFree(out_data_colwise_d); + cudaFree(out_scales_colwise_d); + } +} + +std::vector processing_methods = { + ProcessingMethod::CAST_ONLY, + ProcessingMethod::CAST_DBIAS, + ProcessingMethod::CAST_DBIAS_DACT, + ProcessingMethod::CAST_DACT, + ProcessingMethod::CAST_ACT, +}; + +std::vector activation_kinds = { + ActivationKind::Identity, + ActivationKind::GeLU, + // ActivationKind::SiLU, + // ActivationKind::ReLU, + // ActivationKind::QGeLU, + // ActivationKind::SReLU, +}; + +enum ScalingDirection { + ROWWISE = 0, + COLWISE = 1, + BOTH = 2 +}; + +std::vector scaling_directions = { + ScalingDirection::ROWWISE, + ScalingDirection::COLWISE, + ScalingDirection::BOTH, +}; + +// {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} +std::vector> input_config = { + {SAME_BOTH_DIMS, 1, 128,128}, + {SAME_BOTH_DIMS, 2, 256,128}, + {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, + {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, + {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, +}; + +} // namespace + +class GroupedFusedCastMXFP8TestSuite : public ::testing::TestWithParam + , // Config + transformer_engine::DType, // InputType + transformer_engine::DType // OutputType + >> {}; + +TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationKind activation = std::get<1>(GetParam()); + const ScalingDirection scaling_direction = std::get<2>(GetParam()); + const std::vector input_config = std::get<3>(GetParam()); + const DType input_type = std::get<4>(GetParam()); + const DType output_type = std::get<5>(GetParam()); + + const ShapeRepresentation shape_rep = static_cast(input_config[0]); + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); + + const size_t num_tensors = input_config[1]; + const std::vector logical_shape = {input_config[2], input_config[3]}; + std::vector first_dims(num_tensors); + std::vector last_dims(num_tensors); + std::vector offsets(num_tensors + 1, 0); + for (size_t t = 0; t < num_tensors; ++t) { + switch (shape_rep) { + case SAME_BOTH_DIMS: { + first_dims[t] = logical_shape[0] / num_tensors; + last_dims[t] = logical_shape[1]; + break; + } + case VARYING_FIRST_DIM: { + first_dims[t] = input_config[t + 4]; + last_dims[t] = logical_shape[1]; + break; + } + case VARYING_LAST_DIM: { + first_dims[t] = logical_shape[0]; + last_dims[t] = input_config[t + 4]; + break; + } + case VARYING_BOTH_DIMS: { + first_dims[t] = input_config[t + 4]; + last_dims[t] = input_config[t + (4 + num_tensors)]; + break; + } + } + offsets[t+1] = offsets[t] + first_dims[t] * last_dims[t]; + // Skip tests when the tensor shape is incompatible with the kernel. + // The TMA engine requires strides to be 16-byte aligned. + if ((first_dims[t] % 128 != 0) || (last_dims[t] % 16 != 0)) { + GTEST_SKIP(); + } + // If a grouped tensor has a varying last dimension, it must be a multiple of 128. + // Otherwise, computing the grid size adds runtime overhead in the non-persistent kernel, + // since the relevant tensor metadata resides in device memory. + constexpr size_t CHUNK_DIM_X = 128; + if (!is_single_tensor && (last_dims[t] % CHUNK_DIM_X != 0)) { + GTEST_SKIP(); + } + } + // Skip dBias tests when tensors in the group have different last dimensions. + if ((processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + && !is_single_tensor) { + GTEST_SKIP(); + } + + // Skip non-activation tests when the activation type is not Identity. + if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + && activation != ActivationKind::Identity) { + GTEST_SKIP(); + } + // Skip activation tests when the activation type is Identity. + if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + || processing_method == ProcessingMethod::CAST_DACT + || processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) { + GTEST_SKIP(); + } + + bool rowwise = false; + bool colwise = false; + switch (scaling_direction) { + case ScalingDirection::ROWWISE: rowwise = true; break; + case ScalingDirection::COLWISE: colwise = true; break; + case ScalingDirection::BOTH: rowwise = true; colwise = true; break; + } + + auto OP = &identity; + + if (processing_method == ProcessingMethod::CAST_ACT) { + switch (activation) { + case ActivationKind::GeLU: OP = &gelu; break; + case ActivationKind::SiLU: OP = &silu; break; + case ActivationKind::ReLU: OP = &relu; break; + case ActivationKind::QGeLU: OP = &qgelu; break; + case ActivationKind::SReLU: OP = &srelu; break; + case ActivationKind::Identity: /*ROCm: comiler warining*/ break; + } + } else if (processing_method == ProcessingMethod::CAST_DACT + || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + switch (activation) { + case ActivationKind::GeLU: OP = &dgelu; break; + case ActivationKind::SiLU: OP = &dsilu; break; + case ActivationKind::ReLU: OP = &drelu; break; + case ActivationKind::QGeLU: OP = &dqgelu; break; + case ActivationKind::SReLU: OP = &dsrelu; break; + case ActivationKind::Identity: /*ROCm: comiler warining*/ break; + } + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + performTest(processing_method, OP, shape_rep, num_tensors, + logical_shape, first_dims, last_dims, offsets, + rowwise, colwise); + ); + ); +} + +std::string to_string(const ProcessingMethod method) { + switch (method) { + case ProcessingMethod::CAST_ONLY: return "CAST_ONLY"; + case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS"; + case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT"; + case ProcessingMethod::CAST_DACT: return "CAST_DACT"; + case ProcessingMethod::CAST_ACT: return "CAST_ACT"; + default: return ""; + } +} + +std::string to_string(const ActivationKind activation) { + switch (activation) { + case ActivationKind::Identity: return "Identity"; + case ActivationKind::GeLU: return "GeLU"; + case ActivationKind::SiLU: return "SiLU"; + case ActivationKind::ReLU: return "ReLU"; + case ActivationKind::QGeLU: return "QGeLU"; + case ActivationKind::SReLU: return "SReLU"; + default: return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + GroupedFusedCastMXFP8TestSuite, + ::testing::Combine( + ::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(activation_kinds), + ::testing::ValuesIn(scaling_directions), + ::testing::ValuesIn(input_config), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + [](const testing::TestParamInfo& info) { + const ProcessingMethod method = std::get<0>(info.param); + std::string name = to_string(method); + name += "X" + to_string(std::get<1>(info.param)); + + switch (std::get<2>(info.param)) { + case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break; + case ScalingDirection::COLWISE: name += "_COLWISE_"; break; + case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break; + } + + const std::vector input = std::get<3>(info.param); + + switch(static_cast(input[0])) { + case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break; + case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break; + case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break; + case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break; + }; + + name += "_N_" + std::to_string(input[1]); + + name += "_SHAPE_" + + std::to_string(input[2]) + + "X" + std::to_string(input[3]); + + name += "_" + test::typeName(std::get<4>(info.param)) + + "_" + test::typeName(std::get<5>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 50f2d36fe..a8d8c08b9 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -64,12 +64,16 @@ std::vector create_transpose(const InputType* const input, const size } // Compute the global encode scale factor for a given global amax -float compute_global_encode_scaling_factor_FP4(const float global_amax) { +float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math) { constexpr float fp8_max = 448.0f; // 448.0f; constexpr float fp4_max = 6.0f; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; - // If scale is infinity, return max value of float32 - global_encode_scale = fminf(global_encode_scale, Numeric_Traits::maxNorm); + // If scale is infinity, return the max normalized value + const float max_norm_clamp = use_fast_math + ? Numeric_Traits::maxNorm + : Numeric_Traits::maxNorm; + + global_encode_scale = fminf(global_encode_scale, max_norm_clamp); // If global amax is 0 or infinity, return 1 if (global_amax == 0.0f || global_encode_scale == 0.0f) { return 1.0f; @@ -86,10 +90,11 @@ void quantize_nvfp4_1d(float (*OP)(const float), const size_t rows, const size_t cols, const size_t scales_stride, - const float global_amax) { + const float global_amax, + const bool use_fast_math) { // Compute a global encoding/decoding scaling factor for all S_dec_b - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); constexpr size_t block_size_X = 16; const size_t blocks_X = divide_round_up(cols, block_size_X); @@ -124,14 +129,20 @@ void quantize_nvfp4_1d(float (*OP)(const float), const float S_dec_b = block_amax / 6.0f; // Scale & Store per-block decoding scaling factor - const float S_dec_b_fp8 = S_dec_b * S_enc; + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); // Compute "correct" per-block encoding scaling factor - const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; const size_t scale_idx = i * scales_stride + block_X; - scales[scale_idx] = static_cast(S_dec_b_fp8); - const float scale_reciprocal = S_enc_b_fp8; + scales[scale_idx] = S_dec_b_fp8; + + float scale_reciprocal = S_enc_b_fp8; + if (use_fast_math) { + // Numerical truncation to match GPU implementation, if mixed precision FMA instruction is used + scale_reciprocal = static_cast(static_cast(scale_reciprocal)); + } for (size_t j = j_min; j < j_max; j += 2) { const int idx_pair = (i * cols + j) / 2; @@ -146,7 +157,7 @@ void quantize_nvfp4_1d(float (*OP)(const float), fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); output[idx_pair] = casted_to_e2m1_pair; - // const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); + const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); } } } @@ -159,9 +170,10 @@ void compute_2d_mathematical_scales(float (*OP)(const float), const size_t rows, const size_t cols, const float global_amax, - std::vector>& math_scales) { + std::vector>& math_scales, + const bool use_fast_math) { - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -205,13 +217,14 @@ void quantize_nvfp4_2d(float (*OP)(const float), const size_t rows, const size_t cols, const size_t scales_stride, - const float global_amax) { + const float global_amax, + const bool use_fast_math) { // Step 1: Compute mathematical 8x8 scaling factors std::vector> math_scales; - compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -292,11 +305,12 @@ void quantize_nvfp4(float (*OP)(const float), const size_t cols, const size_t scales_stride, const float global_amax, + const bool use_fast_math, const bool use_2d_quantization = false) { if (use_2d_quantization) { - quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); } else { - quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); } } @@ -312,6 +326,7 @@ void compute_ref(float (*OP)(const float), const size_t cols, const size_t scales_stride, const size_t scales_stride_t, + const bool use_fast_math, const bool use_2d_quantization = false) { std::vector input_t = create_transpose(input, rows, cols); @@ -319,7 +334,7 @@ void compute_ref(float (*OP)(const float), if (use_2d_quantization) { // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; - compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; @@ -346,12 +361,16 @@ void compute_ref(float (*OP)(const float), // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d // (This part processes the actual FP4 data using the mathematical scaling factors) - quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled - quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled + quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax, + use_fast_math); // scales already filled + quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax, + use_fast_math); // scales_t already filled } else { - quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization); - quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization); + quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, + use_fast_math, use_2d_quantization); + quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, + use_fast_math, use_2d_quantization); } } @@ -359,6 +378,8 @@ void compare_nvfp4_tensors(const std::string& name, const fp4e2m1 *test_data, const fp4e2m1 *ref_data, const int rows, const int cols, double atol = 1e-5, double rtol = 1e-8) { + constexpr int max_mismatches_to_print = 3; + std::vector mismatch_messages; size_t total_mismatches = 0; @@ -372,10 +393,11 @@ void compare_nvfp4_tensors(const std::string& name, const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); - bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); - /* For Float32 the floating point comparison is enough to error out */ - bool assertion = false; - if (mismatch && !assertion) { +#ifndef __HIP_PLATFORM_AMD__ + const bool mismatch = fabs(t - r) > (atol + fabs(r) * rtol); +#else + bool mismatch = fabs(t - r) > (atol + fabs(r) * rtol); + if (mismatch) { /* Check if it is just a failure of round to nearest choosing different side of the real value */ const double mean = (t + r) / 2; @@ -383,18 +405,18 @@ void compare_nvfp4_tensors(const std::string& name, const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); const double cast_mean_p = static_cast(static_cast(mean_p)); const double cast_mean_m = static_cast(static_cast(mean_m)); - assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + mismatch = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); } - if (assertion) { +#endif + if (mismatch) { total_mismatches++; - std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + - std::to_string(t) + " vs " + std::to_string(r) + - " (abs_diff: " + std::to_string(fabs(t - r)) + - ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; - mismatch_messages.push_back(msg); - // Optional: limit number of detailed messages to avoid overwhelming output - if (mismatch_messages.size() <= 100) { + if (total_mismatches <= max_mismatches_to_print) { + std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + + std::to_string(t) + " vs " + std::to_string(r) + + " (abs_diff: " + std::to_string(fabs(t - r)) + + ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; + mismatch_messages.push_back(msg); std::cout << "Error in tensor " << name << ": " << msg << std::endl; } } @@ -410,8 +432,9 @@ void compare_nvfp4_tensors(const std::string& name, std::cout << "STATUS: FAILED for output" << std::endl; std::cout << "Total mismatches found: " << total_mismatches << std::endl; std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; - if (mismatch_messages.size() > 100) { - std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; + if (mismatch_messages.size() > max_mismatches_to_print) { + std::cout << "... and " << (mismatch_messages.size() - max_mismatches_to_print) + << " more mismatches (showing first " << max_mismatches_to_print << ")" << std::endl; } std::cout << "============================" << std::endl; @@ -529,7 +552,8 @@ void compareResults_nvfp4(const Tensor &test, template void performTest(float (*OP)(const float), - const std::vector& shape) { + const std::vector& shape, + const bool use_fast_math) { using namespace test; DType itype = TypeInfo::dtype; @@ -601,15 +625,16 @@ void performTest(float (*OP)(const float), cols, scales_stride, scales_stride_t, + use_fast_math, use_2d_quantization); - - QuantizationConfigWrapper quant_config; - // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed rng_state.rowwise_cpu_dptr()[1] = 321; // rng_sequence rng_state.from_cpu(); + + QuantizationConfigWrapper quant_config; + quant_config.set_use_fast_math(use_fast_math); #ifdef __HIP_PLATFORM_AMD__ quant_config.set_stochastic_rounding(use_stochastic_rounding); #else @@ -644,8 +669,8 @@ void performTest(float (*OP)(const float), } ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - const double atol = 0.05; - const double rtol = 0.1; + const double atol = 1.0E-6; + const double rtol = 1.0E-6; // Set dump_data=true to enable dumping tensor data to files for analysis compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); @@ -710,7 +735,8 @@ std::vector Activation_types = { class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { #ifndef __HIP_PLATFORM_AMD__ @@ -726,6 +752,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const ActivationType Act_type = std::get<0>(GetParam()); const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); + const bool use_fast_math = std::get<3>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -743,7 +770,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims); + performTest(OP, tensor_dims, use_fast_math); ); } @@ -765,7 +792,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(tensor_dims), - ::testing::Values(DType::kBFloat16)), + ::testing::Values(DType::kBFloat16), + ::testing::Values(false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); const auto& shape = std::get<1>(info.param); @@ -773,5 +801,8 @@ INSTANTIATE_TEST_SUITE_P( name += "X" + std::to_string(s); } name += "X" + test::typeName(std::get<2>(info.param)); + if (std::get<3>(info.param)) { + name += "X_FAST_SCALING"; + } return name; }); diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu new file mode 100644 index 000000000..34bb729b2 --- /dev/null +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -0,0 +1,401 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum class InputCase { + kFP8Current, + kBF16, + kMXFP8, +}; + +enum class ShapeCase { + kAllSame, + kSameFirst, + kSameLast, + kAllDifferent, +}; + +size_t grouped_setup_workspace_size(const size_t num_tensors) { + const size_t ptr_bytes = num_tensors * sizeof(void*); + const size_t int_bytes = num_tensors * sizeof(int); + // Layout: 8 pointer arrays (A, B, C, D, alpha, beta, a_scale, b_scale) + 6 int arrays + size_t size = 8 * ptr_bytes + 6 * int_bytes; + const size_t alignment = 256; + size = ((size + alignment - 1) / alignment) * alignment; + return size; +} + +Tensor make_fp8_operand(const std::string& name, const std::vector& shape) { + Tensor input_fp32(name + "_fp32", shape, DType::kFloat32); + + const size_t numel = shape[0] * shape[1]; + std::vector data(numel); + std::mt19937 gen(std::hash{}(name)); + // Random mean and stddev -> different amax per tensor -> different scales + std::uniform_real_distribution param_dis(0.1f, 10.0f); + float mean = param_dis(gen); + float stddev = param_dis(gen); + std::normal_distribution dis(mean, stddev); + for (size_t i = 0; i < numel; ++i) { + data[i] = dis(gen); + } + NVTE_CHECK_CUDA(cudaMemcpy(input_fp32.rowwise_dptr(), data.data(), + numel * sizeof(float), cudaMemcpyHostToDevice)); + + Tensor fp8(name, shape, TypeInfo::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING); + + nvte_compute_amax(input_fp32.data(), fp8.data(), 0); + QuantizationConfigWrapper config; + nvte_compute_scale_from_amax(fp8.data(), config, 0); + nvte_quantize(input_fp32.data(), fp8.data(), 0); + return fp8; +} + +Tensor make_bf16_operand(const std::string& name, const std::vector& shape) { + Tensor t(name, shape, DType::kBFloat16); + const size_t numel = shape[0] * shape[1]; + std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f)); + NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(), + numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice)); + return t; +} + + +// Creates an MXFP8 operand with the correct data layout for GEMM. +// MXFP8 GEMM requirements (scales are along K dimension): +// A transposed -> needs rowwise data/scales +// A non-transposed -> needs columnwise data/scales +// B transposed -> needs columnwise data/scales +// B non-transposed -> needs rowwise data/scales +Tensor make_mxfp8_operand(const std::string& name, const std::vector& shape, + bool is_A, bool transposed) { + // Determine which data layout we need + bool use_rowwise, use_colwise; + if (is_A) { + // A: transposed -> rowwise, non-transposed -> columnwise + use_rowwise = transposed; + use_colwise = !transposed; + } else { + // B: transposed -> columnwise, non-transposed -> rowwise (opposite of A!) + use_rowwise = !transposed; + use_colwise = transposed; + } + + // Create BF16 input with random data + Tensor input_bf16(name + "_bf16", shape, DType::kBFloat16); + fillUniform(&input_bf16); + + // Create MXFP8 tensor with only the required data layout + Tensor mxfp8(name, shape, TypeInfo::dtype, use_rowwise, use_colwise, + NVTE_MXFP8_1D_SCALING); + + // Quantize BF16 -> MXFP8 + nvte_quantize(input_bf16.data(), mxfp8.data(), 0); + + // Create output tensor for swizzled scales (same data shape, same layout) + Tensor mxfp8_swizzled(name + "_swizzled", shape, TypeInfo::dtype, + use_rowwise, use_colwise, NVTE_MXFP8_1D_SCALING); + mxfp8_swizzled.set_with_gemm_swizzled_scales(true); // Must be set BEFORE swizzle call + + // Copy quantized data from mxfp8 to mxfp8_swizzled + if (use_rowwise) { + size_t data_bytes = test::bytes(mxfp8.rowwise_shape(), mxfp8.dtype()); + NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.rowwise_dptr(), mxfp8.rowwise_dptr(), + data_bytes, cudaMemcpyDeviceToDevice)); + } + if (use_colwise) { + size_t data_bytes = test::bytes(mxfp8.columnwise_shape(), mxfp8.dtype()); + NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.columnwise_dptr(), mxfp8.columnwise_dptr(), + data_bytes, cudaMemcpyDeviceToDevice)); + } + + // Swizzle scales for GEMM + nvte_swizzle_scaling_factors(mxfp8.data(), mxfp8_swizzled.data(), 0); + + // Sync to ensure operations are complete + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + return mxfp8_swizzled; +} + +struct TestParams { + InputCase input_case; + bool transa; + bool transb; + ShapeCase shape_case; + bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0) +}; + +// Returns a vector of (M, N, K) tuples for each GEMM in the group. +// M - number of rows in output D +// N - number of columns in output D +// K - reduction dimension shared between A and B +std::vector> make_shapes(ShapeCase scase) { + switch (scase) { + case ShapeCase::kAllSame: + return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}}; + case ShapeCase::kSameFirst: + // Same M (first dim), varying N and K + return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}}; + case ShapeCase::kSameLast: + // Same N (last dim), varying M and K + return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}}; + case ShapeCase::kAllDifferent: + default: + return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}}; + } +} + +void run_grouped_gemm_case(const TestParams& params) { +#if CUBLAS_VERSION < 130200 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " + << CUBLAS_VERSION << "."; +#else + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; + } + + const std::vector> shapes = make_shapes(params.shape_case); + + const size_t num_gemms = shapes.size(); + std::vector A_tensors; + std::vector B_tensors; + std::vector D_multi; + + A_tensors.reserve(num_gemms); + B_tensors.reserve(num_gemms); + D_multi.reserve(num_gemms); + + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + const std::vector a_shape = params.transa ? std::vector{N, K} + : std::vector{K, N}; + const std::vector b_shape = params.transb ? std::vector{K, M} + : std::vector{M, K}; + switch (params.input_case) { + case InputCase::kFP8Current: { + A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kBF16: { + A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kMXFP8: { + A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape, + /*is_A=*/true, params.transa)); + B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape, + /*is_A=*/false, params.transb)); + break; + } + } + D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + } + + std::vector A_ptrs(num_gemms); + std::vector B_ptrs(num_gemms); + std::vector D_ptrs(num_gemms); + std::vector workspaces(num_gemms); + std::vector workspace_ptrs(num_gemms, nullptr); + std::vector A_views; + std::vector B_views; + A_views.reserve(num_gemms); + B_views.reserve(num_gemms); + + // Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues) + std::vector bias_ptrs(num_gemms, nullptr); + std::vector gelu_ptrs(num_gemms, nullptr); + + const size_t cublas_ws_bytes = 32ull * 1024 * 1024; + + for (size_t i = 0; i < num_gemms; ++i) { + A_ptrs[i] = A_tensors[i].data(); + B_ptrs[i] = B_tensors[i].data(); + D_ptrs[i] = D_multi[i].data(); + workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); + workspace_ptrs[i] = workspaces[i].data(); + A_views.push_back(&A_tensors[i]); + B_views.push_back(&B_tensors[i]); + } + + nvte_multi_tensor_gemm(A_ptrs.data(), + B_ptrs.data(), + D_ptrs.data(), + bias_ptrs.data(), + gelu_ptrs.data(), + static_cast(num_gemms), + params.transa, + params.transb, + false, // grad + workspace_ptrs.data(), + false, // accumulate + false, // use_split_accumulator + 0, // sm_count + 0); + + GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); + + std::vector C_tensors; + std::vector D_group_tensors; + C_tensors.reserve(num_gemms); + D_group_tensors.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + (void)K; + if (!params.use_null_c) { + C_tensors.emplace_back(Tensor("C" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + } + D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype()))); + } + + std::vector C_views, D_views; + for (size_t i = 0; i < num_gemms; ++i) { + if (!params.use_null_c) { + C_views.push_back(&C_tensors[i]); + } + D_views.push_back(&D_group_tensors[i]); + } + + std::optional grouped_C; + if (!params.use_null_c) { + grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + } + GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); + + // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) + Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); + Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); + std::vector alpha_vals(num_gemms, 1.f); + std::vector beta_vals(num_gemms, 0.f); + NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + + const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); + Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); + Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); + + nvte_grouped_gemm(grouped_A.get_handle(), + params.transa, + grouped_B.get_handle(), + params.transb, + params.use_null_c ? nullptr : grouped_C->get_handle(), + grouped_D.get_handle(), + alpha_tensor.data(), + beta_tensor.data(), + setup_ws.data(), + cublas_ws.data(), + nullptr, // config (use defaults) + 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // Compare results + for (size_t i = 0; i < num_gemms; ++i) { + Tensor grouped_split("grouped_D" + std::to_string(i), + std::vector{static_cast(std::get<0>(shapes[i])), + static_cast(std::get<1>(shapes[i]))}, + D_multi[i].dtype()); + const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), + static_cast(grouped_D.get_data()) + offset_bytes, + grouped_D.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + grouped_split.to_cpu(); + D_multi[i].to_cpu(); + auto [atol, rtol] = getTolerances(D_multi[i].dtype()); + compareResults("grouped_vs_multi", + grouped_split, + D_multi[i].rowwise_cpu_dptr(), + true, + atol, + rtol); + } +#endif // CUBLAS_VERSION >= 130200 +} + +class GroupedGemmTest : public ::testing::TestWithParam {}; + +TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) { + run_grouped_gemm_case(GetParam()); +} + +std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { + constexpr const char* kInputNames[] = {"FP8Current", "BF16", "MXFP8"}; + constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; + const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") + + "tb" + (info.param.transb ? "T" : "N"); + const std::string null_c = info.param.use_null_c ? "_NullC" : ""; + return std::string(kInputNames[static_cast(info.param.input_case)]) + "_" + + kShapeNames[static_cast(info.param.shape_case)] + "_" + layout + null_c; +} + +// TestParams: {input_case, transa, transb, shape_case, use_null_c} +const std::vector kTestParams = { + // FP8 tests (each tensor has random mean/stddev -> different scales) + {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, + {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, + {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, + // BF16 tests + {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, + {InputCase::kBF16, false, true, ShapeCase::kSameLast, false}, + {InputCase::kBF16, false, false, ShapeCase::kAllSame, false}, + {InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false}, + // Test NULL C (valid when beta=0) + {InputCase::kBF16, false, false, ShapeCase::kAllSame, true}, + // MXFP8 tests + {InputCase::kMXFP8, true, false, ShapeCase::kAllSame, false}, + {InputCase::kMXFP8, true, false, ShapeCase::kAllDifferent, false}, + {InputCase::kMXFP8, false, true, ShapeCase::kAllSame, false}, + {InputCase::kMXFP8, false, true, ShapeCase::kAllDifferent, false}, + {InputCase::kMXFP8, false, false, ShapeCase::kAllSame, false}, + {InputCase::kMXFP8, false, false, ShapeCase::kAllDifferent, false}, + {InputCase::kMXFP8, false, false, ShapeCase::kSameFirst, false}, + // MXFP8 with NULL C + {InputCase::kMXFP8, true, false, ShapeCase::kAllSame, true}, +}; + +INSTANTIATE_TEST_SUITE_P(OperatorTest, + GroupedGemmTest, + ::testing::ValuesIn(kTestParams), + MakeGroupedGemmTestName); + +} // namespace diff --git a/tests/cpp/operator/test_splits_to_offsets.cu b/tests/cpp/operator/test_splits_to_offsets.cu new file mode 100644 index 000000000..faac4b7b6 --- /dev/null +++ b/tests/cpp/operator/test_splits_to_offsets.cu @@ -0,0 +1,80 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include + +#include +#include "../test_common.h" + +class SplitsToOffsetsTestSuite : public ::testing::TestWithParam> {}; + +TEST_P(SplitsToOffsetsTestSuite, TestSplitsToOffsets) { + const size_t num_tensors = std::get<0>(GetParam()); + const int64_t logical_last_dim = std::get<1>(GetParam()); + + std::vector h_first_dims(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + h_first_dims[i] = static_cast((i % 17) + 1); + } + + std::vector h_expected(num_tensors + 1, 0); + for (size_t i = 0; i < num_tensors; ++i) { + h_expected[i + 1] = h_expected[i] + h_first_dims[i] * logical_last_dim; + } + + std::vector h_output(num_tensors + 1, -1); + + int64_t *d_first_dims = nullptr; + int64_t *d_output = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&d_first_dims, sizeof(int64_t) * num_tensors)); + NVTE_CHECK_CUDA(cudaMalloc(&d_output, sizeof(int64_t) * (num_tensors + 1))); + NVTE_CHECK_CUDA(cudaMemcpy(d_first_dims, h_first_dims.data(), sizeof(int64_t) * num_tensors, + cudaMemcpyHostToDevice)); + + nvte_splits_to_offsets(d_first_dims, d_output, num_tensors, logical_last_dim, 0 /* stream */); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + NVTE_CHECK_CUDA(cudaMemcpy(h_output.data(), d_output, sizeof(int64_t) * (num_tensors + 1), + cudaMemcpyDeviceToHost)); + + NVTE_CHECK_CUDA(cudaFree(d_first_dims)); + NVTE_CHECK_CUDA(cudaFree(d_output)); + + for (size_t i = 0; i < h_output.size(); ++i) { + EXPECT_EQ(h_output[i], h_expected[i]) + << "Mismatch at index " << i << ": expected " << h_expected[i] << ", got " << h_output[i]; + } +} + +namespace { + +std::vector splits_to_offsets_num_tensors = { + 1, + 4, + 255, + 256, + 257, + 1024, +}; + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, SplitsToOffsetsTestSuite, + ::testing::Combine(::testing::ValuesIn(splits_to_offsets_num_tensors), + ::testing::Values(static_cast(1), static_cast(7), + static_cast(128))), + [](const testing::TestParamInfo &info) { + std::string name = std::to_string(std::get<0>(info.param)) + "X" + + std::to_string(std::get<1>(info.param)); + return name; + }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 062eb6b5e..188f30ce7 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -1282,4 +1283,250 @@ std::array get_scale_tensor_dims(const size_t rows, return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X}; } +GroupedBuffers build_grouped_tensor(const std::vector& tensors, + const NVTEScalingMode scaling_mode) { + NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build."); + + // Check which data layouts are available (all tensors must have the same) + const bool has_rowwise = tensors[0]->rowwise(); + const bool has_columnwise = tensors[0]->columnwise(); + NVTE_CHECK(has_rowwise || has_columnwise, "Tensors must have at least one data layout."); + + const NVTEShape shape = has_rowwise ? tensors[0]->rowwise_shape() + : tensors[0]->columnwise_shape(); + const DType dtype = tensors[0]->dtype(); + const size_t num_tensors = tensors.size(); + const size_t elem_size = typeToNumBits(dtype) / 8; + GroupedBuffers grouped; + grouped.elem_size = elem_size; + grouped.num_tensors = num_tensors; + grouped.dtype = dtype; + grouped.scaling_mode = scaling_mode; + grouped.tensor_bytes.resize(num_tensors); + grouped.offsets_host.resize(num_tensors, 0); + + std::vector first_dims(num_tensors); + std::vector last_dims(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + const auto s = has_rowwise ? tensors[i]->rowwise_shape() + : tensors[i]->columnwise_shape(); + NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors."); + first_dims[i] = static_cast(s.data[0]); + last_dims[i] = static_cast(s.data[1]); + grouped.tensor_bytes[i] = bytes(s, dtype); + } + + const bool same_first = std::all_of(first_dims.begin(), first_dims.end(), + [&](int64_t v) { return v == first_dims[0]; }); + const bool same_last = std::all_of(last_dims.begin(), last_dims.end(), + [&](int64_t v) { return v == last_dims[0]; }); + + std::vector offsets(num_tensors, 0); + auto random_padding = [&]() -> int64_t { + // Random padding ensuring 16-byte alignment regardless of element size + // cuBLAS requires aligned pointers for vectorized loads + static std::mt19937 gen(12345); + std::uniform_int_distribution dist(0, 3); + // Calculate elements needed for 16-byte alignment in bytes, rounded up + const size_t align_elements = + std::max(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size + return dist(gen) * static_cast(align_elements); + }; + + auto numel = [&](size_t idx) -> int64_t { + return first_dims[idx] * last_dims[idx]; + }; + + const bool need_offsets = !same_first || !same_last; + const bool use_random_padding = need_offsets && scaling_mode != NVTE_MXFP8_1D_SCALING; + if (need_offsets) { + offsets[0] = 0; + for (size_t i = 1; i < num_tensors; ++i) { + offsets[i] = offsets[i - 1] + numel(i - 1) + (use_random_padding ? random_padding() : 0); + } + } else { + for (size_t i = 0; i < num_tensors; ++i) { + offsets[i] = static_cast(i) * numel(0); + } + } + grouped.offsets_host = offsets; + + int64_t logical_first = 0; + int64_t logical_last = 0; + if (same_first && same_last) { + logical_first = first_dims[0] * static_cast(num_tensors); + logical_last = last_dims[0]; + } else if (same_first && !same_last) { + logical_first = first_dims[0]; + logical_last = std::accumulate(last_dims.begin(), last_dims.end(), int64_t{0}); + } else if (!same_first && same_last) { + logical_first = std::accumulate(first_dims.begin(), first_dims.end(), int64_t{0}); + logical_last = last_dims[0]; + } else { + logical_first = 1; + logical_last = 0; + for (size_t i = 0; i < num_tensors; ++i) { + logical_last += first_dims[i] * last_dims[i]; + } + } + size_t logical_data[2] = {static_cast(logical_first), + static_cast(logical_last)}; + grouped.logical_shape = nvte_make_shape(logical_data, 2); + grouped.handle.reset(nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape)); + + const int64_t last_idx = static_cast(num_tensors - 1); + const int64_t total_elems = need_offsets + ? (offsets[last_idx] + numel(last_idx)) + : (logical_first * logical_last); + const size_t total_bytes = static_cast(total_elems) * elem_size; + + NVTEGroupedTensor h = grouped.handle.get(); + + size_t total_elems_size = static_cast(total_elems); + NVTEShape flat_shape = nvte_make_shape(&total_elems_size, 1); + // Copy rowwise data if available + if (has_rowwise) { + grouped.data = cuda_alloc(total_bytes); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data.get()) + offset_bytes, + tensors[i]->rowwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + NVTEBasicTensor data_tensor{grouped.data.get(), static_cast(dtype), flat_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor)); + } + + // Copy columnwise data if available + if (has_columnwise) { + grouped.columnwise_data = cuda_alloc(total_bytes); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data.get()) + offset_bytes, + tensors[i]->columnwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + NVTEBasicTensor col_tensor{grouped.columnwise_data.get(), + static_cast(dtype), + flat_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseData, &col_tensor, sizeof(col_tensor)); + } + + if (!same_first) { + grouped.first_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev.get(), first_dims.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedFirstDims, &fd_tensor, sizeof(fd_tensor)); + } + + if (!same_last) { + grouped.last_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev.get(), last_dims.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedLastDims, &ld_tensor, sizeof(ld_tensor)); + } + + if (!same_first || !same_last) { + grouped.offsets_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor)); + } + + if (isFp8Type(dtype) && scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + // FP8 tensor scaling: one float scale_inv per tensor + // For delayed scaling, rowwise and columnwise share the same scale + std::vector scale_inv_cpu(num_tensors, 1.f); + for (size_t i = 0; i < num_tensors; ++i) { + tensors[i]->to_cpu(); + if (has_rowwise) { + scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; + } else { + scale_inv_cpu[i] = tensors[i]->columnwise_cpu_scale_inv_ptr()[0]; + } + } + grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(), + sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); + NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &scale_tensor, + sizeof(scale_tensor)); + nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor, + sizeof(scale_tensor)); + } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + // MXFP8: E8M0 scale_inv per block of 32 elements + // Helper to gather scale_inv from individual tensors into a contiguous buffer + auto gather_scales = [&]( + auto get_shape_fn, + auto get_cpu_ptr_fn) -> std::pair, size_t> { + // Compute total size and offsets + size_t total_bytes = 0; + std::vector scale_offsets(num_tensors); + std::vector numels(num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + scale_offsets[i] = total_bytes; + const NVTEShape shape = get_shape_fn(tensors[i]); + size_t numel = 1; + for (size_t d = 0; d < shape.ndim; ++d) { + numel *= shape.data[d]; + } + numels[i] = numel; + total_bytes += numel; // E8M0 is 1 byte per element + } + + // Allocate and copy + CudaPtr<> buffer = cuda_alloc(total_bytes); + for (size_t i = 0; i < num_tensors; ++i) { + tensors[i]->to_cpu(); + NVTE_CHECK_CUDA(cudaGetLastError()); + void* dst = static_cast(buffer.get()) + scale_offsets[i]; + const void* src = get_cpu_ptr_fn(tensors[i]); + NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice)); + } + return {std::move(buffer), total_bytes}; + }; + + // Gather rowwise scale_inv if available + if (has_rowwise) { + auto [row_buffer, row_total] = gather_scales( + [](Tensor* t) { return t->rowwise_scale_inv_shape(); }, + [](Tensor* t) { return t->rowwise_cpu_scale_inv_ptr(); }); + grouped.scale_inv = std::move(row_buffer); + + NVTEShape row_shape = nvte_make_shape(&row_total, 1); + NVTEBasicTensor row_tensor{grouped.scale_inv.get(), kNVTEFloat8E8M0, row_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &row_tensor, sizeof(row_tensor)); + } + + // Gather columnwise scale_inv if available + if (has_columnwise) { + auto [col_buffer, col_total] = gather_scales( + [](Tensor* t) { return t->columnwise_scale_inv_shape(); }, + [](Tensor* t) { return t->columnwise_cpu_scale_inv_ptr(); }); + grouped.columnwise_scale_inv = std::move(col_buffer); + + NVTEShape col_shape = nvte_make_shape(&col_total, 1); + NVTEBasicTensor col_tensor{grouped.columnwise_scale_inv.get(), kNVTEFloat8E8M0, col_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &col_tensor, sizeof(col_tensor)); + } + + // Mark as having swizzled scales (required for GEMM) + const uint8_t swizzled = 1; + nvte_set_grouped_tensor_param(h, kNVTEGroupedWithGEMMSwizzledScales, &swizzled, + sizeof(swizzled)); + } + + return grouped; +} + } // namespace test diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 796e66999..03cc49114 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -483,10 +483,14 @@ inline fp8e8m0 float_to_e8m0(float val) { } inline float exp2f_rcp(fp8e8m0 biased_exp) { - if (biased_exp == 0) { - return 1.0f; + int32_t int_val = 0; + if (biased_exp == 255) { + int_val = 0x7fffffff; + } else if (biased_exp == 254) { + int_val = 0x00400000; + } else { + int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127) } - int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127) float fp32_val = *reinterpret_cast(&int_val); return fp32_val; } @@ -572,6 +576,61 @@ int32_t getDeviceComputeCapability(); constexpr int32_t hopperComputeCapability = 90; constexpr int32_t blackwellComputeCapability = 100; +// Custom deleters for RAII +struct CudaDeleter { + void operator()(void* p) const { if (p) cudaFree(p); } +}; +struct GroupedTensorDeleter { + void operator()(NVTEGroupedTensor h) const { if (h) nvte_destroy_grouped_tensor(h); } +}; + +template +using CudaPtr = std::unique_ptr; +using GroupedTensorHandle = std::unique_ptr, GroupedTensorDeleter>; + +// Helper to allocate CUDA memory into a CudaPtr +template +CudaPtr cuda_alloc(size_t bytes) { + void* ptr = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&ptr, bytes)); + return CudaPtr(static_cast(ptr)); +} + +// Helper owning GPU buffers that back NVTEGroupedTensor. +// NVTEGroupedTensor does not own memory; data/offsets/scales +// must be allocated and freed by the test. +struct GroupedBuffers { + GroupedTensorHandle handle; + CudaPtr<> data; + CudaPtr<> scale_inv; + CudaPtr<> columnwise_scale_inv; + CudaPtr first_dims_dev; + CudaPtr last_dims_dev; + CudaPtr offsets_dev; + CudaPtr<> columnwise_data; + NVTEShape logical_shape{}; + std::vector offsets_host; + std::vector tensor_bytes; + size_t num_tensors{0}; + size_t elem_size{0}; + DType dtype{DType::kFloat32}; + NVTEScalingMode scaling_mode{NVTE_DELAYED_TENSOR_SCALING}; + + GroupedBuffers() = default; + GroupedBuffers(const GroupedBuffers&) = delete; + GroupedBuffers& operator=(const GroupedBuffers&) = delete; + GroupedBuffers(GroupedBuffers&&) = default; + GroupedBuffers& operator=(GroupedBuffers&&) = default; + ~GroupedBuffers() = default; + + // Convenience accessors for raw pointers + NVTEGroupedTensor get_handle() const { return handle.get(); } + void* get_data() const { return data.get(); } +}; + +GroupedBuffers build_grouped_tensor(const std::vector& tensors, + const NVTEScalingMode scaling_mode); + } // namespace test #if FP4_TYPE_SUPPORTED diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 6b7520d14..db30f0ed3 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -11,6 +11,10 @@ import transformer_engine.jax from transformer_engine_jax import get_device_compute_capability +from transformer_engine.jax.version_utils import ( + TRITON_EXTENSION_MIN_JAX_VERSION, + is_triton_extension_supported, +) @pytest.fixture(autouse=True, scope="function") @@ -83,5 +87,28 @@ def pytest_sessionfinish(self, session, exitstatus): def pytest_configure(config): + config.addinivalue_line( + "markers", + "triton: mark test (or test class) as requiring JAX Triton kernel support" + f" (JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION})." + " Apply per test/class with @pytest.mark.triton so non-Triton tests in the same file run on" + " old JAX.", + ) if os.getenv("NVTE_JAX_TEST_TIMING", "0") == "1": config.pluginmanager.register(TestTimingPlugin(), "test_timing") + + +def pytest_collection_modifyitems(config, items): + """Skip tests marked 'triton' when JAX is too old for Triton kernel dispatch.""" + if is_triton_extension_supported(): + return + skip_triton = pytest.mark.skip( + reason=( + f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for Triton kernel support. " + "Triton kernel dispatch segfaults with older jaxlib. " + "Upgrade with: pip install --upgrade jax jaxlib" + ) + ) + for item in items: + if item.get_closest_marker("triton"): + item.add_marker(skip_triton) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 3d9362cb4..c228d0bf3 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1826,8 +1826,6 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout, use_async_d2h_group dtype, input_shape, layout ) if use_async_d2h_group_size: - if is_hip_extension(): - pytest.skip("ROCm does not support use_async_d2h_group_sizes yet.") num_gemms = input_shape[0] _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))( group_sizes, @@ -1976,3 +1974,37 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype) + + +class TestDebugInspectFFI: + + @pytest_parametrize_wrapper("shape", [(256, 128)]) + @pytest_parametrize_wrapper( + "dtype", + [ + jnp.float32, + jnp.bfloat16, + jnp.float16, + # Note: fp4 currently doesn't work + # jnp.float4_e2m1fn + ] + + ([jnp_float8_e4m3_type, jnp_float8_e5m2_type] if is_fp8_supported else []), + ) + def test_debug_inspect_ffi(self, shape, dtype): + from transformer_engine.jax.debug.experimental import inspect_array, load_array_dump + + def f(x): + x = x + 1 + x = inspect_array(x, "my_array") + x = x + 1 + return x + + key = jax.random.PRNGKey(0) + x = jax.random.uniform(key, shape, jnp.float32) + x = x.astype(dtype) + _ = jax.jit(f)(x) + + expected = x + 1 + actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype) + + assert_allclose(actual, expected, dtype=dtype) diff --git a/tests/jax/test_distributed_dense.py b/tests/jax/test_distributed_dense.py index b8caf188d..0c2ac8b24 100644 --- a/tests/jax/test_distributed_dense.py +++ b/tests/jax/test_distributed_dense.py @@ -161,16 +161,21 @@ def test_distributed_gemm( # Compare results assert_allclose(gathered_te, gathered_jax, dtype=dtype) - def _te_sum_dense(self, x, weight, bias, contracting_dims): + def _te_sum_dense(self, x, weight, bias, contracting_dims, output_sharding): """TE GEMM function for gradient testing""" - return jnp.sum(dense(x, weight, bias=bias, contracting_dims=contracting_dims)) + output = dense(x, weight, bias=bias, contracting_dims=contracting_dims) + if output_sharding is not None: + output = jax.lax.with_sharding_constraint(output, output_sharding) + return jnp.sum(output) - def _jax_sum_dense(self, x, weight, bias, contracting_dims): + def _jax_sum_dense(self, x, weight, bias, contracting_dims, output_sharding): """JAX dot function for gradient testing""" - result = ( + output = ( jax.lax.dot_general(x, weight, dimension_numbers=(contracting_dims, ((), ()))) + bias ) - return jnp.sum(result) + if output_sharding is not None: + output = jax.lax.with_sharding_constraint(output, output_sharding) + return jnp.sum(output) @pytest_parametrize_wrapper( "device_count,mesh_shape,mesh_axes,mesh_resource", @@ -213,18 +218,18 @@ def test_te_distributed_dense_grad( # Test gradients w.r.t. all inputs te_grad_func = jax.jit( jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)), - static_argnames=("contracting_dims",), + static_argnames=("contracting_dims", "output_sharding"), ) jax_grad_func = jax.jit( jax.value_and_grad(self._jax_sum_dense, argnums=(0, 1, 2)), - static_argnames=("contracting_dims",), + static_argnames=("contracting_dims", "output_sharding"), ) te_val, te_grads = te_grad_func( - x_sharded, weight_sharded, bias_sharded, contracting_dims + x_sharded, weight_sharded, bias_sharded, contracting_dims, output_sharding ) jax_val, jax_grads = jax_grad_func( - x_sharded, weight_sharded, bias_sharded, contracting_dims + x_sharded, weight_sharded, bias_sharded, contracting_dims, output_sharding ) # Compare forward pass diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index fb40e2b1f..c218a8c8c 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -72,9 +72,7 @@ def impl_test_self_attn( attn_mask_type, dtype, softmax_type, - use_shardy, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) dropout_prob = 0.0 is_training = True batch, seqlen, num_head, hidden = data_shape @@ -183,48 +181,6 @@ def test_self_attn( attn_mask_type, dtype, softmax_type, - use_shardy=False, - ) - - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize( - "attn_bias_type, bias_shape", - [ - pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), - pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), - ], - ) - @pytest.mark.parametrize( - "softmax_type", - [ - pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"), - pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"), - pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"), - ], - ) - def test_self_attn_shardy( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - attn_bias_type, - bias_shape, - softmax_type, - ): - data_shape = (32, 512, 12, 64) - self.impl_test_self_attn( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - attn_bias_type, - bias_shape, - AttnMaskType.PADDING_MASK, - jnp.bfloat16, - softmax_type, - use_shardy=True, ) @@ -354,7 +310,6 @@ def impl_test_context_parallel_attn( qkv_layout, load_balanced, cp_strategy, - use_shardy, use_scan_ring=False, window_size=None, stripe_size=None, @@ -372,8 +327,6 @@ def impl_test_context_parallel_attn( os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1" else: os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" - - jax.config.update("jax_use_shardy_partitioner", use_shardy) attn_bias_type = AttnBiasType.NO_BIAS bias_shape = None dropout_prob = 0.0 @@ -466,46 +419,6 @@ def check_has_backend_for_mask(mask_type): runner.test_backward() del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] - @pytest_parametrize_wrapper( - "device_count,mesh_shape,mesh_axes,mesh_resource", - generate_context_parallel_configs_for_attn(), - ) - @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) - @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) - @pytest.mark.parametrize( - "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, - ) - @pytest.mark.skipif(version.parse(jax.__version__) < version.parse("0.5.0"), reason="shardy sharding requires JAX 0.5.0") - def test_context_parallel_allgather_attn_shardy( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - attn_mask_type, - dtype, - qkv_layout, - ): - if qkv_layout.is_thd(): - pytest.skip("Only BSHD layout is supported for CP + AG + Dual chunk attention") - kv_groups = 8 - self.impl_test_context_parallel_attn( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - kv_groups, - attn_mask_type, - dtype, - qkv_layout, - load_balanced=True, - cp_strategy=CPStrategy.ALL_GATHER, - use_shardy=True, - ) - @pytest_parametrize_wrapper( "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs_for_attn(), @@ -568,7 +481,6 @@ def test_context_parallel_allgather_striped_attn( qkv_layout, load_balanced, CPStrategy.ALL_GATHER, - use_shardy=False, window_size=window_size, stripe_size=stripe_size, num_segments_per_seq=num_segments_per_seq, @@ -616,7 +528,6 @@ def test_context_parallel_allgather_attn( qkv_layout, load_balanced, CPStrategy.ALL_GATHER, - use_shardy=False, ) @pytest_parametrize_wrapper( @@ -681,54 +592,11 @@ def test_context_parallel_ring_attn( qkv_layout, load_balanced, CPStrategy.RING, - use_shardy=False, use_scan_ring=use_scan, window_size=window_size, stripe_size=stripe_size, ) - @pytest.mark.skipif(version.parse(jax.__version__) < version.parse("0.5.0"), reason="shardy sharding requires JAX 0.5.0") - @pytest_parametrize_wrapper( - "device_count,mesh_shape,mesh_axes,mesh_resource", - generate_context_parallel_configs_for_attn(), - ) - @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) - @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) - @pytest.mark.parametrize( - "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, - ) - def test_context_parallel_ring_attn_shardy( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - attn_mask_type, - dtype, - qkv_layout, - ): - kv_groups = 8 - # Set the stripe size to 1 (ring attention only support stripe_size=1) - stripe_size = 1 if qkv_layout.is_thd() else None - self.impl_test_context_parallel_attn( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - kv_groups, - attn_mask_type, - dtype, - qkv_layout, - load_balanced=True, - cp_strategy=CPStrategy.RING, - use_shardy=False, - use_scan_ring=True, - stripe_size=stripe_size, - ) - REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = { "L0": [[]], diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 21359cedf..85108f47c 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -89,7 +89,6 @@ def generate_collectives_count_ref( @pytest_parametrize_wrapper("zero_centered_gamma", [False, True]) @pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - @pytest_parametrize_wrapper("use_shardy", [False, True] if version.parse(jax.__version__) >= version.parse("0.5.0") else [False]) def test_layernorm( self, device_count, @@ -101,9 +100,7 @@ def test_layernorm( zero_centered_gamma, shard_weights, fp8_recipe, - use_shardy, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) epsilon = 1e-6 ln_type = "layernorm" q_dtype = get_jnp_float8_e4m3_type() @@ -180,7 +177,6 @@ def ref_func(x, gamma, beta): @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - @pytest_parametrize_wrapper("use_shardy", [False, True] if version.parse(jax.__version__) >= version.parse("0.5.0") else [False]) def test_rmsnorm( self, device_count, @@ -191,9 +187,7 @@ def test_rmsnorm( dtype, shard_weights, fp8_recipe, - use_shardy, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) epsilon = 1e-6 ln_type = "rmsnorm" q_dtype = get_jnp_float8_e4m3_type() diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 6a2f395b1..12c66c755 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -210,7 +210,6 @@ def _test_layernorm_mlp_grad( input_shape, dtype, quantization_recipe, - use_shardy, with_jax_gemm, ): if ( @@ -221,6 +220,7 @@ def _test_layernorm_mlp_grad( and (dtype == jnp.bfloat16) ): pytest.xfail("Skip known failure case.") + #ROCm: skip unsupported MXFP8 layernorm MLP grad test cases if isinstance(quantization_recipe, recipe.MXFP8BlockScaling): _check_mxfp8_layernorm_mlp_grad_support( input_shape[0]*input_shape[1], @@ -232,7 +232,6 @@ def _test_layernorm_mlp_grad( use_bias, with_jax_gemm ) - jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" @@ -354,36 +353,6 @@ def test_layernorm_mlp_grad( dtype, quantization_recipe, with_jax_gemm, - ): - if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): - pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") - self._test_layernorm_mlp_grad( - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - quantization_recipe, - use_shardy=False, - with_jax_gemm=with_jax_gemm, - ) - - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) - @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_layernorm_mlp_grad_shardy( - self, - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - quantization_recipe, - with_jax_gemm, ): if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") @@ -394,7 +363,6 @@ def test_layernorm_mlp_grad_shardy( input_shape, dtype, quantization_recipe=quantization_recipe, - use_shardy=True, with_jax_gemm=with_jax_gemm, ) @@ -407,9 +375,9 @@ def _test_layernorm_mlp( dtype, use_fp8, quantization_recipe, - use_shardy, with_jax_gemm, ): + #ROCm: skip unsupported MXFP8 layernorm MLP test cases if isinstance(quantization_recipe, recipe.MXFP8BlockScaling): _check_mxfp8_layernorm_mlp_support( input_shape[0]*input_shape[1], @@ -421,7 +389,6 @@ def _test_layernorm_mlp( use_bias, with_jax_gemm ) - jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape layernorm_type = "rmsnorm" @@ -533,7 +500,6 @@ def test_layernorm_mlp_layer( dtype, use_fp8=False, quantization_recipe=None, - use_shardy=False, with_jax_gemm=with_jax_gemm, ) @@ -564,58 +530,5 @@ def test_layernorm_mlp_layer_fp8( dtype, use_fp8=True, quantization_recipe=quantization_recipe, - use_shardy=False, - with_jax_gemm=with_jax_gemm, - ) - - @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) - @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) - @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_layernorm_mlp_layer_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm - ): - self._test_layernorm_mlp( - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - use_fp8=False, - quantization_recipe=None, - use_shardy=True, - with_jax_gemm=with_jax_gemm, - ) - - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) - @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_layernorm_mlp_layer_fp8_shardy( - self, - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - quantization_recipe, - with_jax_gemm, - ): - if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): - pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") - self._test_layernorm_mlp( - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - use_fp8=True, - quantization_recipe=quantization_recipe, - use_shardy=True, with_jax_gemm=with_jax_gemm, ) diff --git a/tests/jax/test_distributed_permutation.py b/tests/jax/test_distributed_permutation.py index 5b6d8fec4..ee7a56a7e 100644 --- a/tests/jax/test_distributed_permutation.py +++ b/tests/jax/test_distributed_permutation.py @@ -34,11 +34,28 @@ from distributed_test_base import generate_configs from utils import assert_allclose, pytest_parametrize_wrapper -# High-level API with VJP support -from transformer_engine.jax.permutation import ( - token_dispatch, - token_combine, -) + +@pytest.fixture(autouse=True, scope="function") +def _inject_permutation(request): + """Lazy-load permutation API only for tests marked 'triton'. Other tests run without importing. + + We inject into sys.modules[__name__] so test code in this module can use + token_dispatch, token_combine as module-level names (fixture locals are not + visible to test methods). + """ + if not request.node.get_closest_marker("triton"): + yield + return + import sys + from transformer_engine.jax.permutation import token_dispatch, token_combine + + mod = sys.modules[__name__] + mod.token_dispatch = token_dispatch + mod.token_combine = token_combine + yield + + +# High-level API with VJP support (injected by _inject_permutation) # Reference implementations from test_permutation.py from test_permutation import ( @@ -80,6 +97,7 @@ } +@pytest.mark.triton class TestDistributedPermutation: """Test distributed/sharded execution of MoE permutation primitives. @@ -135,7 +153,6 @@ def generate_routing_map( DISPATCH_COMBINE_CASES, ) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_local_token_dispatch( self, device_count, @@ -147,7 +164,6 @@ def test_local_token_dispatch( hidden_size, topk, dtype, - use_shardy, ): """ Test token_dispatch with sharded inputs. @@ -164,7 +180,6 @@ def test_local_token_dispatch( matching the sharded execution's output ordering. Tests both forward pass (output values) and backward pass (gradients). """ - jax.config.update("jax_use_shardy_partitioner", use_shardy) key = jax.random.PRNGKey(42) # Generate global inputs @@ -307,7 +322,6 @@ def ref_chunk_loss(inp_chunk, routing_chunk, probs_chunk): DISPATCH_COMBINE_CASES, ) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_local_roundtrip( self, device_count, @@ -319,7 +333,6 @@ def test_local_roundtrip( hidden_size, topk, dtype, - use_shardy, ): """ Test roundtrip: token_dispatch followed by token_combine with sharded inputs. @@ -332,7 +345,6 @@ def test_local_roundtrip( Tests both forward pass and backward pass (gradient should be 2*x). """ - jax.config.update("jax_use_shardy_partitioner", use_shardy) key = jax.random.PRNGKey(42) # Generate global inputs @@ -403,7 +415,6 @@ def roundtrip_loss(x, rm, mprobs): DISPATCH_COMBINE_PADDING_CASES, ) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_local_token_dispatch_with_padding( self, device_count, @@ -416,14 +427,12 @@ def test_local_token_dispatch_with_padding( topk, align_size, dtype, - use_shardy, ): """ Test token_dispatch with padding using sharded inputs. Tests both forward pass (output values) and backward pass (gradients). """ - jax.config.update("jax_use_shardy_partitioner", use_shardy) key = jax.random.PRNGKey(42) # Generate global inputs @@ -502,7 +511,6 @@ def loss_with_padding(x, rm, p): DISPATCH_COMBINE_PADDING_CASES, ) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_local_roundtrip_with_padding( self, device_count, @@ -515,7 +523,6 @@ def test_local_roundtrip_with_padding( topk, align_size, dtype, - use_shardy, ): """ Test roundtrip with padding/alignment using sharded inputs. @@ -523,7 +530,6 @@ def test_local_roundtrip_with_padding( With uniform merging probs, should recover original input. Tests both forward pass and backward pass. """ - jax.config.update("jax_use_shardy_partitioner", use_shardy) key = jax.random.PRNGKey(42) # Generate inputs diff --git a/tests/jax/test_distributed_router.py b/tests/jax/test_distributed_router.py new file mode 100644 index 000000000..35f59c897 --- /dev/null +++ b/tests/jax/test_distributed_router.py @@ -0,0 +1,475 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for distributed/sharded execution of fused MoE router primitives. + +Testing Strategy: +================= +Router operations process each token independently (1 warp per token), so +sharded execution on the token dimension should produce identical results +to processing each shard independently with the reference implementation. + +For fused_topk_with_score_function (including compute_aux_scores mode): +- Input logits [num_tokens, num_experts] are sharded on num_tokens (DP axis) +- Expert dimension is replicated +- Each GPU processes its local tokens independently +- We verify sharded output matches per-shard reference, concatenated + +For fused_moe_aux_loss: +- This is a global reduction to a scalar +- All inputs and outputs are replicated (partition function forces this) +- We verify the op works correctly under a mesh context + +These tests exercise: batcher and shardy_sharding_rule from the router primitives. +""" + +import pytest + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from distributed_test_base import generate_configs +from utils import assert_allclose, pytest_parametrize_wrapper + + +@pytest.fixture(autouse=True, scope="function") +def _inject_router(request): + """Lazy-load router API only for tests marked 'triton'. Other tests run without importing. + + We inject into sys.modules[__name__] so test code can use fused_topk_with_score_function, + fused_moe_aux_loss as module-level names (fixture locals are not visible to tests). + """ + if not request.node.get_closest_marker("triton"): + yield + return + import sys + from transformer_engine.jax.router import ( + fused_topk_with_score_function, + fused_moe_aux_loss, + ) + + mod = sys.modules[__name__] + mod.fused_topk_with_score_function = fused_topk_with_score_function + mod.fused_moe_aux_loss = fused_moe_aux_loss + yield + + +jax.config.update("jax_use_shardy_partitioner", True) + +from test_fused_router import ( + reference_topk_softmax_sigmoid, + reference_compute_scores_for_aux_loss, + reference_aux_loss, + make_logits, +) + +# (num_tokens, num_experts, topk) +ALL_TOPK_CASES = [ + (128, 32, 4), + (2048, 128, 8), +] +TOPK_CASES = { + "L0": ALL_TOPK_CASES[0:1], + "L2": ALL_TOPK_CASES, +} + +ALL_AUX_LOSS_CASES = [ + (128, 32, 4), + (2048, 128, 4), +] +AUX_LOSS_CASES = { + "L0": ALL_AUX_LOSS_CASES[0:1], + "L2": ALL_AUX_LOSS_CASES, +} + + +@pytest.mark.triton +class TestDistributedFusedTopk: + """Test distributed execution of fused_topk_with_score_function. + + Shards logits on the token dimension. Each GPU independently runs the + fused kernel on its local tokens. We compare against the reference + implementation run per-shard and concatenated. + """ + + def _impl_test( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + ): + logits = make_logits(num_tokens, num_experts, score_function) + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + dp_axis = mesh_resource.dp_resource + sharded_pspec = PartitionSpec(dp_axis, None) + num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1 + local_num_tokens = num_tokens // num_dp_devices + + with mesh: + logits_sharding = NamedSharding(mesh, sharded_pspec) + logits_sharded = jax.device_put(logits, logits_sharding) + + # === Forward === + @jax.jit + def target_fwd(x): + return fused_topk_with_score_function( + x, + topk=topk, + score_function=score_function, + ) + + target_probs, target_routing_map = target_fwd(logits_sharded) + + logits_shards = jnp.reshape(logits, (num_dp_devices, local_num_tokens, num_experts)) + ref_fwd_fn = jax.jit( + lambda x: reference_topk_softmax_sigmoid( + x, + topk=topk, + score_function=score_function, + ) + ) + ref_probs_list = [] + ref_routing_list = [] + for i in range(num_dp_devices): + p, rm = ref_fwd_fn(logits_shards[i]) + ref_probs_list.append(p) + ref_routing_list.append(rm) + + ref_probs = jnp.concatenate(ref_probs_list, axis=0) + ref_routing = jnp.concatenate(ref_routing_list, axis=0) + + assert_allclose( + jax.device_get(target_probs), + ref_probs, + dtype=jnp.float32, + ) + assert jnp.array_equal( + jax.device_get(target_routing_map), + ref_routing, + ), "Routing map mismatch in distributed fused_topk" + + # === Backward === + def target_loss(x): + p, _ = fused_topk_with_score_function( + x, + topk=topk, + score_function=score_function, + ) + return jnp.sum(p) + + def ref_chunk_loss(x_chunk): + p, _ = reference_topk_softmax_sigmoid( + x_chunk, + topk=topk, + score_function=score_function, + ) + return jnp.sum(p) + + target_grad = jax.jit(jax.grad(target_loss))(logits_sharded) + + ref_grads = [] + ref_chunk_grad_fn = jax.jit(jax.grad(ref_chunk_loss)) + for i in range(num_dp_devices): + ref_grads.append(ref_chunk_grad_fn(logits_shards[i])) + ref_grad = jnp.concatenate(ref_grads, axis=0) + + assert_allclose( + jax.device_get(target_grad), + ref_grad, + dtype=jnp.float32, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + TOPK_CASES, + ) + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_distributed_topk( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + ): + self._impl_test( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + ) + + +@pytest.mark.triton +class TestDistributedScoreForAuxLoss: + """Test distributed execution of fused_topk_with_score_function with compute_aux_scores=True. + + Same sharding strategy as fused_topk: shard on token dim, replicate experts. + Each GPU independently computes scores and routing map for its local tokens. + """ + + def _impl_test( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + ): + logits = make_logits(num_tokens, num_experts, score_function) + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + dp_axis = mesh_resource.dp_resource + sharded_pspec = PartitionSpec(dp_axis, None) + num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1 + local_num_tokens = num_tokens // num_dp_devices + + with mesh: + logits_sharding = NamedSharding(mesh, sharded_pspec) + logits_sharded = jax.device_put(logits, logits_sharding) + + # === Forward === + @jax.jit + def target_fwd(x): + return fused_topk_with_score_function( + x, + topk=topk, + score_function=score_function, + compute_aux_scores=True, + ) + + target_scores, target_routing_map = target_fwd(logits_sharded) + + logits_shards = jnp.reshape(logits, (num_dp_devices, local_num_tokens, num_experts)) + ref_fwd_fn = jax.jit( + lambda x: reference_compute_scores_for_aux_loss( + x, + topk=topk, + score_function=score_function, + ) + ) + ref_routing_list = [] + ref_scores_list = [] + for i in range(num_dp_devices): + rm, s = ref_fwd_fn(logits_shards[i]) + ref_routing_list.append(rm) + ref_scores_list.append(s) + + ref_routing = jnp.concatenate(ref_routing_list, axis=0) + ref_scores = jnp.concatenate(ref_scores_list, axis=0) + + assert_allclose( + jax.device_get(target_scores), + ref_scores, + dtype=jnp.float32, + ) + assert jnp.array_equal( + jax.device_get(target_routing_map), + ref_routing, + ), "Routing map mismatch in distributed score_for_aux_loss" + + # === Backward === + def target_loss(x): + s, _ = fused_topk_with_score_function( + x, + topk=topk, + score_function=score_function, + compute_aux_scores=True, + ) + return jnp.sum(s) + + def ref_chunk_loss(x_chunk): + _, s = reference_compute_scores_for_aux_loss( + x_chunk, + topk=topk, + score_function=score_function, + ) + return jnp.sum(s) + + target_grad = jax.jit(jax.grad(target_loss))(logits_sharded) + + ref_grads = [] + ref_chunk_grad_fn = jax.jit(jax.grad(ref_chunk_loss)) + for i in range(num_dp_devices): + ref_grads.append(ref_chunk_grad_fn(logits_shards[i])) + ref_grad = jnp.concatenate(ref_grads, axis=0) + + assert_allclose( + jax.device_get(target_grad), + ref_grad, + dtype=jnp.float32, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + TOPK_CASES, + ) + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_distributed_score_for_aux_loss( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + ): + self._impl_test( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + ) + + +@pytest.mark.triton +class TestDistributedMoEAuxLoss: + """Test distributed execution of fused_moe_aux_loss. + + Aux loss is a global reduction to a scalar. The partition function forces + all inputs to be replicated. We verify the op produces correct results + under a mesh context with replicated sharding, testing both forward + (scalar loss) and backward (gradient w.r.t. probs). + """ + + def _impl_test( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + ): + key = jax.random.PRNGKey(42) + _, subkey1, _ = jax.random.split(key, 3) + + offset = jnp.arange(-num_tokens // 2, num_tokens // 2, dtype=jnp.float32) * 1e-4 + probs = jnp.arange(-num_experts // 2, num_experts // 2, dtype=jnp.float32) * 1e-2 + probs = probs[None, :].repeat(num_tokens, axis=0) + offset[:, None] + + tokens_per_expert = jax.random.randint(subkey1, (num_experts,), 1, 1000).astype(jnp.int32) + coeff = 0.01 + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + replicated_2d_pspec = PartitionSpec(None, None) + replicated_1d_pspec = PartitionSpec(None) + + with mesh: + probs_sharding = NamedSharding(mesh, replicated_2d_pspec) + tpe_sharding = NamedSharding(mesh, replicated_1d_pspec) + + probs_dev = jax.device_put(probs, probs_sharding) + tpe_dev = jax.device_put(tokens_per_expert, tpe_sharding) + + # === Forward === + @jax.jit + def target_fwd(p, tpe): + return fused_moe_aux_loss(p, tpe, topk=topk, coeff=coeff) + + target_loss = target_fwd(probs_dev, tpe_dev) + + ref_fwd_fn = jax.jit( + lambda p: reference_aux_loss( + p, + tokens_per_expert, + num_tokens, + topk, + num_experts, + coeff, + ) + ) + ref_loss = ref_fwd_fn(probs) + + assert_allclose( + jax.device_get(target_loss), + ref_loss, + dtype=jnp.float32, + ) + + # === Backward === + def target_loss_fn(p): + return fused_moe_aux_loss( + p, + tokens_per_expert, + topk=topk, + coeff=coeff, + ) + + def ref_loss_fn(p): + return reference_aux_loss( + p, + tokens_per_expert, + num_tokens, + topk, + num_experts, + coeff, + ) + + target_grad = jax.jit(jax.grad(target_loss_fn))(probs_dev) + ref_grad = jax.jit(jax.grad(ref_loss_fn))(probs) + + assert_allclose( + jax.device_get(target_grad), + ref_grad, + dtype=jnp.float32, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + AUX_LOSS_CASES, + ) + def test_distributed_aux_loss( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + ): + self._impl_test( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + ) diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 0665baa4e..ca1dcf117 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -87,12 +87,9 @@ def impl_test_softmax( dtype, bad_sharding, broadcast_batch_mask, - use_shardy, ): if broadcast_batch_mask and softmax_fusion_type != SoftmaxFusionType.SCALED_MASKED: pytest.skip("Softmax type has no mask.") - - jax.config.update("jax_use_shardy_partitioner", use_shardy) target_func = partial( self.target_func, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type ) @@ -181,35 +178,4 @@ def test_softmax( dtype, bad_sharding, broadcast_batch_mask, - use_shardy=True, - ) - - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize( - "softmax_fusion_type", [SoftmaxFusionType.SCALED, SoftmaxFusionType.SCALED_MASKED] - ) - @pytest.mark.parametrize("bad_sharding", [False, True]) - @pytest.mark.parametrize("broadcast_batch_mask", [False, True]) - def test_softmax_gspmd( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - softmax_fusion_type, - bad_sharding, - broadcast_batch_mask, - ): - self.impl_test_softmax( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape=[32, 12, 128, 128], - softmax_fusion_type=softmax_fusion_type, - scale_factor=1.0, - dtype=DTYPES[0], - bad_sharding=bad_sharding, - broadcast_batch_mask=broadcast_batch_mask, - use_shardy=False, ) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 8639af79b..7890ee3d3 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -4,6 +4,7 @@ # # See LICENSE for license information. """Tests for fused attention""" +import os from enum import Enum, auto from dataclasses import dataclass, field from functools import partial @@ -52,6 +53,9 @@ from distributed_test_base import assert_equal_collectives from utils import assert_allclose, print_debug_tensor_stats +# Get determinism +_deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + @pytest.fixture(autouse=True, scope="module") def init(): @@ -417,16 +421,25 @@ def _check_configs(self): pytest.skip( "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) - # TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support - if ( - get_device_compute_capability(0) >= 100 - and self.dropout_prob == 0.1 - and self.attn_bias_type is not AttnBiasType.NO_BIAS - and not is_hip_extension() - ): - pytest.skip( - "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" - ) + + if not is_hip_extension() and get_device_compute_capability(0) >= 100 and self.is_training: + if FusedAttnHelper.is_non_deterministic_allowed() and ( + (self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS) + or get_cudnn_version() < 90700 + ): + pytest.skip( + "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with" + " dropout" + ) + if not FusedAttnHelper.is_non_deterministic_allowed() and ( + self.dropout_prob != 0.0 + or self.attn_bias_type != AttnBiasType.NO_BIAS + or get_cudnn_version() < 91801 + ): + pytest.skip( + "For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or" + " dropout" + ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): @@ -1065,7 +1078,7 @@ def check_dqkv(primitive, reference, pad, idx): # Assume all batch has the same actual_seqlen, probably needs to extend the tests bias_mask = self.mask[0, 0] - + # Assert all masked dbias are 0s assert_allclose( jnp.where(bias_mask, primitive_dbias, 0), @@ -1346,6 +1359,7 @@ def check_dqkv(primitive, reference, pad, idx): pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"), ], ) +@pytest.mark.skipif(_deterministic, reason="Test non-determinism only") class TestFusedAttn: """ Fused attention tester @@ -1474,6 +1488,185 @@ def test_backward( ) runner.test_backward() + +@pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), + pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"), + pytest.param( + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT" + ), + ], +) +@pytest.mark.parametrize( + "softmax_type", + [ + pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"), + ], +) +@pytest.mark.parametrize( + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", + [ + # large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate + pytest.param( + 2, + 1024, + 2048, + 12, + 6, + 128, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-SEPARATE", + ), + pytest.param( + 2, + 1024, + 2048, + 12, + 6, + 128, + 64, + jnp.bfloat16, + QKVLayout.THD_THD_THD, + id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE", + ), + ], +) +@pytest.mark.parametrize( + "dropout_prob", + [ + pytest.param(0.0, id="DROP_0.0"), + ], +) +@pytest.mark.parametrize( + "swa", + [ + pytest.param(False, id="NO_SWA"), + ], +) +@pytest.mark.parametrize( + "seq_desc_format", + [ + pytest.param(SeqDescFormat.Seqlens, id="Seqlens"), + ], +) +@pytest.mark.skipif(not _deterministic, reason="Test determinism only") +class TestFusedAttnWithDeterminism: + """ + Fused attention tester with determinism + """ + + @staticmethod + @pytest.mark.parametrize( + "is_training", + [ + pytest.param(True, id="TRAINING"), + ], + ) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) + def _test_forward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + is_training, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ): + """ + Test forward with parameterized configs + This test is not intended to run automatically during CI as it is time-consuming + It is kept for development and debugging + """ + TestFusedAttn._test_forward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + is_training, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ) + + @staticmethod + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) + def test_backward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ): + """ + Test backward with parameterized configs + """ + TestFusedAttn.test_backward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ) + # Single test with new-style RNG @pytest.mark.skipif( not is_hip_extension(), reason="New-style RNGs only enabled on AMD hardware" diff --git a/tests/jax/test_fused_router.py b/tests/jax/test_fused_router.py new file mode 100644 index 000000000..89a32f1ce --- /dev/null +++ b/tests/jax/test_fused_router.py @@ -0,0 +1,561 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for fused MoE router CUDA kernels (JAX wrappers).""" + +import sys +from functools import partial +from typing import Optional + +import jax +import jax.numpy as jnp +import pytest + +from utils import pytest_parametrize_wrapper + + +@pytest.fixture(autouse=True, scope="function") +def _inject_router(request): + """Lazy-load router API only for tests marked 'triton'. Other tests run without importing. + + We inject into sys.modules[__name__] so test code can use fused_topk_with_score_function, + fused_moe_aux_loss as module-level names (fixture locals are not visible to tests). + """ + if not request.node.get_closest_marker("triton"): + yield + return + from transformer_engine.jax.router import ( + fused_topk_with_score_function, + fused_moe_aux_loss, + ) + + mod = sys.modules[__name__] + mod.fused_topk_with_score_function = fused_topk_with_score_function + mod.fused_moe_aux_loss = fused_moe_aux_loss + yield + + +# ============================================================================= +# Test case definitions (L0 = fast smoke, L2 = comprehensive) +# ============================================================================= + +# (num_tokens, num_experts, topk) +ALL_TOPK_CASES = [ + (128, 32, 4), + (2048, 32, 4), + (2048, 128, 8), + (7168, 128, 4), + (7168, 32, 8), +] +TOPK_CASES = { + "L0": ALL_TOPK_CASES[0:2], + "L2": ALL_TOPK_CASES, +} + +ALL_GROUP_TOPK_OPTIONS = [None, 4] +GROUP_TOPK_OPTIONS = { + "L0": [None], + "L2": ALL_GROUP_TOPK_OPTIONS, +} + +ALL_SCALING_FACTOR_OPTIONS = [None, 1.2] +SCALING_FACTOR_OPTIONS = { + "L0": [None], + "L2": ALL_SCALING_FACTOR_OPTIONS, +} + +ALL_ENABLE_BIAS_OPTIONS = [True, False] +ENABLE_BIAS_OPTIONS = { + "L0": [False], + "L2": ALL_ENABLE_BIAS_OPTIONS, +} + +ALL_USE_PRE_SOFTMAX_OPTIONS = [True, False] +USE_PRE_SOFTMAX_OPTIONS = { + "L0": [False], + "L2": ALL_USE_PRE_SOFTMAX_OPTIONS, +} + +# (num_tokens, num_experts, topk) +ALL_SCORE_AUX_LOSS_CASES = [ + (128, 32, 4), + (2048, 128, 4), + (2048, 256, 8), + (7168, 128, 8), + (7168, 32, 4), +] +SCORE_AUX_LOSS_CASES = { + "L0": ALL_SCORE_AUX_LOSS_CASES[0:2], + "L2": ALL_SCORE_AUX_LOSS_CASES, +} + +ALL_SCORE_FUNCTIONS = ["softmax", "sigmoid"] +SCORE_FUNCTIONS = { + "L0": ["softmax"], + "L2": ALL_SCORE_FUNCTIONS, +} + +# (num_tokens, num_experts, topk) +ALL_AUX_LOSS_CASES = [ + (128, 32, 4), + (2048, 128, 4), + (2048, 256, 4), + (7168, 128, 4), + (7168, 32, 4), +] +AUX_LOSS_CASES = { + "L0": ALL_AUX_LOSS_CASES[0:2], + "L2": ALL_AUX_LOSS_CASES, +} + +ALL_DTYPES = [jnp.float32] +DTYPES = { + "L0": [jnp.float32], + "L2": ALL_DTYPES, +} + +SEED = 42 + + +# ============================================================================= +# Reference Implementations +# ============================================================================= + + +def reference_group_limited_topk( + scores: jnp.ndarray, + topk: int, + num_tokens: int, + num_experts: int, + num_groups: int, + group_topk: int, +): + """Reference implementation for grouped top-k. + + Only valid when num_groups and group_topk are both positive integers. + For plain top-k without grouping, use jax.lax.top_k directly. + """ + assert num_groups is not None and num_groups > 0, ( + "reference_group_limited_topk requires valid num_groups > 0. " + "For plain top-k, use jax.lax.top_k directly." + ) + assert ( + group_topk is not None and group_topk > 0 + ), "reference_group_limited_topk requires valid group_topk > 0." + assert ( + num_experts % num_groups == 0 + ), f"num_experts ({num_experts}) must be divisible by num_groups ({num_groups})" + group_size = num_experts // num_groups + experts_per_group = topk // group_topk + + group_scores = ( + scores.reshape(num_tokens, num_groups, group_size) + .sort(axis=-1)[..., -experts_per_group:] + .sum(axis=-1) + ) + group_idx = jax.lax.top_k(group_scores, k=group_topk)[1] + group_mask = jnp.zeros_like(group_scores).at[jnp.arange(num_tokens)[:, None], group_idx].set(1) + + score_mask = (group_mask[:, :, None] * jnp.ones((num_tokens, num_groups, group_size))).reshape( + num_tokens, -1 + ) + + masked_scores = jnp.where(score_mask.astype(bool), scores, -jnp.inf) + probs, top_indices = jax.lax.top_k(masked_scores, k=topk) + return probs, top_indices + + +def reference_topk_softmax_sigmoid( + logits: jnp.ndarray, + topk: int, + use_pre_softmax: bool = False, + num_groups: Optional[int] = None, + group_topk: Optional[int] = None, + scaling_factor: Optional[float] = None, + score_function: str = "softmax", + expert_bias: Optional[jnp.ndarray] = None, +): + """Reference implementation for topk + softmax/sigmoid.""" + num_tokens, num_experts = logits.shape + + def compute_topk(scores, topk, num_groups=None, group_topk=None): + if group_topk: + return reference_group_limited_topk( + scores=scores, + topk=topk, + num_tokens=num_tokens, + num_experts=num_experts, + num_groups=num_groups, + group_topk=group_topk, + ) + else: + return jax.lax.top_k(scores, k=topk) + + if score_function == "softmax": + if use_pre_softmax: + scores = jax.nn.softmax(logits.astype(jnp.float32), axis=-1).astype(logits.dtype) + probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) + else: + scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) + probs = jax.nn.softmax(scores.astype(jnp.float32), axis=-1).astype(logits.dtype) + elif score_function == "sigmoid": + scores = jax.nn.sigmoid(logits.astype(jnp.float32)).astype(logits.dtype) + if expert_bias is not None: + scores_for_routing = scores + expert_bias + _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) + scores = jnp.take_along_axis(scores, top_indices, axis=1).astype(logits.dtype) + else: + scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) + probs = scores / (scores.sum(axis=-1, keepdims=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + if scaling_factor: + probs = probs * scaling_factor + + topk_masked_gates = ( + jnp.zeros_like(logits).at[jnp.arange(num_tokens)[:, None], top_indices].set(probs) + ) + topk_map = ( + jnp.zeros_like(logits, dtype=jnp.bool_) + .at[jnp.arange(num_tokens)[:, None], top_indices] + .set(True) + ) + + return topk_masked_gates, topk_map + + +def reference_compute_scores_for_aux_loss(logits: jnp.ndarray, topk: int, score_function: str): + """Reference implementation for computing routing scores for aux loss.""" + if score_function == "softmax": + scores = jax.nn.softmax(logits.astype(jnp.float32), axis=-1) + elif score_function == "sigmoid": + scores = jax.nn.sigmoid(logits.astype(jnp.float32)) + scores = scores / (scores.sum(axis=-1, keepdims=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + _, top_indices = jax.lax.top_k(scores, k=topk) + num_tokens = logits.shape[0] + routing_map = ( + jnp.zeros_like(logits, dtype=jnp.bool_) + .at[jnp.arange(num_tokens)[:, None], top_indices] + .set(True) + ) + return routing_map, scores + + +def reference_aux_loss( + probs: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + total_num_tokens: int, + topk: int, + num_experts: int, + moe_aux_loss_coeff: float, +): + """Reference implementation for MoE auxiliary loss.""" + aggregated_probs_per_expert = probs.sum(axis=0) + aux_loss = jnp.sum(aggregated_probs_per_expert * tokens_per_expert) * ( + num_experts * moe_aux_loss_coeff / (topk * total_num_tokens * total_num_tokens) + ) + return aux_loss + + +# ============================================================================= +# Helper: logits generation +# ============================================================================= + + +def make_logits(num_tokens, num_experts, score_function, dtype=jnp.float32): + """Create deterministic logits for testing.""" + if score_function == "sigmoid": + offset = jnp.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype) * 1e-4 + logits = jnp.arange(-num_experts // 2, num_experts // 2, dtype=dtype) * 1e-2 + logits = logits[None, :].repeat(num_tokens, axis=0) + offset[:, None] + else: + logits = ( + jnp.arange( + -num_tokens * num_experts // 2, + num_tokens * num_experts // 2, + dtype=dtype, + ) + * 1e-4 + ) + logits = logits.reshape(num_tokens, num_experts) + return logits + + +# ============================================================================= +# Test: Fused Top-K with Score Function +# ============================================================================= + + +def run_topk_comparison( + dtype, + num_tokens, + num_experts, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + enable_bias, +): + """Compare fused vs reference top-k implementation, both jitted.""" + logits = make_logits(num_tokens, num_experts, score_function, dtype) + + if enable_bias and score_function == "sigmoid": + expert_bias = jnp.arange(num_experts, dtype=jnp.float32) * 0.1 + expert_bias = jnp.flip(expert_bias) + else: + expert_bias = None + + # Forward: reference (jitted) + ref_fwd_fn = jax.jit( + partial( + reference_topk_softmax_sigmoid, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=expert_bias, + ) + ) + probs_ref, routing_map_ref = ref_fwd_fn(logits) + + # Forward: fused (jitted) + fused_fwd_fn = jax.jit( + partial( + fused_topk_with_score_function, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups if num_groups else -1, + group_topk=group_topk if group_topk else -1, + scaling_factor=scaling_factor if scaling_factor else 1.0, + score_function=score_function, + expert_bias=expert_bias, + ) + ) + probs_fused, routing_map_fused = fused_fwd_fn(logits) + + assert jnp.allclose( + probs_ref, probs_fused, atol=1e-5, rtol=1e-5 + ), f"Probs mismatch: max diff = {jnp.abs(probs_ref - probs_fused).max()}" + assert jnp.array_equal(routing_map_ref, routing_map_fused), "Routing map mismatch" + + # Backward: reference (jitted) + def loss_ref(logits_): + p, _ = reference_topk_softmax_sigmoid( + logits_, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + ) + return p.sum() + + def loss_fused(logits_): + p, _ = fused_topk_with_score_function( + logits_, + topk, + use_pre_softmax, + num_groups if num_groups else -1, + group_topk if group_topk else -1, + scaling_factor if scaling_factor else 1.0, + score_function, + expert_bias, + ) + return p.sum() + + grad_ref = jax.jit(jax.grad(loss_ref))(logits) + grad_fused = jax.jit(jax.grad(loss_fused))(logits) + assert jnp.allclose( + grad_ref, grad_fused, atol=1e-5, rtol=1e-5 + ), f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + TOPK_CASES, +) +@pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS) +@pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS) +@pytest_parametrize_wrapper("enable_bias", ENABLE_BIAS_OPTIONS) +@pytest.mark.triton +def test_topk_sigmoid( + dtype, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias +): + num_groups = 8 if group_topk else None + run_topk_comparison( + dtype=dtype, + num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + use_pre_softmax=False, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function="sigmoid", + enable_bias=enable_bias, + ) + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + TOPK_CASES, +) +@pytest_parametrize_wrapper("use_pre_softmax", USE_PRE_SOFTMAX_OPTIONS) +@pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS) +@pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS) +@pytest.mark.triton +def test_topk_softmax( + dtype, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor +): + num_groups = 8 if group_topk else None + run_topk_comparison( + dtype=dtype, + num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function="softmax", + enable_bias=False, + ) + + +# ============================================================================= +# Test: Fused Score for MoE Aux Loss +# ============================================================================= + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + SCORE_AUX_LOSS_CASES, +) +@pytest_parametrize_wrapper("score_function", SCORE_FUNCTIONS) +@pytest.mark.triton +def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): + logits = make_logits(num_tokens, num_experts, score_function, dtype) + + # Forward: reference (jitted) + ref_fwd_fn = jax.jit( + partial( + reference_compute_scores_for_aux_loss, + topk=topk, + score_function=score_function, + ) + ) + routing_map_ref, scores_ref = ref_fwd_fn(logits) + + # Forward: fused (jitted) + fused_fwd_fn = jax.jit( + partial( + fused_topk_with_score_function, + topk=topk, + score_function=score_function, + compute_aux_scores=True, + ) + ) + scores_fused, routing_map_fused = fused_fwd_fn(logits) + + assert jnp.allclose( + scores_ref, scores_fused, atol=1e-5, rtol=1e-5 + ), f"Scores mismatch: max diff = {jnp.abs(scores_ref - scores_fused).max()}" + assert jnp.array_equal(routing_map_ref, routing_map_fused), "Routing map mismatch" + + # Backward (jitted) + def loss_ref(logits_): + _, s = reference_compute_scores_for_aux_loss(logits_, topk, score_function) + return s.sum() + + def loss_fused(logits_): + s, _ = fused_topk_with_score_function( + logits_, + topk, + score_function=score_function, + compute_aux_scores=True, + ) + return s.sum() + + grad_ref = jax.jit(jax.grad(loss_ref))(logits) + grad_fused = jax.jit(jax.grad(loss_fused))(logits) + assert jnp.allclose( + grad_ref, grad_fused, atol=1e-5, rtol=1e-5 + ), f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" + + +# ============================================================================= +# Test: Fused MoE Aux Loss +# ============================================================================= + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + AUX_LOSS_CASES, +) +@pytest.mark.triton +def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): + key = jax.random.PRNGKey(SEED) + + offset = jnp.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype) * 1e-4 + probs = jnp.arange(-num_experts // 2, num_experts // 2, dtype=dtype) * 1e-2 + probs = probs[None, :].repeat(num_tokens, axis=0) + offset[:, None] + probs = probs.reshape(num_tokens, num_experts) + + tokens_per_expert = jax.random.randint(key, (num_experts,), 1, 1000).astype(jnp.int32) + coeff = 0.01 + + # Forward: reference (jitted) + ref_fwd_fn = jax.jit( + partial( + reference_aux_loss, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + topk=topk, + num_experts=num_experts, + moe_aux_loss_coeff=coeff, + ) + ) + aux_loss_ref = ref_fwd_fn(probs) + + # Forward: fused (jitted) + fused_fwd_fn = jax.jit( + partial( + fused_moe_aux_loss, + tokens_per_expert=tokens_per_expert, + topk=topk, + coeff=coeff, + ) + ) + aux_loss_fused = fused_fwd_fn(probs) + + assert jnp.allclose( + aux_loss_ref, aux_loss_fused, atol=1e-5, rtol=1e-5 + ), f"Aux loss mismatch: ref={aux_loss_ref}, fused={aux_loss_fused}" + + # Backward (jitted) + def loss_ref_fn(probs_): + return reference_aux_loss(probs_, tokens_per_expert, num_tokens, topk, num_experts, coeff) + + def loss_fused_fn(probs_): + return fused_moe_aux_loss(probs_, tokens_per_expert, topk, coeff) + + grad_ref = jax.jit(jax.grad(loss_ref_fn))(probs) + grad_fused = jax.jit(jax.grad(loss_fused_fn))(probs) + assert jnp.allclose( + grad_ref, grad_fused, atol=1e-5, rtol=1e-5 + ), f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 5bb59c6ed..38fbee18e 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -5,25 +5,46 @@ """Tests for permutation Triton kernels and high-level APIs""" import functools +import sys import jax import jax.numpy as jnp import pytest -# High-level API with VJP support -from transformer_engine.jax.permutation import ( - token_dispatch, - token_combine, - sort_chunks_by_index, -) from utils import assert_allclose, pytest_parametrize_wrapper +@pytest.fixture(autouse=True, scope="function") +def _inject_permutation(request): + """Lazy-load permutation API only for tests marked 'triton'. Other tests run without importing. + + We inject into sys.modules[__name__] so that test code in this module can use + token_dispatch, token_combine, etc. as module-level names. A plain import inside + this fixture would only bind those names in the fixture's local scope; the test + methods (e.g. in TestHighLevelPermutationAPI) reference them as globals, so they + must exist on the module's namespace. + """ + if not request.node.get_closest_marker("triton"): + yield + return + from transformer_engine.jax.permutation import ( + token_dispatch, + token_combine, + sort_chunks_by_index, + ) + + mod = sys.modules[__name__] + mod.token_dispatch = token_dispatch + mod.token_combine = token_combine + mod.sort_chunks_by_index = sort_chunks_by_index + yield + + ALL_DISPATCH_COMBINE_CASES = [ (128, 5, 128, 3), (1024, 8, 128, 8), (4096, 32, 1280, 2), - (4096, 256, 4096, 6), + (4096, 64, 4096, 6), ] DISPATCH_COMBINE_CASES = { "L0": ALL_DISPATCH_COMBINE_CASES[0:2], @@ -44,7 +65,7 @@ (128, 5, 128, 3, 8), (1024, 8, 128, 8, 16), (4096, 32, 1280, 2, 128), - (4096, 256, 4096, 6, 16), + (4096, 64, 4096, 6, 16), ] DISPATCH_COMBINE_PADDING_CASES = { "L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2], @@ -449,6 +470,7 @@ def reference_sort_chunks_by_map( return output, permuted_probs +@pytest.mark.triton class TestHighLevelPermutationAPI: """Test high-level permutation APIs (token_dispatch, token_combine, etc.) diff --git a/tests/jax/test_triton_custom_calls.py b/tests/jax/test_triton_custom_calls.py index 6d969de0d..846d26a41 100644 --- a/tests/jax/test_triton_custom_calls.py +++ b/tests/jax/test_triton_custom_calls.py @@ -7,7 +7,9 @@ import jax.numpy as jnp import pytest -from utils import assert_allclose, pytest_parametrize_wrapper +from utils import assert_allclose, pytest_parametrize_wrapper, require_triton_or_skip_test_file + +require_triton_or_skip_test_file() import triton import triton.language as tl @@ -23,6 +25,7 @@ def init(): yield +@pytest.mark.triton class TestTritonBinding: """Test Triton binding primitive.""" diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 373f0a938..67d160b73 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -28,9 +28,13 @@ canonicalize_attn_mask_type, make_swa_mask, ) +from transformer_engine.jax.cpp_extensions.misc import is_hip_extension from transformer_engine.jax.quantize.helper import DType as TEDType +from transformer_engine.jax.version_utils import ( + TRITON_EXTENSION_MIN_JAX_VERSION, + is_triton_extension_supported, +) from transformer_engine.jax.util import get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type -from transformer_engine.jax.cpp_extensions.misc import is_hip_extension PRNGKey = Any Shape = Tuple[int, ...] @@ -45,6 +49,17 @@ NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0))) +def require_triton_or_skip_test_file(): + """Skip the current test file if JAX is too old for Triton kernel support (calls pytest.skip).""" + if not is_triton_extension_supported(): + pytest.skip( + f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for Triton kernel support. " + "Triton kernel dispatch segfaults with older jaxlib. " + "Upgrade with: pip install --upgrade jax jaxlib", + allow_module_level=True, + ) + + def is_devices_enough(required): """ Check if the available GPUs is enough diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 77af73830..cb7c636ee 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -190,10 +190,13 @@ def run_dpa_with_cp( fp8_mha="False", scaling_mode="delayed", f16_O="False", + is_training="True", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" logging.root.setLevel(log_level) + # When is_training is False, gradient outputs are None. + is_training = is_training == "True" # set up environment variables and config fp8_bwd = fp8_bwd == "True" and dtype == "fp8" @@ -268,7 +271,9 @@ def run_dpa_with_cp( softmax_type=config.softmax_type, return_max_logit=config.return_max_logit, ).cuda() - if config.softmax_type != "vanilla": + if not is_training: + core_attn.eval() + if is_training and config.softmax_type != "vanilla": core_attn.softmax_offset.requires_grad = True # generate attention inputs @@ -316,8 +321,25 @@ def run_dpa_with_cp( x.requires_grad = True if config.attn_bias_type not in ["no_bias", "alibi"]: - attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) + bias_shape_map = { + "1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), + "11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv), + "b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv), + "bhss": ( + config.batch_size, + config.num_heads, + config.max_seqlen_q, + config.max_seqlen_kv, + ), + "111s": (1, 1, 1, config.max_seqlen_kv), + } + attn_bias_shape = bias_shape_map.get(config.bias_shape) + if attn_bias_shape is None: + assert False, f"cuDNN does not support {config.bias_shape=}" bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() + # cuDNN does not support dbias calculation for 111s as of cuDNN 9.18 + # TODO(KshitijLakhani): Set requires_grad to True for all shapes once 111s is supported + bias.requires_grad = True if config.bias_shape != "111s" else False else: bias = None @@ -344,15 +366,20 @@ def run_dpa_with_cp( ) if config.return_max_logit: out, max_logit = out - if fp8_bwd and fp8_mha: - dout_fp8 = dout_quantizer(dout) - out.backward(dout_fp8) - else: - out.backward(dout) - dq, dk, dv = q.grad, k.grad, v.grad - d_softmax_offset = None - if config.softmax_type != "vanilla": - d_softmax_offset = core_attn.softmax_offset.grad + if is_training: + if fp8_bwd and fp8_mha: + dout_fp8 = dout_quantizer(dout) + out.backward(dout_fp8) + else: + out.backward(dout) + if is_training: + dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None + d_softmax_offset = ( + core_attn.softmax_offset.grad if config.softmax_type != "vanilla" else None + ) + else: + dq, dk, dv, dbias = None, None, None, None + d_softmax_offset = None ############ run with CP ############ logging.info(f"[Rank {rank}] Run with context parallelism") @@ -398,13 +425,30 @@ def run_dpa_with_cp( dout_quantizer.amax.fill_(0.0) if fp8_mha: q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) - q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] + if is_training: + q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: - bias_ = bias_.view( - *bias_.shape[:-2], 2 * world_size, bias_.shape[-2] // (2 * world_size), bias_.shape[-1] - ) - bias_ = bias_.index_select(2, seq_idx) - bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) + ndim = bias_.ndim + seq_q_dim = ndim - 2 + if qkv_format == "thd": + bias_seq_idx = seq_idx_q + else: + bias_seq_idx = seq_idx + shape_before_seq = bias_.shape[:seq_q_dim] + seq_q_size = bias_.shape[seq_q_dim] + seq_kv_size = bias_.shape[-1] + if seq_q_size == 1: + # TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s + bias_.requires_grad = False + # Bias is broadcast, no need to partition along sequence dimension + pass + else: + bias_ = bias_.view( + *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + ) + bias_ = bias_.index_select(seq_q_dim, bias_seq_idx) + bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size) + bias_.requires_grad = True # set up environment core_attn.set_context_parallel_group( cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, @@ -439,96 +483,149 @@ def run_dpa_with_cp( ) if config.return_max_logit: out_, max_logit_ = out_ - if fp8_bwd and fp8_mha: - dout_fp8_ = dout_quantizer(dout_) - out_.backward(dout_fp8_) - else: - out_.backward(dout_) - dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad - d_softmax_offset_ = None - if config.softmax_type != "vanilla": - d_softmax_offset_ = core_attn.softmax_offset.grad.clone() + if is_training: + if fp8_bwd and fp8_mha: + dout_fp8_ = dout_quantizer(dout_) + out_.backward(dout_fp8_) + else: + out_.backward(dout_) + if is_training: + dq_, dk_, dv_, dbias_ = ( + q_.grad, + k_.grad, + v_.grad, + bias_.grad if bias_ is not None else None, + ) + d_softmax_offset_ = ( + core_attn.softmax_offset.grad.clone() if config.softmax_type != "vanilla" else None + ) + else: + dq_, dk_, dv_, dbias_ = None, None, None, None + d_softmax_offset_ = None # get outputs - tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_] + tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] if fp8_mha: tensors_to_deq = [out, out_] if not fp8_bwd else tensors for i, tensor in enumerate(tensors_to_deq): - tensors_to_deq[i] = tensor.dequantize() + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: - tensors[0], tensors[4] = tensors_to_deq - for tensor in tensors[4:]: - assert torch.all(~torch.isnan(tensor)) - assert torch.all(~torch.isinf(tensor)) - out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors + tensors[0], tensors[5] = tensors_to_deq + for tensor in tensors: + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + assert torch.all(~torch.isnan(tensor)) + assert torch.all(~torch.isinf(tensor)) + out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ if qkv_format == "bshd" or qkv_format == "sbhd": - dq, dk, dv, out = [ - x.view( - *x.shape[:seq_dim], + if is_training: + dq, dk, dv, out = [ + x.view( + *x.shape[:seq_dim], + 2 * world_size, + x.shape[seq_dim] // (2 * world_size), + *x.shape[(seq_dim + 1) :], + ) + for x in [dq, dk, dv, out] + ] + dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] + dq_, dk_, dv_, out_ = [ + x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) + for x in [dq_, dk_, dv_, out_] + ] + if dbias is not None and dbias_ is not None: + ndim = dbias.ndim + # Query seq is at dim -2 + seq_q_dim = ndim - 2 + shape_before_seq = dbias.shape[:seq_q_dim] + seq_q_size = dbias.shape[seq_q_dim] + seq_kv_size = dbias.shape[-1] + # Reshape to split seq_q dimension + dbias = dbias.view( + *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + ) + # Index select on the newly created dimension (now at position seq_q_dim) + dbias = dbias.index_select(seq_q_dim, seq_idx) + dbias_ = dbias_.view( + *shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size + ) + else: + # Forward-only: reshape only out/out_ for comparison + out = out.view( + *out.shape[:seq_dim], 2 * world_size, - x.shape[seq_dim] // (2 * world_size), - *x.shape[(seq_dim + 1) :], + out.shape[seq_dim] // (2 * world_size), + *out.shape[(seq_dim + 1) :], ) - for x in [dq, dk, dv, out] - ] - dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] - dq_, dk_, dv_, out_ = [ - x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) - for x in [dq_, dk_, dv_, out_] - ] + out = out.index_select(seq_dim, seq_idx) + out_ = out_.view( + *out_.shape[:seq_dim], 2, out_.shape[seq_dim] // 2, *out_.shape[(seq_dim + 1) :] + ) + elif qkv_format == "thd": - dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] - dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded // world_size - cu_seqlens_q = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True - ) - cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q - num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] - for x in [dq, out, dq_, out_]: - if IS_HIP_EXTENSION and torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() != 0: - warnings.warn(f"Rank:{rank} non-zero elements in padding region") - x[cu_seqlens_q_padded[-1] :] = 0 - assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_q[b] == 0 - or torch.count_nonzero( - x[(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[b + 1]] - ).item() - == 0 - ) - cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size - cu_seqlens_kv = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True - ) - cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv - num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] - for x in [dk, dv, dk_, dv_]: - if IS_HIP_EXTENSION and torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() != 0: - warnings.warn(f"Rank:{rank} non-zero elements in padding region") - x[cu_seqlens_kv_padded[-1] :] = 0 - assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_kv[b] == 0 - or torch.count_nonzero( - x[ - (cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]) : cu_seqlens_kv_padded[ - b + 1 + if is_training: + dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] + dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] + cu_seqlens_q_padded = cu_seqlens_q_padded // world_size + cu_seqlens_q = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True + ) + cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q + num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] + for x in [dq, out, dq_, out_]: + if IS_HIP_EXTENSION and torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() != 0: + warnings.warn(f"Rank:{rank} non-zero elements in padding region") + x[cu_seqlens_q_padded[-1] :] = 0 + assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_q[b] == 0 + or torch.count_nonzero( + x[ + (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ + b + 1 + ] ] - ] - ).item() - == 0 - ) + ).item() + == 0 + ) + cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size + cu_seqlens_kv = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True + ) + cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv + num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] + for x in [dk, dv, dk_, dv_]: + if IS_HIP_EXTENSION and torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() != 0: + warnings.warn(f"Rank:{rank} non-zero elements in padding region") + x[cu_seqlens_kv_padded[-1] :] = 0 + assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_kv[b] == 0 + or torch.count_nonzero( + x[ + ( + cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] + ) : cu_seqlens_kv_padded[b + 1] + ] + ).item() + == 0 + ) + else: + # Forward-only: reshape only out/out_ for comparison + out = out.index_select(0, seq_idx_q).contiguous() + out_ = out_ atol, rtol, rmse_tol = get_tols(config, dtype) - tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_] - tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit] - names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"] + tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] + tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] + names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] names_cp = [x + "_cp" for x in names] names_no_cp = [x + "_no_cp" for x in names] is_fp8 = dtype == "fp8" @@ -536,47 +633,113 @@ def run_dpa_with_cp( if t is not None: if "softmax_offset" not in names[i] and "max_logit" not in names[i]: if qkv_format == "bshd": - compare_and_assert( - t[:, 0], - tensors_cp[i][:, 0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[:, 1], - tensors_cp[i][:, 1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) + # Compare the two sequence chunks separately + # Compare dbias + if names[i] == "dbias": + # Compare the two chunks along dimension 2 (the split sequence dimension) + seq_q_dim_bias = 2 + ndim_bias = t.ndim + slice_0 = [slice(None)] * ndim_bias + slice_0[seq_q_dim_bias] = 0 + slice_1 = [slice(None)] * ndim_bias + slice_1[seq_q_dim_bias] = 1 + compare_and_assert( + t[tuple(slice_0)], + tensors_cp[i][tuple(slice_0)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[tuple(slice_1)], + tensors_cp[i][tuple(slice_1)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare the two chunks along dimension 1 (the split sequence dimension) + compare_and_assert( + t[:, 0], + tensors_cp[i][:, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, 1], + tensors_cp[i][:, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) elif qkv_format == "sbhd": - compare_and_assert( - t[0], - tensors_cp[i][0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[1], - tensors_cp[i][1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) + # Compare the two sequence chunks separately + # Compare dbias (same as BSHD) + if names[i] == "dbias": + # Same as bshd: Compare the two chunks along dimension 2 (the split sequence dimension) + seq_q_dim_bias = 2 + ndim_bias = t.ndim + slice_0 = [slice(None)] * ndim_bias + slice_0[seq_q_dim_bias] = 0 + slice_1 = [slice(None)] * ndim_bias + slice_1[seq_q_dim_bias] = 1 + compare_and_assert( + t[tuple(slice_0)], + tensors_cp[i][tuple(slice_0)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[tuple(slice_1)], + tensors_cp[i][tuple(slice_1)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare the two chunks along dimension 0 (the split sequence dimension) + compare_and_assert( + t[0], + tensors_cp[i][0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[1], + tensors_cp[i][1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) elif qkv_format == "thd": compare_and_assert( t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 5eff52c45..a2a8eaa10 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -48,6 +48,7 @@ scaled_init_method_normal, ) from transformer_engine.pytorch.utils import get_cudnn_version +from transformer_engine.pytorch.constants import FP8BwdTensorIdx, FP8FwdTensorIdx import transformer_engine_torch as tex from transformer_engine.pytorch.quantized_tensor import ( Quantizer, @@ -76,6 +77,14 @@ f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}" ) + +# Get determinism +_deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() +) + + # Reset RNG seed and states seed = 1234 reset_rng_states() @@ -223,6 +232,7 @@ def test_dot_product_attention( if config.window_size == (-1, -1) and swa: config.window_size = [2, 2] + config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] if qkv_format == "thd" and "padding" not in config.attn_mask_type: @@ -231,15 +241,26 @@ def test_dot_product_attention( ) # Get backends + # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. + # For all other shapes test fwd+bwd is_training = True + # TODO(KshitijLakhani): Set is_training to True for all cases once cuDNN supports dbias for 111s. + if config.bias_shape == "111s": + is_training = False + logging.info( + "Setting is_training to False as cuDNN does not support dbias for" + f" {config.bias_shape=} " + ) available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, pad_between_seqs=pad_between_seqs, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported: is_training = False available_backends, _, fused_attn_backends = get_available_attention_backends( @@ -248,6 +269,7 @@ def test_dot_product_attention( qkv_layout=qkv_layout, pad_between_seqs=pad_between_seqs, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends @@ -529,6 +551,15 @@ def test_dpa_softmax(dtype, model_configs, model): ) +@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_softmax]) +@pytest.mark.parametrize("model", model_configs_softmax.keys()) +def test_dpa_softmax_thd(dtype, model_configs, model): + """Test DotProductAttention module with different softmax types""" + test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False) + + model_configs_mla = { # test: ModelConfig(b, sq, hq, dqk) "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), @@ -725,7 +756,8 @@ def test_dpa_bias(dtype, model_configs, model): "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"), "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"), - "bias_1_4": ModelConfig( + "bias_1_4": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="111s"), + "bias_1_5": ModelConfig( 4, 2048, 24, @@ -735,7 +767,7 @@ def test_dpa_bias(dtype, model_configs, model): bias_shape="1hss", alibi_type="custom", ), - "bias_1_5": ModelConfig( + "bias_1_6": ModelConfig( 2, 2048, 24, @@ -792,9 +824,10 @@ def test_dpa_bias_shapes(dtype, model_configs, model): @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model", model_configs_swa.keys()) -def test_dpa_sliding_window(dtype, model_configs, model): +@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "sbhd_sbhd_sbhd"]) +def test_dpa_sliding_window(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with sliding window attention""" - test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False) + test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False) model_configs_alibi_slopes = { @@ -1018,11 +1051,14 @@ def _run_dot_product_attention( reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True # Create seqlens @@ -1250,10 +1286,16 @@ def _run_dot_product_attention( bias = None if config.attn_bias_type == "post_scale_bias": shape = "_".join(config.bias_shape) + # For 1hss, 11ss, b1ss, bhss + shape_cache = shape shape = shape.replace("_s_s", "_sq_skv") + # For 111s + if shape == shape_cache: + shape = shape.replace("_1_s", "_1_skv") tensor_shape = [dim_to_num[j] for j in shape.split("_")] bias = torch.randn(tensor_shape, dtype=dtype, device="cuda") - if config.bias_shape != "1hss": + # For 111s, dbias calculation is not supported as of cuDNN 9.18 + if config.bias_shape == "111s": bias.requires_grad = False # Create RNG @@ -1424,6 +1466,7 @@ def test_transformer_layer( qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd") ), is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: @@ -1437,6 +1480,7 @@ def test_transformer_layer( else qkv_format.replace("hd", "3hd") ), is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends @@ -1625,10 +1669,13 @@ def _run_transformer_layer( reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True # Create input tensor @@ -1822,6 +1869,7 @@ def test_dpa_fp8_extra_state(model, dtype): qkv_dtype=torch.float8_e4m3fn, qkv_layout="sb3hd", is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported and not flash_attn_supported: @@ -1981,10 +2029,16 @@ def get_model(dtype, config): @pytest.mark.parametrize("is_training", [True, False]) @pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) def test_mha_fp8_vs_f16( - dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode + dtype, + model, + qkv_format, + input_layernorm, + fp8_dpa_bwd, + RoPE, + is_training, + scaling_mode, ): """Test MultiHeadAttention module in FP8""" - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] @@ -2006,32 +2060,33 @@ def test_mha_fp8_vs_f16( ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe - available_backends, _, fused_attn_backends = get_available_attention_backends( + available_backends, _, _ = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_format.replace("hd", "h3d"), fp8=True, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_format.replace("hd", "h3d"), + is_training=is_training, + deterministic=_deterministic, + ) + _, fused_attn_supported_f16, _ = available_backends if flash_attn_supported + fused_attn_supported_fp8 < 1: pytest.skip("No FP8 attention backend available.") - fused_attn_supported_f16 = False - if not fp8_dpa_bwd: - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_format.replace("hd", "h3d"), - is_training=is_training, - ) - _, fused_attn_supported_f16, _ = available_backends - if not fused_attn_supported_f16: - pytest.skip("No attention backend available.") + if not fused_attn_supported_f16: + pytest.skip("No reference backend available.") if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( @@ -2041,6 +2096,7 @@ def test_mha_fp8_vs_f16( if fused_attn_supported_fp8: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( @@ -2050,6 +2106,7 @@ def test_mha_fp8_vs_f16( if fused_attn_supported_f16: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( @@ -2237,7 +2294,6 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal # config.dropout_p = 0.1 os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1" # Test backend availability @@ -2256,33 +2312,35 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe - available_backends, _, fused_attn_backends = get_available_attention_backends( + available_backends, _, _ = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_layout, fp8=True, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) - flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - if flash_attn_supported + fused_attn_supported < 1: + flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends + available_backends, _, _ = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + is_training=is_training, + deterministic=_deterministic, + ) + _, fused_attn_supported_f16, _ = available_backends + if flash_attn_supported + fused_attn_supported_fp8 < 1: pytest.skip("No FP8 attention backend available.") - if not fp8_dpa_bwd: - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - is_training=is_training, - ) - _, fused_attn_supported, _ = available_backends - if not fused_attn_supported: - pytest.skip("No attention backend available.") + if not fused_attn_supported_f16: + pytest.skip("No reference backend available.") if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: pytest.skip("qkv_layout not applicable for MQA/GQA") if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)") flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2292,34 +2350,39 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if unfused_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( dtype, config, True, qkv_layout, is_training, fp8_recipe ) - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") - fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training, fp8_recipe - ) - - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "1" - if config.dropout_p == 0.0: - # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") - fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout, is_training, fp8_recipe + if fused_attn_supported_fp8: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training, fp8_recipe ) + if fused_attn_supported_f16: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + if config.dropout_p == 0.0: + # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") + fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + dtype, config, False, qkv_layout, is_training, fp8_recipe + ) + atol = 5e-1 rtol = 5e-2 rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] - if flash_attn_supported: + if flash_attn_supported and fused_attn_supported_f16: logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( @@ -2332,7 +2395,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal rmse_tol, True, ) - if unfused_attn_supported: + if unfused_attn_supported and fused_attn_supported_f16: logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( @@ -2358,37 +2421,38 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal rmse_tol, True, ) - if config.dropout_p != 0.0: - # test cuDNN FP8 dropout - assert torch.all( - fused_attn_fwd_fp8 == 1 - ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." - else: - logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) - logging.debug("========== {:^25s} ==========".format("forward output")) - compare_and_assert( - fused_attn_fwd_fp8, - fused_attn_fwd_f16, - "fused_attn_fwd_fp8", - "fused_attn_fwd_f16", - atol, - rtol, - rmse_tol, - True, - ) - if is_training: - for i, _ in enumerate(fused_attn_bwd_f16): - logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - compare_and_assert( - fused_attn_bwd_fp8[i], - fused_attn_bwd_f16[i], - f"fused_attn_bwd_fp8[{i}]", - f"fused_attn_bwd_f16[{i}]", - atol, - rtol, - rmse_tol, - True, - ) + if fused_attn_supported_fp8 and fused_attn_supported_f16: + if config.dropout_p != 0.0: + # test cuDNN FP8 dropout + assert torch.all( + fused_attn_fwd_fp8 == 1 + ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." + else: + logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( + fused_attn_fwd_fp8, + fused_attn_fwd_f16, + "fused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, + True, + ) + if is_training: + for i, _ in enumerate(fused_attn_bwd_f16): + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + compare_and_assert( + fused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"fused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, + True, + ) os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0" @@ -2563,13 +2627,16 @@ def test_custom_mha_fp8_vs_f16(dtype, model): qkv_dtype=torch.float8_e4m3fn, qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd", is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not (fused_attn_backends and unfused_attn_supported): pytest.skip("Not enough backends to run this test with.") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") - unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") + unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16( + dtype, config, "UnfusedDotProductAttention" + ) atol = 5e-1 rtol = 5e-1 @@ -2602,10 +2669,13 @@ def _run_custom_mha_fp8(dtype, config, backend): reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True inp = 0.0001 * torch.randint( @@ -2656,10 +2726,13 @@ def _run_ref_mha_f16(dtype, config, backend): os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True inp = torch.load("qkv.pt").to(device="cuda") @@ -2705,12 +2778,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: _2X_ACC_DGRAD = False _2X_ACC_WGRAD = False -META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT -META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 -META_O = tex.FP8FwdTensors.GEMM2_INPUT -META_DO = tex.FP8BwdTensors.GRAD_INPUT2 -META_S = tex.FP8FwdTensors.GEMM3_OUTPUT -META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +META_QKV = FP8FwdTensorIdx.GEMM1_OUTPUT +META_DQKV = FP8BwdTensorIdx.GRAD_OUTPUT1 +META_O = FP8FwdTensorIdx.GEMM2_INPUT +META_DO = FP8BwdTensorIdx.GRAD_INPUT2 +META_S = FP8FwdTensorIdx.GEMM3_OUTPUT +META_DP = FP8BwdTensorIdx.GRAD_INPUT3 class _custom_mha_fp8(torch.autograd.Function): @@ -2738,14 +2811,14 @@ def forward( d = in_features // h b = cu_seqlens.numel() - 1 - input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] - qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] - o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] - dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] - dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] - s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2] - dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3] + input_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT] + qkv_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_INPUT] + qkv_weight_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT] + o_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT] + dO_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1] + dQKV_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] + s_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT2] + dP_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT3] inp_fp8 = input_quantizer(inp) @@ -2947,7 +3020,7 @@ def forward( cu_seqlens, max_s, ) -> torch.Tensor: - with self.prepare_forward(inp, num_gemms=3) as inp: + with self.prepare_forward_ctx(inp, num_gemms=3) as inp: out = _custom_mha_fp8.apply( inp, self.qkv_weight, diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 3b7573b9a..cfea62d00 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -155,7 +155,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA - "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA + "cp_1_4": ModelConfig( + 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" + ), # MHA + "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_2": ModelConfig( @@ -168,10 +171,31 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): attn_bias_type="post_scale_bias", ), # GQA "cp_2_3": ModelConfig( - 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" + 2, + 4096, + 12, + 128, + num_gqa_groups=2, + attn_mask_type="causal", + attn_bias_type="post_scale_bias", + bias_shape="11ss", ), # GQA "cp_2_4": ModelConfig( - 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) + 2, + 4096, + 12, + 128, + num_gqa_groups=2, + attn_mask_type="causal", + attn_bias_type="post_scale_bias", + bias_shape="111s", + return_max_logit=True, + ), # GQA + "cp_2_5": ModelConfig( + 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" + ), # GQA + "cp_2_6": ModelConfig( + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA @@ -179,6 +203,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA + "cp_3_4": ModelConfig( + 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss", head_dim_v=64 + ), # MLA "cp_4_0": ModelConfig( 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla" ), # GQA @@ -195,7 +222,19 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] + configs = [ + "cp_1_0", + "cp_1_1", + "cp_1_4", + "cp_1_5", + "cp_2_0", + "cp_2_2", + "cp_2_3", + "cp_2_4", + "cp_3_2", + "cp_3_4", + "cp_4_2", + ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] qkv_formats = ["sbhd", "thd"] @@ -293,9 +332,14 @@ def test_cp_with_fused_attention( pytest.skip( "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" ) - if config.softmax_type != "vanilla" and qkv_format == "thd": + if ( + get_cudnn_version() < (9, 18, 0) + and config.softmax_type != "vanilla" + and qkv_format == "thd" + ): pytest.skip( - "CP implementation does not support qkv_format=thd for non-vanilla softmax types!" + "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" + " non-vanilla softmax types!" ) dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -320,12 +364,15 @@ def test_cp_with_fused_attention( Float8CurrentScaling(fp8_dpa=True), DelayedScaling(fp8_dpa=True), ] + # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. + is_training = False if config.bias_shape == "111s" else True available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtypes[dtype] if dtype != "fp8" else get_torch_float8_e4m3_type(), qkv_layout="_".join([qkv_format] * 3), fp8=fp8, fp8_meta=fp8_meta, + is_training=is_training, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -344,6 +391,7 @@ def test_cp_with_fused_attention( fp8_mha=fp8_mha, scaling_mode=scaling_mode, f16_O=f16_O, + is_training=is_training, log_level=pytest_logging_level, ), check=True, diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 7edc0cc90..b16291ff6 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -15,6 +15,7 @@ is_fp8_available, is_mxfp8_available, is_fp8_block_scaling_available, + is_nvfp4_available, ) from transformer_engine.pytorch.quantization import RecipeState from transformer_engine.debug.pytorch.debug_state import TEDebugState @@ -29,6 +30,7 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( return_reason=True ) +nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True) LOG_QUANTIZED_CONFIG_BASE = """ log: @@ -149,6 +151,58 @@ def test_sanity(feature_dirs): assert stat in output, f"Stat {stat} not found in output" +LOG_FP8_MODEL_PARAMETERS_CONFIG_BASE = """ +log: + layers: + layer_name_regex_pattern: .* + enabled: + True + transformer_engine: + LogTensorStats: + enabled: + True + stats: [min] + tensors: [weight, activation, gradient] + freq: 1 + LogFp8TensorStats: + enabled: + True + tensors_struct: + - tensor: activation + stats: [scale_inv_min, scale_inv_max, underflows%] + - tensor: weight + stats: [scale_inv_min, scale_inv_max] + freq: 1 +""" + + +def test_sanity_log_fp8_model_parameters(feature_dirs): + """ + Tests logging stats when model parameters are in fp8. + It tests 3 things: + - LogTensorStats for weight tensor should work without change, + - LogTensorStats and LogFp8TensorStats for non-weight tensors should work without change, + - LogFp8TensorStats should support scale_inv_min, scale_inv_max for weight tensor. + + """ + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + with debug_session(LOG_FP8_MODEL_PARAMETERS_CONFIG_BASE, feature_dirs) as log_dir: + with te.fp8_model_init(recipe=recipe.DelayedScaling()): + model = te.Linear(128, 128, params_dtype=torch.bfloat16) + inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda() + for _ in range(10): + with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()): + output = model(inp) + loss = output.sum() + loss.backward() + debug_api.step() + output = read_log(log_dir) + assert output, "Output is empty" + TEDebugState._reset() + + fp8_recipes = [ recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), @@ -363,6 +417,124 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): TEDebugState._reset() +# NVFP4 tests +LOG_NVFP4_CONFIG_BASE = """ +log: + layers: + layer_name_regex_pattern: .* + enabled: + True + transformer_engine: + LogNvfp4TensorStats: + enabled: True + stats: [ + {stats} + ] + tensors: [activation, gradient, weight] + freq: 2 + start_step: 0 + end_step: 10 +""" + + +def test_nvfp4_numeric(feature_dirs): + """Test that NVFP4 underflows% and MSE stats are computed correctly with known values.""" + if not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + + log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse") + + with debug_session(log_nvfp4_config, feature_dirs) as log_dir: + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + from transformer_engine.pytorch.quantization import RecipeState + + recipe_state = RecipeState.create( + recipe.NVFP4BlockScaling(), + mode="forward", + num_quantizers=3, + ) + + # Create test tensor with known distribution + torch.manual_seed(42) + tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda() + # Add some small values that should underflow to zero in FP4 + tensor[0, :16] = 0.0001 + + quantizer = recipe_state.make_quantizers()[0] + quantized_tensor = quantizer(tensor) + + debug_api.transformer_engine.inspect_tensor( + layer_name="test_layer", + tensor_name="activation", + iteration=0, + tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=quantized_tensor, + columnwise_quantized_tensor=quantized_tensor, + ) + debug_api.step() + + dequantized_tensor = quantized_tensor.dequantize() + output = read_log(log_dir) + + # Validate both stats are present + assert "nvfp4_underflows%" in output, "underflows% stat missing" + assert "nvfp4_mse" in output, "mse stat missing" + + # Extract values and validate numerics + underflows_value = None + mse_value = None + + for line in output.splitlines(): + if "nvfp4_underflows%" in line and "value=" in line: + underflows_value = float(line.split("value=")[1].split()[0]) + if "nvfp4_mse" in line and "value=" in line: + mse_value = float(line.split("value=")[1].split()[0]) + + # Compute expected underflows: non-zero elements that became zero after quantization + orig_nonzero_mask = tensor != 0 + dequant_zero_mask = dequantized_tensor == 0 + expected_underflows = ( + (orig_nonzero_mask & dequant_zero_mask).sum().float() / tensor.numel() * 100 + ) + + # Allow some tolerance + assert underflows_value == pytest.approx(expected_underflows.cpu().item(), abs=1e-4) + + # Compute expected MSE + expected_mse = torch.nn.functional.mse_loss( + dequantized_tensor.float(), tensor.float(), reduction="mean" + ) + + assert mse_value == pytest.approx(expected_mse.cpu().item(), abs=1e-4) + + +def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs): + """Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis.""" + if not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + + # Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately) + log_fp8_config = LOG_QUANTIZED_CONFIG_BASE.format(stats="mxfp8_mse") + + with debug_session(log_fp8_config, feature_dirs) as log_dir: + model = te.Linear(128, 128, params_dtype=torch.bfloat16) + inp = torch.randn(128, 128, dtype=torch.bfloat16).cuda() + + # Should work - recipe-prefixed stats compute MXFP8 separately for comparison + for _ in range(2): + with te.autocast(recipe=recipe.NVFP4BlockScaling()): + output = model(inp) + loss = output.sum() + loss.backward() + debug_api.step() + + output = read_log(log_dir) + # Should have logged MXFP8 MSE stat (what-if scenario) + assert "mxfp8_mse" in output + + def test_log_grouped_gemm(feature_dirs): if not fp8_available: pytest.skip(reason_for_no_fp8) diff --git a/tests/pytorch/debug/test_sanity.py b/tests/pytorch/debug/test_sanity.py index aee5474e7..2bc4b3559 100644 --- a/tests/pytorch/debug/test_sanity.py +++ b/tests/pytorch/debug/test_sanity.py @@ -30,10 +30,17 @@ stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] start_step : 0 end_step: 1 +""", + "log_fp8": """log_fp8: + layers: + layer_types: [linear] + enabled: + True + transformer_engine: LogFp8TensorStats: enabled: True tensors: [activation, gradient, weight] - stats: [underflows, overflows] + stats: [underflows%] start_step : 0 end_step: 1 """, @@ -46,22 +53,26 @@ FakeQuant: enabled: True gemms: [fprop, dgrad, wgrad] + tensors: [activation, weight, gradient] quant_format: FP8E5M2 """, } +# Configs that require FP8 to be enabled +fp8_required_configs = {"log_fp8"} + def _get_model(model_key): if model_key == "linear": - return te.Linear(D, D) + return te.Linear(D, D, name="layer") if model_key == "layernorm_linear": - return te.LayerNormLinear(D, D) + return te.LayerNormLinear(D, D, name="layer") if model_key == "layernorm_mlp": - return te.LayerNormMLP(D, D, D) + return te.LayerNormMLP(D, D, D, name="layer") if model_key == "mha_attention": - return te.MultiheadAttention(D, H) + return te.MultiheadAttention(D, H, name="layer") if model_key == "transformer_layer": - return te.TransformerLayer(D, D, H) + return te.TransformerLayer(D, D, H, name="layer") def _run_forward_backward(model, fp8): @@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir): def test_sanity_debug(model_key, fp8, config_key, feature_dirs): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if not fp8 and config_key in fp8_required_configs: + pytest.skip(f"Config '{config_key}' requires FP8") _run_test(model_key, fp8, configs[config_key], feature_dirs) diff --git a/tests/pytorch/distributed/run_fsdp2_fp8_model.py b/tests/pytorch/distributed/run_fsdp2_fp8_model.py index 5f73e476d..205724263 100644 --- a/tests/pytorch/distributed/run_fsdp2_fp8_model.py +++ b/tests/pytorch/distributed/run_fsdp2_fp8_model.py @@ -1,4 +1,6 @@ #!/usr/bin/python3 +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # See LICENSE for license information. @@ -17,6 +19,7 @@ from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import init_device_mesh +from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch import torch_version from transformer_engine.pytorch.quantization import quantized_model_init from torch.nn.parallel import DistributedDataParallel as DDP @@ -257,6 +260,9 @@ def _train(args): output = model(input_data) target = torch.randn(args.batch_size, args.output_size).to(device) loss = F.mse_loss(output, target) + # AIPYTORCH-427 Forward and backward pass overlap with FSDP2 can cause RCCL deadlock. + if IS_HIP_EXTENSION: + torch.cuda.current_stream().synchronize() loss.backward() optimizer.step() if LOCAL_RANK == 0: diff --git a/tests/pytorch/distributed/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/run_fsdp2_fused_adam.py new file mode 100644 index 000000000..403755713 --- /dev/null +++ b/tests/pytorch/distributed/run_fsdp2_fused_adam.py @@ -0,0 +1,737 @@ +#!/usr/bin/python3 +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""FSDP2 + FusedAdam compatibility tests. + +Launched via torchrun from test_fused_optimizer.py. +""" + +import argparse +import functools +import os + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.utils.cpp_extension import IS_HIP_EXTENSION + +import transformer_engine.pytorch as te +from transformer_engine.pytorch import QuantizedTensor +import transformer_engine.common.recipe + + +def get_recipe_from_string(recipe): + return getattr(transformer_engine.common.recipe, recipe)() + + +HIDDEN_SIZE = 256 +FFN_HIDDEN_SIZE = 1024 +NUM_ATTENTION_HEADS = 8 +NUM_LAYERS = 2 +SEQ_LEN = 32 +BATCH_PER_RANK = 2 +NUM_STEPS = 3 + + +def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + if isinstance(param, QuantizedTensor): + ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")] + else: + ignore_keys = [] + attrs = vars(param) + custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys} + return custom_attrs + + +def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + +def _setup(): + """Common distributed setup. Returns (world_size, local_rank, device).""" + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + # CPU backend required for async save + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + device = torch.device(f"cuda:{local_rank}") + torch.manual_seed(42) + torch.cuda.manual_seed(42) + return world_size, local_rank, device + + +def _build_model(fp8_init, fuse_wgrad_accumulation=False, recipe=None, use_meta_device=True): + """Build a Sequential of TransformerLayers, optionally with FP8 init. + + When fp8_init=True and use_meta_device=True (the default), the model is + created on the meta device to avoid FSDP2 incompatibility with + QuantizedTensor wrapper subclasses (e.g. MXFP8Tensor) whose storage is + inaccessible via data_ptr(). Parameters are materialized after FSDP2 + sharding via reset_parameters() in _shard_model(). + + When use_meta_device=False, the model is created directly on CUDA. + This is the legacy path that does NOT work for block-scaling quantized + tensors (MXFP8, Float8Blockwise, NVFP4) because FSDP2's + reset_sharded_param() crashes on wrapper subclass tensors with + data_ptr() == 0. + """ + if fp8_init: + ctx = te.quantized_model_init(enabled=True, recipe=recipe) + else: + from contextlib import nullcontext + + ctx = nullcontext() + kwargs = dict( + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + fuse_qkv_params=True, + params_dtype=torch.bfloat16, + hidden_dropout=0.0, + attention_dropout=0.0, + ) + if fp8_init and use_meta_device: + kwargs["device"] = "meta" + with ctx: + model = torch.nn.Sequential( + *[ + te.TransformerLayer( + HIDDEN_SIZE, + FFN_HIDDEN_SIZE, + NUM_ATTENTION_HEADS, + **kwargs, + ) + for _ in range(NUM_LAYERS) + ] + ) + return model + + +def _shard_model(model, world_size): + """Apply FSDP2 sharding with save/restore custom attrs. + + If the model was created on the meta device (e.g. for FP8 init), + parameters are materialized after sharding via reset_parameters(). + + restore_custom_attrs is called last so it applies to the final parameter + objects. For meta-device models, reset_parameters() replaces params via + module_setattr (base.py:1336-1339), so attrs must be restored afterward. + """ + has_meta_params = any(p.is_meta for p in model.parameters()) + custom_attrs = save_custom_attrs(model) + mesh = DeviceMesh("cuda", list(range(world_size))) + for child in model.children(): + fully_shard(child, mesh=mesh) + fully_shard(model, mesh=mesh) + if has_meta_params: + for module in model.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + # Restore after reset_parameters so attrs land on the final param objects. + # save_custom_attrs skips private attrs (_*) on QuantizedTensor params; + # reset_parameters fully reinitializes quantizer state from + # self.param_init_meta, so no private attrs need restoring. + restore_custom_attrs(model, custom_attrs) + return model + + +def test_fused_adam_fp8_master_weights(recipe=None): + """FusedAdam with master_weights + FSDP2 + quantized_model_init (FP8 params). + + Verifies: + - Optimizer states are created with correct dtype (float32) + - Training loop completes without error + - DTensor wrapping and QuantizedTensor local tensors are preserved + """ + world_size, _, device = _setup() + + model = _build_model(fp8_init=True, recipe=recipe) + model = _shard_model(model, world_size) + + # Verify params are DTensors with QuantizedTensor local shards + for name, param in model.named_parameters(): + assert isinstance(param, DTensor), f"{name} is not DTensor" + qt_count = sum( + 1 + for _, p in model.named_parameters() + if isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) + ) + assert qt_count > 0, "No QuantizedTensor local tensors after sharding" + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + # AIPYTORCH-427 Forward and backward pass overlap with FSDP2 can cause RCCL deadlock. + if IS_HIP_EXTENSION: + torch.cuda.current_stream().synchronize() + loss.backward() + optimizer.step() + + # Verify optimizer states + for param in model.parameters(): + state = optimizer.state[param] + assert ( + state["exp_avg"].dtype == torch.float32 + ), f"exp_avg dtype {state['exp_avg'].dtype}, expected float32" + assert ( + state["exp_avg_sq"].dtype == torch.float32 + ), f"exp_avg_sq dtype {state['exp_avg_sq'].dtype}, expected float32" + if "master_param" in state: + assert ( + state["master_param"].dtype == torch.float32 + ), f"master_param dtype {state['master_param'].dtype}, expected float32" + + # Verify FP8 params preserved + qt_count = sum( + 1 + for _, p in model.named_parameters() + if isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) + ) + assert qt_count > 0, "No QuantizedTensor local tensors after training" + + dist.destroy_process_group() + + +def test_fused_adam_fp8_master_weights_no_meta(recipe=None): + """FusedAdam with master_weights + FSDP2 + quantized_model_init WITHOUT meta device. + + This is the legacy path that creates quantized params directly on CUDA. + FSDP2's reset_sharded_param() crashes on block-scaling QuantizedTensor + wrapper subclasses (data_ptr() == 0). This test documents that failure. + + For per-tensor FP8 (DelayedScaling, Float8CurrentScaling) this works + because Float8Tensor's storage is accessible via data_ptr(). + """ + world_size, _, device = _setup() + + model = _build_model(fp8_init=True, recipe=recipe, use_meta_device=False) + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + # AIPYTORCH-427 Forward and backward pass overlap with FSDP2 can cause RCCL deadlock. + if IS_HIP_EXTENSION: + torch.cuda.current_stream().synchronize() + loss.backward() + optimizer.step() + + dist.destroy_process_group() + + +def test_fused_adam_bf16(recipe=None): + """FusedAdam with master_weights + FSDP2 + bf16 params (no FP8). + + Verifies the non-FP8 DTensor param path in step() works correctly. + """ + world_size, _, device = _setup() + + model = _build_model(fp8_init=False) + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + losses = [] + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + losses.append(loss.item()) + loss.backward() + optimizer.step() + + # Verify optimizer states are float32 + for param in model.parameters(): + state = optimizer.state[param] + assert state["exp_avg"].dtype == torch.float32 + assert state["exp_avg_sq"].dtype == torch.float32 + + # Verify loss decreased (basic sanity) + assert losses[-1] < losses[0], f"Loss did not decrease: {losses}" + + dist.destroy_process_group() + + +def test_fused_adam_fp8_no_master(recipe=None): + """FusedAdam without master_weights + FSDP2 + FP8 params. + + Verifies FusedAdam works with FSDP2 even without master weights enabled. + """ + world_size, _, device = _setup() + + model = _build_model(fp8_init=True, recipe=recipe) + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=False, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + # AIPYTORCH-427 Forward and backward pass overlap with FSDP2 can cause RCCL deadlock. + if IS_HIP_EXTENSION: + torch.cuda.current_stream().synchronize() + loss.backward() + optimizer.step() + + # Verify DTensors preserved + for name, param in model.named_parameters(): + assert isinstance(param, DTensor), f"{name} lost DTensor wrapping" + + dist.destroy_process_group() + + +def test_fused_adam_bf16_store_param_remainders(recipe=None): + """FusedAdam with master_weights + store_param_remainders + FSDP2 + bf16 params. + + store_param_remainders stores only the trailing 16 remainder bits (int16) + instead of full FP32 master params. The FP32 master can be reconstructed + from BF16 params + int16 remainders. Only works with bf16 params + fp32 + master weights. + + Verifies: + - Training loop completes without error + - Optimizer master_param states are int16 (remainder bits) + - exp_avg and exp_avg_sq are float32 + - Loss decreases (basic sanity) + """ + world_size, _, device = _setup() + + model = _build_model(fp8_init=False) + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + store_param_remainders=True, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + losses = [] + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + losses.append(loss.item()) + loss.backward() + optimizer.step() + + # Verify model params are bf16 (required for store_param_remainders) + for name, param in model.named_parameters(): + assert ( + param.dtype == torch.bfloat16 + ), f"{name}: param dtype {param.dtype}, expected bfloat16" + + # Verify optimizer states + for name, param in model.named_parameters(): + state = optimizer.state[param] + assert ( + state["exp_avg"].dtype == torch.float32 + ), f"{name}: exp_avg dtype {state['exp_avg'].dtype}, expected float32" + assert ( + state["exp_avg_sq"].dtype == torch.float32 + ), f"{name}: exp_avg_sq dtype {state['exp_avg_sq'].dtype}, expected float32" + # store_param_remainders stores master_param as int16 remainder bits + if "master_param" in state: + assert ( + state["master_param"].dtype == torch.int16 + ), f"{name}: master_param dtype {state['master_param'].dtype}, expected int16" + + # Verify loss decreased (basic sanity) + assert losses[-1] < losses[0], f"Loss did not decrease: {losses}" + + dist.destroy_process_group() + + +def test_fuse_wgrad_accumulation(recipe=None): + """fuse_wgrad_accumulation=True + FSDP2 -- expected to fail. + + With vanilla FSDP2, PyTorch's autograd Function.apply unwraps DTensor + inputs to local tensors. The local Float8Tensor inside the autograd + function does not have the `main_grad` attribute (which is set on the + DTensor parameter). This causes an AttributeError during backward. + + Additionally, even if main_grad were accessible, fuse_wgrad_accumulation + writes the gradient directly into main_grad and returns None to autograd, + bypassing FSDP2's reduce-scatter. + """ + world_size, _, device = _setup() + + model = _build_model(fp8_init=True, fuse_wgrad_accumulation=True, recipe=recipe) + + # Allocate main_grad buffers on the DTensor params + for param in model.parameters(): + param.main_grad = torch.zeros(param.shape, dtype=torch.float32, device=param.device) + + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + use_decoupled_grad=True, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + # This is currently failing during backward because the local Float8Tensor + # inside the autograd function doesn't have main_grad. + optimizer.zero_grad(set_to_none=True) + for param in model.parameters(): + param.main_grad.zero_() + + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + + loss = F.mse_loss(output, target) + loss.backward() # Expected to raise AttributeError + + dist.destroy_process_group() + + +def test_safetensors_fp32_export(recipe=None): + """Export full-precision (FP32) model to safetensors from optimizer master weights. + + Verifies: + - get_model_state_dict with full_state_dict gathers all params + - get_optimizer_state_dict with full_state_dict gathers optimizer state + - FP32 state dict is built from optimizer master weights + - All saved tensors are float32 + - Saved tensor shapes match expected (unsharded) shapes + """ + from safetensors.torch import load_file, save_file + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_optimizer_state_dict, + ) + + world_size, _, device = _setup() + + model = _build_model(fp8_init=True, recipe=recipe) + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + # Train a few steps. + for step in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + # AIPYTORCH-427 Forward and backward pass overlap with FSDP2 can cause RCCL deadlock. + if IS_HIP_EXTENSION: + torch.cuda.current_stream().synchronize() + loss.backward() + optimizer.step() + + # Gather full state dicts (all ranks participate). + full_opts = StateDictOptions(full_state_dict=True, cpu_offload=True) + full_model_state = get_model_state_dict(model, options=full_opts) + full_opt_state = get_optimizer_state_dict(model, optimizer, options=full_opts) + + rank = int(os.environ.get("RANK", "0")) + save_path = "/tmp/te_test_fsdp2_model_fp32.safetensors" + + if rank == 0: + # Build FP32 state dict from optimizer master weights. + fp32_state = {} + opt_param_states = full_opt_state.get("state", {}) + + for key, value in full_model_state.items(): + if key in opt_param_states and "master_param" in opt_param_states[key]: + fp32_state[key] = opt_param_states[key]["master_param"].float() + else: + fp32_state[key] = value.float() + + assert len(fp32_state) > 0, "FP32 state dict is empty" + + # Save and verify. + save_file(fp32_state, save_path) + loaded = load_file(save_path) + + assert len(loaded) == len( + fp32_state + ), f"Loaded {len(loaded)} tensors, expected {len(fp32_state)}" + for k, v in loaded.items(): + assert v.dtype == torch.float32, f"{k}: expected float32, got {v.dtype}" + + # Clean up. + os.remove(save_path) + + dist.destroy_process_group() + + +def test_dcp_output_parity(recipe=None, async_save=False): + """DCP save/load round-trip produces bitwise-identical model outputs. + + 1. Builds and trains a model for NUM_STEPS + 2. Runs a forward pass and records the output + 3. Saves model + optimizer state via DCP + 4. Builds a *fresh* model + optimizer (same architecture) + 5. Loads the DCP checkpoint into the fresh model + 6. Runs the same forward pass and asserts outputs are identical + 7. Runs one more training step on both models and asserts outputs still match + """ + import torch.distributed.checkpoint as dcp + + world_size, local_rank, device = _setup() + + # ── Build and train the original model ─────────────────────────── + model = _build_model(fp8_init=True, recipe=recipe) + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + for _ in range(NUM_STEPS): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + # AIPYTORCH-427 Forward and backward pass overlap with FSDP2 can cause RCCL deadlock. + if IS_HIP_EXTENSION: + torch.cuda.current_stream().synchronize() + loss.backward() + optimizer.step() + + # Record reference output from the trained model. + with torch.no_grad(): + with te.autocast(enabled=True, recipe=recipe): + ref_output = model(x).clone() + + # ── Save checkpoint ────────────────────────────────────────────── + checkpoint_dir = "/tmp/te_test_fsdp2_dcp_parity" + + if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling): + # We need to remove the _extra_state keys from the model state dict for DelayedScaling, + # since otherwise we'll run into an error that the tensor sizes are different. The + # alternative is a LoadPlanner that dynamically re-sizes the input tensors, see + # NVIDIA/TransformerEngine#1860 for more details. + model_state = { + k: v for k, v in model.state_dict().items() if not k.endswith("_extra_state") + } + else: + model_state = model.state_dict() + + save_state = {"model": model_state, "optimizer": optimizer.state_dict()} + + if not async_save: + dcp.save(save_state, checkpoint_id=checkpoint_dir) + else: + future = dcp.async_save(save_state, checkpoint_id=checkpoint_dir) + future.result() # Block on async save completion + + # ── Build a fresh model and load the checkpoint ────────────────── + model2 = _build_model(fp8_init=True, recipe=recipe) + model2 = _shard_model(model2, world_size) + + optimizer2 = te.optimizers.FusedAdam( + model2.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + # Populate optimizer state so load_state_dict has matching structure. + optimizer2.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + out_tmp = model2(x) + # AIPYTORCH-427 Forward and backward pass overlap with FSDP2 can cause RCCL deadlock. + if IS_HIP_EXTENSION: + torch.cuda.current_stream().synchronize() + F.mse_loss(out_tmp, target).backward() + optimizer2.step() + + if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling): + model2_state = { + k: v for k, v in model2.state_dict().items() if not k.endswith("_extra_state") + } + else: + model2_state = model2.state_dict() + + state_to_load = {"model": model2_state, "optimizer": optimizer2.state_dict()} + + dcp.load(state_to_load, checkpoint_id=checkpoint_dir) + model2.load_state_dict( + state_to_load["model"], + strict=( + False if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling) else True + ), + ) + optimizer2.load_state_dict(state_to_load["optimizer"]) + + # ── Verify identical forward-pass output ───────────────────────── + with torch.no_grad(): + with te.autocast(enabled=True, recipe=recipe): + loaded_output = model2(x) + + if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling): + # DelayedScaling stores amax history and scaling factors in _extra_state, + # which cannot be saved via DCP due to non-deterministic pickle sizes + # across ranks. The fresh model therefore uses default scaling factors, + # producing small numerical differences from FP8 re-quantization. + torch.testing.assert_close( + loaded_output, + ref_output, + rtol=0.05, + atol=0.1, + msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}", + ) + else: + torch.testing.assert_close( + loaded_output, + ref_output, + rtol=0, + atol=0, + msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}", + ) + + # ── Verify one more training step produces identical results ───── + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + out1 = model(x) + loss1 = F.mse_loss(out1, target) + # AIPYTORCH-427 Forward and backward pass overlap with FSDP2 can cause RCCL deadlock. + if IS_HIP_EXTENSION: + torch.cuda.current_stream().synchronize() + loss1.backward() + optimizer.step() + + optimizer2.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + out2 = model2(x) + loss2 = F.mse_loss(out2, target) + # AIPYTORCH-427 Forward and backward pass overlap with FSDP2 can cause RCCL deadlock. + if IS_HIP_EXTENSION: + torch.cuda.current_stream().synchronize() + loss2.backward() + optimizer2.step() + + if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling): + torch.testing.assert_close( + out2, + out1, + rtol=0.05, + atol=0.1, + msg="Training step after DCP load produces different output", + ) + else: + torch.testing.assert_close( + out2, out1, msg="Training step after DCP load produces different output" + ) + + # ── Cleanup ────────────────────────────────────────────────────── + import shutil + + if int(os.environ.get("RANK", "0")) == 0: + shutil.rmtree(checkpoint_dir, ignore_errors=True) + + dist.destroy_process_group() + + +TESTS = { + "fused_adam_fp8_master_weights": test_fused_adam_fp8_master_weights, + "fused_adam_fp8_master_weights_no_meta": test_fused_adam_fp8_master_weights_no_meta, + "fused_adam_bf16": test_fused_adam_bf16, + "fused_adam_fp8_no_master": test_fused_adam_fp8_no_master, + "fused_adam_bf16_store_param_remainders": test_fused_adam_bf16_store_param_remainders, + "fuse_wgrad_accumulation": test_fuse_wgrad_accumulation, + "dcp_output_parity": functools.partial(test_dcp_output_parity, async_save=False), + "dcp_output_parity_async": functools.partial(test_dcp_output_parity, async_save=True), + "safetensors_fp32_export": test_safetensors_fp32_export, +} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--test", required=True, choices=list(TESTS.keys())) + parser.add_argument( + "--recipe", + type=str, + default="MXFP8BlockScaling", + help="Quantizer type.", + choices=[ + "DelayedScaling", + "Float8CurrentScaling", + "Float8BlockScaling", + "MXFP8BlockScaling", + "NVFP4BlockScaling", + ], + ) + args = parser.parse_args() + recipe = get_recipe_from_string(args.recipe) + TESTS[args.test](recipe) diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index 3b9264279..49653d3b2 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -11,13 +11,7 @@ import argparse import transformer_engine.pytorch as te -from transformer_engine.common.recipe import ( - Format, - DelayedScaling, - Float8CurrentScaling, - MXFP8BlockScaling, -) -from transformer_engine.pytorch import torch_version +import transformer_engine.common.recipe import torch import torch.distributed as dist @@ -27,6 +21,7 @@ from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import init_device_mesh +from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch import QuantizedTensor from contextlib import nullcontext @@ -46,14 +41,23 @@ def _parse_args(argv=None, namespace=None): parser.add_argument("--seq-length", type=int, default=128, help="Sequence length of input") parser.add_argument("--params-dtype", type=str, default="float32", help="Parameter dtype.") parser.add_argument( - "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." + "--fp8-init", + action="store_true", + default=False, + help="Initialize primary weights in FP8.", ) parser.add_argument( "--recipe", type=str, - default="mx_fp8_block_scaling", + default="MXFP8BlockScaling", help="Quantizer type.", - choices=["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"], + choices=[ + "DelayedScaling", + "Float8CurrentScaling", + "Float8BlockScaling", + "MXFP8BlockScaling", + "NVFP4BlockScaling", + ], ) parser.add_argument( "--layer-type", @@ -113,15 +117,8 @@ def get_te_layer_from_string(layer_name): return te_layer_map[layer_name.lower()] -def get_recipe_from_string(recipe, fp8_format=Format.HYBRID): - if recipe == "delayed_scaling": - return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") - elif recipe == "current_scaling": - return Float8CurrentScaling(fp8_format=fp8_format) - elif recipe == "mx_fp8_block_scaling": - return MXFP8BlockScaling(fp8_format=fp8_format) - else: - raise ValueError(f"Unknown quantizer type: {recipe}") +def get_recipe_from_string(recipe): + return getattr(transformer_engine.common.recipe, recipe)() def init_te_model(config): @@ -247,7 +244,7 @@ def test_fp8_fsdp2_allgather(model): module.unshard() # Make sure allgathered parameters match exactly for name, param in model.named_parameters(): - assert torch.allclose(param.dequantize(), fp32_allgathered_params[name]) + torch.testing.assert_close(param.dequantize(), fp32_allgathered_params[name]) # Revert model to original sharded state for module in model.modules(): # Not all modules are wrapped/sharded with FSDP2. @@ -281,8 +278,7 @@ def _train(args): device = torch.device(f"cuda:{LOCAL_RANK}") # FP8 Configuration - fp8_format = Format.HYBRID - fp8_recipe = get_recipe_from_string(args.recipe, fp8_format) + fp8_recipe = get_recipe_from_string(args.recipe) build_model_context_args = {} if not args.fp8_init: @@ -295,13 +291,13 @@ def _train(args): build_model_context_args["enabled"] = True build_model_context_args["recipe"] = fp8_recipe - dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device)/1e6} MB") + dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device) / 1e6} MB") # Create the model on the meta/cuda device as per args with build_model_context(**build_model_context_args): model, inp_shape, out_shape = init_te_model(args) dist_print( f"Memory after model init on device {args.device}:" - f" {torch.cuda.memory_allocated(device)/1e6} MB" + f" {torch.cuda.memory_allocated(device) / 1e6} MB" ) # Creating a DeviceMesh for fully_shard @@ -322,7 +318,7 @@ def _train(args): dist_print(f" Sharded parameters materialized and initialized on cuda device.") dist_print( - f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB" + f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device) / 1e6} MB" ) optimizer = optim.Adam(model.parameters(), lr=1e-3) @@ -330,11 +326,23 @@ def _train(args): for iteration in range(args.iter): # Zero the parameter gradients optimizer.zero_grad() - input_data = torch.randn(inp_shape).to(device) - with te.autocast(enabled=True, recipe=fp8_recipe): - output = model(input_data) - target = torch.randn(out_shape).to(device) - loss = F.mse_loss(output, target) + + input_data = torch.randn(inp_shape, device=device) + target = torch.randn(out_shape, device=device) + + # NVFP4BlockScaling requires bfloat16 inputs in both the forward and backward passes. + with ( + torch.autocast(device_type="cuda", dtype=torch.bfloat16) + if args.recipe == "NVFP4BlockScaling" + else nullcontext() + ): + with te.autocast(enabled=True, recipe=fp8_recipe): + output = model(input_data) + loss = F.mse_loss(output, target) + + # AIPYTORCH-427 Forward and backward pass overlap with FSDP2 can cause RCCL deadlock. + if IS_HIP_EXTENSION: + torch.cuda.current_stream().synchronize() loss.backward() optimizer.step() dist_print(f"Iteration {iteration} completed with loss {loss.item()}") @@ -348,7 +356,7 @@ def _train(args): # destroy_process_group() while other ranks still have in-flight NCCL ops, # which can trigger a NCCL/RCCL comm error. Newer releases (>= 2.6) fixed # this, but we kept a version-guarded barrier on older Torch for stability. - if torch_version() < (2, 6, 0): + if te.torch_version() < (2, 6, 0): dist.barrier(device_ids=[torch.cuda.current_device()]) dist.destroy_process_group() return 0 diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 8a944805c..34fdb1eb9 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -7,6 +7,7 @@ import argparse import datetime import os +import tempfile import subprocess import sys import pathlib @@ -20,6 +21,7 @@ DelayedScaling, Float8CurrentScaling, Float8BlockScaling, + NVFP4BlockScaling, MXFP8BlockScaling, Format, Recipe, @@ -28,14 +30,21 @@ from transformer_engine.pytorch import ( is_fp8_available, is_fp8_block_scaling_available, - is_mxfp8_available, + is_nvfp4_available, QuantizedTensor, Float8Tensor, Float8BlockwiseQTensor, + NVFP4Tensor, + is_mxfp8_available, MXFP8Tensor, ) -from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8 +from transformer_engine.pytorch.tensor.utils import ( + quantize_master_weights, + cast_master_weights_to_fp8, +) +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data + from torch.utils.cpp_extension import IS_HIP_EXTENSION def _get_quantization_recipe(quantization) -> Recipe: @@ -69,6 +78,12 @@ def _get_raw_data(quantized_tensor, colwise=False): quantized_tensor._rowwise_data.dtype == torch.uint8 ), "Float8BlockwiseQTensor _rowwise_data must be uint8" return quantized_tensor._rowwise_data + elif isinstance(quantized_tensor, NVFP4Tensor): + assert hasattr(quantized_tensor, "_rowwise_data"), "NVFP4Tensor missing _rowwise_data" + assert ( + quantized_tensor._rowwise_data.dtype == torch.uint8 + ), "NVFP4Tensor _rowwise_data must be uint8" + return quantized_tensor._rowwise_data elif isinstance(quantized_tensor, MXFP8Tensor): if colwise: assert hasattr( @@ -137,22 +152,45 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals self.offsets = [0] for weight in self.weights: self.offsets.append(self.offsets[-1] + weight.numel()) - # Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may # not be the end range of the last weight. if self.offsets[-1] % self.world_size != 0: self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size + self.weights_are_nvfp4 = isinstance(self.weights[0], NVFP4Tensor) + + # Storage offsets operate on the packed representation. + # For NVFP4: packed size (2 values per byte) + # For others: same as numel() + self.storage_offsets = [0] + self.storage_sizes = [] + for weight in self.weights: + if self.weights_are_nvfp4: + storage_size = _get_raw_data(weight).view(-1).numel() + else: + storage_size = weight.numel() + self.storage_sizes.append(storage_size) + self.storage_offsets.append(self.storage_offsets[-1] + storage_size) + if self.storage_offsets[-1] % self.world_size != 0: + self.storage_offsets[-1] += self.world_size - self.storage_offsets[-1] % self.world_size + self.storage_total = self.storage_offsets[-1] + self.master_weights = [] # The start offset of the master weight in the weight self.start_offsets = [] # The overlapping area of the weight and this rank's local buffer self.overlapping_areas = [] + # Storage equivalents (only populated for NVFP4 tensors). + self.storage_start_offsets = [None] * len(self.weights) + self.storage_overlapping_areas = [None] * len(self.weights) - # The start and end of this rank's local buffer in the global buffer + # The start and end of this rank's local buffer in the global buffer (logical offsets) rank_start = self.offsets[-1] // self.world_size * self.rank rank_end = rank_start + self.offsets[-1] // self.world_size + # Storage-based rank boundaries (for NVFP4: packed size, for others: same as logical) + storage_rank_start = self.storage_total // self.world_size * self.rank + storage_rank_end = storage_rank_start + self.storage_total // self.world_size for weight, offset in zip(self.weights, self.offsets[:-1]): if offset >= rank_end or (offset + weight.numel()) <= rank_start: # This weight is not in this rank's local buffer @@ -180,6 +218,20 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals self.start_offsets.append(start_offset) self.overlapping_areas.append(overlapping_area) + if self.weights_are_nvfp4: + for idx, (weight, storage_offset, storage_size) in enumerate( + zip(self.weights, self.storage_offsets[:-1], self.storage_sizes) + ): + if ( + storage_offset >= storage_rank_end + or (storage_offset + storage_size) <= storage_rank_start + ): + continue + overlap_start = max(storage_rank_start, storage_offset) + overlap_end = min(storage_rank_end, storage_offset + storage_size) + self.storage_start_offsets[idx] = overlap_start - storage_offset + self.storage_overlapping_areas[idx] = (overlap_start, overlap_end) + # Create global buffer for grads reduce-scatter self.grad_buffer = torch.empty( [self.offsets[-1]], dtype=torch.float32, device=weights[0].device @@ -192,9 +244,9 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals else: weight_buffer_dtype = weights[0].dtype self.weight_buffer = torch.empty( - [self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device + [self.storage_total], dtype=weight_buffer_dtype, device=weights[0].device ) - self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end] + self.weight_buffer_slice = self.weight_buffer[storage_rank_start:storage_rank_end] def step(self): # ----------------------------------------------------------------------------------------- @@ -233,10 +285,20 @@ def step(self): # ----------------------------------------------------------------------------------------- # Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight # ----------------------------------------------------------------------------------------- - if isinstance(self.weights[0], QuantizedTensor): - # FP8 weights case - for i in range(1, len(self.weights)): - assert isinstance(self.weights[i], QuantizedTensor) + first_weight = self.weights[0] + if isinstance(first_weight, NVFP4Tensor): + for weight in self.weights: + assert isinstance(weight, NVFP4Tensor) + quantize_master_weights( + self.weights, + self.master_weights, + self.start_offsets, + self.dp_group, + manual_post_all_gather_processing=self.manual_post_all_gather_processing, + ) + elif isinstance(first_weight, (Float8Tensor, Float8BlockwiseQTensor, MXFP8Tensor)): + for weight in self.weights: + assert isinstance(weight, QuantizedTensor) cast_master_weights_to_fp8( self.weights, self.master_weights, @@ -255,20 +317,31 @@ def step(self): end = start_offset + master_weight.numel() weight.data.view(-1)[start:end].copy_(master_weight) + # ----------------------------------------------------------------------------------------- + # Step 5: Copy the updated weights (not all weights) to the weight buffer + # ----------------------------------------------------------------------------------------- colwise_list = [False] if isinstance(self.weights[0], MXFP8Tensor): colwise_list.append(True) for colwise in colwise_list: - # ------------------------------------------------------------------------------------- - # Step 5: Copy the updated weights (not all weights) to the weight buffer - # ------------------------------------------------------------------------------------- for i in range(len(self.weights)): master_weight = self.master_weights[i] if master_weight is None: continue start_offset = self.start_offsets[i] - if isinstance(self.weights[i], QuantizedTensor): + if isinstance(self.weights[i], NVFP4Tensor): + storage_start = self.storage_start_offsets[i] + storage_overlap = self.storage_overlapping_areas[i] + if storage_start is None or storage_overlap is None: + continue + weight = _get_raw_data(self.weights[i]).view(-1) + storage_len = storage_overlap[1] - storage_overlap[0] + weight_slice = weight[storage_start : storage_start + storage_len] + overlapping_start, overlapping_end = storage_overlap + self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) + continue + elif isinstance(self.weights[i], QuantizedTensor): weight = _get_raw_data(self.weights[i], colwise) else: weight = self.weights[i] @@ -286,12 +359,22 @@ def step(self): # ------------------------------------------------------------------------------------- # Step 7: Copy the gathered weights from weight buffer to the actual weights # ------------------------------------------------------------------------------------- - for weight, offset in zip(self.weights, self.offsets[:-1]): - start = offset - end = offset + weight.numel() - if isinstance(weight, QuantizedTensor): - weight = _get_raw_data(weight, colwise) - weight.view(-1).data.copy_(self.weight_buffer[start:end]) + if self.weights_are_nvfp4: + # NVFP4: use storage offsets (packs 2 values per byte) + for weight, storage_offset, storage_size in zip( + self.weights, self.storage_offsets[:-1], self.storage_sizes + ): + start = storage_offset + end = storage_offset + storage_size + raw_data = _get_raw_data(weight) + raw_data.view(-1).data.copy_(self.weight_buffer[start:end]) + else: + for weight, offset in zip(self.weights, self.offsets[:-1]): + start = offset + end = offset + weight.numel() + if isinstance(weight, QuantizedTensor): + weight = _get_raw_data(weight, colwise) + weight.view(-1).data.copy_(self.weight_buffer[start:end]) if self.manual_post_all_gather_processing: quantized_weights = [ @@ -466,8 +549,23 @@ def step(self): # Update the master weight using gradient descent master_weight -= grad * self.lr - # Step 3: Cast master weights to FP8 or BF16 precision - if isinstance(self.weights[0], QuantizedTensor): + # Step 3: Cast master weights to quantized or BF16 precision + first_weight = self.weights[0] + if isinstance(first_weight, NVFP4Tensor): + local_weights = [] + for local_weight in self.local_weights: + if local_weight is None: + local_weights.append(None) + continue + local_weights.append(local_weight) + quantize_master_weights( + self.weights, + self.master_weights, + [idx[0] for idx in self.weight_indices], + self.dp_group, + local_weights, + ) + elif isinstance(first_weight, QuantizedTensor): local_weights = [] for i, local_weight in enumerate(self.local_weights): if self.flatten_columnwise is not None: @@ -740,6 +838,90 @@ def _test_fsdp_cast_master_weights_to_fp8( ), f"Loss mismatch at rank {rank}, step {i} for {quantization} (FSDP)" +def _test_cast_master_weights_to_nvfp4(dp_group, manual_post_all_gather_processing): + available, reason = is_nvfp4_available(return_reason=True) + if not available: + pytest.skip(reason) + + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] + mock_group = mock_groups[rank] + + linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} + # Disable stochastic rounding for deterministic gradients + nvfp4_recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) + + with te.quantized_model_init( + enabled=True, recipe=nvfp4_recipe, preserve_high_precision_init_val=True + ): + model_nvfp4 = nn.Sequential( + te.Linear(128, 256 + 64, **linear_kwargs), + te.Linear(256 + 64, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + # Create model with bf16 weights + model = nn.Sequential( + te.Linear(128, 256 + 64, **linear_kwargs), + te.Linear(256 + 64, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): + high_precision_init_val = w_nvfp4.get_high_precision_init_val() + w.data.copy_(high_precision_init_val) + + for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): + w_nvfp4.main_grad = torch.zeros_like(w_nvfp4, dtype=torch.float32, device="cuda") + w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda") + + optimizer_nvfp4 = MiniZero_1( + [w for w in model_nvfp4.parameters()], 10.0, dp_group, manual_post_all_gather_processing + ) + optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) + + for i in range(500): + for w_nvfp4, w in zip(model_nvfp4.parameters(), model.parameters()): + w_nvfp4.main_grad.zero_() + w.main_grad.zero_() + + inputs = [ + torch.randn(2048, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + ] + x = inputs[rank] + + with te.autocast( + enabled=True, + recipe=nvfp4_recipe, + amax_reduction_group=mock_group, + ): + y_nvfp4 = model_nvfp4(x) + + with te.autocast( + enabled=True, + recipe=nvfp4_recipe, + amax_reduction_group=mock_group, + ): + y = model(x) + + targets = [torch.randn_like(y) for _ in range(world_size)] + target = targets[rank] + loss_nvfp4 = nn.MSELoss()(y_nvfp4, target) + loss = nn.MSELoss()(y, target) + + loss_nvfp4.backward() + loss.backward() + + optimizer.step() + optimizer_nvfp4.step() + + torch.testing.assert_close(loss_nvfp4, loss, atol=0, rtol=0) + + def run_parallel_tests() -> None: """Run parallel tests""" @@ -775,9 +957,9 @@ def run_parallel_tests() -> None: keep_fp8_weight_transpose_caches = [True] if IS_HIP_EXTENSION: keep_fp8_weight_transpose_caches.append(False) - + print("starting mini optimizer test") _test_mini_optimizer(dp_group) - + print("starting cast master weights to fp8 test") for quantization in quantizations: for post_ag_processing in manual_post_all_gather_processings: for keep_fp8_weight_transpose_cache in keep_fp8_weight_transpose_caches: @@ -787,6 +969,38 @@ def run_parallel_tests() -> None: _test_fsdp_cast_master_weights_to_fp8( quantization, dp_group, post_ag_processing, keep_fp8_weight_transpose_cache ) + nvfp4_available, _ = is_nvfp4_available(return_reason=True) + if nvfp4_available: + print("starting cast master weights to nvfp4 test") + for post_ag_processing in manual_post_all_gather_processings: + _test_cast_master_weights_to_nvfp4(dp_group, post_ag_processing) + + dist.destroy_process_group() + + +def run_parallel_nvfp4_partial_cast_test() -> None: + """Run the NVFP4 partial-cast distributed worker test.""" + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + dp_group = dist.new_group(backend="nccl") + + _test_nvfp4_partial_cast_matches_full(dp_group) dist.destroy_process_group() @@ -818,9 +1032,281 @@ def test_cast_master_weights_to_fp8(world_size: int) -> None: def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--parallel", action="store_true", help="Run parallel tests") + parser.add_argument( + "--parallel-nvfp4-partial", + action="store_true", + help="Run NVFP4 partial-cast distributed worker test", + ) args = parser.parse_args() if args.parallel: run_parallel_tests() + elif args.parallel_nvfp4_partial: + run_parallel_nvfp4_partial_cast_test() + + +# Debugging tests for NVFP4 +def test_nvfp4_transpose_kernel() -> None: + """Test that nvfp4_transpose kernel produces bitwise identical results to reference.""" + available, reason = is_nvfp4_available(return_reason=True) + if not available: + pytest.skip(reason) + + torch.manual_seed(1234) + device = torch.device("cuda") + shape = (2048, 5120) + master_weight = torch.randn(shape, dtype=torch.float32, device=device) + + print("\n=== Testing NVFP4 transpose kernel ===") + + # Create reference with both rowwise and columnwise data + quantizer_with_colwise = NVFP4Quantizer( + rowwise=True, columnwise=True, with_2d_quantization=True + ) + reference_tensor = quantizer_with_colwise(master_weight.to(torch.bfloat16)) + assert reference_tensor._columnwise_data is not None, "Reference should have columnwise data" + assert ( + reference_tensor._columnwise_scale_inv is not None + ), "Reference should have columnwise scale_inv" + reference_columnwise_data = reference_tensor._columnwise_data.detach().clone() + reference_columnwise_scale_inv = reference_tensor._columnwise_scale_inv.detach().clone() + reference_columnwise_amax = ( + reference_tensor._amax_columnwise.detach().clone() + if reference_tensor._amax_columnwise is not None + else None + ) + + # Create tensor with only rowwise data, then call _create_columnwise() + quantizer_rowwise_only = NVFP4Quantizer( + rowwise=True, columnwise=False, with_2d_quantization=True + ) + test_tensor = quantizer_rowwise_only(master_weight.to(torch.bfloat16)) + assert test_tensor._columnwise_data is None, "Test tensor should not have columnwise data yet" + + # Now call _create_columnwise() which uses our nvfp4_transpose kernel + test_tensor.update_usage(rowwise_usage=True, columnwise_usage=True) + assert ( + test_tensor._columnwise_data is not None + ), "Test tensor should have columnwise data after _create_columnwise()" + assert ( + test_tensor._columnwise_scale_inv is not None + ), "Test tensor should have columnwise scale_inv after _create_columnwise()" + + # Compare columnwise data - should be bitwise identical + torch.testing.assert_close( + test_tensor._columnwise_data, + reference_columnwise_data, + atol=0, + rtol=0, + msg="NVFP4 transpose kernel produced different columnwise data than reference!", + ) + + torch.testing.assert_close( + test_tensor._columnwise_scale_inv, + reference_columnwise_scale_inv, + atol=0, + rtol=0, + msg="NVFP4 _create_columnwise produced different columnwise scale_inv than reference!", + ) + + torch.testing.assert_close( + test_tensor._amax_columnwise, + reference_columnwise_amax, + atol=0, + rtol=0, + msg="NVFP4 _create_columnwise produced different columnwise amax than reference!", + ) + + +def _test_nvfp4_partial_cast_matches_full(dp_group) -> None: + """Multi-GPU worker: split master weight, partial cast on each rank, gather, compare.""" + WORLD_RANK = dist.get_rank(dp_group) + WORLD_SIZE = dist.get_world_size(dp_group) + + torch.manual_seed(1234) + device = torch.device("cuda") + # Shape must be divisible by WORLD_SIZE for even splitting + # Also ensure dimensions are multiples of 16 for NVFP4 tiles + shape = (4096, 4096) + total_elements = shape[0] * shape[1] + assert total_elements % WORLD_SIZE == 0, "Total elements must be divisible by WORLD_SIZE" + + # Full master weight (same on all ranks due to same seed) + full_master_weight = torch.randn(shape, dtype=torch.float32, device=device) + + # Create reference using full quantization + quantizer = NVFP4Quantizer(rowwise=True, columnwise=False, with_2d_quantization=True) + reference_tensor = quantizer(full_master_weight.to(torch.bfloat16)) + reference_data = reference_tensor._rowwise_data.detach().clone() + reference_scale = reference_tensor._rowwise_scale_inv.detach().clone() + reference_amax = reference_tensor._amax_rowwise.detach().clone() + + # Split master weight evenly across ranks + shard_size = total_elements // WORLD_SIZE + start_offset = WORLD_RANK * shard_size + end_offset = start_offset + shard_size + master_weight_shard = full_master_weight.view(-1)[start_offset:end_offset].clone() + + # Create empty NVFP4 tensor for this rank (full shape, but we'll only fill our shard) + nvfp4_tensor = quantizer.make_empty(shape, dtype=torch.bfloat16, device=device) + nvfp4_tensor._rowwise_data.zero_() + nvfp4_tensor._rowwise_scale_inv.zero_() + if nvfp4_tensor._amax_rowwise is not None: + nvfp4_tensor._amax_rowwise.zero_() + + # Partial cast on each rank's shard + quantize_master_weights( + [nvfp4_tensor], + [master_weight_shard], + [start_offset], + dp_group, + ) + + # All-gather the rowwise data (packed FP4 bytes) + # Each rank has the full tensor but only its shard is filled + # We need to all-gather the shards + rowwise_data_flat = nvfp4_tensor._rowwise_data.view(-1) + + # For NVFP4, 2 elements are packed per byte, so byte shard size is shard_size // 2 + byte_shard_size = shard_size // 2 + byte_start = WORLD_RANK * byte_shard_size + byte_end = byte_start + byte_shard_size + my_shard_bytes = rowwise_data_flat[byte_start:byte_end].contiguous() + + # Gather all shards + gathered_shards = [torch.empty_like(my_shard_bytes) for _ in range(WORLD_SIZE)] + dist.all_gather(gathered_shards, my_shard_bytes, group=dp_group) + + # Reconstruct the full rowwise data + gathered_data = torch.cat(gathered_shards, dim=0).view(reference_data.shape) + + # Compare with reference + torch.testing.assert_close( + gathered_data, + reference_data, + atol=0, + rtol=0, + msg=f"[Rank {WORLD_RANK}] Gathered rowwise data does not match reference!", + ) + + # Also verify scale matches (scale should be identical on all ranks after all-reduce) + torch.testing.assert_close( + nvfp4_tensor._rowwise_scale_inv, + reference_scale, + atol=0, + rtol=0, + msg=f"[Rank {WORLD_RANK}] Scale does not match reference!", + ) + + # Verify amax matches + torch.testing.assert_close( + nvfp4_tensor._amax_rowwise, + reference_amax, + atol=0, + rtol=0, + msg=f"[Rank {WORLD_RANK}] Amax does not match reference!", + ) + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="NVFP4 partial-cast test needs at least 2 GPUs." +) +@pytest.mark.parametrize("world_size", [2]) +def test_nvfp4_partial_cast_matches_full(world_size: int) -> None: + """Launch a distributed job for NVFP4 partial-cast equivalence test.""" + + available, reason = is_nvfp4_available(return_reason=True) + if not available: + pytest.skip(reason) + + python_exe = pathlib.Path(sys.executable).resolve() + current_file = pathlib.Path(__file__).resolve() + command = [ + python_exe, + "-m", + "torch.distributed.run", + f"--nproc_per_node={world_size}", + current_file, + "--parallel-nvfp4-partial", + ] + subprocess.run(command, check=True) + + +def test_single_gpu_partial_cast_vs_full(): + """ + Single GPU test: compare quantize_master_weights (offset=0) vs quantizer(). + This isolates whether the issue is in our manual Python scale computation or elsewhere. + """ + available, reason = is_nvfp4_available(return_reason=True) + if not available: + pytest.skip(reason) + + torch.manual_seed(1234) + device = torch.device("cuda") + + # Test with same shape as the optimizer test + shape = (2048, 2048) + + # Create BF16 master weight + master_weight = torch.randn(shape, dtype=torch.bfloat16, device=device) + + # === Reference: Use NVFP4Quantizer directly === + quantizer = NVFP4Quantizer(rowwise=True, columnwise=False, with_2d_quantization=True) + ref = quantizer(master_weight) + ref_data = ref._rowwise_data.clone() + ref_scale = ref._rowwise_scale_inv.clone() + ref_amax = ref._amax_rowwise.clone() + + # === Test: Use quantize_master_weights with offset=0 (full tensor) === + # Create empty NVFP4 tensor + test_tensor = quantizer.make_empty(shape, dtype=torch.bfloat16, device=device) + test_tensor._rowwise_data.zero_() + test_tensor._rowwise_scale_inv.zero_() + if test_tensor._amax_rowwise is not None: + test_tensor._amax_rowwise.zero_() + + # Create a local single-rank process group when running under plain pytest. + initialized_here = False + rendezvous_file = None + if not dist.is_initialized(): + torch.cuda.set_device(0) + with tempfile.NamedTemporaryFile(delete=False) as f: + rendezvous_file = pathlib.Path(f.name) + dist.init_process_group( + backend="nccl", + init_method=rendezvous_file.resolve().as_uri(), + rank=0, + world_size=1, + ) + initialized_here = True + + if dist.get_world_size() != 1: + pytest.skip("test_single_gpu_partial_cast_vs_full requires world_size == 1") + + mock_group = dist.new_group(ranks=[0], backend="nccl") + try: + quantize_master_weights( + [test_tensor], + [master_weight.view(-1)], # Flatten as expected + [0], # offset=0 means full tensor + mock_group, + ) + finally: + if initialized_here: + dist.destroy_process_group() + if rendezvous_file is not None: + rendezvous_file.unlink(missing_ok=True) + + # Compare amax + amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax) + assert amax_match, f"Amax mismatch: {test_tensor._amax_rowwise} vs {ref_amax}" + + # Compare scale + scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale) + assert scale_match, f"Scale mismatch: {test_tensor._rowwise_scale_inv} vs {ref_scale}" + + # Compare data + data_match = torch.equal(test_tensor._rowwise_data, ref_data) + assert data_match, f"Data mismatch" if __name__ == "__main__": diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 8d98c6263..db9352fcf 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -228,7 +228,6 @@ def test_bulk_overlaps(comm_type, quantization, connections): (te.Linear.__name__, "row", False), (te.Linear.__name__, "column", False), (te.Linear.__name__, "column", True), - (te.LayerNormLinear.__name__, "row", False), (te.LayerNormLinear.__name__, "column", False), (te.LayerNormLinear.__name__, "column", True), ] @@ -243,7 +242,6 @@ def test_bulk_overlaps(comm_type, quantization, connections): f" {te.Linear.__name__} - ROW-PARALLEL ", f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f" {te.Linear.__name__} - COL-PARALLEL - DGRAD+RS ", - f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", ] @@ -272,7 +270,6 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d (te.Linear.__name__, "row", False), (te.Linear.__name__, "column", False), (te.Linear.__name__, "column", True), - (te.LayerNormLinear.__name__, "row", False), (te.LayerNormLinear.__name__, "column", False), (te.LayerNormLinear.__name__, "column", True), ] @@ -287,7 +284,6 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d f"{te.Linear.__name__}-row_tensor_parallel", f"{te.Linear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD", f"{te.Linear.__name__}-col_tensor_parallel-DGRAD+RS", - f"{te.LayerNormLinear.__name__}-row_tensor_parallel", f"{te.LayerNormLinear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD", f"{te.LayerNormLinear.__name__}-col_tensor_parallel-DGRAD+RS", ] diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index e328e5775..02e45d99c 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -3,19 +3,64 @@ # See LICENSE for license information. import os -import pytest import subprocess from pathlib import Path -import transformer_engine.pytorch as te +import pytest import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch import fp8 -fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) -mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) NUM_PROCS: int = torch.cuda.device_count() +def check_nvfp4_support(): + supported, reason = fp8.check_nvfp4_support() + if supported and torch.cuda.get_device_capability()[0] == 12: + return ( + False, + ( + "NVFP4BlockScaling is failing on SM120 with " + "hadamard_transform/hadamard_transform_cast_fusion.cu:672 in function " + "rht_gemm_ntt_w_sfc: CUDA Error: invalid argument" + ), + ) + + return supported, reason + + +# Each entry: (recipe_class_name, check_fn) +_FP8_RECIPE_CONFIGS = [ + ("DelayedScaling", fp8.check_fp8_support), + ("Float8CurrentScaling", fp8.check_fp8_support), + ("Float8BlockScaling", fp8.check_fp8_block_scaling_support), + ("MXFP8BlockScaling", fp8.check_mxfp8_support), + ("NVFP4BlockScaling", check_nvfp4_support), +] + + +def _parametrize_fp8_recipes(): + """Generate pytest.param objects with skip marks for unsupported FP8 recipes.""" + params = [] + for name, check_fn in _FP8_RECIPE_CONFIGS: + supported, reason = check_fn() + params.append( + pytest.param( + name, + id=name, + marks=pytest.mark.skipif(not supported, reason=reason), + ) + ) + return params + + +@pytest.fixture(params=_parametrize_fp8_recipes()) +def fp_recipe(request): + """Parametrized fixture providing FP8 recipe Hydra overrides for each supported TE recipe.""" + return request.param + + def _run_test(fp_init, sharding_dims, recipe, layer_type): test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] @@ -32,28 +77,175 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type): test_cmd += ["--recipe", recipe] test_cmd += ["--layer-type", layer_type] - result = subprocess.run(test_cmd, env=os.environ, check=True) + subprocess.run(test_cmd, env=os.environ, check=True) -@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") @pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("fp8_init", (False, True)) -@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling")) @pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer")) -def test_distributed(fp8_init, sharding_dims, recipe, layer_type): +def test_distributed(fp8_init, sharding_dims, fp_recipe, layer_type): + + if fp_recipe in ("Float8BlockScaling", "NVFP4BlockScaling") and fp8_init: + pytest.xfail(f"{fp_recipe} + fp8_init: test_fp8_fsdp2_allgather is currently failing.") + + _run_test(fp8_init, sharding_dims, fp_recipe, layer_type) + + +## ── FusedAdam + FSDP2 tests ───────────────────────────────────────── + + +def _run_fused_adam_test(test_name, recipe="delayed_scaling"): + """Launch an FSDP2 + FusedAdam test via torchrun.""" + test_path = Path(__file__).parent.resolve() / "run_fsdp2_fused_adam.py" + nproc = min(NUM_PROCS, 2) # These tests only need 2 GPUs + test_cmd = [ + "torchrun", + f"--nproc_per_node={nproc}", + str(test_path), + "--test", + test_name, + "--recipe", + recipe, + ] + + subprocess.run(test_cmd, env=os.environ, check=True) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_fused_adam_fp8_master_weights(fp_recipe): + """FusedAdam(master_weights=True) + FSDP2 + quantized_model_init (meta device init).""" + if fp_recipe in ("NVFP4BlockScaling",): + pytest.xfail( + f"{fp_recipe}: quantized_model_init and FSDP2 is not currently supported, since the " + "block tensor is dequantized before we flatten it for FSDP2." + ) + _run_fused_adam_test("fused_adam_fp8_master_weights", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_fused_adam_fp8_master_weights_no_meta(fp_recipe): + """FusedAdam(master_weights=True) + FSDP2 + quantized_model_init (CUDA init, no meta device). + + Block-scaling QuantizedTensors (MXFP8, Float8Blockwise, NVFP4) are wrapper + subclasses with data_ptr() == 0. Without meta-device init, FSDP2's + reset_sharded_param() crashes with 'invalid python storage'. + Per-tensor FP8 (DelayedScaling, Float8CurrentScaling) works because + Float8Tensor's storage is accessible. + """ + if fp_recipe in ("MXFP8BlockScaling", "Float8BlockScaling", "NVFP4BlockScaling"): + pytest.xfail( + f"{fp_recipe}: FSDP2 without meta-device init crashes on block-scaling " + "QuantizedTensor wrapper subclasses (data_ptr() == 0). " + "Use device='meta' + reset_parameters() after sharding." + ) + _run_fused_adam_test("fused_adam_fp8_master_weights_no_meta", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_fused_adam_bf16(fp_recipe): + """FusedAdam(master_weights=True) + FSDP2 + bf16 params (no FP8).""" + _run_fused_adam_test("fused_adam_bf16", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_fused_adam_fp8_no_master(fp_recipe): + """FusedAdam(master_weights=False) + FSDP2 + FP8 params.""" + if fp_recipe in ("MXFP8BlockScaling", "Float8BlockScaling", "NVFP4BlockScaling"): + pytest.xfail( + f"{fp_recipe}: FusedAdam without master_weights does not support " + "block-scaling quantized tensors. Use master_weights=True." + ) + _run_fused_adam_test("fused_adam_fp8_no_master", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_fused_adam_bf16_store_param_remainders(fp_recipe): + """FusedAdam(master_weights=True, store_param_remainders=True) + FSDP2 + bf16.""" + _run_fused_adam_test("fused_adam_bf16_store_param_remainders", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_dcp_output_parity(fp_recipe): + """DCP save/load round-trip into a fresh model produces identical outputs.""" + if fp_recipe == "MXFP8BlockScaling": + pytest.xfail( + "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " + "MXFP8 quantized tensors, causing illegal memory access" + ) + + if fp_recipe == "NVFP4BlockScaling": + pytest.xfail( + "NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() " + "which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage" + ) + + if fp_recipe == "Float8BlockScaling" and torch.cuda.get_device_capability()[0] == 12: + pytest.xfail( + "Float8BlockScaling is failing on SM120 with RuntimeError: " + "transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu:534 " + "in function quantize_transpose_vector_blockwise: Assertion failed: pow2_scale. On " + "Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, which " + "requires using power of two scaling factors." + ) + + _run_fused_adam_test("dcp_output_parity", fp_recipe) + + +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_dcp_output_parity_async(fp_recipe): + """DCP save/load round-trip into a fresh model produces identical outputs.""" + if fp_recipe == "MXFP8BlockScaling": + pytest.xfail( + "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " + "MXFP8 quantized tensors, causing illegal memory access: " + "/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh:92 in function " + "multi_tensor_apply: CUDA Error: an illegal memory access was encountered" + ) + + if fp_recipe == "NVFP4BlockScaling": + pytest.xfail( + "NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() " + "which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage" + ) + + if fp_recipe == "Float8BlockScaling": + pytest.xfail( + "Float8BlockScaling: async DCP save/load round-trip produces different model " + "outputs — quantization metadata (scales) is not correctly persisted through " + "async distributed checkpointing. On SM120, additionally fails with pow2_scale " + "assertion in quantize_transpose_vector_blockwise." + ) + + _run_fused_adam_test("dcp_output_parity_async", fp_recipe) + - # Skip invalid configurations - if torch.cuda.device_count() < 4: - pytest.skip("FSDP2 test requires at least 4 GPUs") +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_safetensors_fp32_export(fp_recipe): + """Export FP32 model from optimizer master weights to safetensors.""" + if fp_recipe == "MXFP8BlockScaling": + pytest.xfail( + "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " + "MXFP8 quantized tensors, causing illegal memory access" + ) + _run_fused_adam_test("safetensors_fp32_export", fp_recipe) - if recipe == "mx_fp8_block_scaling" and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - elif not fp8_available: - pytest.skip(reason_for_no_fp8) - _run_test(fp8_init, sharding_dims, recipe, layer_type) +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +@pytest.mark.xfail( + reason=( + "fuse_wgrad_accumulation is incompatible with vanilla FSDP2: " + "autograd Function.apply unwraps DTensors to local tensors, so " + "main_grad (set on the DTensor) is inaccessible during backward. " + "Additionally, the fused wgrad GEMM bypasses FSDP2's reduce-scatter." + ), + raises=subprocess.CalledProcessError, + strict=True, +) +def test_fsdp2_fuse_wgrad_accumulation(fp_recipe): + """fuse_wgrad_accumulation=True + FSDP2 -- expected to fail.""" + _run_fused_adam_test("fuse_wgrad_accumulation", fp_recipe) def test_dummy() -> None: diff --git a/tests/pytorch/mxfp8/mxfp8_utils.py b/tests/pytorch/mxfp8/mxfp8_utils.py new file mode 100644 index 000000000..99e088a20 --- /dev/null +++ b/tests/pytorch/mxfp8/mxfp8_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import math + + +# Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization without padding +def get_mxfp8_scale_shape_no_padding(shape, columnwise): + M, K = 1, 1 + M = math.prod(shape[:-1]) + K = shape[-1] + + if columnwise: + outer = M // 32 + inner = K + return (outer, inner) + # rowwise + outer = M + inner = K // 32 + return (outer, inner) + + +def _rowwise_swizzle_mxfp8_scale(input_M, input_N, scale: torch.Tensor) -> torch.Tensor: + assert scale.dim() == 2 + assert input_M == scale.shape[0] + assert input_N // 32 == scale.shape[1] + + x = scale.view(input_M // 128, 4, 32, input_N // 128, 4) + x = x.permute(0, 3, 2, 1, 4) + x = x.contiguous() + # View back as original 2D shape + x = x.view(input_M, input_N // 32) + return x + + +def _columnwise_swizzle_mxfp8_scale(input_M, input_N, scale: torch.Tensor) -> torch.Tensor: + assert scale.dim() == 2 + assert input_M // 32 == scale.shape[0] + assert input_N == scale.shape[1] + + x = scale.view(input_M // 128, 4, input_N // 128, 4, 32) + x = x.permute(2, 0, 4, 3, 1) + x = x.contiguous() + + # alternative way: transpose the scale and do rowwise swizzle with M, N swapped + x1 = _rowwise_swizzle_mxfp8_scale(input_N, input_M, scale.transpose(0, 1).contiguous()) + torch.testing.assert_close( + x.view(-1), x1.view(-1), atol=0.0, rtol=0.0, msg="columnwise swizzle sanity check failed" + ) + + # View back as original 2D shape + x = x.view(input_M // 32, input_N) + return x + + +def swizzle_mxfp8_scale(input_M, input_N, scale: torch.Tensor, columnwise: bool) -> torch.Tensor: + if not columnwise: + return _rowwise_swizzle_mxfp8_scale(input_M, input_N, scale) + else: + return _columnwise_swizzle_mxfp8_scale(input_M, input_N, scale) diff --git a/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py b/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py new file mode 100644 index 000000000..3c197bc6f --- /dev/null +++ b/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py @@ -0,0 +1,475 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import MXFP8Quantizer + +import pytest +import torch +import random +import math + +from mxfp8_utils import swizzle_mxfp8_scale, get_mxfp8_scale_shape_no_padding + +recipe_available, reason_for_no_recipe = te.is_mxfp8_available(return_reason=True) + + +def generate_random_multiples_sum(total=8192, n=4, multiple=64): + if total % multiple != 0: + raise ValueError(f"Total ({total}) must be a multiple of {multiple}") + if (total // multiple) < n: + raise ValueError("Total too small for given n and multiple.") + + # Work in units of multiples + total_units = total // multiple + + # choose n−1 random cut points in [1, total_units−1) + cuts = sorted(random.sample(range(1, total_units), n - 1)) + + # convert to segment lengths + parts = ( + [cuts[0]] + [cuts[i] - cuts[i - 1] for i in range(1, len(cuts))] + [total_units - cuts[-1]] + ) + + # convert back to multiples + return [p * multiple for p in parts] + + +def generate_split_sections(M: int, N: int, edge_cases: str) -> list[int]: + least_multiple = 128 + num_chunks = 4 + split_sections = None + + avg_split = M // num_chunks + + if M == 0 or N == 0: + # all zeros + return [0] * num_chunks + if edge_cases == "regular": + split_sections = [avg_split] * num_chunks + elif edge_cases == "zero_tokens_all": + split_sections = [0] * num_chunks + elif edge_cases == "zero_tokens_front": + split_sections = [0] + [avg_split] * (num_chunks - 2) + [avg_split * 2] + elif edge_cases == "zero_tokens_end": + split_sections = [avg_split * 2] + [avg_split] * (num_chunks - 2) + [0] + elif edge_cases == "zero_tokens_middle": + split_sections = [avg_split] * (num_chunks - 2) + [0] + [avg_split * 2] + elif edge_cases == "random_uneven_split": + split_sections = generate_random_multiples_sum(M, num_chunks, least_multiple) + else: + raise ValueError(f"Invalid edge case: {edge_cases}") + + # adds up the split_sections to make it M + assert sum(split_sections) == M, "The split_sections do not add up to M" + + # make sure every split_section is a multiple of least_multiple + for split_section in split_sections: + assert ( + split_section % least_multiple == 0 + ), "The split_sections are not multiples of least_multiple" + + return split_sections + + +def reference_group_quantize( + x: torch.Tensor, + quantizers: list[MXFP8Quantizer], + split_sections: list[int], + return_identity: bool, + return_transpose: bool, +) -> torch.Tensor: + x_chunks = torch.split(x, split_sections) + + # rowwise quantization + x_qx = [] + x_sx = [] + # columnwise quantization + x_qx_t = [] + x_sx_t = [] + + for i in range(len(x_chunks)): + x_chunk = x_chunks[i] + x_mxfp8_res = quantizers[i](x_chunk) + if return_identity: + x_qx.append(x_mxfp8_res._rowwise_data.view(dtype=torch.uint8)) + x_sx.append(x_mxfp8_res._rowwise_scale_inv) + else: + x_qx.append(None) + x_sx.append(None) + if return_transpose: + x_qx_t.append(x_mxfp8_res._columnwise_data.view(dtype=torch.uint8)) + x_sx_t.append(x_mxfp8_res._columnwise_scale_inv) + else: + x_qx_t.append(None) + x_sx_t.append(None) + + return x_qx, x_sx, x_qx_t, x_sx_t + + +def fused_grouped_quantize( + x: torch.Tensor, split_section_tensor: torch.Tensor, quantizer: MXFP8Quantizer +): + + # view x as a 2D tensor + hidden_dim = x.shape[-1] + x = x.view(-1, hidden_dim) + num_tensors = split_section_tensor.shape[0] + + grouped_output = tex.group_quantize(x, quantizer, num_tensors, split_section_tensor) + + return grouped_output + + +def assert_same_shape_and_dtype(x: torch.Tensor, y: torch.Tensor) -> None: + assert x.shape == y.shape + assert x.dtype == y.dtype + + +def check_grouped_tensor_mxfp8_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_identity: bool, + return_transpose: bool, + split_sections: list[int], + optimize_for_gemm: bool = False, +) -> None: + + te_dtype = tex.DType.kFloat8E4M3 + + split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + x_splits = torch.split(x, split_sections) + + # Quantize + quantizers = [ + MXFP8Quantizer( + fp8_dtype=te_dtype, + rowwise=return_identity, + columnwise=return_transpose, + ) + for _ in range(len(split_sections)) + ] + + grouped_quantizer = quantizers[0].copy() + # configure grouped quantizer with swizzle fusion + # and compare with reference without swizzle fusion + grouped_quantizer.optimize_for_gemm = optimize_for_gemm + + x_qx_ref, x_sx_ref, x_qx_t_ref, x_sx_t_ref = reference_group_quantize( + x, quantizers, split_sections, return_identity, return_transpose + ) + + group_quantized_output = fused_grouped_quantize(x, split_section_tensor, grouped_quantizer) + # get a list of MXFP8 quantized tensors for testing + split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() + + if return_identity: + x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] + x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] + + for i in range(len(x_qx)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i]) + assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i]) + else: + torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_mxfp8_scale_shape_no_padding(x_splits[i].shape, False) + assert ( + valid_scale_shape == x_sx[i].shape + ), "The scale shape is not correctly aligned" + x_sx_i = x_sx[i].clone() + x_sx_ref_i = x_sx_ref[i].clone() + if optimize_for_gemm: + x_sx_ref_i = swizzle_mxfp8_scale( + split_sections[i], N, x_sx_ref_i, columnwise=False + ) + torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=0.0, rtol=0.0) + + if return_transpose: + x_qx_t = [ + output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs + ] + x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs] + # assert with zero tolerance + for i in range(len(x_qx_t)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i]) + assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i]) + else: + torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_mxfp8_scale_shape_no_padding(x_splits[i].shape, True) + assert ( + valid_scale_shape == x_sx_t[i].shape + ), "The scale shape is not correctly aligned" + x_sx_t_i = x_sx_t[i].clone() + x_sx_t_ref_i = x_sx_t_ref[i].clone() + if optimize_for_gemm: + x_sx_t_ref_i = swizzle_mxfp8_scale( + split_sections[i], N, x_sx_t_ref_i, columnwise=True + ) + torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=0.0, rtol=0.0) + + +def check_grouped_tensor_mxfp8_with_paged_stashing( + x_dtype: torch.dtype, + M: int, + N: int, + return_identity: bool, + return_transpose: bool, + split_sections: list[int], + valid_M: int = None, + optimize_for_gemm: bool = False, +) -> None: + + te_dtype = tex.DType.kFloat8E4M3 + + assert valid_M is not None, "valid_M must be provided when with_paged_stashing is True" + assert valid_M < M, "valid_M must be less than M when with_paged_stashing is True" + + split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input (fill the entire tensor with garbage too) + x = torch.randn((M, N), dtype=x_dtype, device=device) + valid_x = x[:valid_M, :].clone() + x_splits = torch.split(valid_x, split_sections) + + # Quantize + quantizers = [ + MXFP8Quantizer( + fp8_dtype=te_dtype, + rowwise=return_identity, + columnwise=return_transpose, + ) + for _ in range(len(split_sections)) + ] + + grouped_quantizer = quantizers[0].copy() + # configure grouped quantizer with swizzle fusion + # and compare with reference without swizzle fusion + grouped_quantizer.optimize_for_gemm = optimize_for_gemm + + x_qx_ref, x_sx_ref, x_qx_t_ref, x_sx_t_ref = reference_group_quantize( + valid_x, quantizers, split_sections, return_identity, return_transpose + ) + + # Note: for grouped quantize with paged stashing + # it's expected that we can just pass in the regular input x, not the valid_x + # the kernel is expected to porcess it correctly by becoming no-op for cuda graph + group_quantized_output = fused_grouped_quantize(x, split_section_tensor, grouped_quantizer) + + # get a list of MXFP8 quantized tensors for testing + split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() + + if return_identity: + x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] + x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] + + for i in range(len(x_qx)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i]) + assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i]) + else: + torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_mxfp8_scale_shape_no_padding(x_splits[i].shape, False) + assert ( + valid_scale_shape == x_sx[i].shape + ), "The scale shape is not correctly aligned" + x_sx_i = x_sx[i].clone() + x_sx_ref_i = x_sx_ref[i].clone() + if optimize_for_gemm: + x_sx_ref_i = swizzle_mxfp8_scale( + split_sections[i], N, x_sx_ref_i, columnwise=False + ) + torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=0.0, rtol=0.0) + + if return_transpose: + x_qx_t = [ + output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs + ] + x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs] + # assert with zero tolerance + for i in range(len(x_qx_t)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i]) + assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i]) + else: + torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_mxfp8_scale_shape_no_padding(x_splits[i].shape, True) + assert ( + valid_scale_shape == x_sx_t[i].shape + ), "The scale shape is not correctly aligned" + x_sx_t_i = x_sx_t[i].clone() + x_sx_t_ref_i = x_sx_t_ref[i].clone() + if optimize_for_gemm: + x_sx_t_ref_i = swizzle_mxfp8_scale( + split_sections[i], N, x_sx_t_ref_i, columnwise=True + ) + torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # edge case, zero tokens for all + (0, 512), + # full tile cases + (1024, 256), + # larger sizes + (8192, 1024), + (16384, 8192), + (16384, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + "zero_tokens_front", + "zero_tokens_end", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] +) +@pytest.mark.parametrize( + "optimize_for_gemm", [True, False], ids=["optimize_for_gemm", "no_optimize_for_gemm"] +) +def test_grouped_tensor_mxfp8_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + quantize_mode: str, + optimize_for_gemm: bool, +) -> None: + + split_sections = generate_split_sections(M, N, edge_cases) + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + check_grouped_tensor_mxfp8_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + return_identity=return_identity, + return_transpose=return_transpose, + split_sections=split_sections, + optimize_for_gemm=optimize_for_gemm, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # M won't be empty in paged stashing + # full tile cases + (1024, 256), + # larger sizes + (8192, 1024), + (16384, 8192), + (16384, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + # even if buffer is not empty, but the token splits are all zero + "zero_tokens_all", + # partially zero tokens + "zero_tokens_front", + "zero_tokens_end", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] +) +@pytest.mark.parametrize( + "optimize_for_gemm", [True, False], ids=["optimize_for_gemm", "no_optimize_for_gemm"] +) +def test_grouped_tensor_mxfp8_with_paged_stashing( + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + quantize_mode: str, + optimize_for_gemm: bool, +) -> None: + + # paged stashing means that the sum of total tokens is less than + # or equal to the buffer size, you can have buffer [2048, 1024] + # and when you only receive 1024 tokens, the last half is garbage + # so input has shape [2048, 1024] + # split sections can be [256, 256, 256, 256], sums to 1024 + valid_M = 0 if edge_cases == "zero_tokens_all" else M // 2 + split_sections = generate_split_sections(valid_M, N, edge_cases) + + # sanity check + if edge_cases == "zero_tokens_all": + assert valid_M == 0, "valid_M must be 0 when edge_cases is zero_tokens_all" + else: + assert valid_M == M // 2, "valid_M must be M // 2 when edge_cases is not zero_tokens_all" + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + check_grouped_tensor_mxfp8_with_paged_stashing( + x_dtype=x_dtype, + M=M, + N=N, + return_identity=return_identity, + return_transpose=return_transpose, + split_sections=split_sections, + valid_M=valid_M, + optimize_for_gemm=optimize_for_gemm, + ) diff --git a/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py b/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py new file mode 100644 index 000000000..94ea699d1 --- /dev/null +++ b/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py @@ -0,0 +1,134 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import MXFP8Quantizer +from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage + +import pytest +import torch +import random +import math + +from typing import Tuple + +from mxfp8_utils import swizzle_mxfp8_scale, get_mxfp8_scale_shape_no_padding + +recipe_available, reason_for_no_recipe = te.is_mxfp8_available(return_reason=True) + + +def unpack_quantized_tensor( + quantized_tensor: MXFP8TensorStorage, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + qx, sx, qx_t, sx_t = None, None, None, None + if quantized_tensor._rowwise_data is not None: + qx = quantized_tensor._rowwise_data.view(dtype=torch.uint8) + if quantized_tensor._rowwise_scale_inv is not None: + sx = quantized_tensor._rowwise_scale_inv + if quantized_tensor._columnwise_data is not None: + qx_t = quantized_tensor._columnwise_data.view(dtype=torch.uint8) + if quantized_tensor._columnwise_scale_inv is not None: + sx_t = quantized_tensor._columnwise_scale_inv + return qx, sx, qx_t, sx_t + + +def check_mxfp8_quantize_swizzle_fusion( + x_dtype: torch.dtype, + M: int, + N: int, + return_identity: bool, + return_transpose: bool, +) -> None: + + te_dtype = tex.DType.kFloat8E4M3 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + + # Quantize + quantizer = MXFP8Quantizer( + fp8_dtype=te_dtype, + rowwise=return_identity, + columnwise=return_transpose, + ) + + quantizer_swizzle_fusion = quantizer.copy() + quantizer_swizzle_fusion.optimize_for_gemm = True + + x_qx_swf, x_sx_swf, x_qx_t_swf, x_sx_t_swf = unpack_quantized_tensor( + quantizer_swizzle_fusion(x) + ) + x_qx_ref, x_sx_ref, x_qx_t_ref, x_sx_t_ref = unpack_quantized_tensor(quantizer(x)) + + if return_identity: + torch.testing.assert_close(x_qx_swf, x_qx_ref, atol=0.0, rtol=0.0) + valid_scale_shape = get_mxfp8_scale_shape_no_padding(x.shape, False) + assert valid_scale_shape == x_sx_swf.shape, ( + "The scale shape is not correctly aligned, this test assumes no padding is needed for" + " scaling factors" + ) + x_sx_ref_swizzled = swizzle_mxfp8_scale(M, N, x_sx_ref, columnwise=False) + torch.testing.assert_close(x_sx_swf, x_sx_ref_swizzled, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(x_qx_t_swf, x_qx_t_ref, atol=0.0, rtol=0.0) + valid_scale_shape = get_mxfp8_scale_shape_no_padding(x.shape, True) + assert valid_scale_shape == x_sx_t_swf.shape, ( + "The scale shape is not correctly aligned, this test assumes no padding is needed for" + " scaling factors" + ) + x_sx_t_ref_swizzled = swizzle_mxfp8_scale(M, N, x_sx_t_ref, columnwise=True) + torch.testing.assert_close(x_sx_t_swf, x_sx_t_ref_swizzled, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (1024, 256), + # larger sizes + (8192, 1024), + (16384, 8192), + (16384, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] +) +def test_mxfp8_quantize_swizzle_fusion( + x_dtype: torch.dtype, + M: int, + N: int, + quantize_mode: str, +) -> None: + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + check_mxfp8_quantize_swizzle_fusion( + x_dtype=x_dtype, + M=M, + N=N, + return_identity=return_identity, + return_transpose=return_transpose, + ) diff --git a/tests/pytorch/nvfp4/nvfp4_utils.py b/tests/pytorch/nvfp4/nvfp4_utils.py new file mode 100644 index 000000000..5f1b5ac36 --- /dev/null +++ b/tests/pytorch/nvfp4/nvfp4_utils.py @@ -0,0 +1,159 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer + +import torch +import math +import random + + +# Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization without padding +def get_nvfp4_scale_shape_no_padding(shape, columnwise): + M, K = 1, 1 + M = math.prod(shape[:-1]) + K = shape[-1] + + if columnwise: + outer = K + inner = math.ceil(M / 16) + return (outer, inner) + # rowwise + outer = M + inner = math.ceil(K / 16) + return (outer, inner) + + +def _rowwise_swizzle_nvfp4_scale(input_M, input_N, scale: torch.Tensor) -> torch.Tensor: + assert scale.dim() == 2 + assert input_M == scale.shape[0] + assert input_N // 16 == scale.shape[1] + + x = scale.view(input_M // 128, 4, 32, input_N // 64, 4) + x = x.permute(0, 3, 2, 1, 4) + x = x.contiguous() + # View back as original 2D shape + x = x.view(input_M, input_N // 16) + return x + + +# TN-only layout for NVFP4 means that there is only rowwise swizzle +# just need to switch the M, N which means transposing the input +def swizzle_nvfp4_scale(input_M, input_N, scale: torch.Tensor, columnwise: bool) -> torch.Tensor: + if not columnwise: + return _rowwise_swizzle_nvfp4_scale(input_M, input_N, scale) + else: + return _rowwise_swizzle_nvfp4_scale(input_N, input_M, scale) + + +# Helper function to generate random multiples sum +def _generate_random_multiples_sum(total=8192, n=4, multiple=64): + if total % multiple != 0: + raise ValueError(f"Total ({total}) must be a multiple of {multiple}") + if (total // multiple) < n: + raise ValueError("Total too small for given n and multiple.") + + # Work in units of multiples + total_units = total // multiple + + # choose n−1 random cut points in [1, total_units−1) + cuts = sorted(random.sample(range(1, total_units), n - 1)) + + # convert to segment lengths + parts = ( + [cuts[0]] + [cuts[i] - cuts[i - 1] for i in range(1, len(cuts))] + [total_units - cuts[-1]] + ) + + # convert back to multiples + return [p * multiple for p in parts] + + +# Generate split sections for NVFP4 1D blockwise quantization +def generate_split_sections( + M: int, N: int, edge_cases: str, least_multiple: int = 128 +) -> list[int]: + num_chunks = 4 + split_sections = None + + avg_split = M // num_chunks + + if M == 0 or N == 0: + # all zeros + return [0] * num_chunks + if edge_cases == "regular": + split_sections = [avg_split] * num_chunks + elif edge_cases == "zero_tokens_all": + split_sections = [0] * num_chunks + elif edge_cases == "zero_tokens_front": + split_sections = [0] + [avg_split] * (num_chunks - 2) + [avg_split * 2] + elif edge_cases == "zero_tokens_end": + split_sections = [avg_split * 2] + [avg_split] * (num_chunks - 2) + [0] + elif edge_cases == "zero_tokens_middle": + split_sections = [avg_split] * (num_chunks - 2) + [0] + [avg_split * 2] + elif edge_cases == "random_uneven_split": + split_sections = _generate_random_multiples_sum(M, num_chunks, least_multiple) + else: + raise ValueError(f"Invalid edge case: {edge_cases}") + + # adds up the split_sections to make it M + assert sum(split_sections) == M, "The split_sections do not add up to M" + + # make sure every split_section is a multiple of least_multiple + for split_section in split_sections: + assert ( + split_section % least_multiple == 0 + ), "The split_sections are not multiples of least_multiple" + + return split_sections + + +# Reference implementation of group quantization for NVFP4 1D blockwise quantization +def reference_group_quantize( + x: torch.Tensor, + quantizers: list[NVFP4Quantizer], + split_sections: list[int], + return_identity: bool, + return_transpose: bool, +) -> torch.Tensor: + x_view = x.reshape(-1, x.size(-1)) + x_chunks = torch.split(x, split_sections) + + # rowwise quantization + x_qx = [] + x_sx = [] + x_amax_rowwise = [] + # columnwise quantization + x_qx_t = [] + x_sx_t = [] + x_amax_colwise = [] + + for i in range(len(x_chunks)): + x_chunk = x_chunks[i] + x_nvfp4_res = quantizers[i](x_chunk) + if return_identity: + x_qx.append(x_nvfp4_res._rowwise_data.view(dtype=torch.uint8)) + x_sx.append(x_nvfp4_res._rowwise_scale_inv) + x_amax_rowwise.append(x_nvfp4_res._amax_rowwise) + else: + x_qx.append(None) + x_sx.append(None) + x_amax_rowwise.append(None) + if return_transpose: + x_qx_t.append(x_nvfp4_res._columnwise_data.view(dtype=torch.uint8)) + x_sx_t.append(x_nvfp4_res._columnwise_scale_inv) + x_amax_colwise.append(x_nvfp4_res._amax_columnwise) + else: + x_qx_t.append(None) + x_sx_t.append(None) + x_amax_colwise.append(None) + + return x_qx, x_sx, x_amax_rowwise, x_qx_t, x_sx_t, x_amax_colwise + + +# Function to assert that two tensors have the same shape and dtype +def assert_same_shape_and_dtype(x: torch.Tensor, y: torch.Tensor) -> None: + assert x.shape == y.shape + assert x.dtype == y.dtype diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 01a4a0120..d4bf1fd3a 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -23,126 +23,14 @@ import random import math -recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) - - -def generate_random_multiples_sum(total=8192, n=4, multiple=64): - if total % multiple != 0: - raise ValueError(f"Total ({total}) must be a multiple of {multiple}") - if (total // multiple) < n: - raise ValueError("Total too small for given n and multiple.") - - # Work in units of multiples - total_units = total // multiple - - # choose n−1 random cut points in [1, total_units−1) - cuts = sorted(random.sample(range(1, total_units), n - 1)) - - # convert to segment lengths - parts = ( - [cuts[0]] + [cuts[i] - cuts[i - 1] for i in range(1, len(cuts))] + [total_units - cuts[-1]] - ) - - # convert back to multiples - return [p * multiple for p in parts] - - -def generate_split_sections(M: int, N: int, edge_cases: str) -> list[int]: - least_multiple = 64 - num_chunks = 4 - split_sections = None - - avg_split = M // num_chunks - - if M == 0 or N == 0: - # all zeros - return [0] * num_chunks - if edge_cases == "regular": - split_sections = [avg_split] * num_chunks - elif edge_cases == "zero_tokens_front": - split_sections = [0] + [avg_split] * (num_chunks - 2) + [avg_split * 2] - elif edge_cases == "zero_tokens_end": - split_sections = [avg_split * 2] + [avg_split] * (num_chunks - 2) + [0] - elif edge_cases == "zero_tokens_middle": - split_sections = [avg_split] * (num_chunks - 2) + [0] + [avg_split * 2] - elif edge_cases == "random_uneven_split": - split_sections = generate_random_multiples_sum(M, num_chunks, least_multiple) - else: - raise ValueError(f"Invalid edge case: {edge_cases}") - - # adds up the split_sections to make it M - assert sum(split_sections) == M, "The split_sections do not add up to M" - - # make sure every split_section is a multiple of least_multiple - for split_section in split_sections: - assert ( - split_section % least_multiple == 0 - ), "The split_sections are not multiples of least_multiple" - - return split_sections - - -# Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization without padding -def get_nvfp4_scale_shape_no_padding(shape, columnwise): - M, K = 1, 1 - M = math.prod(shape[:-1]) - K = shape[-1] - - if columnwise: - outer = K - inner = math.ceil(M / 16) - return (outer, inner) - # rowwise - outer = M - inner = math.ceil(K / 16) - return (outer, inner) - - -def reference_group_quantize( - x: torch.Tensor, - quantizers: list[NVFP4Quantizer], - split_sections: list[int], - return_identity: bool, - return_transpose: bool, -) -> torch.Tensor: - x_view = x.reshape(-1, x.size(-1)) - x_chunks = torch.split(x, split_sections) - - # rowwise quantization - x_qx = [] - x_sx = [] - x_amax_rowwise = [] - # columnwise quantization - x_qx_t = [] - x_sx_t = [] - x_amax_colwise = [] - - for i in range(len(x_chunks)): - x_chunk = x_chunks[i] - x_nvfp4_res = quantizers[i](x_chunk) - if return_identity: - x_qx.append(x_nvfp4_res._rowwise_data.view(dtype=torch.uint8)) - x_sx.append(x_nvfp4_res._rowwise_scale_inv) - x_amax_rowwise.append(x_nvfp4_res._amax_rowwise) - else: - x_qx.append(None) - x_sx.append(None) - x_amax_rowwise.append(None) - if return_transpose: - x_qx_t.append(x_nvfp4_res._columnwise_data.view(dtype=torch.uint8)) - x_sx_t.append(x_nvfp4_res._columnwise_scale_inv) - x_amax_colwise.append(x_nvfp4_res._amax_columnwise) - else: - x_qx_t.append(None) - x_sx_t.append(None) - x_amax_colwise.append(None) - - return x_qx, x_sx, x_amax_rowwise, x_qx_t, x_sx_t, x_amax_colwise - +from nvfp4_utils import ( + get_nvfp4_scale_shape_no_padding, + generate_split_sections, + assert_same_shape_and_dtype, + reference_group_quantize, +) -def assert_same_shape_and_dtype(x: torch.Tensor, y: torch.Tensor) -> None: - assert x.shape == y.shape - assert x.dtype == y.dtype +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) def check_group_quantization_nvfp4_versus_reference( @@ -242,6 +130,8 @@ def check_group_quantization_nvfp4_versus_reference( [ # edge case, zero tokens for all (0, 512), + # edge case, not 128 multiple hidden dimension + (1024, 320), # full tile cases (256, 1024), (1024, 256), @@ -279,7 +169,7 @@ def test_rht_with_quantization_block_tiling_versus_reference( with_rht: bool, ) -> None: - split_sections = generate_split_sections(M, N, edge_cases) + split_sections = generate_split_sections(M, N, edge_cases, least_multiple=64) # currently disable pre-RHT amax with_post_rht_amax = with_rht diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py new file mode 100644 index 000000000..8d81d578a --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py @@ -0,0 +1,451 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes import utils +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.common.recipe import NVFP4BlockScaling +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor + +import pytest +import torch +import random +import math + +from nvfp4_utils import ( + get_nvfp4_scale_shape_no_padding, + swizzle_nvfp4_scale, + generate_split_sections, + assert_same_shape_and_dtype, + reference_group_quantize, +) + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +def fused_grouped_quantize( + x: torch.Tensor, split_section_tensor: torch.Tensor, quantizer: NVFP4Quantizer +): + + # view x as a 2D tensor + hidden_dim = x.shape[-1] + x = x.view(-1, hidden_dim) + num_tensors = split_section_tensor.shape[0] + + grouped_output = tex.group_quantize(x, quantizer, num_tensors, split_section_tensor) + + return grouped_output + + +def check_grouped_tensor_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_identity: bool, + return_transpose: bool, + split_sections: list[int], + with_rht: bool = True, + with_post_rht_amax: bool = True, + with_random_sign_mask: bool = True, + optimize_for_gemm: bool = False, +) -> None: + + te_dtype = tex.DType.kFloat4E2M1 + + split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + num_chunks = len(split_sections) + + x_splits = torch.split(x, split_sections) + + # Quantize + quantizers = [ + NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=return_identity, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + ) + for _ in range(len(split_sections)) + ] + + grouped_quantizer = quantizers[0].copy() + # configure grouped quantizer with swizzle fusion + # and compare with reference without swizzle fusion + grouped_quantizer.optimize_for_gemm = optimize_for_gemm + + x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = ( + reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose) + ) + + group_quantized_output = fused_grouped_quantize(x, split_section_tensor, grouped_quantizer) + # get a list of nvfp4 quantized tensors for testing + split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() + + if return_identity: + x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] + x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] + x_amax_rowwise = [output._amax_rowwise for output in split_quantize_outputs] + + for i in range(len(x_qx)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i]) + assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i]) + assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i]) + else: + torch.testing.assert_close( + x_amax_rowwise[i], x_amax_rowwise_ref[i], atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, False) + assert ( + valid_scale_shape == x_sx[i].shape + ), "The scale shape is not correctly aligned" + x_sx_i = x_sx[i].clone() + x_sx_ref_i = x_sx_ref[i].clone() + if optimize_for_gemm: + x_sx_ref_i = swizzle_nvfp4_scale( + split_sections[i], N, x_sx_ref_i, columnwise=False + ) + torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=0.0, rtol=0.0) + + if return_transpose: + x_qx_t = [ + output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs + ] + x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs] + x_amax_colwise = [output._amax_columnwise for output in split_quantize_outputs] + # assert with zero tolerance + for i in range(len(x_qx_t)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i]) + assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i]) + assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i]) + else: + torch.testing.assert_close( + x_amax_colwise[i], x_amax_colwise_ref[i], atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True) + assert ( + valid_scale_shape == x_sx_t[i].shape + ), "The scale shape is not correctly aligned" + x_sx_t_i = x_sx_t[i].clone() + x_sx_t_ref_i = x_sx_t_ref[i].clone() + if optimize_for_gemm: + x_sx_t_ref_i = swizzle_nvfp4_scale( + split_sections[i], N, x_sx_t_ref_i, columnwise=True + ) + torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=0.0, rtol=0.0) + + +def check_grouped_tensor_nvfp4_with_paged_stashing( + x_dtype: torch.dtype, + M: int, + N: int, + return_identity: bool, + return_transpose: bool, + split_sections: list[int], + with_rht: bool = True, + with_post_rht_amax: bool = True, + with_random_sign_mask: bool = True, + valid_M: int = None, + optimize_for_gemm: bool = False, +) -> None: + + te_dtype = tex.DType.kFloat4E2M1 + + assert valid_M is not None, "valid_M must be provided when with_paged_stashing is True" + assert valid_M < M, "valid_M must be less than M when with_paged_stashing is True" + + split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input (fill the entire tensor with garbage too) + x = torch.randn((M, N), dtype=x_dtype, device=device) + valid_x = x[:valid_M, :].clone() + num_chunks = len(split_sections) + + x_splits = torch.split(valid_x, split_sections) + + # Quantize + quantizers = [ + NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=return_identity, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + ) + for _ in range(len(split_sections)) + ] + + grouped_quantizer = quantizers[0].copy() + # configure grouped quantizer with swizzle fusion + # and compare with reference without swizzle fusion + grouped_quantizer.optimize_for_gemm = optimize_for_gemm + + x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = ( + reference_group_quantize( + valid_x, quantizers, split_sections, return_identity, return_transpose + ) + ) + + # Note: for grouped quantize with paged stashing + # it's expected that we can just pass in the regular input x, not the valid_x + # the kernel is expected to porcess it correctly by becoming no-op for cuda graph + group_quantized_output = fused_grouped_quantize(x, split_section_tensor, grouped_quantizer) + + # get a list of nvfp4 quantized tensors for testing + split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() + + if return_identity: + x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] + x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] + x_amax_rowwise = [output._amax_rowwise for output in split_quantize_outputs] + + for i in range(len(x_qx)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i]) + assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i]) + assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i]) + else: + torch.testing.assert_close( + x_amax_rowwise[i], x_amax_rowwise_ref[i], atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, False) + assert ( + valid_scale_shape == x_sx[i].shape + ), "The scale shape is not correctly aligned" + x_sx_i = x_sx[i].clone() + x_sx_ref_i = x_sx_ref[i].clone() + if optimize_for_gemm: + x_sx_ref_i = swizzle_nvfp4_scale( + split_sections[i], N, x_sx_ref_i, columnwise=False + ) + torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=0.0, rtol=0.0) + + if return_transpose: + x_qx_t = [ + output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs + ] + x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs] + x_amax_colwise = [output._amax_columnwise for output in split_quantize_outputs] + # assert with zero tolerance + for i in range(len(x_qx_t)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i]) + assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i]) + assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i]) + else: + torch.testing.assert_close( + x_amax_colwise[i], x_amax_colwise_ref[i], atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True) + x_sx_t_i = x_sx_t[i].clone() + x_sx_t_ref_i = x_sx_t_ref[i].clone() + if optimize_for_gemm: + x_sx_t_ref_i = swizzle_nvfp4_scale( + split_sections[i], N, x_sx_t_ref_i, columnwise=True + ) + torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # edge case, zero tokens for all + (0, 512), + # full tile cases + (1024, 256), + # larger sizes + (8192, 1024), + (16384, 8192), + (16384, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + "zero_tokens_front", + "zero_tokens_end", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +@pytest.mark.parametrize("with_rht", [True], ids=["with_rht"]) +@pytest.mark.parametrize( + "optimize_for_gemm", [True, False], ids=["optimize_for_gemm", "no_optimize_for_gemm"] +) +def test_grouped_tensor_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + quantize_mode: str, + with_random_sign_mask: bool, + with_rht: bool, + optimize_for_gemm: bool, +) -> None: + + split_sections = generate_split_sections(M, N, edge_cases, least_multiple=128) + + # currently disable pre-RHT amax + with_post_rht_amax = with_rht + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + check_grouped_tensor_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + return_identity=return_identity, + return_transpose=return_transpose, + split_sections=split_sections, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + optimize_for_gemm=optimize_for_gemm, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # M won't be empty in paged stashing + # full tile cases + (1024, 256), + # larger sizes + (8192, 1024), + (16384, 8192), + (16384, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + # even if buffer is not empty, but the token splits are all zero + "zero_tokens_all", + # partially zero tokens + "zero_tokens_front", + "zero_tokens_end", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +@pytest.mark.parametrize("with_rht", [True], ids=["with_rht"]) +@pytest.mark.parametrize( + "optimize_for_gemm", [True, False], ids=["optimize_for_gemm", "no_optimize_for_gemm"] +) +def test_grouped_tensor_nvfp4_with_paged_stashing( + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + quantize_mode: str, + with_random_sign_mask: bool, + with_rht: bool, + optimize_for_gemm: bool, +) -> None: + + # paged stashing means that the sum of total tokens is less than + # or equal to the buffer size, you can have buffer [2048, 1024] + # and when you only receive 1024 tokens, the last half is garbage + # so input has shape [2048, 1024] + # split sections can be [256, 256, 256, 256], sums to 1024 + valid_M = 0 if edge_cases == "zero_tokens_all" else M // 2 + split_sections = generate_split_sections(valid_M, N, edge_cases, least_multiple=128) + + # sanity check + if edge_cases == "zero_tokens_all": + assert valid_M == 0, "valid_M must be 0 when edge_cases is zero_tokens_all" + else: + assert valid_M == M // 2, "valid_M must be M // 2 when edge_cases is not zero_tokens_all" + + # currently disable pre-RHT amax + with_post_rht_amax = with_rht + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + check_grouped_tensor_nvfp4_with_paged_stashing( + x_dtype=x_dtype, + M=M, + N=N, + return_identity=return_identity, + return_transpose=return_transpose, + split_sections=split_sections, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + valid_M=valid_M, + optimize_for_gemm=optimize_for_gemm, + ) diff --git a/tests/pytorch/test_checkpoint.py b/tests/pytorch/test_checkpoint.py index 1383264fd..0427886b8 100644 --- a/tests/pytorch/test_checkpoint.py +++ b/tests/pytorch/test_checkpoint.py @@ -101,7 +101,7 @@ def _save_checkpoint(name: str, checkpoint_dir: Optional[pathlib.Path] = None) - # Path to save checkpoint if checkpoint_dir is None: checkpoint_dir = TestLoadCheckpoint._checkpoint_dir() - checkpoint_dir.mkdir(exist_ok=True) + checkpoint_dir.mkdir(parents=True, exist_ok=True) checkpoint_file = checkpoint_dir / f"{name}.pt" # Create module and save checkpoint diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 4de49115b..536d43adc 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -8,6 +8,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.common import recipe +from transformer_engine.pytorch.constants import FP8BwdTensorIdx, FP8FwdTensorIdx from transformer_engine.pytorch import ( autocast, Linear, @@ -169,11 +170,11 @@ def test_custom_recipe_matches_current_scaling(): with autocast(enabled=True, recipe=ref_recipe): out_ref = model_ref(inp_ref) # Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd) - ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - ref_fwd_w = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] - ref_fwd_out = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] - ref_bwd_go = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] - ref_bwd_gi = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + ref_fwd_in = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT] + ref_fwd_w = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT] + ref_fwd_out = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT] + ref_bwd_go = model_ref.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1] + ref_bwd_gi = model_ref.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] assert ref_fwd_in.dtype == tex.DType.kFloat8E4M3 assert ref_fwd_w.dtype == tex.DType.kFloat8E4M3 assert ref_fwd_out.dtype == tex.DType.kFloat8E4M3 @@ -200,11 +201,11 @@ def quantizer_factory(role): with autocast(enabled=True, recipe=custom_recipe): out_custom = model_custom(inp_custom) # Assert dtypes for custom quantizers match reference mapping - cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - cus_fwd_w = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] - cus_fwd_out = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] - cus_bwd_go = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] - cus_bwd_gi = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + cus_fwd_in = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT] + cus_fwd_w = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT] + cus_fwd_out = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT] + cus_bwd_go = model_custom.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1] + cus_bwd_gi = model_custom.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] assert cus_fwd_in.dtype == tex.DType.kFloat8E4M3 assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3 assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3 diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 091943d02..e2c498f90 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -13,8 +13,9 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te -from transformer_engine.common.recipe import DelayedScaling +from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8BlockScaling from transformer_engine.pytorch import MultiheadAttention, quantized_model_init, is_bf16_available +from transformer_engine.pytorch import QuantizedTensor from transformer_engine.pytorch.utils import gpu_autocast_ctx from transformer_engine.pytorch.utils import get_device_compute_capability @@ -414,6 +415,20 @@ def test_bf16_exp_avg_sq(self): master_atol=2e-3, ) + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") + def test_bf16_exp_avg_and_exp_avg_sq(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.bfloat16, + exp_avg_sq_dtype=torch.bfloat16, + master_rtol=2e-3, + master_atol=2e-3, + ) + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg_sq(self): @@ -512,6 +527,269 @@ def test_fp8_model_weight_cast(self): ) +class TestFusedAdamMXFP8(TestFusedOptimizer): + """FusedAdam with MXFP8BlockScaling quantized primary weights (single GPU, no FSDP).""" + + def setup_method(self) -> None: + super().setup_method(iters=5) + mxfp8_available, self.mxfp8_reason = te.is_mxfp8_available(return_reason=True) + self.mxfp8_available = mxfp8_available + + def _build_model(self): + recipe = MXFP8BlockScaling() + with quantized_model_init(enabled=True, recipe=recipe): + model = te.Linear(256, 256, params_dtype=torch.bfloat16).cuda() + return model, recipe + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + def test_mxfp8_linear_fused_adam_master_weights(self): + """quantized_model_init(MXFP8) + te.Linear + FusedAdam(master_weights=True). + + Verifies: + - Model params are MXFP8 QuantizedTensors after init + - FP32 master weights track a reference Adam optimizer + - Params remain QuantizedTensors after training + - Loss decreases over training steps + """ + if not self.mxfp8_available: + pytest.skip(self.mxfp8_reason) + + model, recipe = self._build_model() + + # Verify weight params are QuantizedTensors (bias stays bf16) + for name, p in model.named_parameters(): + if "bias" not in name: + assert isinstance( + p, QuantizedTensor + ), f"Expected QuantizedTensor for {name}, got {type(p).__name__}" + + # Build reference: clone dequantized weights for a plain Adam + ref_params = [p.detach().clone().float() for p in model.parameters()] + + options = {"lr": 5e-4, "betas": (0.9, 0.999), "eps": 1e-8, "weight_decay": 0} + ref_optim = torch.optim.Adam(ref_params, **options) + tst_optim = te.optimizers.FusedAdam( + list(model.parameters()), + master_weights=True, + master_weight_dtype=torch.float32, + use_decoupled_grad=True, + **options, + ) + + for _ in range(self.iters): + for p_ref, p in zip(ref_params, model.parameters()): + p_ref.grad = torch.rand_like(p_ref) + p.decoupled_grad = p_ref.grad.clone() + ref_optim.step() + tst_optim.step() + + # FP32 master weights should match reference Adam exactly + master_params = [ + tst_optim.get_unscaled_state(p, "master_param") for p in model.parameters() + ] + torch.testing.assert_close(ref_params, master_params) + + # Weight params should still be QuantizedTensors after training + for name, p in model.named_parameters(): + if "bias" not in name: + assert isinstance( + p, QuantizedTensor + ), f"{name} lost QuantizedTensor type after training: {type(p).__name__}" + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + def test_mxfp8_linear_forward_backward_step(self): + """End-to-end: quantized_model_init + autocast forward + backward + FusedAdam.step(). + + Uses te.autocast with MXFP8BlockScaling recipe for the forward pass, + verifying the full training loop works with quantized compute. + """ + if not self.mxfp8_available: + pytest.skip(self.mxfp8_reason) + + model, recipe = self._build_model() + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + batch_size, seq_len, hidden = 4, 32, 256 + x = torch.randn(batch_size, seq_len, hidden, dtype=torch.bfloat16, device="cuda") + target = torch.randn_like(x) + + losses = [] + for i in range(self.iters): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = torch.nn.functional.mse_loss(output, target) + losses.append(loss.item()) + loss.backward() + + # Verify all params have non-None gradients after backward + for name, p in model.named_parameters(): + assert p.grad is not None, f"Step {i}: {name} has no gradient after backward" + assert ( + p.grad.shape == p.shape + ), f"Step {i}: {name} grad shape {p.grad.shape} != param shape {p.shape}" + assert torch.isfinite(p.grad).all(), f"Step {i}: {name} has non-finite gradients" + assert p.grad.any(), f"Step {i}: {name} gradient is all zeros" + + optimizer.step() + + # Verify loss decreased + assert losses[-1] < losses[0], f"Loss did not decrease: {losses}" + + # Verify weight params remain QuantizedTensors + for name, p in model.named_parameters(): + if "bias" not in name: + assert isinstance( + p, QuantizedTensor + ), f"{name} lost QuantizedTensor type: {type(p).__name__}" + + # Verify optimizer states are float32 + for name, p in model.named_parameters(): + state = optimizer.state[p] + assert state["exp_avg"].dtype == torch.float32 + assert state["exp_avg_sq"].dtype == torch.float32 + if "bias" not in name: + assert state["master_param"].dtype == torch.float32 + + +class TestFusedAdamFloat8Block(TestFusedOptimizer): + """FusedAdam with Float8BlockScaling quantized primary weights (single GPU, no FSDP).""" + + def setup_method(self) -> None: + super().setup_method(iters=5) + fp8_block_available, self.fp8_block_reason = te.is_fp8_block_scaling_available( + return_reason=True + ) + self.fp8_block_available = fp8_block_available + + def _build_model(self): + recipe = Float8BlockScaling() + with quantized_model_init(enabled=True, recipe=recipe): + model = te.Linear(256, 256, params_dtype=torch.bfloat16).cuda() + return model, recipe + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + def test_float8block_linear_fused_adam_master_weights(self): + """quantized_model_init(Float8BlockScaling) + te.Linear + FusedAdam(master_weights=True). + + Verifies: + - Model params are QuantizedTensors after init + - FP32 master weights track a reference Adam optimizer + - Params remain QuantizedTensors after training + """ + if not self.fp8_block_available: + pytest.skip(self.fp8_block_reason) + + model, recipe = self._build_model() + + # Verify weight params are QuantizedTensors (bias stays bf16) + for name, p in model.named_parameters(): + if "bias" not in name: + assert isinstance( + p, QuantizedTensor + ), f"Expected QuantizedTensor for {name}, got {type(p).__name__}" + + # Build reference: clone dequantized weights for a plain Adam + ref_params = [p.detach().clone().float() for p in model.parameters()] + + options = {"lr": 5e-4, "betas": (0.9, 0.999), "eps": 1e-8, "weight_decay": 0} + ref_optim = torch.optim.Adam(ref_params, **options) + tst_optim = te.optimizers.FusedAdam( + list(model.parameters()), + master_weights=True, + master_weight_dtype=torch.float32, + use_decoupled_grad=True, + **options, + ) + + for _ in range(self.iters): + for p_ref, p in zip(ref_params, model.parameters()): + p_ref.grad = torch.rand_like(p_ref) + p.decoupled_grad = p_ref.grad.clone() + ref_optim.step() + tst_optim.step() + + # FP32 master weights should match reference Adam exactly + master_params = [ + tst_optim.get_unscaled_state(p, "master_param") for p in model.parameters() + ] + torch.testing.assert_close(ref_params, master_params) + + # Weight params should still be QuantizedTensors after training + for name, p in model.named_parameters(): + if "bias" not in name: + assert isinstance( + p, QuantizedTensor + ), f"{name} lost QuantizedTensor type after training: {type(p).__name__}" + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + def test_float8block_linear_forward_backward_step(self): + """End-to-end: quantized_model_init + autocast forward + backward + FusedAdam.step(). + + Uses te.autocast with Float8BlockScaling recipe for the forward pass, + verifying the full training loop works with quantized compute. + """ + if not self.fp8_block_available: + pytest.skip(self.fp8_block_reason) + + model, recipe = self._build_model() + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + batch_size, seq_len, hidden = 4, 32, 256 + x = torch.randn(batch_size, seq_len, hidden, dtype=torch.bfloat16, device="cuda") + target = torch.randn_like(x) + + losses = [] + for i in range(self.iters): + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = torch.nn.functional.mse_loss(output, target) + losses.append(loss.item()) + loss.backward() + + # Verify all params have non-None gradients after backward + for name, p in model.named_parameters(): + assert p.grad is not None, f"Step {i}: {name} has no gradient after backward" + assert ( + p.grad.shape == p.shape + ), f"Step {i}: {name} grad shape {p.grad.shape} != param shape {p.shape}" + assert torch.isfinite(p.grad).all(), f"Step {i}: {name} has non-finite gradients" + assert p.grad.any(), f"Step {i}: {name} gradient is all zeros" + + optimizer.step() + + # Verify loss decreased + assert losses[-1] < losses[0], f"Loss did not decrease: {losses}" + + # Verify weight params remain QuantizedTensors + for name, p in model.named_parameters(): + if "bias" not in name: + assert isinstance( + p, QuantizedTensor + ), f"{name} lost QuantizedTensor type: {type(p).__name__}" + + # Verify optimizer states are float32 + for name, p in model.named_parameters(): + state = optimizer.state[p] + assert state["exp_avg"].dtype == torch.float32 + assert state["exp_avg_sq"].dtype == torch.float32 + if "bias" not in name: + assert state["master_param"].dtype == torch.float32 + + class TestFusedSGD(TestFusedOptimizer): def setup_method(self) -> None: @@ -560,7 +838,7 @@ def forward(self, x): return y -class AdamTest: +class TestAdamTest: def setup_method(self, *, seed: int = 0) -> None: torch.manual_seed(seed) @@ -576,8 +854,8 @@ def setup_method(self, *, seed: int = 0) -> None: def test_grad_scaler(self): params_ = [p for p in self.model_.parameters() if p.requires_grad] optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False) - scaler = torch.cuda.amp.GradScaler(enabled=True) - scaler_ = torch.cuda.amp.GradScaler(enabled=True) + scaler = torch.amp.GradScaler("cuda", enabled=True) + scaler_ = torch.amp.GradScaler("cuda", enabled=True) for i in range(100): x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) @@ -627,8 +905,8 @@ def test_grad_scaler(self): def test_grad_scaler_capturable(self): params_ = [p for p in self.model_.parameters() if p.requires_grad] optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=True) - scaler = torch.cuda.amp.GradScaler(enabled=True) - scaler_ = torch.cuda.amp.GradScaler(enabled=True) + scaler = torch.amp.GradScaler("cuda", enabled=True) + scaler_ = torch.amp.GradScaler("cuda", enabled=True) for i in range(100): x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) @@ -685,8 +963,8 @@ def test_grad_scaler_capturable_master(self): optimizer_ = te.optimizers.FusedAdam( params_, lr=self.lr, capturable=True, master_weights=master_weights ) - scaler = torch.cuda.amp.GradScaler(enabled=True) - scaler_ = torch.cuda.amp.GradScaler(enabled=True) + scaler = torch.amp.GradScaler("cuda", enabled=True) + scaler_ = torch.amp.GradScaler("cuda", enabled=True) for i in range(100): x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index f559362d8..36c09060e 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -47,7 +47,7 @@ def group_limited_topk( # Pytorch-based topk softmax/sigmoid -def topk_softmax_sigmoid_pytorch( +def topk_score_function_pytorch( logits: torch.Tensor, topk: int, use_pre_softmax: bool = False, @@ -74,17 +74,20 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): if score_function == "softmax": if use_pre_softmax: - scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + scores = torch.softmax(logits, dim=-1, dtype=torch.float32) probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) else: scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) - probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) - elif score_function == "sigmoid": - scores = torch.sigmoid(logits.float()).type_as(logits) + probs = torch.softmax(scores, dim=-1, dtype=torch.float32) + elif score_function in ("sigmoid", "sqrtsoftplus"): + if score_function == "sigmoid": + scores = torch.sigmoid(logits.float()) + else: + scores = torch.nn.functional.softplus(logits.float()).sqrt() if expert_bias is not None: scores_for_routing = scores + expert_bias _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) - scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) + scores = torch.gather(scores, dim=1, index=top_indices) else: scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores @@ -94,6 +97,8 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): if scaling_factor: probs = probs * scaling_factor + probs = probs.type_as(logits) + topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs) topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() @@ -107,8 +112,11 @@ def compute_scores_for_aux_loss_pytorch( if score_function == "softmax": scores = torch.softmax(logits, dim=-1, dtype=torch.float32) elif score_function == "sigmoid": - scores = torch.sigmoid(logits) - scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + scores = torch.sigmoid(logits.float()) + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) + elif score_function == "sqrtsoftplus": + scores = torch.nn.functional.softplus(logits.float()).sqrt() + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) else: raise ValueError(f"Invalid score_function: {score_function}") @@ -146,8 +154,9 @@ def run_comparison( enable_bias, ): # Set some parameters - if score_function == "sigmoid": - # Construct the special logits to avoid inf in the sigmoid function + if score_function in ("sigmoid", "sqrtsoftplus"): + # Construct logits with a narrow range to avoid very small activation values, + # which would cause precision loss when adding/subtracting expert bias in float32. offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 logits = ( torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 @@ -165,8 +174,8 @@ def run_comparison( ) logits = logits.view(num_tokens, num_experts) logits.requires_grad = True - if enable_bias and score_function == "sigmoid": - expert_bias = torch.arange(num_experts, device="cuda") * 0.1 + if enable_bias and score_function in ("sigmoid", "sqrtsoftplus"): + expert_bias = torch.arange(num_experts, device="cuda", dtype=dtype) * 0.1 expert_bias = torch.flip(expert_bias, dims=[0]) expert_bias.requires_grad = True else: @@ -183,7 +192,7 @@ def run_comparison( # Run the original implementation # We do not support the capacity factor case - probs, routing_map = topk_softmax_sigmoid_pytorch( + probs, routing_map = topk_score_function_pytorch( logits=logits, topk=topk, use_pre_softmax=use_pre_softmax, @@ -252,6 +261,37 @@ def test_topk_sigmoid( ) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992]) +@pytest.mark.parametrize("num_experts", [128, 32]) +@pytest.mark.parametrize("topk", [4, 8]) +@pytest.mark.parametrize("group_topk", [None, 4]) +@pytest.mark.parametrize("scaling_factor", [None, 1.2]) +@pytest.mark.parametrize("enable_bias", [True, False]) +def test_topk_sqrtsoftplus( + dtype, + num_tokens, + num_experts, + topk, + group_topk, + scaling_factor, + enable_bias, +): + num_groups = 8 if group_topk else None + run_comparison( + dtype=dtype, + num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + use_pre_softmax=False, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function="sqrtsoftplus", + enable_bias=enable_bias, + ) + + @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [128, 32]) @@ -284,13 +324,13 @@ def test_topk_softmax( @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) +@pytest.mark.parametrize("num_tokens", [2048, 7168]) @pytest.mark.parametrize("num_experts", [256, 128, 32]) -@pytest.mark.parametrize("topk", [4, 8]) -@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("topk", [1, 4, 8]) +@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"]) def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): - if score_function == "sigmoid": - # Construct the special logits to avoid inf in the sigmoid function + if score_function in ("sigmoid", "sqrtsoftplus"): + # Construct logits with a narrow range to avoid very small activation values offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 logits = ( torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 @@ -396,15 +436,6 @@ def profile_topk_softmax( test_topk_softmax( torch.float32, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor ) - - -if __name__ == "__main__": - test_topk_softmax( - dtype=torch.float32, - num_tokens=1024, - num_experts=128, - topk=4, - use_pre_softmax=False, - group_topk=None, - scaling_factor=None, + test_topk_sqrtsoftplus( + torch.float32, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias ) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e4647ac82..20ac80207 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -7,8 +7,10 @@ from __future__ import annotations from collections.abc import Iterable +import functools import io import math +import random from typing import Optional import pytest @@ -40,7 +42,14 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION # Import utility functions -from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states +from utils import ( + assert_close, + assert_close_grads, + dtype_tols, + make_recipe, + quantization_tols, + reset_rng_states, +) # Check for supported quantization schemes fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) @@ -111,6 +120,9 @@ def maybe_skip_quantization( @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + *, + min: float = 0.0, + max: float = 1.0, quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", @@ -131,7 +143,8 @@ def make_reference_and_test_tensors( """ # Random reference tensor - ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + ref = torch.empty(shape, dtype=ref_dtype, device=ref_device) + ref.uniform_(min, max) # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) @@ -169,7 +182,7 @@ def make_reference_and_test_tensors( test = test.dequantize() # Make sure reference and test tensors match each other - ref.copy_(test) + ref.copy_(test.to(dtype=ref.dtype)) ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) @@ -1573,7 +1586,19 @@ def test_make_extra_output( @pytest.mark.parametrize( "activation", - ("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"), + ( + "gelu", + "geglu", + "qgelu", + "qgeglu", + "relu", + "reglu", + "glu", + "srelu", + "sreglu", + "silu", + "swiglu", + ), ) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) @@ -1593,7 +1618,7 @@ def test_activation( # Tensor dimensions in_shape = list(out_shape) - if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"): + if activation in ("geglu", "glu", "qgeglu", "reglu", "sreglu", "swiglu"): in_shape[-1] *= 2 # Skip invalid configurations @@ -1633,6 +1658,13 @@ def test_activation( elif activation == "reglu": x1, x2 = x_ref.chunk(2, dim=-1) y_ref = torch.nn.functional.relu(x1) * x2 + elif activation == "sigmoid": + y_ref = torch.nn.functional.sigmoid(x_ref) + elif activation == "glu": + x = x_ref.reshape(*in_shape[:-1], 2, in_shape[-1] // 2) + x = x.flip(-2) # PyTorch GLU swaps gate and linear unit + x = x.reshape(in_shape) + y_ref = torch.nn.functional.glu(x) elif activation == "srelu": y_ref = torch.nn.functional.relu(x_ref) ** 2 elif activation == "sreglu": @@ -1652,6 +1684,7 @@ def test_activation( make_op = dict( gelu=te_ops.GELU, geglu=te_ops.GEGLU, + glu=te_ops.GLU, qgelu=te_ops.QGELU, qgeglu=te_ops.QGEGLU, relu=te_ops.ReLU, @@ -1696,6 +1729,7 @@ def test_swiglu( quantization: Optional[str], quantize_forward: bool, quantize_backward: bool, + glu_interleave_size: Optional[int] = None, ): # Tensor dimensions @@ -1722,7 +1756,17 @@ def test_swiglu( ) # Plain PyTorch implementation - x1, x2 = x_ref.chunk(2, dim=-1) + x = x_ref + if glu_interleave_size is not None: + x = x.reshape( + *in_shape[:-1], + in_shape[-1] // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(-3, -2) + x = x.reshape(in_shape) + x1, x2 = x.chunk(2, dim=-1) y_ref = torch.nn.functional.silu(x1) * x2 y_ref.backward(dy_ref) @@ -1730,7 +1774,7 @@ def test_swiglu( recipe = make_recipe(quantization) forward = te_ops.Sequential( te_ops.Quantize(forward=False, backward=quantize_backward), - te_ops.SwiGLU(), + te_ops.SwiGLU(glu_interleave_size=glu_interleave_size), te_ops.Quantize(forward=quantize_forward, backward=False), ) with te.autocast(enabled=quantized_compute, recipe=recipe): @@ -1743,10 +1787,19 @@ def test_swiglu( tols = quantization_tols(quantization) # Check results - y_test = y_test.to(dtype=torch.float64, device="cpu") - dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(y_test, y_ref, **tols) - torch.testing.assert_close(dx_test, x_ref.grad, **tols) + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + + def test_interleaved_swiglu(self): + """SwiGLU with block interleaved input format""" + self.test_swiglu( + out_shape=(32, 192), + dtype=torch.float32, + quantization=None, + quantize_forward=False, + quantize_backward=False, + glu_interleave_size=32, + ) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) @@ -1756,6 +1809,7 @@ def test_clamped_swiglu( self, *, out_shape: Iterable[int] = (32, 32), + glu_interleave_size: Optional[int] = None, dtype: torch.dtype, device: torch.device = "cuda", quantization: Optional[str], @@ -1764,7 +1818,7 @@ def test_clamped_swiglu( limit: float = 0.75, alpha: float = 1.702, ): - # Test SwiGLU variant used in GPT OSS. + """SwiGLU variant used in GPT-OSS""" # Tensor dimensions in_shape = list(out_shape) in_shape[-1] *= 2 @@ -1789,7 +1843,17 @@ def test_clamped_swiglu( ) # Plain PyTorch implementation - x_glu, x_linear = x_ref.chunk(2, dim=-1) + x = x_ref + if glu_interleave_size is not None: + x = x.reshape( + *in_shape[:-1], + in_shape[-1] // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(-3, -2) + x = x.reshape(in_shape) + x_glu, x_linear = x.chunk(2, dim=-1) x_glu = x_glu.clamp(min=None, max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) out_glu = x_glu * torch.sigmoid(alpha * x_glu) @@ -1801,7 +1865,11 @@ def test_clamped_swiglu( forward = te_ops.Sequential( te_ops.Quantize(forward=False, backward=quantize_backward), - te_ops.ClampedSwiGLU(limit=limit, alpha=alpha), + te_ops.ClampedSwiGLU( + limit=limit, + alpha=alpha, + glu_interleave_size=glu_interleave_size, + ), te_ops.Quantize(forward=quantize_forward, backward=False), ) with te.autocast(enabled=quantized_compute, recipe=recipe): @@ -1817,10 +1885,19 @@ def test_clamped_swiglu( tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results - y_test = y_test.to(dtype=torch.float64, device="cpu") - dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(y_test, y_ref, **tols) - torch.testing.assert_close(dx_test, x_ref.grad, **tols) + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + + def test_interleaved_clamped_swiglu(self): + """GPT-OSS SwiGLU with block interleaved input format""" + self.test_clamped_swiglu( + out_shape=(32, 192), + dtype=torch.float32, + quantization=None, + quantize_forward=False, + quantize_backward=False, + glu_interleave_size=32, + ) @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) @pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2))) @@ -1895,7 +1972,7 @@ def test_dropout( ) with torch.no_grad(): x_test += 1 - x_ref.copy_(x_test) + x_ref.copy_(x_test.to(dtype=x_ref.dtype)) dy_ref, dy_test = make_reference_and_test_tensors( shape, test_dtype=dtype, @@ -1940,6 +2017,231 @@ def test_dropout( abs(z_score) < 2.5758 ), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})" + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_compute", (False, True)) + @pytest.mark.parametrize("quantized_weight", (False, True)) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("weight_requires_grad", (False, True)) + def test_grouped_linear( + self, + *, + group_size: int = 4, + bias: bool, + weight_shape: tuple[int, int] = (128, 128), + split_alignment: int = 128, + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_compute: bool, + quantized_weight: bool, + input_requires_grad: bool, + weight_requires_grad: bool, + ) -> None: + """Grouped GEMM""" + + # Split sizes + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = (split_sizes.sum().item(), in_features) + out_shape = (in_shape[0], out_features) + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not used") + if quantization is not None and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + ws_ref, ws_test = [], [] + bs_ref, bs_test = [], [] + for _ in range(group_size): + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=weight_requires_grad, + ) + b_ref, b_test = None, None + if bias: + b_ref, b_test = make_reference_and_test_tensors( + out_features, + test_dtype=dtype, + test_device=device, + requires_grad=weight_requires_grad, + ) + ws_ref.append(w_ref) + ws_test.append(w_test) + bs_ref.append(b_ref) + bs_test.append(b_test) + + # Plain PyTorch implementation + xs_ref = torch.split(x_ref, split_sizes.tolist()) + ys_ref = [] + for x, w, b in zip(xs_ref, ws_ref, bs_ref): + ys_ref.append(torch.nn.functional.linear(x, w, bias=b)) + y_ref = torch.cat(ys_ref) + if input_requires_grad or weight_requires_grad: + y_ref.backward(dy_ref) + + # Construct fusible operation + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + op = te_ops.GroupedLinear( + group_size, + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + ) + with torch.no_grad(): + for group_idx in range(group_size): + getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx]) + if bias: + getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx]) + del ws_test, bs_test + for param in op.parameters(): + param.requires_grad_(requires_grad=weight_requires_grad) + + # Forward and backward pass with op + with te.autocast(enabled=quantized_compute, recipe=recipe): + y_test = op(x_test, split_sizes) + if input_requires_grad or weight_requires_grad: + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = quantization_tols(quantization) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + if input_requires_grad: + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + else: + assert x_test.grad is None + for group_idx in range(group_size): + w_test = getattr(op, f"weight{group_idx}") + if weight_requires_grad: + dw_test = w_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dw_test, ws_ref[group_idx].grad, **tols) + else: + assert w_test.grad is None + if bias: + b_test = getattr(op, f"bias{group_idx}") + if weight_requires_grad: + db_test = b_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, bs_ref[group_idx].grad, **tols) + else: + assert b_test.grad is None + + @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("scales_requires_grad", (False, True)) + def test_scaled_swiglu( + self, + *, + in_shape: Iterable[int], + glu_interleave_size: Optional[int] = None, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + input_requires_grad: bool, + scales_requires_grad: bool, + ) -> None: + """SwiGLU with post-scale""" + + # Tensor dims + out_shape = list(in_shape) + out_shape[-1] //= 2 + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + scales_ref, scales_test = make_reference_and_test_tensors( + in_shape[:-1], + test_dtype=dtype, + test_device=device, + requires_grad=scales_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x = x_ref + if glu_interleave_size is not None: + x = x.reshape( + -1, + in_shape[-1] // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(1, 2) + x = x.reshape(in_shape) + x1, x2 = x.chunk(2, dim=-1) + y = torch.nn.functional.silu(x1) * x2 + y_ref = scales_ref.unsqueeze(-1) * y + if input_requires_grad or scales_requires_grad: + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + y_test = op(x_test, scales_test) + if input_requires_grad or scales_requires_grad: + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + assert_close_grads(scales_test, scales_ref, **tols) + + def test_interleaved_scaled_swiglu(self): + """SwiGLU with post-scale and block interleaved input format""" + self.test_scaled_swiglu( + in_shape=(32, 192), + glu_interleave_size=32, + input_requires_grad=True, + scales_requires_grad=True, + ) + class TestFusedOps: """Tests for fused operations""" @@ -2345,13 +2647,13 @@ def test_backward_activation_bias( backward_ops = model._module_groups[0]._backward_ops if with_quantization: assert len(backward_ops) == 2 - assert isinstance(backward_ops[0][0], BackwardActivationBias) - assert isinstance(backward_ops[1][0], te_ops.Quantize) + assert isinstance(backward_ops[0][0], te_ops.Quantize) + assert isinstance(backward_ops[1][0], BackwardActivationBias) else: assert len(backward_ops) == 3 - assert isinstance(backward_ops[0][0], act_type) + assert isinstance(backward_ops[0][0], te_ops.Quantize) assert isinstance(backward_ops[1][0], te_ops.Bias) - assert isinstance(backward_ops[2][0], te_ops.Quantize) + assert isinstance(backward_ops[2][0], act_type) # Expected numerical error tols = dtype_tols(dtype) @@ -2946,3 +3248,607 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: if bias: torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) + + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("glu_interleave_size", (None, 32)) + def test_grouped_mlp( + self, + *, + group_size: int = 4, + bias: bool, + hidden_size: int = 256, + dtype: torch.dtype, + quantization: Optional[str], + device: torch.device = "cuda", + split_alignment: int = 256, + glu_interleave_size: Optional[int], + ) -> None: + """GroupedLinear + ScaledSwiGLU + GroupedLinear""" + + # Split sizes + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + + # Make input shape + in_shape = (split_sizes.sum().item(), hidden_size) + out_shape = in_shape + + # Skip invalid configurations + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + if with_quantization and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + min=-0.25, + max=0.25, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + min=-0.25, + max=0.25, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + probs_ref, probs_test = make_reference_and_test_tensors( + (in_shape[0],), + test_dtype=dtype, + test_device=device, + ) + fc1_ws_ref, fc1_ws_test = [], [] + fc1_bs_ref, fc1_bs_test = [], [] + fc2_ws_ref, fc2_ws_test = [], [] + fc2_bs_ref, fc2_bs_test = [], [] + for _ in range(group_size): + fc1_w_ref, fc1_w_test = make_reference_and_test_tensors( + (2 * hidden_size, hidden_size), + min=-0.25, + max=0.25, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + fc2_w_ref, fc2_w_test = make_reference_and_test_tensors( + (hidden_size, hidden_size), + min=-0.25, + max=0.25, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + fc1_b_ref, fc1_b_test = None, None + fc2_b_ref, fc2_b_test = None, None + if bias: + fc1_b_ref, fc1_b_test = make_reference_and_test_tensors( + (2 * hidden_size,), + min=-0.5, + max=0.5, + test_dtype=dtype, + test_device=device, + ) + fc2_b_ref, fc2_b_test = make_reference_and_test_tensors( + (hidden_size,), + min=-0.5, + max=0.5, + test_dtype=dtype, + test_device=device, + ) + fc1_ws_ref.append(fc1_w_ref) + fc1_bs_ref.append(fc1_b_ref) + fc1_ws_test.append(fc1_w_test) + fc1_bs_test.append(fc1_b_test) + fc2_ws_ref.append(fc2_w_ref) + fc2_bs_ref.append(fc2_b_ref) + fc2_ws_test.append(fc2_w_test) + fc2_bs_test.append(fc2_b_test) + + # Reference implementation + xs = torch.split(x_ref, split_sizes.tolist()) + probs = torch.split(probs_ref, split_sizes.tolist()) + ys = [] + for group_idx in range(group_size): + x = xs[group_idx] + x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]) + if glu_interleave_size is not None: + x = x.reshape( + -1, + 2 * hidden_size // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(1, 2) + x = x.reshape(-1, 2 * hidden_size) + x1, x2 = x.chunk(2, dim=-1) + x = torch.nn.functional.silu(x1) * x2 + x = x * probs[group_idx].unsqueeze(-1) + x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx], bias=fc2_bs_ref[group_idx]) + ys.append(x) + y_ref = torch.cat(ys) + y_ref.backward(dy_ref) + + # Construct operations + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=with_quantization, recipe=recipe): + fc1 = te_ops.GroupedLinear( + group_size, + hidden_size, + 2 * hidden_size, + bias=bias, + device=device, + dtype=dtype, + ) + fc2 = te_ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + ) + module = te_ops.Sequential( + fc1, + te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), + fc2, + ) + + # Copy weights + with torch.no_grad(): + for group_idx in range(group_size): + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) + if bias: + getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) + getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) + del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test + + # Fuse ops and perform forward and backward pass + with te.autocast(enabled=with_quantization, recipe=recipe): + y_test = module(x_test, split_sizes, probs_test, split_sizes) + y_test.backward(dy_test) + + # Loose tols for sanity checking + tols = {"rtol": 0.125, "atol": 0.25} + if quantization == "nvfp4": + tols = {"rtol": 0.25, "atol": 0.5} + + # Check values + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + assert_close_grads(probs_test, probs_ref, **tols) + for group_idx in range(group_size): + assert_close_grads(getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols) + assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols) + assert_close_grads(getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols) + assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols) + + +class TestCustomOps: + """Test with ops that are defined externally""" + + def test_custom_basic_op( + self, + *, + shape: Iterable[int] = (7, 5), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ) -> None: + """Custom basic op""" + + class LearnableScale(te.ops.BasicOperation): + """Custom op that applies a learnable scale + + This class is as an example in the op fuser guide at + docs/examples/op_fuser/op_fuser.rst (see "Implementing a + basic operation"). Any changes made to this class should + also be made there. + + """ + + def __init__(self) -> None: + super().__init__() + self.scale: torch.nn.Parameter + scale = torch.ones((), dtype=dtype, device=device) + scale = torch.nn.Parameter(scale) + self.register_parameter("scale", scale) + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + **unused, + ) -> torch.Tensor: + out = self.scale * input_ + ctx.save_for_backward(self.scale, input_) + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]: + scale, input_ = ctx.saved_tensors + grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1)).reshape(()) + grad_input = scale * grad_output + return grad_input, (grad_scale,) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = w_ref * x_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = LearnableScale() + forward = te.ops.Sequential(te.ops.Identity(), op, te.ops.Identity()) + with torch.no_grad(): + op.scale.copy_(w_test) + del w_test + y_test = forward(x_test) + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = op.scale.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + def test_custom_forward_fused_op1( + self, + *, + shape: Iterable[int] = (5, 11), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ): + """Custom fused op in forward pass""" + + class ForwardAxpy(te.ops.FusedOperation): + """Custom op that computes BLAS SAXPY in forward pass + + This class is as an example in the op fuser guide at + docs/examples/op_fuser/op_fuser.rst (see "Implementing a + fused operation"). Any changes made to this class should + also be made there. + + """ + + _enabled = True + + def __init__( + self, + scale: te.ops.ConstantScale, + add: te.ops.AddExtraInput, + ) -> None: + super().__init__((scale, add)) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + **unused, + ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, ...]]]: + scale_op, add_op = self.basic_ops + extra_input = basic_op_extra_inputs[1][0] # Extra input to add op + out = scale_op.scale * input_ + extra_input + scale_ctx, add_ctx = basic_op_ctxs # No state needed for backward + return ( + out, # Output + [(), ()], # Extra outputs for each basic op + ) + + def fuse_axpy_ops( + ops: list[te.ops.FusibleOperation], + **unused, + ) -> list[te.ops.FusibleOperation]: + """Apply fusion the first time this function is called""" + if ForwardAxpy._enabled: + ForwardAxpy._enabled = False + else: + return ops + out = [] + window, ops = ops[:2], ops[2:] + while len(window) == 2: + if isinstance(window[0], te.ops.ConstantScale) and isinstance( + window[1], te.ops.AddExtraInput + ): + window = [ForwardAxpy(*window)] + else: + out.append(window[0]) + window = window[1:] + window, ops = window + ops[:1], ops[1:] + out.extend(window + ops) + return out + + # Random data + scale = 0.5 + x1_ref, x1_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + x2_ref, x2_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = scale * x1_ref + x2_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operation + te.ops.register_forward_fusion(fuse_axpy_ops) + model = te.ops.Sequential( + te.ops.ConstantScale(scale=scale), + te.ops.AddExtraInput(), + ) + y_test = model(x1_test, x2_test) + y_test.backward(dy_test) + + # Check values + tols = dtype_tols(dtype) + assert_close(y_test, y_ref, **tols) + assert_close_grads(x1_test, x1_ref, **tols) + assert_close_grads(x2_test, x2_ref, **tols) + + def test_custom_forward_fused_op2( + self, + *, + shape: Iterable[int] = (7, 11), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ): + """Custom fused op in forward pass""" + + class CustomForwardLinearSiLU(te.ops.FusedOperation): + """Custom fused op for GEMM + SiLU""" + + _enabled = True + + def __init__(self, *, linear, silu) -> None: + super().__init__((linear, silu)) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + **unused, + ) -> torch.Tensor: + weight = self.basic_ops[0].weight + dtype = weight.dtype + device = weight.device + + # Perform compute on CPU, because why not? + x = input_.cpu() + w = weight.cpu() + y = torch.matmul(x, w.T) + z = torch.nn.functional.silu(y) + out = z.to(device=device) + + # Save state for linear backward + linear_op_ctx = basic_op_ctxs[0] + linear_op_ctx.save_for_backward(input_, weight) + linear_op_ctx.with_quantized_compute = False + linear_op_ctx.input_quantizer = None + linear_op_ctx.weight_quantizer = None + linear_op_ctx.grad_output_quantizer = None + linear_op_ctx.grad_input_quantizer = None + linear_op_ctx.dtype = dtype + linear_op_ctx.input_requires_grad = True + linear_op_ctx.weight_requires_grad = True + + # Save state for SiLU backward + silu_op_ctx = basic_op_ctxs[1] + silu_op_ctx.save_for_backward(y.to(device=device)) + silu_op_ctx.dtype = dtype + silu_op_ctx.prev_op_grad_output_quantizer = None + + return out, [(), ()] + + @staticmethod + def fuse_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply fusion the first time this function is called""" + if CustomForwardLinearSiLU._enabled: + CustomForwardLinearSiLU._enabled = False + op = CustomForwardLinearSiLU(linear=ops[0], silu=ops[1]) + return [op] + ops[2:] + return ops + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (shape[-1], shape[-1]), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref) + y_ref = torch.nn.functional.silu(y_ref) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + te.ops.register_forward_fusion(CustomForwardLinearSiLU.fuse_ops) + model = te.ops.Sequential( + te.ops.Linear(shape[-1], shape[-1], bias=False), + te.ops.SiLU(), + ) + with torch.no_grad(): + model[0].weight.copy_(w_test) + del w_test + y_test = model(x_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + forward_ops = model._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + assert isinstance(forward_ops[0][0], CustomForwardLinearSiLU) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + def test_custom_backward_fused_op( + self, + *, + shape: Iterable[int] = (13, 5), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ): + """Custom fused op in backward pass""" + + class CustomBackwardLinearScale(te.ops.FusedOperation): + """Custom fused op for backward linear + scale""" + + _enabled: bool = True + + def __init__(self, *, scale, linear) -> None: + super().__init__((scale, linear)) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + **unused, + ) -> torch.Tensor: + + # Load state from linear forward + linear_op_ctx = basic_op_ctxs[1] + x, w = linear_op_ctx.saved_tensors + dtype = linear_op_ctx.dtype + device = w.device + + # Perform compute in FP64 and apply scale before dgrad + # GEMM instead of after + scale = self.basic_ops[0].scale + dy = grad_output.double() + x = x.double() + w = w.double() + dx = torch.matmul(dy, scale * w) + dw = torch.matmul(dy.T, x) + dx = dx.to(dtype=dtype) + dw = dw.to(dtype=dtype) + + return dx, [(), (dw,)], [(), ()] + + @staticmethod + def fuse_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply fusion the first time this function is called""" + if CustomBackwardLinearScale._enabled: + CustomBackwardLinearScale._enabled = False + op = CustomBackwardLinearScale(scale=ops[0], linear=ops[1]) + return [op] + ops[2:] + return ops + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (shape[-1], shape[-1]), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + scale = 1.234 + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(scale * x_ref, w_ref) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + te.ops.register_backward_fusion(CustomBackwardLinearScale.fuse_ops, prepend=True) + model = te.ops.Sequential( + te.ops.ConstantScale(scale), + te.ops.Linear(shape[-1], shape[-1], bias=False), + ) + with torch.no_grad(): + model[1].weight.copy_(w_test) + del w_test + y_test = model(x_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + backward_ops = model._module_groups[0]._backward_ops + assert len(backward_ops) == 1 + assert isinstance(backward_ops[0][0], CustomBackwardLinearScale) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py new file mode 100644 index 000000000..9dd965fa9 --- /dev/null +++ b/tests/pytorch/test_grouped_tensor.py @@ -0,0 +1,466 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for GroupedTensor class""" + +from typing import List, Tuple +import pytest +import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor +from transformer_engine.pytorch import ( + Quantizer, + Float8Quantizer, + Float8CurrentScalingQuantizer, + Float8BlockQuantizer, + MXFP8Quantizer, + NVFP4Quantizer, +) +from transformer_engine.pytorch.constants import TE_DType_To_Torch +import transformer_engine_torch as tex + +# Check available recipes +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +_quantization_params = [ + pytest.param( + "fp8_delayed_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + ), + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + ), + pytest.param( + "fp8_blockwise", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + ), +] + + +def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer: + """Create quantizer for given quantization scheme""" + + if quantization == "fp8_delayed_scaling": + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device="cuda"), + amax=torch.zeros(1, dtype=torch.float32, device="cuda"), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + quantizer.set_usage(rowwise=True, columnwise=False) + elif quantization == "fp8_blockwise": + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=False, + force_pow_2_scales=True, + amax_epsilon=0.0, + block_scaling_dim=1, + ) + elif quantization == "mxfp8": + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + elif quantization == "nvfp4": + quantizer = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + else: + raise ValueError(f"Unknown quantization scheme: {quantization}") + + quantizer.internal = False + + return quantizer + + +def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor: + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"): + return qtensor._data + if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"): + return qtensor._rowwise_data + raise ValueError(f"Unknown quantization scheme: {quantization}") + + +def _rowwise_offset_bytes(numel: int, quantization: str) -> int: + if quantization == "nvfp4": + return numel // 2 + return numel + + +class TestGroupedTensor: + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_basic_construction_all_same_shape(self) -> None: + """Test GroupedTensor construction with all tensors having same shape""" + num_tensors = 4 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert grouped_tensor.all_same_shape() + assert grouped_tensor.all_same_first_dim() + assert grouped_tensor.all_same_last_dim() + assert grouped_tensor.logical_shape == (num_tensors * 256, 512) + assert grouped_tensor.get_common_first_dim() == 256 + assert grouped_tensor.get_common_last_dim() == 512 + assert grouped_tensor.has_data() + + def test_basic_construction_varying_first_dim(self) -> None: + """Test GroupedTensor construction with varying first dimension""" + num_tensors = 3 + shape = [(128, 512), (256, 512), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert not grouped_tensor.all_same_shape() + assert not grouped_tensor.all_same_first_dim() + assert grouped_tensor.all_same_last_dim() + assert grouped_tensor.get_common_last_dim() == shape[0][1] + assert grouped_tensor.logical_shape == ( + sum(v for v, _ in shape), + shape[0][1], + ) # sum of first dims + + def test_split_into_quantized_tensors_no_quantization(self) -> None: + """Test split_into_quantized_tensors for unquantized tensors""" + num_tensors = 3 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + # GroupedTensor is a wrapper; use backing storage buffer pointer. + storage = grouped_tensor.rowwise_data + if storage is None: + storage = grouped_tensor.columnwise_data + assert storage is not None + original_data_ptr = storage.data_ptr() + + # Split into tensors + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify each tensor has correct shape and shares storage + for i, tensor in enumerate(tensors): + assert tensor.shape == shape[i] + assert isinstance(tensor, torch.Tensor) + assert not hasattr(tensor, "_data") # Not a quantized tensor + + # Verify data pointer is within the original grouped tensor storage + # The tensor should be a view of the original data + assert tensor.data_ptr() >= original_data_ptr + + # Calculate expected offset + expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size() + assert tensor.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None: + """Test split_into_quantized_tensors for quantized tensors""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shape, + quantizer=quantizer, + device="cuda", + dtype=torch.float32, + ) + + # GroupedTensor is a wrapper; use backing storage buffer pointer. + storage = grouped_tensor.rowwise_data + if storage is None: + storage = grouped_tensor.columnwise_data + assert storage is not None + original_data_ptr = storage.data_ptr() + + # Split into tensors + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify each tensor shares storage with the grouped tensor + for i, tensor in enumerate(tensors): + rowwise_data = _get_rowwise_data_tensor(tensor, quantization) + assert rowwise_data is not None + assert rowwise_data.data_ptr() >= original_data_ptr + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + def test_split_varying_shapes(self) -> None: + """Test split_into_quantized_tensors with varying shapes""" + num_tensors = 3 + shape = [(128, 512), (256, 512), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + storage = grouped_tensor.rowwise_data + if storage is None: + storage = grouped_tensor.columnwise_data + assert storage is not None + original_data_ptr = storage.data_ptr() + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify shapes and storage + cumulative_offset = 0 + for i, tensor in enumerate(tensors): + assert tensor.shape == shape[i] + expected_offset = cumulative_offset * tensor.element_size() + assert tensor.data_ptr() == original_data_ptr + expected_offset + cumulative_offset += shape[i][0] * shape[i][1] + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_quantize_inplace(self, quantization: str) -> None: + """Test that quantize is done in-place for all recipes""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizer = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shape, + quantizer=quantizer, + device="cuda", + dtype=torch.float32, + ) + + # Get original data pointers before quantization + storage = grouped_tensor.rowwise_data + if storage is None: + storage = grouped_tensor.columnwise_data + assert storage is not None + original_data_ptr = storage.data_ptr() + original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr() + original_scale_ptr = ( + grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None + ) + + # Create input tensors + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Quantize in place + quantized_tensors = grouped_tensor.quantize(input_tensors) + + # Verify data pointers haven't changed (in-place operation) + assert storage.data_ptr() == original_data_ptr + assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr + if original_scale_ptr is not None: + assert grouped_tensor.scale.data_ptr() == original_scale_ptr + + # Verify returned tensors point to the same storage + for i, qtensor in enumerate(quantized_tensors): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_quantize_varying_shapes(self, quantization: str) -> None: + """Test quantize with varying shapes""" + num_tensors = 3 + shape = [(256, 512), (512, 512), (768, 512)] + quantizer = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shape, + quantizer=quantizer, + device="cuda", + dtype=torch.float32, + ) + + # Get original data pointers + storage = grouped_tensor.rowwise_data + if storage is None: + storage = grouped_tensor.columnwise_data + assert storage is not None + original_data_ptr = storage.data_ptr() + + # Create input tensors with varying shapes + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Quantize in place + quantized_tensors = grouped_tensor.quantize(input_tensors) + + # Verify data pointer hasn't changed + assert storage.data_ptr() == original_data_ptr + + # Verify each tensor points to correct location + cumulative_numel = 0 + for qtensor, tensor_shape in zip(quantized_tensors, shape): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + cumulative_numel += tensor_shape[0] * tensor_shape[1] + + @pytest.mark.parametrize( + "shape", + [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], + ) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: + """Test grouped quantization for MXFP8 against per-tensor quantization.""" + # Test wont pass until the grouped quantization PR from Oleg is merged. + num_tensors = 2 + shape = [(512, 1024) for _ in range(num_tensors)] + + # Create BF16 input tensors and pack into a 2D tensor + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + grouped_input = torch.cat(input_tensors, dim=0) + + # Create MXFP8 output grouped tensor (rowwise only for easier validation) + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=False) + first_dims = torch.tensor( + [shape[0][0] for _ in range(num_tensors)], + dtype=torch.int64, + device="cuda", + ) + + # Quantize using grouped API + grouped_output = tex.group_quantize( + grouped_input, + quantizer, + num_tensors, + first_dims, + ) + # Build expected output by quantizing each tensor independently + expected_data = [] + expected_scale_inv = [] + for tensor in input_tensors: + qtensor = quantizer(tensor) + expected_data.append(qtensor._rowwise_data.reshape(-1)) + expected_scale_inv.append(qtensor._rowwise_scale_inv.reshape(-1)) + + expected_data = torch.cat(expected_data) + expected_scale_inv = torch.cat(expected_scale_inv) + + assert torch.equal(grouped_output.rowwise_data, expected_data) + assert torch.equal(grouped_output.scale_inv, expected_scale_inv) + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_group_quantize_cudagraph_capturable(self) -> None: + """Ensure group_quantize is CUDA graph capturable.""" + num_tensors = 2 + shape = [(512, 1024) for _ in range(num_tensors)] + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + grouped_input = torch.cat(input_tensors, dim=0) + + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=False) + first_dims = torch.tensor( + [shape[0][0] for _ in range(num_tensors)], + dtype=torch.int64, + device="cuda", + ) + + torch.cuda.synchronize() + static_input = grouped_input.clone() + static_first_dims = first_dims.clone() + + # Warmup to initialize kernels and allocator state + _ = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + static_output = tex.group_quantize( + static_input, + quantizer, + num_tensors, + static_first_dims, + ) + + fresh_input = torch.cat( + [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape], + dim=0, + ) + static_input.copy_(fresh_input) + graph.replay() + torch.cuda.synchronize() + + expected = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims) + assert torch.equal(static_output.rowwise_data, expected.rowwise_data) + assert torch.equal(static_output.scale_inv, expected.scale_inv) + + def test_clear(self) -> None: + """Test clear method""" + num_tensors = 3 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.has_data() + assert grouped_tensor.num_tensors == num_tensors + + grouped_tensor.clear() + + assert not grouped_tensor.has_data() + assert grouped_tensor.num_tensors == 0 + assert grouped_tensor.rowwise_data is None + assert grouped_tensor.logical_shape == (0, 0) diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index 359155f00..20b513c07 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -140,6 +140,117 @@ def find_inf( ) +@pytest.mark.parametrize("input_size_pair", input_size_pairs) +@pytest.mark.parametrize("applier", appliers) +@pytest.mark.parametrize("repeat", [1, 55]) +@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("out_type", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("inplace", [False, True]) +def test_multi_tensor_scale_tensor(input_size_pair, applier, repeat, in_type, out_type, inplace): + if inplace is True and (out_type is not in_type): + pytest.skip("inplace=True and out_type != in_type is not supported.") + elif (in_type == torch.float16 and out_type == torch.bfloat16) or ( + in_type == torch.bfloat16 and out_type == torch.float16 + ): + pytest.skip("float16 to bfloat16 is not necessary and vice versa.") + + device = torch.device("cuda") + scale = 4.0 + inv_scale_cuda = torch.tensor([1.0 / scale], dtype=torch.float32, device=device) + overflow_buf = torch.zeros(1, dtype=torch.int32, device=device) + ref = torch.tensor([1.0], dtype=torch.float32, device=device) + sizea, sizeb = input_size_pair + + def downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=False): + overflow_buf.zero_() + a = torch.full([sizea], scale, dtype=torch.float32, device=device) + b = torch.full([sizeb], scale, dtype=torch.float32, device=device) + + out_list = [] + for _ in range(repeat): + out_list += [a.clone().to(out_type), b.clone().to(out_type)] + + if inplace: + in_list = out_list + else: + in_list = [out.clone().to(in_type) for out in out_list] + + applier(tex.multi_tensor_scale_tensor, overflow_buf, [in_list, out_list], inv_scale_cuda) + + assert all([torch.allclose(out, ref.to(out_type)) for out in out_list]) + assert overflow_buf.item() == 0 + + def find_inf( + sizea, + sizeb, + applier, + repeat, + in_type, + out_type, + t, + ind, + val, + inplace=False, + ): + overflow_buf.zero_() + a = torch.full([sizea], scale, dtype=torch.float32, device=device) + b = torch.full([sizeb], scale, dtype=torch.float32, device=device) + + out_list = [] + for _ in range(repeat): + out_list += [a.clone().to(out_type), b.clone().to(out_type)] + + if inplace: + in_list = out_list + else: + in_list = [out.clone().to(in_type) for out in out_list] + + applier(tex.multi_tensor_scale_tensor, overflow_buf, [in_list, out_list], inv_scale_cuda) + + overflow_buf.zero_() + in_list[t][ind] = val + applier(tex.multi_tensor_scale_tensor, overflow_buf, [in_list, out_list], inv_scale_cuda) + assert overflow_buf.item() > 0 + + downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=inplace) + find_inf( + sizea, + sizeb, + applier, + repeat, + in_type, + out_type, + 0, + 0, + float("nan"), + inplace=inplace, + ) + find_inf( + sizea, + sizeb, + applier, + repeat, + in_type, + out_type, + 2 * repeat - 1, + sizeb - 1, + float("inf"), + inplace=inplace, + ) + find_inf( + sizea, + sizeb, + applier, + repeat, + in_type, + out_type, + 2 * (repeat // 2), + sizea // 2, + float("inf"), + inplace=inplace, + ) + + @pytest.mark.parametrize("input_size_pair", input_size_pairs) @pytest.mark.parametrize("applier", appliers) @pytest.mark.parametrize("repeat", [1, 55]) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index d1e9b341e..7b672a640 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -54,7 +54,12 @@ ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm +from transformer_engine.pytorch.cpp_extensions import ( + general_gemm, + general_grouped_gemm, + general_grouped_gemm_for_grouped_tensor, +) +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.common import recipe import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states @@ -109,6 +114,7 @@ def rocm_attn_backend() -> tuple[bool, bool, bool]: all_activations = [ "gelu", "geglu", + "glu", "qgelu", "qgeglu", "relu", @@ -512,6 +518,7 @@ def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor: _supported_act = { "gelu": nn.GELU(approximate="tanh"), "geglu": nn.GELU(approximate="tanh"), + "glu": nn.Sigmoid(), "qgelu": TorchQuickGELU(), "qgeglu": TorchQuickGELU(), "relu": nn.ReLU(), @@ -3079,6 +3086,343 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) +def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: + data = grouped_tensor.rowwise_data + if data is None: + data = grouped_tensor.columnwise_data + if data is None: + raise ValueError("GroupedTensor has no data buffers to pack.") + offset = 0 + for tensor in tensors: + numel = tensor.numel() + data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + +def _make_grouped_tensor_from_splits( + m_sizes: List[int], + last_dim: int, + device: torch.device, + dtype: torch.dtype, +) -> GroupedTensor: + first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) + return GroupedTensor.make_grouped_tensor( + num_tensors=len(m_sizes), + first_dims=first_dims, + last_dims=None, + logical_first_dim=sum(m_sizes), + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + +def _make_grouped_tensor_uniform( + num_tensors: int, + first_dim: int, + last_dim: int, + device: torch.device, + dtype: torch.dtype, +) -> GroupedTensor: + return GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=None, + last_dims=None, + logical_first_dim=num_tensors * first_dim, + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + +@pytest.mark.skipif(IS_HIP_EXTENSION, reason="Grouped GEMM is not yet supported in ROCm TE") +@pytest.mark.parametrize( + "z, m, n, k", + [ + (4, 256, 256, 256), + (4, 512, 256, 512), + (4, 512, 512, 256), + (8, 512, 256, 512), + ], +) +@pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"]) +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("accumulate", [False, True]) +def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate) -> None: + if tex.get_cublasLt_version() < 130200: + pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + + dtype = torch.bfloat16 + + split_points = torch.randperm(m - 1)[: z - 1] + 1 + split_points = torch.sort(split_points).values.tolist() + m_sizes = [split_points[0]] + m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])] + m_sizes.append(m - split_points[-1]) + assert sum(m_sizes) == m and len(m_sizes) == z + + if layout == "NT": + A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out_ref = [torch.matmul(B[i].transpose(0, 1).float(), A[i].float()) for i in range(z)] + else: + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [ + torch.randn(ms, k if layout == "TN" else n, dtype=dtype, device="cuda") + for ms in m_sizes + ] # TN --> input, NN --> grad_output + out = [ + torch.randn(ms, n if layout == "TN" else k, dtype=dtype, device="cuda") + for ms in m_sizes + ] # TN --> output, NN --> dgrad + if layout == "NN": + out_ref = [torch.matmul(B[i].float(), A[i].float()) for i in range(z)] + else: # layout == "TN" + out_ref = [torch.matmul(B[i].float(), A[i].transpose(0, 1).float()) for i in range(z)] + + if accumulate: + out_ref = [out[i].float() + o for i, o in enumerate(out_ref)] + + # Bias is applied after GEMM (broadcasted along rows) + # Match kernel behavior: GEMM output is already in output dtype when bias is added. + out_ref_no_bias = [o.to(dtype) for o in out_ref] + if layout == "TN": + bias_last_dim = n + else: # layout == "NT" or "NN" + bias_last_dim = k + bias = ( + [torch.randn(1, bias_last_dim, dtype=dtype, device="cuda") for _ in range(z)] + if case != "discrete_out" + else None + ) + # Bias add in grouped kernel accumulates in FP32 for BF16/FP16. + out_ref = ( + [(o.float() + b.float()).to(dtype) for o, b in zip(out_ref_no_bias, bias)] + if bias is not None + else out_ref_no_bias + ) + # Create grouped tensors based on case + device = A[0].device + grouped_A = A + grouped_out = out + grouped_out_bias = [o.clone() for o in out] + grouped_out_no_bias = [o.clone() for o in out] + grouped_bias = None + if layout == "TN": + grouped_A = ( + _make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A + ) # weight + grouped_B = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input + if case != "discrete_out": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # output + grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + elif layout == "NN": + grouped_A = ( + _make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A + ) # weight + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output + if case != "discrete_out": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + else: # layout == "NT" + grouped_A = ( + _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + if case != "discrete_in" + else A + ) # input + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output + if case != "discrete_out": + grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) # wgrad + grouped_out_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype) + grouped_out_no_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype) + _pack_grouped_tensor(grouped_B, B) + if case != "discrete_out": + _pack_grouped_tensor(grouped_out, out) + _pack_grouped_tensor(grouped_out_bias, out) + _pack_grouped_tensor(grouped_out_no_bias, out) + if case != "discrete_in": + _pack_grouped_tensor(grouped_A, A) + + if bias is not None: + grouped_bias = _make_grouped_tensor_uniform(z, 1, bias_last_dim, device, dtype) + _pack_grouped_tensor(grouped_bias, bias) + + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out_no_bias, + layout=layout, + accumulate=accumulate, + bias=None, + ) + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out_bias, + layout=layout, + accumulate=accumulate, + bias=grouped_bias, + ) + out_grouped_no_bias = ( + grouped_out_no_bias + if isinstance(grouped_out_no_bias, list) + else grouped_out_no_bias.split_into_quantized_tensors() + ) + out_grouped_bias = ( + grouped_out_bias + if isinstance(grouped_out_bias, list) + else grouped_out_bias.split_into_quantized_tensors() + ) + + out_grouped_manual_bias = ( + [(o.float() + b.float()).to(dtype) for o, b in zip(out_grouped_no_bias, bias)] + if bias is not None + else out_grouped_no_bias + ) + tols = dtype_tols(dtype) + for o, o_ref in zip(out_grouped_no_bias, out_ref_no_bias): + torch.testing.assert_close(o, o_ref, **tols) + if bias is not None: + for o, o_ref in zip(out_grouped_bias, out_grouped_manual_bias): + torch.testing.assert_close(o, o_ref, **tols) + + +def _make_grouped_tensor_quantized_mxfp8( + tensors: List[torch.Tensor], + *, + is_a: bool, + transposed: bool, + device: torch.device, + optimize_for_gemm: bool = True, +) -> GroupedTensor: + if not tensors: + raise ValueError("Expected non-empty tensor list for grouped quantization.") + if is_a: + rowwise = transposed + columnwise = not transposed + else: + rowwise = not transposed + columnwise = transposed + quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=rowwise, + columnwise=columnwise, + ) + quantizer.optimize_for_gemm = optimize_for_gemm + grouped_input = torch.cat(tensors, dim=0) + first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) + return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims) + + +@pytest.mark.parametrize( + "shape", + [ + (1, 128, 128, 512), + (8, 1024, 128, 512), + (16, 4096, 128, 512), + ], +) +@pytest.mark.parametrize("accumulate", [False, True]) +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_grouped_gemm_grouped_tensor_mxfp8( + shape, accumulate, layout: str, case: str, dtype: torch.dtype +) -> None: + if not IS_HIP_EXTENSION and tex.get_cublasLt_version() < 130200: + pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") + if IS_HIP_EXTENSION: + if not is_mxfp8_available(): + pytest.skip("MXFP8 is not supported on this config") + elif torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if dtype == torch.bfloat16 and not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + z, m, k, n = shape + m_sizes = [m // z] * z + + if layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + out = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # output + grad = False + elif layout == "NN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # dgrad + grad = True + else: # layout == "NT" + A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + grad = True + + out_ref = [o.clone() for o in out] + + transa = layout[0] == "T" + transb = layout[1] == "T" + grouped_A = _make_grouped_tensor_quantized_mxfp8(A, is_a=True, transposed=transa, device="cuda") + grouped_B = _make_grouped_tensor_quantized_mxfp8( + B, is_a=False, transposed=transb, device="cuda" + ) + A_fp8 = grouped_A.split_into_quantized_tensors() + B_fp8 = grouped_B.split_into_quantized_tensors() + + general_grouped_gemm( + A_fp8, + B_fp8, + out_ref, + [None] * z, + dtype, + m_splits=m_sizes, + grad=grad, + accumulate=accumulate, + layout=layout, + single_output=False, + ) + + device = A[0].device + + grouped_out = None + if case != "discrete_out": + if layout == "TN": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + elif layout == "NN": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + else: # layout == "NT" + grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) + _pack_grouped_tensor(grouped_out, out) + + grouped_out_input = out if case == "discrete_out" else grouped_out + grouped_A_input = A_fp8 if case == "discrete_in" else grouped_A + general_grouped_gemm_for_grouped_tensor( + grouped_A_input, + grouped_B, + grouped_out_input, + layout=layout, + accumulate=accumulate, + ) + + out_grouped = out if case == "discrete_out" else grouped_out.split_into_quantized_tensors() + tols = dict(rtol=0.125, atol=0.0675) # mxfp8 tolerance + + for o, o_ref in zip(out_grouped, out_ref): + torch.testing.assert_close(o, o_ref, **tols) + + @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 50cd150c4..9aea3bc27 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -713,6 +713,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation): _test_export_layernorm_mlp(activation=activation) +# Quantization recipes with fp8_dpa=True for attention emulation export test +dpa_quantization_recipes = [None] # None = no quantization +if fp8_available: + dpa_quantization_recipes.append(recipe.DelayedScaling(fp8_dpa=True)) + dpa_quantization_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True)) + + +@pytest.mark.parametrize("fp8_recipe", dpa_quantization_recipes) @pytest.mark.parametrize( "precision, use_mask, attn_mask_type", [ @@ -730,6 +738,7 @@ def test_export_core_attention( precision: torch.dtype, use_mask: bool, attn_mask_type: str, + fp8_recipe: recipe.Recipe, ): # Set dimensions (these are arbitrary). seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) @@ -749,22 +758,25 @@ def test_export_core_attention( mask_str = get_attn_mask_str(use_mask, attn_mask_type) high_prec_str = dtype2str(precision) - fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" + fp8_str = "_fp8_dpa" if fp8_recipe is not None else "" + fname = f"te.core_attention{fp8_str}{mask_str}{high_prec_str}.onnx" + + is_fp8 = fp8_recipe is not None model = te.attention.DotProductAttention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, - attention_dropout=0.5, qkv_format=qkv_format, attn_mask_type=attn_mask_type, ).to(device="cuda") - do_export(model, inp, fname, input_names=input_names, fp8_recipe=None) - te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None) + do_export(model, inp, fname, input_names=input_names, fp8_recipe=fp8_recipe) + te_outputs = te_infer(model, inp, is_fp8=is_fp8, fp8_recipe=fp8_recipe) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) if precision in (torch.bfloat16,): return + atol = 5e-1 if is_fp8 else 1e-2 validate_result( - fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs + fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs ) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 7b92672af..b4ea193f0 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -167,3 +167,24 @@ def test_ignore_idx_reduced_loss(self): reduce_loss=True, ignore_idx=True, ) + + +def test_non_contiguous_transposed_input(): + """Regression test: stride(-2) != shape[-1] should not produce wrong results.""" + s, b, v = 4, 2, 8 + torch.manual_seed(42) + logits = torch.randn(s, b, v, device="cuda") + target = torch.randint(0, v, (b, s), device="cuda") + + logits_transposed = logits.transpose(0, 1) # stride(-2) != shape[-1] + logits_contiguous = logits_transposed.contiguous() + + assert logits_transposed.stride(-1) == 1 + assert logits_transposed.stride(-2) != logits_transposed.shape[-1] + + loss_t = parallel_cross_entropy(logits_transposed, target, 0.0, False, None) + loss_c = parallel_cross_entropy(logits_contiguous, target, 0.0, False, None) + + assert torch.allclose( + loss_t, loss_c + ), f"Non-contiguous transposed input gave wrong results: {loss_t} vs {loss_c}" diff --git a/tests/pytorch/test_qk_norm.py b/tests/pytorch/test_qk_norm.py index 873bd9186..b182d175e 100644 --- a/tests/pytorch/test_qk_norm.py +++ b/tests/pytorch/test_qk_norm.py @@ -11,7 +11,8 @@ @pytest.mark.parametrize("qk_norm_type", [None, "L2Normalization", "RMSNorm", "LayerNorm"]) @pytest.mark.parametrize("attention_type", ["self", "cross"]) @pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5]) -def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> None: +@pytest.mark.parametrize("params_dtype", [torch.float32, torch.bfloat16]) +def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps, params_dtype) -> None: """Test QK normalization functionality, module structure, and numerical behavior.""" hidden_size = 256 num_attention_heads = 8 @@ -26,6 +27,7 @@ def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> Non qk_norm_eps=qk_norm_eps, bias=False, device="cuda", + params_dtype=params_dtype, ).cuda() # Check module structure based on qk_norm_type parameter @@ -78,13 +80,11 @@ def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> Non # Create input tensors batch_size = 2 # Use a fixed batch size for testing - hidden_states = torch.randn( - seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 - ) + hidden_states = torch.randn(seq_len, batch_size, hidden_size, device="cuda", dtype=params_dtype) if attention_type == "cross": encoder_output = torch.randn( - seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + seq_len, batch_size, hidden_size, device="cuda", dtype=params_dtype ) else: encoder_output = None @@ -109,7 +109,7 @@ def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> Non if attention_type == "self": head_dim = hidden_size // num_attention_heads rotary_dim = head_dim // 2 - rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32) + rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=params_dtype) with torch.no_grad(): output_with_rope = mha(hidden_states, rotary_pos_emb=rotary_pos_emb) diff --git a/tests/pytorch/test_quantized_tensor.py b/tests/pytorch/test_quantized_tensor.py index b2e8fca7c..620fc834d 100644 --- a/tests/pytorch/test_quantized_tensor.py +++ b/tests/pytorch/test_quantized_tensor.py @@ -173,7 +173,7 @@ def make_reference_and_test_tensors( raise ValueError(f"Unsupported quantization scheme ({quantization})") # Make sure reference and test tensors match each other - ref.copy_(test) + ref.copy_(test.to(dtype=ref.dtype)) ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) @@ -656,3 +656,86 @@ def test_chunk( tols = dict(rtol=0, atol=0) # Chunking is exact y_test = y_test.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(y_test, y_ref, **tols) + + +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +class TestMXFP8Tensor: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("dims", [[128, 128], [256, 256], [128, 256]]) + def test_mxfp8_dequantize_columnwise_only( + self, + fp8_dtype: tex.DType, + dtype: torch.dtype, + dims: DimsType, + ) -> None: + """Check dequantization of MXFP8 tensor with only columnwise data""" + + # Initialize random data + x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cuda") - 1 + + # Quantize with both rowwise and columnwise + quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) + x_mxfp8 = quantizer(x_ref) + + # Dequantize from rowwise (default path) + x_deq_rowwise = x_mxfp8.dequantize(dtype=dtype) + + # Rowwise dequantization should be close to the original + torch.testing.assert_close(x_deq_rowwise, x_ref, **_tols[fp8_dtype]) + + # Strip rowwise data, keeping only columnwise + x_mxfp8.update_usage(rowwise_usage=False, columnwise_usage=True) + assert x_mxfp8._rowwise_data is None + assert x_mxfp8._columnwise_data is not None + + # Dequantize from columnwise only + x_deq_columnwise = x_mxfp8.dequantize(dtype=dtype) + + # Columnwise dequantization should be close to the original + torch.testing.assert_close(x_deq_columnwise, x_ref, **_tols[fp8_dtype]) + + # Rowwise and columnwise dequantizations should match each other + torch.testing.assert_close(x_deq_columnwise, x_deq_rowwise, **_tols[fp8_dtype]) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_deq_columnwise, -x_ref, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dims", [[128, 128], [256, 256]]) + def test_mxfp8_dequantize_columnwise_only_quantized_separately( + self, + fp8_dtype: tex.DType, + dims: DimsType, + ) -> None: + """Check dequantization of MXFP8 tensor quantized with columnwise only""" + + dtype = torch.bfloat16 + + # Initialize random data + x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cuda") - 1 + + # Quantize with columnwise only (no rowwise) + quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=False, columnwise=True) + x_mxfp8 = quantizer(x_ref) + assert x_mxfp8._rowwise_data is None + assert x_mxfp8._columnwise_data is not None + + # Dequantize from columnwise only + x_deq = x_mxfp8.dequantize(dtype=dtype) + + # Should be close to the original + torch.testing.assert_close(x_deq, x_ref, **_tols[fp8_dtype]) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_deq, -x_ref, **_tols[fp8_dtype]) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 30a993976..1286aef48 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -4,7 +4,7 @@ # # See LICENSE for license information. -from typing import Optional +from typing import Optional, List import torch import pytest @@ -116,6 +116,7 @@ def nvfp4_vanilla(): all_activations = [ "gelu", "geglu", + "glu", "qgelu", "qgeglu", "relu", @@ -140,6 +141,23 @@ def reset_global_fp8_state(): FP8GlobalStateManager.reset() +def check_grouped_weight( + module: GroupedLinear, num_gemms: int, out_features: int, in_features: int +): + """ + Verify GroupedLinear exposes one grouped weight parameter with shape + [num_gemms, out_features, in_features]. + """ + weight_params = [(name, p) for name, p in module.named_parameters() if "weight" in name] + assert len(weight_params) == 1, f"Expected 1 grouped weight parameter, got {len(weight_params)}" + name, weight = weight_params[0] + assert name == "weight", f"Expected grouped parameter name 'weight', got {name}" + assert tuple(weight.shape) == (num_gemms, out_features, in_features), ( + "Grouped weight has unexpected shape. " + f"Expected {(num_gemms, out_features, in_features)}, got {tuple(weight.shape)}" + ) + + def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( (config.max_seqlen_q, config.batch_size, config.hidden_size), @@ -442,8 +460,6 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias): - if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params: - pytest.skip("Quantized model parameters are not supported in debug mode.") config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size num_tokens = bs * config.max_seqlen_q @@ -476,13 +492,20 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) +@pytest.mark.parametrize("single_param", all_boolean) @pytest.mark.parametrize("empty_split", ["first", "last", "middle"]) @pytest.mark.parametrize("num_gemms", [4]) def test_sanity_grouped_linear( - dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split + dtype, + bs, + model, + fp8_recipe, + fp8_model_params, + use_bias, + single_param, + num_gemms, + empty_split, ): - if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params: - pytest.skip("FP8 model parameters are not supported in debug mode.") config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. @@ -498,9 +521,19 @@ def test_sanity_grouped_linear( use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): te_grouped_linear = GroupedLinear( - num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype + num_gemms, + config.hidden_size, + ffn_hidden_size, + bias=use_bias, + params_dtype=dtype, + single_grouped_parameter=single_param, ).cuda() + # Verify grouped linear exposes a single grouped weight parameter. + if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()): + if single_param: + check_grouped_weight(te_grouped_linear, num_gemms, ffn_hidden_size, config.hidden_size) + inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True ).cuda() @@ -959,7 +992,13 @@ def test_replace_raw_data_for_float8tensor(): random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda") fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor) - attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"] + attrs_to_check = [ + "_quantizer", + "_fp8_dtype", + "_scale_inv", + "_transpose", + "_transpose_invalid", + ] attrs = {} for attr in attrs_to_check: attrs[attr] = getattr(fp8_tensor, attr) @@ -1082,8 +1121,6 @@ def test_inference_mode( quantization: Optional[str], ) -> None: """Test heuristics for initializing quantized weights""" - if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None: - pytest.skip("Quantized model parameters are not supported in debug mode.") # Tensor dimensions sequence_length = 32 diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index c0398b801..92c857f70 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -18,7 +18,7 @@ import transformer_engine import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch import InferenceParams +from transformer_engine.pytorch import InferenceParams, QuantizedTensor from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention.utils import ( get_attention_backend, @@ -297,7 +297,6 @@ def get_available_attention_backends( os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True - alibi_slopes_shape = None if config.attn_bias_type == "alibi" and config.alibi_type == "custom": if config.bias_shape == "1hss": @@ -315,7 +314,9 @@ def get_available_attention_backends( and config.head_dim_qk <= 128 and config.head_dim_v <= 128 ): - core_attention_bias_requires_grad = True + # TODO(KshitijLakhani): Remove this guard when cuDNN starts support dbias calculation for bias shape 111s + if core_attention_bias_shape != "111s": + core_attention_bias_requires_grad = True fused_attn_backends = [] available_backends = None @@ -395,11 +396,56 @@ def test(): backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} if AttentionLogging._is_logging_setup is False: AttentionLogging.setup_logging() - with logging_context(highest_level=AttentionLogging._log_level): - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, flash_attention_backend, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) + + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, flash_attention_backend, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) return available_backends, flash_attention_backend, fused_attn_backends + + +@torch.no_grad +def assert_close( + actual: Optional[torch.Tensor], + expected: Optional[torch.Tensor], + *, + check_device: bool = False, + check_dtype: bool = False, + check_layout: bool = False, + **kwargs, +) -> None: + """Assert that two tensors are close. + + This function is a wrapper around torch.testing.assert_close. It + changes the defaults for device and dtype checks (useful when the + reference implementation is computed in high precision on CPU) and + it can handle quantized tensors. + + """ + if isinstance(actual, QuantizedTensor): + actual = actual.dequantize() + if isinstance(expected, QuantizedTensor): + expected = expected.dequantize() + torch.testing.assert_close( + actual, + expected, + check_device=check_device, + check_dtype=check_dtype, + check_layout=check_layout, + **kwargs, + ) + + +def assert_close_grads( + actual: Optional[torch.Tensor], + expected: Optional[torch.Tensor], + **kwargs, +) -> None: + """Assert that two tensors have close gradients.""" + if actual is None and expected is None: + return + assert actual is not None + assert expected is not None + assert_close(actual.grad, expected.grad, **kwargs) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 0bdfd6085..4714bff79 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -28,7 +28,7 @@ else() else() message(FATAL_ERROR "Could not find NVCC at '$ENV{CUDA_HOME}/bin/nvcc'") endif() - + endif() # Language options @@ -166,23 +166,26 @@ set(transformer_engine_cuda_sources) set(transformer_engine_cuda_arch_specific_sources) # Source files in both cuda and rocm +set(GX_CUDA $) list(APPEND transformer_engine_cpp_sources + $<${GX_CUDA}:cudnn_utils.cpp> transformer_engine.cpp - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/comm_gemm_overlap.cpp + $<${GX_CUDA}:fused_attn/fused_attn.cpp> gemm/config.cpp normalization/common.cpp normalization/layernorm/ln_api.cpp normalization/rmsnorm/rmsnorm_api.cpp util/cuda_driver.cpp + $<${GX_CUDA}:util/cuda_nvml.cpp> util/cuda_runtime.cpp util/multi_stream.cpp - util/rtc.cpp) + util/rtc.cpp + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/comm_gemm_overlap.cpp) list(APPEND transformer_engine_cuda_sources common.cu - comm_gemm_overlap/userbuffers/userbuffers.cu multi_tensor/adam.cu multi_tensor/l2norm.cu multi_tensor/scale.cu @@ -192,18 +195,26 @@ list(APPEND transformer_engine_cuda_sources transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu + $<${GX_CUDA}:transpose/quantize_transpose_vector_blockwise.cu> transpose/swap_first_dims.cu dropout/dropout.cu fused_attn/flash_attn.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu + $<${GX_CUDA}:fused_attn/fused_attn_f16_max512_seqlen.cu> + $<${GX_CUDA}:fused_attn/fused_attn_f16_arbitrary_seqlen.cu> + $<${GX_CUDA}:fused_attn/fused_attn_fp8.cu> + $<${GX_CUDA}:fused_attn/utils.cu> gemm/cublaslt_gemm.cu + gemm/cublaslt_grouped_gemm.cu normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu util/padding.cu + swizzle/swizzle.cu + $<${GX_CUDA}:swizzle/swizzle_block_scaling.cu> fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu @@ -214,63 +225,49 @@ list(APPEND transformer_engine_cuda_sources recipe/current_scaling.cu recipe/delayed_scaling.cu recipe/fp8_block_scaling.cu - swizzle/swizzle.cu) + comm_gemm_overlap/userbuffers/userbuffers.cu) list(APPEND transformer_engine_cuda_arch_specific_sources activation/gelu.cu + activation/glu.cu activation/relu.cu activation/swiglu.cu cast/cast.cu + $<${GX_CUDA}:gemm/cutlass_grouped_gemm.cu> + $<${GX_CUDA}:hadamard_transform/group_hadamard_transform.cu> + $<${GX_CUDA}:hadamard_transform/graph_safe_group_hadamard_transform.cu> + $<${GX_CUDA}:hadamard_transform/hadamard_transform.cu> + $<${GX_CUDA}:hadamard_transform/hadamard_transform_cast_fusion.cu> + $<${GX_CUDA}:hadamard_transform/group_hadamard_transform_cast_fusion.cu> + $<${GX_CUDA}:hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu> + $<${GX_CUDA}:hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu> multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu + recipe/nvfp4.cu + $<${GX_CUDA}:transpose/quantize_transpose_square_blockwise.cu> transpose/quantize_transpose_vector_blockwise_fp4.cu) -if(USE_CUDA) -#NV specific source codes - list(APPEND transformer_engine_cpp_sources - cudnn_utils.cpp - fused_attn/fused_attn.cpp - util/cuda_nvml.cpp) - list(APPEND transformer_engine_cuda_sources - transpose/quantize_transpose_vector_blockwise.cu - fused_attn/fused_attn_f16_max512_seqlen.cu - fused_attn/fused_attn_f16_arbitrary_seqlen.cu - fused_attn/fused_attn_fp8.cu - fused_attn/utils.cu - swizzle/swizzle_block_scaling.cu - recipe/nvfp4.cu) - list(APPEND transformer_engine_cuda_arch_specific_sources - gemm/cutlass_grouped_gemm.cu - hadamard_transform/group_hadamard_transform.cu - transpose/quantize_transpose_square_blockwise.cu - hadamard_transform/hadamard_transform.cu - hadamard_transform/hadamard_transform_cast_fusion.cu - hadamard_transform/group_hadamard_transform_cast_fusion.cu - hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu - transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise_fp4.cu) -else() -#ROCm specific source codes - list(APPEND transformer_engine_cpp_sources - comm_gemm_overlap/rocm_comm_gemm_overlap.cpp - fused_attn_rocm/fused_attn.cpp - gemm/rocm_gemm.cu - gemm/ck_grouped_gemm/ck_grouped_gemm.cpp - gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp - gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp - amd_detail/system.cpp) - list(APPEND transformer_engine_cuda_sources - fused_attn_rocm/fused_attn_aotriton.cpp - fused_attn_rocm/fused_attn_ck.cpp - fused_attn_rocm/utils.cpp) -endif() - # Compiling the files with the worst compilation time first to hopefully overlap # better with the faster-compiling cpp files list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_sources} ${transformer_engine_cuda_sources} ${transformer_engine_cpp_sources}) +if(USE_ROCM) + #ROCm specific source codes + list(APPEND transformer_engine_SOURCES + amd_detail/system.cpp + comm_gemm_overlap/rocm_comm_gemm_overlap.cpp + fused_attn_rocm/fused_attn.cpp + fused_attn_rocm/fused_attn_aotriton.cpp + fused_attn_rocm/fused_attn_ck.cpp + fused_attn_rocm/utils.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp + gemm/rocm_gemm.cu) +endif() + if(USE_CUDA) # Set compile options for CUDA sources with generic architectures foreach(cuda_source IN LISTS transformer_engine_cuda_sources) @@ -313,7 +310,7 @@ endif() add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) -else() +else() #USE_ROCM # process source code files include("${CMAKE_CURRENT_SOURCE_DIR}/../../build_tools/hipify/hipify.cmake") @@ -398,26 +395,40 @@ endif() option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) if (NVTE_WITH_CUBLASMP) target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) - target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) + target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include) find_library(CUBLASMP_LIB NAMES cublasmp libcublasmp PATHS ${CUBLASMP_DIR} PATH_SUFFIXES lib REQUIRED) - find_library(NVSHMEM_HOST_LIB - NAMES nvshmem_host libnvshmem_host.so.3 - PATHS ${NVSHMEM_DIR} + find_library(NCCL_LIB + NAMES nccl libnccl PATH_SUFFIXES lib REQUIRED) - target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) + target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB}) message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") - message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") endif() +# Number of philox4x32 rounds for stochastic rounding (build-time constant). +set(NVTE_BUILD_NUM_PHILOX_ROUNDS_STR $ENV{NVTE_BUILD_NUM_PHILOX_ROUNDS}) +if (NOT NVTE_BUILD_NUM_PHILOX_ROUNDS_STR) + set(NVTE_BUILD_NUM_PHILOX_ROUNDS_STR "10") +endif() +if (NOT NVTE_BUILD_NUM_PHILOX_ROUNDS_STR MATCHES "^[1-9][0-9]*$") + message(FATAL_ERROR + "Environment variable NVTE_BUILD_NUM_PHILOX_ROUNDS must be a positive integer, " + "but got '${NVTE_BUILD_NUM_PHILOX_ROUNDS_STR}'.") +endif() +set(NVTE_BUILD_NUM_PHILOX_ROUNDS ${NVTE_BUILD_NUM_PHILOX_ROUNDS_STR}) + +target_compile_definitions(transformer_engine + PUBLIC NVTE_BUILD_NUM_PHILOX_ROUNDS=${NVTE_BUILD_NUM_PHILOX_ROUNDS}) +message(STATUS "Philox rounds for stochastic rounding: ${NVTE_BUILD_NUM_PHILOX_ROUNDS}") + # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) -else() +else() # USE_ROCM option(NVTE_ENABLE_ROCSHMEM "Compile with ROCSHMEM library" OFF) if (NVTE_ENABLE_ROCSHMEM) add_subdirectory(rocshmem_api) @@ -534,6 +545,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) list(APPEND nvte_sources_with_fast_math activation/gelu.cu + activation/glu.cu activation/relu.cu activation/swiglu.cu) endif() diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 9dbf998e5..5864c533e 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -292,11 +292,13 @@ def _nvidia_cudart_include_dir() -> str: return "" # Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia" - # above doesn't through. However, they don't set "__file__" attribute. - if nvidia.__file__ is None: - return "" + # above doesn't throw. However, they don't set "__file__" attribute. + if nvidia.__file__ is not None: + nvidia_root = Path(nvidia.__file__).parent + else: + nvidia_root = Path(nvidia.__path__[0]) # namespace package - include_dir = Path(nvidia.__file__).parent / "cuda_runtime" + include_dir = nvidia_root / "cuda_runtime" return str(include_dir) if include_dir.exists() else "" diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 675341f7d..ea864813b 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -13,6 +13,14 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { act_fn>(input, output, stream); } +void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_gelu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgelu); @@ -20,6 +28,20 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dgelu); + using namespace transformer_engine; + NVTEGroupedTensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -33,6 +55,20 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; @@ -54,6 +90,15 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) act_fn>(input, output, stream); } +void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_qgelu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgelu); @@ -61,6 +106,20 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } +void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dqgelu); + using namespace transformer_engine; + NVTEGroupedTensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -74,6 +133,20 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/glu.cu b/transformer_engine/common/activation/glu.cu new file mode 100644 index 000000000..45a667067 --- /dev/null +++ b/transformer_engine/common/activation/glu.cu @@ -0,0 +1,24 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_glu); + using namespace transformer_engine; + Empty e = {}; + gated_act_fn>(input, output, e, stream); +} + +void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_dglu); + using namespace transformer_engine; + Empty e = {}; + dgated_act_fn, dsigmoid>(grad, input, output, e, + stream); +} diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index fd70e38c1..fc9122b7e 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -13,6 +13,14 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { act_fn>(input, output, stream); } +void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_relu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_drelu); @@ -20,6 +28,20 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_drelu); + using namespace transformer_engine; + NVTEGroupedTensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -33,6 +55,20 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; @@ -54,6 +90,15 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) act_fn>(input, output, stream); } +void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_srelu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsrelu); @@ -61,6 +106,20 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } +void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dsrelu); + using namespace transformer_engine; + NVTEGroupedTensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -74,6 +133,20 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index cc812a17f..12478af4c 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -13,6 +13,14 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { act_fn>(input, output, stream); } +void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_silu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsilu); @@ -20,6 +28,20 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dsilu); + using namespace transformer_engine; + NVTEGroupedTensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { @@ -33,6 +55,20 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 46c9b64be..4e7e3c4da 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -30,6 +30,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea dispatch::quantize_fwd_helper(input, output, nullptr, stream); } +void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::group_quantize_fwd_helper(input, output, nullptr, stream); +} + void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_noop); @@ -64,6 +73,19 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d input, activation_input, output, dbias, workspace, nullptr, stream); } +void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr const NVTEGroupedTensor activation_input = nullptr; + + dispatch::group_quantize_bwd_helper( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dequantize); using namespace transformer_engine; @@ -106,7 +128,8 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, } // Group quantize assumes contiguous inputs and outputs in memory allocation -// TODO (zhongbo): find a better way to make it a more generalized API +// Note: this API assumes knowing split sections from the host, if split information +// comes from D2H copy, it will break cuda graph capture void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections, const size_t num_tensors, const NVTEQuantizationConfig quant_config, @@ -116,6 +139,6 @@ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *out constexpr bool IS_ACT = false; - dispatch::group_quantize_fwd_helper(input, outputs, split_sections, - num_tensors, quant_config, stream); + dispatch::group_quantize_fwd_host_aware_helper( + input, outputs, split_sections, num_tensors, quant_config, stream); } diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 3f82798f9..f70baf05d 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -26,6 +26,16 @@ namespace transformer_engine { namespace dispatch { namespace common { + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +#ifndef __HIP_PLATFORM_AMD__ + inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { const size_t N = product(t->data.shape); const bool isFullTile = (N % elems_per_block == 0); @@ -39,6 +49,14 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) { return cols % alignment_requirement == 0; } +__device__ __forceinline__ unsigned char *align_smem_ptr_per_TMA_requirements(unsigned char *p) { + size_t addr = reinterpret_cast(p); + addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1); + return reinterpret_cast(addr); +} + +#endif //!__HIP_PLATFORM_AMD__ + namespace kernel { constexpr size_t THREADS_PER_BLOCK = 256; @@ -76,6 +94,56 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) } stg_vec.store_to(thread_out_base); } + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + group_reduce_dbias_kernel(const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const offsets_ptr, const int64_t *const first_dims_ptr, + const int64_t *const last_dims_ptr, OType *const dbias_output, + const float *dbias_partial, const size_t chunk_dim_Y) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const size_t tensor_id = blockIdx.y; + const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (first_logical_dim / num_tensors) + : first_dims_ptr[tensor_id]; + + const size_t rows = tensor_rows / chunk_dim_Y; + const size_t cols = last_logical_dim; + + const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (tensor_id * (tensor_rows / chunk_dim_Y)) + : (offsets_ptr[tensor_id] / cols / chunk_dim_Y); + + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + dbias_in_offset_Y * cols + thread_id * nvec; + OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} } // namespace kernel template @@ -94,6 +162,32 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, NVTE_CHECK_CUDA(cudaGetLastError()); } +template +void grouped_reduce_dbias(const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const data_tensor_offsets_ptr, + const int64_t *const data_tensor_first_dims_ptr, + const int64_t *const data_tensor_last_dims_ptr, GroupedTensor *dbias, + const float *workspace_ptr, const size_t chunk_dim_Y, + cudaStream_t stream) { + using namespace kernel; + constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 + constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(last_logical_dim % reduce_dbias_nvec == 0, "Unsupported shape."); + + const size_t blocks_X = DIVUP(last_logical_dim, THREADS_PER_BLOCK * reduce_dbias_nvec); + const size_t blocks_Y = num_tensors; + const dim3 grid(blocks_X, blocks_Y); + + group_reduce_dbias_kernel<<>>( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, data_tensor_offsets_ptr, + data_tensor_first_dims_ptr, data_tensor_last_dims_ptr, + reinterpret_cast(dbias->data.dptr), workspace_ptr, chunk_dim_Y); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace common } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 4f9ef80dc..809581db4 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -20,6 +20,7 @@ #include "../../util/vectorized_pointwise.h" #include "../core/common.cuh" #include "../fp8/quantize_fp8.cuh" +#include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" //TODO: ROCm TE does not support nvfp4 yet #ifndef __HIP_PLATFORM_AMD__ @@ -324,10 +325,12 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens } } +// Host-aware and not graph-safe: group quantization with split section info from the host. template -void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs, - const size_t *split_sections, const size_t num_tensors, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { +void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *outputs, + const size_t *split_sections, const size_t num_tensors, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { using namespace detail; const Tensor *input_tensor = convertNVTETensorCheck(input); @@ -390,6 +393,90 @@ void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs, } } +template +void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + const NVTEGroupedTensor activation = nullptr; + NVTEGroupedTensor dbias = nullptr; + NVTETensor workspace = nullptr; + + const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + mxfp8::group_quantize( + input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } +} + +template +void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + using namespace detail; + + NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); + const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + mxfp8::group_quantize( + grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } +} + } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh new file mode 100644 index 000000000..b56e28968 --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -0,0 +1,1029 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file group_quantize_mxfp8.cuh + * \brief CUDA kernels to quantize grouped tensors to MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "../core/common.cuh" +#include "swizzle.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +#ifndef __HIP_PLATFORM_AMD__ +namespace group_quantize_kernel { + +using namespace dispatch::common; + +constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; +__device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 128; + +constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X; + +constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; +constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + +constexpr size_t BUFF_DIM_Y = THREADS_Y; +constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; +static_assert(BUFF_DIM_Y == 32); + +constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; +static_assert(STAGES >= 1); + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + +__device__ __forceinline__ size_t get_current_tensor_id( + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, + const size_t block_Y, const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr) { + if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + const size_t current_row = block_Y * CHUNK_DIM_Y; + const size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } else { + size_t low = 1; + size_t hi = num_tensors; // [low, hi] + + while (low < hi) { + const size_t mid = low + (hi - low) / 2; + const size_t mid_offset = static_cast(offsets_ptr[mid]); + + if (mid_offset <= current_offset) { + low = mid + 1; + } else { + hi = mid; + } + } + return low - 1; + } +} + +__device__ __forceinline__ size_t get_tensor_rows_num( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim, + const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) { + size_t rows_num = 0; + switch (shape_rep) { + case ShapeRepresentation::SAME_BOTH_DIMS: + case ShapeRepresentation::VARYING_LAST_DIM: + rows_num = first_logical_dim; + break; + case ShapeRepresentation::VARYING_FIRST_DIM: + case ShapeRepresentation::VARYING_BOTH_DIMS: + rows_num = static_cast(first_dims_ptr[tensor_id]); + break; + } + if (rows_num % 128 != 0) { + NVTE_DEVICE_ERROR("First dimension of each tensor in a group must be divisible by 128."); + } + return rows_num; +} + +__device__ __forceinline__ size_t get_tensor_cols_num( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t last_logical_dim, + const int64_t *const __restrict__ last_dims_ptr) { + size_t cols_num = 0; + switch (shape_rep) { + case ShapeRepresentation::SAME_BOTH_DIMS: + case ShapeRepresentation::VARYING_FIRST_DIM: + cols_num = last_logical_dim; + break; + case ShapeRepresentation::VARYING_LAST_DIM: + case ShapeRepresentation::VARYING_BOTH_DIMS: + cols_num = static_cast(last_dims_ptr[tensor_id]); + break; + } + return cols_num; +} + +// Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index +__device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_tensor_map, + CUtensorMap *global_tensor_map, + const uintptr_t global_data_ptr, + const size_t global_dim_Y, + const size_t global_dim_X, + const size_t data_type_size_bytes) { + __shared__ CUtensorMap shared_tensor_map; + shared_tensor_map = base_tensor_map; // Copy the base tensor map into shmem + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + const size_t global_stride_bytes = global_dim_X * data_type_size_bytes; + if (global_stride_bytes % TMA_GMEM_ALIGNMENT != 0) { + NVTE_DEVICE_ERROR("Shape not supported. Data stride must be 16B aligned."); + } + if (global_data_ptr % TMA_GMEM_ALIGNMENT != 0) { + NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned"); + } + + asm volatile( + "{\n\t" + ".reg.b64 tensor_map_ptr; \n\t" + "mov.b64 tensor_map_ptr, %0; \n\t" + "tensormap.replace.tile.global_address.b1024.b64 [tensor_map_ptr], %1; \n\t" + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 1, %2; \n\t" // DIM Y + "tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 0, %3; \n\t" // DIM X + "tensormap.replace.tile.global_stride.b1024.b64 [tensor_map_ptr], 0, %4; \n" + "}\n" ::"l"(reinterpret_cast(&shared_tensor_map)), + "l"(global_data_ptr), "r"(static_cast(global_dim_Y)), + "r"(static_cast(global_dim_X)), "l"(static_cast(global_stride_bytes)) + : "memory"); + *global_tensor_map = shared_tensor_map; + } else { + NVTE_DEVICE_ERROR( + "tensormap.replace is architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } +} + +template +__global__ void update_tma_descriptors( + const __grid_constant__ CUtensorMap base_tensor_map_input, + const __grid_constant__ CUtensorMap base_tensor_map_act_input, + const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap base_tensor_map_output_colwise, + const IType *const __restrict__ input_data_ptr, + const IType *const __restrict__ act_input_data_ptr, + const OType *const __restrict__ output_rowwise_data_ptr, + const OType *const __restrict__ output_colwise_data_ptr, const ShapeRepresentation shape_rep, + const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, const bool colwise, + const bool compute_dactivations) { + const bool leading_thread = (threadIdx.x == 0); + const size_t tensor_id = blockIdx.x; + + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + + const size_t offset_elts = offsets_ptr[tensor_id]; + + if (leading_thread && (tensor_id < num_tensors)) { + { + const uintptr_t global_data_ptr = reinterpret_cast(input_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], + global_data_ptr, rows, cols, sizeof(IType)); + } + if (compute_dactivations) { + const uintptr_t global_data_ptr = + reinterpret_cast(act_input_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_act_input, &g_tensor_maps_act_input[tensor_id], + global_data_ptr, rows, cols, sizeof(IType)); + } + if (rowwise) { + const uintptr_t global_data_ptr = + reinterpret_cast(output_rowwise_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_output_rowwise, + &g_tensor_maps_output_rowwise[tensor_id], global_data_ptr, rows, cols, + sizeof(OType)); + } + if (colwise) { + const uintptr_t global_data_ptr = + reinterpret_cast(output_colwise_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_output_colwise, + &g_tensor_maps_output_colwise[tensor_id], global_data_ptr, rows, cols, + sizeof(OType)); + } + } +} + +__device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tensor_map) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" ::"l"(tensor_map)); +#else + NVTE_DEVICE_ERROR("fence_acquire_tensormap is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( + const __grid_constant__ CUtensorMap tensor_map_input_static, + const __grid_constant__ CUtensorMap tensor_map_act_input_static, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, + const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim, + const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, + e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, + float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + + using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx; + + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + + const size_t block_ID = blockIdx.y * gridDim.x + blockIdx.x; + const size_t block_global_offset = + is_single_tensor ? (blockIdx.y * CHUNK_DIM_Y * last_logical_dim + blockIdx.x * CHUNK_DIM_X) + : (block_ID * ELTS_PER_CHUNK); + + const size_t tensor_id = + get_current_tensor_id(shape_rep, num_tensors, block_global_offset, blockIdx.y, + first_logical_dim, last_logical_dim, offsets_ptr); + + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + + const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); + const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); + + // grouped tensor can be treated as continuous tensor for MXFP8 + const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); + // For grouped tensors represented as a single logical tensor, scale swizzle must still be + // computed per tensor (expert) and then concatenated along dim-0. + const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) + ? static_cast(offsets_ptr[tensor_id]) + : tensor_base; + + // In graph-safe paged stashing, the logical shape can include trailing garbage. Skip CTAs that + // map outside the current tensor's valid [rows, cols] region. + if (rows == 0 || cols == 0) { + return; + } + if (shape_rep != SAME_BOTH_DIMS) { + const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); + if (block_global_offset >= tensor_end_offset) { + return; + } + const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; + const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; + const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; + if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { + return; + } + } + + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + const CUtensorMap &tensor_map_output_rowwise = + is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap &tensor_map_output_colwise = + is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; + + const bool leading_thread = (threadIdx.x == 0); + + if (leading_thread && (!is_single_tensor)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); + } + if constexpr (ROWWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_rowwise); + } + if constexpr (COLWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_colwise); + } + } + + const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); + const size_t block_id_in_current_tensor = + is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); + + const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; + const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; + + const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; + const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + + e8m0_t *const scales_rowwise = + scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); + e8m0_t *const scales_colwise = + scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); + + const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X; + const size_t tid_Y_colwise = 0; + const size_t tid_X_colwise = threadIdx.x; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { +#pragma unroll + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; + } + } + + float block_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, leading_thread); + + int parity = 0; + + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], + &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], leading_thread); + } else { + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], leading_thread); + } + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + leading_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], leading_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], parity); + + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + const size_t tensor_base_row = tensor_base_for_scales / cols; + const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; + const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; + const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; + scale_idx = tensor_scales_offset_colwise_base + + gemm_swizzled_scale_idx(global_scales_offset_X, local_scales_offset_Y, + DIVUP(rows, static_cast(128))); + } else { + scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + } + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } + + // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, + DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } + scales_rowwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (leading_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + + parity ^= 1; + + if constexpr (IS_DBIAS) { + if (is_single_tensor) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] + // HEIGHT = THREADS_Y + // WIDTH = THREADS_X * (SCALE_DIM_X + 1) + // Added extra 1-element padding per thread_X to reduce bank conflicts + float *partial_dbias_rowwise = reinterpret_cast(dshmem); + + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < THREADS_Y; ++i) { + // Add extra element offset per MXFP8 scaling block [1x32] + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const int dbias_stride = cols; + const int dbias_offset_Y = block_id_Y; + const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (leading_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, leading_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace group_quantize_kernel +#endif //!__HIP_PLATFORM_AMD__ + +template +void group_quantize(const GroupedTensor *input, const GroupedTensor *activations, + const Tensor *noop, GroupedTensor *output, GroupedTensor *dbias, + Tensor *workspace, cudaStream_t stream) { +#ifdef __HIP_PLATFORM_AMD__ + NVTE_ERROR("group_quantize is not supported on ROCm yet."); +#else + using namespace group_quantize_kernel; + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + + const bool use_rowwise_scaling = output->has_data(); + const bool use_colwise_scaling = output->has_columnwise_data(); + NVTE_CHECK(use_rowwise_scaling || use_colwise_scaling, + "Either rowwise or columnwise output data need to be allocated."); + + ScalingType scaling_type = ScalingType::BIDIMENSIONAL; + if (!use_colwise_scaling) { + scaling_type = ScalingType::ROWWISE; + } else if (!use_rowwise_scaling) { + scaling_type = ScalingType::COLWISE; + } + + ShapeRepresentation shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + if (output->all_same_shape()) { + shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + } else if (output->all_same_first_dim()) { + shape_rep = ShapeRepresentation::VARYING_LAST_DIM; + } else if (output->all_same_last_dim()) { + shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + } else if (output->varying_both_dims()) { + shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; + } + + // Treat a grouped tensor with const last dims as a single tensor + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + + NVTE_CHECK(input->num_tensors == output->num_tensors, + "Number of input and output tensors must be same."); + NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + if (IS_DACT) { + NVTE_CHECK(activations->has_data(), "Activations tensor must have data."); + NVTE_CHECK(input->num_tensors == activations->num_tensors, + "Number of grad and activations tensors must be same."); + NVTE_CHECK(input->dtype() == activations->dtype(), + "Grad and activations tensors must have the same type."); + } + + const size_t first_logical_dim = input->logical_shape.data[0]; + const size_t last_logical_dim = input->logical_shape.data[1]; + const size_t elts_total = first_logical_dim * last_logical_dim; + + const size_t num_tensors = input->num_tensors; + + size_t blocks_X = 0; + size_t blocks_Y = 0; + + if (is_single_tensor) { + blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); + blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); + } else { + NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, + "Number of tensors in a group is larger than " + "the MAX number of supported descriptors (64)."); + blocks_Y = 1; + blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + } + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; + + // Logical shape of a tensor with varying all dims is [1, M*K] + if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) { + NVTE_CHECK(first_logical_dim % 128 == 0, + "First logical dimension of a grouped tensor must be divisible by 128."); + } + + const int64_t *const offsets_ptr = reinterpret_cast(output->tensor_offsets.dptr); + const int64_t *const first_dims_ptr = reinterpret_cast(output->first_dims.dptr); + const int64_t *const last_dims_ptr = reinterpret_cast(output->last_dims.dptr); + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + + e8m0_t *const scales_rowwise_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + + if (use_rowwise_scaling) { + NVTE_CHECK(scales_rowwise_ptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); + } + + if constexpr (IS_DBIAS) { + NVTE_CHECK(is_single_tensor, + "DBias is only supported for tensors with the const last dimension."); + NVTE_CHECK(dbias->data.dtype == input->dtype(), + "DBias must have the same type as input_tensor."); + + std::vector expected_shape_dbias_tensor = {num_tensors, last_logical_dim}; + NVTE_CHECK(dbias->data.shape == expected_shape_dbias_tensor, "Wrong shape of DBias."); + + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + const size_t dbias_workspace_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t dbias_workspace_cols = last_logical_dim; + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_workspace_rows, dbias_workspace_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, output_type_bit_size); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + auto kernel = + group_quantize_mxfp8_kernel; + switch (scaling_type) { + case ScalingType::ROWWISE: { + kernel = + group_quantize_mxfp8_kernel; + break; + } + case ScalingType::COLWISE: { + kernel = + group_quantize_mxfp8_kernel; + break; + } + case ScalingType::BIDIMENSIONAL: { + kernel = + group_quantize_mxfp8_kernel; + break; + } + } + + // Update tensor descriptors before launching the kernel + if (!is_single_tensor) { + const IType *const input_dptr = reinterpret_cast(input->data.dptr); + + const IType *const act_input_dptr = + IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; + + OType *const output_rowwise_dptr = + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; + + OType *const output_colwise_dptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) + : nullptr; + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, + output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, + use_rowwise_scaling, use_colwise_scaling, IS_DACT); + } + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, + scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); + + if constexpr (IS_DBIAS) { + common::grouped_reduce_dbias( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, + first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); + } + + NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +#endif //__HIP_PLATFORM_AMD__ +} + +} // namespace mxfp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh index bdbe5cddc..8d2d80655 100644 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -88,10 +88,11 @@ __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const return global_encode_scale; } -__device__ __forceinline__ uint32_t -get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10> &rng, - // philox4x32_native_state<10>: 10 rounds of philox4_32 - uint4 &random_uint4, int &rnd_idx) { +__device__ __forceinline__ uint32_t get_rbits( + transformer_engine::curanddx::detail::philox4x32_native_state + &rng, + // philox4x32_native_state: compile-time configurable rounds + uint4 &random_uint4, int &rnd_idx) { if (rnd_idx == 4) { rnd_idx = 0; random_uint4 = rng.generate4(); diff --git a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh index 1ceb08a9d..a2f3dac15 100644 --- a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh @@ -191,7 +191,7 @@ __global__ void __launch_bounds__(THREADS_NUM) threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + transformer_engine::curanddx::detail::philox4x32_native_state rng; rng.init(rng_seed, rng_sequence, rng_offset); uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 5da9cc5a5..f164636e3 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -21,6 +21,7 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" #include "core_nvfp4.cuh" +#include "specialized/quantize_transpose_nvfp4_tuned_1D.cuh" namespace transformer_engine { namespace dispatch { @@ -134,7 +135,7 @@ __global__ void __launch_bounds__(THREADS_NUM) threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + transformer_engine::curanddx::detail::philox4x32_native_state rng; rng.init(rng_seed, rng_sequence, rng_offset); uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x @@ -646,7 +647,7 @@ __global__ void __launch_bounds__(THREADS_NUM) threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + transformer_engine::curanddx::detail::philox4x32_native_state rng; rng.init(rng_seed, rng_sequence, rng_offset); uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; int rnd_idx = @@ -1159,6 +1160,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, #if FP4_TYPE_SUPPORTED using namespace quantize_transpose_kernel; using namespace ptx; + bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to @@ -1166,6 +1168,11 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, // TODO(Frank): Is there a better way to do this? bool return_transpose = output->has_columnwise_data(); + if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { + quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); + return; + } + constexpr bool COMPUTE_ACTIVATIONS = false; using ParamOP = Empty; constexpr float (*OP)(float, const ParamOP &) = nullptr; diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh new file mode 100644 index 000000000..fc337f607 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -0,0 +1,805 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_transpose_nvfp4_tuned_1D.cuh + * \brief Tuned kernel to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ + +#include +#include +#include +#include + +#include "../../../common.h" +#include "../../../util/math.h" +#include "../../../util/ptx.cuh" +#include "../../../utils.cuh" +#include "../core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +namespace quantize_transpose_tuned_kernel { + +using namespace quantization_and_transposition_SF; +using namespace core; +using namespace ptx; + +#if FP4_TYPE_SUPPORTED + +struct TunableConfig { + static constexpr int CHUNK_DIM_Y = 128; + static constexpr int CHUNK_DIM_X = 128; + static constexpr int PREFETCH_STAGES = 1; + static constexpr bool PERSISTENT = false; +}; + +constexpr int SCALE_DIM = 16; // NVFP4 block (x16 elts) +constexpr int THREADS_NUM = 128; +constexpr int ELTS_PER_THREAD = 16; +constexpr int TILE_DIM_Y = 64; +constexpr int TILE_DIM_X = 64; + +static_assert(ELTS_PER_THREAD == SCALE_DIM && "Hardcoded and fixed parameter\0"); + +static_assert((THREADS_NUM * ELTS_PER_THREAD <= TILE_DIM_Y * TILE_DIM_X) && + "Unbalanced threads workload\0"); + +static_assert((TunableConfig::CHUNK_DIM_Y % TILE_DIM_Y == 0) && + "Chunk size Y must be evenly divisible by the tile size Y\0"); +static_assert((TunableConfig::CHUNK_DIM_X % TILE_DIM_X == 0) && + "Chunk size X must be evenly divisible by the tile size X\0"); + +static_assert((TILE_DIM_Y % SCALE_DIM == 0) && + "Tile size Y must be evenly divisible by the scale dim\0"); +static_assert((TILE_DIM_X % SCALE_DIM == 0) && + "Tile size X must be evenly divisible by the scale dim\0"); + +constexpr int TILES_Y = TunableConfig::CHUNK_DIM_Y / TILE_DIM_Y; +constexpr int TILES_X = TunableConfig::CHUNK_DIM_X / TILE_DIM_X; + +constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; + +constexpr int SCALES_PER_CHUNK_Y = TunableConfig::CHUNK_DIM_Y / SCALE_DIM; +constexpr int SCALES_PER_CHUNK_X = TunableConfig::CHUNK_DIM_X / SCALE_DIM; + +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; + +constexpr int STAGES_Y = TILES_Y; +constexpr int STAGES_X = TILES_X; +constexpr int STAGES = STAGES_Y * STAGES_X; + +constexpr int BUFFS_NUM = TunableConfig::PREFETCH_STAGES + 1; +constexpr int BUFFS_NUM_IN = BUFFS_NUM; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; +constexpr int BUFFS_NUM_OUT_TR = 2; +constexpr int BUFF_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_DIM_X = TILE_DIM_X; +constexpr int BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr int BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr int BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr int BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; +constexpr int BUFF_IN_ELTS_NUM = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr int BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr int BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr int BUFF_OUT_TR_DIM_Y = BUFF_DIM_X; +constexpr int BUFF_OUT_TR_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr int PACK_SIZE = 8; +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; + +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; + +constexpr int THREADS_X_TR = TILE_DIM_X / 2; +constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR; + +constexpr int ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; +constexpr int ITERATIONS_TR = SCALES_PER_TILE_Y / THREADS_Y_TR; +static_assert(ITERATIONS_TR >= 1 && "Number of transpose iterations should be >=1\0"); +static_assert((SCALES_PER_TILE_Y % THREADS_Y_TR == 0) && + "Partial transpose iterations are not supported\0"); + +constexpr int BUFF_OUT_IT_OFFSET = BUFF_OUT_TR_DIM_X / ITERATIONS_TR / STAGES; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(TunableConfig::CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; + +using IType = bf16; +using IType2 = typename ptx::FPx2; +using IType3D = IType[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; +using IType2x3D = IType2[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; +using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; +using ScalesType2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesTypeTr2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; +using RNG_t = typename transformer_engine::curanddx::detail::philox4x32_native_state< + NVTE_BUILD_NUM_PHILOX_ROUNDS>; + +template +struct SCALING_COEFFICIENT_TYPE {}; +template <> +struct SCALING_COEFFICIENT_TYPE { + using type = float; +}; +template <> +struct SCALING_COEFFICIENT_TYPE { + using type = bf16; +}; + +__device__ __forceinline__ float get_amax_of_pair(const IType2 pair) { + return static_cast(__hmax(__habs(pair.x), __habs(pair.y))); +} + +// Compute "correct" per-block encoding scaling factor +template +__device__ __forceinline__ SF_TYPE +compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_enc) { + NVTE_DEVICE_ERROR("Unsupported scaling-factor type. Only FP32 and BF16 are supported."); +} + +template <> +__device__ __forceinline__ float compute_nvfp4_scaling_coefficient( + const nvfp4_scale_t S_dec_block, const float S_enc) { + const float S_dec = 1.0f / S_enc; + const float scale_rcp = + fminf(1.0f / (static_cast(S_dec_block) * S_dec), detail::TypeExtrema::max); + return scale_rcp; +} + +template <> +__device__ __forceinline__ bf16 +compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_enc) { + const float scale_rcp = + fminf(S_enc / (static_cast(S_dec_block)), detail::TypeExtrema::max); + return static_cast(scale_rcp); +} + +template +__device__ __forceinline__ void colwise_scaling(const IType *__restrict__ sIn_ptr, + fp4e2m1x2 *__restrict__ sOut_tr_ptr, + nvfp4_scale_t *__restrict__ sSFcolwise_ptr, + const float S_enc_colwise, const int stage_Y, + const int stage_X, const int buff_in, + const int buff_out_tr, RNG_t &rng, + uint4 &random_uint4, int &rnd_idx) { + using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; + + const auto &sIn2x = *reinterpret_cast(sIn_ptr); + auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); + auto &sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + const int warp = threadIdx.x / THREADS_PER_WARP; + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; + const int tid_X_colwise = thread_lane; + + const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; + const int thread_offset_X_colwise = tid_X_colwise * 2; + + const int in_thread_offset_Y = thread_offset_Y_colwise; + const int in_thread_offset_X = thread_offset_X_colwise / 2; + + const int out_tr_thread_offset_Y = thread_offset_X_colwise; + const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; + + const int scale_tr_offset_Y = (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; + const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; + + __align__(8) IType rIn[2][SCALE_DIM]; + // Read (cache) a pair of input elements (S2R). Find NVFP4-block AMAX + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const IType2 elt_pair = + ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); + rIn[0][i] = elt_pair.x; + rIn[1][i] = elt_pair.y; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + } + const float block_amax[2] = {static_cast(__habs(thread_amax_2x.x)), + static_cast(__habs(thread_amax_2x.y))}; +#pragma unroll + for (int w = 0; w < 2; ++w) { + const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax[w], S_enc_colwise); + + // Store scaling factors to SMEM buffer (R2S) + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; + + const scaling_coeff_type SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_colwise); + + // Scale elements + __align__(8) uint32_t rOut[SCALE_DIM / 8]; +#pragma unroll + for (int e = 0; e < SCALE_DIM / 8; ++e) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][8 * e]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][8 * e + 4]); + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); + } + } + uint64_t &out_pack_16x = *reinterpret_cast(rOut); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], + out_pack_16x); + } +} + +template +__device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_ptr, + fp4e2m1x2 *__restrict__ sOut_ptr, + nvfp4_scale_t *__restrict__ sSFrowwise_ptr, + const float S_enc_rowwise, const int stage_Y, + const int stage_X, const int buff_in, + const int buff_out, RNG_t &rng, uint4 &random_uint4, + int &rnd_idx) { + using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; + + const auto &sIn = *reinterpret_cast(sIn_ptr); + auto &sOut = *reinterpret_cast(sOut_ptr); + auto &sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; + + const int SF_thread_offset_rowwise_Y = tid_Y_rowwise; + const int SF_thread_offset_rowwise_X = tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; + + const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); + + const int stage_rowwise_scales_offset_Y = SF_thread_offset_rowwise_Y + stage_Y * TILE_DIM_Y; + const int stage_rowwise_scales_offset_X = + SF_thread_offset_rowwise_X + stage_X * SCALES_PER_TILE_X; +#pragma unroll + for (int it = 0; it < ITERATIONS_NORMAL; ++it) { + const int it_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + __align__(16) IType2 rIn[WAVES][PACK_SIZE / 2]; + + // Read (cache) input elements (S2R). Find NVFP4-block AMAX + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + + // Load elements + __uint128_t &elts_8x = *reinterpret_cast<__uint128_t *>(&rIn[w]); + elts_8x = ptx::ld_shared_b128(&sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); + } + } + const float block_amax = get_amax_of_pair(thread_amax_2x); + + const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + const scaling_coeff_type SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); + + // Store scaling factors to SMEM buffer (R2S) + if (SF_storing_thread) { + const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; + } + +// Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][0]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][2]); + + uint32_t out_x8; + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); + } + + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); + } + } +} + +template +__global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, const size_t cols, + const size_t scale_stride, const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG_t rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + int rnd_idx = 0; + + const bool leading_thread = (threadIdx.x == 0); + + constexpr int buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; + + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + + constexpr int in_mem = buff_size_aligned_in; + + constexpr int out_mem_rowwise_data = buff_size_aligned_out; + constexpr int out_mem_colwise_data = RETURN_TRANSPOSE ? buff_size_aligned_out_t : 0; + constexpr int out_mem_rowwise_scales = DIVUP_TO_MULTIPLE( + TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + + IType *sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2 *sOut_ptr = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *sOut_tr_ptr = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + auto &sIn = *reinterpret_cast(sIn_ptr); + auto &sOut = *reinterpret_cast(sOut_ptr); + auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); + + nvfp4_scale_t *sSFrowwise_ptr = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *sSFcolwise_ptr = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + + auto &sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + auto &sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = + (amax_rowwise_ptr == nullptr) + ? 1.0f + : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + + const float S_enc_colwise = + (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + + __shared__ uint64_t workID_mbar; + __shared__ __uint128_t workID_response; + constexpr uint32_t workID_response_size = sizeof(workID_response); + static_assert(workID_response_size == 16); + + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + + // Coordinates of the first chunk (CTA) to process + int32_t ctaid_X = blockIdx.x; + int32_t ctaid_Y = blockIdx.y; + + // Initialize shared memory barriers with the number of threads participating in them + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::mbarrier_init(&workID_mbar, 1); + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + bool job_finished = false; + int buff_in = 0; + int buff_out = 0; + int buff_out_tr = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; + int ctaid_parity = 0; + +// Prefetch input data only when processing the first chunk, +// which enables the one-iteration overlap throughout the entire kernel life +#pragma unroll + for (int stage = 0; stage < TunableConfig::PREFETCH_STAGES; ++stage) { + const int buff_in = stage; + const int stage_Y = stage / STAGES_X; + const int stage_X = stage % STAGES_X; + + const int stage_offset_Y = stage_Y * TILE_DIM_Y; + const int stage_offset_X = stage_X * TILE_DIM_X; + + const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X; + + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X + stage_offset_X; + + uint64_t *barrier = &IN_buff_readable_mbar[buff_in]; + if (leading_thread) { + uint64_t *dst = reinterpret_cast(&sIn[buff_in]); + const uint64_t *src = reinterpret_cast(&tensor_map_input); + + // Arrive on the barrier and tell how many bytes are expected to come in + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + barrier); + } + } + + while (!job_finished) { + const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X; + + const int block_offset_Y_tr = ctaid_X * TunableConfig::CHUNK_DIM_X; + const int block_offset_X_tr = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + + const int chunk_rows = rows - block_offset_Y; + const int chunk_cols = cols - block_offset_X; + + const int scales_block_offset_Y_rowwise = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = ctaid_X * SCALES_PER_CHUNK_X; + const int scales_block_offset_Y_tr = ctaid_X * TunableConfig::CHUNK_DIM_X; + const int scales_block_offset_X_tr = ctaid_Y * SCALES_PER_CHUNK_Y; + + if constexpr (TunableConfig::PERSISTENT) { + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); + ptx::try_cancel_cta(&workID_mbar, &workID_response); + } + } + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / STAGES_X; + const int stage_X = stage % STAGES_X; + + const int stage_offset_Y = stage_Y * TILE_DIM_Y; + const int stage_offset_X = stage_X * TILE_DIM_X; + + if (stage == STAGES - TunableConfig::PREFETCH_STAGES) { + if constexpr (TunableConfig::PERSISTENT) { + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); + ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); + ctaid_parity ^= 1; + } else { + ctaid_X = -1; + ctaid_Y = -1; + } + if (ctaid_X == -1 && ctaid_Y == -1) { + job_finished = true; + } + } + + // Prefetch next stage Input data + if (!job_finished || (stage < STAGES - TunableConfig::PREFETCH_STAGES)) { + const int next_prefetch_buff = (buff_in + TunableConfig::PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + TunableConfig::PREFETCH_STAGES) % STAGES; + const int next_prefetch_stage_Y = next_prefetch_stage / STAGES_X; + const int next_prefetch_stage_X = next_prefetch_stage % STAGES_X; + + const int next_prefetch_stage_offset_Y = next_prefetch_stage_Y * TILE_DIM_Y; + const int next_prefetch_stage_offset_X = next_prefetch_stage_X * TILE_DIM_X; + + // Offsets change, because coordinates of the next "to-be-prefetched" CTA do also chage + const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X; + + const int global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y; + const int global_offset_X = block_offset_X + next_prefetch_stage_offset_X; + + uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; + if (leading_thread) { + uint64_t *dst = reinterpret_cast(&sIn[next_prefetch_buff]); + const uint64_t *src = reinterpret_cast(&tensor_map_input); + + // Arrive on the barrier and tell how many bytes are expected to come in + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + barrier); + } + ptx::fence_proxy_async_shared_cta(); + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + + // Wait for TMA transfer to have finished reading shared memory + // I.e. the OUT buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read(); + + // NVFP4 Quantization + rowwise_scaling( + sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, + rng, random_uint4, rnd_idx); + + if constexpr (RETURN_TRANSPOSE) { + colwise_scaling( + sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, stage_Y, stage_X, buff_in, + buff_out_tr, rng, random_uint4, rnd_idx); + } + + // Wait for shared memory writes to be visible to TMA engine + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine + + // Initiate TMA transfer to copy shared memory to global memory + if (leading_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X + stage_offset_X; + const int global_offset_Y_tr = block_offset_Y_tr + stage_offset_X; + const int global_offset_X_tr = block_offset_X_tr + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, + global_offset_Y, reinterpret_cast(&sOut[buff_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_tr, + global_offset_Y_tr, reinterpret_cast(&sOut_tr[buff_out_tr])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation + ptx::cp_async_bulk_commit_group(); + } + + buff_in = (buff_in + 1) % BUFFS_NUM_IN; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; + } // end of stages + + // Vectorized store of scaling factors (S2G) + { + // Rowwise + { + using ScalesVec = Vec; + // number of scales in X dimension of this chunk + const int count = min(SCALES_PER_CHUNK_X, chunk_cols / SCALE_DIM); + + for (size_t row = threadIdx.x; row < TunableConfig::CHUNK_DIM_Y; row += THREADS_NUM) { + const size_t row_global = scales_block_offset_Y_rowwise + row; + if (row_global < rows) { + ScalesVec &scales_vec = *reinterpret_cast(sSFrowwise[row]); + const size_t scale_idx_global = + row_global * scale_stride + scales_block_offset_X_rowwise; + scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); + } + } + } + + // Colwise + if constexpr (RETURN_TRANSPOSE) { + using ScalesVec = Vec; + // number of scales in Y dimension of this chunk + const int count = min(SCALES_PER_CHUNK_Y, chunk_rows / SCALE_DIM); + + for (size_t row_tr = threadIdx.x; row_tr < TunableConfig::CHUNK_DIM_X; + row_tr += THREADS_NUM) { + const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; + if (row_tr_global < cols) { + ScalesVec &scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); + const size_t scale_idx_global = + row_tr_global * scale_stride_t + scales_block_offset_X_tr; + scales_vec.store_to_elts(&scales_t_ptr[scale_idx_global], 0, count); + } + } + } + + if (!job_finished) { + // Ensures all reads from SFs buffer have completed and it's ready to be reused + __syncthreads(); + } + } + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + ptx::mbarrier_invalid(&workID_mbar); + } +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +#endif // FP4_TYPE_SUPPORTED +} // namespace quantize_transpose_tuned_kernel + +inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_transpose_tuned_kernel; + using namespace ptx; + + const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + + // If transposed output is allocated, return the transposed data + // Otherwise, it's not necesary to return the transposed data. + const bool return_transpose = output->has_columnwise_data(); + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + if (return_transpose) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated"); + } + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + NVTE_CHECK(rows % 32 == 0, + "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA + NVTE_CHECK(cols % 32 == 0, + "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA + + const int blocks_Y = DIVUP(rows, static_cast(TunableConfig::CHUNK_DIM_Y)); + const int blocks_X = DIVUP(cols, static_cast(TunableConfig::CHUNK_DIM_X)); + const dim3 grid(blocks_X, blocks_Y); + const int block_size = THREADS_NUM; + + const size_t scale_stride = output->scale_inv.shape[1]; + const size_t scale_stride_transpose = + return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + nvfp4_scale_t *const scales_transpose_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const float *const amax_colwise_ptr = + reinterpret_cast(output->columnwise_amax.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + 4); + if (return_transpose) { + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + } + + constexpr int buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + + constexpr int buff_size_scales = DIVUP_TO_MULTIPLE( + TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_transpose = DIVUP_TO_MULTIPLE( + TunableConfig::CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + + const int in_mem = buff_size_aligned_in; + + const int out_data_mem = buff_size_aligned_out; + const int out_data_transpose_mem = return_transpose ? buff_size_aligned_out_t : 0; + const int out_scales_mem = buff_size_scales; + const int out_scales_transpose_mem = return_transpose ? buff_size_scales_transpose : 0; + + const int out_mem = out_data_mem + out_data_transpose_mem; + + const int dshmem_size = + in_mem + out_mem + out_scales_transpose_mem + out_scales_mem + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, USE_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + }););); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index 66a3da55d..7be3d1bb4 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -8,7 +8,6 @@ #include #include -#include #include #include @@ -236,7 +235,7 @@ void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n ctx->grid_row_major.get(), ctx->d_desc.get())); const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE; - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, sizeof epilogue)); } @@ -273,46 +272,46 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N; const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N; - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a, sizeof trans_a)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b, sizeof trans_b)); cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr, sizeof algo_attr)); const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32; if (is_fp8_dtype(a->dtype())) { NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER, &a->scale_inv.dptr, sizeof(void*))); } if (is_fp8_dtype(b->dtype())) { NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER, &b->scale_inv.dptr, sizeof(void*))); } if (is_fp8_dtype(d->dtype())) { NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER, &d->scale.dptr, sizeof(void*))); if (d->amax.dptr) { - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER, &d->amax.dptr, sizeof(void*))); } @@ -321,7 +320,7 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo // Might be set to ALLREDUCE before, need to OR with the new flags to set. cublasMpMatmulEpilogue_t epilogue{}; size_t size_read{}; - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorGetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, sizeof epilogue, &size_read)); NVTE_CHECK(size_read == sizeof epilogue); @@ -339,42 +338,42 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad}); it != flags_to_epilogue.end()) { epilogue = static_cast(epilogue | it->second); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, sizeof epilogue)); } if (bias && bias->data.dptr) { cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type, sizeof bias_type)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr, sizeof bias->data.dptr)); } if (pre_act_out && pre_act_out->data.dptr) { cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof aux_type)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER, &pre_act_out->data.dptr, sizeof pre_act_out->data.dptr)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd, sizeof ldd)); if (is_fp8_dtype(pre_act_out->dtype())) { NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER, &pre_act_out->scale.dptr, sizeof(void*))); if (pre_act_out->amax.dptr) { - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER, &pre_act_out->amax.dptr, sizeof(void*))); } @@ -382,12 +381,12 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo } if (comm_sm_count) { - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT, &comm_sm_count, sizeof comm_sm_count)); } - NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream)); + NVTE_CHECK_CUBLASMP(cublasMpSetStream(ctx->cublas_mp.get(), main_stream)); size_t wrksp_size_device{}; size_t wrksp_size_host{}; @@ -423,8 +422,14 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo std::vector workspace_host(wrksp_size_host); if (ctx->workspace_size < wrksp_size_device) { - nvshmem_free(ctx->workspace); - ctx->workspace = nvshmem_malloc(wrksp_size_device); + if (ctx->workspace) { + NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); + NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); + } + NVTE_CHECK_CUBLASMP( + cublasMpMalloc(ctx->grid_col_major.get(), &ctx->workspace, wrksp_size_device)); + NVTE_CHECK_CUBLASMP( + cublasMpBufferRegister(ctx->grid_row_major.get(), ctx->workspace, wrksp_size_device)); ctx->workspace_size = wrksp_size_device; } @@ -473,7 +478,10 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); - nvshmemx_sync_all_on_stream(ctx->stream.get()); + if (ctx->workspace) { + NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); + NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); + } delete ctx; } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index eb6b86e07..a1a58b81c 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -124,10 +124,11 @@ bool has_mnnvl_fabric(int device_id) { NVTE_CALL_CHECK_CUDA_NVML(nvmlDeviceGetHandleByIndex_v2, device_id, &local_device); nvmlGpuFabricInfoV_t fabricInfo = {}; fabricInfo.version = nvmlGpuFabricInfo_v2; - fabricInfo.clusterUuid[0] = '\0'; NVTE_CALL_CHECK_CUDA_NVML(nvmlDeviceGetGpuFabricInfoV, local_device, &fabricInfo); NVTE_CALL_CHECK_CUDA_NVML(nvmlShutdown); - if (fabricInfo.state >= NVML_GPU_FABRIC_STATE_COMPLETED && fabricInfo.clusterUuid[0] != '\0') { + const unsigned char zero_uuid[NVML_GPU_FABRIC_UUID_LEN] = {0}; + if (fabricInfo.state == NVML_GPU_FABRIC_STATE_COMPLETED && + memcmp(fabricInfo.clusterUuid, zero_uuid, NVML_GPU_FABRIC_UUID_LEN) != 0) { mnnvl_fabric_support = true; } } diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index bcb01e2bb..426d476b7 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -91,6 +91,48 @@ __global__ void __launch_bounds__(kThreadsPerBlock) reinterpret_cast(ptr)[idx] = data.value; } +__global__ void __launch_bounds__(kThreadsPerBlock) + splits_to_offsets_kernel(const int64_t *__restrict__ first_dims, int64_t *__restrict__ output, + size_t num_tensors, int64_t logical_last_dim) { + __shared__ int64_t block_scan[kThreadsPerBlock]; + __shared__ int64_t chunk_prefix; + + const size_t tid = threadIdx.x; + if (tid == 0) { + output[0] = 0; + chunk_prefix = 0; + } + __syncthreads(); + + for (size_t chunk_start = 0; chunk_start < num_tensors; chunk_start += kThreadsPerBlock) { + const size_t idx = chunk_start + tid; + int64_t value = 0; + if (idx < num_tensors) { + value = first_dims[idx] * logical_last_dim; + } + block_scan[tid] = value; + __syncthreads(); + + // Inclusive scan in shared memory. + for (size_t offset = 1; offset < kThreadsPerBlock; offset <<= 1) { + const int64_t addend = (tid >= offset) ? block_scan[tid - offset] : 0; + __syncthreads(); + block_scan[tid] += addend; + __syncthreads(); + } + + if (idx < num_tensors) { + output[idx + 1] = chunk_prefix + block_scan[tid]; + } + __syncthreads(); + + if (tid == kThreadsPerBlock - 1) { + chunk_prefix += block_scan[tid]; + } + __syncthreads(); + } +} + } // namespace #define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \ @@ -120,6 +162,19 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float, stream); MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, uint8_t, stream); } + +void nvte_splits_to_offsets(const int64_t *first_dims, int64_t *output, size_t num_tensors, + int64_t logical_last_dim, cudaStream_t stream) { + NVTE_API_CALL(nvte_splits_to_offsets); + NVTE_CHECK(output != nullptr, "Output pointer must be allocated."); + NVTE_CHECK(num_tensors > 0, "num_tensors must be greater than 0."); + NVTE_CHECK(first_dims != nullptr, "first_dims pointer must be allocated."); + NVTE_CHECK(logical_last_dim > 0, "logical_last_dim must be greater than 0."); + + splits_to_offsets_kernel<<<1, kThreadsPerBlock, 0, stream>>>(first_dims, output, num_tensors, + logical_last_dim); + NVTE_CHECK_CUDA(cudaGetLastError()); +} } // extern "C" #ifndef __HIP_PLATFORM_AMD__ diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index ed2674f42..6bf3aff0e 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -20,6 +20,12 @@ #endif #endif //#ifndef __HIP_PLATFORM_AMD__ +#ifndef NVTE_BUILD_NUM_PHILOX_ROUNDS +#define NVTE_BUILD_NUM_PHILOX_ROUNDS 10 +#endif +static_assert(NVTE_BUILD_NUM_PHILOX_ROUNDS > 0, + "NVTE_BUILD_NUM_PHILOX_ROUNDS must be a positive integer."); + #include #include #include @@ -48,6 +54,8 @@ namespace transformer_engine { std::string to_string(const DType type); std::string to_string(const NVTEScalingMode &mode); +inline std::string to_string_like(const DType &val) { return to_string(val); } + inline bool is_tensor_scaling(const NVTEScalingMode &mode) { return mode == NVTE_DELAYED_TENSOR_SCALING; } @@ -323,6 +331,9 @@ struct GroupedTensor { SimpleTensor columnwise_amax; SimpleTensor scale; // for FP8-DS only + NVTEScalingMode scaling_mode; + size_t num_tensors; + // Shape information (OPTIONAL - empty if dimension is uniform across all tensors) // first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim) // last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim) @@ -340,10 +351,29 @@ struct GroupedTensor { // Always 2D with positive dimensions NVTEShape logical_shape; - NVTEScalingMode scaling_mode; - size_t num_tensors; NVTEGroupedTensor nvte_tensor; + /*! \brief Whether scaling factors are in format expected by GEMM + * + * Only meaningful for MXFP8 and NVFP4. + */ + bool with_gemm_swizzled_scales = false; + + /*! Map from NVTEGroupedTensorParam to parameter sizes */ + static constexpr size_t attr_sizes[] = { + sizeof(NVTEBasicTensor), // kNVTEGroupedRowwiseData + sizeof(NVTEBasicTensor), // kNVTEGroupedColumnwiseData + sizeof(NVTEBasicTensor), // kNVTEGroupedScale + sizeof(NVTEBasicTensor), // kNVTEGroupedAmax + sizeof(NVTEBasicTensor), // kNVTEGroupedRowwiseScaleInv + sizeof(NVTEBasicTensor), // kNVTEGroupedColumnwiseScaleInv + sizeof(NVTEBasicTensor), // kNVTEGroupedColumnwiseAmax + sizeof(NVTEBasicTensor), // kNVTEGroupedFirstDims + sizeof(NVTEBasicTensor), // kNVTEGroupedLastDims + sizeof(NVTEBasicTensor), // kNVTEGroupedTensorOffsets + sizeof(uint8_t) // kNVTEGroupedWithGEMMSwizzledScales + }; + GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors) : data(), columnwise_data(), @@ -352,13 +382,14 @@ struct GroupedTensor { amax(), columnwise_amax(), scale(), + scaling_mode(scaling_mode), num_tensors(num_tensors), first_dims(nullptr, std::vector{0}, DType::kInt64), last_dims(nullptr, std::vector{0}, DType::kInt64), tensor_offsets(nullptr, std::vector{0}, DType::kInt64), logical_shape(nvte_make_shape(nullptr, 1)), - scaling_mode(scaling_mode), - nvte_tensor(0) {} + nvte_tensor(0), + with_gemm_swizzled_scales(false) {} explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; } @@ -410,6 +441,7 @@ struct GroupedTensor { num_tensors = 0; scaling_mode = NVTE_DELAYED_TENSOR_SCALING; nvte_tensor = 0; + with_gemm_swizzled_scales = false; } }; @@ -652,125 +684,149 @@ struct TypeInfo { #define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing #endif -#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kByte: { \ - using type = unsigned char; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kInt16: { \ - using type = int16_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kInt32: { \ - using type = int32_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kInt64: { \ - using type = int64_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E8M0: { \ - using type = byte; \ - { __VA_ARGS__ } \ - } break; \ - SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kByte: { \ + using type = unsigned char; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt16: { \ + using type = int16_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt32: { \ + using type = int32_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt64: { \ + using type = int64_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E8M0: { \ + using type = byte; \ + { __VA_ARGS__ } \ + } break; \ + SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Byte, Int16, Int32, Int64, Float32, " \ + "Float16, BFloat16, Float8E4M3, Float8E5M2, " \ + "Float8E8M0, Float4E2M1."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16, " \ + "Float8E4M3, Float8E5M2."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported output dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16, " \ + "Float8E5M2, Float8E4M3."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, BFloat16."); \ } // Add a pack_size argument to select the packed type for FP4 @@ -782,80 +838,90 @@ struct TypeInfo { { __VA_ARGS__ } \ } break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected: Float4E2M1."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float8E5M2, Float8E4M3."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E5M2: \ - case DType::kFloat8E4M3: { \ - NVTE_ERROR("FP8 type not instantiated for input."); \ - } break; \ - case DType::kFloat4E2M1: { \ - NVTE_ERROR("FP4 type not instantiated for input."); \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: \ + case DType::kFloat8E4M3: { \ + NVTE_ERROR("FP8 dtype ", to_string(static_cast(dtype)), \ + " is not instantiated for input. " \ + "Expected one of: Float32, Float16, BFloat16."); \ + } break; \ + case DType::kFloat4E2M1: { \ + NVTE_ERROR( \ + "FP4 dtype Float4E2M1 is not instantiated " \ + "for input. Expected one of: Float32, Float16, " \ + "BFloat16."); \ + } break; \ + default: \ + NVTE_ERROR("Unsupported input dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat16: { \ - using type = fp16; \ - __VA_ARGS__; \ - break; \ - } \ - case DType::kBFloat16: { \ - using type = bf16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - NVTE_ERROR("Invalid type for 16 bit."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat16: { \ + using type = fp16; \ + __VA_ARGS__; \ + break; \ + } \ + case DType::kBFloat16: { \ + using type = bf16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + NVTE_ERROR("Unsupported 16-bit dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float16, BFloat16."); \ } -#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ - switch (SCALE_DIM) { \ - case 1: { \ - constexpr size_t DIM = 1; \ - { __VA_ARGS__ } \ - } break; \ - case 32: { \ - constexpr size_t DIM = 32; \ - { __VA_ARGS__ } \ - } break; \ - default: { \ - NVTE_ERROR("Invalid size of the MX scaling factor."); \ - } \ +#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ + switch (SCALE_DIM) { \ + case 1: { \ + constexpr size_t DIM = 1; \ + { __VA_ARGS__ } \ + } break; \ + case 32: { \ + constexpr size_t DIM = 32; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported MX scaling factor dimension ", SCALE_DIM, \ + ". Expected one of: 1, 32."); \ + } \ } #define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index fde0d3892..6a136c67e 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -15,73 +15,87 @@ #include "fused_attn_fp8.h" #include "utils.h" -namespace { -// Helper function to create a tensor view with modified shape and optional pointer offset -transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *source, - const std::vector &shape, - size_t offset_bytes = 0) { - transformer_engine::Tensor view = *source; - if (offset_bytes > 0) { - view.data.dptr = static_cast(static_cast(source->data.dptr) + offset_bytes); - } - view.data.shape = shape; - view.nvte_tensor = 0; // Mark as unmanaged/local tensor view - return view; -} - -// Helper function to calculate stride in bytes for packed QKV tensor unpacking -size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, - size_t h, size_t d) { - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (transformer_engine::typeToNumBits(dtype) * h * d) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (transformer_engine::typeToNumBits(dtype) * d) / 8; +namespace transformer_engine { + +std::string to_string(NVTE_QKV_Layout layout) { + switch (layout) { + case NVTE_SB3HD: + return "NVTE_SB3HD"; + case NVTE_SBH3D: + return "NVTE_SBH3D"; + case NVTE_SBHD_SB2HD: + return "NVTE_SBHD_SB2HD"; + case NVTE_SBHD_SBH2D: + return "NVTE_SBHD_SBH2D"; + case NVTE_SBHD_SBHD_SBHD: + return "NVTE_SBHD_SBHD_SBHD"; + case NVTE_BS3HD: + return "NVTE_BS3HD"; + case NVTE_BSH3D: + return "NVTE_BSH3D"; + case NVTE_BSHD_BS2HD: + return "NVTE_BSHD_BS2HD"; + case NVTE_BSHD_BSH2D: + return "NVTE_BSHD_BSH2D"; + case NVTE_BSHD_BSHD_BSHD: + return "NVTE_BSHD_BSHD_BSHD"; + case NVTE_T3HD: + return "NVTE_T3HD"; + case NVTE_TH3D: + return "NVTE_TH3D"; + case NVTE_THD_T2HD: + return "NVTE_THD_T2HD"; + case NVTE_THD_TH2D: + return "NVTE_THD_TH2D"; + case NVTE_THD_THD_THD: + return "NVTE_THD_THD_THD"; + case NVTE_SBHD_BSHD_BSHD: + return "NVTE_SBHD_BSHD_BSHD"; + case NVTE_BSHD_SBHD_SBHD: + return "NVTE_BSHD_SBHD_SBHD"; + case NVTE_THD_BSHD_BSHD: + return "NVTE_THD_BSHD_BSHD"; + case NVTE_THD_SBHD_SBHD: + return "NVTE_THD_SBHD_SBHD"; + case NVTE_Paged_KV_BSHD_BSHD_BSHD: + return "NVTE_Paged_KV_BSHD_BSHD_BSHD"; + case NVTE_Paged_KV_BSHD_SBHD_SBHD: + return "NVTE_Paged_KV_BSHD_SBHD_SBHD"; + case NVTE_Paged_KV_SBHD_BSHD_BSHD: + return "NVTE_Paged_KV_SBHD_BSHD_BSHD"; + case NVTE_Paged_KV_SBHD_SBHD_SBHD: + return "NVTE_Paged_KV_SBHD_SBHD_SBHD"; + case NVTE_Paged_KV_THD_BSHD_BSHD: + return "NVTE_Paged_KV_THD_BSHD_BSHD"; + case NVTE_Paged_KV_THD_SBHD_SBHD: + return "NVTE_Paged_KV_THD_SBHD_SBHD"; + default: + return "UNKNOWN_QKV_LAYOUT(" + std::to_string(static_cast(layout)) + ")"; } - return stride; } -// Helper function to determine unpacked shape for QKV packed tensor -std::vector calculate_qkv_unpacked_shape(const transformer_engine::Tensor *qkv_tensor, - size_t h, size_t d) { - std::vector unpacked_shape; - if (qkv_tensor->data.shape.size() == 4) { - // T3HD or TH3D (4D) -> THD (3D): remove dimension "3" at position 1 - unpacked_shape = {qkv_tensor->data.shape[0], h, d}; - } else { - // BS3HD/SB3HD or BSH3D/SBH3D (5D) -> BSHD/SBHD (4D): remove dimension "3" at position 2 - unpacked_shape = {qkv_tensor->data.shape[0], qkv_tensor->data.shape[1], h, d}; +std::string to_string(NVTE_QKV_Format format) { + switch (format) { + case NVTE_SBHD: + return "NVTE_SBHD"; + case NVTE_BSHD: + return "NVTE_BSHD"; + case NVTE_THD: + return "NVTE_THD"; + case NVTE_BSHD_2SBHD: + return "NVTE_BSHD_2SBHD"; + case NVTE_SBHD_2BSHD: + return "NVTE_SBHD_2BSHD"; + case NVTE_THD_2BSHD: + return "NVTE_THD_2BSHD"; + case NVTE_THD_2SBHD: + return "NVTE_THD_2SBHD"; + default: + return "UNKNOWN_QKV_FORMAT(" + std::to_string(static_cast(format)) + ")"; } - return unpacked_shape; } -// Helper function to calculate stride for packed KV tensor unpacking -size_t calculate_kv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, - size_t h_kv, size_t d) { - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (transformer_engine::typeToNumBits(dtype) * h_kv * d) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (transformer_engine::typeToNumBits(dtype) * d) / 8; - } - return stride; -} - -// Helper function to determine unpacked shape for KV packed tensor -std::vector calculate_kv_unpacked_shape(const transformer_engine::Tensor *kv_tensor, - NVTE_QKV_Layout_Group layout_group, - NVTE_QKV_Format kv_format, size_t t_kv, size_t h_kv, - size_t d) { - std::vector unpacked_kv_shape; - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - unpacked_kv_shape = {t_kv, h_kv, d}; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD || - layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - unpacked_kv_shape = {kv_tensor->data.shape[0], kv_tensor->data.shape[1], h_kv, d}; - } - return unpacked_kv_shape; -} -} // namespace +} // namespace transformer_engine // map NVTE_QKV_Layout to NVTE_QKV_Layout_Group NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { @@ -118,7 +132,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; default: - NVTE_ERROR("qkv_layout not supported!"); + NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout), + " in nvte_get_qkv_layout_group."); } } @@ -158,7 +173,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; default: - NVTE_ERROR("qkv_layout not supported!"); + NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout), + " in nvte_get_qkv_format."); } } @@ -177,7 +193,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; default: - NVTE_ERROR("qkv_layout not supported!"); + NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format), + " in nvte_get_q_format."); } } @@ -196,7 +213,8 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; default: - NVTE_ERROR("qkv_layout not supported!"); + NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format), + " in nvte_get_kv_format."); } } @@ -206,7 +224,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -406,9 +424,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (window_size_right == -1 || window_size_right == 0)) || // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} (cudnn_runtime_version >= 90200 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ((window_size_left == -1 && window_size_right == -1 && + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || + ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && @@ -418,12 +438,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} (cudnn_runtime_version >= 90600 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && + ((window_size_left >= 0 || window_size_left == -1) && + (window_size_right >= 0 || window_size_right == -1) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && // TODO(cyang): fix bug for BRCM + cross-attention on sm100 (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700)))) || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && @@ -440,7 +462,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.13.1+: vanilla, off-by-one, learnable (cudnn_runtime_version >= 91301 || (cudnn_runtime_version < 91301 && - softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) { + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && + // determinism on Blackwell + // pre-9.18.1: fwd: deterministic; bwd: non-deterministic + // 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic + (sm_arch_ < 100 || + (sm_arch_ >= 100 && (!is_training || + (is_training && !deterministic && + (dropout == 0.0 || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)) || + (is_training && deterministic && cudnn_runtime_version >= 91801 && + dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { @@ -503,585 +534,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( return backend; } -// NVTE fused attention FWD with packed QKV -// DEPRECATED: This API is deprecated. -// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); - using namespace transformer_engine; - - const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); - const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); - const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); - const Tensor *input_QKV = convertNVTETensorCheck(QKV); - const Tensor *input_Bias = convertNVTETensorCheck(Bias); - const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); - Tensor *input_output_S = convertNVTETensorCheck(S); - Tensor *output_O = convertNVTETensorCheck(O); - Tensor *wkspace = convertNVTETensor(workspace); - - auto ndim = input_QKV->data.shape.size(); - size_t b = input_cu_seqlens->data.shape[0] - 1; - size_t h = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - h = input_QKV->data.shape[ndim - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - h = input_QKV->data.shape[ndim - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); - } - size_t d = input_QKV->data.shape[ndim - 1]; - size_t t = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - t = input_QKV->data.shape[0]; - } - - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); - - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, - cuda_graph); - - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { -#if (CUDNN_VERSION >= 8901) - // Unpack QKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - fused_attn_max_512_fwd(b, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, - input_cu_seqlens, input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { -#if (CUDNN_VERSION >= 8900) - // Unpack QKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - fused_attn_arbitrary_seqlen_fwd( - b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, - return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, - wkspace, stream, handle); -#else - NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { -#if (CUDNN_VERSION >= 8900) - // Unpack QKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, - input_cu_seqlens, input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); -#endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } -} -// NVTE fused attention BWD with packed QKV -// DEPRECATED: This API is deprecated. -// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. -void nvte_fused_attn_bwd_qkvpacked( - const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, - NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, - NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); - using namespace transformer_engine; - - const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); - const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); - const Tensor *input_QKV = convertNVTETensorCheck(QKV); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - const Tensor *input_S = convertNVTETensorCheck(S); - Tensor *input_output_dP = convertNVTETensorCheck(dP); - Tensor *output_dQKV = convertNVTETensorCheck(dQKV); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); - Tensor *wkspace = convertNVTETensor(workspace); - - auto ndim = input_QKV->data.shape.size(); - size_t b = input_cu_seqlens->data.shape[0] - 1; - size_t h = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - h = input_QKV->data.shape[ndim - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - h = input_QKV->data.shape[ndim - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); - } - size_t d = input_QKV->data.shape[ndim - 1]; - size_t t = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - t = input_QKV->data.shape[0]; - } - - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); - - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); - - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { -#if (CUDNN_VERSION >= 8901) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - - // Unpack QKV and dQKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V and dQ, dK, dV - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); - Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); - Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - - fused_attn_max_512_bwd(b, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, output_S, - &dQ_view, &dK_view, &dV_view, output_dBias, input_cu_seqlens, - input_cu_seqlens, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { -#if (CUDNN_VERSION >= 8900) - size_t i = 0; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_Bias, *input_SoftmaxOffset; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } - - // Unpack QKV and dQKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V and dQ, dK, dV - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); - Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); - Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - - fused_attn_arbitrary_seqlen_bwd( - b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, - &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, - &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); -#else - const char *err_msg = - "cuDNN 8.9.0 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; - NVTE_ERROR(err_msg); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { -#if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - - // Unpack QKV and dQKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V and dQ, dK, dV - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); - Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); - Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - - fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, - input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, - input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, - handle); -#else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); -#endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } -} -// NVTE fused attention FWD with packed KV -// DEPRECATED: This API is deprecated. -// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. -void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, - NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); - const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); - const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_KV = convertNVTETensorCheck(KV); - const Tensor *input_Bias = convertNVTETensorCheck(Bias); - const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); - Tensor *input_output_S = convertNVTETensorCheck(S); - Tensor *output_O = convertNVTETensorCheck(O); - Tensor *wkspace = convertNVTETensor(workspace); - - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - auto ndim = input_Q->data.shape.size(); - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; - auto ndim_kv = input_KV->data.shape.size(); - size_t h_kv = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - h_kv = input_KV->data.shape[ndim_kv - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - h_kv = input_KV->data.shape[ndim_kv - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); - } - size_t t_q = 0; - size_t t_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_KV->data.shape[0]; - } - int64_t num_pages_k = 0; - int64_t num_pages_v = 0; - int64_t page_size_k = 0; - int64_t page_size_v = 0; - int64_t max_pages_per_seq_k = 0; - int64_t max_pages_per_seq_v = 0; - if (input_page_table_k->data.dptr != nullptr) { - max_pages_per_seq_k = input_page_table_k->data.shape[1]; - } - if (input_page_table_v->data.dptr != nullptr) { - max_pages_per_seq_v = input_page_table_v->data.shape[1]; - } - if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { - num_pages_k = input_KV->data.shape[0]; - page_size_k = input_KV->data.shape[1]; - num_pages_v = num_pages_v; - page_size_v = page_size_v; - } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { - num_pages_k = input_KV->data.shape[1]; - page_size_k = input_KV->data.shape[0]; - num_pages_v = num_pages_v; - page_size_v = page_size_v; - } - } - - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_KV->data.dtype); - - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_logit, cuda_graph); - - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { -#if (CUDNN_VERSION >= 8901) - // Unpack KV and call the non-packed function - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { -#if (CUDNN_VERSION >= 8903) - // Unpack KV and call the non-packed function - const auto Q_type = input_Q->data.dtype; - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR( - "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { -#if (CUDNN_VERSION >= 8900) - // Unpack KV and call the non-packed function - const auto Q_type = input_Q->data.dtype; - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); -#endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } -} -// NVTE fused attention BWD with packed KV -// DEPRECATED: This API is deprecated. -// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. -void nvte_fused_attn_bwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_KV = convertNVTETensorCheck(KV); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - const Tensor *input_S = convertNVTETensorCheck(S); - Tensor *input_output_dP = convertNVTETensorCheck(dP); - Tensor *output_dQ = convertNVTETensorCheck(dQ); - Tensor *output_dKV = convertNVTETensorCheck(dKV); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); - Tensor *wkspace = convertNVTETensor(workspace); - - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - auto ndim = input_Q->data.shape.size(); - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; - auto ndim_kv = input_KV->data.shape.size(); - size_t h_kv = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - h_kv = input_KV->data.shape[ndim_kv - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - h_kv = input_KV->data.shape[ndim_kv - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); - } - size_t t_q = 0; - size_t t_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_KV->data.shape[0]; - } - - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_KV->data.dtype); - - NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, - softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - d, window_size_left, window_size_right, false, cuda_graph); - - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { -#if (CUDNN_VERSION >= 8901) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - - // Unpack KV and dKV and call the non-packed function - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); - Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); - - fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_dO, output_S, - output_dQ, &dK_view, &dV_view, output_dBias, input_cu_seqlens_q, - input_cu_seqlens_kv, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { -#if (CUDNN_VERSION >= 8903) - size_t i = 0; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_Bias, *input_SoftmaxOffset; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } - - // Unpack KV and dKV and call the non-packed function - const auto Q_type = input_Q->data.dtype; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - // Create tensor views for dK, dV - Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); - Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); - - fused_attn_arbitrary_seqlen_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, - input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, - output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, - wkspace, stream, handle); -#else - const char *err_msg = - "cuDNN 8.9.3 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; - NVTE_ERROR(err_msg); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { -#if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - - // Unpack KV and dKV and call the non-packed function - const auto Q_type = input_Q->data.dtype; - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); - Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); - - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, - &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, - stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); -#endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } -} // NVTE fused attention FWD with separate Q, K and V void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, @@ -1094,8 +546,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1166,7 +618,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -1183,13 +635,14 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) @@ -1215,8 +668,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1262,7 +716,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph); + cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -1289,8 +743,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, - input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, + input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -1305,9 +759,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, + qkv_layout, bias_type, attn_mask_type, deterministic, input_Q, input_K, + input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, + output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index d3746fc04..eb2ebcff3 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -52,12 +52,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, - int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, - bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, - void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, + bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, + void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, + void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -75,6 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; + bottom_right_diagonal = false; } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (is_training && dropout_probability != 0.0f); @@ -120,6 +122,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( max_pages_per_seq_v, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, is_training, dropout_probability, @@ -129,6 +133,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, true, tensorType, cudnn_frontend::DataType_t::NOT_SET, @@ -248,23 +253,30 @@ void fused_attn_arbitrary_seqlen_fwd_impl( fe::graph::SDPA_attributes sdpa_options; sdpa_options = fe::graph::SDPA_attributes() .set_name("flash_attention") - .set_is_inference(false) .set_generate_stats(generate_stats) .set_causal_mask(is_causal) .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_options.set_diagonal_band_right_bound(window_size_right); + } sdpa_options.set_alibi_mask(is_alibi); if (is_bias) { - bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + bias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_options.set_bias(bias); } @@ -540,12 +552,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, - void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, - void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, - void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, + int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, void *devPtrQ, void *devPtrKTranspose, + void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, + void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, + void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -563,6 +576,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; + bottom_right_diagonal = false; } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (dropout_probability != 0.0f); @@ -612,6 +626,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, true, dropout_probability, @@ -621,6 +637,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, tensorType, cudnn_frontend::DataType_t::NOT_SET, @@ -781,9 +798,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_max_total_seq_len_kv(s_kv); } + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + } if (cudnn_runtime_version >= 90000) { sdpa_backward_options.set_deterministic_algorithm(deterministic); @@ -792,19 +817,20 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_alibi_mask(is_alibi); if (is_bias) { - bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); - dBias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dBias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + bias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_backward_options.set_bias(bias); - // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] - // are not supported for dbias calculation but they are - // supported for forward bias calculation - if ((bias_b == 1) && (bias_h == h)) { + // bias shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s], [1, h, s, s] are supported for dbias calculation + // bias shape [1, 1, 1, s] is not supported for dbias calculation as of cuDNN 9.18 + if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) { + dBias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("dBias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_backward_options.set_dbias(dBias); } } @@ -955,10 +981,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_bias) { variant_pack[bias] = devPtrBias; - if ((bias_b == 1) && (bias_h == h)) { + if (dBias != nullptr) { variant_pack[dBias] = devPtrdBias; - } else { - variant_pack[dBias] = nullptr; } } @@ -1044,8 +1068,8 @@ void fused_attn_arbitrary_seqlen_fwd( size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, @@ -1064,10 +1088,14 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; + size_t bias_sq = 0; + size_t bias_skv = 0; if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { devPtrBias = input_Bias->data.dptr; bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; + bias_sq = input_Bias->data.shape[2]; + bias_skv = input_Bias->data.shape[3]; } void *devPtrSoftmaxOffset = nullptr; if (softmax_type != NVTE_VANILLA_SOFTMAX) { @@ -1133,7 +1161,7 @@ void fused_attn_arbitrary_seqlen_fwd( if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; + output_bias->data.shape = {bias_b, bias_h, bias_sq, bias_skv}; output_bias->data.dtype = QKV_type; } @@ -1178,13 +1206,13 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, - devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, + is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, + devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1206,13 +1234,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; @@ -1224,11 +1253,15 @@ void fused_attn_arbitrary_seqlen_bwd( void *devPtrdBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; + size_t bias_sq = 0; + size_t bias_skv = 0; if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { devPtrBias = input_Bias->data.dptr; devPtrdBias = output_dBias->data.dptr; bias_b = output_dBias->data.shape[0]; bias_h = output_dBias->data.shape[1]; + bias_sq = output_dBias->data.shape[2]; + bias_skv = output_dBias->data.shape[3]; } size_t max_batch_size = 0; @@ -1271,11 +1304,11 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, - devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, + devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index c34eae4e6..4dd7f3d1d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -25,8 +25,8 @@ void fused_attn_arbitrary_seqlen_fwd( size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, @@ -37,13 +37,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 3630041cc..80e64370f 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1671,6 +1671,8 @@ void fused_attn_fp8_fwd_impl_v1( bool is_dropout = (is_training && dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; + auto bias_sq = s_q; + auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || @@ -1697,6 +1699,8 @@ void fused_attn_fp8_fwd_impl_v1( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, is_training, dropout_probability, @@ -1707,6 +1711,7 @@ void fused_attn_fp8_fwd_impl_v1( 0, 0, true, + true, qkv_tensor_type, o_tensor_type, cudnn_frontend::DataType_t::NOT_SET, @@ -1809,7 +1814,7 @@ void fused_attn_fp8_fwd_impl_v1( fe::graph::SDPA_fp8_attributes sdpa_options; sdpa_options = fe::graph::SDPA_fp8_attributes() .set_name("sdpa_fp8") - .set_is_inference(false) + .set_generate_stats(true) .set_causal_mask(is_causal) .set_attn_scale(attn_scale); @@ -1817,8 +1822,8 @@ void fused_attn_fp8_fwd_impl_v1( // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("bias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // sdpa_options.set_bias(bias); // } @@ -1977,13 +1982,13 @@ void fused_attn_fp8_fwd_impl_v1( void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, - void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, - void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, - void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, - void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, - void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, - void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, + NVTE_Mask_Type mask_type, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, + void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, + void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, + void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, + void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, + void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size, @@ -1998,6 +2003,9 @@ void fused_attn_fp8_bwd_impl_v1( bool is_dropout = (dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; + const auto cudnn_runtime_version = cudnnGetVersion(); + auto bias_sq = s_q; + auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || @@ -2026,6 +2034,8 @@ void fused_attn_fp8_bwd_impl_v1( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, true, dropout_probability, @@ -2035,7 +2045,8 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, - false, + true, + deterministic, qkv_tensor_type, o_tensor_type, do_tensor_type, @@ -2192,21 +2203,24 @@ void fused_attn_fp8_bwd_impl_v1( // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("bias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // dBias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("dBias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // sdpa_backward_options.set_bias(bias); - // // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] - // // are not supported for dbias calculation but they are - // // supported for forward bias calculation - // if ((bias_b == 1) && (bias_h == h)) { - // sdpa_backward_options.set_dbias(dBias); - // } + // bias shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s], [1, h, s, s] are supported for dbias calculation + // bias shape [1, 1, 1, s] is not supported for dbias calculation as of cuDNN 9.18 + // if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) { + // sdpa_backward_options.set_dbias(dBias); + // } // } + if (cudnn_runtime_version >= 91900) { + sdpa_backward_options.set_deterministic_algorithm(deterministic); + } + if (is_padding) { seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("seq_q") @@ -2510,11 +2524,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q, - const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, - const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv, - const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, - const Tensor* output_dK, const Tensor* output_dV, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic, + const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, + const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, + const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, + const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -2572,11 +2586,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + p_dropout, qkv_layout, bias_type, mask_type, deterministic, devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, + devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, + devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, + devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index a1a932fdf..225e700ef 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -28,11 +28,11 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv, - const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, - const Tensor *output_dK, const Tensor *output_dV, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, + const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, + const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index d85d62cf2..3ede91b8d 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -539,11 +539,13 @@ size_t get_max_batch_size(size_t batch_size) { // batch size is expected to be 10s-100s // b = 1, ..., 32 -> max_b = 32 // b = 33, ..., 512 -> max_b = next power of 2 - // otherwise -> max_b = b + // b = 513, ... -> max_b = increment by 512 if (log2_b <= 5) { max_b = 32; } else if (log2_b <= 9) { max_b = pow(2, log2_b); + } else { + max_b = (batch_size + 511) / 512 * 512; } return max_b; } diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 7d23bb5c5..08a56cda6 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -101,6 +101,8 @@ struct FADescriptor_v1 { std::int64_t max_pages_per_seq_v; std::int64_t bias_b; std::int64_t bias_h; + std::int64_t bias_sq; + std::int64_t bias_skv; float attnScale; bool isTraining; float dropoutProbability; @@ -110,6 +112,7 @@ struct FADescriptor_v1 { NVTE_Softmax_Type softmax_type; std::int64_t window_size_left; std::int64_t window_size_right; + bool bottom_right_diagonal; bool deterministic; cudnn_frontend::DataType_t qkv_tensor_type; cudnn_frontend::DataType_t o_tensor_type; @@ -119,15 +122,17 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, - attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, - o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, + bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, + dqkv_tensor_type, generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, - rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, - rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, - rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, + rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, + rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, + rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, + rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); } diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index e787b31c8..3b798406b 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -273,13 +273,17 @@ void log_fused_attn_config( std::cout<data.shape.size(); - size_t b = input_cu_seqlens->data.shape[0] - 1; - size_t h = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - h = input_QKV->data.shape[ndim - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - h = input_QKV->data.shape[ndim - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); - } - size_t d = input_QKV->data.shape[ndim - 1]; - - const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); - - //log the fused attn config at NVTE common level - log_fused_attn_config( - __FUNCTION__, - QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, b, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); - - // fix the incompatible window size from upstream frameworks pytorch/jax - std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); - - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, - cuda_graph); - - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { - fused_attn_ck_fwd_qkvpacked( - b, h, max_seqlen, d, - is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - window_size_left, window_size_right, - input_QKV, input_Bias, - output_O, Aux_CTX_Tensors, - input_cu_seqlens, input_cu_seqlens_padded, - input_rng_state, - wkspace, - stream); - } else if(fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_AOTriton){ - fused_attn_aotriton_fwd_qkvpacked( - b, h, max_seqlen, d, - is_training, attn_scale, dropout, - window_size_left, window_size_right, - qkv_layout, bias_type, attn_mask_type, - input_QKV, - output_O, Aux_CTX_Tensors, - input_cu_seqlens, - input_rng_state, - wkspace, - stream); - }else{ - NVTE_ERROR("Invalid combination of data type and sequence length for rocm fused attention. \n"); - } -} - -// NVTE fused attention BWD with packed QKV -void nvte_fused_attn_bwd_qkvpacked( - const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, - NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, - NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); - using namespace transformer_engine; - - const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); - const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); - const Tensor *input_QKV = convertNVTETensorCheck(QKV); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - Tensor *output_dQKV = convertNVTETensorCheck(dQKV); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - Tensor *wkspace = convertNVTETensorCheck(workspace); - - // auxiliary tensors - const Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); //softmax lse - //extract the saved rng state from aux_ctx_tensor - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor *input_Bias = nullptr; - - auto ndim = input_QKV->data.shape.size(); - size_t b = input_cu_seqlens->data.shape[0] - 1; - size_t h = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - h = input_QKV->data.shape[ndim - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - h = input_QKV->data.shape[ndim - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); - } - size_t d = input_QKV->data.shape[ndim - 1]; - - const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); - - //log the fused attn config at NVTE common level - log_fused_attn_config( - __FUNCTION__, - QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, b, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); - - // fix the incompatible window size from upstream frameworks pytorch/jax - std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); - - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); - - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { - if((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)){ - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } - fused_attn_ck_bwd_qkvpacked( - b, h, max_seqlen, d, - attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, - window_size_left, window_size_right, - false, // TODO: enable deterministic after CK team show us how - input_QKV, input_O, input_dO, input_Bias, output_S, - output_dQKV, output_dBias, - input_cu_seqlens, input_cu_seqlens_padded, - input_rng_state, - wkspace, - stream); - } else if(fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_AOTriton){ - //currently aotriton is deterministic - fused_attn_aotriton_bwd_qkvpacked( - b, h, max_seqlen, d, - attn_scale, dropout, - window_size_left, window_size_right, - qkv_layout, bias_type, attn_mask_type, - input_QKV, input_O, input_dO, output_S, - output_dQKV, - input_cu_seqlens, - input_rng_state, - wkspace, - stream); - }else{ - NVTE_ERROR("Invalid combination of data type and sequence length for rocm fused attention. \n"); - } -} - -// NVTE fused attention FWD with packed KV -void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, - NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { - - NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); - const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); - const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_KV = convertNVTETensorCheck(KV); - const Tensor *input_Bias = convertNVTETensorCheck(Bias); - Tensor *output_O = convertNVTETensorCheck(O); - Tensor *wkspace = convertNVTETensorCheck(workspace); - - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - auto ndim = input_Q->data.shape.size(); - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; - auto ndim_kv = input_KV->data.shape.size(); - size_t h_kv = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - h_kv = input_KV->data.shape[ndim_kv - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - h_kv = input_KV->data.shape[ndim_kv - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); - } - - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_KV->data.dtype); - - //log the fused attn config at NVTE common level - log_fused_attn_config( - __FUNCTION__, - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, b, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); - - // fix the incompatible window size from upstream frameworks pytorch/jax - std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); - - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_logit, cuda_graph); - - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { - fused_attn_ck_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, - window_size_left, window_size_right, - input_Q, input_KV, input_Bias, - output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, - input_cu_seqlens_kv, - input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, - input_rng_state, - wkspace, - stream); - } else if(fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_AOTriton){ - fused_attn_aotriton_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - is_training, attn_scale, dropout, - window_size_left, window_size_right, - qkv_layout, bias_type, attn_mask_type, - input_Q, input_KV, - output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, - input_cu_seqlens_kv, - input_rng_state, - wkspace, - stream); - }else{ - NVTE_ERROR("Invalid combination of data type and sequence length for rocm fused attention. \n"); - } -} - -// NVTE fused attention BWD with packed KV -void nvte_fused_attn_bwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_KV = convertNVTETensorCheck(KV); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - Tensor *output_dQ = convertNVTETensorCheck(dQ); - Tensor *output_dKV = convertNVTETensorCheck(dKV); - Tensor *wkspace = convertNVTETensorCheck(workspace); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - - // auxiliary tensors (to be propagated to the backward pass later) - const Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); //softmax lse - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor *input_Bias = nullptr; - - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - auto ndim = input_Q->data.shape.size(); - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; - auto ndim_kv = input_KV->data.shape.size(); - size_t h_kv = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - h_kv = input_KV->data.shape[ndim_kv - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - h_kv = input_KV->data.shape[ndim_kv - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_bwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); - } - - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_KV->data.dtype); - - //log the fused attn config at NVTE common level - log_fused_attn_config( - __FUNCTION__, - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, b, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); - - // fix the incompatible window size from upstream frameworks pytorch/jax - std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); - - NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, - softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - d, window_size_left, window_size_right, false, cuda_graph); - - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } - fused_attn_ck_bwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, - window_size_left, window_size_right, - false, // TODO: enable deterministic after CK team show us how - input_Q, input_KV, input_O, input_dO, input_Bias, - output_S, - output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, - input_cu_seqlens_kv, - input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, - input_rng_state, - wkspace, - stream); - } else if(fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_AOTriton){ - // currently aotriton is deterministic - fused_attn_aotriton_bwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - attn_scale, dropout, - window_size_left, window_size_right, - qkv_layout, bias_type, attn_mask_type, - input_Q, input_KV, input_O, input_dO, - output_S, - output_dQ, output_dKV, - input_cu_seqlens_q, - input_cu_seqlens_kv, - input_rng_state, - wkspace, - stream); - }else{ - NVTE_ERROR("Invalid combination of data type and sequence length for rocm fused attention. \n"); - } -} - // NVTE fused attention FWD with separate Q, K and V void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, @@ -720,8 +361,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -760,7 +401,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd( @@ -806,8 +447,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -852,7 +494,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph); + cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -907,3 +549,5 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se using namespace transformer_engine::fused_attn_rocm; PopulateRngStateAsync(rng_state_dst, seed, q_max_seqlen, kv_max_seqlen, backend, stream); } + +#pragma GCC diagnostic pop diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 9a0161ca5..33557d214 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -495,328 +495,6 @@ void fused_attn_aotriton_bwd_impl( } // namespace fused_attn_rocm using namespace transformer_engine::fused_attn_rocm; -void fused_attn_aotriton_fwd_qkvpacked( - size_t b, size_t h, size_t max_seqlen, size_t d, - bool is_training, float attn_scale, float dropout, - int32_t window_left, int32_t window_right, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const Tensor* input_QKV, - Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor* input_cu_seqlens, - const Tensor* rng_state, - Tensor *workspace, - cudaStream_t stream){ - -#ifdef USE_FUSED_ATTN_AOTRITON - const DType QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - // determine the stride based on qkv layout - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = nvte_dtype_size(QKV_type) * h * d; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = nvte_dtype_size(QKV_type) * d; - } - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - //save the input rng state to Aux_CTX_Tensors - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {b, h, max_seqlen, 1}; - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - size_t workspace_size = 0; - - fused_attn_aotriton_fwd_impl( - b, h, h, max_seqlen, max_seqlen, d, - is_training, attn_scale, dropout, - window_left, window_right, - qkv_layout, - bias_type, attn_mask_type, - devPtrQ, devPtrK, devPtrV, - devPtrS, devPtrO, - reinterpret_cast(rng_state->data.dptr), - reinterpret_cast(rng_state->data.dptr) + 1, - input_cu_seqlens->data.dptr, input_cu_seqlens->data.dptr, - nvte_to_aotriton_dtype(QKV_type), - workspace->data.dptr, - &workspace_size, - stream); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -#else - NVTE_ERROR("AOTriton backend not compiled."); -#endif // USE_FUSED_ATTN_AOTRITON -} - -void fused_attn_aotriton_bwd_qkvpacked( - size_t b, size_t h, size_t max_seqlen, size_t d, - float attn_scale, float dropout, - int32_t window_size_left, int32_t window_size_right, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, - const Tensor* output_S, - Tensor* output_dQKV, - const Tensor* input_cu_seqlens, - const Tensor* rng_state, - Tensor* workspace, - cudaStream_t stream){ - -#ifdef USE_FUSED_ATTN_AOTRITON - const DType QKV_type = input_QKV->data.dtype; - //input tensor - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = nvte_dtype_size(QKV_type) * h * d; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = nvte_dtype_size(QKV_type) * d; - } - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - void *devPtrSoftmaxStats = output_S->data.dptr; - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - - // output tensor - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = static_cast(devPtrdQKV); - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - size_t workspace_size = 0; - fused_attn_aotriton_bwd_impl( - b, h, h, max_seqlen, max_seqlen, d, - attn_scale, dropout, - window_size_left, window_size_right, - qkv_layout, - bias_type, attn_mask_type, - devPtrQ, devPtrK, devPtrV, - devPtrO, devPtrSoftmaxStats, - devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, - input_cu_seqlens->data.dptr, input_cu_seqlens->data.dptr, - reinterpret_cast(rng_state->data.dptr), - reinterpret_cast(rng_state->data.dptr) + 1, - nvte_to_aotriton_dtype(QKV_type), - workspace->data.dptr, - &workspace_size, - stream); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -#else - NVTE_ERROR("AOTriton backend not compiled."); -#endif // USE_FUSED_ATTN_AOTRITON -} - -void fused_attn_aotriton_fwd_kvpacked( - size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, - bool is_training, float attn_scale, float dropout, - int32_t window_left, int32_t window_right, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const Tensor* input_Q, const Tensor* input_KV, - Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor* input_cu_seqlens_q, - const Tensor* input_cu_seqlens_kv, - const Tensor* rng_state, - Tensor *workspace, - cudaStream_t stream){ - -#ifdef USE_FUSED_ATTN_AOTRITON - const DType QKV_type = input_Q->data.dtype; - //input tensor - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = nvte_dtype_size(QKV_type)*h_kv*d; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = nvte_dtype_size(QKV_type) * d; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {b, h_q, max_seqlen_q, 1}; - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - size_t workspace_size = 0; - - fused_attn_aotriton_fwd_impl( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - is_training, attn_scale, dropout, - window_left, window_right, - qkv_layout, - bias_type, attn_mask_type, - devPtrQ, devPtrK, devPtrV, - devPtrS, devPtrO, - reinterpret_cast(rng_state->data.dptr), - reinterpret_cast(rng_state->data.dptr) + 1, - input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, - nvte_to_aotriton_dtype(QKV_type), - workspace->data.dptr, - &workspace_size, - stream); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -#else - NVTE_ERROR("AOTriton backend not compiled."); -#endif // USE_FUSED_ATTN_AOTRITON -} - -void fused_attn_aotriton_bwd_kvpacked( - size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, - float attn_scale, float dropout, - int32_t window_size_left, int32_t window_size_right, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, - const Tensor* output_S, - Tensor* output_dQ, Tensor* output_dKV, - const Tensor* input_cu_seqlens_q, - const Tensor* input_cu_seqlens_kv, - const Tensor* rng_state, - Tensor* workspace, - cudaStream_t stream){ - -#ifdef USE_FUSED_ATTN_AOTRITON - const DType QKV_type = input_Q->data.dtype; - //input tensor - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = nvte_dtype_size(QKV_type) * h_kv * d; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = nvte_dtype_size(QKV_type) * d; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - // output tensor - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdKV = output_dKV->data.dptr; - void *devPtrdK = devPtrdKV; - void *devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - - void *devPtrSoftmaxStats = output_S->data.dptr; - - size_t workspace_size = 0; - fused_attn_aotriton_bwd_impl( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - attn_scale, dropout, - window_size_left, window_size_right, - qkv_layout, - bias_type, attn_mask_type, - devPtrQ, devPtrK, devPtrV, - devPtrO, devPtrSoftmaxStats, - devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, - input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, - reinterpret_cast(rng_state->data.dptr), - reinterpret_cast(rng_state->data.dptr) + 1, - nvte_to_aotriton_dtype(QKV_type), - workspace->data.dptr, - &workspace_size, - stream); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -#else - NVTE_ERROR("AOTriton backend not compiled."); -#endif // USE_FUSED_ATTN_AOTRITON -} void fused_attn_aotriton_fwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index bc96b2c7c..6d53a4556 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -33,58 +33,6 @@ bool is_aotriton_backend_supported( int64_t window_size_right); } // namespace fused_attn_rocm -void fused_attn_aotriton_fwd_qkvpacked( - size_t b, size_t h, size_t max_seqlen, size_t d, - bool is_training, float attn_scale, float dropout, - int32_t window_left, int32_t window_right, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const Tensor* input_QKV, - Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor* input_cu_seqlens, - const Tensor* rng_state, - Tensor *workspace, - cudaStream_t stream); - -void fused_attn_aotriton_bwd_qkvpacked( - size_t b, size_t h, size_t max_seqlen, size_t d, - float attn_scale, float dropout, - int32_t window_size_left, int32_t window_size_right, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, - const Tensor* output_S, - Tensor* output_dQKV, - const Tensor* input_cu_seqlens, - const Tensor* rng_state, - Tensor* workspace, - cudaStream_t stream); - -void fused_attn_aotriton_fwd_kvpacked( - size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, - bool is_training, float attn_scale, float dropout, - int32_t window_left, int32_t window_right, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const Tensor* input_Q, const Tensor* input_KV, - Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor* input_cu_seqlens_q, - const Tensor* input_cu_seqlens_kv, - const Tensor* rng_state, - Tensor *workspace, - cudaStream_t stream); - -void fused_attn_aotriton_bwd_kvpacked( - size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, - float attn_scale, float dropout, - int32_t window_size_left, int32_t window_size_right, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, - const Tensor* output_S, - Tensor* output_dQ, Tensor* output_dKV, - const Tensor* input_cu_seqlens_q, - const Tensor* input_cu_seqlens_kv, - const Tensor* rng_state, - Tensor* workspace, - cudaStream_t stream); - void fused_attn_aotriton_fwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, bool is_training, float attn_scale, float dropout, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 02bc9ce94..11f6bc7b4 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -1300,502 +1300,6 @@ void fused_attn_ck_bwd_impl( } // namespace fused_attn_rocm using namespace transformer_engine::fused_attn_rocm; -void fused_attn_ck_fwd_qkvpacked( - size_t b, size_t h, size_t max_seqlen, size_t d, - bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - const Tensor* input_QKV, const Tensor* input_Bias, - Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor* input_cu_seqlens, - const Tensor* input_cu_seqlens_padded, - const Tensor* rng_state, - Tensor *workspace, - cudaStream_t stream){ - -#ifdef USE_FUSED_ATTN_CK - const DType QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - // determine the stride based on qkv layout - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride_to_k = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride_to_k = nvte_dtype_size(QKV_type) * h * d; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride_to_k = nvte_dtype_size(QKV_type) * d; - } - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride_to_k); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2*stride_to_k); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - void *devPtrCuSeqlens = input_cu_seqlens->data.dptr; - void *devPtrSeqOffsets = input_cu_seqlens_padded->data.dptr; - - size_t max_tokens = std::accumulate((input_QKV->data).shape.begin(), (input_QKV->data).shape.end(), static_cast(1), std::multiplies())/h/d/3; - bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; - - if (Aux_CTX_Tensors->size == 0) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if(is_ragged){ - output_S->data.shape = {max_tokens, h, 1}; - }else{ - output_S->data.shape = {b, h, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; - output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if(is_ragged){ - output_S->data.shape = {max_tokens, h, 1}; - }else{ - output_S->data.shape = {b, h, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - size_t workspace_size = 0; - - bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - fused_attn_ck_fwd_impl( - b, h, h, max_seqlen, max_seqlen, d, d, bias_b, bias_h, - max_tokens, max_tokens, - is_training, attn_scale, dropout, - qkv_layout, - bias_type, attn_mask_type, - window_size_left, window_size_right, - devPtrQ, - devPtrK, - devPtrV, - devPtrBias, - devPtrS, - devPtrO, - rng_state->data.dptr, - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1), - devPtrCuSeqlens, devPtrCuSeqlens, - devPtrSeqOffsets, devPtrSeqOffsets, - QKV_type, - workspace->data.dptr, - &workspace_size, - stream); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -#else - NVTE_ERROR("CK fused attn backend not compiled."); -#endif // USE_FUSED_ATTN_CK -} - -void fused_attn_ck_bwd_qkvpacked( - size_t b, size_t h, size_t max_seqlen, size_t d, - float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, - const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_Bias, - const Tensor* output_S, - Tensor* output_dQKV, - Tensor* output_dBias, - const Tensor* input_cu_seqlens, - const Tensor* input_cu_seqlens_padded, - const Tensor* rng_state, - Tensor* workspace, - cudaStream_t stream){ - -#ifdef USE_FUSED_ATTN_CK - const DType QKV_type = input_QKV->data.dtype; - //input tensor - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride_to_k = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride_to_k = nvte_dtype_size(QKV_type) * h * d; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride_to_k = nvte_dtype_size(QKV_type) * d; - } - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride_to_k); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2*stride_to_k); - void *devPtrSoftmaxStats = output_S->data.dptr; - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - // output tensor - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = static_cast(devPtrdQKV); - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride_to_k); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2*stride_to_k); - - void *devPtrCuSeqlens = input_cu_seqlens->data.dptr; - void *devPtrSeqOffsets = input_cu_seqlens_padded->data.dptr; - - size_t workspace_size = 0; - - // extract the qkv and o storage bytes to clear dq buffer - bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; - bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - // extract the max_tokens for padding/unpadding and softmax_lse buffer - // b from cu_seqlen and max_seqlen are not the actual storage batch and seqlen for pad_between_seqs case - size_t max_tokens = std::accumulate((input_QKV->data).shape.begin(), (input_QKV->data).shape.end(), static_cast(1), std::multiplies())/h/d/3; - - // in qkvpacked layouts, o is of the same max_tokens as q - // dqkv has the same shape as qkv - // do has the same shape as o - - fused_attn_ck_bwd_impl( - b, h, h, max_seqlen, max_seqlen, d, d, bias_b, bias_h, - max_tokens, max_tokens, - attn_scale, dropout, - qkv_layout, - bias_type, attn_mask_type, - window_size_left, window_size_right, - deterministic, - devPtrQ, devPtrK, devPtrV, - devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdBias, - rng_state->data.dptr, - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1), - devPtrCuSeqlens, devPtrCuSeqlens, - devPtrSeqOffsets, devPtrSeqOffsets, - QKV_type, - workspace->data.dptr, - &workspace_size, - stream); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -#else - NVTE_ERROR("CK fused attn backend not compiled."); -#endif // USE_FUSED_ATTN_CK -} - -void fused_attn_ck_fwd_kvpacked( - size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, - bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_Bias, - Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor* input_cu_seqlens_q, - const Tensor* input_cu_seqlens_kv, - const Tensor* input_cu_seqlens_q_padded, - const Tensor* input_cu_seqlens_kv_padded, - const Tensor* rng_state, - Tensor *workspace, - cudaStream_t stream){ - -#ifdef USE_FUSED_ATTN_CK - const DType QKV_type = input_Q->data.dtype; - //input tensor - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = nvte_dtype_size(QKV_type)*h_kv*d; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = nvte_dtype_size(QKV_type) * d; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - void *devPtrCuSeqlensQ = input_cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = input_cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = input_cu_seqlens_q_padded->data.dptr; - void *devPtrSeqOffsetsKV = input_cu_seqlens_kv_padded->data.dptr; - - size_t max_tokens_q = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast(1), std::multiplies())/h_q/d; - size_t max_tokens_kv = std::accumulate((input_KV->data).shape.begin(), (input_KV->data).shape.end(), static_cast(1), std::multiplies())/h_kv/d/2; - - bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; - - if (Aux_CTX_Tensors->size == 0) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if(is_ragged){ - output_S->data.shape = {max_tokens_q, h_q, 1}; - }else{ - output_S->data.shape = {b, h_q, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; - output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if(is_ragged){ - output_S->data.shape = {max_tokens_q, h_q, 1}; - }else{ - output_S->data.shape = {b, h_q, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - size_t workspace_size = 0; - - bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - - fused_attn_ck_fwd_impl( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, bias_b, bias_h, - max_tokens_q, max_tokens_kv, - is_training, attn_scale, dropout, - qkv_layout, - bias_type, attn_mask_type, - window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, - devPtrS, devPtrO, - rng_state->data.dptr, - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1), - devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - QKV_type, - workspace->data.dptr, - &workspace_size, - stream); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -#else - NVTE_ERROR("CK fused attn backend not compiled."); -#endif // USE_FUSED_ATTN_CK -} - -void fused_attn_ck_bwd_kvpacked( - size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, - float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, - const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_Bias, - const Tensor* output_S, - Tensor* output_dQ, Tensor* output_dKV, - Tensor* output_dBias, - const Tensor* input_cu_seqlens_q, - const Tensor* input_cu_seqlens_kv, - const Tensor* input_cu_seqlens_q_padded, - const Tensor* input_cu_seqlens_kv_padded, - const Tensor* rng_state, - Tensor* workspace, - cudaStream_t stream){ -#ifdef USE_FUSED_ATTN_CK - const DType QKV_type = input_Q->data.dtype; - //input tensor - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = nvte_dtype_size(QKV_type) * h_kv * d; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = nvte_dtype_size(QKV_type) * d; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - // output tensor - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdKV = output_dKV->data.dptr; - void *devPtrdK = devPtrdKV; - void *devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - - void *devPtrSoftmaxStats = output_S->data.dptr; - - void *devPtrCuSeqlensQ = input_cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = input_cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = input_cu_seqlens_q_padded->data.dptr; - void *devPtrSeqOffsetsKV = input_cu_seqlens_kv_padded->data.dptr; - - size_t workspace_size = 0; - - bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; - bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - - // extract the max_tokens for padding/unpadding and softmax_lse buffer - // b from cu_seqlen and max_seqlen are not the actual storage batch and seqlen for pad_between_seqs case - size_t max_tokens_q = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast(1), std::multiplies())/h_q/d; - size_t max_tokens_kv = std::accumulate((input_KV->data).shape.begin(), (input_KV->data).shape.end(), static_cast(1), std::multiplies())/h_kv/d/2; - - fused_attn_ck_bwd_impl( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, bias_b, bias_h, - max_tokens_q, max_tokens_kv, - attn_scale, dropout, - qkv_layout, - bias_type, attn_mask_type, - window_size_left, window_size_right, - deterministic, - devPtrQ, devPtrK, devPtrV, - devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdBias, - rng_state->data.dptr, - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1), - devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - QKV_type, - workspace->data.dptr, - &workspace_size, - stream); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -#else - NVTE_ERROR("CK fused attn backend not compiled."); -#endif // USE_FUSED_ATTN_CK -} void fused_attn_ck_fwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d_qk, size_t d_v, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h index 0772609ff..7aafc883f 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h @@ -32,68 +32,6 @@ bool is_ck_backend_supported( int64_t window_size_right); } // namespace fused_attn_rocm -void fused_attn_ck_fwd_qkvpacked( - size_t b, size_t h, size_t max_seqlen, size_t d, - bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - const Tensor* input_QKV, const Tensor* input_Bias, - Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor* input_cu_seqlens, - const Tensor* input_cu_seqlens_padded, - const Tensor* rng_state, - Tensor *workspace, - cudaStream_t stream); - -void fused_attn_ck_bwd_qkvpacked( - size_t b, size_t h, size_t max_seqlen, size_t d, - float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, - const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_Bias, - const Tensor* output_S, - Tensor* output_dQKV, - Tensor* output_dBias, - const Tensor* input_cu_seqlens, - const Tensor* input_cu_seqlens_padded, - const Tensor* rng_state, - Tensor* workspace, - cudaStream_t stream); - -void fused_attn_ck_fwd_kvpacked( - size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, - bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_Bias, - Tensor* output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor* input_cu_seqlens_q, - const Tensor* input_cu_seqlens_kv, - const Tensor* input_cu_seqlens_q_padded, - const Tensor* input_cu_seqlens_kv_padded, - const Tensor* rng_state, - Tensor *workspace, - cudaStream_t stream); - -void fused_attn_ck_bwd_kvpacked( - size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, - float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, - const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_Bias, - const Tensor* output_S, - Tensor* output_dQ, Tensor* output_dKV, - Tensor* output_dBias, - const Tensor* input_cu_seqlens_q, - const Tensor* input_cu_seqlens_kv, - const Tensor* input_cu_seqlens_q_padded, - const Tensor* input_cu_seqlens_kv_padded, - const Tensor* rng_state, - Tensor* workspace, - cudaStream_t stream); - void fused_attn_ck_fwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d_qk, size_t d_v, bool is_training, float attn_scale, float dropout, diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index a2151aa21..dcafe72d6 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -18,9 +18,7 @@ #include "utils.h" namespace transformer_engine { - -// Using Double to hanld all the calculations -using CompType = double; +namespace fused_router { template __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, @@ -100,7 +98,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Compute the aux_loss */ float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens; - aux_loss[0] = static_cast(static_cast(intermediate_result) * C_coeff); + aux_loss[0] = static_cast(intermediate_result * C_coeff); Const_buf[0] = C_coeff; } } @@ -156,7 +154,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Compute the aux_loss */ float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens; - aux_loss[0] = static_cast(static_cast(intermediate_result) * C_coeff); + aux_loss[0] = static_cast(intermediate_result * C_coeff); Const_buf[0] = C_coeff; } } @@ -236,8 +234,8 @@ __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf, // Loop: for all positions in each row for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { float C_coeff = Const_buf[0]; - double tokens_per_expert_i = static_cast(tokens_per_expert[i]); - double grad_aux_loss_value = static_cast(grad_aux_loss[0]); + CompType tokens_per_expert_i = static_cast(tokens_per_expert[i]); + CompType grad_aux_loss_value = static_cast(grad_aux_loss[0]); // Loop: for all rows for (int j = global_warp_id; j < num_rows; j += global_warp_num) { grad_probs[j * num_cols + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value; @@ -272,6 +270,7 @@ void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_p reinterpret_cast(grad_probs.data.dptr), stream););); } +} // namespace fused_router } // namespace transformer_engine void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert, @@ -280,7 +279,7 @@ void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor to NVTETensor Const_buf, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_moe_aux_loss_forward); using namespace transformer_engine; - fused_moe_aux_loss_forward( + fused_router::fused_moe_aux_loss_forward( *convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert), total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, *convertNVTETensorCheck(aux_loss), *convertNVTETensorCheck(Const_buf), stream); @@ -292,8 +291,8 @@ void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_moe_aux_loss_backward); using namespace transformer_engine; - fused_moe_aux_loss_backward(*convertNVTETensorCheck(Const_buf), - *convertNVTETensorCheck(tokens_per_expert), num_rows, num_cols, - *convertNVTETensorCheck(grad_aux_loss), - *convertNVTETensorCheck(grad_probs), stream); + fused_router::fused_moe_aux_loss_backward(*convertNVTETensorCheck(Const_buf), + *convertNVTETensorCheck(tokens_per_expert), num_rows, + num_cols, *convertNVTETensorCheck(grad_aux_loss), + *convertNVTETensorCheck(grad_probs), stream); } diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 7540b5c41..ac13e7159 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -14,17 +16,16 @@ #include "utils.h" namespace transformer_engine { +namespace fused_router { template __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens, int num_experts, int topk, - int score_function, DataType *scores, + int score_function, float *scores, bool *routing_map, - DataType *intermediate_output) { + CompType *intermediate_output) { /*** * Section: Global Variables/Addresses init - * - Assume the sizeof(DataType) >= sizeof(int), - * So DataType address is assigned firstly to avoid the alignment issue * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ @@ -33,13 +34,13 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; extern __shared__ float shmem_scores_for_aux_loss[]; - DataType *logits_buf = reinterpret_cast(shmem_scores_for_aux_loss); - DataType *topk_logits_buf = - reinterpret_cast(logits_buf + num_experts * num_token_per_block); + CompType *logits_buf = reinterpret_cast(shmem_scores_for_aux_loss); + CompType *topk_logits_buf = + reinterpret_cast(logits_buf + num_experts * num_token_per_block); int *topk_indices_buf = reinterpret_cast(topk_logits_buf + topk * num_token_per_block); // The address of buffers on the current warp - DataType *local_logits = logits_buf + warp_id * num_experts; - DataType *topk_logits = topk_logits_buf + warp_id * topk; + CompType *local_logits = logits_buf + warp_id * num_experts; + CompType *topk_logits = topk_logits_buf + warp_id * topk; int *topk_indices = topk_indices_buf + warp_id * topk; /*** @@ -63,12 +64,12 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { routing_map[pos_offset + i] = false; if (score_function == 1) { - intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); + intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); } } // Load the logits to shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_logits[i] = logits[pos_offset + i]; + local_logits[i] = static_cast(logits[pos_offset + i]); } __threadfence_block(); __syncwarp(); @@ -78,11 +79,11 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi * Possible preprocess the scores before the topk operation * - Pre-softmax * - Sigmoid - * - Sigmoid post-processing when topk > 1 + * - Sqrtsoftplus + * - Sigmoid/Sqrtsoftplus post-processing when topk > 1 * This is in-place scores update */ - // score_function == 1 means softmax - if (score_function == 1) { + if (score_function == 1) { // score_function == 1 means softmax // Apply softmax to the logits before the topk apply_softmax_on_float(local_logits, num_experts, lane_id); __syncwarp(); @@ -90,10 +91,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = local_logits[i]; } - } - - // score_function == 0 means sigmoid - if (score_function == 0) { + } else if (score_function == 0) { // score_function == 0 means sigmoid // Apply sigmoid to the logits apply_sigmoid_on_float(local_logits, num_experts, lane_id); __syncwarp(); @@ -101,18 +99,24 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = local_logits[i]; } + } else if (score_function == 2) { // score_function == 2 means sqrtsoftplus + // First save the original logits for backward (needed for gradient computation) + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = local_logits[i]; // Save original logits + } + __syncwarp(); + // Apply sqrtsoftplus to the logits + apply_sqrtsoftplus_on_float(local_logits, num_experts, lane_id); } - __syncwarp(); //Confirm the scores is written to the softmax/sigmoid output + __syncwarp(); //Confirm the scores is written to the output - if (score_function == 0) { - if (topk > 1) { - auto sum_logits = - warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_logits[i] = static_cast(static_cast(local_logits[i]) / - (static_cast(sum_logits) + epsilon)); - } + // Sigmoid/Sqrtsoftplus post-processing + if (score_function == 0 || score_function == 2) { + auto sum_logits = + warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_logits[i] /= (sum_logits + epsilon); } __syncwarp(); } @@ -140,12 +144,12 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi template void fused_score_for_moe_aux_loss_forward_kernel_launcher( const DataType *logits, int num_tokens, int num_experts, int topk, int score_function, - DataType *scores, bool *routing_map, DataType *intermediate_output, cudaStream_t stream) { + float *scores, bool *routing_map, CompType *intermediate_output, cudaStream_t stream) { // Meta data for the kernel size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // logits - + topk * num_token_per_block * sizeof(DataType) // topk_logits + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // logits + + topk * num_token_per_block * sizeof(CompType) // topk_logits + topk * num_token_per_block * sizeof(int); // topk_indices fused_score_for_moe_aux_loss_forward_kernel <<>>( @@ -162,20 +166,19 @@ void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, logits.data.dtype, DataType, fused_score_for_moe_aux_loss_forward_kernel_launcher( reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, - score_function, reinterpret_cast(scores.data.dptr), + score_function, reinterpret_cast(scores.data.dptr), reinterpret_cast(routing_map.data.dptr), - reinterpret_cast(intermediate_output.data.dptr), stream);); + reinterpret_cast(intermediate_output.data.dptr), stream);); } template -__global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *intermediate_output, - const DataType *grad_scores, +__global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *intermediate_output, + const float *grad_scores, int num_tokens, int num_experts, int topk, int score_function, DataType *grad_logits) { /*** * Section: Global Variables/Addresses init - * - Assume the sizeof(DataType) >= sizeof(int), * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ @@ -184,16 +187,14 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; extern __shared__ float shmem[]; - DataType *grad_scores_buf = reinterpret_cast(shmem); - // To store the output of softmax/sigmoid from the fwd - DataType *act_from_fwd_buf = - reinterpret_cast(grad_scores_buf + num_experts * num_token_per_block); - DataType *comp_buf = - reinterpret_cast(act_from_fwd_buf + num_experts * num_token_per_block); + CompType *grad_scores_buf = reinterpret_cast(shmem); + // To store the output of softmax/sigmoid from fwd, or original logits for sqrtsoftplus + CompType *act_from_fwd_buf = grad_scores_buf + num_experts * num_token_per_block; + CompType *comp_buf = act_from_fwd_buf + num_experts * num_token_per_block; // The address of buffers on the current warp - DataType *local_grad = grad_scores_buf + warp_id * num_experts; - DataType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; - DataType *local_comp_buf = comp_buf + warp_id * num_experts; + CompType *local_grad = grad_scores_buf + warp_id * num_experts; + CompType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; + CompType *local_comp_buf = comp_buf + warp_id * num_experts; /*** * Section: Main Loop @@ -212,10 +213,6 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int * - Load the dgrad/output_from_fwd to shmem */ int pos_offset = token_offset_cur_warp * num_experts; - // Clear the logits_grad in global mem - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - grad_logits[pos_offset + i] = 0.0f; - } // Load the dgrad/output_from_fwd to shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_grad[i] = grad_scores[pos_offset + i]; @@ -227,31 +224,54 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int /*** * Section: Backward of ops before the topk * - Pre-softmax bwd - * - Sigmoid Post-processing bwd when topk > 1 + * - Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 * - Sigmoid bwd + * - Sqrtsoftplus bwd * - Write the grad_logits to the global mem */ - // Sigmoid Post-processing bwd when topk > 1 - if (topk > 1 && score_function == 0) { - auto sum_fwd_input = - warp_reduce_on_shmem(local_act_from_fwd, num_experts, ReduceFuncType::SUM, lane_id); - // Put the result of output * grad to the comp_buf + // Sqrtsoftplus: First compute sqrtsoftplus output from original logits + // (needed for both post-processing bwd and activation bwd, compute once here) + // For sqrtsoftplus, intermediate_output stores original logits + if (score_function == 2) { + // Copy original logits to local_comp_buf and apply sqrtsoftplus in-place for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i]; + local_comp_buf[i] = local_act_from_fwd[i]; } __syncwarp(); - auto sum_Output_x_Grad = - warp_reduce_on_shmem(local_comp_buf, num_experts, ReduceFuncType::SUM, lane_id); + apply_sqrtsoftplus_on_float(local_comp_buf, num_experts, lane_id); + __syncwarp(); + } + + // Sigmoid/Sqrtsoftplus Post-processing bwd (normalization backward) + if (score_function == 0 || score_function == 2) { + // Select the correct activation output buffer: + // - Sigmoid: local_act_from_fwd already contains sigmoid output + // - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above + CompType *act_output = (score_function == 0) ? local_act_from_fwd : local_comp_buf; + + auto sum_fwd_input = + warp_reduce_on_shmem(act_output, num_experts, ReduceFuncType::SUM, lane_id); + // Compute sum of output * grad using registers + CompType local_sum_Output_x_Grad = 0.0; + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_sum_Output_x_Grad += local_grad[i] * act_output[i]; + } + // Warp reduce the sum + for (int s = 16; s > 0; s /= 2) { +#ifdef __HIP_PLATFORM_AMD__ + local_sum_Output_x_Grad += __shfl_xor(local_sum_Output_x_Grad, s, kThreadsPerWarp); +#else + local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s); +#endif + } + CompType sum_Output_x_Grad = local_sum_Output_x_Grad; // In-place update for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_grad[i] = - static_cast(local_grad[i]) / (static_cast(sum_fwd_input) + epsilon) - - static_cast(sum_Output_x_Grad) / - ((static_cast(sum_fwd_input) + epsilon) * - (static_cast(sum_fwd_input) + epsilon)); + local_grad[i] = local_grad[i] / (sum_fwd_input + epsilon) - + sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon)); } + __syncwarp(); } - __syncwarp(); // Pre-softmax bwd if (score_function == 1) { @@ -264,9 +284,17 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id); __syncwarp(); } + // Sqrtsoftplus bwd + // For sqrtsoftplus, local_comp_buf already contains sqrtsoftplus output computed earlier + // Now compute gradient: dy/dx = sigmoid(x) / (2 * y) + if (score_function == 2) { + apply_sqrtsoftplus_bwd_on_float(local_grad, local_comp_buf, local_act_from_fwd, num_experts, + lane_id); + __syncwarp(); + } // Write the grad_logits to the global mem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - grad_logits[pos_offset + i] = local_grad[i]; + grad_logits[pos_offset + i] = static_cast(local_grad[i]); } __syncwarp(); } @@ -274,15 +302,15 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int template void fused_score_for_moe_aux_loss_backward_kernel_launcher( - const DataType *intermediate_output, const DataType *grad_scores, int num_tokens, - int num_experts, int topk, int score_function, DataType *grad_logits, cudaStream_t stream) { + const CompType *intermediate_output, const float *grad_scores, int num_tokens, int num_experts, + int topk, int score_function, DataType *grad_logits, cudaStream_t stream) { // Meta data for the kernel size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // grad_scores + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // grad_scores + - num_experts * num_token_per_block * sizeof(DataType) // act_from_fwd - + num_experts * num_token_per_block * sizeof(DataType); // comp_buf + num_experts * num_token_per_block * sizeof(CompType) // act_from_fwd + + num_experts * num_token_per_block * sizeof(CompType); // comp_buf fused_score_for_moe_aux_loss_backward_kernel <<>>( intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, @@ -295,13 +323,14 @@ void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output, int num_experts, int topk, int score_function, Tensor &grad_logits, cudaStream_t stream) { TE_ROUTER_PROBS_TYPE_SWITCH_ALL( - grad_scores.data.dtype, DataType, + grad_logits.data.dtype, DataType, fused_score_for_moe_aux_loss_backward_kernel_launcher( - reinterpret_cast(intermediate_output.data.dptr), - reinterpret_cast(grad_scores.data.dptr), num_tokens, num_experts, topk, + reinterpret_cast(intermediate_output.data.dptr), + reinterpret_cast(grad_scores.data.dptr), num_tokens, num_experts, topk, score_function, reinterpret_cast(grad_logits.data.dptr), stream);); } +} // namespace fused_router } // namespace transformer_engine void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_tokens, @@ -311,10 +340,10 @@ void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_ cudaStream_t stream) { NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_forward); using namespace transformer_engine; - fused_score_for_moe_aux_loss_forward(*convertNVTETensorCheck(logits), num_tokens, num_experts, - topk, score_function, *convertNVTETensorCheck(scores), - *convertNVTETensorCheck(routing_map), - *convertNVTETensorCheck(intermediate_output), stream); + fused_router::fused_score_for_moe_aux_loss_forward( + *convertNVTETensorCheck(logits), num_tokens, num_experts, topk, score_function, + *convertNVTETensorCheck(scores), *convertNVTETensorCheck(routing_map), + *convertNVTETensorCheck(intermediate_output), stream); } void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_output, @@ -323,7 +352,7 @@ void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_ou NVTETensor grad_logits, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_backward); using namespace transformer_engine; - fused_score_for_moe_aux_loss_backward( + fused_router::fused_score_for_moe_aux_loss_backward( *convertNVTETensorCheck(intermediate_output), *convertNVTETensorCheck(grad_scores), num_tokens, num_experts, topk, score_function, *convertNVTETensorCheck(grad_logits), stream); } diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index fbd6dcee6..513c9da0d 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -16,17 +16,16 @@ #include "utils.h" namespace transformer_engine { +namespace fused_router { template __global__ void fused_topk_with_score_function_forward_kernel( const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, int num_groups, int group_topk, float scaling_factor, int score_function, const BiasType *expert_bias, DataType *probs, bool *routing_map, - DataType *intermediate_output) { + CompType *intermediate_output) { /*** * Section: Global Variables/Addresses init - * - Assume the sizeof(DataType) >= sizeof(int), - * So DataType address is assigned firstly to avoid the alignment issue * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ @@ -35,24 +34,22 @@ __global__ void fused_topk_with_score_function_forward_kernel( int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; extern __shared__ float shmem[]; - DataType *scores_buf = reinterpret_cast(shmem); - DataType *topk_scores_buf = - reinterpret_cast(scores_buf + num_experts * num_token_per_block); - DataType *group_scores_buf = nullptr, *masked_scores_buf = nullptr; + CompType *scores_buf = reinterpret_cast(shmem); + CompType *topk_scores_buf = scores_buf + num_experts * num_token_per_block; + CompType *group_scores_buf = nullptr, *masked_scores_buf = nullptr; int *topk_indices_buf = nullptr; if (group_topk > 0) { - masked_scores_buf = reinterpret_cast(topk_scores_buf + topk * num_token_per_block); - group_scores_buf = - reinterpret_cast(masked_scores_buf + num_experts * num_token_per_block); + masked_scores_buf = topk_scores_buf + topk * num_token_per_block; + group_scores_buf = masked_scores_buf + num_experts * num_token_per_block; topk_indices_buf = reinterpret_cast(group_scores_buf + num_groups * num_token_per_block); } else { topk_indices_buf = reinterpret_cast(topk_scores_buf + topk * num_token_per_block); } // The address of buffers on the current warp - DataType *scores = scores_buf + warp_id * num_experts; - DataType *topk_scores = topk_scores_buf + warp_id * topk; - DataType *masked_scores = masked_scores_buf + warp_id * num_experts; - DataType *group_scores = group_scores_buf + warp_id * num_groups; + CompType *scores = scores_buf + warp_id * num_experts; + CompType *topk_scores = topk_scores_buf + warp_id * topk; + CompType *masked_scores = masked_scores_buf + warp_id * num_experts; + CompType *group_scores = group_scores_buf + warp_id * num_groups; int *topk_indices = topk_indices_buf + warp_id * topk; /*** @@ -74,10 +71,10 @@ __global__ void fused_topk_with_score_function_forward_kernel( int pos_offset = token_offset_cur_warp * num_experts; // Clear the probs/routing_map (num_experts) for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - probs[pos_offset + i] = 0.0f; + probs[pos_offset + i] = 0.0; routing_map[pos_offset + i] = false; if (score_function == 1) { - intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); + intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); } } // Load the logits to shmem @@ -87,7 +84,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( // If group_topk > 0, init the masked_scores to -inf if (group_topk > 0) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - masked_scores[i] = -std::numeric_limits::infinity(); + masked_scores[i] = -std::numeric_limits::infinity(); } } __threadfence_block(); @@ -98,11 +95,11 @@ __global__ void fused_topk_with_score_function_forward_kernel( * Possible preprocess the scores before the topk operation * - Pre-softmax * - Sigmoid + * - Sqrtsoftplus * - Expert bias * This is in-place scores update */ - // score_function == 1 means softmax - if (use_pre_softmax && score_function == 1) { + if (use_pre_softmax && score_function == 1) { // score_function == 1 means softmax // Apply softmax to the logits before the topk apply_softmax_on_float(scores, num_experts, lane_id); __syncwarp(); @@ -110,10 +107,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = scores[i]; } - } - - // score_function == 0 means sigmoid - if (score_function == 0) { + } else if (score_function == 0) { // score_function == 0 means sigmoid // Apply sigmoid to the logits apply_sigmoid_on_float(scores, num_experts, lane_id); __syncwarp(); @@ -121,18 +115,25 @@ __global__ void fused_topk_with_score_function_forward_kernel( for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = scores[i]; } + } else if (score_function == 2) { // score_function == 2 means sqrtsoftplus + // First save the original logits for backward (needed for sqrtsoftplus gradient computation) + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = scores[i]; // Save original logits + } + __syncwarp(); + // Apply sqrtsoftplus to the logits + apply_sqrtsoftplus_on_float(scores, num_experts, lane_id); } - __syncwarp(); //Confirm the scores is written to the softmax/sigmoid output + __syncwarp(); //Confirm the scores is written to the output - // Expert bias is only used at the sigmoid case - if (expert_bias && score_function == 0) { + // Expert bias is only used at the sigmoid/sqrtsoftplus case + if (expert_bias && (score_function == 0 || score_function == 2)) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - scores[i] = static_cast(static_cast(scores[i]) + - static_cast(expert_bias[i])); + scores[i] += static_cast(expert_bias[i]); } + __syncwarp(); } - __syncwarp(); /*** * Section: Topk @@ -142,7 +143,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( * - topk with expert bias */ // Topk on the scores - // The bias is not empty only happens at the sigmod case + // The bias being not empty happens at the sigmoid/sqrtsoftplus case if (group_topk > 0) { int group_size = num_experts / num_groups; // Top2 @@ -159,9 +160,9 @@ __global__ void fused_topk_with_score_function_forward_kernel( if (lane_id == 0) { //TODO: release after /opt/rocm/include/hip/amd_detail/amd_hip_bfloat16.h remove explict constructor restriction #ifdef __HIP_PLATFORM_AMD__ - DataType tmp(0.0f); + CompType tmp(0.0f); #else - DataType tmp = 0.0f; + CompType tmp = 0.0; #endif for (int j = 0; j < topk / group_topk; j++) { tmp = tmp + topk_scores[j]; @@ -201,17 +202,16 @@ __global__ void fused_topk_with_score_function_forward_kernel( * Possible postprocess the scores after the topk operation * - Revert Expert bias * - Softmax - * - Sigmoid post-processing when topk > 1 + * - Sigmoid/Sqrtsoftplus post-processing when topk > 1 * - Write the result with scaling_factor */ // Revert Expert bias from the topk scores - if (expert_bias && score_function == 0) { + if (expert_bias && (score_function == 0 || score_function == 2)) { for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - topk_scores[i] = - static_cast(topk_scores[i]) - static_cast(expert_bias[topk_indices[i]]); + topk_scores[i] = topk_scores[i] - static_cast(expert_bias[topk_indices[i]]); } + __syncwarp(); } - __syncwarp(); // score_function == 1 means softmax if (!use_pre_softmax && score_function == 1) { @@ -222,14 +222,15 @@ __global__ void fused_topk_with_score_function_forward_kernel( for (int i = lane_id; i < topk; i += kThreadsPerWarp) { intermediate_output[pos_offset + topk_indices[i]] = topk_scores[i]; } + __syncwarp(); } - // score_function == 0 means sigmoid - if (score_function == 0) { + // Sigmoid/Sqrtsoftplus post-processing when topk > 1 + if (score_function == 0 || score_function == 2) { if (topk > 1) { - double sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id); + CompType sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id); for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - topk_scores[i] = static_cast(topk_scores[i]) / (sum_scores + epsilon); + topk_scores[i] = topk_scores[i] / (sum_scores + epsilon); } } __syncwarp(); @@ -238,7 +239,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( // Write the probs/routing_map to the output tensor for (int i = lane_id; i < topk; i += kThreadsPerWarp) { routing_map[pos_offset + topk_indices[i]] = true; - probs[pos_offset + topk_indices[i]] = scaling_factor * static_cast(topk_scores[i]); + probs[pos_offset + topk_indices[i]] = scaling_factor * topk_scores[i]; } __threadfence_block(); __syncwarp(); @@ -249,16 +250,16 @@ template void fused_topk_with_score_function_forward_kernel_launcher( const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, int num_groups, int group_topk, float scaling_factor, int score_function, - const BiasType *expert_bias, DataType *probs, bool *routing_map, DataType *intermediate_output, + const BiasType *expert_bias, DataType *probs, bool *routing_map, CompType *intermediate_output, cudaStream_t stream) { size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // scores - + topk * num_token_per_block * sizeof(DataType) // topk_scores + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // scores + + topk * num_token_per_block * sizeof(CompType) // topk_scores + topk * num_token_per_block * sizeof(int); // topk_indices if (group_topk > 0) { - shared_memory_size += num_groups * num_token_per_block * sizeof(DataType); // group_scores - shared_memory_size += num_experts * num_token_per_block * sizeof(DataType); // maksed_scores + shared_memory_size += num_groups * num_token_per_block * sizeof(CompType); // group_scores + shared_memory_size += num_experts * num_token_per_block * sizeof(CompType); // maksed_scores } fused_topk_with_score_function_forward_kernel <<>>( @@ -283,13 +284,13 @@ void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, reinterpret_cast(expert_bias.data.dptr), reinterpret_cast(probs.data.dptr), reinterpret_cast(routing_map.data.dptr), - reinterpret_cast(intermediate_output.data.dptr), stream););); + reinterpret_cast(intermediate_output.data.dptr), stream););); } template __global__ void fused_topk_with_score_function_backward_kernel( // Inputs tensor - const bool *routing_map, const DataType *intermediate_output, const DataType *grad_probs, + const bool *routing_map, const CompType *intermediate_output, const DataType *grad_probs, // Other parameters int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, @@ -297,7 +298,6 @@ __global__ void fused_topk_with_score_function_backward_kernel( DataType *grad_logits) { /*** * Section: Global Variables/Addresses init - * - Assume the sizeof(DataType) >= sizeof(int), * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ @@ -306,18 +306,16 @@ __global__ void fused_topk_with_score_function_backward_kernel( int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; extern __shared__ float shmem[]; - DataType *grad_probs_buf = reinterpret_cast(shmem); - // To store the output of softmax/sigmoid from the fwd - DataType *act_from_fwd_buf = - reinterpret_cast(grad_probs_buf + num_experts * num_token_per_block); - DataType *comp_buf = - reinterpret_cast(act_from_fwd_buf + num_experts * num_token_per_block); + CompType *grad_probs_buf = reinterpret_cast(shmem); + // To store the output of softmax/sigmoid from fwd, or original logits for sqrtsoftplus + CompType *act_from_fwd_buf = grad_probs_buf + num_experts * num_token_per_block; + CompType *comp_buf = act_from_fwd_buf + num_experts * num_token_per_block; // To store the routing_map from the fwd bool *routing_map_buf = reinterpret_cast(comp_buf + num_experts * num_token_per_block); // The address of buffers on the current warp - DataType *local_grad = grad_probs_buf + warp_id * num_experts; - DataType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; - DataType *local_comp_buf = comp_buf + warp_id * num_experts; + CompType *local_grad = grad_probs_buf + warp_id * num_experts; + CompType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; + CompType *local_comp_buf = comp_buf + warp_id * num_experts; bool *local_routing_map = routing_map_buf + warp_id * num_experts; /*** @@ -337,10 +335,6 @@ __global__ void fused_topk_with_score_function_backward_kernel( * - Load the dgrad/output_from_fwd to shmem */ int pos_offset = token_offset_cur_warp * num_experts; - // Clear the logits_grad in global mem - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - grad_logits[pos_offset + i] = 0.0f; - } // Load the dgrad/output_from_fwd to shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_grad[i] = grad_probs[pos_offset + i]; @@ -353,48 +347,72 @@ __global__ void fused_topk_with_score_function_backward_kernel( /*** * Section: Backward of ops after the topk * - Backward of the used scaling_factor - * - Sigmoid Post-processing bwd when topk > 1 + * - Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 * - Softmax bwd if use_pre_softmax is false */ // Backward of the used scaling_factor // In-place update for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { if (local_routing_map[i]) { - local_grad[i] = static_cast(local_grad[i]) * scaling_factor; + local_grad[i] = local_grad[i] * scaling_factor; } } __syncwarp(); - // Sigmoid Post-processing bwd when topk > 1 - if (topk > 1 && score_function == 0) { - double sum_fwd_input = masked_warp_reduce_on_shmem( - /*data ptr = */ local_act_from_fwd, - /*mask ptr = */ local_routing_map, - /*data size = */ num_experts, - /*reduce func = */ ReduceFuncType::SUM, lane_id); - // Put the result of output * grad to the comp_buf + + // Sqrtsoftplus: First compute sqrtsoftplus output from original logits + // (needed for both post-processing bwd and activation bwd, compute once here) + // For sqrtsoftplus, intermediate_output stores original logits + if (score_function == 2) { + // Copy original logits to local_comp_buf and apply sqrtsoftplus in-place for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_comp_buf[i] = (local_routing_map[i] ? static_cast(local_grad[i]) * - static_cast(local_act_from_fwd[i]) - : 0.0f); + local_comp_buf[i] = local_act_from_fwd[i]; } __syncwarp(); - double sum_Output_x_Grad = masked_warp_reduce_on_shmem( - /*data ptr = */ local_comp_buf, + apply_sqrtsoftplus_on_float(local_comp_buf, num_experts, lane_id); + __syncwarp(); + } + + // Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) + if (topk > 1 && (score_function == 0 || score_function == 2)) { + // Select the correct activation output buffer: + // - Sigmoid: local_act_from_fwd already contains sigmoid output + // - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above + CompType *act_output = (score_function == 0) ? local_act_from_fwd : local_comp_buf; + + CompType sum_fwd_input = masked_warp_reduce_on_shmem( + /*data ptr = */ act_output, /*mask ptr = */ local_routing_map, /*data size = */ num_experts, /*reduce func = */ ReduceFuncType::SUM, lane_id); + // Compute sum of output * grad using registers + CompType local_sum_Output_x_Grad = 0.0; + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + if (local_routing_map[i]) { + local_sum_Output_x_Grad += local_grad[i] * act_output[i]; + } + } + // Warp reduce the sum + for (int s = 16; s > 0; s /= 2) { +#ifdef __HIP_PLATFORM_AMD__ + local_sum_Output_x_Grad += __shfl_xor(local_sum_Output_x_Grad, s, kThreadsPerWarp); +#else + local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s); +#endif + } + CompType sum_Output_x_Grad = local_sum_Output_x_Grad; // In-place update for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { if (local_routing_map[i]) { local_grad[i] = - static_cast(local_grad[i]) / (sum_fwd_input + epsilon) - + local_grad[i] / (sum_fwd_input + epsilon) - sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon)); } else { - local_grad[i] = 0.0f; + local_grad[i] = 0.0; } } + __syncwarp(); } - __syncwarp(); + // Softmax bwd if use_pre_softmax is false if (!use_pre_softmax && score_function == 1) { apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, local_routing_map, @@ -408,7 +426,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( */ for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { if (!local_routing_map[i]) { - local_grad[i] = 0.0f; + local_grad[i] = 0.0; } } __syncwarp(); @@ -417,6 +435,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( * Section: Backward of ops before the topk * - Pre-softmax bwd * - Sigmoid bwd + * - Sqrtsoftplus bwd * - Write the grad_logits to the global mem */ // Pre-softmax bwd @@ -430,6 +449,14 @@ __global__ void fused_topk_with_score_function_backward_kernel( apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id); __syncwarp(); } + // Sqrtsoftplus bwd + // For sqrtsoftplus, local_comp_buf already contains sqrtsoftplus output computed earlier + // Now compute gradient: dy/dx = sigmoid(x) / (2 * y) + if (score_function == 2) { + apply_sqrtsoftplus_bwd_on_float(local_grad, local_comp_buf, local_act_from_fwd, num_experts, + lane_id); + __syncwarp(); + } // Write the grad_logits to the global mem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { grad_logits[pos_offset + i] = local_grad[i]; @@ -440,16 +467,16 @@ __global__ void fused_topk_with_score_function_backward_kernel( template void fused_topk_with_score_function_backward_kernel_launcher( - const bool *routing_map, const DataType *intermediate_output, const DataType *grad_probs, + const bool *routing_map, const CompType *intermediate_output, const DataType *grad_probs, int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, DataType *grad_logits, cudaStream_t stream) { // Meta data for the kernel size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // grad_probs + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // grad_probs + - num_experts * num_token_per_block * sizeof(DataType) // act_from_fwd - + num_experts * num_token_per_block * sizeof(DataType) // comp_buf + num_experts * num_token_per_block * sizeof(CompType) // act_from_fwd + + num_experts * num_token_per_block * sizeof(CompType) // comp_buf + num_experts * num_token_per_block * sizeof(bool); // routing_map fused_topk_with_score_function_backward_kernel <<>>( @@ -468,12 +495,13 @@ void fused_topk_with_score_function_backward(const Tensor &routing_map, grad_logits.data.dtype, DataType, fused_topk_with_score_function_backward_kernel_launcher( reinterpret_cast(routing_map.data.dptr), - reinterpret_cast(intermediate_output.data.dptr), + reinterpret_cast(intermediate_output.data.dptr), reinterpret_cast(grad_probs.data.dptr), num_tokens, num_experts, topk, use_pre_softmax, scaling_factor, score_function, reinterpret_cast(grad_logits.data.dptr), stream);); } +} // namespace fused_router } // namespace transformer_engine void nvte_fused_topk_with_score_function_forward( @@ -483,7 +511,7 @@ void nvte_fused_topk_with_score_function_forward( NVTETensor intermediate_output, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_topk_with_score_function_forward); using namespace transformer_engine; - fused_topk_with_score_function_forward( + fused_router::fused_topk_with_score_function_forward( *convertNVTETensorCheck(logits), num_tokens, num_experts, topk, static_cast(use_pre_softmax), num_groups, group_topk, scaling_factor, score_function, *convertNVTETensorCheck(expert_bias), *convertNVTETensorCheck(probs), @@ -498,7 +526,7 @@ void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map, NVTETensor grad_logits, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_topk_with_score_function_backward); using namespace transformer_engine; - fused_topk_with_score_function_backward( + fused_router::fused_topk_with_score_function_backward( *convertNVTETensorCheck(routing_map), *convertNVTETensorCheck(intermediate_output), *convertNVTETensorCheck(grad_probs), num_tokens, num_experts, topk, static_cast(use_pre_softmax), scaling_factor, score_function, diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 3dcb593fc..7e9c74376 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -12,9 +12,15 @@ #include "transformer_engine/transformer_engine.h" namespace transformer_engine { +namespace fused_router { -#ifdef __HIP_PLATFORM_AMD__ -// TODO: remove after rocm supports NV __syncwarp equivalent +// Using FP32 to handle all the calculations. +// Currently, only FP32 is supported because +// 1. The score functions (sigmoid, softmax, sqrtsoftplus) are implemented in FP32. +// 2. The intermediate buffer is initialized in FP32. +using CompType = float; + +#if defined(__HIP_PLATFORM_AMD__) && __HIP_VERSION__ < 70000000 __device__ inline void __syncwarp() { __builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront"); @@ -22,7 +28,6 @@ __device__ inline void __syncwarp() __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront"); } - #endif constexpr size_t kThreadsPerWarp = 32; constexpr int kThreadsPerBlock = @@ -48,35 +53,30 @@ template __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncType type, int lane_id) { T (*reduce_func)(T, T); - double default_val = 0; + CompType default_val = 0.0; if (type == ReduceFuncType::SUM) { reduce_func = sum; - default_val = 0; + default_val = 0.0; } else if (type == ReduceFuncType::MAX) { reduce_func = max; - default_val = -std::numeric_limits::infinity(); + default_val = -std::numeric_limits::infinity(); } // Some value is hanlded in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread - volatile double val = lane_id < data_size ? static_cast(data_ptr[lane_id]) : default_val; + CompType val = lane_id < data_size ? data_ptr[lane_id] : default_val; for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { -//TODO: release after /opt/rocm/include/hip/amd_detail/amd_hip_bfloat16.h provide bf16 constructor from double -#ifdef __HIP_PLATFORM_AMD__ - val = reduce_func(static_cast(val), data_ptr[i]); -#else val = reduce_func(val, data_ptr[i]); -#endif } // Warp shuffle between threads #ifdef __HIP_PLATFORM_AMD__ - val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 16, kThreadsPerWarp))); - val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 8, kThreadsPerWarp))); - val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 4, kThreadsPerWarp))); - val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 2, kThreadsPerWarp))); - val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 1, kThreadsPerWarp))); + val = reduce_func(val, __shfl_xor(val, 16, kThreadsPerWarp)); + val = reduce_func(val, __shfl_xor(val, 8, kThreadsPerWarp)); + val = reduce_func(val, __shfl_xor(val, 4, kThreadsPerWarp)); + val = reduce_func(val, __shfl_xor(val, 2, kThreadsPerWarp)); + val = reduce_func(val, __shfl_xor(val, 1, kThreadsPerWarp)); #else val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16)); val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8)); @@ -88,44 +88,36 @@ __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncT return T(val); } -template -__device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, int lane_id) { - for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - scores[i] = static_cast(1.0f / (1.0f + exp(-static_cast(scores[i])))); - } -} - template __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size, ReduceFuncType type, int lane_id) { T (*reduce_func)(T, T); - double default_val = 0; + CompType default_val = 0.0; if (type == ReduceFuncType::SUM) { reduce_func = sum; - default_val = 0; + default_val = 0.0; } else if (type == ReduceFuncType::MAX) { reduce_func = max; - default_val = -std::numeric_limits::infinity(); + default_val = -std::numeric_limits::infinity(); } // Some value is hanlded in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread - volatile double val = - lane_id < data_size && mask[lane_id] ? static_cast(data_ptr[lane_id]) : default_val; + CompType val = lane_id < data_size && mask[lane_id] ? data_ptr[lane_id] : default_val; for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { if (mask[i]) { - val = reduce_func(static_cast(val), data_ptr[i]); + val = reduce_func(val, data_ptr[i]); } } // Warp shuffle between threads #ifdef __HIP_PLATFORM_AMD__ - val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 16, kThreadsPerWarp))); - val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 8, kThreadsPerWarp))); - val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 4, kThreadsPerWarp))); - val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 2, kThreadsPerWarp))); - val = reduce_func(static_cast(val), static_cast(__shfl_xor(val, 1, kThreadsPerWarp))); + val = reduce_func(val, __shfl_xor(val, 16, kThreadsPerWarp)); + val = reduce_func(val, __shfl_xor(val, 8, kThreadsPerWarp)); + val = reduce_func(val, __shfl_xor(val, 4, kThreadsPerWarp)); + val = reduce_func(val, __shfl_xor(val, 2, kThreadsPerWarp)); + val = reduce_func(val, __shfl_xor(val, 1, kThreadsPerWarp)); #else val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16)); val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8)); @@ -137,28 +129,70 @@ __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int dat return T(val); } -template -__device__ inline void apply_sigmoid_bwd_on_float(DataType *grad, DataType *fwd_output, - int data_size, int lane_id) { +__device__ inline void apply_sigmoid_on_float(float *scores, int data_size, int lane_id) { for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - grad[i] = static_cast(grad[i]) * static_cast(fwd_output[i]) * - (1 - static_cast(fwd_output[i])); + scores[i] = 1.0f / (1.0f + expf(-scores[i])); } } -template -__device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_output, - DataType *comp_buf, bool *mask, int data_size, +__device__ inline void apply_sigmoid_bwd_on_float(float *grad, float *fwd_output, int data_size, int lane_id) { + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + grad[i] = grad[i] * fwd_output[i] * (1.0f - fwd_output[i]); + } +} + +// sqrtsoftplus: y = sqrt(softplus(x)) = sqrt(log(1 + exp(x))) +__device__ inline void apply_sqrtsoftplus_on_float(float *scores, int data_size, int lane_id) { + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + float x = scores[i]; + // softplus(x) = log(1 + exp(x)), numerically stable version + // Matches PyTorch's Softplus(beta=1.0, threshold=20.0) + float softplus_val; + if (x > 20.0f) { + softplus_val = x; // for large x, softplus(x) ≈ x + } else { + softplus_val = log1pf(expf(x)); + } + scores[i] = sqrtf(softplus_val); + } +} + +// sqrtsoftplus backward: +// y = sqrt(softplus(x)) +// Matches PyTorch's Softplus(beta=1.0, threshold=20.0) +// We need the original logits (x) to compute the gradient +__device__ inline void apply_sqrtsoftplus_bwd_on_float(float *grad, float *fwd_output, + float *logits_buf, int data_size, + int lane_id) { + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + float x = logits_buf[i]; // original logit + float y = fwd_output[i]; // sqrtsoftplus output + float dy_dx; + if (x > 20.0f) { + // When softplus(x) = x, y = sqrt(x), dy/dx = 1/(2*y) + dy_dx = 1.0f / (2.0f * y + epsilon); + } else { + // When softplus(x) = log(1+exp(x)), dy/dx = sigmoid(x) / (2*y) + // where sigmoid(x) = 1 / (1 + exp(-x)) + float sigmoid_x = 1.0f / (1.0f + expf(-x)); + dy_dx = sigmoid_x / (2.0f * y + epsilon); + } + grad[i] = grad[i] * dy_dx; + } +} + +__device__ inline void apply_softmax_bwd_on_float(float *grad, float *fwd_output, float *comp_buf, + bool *mask, int data_size, int lane_id) { // Put the result of output * grad to the comp_buf for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { if (mask) { if (mask[i]) - comp_buf[i] = static_cast(grad[i]) * static_cast(fwd_output[i]); + comp_buf[i] = grad[i] * fwd_output[i]; else comp_buf[i] = 0.0f; } else { - comp_buf[i] = static_cast(grad[i]) * static_cast(fwd_output[i]); + comp_buf[i] = grad[i] * fwd_output[i]; } } __syncwarp(); @@ -170,40 +204,34 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_ for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { if (mask) { if (mask[i]) - grad[i] = - static_cast(fwd_output[i]) * (static_cast(grad[i]) - sum_Output_x_Grad); + grad[i] = fwd_output[i] * (grad[i] - sum_Output_x_Grad); else grad[i] = 0.0f; } else { - grad[i] = - static_cast(fwd_output[i]) * (static_cast(grad[i]) - sum_Output_x_Grad); + grad[i] = fwd_output[i] * (grad[i] - sum_Output_x_Grad); } } } -template -__device__ inline void apply_softmax_on_float(DataType *scores, int data_size, int lane_id) { +__device__ inline void apply_softmax_on_float(float *scores, int data_size, int lane_id) { // 1. compute the max of value - float max_val = - static_cast(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::MAX, lane_id)); + float max_val = warp_reduce_on_shmem(scores, data_size, ReduceFuncType::MAX, lane_id); // 2. value -> exp_value for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - scores[i] = static_cast(exp(static_cast(scores[i]) - max_val)); + scores[i] = expf(scores[i] - max_val); } __syncwarp(); // 3. compute the sum of exp_value - float sum_val = - static_cast(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::SUM, lane_id)); + float sum_val = warp_reduce_on_shmem(scores, data_size, ReduceFuncType::SUM, lane_id); // 4. update the softmax value for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - scores[i] = static_cast(scores[i]) / sum_val; + scores[i] = scores[i] / sum_val; } __syncwarp(); } -template -__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices, - T *topk_scores, int lane_id) { +__device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int topk, + int *topk_indices, CompType *topk_scores, int lane_id) { // Check if the index is masked by the later iteration auto is_masked = [&topk_indices](int k, int index) { if (k == 0) return false; @@ -217,16 +245,15 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i // After looping topk times, the topk_indices will be the topk indices for (int k = 0; k < topk; k++) { // Find the max value and its index - volatile double val = (lane_id < data_size && !is_masked(k, lane_id)) - ? static_cast(scores[lane_id]) - : -std::numeric_limits::infinity(); - volatile int index = (lane_id < data_size) ? lane_id : 0; + CompType val = (lane_id < data_size && !is_masked(k, lane_id)) + ? scores[lane_id] + : -std::numeric_limits::infinity(); + int index = (lane_id < data_size) ? lane_id : 0; // Some value is hanlded in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { - volatile double cur_val = (is_masked(k, i)) ? -std::numeric_limits::infinity() - : static_cast(scores[i]); + CompType cur_val = (is_masked(k, i)) ? -std::numeric_limits::infinity() : scores[i]; if (cur_val > val) { val = cur_val; index = i; @@ -235,11 +262,11 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i // Warp shuffle between threads for (int s = 16; s > 0; s /= 2) { #ifdef __HIP_PLATFORM_AMD__ - volatile auto shuffled_val = __shfl_xor(val, s, kThreadsPerWarp); - volatile auto shuffled_index = __shfl_xor(index, s, kThreadsPerWarp); + auto shuffled_val = __shfl_xor(val, s, kThreadsPerWarp); + auto shuffled_index = __shfl_xor(index, s, kThreadsPerWarp); #else - volatile auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); - volatile auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); + auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); + auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); #endif if (shuffled_val > val) { val = shuffled_val; @@ -255,46 +282,51 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i } // Current TE only support float32/bf16/fp16, float64 probs should be considered in the future -#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported router probs dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16."); \ } -#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kInt32: { \ - using type = int32_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kInt64: { \ - using type = int64_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kInt32: { \ + using type = int32_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt64: { \ + using type = int64_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported router index dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Int32, Int64, BFloat16, " \ + "Float32."); \ } +} // namespace fused_router } // namespace transformer_engine -#endif + +#endif // TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp index 2532e96bb..de533909f 100644 --- a/transformer_engine/common/gemm/config.cpp +++ b/transformer_engine/common/gemm/config.cpp @@ -126,3 +126,124 @@ void nvte_destroy_matmul_config(NVTEMatmulConfig config) { delete reinterpret_cast(config); } } + +NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config() { + return new transformer_engine::GroupedMatmulConfig; +} + +void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, + NVTEGroupedMatmulConfigAttribute attr, void *buf, + size_t size_in_bytes, size_t *size_written) { + // Write attribute size + NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes, + "Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); + NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); + const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr]; + *size_written = attr_size; + + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for grouped matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + + // bool size is implementation-dependent, so we explicitly specify + // uint8_t in the user-facing API. + auto bool_to_uint8 = [](bool in, void *out) { + *reinterpret_cast(out) = static_cast(in); + }; + + // Write to buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); + const auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEGroupedMatmulConfigAvgM: { + int64_t val = config_.avg_m.value_or(0); + std::memcpy(buf, &val, attr_size); + break; + } + case kNVTEGroupedMatmulConfigAvgN: { + int64_t val = config_.avg_n.value_or(0); + std::memcpy(buf, &val, attr_size); + break; + } + case kNVTEGroupedMatmulConfigAvgK: { + int64_t val = config_.avg_k.value_or(0); + std::memcpy(buf, &val, attr_size); + break; + } + case kNVTEGroupedMatmulConfigUseSplitAccumulator: + bool_to_uint8(config_.use_split_accumulator, buf); + break; + case kNVTEGroupedMatmulConfigSMCount: + std::memcpy(buf, &config_.sm_count, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, + NVTEGroupedMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes, + "Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); + const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for grouped matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); + + // bool size is implementation-dependent, so we explicitly specify + // uint8_t in the user-facing API. + auto uint8_to_bool = [](const void *in, bool &out) { + out = static_cast(*reinterpret_cast(in)); + }; + + // Read from buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); + auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEGroupedMatmulConfigAvgM: { + int64_t val; + std::memcpy(&val, buf, attr_size); + config_.avg_m = val; + break; + } + case kNVTEGroupedMatmulConfigAvgN: { + int64_t val; + std::memcpy(&val, buf, attr_size); + config_.avg_n = val; + break; + } + case kNVTEGroupedMatmulConfigAvgK: { + int64_t val; + std::memcpy(&val, buf, attr_size); + config_.avg_k = val; + break; + } + case kNVTEGroupedMatmulConfigUseSplitAccumulator: + uint8_to_bool(buf, config_.use_split_accumulator); + break; + case kNVTEGroupedMatmulConfigSMCount: + std::memcpy(&config_.sm_count, buf, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config) { + if (config != nullptr) { + delete reinterpret_cast(config); + } +} diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h index 86a617b5f..eed47e23d 100644 --- a/transformer_engine/common/gemm/config.h +++ b/transformer_engine/common/gemm/config.h @@ -9,6 +9,9 @@ #include +#include +#include + namespace transformer_engine { struct MatmulConfig { @@ -31,6 +34,25 @@ struct MatmulConfig { }; }; +struct GroupedMatmulConfig { + // Average dimension hints for cuBLASLt algorithm selection heuristics. + // nullopt means "not set" - compute automatically from tensor shapes. + std::optional avg_m; + std::optional avg_n; + std::optional avg_k; + + // Number of streaming multiprocessors to use in GEMM kernel + int sm_count = 0; + + // Split accumulator mode. Only taken into account on Hopper. + bool use_split_accumulator = false; + + // Note: API transfers the value type, not std::optional + static constexpr size_t attr_sizes[] = { + sizeof(decltype(avg_m)::value_type), sizeof(decltype(avg_n)::value_type), + sizeof(decltype(avg_k)::value_type), sizeof(sm_count), sizeof(uint8_t)}; +}; + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7326f330f..35cad5092 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -131,6 +131,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // Set conditions for MXFP8 and NVFP4 gemm execution. const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode); const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode); + int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling + if (is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) { + is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + } // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { @@ -140,7 +144,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; ret.lda = is_A_transposed ? k : m; - if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { + if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -151,7 +155,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), @@ -231,7 +235,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; - if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { + if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -242,7 +246,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), @@ -313,13 +317,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla return ret; } -/* cuBLAS version number at run-time */ -size_t cublas_version() { - // Cache version to avoid cuBLAS logging overhead - static size_t version = cublasLtGetVersion(); - return version; -} - } // namespace #endif // __HIP_PLATFORM_AMD__ @@ -524,8 +521,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #endif // CUBLAS_VERSION >= 120800 } else if (mxfp8_gemm) { #if CUBLAS_VERSION >= 120800 - NVTE_CHECK(cublas_version() >= 120800, - "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); + NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120800, + "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", + transformer_engine::cuda::cublas_version()); // Check that scales are in expected format NVTE_CHECK(inputA->with_gemm_swizzled_scales, @@ -547,7 +545,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. - if (cublas_version() <= 120803) { + if (transformer_engine::cuda::cublas_version() <= 120803) { const int64_t dummy_a_vec_stride = 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, @@ -559,8 +557,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #endif // CUBLAS_VERSION >= 120800 } else if (use_fp4) { // NVFP4 GEMM #if CUBLAS_VERSION >= 120800 - NVTE_CHECK(cublas_version() >= 120800, - "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); + NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120800, + "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", + transformer_engine::cuda::cublas_version()); // Check that scales are in expected format NVTE_CHECK(inputA->with_gemm_swizzled_scales, @@ -595,9 +594,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { #if CUBLAS_VERSION >= 120900 - NVTE_CHECK(cublas_version() >= 120900, + NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120900, "FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ", - cublas_version()); + transformer_engine::cuda::cublas_version()); // Check that matrix formats are valid NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && @@ -630,7 +629,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } #if CUBLAS_VERSION >= 120800 - if (cublas_version() >= 120800) { + if (transformer_engine::cuda::cublas_version() >= 120800) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a))); @@ -647,7 +646,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); #if CUBLAS_VERSION >= 120800 - if (cublas_version() >= 120800) { + if (transformer_engine::cuda::cublas_version() >= 120800) { // NOTE: In all current cases where FP8 output is supported, the input is // scaled identically to the output. NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -731,12 +730,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", CUBLAS_VERSION); #else - NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, + NVTE_CHECK(transformer_engine::cuda::cudart_version() >= 12020 && + transformer_engine::cuda::cudart_version() < 13000, "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ", - cuda::cudart_version()); - NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000, + transformer_engine::cuda::cudart_version()); + NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120205 && + transformer_engine::cuda::cublas_version() < 130000, "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", - cublas_version()); + transformer_engine::cuda::cublas_version()); if (m_split == 0) m_split = 1; if (n_split == 0) n_split = 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( @@ -966,10 +967,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ", transformer_engine::cuda::cudart_version()); NVTE_CHECK( - cublas_version() >= 120205 && cublas_version() < 130000, + transformer_engine::cuda::cublas_version() >= 120205 && + transformer_engine::cuda::cublas_version() < 130000, "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", - cublas_version()); -#endif + transformer_engine::cuda::cublas_version()); +#endif const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu new file mode 100644 index 000000000..50b36c058 --- /dev/null +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -0,0 +1,1435 @@ +/************************************************************************* +* This file was modified for portability to AMDGPU +* Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +* +* See LICENSE for license information. +************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/handle_manager.h" +#include "../util/logging.h" +#include "../util/vectorized_pointwise.h" +#include "./config.h" + +#ifndef __HIP_PLATFORM_AMD__ + +namespace { + +inline void CreateCublasHandle(cublasLtHandle_t *handle) { + NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); +} + +} // namespace + +// MXFP8 support for grouped GEMM requires cuBLAS 13.2+ +#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130200 + +#if CUBLAS_VERSION >= 130200 + +namespace { + +// Helper struct to pass per-tensor shape/offset info (pointer or uniform value) +struct TensorShapeInfo { + const int64_t *first_dims; // nullptr if uniform + const int64_t *last_dims; // nullptr if uniform + const int64_t *offsets; // nullptr if need to compute + int64_t uniform_first; // used if first_dims == nullptr + int64_t uniform_last; // used if last_dims == nullptr + + // Create from GroupedTensor + static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { + const bool has_first = t->first_dims.has_data(); + const bool has_last = t->last_dims.has_data(); + // When per-tensor dims are not provided, we must be in the uniform-shape case. + NVTE_CHECK(has_first || t->all_same_first_dim(), + "GroupedTensor is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || t->all_same_last_dim(), + "GroupedTensor is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; + + const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); + + return {first_ptr, last_ptr, + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) + : nullptr, + uniform_first, uniform_last}; + } + + // Create for C tensor (uses D's dimensions, only has offsets) + static TensorShapeInfo create_shape_info_for_C(const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D) { + const bool has_first = D->first_dims.has_data(); + const bool has_last = D->last_dims.has_data(); + NVTE_CHECK(has_first || D->all_same_first_dim(), + "GroupedTensor D is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || D->all_same_last_dim(), + "GroupedTensor D is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(D->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(D->last_dims.dptr) : nullptr; + const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); + + return {first_ptr, last_ptr, + C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) + : nullptr, + uniform_first, uniform_last}; + } +}; + +// Helper functions to compute average dimensions from logical_shape for heuristics +// These are hints for cuBLASLt algorithm selection, don't need to be exact +inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { + // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) + // In both cases, dividing by num_tensors gives the average + return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); +} + +inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { + if (t->all_same_last_dim()) { + // logical_shape[1] is the common N + return static_cast(t->logical_shape.data[1]); + } + // When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division. + return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); +} + +// Constants for grouped GEMM workspace (declared early for use in helpers) +static constexpr size_t kGroupedGemmAlignment = 256; +static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB + +// Workspace layout for grouped GEMM +struct GroupedGemmSetupWorkspace { + void **A_ptrs; + void **B_ptrs; + void **C_ptrs; + void **D_ptrs; + float **alpha_ptrs; + float **beta_ptrs; + void ** + a_scale_inv_ptrs; // Per-tensor FP8 scale pointers for A (float* for tensor scaling, E8M0* for MXFP8) + void ** + b_scale_inv_ptrs; // Per-tensor FP8 scale pointers for B (float* for tensor scaling, E8M0* for MXFP8) + // Storage dimensions for cuBLAS matrix layouts + int *a_rows; + int *a_cols; + int *b_rows; + int *b_cols; + int *d_rows; // M (first dim) - also used for C + int *d_cols; // N (last dim) - also used for C + + // Initialize from workspace buffer + // Layout: all pointer arrays first (16-byte aligned for cuBLAS), then int arrays + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { + GroupedGemmSetupWorkspace ws; + size_t offset = 0; + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + constexpr size_t kPtrAlignment = 16; // cuBLAS requires 16-byte alignment for pointer arrays + + // Helper to align offset to kPtrAlignment + auto align_offset = [&]() { + offset = (offset + kPtrAlignment - 1) / kPtrAlignment * kPtrAlignment; + }; + + // Pointer arrays first (all 16-byte aligned for cuBLAS grouped GEMM) + align_offset(); + ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + align_offset(); + ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + align_offset(); + ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + align_offset(); + ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + align_offset(); + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + align_offset(); + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + align_offset(); + ws.a_scale_inv_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + align_offset(); + ws.b_scale_inv_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + + // Int arrays for storage dimensions (4-byte aligned is fine) + align_offset(); + ws.a_rows = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.a_cols = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.b_rows = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.b_cols = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.d_rows = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.d_cols = reinterpret_cast(setup_ws_ptr + offset); + + return ws; + } + + // Calculate required size for setup workspace + static size_t required_setup_size(size_t num_tensors, size_t alignment) { + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + constexpr size_t kPtrAlignment = 16; // Must match from_buffers + + // Layout: 8 ptr arrays (each 16-byte aligned), then 6 int arrays + // Each ptr array takes ptr_size bytes but needs to start at 16-byte boundary + auto aligned_ptr_size = ((ptr_size + kPtrAlignment - 1) / kPtrAlignment) * kPtrAlignment; + size_t size = 8 * aligned_ptr_size + 6 * int_size; + size = ((size + alignment - 1) / alignment) * alignment; + return size; + } +}; + +inline size_t validate_grouped_gemm_inputs( + size_t num_tensors, std::initializer_list inputs, + const transformer_engine::Tensor *alpha_tensor, const transformer_engine::Tensor *beta_tensor) { + NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: number of tensors must be at least 1"); + for (const auto *tensor : inputs) { + NVTE_CHECK(tensor->num_tensors == num_tensors, + "Grouped GEMM: inputs must have the same number of tensors"); + } + + const size_t alpha_numel = alpha_tensor->data.numel(); + const size_t beta_numel = beta_tensor->data.numel(); + NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, + ") elements, got ", alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, + ") elements, got ", beta_numel); + + auto is_supported_input_dtype = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2 || + dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16; + }; + bool dtype_ok = true; + for (const auto *tensor : inputs) { + dtype_ok = dtype_ok && is_supported_input_dtype(tensor->dtype()); + } + NVTE_CHECK(dtype_ok, "Grouped GEMM inputs must be FP8, BF16, or FP16."); + for (const auto *tensor : inputs) { + NVTE_CHECK(tensor->has_data() || tensor->has_columnwise_data(), + "Grouped GEMM: input tensor is missing both row-wise and column-wise data"); + } + + // Cross-operand consistency across all inputs. + const auto *ref = *inputs.begin(); + const bool ref_is_fp8 = is_fp8_dtype(ref->dtype()); + const bool ref_is_mxfp8 = transformer_engine::is_mxfp_scaling(ref->scaling_mode); + for (const auto *tensor : inputs) { + NVTE_CHECK(is_fp8_dtype(tensor->dtype()) == ref_is_fp8, + "Grouped GEMM: A and B must both be FP8 or both be non-FP8."); + NVTE_CHECK(transformer_engine::is_mxfp_scaling(tensor->scaling_mode) == ref_is_mxfp8, + "Grouped GEMM: A and B must both use MXFP8 scaling or both use tensor scaling."); + if (ref_is_mxfp8) { + NVTE_CHECK(tensor->with_gemm_swizzled_scales, + "MXFP8 grouped GEMM: scales must be swizzled for GEMM."); + } + } + return num_tensors; +} + +inline void validate_grouped_gemm_outputs( + size_t num_tensors, std::initializer_list outputs) { + auto is_output_dtype = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16 || + dtype == transformer_engine::DType::kFloat32; + }; + for (const auto *tensor : outputs) { + if (tensor == nullptr) { + continue; + } + NVTE_CHECK(tensor->num_tensors == num_tensors, + "Grouped GEMM: outputs must have the same number of tensors as inputs"); + NVTE_CHECK(is_output_dtype(tensor->dtype()), + "Grouped GEMM: outputs must be BF16, FP16, or FP32."); + } +} + +inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { + return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); +} + +inline void check_grouped_gemm_requirements(const char *api_name) { + const int current_device = transformer_engine::cuda::current_device(); + NVTE_CHECK(transformer_engine::cuda::sm_arch(current_device) >= 100, api_name, + " requires Blackwell (SM100) or newer architecture."); + NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 130200, api_name, + " requires cuBLAS 13.2+, but run-time cuBLAS version is ", + transformer_engine::cuda::cublas_version()); +} + +inline transformer_engine::GroupedMatmulConfig parse_grouped_gemm_config( + NVTEGroupedMatmulConfig config) { + transformer_engine::GroupedMatmulConfig config_; + if (config != nullptr) { + config_ = *reinterpret_cast(config); + } + return config_; +} + +// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. +// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and +// fallback to column-wise data when row-wise is absent. +// Contains all information needed for GEMM setup - shape already accounts for storage layout. +struct GroupedOperandSelection { + TensorShapeInfo shape; // Shape info with dims already swapped for columnwise if needed + char *dptr = nullptr; + void *scale_inv = nullptr; // Contiguous array of scales (input) + transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + bool with_gemm_swizzled_scales = false; + bool trans = false; +}; + +constexpr int kMaxTensorsPerKernel = 64; +// Arguments for the grouped GEMM kernel that operates on multiple output tensors. +struct MultiTensorGroupGemmOutputArgs { + void *data_ptrs[kMaxTensorsPerKernel]; + int rows[kMaxTensorsPerKernel]; + int cols[kMaxTensorsPerKernel]; +}; + +// Arguments for the grouped GEMM kernel that operates on multiple inputA tensors. +struct MultiTensorGroupGemmInputArgs { + void *data_ptrs[kMaxTensorsPerKernel]; + void *scale_inv_ptrs[kMaxTensorsPerKernel]; + int rows[kMaxTensorsPerKernel]; + int cols[kMaxTensorsPerKernel]; +}; +struct MultiTensorListInfo { + bool all_row = true; + bool all_col = true; + transformer_engine::DType row_dtype = transformer_engine::DType::kNumTypes; + transformer_engine::DType col_dtype = transformer_engine::DType::kNumTypes; + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + bool with_gemm_swizzled_scales = false; +}; + +struct OperandStorageChoice { + bool use_rowwise = true; + bool swap_dims = true; + bool trans = false; +}; + +inline OperandStorageChoice choose_grouped_operand_storage(bool trans, bool is_A, bool is_mxfp8, + bool is_fp8, bool non_tn_fp8_ok, + bool has_row, bool has_col, + const char *name) { + NVTE_CHECK(has_row || has_col, "Grouped GEMM: ", name, + " is missing both row-wise and column-wise data"); + if (is_mxfp8) { + if (is_A) { + if (trans) { + NVTE_CHECK(has_row, "Grouped GEMM: MXFP8 transposed ", name, " is missing row-wise data"); + return {true, true, trans}; + } + NVTE_CHECK(has_col, "Grouped GEMM: MXFP8 non-transposed ", name, + " is missing column-wise data"); + return {false, false, trans}; + } + if (trans) { + NVTE_CHECK(has_col, "Grouped GEMM: MXFP8 transposed ", name, " is missing column-wise data"); + return {false, false, trans}; + } + NVTE_CHECK(has_row, "Grouped GEMM: MXFP8 non-transposed ", name, " is missing row-wise data"); + return {true, true, trans}; + } + + // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. + if (is_fp8 && !non_tn_fp8_ok) { + if (is_A && !trans) { + NVTE_CHECK(has_col, "Grouped GEMM: ", name, + " is missing column-wise data needed for FP8 TN layout"); + return {false, true, true}; + } + if (!is_A && trans) { + NVTE_CHECK(has_col, "Grouped GEMM: ", name, + " is missing column-wise data needed for FP8 TN layout"); + return {false, true, false}; + } + } + + // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). + if (!has_row && has_col) { + NVTE_CHECK(!is_fp8 || non_tn_fp8_ok, + "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose config."); + return {false, true, !trans}; + } + + NVTE_CHECK(has_row, "Grouped GEMM: ", name, " is missing row-wise data"); + return {true, true, trans}; +} + +// Build Kernel Arguments detailing out addresses and other metadata for list of C/D tensors +// passed to the grouped GEMM kernel. Use-case: C/D --> List of wgrads for experts in MOE +inline MultiTensorGroupGemmOutputArgs build_grouped_gemm_multi_out_args( + const NVTETensor *tensor_list, size_t list_size, size_t expected_num_tensors, + transformer_engine::DType expected_dtype, const char *name) { + MultiTensorGroupGemmOutputArgs args{}; + if (list_size == 0) { + NVTE_CHECK(tensor_list == nullptr, "Grouped GEMM: ", name, "_list provided with num_", name, + "_tensors=0"); + return args; + } + NVTE_CHECK(tensor_list != nullptr, "Grouped GEMM: ", name, "_list is null but num_", name, + "_tensors=", list_size); + NVTE_CHECK(list_size == expected_num_tensors, "Grouped GEMM: ", name, + "_list must have num_tensors (", expected_num_tensors, ") entries, got ", list_size); + NVTE_CHECK(list_size <= static_cast(kMaxTensorsPerKernel), "Grouped GEMM: ", name, + "_list supports up to ", kMaxTensorsPerKernel, " tensors per kernel, got ", list_size); + + for (size_t i = 0; i < list_size; ++i) { + const transformer_engine::Tensor *t = + transformer_engine::convertNVTETensorCheck(tensor_list[i]); + NVTE_CHECK(t->has_data(), "Grouped GEMM: ", name, "_list tensor ", i, " has no data"); + NVTE_CHECK(t->dtype() == expected_dtype, "Grouped GEMM: ", name, "_list tensor ", i, + " dtype mismatch. Expected ", transformer_engine::to_string(expected_dtype), " got ", + transformer_engine::to_string(t->dtype())); + const auto &shape = t->shape(); + NVTE_CHECK(shape.size() == 2, "Grouped GEMM: ", name, "_list tensor ", i, " must be 2D."); + args.data_ptrs[i] = t->data.dptr; + args.rows[i] = static_cast(shape[1]); + args.cols[i] = static_cast(shape[0]); + } + return args; +} + +// Build Kernel Arguments detailing out addresses and other metadata for list of A tensors +// passed to the grouped GEMM kernel. Use-case: A --> List of Expert weights +inline MultiTensorGroupGemmInputArgs build_grouped_gemm_multi_inputA_args( + const NVTETensor *tensor_list, size_t list_size, bool use_rowwise, bool is_fp8, + int64_t *avg_first_dim, int64_t *avg_last_dim, const char *name) { + using namespace transformer_engine; + MultiTensorGroupGemmInputArgs args{}; + *avg_first_dim = 0; + *avg_last_dim = 0; + if (list_size == 0) { + return args; + } + for (size_t i = 0; i < list_size; ++i) { + const transformer_engine::Tensor *t = + transformer_engine::convertNVTETensorCheck(tensor_list[i]); + const transformer_engine::SimpleTensor &data = use_rowwise ? t->data : t->columnwise_data; + const transformer_engine::SimpleTensor &scale_inv = + use_rowwise ? t->scale_inv : t->columnwise_scale_inv; + NVTE_CHECK(data.has_data(), "Grouped GEMM: ", name, "_list tensor ", i, + " is missing required data."); + NVTE_CHECK(data.shape.size() == 2, "Grouped GEMM: ", name, "_list tensor ", i, " must be 2D."); + args.data_ptrs[i] = data.dptr; + args.rows[i] = static_cast(data.shape[1]); + args.cols[i] = static_cast(data.shape[0]); + *avg_first_dim += static_cast(data.shape[0]); + *avg_last_dim += static_cast(data.shape[1]); + + if (is_fp8) { + NVTE_CHECK(scale_inv.has_data(), "Grouped GEMM: ", name, "_list tensor ", i, + " requires scale_inv for FP8."); + args.scale_inv_ptrs[i] = scale_inv.dptr; + } else { + args.scale_inv_ptrs[i] = nullptr; + } + } + *avg_first_dim /= static_cast(list_size); + *avg_last_dim /= static_cast(list_size); + return args; +} + +inline MultiTensorListInfo validate_grouped_gemm_multi_inputA_list(const NVTETensor *tensor_list, + size_t list_size, + size_t expected_num_tensors, + const char *name) { + using namespace transformer_engine; + MultiTensorListInfo info{}; + if (list_size == 0) { + NVTE_CHECK(tensor_list == nullptr, "Grouped GEMM: ", name, "_list provided with num_", name, + "_tensors=0"); + return info; + } + NVTE_CHECK(tensor_list != nullptr, "Grouped GEMM: ", name, "_list is null but num_", name, + "_tensors=", list_size); + NVTE_CHECK(list_size == expected_num_tensors, "Grouped GEMM: ", name, + "_list must have num_tensors (", expected_num_tensors, ") entries, got ", list_size); + NVTE_CHECK(list_size <= static_cast(kMaxTensorsPerKernel), "Grouped GEMM: ", name, + "_list supports up to ", kMaxTensorsPerKernel, " tensors per kernel, got ", list_size); + + const transformer_engine::Tensor *t0 = transformer_engine::convertNVTETensorCheck(tensor_list[0]); + info.scaling_mode = t0->scaling_mode; + info.with_gemm_swizzled_scales = t0->with_gemm_swizzled_scales; + const bool mxfp8 = transformer_engine::is_mxfp_scaling(info.scaling_mode); + NVTE_CHECK(info.scaling_mode == NVTE_DELAYED_TENSOR_SCALING || mxfp8, + "Grouped GEMM: input list only supports tensor scaling or MXFP8."); + + for (size_t i = 0; i < list_size; ++i) { + const transformer_engine::Tensor *t = + transformer_engine::convertNVTETensorCheck(tensor_list[i]); + NVTE_CHECK(t->scaling_mode == info.scaling_mode, "Grouped GEMM: ", name, + "_list tensors must share the same scaling mode."); + NVTE_CHECK(t->with_gemm_swizzled_scales == info.with_gemm_swizzled_scales, + "Grouped GEMM: ", name, "_list tensors must share GEMM swizzled scale state."); + + if (t->has_data()) { + if (info.row_dtype == DType::kNumTypes) { + info.row_dtype = t->data.dtype; + } + // Check all tensors have the same dtype + NVTE_CHECK(t->data.dtype == info.row_dtype, "Grouped GEMM: ", name, + "_list rowwise dtypes must match."); + } else { + // All tensors must have either data or columnwise data + info.all_row = false; + } + + if (t->has_columnwise_data()) { + if (info.col_dtype == DType::kNumTypes) { + info.col_dtype = t->columnwise_data.dtype; + } + NVTE_CHECK(t->columnwise_data.dtype == info.col_dtype, "Grouped GEMM: ", name, + "_list columnwise dtypes must match."); + } else { + // All tensors must have either data or columnwise data + info.all_col = false; + } + } + + return info; +} + +// Helper to create TensorShapeInfo from a GroupedTensor, optionally swapping first/last dims. +// When swap_dims=true, first_dims and last_dims are swapped to account for columnwise storage. +// Note: tensor_offsets are the same for rowwise and columnwise data (same element count per tensor). +inline TensorShapeInfo create_shape_info(const transformer_engine::GroupedTensor *t, + bool swap_dims) { + const bool has_first = t->first_dims.has_data(); + const bool has_last = t->last_dims.has_data(); + NVTE_CHECK(has_first || t->all_same_first_dim(), + "GroupedTensor is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || t->all_same_last_dim(), + "GroupedTensor is missing last_dims for varying shapes"); + + const int64_t *first_ptr = has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; + const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); + + const int64_t *offsets_ptr = + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr; + + if (swap_dims) { + // Swap first/last to account for columnwise (transposed) storage + return {last_ptr, first_ptr, offsets_ptr, uniform_last, uniform_first}; + } + return {first_ptr, last_ptr, offsets_ptr, uniform_first, uniform_last}; +} + +inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, + bool trans, bool is_A) { + using namespace transformer_engine; + const bool has_row = t->has_data(); + const bool has_col = t->has_columnwise_data(); + NVTE_CHECK(has_row || has_col, + "Grouped GEMM operand is missing both row-wise and column-wise data"); + + const auto sm = t->scaling_mode; + const bool mxfp8 = is_mxfp_scaling(sm); + + // Validate scaling mode + NVTE_CHECK(sm == NVTE_DELAYED_TENSOR_SCALING || mxfp8, + "Grouped GEMM is only supported with bf16, fp8 tensor scaling and MXFP8"); + + const DType row_dtype = t->data.dtype; + const DType col_dtype = t->columnwise_data.dtype; + GroupedOperandSelection sel{}; + sel.trans = trans; + sel.scaling_mode = sm; + sel.with_gemm_swizzled_scales = t->with_gemm_swizzled_scales; + + const DType rep_dtype = has_row ? row_dtype : col_dtype; + const bool is_fp8 = is_fp8_dtype(rep_dtype); + const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); + + // Helper to select columnwise storage. + // swap_dims=true (default): swap first/last dims in shape info (used when columnwise == transposed). + // swap_dims=false: keep original dims (MXFP8: columnwise data has different scale direction, + // but the logical matrix shape and transpose flag remain unchanged). + auto use_columnwise = [&](bool swap_dims = true) { + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.scale_inv = t->columnwise_scale_inv.dptr; + sel.dtype = col_dtype; + sel.shape = create_shape_info(t, swap_dims); + }; + + // Helper to select row-wise storage + auto use_rowwise = [&]() { + sel.dptr = static_cast(t->data.dptr); + sel.scale_inv = t->scale_inv.dptr; + sel.dtype = row_dtype; + sel.shape = create_shape_info(t, /*swap_dims=*/false); + }; + + const auto choice = choose_grouped_operand_storage(trans, is_A, mxfp8, is_fp8, non_tn_fp8_ok, + has_row, has_col, is_A ? "A" : "B"); + sel.trans = choice.trans; + if (choice.use_rowwise) { + use_rowwise(); + } else { + use_columnwise(choice.swap_dims); + } + return sel; +} + +inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size, + const char *workspace_name) { + NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); + const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); + NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name, + ". Required: ", required_size, " bytes, Available: ", provided_size, " bytes."); + return ws->data.dptr; +} + +inline void init_matrix_layouts( + cublasLtMatrixLayoutOpaque_t &descA, cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, cublasLtMatrixLayoutOpaque_t &descD, + const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, transformer_engine::DType d_dtype, size_t num_tensors) { + const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); + const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); + const cudaDataType_t D_type = get_cuda_dtype(d_dtype); + + // Storage dimensions computed by kernel, leading dimension = rows + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, ws.a_rows, + ws.a_cols, ws.a_rows)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, ws.b_rows, + ws.b_cols, ws.b_rows)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.d_rows, + ws.d_cols, ws.d_rows)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.d_rows, + ws.d_cols, ws.d_rows)); +} + +inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, + cublasOperation_t op_B, bool use_fp8, bool use_split_accumulator) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, + sizeof(op_A))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, + sizeof(op_B))); + + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, sizeof(pointer_mode))); + + int64_t alphabeta_batch_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); + + // Fast accumulation is only supported for FP8 (mirrors non-grouped GEMM logic). + int8_t fastAccuMode = use_split_accumulator ? 0 : static_cast(use_fp8); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &fastAccuMode, sizeof(fastAccuMode))); +} + +// Configures cuBLAS for MXFP8 grouped GEMM: sets VEC32_UE8M0 scale mode and scale pointers +// for both A and B. +inline void set_mxfp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, + void **a_scale_inv_ptrs, void **b_scale_inv_ptrs) { +#if CUBLAS_VERSION >= CUBLAS_MXFP8_GROUPED_GEMM_VERSION + NVTE_CHECK(transformer_engine::cuda::cublas_version() >= CUBLAS_MXFP8_GROUPED_GEMM_VERSION, + "MXFP8 grouped GEMM requires cuBLAS ", CUBLAS_MXFP8_GROUPED_GEMM_VERSION, + "+, but run-time cuBLAS version is ", transformer_engine::cuda::cublas_version()); + const cublasLtMatmulMatrixScale_t scale_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, + &scale_mode, sizeof(scale_mode))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, + &scale_mode, sizeof(scale_mode))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &a_scale_inv_ptrs, sizeof(a_scale_inv_ptrs))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &b_scale_inv_ptrs, sizeof(b_scale_inv_ptrs))); +#else + NVTE_CHECK(false, "MXFP8 grouped GEMM requires cuBLAS ", CUBLAS_MXFP8_GROUPED_GEMM_VERSION, + "+, but compile-time cuBLAS version is ", CUBLAS_VERSION); +#endif // CUBLAS_VERSION >= CUBLAS_MXFP8_GROUPED_GEMM_VERSION +} + +// Configures cuBLAS for tensor-scaling FP8 grouped GEMM: sets PER_BATCH_SCALAR_32F scale mode +// and scale pointers for A and B. Both operands are guaranteed FP8 by the caller. +inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, void **a_scale_inv_ptrs, + void **b_scale_inv_ptrs) { + const cublasLtMatmulMatrixScale_t scale_mode = CUBLASLT_MATMUL_MATRIX_SCALE_PER_BATCH_SCALAR_32F; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, + &scale_mode, sizeof(scale_mode))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &a_scale_inv_ptrs, sizeof(a_scale_inv_ptrs))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, + &scale_mode, sizeof(scale_mode))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &b_scale_inv_ptrs, sizeof(b_scale_inv_ptrs))); +} +inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, + cublasLtMatmulDescOpaque_t &matmulDesc, + cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, + int64_t avg_m, int64_t avg_n, int64_t avg_k) { + cublasLtMatmulPreferenceOpaque_t preference; + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); + NVTE_CHECK_CUBLAS( + cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &kGroupedGemmCublasWorkspaceSize, sizeof(size_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t))); + + cublasLtMatmulHeuristicResult_t heuristicResult; + int returnedResults = 0; + auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, + &preference, 1, &heuristicResult, &returnedResults); + NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, + "Unable to find suitable cuBLAS grouped GEMM algorithm"); + NVTE_CHECK_CUBLAS(status); + NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); + return heuristicResult.algo; +} + +struct GroupedGemmWorkspace { + GroupedGemmSetupWorkspace setup_workspace; + void *cublas_workspace_ptr = nullptr; + size_t num_tensors = 0; +}; + +inline GroupedGemmWorkspace setup_grouped_gemm_workspace(transformer_engine::Tensor *wspace_setup, + transformer_engine::Tensor *wspace_cublas, + size_t num_tensors) { + const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); + const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; + void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, + "Grouped GEMM setup workspace"); + void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, + "Grouped GEMM cuBLAS workspace"); + auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( + static_cast(setup_workspace_ptr), num_tensors); + return {std::move(setup_workspace), cublas_workspace_ptr, num_tensors}; +} + +inline void execute_grouped_gemm(const GroupedGemmSetupWorkspace &setup_workspace, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, + transformer_engine::DType d_dtype, size_t num_tensors, + bool use_split_accumulator, bool use_fp8, int64_t avg_m_val, + int64_t avg_n_val, int64_t avg_k_val, void *cublas_workspace_ptr, + cudaStream_t stream) { + using cublasHandleManager = + transformer_engine::detail::HandleManager; + cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); + + cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + + cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, d_dtype, + num_tensors); + + cublasLtMatmulDescOpaque_t matmulDesc; + init_matmul_desc(matmulDesc, op_A, op_B, use_fp8, use_split_accumulator); + if (transformer_engine::is_mxfp_scaling(A_sel.scaling_mode)) { + set_mxfp8_scale_pointers(matmulDesc, setup_workspace.a_scale_inv_ptrs, + setup_workspace.b_scale_inv_ptrs); + } else if (use_fp8) { + set_fp8_scale_pointers(matmulDesc, setup_workspace.a_scale_inv_ptrs, + setup_workspace.b_scale_inv_ptrs); + } + + cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, + descD, avg_m_val, avg_n_val, avg_k_val); + + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, + setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, + setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, + setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, + kGroupedGemmCublasWorkspaceSize, stream)); +} + +// Device helper: compute the element offset for tensor `idx` given shape metadata. +// Three cases: +// 1. Explicit per-tensor offset array provided → use it directly. +// 2. Per-tensor first/last dims provided but no offsets → cumulative sum of (first*last) products. +// 3. Fully uniform shapes → idx * uniform_first * uniform_last. +__forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorShapeInfo &meta, + size_t idx) { + if (meta.offsets) { + return meta.offsets[idx]; + } else if (meta.first_dims != nullptr || meta.last_dims != nullptr) { + // offset[i] = sum_{j < i} (first_dims[j] * last_dims[j]) + int64_t cumsum = 0; + for (size_t i = 0; i < idx; i++) { + int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first; + int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last; + cumsum += f * l; + } + return cumsum; + } else { + return static_cast(idx) * meta.uniform_first * meta.uniform_last; + } +} + +// Kernel that performs bias addition to the Grouped GEMM output tensors. +// Bias itself is a grouped tensor with the collections of same number of tensors +// as the output tensors. +template +__global__ void grouped_bias_add_kernel(char *d_base, const char *bias_base, TensorShapeInfo d_meta, + TensorShapeInfo bias_meta, size_t num_tensors) { + const size_t tensor_idx = blockIdx.x; + if (tensor_idx >= num_tensors) return; + + const int64_t m = d_meta.first_dims ? d_meta.first_dims[tensor_idx] : d_meta.uniform_first; + const int64_t n = d_meta.last_dims ? d_meta.last_dims[tensor_idx] : d_meta.uniform_last; + if (m == 0 || n == 0) return; + + const int64_t d_offset = compute_grouped_tensor_offset(d_meta, tensor_idx); + const int64_t bias_offset = compute_grouped_tensor_offset(bias_meta, tensor_idx); + + auto *d_ptr = reinterpret_cast(d_base + d_offset * sizeof(T)); + const auto *bias_ptr = reinterpret_cast(bias_base + bias_offset * sizeof(T)); + + const int64_t elements = m * n; + const int64_t vec_count = elements / kVec; + using VecStorage = transformer_engine::VectorizedStorage; + using VecType = typename VecStorage::LType; + transformer_engine::VectorizedLoader loader(d_ptr, elements); + transformer_engine::VectorizedStorer storer(d_ptr, elements); + const int64_t vec_id = static_cast(blockIdx.y) * blockDim.x + threadIdx.x; + if (vec_id >= vec_count) return; + const int64_t vec_start = vec_id * kVec; + const int64_t col = vec_start % n; + loader.load(vec_id, elements); + const auto *b_vec = reinterpret_cast(bias_ptr + col); + VecStorage b_in; + b_in.scratch_.aligned = *b_vec; +#pragma unroll + for (int i = 0; i < kVec; ++i) { + storer.separate()[i] = loader.separate()[i] + b_in.scratch_.separate[i]; + } + storer.store(vec_id, elements); +} + +// Single kernel that sets up all GEMM parameters. +// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions, +// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. +// We bridge the mismatch on GPU by computing per-group pointers and storage dims in one kernel. +__global__ void setup_grouped_gemm_kernel( + // Output arrays + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *a_rows, int *a_cols, + int *b_rows, int *b_cols, int *d_rows, int *d_cols, float **alpha_ptrs, float **beta_ptrs, + void **a_scale_inv_ptrs, void **b_scale_inv_ptrs, + // Inputs + char *a_base, char *b_base, char *c_base, char *d_base, TensorShapeInfo A_meta, + TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, size_t a_elem_size, + size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, float *alpha_ptr, float *beta_ptr, + // Scale inputs: for tensor scaling, pass float* and set mxfp8_base to nullptr + // For MXFP8, pass nullptr for tensor_scale and set mxfp8_base + float *a_scale_base, float *b_scale_base, NVTEScalingMode scaling_mode, size_t num_tensors, + MultiTensorGroupGemmInputArgs a_multi_tensor_args, + MultiTensorGroupGemmOutputArgs c_multi_tensor_args, + MultiTensorGroupGemmOutputArgs d_multi_tensor_args) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_tensors) return; + + // Get dimensions for this tensor (from array or uniform value) + const bool has_a_multi_tensor = (a_base == nullptr); + const bool has_c_multi_tensor = (c_base == nullptr); + const bool has_d_multi_tensor = (d_base == nullptr); + int64_t a_first = 0; + int64_t a_last = 0; + if (has_a_multi_tensor) { + a_first = static_cast(a_multi_tensor_args.cols[idx]); + a_last = static_cast(a_multi_tensor_args.rows[idx]); + } else { + a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first; + a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; + } + int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; + int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; + int64_t d_first = D_meta.first_dims ? D_meta.first_dims[idx] : D_meta.uniform_first; + int64_t d_last = D_meta.last_dims ? D_meta.last_dims[idx] : D_meta.uniform_last; + + // Compute offsets (from explicit array, cumulative from per-tensor dims, or uniform) + int64_t a_offset = has_a_multi_tensor ? 0 : compute_grouped_tensor_offset(A_meta, idx); + int64_t b_offset = compute_grouped_tensor_offset(B_meta, idx); + int64_t c_offset = compute_grouped_tensor_offset(C_meta, idx); + int64_t d_offset = compute_grouped_tensor_offset(D_meta, idx); + + // Compute data pointers + A_ptrs[idx] = + has_a_multi_tensor ? a_multi_tensor_args.data_ptrs[idx] : (a_base + a_offset * a_elem_size); + B_ptrs[idx] = b_base + b_offset * b_elem_size; + C_ptrs[idx] = + has_c_multi_tensor ? c_multi_tensor_args.data_ptrs[idx] : (c_base + c_offset * c_elem_size); + D_ptrs[idx] = + has_d_multi_tensor ? d_multi_tensor_args.data_ptrs[idx] : (d_base + d_offset * d_elem_size); + + // Compute storage dimensions for cuBLAS matrix layouts. + // For INPUTS (A, B): Row-wise storage is seen as transposed column-major by cuBLAS, + // so rows=last, cols=first. For columnwise, dims are already swapped. + a_rows[idx] = static_cast(a_last); + a_cols[idx] = static_cast(a_first); + b_rows[idx] = static_cast(b_last); + b_cols[idx] = static_cast(b_first); + if (has_d_multi_tensor) { + d_rows[idx] = d_multi_tensor_args.rows[idx]; + d_cols[idx] = d_multi_tensor_args.cols[idx]; + } else { + d_rows[idx] = static_cast(d_last); + d_cols[idx] = static_cast(d_first); + } + + // Fill alpha/beta pointers (per-matrix) + alpha_ptrs[idx] = alpha_ptr + idx; + beta_ptrs[idx] = beta_ptr + idx; + + // Fill scale pointers (per-matrix). + // The interpretation of the scale buffers depends on the shared scaling recipe: + // NVTE_MXFP8_1D_SCALING : E8M0 byte stream; offset = data_offset / 32 elements + // otherwise : one float per tensor, indexed by tensor index + if (a_scale_base) { + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + a_scale_inv_ptrs[idx] = reinterpret_cast( + static_cast(static_cast(a_scale_base)) + a_offset / 32); + } else { + a_scale_inv_ptrs[idx] = static_cast(a_scale_base) + idx; + } + } else { + a_scale_inv_ptrs[idx] = a_multi_tensor_args.scale_inv_ptrs[idx]; + } + if (b_scale_base) { + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + b_scale_inv_ptrs[idx] = reinterpret_cast( + static_cast(static_cast(b_scale_base)) + b_offset / 32); + } else { + b_scale_inv_ptrs[idx] = static_cast(b_scale_base) + idx; + } + } +} + +// Launch the setup kernel to populate workspace arrays +inline void launch_grouped_gemm_setup( + const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream, + const MultiTensorGroupGemmInputArgs &a_multi_tensor_args, const NVTETensor *C_list, + const NVTETensor *D_list, char *a_base, transformer_engine::DType c_dtype, + transformer_engine::DType d_dtype) { + // Use shape info from selection (already accounts for columnwise dimension swap) + TensorShapeInfo A_meta = A_sel.shape; + TensorShapeInfo B_meta = B_sel.shape; + TensorShapeInfo C_meta{}; + TensorShapeInfo D_meta{}; + + const bool has_d_multi_tensor = (D_list != nullptr); + const bool has_c_multi_tensor = (C_list != nullptr) || has_d_multi_tensor; + MultiTensorGroupGemmOutputArgs c_multi_tensor_args{}; + MultiTensorGroupGemmOutputArgs d_multi_tensor_args{}; + if (has_d_multi_tensor) { + d_multi_tensor_args = + build_grouped_gemm_multi_out_args(D_list, num_tensors, num_tensors, d_dtype, "D"); + } + if (C_list != nullptr) { + c_multi_tensor_args = + build_grouped_gemm_multi_out_args(C_list, num_tensors, num_tensors, d_dtype, "C"); + } else if (has_d_multi_tensor) { + c_multi_tensor_args = d_multi_tensor_args; + } + + char *c_base = nullptr; + char *d_base = nullptr; + + if (!has_c_multi_tensor) { + NVTE_CHECK(C != nullptr && D != nullptr, + "Grouped GEMM: C/D grouped tensors are required when no C list is provided"); + C_meta = TensorShapeInfo::create_shape_info_for_C(C, D); + c_base = static_cast(C->data.dptr); + } + if (!has_d_multi_tensor) { + NVTE_CHECK(D != nullptr, + "Grouped GEMM: D grouped tensor is required when no D list is provided"); + D_meta = TensorShapeInfo::from_tensor(D); + d_base = static_cast(D->data.dptr); + } + + const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); + const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); + const size_t c_elem_size = transformer_engine::typeToSize(c_dtype); + const size_t d_elem_size = transformer_engine::typeToSize(d_dtype); + + const int threads_per_block = 256; + const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; + + // A and B share the same scaling recipe (validated in validate_grouped_gemm_inputs). + // Pass scale buffers as void* and let the kernel interpret them via scaling_mode. + setup_grouped_gemm_kernel<<>>( + ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.a_rows, ws.a_cols, ws.b_rows, ws.b_cols, + ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, ws.a_scale_inv_ptrs, ws.b_scale_inv_ptrs, + A_sel.dptr, B_sel.dptr, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, + b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), reinterpret_cast(A_sel.scale_inv), + reinterpret_cast(B_sel.scale_inv), A_sel.scaling_mode, num_tensors, + a_multi_tensor_args, c_multi_tensor_args, d_multi_tensor_args); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace + +size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors) { + NVTE_API_CALL(nvte_get_grouped_gemm_setup_workspace_size); + return grouped_gemm_setup_workspace_size(num_tensors); +} + +void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, + const NVTETensor beta, NVTETensor workspace_setup, + NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_grouped_gemm); + using namespace transformer_engine; + + // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.2+ + check_grouped_gemm_requirements("nvte_grouped_gemm"); + + // Convert to internal types + const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); + const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); + const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL + GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); + const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); + const Tensor *beta_tensor = convertNVTETensorCheck(beta); + Tensor *wspace_setup = convertNVTETensor(workspace_setup); + Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); + + // Parse config (if provided) + GroupedMatmulConfig config_ = parse_grouped_gemm_config(config); + + // Validate inputs and outputs. + const size_t num_tensors = validate_grouped_gemm_inputs(inputA->num_tensors, {inputA, inputB}, + alpha_tensor, beta_tensor); + validate_grouped_gemm_outputs(num_tensors, {inputC_raw, outputD}); + + // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) + const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; + // num_tensors validated above. + // Select operand storage (row-wise vs column-wise) and adjust transpose flags to + // mirror the non-grouped GEMM logic for FP8 layout constraints. + auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); + auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); + + // Workspaces: setup (pointer arrays) and cuBLAS + auto workspace = setup_grouped_gemm_workspace(wspace_setup, wspace_cublas, num_tensors); + + MultiTensorGroupGemmInputArgs a_multi_tensor_args{}; + launch_grouped_gemm_setup(workspace.setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, + beta_tensor, num_tensors, stream, a_multi_tensor_args, + /*C_list=*/nullptr, /*D_list=*/nullptr, A_sel.dptr, inputC->dtype(), + outputD->dtype()); + + // Compute average dimensions for heuristics + // K dimension: if transa, K is A's first dim; if not, K is A's last dim + // Use original inputA and transa for heuristics (not modified A_sel.trans) + int64_t avg_m_val = config_.avg_m.value_or(compute_avg_first_dim(outputD)); + int64_t avg_n_val = config_.avg_n.value_or(compute_avg_last_dim(outputD)); + int64_t avg_k_val = + config_.avg_k.value_or(transa ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); + const bool use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); + execute_grouped_gemm(workspace.setup_workspace, A_sel, B_sel, outputD->dtype(), num_tensors, + config_.use_split_accumulator, use_fp8, avg_m_val, avg_n_val, avg_k_val, + workspace.cublas_workspace_ptr, stream); +} + +void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num_a_tensors, + int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, + const NVTETensor alpha, const NVTETensor beta, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEGroupedMatmulConfig config, cudaStream_t stream) { + NVTE_API_CALL(nvte_grouped_gemm_with_discrete_inputA); + using namespace transformer_engine; + + // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.2+ + check_grouped_gemm_requirements("nvte_grouped_gemm_with_discrete_inputA"); + + NVTE_CHECK(A_list != nullptr, "Grouped GEMM: A_list is null."); + NVTE_CHECK(num_a_tensors > 0, "Grouped GEMM: num_a_tensors must be > 0."); + + const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); + const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL + GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); + const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); + const Tensor *beta_tensor = convertNVTETensorCheck(beta); + Tensor *wspace_setup = convertNVTETensor(workspace_setup); + Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); + + // Parse config (if provided) + GroupedMatmulConfig config_ = parse_grouped_gemm_config(config); + + // Validate inputs and outputs. + const size_t num_tensors = + validate_grouped_gemm_inputs(num_a_tensors, {inputB}, alpha_tensor, beta_tensor); + validate_grouped_gemm_outputs(num_tensors, {inputC_raw, outputD}); + + // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) + const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; + + // Validate A list and selection + auto A_list_info = + validate_grouped_gemm_multi_inputA_list(A_list, num_a_tensors, num_tensors, "A"); + auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2 || + dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16; + }; + NVTE_CHECK(is_fp8_or_16bit(A_list_info.all_row ? A_list_info.row_dtype : A_list_info.col_dtype), + "Grouped GEMM: A_list tensors must be FP8, BF16, or FP16."); + + // Cross-operand consistency (mirrors validate_grouped_gemm_inputs). + const DType a_rep_dtype = A_list_info.all_row ? A_list_info.row_dtype : A_list_info.col_dtype; + NVTE_CHECK(is_fp8_dtype(a_rep_dtype) == is_fp8_dtype(inputB->dtype()), + "Grouped GEMM: A and B must both be FP8 or both be non-FP8."); + NVTE_CHECK(transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode) == + transformer_engine::is_mxfp_scaling(inputB->scaling_mode), + "Grouped GEMM: A and B must both use MXFP8 scaling or both use tensor scaling."); + if (transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode)) { + NVTE_CHECK(A_list_info.with_gemm_swizzled_scales, + "MXFP8 grouped GEMM: A scales must be swizzled for GEMM."); + NVTE_CHECK(inputB->with_gemm_swizzled_scales, + "MXFP8 grouped GEMM: B scales must be swizzled for GEMM."); + } + + // Select operand storage for B (row-wise vs column-wise) + auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); + + GroupedOperandSelection A_sel{}; + A_sel.scaling_mode = A_list_info.scaling_mode; + A_sel.with_gemm_swizzled_scales = A_list_info.with_gemm_swizzled_scales; + A_sel.trans = static_cast(transa); + + const DType rep_dtype = A_list_info.all_row ? A_list_info.row_dtype : A_list_info.col_dtype; + const bool is_fp8 = is_fp8_dtype(rep_dtype); + const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); + const bool mxfp8 = transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode); + + int64_t avg_first_dim = 0; + int64_t avg_last_dim = 0; + MultiTensorGroupGemmInputArgs a_multi_tensor_args{}; + + const auto choice = + choose_grouped_operand_storage(static_cast(transa), /*is_A=*/true, mxfp8, is_fp8, + non_tn_fp8_ok, A_list_info.all_row, A_list_info.all_col, "A"); + A_sel.trans = choice.trans; + if (choice.use_rowwise) { + NVTE_CHECK(A_list_info.all_row, "Grouped GEMM: A_list is missing row-wise data"); + A_sel.dtype = A_list_info.row_dtype; + a_multi_tensor_args = build_grouped_gemm_multi_inputA_args( + A_list, num_a_tensors, /*use_rowwise=*/true, is_fp8, &avg_first_dim, &avg_last_dim, "A"); + } else { + NVTE_CHECK(A_list_info.all_col, "Grouped GEMM: A_list is missing column-wise data"); + A_sel.dtype = A_list_info.col_dtype; + a_multi_tensor_args = build_grouped_gemm_multi_inputA_args( + A_list, num_a_tensors, /*use_rowwise=*/false, is_fp8, &avg_first_dim, &avg_last_dim, "A"); + } + + // For discrete A_list, scale pointers are per-tensor; use multi-tensor args. + // Base pointer is unused when providing per-tensor pointers. + A_sel.scale_inv = nullptr; + A_sel.dptr = nullptr; + + // Workspaces: setup (pointer arrays) and cuBLAS + auto workspace = setup_grouped_gemm_workspace(wspace_setup, wspace_cublas, num_tensors); + + launch_grouped_gemm_setup(workspace.setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, + beta_tensor, num_tensors, stream, a_multi_tensor_args, + /*C_list=*/nullptr, /*D_list=*/nullptr, nullptr, inputC->dtype(), + outputD->dtype()); + + // Compute average dimensions for heuristics + int64_t avg_m_val = config_.avg_m.value_or(compute_avg_first_dim(outputD)); + int64_t avg_n_val = + config_.avg_n.value_or(transb ? compute_avg_first_dim(inputB) : compute_avg_last_dim(inputB)); + int64_t avg_k_val = + config_.avg_k.value_or(static_cast(transa) ? avg_last_dim : avg_first_dim); + const bool use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); + execute_grouped_gemm(workspace.setup_workspace, A_sel, B_sel, outputD->dtype(), num_tensors, + config_.use_split_accumulator, use_fp8, avg_m_val, avg_n_val, avg_k_val, + workspace.cublas_workspace_ptr, stream); +} + +void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, + const NVTEGroupedTensor B, int transb, + const NVTETensor *C_list, size_t num_c_tensors, + NVTETensor *D_list, size_t num_d_tensors, + const NVTETensor alpha, const NVTETensor beta, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEGroupedMatmulConfig config, cudaStream_t stream) { + NVTE_API_CALL(nvte_grouped_gemm_with_discrete_out); + using namespace transformer_engine; + + // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.2+ + check_grouped_gemm_requirements("nvte_grouped_gemm_with_discrete_out"); + + NVTE_CHECK(D_list != nullptr, "Grouped GEMM: D_list is null."); + NVTE_CHECK(num_d_tensors > 0, "Grouped GEMM: num_d_tensors must be > 0."); + if (num_c_tensors > 0) { + NVTE_CHECK(C_list != nullptr, "Grouped GEMM: C_list is null but num_c_tensors > 0."); + } + + const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); + const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); + const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); + const Tensor *beta_tensor = convertNVTETensorCheck(beta); + Tensor *wspace_setup = convertNVTETensor(workspace_setup); + Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); + + const Tensor *d0 = convertNVTETensorCheck(D_list[0]); + const DType d_dtype = d0->dtype(); + + const size_t num_tensors = validate_grouped_gemm_inputs(inputA->num_tensors, {inputA, inputB}, + alpha_tensor, beta_tensor); + NVTE_CHECK(num_d_tensors == num_tensors, "Grouped GEMM: D_list must have num_tensors (", + num_tensors, ") entries, got ", num_d_tensors); + if (num_c_tensors > 0) { + NVTE_CHECK(num_c_tensors == num_tensors, "Grouped GEMM: C_list must have num_tensors (", + num_tensors, ") entries, got ", num_c_tensors); + } + auto is_output_dtype = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16 || + dtype == transformer_engine::DType::kFloat32; + }; + NVTE_CHECK(is_output_dtype(d_dtype), "Grouped GEMM: D must be BF16, FP16, or FP32."); + + // Parse config (if provided) + GroupedMatmulConfig config_ = parse_grouped_gemm_config(config); + + // Select operand storage (row-wise vs column-wise) and adjust transpose flags to + // mirror the non-grouped GEMM logic for FP8 layout constraints. + auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); + auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); + // Workspaces: setup (pointer arrays) and cuBLAS + auto workspace = setup_grouped_gemm_workspace(wspace_setup, wspace_cublas, num_tensors); + + MultiTensorGroupGemmInputArgs a_multi_tensor_args{}; + launch_grouped_gemm_setup(workspace.setup_workspace, A_sel, B_sel, /*C=*/nullptr, /*D=*/nullptr, + alpha_tensor, beta_tensor, num_tensors, stream, a_multi_tensor_args, + C_list, D_list, A_sel.dptr, d_dtype, d_dtype); + + // Compute average dimensions for heuristics + int64_t avg_m_val = + config_.avg_m.value_or(transa ? compute_avg_last_dim(inputA) : compute_avg_first_dim(inputA)); + int64_t avg_n_val = + config_.avg_n.value_or(transb ? compute_avg_first_dim(inputB) : compute_avg_last_dim(inputB)); + int64_t avg_k_val = + config_.avg_k.value_or(transa ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); + const bool use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); + execute_grouped_gemm(workspace.setup_workspace, A_sel, B_sel, d_dtype, num_tensors, + config_.use_split_accumulator, use_fp8, avg_m_val, avg_n_val, avg_k_val, + workspace.cublas_workspace_ptr, stream); +} + +void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, + cudaStream_t stream) { + NVTE_API_CALL(nvte_grouped_bias_add); + using namespace transformer_engine; + + const GroupedTensor *outputD = convertNVTEGroupedTensorCheck(output); + const GroupedTensor *bias_tensor = convertNVTEGroupedTensorCheck(bias); + + NVTE_CHECK(outputD->num_tensors >= 1, "Grouped bias add: number of tensors must be at least 1"); + NVTE_CHECK(outputD->num_tensors == bias_tensor->num_tensors, + "Grouped bias add: output and bias must have the same number of tensors"); + NVTE_CHECK(outputD->has_data(), "Grouped bias add: output is missing row-wise data"); + NVTE_CHECK(bias_tensor->has_data(), "Grouped bias add: bias is missing row-wise data"); + NVTE_CHECK(outputD->dtype() == bias_tensor->dtype(), + "Grouped bias add: output and bias must have matching dtypes"); + NVTE_CHECK(bias_tensor->all_same_first_dim(), + "Grouped bias add: bias must have uniform first dim (expected 1)"); + NVTE_CHECK(bias_tensor->get_common_first_dim() == 1, + "Grouped bias add: bias first dim must be 1"); + NVTE_CHECK(outputD->all_same_last_dim() && bias_tensor->all_same_last_dim(), + "Grouped bias add requires uniform last dim for output and bias"); + NVTE_CHECK(outputD->get_common_last_dim() == bias_tensor->get_common_last_dim(), + "Grouped bias add: output and bias last dims must match"); + constexpr int kVec = 4; + NVTE_CHECK(outputD->get_common_last_dim() % kVec == 0, + "Grouped bias add requires last dim divisible by ", kVec); + + const TensorShapeInfo d_meta = TensorShapeInfo::from_tensor(outputD); + const TensorShapeInfo bias_meta = TensorShapeInfo::from_tensor(bias_tensor); + + const DType dtype = outputD->dtype(); + constexpr int kThreads = 256; + const size_t total_elements = static_cast(outputD->logical_shape.data[0]) * + static_cast(outputD->logical_shape.data[1]); + const size_t total_vec_count = (total_elements + kVec - 1) / kVec; + int blocks_per_tensor = static_cast((total_vec_count + kThreads - 1) / kThreads); + const dim3 grid(outputD->num_tensors, blocks_per_tensor); + const dim3 block(kThreads); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, T, { + grouped_bias_add_kernel<<>>( + static_cast(outputD->data.dptr), static_cast(bias_tensor->data.dptr), + d_meta, bias_meta, outputD->num_tensors); + }); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +#else // CUBLAS_VERSION < 130200 + +void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, + const NVTETensor beta, NVTETensor workspace_setup, + NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, + cudaStream_t stream) { + NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); +} + +void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num_a_tensors, + int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, + const NVTETensor alpha, const NVTETensor beta, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEGroupedMatmulConfig config, cudaStream_t stream) { + NVTE_ERROR( + "nvte_grouped_gemm_with_discrete_inputA requires cuBLAS 13.2+, but compile-time " + "cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); +} + +void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, + const NVTEGroupedTensor B, int transb, + const NVTETensor *C_list, size_t num_c_tensors, + NVTETensor *D_list, size_t num_d_tensors, + const NVTETensor alpha, const NVTETensor beta, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEGroupedMatmulConfig config, cudaStream_t stream) { + NVTE_ERROR( + "nvte_grouped_gemm_with_discrete_out requires cuBLAS 13.2+, but compile-time " + "cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); +} + +void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, + cudaStream_t stream) { + NVTE_ERROR("nvte_grouped_bias_add requires cuBLAS 13.2+, but compile-time cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); +} + +size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors) { + NVTE_ERROR( + "nvte_get_grouped_gemm_setup_workspace_size requires cuBLAS 13.2+, but compile-time cuBLAS " + "version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); + return 0; +} + +#endif // CUBLAS_VERSION >= 130200 + +#else //__HIP_PLATFORM_AMD__ + +void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, + const NVTETensor beta, NVTETensor workspace_setup, + NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, + cudaStream_t stream) { + NVTE_ERROR("nvte_grouped_gemm is not supported on ROCm yet"); +} + +void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num_a_tensors, + int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, + const NVTETensor alpha, const NVTETensor beta, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEGroupedMatmulConfig config, cudaStream_t stream) { + NVTE_ERROR("nvte_grouped_gemm_with_discrete_inputA is not supported on ROCm yet"); +} + +void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, + const NVTEGroupedTensor B, int transb, + const NVTETensor *C_list, size_t num_c_tensors, + NVTETensor *D_list, size_t num_d_tensors, + const NVTETensor alpha, const NVTETensor beta, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEGroupedMatmulConfig config, cudaStream_t stream) { + NVTE_ERROR("nvte_grouped_gemm_with_discrete_out is not supported on ROCm yet"); +} + +void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, + cudaStream_t stream) { + NVTE_ERROR("nvte_grouped_bias_add is not supported on ROCm yet"); +} + +size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors) { + NVTE_ERROR("nvte_get_grouped_gemm_setup_workspace_size is not supported on ROCm yet"); + return 0; +} + +#endif // __HIP_PLATFORM_AMD__ +namespace { + +__global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) dst[idx] = static_cast(src[idx]); +} + +} // namespace + +void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) { + NVTE_API_CALL(nvte_convert_int32_to_int64); + if (n == 0) return; + const int threads = 256; + const int blocks = static_cast((n + threads - 1) / threads); + convert_int32_to_int64_kernel<<>>(src, dst, n); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh index eb99edc4d..aa2bde420 100644 --- a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh @@ -326,17 +326,17 @@ void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, // Check can implement the kernel. if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) { - NVTE_CHECK(false, "Failed to implement CUTLASS Grouped GEMM"); + NVTE_ERROR("Failed to implement CUTLASS Grouped GEMM with ", num_gemms, " GEMMs"); } // Initialize the kernel. if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) { - NVTE_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM"); + NVTE_ERROR("Failed to initialize CUTLASS Grouped GEMM with ", num_gemms, " GEMMs"); } // Execute the kernel in the current stream. if (gemm.run(stream) != cutlass::Status::kSuccess) { - NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); + NVTE_ERROR("Failed to run CUTLASS Grouped GEMM with ", num_gemms, " GEMMs"); } } diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 3bc8d9bc8..ce39283d5 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1498,6 +1498,9 @@ void release_service_stream(hipStream_t stream, struct ServiceStreamCtl &ctl) } // namespace +#pragma GCC diagnostic push +#pragma GCC diagnostic error "-Wmissing-declarations" + void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, @@ -1564,4 +1567,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } } +#pragma GCC diagnostic pop + } //namespace transformer_engine diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu new file mode 100644 index 000000000..04e965a9d --- /dev/null +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -0,0 +1,591 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "hadamard_transform_utils.cuh" + +namespace transformer_engine { +namespace { + +constexpr int kMaxTensorsPerKernel = 64; +constexpr int kThreadsPerWarp = 32; + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +__device__ __forceinline__ size_t get_current_tensor_id( + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t* const __restrict__ offsets_ptr) { + if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + const size_t current_row = current_offset / last_logical_dim; + const size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } else { + // upper_bound(offsets, current_offset) - 1 in range i in [0..num_tensors) + size_t low = 0; + size_t hi = num_tensors; // half-open [low, hi) + + while (low < hi) { + const size_t mid = low + (hi - low) / 2; + const size_t mid_offset = static_cast(offsets_ptr[mid]); + + if (mid_offset <= current_offset) { + low = mid + 1; + } else { + hi = mid; + } + } + + // low = first index where offsets[low] > current_offset (or low == num_tensors) + // id = low - 1, but need to evaluate if current_offset < offsets[0] + return (low == 0) ? 0 : (low - 1); + } +} + +template +__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], + IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + uint32_t& local_amax_reg, + uint32_t& local_amax_t_reg) { + uint32_t a_frag[4]; // A matrix fragment + uint32_t c_frag[4]; // Result fragment + + int warp_id = threadIdx.x / kThreadsPerWarp; + int local_rank = (threadIdx.x % kThreadsPerWarp); + + int ld_row_idx = local_rank % kHadamardDimension; + int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + + uint32_t temp_amax_reg; + uint32_t temp_amax_t_reg; + + if (kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } + + if (kReturnTransposedAmax) { + // TODO(Frank): This is not efficient, since we could directly load the + // matrix in transposed layout. + if (!kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + matrix_transpose_m8_n8_b16_inplace(a_frag[0]); + matrix_transpose_m8_n8_b16_inplace(a_frag[1]); + matrix_transpose_m8_n8_b16_inplace(a_frag[2]); + matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], + b_frag_t[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_t_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_t_reg) + : "r"(local_amax_t_reg), "r"(temp_amax_t_reg)); + } + + if (kReturnPreRhtAmax) { + if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[1])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[2]) + : "r"(a_frag[2]), "r"(a_frag[3])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[2])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_pre_rht_amax_reg) + : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); + } +} + +template +__device__ __host__ constexpr int NextPowerOf2() { + static_assert(kN > 0, "kN must be > 0"); + // Round up to the next power of 2 by counting leading zeros. + return 1 << (32 - __builtin_clz(kN - 1)); +} + +template +__device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float identity_amax, + const float transpose_amax, float* staging_for_pre_rht, + float* staging_for_identity, float* staging_for_transpose, + float* output_pre_rht_amax_ptr, + float* output_identity_amax_ptr, + float* output_transpose_amax_ptr, const int warpid) { + // intra-warp reduction + constexpr int kWarpSize = 32; + int local_rank = threadIdx.x % 32; + float warp_pre_rht_amax = kReturnPreRhtAmax ? warp_reduce_max(pre_rht_amax) : 0.0f; + float warp_identity_amax = kReturnIdentityAmax ? warp_reduce_max(identity_amax) : 0.0f; + float warp_transpose_amax = + kReturnTransposedAmax ? warp_reduce_max(transpose_amax) : 0.0f; + + // inter-warp reduction + if (threadIdx.x % 32 == 0) { + if (kReturnPreRhtAmax) { + staging_for_pre_rht[warpid] = warp_pre_rht_amax; + } + if (kReturnIdentityAmax) { + staging_for_identity[warpid] = warp_identity_amax; + } + if (kReturnTransposedAmax) { + staging_for_transpose[warpid] = warp_transpose_amax; + } + } + __syncthreads(); + constexpr int kNumWarpsPow2 = NextPowerOf2(); + if (warpid == 0) { + if (kReturnIdentityAmax) { + float identity_accum = local_rank < kNumWarps ? staging_for_identity[local_rank] : 0.0f; + identity_accum = warp_reduce_max(identity_accum); + if (local_rank == 0) { + atomicMaxFloat(output_identity_amax_ptr, identity_accum); + } + } + } + if (warpid == 1) { + if (kReturnTransposedAmax) { + float transpose_accum = local_rank < kNumWarps ? staging_for_transpose[local_rank] : 0.0f; + transpose_accum = warp_reduce_max(transpose_accum); + if (local_rank == 0) { + atomicMaxFloat(output_transpose_amax_ptr, transpose_accum); + } + } + } + if (warpid == 2) { + if (kReturnPreRhtAmax) { + float pre_rht_accum = local_rank < kNumWarps ? staging_for_pre_rht[local_rank] : 0.0f; + pre_rht_accum = warp_reduce_max(pre_rht_accum); + if (local_rank == 0) { + atomicMaxFloat(output_pre_rht_amax_ptr, pre_rht_accum); + } + } + } +} + +__global__ void GraphSafeMultiZeroAmaxKernel(const size_t num_tensors, float* amax_rowwise_ptr, + float* amax_colwise_ptr) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + // Assign each thread a range for rowwise and colwise independently + if (amax_rowwise_ptr != nullptr) { + for (int i = tid; i < num_tensors; i += stride) { + amax_rowwise_ptr[i] = 0.f; + } + } + if (amax_colwise_ptr != nullptr) { + for (int i = tid; i < num_tensors; i += stride) { + amax_colwise_ptr[i] = 0.f; + } + } +} + +__global__ void GraphSafeMultiAmaxMemcpyD2DKernelPreRHT(const size_t num_tensors, + float* amax_rowwise_ptr, + float* amax_colwise_ptr) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + if (amax_rowwise_ptr != nullptr && amax_colwise_ptr != nullptr) { + for (; tid < num_tensors; tid += stride) { + float* output_pre_rht_amax_ptr = amax_rowwise_ptr + tid; + float* output_transpose_amax_ptr = amax_colwise_ptr + tid; + *output_transpose_amax_ptr = *output_pre_rht_amax_ptr; + } + } +} + +template +__global__ void GraphSafeGroupHadamardAmaxTmaKernel( + const __grid_constant__ CUtensorMap tensor_map_input, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t* const __restrict__ offsets_ptr, const int64_t* const __restrict__ first_dims_ptr, + float* const __restrict__ amax_rowwise_ptr, float* const __restrict__ amax_colwise_ptr) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + float* output_pre_rht_amax_ptr; + float* output_identity_amax_ptr = nullptr; + float* output_transpose_amax_ptr; + + // calculate the global offset to get tensor id + size_t global_offset = blockIdx.y * CHUNK_DIM_Y * last_logical_dim; + // paged stashing: will have input buffer [M, N], where M is larger than sum(first_dims) + // also need to early return if this CTA is processing a region larger than the last offsets[num_tensors] + if (global_offset >= offsets_ptr[num_tensors]) { + return; + } + int tensor_id = get_current_tensor_id(shape_rep, num_tensors, global_offset, first_logical_dim, + last_logical_dim, offsets_ptr); + output_pre_rht_amax_ptr = static_cast(amax_rowwise_ptr) + tensor_id; + output_transpose_amax_ptr = static_cast(amax_colwise_ptr) + tensor_id; + + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y && CHUNK_DIM_Y % BUFF_DIM_Y == 0); + static_assert(CHUNK_DIM_X >= BUFF_DIM_X && CHUNK_DIM_X % BUFF_DIM_X == 0); + + constexpr size_t STAGES_Y = CHUNK_DIM_Y / BUFF_DIM_Y; + constexpr size_t STAGES_X = CHUNK_DIM_X / BUFF_DIM_X; + + constexpr int kNumWarps = (THREADS_PER_CHUNK * THREADS_PER_Y) / kThreadsPerWarp; + + const int input_block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int input_block_offset_X = blockIdx.x * CHUNK_DIM_X; + + extern __shared__ __align__(128) char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uint8_t* dshmem = reinterpret_cast((base_shmem_ptr + 127) & ~127ULL); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + constexpr size_t in_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + IType* in_sh_0 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + IType* in_sh_1 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + + IType* in_shs[2] = {in_sh_0, in_sh_1}; + + constexpr int shmem_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + + const bool is_master_thread = (threadIdx.x == 0 && threadIdx.y == 0); + + // Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + uint64_t* mbar = reinterpret_cast(dshmem); + dshmem += sizeof(uint64_t) * (STAGES_X * STAGES_Y); + + float* max_staging_identity = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_transpose = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_pre_rht = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + + initialize_barriers(mbar, + is_master_thread); + + copy_2d_to_shared(in_shs[0], reinterpret_cast(&tensor_map_input), + input_block_offset_X, input_block_offset_Y, shmem_buff_size, &mbar[0], + is_master_thread); + + uint32_t had_frag_i[4]; + uint32_t had_frag_t[4]; + get_hadamard_matrix_fragment( + had_frag_i, random_sign_mask, had_frag_t, random_sign_mask_t); + + float local_pre_rht_amax = 0.0; + float local_amax = 0.0; + float local_amax_t = 0.0; + uint32_t local_pre_rht_amax_reg = *reinterpret_cast(&local_pre_rht_amax); + uint32_t local_amax_reg = *reinterpret_cast(&local_amax); + uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { + for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { + int stage = STAGES_X * stage_y + stage_x; + + const int next_stage = stage + 1; + const int next_stage_x = stage_x + 1 == STAGES_X ? 0 : stage_x + 1; + const int next_stage_y = stage_x + 1 == STAGES_X ? stage_y + 1 : stage_y; + + if (next_stage < STAGES_X * STAGES_Y) { + const int input_global_offset_Y = input_block_offset_Y + next_stage_y * BUFF_DIM_Y; + const int input_global_offset_X = input_block_offset_X + next_stage_x * BUFF_DIM_X; + + copy_2d_to_shared(in_shs[next_stage % 2], // ping-pong + reinterpret_cast(&tensor_map_input), input_global_offset_X, + input_global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + const size_t compute_stage_x_num = + BUFF_DIM_X / (kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)); + const size_t compute_stage_y_num = BUFF_DIM_Y / (kHadamardDimension * THREADS_PER_Y); + + const size_t in_row_stride = BUFF_DIM_X; + + IType* in_sh_ptr = in_shs[stage % 2]; + +#pragma unroll + for (size_t compute_stage_y = 0; compute_stage_y < compute_stage_y_num; compute_stage_y++) { + const int row_idx_offset = (compute_stage_y * kHadamardDimension * THREADS_PER_Y + + threadIdx.y * kHadamardDimension); + const int in_row_offset = row_idx_offset * in_row_stride; + +#pragma unroll + for (size_t compute_stage_x = 0; compute_stage_x < compute_stage_x_num; compute_stage_x++) { + ComputeKernel( + had_frag_i, had_frag_t, + in_sh_ptr + in_row_offset + + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), + local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + } + + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); + } + + // Ensure generic shared-memory accesses are visible before the next TMA write. + ptx::fence_proxy_async_shared_cta(); + } + } + + const int warpid = (threadIdx.x + threadIdx.y * blockDim.x) / kThreadsPerWarp; + + if constexpr (kReturnPreRhtAmax) { + unpack_max_of_packed_bf16(local_pre_rht_amax_reg, local_pre_rht_amax); + } + if constexpr (kReturnIdentityAmax) { + unpack_max_of_packed_bf16(local_amax_reg, local_amax); + } + if constexpr (kReturnTransposedAmax) { + unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t); + } + + ReduceMax( + local_pre_rht_amax, local_amax, local_amax_t, max_staging_pre_rht, max_staging_identity, + max_staging_transpose, output_pre_rht_amax_ptr, output_identity_amax_ptr, + output_transpose_amax_ptr, warpid); + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("Kernel is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +} // namespace + +// broadcast_pre_rht_amax: when it's true, hadamard transform will be disabled +// if at this time, the amax buffers for output expects both amax_rowwise and amax_colwise +// then call MultiAmaxMemcpyD2DKernelPreRHT to D2D copy the amax values +void group_hadamard_transform_amax_graph_safe(const GroupedTensor* input, GroupedTensor* output, + uint16_t random_sign_mask, + uint16_t random_sign_mask_t, + bool broadcast_pre_rht_amax, cudaStream_t stream) { + NVTE_API_CALL(group_hadamard_transform_amax_graph_safe); +#if CUDA_VERSION >= 12080 + + NVTE_CHECK(input->num_tensors == output->num_tensors, + "Number of input and output tensors must be same."); + NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data."); + + checkCuDriverContext(stream); + + bool all_return_pre_rht_amax = output->has_data(); + // there is no rowwise RHT transform in current recipe + bool all_return_identity_amax = false; + bool all_return_transposed_amax = output->has_columnwise_data(); + + NVTE_CHECK(all_return_pre_rht_amax || all_return_identity_amax || all_return_transposed_amax, + "At least one of return_pre_rht_amax, return_identity_amax, or return_transposed_amax " + "must be true"); + + if (broadcast_pre_rht_amax) { + NVTE_CHECK(all_return_pre_rht_amax, + "broadcast_pre_rht_amax is only supported when we compute pre-RHT amax"); + // if all_return_identity_amax and all_return_transposed_amax both are false, there is no need to broadcast anything + broadcast_pre_rht_amax &= (all_return_identity_amax || all_return_transposed_amax); + } + + const size_t num_tensors = input->num_tensors; + const size_t first_logical_dim = input->logical_shape.data[0]; + const size_t last_logical_dim = input->logical_shape.data[1]; + // const size_t elts_total = first_logical_dim * last_logical_dim; + NVTE_CHECK(first_logical_dim % 128 == 0, + "First dimension of a grouped tensor should be divisible by 128."); + NVTE_CHECK(last_logical_dim % 128 == 0, + "Last dimension of a grouped tensor should be divisible by 128."); + + float* const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + float* const amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); + + const int64_t* const offsets_ptr = reinterpret_cast(output->tensor_offsets.dptr); + const int64_t* const first_dims_ptr = reinterpret_cast(output->first_dims.dptr); + + // some sanity checks + if (all_return_pre_rht_amax) { + NVTE_CHECK(amax_rowwise_ptr != nullptr, "Amax rowwise pointer should not be nullptr."); + } + if (all_return_transposed_amax) { + NVTE_CHECK(amax_colwise_ptr != nullptr, "Amax columnwise pointer should not be nullptr."); + } + + // Multi zero out multiple amaxes if needed + dim3 block_setup_amax(kMaxTensorsPerKernel); + dim3 grid_setup_amax(1); + GraphSafeMultiZeroAmaxKernel<<>>( + num_tensors, amax_rowwise_ptr, amax_colwise_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); + + using IType = bf16; + constexpr int kHadamardDimension = 16; + + // four (1x4) 64x64 sub-tiles for ping-pong overlap + constexpr uint64_t kChunkBlockXSmall = 256; + constexpr uint64_t kChunkBlockYSmall = 64; + constexpr uint64_t kBuffDimX = 64; + constexpr uint64_t kBuffDimY = 64; + + alignas(64) CUtensorMap tensor_map_input{}; + + create_2D_tensor_map( + /*tensorMap=*/tensor_map_input, + /*tensor=*/input->data, + /*globalY=*/first_logical_dim, + /*globalX=*/last_logical_dim, + /*shmemY=*/kBuffDimY, + /*shmemX=*/kBuffDimX, + /*stride_elems=*/last_logical_dim, + /*offset_elems=*/0, + /*type_num_bits=*/sizeof(IType) * 8, + /*swizzle=*/CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B); + + constexpr uint64_t kThreadBlockX = 4; + constexpr uint64_t kThreadBlockY = 1; + constexpr uint64_t kNumWarps = kThreadBlockX * kThreadBlockY; + + dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY); + dim3 grid(DIVUP(last_logical_dim, kChunkBlockXSmall), + DIVUP(first_logical_dim, kChunkBlockYSmall)); + + ShapeRepresentation shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + if (output->all_same_shape()) { + shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + } else if (output->all_same_first_dim()) { + shape_rep = ShapeRepresentation::VARYING_LAST_DIM; + } else if (output->all_same_last_dim()) { + shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + } else if (output->varying_both_dims()) { + shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; + } + + const bool is_const_last_dim = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || + shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); + + NVTE_CHECK(is_const_last_dim, + "Currently we only support const last dimension for graph safe hadamard transform."); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + (all_return_transposed_amax && !broadcast_pre_rht_amax), kReturnTransposedAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + (all_return_identity_amax && !broadcast_pre_rht_amax), kReturnIdentityAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_return_pre_rht_amax, kReturnPreRhtAmax, + + // *2 for ping-pong + size_t in_sh_size = kBuffDimX * kBuffDimY * 2 * sizeof(IType); + size_t mbar_size = sizeof(uint64_t) * (kChunkBlockXSmall / kBuffDimX) * + (kChunkBlockYSmall / kBuffDimY); + size_t shmem_bytes = in_sh_size + mbar_size + kNumWarps * sizeof(float) * 3; + // Add padding in case shmem ptr is not aligned to 128 bytes. + shmem_bytes = (shmem_bytes + 128); + + auto kernel = GraphSafeGroupHadamardAmaxTmaKernel< + IType, kHadamardDimension, kChunkBlockYSmall, kChunkBlockXSmall, kBuffDimY, + kBuffDimX, kThreadBlockX * kThreadsPerWarp, kThreadBlockY, kReturnPreRhtAmax, + kReturnIdentityAmax, kReturnTransposedAmax>; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_bytes); + + kernel<<>>( + tensor_map_input, random_sign_mask, random_sign_mask_t, shape_rep, num_tensors, + first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr, + amax_rowwise_ptr, amax_colwise_ptr); + if (broadcast_pre_rht_amax) { + GraphSafeMultiAmaxMemcpyD2DKernelPreRHT<<>>(num_tensors, amax_rowwise_ptr, + amax_colwise_ptr); + }))); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ", + CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +} + +} // namespace transformer_engine + +void nvte_group_hadamard_transform_amax_graph_safe(const NVTEGroupedTensor input, + NVTEGroupedTensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_hadamard_transform_amax_graph_safe); + using namespace transformer_engine; + + GroupedTensor* input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor* output_tensor = convertNVTEGroupedTensorCheck(output); + + if (input_tensor->num_tensors == 0) { + return; + } + + // Call the group tensor Hadamard transform amax implementation. + group_hadamard_transform_amax_graph_safe( + input_tensor, output_tensor, static_cast(random_sign_mask), + static_cast(random_sign_mask_t), false, stream); +} + +// Grouped-tensor amax without doing hadamard transform +void nvte_group_amax_graph_safe(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_amax_graph_safe); + using namespace transformer_engine; + + GroupedTensor* input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor* output_tensor = convertNVTEGroupedTensorCheck(output); + + if (input_tensor->num_tensors == 0) { + return; + } + + group_hadamard_transform_amax_graph_safe(input_tensor, output_tensor, 0, 0, true, stream); +} diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu new file mode 100644 index 000000000..6f3cf90d9 --- /dev/null +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -0,0 +1,1516 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_runtime.h" +#include "common/util/curanddx.hpp" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "customized_pipeline.cuh" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/float8.h" +#include "cutlass/float_subbyte.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/platform/platform.h" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/print_error.hpp" + +namespace transformer_engine { +namespace detail { +namespace { + +using namespace cute; + +// Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor +using cute::Tensor; + +constexpr int kMaxTensorsPerKernel = 64; +constexpr int kNVFP4BlockSize = 16; + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +__device__ __forceinline__ size_t get_current_tensor_id( + const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr) { + if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + const size_t current_row = current_offset / last_logical_dim; + const size_t rows_per_tensor = first_logical_dim / num_tensors; + return current_row / rows_per_tensor; + } else { + // upper_bound(offsets, current_offset) - 1 in range i in [0..num_tensors) + size_t low = 0; + size_t hi = num_tensors; // half-open [low, hi) + + while (low < hi) { + const size_t mid = low + (hi - low) / 2; + const size_t mid_offset = static_cast(offsets_ptr[mid]); + + if (mid_offset <= current_offset) { + low = mid + 1; + } else { + hi = mid; + } + } + + // low = first index where offsets[low] > current_offset (or low == num_tensors) + // id = low - 1, but need to evaluate if current_offset < offsets[0] + return (low == 0) ? 0 : (low - 1); + } +} + +CUTLASS_DEVICE +cutlass::Array StochasticNumericConverterBase( + cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + auto output_ptr = reinterpret_cast(&output); + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" + "}" + : "=h"(output_ptr[0]), "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]), + "f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1])); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return output; +} + +CUTLASS_DEVICE +cutlass::Array StochasticNumericConverter( + cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + cutlass::Array *result_ptr = + reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = + reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = + reinterpret_cast const *>(&rbits); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; i++) { + result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); + } + return output; +} + +template +struct SharedStorage { + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr EpilogueUnrollFactor = EpilogueUnrollFactor_; + using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = + cutlass::detail::CustomizedPipelineTmaUmmaAsync, + AtomThrShapeMNK>; + using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; + using SchedPipeline = cutlass::PipelineCLCFetchAsync; + using SchedPipelineStorage = typename SchedPipeline::SharedStorage; + using SchedThrottlePipeline = cutlass::PipelineAsync; + using SchedThrottlePipelineStorage = typename SchedThrottlePipeline::SharedStorage; + + struct TensorStorage : cute::aligned_struct<128, _1> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) cute::uint64_t tma_barrier[1]; + alignas(16) SchedPipelineStorage sched; + alignas(16) SchedThrottlePipelineStorage sched_throttle; + alignas(16) int32_t atomic_tile_id[SchedulerPipelineStageCount_]; + alignas(16) float global_a_amax[kMaxTensorsPerKernel]; + alignas(16) float global_d_amax[kMaxTensorsPerKernel]; + uint32_t atomic_tile_counter[SchedulerPipelineStageCount_]; + uint32_t tmem_base_ptr; +}; + +// Main RHT GEMM kernel entry -- highly templated for flexible architecture/config support +template +__launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_graph_safe( + MShape M, NShape packed_N, KShape K, ClusterShape cluster_shape, ClusterTileShape cluster_tile, + TA const *A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, + TB const *B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, + TQA *QA, QAStride dQA, TSFA *SFA, TSFALayout sfa_layout, TQA *QA_COLWISE, TSFA *SFA_COLWISE, + float *amax_rowwise, float *amax_colwise, const int64_t *offsets, const int64_t *first_dims, + size_t num_tensors, ShapeRepresentation shape_rep, uint32_t *tile_scheduler_workspace, + TiledMMA mma, const size_t *rng_state) { + using namespace cute; + + // Abort immediately if compilation is not supported + constexpr bool is_blackwell_arch = ARCH_BLACKWELL_FAMILY; + if constexpr (!is_blackwell_arch) { + NVTE_DEVICE_ERROR( + "group_row_col_rht_gemm_device_graph_safe is only supported on Blackwell " + "with architecture-specific compilation. " + "Try recompiling with sm_100a or similar."); + return; + } + static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, + "group_row_col_rht_gemm_device_graph_safe must generate row-wise " + "and/or column-wise output."); +#if !defined(CUTLASS_ARCH_CLC_ENABLED) + CUTLASS_NOT_IMPLEMENTED(); + return; +#endif + + using X = Underscore; + // Accumulator data type for main computation + using ElementAccumulator = float; + static int constexpr K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static uint32_t constexpr kTmaTransactionBytes = cutlass::bits_to_bytes( + size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v); + static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_; + static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_; + static constexpr bool kEnableRowQuant = kEnableRowQuant_; + static constexpr bool kEnableSwizzleSFOutput = kEnableSwizzleSFOutput_; + static constexpr bool kUseFastMath = kUseFastMath_; + + // Constant for RHT tensor processing (tile size etc) + static int constexpr RhtTensorSize = 16; + + // Get the total number of tokens to process + // Note that here M is the hidden size, which is the last logical dimension of the input tensor x + // The kernel is designed in column major, so M is the hidden size + size_t sum_token_dims = offsets[num_tensors] / M; + + // Transaction bytes for TMA transfer on RHT tensor blocks + static int constexpr kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v); + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + + // Mainloop pipeline stage calculation, vectorization parameters for scaling factors + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + static int constexpr SFVecSize = 16; + // Swizzle output layout for scaling factor arrays + using SwizzledSFALayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFDLayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + + // Mainloop pipeline types for TMA async execution and epilogue cluster scheduling + using MainloopPipeline = + cutlass::detail::CustomizedPipelineTmaUmmaAsync; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + using SchedPipeline = cutlass::PipelineCLCFetchAsync; + using SchedPipelineState = typename SchedPipeline::PipelineState; + using SchedThrottlePipeline = cutlass::PipelineAsync; + using SchedThrottlePipelineState = typename SchedThrottlePipeline::PipelineState; + + static_assert(ClusterShape{} == Shape<_1, _1, _1>{}, "ClusterShape must be Shape<_1,_1,_1>"); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static int constexpr VectorSize = RhtTensorSize; + + // Compile-time safety: static shapes required for shared memory layouts + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + // CUTE_STATIC_ASSERT(is_static::value); + + auto cluster_size = size<0>(cluster_shape); + auto mainloop_tiler = Shape<_128, _16, _128>{}; + auto epilogue_tiler = Shape<_128, _128, _128>{}; + + static int constexpr EpilogueUnrollFactor = size<2>(epilogue_tiler) / size<2>(cluster_tile); + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + int const K_TILE_MAX = min(packed_N, K) / size<2>(epilogue_tiler); + + struct TileScheduler { + uint32_t tiles_in_m = 0; + uint32_t tiles_in_n = 0; + uint32_t linear_idx = 0; + uint32_t next_linear_idx = 0; + uint32_t start_idx = 0; + uint32_t tile_m_idx = 0; + uint32_t tile_n_idx = 0; + int k_tile_max = 0; + uint32_t *atomic_tile_index_; + uint32_t *smem_tile_counter; + uint32_t atomic_offset; + cutlass::FastDivmodU64 divmod_tiles_in_m; + + CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, + uint32_t *atomic_tile_index, uint32_t *smem_tile_counter) + : tiles_in_m(tiles_m), + tiles_in_n(tiles_n), + linear_idx(blockIdx.x), + next_linear_idx(blockIdx.x), + start_idx(blockIdx.x), + k_tile_max(kmax), + atomic_tile_index_(atomic_tile_index), + smem_tile_counter(smem_tile_counter), + atomic_offset(gridDim.x), + divmod_tiles_in_m(uint64_t(tiles_m)) { + update_tile_idx(); + } + CUTLASS_DEVICE void update_tile_idx() { + uint64_t q, r; + divmod_tiles_in_m(q, r, uint64_t(linear_idx)); + tile_m_idx = static_cast(r); + tile_n_idx = static_cast(q) * uint32_t(k_tile_max); + } + CUTLASS_DEVICE uint32_t tile_m() const { return tile_m_idx; } + CUTLASS_DEVICE uint32_t tile_n_base() const { return tile_n_idx; } + CUTLASS_DEVICE uint32_t tiles_m() const { return tiles_in_m; } + + CUTLASS_DEVICE uint32_t tiles_n() const { return tiles_in_n; } + + CUTLASS_DEVICE bool is_valid() const { + return cute::elem_less(cute::make_coord(tile_m(), tile_n_base()), + cute::make_coord(tiles_in_m, tiles_in_n)); + } + + CUTLASS_DEVICE bool is_first_wave() const { return linear_idx == start_idx; } + + CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { return linear_idx; } + + // Fetch a new tile_id using atomics. + CUTLASS_DEVICE uint32_t fetch_tile_id_counter(int pred) { + uint32_t tile_id_counter = 0; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p atom.global.add.u32 %0, [%1], 1; \n\t" + "}" + : "=r"(tile_id_counter) + : "l"(atomic_tile_index_), "r"(pred)); + + return tile_id_counter; + } + + CUTLASS_DEVICE auto fetch_next_work(SchedPipeline &sched_pipeline, + SchedPipelineState sched_pipeline_consumer_state) { + sched_pipeline.consumer_wait(sched_pipeline_consumer_state); + next_linear_idx = smem_tile_counter[sched_pipeline_consumer_state.index()]; + cutlass::arch::fence_view_async_shared(); + sched_pipeline.consumer_release(sched_pipeline_consumer_state); + return; + } + + CUTLASS_DEVICE auto advance_to_next_work(SchedPipeline &sched_pipeline, + SchedPipelineState sched_pipeline_producer_state) { + uint32_t mbarrier_addr = sched_pipeline.producer_get_barrier(sched_pipeline_producer_state); + // Wait for clcID buffer to become empty with a flipped phase + sched_pipeline.producer_acquire(sched_pipeline_producer_state); + auto is_leading_thread = cute::elect_one_sync(); + uint32_t tile_id_counter = fetch_tile_id_counter(is_leading_thread) + atomic_offset; + uint32_t smem_addr = + cute::cast_smem_ptr_to_uint(&smem_tile_counter[sched_pipeline_producer_state.index()]); + if (is_leading_thread) { + cute::store_shared_remote(tile_id_counter, smem_addr, mbarrier_addr, 0); + } + + ++sched_pipeline_producer_state; + return sched_pipeline_producer_state; + } + + CUTLASS_DEVICE auto update_work_tile_info() { + linear_idx = next_linear_idx; + update_tile_idx(); + return; + } + }; + + // Allocate and alias shared memory to the kernel's shared storage type + extern __shared__ char shared_memory[]; + using SharedStorage = + SharedStorage; + SharedStorage &shared_storage = *reinterpret_cast(shared_memory); + + // Compute the number of tiles in M and N after tiling and assign scheduler + uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); + uint32_t tiles_in_n = uint32_t(size(ceil_div(sum_token_dims, size<2>(epilogue_tiler)))); + + TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, tile_scheduler_workspace, + shared_storage.atomic_tile_counter); + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Shapes for accumulated tiles in mainloop and epilogue + auto acc_shape_mma = make_shape(take<0, 2>(mainloop_tiler), _1{}, _1{}); + auto acc_shape_epilogue = make_shape(take<0, 2>(epilogue_tiler), _1{}, _1{}); + + // Shape of the accumulator fragment for the main loop pipeline, with pipeline stages appended + auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int{}); + auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape); + + // Number of threads assigned for various epilogue roles depending on quantization settings + static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0; + static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0; + static int constexpr NumMmaThreadCount = kEnableRHTColQuant ? 32 : 0; + static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant ? 1 : 0; + static int constexpr NumSchedThreads = 32; + static int constexpr NumMainloopLoadThreads = 32; + static int constexpr NumEpilogueThreads = + NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount; + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier( + NumMmaThreadCount + NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + // warp assignment + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_sched_warp = (warp_idx == 2); + bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7); + bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15); + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + mainloop_pipeline_params.num_consumers = NumEpilogueRowQuantThreadCount + NumMmaIssueThreadCount; + + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params, + cluster_shape, cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = + cutlass::make_producer_start_state(); + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + using AccumulatorPipelineInitBarriers = cute::bool_constant; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = + cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_col_quant_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = + size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params, + cluster_shape, AccumulatorPipelineInitBarriers{}, + cute::true_type{}); // Delay mask calculation + typename SchedPipeline::Params sched_pipeline_params; + if (is_sched_warp) { + sched_pipeline_params.role = SchedPipeline::ThreadCategory::ProducerConsumer; + } else { + sched_pipeline_params.role = SchedPipeline::ThreadCategory::Consumer; + } + sched_pipeline_params.producer_blockid = 0; + sched_pipeline_params.producer_arv_count = 1; + sched_pipeline_params.consumer_arv_count = + NumSchedThreads + + cluster_size * (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); + sched_pipeline_params.transaction_bytes = sizeof(uint32_t); + sched_pipeline_params.initializing_warp = 3; + SchedPipeline sched_pipeline(shared_storage.sched, sched_pipeline_params, cluster_shape); + SchedPipelineState sched_pipeline_consumer_state; + SchedPipelineState sched_pipeline_producer_state = + cutlass::make_producer_start_state(); + + typename SchedThrottlePipeline::Params sched_throttle_pipeline_params; + if (is_dma_warp) { + sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Producer; + } + if (is_sched_warp) { + sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Consumer; + } + sched_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + sched_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + sched_throttle_pipeline_params.dst_blockid = 0; + sched_throttle_pipeline_params.initializing_warp = 4; + + SchedThrottlePipeline sched_throttle_pipeline(shared_storage.sched_throttle, + sched_throttle_pipeline_params); + SchedThrottlePipelineState sched_pipeline_throttle_consumer_state; + SchedThrottlePipelineState sched_pipeline_throttle_producer_state = + cutlass::make_producer_start_state(); + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + + // Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer + if (is_dma_warp) { + // Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access). + cutlass::arch::warpgroup_reg_dealloc<32>(); + // Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory. + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); + + // Partition tensors for tiling according to the mainloop and cluster tilers. + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gB_nk = + local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) + + // Shared memory tensors for pipeline + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + // Determine warp/tile positioning + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + // Partition global to local fragments for A and B + Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = + tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = + tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); + + auto [tBgB, tBsB] = + tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + if constexpr (kEnableRHTColQuant) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], + kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), + tBsB(_, 0)); + } + } + + do { + // is_first_wave indicates whether this scheduler wave is the first among a group. + bool is_first_wave = scheduler.is_first_wave(); + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_, scheduler.tile_m(), _); + int k_tile = 0; + + sched_throttle_pipeline.producer_acquire(sched_pipeline_throttle_producer_state); + sched_throttle_pipeline.producer_commit(sched_pipeline_throttle_producer_state); + ++sched_pipeline_throttle_producer_state; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) { + int k_tile_idx_n = scheduler.tile_n_base() + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType *tma_barrier = + mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n), + tAsA(_, write_stage)); + } + } + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + // scheduler.advance(); + } while (scheduler.is_valid()); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + // This warp executes the main tensor core matrix-multiply-accumulate for the Hadamard transform. + cutlass::arch::warpgroup_reg_dealloc<32>(); + if constexpr (kEnableRHTColQuant) { + // Setup shared memory fragments for A and B tiles. + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, + &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + // Wait until the B (Hadamard) tensor copy is complete + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + + auto barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_, _, _, read_stage); + auto tCrB_nk = tCrB(_, _, 0, 0); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / EpilogueUnrollFactor; ++k_block) { + int accumulator_k_block = + accumulator_pipe_producer_state.index() * EpilogueUnrollFactor; + int tCrA_k_block = k_block * EpilogueUnrollFactor; + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < EpilogueUnrollFactor; i++) { + auto accumulators = bulk_tmem_mma(_, _, _, accumulator_k_block + i); + gemm(mma, tCrA_mk(_, _, tCrA_k_block + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + mainloop_pipeline.umma_consumer_release(curr_mainloop_pipe_consumer_state); + barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } else if (is_sched_warp) { + // Scheduler warp manages tile assignment and pipeline progress for warps + cutlass::arch::warpgroup_reg_dealloc<32>(); + do { + sched_throttle_pipeline.consumer_wait(sched_pipeline_throttle_consumer_state); + sched_throttle_pipeline.consumer_release(sched_pipeline_throttle_consumer_state); + ++sched_pipeline_throttle_consumer_state; + sched_pipeline_producer_state = + scheduler.advance_to_next_work(sched_pipeline, sched_pipeline_producer_state); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } else if (is_epilogue_col_quant_warp) { + // Warp responsible for quantizing output of Hadamard transform to FP4 for columnwise usage, + // and writing result tensors/scales to global memory. + cutlass::arch::warpgroup_reg_alloc<192>(); + if constexpr (kEnableRHTColQuant) { + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + auto acc_epilogue_pipelined_shape = + append(acc_shape_epilogue, Int{}); + auto bulk_tmem_epilogue_layout = make_layout( + acc_epilogue_pipelined_shape, + make_stride(stride<0>(bulk_tmem_mma), Int<0>{}, Int<0>{}, size<1>(epilogue_tiler))); + auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr(), bulk_tmem_epilogue_layout); + + // Use 256-bit fragments for aligned bulk stores + static int constexpr FragmentSize = 256 / sizeof_bits_v; + + // Wait for TMEM allocation for this pipeline to finish + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % cutlass::NumThreadsPerWarpGroup; + // g2s load all global_d_amax + CUTLASS_PRAGMA_NO_UNROLL + for (int g = local_thread_idx; g < num_tensors; g += NumEpilogueColQuantThreadCount) { + shared_storage.global_d_amax[g] = __ldg(reinterpret_cast(amax_colwise + g)); + } + + size_t rng_seed = 0; + size_t rng_offset = 0; + // Setup RNG for stochastic rounding + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + // TODO(zhongbo): double check the logic here + int group_idx = get_current_tensor_id(shape_rep, num_tensors, + (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, + packed_N, M, offsets); + + // Determine quantization scale factor layouts/output splits for this group + TSFDLayout sfd_layout; + int cur_N = static_cast(first_dims[group_idx]); + if constexpr (kEnableSwizzleSFOutput) { + sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); + } else { + sfd_layout = make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), + make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); + } + // Build output tensors for columns and their quant scales + // TODO(zhongbo): double check the logic here + Tensor mD = make_tensor(cute::subbyte_iterator(reinterpret_cast( + reinterpret_cast(QA_COLWISE) + offsets[group_idx] / 2)), + make_shape(M, cur_N), DStride{}); // (M,packed_N) + Tensor gD_mn = + local_tile(mD, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + // for every tensor [x, y] row major, x y both a multiple of 128 + // both of its rowwise and colwise scaling factors will have exactly x * y / 16 elements in FP8 E4M3 + Tensor mSFD = make_tensor( + make_gmem_ptr(reinterpret_cast(reinterpret_cast(SFA_COLWISE) + + offsets[group_idx] / kNVFP4BlockSize)), + sfd_layout); + Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + Tensor gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); + + // Setup tile-level TMEM (t2r) and global memory (r2g) copy descriptors + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); + auto tiled_r2g = + make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(local_thread_idx); + auto thr_r2g = tiled_r2g.get_slice(local_thread_idx); + + cutlass::arch::NamedBarrier::sync(NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + static constexpr float fp4_max_inv = 1.0f / fp4_max; + float c_global_amax_val = shared_storage.global_d_amax[group_idx]; + float global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + float global_decode_scale = 1.0f / global_encode_scale; + + // Scaling factor for fast math path + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + + do { + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); + ++k_tile) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); + + // TODO(zhongbo): double check the logic here + int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors, + global_tile_n_offset * M, packed_N, M, offsets); + + if (cur_group_idx != group_idx) { + group_idx = cur_group_idx; + c_global_amax_val = shared_storage.global_d_amax[group_idx]; + // update amax + global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + global_decode_scale = 1.0f / global_encode_scale; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + // TODO(zhongbo): double check the logic here + cur_N = first_dims[group_idx]; + if constexpr (kEnableSwizzleSFOutput) { + sfd_layout = + tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); + } else { + sfd_layout = + make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), + make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); + } + // update tensor + mD = make_tensor(cute::subbyte_iterator(reinterpret_cast( + reinterpret_cast(QA_COLWISE) + offsets[group_idx] / 2)), + make_shape(M, cur_N), DStride{}); + gD_mn = local_tile(mD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + mSFD = make_tensor( + make_gmem_ptr(reinterpret_cast(reinterpret_cast(SFA_COLWISE) + + offsets[group_idx] / kNVFP4BlockSize)), + sfd_layout); + gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); + } + int group_start_offset = offsets[group_idx] / M; + int local_tile_n_idx = + (global_tile_n_offset - group_start_offset) / size<1>(epilogue_tiler); + Tensor tDgD_mn = gD_mn_view(_, _, _, scheduler.tile_m(), local_tile_n_idx); + + Tensor tDgSFD_mn = gSFD_mn(_, _, scheduler.tile_m(), local_tile_n_idx); + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto Acc = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index()); + Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = + make_tensor(shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrD = make_tensor(shape(tDgD)); + Tensor tTR_rAcc_frag = + recast>(coalesce(tTR_rAcc)); + Tensor tDrD_frag = recast>(coalesce(tDrD)); + + Tensor src = thr_r2g.retile_S(tDrD); + Tensor dst = thr_r2g.retile_D(tDgD); + + Tensor tDgSFD_view = make_tensor( + tDgSFD_mn.data(), make_layout(make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{}))); + Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view)); + Tensor tDrSFD = make_tensor(shape(tDgSFD)); + + static int constexpr NumVecs = size(tDgD) / VectorSize; + Tensor tD_rRowSFD_frg = recast>(tDrSFD); + + // Compute amax and quantization scales for this tile + cutlass::maximum_absolute_value_reduction, + true> + amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // Copy from TMEM to registers + copy(tiled_t2r, tDtAcc, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + ++accumulator_pipe_consumer_state; + + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with + // unfused kernels + auto convert_accum_to_bf16 = + cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = + cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + tTR_rAcc_frag(_1{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{}))); + } + + auto compute_frgs = reinterpret_cast *>( + tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrD_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } + + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales = cutlass::multiplies>{}( + vec_maxs, global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales = + cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}( + pvscales, global_encode_scale); + } + auto pvscales_cvted = + cutlass::NumericArrayConverter{}(pvscales); + + tD_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}( + tD_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}( + qpvscale_ups, global_decode_scale); + cutlass::Array acc_scales; + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal + acc_scales = + cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // Accurate math: compute reciprocal with division + acc_scales = cutlass::divides>{}( + 1.0, qpvscale_scaled); + } + + // Prepare stochastic rounding random state if enabled + uint4 random_uint4 = uint4{0, 0, 0, 0}; + transformer_engine::curanddx::detail::philox4x32_native_state< + NVTE_BUILD_NUM_PHILOX_ROUNDS> + rng; + // "Prefetch" a stochastic rounding state for the first tile + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + // Apply round/quantize to each fragment, with or without stochastic rounding + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales[v], cutlass::platform::numeric_limits::max()); + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale), + *reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale)); + } + } + + // Write quantized FP4 tile and dequant scale to gmem + copy(tiled_r2g, src, dst); + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFD, tDgSFD); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + } else if (is_epilogue_row_quant_warp) { + // Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage. + cutlass::arch::warpgroup_reg_alloc<136>(); + if constexpr (kEnableRowQuant) { + using S2RVectorType = uint128_t; + + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % 256; + size_t rng_seed = 0; + size_t rng_offset = 0; + // g2s load all global_a_amax for all groups/tensors + CUTLASS_PRAGMA_NO_UNROLL + for (int g = local_thread_idx; g < num_tensors; g += NumEpilogueRowQuantThreadCount) { + shared_storage.global_a_amax[g] = __ldg(reinterpret_cast(amax_rowwise + g)); + } + // RNG for stochastic rounding + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + // Input/output tensors/partitions for row quant warp + Tensor mQA = + make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, packed_N), dQA)); + Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout); + + Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_N) + // Swizzled shared memory A tile, with layout + Tensor sA = as_position_independent_swizzle_tensor(group_modes<0, 2>( + coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE) + + // Set up layouts for partitioning – tile-by-warp, with vector granularity + using S2RWarpLayout = Layout>; + using WarpGroupLayout = Layout>; + using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{})); + using S2RValLayout = Layout, _1>>; + using S2RAtomA = Copy_Atom; + using R2GAtomQA = Copy_Atom; + using R2GAtomSFA = Copy_Atom; + auto tiled_s2r = make_tiled_copy(S2RAtomA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_QA = make_tiled_copy(R2GAtomQA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_SFA = make_tiled_copy(R2GAtomSFA{}, S2RThreadLayout{}, S2RValLayout{}); + + auto thr_s2r = tiled_s2r.get_slice(local_thread_idx); + auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx); + auto thr_r2g_SFA = tiled_r2g_SFA.get_slice(local_thread_idx); + Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) + + // Allocate temporary register tensors for copying quantization => output + Tensor tQArA = make_tensor_like( + make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) + Tensor tQAgQA = thr_r2g_QA.partition_S(gQA_mn); + Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{})); + + Tensor tQAgSFA = thr_r2g_SFA.partition_S(gSFA_mn); + Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{})); + + // Will result in barrier_id=10 passed to bar.sync instr as cutlass adds 8 + // in order to go over the reserved named barrier count. + constexpr int row_quant_barrier_id = 2; + cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id); + + int group_idx = get_current_tensor_id(shape_rep, num_tensors, + (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, + packed_N, M, offsets); + float a_global_amax_val = shared_storage.global_a_amax[group_idx]; + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + static constexpr float fp4_max_inv = 1.0f / fp4_max; + float global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + + float global_decode_scale = 1.0f / global_encode_scale; + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + auto sfa_converter = cutlass::NumericConverter{}; + do { + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); + + int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors, + global_tile_n_offset * M, packed_N, M, offsets); + if (cur_group_idx != group_idx) { + group_idx = cur_group_idx; + a_global_amax_val = shared_storage.global_a_amax[group_idx]; + // Update group quantization parameters/scaling + global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + global_decode_scale = 1.0f / global_encode_scale; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + } + + auto tQAgSFA_mn = tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto tQAgQA_mn = tQAgQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state); + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + copy(tiled_s2r, tQAsA(_, _, _, mainloop_pipe_consumer_state.index()), tQArA); + cutlass::arch::fence_view_async_shared(); + mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state); + ++mainloop_pipe_consumer_state; + ++k_tile; + + // static int constexpr NumVecs = size(tQArA) / VectorSize; + cutlass::maximum_absolute_value_reduction, + true> + amax_reduction; + auto compute_frgs = reinterpret_cast *>(tQArA.data()); + auto output_frgs = + reinterpret_cast *>(raw_pointer_cast(tQArQA.data())); + Tensor amax = + make_tensor(prepend(take<1, rank(tQArA)>(tQArA.shape()), _1{})); + Tensor pvscales = make_tensor_like(amax); + transformer_engine::curanddx::detail::philox4x32_native_state< + NVTE_BUILD_NUM_PHILOX_ROUNDS> + rng; + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512 + + tiles_in_m * tiles_in_n * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < size<1>(group_modes<1, rank(tQArA)>(tQArA)); v++) { + auto amax_view = group_modes<1, rank(amax)>(amax); + auto pvscales_view = group_modes<1, rank(pvscales)>(pvscales); + auto compute_frgs_up = + cutlass::NumericArrayConverter{}( + compute_frgs[v]); + amax_view(_0{}, v) = amax_reduction(ElementAccumulator(0), compute_frgs_up); + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales_view(_0{}, v) = cutlass::multiplies{}( + amax_view(_0{}, v), global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales_view(_0{}, v) = + cutlass::divides{}(amax_view(_0{}, v), fp4_max); + pvscales_view(_0{}, v) = cutlass::multiplies{}( + pvscales_view(_0{}, v), global_encode_scale); + } + filter(tQArSFA)(v) = sfa_converter(pvscales_view(_0{}, v)); + auto qpvscale_ups = + cutlass::NumericConverter{}(filter(tQArSFA)(v)); + auto qpvscale_scaled = + cutlass::multiplies{}(qpvscale_ups, global_decode_scale); + ElementAccumulator acc_scales; + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal + acc_scales = + cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // Accurate math: compute reciprocal with division + acc_scales = cutlass::divides{}(1.0, qpvscale_scaled); + } + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales, cutlass::platform::numeric_limits::max()); + uint4 random_uint4 = uint4{0, 0, 0, 0}; + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs_up, acc_scale), + *reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = + cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs_up, acc_scale)); + } + } + copy(tiled_r2g_QA, tQArQA, tQAgQA_mn); + copy(tiled_r2g_SFA, filter(tQArSFA), filter(tQAgSFA_mn)); + } + // scheduler.advance(); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + + } else { + cutlass::arch::warpgroup_reg_dealloc<32>(); + } +} // NOLINT(readability/fn_size) + +template +void group_row_col_rht_gemm_ntt_w_sfc_graph_safe( + int packed_sequence_length, int hidden_size, size_t num_tensors, ShapeRepresentation shape_rep, + TA const *A, TB const *B, TQA *QA, TSFA *SFA, TQA *QA_COLWISE, TSFA *SFA_COLWISE, + float *amax_rowwise, float *amax_colwise, const int64_t *offsets, const int64_t *first_dims, + const size_t *rng_state, uint32_t *tile_scheduler_workspace, uint32_t sm_count, + cudaStream_t stream, int k_tile_size = 1024) { + using namespace cute; + static int constexpr SFVecSize = 16; + static int constexpr RhtTensorSize = 16; + + static_assert(RhtTensorSize == 16, "RhtTensorSize must be 16"); + using LinearSFALayout = decltype(make_layout(make_shape(make_shape(Int{}, 0), 0), + make_stride(make_stride(_0{}, _1{}), 0))); + using LinearSFDLayout = decltype(make_layout(make_shape(0, make_shape(Int{}, 0)), + make_stride(0, make_stride(_0{}, _1{})))); + + using SwizzledSFALayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFDLayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFALayout = decltype(tile_to_shape( + SwizzledSFALayoutAtom{}, make_shape(hidden_size, packed_sequence_length), Step<_1, _2>{})); + using SwizzledSFDLayout = decltype(tile_to_shape( + SwizzledSFDLayoutAtom{}, make_shape(hidden_size, packed_sequence_length), Step<_2, _1>{})); + + using SFALayout = cute::conditional_t; + using SFDLayout = cute::conditional_t; + SFALayout sfa_layout; + SFDLayout sfd_layout; + + if constexpr (kEnableSwizzleSFOutput) { + sfa_layout = tile_to_shape(SwizzledSFALayoutAtom{}, + make_shape(hidden_size, packed_sequence_length), Step<_1, _2>{}); + sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, + make_shape(hidden_size, packed_sequence_length), Step<_2, _1>{}); + } else { + sfa_layout = make_layout( + make_shape(make_shape(Int{}, hidden_size / SFVecSize), packed_sequence_length), + make_stride(make_stride(_0{}, _1{}), hidden_size / SFVecSize)); + sfd_layout = make_layout( + make_shape(hidden_size, make_shape(Int{}, packed_sequence_length / SFVecSize)), + make_stride(packed_sequence_length / SFVecSize, make_stride(_0{}, _1{}))); + } + + // Define shapes (dynamic) + auto M = hidden_size; + auto N = packed_sequence_length; + Tensor tensorA = make_tensor(A, make_shape(hidden_size, packed_sequence_length), LayoutLeft{}); + Tensor tensorB = make_tensor(B, make_shape(RhtTensorSize, RhtTensorSize), LayoutLeft{}); + Tensor tensorQA = make_tensor(QA, make_shape(hidden_size, packed_sequence_length), LayoutLeft{}); + Tensor tensorSFA = make_tensor(SFA, sfa_layout); + + // Define strides (from tensors) + auto dA = stride(tensorA); // (dM,dK) + auto dB = stride(tensorB); // (dN,dK) + auto dD = LayoutRight{}; // (dM,dN) + auto dQA = stride(tensorQA); // (dM,dK) + using ClusterShape = Shape<_1, _1, _1>; + auto cluster_shape = ClusterShape{}; + auto cluster_tile_shape = Shape<_128, Int, Int>{}; + auto cluster_tile_mainloop = Shape<_128, Int, _128>{}; + + // Each mainloop / epilogue loads 128 x 64 tiles while each MMA proceeds with 128 x 16 tiles + static int constexpr EpilogueUnrollFactor = + size<2>(cluster_tile_mainloop) / size<2>(cluster_tile_shape); + // Construct the MMA + auto mma = make_tiled_mma( + SM100_MMA_F16BF16_SS(cluster_tile_shape), size<1>(cluster_tile_shape), + UMMA::Major::MN, UMMA::Major::MN>{}, + Layout>{}); + + // Assert that the TiledMMA uses all CTAs in the CGA. + CUTE_STATIC_ASSERT_V(size(cluster_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cluster_tile_shape, tile_shape(mma))); + + // Determine the A and B shapes + auto mma_shape_B = + partition_shape_B(mma, make_shape(size<1>(cluster_tile_shape), size<2>(cluster_tile_shape))); + + using TiledMma = decltype(mma); + using AtomThrID = typename TiledMma::AtomThrID; + + using SmemShape_M = decltype(shape_div( + shape<0>(cluster_tile_shape), + shape_div(shape<0>(cluster_tile_shape), size<0>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div( + shape<1>(cluster_tile_shape), + shape_div(shape<1>(cluster_tile_shape), size<1>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cluster_tile_shape)); + + using SmemLayoutAtomB = + decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + + auto mma_shape_A = partition_shape_A( + mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = + decltype(shape_div(shape<0>(cluster_tile_mainloop), + shape_div(shape<0>(cluster_tile_mainloop), + size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + + static uint32_t constexpr TotalTmemRows = 128; + static uint32_t constexpr Sm100TmemCapacityColumns = 512; + static uint32_t constexpr TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns; + static uint32_t constexpr AccumulatorPipelineStageCount = + TotalTmem / (cute::size<0>(cluster_tile_shape) * cute::size<1>(cluster_tile_shape)); + + // Define the smem layouts (static) + // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory + constexpr int SchedulerPipelineStageCount = 4; + static int constexpr MainloopPipelineBytes = sizeof( + typename cutlass::detail::CustomizedPipelineTmaUmmaAsync<1, Shape<_1, _1, _1>, + Shape<_1, _1, _1>>::SharedStorage); + + static int constexpr SchedulerWorkspaceBytes = sizeof(int) * SchedulerPipelineStageCount; + static int constexpr SchedulerThrottlePipelineBytes = + sizeof(typename cutlass::PipelineAsync::SharedStorage); + static int constexpr SchedulerPipelineBytes = + sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + + static int constexpr TmemDeallocBytes = sizeof(cutlass::arch::ClusterBarrier); + static int constexpr BTensorBytes = cute::size(mma_shape_B) * sizeof(TB); + static int constexpr AccPipelineBytes = sizeof( + typename cutlass::PipelineUmmaAsync>::SharedStorage); + static int constexpr TmemBasePtrsBytes = sizeof(uint32_t); + static int constexpr kBlackwellSmemSize = 232448; // 232KB in bytes + static int constexpr kBytesPerStage = + cute::size(mma_shape_A) * sizeof(TA) + MainloopPipelineBytes; + static int constexpr kReservedBytes = SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes + + SchedulerPipelineBytes + TmemBasePtrsBytes + + TmemDeallocBytes + BTensorBytes + + AccPipelineBytes; // Reserve for barriers and other uses + static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; + auto sP = Int{}; // SMEM pipelines + + auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP), + Step<_2, _1, _3>{}); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, + append(mma_shape_B, _1{})); // (MMA,MMA_N,MMA_K, _1) + auto sD = Layout<_1>{}; // XXX Dummy + + auto tma_load_a = + make_tma_copy_A_sm100(SM90_TMA_LOAD{}, tensorA, sA(_, _, _, 0), cluster_tile_mainloop, mma); + auto tma_load_b = + make_tma_copy_B_sm100(SM90_TMA_LOAD{}, tensorB, sB(_, _, _, 0), cluster_tile_shape, mma); + + // Assert checks on tile sizes -- no predication + assert(M % size<0>(cluster_tile_shape) == 0); + assert(N % size<1>(cluster_tile_shape) == 0); + + dim3 dimBlock(512); + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); + dim3 dimGrid(sm_count, 1, 1); + + int smem_size = sizeof( + SharedStorage); + + auto *kernel_ptr = &group_row_col_rht_gemm_device_graph_safe< + decltype(M), decltype(N), decltype(k_tile_size), decltype(cluster_shape), + decltype(cluster_tile_shape), TA, decltype(dA), decltype(sA), decltype(tma_load_a), TB, + decltype(dB), decltype(sB), decltype(tma_load_b), TD, decltype(dD), decltype(sD), TSFD, + decltype(sfd_layout), TQA, decltype(dQA), TSFA, decltype(sfa_layout), decltype(mma), + AccumulatorPipelineStageCount, SchedulerPipelineStageCount, kEnableStochasticRounding, + kEnableRHTColQuant, kEnableRowQuant, kEnableSwizzleSFOutput, kUseFastMath>; + + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // Set workspace and set to zero + NVTE_CHECK_CUDA(cudaMemsetAsync(reinterpret_cast(tile_scheduler_workspace), 0, + sizeof(uint32_t), stream)); + + // Launch kernel + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream}; + cutlass::Status status = cutlass::launch_kernel_on_cluster( + params, (void const *)kernel_ptr, M, N, k_tile_size, cluster_shape, cluster_tile_shape, A, dA, + sA, tma_load_a, B, dB, sB, tma_load_b, QA, dQA, SFA, sfa_layout, QA_COLWISE, SFA_COLWISE, + amax_rowwise, amax_colwise, offsets, first_dims, num_tensors, shape_rep, + tile_scheduler_workspace, mma, rng_state); + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed."); +} + +} // namespace +} // namespace detail + +void group_hadamard_transform_cast_fusion_graph_safe(const GroupedTensor *input, + GroupedTensor *output, + const Tensor &hadamard_matrix_, + QuantizationConfig &quant_config, + Tensor &quant_workspace, cudaStream_t stream) { + NVTE_API_CALL(group_hadamard_transform_cast_fusion_graph_safe); + + using transformer_engine::detail::kMaxTensorsPerKernel; + using transformer_engine::detail::ShapeRepresentation; + + void *input_base_ptr = reinterpret_cast(input->data.dptr); + // TODO(zhongbo): add input sanity checks here + + bool all_has_row_quant = output->has_data(); + bool all_has_col_quant = output->has_columnwise_data(); + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (use_stochastic_rounding) { + NVTE_CHECK(quant_config.rng_state != nullptr, + "Enabled stochastic rounding without providing RNG state"); + const Tensor &rng_state_tensor = *convertNVTETensorCheck(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + uint32_t *tile_scheduler_workspace = nullptr; + NVTE_CHECK(quant_workspace.data.dptr != nullptr, "Quantization workspace must be provided."); + NVTE_CHECK(quant_workspace.data.buffer_size_bytes() >= sizeof(uint32_t), + "Quantization workspace must be at least 4 bytes."); + tile_scheduler_workspace = reinterpret_cast(quant_workspace.data.dptr); + + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TD = cutlass::float_e2m1_t; + using TSFD = cutlass::float_ue4m3_t; + using TQA = TD; + using TSFA = TSFD; + + checkCuDriverContext(stream); + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; + + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; + + const size_t num_tensors = input->num_tensors; + const size_t first_logical_dim = input->logical_shape.data[0]; + const size_t last_logical_dim = input->logical_shape.data[1]; + // const size_t elts_total = first_logical_dim * last_logical_dim; + NVTE_CHECK(first_logical_dim % 128 == 0, + "First dimension of a grouped tensor should be divisible by 128."); + NVTE_CHECK(last_logical_dim % 128 == 0, + "Last dimension of a grouped tensor should be divisible by 128."); + NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel, + "Number of tensors should be less than or equal to ", kMaxTensorsPerKernel); + + ShapeRepresentation shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + if (output->all_same_shape()) { + shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + } else if (output->all_same_first_dim()) { + shape_rep = ShapeRepresentation::VARYING_LAST_DIM; + } else if (output->all_same_last_dim()) { + shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + } else if (output->varying_both_dims()) { + shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; + } + + TQA *const rowwise_data_base_ptr = reinterpret_cast(output->data.dptr); + TSFA *const rowwise_scale_inv_base_ptr = reinterpret_cast(output->scale_inv.dptr); + TQA *const colwise_data_base_ptr = reinterpret_cast(output->columnwise_data.dptr); + TSFA *const colwise_scale_inv_base_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); + float *const amax_rowwise_base_ptr = reinterpret_cast(output->amax.dptr); + float *const amax_colwise_base_ptr = reinterpret_cast(output->columnwise_amax.dptr); + + const int64_t *const offsets_ptr = reinterpret_cast(output->tensor_offsets.dptr); + const int64_t *const first_dims_ptr = reinterpret_cast(output->first_dims.dptr); + + const bool is_const_last_dim = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || + shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); + NVTE_CHECK(is_const_last_dim, + "Currently we only support const last dimension for graph safe hadamard transform."); + + auto sm_count = transformer_engine::cuda::sm_count(); + + int k_tile_size = 1024; + + const bool use_swizzle_sf_output = output->with_gemm_swizzled_scales; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kEnableStochasticRounding, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_has_col_quant, kEnableRhtColQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_has_row_quant, kEnableRowQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_swizzle_sf_output, kEnableSwizzleSFOutput, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config.use_fast_math, kUseFastMath, + + if constexpr (kEnableRhtColQuant || kEnableRowQuant) { + detail::group_row_col_rht_gemm_ntt_w_sfc_graph_safe< + kEnableStochasticRounding, kEnableRhtColQuant, kEnableRowQuant, + kEnableSwizzleSFOutput, TA, TB, TQA, TSFA, TD, TSFD, kUseFastMath>( + /*packed_sequence_length=*/first_logical_dim, + /*hidden_size=*/last_logical_dim, + /*num_tensors=*/num_tensors, + /*shape_rep=*/shape_rep, + /*A=*/reinterpret_cast(input_base_ptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*QA=*/reinterpret_cast(rowwise_data_base_ptr), + /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), + /*QA_COLWISE=*/reinterpret_cast(colwise_data_base_ptr), + /*SFA_COLWISE=*/reinterpret_cast(colwise_scale_inv_base_ptr), + /*amax_rowwise=*/reinterpret_cast(amax_rowwise_base_ptr), + /*amax_colwise=*/reinterpret_cast(amax_colwise_base_ptr), + /*offsets=*/offsets_ptr, + /*first_dims=*/first_dims_ptr, + /*rng_state=*/rng_state, + /*tile_scheduler_workspace=*/tile_scheduler_workspace, + /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size); + } else { + NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", + kEnableRhtColQuant, ", kEnableRowQuant=", kEnableRowQuant, ")."); + } + + ););););); +} + +} // namespace transformer_engine + +void nvte_group_hadamard_transform_cast_fusion_graph_safe( + const NVTEGroupedTensor input, NVTEGroupedTensor output, const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, NVTETensor quant_workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_hadamard_transform_cast_fusion_graph_safe); + using namespace transformer_engine; + + GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + + Tensor *quant_workspace_tensor = convertNVTETensorCheck(quant_workspace); + + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + if (input_tensor->num_tensors == 0) { + return; + } + + // Call the multi-tensor Hadamard transform amax implementation. + group_hadamard_transform_cast_fusion_graph_safe( + input_tensor, output_tensor, *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, + *quant_workspace_tensor, stream); +} diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 5d45996dc..07813be05 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -323,8 +323,6 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t is_master_thread); } - ptx::fence_proxy_async_shared_cta(); - // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], 0); @@ -356,6 +354,9 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t // memory. __syncthreads(); } + + // Ensure generic shared-memory accesses are visible before the next TMA write. + ptx::fence_proxy_async_shared_cta(); } } diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index 85bb98f0f..1e40fd4a5 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -660,7 +660,8 @@ __global__ static void group_rht_gemm_device( // Initialize RNG for tile const size_t rng_sequence = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + transformer_engine::curanddx::detail::philox4x32_native_state + rng; rng.init(rng_seed, rng_sequence, rng_offset); uint4 random_uint4 = uint4{0, 0, 0, 0}; diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 1ef1f81e8..4013fdf11 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -891,7 +891,9 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( // Prepare stochastic rounding random state if enabled uint4 random_uint4 = uint4{0, 0, 0, 0}; - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + transformer_engine::curanddx::detail::philox4x32_native_state< + NVTE_BUILD_NUM_PHILOX_ROUNDS> + rng; // "Prefetch" a stochastic rounding state for the first tile if constexpr (kEnableStochasticRounding) { const size_t rng_sequence = global_thread_idx + k_tile * 512 + @@ -1048,7 +1050,9 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( Tensor amax = make_tensor(prepend(take<1, rank(tQArA)>(tQArA.shape()), _1{})); Tensor pvscales = make_tensor_like(amax); - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + transformer_engine::curanddx::detail::philox4x32_native_state< + NVTE_BUILD_NUM_PHILOX_ROUNDS> + rng; if constexpr (kEnableStochasticRounding) { const size_t rng_sequence = global_thread_idx + k_tile * 512 + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512 + diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index de930aa2c..4adc83688 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -266,8 +266,6 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor is_master_thread); } - ptx::fence_proxy_async_shared_cta(); - // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], 0); @@ -299,6 +297,9 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor // memory. __syncthreads(); } + + // Ensure generic shared-memory accesses are visible before the next TMA write. + ptx::fence_proxy_async_shared_cta(); } } diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 0696deaaa..1a2462e6f 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -516,7 +516,8 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, const size_t rng_sequence = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + transformer_engine::curanddx::detail::philox4x32_native_state + rng; rng.init(rng_seed, rng_sequence, rng_offset); uint4 random_uint4 = uint4{0, 0, 0, 0}; diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 55cd44d9d..854f52c20 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -31,6 +31,7 @@ extern "C" { enum class NVTE_Activation_Type { GELU, GEGLU, + GLU, SILU, SWIGLU, RELU, @@ -52,6 +53,17 @@ enum class NVTE_Activation_Type { */ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the GeLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the SiLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -62,6 +74,17 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the SiLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -72,6 +95,17 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the ReLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the Quick GeLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -82,6 +116,17 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Quick GeLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the Squared ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -92,6 +137,17 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Squared ReLU activation of the grouped input. + * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); + /*! \brief Computes the GeLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -104,6 +160,19 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the GeLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); + /*! \brief Computes the SiLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -116,6 +185,19 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the SiLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); + /*! \brief Computes the ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -128,6 +210,19 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the ReLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); + /*! \brief Computes the Quick GeLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -140,6 +235,19 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Quick GeLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); + /*! \brief Computes the Squared ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -152,6 +260,45 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Squared ReLU activation gradient of the grouped input. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] grad Incoming grouped gradient. + * \param[in] input Input grouped tensor for activation. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTETensor output, cudaStream_t stream); + +/*! \brief Computes the GLU (Gated Linear Unit) activation of the input. + * GLU(a,b) = sigmoid(a) * b + * See "Language Modeling with Gated Convolutional Networks" (arXiv:1612.08083) + * and "GLU Variants Improve Transformer" (arXiv:2002.05202). + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes sigmoid(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Computes the GLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream); + /*! \brief Computes the gated GeLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 576494a4d..755052d6d 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -89,6 +89,18 @@ extern "C" { */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Casts input grouped tensor to MXFP8. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. See file level comments. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in,out] output Output grouped MXFP8 tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream); + /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. * The type of quantized tensor in the output depends on the scaling mode of the output @@ -132,6 +144,27 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output, void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Casts input grouped tensor to MXFP8. Additionally, reduces the input along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); + /*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the GeLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -155,6 +188,31 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of GeLU operation on the grouped input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the GeLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); + /*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the SiLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -178,6 +236,31 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of SiLU operation on the grouped input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the SiLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); + /*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the ReLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -201,6 +284,31 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of ReLU operation on the grouped input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the ReLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); + /*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Quick GeLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -224,6 +332,31 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of Quick GeLU operation on the grouped input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Quick GeLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); + /*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Squared ReLU backward along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -247,6 +380,31 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); +/*! \brief Computes backward of Squared ReLU operation on the grouped input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Squared ReLU backward along columns. + * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in] act_input Activation input grouped tensor. + * \param[in,out] output Output grouped FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor act_input, NVTEGroupedTensor output, + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); + /*! \brief Casts input tensor from reduced to higher precision. * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, * the block dequantization (MXFP8) of the specified shape of the block will be used. @@ -261,11 +419,11 @@ void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t str /*! \brief Casts multiple input tensors to quantized output tensors. * - * \param[in] inputs List of input tensors to be cast. - * \param[in,out] outputs List of output quantized tensors. - * \param[in] quant_config (Optional) Quantization configurations. - * \param[in] num_tensors Number of input and output tensors. - * \param[in] stream CUDA stream used for the operation. + * \param[in] inputs List of input tensors to be cast. + * \param[in,out] outputs List of output quantized tensors. + * \param[in] quant_config (Optional) Quantization configurations. + * \param[in] num_tensors Number of input and output tensors. + * \param[in] stream CUDA stream used for the operation. */ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, const NVTEQuantizationConfig quant_config, const size_t num_tensors, @@ -274,11 +432,11 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, /*! \brief Casts grouped input tensor to quantized output tensors. * * \param[in] input Input tensor to be cast. - * \param[in,out] outputs Output quantized tensors. - * \param[in] split_sections Split sections of the input tensor. - * \param[in] num_tensors Number of output tensors. + * \param[in,out] outputs Output quantized tensors. + * \param[in] split_sections Split sections of the input tensor. + * \param[in] num_tensors Number of output tensors. * \param[in] quant_config (Optional) Quantization configurations. - * \param[in] stream CUDA stream used for the operation. + * \param[in] stream CUDA stream used for the operation. */ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections, size_t num_tensors, diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm.h b/transformer_engine/common/include/transformer_engine/comm_gemm.h index 06b56789a..65d3aa5d9 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm.h @@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank /*! \brief Destroy a comm-gemm context. * * \param[in] ctx Context to destroy. + * + * It's the caller's responsibility to synchronize all streams involved before calling this function. */ void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 726cc4e47..d6ef408cd 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -221,316 +221,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] window_size_right Sliding window size (the right half). * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. + * \param[in] deterministic Whether determinism is required or not. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph); - -/*! \brief Compute dot product attention with packed QKV input. - * - * \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. - * - * Computes: - * - P = Q * Transpose(K) + Bias - * - S = ScaleMaskSoftmax(P) - * - D = Dropout(S) - * - O = D * Transpose(V) - * - * Support Matrix for ROCm AOTriton: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | aotriton| FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO | NO/CAUSAL | Yes | arbitrary | arbitrary | - \endverbatim - * - * Support Matrix: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | BS3HD,SB3HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | - | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | - | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | - \endverbatim - * - * Notes: - * - * Tensor `cu_seqlens_padded` helps identify the correct offsets of different sequences - * in tensors Q, K, V and O. - * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, - * the offset tensor is not used in the attention calculation and can be set to empty `NVTETensor`. - * When the QKV format is `thd`, this tensor should follow the following rules. - * When there is no padding between sequences, the offset tensor should be equal to `cu_seqlens`, - * When there is padding between sequences, users are responsible to adjust the offsets as needed. - * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have - * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. - * - * \param[in] QKV The QKV tensor in packed format, H3D or 3HD. - * \param[in] Bias The Bias tensor. - * \param[in] SoftmaxOffset The SoftmaxOffset tensor. - * \param[in,out] S The S tensor. - * \param[out] O The output O tensor. - * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, - * e.g. M, ZInv, rng_state. - * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. - * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1]. - * \param[in] rng_state Seed and offset of CUDA random number generator. - * \param[in] max_seqlen Max sequence length used for computing, - * it may be >= max(seqlen_i) for i=0,...batch_size-1. - * \param[in] is_training Whether this is in training mode or inference. - * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. - * \param[in] cuda_graph Whether cuda graph capture is enabled or not. - * \param[in] attn_scale Scaling factor for Q * K.T. - * \param[in] dropout Dropout probability. - * \param[in] qkv_layout QKV tensor's layout. - * \param[in] bias_type Bias type. - * \param[in] attn_mask_type Attention mask type. - * \param[in] softmax_type Attention softmax type. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). - * \param[in] workspace Workspace tensor. - * \param[in] stream CUDA stream used for this operation. - */ -[[deprecated( - "nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " - "Q, K, V tensors instead.")]] -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); - -/*! \brief Compute the backward of the dot product attention with packed QKV input. - * - * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. - * Support Matrix for ROCm AOTriton: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | aotriton| FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO | NO/CAUSAL | Yes | arbitrary | arbitrary | - \endverbatim - * - * Support Matrix: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | BS3HD,SB3HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | - | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | - | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | - \endverbatim - * - * Notes: - * - * Tensor `cu_seqlens_padded` helps identify the correct offsets of different sequences - * in tensors Q, K, V and O. - * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, - * the offset tensor is not used in the attention calculation and can be set to empty `NVTETensor`. - * When the QKV format is `thd`, this tensor should follow the following rules. - * When there is no padding between sequences, the offset tensor should be equal to `cu_seqlens`, - * When there is padding between sequences, users are responsible to adjust the offsets as needed. - * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have - * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. - * - * \param[in] QKV The QKV tensor in packed format, H3D or 3HD. - * \param[in] O The O tensor from forward. - * \param[in] dO The gradient of the O tensor. - * \param[in] S The S tensor. - * \param[in,out] dP The gradient of the P tensor. - * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, - * e.g. M, ZInv, rng_state. - * \param[out] dQKV The gradient of the QKV tensor. - * \param[out] dBias The gradient of the Bias tensor. - * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. - * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. - * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1]. - * \param[in] max_seqlen Max sequence length used for computing, - * it may be >= max(seqlen_i) for i=0,...batch_size-1. - * \param[in] attn_scale Scaling factor for Q * K.T. - * \param[in] dropout Dropout probability. - * \param[in] qkv_layout QKV tensor's layout. - * \param[in] bias_type Bias type. - * \param[in] attn_mask_type Attention mask type. - * \param[in] softmax_type Attention softmax type. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). - * \param[in] deterministic Whether to execute with deterministic behaviours. - * \param[in] cuda_graph Whether cuda graph capture is enabled or not. - * \param[in] workspace Workspace tensor. - * \param[in] stream CUDA stream used for this operation. - */ -[[deprecated( - "nvte_fused_attn_bwd_qkvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate " - "Q, K, V tensors instead.")]] -void nvte_fused_attn_bwd_qkvpacked( - const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, - NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, - NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); - -/*! \brief Compute dot product attention with packed KV input. - * - * \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. - * - * Computes: - * - P = Q * Transpose(K) + Bias - * - S = ScaleMaskSoftmax(P) - * - D = Dropout(S) - * - O = D * Transpose(V) - * - * Support Matrix for ROCm AOTriton: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | aotriton| FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO | NO/CAUSAL | Yes | arbitrary | arbitrary | - \endverbatim - * - * Support Matrix: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | - | 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | - \endverbatim - * - * Notes: - * - * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` - * help identify the correct offsets of different sequences in tensors Q, K, V and O. - * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, - * offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. - * When the QKV format is `thd`, these tensors should follow the following rules. - * When there is no padding between sequences, the offset tensors should be equal to - * `cu_seqlens_q` and `cu_seqlens_kv` respectively. - * When there is padding between sequences, users are responsible to adjust the offsets as needed. - * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have - * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. - * - * \param[in] Q The Q tensor, in HD layouts. - * \param[in] KV The KV tensor, in 2HD or H2D layouts. - * \param[in] Bias The Bias tensor. - * \param[in] SoftmaxOffset The SoftmaxOffset tensor. - * \param[in,out] S The S tensor. - * \param[out] O The output O tensor. - * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, - * e.g. M, ZInv, rng_state. - * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. - * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. - * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. - * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. - * \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k]. - * \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v]. - * \param[in] rng_state Seed and offset of CUDA random number generator. - * \param[in] max_seqlen_q Max sequence length used for computing for Q. - * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. - * \param[in] max_seqlen_kv Max sequence length used for computing for KV. - * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. - * \param[in] is_training Whether this is in training mode or inference. - * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. - * \param[in] cuda_graph Whether cuda graph capture is enabled or not. - * \param[in] attn_scale Scaling factor for Q * K.T. - * \param[in] dropout Dropout probability. - * \param[in] qkv_layout QKV tensor's layout. - * \param[in] bias_type Bias type. - * \param[in] attn_mask_type Attention mask type. - * \param[in] softmax_type Attention softmax type. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). - * \param[in] workspace Workspace tensor. - * \param[in] stream CUDA stream used for this operation. - */ -[[deprecated( - "nvte_fused_attn_fwd_kvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " - "Q, K, V tensors instead.")]] -void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, - NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); - -/*! \brief Compute the backward of the dot product attention with packed KV input. - * - * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. - * Support Matrix for ROCm AOTriton: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | aotriton| FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO | NO/CAUSAL | Yes | arbitrary | arbitrary | - \endverbatim - * - * Support Matrix: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | - | 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | - \endverbatim - * - * Notes: - * - * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` - * help identify the correct offsets of different sequences in tensors Q, K, V and O. - * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, - * offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. - * When the QKV format is `thd`, these tensors should follow the following rules. - * When there is no padding between sequences, the offset tensors should be equal to - * `cu_seqlens_q` and `cu_seqlens_kv` respectively. - * When there is padding between sequences, users are responsible to adjust the offsets as needed. - * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have - * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. - * - * \param[in] Q The Q tensor, in HD layouts. - * \param[in] KV The KV tensor, in H2D or 2HD layouts. - * \param[in] O The O tensor from forward. - * \param[in] dO The gradient of the O tensor. - * \param[in] S The S tensor. - * \param[in,out] dP The gradient of the P tensor. - * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, - * e.g. M, ZInv, rng_state. - * \param[out] dQ The gradient of the Q tensor. - * \param[out] dKV The gradient of the KV tensor. - * \param[out] dBias The gradient of the Bias tensor. - * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. - * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. - * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. - * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. - * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. - * \param[in] max_seqlen_q Max sequence length used for computing for Q. - * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. - * \param[in] max_seqlen_kv Max sequence length used for computing for KV. - * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. - * \param[in] attn_scale Scaling factor for Q * K.T. - * \param[in] dropout Dropout probability. - * \param[in] qkv_layout QKV tensor's layout. - * \param[in] bias_type Bias type. - * \param[in] attn_mask_type Attention mask type. - * \param[in] softmax_type Attention softmax type. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). - * \param[in] deterministic Whether to execute with deterministic behaviours. - * \param[in] cuda_graph Whether cuda graph capture is enabled or not. - * \param[in] workspace Workspace tensor. - * \param[in] stream CUDA stream used for this operation. - */ -[[deprecated( - "nvte_fused_attn_bwd_kvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate " - "Q, K, V tensors instead.")]] -void nvte_fused_attn_bwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream); + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); /*! \brief Compute dot product attention with separate Q, K and V. * @@ -602,19 +300,23 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd( - const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -679,6 +381,7 @@ void nvte_fused_attn_fwd( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -694,8 +397,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream); + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/common/include/transformer_engine/fused_router.h b/transformer_engine/common/include/transformer_engine/fused_router.h index 1f026a703..794880d32 100644 --- a/transformer_engine/common/include/transformer_engine/fused_router.h +++ b/transformer_engine/common/include/transformer_engine/fused_router.h @@ -23,8 +23,8 @@ extern "C" { * \param[in] num_groups Number of groups in grouped topk. * \param[in] group_topk Grouped topk value. * \param[in] scaling_factor Scaling factor. - * \param[in] score_function Score function, 0: sigmoid, 1: softmax. - * \param[in] expert_bias Expert bias. (Only used at the sigmoid case) + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. + * \param[in] expert_bias Expert bias. (Used at the sigmoid/sqrtsoftplus cases) * \param[out] probs Output tensor for probabilities. * \param[out] routing_map Output tensor for routing map. * \param[out] intermediate_output Output tensor for intermediate output. (Softmax/sigmoid output) @@ -46,7 +46,7 @@ void nvte_fused_topk_with_score_function_forward( * \param[in] topk Topk value. * \param[in] use_pre_softmax Whether to use softmax before topk. * \param[in] scaling_factor Scaling factor. - * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. * \param[out] grad_logits Gradient of logits. * \param[in] stream CUDA stream used for the operation. */ @@ -63,7 +63,7 @@ void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map, * \param[in] num_tokens Number of tokens. * \param[in] num_experts Number of experts. * \param[in] topk Topk value. - * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. * \param[out] scores Output tensor for scores. * \param[in] routing_map Routing map. * \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output) @@ -82,7 +82,7 @@ void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_ * \param[in] num_tokens Number of tokens. * \param[in] num_experts Number of experts. * \param[in] topk Topk value. - * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. * \param[out] grad_logits Gradient of logits. * \param[in] stream CUDA stream used for the operation. */ diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index fcff444e7..e09a0f154 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -13,6 +13,8 @@ #ifndef TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_ +#include + #include "transformer_engine.h" #ifdef __cplusplus @@ -22,6 +24,9 @@ extern "C" { /*! \brief Configuration for matrix multiplication. */ typedef void *NVTEMatmulConfig; +/*! \brief Configuration for grouped matrix multiplication. */ +typedef void *NVTEGroupedMatmulConfig; + /*! \enum NVTEMatmulConfigAttribute * \brief Type of option for matrix multiplication. */ @@ -54,6 +59,38 @@ enum NVTEMatmulConfigAttribute { kNVTEMatmulConfigNumAttributes }; +/*! \enum NVTEGroupedMatmulConfigAttribute + * \brief Type of option for grouped matrix multiplication. + */ +enum NVTEGroupedMatmulConfigAttribute { + /*! Average M dimension hint + * + * Optional hint for average M dimension across all matrices in the group. + * Used by cuBLASLt for algorithm selection heuristics. If not set, + * computed automatically from D's logical shape. + */ + kNVTEGroupedMatmulConfigAvgM = 0, + /*! Average N dimension hint + * + * Optional hint for average N dimension across all matrices in the group. + * Used by cuBLASLt for algorithm selection heuristics. If not set, + * computed automatically from D's logical shape. + */ + kNVTEGroupedMatmulConfigAvgN = 1, + /*! Average K (reduction) dimension hint + * + * Optional hint for average K dimension across all matrices in the group. + * Used by cuBLASLt for algorithm selection heuristics. If not set, + * computed automatically from A's logical shape. + */ + kNVTEGroupedMatmulConfigAvgK = 2, + /*! Number of streaming multiprocessors to use in GEMM kernel. */ + kNVTEGroupedMatmulConfigSMCount = 3, + /*! Split accumulator mode. Only taken into account on Hopper. Default: true. */ + kNVTEGroupedMatmulConfigUseSplitAccumulator = 4, + kNVTEGroupedMatmulConfigNumAttributes +}; + /*! \brief Create a matrix multiplication configuration. */ NVTEMatmulConfig nvte_create_matmul_config(); @@ -84,6 +121,38 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA /*! \brief Destroy a matrix multiplication configuration. */ void nvte_destroy_matmul_config(NVTEMatmulConfig config); +/*! \brief Create a grouped matrix multiplication configuration. */ +NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config(); + +/*! \brief Query an option in grouped matrix multiplication configuration. + * + * \param[in] config Grouped matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to write option value. Ignored if + * NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. + */ +void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, + NVTEGroupedMatmulConfigAttribute attr, void *buf, + size_t size_in_bytes, size_t *size_written); + +/*! \brief Set an option in grouped matrix multiplication configuration. + * + * \param[in] config Grouped matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to read option value. + * \param[in] size_in_bytes Size of buf. + */ +void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, + NVTEGroupedMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes); + +/*! \brief Destroy a grouped matrix multiplication configuration. */ +void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config); + /*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated). * * This has been deprecated in favor of nvte_cublas_gemm_v2. @@ -230,6 +299,116 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C + * + * \note Requires cuBLAS 13.2+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture. + * Will error at runtime if compiled with an older cuBLAS version or run on + * a pre-Blackwell GPU. + * + * Performs batched GEMM on a collection of matrices with potentially different shapes. + * All tensors in the group must have compatible dimensions for matrix multiplication. + * Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous + * memory layout and shape metadata. + * + * \param[in] A Input grouped tensor A. + * \param[in] transa Whether to transpose A matrices. + * \param[in] B Input grouped tensor B. + * \param[in] transb Whether to transpose B matrices. + * \param[in] C Input grouped tensor C (can be NULL for beta=0). + * \param[out] D Output grouped tensor D. + * \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements). + * \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements). + * \param[in] workspace_setup Workspace tensor for pointer array setup. + * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. + * \param[in] config Additional configuration (can be NULL for defaults). + * \param[in] stream CUDA stream for the operation. + * + * Requirements: + * - cuBLAS 13.2+ (CUDA 13.1+) + * - Blackwell (SM100) or newer GPU architecture + * - A, B, C (if provided), D must have the same num_tensors + * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] + * - Shape compatibility: if transa=false, transb=false: + * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) + */ +/*! \brief Return the required size in bytes for the setup workspace of grouped GEMM. + * + * The setup workspace stores pointer arrays and per-matrix dimension arrays used + * by the grouped GEMM kernel. Its size depends only on the number of tensors (GEMMs) + * in the group and is independent of matrix dimensions. + * + * Pass the result as the size of the workspace_setup tensor in nvte_grouped_gemm. + * + * \param[in] num_tensors Number of tensors (GEMMs) in the group. + * \return Required size in bytes for workspace_setup. + */ +size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors); + +/*! \brief Convert a device array of int32 values to int64 values. + * + * Useful for preparing group_sizes for nvte_grouped_gemm when the caller + * holds int32 sizes and needs int64 values on the device. + * + * \param[in] src Device pointer to source int32 array. + * \param[out] dst Device pointer to destination int64 array. + * \param[in] n Number of elements. + * \param[in] stream CUDA stream. + */ +void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream); + +void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, + const NVTETensor beta, NVTETensor workspace_setup, + NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, + cudaStream_t stream); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Grouped matrix multiplication with discrete A input tensors. + * + * Identical to nvte_grouped_gemm, but A is provided as a list of tensors + * instead of NVTEGroupedTensor. This enables discrete per-expert weights as inputA + * for Grouped GEMM. + * + * \param[in] A_list List of A tensors (length = num_tensors). + * \param[in] num_a_tensors Number of tensors in A_list. + */ +void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num_a_tensors, + int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, + const NVTETensor alpha, const NVTETensor beta, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEGroupedMatmulConfig config, cudaStream_t stream); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Grouped matrix multiplication with discrete output tensors. +* +* Identical to nvte_grouped_gemm, but C and D are provided as lists of tensors +* instead of NVTEGroupedTensor. This enables accumulation into non-contiguous +* per-expert buffers (for wgrads). +* +* \param[in] C_list Optional list of C tensors (length = num_tensors). +* \param[in] num_c_tensors Number of tensors in C_list (Can be 0 if C is not provided). +* \param[out] D_list List of D tensors (length = num_tensors). +* \param[in] num_d_tensors Number of tensors in D_list. +* \note All tensors in C_list and D_list must share the same dtype. +*/ +void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, + const NVTEGroupedTensor B, int transb, + const NVTETensor *C_list, size_t num_c_tensors, + NVTETensor *D_list, size_t num_d_tensors, + const NVTETensor alpha, const NVTETensor beta, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEGroupedMatmulConfig config, cudaStream_t stream); + +/*! \brief Grouped bias add for grouped GEMM outputs. +* +* Requires uniform last-dimension across all output tensors and bias tensors. +*/ +void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus @@ -335,6 +514,77 @@ class MatmulConfigWrapper { NVTEMatmulConfig config_ = nullptr; }; +/*! \struct GroupedMatmulConfigWrapper + * \brief C++ wrapper for NVTEGroupedMatmulConfig. + */ +class GroupedMatmulConfigWrapper { + public: + GroupedMatmulConfigWrapper() : config_{nvte_create_grouped_matmul_config()} {} + + GroupedMatmulConfigWrapper(const GroupedMatmulConfigWrapper &) = delete; + GroupedMatmulConfigWrapper &operator=(const GroupedMatmulConfigWrapper &) = delete; + + GroupedMatmulConfigWrapper(GroupedMatmulConfigWrapper &&other) : config_{other.config_} { + other.config_ = nullptr; + } + GroupedMatmulConfigWrapper &operator=(GroupedMatmulConfigWrapper &&other) { + if (config_ != nullptr) { + nvte_destroy_grouped_matmul_config(config_); + } + config_ = other.config_; + other.config_ = nullptr; + return *this; + } + + ~GroupedMatmulConfigWrapper() { + if (config_ != nullptr) { + nvte_destroy_grouped_matmul_config(config_); + config_ = nullptr; + } + } + + /*! \brief Get the underlying NVTEGroupedMatmulConfig. + * + * \return NVTEGroupedMatmulConfig held by this GroupedMatmulConfigWrapper. + */ + operator NVTEGroupedMatmulConfig() const noexcept { return config_; } + + /*! \brief Set average M dimension hint for algorithm selection. */ + void set_avg_m(int64_t avg_m) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgM, &avg_m, + sizeof(int64_t)); + } + + /*! \brief Set average N dimension hint for algorithm selection. */ + void set_avg_n(int64_t avg_n) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgN, &avg_n, + sizeof(int64_t)); + } + + /*! \brief Set average K dimension hint for algorithm selection. */ + void set_avg_k(int64_t avg_k) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgK, &avg_k, + sizeof(int64_t)); + } + + /*! \brief Set number of streaming multiprocessors to use. */ + void set_sm_count(int sm_count) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigSMCount, &sm_count, + sizeof(int)); + } + + /*! \brief Set split accumulator mode. Only taken into account on Hopper. */ + void set_use_split_accumulator(bool use_split_accumulator) { + const auto val = static_cast(use_split_accumulator); + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigUseSplitAccumulator, + &val, sizeof(val)); + } + + private: + /*! \brief Wrapped NVTEGroupedMatmulConfig. */ + NVTEGroupedMatmulConfig config_ = nullptr; +}; + } // namespace transformer_engine #endif // __cplusplus diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index 68722d61a..fcd37cfc1 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -90,6 +90,24 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp int random_sign_mask, int random_sign_mask_t, cudaStream_t stream); +/*! \brief Grouped-tensor amax with Hadamard transform (graph safe, device-managed grouping). + * + * This function is experimental and the API is not stable. + * + * This API assumes that the split info (grouping of tensors) is on device and unknown to the host; + * therefore, this is a graph safe API and the grouped-tensor argument is passed as a single device structure. + * + * \param[in] input NVTEGroupedTensor representing grouped input tensors. + * \param[in,out] output NVTEGroupedTensor for output amax (row/col). Only the row-wise and + * column-wise amaxes are updated. + * \param[in] random_sign_mask 16-bit sign mask for RHT. + * \param[in] random_sign_mask_t 16-bit sign mask for transposed RHT. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_hadamard_transform_amax_graph_safe(const NVTEGroupedTensor input, + NVTEGroupedTensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream); + /*! * \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation. * @@ -128,6 +146,22 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso const NVTEQuantizationConfig quant_config, NVTETensor quant_workspace, cudaStream_t stream); +/*! + * \brief Perform the grouped-tensor Hadamard transform cast fusion operation in graph-safe mode. + * + * This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated. + * + * \param[in] input NVTEGroupedTensor representing grouped input tensors. + * \param[in,out] output NVTEGroupedTensor for output (row/column-wise quantized results). + * \param[in] hadamard_matrix Hadamard matrix to use for transformation. + * \param[in] quant_config Quantization configuration. + * \param[in] quant_workspace Workspace buffer. Must be at least 4 bytes. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_hadamard_transform_cast_fusion_graph_safe( + const NVTEGroupedTensor input, NVTEGroupedTensor output, const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, NVTETensor quant_workspace, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index 303801a88..09ab260f1 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -233,17 +233,34 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor * \warning This API is **experimental** and subject to change. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. - * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[out] is_infinite Whether the kernel detected a non-finite input value. * \param[in,out] tensor_lists 2D array of input tensors. * \param[in] num_tensor_lists Size (dim0) of tensor_lists. * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. * \param[in] scale Scalar for the scaling operation. * \param[in] stream CUDA stream used for this operation. */ -void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, +void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor is_infinite, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, float scale, cudaStream_t stream); +/*! \brief Check overflow and scale a list of tensors. scale is tensor input. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[out] is_infinite Whether the kernel detected a non-finite input value. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] scale Tensor for the scaling operation. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor is_infinite, + NVTETensor **tensor_lists, const size_t num_tensor_lists, + const size_t num_tensors_per_list, NVTETensor scale, + cudaStream_t stream); + /*! \brief Check overflow and scale a list of tensors. * * \warning This API is **experimental** and subject to change. @@ -296,6 +313,17 @@ void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor ** void nvte_group_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections, size_t num_tensors, cudaStream_t stream); +/*! \brief Grouped-tensor amax without doing hadamard transform. + * + * This function is experimental and the API is not stable. + * + * \param[in] input NVTEGroupedTensor Input tensor. + * \param[in,out] output NVTEGroupedTensor Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_amax_graph_safe(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index effc6bad0..ac4ef9c97 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -338,6 +338,118 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r const NVTETensor inpB, const bool use_rowwise_amax_B, float alpha_in, NVTETensor alpha_out, cudaStream_t stream); +/*! \brief Compute tile-level amax for a partial shard of a 2D tensor. + * + * For NVFP4 2D quantization with 16x16 tiles. Computes the maximum absolute + * value within each tile, but only for elements in [start_offset, start_offset + len) + * of the flattened tensor. Used in distributed settings where each rank owns a shard. + * + * \param[in] inp Input tensor (partial shard, high-precision). + * \param[out] amax Output amax buffer [tile_rows, tile_cols], float32. + * \param[in] h Number of rows in the full 2D tensor. + * \param[in] w Number of columns in the full 2D tensor. + * \param[in] amax_stride_h Stride for amax in tile-row dimension. + * \param[in] amax_stride_w Stride for amax in tile-col dimension. + * \param[in] start_offset Starting element offset in the flattened tensor. + * \param[in] block_len Tile dimension (must be 16 for NVFP4 2D). + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, size_t w, + size_t amax_stride_h, size_t amax_stride_w, + size_t start_offset, size_t block_len, cudaStream_t stream); + +/*! \brief Cast a partial shard of a tensor to NVFP4 using 2D tile-based quantization. + * + * Quantizes elements in [start_offset, start_offset + len) of the flattened tensor + * using precomputed per-tile scales. Each 16x16 tile uses its own scale factor. + * Used in distributed settings where each rank casts its owned shard. + * + * \param[in] inp Input tensor (partial shard, high-precision). + * \param[out] out Output NVFP4 packed tensor (2 values per byte). + * \param[in] scale Per-tile scale factors [tile_rows, tile_cols], float32. + * \param[in] global_scale Global scale factor [1], float32. + * \param[in] h Number of rows in the full 2D tensor. + * \param[in] w Number of columns in the full 2D tensor. + * \param[in] scale_stride_h Stride for scale in tile-row dimension. + * \param[in] scale_stride_w Stride for scale in tile-col dimension. + * \param[in] start_offset Starting element offset in the flattened tensor. + * \param[in] block_len Tile dimension (must be 16 for NVFP4 2D). + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTETensor scale, + const NVTETensor global_scale, size_t h, size_t w, + size_t scale_stride_h, size_t scale_stride_w, size_t start_offset, + size_t block_len, cudaStream_t stream); + +/*! \brief Expand tile-level scales to row-level scales and convert to FP8 E4M3, used in partial cast. + * + * Each tile row's scale is repeated block_len times in the output. + * + * \param[in] input Input tensor with tile scales [tile_rows, tile_cols], float32. + * \param[out] output Output tensor with expanded scales [rows_padded, tile_cols], uint8 (E4M3). + * \param[in] tile_rows Number of tile rows. + * \param[in] tile_cols Number of tile columns. + * \param[in] rows_padded Padded row count in output. + * \param[in] block_len Block length (typically 16 for NVFP4). + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, size_t tile_rows, + size_t tile_cols, size_t rows_padded, size_t block_len, + cudaStream_t stream); + +/*! \brief Compute per-block decode scale from block amax and global amax. + * + * Computes: + * global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax + * per_block_decode_scale = block_amax / fp4_max * global_scale + * + * This matches the CUDA device function compute_decoding_scaling_factor() in core_nvfp4.cuh. + * + * \param[in] block_amax Input block amax tensor [tile_rows, tile_cols], float32. + * \param[out] scale Output scale tensor [tile_rows, tile_cols], float32. + * \param[in] global_amax Global amax tensor (single element), float32. Avoids D2H transfer. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_compute_per_block_scale(const NVTETensor block_amax, NVTETensor scale, + const NVTETensor global_amax, cudaStream_t stream); + +/*! \brief Fused kernel for NVFP4 scale computation. + * + * Fuses three operations into one kernel: + * 1. Compute per-block decode scales from block amax and global amax + * 2. Copy global amax to target tensor + * 3. Expand tile-level scales to row-level and convert to FP8 E4M3 + * + * Saves 2 kernel launches per parameter. + * + * \param[in] block_amax Input block amax tensor [tile_rows, tile_cols], float32. + * \param[in] global_amax Global amax tensor [1], float32. + * \param[out] per_block_scale Output per-block scale [tile_rows, tile_cols], float32 (for partial_cast). + * \param[out] target_scale Output scale tensor [rows_padded, tile_cols], uint8 (E4M3). + * \param[out] target_amax Output amax tensor [1], float32 (copy of global_amax). + * \param[in] tile_rows Number of tile rows. + * \param[in] tile_cols Number of tile columns. + * \param[in] rows_padded Total padded rows in output. + * \param[in] block_len Block length (16 for NVFP4). + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_fused_scale(const NVTETensor block_amax, const NVTETensor global_amax, + NVTETensor per_block_scale, NVTETensor target_scale, + NVTETensor target_amax, size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, cudaStream_t stream); + +/*! \brief Compute global encode scale from global amax. + * + * Computes: global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax + * If global_amax <= 0, returns 1.0. + * + * \param[in] global_amax Input global amax tensor [num_params], float32. + * \param[out] global_scale Output global scale tensor [num_params], float32. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_compute_global_scale(const NVTETensor global_amax, NVTETensor global_scale, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 5e420b2d4..904812118 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -63,6 +63,21 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen * */ void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM (grouped tensor) + * + * \param[in] input Input grouped tensor with non-swizzled scale_inv. + * \param[in,out] output Output grouped tensor which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements(for now, more features will be added later): + * - scaling mode must be MXFP8 1D scaling. + * - scale_inv is stored in row-major per group. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + * - all tensors in the grouped tensor must have the same shape. + */ +void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 479be3df5..5e1bb22eb 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -429,6 +429,22 @@ int nvte_is_non_tn_fp8_gemm_supported(); */ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream); +/*! \brief Compute scaled prefix-sum offsets for grouped tensors. + * + * Computes: + * output[0] = 0 + * output[i + 1] = sum_{j=0..i}(first_dims[j] * logical_last_dim) + * for i in [0, num_tensors - 1]. + * + * \param[in] first_dims Pointer to device int64 array of size num_tensors. + * \param[out] output Pointer to device int64 array of size num_tensors + 1. + * \param[in] num_tensors Number of entries in first_dims. + * \param[in] logical_last_dim Scale factor applied to each first_dims entry. + * \param[in] stream CUDA stream to use for the operation. + */ +void nvte_splits_to_offsets(const int64_t *first_dims, int64_t *output, size_t num_tensors, + int64_t logical_last_dim, cudaStream_t stream); + /*! \brief TE Grouped Tensor type * * NVTEGroupedTensor is a collection of tensors with potentially different shapes @@ -451,6 +467,8 @@ enum NVTEGroupedTensorParam { kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */ kNVTEGroupedTensorOffsets = 9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */ + kNVTEGroupedWithGEMMSwizzledScales = + 10, /*!< Whether scaling factors are in format expected by GEMM */ kNVTENumGroupedTensorParams }; @@ -481,25 +499,30 @@ NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_ void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor); /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ -/*! \brief Set a parameter of the grouped tensor. +/*! \brief Set a grouped tensor parameter. * - * \param[in/out] tensor Grouped tensor. - * \param[in] param_name The parameter to be set. - * \param[in] param The value to be set (NVTEBasicTensor). + * \param[in/out] tensor Grouped tensor. + * \param[in] param Grouped tensor parameter type. + * \param[in] buf Memory address to read parameter value. + * \param[in] size_in_bytes Size of buf. */ -void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name, - const NVTEBasicTensor *param); +void nvte_set_grouped_tensor_param(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, + const void *buf, size_t size_in_bytes); /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ -/*! \brief Get a value of the parameter of the grouped tensor. - * - * \param[in] tensor Grouped tensor. - * \param[in] param_name The parameter to be queried. +/*! \brief Query a grouped tensor parameter. * - * \return NVTEBasicTensor containing the parameter data. + * \param[in] tensor Grouped tensor. + * \param[in] param Grouped tensor parameter type. + * \param[out] buf Memory address to write parameter value. + * Ignored if NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. */ -NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, - NVTEGroupedTensorParam param_name); +void nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, + void *buf, size_t size_in_bytes, size_t *size_written); /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Get the number of tensors in a grouped tensor. @@ -970,8 +993,235 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; -/*! \warning Deprecated */ -enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; +/*! \struct GroupedTensorWrapper + * \brief C++ wrapper for the NVTEGroupedTensor class. + */ + +class GroupedTensorWrapper { + public: + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * TE grouped tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} + + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : GroupedTensorWrapper(num_tensors, + nvte_make_shape(logical_shape.data(), logical_shape.size()), + scaling_mode) {} + + /*! \brief GroupedTensorWrapper destructor. */ + ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } + + GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; + GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; + + /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ + GroupedTensorWrapper(GroupedTensorWrapper &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + + /*! \brief Assign the data from existing GroupedTensorWrapper. */ + GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { + if (this == &other) return *this; + nvte_destroy_grouped_tensor(tensor_); + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + // Parameter setters + template + GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_grouped_tensor_param(tensor_, param, &data, sizeof(data)); + return *this; + } + + template + GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedScale, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); + } + + void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) { + const auto val = static_cast(with_gemm_swizzled_scales); + nvte_set_grouped_tensor_param(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val)); + } + + // Parameter getters + NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { + NVTEBasicTensor ret; + nvte_get_grouped_tensor_param(tensor_, param, &ret, sizeof(ret), nullptr); + return ret; + } + + NVTEBasicTensor get_rowwise_data() const noexcept { + return get_parameter(kNVTEGroupedRowwiseData); + } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedRowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseAmax); + } + + NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } + + NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } + + NVTEBasicTensor get_tensor_offsets() const noexcept { + return get_parameter(kNVTEGroupedTensorOffsets); + } + + bool get_with_gemm_swizzled_scales() const { + uint8_t val = 0; + nvte_get_grouped_tensor_param(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val), + nullptr); + return static_cast(val); + } + + /*! \brief Get an underlying NVTEGroupedTensor. + * + * \return NVTEGroupedTensor held by this GroupedTensorWrapper. + */ + NVTEGroupedTensor data() const noexcept { return tensor_; } + + /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ + size_t num_tensors() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_grouped_tensor_num_tensors(tensor_); + } + + /*! \brief Get the data type of this GroupedTensorWrapper. */ + DType dtype() const noexcept { + if (tensor_ == nullptr) return DType::kNumTypes; + return static_cast(nvte_grouped_tensor_type(tensor_)); + } + + /*! \brief Get a scaling mode of the grouped tensor. */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_grouped_tensor_scaling_mode(tensor_); + } + + /*! \brief Get the logical shape of this GroupedTensorWrapper. */ + const NVTEShape logical_shape() const noexcept { + if (tensor_ == nullptr) { + return emptyShape; + } + return nvte_get_grouped_tensor_logical_shape(tensor_); + } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = { + {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } + + /*! \brief Wrapped NVTEGroupedTensor. */ + NVTEGroupedTensor tensor_ = nullptr; +}; + +/*! \enum Float8BlockScaleTensorFormat + * \brief Data format for an FP8 block-scaled tensor + */ +enum class Float8BlockScaleTensorFormat { + /*! FP8 data is transposed if needed and scales are swizzled */ + GEMM_READY = 0, + /*! FP8 data is untransposed and scales are not swizzled or padded */ + COMPACT = 1, + INVALID +}; /*! \struct QuantizationConfigWrapper * \brief C++ wrapper for NVTEQuantizationConfigWrapper. diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index 5f9a8fe14..659a48d97 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -326,6 +326,32 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_in */ void nvte_swap_first_dims(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Transpose NVFP4 packed data. + * + * Unlike FP8, NVFP4 packs two 4-bit values per byte. This function correctly + * handles the nibble repacking during transpose. + * + * \param[in] input Input tensor with packed FP4 data. Shape: [M, K/2] bytes. + * \param[out] output Output tensor with transposed packed data. Shape: [K, M/2] bytes. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_data_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Transpose NVFP4 tile-level scales from rowwise to columnwise format. + * + * Takes rowwise_scale_inv where scales are stored at every 16th row (tile boundaries) + * and produces columnwise_scale_inv where scales are repeated 16 times per tile row. + * Scale values are stored as E4M3 (fp8) in uint8 tensors. + * + * \param[in] input Input tensor with rowwise scales [M_padded, K_tiles], uint8 (E4M3). + * \param[out] output Output tensor with columnwise scales [K_padded, M_tiles], uint8 (E4M3). + * \param[in] M_tiles Number of tiles in M dimension. + * \param[in] K_tiles Number of tiles in K dimension. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, size_t M_tiles, + size_t K_tiles, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/multi_tensor/adam.cu b/transformer_engine/common/multi_tensor/adam.cu index 2154102f0..2c4a681ab 100644 --- a/transformer_engine/common/multi_tensor/adam.cu +++ b/transformer_engine/common/multi_tensor/adam.cu @@ -56,7 +56,7 @@ struct FP8Data { template <> struct FP8Data {}; -template +template struct AdamFunctorMaster { static constexpr bool is_fp8_type = is_fp8::value; @@ -86,10 +86,10 @@ struct AdamFunctorMaster { PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; - FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + MOMENT_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); m += chunk_idx * chunk_size; - FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + MOMENT_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); v += chunk_idx * chunk_size; FULL_T *p_master = reinterpret_cast(tl.addresses[4][tensor_loc]); @@ -154,8 +154,8 @@ struct AdamFunctorMaster { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { p_master[i] = static_cast(r_p[ii]); - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); if constexpr (is_fp8_type) { __builtin_assume(fp8_data.max >= 0); fp8_data.max = fmaxf(fabsf(r_p[ii]), fp8_data.max); @@ -182,7 +182,7 @@ struct AdamFunctorMaster { } }; -template +template struct AdamFunctorMasterParamRemainder { __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, TensorListMetadata<5> &tl, // NOLINT(*) @@ -201,10 +201,10 @@ struct AdamFunctorMasterParamRemainder { int16_t *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; - FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + MOMENT_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); m += chunk_idx * chunk_size; - FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + MOMENT_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); v += chunk_idx * chunk_size; int16_t *p_remainder = reinterpret_cast(tl.addresses[4][tensor_loc]); @@ -290,15 +290,15 @@ struct AdamFunctorMasterParamRemainder { p_remainder[i] = local_p_rem[ii]; p[i] = local_p[ii]; - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } } }; -template +template struct AdamFunctor { __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, // NOLINT(*) @@ -324,10 +324,10 @@ struct AdamFunctor { PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; - FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + MOMENT_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); m += chunk_idx * chunk_size; - FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + MOMENT_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); v += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; @@ -379,15 +379,15 @@ struct AdamFunctor { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { p[i] = static_cast(r_p[ii]); - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } } }; -template +template struct AdamCapturableFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, // NOLINT(*) @@ -417,10 +417,10 @@ struct AdamCapturableFunctor { T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; - FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + MOMENT_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); m += chunk_idx * chunk_size; - FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + MOMENT_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); v += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; @@ -473,15 +473,15 @@ struct AdamCapturableFunctor { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { p[i] = static_cast(r_p[ii]); - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } } }; -template +template struct AdamCapturableMasterFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<5> &tl, // NOLINT(*) @@ -511,10 +511,10 @@ struct AdamCapturableMasterFunctor { T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; - FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + MOMENT_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); m += chunk_idx * chunk_size; - FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + MOMENT_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); v += chunk_idx * chunk_size; FULL_T *p_master = reinterpret_cast(tl.addresses[4][tensor_loc]); @@ -571,8 +571,8 @@ struct AdamCapturableMasterFunctor { if (i < n && i < chunk_size) { p[i] = static_cast(r_p[ii]); p_master[i] = static_cast(r_p[ii]); - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } @@ -613,12 +613,17 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK(tensor_lists[1][j]->dtype() == p_in_type_te, "Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()), ", but expected dtype=", to_string(p_in_type_te)); - NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, - " has dtype=", to_string(tensor_lists[2][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); - NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, - " has dtype=", to_string(tensor_lists[3][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); + { + const bool m_is_fp32 = tensor_lists[2][j]->dtype() == DType::kFloat32; + const bool m_is_bf16 = tensor_lists[2][j]->dtype() == DType::kBFloat16; + const bool v_is_fp32 = tensor_lists[3][j]->dtype() == DType::kFloat32; + const bool v_is_bf16 = tensor_lists[3][j]->dtype() == DType::kBFloat16; + NVTE_CHECK((m_is_fp32 && v_is_fp32) || (m_is_bf16 && v_is_bf16), + "First and second moment tensors must both be Float32 or both be BFloat16, but " + "tensor ", + j, " has first moment dtype=", to_string(tensor_lists[2][j]->dtype()), + " and second moment dtype=", to_string(tensor_lists[3][j]->dtype())); + } if (num_tensor_lists == 5) { NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j, " has dtype=", to_string(tensor_lists[4][j]->dtype()), @@ -640,6 +645,9 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, } } + // Get moment dtype (m and v have the same dtype, already validated above) + const auto moment_type_te = tensor_lists[2][0]->dtype(); + // Launch kernel if (requires_64bit_indexing) { if (num_tensor_lists == 4) { @@ -648,22 +656,26 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, - tensor_lists, - AdamFunctor(), stream, - beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);)); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<4>( + (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctor(), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);))); } else { // g, p, m, v, p_master TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, - tensor_lists, - AdamFunctorMaster(), - stream, beta1, beta2, bias_correction1, bias_correction2, - epsilon, lr, (adamMode_t)mode, weight_decay);)); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<5>( + (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), + stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);))); } } else { if (num_tensor_lists == 4) { @@ -672,20 +684,26 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctor(), stream, - beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);)); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<4>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor(), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);))); } else { // g, p, m, v, p_master TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), - stream, beta1, beta2, bias_correction1, bias_correction2, - epsilon, lr, (adamMode_t)mode, weight_decay);)); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<5>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), + stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);))); } } NVTE_CHECK_CUDA(cudaGetLastError()); @@ -723,24 +741,35 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK(tensor_lists[1][j]->dtype() == DType::kBFloat16, "Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()), ", but expected dtype=", to_string(DType::kBFloat16)); - NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, - " has dtype=", to_string(tensor_lists[2][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); - NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, - " has dtype=", to_string(tensor_lists[3][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); + { + const bool m_is_fp32 = tensor_lists[2][j]->dtype() == DType::kFloat32; + const bool m_is_bf16 = tensor_lists[2][j]->dtype() == DType::kBFloat16; + const bool v_is_fp32 = tensor_lists[3][j]->dtype() == DType::kFloat32; + const bool v_is_bf16 = tensor_lists[3][j]->dtype() == DType::kBFloat16; + NVTE_CHECK((m_is_fp32 && v_is_fp32) || (m_is_bf16 && v_is_bf16), + "First and second moment tensors must both be Float32 or both be BFloat16, but " + "tensor ", + j, " has first moment dtype=", to_string(tensor_lists[2][j]->dtype()), + " and second moment dtype=", to_string(tensor_lists[3][j]->dtype())); + } NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kInt16, "Param remainder tensor ", j, " has dtype=", to_string(tensor_lists[4][j]->dtype()), ", but expected dtype=", to_string(DType::kInt16)); } + // Get moment dtype (m and v have the same dtype, already validated above) + const auto moment_type_te = tensor_lists[2][0]->dtype(); + // Launch kernel TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMasterParamRemainder(), stream, - beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<5>( + (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctorMasterParamRemainder(), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, + weight_decay);)); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -819,17 +848,17 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, g_in_type_te, g_in_type, multi_tensor_apply<5, true>( (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), stream, beta1, beta2, + AdamFunctorMaster(), stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( fp8_dtype, FP8_T, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), - stream, beta1, beta2, bias_correction1, bias_correction2, - epsilon, lr, (adamMode_t)mode, weight_decay);)); + multi_tensor_apply<5, true>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), stream, beta1, beta2, + bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); } NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -859,22 +888,32 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()), ", but expected dtype=", to_string(g_in_type_te)); - NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, - " has dtype=", to_string(tensor_lists[2][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); - NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, - " has dtype=", to_string(tensor_lists[3][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); + { + const bool m_is_fp32 = tensor_lists[2][j]->dtype() == DType::kFloat32; + const bool m_is_bf16 = tensor_lists[2][j]->dtype() == DType::kBFloat16; + const bool v_is_fp32 = tensor_lists[3][j]->dtype() == DType::kFloat32; + const bool v_is_bf16 = tensor_lists[3][j]->dtype() == DType::kBFloat16; + NVTE_CHECK((m_is_fp32 && v_is_fp32) || (m_is_bf16 && v_is_bf16), + "First and second moment tensors must both be Float32 or both be BFloat16, but " + "tensor ", + j, " has first moment dtype=", to_string(tensor_lists[2][j]->dtype()), + " and second moment dtype=", to_string(tensor_lists[3][j]->dtype())); + } } + // Get moment dtype (m and v have the same dtype, already validated above) + const auto moment_type_te = tensor_lists[2][0]->dtype(); + // Launch kernel TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamCapturableFunctor(), stream, beta1, beta2, - reinterpret_cast(step.data.dptr), bias_correction, epsilon, - reinterpret_cast(lr.data.dptr), (adamMode_t)mode, weight_decay, - reinterpret_cast(inv_scale.data.dptr));) + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamCapturableFunctor(), stream, beta1, + beta2, reinterpret_cast(step.data.dptr), bias_correction, + epsilon, reinterpret_cast(lr.data.dptr), (adamMode_t)mode, + weight_decay, reinterpret_cast(inv_scale.data.dptr));)) NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -904,25 +943,36 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()), ", but expected dtype=", to_string(g_in_type_te)); - NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, - " has dtype=", to_string(tensor_lists[2][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); - NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, - " has dtype=", to_string(tensor_lists[3][j]->dtype()), - ", but expected dtype=", to_string(DType::kFloat32)); + { + const bool m_is_fp32 = tensor_lists[2][j]->dtype() == DType::kFloat32; + const bool m_is_bf16 = tensor_lists[2][j]->dtype() == DType::kBFloat16; + const bool v_is_fp32 = tensor_lists[3][j]->dtype() == DType::kFloat32; + const bool v_is_bf16 = tensor_lists[3][j]->dtype() == DType::kBFloat16; + NVTE_CHECK((m_is_fp32 && v_is_fp32) || (m_is_bf16 && v_is_bf16), + "First and second moment tensors must both be Float32 or both be BFloat16, but " + "tensor ", + j, " has first moment dtype=", to_string(tensor_lists[2][j]->dtype()), + " and second moment dtype=", to_string(tensor_lists[3][j]->dtype())); + } NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j, " has dtype=", to_string(tensor_lists[4][j]->dtype()), ", but expected dtype=", to_string(DType::kFloat32)); } + // Get moment dtype (m and v have the same dtype, already validated above) + const auto moment_type_te = tensor_lists[2][0]->dtype(); + // Launch kernel TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, - multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamCapturableMasterFunctor(), stream, beta1, beta2, - reinterpret_cast(step.data.dptr), bias_correction, epsilon, - reinterpret_cast(lr.data.dptr), (adamMode_t)mode, weight_decay, - reinterpret_cast(inv_scale.data.dptr));) + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamCapturableMasterFunctor(), stream, + beta1, beta2, reinterpret_cast(step.data.dptr), + bias_correction, epsilon, reinterpret_cast(lr.data.dptr), + (adamMode_t)mode, weight_decay, + reinterpret_cast(inv_scale.data.dptr));)) NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/multi_tensor/scale.cu b/transformer_engine/common/multi_tensor/scale.cu index b3266200c..6b9b66faa 100644 --- a/transformer_engine/common/multi_tensor/scale.cu +++ b/transformer_engine/common/multi_tensor/scale.cu @@ -33,97 +33,141 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int s ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*) } -template -struct ScaleFunctor { - __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, - TensorListMetadata<2> &tl, // NOLINT(*) - float scale) { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - in_t *in = reinterpret_cast(tl.addresses[0][tensor_loc]); - in += chunk_idx * chunk_size; - - out_t *out = reinterpret_cast(tl.addresses[1][tensor_loc]); - out += chunk_idx * chunk_size; - - n -= chunk_idx * chunk_size; - - bool finite = true; - in_t r_in[ILP]; - out_t r_out[ILP]; - - // to make things simple, we put aligned case in a different code path - if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) { - for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; - i_start += blockDim.x) { - // load - load_store(r_in, in, 0, i_start); +__device__ __forceinline__ float get_scale_value(float scale) { return scale; } + +__device__ __forceinline__ float get_scale_value(const float *scale_ptr) { return *scale_ptr; } + +template +__device__ __forceinline__ void scale_chunk(int chunk_size, volatile int *is_infinite_gmem, + TensorListMetadata<2> &tl, scale_t scale_arg) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + const float scale = get_scale_value(scale_arg); + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + in_t *in = reinterpret_cast(tl.addresses[0][tensor_loc]); + in += chunk_idx * chunk_size; + + out_t *out = reinterpret_cast(tl.addresses[1][tensor_loc]); + out += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + bool finite = true; + in_t r_in[ILP]; + out_t r_out[ILP]; + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_in, in, 0, i_start); #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - r_out[ii] = static_cast(r_in[ii]) * scale; - finite = finite && isfinite(static_cast(r_in[ii])); - } - // store - load_store(out, r_out, i_start, 0); + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(static_cast(r_in[ii])); } - } else { - // Non-divergent exit condition for __syncthreads, not necessary here - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + // store + load_store(out, r_out, i_start, 0); + } + } else { + // Non-divergent exit condition for __syncthreads, not necessary here + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - r_in[ii] = 0.f; - int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) r_in[ii] = in[i]; - } - // note for clarification to future michael: - // From a pure memory dependency perspective, there's likely no point unrolling - // the write loop, since writes just fire off once their LDGs arrive. - // Put another way, the STGs are dependent on the LDGs, but not on each other. - // There is still compute ILP benefit from unrolling the loop though. + for (int ii = 0; ii < ILP; ii++) { + r_in[ii] = 0.f; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) r_in[ii] = in[i]; + } + // From a pure memory dependency perspective, there's likely no point unrolling + // the write loop, since writes just fire off once their LDGs arrive. + // Put another way, the STGs are dependent on the LDGs, but not on each other. + // There is still compute ILP benefit from unrolling the loop though. #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - r_out[ii] = static_cast(r_in[ii]) * scale; - finite = finite && isfinite(static_cast(r_in[ii])); - } + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(static_cast(r_in[ii])); + } #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) out[i] = r_out[ii]; - } + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) out[i] = r_out[ii]; } } - if (!finite) *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. + } + if (!finite) *is_infinite_gmem = 1; // Blindly fire off a write. These will race but that's ok. +} + +template +struct ScaleFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *is_infinite_gmem, + TensorListMetadata<2> &tl, // NOLINT(*) + float scale) { + scale_chunk(chunk_size, is_infinite_gmem, tl, scale); } }; -void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, +template +struct ScalePtrFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *is_infinite_gmem, + TensorListMetadata<2> &tl, // NOLINT(*) + float *scale_ptr) { + scale_chunk(chunk_size, is_infinite_gmem, tl, scale_ptr); + } +}; + +void multi_tensor_scale_cuda(int chunk_size, Tensor is_infinite, std::vector> tensor_lists, float scale, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[1][0]->dtype(), g_in_type, - multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, is_infinite, tensor_lists, ScaleFunctor(), stream, scale);)) NVTE_CHECK_CUDA(cudaGetLastError()); } +void multi_tensor_scale_tensor_cuda(int chunk_size, Tensor is_infinite, + std::vector> tensor_lists, float *scale, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + tensor_lists[0][0]->dtype(), p_in_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + tensor_lists[1][0]->dtype(), g_in_type, + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, is_infinite, tensor_lists, + ScalePtrFunctor(), stream, scale);)) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace multi_tensor_scale } // namespace transformer_engine -void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, +void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor is_infinite, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, float scale, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_scale_cuda); using namespace transformer_engine; multi_tensor_scale::multi_tensor_scale_cuda( - chunk_size, *convertNVTETensorCheck(noop_flag), + chunk_size, *convertNVTETensorCheck(is_infinite), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, stream); } + +void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor is_infinite, + NVTETensor **tensor_lists, const size_t num_tensor_lists, + const size_t num_tensors_per_list, NVTETensor scale, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_scale_tensor_cuda); + using namespace transformer_engine; + + Tensor *scale_tensor = convertNVTETensorCheck(scale); + multi_tensor_scale::multi_tensor_scale_tensor_cuda( + chunk_size, *convertNVTETensorCheck(is_infinite), + convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), + reinterpret_cast(scale_tensor->data.dptr), stream); +} diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index d6aa55b37..f2726a37a 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -144,7 +144,9 @@ void TeNormalizationPlan::execute(Tensor* z, void* x_dptr, void* beta_dptr, void* mean_dptr, void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, cudaStream_t stream) { - NVTE_ERROR("Backward normalization should not call the forward execute function!"); + NVTE_ERROR( + "Backward normalization should not call the forward execute function. " + "Use the backward-specific execute overload instead."); } template @@ -201,7 +203,9 @@ void TeNormalizationPlan::execute(void* x_dptr, void* gamma void* dx_dptr, void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, cudaStream_t stream) { - NVTE_ERROR("Forward normalization should not call the backward execute function!"); + NVTE_ERROR( + "Forward normalization should not call the backward execute function. " + "Use the forward-specific execute overload instead."); } template <> diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 366f43d1f..d8ba9d3c9 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -111,33 +111,40 @@ class Recipe: Base recipe class. """ - def nvfp4(self): + @classmethod + def nvfp4(cls): """Whether the given recipe is NVFP4 1D block scaling.""" - return isinstance(self, NVFP4BlockScaling) + return issubclass(cls, NVFP4BlockScaling) - def mxfp8(self): + @classmethod + def mxfp8(cls): """Whether the given recipe is MXFP8 block scaling.""" - return isinstance(self, MXFP8BlockScaling) + return issubclass(cls, MXFP8BlockScaling) - def delayed(self): + @classmethod + def delayed(cls): """Whether the given recipe is delayed scaling.""" - return isinstance(self, DelayedScaling) + return issubclass(cls, DelayedScaling) - def float8_current_scaling(self): + @classmethod + def float8_current_scaling(cls): """Whether the given recipe is (per-tensor) current scaling.""" - return isinstance(self, Float8CurrentScaling) + return issubclass(cls, Float8CurrentScaling) - def float8_per_tensor_scaling(self): + @classmethod + def float8_per_tensor_scaling(cls): """Whether the given recipe is per-tensor scaling.""" - return isinstance(self, (DelayedScaling, Float8CurrentScaling)) + return issubclass(cls, (DelayedScaling, Float8CurrentScaling)) - def float8_block_scaling(self): + @classmethod + def float8_block_scaling(cls): """Whether the given recipe is float8 blockwise scaling.""" - return isinstance(self, Float8BlockScaling) + return issubclass(cls, Float8BlockScaling) - def custom(self): + @classmethod + def custom(cls): """Whether the given recipe is custom.""" - return isinstance(self, CustomRecipe) + return issubclass(cls, CustomRecipe) @dataclass() diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 682d8b53f..9a6ebdeec 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -1,21 +1,75 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include +#include #include +#include #include "../common.h" +#include "../util/ptx.cuh" #include "../utils.cuh" namespace transformer_engine { namespace nvfp4_recipe { +/* + * --------------------------------------------------------------------------- + * NVFP4 2D PARTIAL-SHARD KERNEL DESIGN + * + * These kernels mirror the FP8 block-scaling helpers but operate on shard-local + * slices and nibble-packed FP4 rowwise buffers. One CUDA block covers a logical + * 16x16 tile (grid = ceil(W/16) x ceil(H/16), blockDim = 256 threads). + * + * 1) Partial Amax (`nvfp4_2d_compute_partial_amax_kernel`) + * - Warps sweep the tile using nested loops, accumulating local maxima only + * for elements in [start_offset, start_offset + len). + * - Shared memory reduces the 8 warp maxima; the block writes a float into + * `amax_ptr[tile_row * stride_h + tile_col * stride_w]`. + * + * Tile/warp mapping (each '#' = elements visited by that warp): + * + * +------------------+ + * |########..........| Warp 0 + * |########..........| Warp 1 + * | ... | + * |########..........| Warp 7 + * +------------------+ + * + * 2) Partial Cast (`nvfp4_2d_partial_cast_kernel`) + * - Stage the tile into shared memory (same pattern as FP8). + * - For each 4-value group, build float2 pairs and call + * `ptx::mul_cvt_fp32_to_fp4_4x`, producing packed FP4 nibbles. + * - Compute a shard-local byte index and update only the owned nibble(s) + * using read-modify-write: + * + * packed_bits = [mw3 | mw2 | mw1 | mw0] + * byte_idx = (ref_elem_idx - start_offset) >> 1 + * if elem_idx % 2 == 0: // low nibble + * byte = (byte & 0xF0) | nibble + * else: // high nibble + * byte = (byte & 0x0F) | (nibble << 4) + * + * Thread coverage inside a tile: + * + * rows: 16 columns: 16 + * Warp 0 -> rows 0-1 lanes sweep cols 0..3, 4..7, ... + * Warp 1 -> rows 2-3 (groups of 4 elements per thread) + * ... + * Warp 7 -> rows 14-15 + * --------------------------------------------------------------------------- + */ + // constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); +constexpr int kTileDim = 16; +constexpr int kThreadsPerBlock = 256; // Kernel to compute alpha *= amax_A * amax_B / factor __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, @@ -24,9 +78,816 @@ __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const floa *alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv; } +template +__global__ void __launch_bounds__(kThreadsPerBlock) + nvfp4_2d_compute_partial_amax_kernel(const IType *input, float *amax_ptr, + const size_t amax_stride_h, const size_t amax_stride_w, + const size_t h, const size_t w, const size_t start_offset, + const size_t len) { + constexpr int kThreadsPerWarp = 32; + constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; + static_assert(kTileDim * kTileDim == kThreadsPerBlock); + + const size_t tile_col = blockIdx.x; + const size_t tile_row = blockIdx.y; + const size_t end_offset = start_offset + len; + const IType *input_minus_offset = input - start_offset; + + __shared__ float smem[kNumWarps]; + float amax = 0.0f; + + size_t r = tile_row * kTileDim + threadIdx.x / kTileDim; + size_t c = tile_col * kTileDim + threadIdx.x % kTileDim; + size_t idx = r * w + c; + if (r < h && c < w && idx >= start_offset && idx < end_offset) { + amax = fabs(static_cast(input_minus_offset[idx])); + } + + for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { +#ifdef __HIP_PLATFORM_AMD__ + float other_amax = __shfl_down(amax, delta, kThreadsPerWarp); +#else + float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta); +#endif + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + + if (threadIdx.x % kThreadsPerWarp == 0) { + smem[threadIdx.x / kThreadsPerWarp] = amax; + } + + __syncthreads(); + + if (threadIdx.x == 0) { + for (int i = 0; i < kNumWarps; ++i) { + float other_amax = smem[i]; + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax_ptr[tile_row * amax_stride_h + tile_col * amax_stride_w] = amax; + } +} + +template +__global__ void __launch_bounds__(kThreadsPerBlock) + nvfp4_2d_partial_cast_kernel(const IType *input, uint8_t *output, const float *decode_scale_ptr, + const size_t scale_stride_h, const size_t scale_stride_w, + const float *global_scale_ptr, const size_t h, const size_t w, + const size_t start_offset, const size_t len) { + constexpr int kNumOutputElemsPerBank = 4; + constexpr int kThreadsPerWarp = 32; + constexpr int kLoopsPerRow = (kTileDim + kThreadsPerWarp - 1) / kThreadsPerWarp; + constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; + constexpr int kRowsPerWarp = (kTileDim + kNumWarps - 1) / kNumWarps; + + __shared__ float smem[kTileDim][kTileDim + kNumOutputElemsPerBank]; + + const int tile_w = blockIdx.x; + const int tile_h = blockIdx.y; + const size_t shard_end = start_offset + len; + const IType *input_minus_offset = input - start_offset; + + float global_encode_scale = global_scale_ptr[0]; + if (global_encode_scale <= 0.f) { + global_encode_scale = 1.f; + } + const float global_decode_scale = 1.0f / global_encode_scale; + + float tile_decode_scale = decode_scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w]; + tile_decode_scale = static_cast(static_cast(tile_decode_scale)); + constexpr float kFp32Max = 3.402823466e+38F; + float tile_encode_val = + (tile_decode_scale > 0.f) ? 1.0f / (tile_decode_scale * global_decode_scale) : kFp32Max; + tile_encode_val = fminf(tile_encode_val, kFp32Max); + const float2 scale_vec = make_float2(tile_encode_val, tile_encode_val); + + bool skip_store = true; + for (int i = 0; i < kRowsPerWarp; ++i) { + for (int j = 0; j < kLoopsPerRow; ++j) { + const int h_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i; + const int w_in_smem = threadIdx.x % kThreadsPerWarp + kThreadsPerWarp * j; + if (h_in_smem >= kTileDim || w_in_smem >= kTileDim) { + continue; + } + const int h_in_input = tile_h * kTileDim + h_in_smem; + const int w_in_input = tile_w * kTileDim + w_in_smem; + const size_t idx_in_input = static_cast(h_in_input) * w + w_in_input; + if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset && + idx_in_input < shard_end) { + smem[h_in_smem][w_in_smem] = static_cast(input_minus_offset[idx_in_input]); + skip_store = false; + } + } + } + + for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { +#ifdef __HIP_PLATFORM_AMD__ + bool other = __shfl_down(skip_store, delta, kThreadsPerWarp); +#else + bool other = __shfl_down_sync(0xFFFFFFFF, skip_store, delta); +#endif + skip_store = skip_store && other; + } +#ifdef __HIP_PLATFORM_AMD__ + skip_store = __shfl(skip_store, 0, kThreadsPerWarp); +#else + skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0); +#endif + if (skip_store) { + return; + } + + for (int i = 0; i < kRowsPerWarp; ++i) { + const int row_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i; + const int row_in_output = tile_h * kTileDim + row_in_smem; + if (row_in_output >= h) { + continue; + } + const int col_in_smem = threadIdx.x % kThreadsPerWarp * kNumOutputElemsPerBank; + if (col_in_smem >= kTileDim) { + continue; + } + const int col_in_output = tile_w * kTileDim + col_in_smem; + + float vals[kNumOutputElemsPerBank]; + bool mask[kNumOutputElemsPerBank]; + size_t elem_idx[kNumOutputElemsPerBank]; + bool any_valid = false; + + for (int j = 0; j < kNumOutputElemsPerBank; ++j) { + const int col = col_in_output + j; + const bool in_width = col < w; + const size_t idx = static_cast(row_in_output) * w + col; + elem_idx[j] = idx; + const bool in_shard = in_width && idx >= start_offset && idx < shard_end; + mask[j] = in_shard; + const bool in_tile = (col_in_smem + j) < kTileDim; + const float tile_val = in_tile ? smem[row_in_smem][col_in_smem + j] : 0.0f; + vals[j] = in_shard ? tile_val : 0.0f; + any_valid |= in_shard; + } + + if (!any_valid) { + continue; + } + + const float2 in01 = make_float2(vals[0], vals[1]); + const float2 in23 = make_float2(vals[2], vals[3]); + const auto packed = + transformer_engine::ptx::mul_cvt_fp32_to_fp4_4x(in01, in23, scale_vec, 0); + const uint16_t packed_bits = reinterpret_cast(packed); + + for (int pair = 0; pair < 2; ++pair) { + const int first = pair * 2; + const int second = first + 1; + if (!mask[first] && !mask[second]) { + continue; + } + const size_t ref_idx = mask[first] ? elem_idx[first] : elem_idx[second]; + const size_t byte_idx = (ref_idx - start_offset) >> 1; + uint8_t byte = output[byte_idx]; + + if (mask[first]) { + const uint8_t nibble = static_cast((packed_bits >> (4 * first)) & 0xF); + if ((elem_idx[first] & 1u) == 0) { + byte = static_cast((byte & 0xF0u) | nibble); + } else { + byte = static_cast((byte & 0x0Fu) | (nibble << 4)); + } + } + + if (mask[second]) { + const uint8_t nibble = static_cast((packed_bits >> (4 * second)) & 0xF); + if ((elem_idx[second] & 1u) == 0) { + byte = static_cast((byte & 0xF0u) | nibble); + } else { + byte = static_cast((byte & 0x0Fu) | (nibble << 4)); + } + } + + output[byte_idx] = byte; + } + } +} + +void nvfp4_2d_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size_t w, + size_t amax_stride_h, size_t amax_stride_w, size_t start_offset, + size_t block_len, cudaStream_t stream) { + NVTE_CHECK(block_len == 16, "NVFP4 2D supports 16x16 tiles only (block_len = 16)."); + + size_t len = inp.numel(); + + assert(h > 0 && w > 0); + assert(start_offset < h * w); + assert(start_offset + len <= h * w); + + size_t blocks_x = (w + kTileDim - 1) / kTileDim; + size_t blocks_y = (h + kTileDim - 1) / kTileDim; + assert(blocks_x <= std::numeric_limits::max()); + assert(blocks_y <= std::numeric_limits::max()); + dim3 grid(blocks_x, blocks_y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + inp.dtype(), inp_dtype, + nvfp4_2d_compute_partial_amax_kernel<<>>( + reinterpret_cast(inp.data.dptr), + reinterpret_cast(amax.data.dptr), amax_stride_h, amax_stride_w, h, w, + start_offset, len);) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void nvfp4_2d_partial_cast(const Tensor inp, Tensor out, const Tensor scale, + const Tensor global_scale, size_t h, size_t w, size_t scale_stride_h, + size_t scale_stride_w, size_t start_offset, size_t block_len, + cudaStream_t stream) { + NVTE_CHECK(block_len == 16, "NVFP4 2D supports 16x16 tiles only (block_len = 16)."); + NVTE_CHECK(out.dtype() == DType::kByte, "NVFP4 rowwise data must be uint8."); + + size_t len = inp.numel(); + + assert(h > 0 && w > 0); + assert(start_offset < h * w); + assert(start_offset + len <= h * w); + + size_t blocks_x = (w + kTileDim - 1) / kTileDim; + size_t blocks_y = (h + kTileDim - 1) / kTileDim; + assert(blocks_x <= std::numeric_limits::max()); + assert(blocks_y <= std::numeric_limits::max()); + dim3 grid(blocks_x, blocks_y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + inp.dtype(), inp_dtype, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + w % kTileDim == 0, kWidthAligned, + nvfp4_2d_partial_cast_kernel + <<>>( + reinterpret_cast(inp.data.dptr), + reinterpret_cast(out.data.dptr), + reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, + reinterpret_cast(global_scale.data.dptr), h, w, start_offset, len);)) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +/* + * --------------------------------------------------------------------------- + * NVFP4 TRANSPOSE KERNEL + * + * Unlike FP8, NVFP4 packs two 4-bit values into each byte. A simple byte-wise + * transpose doesn't work because the packing changes: + * - Before transpose: elements [m, 2c] and [m, 2c+1] share a byte + * - After transpose: elements [k, 2*m_packed] and [k, 2*m_packed+1] share a byte + * which were originally [2*m_packed, k] and [2*m_packed+1, k] + * --------------------------------------------------------------------------- + */ + +// Vectorized transpose kernel parameters +constexpr int TRANSPOSE_TILE_DIM = 64; // Logical FP4 elements per tile dimension +constexpr int TRANSPOSE_TILE_PACKED = 32; // TILE_DIM / 2 bytes +constexpr int TRANSPOSE_BLOCK_SIZE = 256; // threads per block + +// Shared memory: store unpacked 4-bit values as bytes for easy transpose +// Size: TILE_DIM x (TILE_DIM + 4) to avoid bank conflicts +constexpr int TRANSPOSE_SHMEM_STRIDE = TRANSPOSE_TILE_DIM + 4; + +/* + * Vectorized transpose kernel with uint2 loads/stores (256 threads) + * Tile: 64x64 logical FP4 = 64x32 packed bytes + */ +__global__ void __launch_bounds__(TRANSPOSE_BLOCK_SIZE) + nvfp4_transpose_kernel(const uint8_t *__restrict__ input, uint8_t *__restrict__ output, + const size_t M, const size_t K) { + const size_t K_packed = K / 2; + const size_t M_packed = M / 2; + + const size_t tile_m_start = blockIdx.x * TRANSPOSE_TILE_DIM; + const size_t tile_k_start = blockIdx.y * TRANSPOSE_TILE_DIM; + + __shared__ uint8_t shmem[TRANSPOSE_TILE_DIM][TRANSPOSE_SHMEM_STRIDE]; + + const int tid = threadIdx.x; + + // Phase 1: Load input tile with VECTORIZED uint2 reads + // 256 threads, each loads 8 bytes (uint2) = 2048 bytes total + // Input tile: [64 rows, 32 cols] = 2048 bytes + { + const int thread_row = tid / 4; // 64 rows, 4 threads per row + const int thread_col = (tid % 4) * 8; // 4 x 8 = 32 bytes per row + + const size_t global_m = tile_m_start + thread_row; + const size_t global_k_packed_base = tile_k_start / 2 + thread_col; + + // Load 8 bytes as uint2 + uint2 loaded = make_uint2(0, 0); + if (global_m < M && global_k_packed_base + 7 < K_packed) { + loaded = *reinterpret_cast(&input[global_m * K_packed + global_k_packed_base]); + } else if (global_m < M) { + // Boundary: scalar loads + uint8_t *bytes = reinterpret_cast(&loaded); +#pragma unroll + for (int b = 0; b < 8; ++b) { + size_t col = global_k_packed_base + b; + bytes[b] = (col < K_packed) ? input[global_m * K_packed + col] : 0; + } + } + + // Unpack 8 bytes -> 16 nibbles and store to shared memory + const uint8_t *bytes = reinterpret_cast(&loaded); +#pragma unroll + for (int b = 0; b < 8; ++b) { + const int k0 = thread_col * 2 + b * 2; + const int k1 = k0 + 1; + shmem[thread_row][k0] = bytes[b] & 0x0F; + shmem[thread_row][k1] = (bytes[b] >> 4) & 0x0F; + } + } + + __syncthreads(); + + // Phase 2: Write output with VECTORIZED uint2 stores + // Output tile: [64 rows, 32 cols] = 2048 bytes + { + const int thread_row = tid / 4; // output K dimension [0, 64) + const int thread_col_base = (tid % 4) * 8; // output M_packed [0, 32) in steps of 8 + + const size_t global_k = tile_k_start + thread_row; + const size_t global_m_packed_base = tile_m_start / 2 + thread_col_base; + + if (global_k >= K) return; + + // Build 8 output bytes in registers + uint8_t out_bytes[8]; + +#pragma unroll + for (int b = 0; b < 8; ++b) { + const int out_m_packed = thread_col_base + b; + + if (global_m_packed_base + b >= M_packed) { + out_bytes[b] = 0; + continue; + } + + // Two M positions that pack into this output byte + const int m0 = out_m_packed * 2; + const int m1 = out_m_packed * 2 + 1; + const int k = thread_row; + + // Read from shared memory (transposed access) + const uint8_t val0 = shmem[m0][k]; + const uint8_t val1 = shmem[m1][k]; + + out_bytes[b] = val0 | (val1 << 4); + } + + // Vectorized store as uint2 + if (global_m_packed_base + 7 < M_packed) { + *reinterpret_cast(&output[global_k * M_packed + global_m_packed_base]) = + *reinterpret_cast(out_bytes); + } else { + // Boundary: scalar stores + for (int b = 0; b < 8 && global_m_packed_base + b < M_packed; ++b) { + output[global_k * M_packed + global_m_packed_base + b] = out_bytes[b]; + } + } + } +} + +void nvfp4_transpose(const Tensor input, Tensor output, cudaStream_t stream) { + // Input has logical shape [M, K], stored as [M, K/2] bytes + // Output has logical shape [K, M], stored as [K, M/2] bytes + + NVTE_CHECK(input.dtype() == DType::kByte, "NVFP4 transpose input must be uint8."); + NVTE_CHECK(output.dtype() == DType::kByte, "NVFP4 transpose output must be uint8."); + + // Get dimensions from packed storage + // input.shape() = [M, K/2], so M = shape[0], K = shape[1] * 2 + const auto in_shape = input.shape(); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 transpose expects 2D input (packed), got ", + in_shape.size(), "D."); + const size_t M = in_shape[0]; + const size_t K_packed = in_shape[1]; + const size_t K = K_packed * 2; + + // Output should be [K, M/2] + const size_t M_packed = M / 2; + NVTE_CHECK(M % 2 == 0, "NVFP4 transpose requires M (", M, ") to be even."); + + const auto out_shape = output.shape(); + NVTE_CHECK(out_shape.size() == 2, "NVFP4 transpose expects 2D output."); + NVTE_CHECK(out_shape[0] == K && out_shape[1] == M_packed, + "NVFP4 transpose output shape mismatch. Expected [", K, ", ", M_packed, "], got [", + out_shape[0], ", ", out_shape[1], "]."); + + if (M == 0 || K == 0) return; + + // Use vectorized kernel (faster than TMA for pure transpose) + // 128x128 tiles with 512 threads and uint4 vectorized access + dim3 block(TRANSPOSE_BLOCK_SIZE); + dim3 grid((M + TRANSPOSE_TILE_DIM - 1) / TRANSPOSE_TILE_DIM, + (K + TRANSPOSE_TILE_DIM - 1) / TRANSPOSE_TILE_DIM); + + nvfp4_transpose_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), M, K); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +/* + * --------------------------------------------------------------------------- + * NVFP4 SCALE TRANSPOSE KERNEL + * + * Transposes tile-level scales from rowwise to columnwise format. + * Scale values are stored as E4M3 (fp8) in uint8 tensors. + * + * Input (rowwise_scale_inv): [M_padded, K_tiles] where scales are stored + * at every 16th row (i.e., row 0, 16, 32, ... contain the actual scales, + * and each row i within a tile block has the same scale as row (i // 16) * 16). + * + * Output (columnwise_scale_inv): [K_padded, M_tiles] where scales are + * repeated 16 times per tile row. + * + * Mapping: + * output[k_tile * 16 + i, m_tile] = input[m_tile * 16, k_tile] + * for i in [0, 16) and valid (k_tile, m_tile) indices. + * --------------------------------------------------------------------------- + */ +__global__ void nvfp4_scale_transpose_kernel( + const uint8_t *__restrict__ input, // [M_padded, K_tiles], E4M3 stored as uint8 + uint8_t *__restrict__ output, // [K_padded, M_tiles], E4M3 stored as uint8 + const size_t M_tiles, // Number of M tiles + const size_t K_tiles, // Number of K tiles + const size_t input_stride, // K_tiles (input row stride) + const size_t output_stride, // M_tiles (output row stride) + const size_t K_padded // Output height +) { + // Each thread handles one output element + const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; + + if (out_row >= K_padded || out_col >= M_tiles) return; + + // Determine which tile row this belongs to + const size_t k_tile = out_row / kTileDim; + + // Read from input: row = m_tile * 16 (first row of the tile), col = k_tile + // m_tile = out_col + if (k_tile < K_tiles) { + const size_t in_row = out_col * kTileDim; // m_tile * 16 + const uint8_t scale = input[in_row * input_stride + k_tile]; + output[out_row * output_stride + out_col] = scale; + } else { + output[out_row * output_stride + out_col] = 0; + } +} + +void nvfp4_scale_transpose(const Tensor input, Tensor output, size_t M_tiles, size_t K_tiles, + cudaStream_t stream) { + NVTE_CHECK(input.dtype() == DType::kByte, "NVFP4 scale transpose input must be uint8 (E4M3)."); + NVTE_CHECK(output.dtype() == DType::kByte, "NVFP4 scale transpose output must be uint8 (E4M3)."); + + const auto in_shape = input.shape(); + const auto out_shape = output.shape(); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 scale transpose expects 2D input."); + NVTE_CHECK(out_shape.size() == 2, "NVFP4 scale transpose expects 2D output."); + + const size_t input_stride = in_shape[1]; // K_tiles + const size_t output_stride = out_shape[1]; // M_tiles + const size_t K_padded = out_shape[0]; + + if (M_tiles == 0 || K_tiles == 0 || K_padded == 0) return; + + constexpr int kBlockDim = 16; + dim3 block(kBlockDim, kBlockDim); + dim3 grid((M_tiles + kBlockDim - 1) / kBlockDim, (K_padded + kBlockDim - 1) / kBlockDim); + + nvfp4_scale_transpose_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), M_tiles, K_tiles, input_stride, output_stride, + K_padded); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +/* + * --------------------------------------------------------------------------- + * NVFP4 SCALE EXPANSION KERNEL + * + * Expands tile-level scales to row-level scales and converts to FP8 E4M3, used in partial cast. + * + * Input (per_block_decode_scale): [tile_rows, tile_cols] in float32 + * Output (target_scale): [rows_padded, tile_cols] in uint8 (E4M3) + * + * Each tile row's scale is repeated block_len times in the output. + * --------------------------------------------------------------------------- + */ +__global__ void nvfp4_expand_scale_to_fp8_kernel( + const float *__restrict__ input, // [tile_rows, tile_cols] + uint8_t *__restrict__ output, // [rows_padded, tile_cols] + const size_t tile_rows, const size_t tile_cols, const size_t rows_padded, + const size_t block_len) { + const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; + + if (out_row >= rows_padded || out_col >= tile_cols) return; + + // Determine which tile row this output row belongs to + const size_t tile_row = out_row / block_len; + + float scale_val = 0.0f; + if (tile_row < tile_rows) { + scale_val = input[tile_row * tile_cols + out_col]; + } + + // Convert float32 to FP8 E4M3 + // Clamp to FP8 E4M3 range and convert + fp8e4m3 fp8_val = static_cast(scale_val); + output[out_row * tile_cols + out_col] = reinterpret_cast(fp8_val); +} + +void nvfp4_expand_scale_to_fp8(const Tensor input, Tensor output, size_t tile_rows, + size_t tile_cols, size_t rows_padded, size_t block_len, + cudaStream_t stream) { + NVTE_CHECK(input.dtype() == DType::kFloat32, "Scale input must be float32."); + NVTE_CHECK(output.dtype() == DType::kByte, "Scale output must be uint8 (E4M3)."); + + if (tile_rows == 0 || tile_cols == 0 || rows_padded == 0) return; + + constexpr int kBlockDim = 16; + dim3 block(kBlockDim, kBlockDim); + dim3 grid((tile_cols + kBlockDim - 1) / kBlockDim, (rows_padded + kBlockDim - 1) / kBlockDim); + + nvfp4_expand_scale_to_fp8_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), tile_rows, tile_cols, rows_padded, block_len); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +/* + * --------------------------------------------------------------------------- + * NVFP4 COMPUTE PER-BLOCK DECODE SCALE KERNEL + * + * Computes per-block decode scale from block amax and global amax: + * global_scale = (fp8_max * fp4_max) / global_amax = 2688 / global_amax + * per_block_decode_scale = block_amax / fp4_max * global_scale + * = block_amax * 448 / global_amax + * + * This matches the CUDA device function compute_decoding_scaling_factor() in core_nvfp4.cuh + * + * Input (block_amax): [tile_rows, tile_cols] in float32 + * Input (global_amax): scalar float32 (per-tensor amax after all-reduce) + * Output (scale): [tile_rows, tile_cols] in float32 + * Output (global_scale_out): scalar float32 (the computed global encode scale) + * --------------------------------------------------------------------------- + */ +__global__ void nvfp4_compute_per_block_scale_kernel( + const float *__restrict__ block_amax, // [tile_rows, tile_cols] + float *__restrict__ scale, // [tile_rows, tile_cols] + const float *__restrict__ global_amax_ptr, // Pointer to single float value (avoids D2H) + const size_t numel) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= numel) return; + + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float flt_max = 3.402823466e+38f; + constexpr float tiny = 1.17549435e-38f; // FLT_MIN + + // Read global_amax from device memory (avoids D2H transfer) + float global_amax = *global_amax_ptr; + + // Compute global encode scale: S_enc = (fp8_max * fp4_max) / global_amax + float safe_global_amax = fmaxf(global_amax, tiny); + float global_scale = + (global_amax > 0.0f) ? fminf((fp8_max * fp4_max) / safe_global_amax, flt_max) : 1.0f; + + // Compute per-block decode scale: S_dec_b = block_amax / fp4_max * S_enc + float amax_val = block_amax[idx]; + float result = fminf((amax_val / fp4_max) * global_scale, flt_max); + scale[idx] = result; +} + +// Simple kernel to compute global encode scale from global amax +__global__ void nvfp4_compute_global_scale_kernel( + const float *__restrict__ global_amax, // [num_params] + float *__restrict__ global_scale, // [num_params] + const size_t num_params) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_params) return; + + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float flt_max = 3.402823466e+38f; + constexpr float tiny = 1.17549435e-38f; // FLT_MIN + + float amax = global_amax[idx]; + float safe_amax = fmaxf(amax, tiny); + float scale = (amax > 0.0f) ? fminf((fp8_max * fp4_max) / safe_amax, flt_max) : 1.0f; + global_scale[idx] = scale; +} + +void nvfp4_compute_per_block_scale(const Tensor block_amax, Tensor scale, const Tensor global_amax, + cudaStream_t stream) { + NVTE_CHECK(block_amax.dtype() == DType::kFloat32, "Block amax must be float32."); + NVTE_CHECK(scale.dtype() == DType::kFloat32, "Scale must be float32."); + NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + + size_t numel = block_amax.numel(); + if (numel == 0) return; + + constexpr int kBlockSize = 256; + int grid_size = (numel + kBlockSize - 1) / kBlockSize; + + nvfp4_compute_per_block_scale_kernel<<>>( + reinterpret_cast(block_amax.data.dptr), + reinterpret_cast(scale.data.dptr), + reinterpret_cast(global_amax.data.dptr), numel); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void nvfp4_compute_global_scale(const Tensor global_amax, Tensor global_scale, + cudaStream_t stream) { + NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); + NVTE_CHECK(global_scale.dtype() == DType::kFloat32, "Global scale must be float32."); + + size_t num_params = global_amax.numel(); + if (num_params == 0) return; + + constexpr int kBlockSize = 256; + int grid_size = (num_params + kBlockSize - 1) / kBlockSize; + + nvfp4_compute_global_scale_kernel<<>>( + reinterpret_cast(global_amax.data.dptr), + reinterpret_cast(global_scale.data.dptr), num_params); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +/* + * --------------------------------------------------------------------------- + * FUSED NVFP4 SCALE COMPUTATION KERNEL + * + * Fuses three operations into one kernel: + * 1. nvfp4_compute_per_block_scale: compute tile-level decode scales from block amax + * 2. target_amax.copy_: copy global amax to target tensor + * 3. nvfp4_expand_scale_to_fp8: expand to row-level and convert to FP8 E4M3 + * + * Input (block_amax): [tile_rows, tile_cols] float32 + * Input (global_amax): [1] float32 + * Output (per_block_scale): [tile_rows, tile_cols] float32 (intermediate, for partial_cast) + * Output (target_scale): [rows_padded, tile_cols] uint8 (E4M3) + * Output (target_amax): [1] float32 (copy of global_amax) + * + * Saves 2 kernel launches per parameter (eliminates nvfp4_compute_per_block_scale and + * nvfp4_expand_scale_to_fp8 as separate calls, plus the amax copy). + * --------------------------------------------------------------------------- + */ +__global__ void nvfp4_fused_scale_kernel( + const float *__restrict__ block_amax, // [tile_rows, tile_cols] + const float *__restrict__ global_amax, // [1] + float *__restrict__ per_block_scale, // [tile_rows, tile_cols] - for partial_cast + uint8_t *__restrict__ target_scale, // [rows_padded, tile_cols] + float *__restrict__ target_amax, // [1] + const size_t tile_rows, const size_t tile_cols, const size_t rows_padded, + const size_t block_len) { + const size_t out_row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t out_col = blockIdx.x * blockDim.x + threadIdx.x; + + // Read global amax once per thread (broadcast) + const float g_amax = *global_amax; + + // Thread (0,0) copies global_amax to target_amax + if (out_row == 0 && out_col == 0) { + *target_amax = g_amax; + } + + if (out_row >= rows_padded || out_col >= tile_cols) return; + + // Determine which tile row this output row belongs to + const size_t tile_row = out_row / block_len; + + // Compute the scale value + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float flt_max = 3.402823466e+38f; + constexpr float tiny = 1.17549435e-38f; + + float scale_val = 0.0f; + if (tile_row < tile_rows) { + float safe_global_amax = fmaxf(g_amax, tiny); + float global_scale = + (g_amax > 0.0f) ? fminf((fp8_max * fp4_max) / safe_global_amax, flt_max) : 1.0f; + + // Read block amax and compute per-block decode scale + float amax_val = block_amax[tile_row * tile_cols + out_col]; + scale_val = fminf((amax_val / fp4_max) * global_scale, flt_max); + + // Write per-block scale (only once per tile, when out_row % block_len == 0) + if (out_row % block_len == 0) { + per_block_scale[tile_row * tile_cols + out_col] = scale_val; + } + } + + // Convert float32 to FP8 E4M3 and write expanded scale + fp8e4m3 fp8_val = static_cast(scale_val); + target_scale[out_row * tile_cols + out_col] = reinterpret_cast(fp8_val); +} + +void nvfp4_fused_scale(const Tensor block_amax, const Tensor global_amax, Tensor per_block_scale, + Tensor target_scale, Tensor target_amax, size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, cudaStream_t stream) { + NVTE_CHECK(block_amax.dtype() == DType::kFloat32, "Block amax must be float32."); + NVTE_CHECK(global_amax.dtype() == DType::kFloat32, "Global amax must be float32."); + NVTE_CHECK(per_block_scale.dtype() == DType::kFloat32, "Per-block scale must be float32."); + NVTE_CHECK(target_scale.dtype() == DType::kByte, "Target scale must be uint8 (E4M3)."); + NVTE_CHECK(target_amax.dtype() == DType::kFloat32, "Target amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + NVTE_CHECK(target_amax.numel() == 1, "Target amax must be a single element tensor."); + + if (tile_rows == 0 || tile_cols == 0 || rows_padded == 0) return; + + constexpr int kBlockDim = 16; + dim3 block(kBlockDim, kBlockDim); + dim3 grid((tile_cols + kBlockDim - 1) / kBlockDim, (rows_padded + kBlockDim - 1) / kBlockDim); + + nvfp4_fused_scale_kernel<<>>( + reinterpret_cast(block_amax.data.dptr), + reinterpret_cast(global_amax.data.dptr), + reinterpret_cast(per_block_scale.data.dptr), + reinterpret_cast(target_scale.data.dptr), + reinterpret_cast(target_amax.data.dptr), tile_rows, tile_cols, rows_padded, + block_len); + NVTE_CHECK_CUDA(cudaGetLastError()); +} } // namespace nvfp4_recipe } // namespace transformer_engine +void nvte_nvfp4_expand_scale_to_fp8(const NVTETensor input, NVTETensor output, size_t tile_rows, + size_t tile_cols, size_t rows_padded, size_t block_len, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_expand_scale_to_fp8); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_expand_scale_to_fp8(*convertNVTETensorCheck(input), + *convertNVTETensorCheck(output), tile_rows, tile_cols, + rows_padded, block_len, stream); +} + +void nvte_nvfp4_compute_per_block_scale(const NVTETensor block_amax, NVTETensor scale, + const NVTETensor global_amax, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_compute_per_block_scale); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_compute_per_block_scale(*convertNVTETensorCheck(block_amax), + *convertNVTETensorCheck(scale), + *convertNVTETensorCheck(global_amax), stream); +} + +void nvte_nvfp4_compute_global_scale(const NVTETensor global_amax, NVTETensor global_scale, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_compute_global_scale); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_compute_global_scale(*convertNVTETensorCheck(global_amax), + *convertNVTETensorCheck(global_scale), stream); +} + +void nvte_nvfp4_scale_transpose(const NVTETensor input, NVTETensor output, size_t M_tiles, + size_t K_tiles, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_scale_transpose); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_scale_transpose(*convertNVTETensorCheck(input), + *convertNVTETensorCheck(output), M_tiles, K_tiles, stream); +} + +void nvte_nvfp4_data_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_data_transpose); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + stream); +} + +void nvte_nvfp4_2d_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, size_t w, + size_t amax_stride_h, size_t amax_stride_w, + size_t start_offset, size_t block_len, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_2d_compute_partial_amax); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_2d_compute_partial_amax(*convertNVTETensorCheck(inp), + *convertNVTETensorCheck(amax), h, w, amax_stride_h, + amax_stride_w, start_offset, block_len, stream); +} + +void nvte_nvfp4_2d_partial_cast(const NVTETensor inp, NVTETensor out, const NVTETensor scale, + const NVTETensor global_scale, size_t h, size_t w, + size_t scale_stride_h, size_t scale_stride_w, size_t start_offset, + size_t block_len, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_2d_partial_cast); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_2d_partial_cast(*convertNVTETensorCheck(inp), *convertNVTETensorCheck(out), + *convertNVTETensorCheck(scale), + *convertNVTETensorCheck(global_scale), h, w, scale_stride_h, + scale_stride_w, start_offset, block_len, stream); +} + void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, const NVTETensor inpB, const bool use_rowwise_amax_B, float alpha_in, NVTETensor alpha_out, @@ -52,3 +913,15 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r reinterpret_cast(amax_B_ptr), reinterpret_cast(alpha_ptr)); NVTE_CHECK_CUDA(cudaGetLastError()); } + +void nvte_nvfp4_fused_scale(const NVTETensor block_amax, const NVTETensor global_amax, + NVTETensor per_block_scale, NVTETensor target_scale, + NVTETensor target_amax, size_t tile_rows, size_t tile_cols, + size_t rows_padded, size_t block_len, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_fused_scale); + using namespace transformer_engine; + nvfp4_recipe::nvfp4_fused_scale( + *convertNVTETensorCheck(block_amax), *convertNVTETensorCheck(global_amax), + *convertNVTETensorCheck(per_block_scale), *convertNVTETensorCheck(target_scale), + *convertNVTETensorCheck(target_amax), tile_rows, tile_cols, rows_padded, block_len, stream); +} diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index c634c73fb..592992d61 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -286,6 +286,40 @@ struct MultiSwizzleArgs { int num_tensors; }; +constexpr size_t round_up_to_multiple(size_t value, size_t multiple) { + return DIVUP(value, multiple) * multiple; +} + +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + grouped_swizzle_row_scaling_uniform_shape_kernel(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, + const size_t scale_stride_bytes) { + const int tensor_id = blockIdx.z; + const uint8_t* input_base = + reinterpret_cast(input) + tensor_id * scale_stride_bytes; + uint8_t* output_base = reinterpret_cast(output) + tensor_id * scale_stride_bytes; + swizzle_row_scaling_kernel_impl( + input_base, output_base, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, + gridDim.y); +} + +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + grouped_swizzle_col_scaling_uniform_shape_kernel(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, + const size_t scale_stride_bytes) { + const int tensor_id = blockIdx.z; + const uint8_t* input_base = + reinterpret_cast(input) + tensor_id * scale_stride_bytes; + uint8_t* output_base = reinterpret_cast(output) + tensor_id * scale_stride_bytes; + swizzle_col_scaling_kernel_impl( + input_base, output_base, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, + gridDim.y); +} + template __global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) { // Find tensor corresponding to block @@ -587,9 +621,9 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, int n_tiles_in_tb = TB_DIM * vec_load_size; int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); /* Calculate number of CUDA blocks needed for each tensor. - * We have to do it here because we have to iterate over all tensors in this batch to - * get the minimum vec_load_size. - */ + * We have to do it here because we have to iterate over all tensors in this batch to + * get the minimum vec_load_size. + */ for (size_t j = 0; j < kernel_args.num_tensors; j++) { const int m = kernel_args.m_list[j]; const int k = kernel_args.k_list[j]; @@ -837,10 +871,10 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, } // namespace transformer_engine /* - * WIP (Phuong): - * - Opt for bank conflicts - * - Adding swizzle for 2d-block scaling. - */ +* WIP (Phuong): +* - Opt for bank conflicts +* - Adding swizzle for 2d-block scaling. +*/ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swizzle_scaling_factors); using namespace transformer_engine; @@ -859,3 +893,171 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen } multi_tensor_swizzle_scaling_factors(input_list, output_list, stream); } + +namespace transformer_engine { + +void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output, + cudaStream_t stream) { + // Check scaling mode + NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING, + "Grouped swizzle supports only MXFP8 scaling."); + + // Check tensors + CheckInputGroupedTensor(*input, "input"); + CheckOutputGroupedTensor(*output, "output", false); + NVTE_CHECK(!input->with_gemm_swizzled_scales, + "Expected input grouped tensor with scales in compact format."); + NVTE_CHECK(output->with_gemm_swizzled_scales, + "Expected output grouped tensor with scales in GEMM swizzled format."); + + // Check scaling factors availability + const bool has_rowwise_scale_inv = input->scale_inv.has_data(); + const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data(); + if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) { + return; + } + + // Only support uniform shapes for graph-safe grouped swizzle + NVTE_CHECK(input->all_same_shape(), "Grouped swizzle requires uniform tensor shapes."); + NVTE_CHECK(input->all_same_last_dim() && input->all_same_first_dim(), + "Grouped swizzle requires uniform tensor shapes."); + + // Assumption is that all the tensors share the same shapes and are contgiuous. + // And so we dont need to pass array of input/output pointers(due to conttiguity) + // as well as array of shapes(due to uniform shapes). + const size_t first_dim = input->get_common_first_dim(); + const size_t last_dim = input->get_common_last_dim(); + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + const dim3 block_size(TB_DIM, TB_DIM); + + auto launch_grouped_swizzle = [&](bool rowwise) { + const size_t m = rowwise ? first_dim : last_dim; + const size_t k = rowwise ? last_dim : first_dim; + const size_t padded_m = round_up_to_multiple(m, 128); + const size_t padded_k = + round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); + const size_t scale_elems = padded_m * padded_k; + + const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) + : typeToSize(input->columnwise_scale_inv.dtype); + const size_t scale_stride_bytes = scale_elems * scale_elem_size; + + if (rowwise) { + NVTE_CHECK(input->scale_inv.numel() == input->num_tensors * scale_elems, + "Grouped input scale_inv size does not match expected packed size."); + NVTE_CHECK(output->scale_inv.numel() == output->num_tensors * scale_elems, + "Grouped output scale_inv size does not match expected packed size."); + } else { + NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, + "Grouped input columnwise_scale_inv size does not match expected packed size."); + NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, + "Grouped output columnwise_scale_inv size does not match expected packed size."); + } + + const int num_tiles_m = padded_m / SF_TILE_DIM_M; + const int num_tiles_k = padded_k / SF_TILE_DIM_K; + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + const int n_tiles_in_tb = TB_DIM * vec_load_size; + + dim3 num_blocks; + if (rowwise) { + num_blocks = dim3(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m, input->num_tensors); + } else { + num_blocks = + dim3(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size), input->num_tensors); + } + const int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + const int original_M = static_cast(rowwise ? first_dim : last_dim); + const int original_K = static_cast(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE))); + const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr; + void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; + + if (rowwise) { + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_row_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_row_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + case 2: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_row_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_row_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + case 1: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_row_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_row_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + } + } else { + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_col_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_col_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + case 2: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_col_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_col_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + case 1: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_col_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_col_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + } + } + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + if (has_rowwise_scale_inv) { + launch_grouped_swizzle(true); + } + if (has_columnwise_scale_inv) { + launch_grouped_swizzle(false); + } +} + +} // namespace transformer_engine + +void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_grouped_scaling_factors); + using namespace transformer_engine; + swizzle_grouped_scaling_factors(convertNVTEGroupedTensorCheck(input), + convertNVTEGroupedTensorCheck(output), stream); +} diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index c724612cb..858b9ec1e 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -461,9 +461,9 @@ class TensorAllocator { } void Free(NVTETensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid tensor."); free_list.push_back(index); // Clean up @@ -571,9 +571,9 @@ class GroupedTensorAllocator { } void Free(NVTEGroupedTensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor."); free_list.push_back(index); // Clean up @@ -657,7 +657,7 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) { NVTEShape nvte_tensor_shape(const NVTETensor tensor) { auto *t = transformer_engine::convertNVTETensor(tensor); if (t == nullptr) { - NVTE_ERROR("Invalid tensor"); + NVTE_ERROR("Invalid tensor: received null pointer in nvte_tensor_shape"); } // Determine tensor shape depending on tensor format @@ -669,7 +669,7 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { auto *t = transformer_engine::convertNVTETensor(tensor); if (t == nullptr) { - NVTE_ERROR("Invalid tensor"); + NVTE_ERROR("Invalid tensor: received null pointer in nvte_tensor_columnwise_shape"); } const std::vector &shape = t->columnwise_data.shape; return nvte_make_shape(shape.data(), shape.size()); @@ -1152,8 +1152,8 @@ NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_ NVTEShape logical_shape) { NVTE_CHECK(num_tensors > 0, "Number of tensors must be greater than 0"); NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D"); - NVTE_CHECK(logical_shape.data[0] > 0 && logical_shape.data[1] > 0, - "Logical shape must have positive dimensions"); + // NVTE_CHECK(logical_shape.data[0] > 0 && logical_shape.data[1] > 0, + // "Logical shape must have positive dimensions"); NVTEGroupedTensor ret = transformer_engine::GroupedTensorAllocator::instance().Allocate( scaling_mode, num_tensors, logical_shape); return ret; @@ -1163,88 +1163,178 @@ void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor) { transformer_engine::GroupedTensorAllocator::instance().Free(tensor); } -void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name, - const NVTEBasicTensor *param) { +void nvte_set_grouped_tensor_param(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, + const void *buf, size_t size_in_bytes) { + using namespace transformer_engine; + + // Check attribute and buffer + NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", + static_cast(param), ")"); NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL."); - auto *t = transformer_engine::convertNVTEGroupedTensor(*tensor); - NVTE_CHECK(t != nullptr, "Grouped tensor is not allocated."); - NVTE_CHECK(param != nullptr, "Grouped tensor param can't be NULL."); + auto &t = *convertNVTEGroupedTensorCheck(tensor); + const auto &attr_size = GroupedTensor::attr_sizes[param]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for grouped tensor parameter " + "(parameter ", + static_cast(param), " needs ", attr_size, " bytes, but buffer has ", + size_in_bytes, " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); - switch (param_name) { - case kNVTEGroupedRowwiseData: - t->data = *param; + // Read from buffer + switch (param) { + case kNVTEGroupedRowwiseData: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.data = *basic_tensor; break; - case kNVTEGroupedColumnwiseData: - t->columnwise_data = *param; + } + case kNVTEGroupedColumnwiseData: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_data = *basic_tensor; break; - case kNVTEGroupedScale: - t->scale = *param; + } + case kNVTEGroupedScale: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.scale = *basic_tensor; break; - case kNVTEGroupedAmax: - t->amax = *param; + } + case kNVTEGroupedAmax: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.amax = *basic_tensor; break; - case kNVTEGroupedRowwiseScaleInv: - t->scale_inv = *param; + } + case kNVTEGroupedRowwiseScaleInv: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.scale_inv = *basic_tensor; break; - case kNVTEGroupedColumnwiseScaleInv: - t->columnwise_scale_inv = *param; + } + case kNVTEGroupedColumnwiseScaleInv: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_scale_inv = *basic_tensor; break; - case kNVTEGroupedColumnwiseAmax: - t->columnwise_amax = *param; + } + case kNVTEGroupedColumnwiseAmax: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_amax = *basic_tensor; break; - case kNVTEGroupedFirstDims: - t->first_dims = *param; - // Validate it's Int64 - NVTE_CHECK(t->first_dims.dtype == transformer_engine::DType::kInt64, - "first_dims must have dtype Int64"); + } + case kNVTEGroupedFirstDims: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.first_dims = *basic_tensor; + NVTE_CHECK(t.first_dims.dtype == DType::kInt64, "first_dims must have dtype Int64"); break; - case kNVTEGroupedLastDims: - t->last_dims = *param; - // Validate it's Int64 - NVTE_CHECK(t->last_dims.dtype == transformer_engine::DType::kInt64, - "last_dims must have dtype Int64"); + } + case kNVTEGroupedLastDims: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.last_dims = *basic_tensor; + NVTE_CHECK(t.last_dims.dtype == DType::kInt64, "last_dims must have dtype Int64"); break; - case kNVTEGroupedTensorOffsets: - t->tensor_offsets = *param; - // Validate it's Int64 - NVTE_CHECK(t->tensor_offsets.dtype == transformer_engine::DType::kInt64, - "tensor_offsets must have dtype Int64"); + } + case kNVTEGroupedTensorOffsets: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.tensor_offsets = *basic_tensor; + NVTE_CHECK(t.tensor_offsets.dtype == DType::kInt64, "tensor_offsets must have dtype Int64"); + break; + } + case kNVTEGroupedWithGEMMSwizzledScales: + t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); break; default: - NVTE_ERROR("Unknown grouped tensor parameter!"); + NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); } } -NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, - NVTEGroupedTensorParam param_name) { - if (tensor == nullptr) { - return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 1)}; +void nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, + void *buf, size_t size_in_bytes, size_t *size_written) { + using namespace transformer_engine; + + // Check param + NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", + static_cast(param), ")"); + + // Write attribute size if provided + const auto &attr_size = GroupedTensor::attr_sizes[param]; + if (size_written != nullptr) { + *size_written = attr_size; } - const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); - switch (param_name) { - case kNVTEGroupedRowwiseData: - return t.data; - case kNVTEGroupedColumnwiseData: - return t.columnwise_data; - case kNVTEGroupedScale: - return t.scale; - case kNVTEGroupedAmax: - return t.amax; - case kNVTEGroupedRowwiseScaleInv: - return t.scale_inv; - case kNVTEGroupedColumnwiseScaleInv: - return t.columnwise_scale_inv; - case kNVTEGroupedColumnwiseAmax: - return t.columnwise_amax; - case kNVTEGroupedFirstDims: - return t.first_dims; - case kNVTEGroupedLastDims: - return t.last_dims; - case kNVTEGroupedTensorOffsets: - return t.tensor_offsets; + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for grouped tensor parameter " + "(parameter ", + static_cast(param), " needs ", attr_size, " bytes, but buffer has ", + size_in_bytes, " bytes)"); + + // Get C++ grouped tensor + const GroupedTensor *t = convertNVTEGroupedTensor(tensor); + std::optional dummy; + if (t == nullptr) { + // Make dummy grouped tensor if provided tensor is invalid + dummy.emplace(NVTE_DELAYED_TENSOR_SCALING, 1); + t = &(*dummy); + } + + // Write to buffer + switch (param) { + case kNVTEGroupedRowwiseData: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->data); + break; + } + case kNVTEGroupedColumnwiseData: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_data); + break; + } + case kNVTEGroupedScale: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->scale); + break; + } + case kNVTEGroupedAmax: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->amax); + break; + } + case kNVTEGroupedRowwiseScaleInv: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->scale_inv); + break; + } + case kNVTEGroupedColumnwiseScaleInv: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_scale_inv); + break; + } + case kNVTEGroupedColumnwiseAmax: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_amax); + break; + } + case kNVTEGroupedFirstDims: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->first_dims); + break; + } + case kNVTEGroupedLastDims: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->last_dims); + break; + } + case kNVTEGroupedTensorOffsets: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->tensor_offsets); + break; + } + case kNVTEGroupedWithGEMMSwizzledScales: + *reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); + break; default: - NVTE_ERROR("Unknown grouped tensor parameter!"); + NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); } } diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 0e286009a..3a8536587 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -463,7 +463,8 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size std::is_same_v) { dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; } else { - NVTE_CHECK(false, "Invalid Output type (must be FP8)."); + NVTE_ERROR( + "Invalid output type for blockwise transpose (must be FP8: Float8E4M3 or Float8E5M2)."); } CUtensorMap tensor_map_output_trans{}; diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 59742d1e7..60e925749 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -213,10 +213,10 @@ __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_ return global_encode_scale; } -__device__ __forceinline__ uint32_t -get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10>& - rng, // philox4x32_native_state<10>: 10 rounds of philox4_32 - uint4& random_uint4, int& rnd_idx) { +__device__ __forceinline__ uint32_t get_rbits( + transformer_engine::curanddx::detail::philox4x32_native_state& + rng, // NVTE_BUILD_NUM_PHILOX_ROUNDS rounds of philox4x32 + uint4& random_uint4, int& rnd_idx) { if (rnd_idx == 4) { rnd_idx = 0; random_uint4 = rng.generate4(); @@ -390,7 +390,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + transformer_engine::curanddx::detail::philox4x32_native_state rng; rng.init(rng_seed, rng_sequence, rng_offset); uint4 random_uint4 = kApplyStochasticRounding ? rng.generate4() : uint4{0, 0, 0, 0}; diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index 4602f41cf..75bb85f5e 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel( split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0 ).to(tl.int32) input_split_sizes_cumsum = tl.cumsum(input_split_sizes) + + # Compute total valid tokens and skip phantom/padding tokens. + # When the input buffer is larger than sum(split_sizes), tokens beyond + # the valid range should map to themselves (identity mapping) to avoid + # corrupting valid output positions. + total_valid_tokens = tl.sum(input_split_sizes) + input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0) input_chunk_idx = tl.sum(input_split_sizes_mask) input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask) @@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel( ).to(tl.int32) output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset + + # For tokens beyond the valid range (pid >= total_valid_tokens), + # use identity mapping to avoid corrupting valid data + dst_row = tl.where(pid < total_valid_tokens, dst_row, pid) + tl.store(dst_rows_ptr + pid, dst_row) @@ -587,6 +599,10 @@ def _sort_chunks_by_map_kernel( input_ptr, row_id_map_ptr, probs_ptr, + # Pre-allocated output buffer for JAX input_output_aliases. + # Aliased to output_ptr in JAX so they point to the same memory. + # In PyTorch, pass the same tensor as output_ptr. + output_buf_ptr, # pylint: disable=unused-argument # strides stride_input_token, stride_input_hidden, diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index 22a193dea..982ff9b28 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -11,7 +11,9 @@ #include +#include #include +#include #include "../common.h" #include "../util/string.h" @@ -31,13 +33,30 @@ void *get_symbol(const char *symbol, int cuda_version = 12010); * without GPUs. Indirect function calls into a lazily-initialized * library ensures we are accessing the correct version. * + * Symbol pointers are cached to avoid repeated lookups. + * * \param[in] symbol Function name * \param[in] args Function arguments */ template inline CUresult call(const char *symbol, ArgTs... args) { using FuncT = CUresult(ArgTs...); - FuncT *func = reinterpret_cast(get_symbol(symbol)); + + static std::unordered_map symbol_cache; + static std::mutex cache_mutex; + FuncT *func; + + { + std::lock_guard lock(cache_mutex); + auto it = symbol_cache.find(symbol); + if (it == symbol_cache.end()) { + void *ptr = get_symbol(symbol); + symbol_cache[symbol] = ptr; + func = reinterpret_cast(ptr); + } else { + func = reinterpret_cast(it->second); + } + } return (*func)(args...); } diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 3505516ba..0adb902ac 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -8,6 +8,8 @@ #include "../util/cuda_runtime.h" +#include + #include #include @@ -264,6 +266,12 @@ int cudart_version() { static int version = get_version(); return version; } + +size_t cublas_version() { + // Cache version to avoid cuBLAS logging overhead + static size_t version = cublasLtGetVersion(); + return version; +} #endif // __HIP_PLATFORM_AMD__ } // namespace cuda diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index 1cccb492f..da3d05da4 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -12,7 +12,6 @@ #include #include - namespace transformer_engine { namespace cuda { @@ -85,6 +84,12 @@ const std::string &include_directory(bool required = false); * Versions may differ between compile-time and run-time. */ int cudart_version(); + +/* \brief cuBLAS version number at run-time + * + * Versions may differ between compile-time and run-time. + */ +size_t cublas_version(); #endif } // namespace cuda diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index a70ae4398..22cc2f858 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -115,12 +115,12 @@ #ifdef NVTE_WITH_CUBLASMP -#define NVTE_CHECK_CUBLASMP(expr) \ - do { \ - const cublasMpStatus_t status = (expr); \ - if (status != CUBLASMP_STATUS_SUCCESS) { \ - NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \ - } \ +#define NVTE_CHECK_CUBLASMP(expr) \ + do { \ + const cublasMpStatus_t status = (expr); \ + if (status != CUBLASMP_STATUS_SUCCESS) { \ + NVTE_ERROR("cuBLASMp Error: ", cublasMpGetStatusString(status)); \ + } \ } while (false) #endif // NVTE_WITH_CUBLASMP diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index a186ed9d3..a4ccf4580 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -126,17 +126,6 @@ constexpr bool is_supported_arch() { #define ARCH_HAS_STOCHASTIC_ROUNDING \ NVTE_CUDA_ARCH_MATCHES(ptx::ArchSpecific<100>, ptx::ArchSpecific<103>) -#else - -// Native FP4 stochastic rounding is available on gfx950 and later. -#if defined(__gfx950__) -#define ARCH_HAS_STOCHASTIC_ROUNDING (true) -#else -#define ARCH_HAS_STOCHASTIC_ROUNDING (false) -#endif - -#endif //#ifndef __HIP_PLATFORM_AMD__ - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -178,6 +167,18 @@ __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +__device__ __forceinline__ void mbarrier_arrive_expect_tx_cta_relaxed_shared_cta( + uint64_t *mbar, const uint32_t tx_count) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.expect_tx.relaxed.cta.shared::cta.b64 _, [%0], %1;" ::"r"(mbar_ptr), + "r"(tx_count)); +#else + NVTE_DEVICE_ERROR( + "mbarrier_arrive_expect_tx_cta_relaxed_shared_cta is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + __device__ __forceinline__ void fence_mbarrier_init_release_cluster() { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile("fence.mbarrier_init.release.cluster;"); @@ -257,13 +258,97 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +__device__ __forceinline__ void mbarrier_wait_parity_acquire_cta_shared_cta(uint64_t *mbar, + uint32_t phase_parity) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile( + "{\n\t" + ".reg .b64 r1; \n\t" + ".reg .pred waitComplete; \n\t" // predicate representing if barrier condition is met + "WAIT: \n\t" // loop around barrier wait + "mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 waitComplete, [%0], %1; \n\t" + "@waitComplete bra DONE; \n\t" // mbarrier conditions are met + "bra WAIT; \n\t" // just a time-out, try again + "DONE: \n\t" + "}\n\t" + : + : "r"(mbar_ptr), "r"(phase_parity) + : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_wait_parity_acquire_cta_shared_cta is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__device__ __forceinline__ void try_cancel_cta(uint64_t *mbar, __uint128_t *response_data_ptr) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr); + asm volatile( + "clusterlaunchcontrol.try_cancel.async.mbarrier::complete_tx::bytes.multicast::cluster::" + "all.b128 " + "[%0], [%1];" ::"r"(workID_response), + "r"(mbar_ptr)); + } else { + NVTE_DEVICE_ERROR( + "Cluster Launch Control PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } +} + +__device__ __forceinline__ void get_cancelled_cta_id_2D(__uint128_t *response_data_ptr, + int32_t &ctaid_X, int32_t &ctaid_Y) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr); + asm volatile( + "{\n\t" + ".reg .s32 x_ctaid; \n\t" + ".reg .s32 y_ctaid; \n\t" + "mov .s32 x_ctaid, -1; \n\t" + "mov .s32 y_ctaid, -1; \n\t" + ".reg.b128 try_cancel_response; \n\t" + "ld.shared.b128 try_cancel_response, [%2]; \n\t" + ".reg .pred P1; \n\t" + "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 P1, try_cancel_response; \n\t" + "@P1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {x_ctaid, y_ctaid, _, " + "_}, try_cancel_response; \n\t" + "mov .s32 %0, x_ctaid; \n\t" + "mov .s32 %1, y_ctaid; \n\t" + "}\n\t" + : "=r"(ctaid_X), "=r"(ctaid_Y) + : "r"(workID_response) + : "memory"); + } else { + NVTE_DEVICE_ERROR( + "Cluster Launch Control PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } +} + +#else + +// Native FP4 stochastic rounding is available on gfx950 and later. +#if defined(__gfx950__) +#define ARCH_HAS_STOCHASTIC_ROUNDING (true) +#else +#define ARCH_HAS_STOCHASTIC_ROUNDING (false) +#endif + +#endif //#ifndef __HIP_PLATFORM_AMD__ + constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_EXPONENT_BIAS = 127; __device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { - return (biased_exp == 0) ? 1 - : __int_as_float((254 - biased_exp) - << FP32_MANTISSA_BITS); // 127 - (biased_exp - 127) + // Handle the special case of NaN. + if (biased_exp == 255) return __int_as_float(0x7fffffff); + // Handle the special case where the unbiased exponent is 127, so the reciprocal is 2^-127 which needs the first bit of + // the mantissa to be 1, which can't be obtained by shifting `FP32_MANTISSA_BITS` bits to the left. + if (biased_exp == 254) return __int_as_float(0x00400000); + // Fast calculation when the unbiased exp is in [-126, 126], and only the exponent part is used to express the reciprocal. + return __int_as_float((254 - biased_exp) << FP32_MANTISSA_BITS); } __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { @@ -308,6 +393,7 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { #endif //#ifndef __HIP_PLATFORM_AMD__ } +#ifndef __HIP_PLATFORM_AMD__ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, @@ -417,6 +503,8 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() { #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } +#endif //!__HIP_PLATFORM_AMD__ + template struct alignas(2 * sizeof(T)) FPx2 { T x; @@ -699,8 +787,186 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); } } + +#ifndef __HIP_PLATFORM_AMD__ + +template +__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_round_to_nearest( + const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient) { + uint32_t out_8x = 0; + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg.f32 zero; \n\t" + "mov.b32 zero, 0; \n\t" + ".reg.b16 scaling_coeff; \n\t" + "mov.b16 scaling_coeff, %3; \n\t" + ".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t" + "mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t" + "mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t" + + ".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t" + + ".reg.b8 f0, f1, f2, f3; \n\t" + // Elements reordered to match e2m1x4 packing order (v1,v0) + "cvt.rn.satfinite.e2m1x2.f32 f0, v1, v0;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v3, v2;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f2, v5, v4;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f3, v7, v6;\n\t" + "mov.b32 %0, {f0, f1, f2, f3};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "h"(reinterpret_cast(scaling_coefficient))); + } else if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg.b64 scaling_coeff_2x; \n\t" + "mov.b64 scaling_coeff_2x, {%3, %3}; \n\t" + ".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1; \n\t" + "mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2; \n\t" + + ".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "cvt.f32.bf16 v4, v4_bf16; \n\t" + "cvt.f32.bf16 v5, v5_bf16; \n\t" + "cvt.f32.bf16 v6, v6_bf16; \n\t" + "cvt.f32.bf16 v7, v7_bf16; \n\t" + + ".reg.b64 v01, v23, v45, v67; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mov.b64 v45, {v4, v5}; \n\t" + "mov.b64 v67, {v6, v7}; \n\t" + "mul.f32x2 v01, v01, scaling_coeff_2x; \n\t" + "mul.f32x2 v23, v23, scaling_coeff_2x; \n\t" + "mul.f32x2 v45, v45, scaling_coeff_2x; \n\t" + "mul.f32x2 v67, v67, scaling_coeff_2x; \n\t" + // Elements reordered to match the packing order (v1,v0) + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "mov.b64 {v5, v4}, v45; \n\t" + "mov.b64 {v7, v6}, v67; \n\t" + + ".reg.b8 f0, f1, f2, f3; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f2, v4, v5;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f3, v6, v7;\n\t" + "mov.b32 %0, {f0, f1, f2, f3};\n\t" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "f"(scaling_coefficient)); + } else { + NVTE_DEVICE_ERROR("Not supported scaling coefficient type."); + } + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return out_8x; +} + +template +__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient, + const uint32_t rbits03, const uint32_t rbits47) { + uint32_t out_8x = 0; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg.f32 zero; \n\t" + "mov.b32 zero, 0; \n\t" + ".reg.b16 scaling_coeff; \n\t" + "mov.b16 scaling_coeff, %3; \n\t" + ".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t" + "mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t" + "mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t" + + ".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t" + + ".reg.b16 b03, b47; \n\t" + // Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0) + "cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t" + "mov.b32 %0, {b03, b47};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "h"(reinterpret_cast(scaling_coefficient)), + "r"(rbits03), "r"(rbits47)); + } else if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1; \n\t" + "mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2; \n\t" + + ".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "cvt.f32.bf16 v4, v4_bf16; \n\t" + "cvt.f32.bf16 v5, v5_bf16; \n\t" + "cvt.f32.bf16 v6, v6_bf16; \n\t" + "cvt.f32.bf16 v7, v7_bf16; \n\t" + + "mul.f32 v0, v0, %3; \n\t" + "mul.f32 v1, v1, %3; \n\t" + "mul.f32 v2, v2, %3; \n\t" + "mul.f32 v3, v3, %3; \n\t" + "mul.f32 v4, v4, %3; \n\t" + "mul.f32 v5, v5, %3; \n\t" + "mul.f32 v6, v6, %3; \n\t" + "mul.f32 v7, v7, %3; \n\t" + ".reg.b16 b03, b47; \n\t" + // Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0) + "cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t" + "mov.b32 %0, {b03, b47};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "f"(scaling_coefficient), "r"(rbits03), "r"(rbits47)); + } else { + NVTE_DEVICE_ERROR("Not supported scaling coefficient type."); + } + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return out_8x; +} + +#endif //!__HIP_PLATFORM_AMD__ #endif // FP4_TYPE_SUPPORTED +#ifndef __HIP_PLATFORM_AMD__ + // SIMD like "Fused" cast + multiplication (x2) __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, const floatx2 &scale) { @@ -868,7 +1134,6 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) } -#ifndef __HIP_PLATFORM_AMD__ __device__ __forceinline__ int32_t elect_one_sync(uint32_t mask = 0xFFFFFFFFu) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) int32_t pred = 0; @@ -1550,8 +1815,60 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) { : "r"(in2[0]), "r"(in2[1])); return out; } -#endif //#ifndef __HIP_PLATFORM_AMD__ +// Loads single BF16/FP16 element from shared memory state space +__device__ __forceinline__ bf16 ld_shared_b16(const bf16 *__restrict__ src_smem) { + const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); + bf16 dst; + asm volatile("ld.shared.b16 %0, [%1];" + : "=h"(reinterpret_cast(dst)) + : "r"(src_smem_ptr)); + return dst; +} + +// Loads pair of BF16/FP16 values from shared memory state space +__device__ __forceinline__ bf16x2 ld_shared_b32(const bf16x2 *__restrict__ src_smem) { + const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); + bf16x2 dst; + asm volatile("ld.shared.b32 %0, [%1];" + : "=r"(reinterpret_cast(dst)) + : "r"(src_smem_ptr)); + return dst; +} + +// Loads 8x BF16 values from shared memory state space +__device__ __forceinline__ __uint128_t ld_shared_b128(const bf16 *__restrict__ src_smem) { + uint64_t elts03, elts47; + const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); + asm volatile( + "{\n\t" + ".reg.b128 xy; \n\t" + "ld.shared.b128 xy, [%2]; \n\t" + "mov.b128 {%0, %1}, xy; \n" + "}\n" + : "=l"(elts03), "=l"(elts47) + : "r"(src_smem_ptr)); + return (static_cast<__uint128_t>(elts47) << 64) | static_cast<__uint128_t>(elts03); +} + +#if FP4_TYPE_SUPPORTED +// Vectorized store of x8 FP4 elements into shared memory state space +__device__ __forceinline__ void st_shared_b32(fp4e2m1x2 *__restrict__ dst_smem, + uint32_t fp4_pack_x8) { + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem); + asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(fp4_pack_x8)); +} +#endif + +// Vectorized store of x16 FP4 elements into shared memory state space +#if FP4_TYPE_SUPPORTED +__device__ __forceinline__ void st_shared_b64(fp4e2m1x2 *__restrict__ dst_smem, + uint64_t fp4_pack_x16) { + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem); + asm volatile("st.shared.b64 [%0], %1;" : : "r"(dst_smem_ptr), "l"(fp4_pack_x16)); +} +#endif +#endif //!__HIP_PLATFORM_AMD__ } // namespace ptx namespace { diff --git a/transformer_engine/debug/features/api.py b/transformer_engine/debug/features/api.py index 9c30f87c3..a1cf80dd2 100644 --- a/transformer_engine/debug/features/api.py +++ b/transformer_engine/debug/features/api.py @@ -244,7 +244,7 @@ def inspect_tensor( config: Dict, layer_name: str, tensor_name: str, - tensor: torch.Tensor, + tensor: Optional[torch.Tensor], rowwise_quantized_tensor: Optional[torch.Tensor], columnwise_quantized_tensor: Optional[torch.Tensor], quantizer: Optional[Quantizer], @@ -262,8 +262,8 @@ def inspect_tensor( layer_name: str tensor_name: str one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`], - tensor: torch.Tensor - tensor in high precision, + tensor: Optional[torch.Tensor] + tensor in high precision. It can be None only if fp8 model parameters are used and tensor name is `weight`. rowwise_quantized_tensor: Optional[torch.Tensor] rowwise quantized tensor, columnwise_quantized_tensor: Optional[torch.Tensor] @@ -479,7 +479,12 @@ def call_feature(self, call, feat_config, layer_name, **kwargs): """ if call.__name__ == "inspect_tensor": kwargs_copy = kwargs.copy() - for k in ["quantizer", "columnwise_quantized_tensor", "rowwise_quantized_tensor"]: + for k in [ + "quantizer", + "columnwise_quantized_tensor", + "rowwise_quantized_tensor", + "tp_size", + ]: if k not in call.__code__.co_varnames: kwargs_copy.pop(k) else: @@ -490,6 +495,10 @@ def call_feature(self, call, feat_config, layer_name, **kwargs): "inspect_tensor_postquantize is deprecated, use inspect_tensor instead.", DeprecationWarning, ) + kwargs_copy = kwargs.copy() + for k in ["tp_size"]: + if k not in call.__code__.co_varnames: + kwargs_copy.pop(k, None) return call(feat_config, layer_name, **kwargs_copy) diff --git a/transformer_engine/debug/features/disable_fp8_gemm.py b/transformer_engine/debug/features/disable_fp8_gemm.py index befebb412..9bbb7ef4a 100644 --- a/transformer_engine/debug/features/disable_fp8_gemm.py +++ b/transformer_engine/debug/features/disable_fp8_gemm.py @@ -2,17 +2,28 @@ # # See LICENSE for license information. -"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect""" +"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect -from nvdlfw_inspect.registry import Registry, api_method -from transformer_engine.debug.features.api import TEConfigAPIMapper +DEPRECATED: This is a backward compatibility alias for DisableQuantizationGEMM. +New code should use DisableQuantizationGEMM instead, which works with all quantization formats. +""" + +import warnings + +from nvdlfw_inspect.registry import Registry +from transformer_engine.debug.features.disable_quantization_gemm import DisableQuantizationGEMM @Registry.register_feature(namespace="transformer_engine") -class DisableFP8GEMM(TEConfigAPIMapper): +class DisableFP8GEMM(DisableQuantizationGEMM): """ GEMM operations are executed in higher precision, even when FP8 autocast is enabled. + .. deprecated:: + Use :class:`DisableQuantizationGEMM` instead. This class is maintained for + backward compatibility only. DisableQuantizationGEMM works with all quantization + formats (FP8, NVFP4, etc.), not just FP8. + Parameters ---------- @@ -32,22 +43,17 @@ class DisableFP8GEMM(TEConfigAPIMapper): layers: layer_types: [fc1] transformer_engine: - DisableFP8GEMM: + DisableFP8GEMM: # Deprecated: use DisableQuantizationGEMM enabled: True gemms: [dgrad, wgrad] """ - @api_method - def fp8_gemm_enabled( - self, config, layer_name: str, gemm: str, iteration: int - ): # pylint: disable=unused-argument - """API call responsible for choice between high-precision and FP8 GEMM execution.""" - - for key in config: - if key != "gemm": - raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') - - # If this feature is invoked, then FP8 GEMM is disabled. - # If not, then default behaviour in TransformerEngineAPI - # is that fp8_gemm() API call returns True. - return False, iteration + 1 + def __init__(self, *args, **kwargs): + warnings.warn( + "DisableFP8GEMM is deprecated. " + "Use DisableQuantizationGEMM instead, which works with all quantization " + "formats (FP8, NVFP4, etc.).", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) diff --git a/transformer_engine/debug/features/disable_fp8_layer.py b/transformer_engine/debug/features/disable_fp8_layer.py index 3839e5f2b..5ae03ef45 100644 --- a/transformer_engine/debug/features/disable_fp8_layer.py +++ b/transformer_engine/debug/features/disable_fp8_layer.py @@ -2,17 +2,27 @@ # # See LICENSE for license information. -"""DisableFP8Layer Feature support for nvidia-dlframework-inspect""" +"""DisableFP8Layer Feature support for nvidia-dlframework-inspect -import nvdlfw_inspect.api as debug_api -from nvdlfw_inspect.registry import Registry, api_method +DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer. +New code should use DisableQuantizationLayer instead, which works with all quantization formats. +""" + +import warnings + +from nvdlfw_inspect.registry import Registry +from transformer_engine.debug.features.disable_quantization_layer import DisableQuantizationLayer @Registry.register_feature(namespace="transformer_engine") -class DisableFP8Layer: +class DisableFP8Layer(DisableQuantizationLayer): """ Disables all FP8 GEMMs in the layer. + .. deprecated:: + Use :class:`DisableQuantizationLayer` instead. This class is maintained for + backward compatibility only. DisableQuantizationLayer works with all quantization + formats (FP8, NVFP4, etc.), not just FP8. Example ------- @@ -20,36 +30,19 @@ class DisableFP8Layer: example_disable_fp8_layer: enabled: True - layers: - layer_types: [fc1] - transformer_engine: - DisableFP8Layer: - enabled: True + layers: + layer_types: [fc1] + transformer_engine: + DisableFP8Layer: # Deprecated: use DisableQuantizationLayer + enabled: True """ - @api_method - def fp8_gemm_enabled( - self, config, layer_name: str, gemm: str, iteration: int - ): # pylint: disable=unused-argument - """API call responsible for selecting between high-precision and FP8 GEMM execution.""" - for key in config: - if key not in ["enabled", "gemm"]: - raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') - # If FP8 training, disable FP8 for the selected layers if this feature is enabled in config. - debug_api.log_message("FP8 Disabled", layer_name) - - # If this feature is invoked, then FP8 GEMM is disabled. - # If not, then default behavior in TransformerEngineAPI - # is that fp8_gemm() API call returns True. - return False, iteration + 1 - - def parse_config_and_api(self, config, **_kwargs): - """Determines whether to run the API - DisableFP8Layer is the only feature provided by the Transformer Engine - which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for - parsing gemms and tensors fields from the config, which are not needed for this feature. - - Explanation of the parse_config_and_api can be found in the - nvidia-dlframework-inspect documentation. - """ - return config["enabled"], None + def __init__(self, *args, **kwargs): + warnings.warn( + "DisableFP8Layer is deprecated. " + "Use DisableQuantizationLayer instead, which works with all quantization " + "formats (FP8, NVFP4, etc.).", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) diff --git a/transformer_engine/debug/features/disable_quantization_gemm.py b/transformer_engine/debug/features/disable_quantization_gemm.py new file mode 100644 index 000000000..932c2f83d --- /dev/null +++ b/transformer_engine/debug/features/disable_quantization_gemm.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""DisableQuantizationGEMM Feature support for nvidia-dlframework-inspect""" + +from nvdlfw_inspect.registry import Registry, api_method +from transformer_engine.debug.features.api import TEConfigAPIMapper + + +@Registry.register_feature(namespace="transformer_engine") +class DisableQuantizationGEMM(TEConfigAPIMapper): + """ + Disables specific GEMM operations from using quantization, forcing high-precision execution. + + Works with any quantization format (FP8, NVFP4, etc.). + + Parameters + ---------- + + gemms: List[str] + list of gemms to disable quantization for + + - fprop + - dgrad + - wgrad + + Example + ------- + .. code-block:: yaml + + example_disable_quantization_gemm: + enabled: True + layers: + layer_types: [fc1] + transformer_engine: + DisableQuantizationGEMM: + enabled: True + gemms: [dgrad, wgrad] + """ + + @api_method + def fp8_gemm_enabled( + self, config, layer_name: str, gemm: str, iteration: int + ): # pylint: disable=unused-argument + """API call responsible for choice between high-precision and quantized GEMM execution. + + Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API, + but it applies to all quantization formats (FP8, NVFP4, etc.). + """ + + for key in config: + if key != "gemm": + raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') + + # If this feature is invoked, then quantized GEMM is disabled (returns to high precision). + # If not, then default behavior in TransformerEngineAPI + # is that fp8_gemm() API call returns True. + return False, iteration + 1 diff --git a/transformer_engine/debug/features/disable_quantization_layer.py b/transformer_engine/debug/features/disable_quantization_layer.py new file mode 100644 index 000000000..081e310ed --- /dev/null +++ b/transformer_engine/debug/features/disable_quantization_layer.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""DisableQuantizationLayer Feature support for nvidia-dlframework-inspect""" + +import nvdlfw_inspect.api as debug_api +from nvdlfw_inspect.registry import Registry, api_method + + +@Registry.register_feature(namespace="transformer_engine") +class DisableQuantizationLayer: + """ + Disables all quantized GEMMs in the layer, forcing high-precision execution. + + Works with any quantization format (FP8, NVFP4, etc.). + + Example + ------- + .. code-block:: yaml + + example_disable_quantization_layer: + enabled: True + layers: + layer_types: [fc1] + transformer_engine: + DisableQuantizationLayer: + enabled: True + """ + + @api_method + def fp8_gemm_enabled( + self, config, layer_name: str, gemm: str, iteration: int + ): # pylint: disable=unused-argument + """API call responsible for selecting between high-precision and quantized GEMM execution. + + Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API, + but it applies to all quantization formats (FP8, NVFP4, etc.). + """ + for key in config: + if key not in ["enabled", "gemm"]: + raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') + # If quantized training, disable quantization for the selected layers if this feature is enabled. + debug_api.log_message("Quantization Disabled", layer_name) + + # If this feature is invoked, then quantized GEMM is disabled (returns to high precision). + # If not, then default behavior in TransformerEngineAPI + # is that fp8_gemm() API call returns True. + return False, iteration + 1 + + def parse_config_and_api(self, config, **_kwargs): + """Determines whether to run the API. + + DisableQuantizationLayer is the only feature provided by the Transformer Engine + which does not inherit from TEConfigAPIMapper - this mapper is primarily responsible for + parsing gemms and tensors fields from the config, which are not needed for this feature. + + Explanation of the parse_config_and_api can be found in the + nvidia-dlframework-inspect documentation. + """ + return config["enabled"], None diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index ffcc6b1ad..cf11964e2 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -6,15 +6,17 @@ from typing import Dict, Optional, List, Tuple from contextlib import contextmanager +import warnings import torch import nvdlfw_inspect.api as debug_api - +import transformer_engine_torch as tex from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS +from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Quantizer, @@ -22,7 +24,14 @@ ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer -from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter + +try: + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + _nvfp4_available = True +except ImportError: + _nvfp4_available = False + NVFP4Quantizer = None ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"] @@ -39,6 +48,8 @@ def _get_recipe_name(quantizer: Optional[Quantizer]): return "mxfp8" if isinstance(quantizer, Float8BlockQuantizer): return "fp8_block_scaling" + if _nvfp4_available and isinstance(quantizer, NVFP4Quantizer): + return "nvfp4" raise ValueError(f"Unsupported quantizer type: {type(quantizer)}") @@ -111,6 +122,10 @@ class LogFp8TensorStats(BaseLogTensorStats): - scale_inv_max - maximum of the inverse of the scaling factors, - mse - mean squared error of the quantized tensor and the original tensor = sum((quantized_tensor - original_tensor)**2) / num_elements, + When collecting stats for the weight tensor with FP8 model parameters enabled, + only "scale_inv_min" and "scale_inv_max" are available. + All other statistics require access to the high precision tensor. + tensors/tensors_struct: List[str] list of tensors to log - activation, @@ -148,7 +163,9 @@ class LogFp8TensorStats(BaseLogTensorStats): end_step: 80 """ - def check_if_stat_is_supported(self, stat: str, current_recipe: str): + def check_if_stat_is_supported( + self, stat: str, current_recipe: str, high_precision_tensor_provided: bool + ): """Returns True if stat is supported, raises ValueError otherwise.""" columnwise = stat.endswith("_columnwise") if columnwise: @@ -156,6 +173,17 @@ def check_if_stat_is_supported(self, stat: str, current_recipe: str): recipe_from_stat, _ = self.get_recipe_from_stat(stat, default_recipe=current_recipe) stat_without_recipe = stat.replace(recipe_from_stat + "_", "") + need_high_precision_tensor_stats = ["underflows%", "overflows%", "mse"] + if ( + stat_without_recipe in need_high_precision_tensor_stats + and not high_precision_tensor_provided + ): + raise ValueError( + f"Stat {stat} requires a high precision tensor to be provided. " + "This feature is not supported for weight tensors when using fp8 model " + "parameters." + ) + if current_recipe == "" and recipe_from_stat == "": raise ValueError( f"Stat {stat} does not contain a recipe name and the current recipe is not set." @@ -164,6 +192,16 @@ def check_if_stat_is_supported(self, stat: str, current_recipe: str): if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES: raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}") + # Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work) + # But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer + if recipe_from_stat == "nvfp4": + raise ValueError( + f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats." + " FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for" + " NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.," + " 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons." + ) + if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise: raise ValueError( f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for" @@ -189,6 +227,7 @@ def check_if_stat_is_supported(self, stat: str, current_recipe: str): def get_recipe_from_stat(self, stat: str, default_recipe: str = ""): """Returns the recipe name from the stat string.""" + columnwise_stat = stat.endswith("_columnwise") for recipe_name in ALL_RECIPE_NAMES: if recipe_name in stat: @@ -213,7 +252,7 @@ def update_aux_dict( Yields the aux_dict. Needs to clean after usage, because it possibly change the usage of the quantized tensor. """ - fp8_dtype = None + fp8_dtype = tex.DType.kFloat8E4M3 if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]: assert isinstance( quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer) @@ -268,27 +307,42 @@ def inspect_tensor( tensor_name: str, iteration: int, tp_group: torch.distributed.ProcessGroup, - tensor: torch.Tensor, + tensor: Optional[torch.Tensor], rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, quantizer: Optional[Quantizer] = None, + tp_size: int = 1, ): """ API call used to collect the data about the tensor after process_tensor()/quantization. """ assert rowwise_quantized_tensor is columnwise_quantized_tensor - assert ( - quantizer is not None - ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe." + + # Skip logging if quantizer is None (layer runs in high precision) + if quantizer is None: + warnings.warn( + f"[LogFp8TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': layer runs in high precision (no quantizer)." + ) + return quantized_tensor = rowwise_quantized_tensor - assert isinstance( - quantized_tensor, QuantizedTensor - ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor." + + # Skip logging if quantized_tensor is not a QuantizedTensor (incompatible precision) + if not isinstance(quantized_tensor, QuantizedTensor): + warnings.warn( + f"[LogFp8TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': incompatible precision " + f"(expected QuantizedTensor, got {type(quantized_tensor).__name__})." + ) + return + recipe_name = _get_recipe_name(quantizer) for stat in config["stats"]: - self.check_if_stat_is_supported(stat, recipe_name) + self.check_if_stat_is_supported( + stat, recipe_name, high_precision_tensor_provided=tensor is not None + ) start_step = config.get("start_step", None) end_step = config.get("end_step", None) @@ -304,7 +358,7 @@ def inspect_tensor( ) skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( - tensor_name, tp_group + tensor_name, tp_group, tp_size ) STATS_BUFFERS.try_add_buffer( diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py new file mode 100644 index 000000000..8a76f4edc --- /dev/null +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -0,0 +1,238 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""LogNvfp4TensorStats Feature support for nvidia-dlframework-inspect""" + +from typing import Dict, Optional +from contextlib import contextmanager +import warnings + +import torch +import nvdlfw_inspect.api as debug_api + +from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats +from nvdlfw_inspect.registry import Registry, api_method + +from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS +from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter +from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage + + +@Registry.register_feature(namespace="transformer_engine") +class LogNvfp4TensorStats(BaseLogTensorStats): + """Logs statistics of NVFP4 quantized tensors. + + In distributed runs each rank first computes its local statistics; the values + are gathered the next time `debug_api.step()` is called. Remember to call + `debug_api.step()` every training step so the logs are flushed. + + The feature is micro-batch aware: if several forward/backward passes occur + between successive `debug_api.step()` calls, statistics are accumulated for all + tensors except weights. + + Collecting NVFP4 statistics is expensive. Choosing a larger `freq` reduces the + overhead, and if the feature is skipped for a step the additional cost is + minimal. When no other debug feature is active, the layer runs at normal + Transformer Engine speed. + + Parameters + ---------- + + stats: List[str] + List of statistics to collect. Available stats: + - underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data) + - mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements + + tensors/tensors_struct: List[str] + list of tensors to log + - activation, + - gradient, + - weight, + + freq: Optional[int], default = 1 + frequency of logging stats, stats will be logged every `freq` steps + start_step: Optional[int], default = None + start step of logging stats + end_step: Optional[int], default = None + end step of logging stats + start_end_list: Optional[list([int, int])], default = None + non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step + + Example + ------- + .. code-block:: yaml + + example_nvfp4_tensor_stat_collection: + enabled: True + layers: + layer_types: [layernorm_linear] + transformer_engine: + LogNvfp4TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, mse] + freq: 1 + - tensor: gradient + stats: [underflows%, mse] + freq: 5 + start_step: 0 + end_step: 80 + """ + + def check_if_stat_is_supported(self, stat: str): + """Returns True if stat is supported, raises ValueError otherwise.""" + supported_stats = [ + "underflows%", + "mse", + ] + if stat not in supported_stats: + raise ValueError( + f"Stat {stat} is not supported for NVFP4. Supported stats: {supported_stats}" + ) + return True + + def get_stat_with_prefix(self, stat: str) -> str: + """Add nvfp4_ prefix to stat name for use in stats_computation.""" + return f"nvfp4_{stat}" + + @contextmanager + def update_aux_dict( + self, + aux_dict: Dict, + quantized_tensor: QuantizedTensor, + quantizer: Quantizer, # pylint: disable=unused-argument + original_tensor: torch.Tensor, + ): + """ + Updates the aux_dict with the quantized tensor and additional NVFP4-specific data. + Yields the aux_dict. + """ + aux_dict = { + "nvfp4": quantized_tensor, + "original_tensor": original_tensor, + } + + try: + yield aux_dict + finally: + pass + + @api_method + def inspect_tensor_enabled( + self, config: Dict, layer_name: str, tensor_name: str, iteration: int + ): # pylint: disable=unused-argument + """API call used to determine whether to run inspect_tensor() in the forward.""" + run_current, next_iter = next_enabled_iter( + config.get("start_step", None), + config.get("end_step", None), + config.get("start_end_list", None), + config.get("freq", 1), + iteration, + ) + STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter + return run_current, next_iter + + @api_method + def inspect_tensor( + self, + config: Dict, + layer_name: str, + tensor_name: str, + iteration: int, + tp_group, + tensor: torch.Tensor, + rowwise_quantized_tensor: Optional[QuantizedTensor] = None, + columnwise_quantized_tensor: Optional[QuantizedTensor] = None, + quantizer: Optional[Quantizer] = None, + tp_size: int = 1, + ): + """ + API call used to collect the data about the tensor after process_tensor()/quantization. + """ + assert rowwise_quantized_tensor is columnwise_quantized_tensor + + # Skip logging if quantizer is None (layer runs in high precision) + if quantizer is None: + warnings.warn( + f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': layer runs in high precision (no quantizer)." + ) + return + + quantized_tensor = rowwise_quantized_tensor + + # Skip logging if not NVFP4 quantizer (incompatible precision) + if not isinstance(quantizer, NVFP4Quantizer): + warnings.warn( + f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': incompatible precision " + f"(expected NVFP4Quantizer, got {type(quantizer).__name__})." + ) + return + + # Skip logging if quantized tensor is not NVFP4TensorStorage (incompatible precision) + if not isinstance(quantized_tensor, NVFP4TensorStorage): + warnings.warn( + f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': incompatible precision " + f"(expected NVFP4TensorStorage, got {type(quantized_tensor).__name__})." + ) + return + + for stat in config["stats"]: + self.check_if_stat_is_supported(stat) + + start_step = config.get("start_step", None) + end_step = config.get("end_step", None) + start_end_list = config.get("start_end_list", None) + if start_end_list is not None: + start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list) + + options = ( + start_step, + end_step, + start_end_list, + "nvfp4", + ) + + skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( + tensor_name, tp_group, tp_size + ) + + # Add nvfp4_ prefix to all stats for internal use + prefixed_stats = [self.get_stat_with_prefix(stat) for stat in config["stats"]] + + STATS_BUFFERS.try_add_buffer( + layer_name=layer_name, + tensor_name=tensor_name, + stats=prefixed_stats, + options=options, + reduction_group=reduction_group, + reduce_within_microbatch=reduce_within_microbatch, + ) + + with self.update_aux_dict( + aux_dict={}, + quantized_tensor=quantized_tensor, + quantizer=quantizer, + original_tensor=tensor, + ) as aux_dict: + STATS_BUFFERS.feed( + layer_name, + tensor_name, + options, + tensor, + iteration, + skip_reduction, + aux_dict=aux_dict, + ) + + debug_api.log_message( + f"Feature={self.__class__.__name__}, API=inspect_tensor: {tensor_name}", + layer_name, + extra_cachable_args=(tensor_name,), + ) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 100fa6448..5e6ce137b 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -180,13 +180,20 @@ def inspect_tensor( tensor_name: str, iteration: int, tp_group: torch.distributed.ProcessGroup, - tensor: torch.Tensor, + tensor: Optional[torch.Tensor], rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, quantizer: Optional[Quantizer] = None, + tp_size: int = 1, ): # pylint: disable=unused-argument """API call used to collect the data about the tensor before process_tensor()/quantization.""" + # Tensor is None only if fp8 model parameters are used and tensor name is `weight`. + # If one wants to collect stats for this tensor, we need to dequantize it. + if tensor is None: + assert isinstance(rowwise_quantized_tensor, QuantizedTensor) + tensor = rowwise_quantized_tensor.dequantize() + assert ( type(tensor) not in [Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage] and tensor.dtype != torch.uint8 @@ -208,7 +215,7 @@ def inspect_tensor( ) skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( - tensor_name, tp_group + tensor_name, tp_group, tp_size ) for stat in config["stats"]: diff --git a/transformer_engine/debug/features/utils/__init__.py b/transformer_engine/debug/features/utils/__init__.py index d691c1828..813fb2add 100644 --- a/transformer_engine/debug/features/utils/__init__.py +++ b/transformer_engine/debug/features/utils/__init__.py @@ -12,7 +12,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState -def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup): +def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup, tp_size: int): """ Returns the statistics reduction parameters for the tensor. """ @@ -20,8 +20,14 @@ def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGr reduction_group = debug_api.get_tensor_reduction_group() reduce_within_microbatch = tensor_name != "weight" if tensor_name == "weight": - if TEDebugState.weight_tensor_tp_group_reduce: - reduction_group = tp_group + if TEDebugState.weight_tensor_tp_group_reduce and tp_size > 1: + # Do not overwrite with `None`: in torch.distributed collectives + # group=None means the default/world process group. + if tp_group is not None: + reduction_group = tp_group + else: + # "Reduce in TP group" requested, but TP group is missing. + skip_reduction = True else: skip_reduction = True return skip_reduction, reduction_group, reduce_within_microbatch diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index 9ce56dd76..ca7f22e2d 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -90,12 +90,19 @@ def feed(self, tensor, iteration, aux_dict=None): if self.modified[0] and not self.reduce_within_microbatch: return - if ( - tensor.numel() == 0 - if hasattr(tensor, "numel") - else all((t is None or t.numel() == 0) for t in tensor.get_data_tensors()) - ): - return + if tensor is not None: + # tensor can be None if we compute fp8 stats for weight and fp8 model parameters are used + # then high precision is not provided and quantized tensor from aux_dict is used. + + # This condition prevents computation of stats for empty tensor. + # This will not happen for weight - since it is the only situation then tensor can be None, + # we do not need to check similar condition for weight. + if ( + tensor.numel() == 0 + if hasattr(tensor, "numel") + else all((t is None or t.numel() == 0) for t in tensor.get_data_tensors()) + ): + return # save stats for tensor to tmp buffer for stat_name in self.stats_to_compute: diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 46a48e2ab..b0002ffee 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -443,3 +443,65 @@ def add_max_blockwise_dynamic_range_stats( add_underflows_stats(_recipe_name, _columnwise) add_scale_inv_stats(_recipe_name, _columnwise) add_mse_stats(_recipe_name, _columnwise) + + +# NVFP4-specific statistics + + +def count_nonzero_nvfp4(fp4_data: torch.Tensor) -> torch.Tensor: + """Count the number of non-zero elements in the FP4 data. + + FP4 data is stored as 2 4-bit values per byte (uint8). + We need to unpack and count non-zeros. + """ + # Each byte contains two FP4 values + # Value 0 in FP4 E2M1 format is represented as 0 (and also 8 for -0.0) + zero_vals = torch.tensor([0, 8], device=fp4_data.device, dtype=torch.uint8) + + # Extract first and second nibbles + first_nibble = fp4_data % 16 + second_nibble = fp4_data // 16 + + # Count zeros + first_zeros = torch.isin(first_nibble, zero_vals).sum() + second_zeros = torch.isin(second_nibble, zero_vals).sum() + + total_elements = fp4_data.numel() * 2 + return total_elements - first_zeros - second_zeros + + +def add_nvfp4_underflows_stats(): + """Register underflow stats for NVFP4. + + Computes underflows by counting zeros in packed FP4 data vs original tensor. + """ + stat_num = "nvfp4_underflows_num" + stat_pct = "nvfp4_underflows%" + + stats_to_num[stat_num] = len(stats_to_num) + stats_to_num[stat_pct] = len(stats_to_num) + + # Count non-zeros in original vs FP4 packed data + STATS[stat_num] = ( + lambda x, aux_dict: x.count_nonzero() + - count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data), + lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)), + ) + STATS[stat_pct] = ( + lambda x, aux_dict: ( + x.count_nonzero() - count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data) + ) + / aux_dict["nvfp4"].numel() + * 100, + lambda buffers, _sn_num=stat_num: 100 + * sum(_get(buffers, _sn_num)) + / sum(_get(buffers, "numel")), + ) + + DEPENDENCIES[stat_num] = {stat_num} + DEPENDENCIES[stat_pct] = {stat_num, "numel"} + + +# Register NVFP4 stats +add_nvfp4_underflows_stats() +add_mse_stats("nvfp4") # Reuse existing MSE function diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 29a108c75..ed5fdd466 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -36,7 +36,7 @@ } API_CALL_MODIFY = "modify_tensor()" -STANDARD_FP8_QUANTIZE = "FP8 Quantize" +STANDARD_QUANTIZE = "Quantize" HIGH_PRECISION = "High Precision" @@ -53,6 +53,7 @@ def __init__( tensor_name: str, parent_quantizer: Optional[Quantizer], tp_group: torch.distributed.ProcessGroup, + tp_size: int, ): super().__init__(rowwise=True, columnwise=True) @@ -60,6 +61,7 @@ def __init__( self.tensor_name = tensor_name self.parent_quantizer = parent_quantizer self.tp_group = tp_group # used in inspect_tensor calls + self.tp_size = tp_size self.iteration = TEDebugState.get_iteration() # Configure parent quantizer @@ -88,7 +90,7 @@ def __init__( # inspect_tensor*_enabled are bool fields, # indicating whether some feature will need to run inspect_tensor_* calls. # - # *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION] + # *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_QUANTIZE, HIGH_PRECISION] # determining what will happen when the quantizer is used for that tensor. self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"] if self.output_tensor: @@ -170,7 +172,7 @@ def get_enabled_look_at_tensors(self): def get_tensors_plan(self): """ Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of - API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior + API_CALL_MODIFY, STANDARD_QUANTIZE, or HIGH_PRECISION, indicating the behavior of this quantizer with respect to these tensors. """ import nvdlfw_inspect.api as debug_api @@ -191,16 +193,16 @@ def get_tensors_plan(self): rowwise_plan = API_CALL_MODIFY else: if self.parent_quantizer is not None: - fp8_quantize = self.process_enabled_api_call( - debug_api.transformer_engine.fp8_gemm_enabled( + quantize_enabled = self.process_enabled_api_call( + debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility layer_name=self.layer_name, gemm=self.rowwise_gemm_name, iteration=self.iteration, ) ) - if fp8_quantize: - rowwise_plan = STANDARD_FP8_QUANTIZE + if quantize_enabled: + rowwise_plan = STANDARD_QUANTIZE if rowwise_plan is None: rowwise_plan = HIGH_PRECISION @@ -218,16 +220,16 @@ def get_tensors_plan(self): columnwise_plan = API_CALL_MODIFY else: if self.parent_quantizer is not None: - fp8_quantize = self.process_enabled_api_call( - debug_api.transformer_engine.fp8_gemm_enabled( + quantize_enabled = self.process_enabled_api_call( + debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility layer_name=self.layer_name, gemm=self.columnwise_gemm_name, iteration=self.iteration, ) ) - if fp8_quantize: - columnwise_plan = STANDARD_FP8_QUANTIZE + if quantize_enabled: + columnwise_plan = STANDARD_QUANTIZE if columnwise_plan is None: columnwise_plan = HIGH_PRECISION @@ -263,11 +265,12 @@ def _call_inspect_tensor_api( "tensor_name": self.tensor_name, "iteration": TEDebugState.get_iteration(), "tp_group": self.tp_group, + "tp_size": self.tp_size, "columnwise_quantized_tensor": columnwise_gemm_tensor, "rowwise_quantized_tensor": rowwise_gemm_tensor, "quantizer": self.parent_quantizer, } - if tensor is not None and self.inspect_tensor_enabled: + if self.inspect_tensor_enabled: debug_api.transformer_engine.inspect_tensor(**args) if self.output_tensor: @@ -278,7 +281,7 @@ def _call_inspect_tensor_api( del args["quantizer"] if ( - self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE] and self.inspect_tensor_postquantize_enabled_rowwise ): args["tensor"] = rowwise_gemm_tensor @@ -286,7 +289,7 @@ def _call_inspect_tensor_api( debug_api.transformer_engine.inspect_tensor_postquantize(**args) if ( - self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE] and self.inspect_tensor_postquantize_enabled_columnwise ): args["tensor"] = columnwise_gemm_tensor @@ -317,14 +320,14 @@ def quantize( self.parent_quantizer.set_usage(rowwise=True) rowwise_gemm_tensor, columnwise_gemm_tensor = None, None - if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: quantized_tensor = self.parent_quantizer(tensor) - # if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8, + # if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized, # one tensor with columnwise=True and rowwise=True is computed # and both rowwise_tensor_plan and columnwise_tensor_plan point to it. - if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE: + if self.rowwise_tensor_plan == STANDARD_QUANTIZE: rowwise_gemm_tensor = quantized_tensor - if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE: + if self.columnwise_tensor_plan == STANDARD_QUANTIZE: columnwise_gemm_tensor = quantized_tensor # 2. modify_tensor() is called, if it is used. @@ -379,7 +382,7 @@ def process_gemm_output(self, tensor: torch.Tensor): """This call is invoked after the gemm to inspect and modify the output tensor.""" import nvdlfw_inspect.api as debug_api - assert self.parent_quantizer is None, "FP8 output is not supported for debug=True." + assert self.parent_quantizer is None, "Quantized output is not supported for debug=True." assert self.output_tensor tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"} if self.rowwise_tensor_plan == API_CALL_MODIFY: @@ -420,9 +423,9 @@ def any_feature_enabled(self) -> bool: ): return True if self.parent_quantizer is not None: - if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: + if self.rowwise_tensor_plan != STANDARD_QUANTIZE: return True - if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: + if self.columnwise_tensor_plan != STANDARD_QUANTIZE: return True return False @@ -446,7 +449,7 @@ def update_quantized( if self.parent_quantizer is not None: if ( dst.rowwise_gemm_tensor is not None - and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + and self.rowwise_tensor_plan == STANDARD_QUANTIZE ): if hasattr(dst.rowwise_gemm_tensor, "quantize_"): dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None) @@ -455,7 +458,7 @@ def update_quantized( updated_rowwise_gemm = True if ( dst.columnwise_gemm_tensor is not None - and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + and self.columnwise_tensor_plan == STANDARD_QUANTIZE and not updated_rowwise_gemm ): if hasattr(dst.columnwise_gemm_tensor, "quantize_"): @@ -540,14 +543,12 @@ def _update_parent_quantizer_usage(self): """ Updates the usage of the parent quantizer. """ - rowwise_gemm_quantize = ( - self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE - ) + rowwise_gemm_quantize = self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_QUANTIZE columnwise_gemm_quantize = ( - self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_QUANTIZE ) - if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: self.parent_quantizer.set_usage( rowwise=rowwise_gemm_quantize, columnwise=columnwise_gemm_quantize, @@ -561,6 +562,30 @@ def set_usage(self, rowwise: bool = None, columnwise: bool = None): if not self.output_tensor: self._update_parent_quantizer_usage() + def wrap_quantized_tensor(self, tensor: QuantizedTensor): + """ + Wraps the quantized tensor with the debug quantizer. + It is used for weight tensors when fp8 model parameters are enabled. + """ + + assert ( + self.rowwise_tensor_plan == STANDARD_QUANTIZE + and self.columnwise_tensor_plan == STANDARD_QUANTIZE + ), ( + "[NVTORCH INSPECT ERROR] Weight tensor with fp8 model parameters enabled cannot be" + " modified by any feature." + ) + + self._call_inspect_tensor_api(None, tensor, tensor) + + return DebugQuantizedTensor( + rowwise_gemm_tensor=tensor, + columnwise_gemm_tensor=tensor, + quantizer=self, + layer_name=self.layer_name, + tensor_name=self.tensor_name, + ) + @classmethod def multi_tensor_quantize( cls, @@ -675,3 +700,12 @@ def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None raise RuntimeError( "Cannot recreate columnwise tensor from rowwise tensor is debug mode." ) + + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self.rowwise_gemm_tensor is not None: + return self.rowwise_gemm_tensor.device + if self.columnwise_gemm_tensor is not None: + return self.columnwise_gemm_tensor.device + raise RuntimeError("DebugQuantizedTensor has no data!") diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 21db296c3..765cf2872 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -569,6 +569,8 @@ def _segment_ids_pos_to_seqlens_offsets( # using the segment ids and pos along with mask type (causal or brcm) is sufficient. # It does not need to involve SW for this mask's creation + # Currently, this function is only exercised for THD qkv_layout. + # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well if (attn_mask_type.is_causal() and window_size is None) or ( window_size == (-1, -1) and not attn_mask_type.is_bottom_right() @@ -693,26 +695,83 @@ def get_seqlens_and_offsets( self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq ): """ - Acquire the seqlens/offsets for cuDNN backend + Acquire the seqlens/offsets for cuDNN backend. """ q_segment_ids, kv_segment_ids = self.segment_ids q_segment_pos, kv_segment_pos = self.segment_pos - assert q_segment_ids.shape == q_segment_pos.shape - assert kv_segment_ids.shape == kv_segment_pos.shape # No segment_ids/segment_pos if q_segment_ids.size + kv_segment_ids.size == 0: return self.seqlens, self.seq_offsets - if qkv_layout.is_thd(): - q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets( - q_segment_ids, - kv_segment_ids, - q_segment_pos, - kv_segment_pos, - attn_mask_type, - window_size, - max_segments_per_seq, + # Allow segment_pos to have fewer leading dims than segment_ids if vmapped segment_ids and non-vmapped segment_pos + # e.g. when using from_segment_ids_and_pos() for segment_pos generation from segment_ids it is acceptable to have + # something like : segment_ids (B, batch, seq), segment_pos (batch, seq)). + if q_segment_ids.ndim < q_segment_pos.ndim or kv_segment_ids.ndim < kv_segment_pos.ndim: + raise AssertionError( + "segment_ids must not have fewer dims than segment_pos; got" + f" q_segment_ids.ndim={q_segment_ids.ndim}," + f" q_segment_pos.ndim={q_segment_pos.ndim}," + f" kv_segment_ids.ndim={kv_segment_ids.ndim}," + f" kv_segment_pos.ndim={kv_segment_pos.ndim}" + ) + if not ( + q_segment_ids.shape[-q_segment_pos.ndim :] == q_segment_pos.shape + and kv_segment_ids.shape[-kv_segment_pos.ndim :] == kv_segment_pos.shape + ): + raise AssertionError( + "segment_pos trailing shape must match segment_ids; got" + f" q_segment_ids.shape={q_segment_ids.shape}," + f" q_segment_pos.shape={q_segment_pos.shape}," + f" kv_segment_ids.shape={kv_segment_ids.shape}," + f" kv_segment_pos.shape={kv_segment_pos.shape}" ) + # THD: compute seqlens/offsets. + if qkv_layout.is_thd(): + # If there are more leading dims on segment_ids, e.g. vmap + if q_segment_ids.ndim > q_segment_pos.ndim or kv_segment_ids.ndim > kv_segment_pos.ndim: + # Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims, + # vmap seqlens/offsets computation with segment_pos broadcast, + # reshape back to the original leading batch dims. + n_extra_batch_dims_q = q_segment_ids.ndim - q_segment_pos.ndim + n_extra_batch_dims_kv = kv_segment_ids.ndim - kv_segment_pos.ndim + extra_batch_shape_q = q_segment_ids.shape[:n_extra_batch_dims_q] + extra_batch_shape_kv = kv_segment_ids.shape[:n_extra_batch_dims_kv] + extra_flat_batch_size_q = jnp.prod(extra_batch_shape_q) + extra_flat_batch_size_kv = jnp.prod(extra_batch_shape_kv) + # vmap below requires same batch size on axis 0 for q_flat and kv_flat; JAX will raise if they differ. + q_flat = q_segment_ids.reshape( + extra_flat_batch_size_q, *q_segment_ids.shape[n_extra_batch_dims_q:] + ) + kv_flat = kv_segment_ids.reshape( + extra_flat_batch_size_kv, *kv_segment_ids.shape[n_extra_batch_dims_kv:] + ) + + single_extra_batch = partial( + _segment_ids_pos_to_seqlens_offsets, + attn_mask_type=attn_mask_type, + window_size=window_size, + max_segments_per_seq=max_segments_per_seq, + ) + + q_sl, kv_sl, q_off, kv_off = jax.vmap( + single_extra_batch, in_axes=(0, 0, None, None) + )(q_flat, kv_flat, q_segment_pos, kv_segment_pos) + + q_seqlens = q_sl.reshape(*extra_batch_shape_q, *q_sl.shape[1:]) + kv_seqlens = kv_sl.reshape(*extra_batch_shape_kv, *kv_sl.shape[1:]) + q_offsets = q_off.reshape(*extra_batch_shape_q, *q_off.shape[1:]) + kv_offsets = kv_off.reshape(*extra_batch_shape_kv, *kv_off.shape[1:]) + else: + q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets( + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + attn_mask_type, + window_size, + max_segments_per_seq, + ) + # BSHD: compute seqlens/offsets. else: q_seqlens, kv_seqlens = _segment_ids_to_seqlens( q_segment_ids, diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index 6a2f9b737..d203fcea9 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -9,3 +9,4 @@ from .quantization import * from .softmax import * from .gemm import * +from .router import * diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 49ed6331c..4444460de 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -46,6 +46,7 @@ ActivationEnum = { ("gelu",): NVTE_Activation_Type.GELU, ("gelu", "linear"): NVTE_Activation_Type.GEGLU, + ("sigmoid", "linear"): NVTE_Activation_Type.GLU, ("silu",): NVTE_Activation_Type.SILU, ("silu", "linear"): NVTE_Activation_Type.SWIGLU, ("relu",): NVTE_Activation_Type.RELU, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 4d669bc46..4d64c0b5f 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -9,15 +9,13 @@ import warnings from dataclasses import dataclass, replace from functools import partial, reduce -from packaging import version from typing import Optional, Tuple import jax import jax.numpy as jnp from jax import dtypes, lax, ffi from jax.sharding import PartitionSpec, NamedSharding -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import SdyShardingRule +from jax.experimental.custom_partitioning import SdyShardingRule import transformer_engine_jax from transformer_engine_jax import NVTE_Fused_Attn_Backend @@ -75,6 +73,7 @@ "is_training", "max_segments_per_seq", "window_size", + "bottom_right_diagonal", "context_parallel_load_balanced", "cp_axis", "cp_striped_window_size", @@ -96,6 +95,7 @@ class _FusedAttnConfig: is_training: bool max_segments_per_seq: int window_size: Tuple[int, int] + bottom_right_diagonal: bool context_parallel_load_balanced: bool cp_axis: str cp_striped_window_size: Tuple[int, int] # Only for CP + Ring P2P + THD + SWA @@ -149,6 +149,7 @@ def get_fused_attn_backend(self): self.head_dim_v, self.window_size[0], self.window_size[1], + not self.is_non_deterministic_allowed(), ) @staticmethod @@ -167,13 +168,25 @@ def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): kv_max_seqlen = q_max_seqlen num_gqa_groups = attn_heads v_head_dim = q_head_dim - assert nqkv == 3 + assert nqkv == 3, ( + f"Expected nqkv == 3 for qkvpacked layout, but got nqkv={nqkv} from" + f" q_aval.shape={q_aval.shape}" + ) elif qkv_layout.is_kvpacked(): *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, v_head_dim = k_aval.shape - assert q_batch_shape == kv_batch_shape - assert q_head_dim == v_head_dim - assert nkv == 2 + assert q_batch_shape == kv_batch_shape, ( + f"Mismatched batch shapes for kvpacked layout: q_batch_shape={q_batch_shape}," + f" kv_batch_shape={kv_batch_shape}" + ) + assert q_head_dim == v_head_dim, ( + f"Mismatched head dims for kvpacked layout: q_head_dim={q_head_dim}," + f" v_head_dim={v_head_dim}" + ) + assert nkv == 2, ( + f"Expected nkv == 2 for kvpacked layout, but got nkv={nkv} from" + f" k_aval.shape={k_aval.shape}" + ) elif qkv_layout.is_separate(): *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape *k_batch_shape, k_max_seqlen, k_num_gqa_groups, k_head_dim = k_aval.shape @@ -252,9 +265,13 @@ def check_seed(self, seed, dropout_probability, is_training): seed = jax.random.key_data(seed) seed = seed.astype(self.rng_state_dtype) - assert seed.dtype == self.rng_state_dtype + assert ( + seed.dtype == self.rng_state_dtype + ), f"Expected seed.dtype={self.rng_state_dtype}, but got seed.dtype={seed.dtype}" # Backend takes an int64_t seed, so only the first two u32 elements are taken - assert seed.size >= self.seed_size + assert ( + seed.size >= self.seed_size + ), f"Expected seed.size >= {self.seed_size}, but got seed.size={seed.size}" return seed @@ -384,7 +401,9 @@ def abstract( # 32-bit unsigned int to get the buffer size we need in the C++ kernel checker = _FusedAttnRNGStateChecker() seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) - assert seed_dtype == checker.rng_state_dtype + assert ( + seed_dtype == checker.rng_state_dtype + ), f"Expected seed_dtype={checker.rng_state_dtype}, but got seed_dtype={seed_dtype}" rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) @@ -394,6 +413,11 @@ def abstract( *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) + bottom_right_diagonal = config.attn_mask_type in [ + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to # prepare for the active fused-attn backend input_batch = reduce(operator.mul, batch_shape) @@ -418,16 +442,25 @@ def abstract( config.max_segments_per_seq, config.window_size[0], config.window_size[1], + bottom_right_diagonal, ) wkspace_aval = q_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - assert softmax_offset_aval.dtype == jnp.float32 + assert ( + softmax_offset_aval.dtype == jnp.float32 + ), f"Expected softmax_offset_aval.dtype=float32, but got {softmax_offset_aval.dtype}" if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: - assert softmax_offset_aval.shape == (1, attn_heads, 1, 1) + assert softmax_offset_aval.shape == (1, attn_heads, 1, 1), ( + f"Expected softmax_offset_aval.shape=(1, {attn_heads}, 1, 1) for" + f" {config.softmax_type}, but got {softmax_offset_aval.shape}" + ) else: - assert softmax_offset_aval.shape == (0,) + assert softmax_offset_aval.shape == (0,), ( + "Expected softmax_offset_aval.shape=(0,) for VANILLA_SOFTMAX, but got" + f" {softmax_offset_aval.shape}" + ) return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval @@ -526,6 +559,7 @@ def lowering( deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_left=window_size_left, window_size_right=window_size_right, + bottom_right_diagonal=config.bottom_right_diagonal, softmax_type=int(config.softmax_type.value), ) @@ -547,7 +581,9 @@ def impl( _kv_segment_pos, config: _FusedAttnConfig, ): - assert FusedAttnFwdPrimitive.inner_primitive is not None + assert ( + FusedAttnFwdPrimitive.inner_primitive is not None + ), "FusedAttnFwdPrimitive.inner_primitive has not been registered" sequence_descriptor = SequenceDescriptor( seqlens=(q_seqlen, kv_seqlen), @@ -640,10 +676,14 @@ def convert_to_2d(offsets, batch, max_seqlen): @staticmethod def batcher(batched_args, batch_dims, *, config): + # batch_dims: each element is the batch axis (0, ...) or None. Only 0 or None allowed. check_valid_batch_dims(batch_dims) - assert FusedAttnFwdPrimitive.outer_primitive is not None + assert ( + FusedAttnFwdPrimitive.outer_primitive is not None + ), "FusedAttnFwdPrimitive.outer_primitive has not been registered" q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims - + # Pass through; segment_ids/segment_pos may have different batch dims (e.g. vmapped ids, + # replicated pos). get_seqlens_and_offsets() in attention.py handles conversion without expanding. out_bdims = q_bdim, q_bdim, seed_bdim return ( FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config), @@ -718,8 +758,6 @@ def partition(config, mesh, arg_infos, result_infos): @staticmethod def shardy_sharding_rule(config, mesh, value_types, result_types): - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del mesh, result_types # Keep in sync with `infer_sharding_from_operands`. @@ -794,8 +832,15 @@ def abstract( v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) - assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype - assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype, ( + f"Mismatched dtypes: q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}," + f" bias_dtype={bias_dtype}, doutput_dtype={doutput_dtype}" + ) + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, ( + "Mismatched seqlen dtypes:" + f" q_seqlen_or_cu_seqlen_aval.dtype={q_seqlen_or_cu_seqlen_aval.dtype}," + f" kv_seqlen_or_cu_seqlen_aval.dtype={kv_seqlen_or_cu_seqlen_aval.dtype}" + ) ( batch_shape, @@ -838,6 +883,7 @@ def abstract( config.max_segments_per_seq, config.window_size[0], config.window_size[1], + config.bottom_right_diagonal, ) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) @@ -973,6 +1019,7 @@ def lowering( deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_left=window_size_left, window_size_right=window_size_right, + bottom_right_diagonal=config.bottom_right_diagonal, softmax_type=int(config.softmax_type.value), ) @@ -997,7 +1044,9 @@ def impl( _kv_segment_pos, config, ): - assert FusedAttnBwdPrimitive.inner_primitive is not None + assert ( + FusedAttnBwdPrimitive.inner_primitive is not None + ), "FusedAttnBwdPrimitive.inner_primitive has not been registered" sequence_descriptor = SequenceDescriptor( seqlens=(q_seqlen, kv_seqlen), @@ -1037,7 +1086,9 @@ def convert_to_2d(offsets, batch, max_seqlen): batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval( q, k, v, config.qkv_layout ) - assert len(batch) == 1 + assert ( + len(batch) == 1 + ), f"Expected len(batch) == 1, but got len(batch)={len(batch)}, batch={batch}" kv_batch = q_batch = batch[0] # Gather valid q_seqlen, which is greater than 0 @@ -1096,9 +1147,11 @@ def convert_to_2d(offsets, batch, max_seqlen): @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) - assert FusedAttnBwdPrimitive.outer_primitive is not None + assert ( + FusedAttnBwdPrimitive.outer_primitive is not None + ), "FusedAttnBwdPrimitive.outer_primitive has not been registered" q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims - + # Pass through; segment_ids/segment_pos may have different batch dims. Conversion is in attention.py. out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim return ( FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), @@ -1200,8 +1253,6 @@ def sharded_impl( @staticmethod def shardy_sharding_rule(config, mesh, value_types, result_types): - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del config, mesh # Keep in sync with `infer_sharding_from_operands`. input_spec = tuple((f"…{x}",) for x in range(len(value_types))) @@ -1384,9 +1435,10 @@ def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size): def get_step_config(self) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + adjusted_mask = self.get_adjusted_mask() return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, - attn_mask_type=self.get_adjusted_mask(), + attn_mask_type=adjusted_mask, softmax_type=self.config.softmax_type, qkv_layout=self.config.qkv_layout, scaling_factor=self.config.scaling_factor, @@ -1394,6 +1446,7 @@ def get_step_config(self) -> _FusedAttnConfig: is_training=self.config.is_training, max_segments_per_seq=self.config.max_segments_per_seq, window_size=self.config.window_size, + bottom_right_diagonal=adjusted_mask.is_bottom_right(), context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, @@ -1402,9 +1455,10 @@ def get_step_config(self) -> _FusedAttnConfig: def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention.""" + adjusted_mask = self.get_adjusted_mask() return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, - attn_mask_type=self.get_adjusted_mask(), + attn_mask_type=adjusted_mask, softmax_type=self.config.softmax_type, qkv_layout=self.config.qkv_layout, scaling_factor=self.config.scaling_factor, @@ -1412,6 +1466,7 @@ def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: is_training=self.config.is_training, max_segments_per_seq=self.get_adjusted_max_segments_per_seq(max_seqlen, cp_size), window_size=self.config.window_size, + bottom_right_diagonal=adjusted_mask.is_bottom_right(), context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, @@ -2457,6 +2512,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: is_training=self.config.is_training, max_segments_per_seq=self.config.max_segments_per_seq, window_size=self.config.window_size, + bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, @@ -3407,7 +3463,9 @@ def fused_attn_fwd( raise ValueError(f"Unknown {qkv_layout=}") if attn_bias_type == AttnBiasType.NO_BIAS: - assert bias is None + assert ( + bias is None + ), f"bias must be None when attn_bias_type is NO_BIAS, but got bias={bias}" bias = jnp.zeros(0, dtype=qkv[0].dtype) if softmax_offset is None: @@ -3425,10 +3483,16 @@ def fused_attn_fwd( softmax_offset, (None, HEAD_AXES, None, None) ) else: - assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX + assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX, ( + "Expected VANILLA_SOFTMAX when softmax_offset is None and not OFF_BY_ONE_SOFTMAX," + f" but got softmax_type={softmax_type}" + ) softmax_offset = jnp.zeros(0, dtype=jnp.float32) else: - assert softmax_offset.dtype == jnp.float32 + assert softmax_offset.dtype == jnp.float32, ( + "Expected softmax_offset.dtype=float32, but got" + f" softmax_offset.dtype={softmax_offset.dtype}" + ) # Shard by heads dimension if not VANILLA_SOFTMAX if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: softmax_offset = with_sharding_constraint_by_logical_axes( @@ -3445,6 +3509,7 @@ def fused_attn_fwd( is_training=is_training, max_segments_per_seq=max_segments_per_seq, window_size=(-1, -1) if window_size is None else window_size, + bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, @@ -3566,7 +3631,9 @@ def fused_attn_bwd( raise ValueError(f"Unknown {qkv_layout=}") if attn_bias_type == AttnBiasType.NO_BIAS: - assert bias is None + assert ( + bias is None + ), f"bias must be None when attn_bias_type is NO_BIAS, but got bias with type={type(bias)}" bias = jnp.zeros(0, dtype=qkv[0].dtype) if softmax_offset is None: @@ -3591,13 +3658,21 @@ def fused_attn_bwd( softmax_offset, (None, HEAD_AXES, None, None) ) - # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on - # sm100+ - compute_capabilities = get_all_device_compute_capability() - if any(x >= 100 for x in compute_capabilities) and not is_hip_extension(): - assert not ( - attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 - ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" + compute_capabilities = [] if is_hip_extension() else get_all_device_compute_capability() + if any(x >= 100 for x in compute_capabilities) and is_training: + assert ( + FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 7, 0) + and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0) + ) or ( + not FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 18, 1) + and attn_bias_type == AttnBiasType.NO_BIAS + and dropout_probability == 0.0 + ), ( + "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout," + " and deterministic bprop (cuDNN 9.18.1+) does not support bias or dropout" + ) fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, @@ -3609,6 +3684,7 @@ def fused_attn_bwd( is_training=is_training, max_segments_per_seq=max_segments_per_seq, window_size=(-1, -1) if window_size is None else window_size, + bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index e65215bec..a61eaef8e 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -10,16 +10,24 @@ from abc import ABCMeta, abstractmethod from functools import partial +import jax from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching from jax._src import dispatch from jax import ffi +from packaging.version import Version as PkgVersion from .misc import is_hip_extension import transformer_engine_jax +# GSPMD sharding propagation (infer_sharding_from_operands) is removed in JAX > 0.9.1. +# Only register it for older JAX versions to maintain backwards compatibility. +# For JAX > 0.9.1, infer_sharding_from_operands is also removed from def_partition's signature, +# so it must not be passed at all. +_JAX_GSPMD_SUPPORTED = PkgVersion(jax.__version__) <= PkgVersion("0.9.1") + class BasePrimitive(metaclass=ABCMeta): """ @@ -146,13 +154,15 @@ def batcher(): """ return NotImplemented - @staticmethod - @abstractmethod - def infer_sharding_from_operands(): + @classmethod + def infer_sharding_from_operands(cls, *args, **kwargs): """ to describe infer_sharding_from_operands for custom_partitioning """ - return NotImplemented + raise NotImplementedError( + f"{cls.__name__} does not support GSPMD sharding propagation." + " Please use Shardy partitioner instead." + ) @staticmethod @abstractmethod @@ -175,6 +185,22 @@ def shardy_sharding_rule(*args): # Registry to store all registered primitive classes _primitive_registry = {} +_gspmd_deprecation_warned = False + + +def _warn_gspmd_deprecation_once(): + global _gspmd_deprecation_warned + if not _gspmd_deprecation_warned: + warnings.warn( + "GSPMD sharding propagation rules in TE-JAX are planned to be removed in June 2026." + " They are no longer maintained or tested. Use them at your own risk." + " Please use Shardy propagation instead." + " In case you cannot upgrade to a JAX version that supports Shardy, please reach out!", + DeprecationWarning, + stacklevel=2, + ) + _gspmd_deprecation_warned = True + def register_primitive(cls, outer_only=False): """ @@ -211,10 +237,28 @@ def name_of_wrapper_p(): outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) + + if _JAX_GSPMD_SUPPORTED: + fn = cls.__dict__.get("infer_sharding_from_operands") + if fn is not None: + actual_fn = ( + cls.infer_sharding_from_operands + ) # Use descriptor protocol to unwrap staticmethod + + def _gspmd_wrapper(*args, **kwargs): + _warn_gspmd_deprecation_once() + return actual_fn(*args, **kwargs) + + gspmd_kwargs = {"infer_sharding_from_operands": _gspmd_wrapper} + else: + gspmd_kwargs = {"infer_sharding_from_operands": cls.infer_sharding_from_operands} + else: + gspmd_kwargs = {} + outer_p_lower.def_partition( - infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition, sharding_rule=cls.shardy_sharding_rule, + **gspmd_kwargs, ) mlir.register_lowering( outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) @@ -223,7 +267,7 @@ def name_of_wrapper_p(): for _name, _value in transformer_engine_jax.registrations().items(): - ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension else "CUDA") + ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension() else "CUDA") def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4e369ebb3..17246d51c 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -7,9 +7,10 @@ import math import operator +import os from collections.abc import Iterable from dataclasses import dataclass -from functools import partial, reduce +from functools import partial, reduce, cache from typing import Tuple, Sequence, Union from enum import Enum import warnings @@ -31,13 +32,11 @@ from transformer_engine_jax import ( initialize_cgemm_communicator, get_cgemm_num_max_streams, + get_grouped_gemm_setup_workspace_size, ) from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize - -from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type - from ..quantize import ( AbstractBaseTensor, NoScaleTensor, @@ -52,8 +51,6 @@ noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, - get_quantize_config_with_recipe, - get_global_quantize_recipe, QuantizeLayout, ) from .misc import get_padded_spec, is_all_reduce_in_float32 @@ -63,6 +60,8 @@ dp_or_fsdp_axis_size, ) +from ..util import get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type + __all__ = [ "CollectiveOp", @@ -72,7 +71,6 @@ "gemm", "grouped_gemm_copy_group_sizes", "grouped_gemm", - "gemm_uses_jax_dot", "sanitize_dims", "get_non_contracting_dims", "transpose_dims", @@ -83,6 +81,21 @@ num_cublas_streams = get_num_compute_streams() +# Cache whether the CUDA-graphable grouped GEMM implementation is available at import time. +# Calling get_grouped_gemm_setup_workspace_size raises a RuntimeError mentioning "cublas" when +# compiled against cuBLAS < 13.2, in which case the cuda-graphable path is unavailable. +try: + if is_hip_extension(): + _v2_grouped_gemm_available = False + else: + get_grouped_gemm_setup_workspace_size(1) + _v2_grouped_gemm_available = True +except RuntimeError as e: + if "cublas" in str(e).lower(): + _v2_grouped_gemm_available = False + else: + raise + def get_cublas_workspace_size_bytes() -> None: """Return workspace size needed for current architecture""" @@ -184,17 +197,26 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ flatten_axis=flatten_axis, ) - assert not isinstance(lhs_q, ScaledTensor2x) - assert not isinstance(rhs_q, ScaledTensor2x) + if isinstance(lhs_q, ScaledTensor2x): + raise TypeError( + "Expected lhs_q to not be ScaledTensor2x after quantization, but got" + f" type={type(lhs_q)}" + ) + if isinstance(rhs_q, ScaledTensor2x): + raise TypeError( + "Expected rhs_q to not be ScaledTensor2x after quantization, but got" + f" type={type(rhs_q)}" + ) def has_rht_applied(q: AbstractBaseTensor) -> bool: return isinstance(q, ScaledTensor1x) and q.has_rht_applied - assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), ( - "With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized" - " with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the" - " GEMM." - ) + if has_rht_applied(lhs_q) != has_rht_applied(rhs_q): + raise ValueError( + "With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be" + " quantized with RHT as well. This is to ensure the RHT is applied to both and will" + " cancel out in the GEMM." + ) return lhs_q, rhs_q @@ -291,15 +313,17 @@ def collective_gemm_bootstrap( this function with its own unique process_id. """ if is_hip_extension(): - assert 0, "collective_gemm_bootstrap is not supported for ROCm yet." - assert ( - num_devices_per_process == 1 and jax.local_device_count() == 1 - ), "Only single device per process is supported at the moment!" - assert num_total_devices % num_devices_per_process == 0, ( - f"Invalid num_total_devices={num_total_devices}," - f" num_devices_per_process={num_devices_per_process}" - ) - assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}" + raise NotImplementedError("Collective GEMM is not supported for ROCm yet.") + + if not (num_devices_per_process == 1 and jax.local_device_count() == 1): + raise RuntimeError("Only single device per process is supported at the moment!") + if num_total_devices % num_devices_per_process != 0: + raise ValueError( + f"Invalid num_total_devices={num_total_devices}," + f" num_devices_per_process={num_devices_per_process}" + ) + if not 0 <= process_id < num_total_devices: + raise ValueError(f"Invalid process_id={process_id}") initialize_cgemm_communicator( num_total_devices, num_devices_per_process, @@ -386,16 +410,65 @@ def get_rhs_axis_boundary(rhs_cdims, is_transposed): return min(rhs_cdims) if is_transposed else max(rhs_cdims) + 1 +@cache +def _get_high_precision_accumulation_from_env() -> bool: + """Read NVTE_FP8_GEMM_HIGH_PRECISION_ACCUMULATION once per process (cached).""" + return os.getenv("NVTE_FP8_GEMM_HIGH_PRECISION_ACCUMULATION", "0") == "1" + + def assert_cublas_requirements(scaling_mode, contracting_size, tensor_name): """Assert that the given tensor shape and layout meet the requirements for cuBLAS GEMM.""" if scaling_mode != ScalingMode.NO_SCALING: # Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage alignment = 32 if scaling_mode.is_nvfp4_scaling else 16 - assert contracting_size % alignment == 0, ( - f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of" - f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}" - ) + if contracting_size % alignment != 0: + raise ValueError( + f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of" + f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}" + ) + + +def _reorder_tpsp_leading(tensor, original_shape): + """Reorder tensor so the tpsp axis is leading: reshape (dp, n, tpsp, m, ...), transpose (2, 0, 1, 3, ...).""" + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = tensor.reshape( + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + tpsp_axis_size(), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) + return reordered.reshape(original_shape) + + +def _reorder_dp_leading(tensor, original_shape): + """Reorder tensor so the dp axis is leading: reshape (tpsp, dp, n, m, ...), transpose (1, 2, 0, 3, ...).""" + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = tensor.reshape( + tpsp_axis_size(), + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim)) + return reordered.reshape(original_shape) class GemmPrimitive(BasePrimitive): @@ -403,9 +476,9 @@ class GemmPrimitive(BasePrimitive): Primitive for cuBLAS GEMM """ - name = "te_gemm_ffi" + name = "te_gemm_v2_ffi" multiple_results = True - impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) + impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14) inner_primitive = None outer_primitive = None @@ -416,15 +489,11 @@ def abstract( rhs, rhs_scale_inv, bias, - gelu_input, alpha, beta, out_dtype, contracting_dims, scaling_mode, - fuse_bias, - fuse_gelu, - grad, use_split_accumulator, transpose_batch_sequence, sequence_dim, @@ -445,55 +514,63 @@ def _dims_are_consecutive(dims): lhs_contracting_dims, rhs_contracting_dims, ) = map(sanitize_dims, operand_ndims, contracting_dims) - assert _dims_are_consecutive(lhs_contracting_dims), ( - "cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got " - f"{lhs_contracting_dims}." - ) - assert _dims_are_consecutive(rhs_contracting_dims), ( - "cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got " - f"{rhs_contracting_dims}." - ) + if not _dims_are_consecutive(lhs_contracting_dims): + raise ValueError( + "cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got " + f"{lhs_contracting_dims}." + ) + if not _dims_are_consecutive(rhs_contracting_dims): + raise ValueError( + "cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got " + f"{rhs_contracting_dims}." + ) lhs_contracting_size, rhs_contracting_size = map( lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), (lhs.shape, rhs.shape), (lhs_contracting_dims, rhs_contracting_dims), ) - assert lhs_contracting_size == rhs_contracting_size, ( - "cuBLAS GEMM operands have incompatible contracting dimensions: " - f"{lhs.shape} @ idx {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}." - ) + if lhs_contracting_size != rhs_contracting_size: + raise ValueError( + f"cuBLAS GEMM operands have incompatible contracting dimensions: {lhs.shape} @ idx" + f" {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}." + ) + assert_cublas_requirements(scaling_mode, lhs_contracting_size, "LHS") + assert_cublas_requirements(scaling_mode, rhs_contracting_size, "RHS") lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims) if scaling_mode != ScalingMode.NO_SCALING: - assert scaling_mode.is_nvfp4_scaling or _compatible_fp8_gemm_dtypes( - lhs.dtype, rhs.dtype - ), ( - "cuBLAS GEMM quantized operands have incompatible data types: " - f"{lhs.dtype} x {rhs.dtype}." - ) - assert ( - lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0 - ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." + if not ( + scaling_mode.is_nvfp4_scaling or _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype) + ): + raise ValueError( + "cuBLAS GEMM quantized operands have incompatible data types: " + f"{lhs.dtype} x {rhs.dtype}." + ) + if not (lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0): + raise ValueError( + "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." + ) if ( scaling_mode != ScalingMode.MXFP8_1D_SCALING and not is_fp8_gemm_with_all_layouts_supported() ): - assert not lhs_is_transposed and rhs_is_transposed, ( - "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " - "require non-transposed LHS and transposed RHS operands " - "(`contracting_dims=((-1, ), (-1, ))`)." - ) + if lhs_is_transposed or not rhs_is_transposed: + raise ValueError( + "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " + "require non-transposed LHS and transposed RHS operands " + "(`contracting_dims=((-1, ), (-1, ))`)." + ) else: - assert lhs.dtype == rhs.dtype, ( - "For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal." - f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" - ) + if lhs.dtype != rhs.dtype: + raise ValueError( + "For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal." + f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" + ) # Determine output shape and dtype - assert ( - dtypes.canonicalize_dtype(out_dtype).itemsize > 1 - ), "cuBLAS GEMM custom op does not support 8-bit quantized output types." + if not dtypes.canonicalize_dtype(out_dtype).itemsize > 1: + raise ValueError("cuBLAS GEMM custom op does not support 8-bit quantized output types.") lhs_non_contracting_shape, rhs_non_contracting_shape = map( lambda shape, dims: [shape[dim] for dim in range(len(shape)) if dim not in dims], (lhs.shape, rhs.shape), @@ -504,7 +581,8 @@ def _dims_are_consecutive(dims): # Adjust output shape for comm+GEMM overlap if not collective_op.is_none and not is_outer: # Inner abstract - assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + if sequence_dim != 1: + raise ValueError(f"Invalid sequence_dim. Got sequence_dim={sequence_dim}") overlap_out_shape = list(out_shape).copy() if collective_op.is_all_gather: overlap_out_shape[1] *= tpsp_axis_size() @@ -512,44 +590,33 @@ def _dims_are_consecutive(dims): overlap_out_shape[sequence_dim] = ( overlap_out_shape[sequence_dim] // tpsp_axis_size() ) - assert out_dtype == jnp.bfloat16, f"Unsupported out_dtype={out_dtype}" + if out_dtype != jnp.bfloat16: + raise ValueError(f"Unsupported out_dtype={out_dtype}") output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype) - # Validate bias - if fuse_bias: - assert bias.shape == tuple(rhs_non_contracting_shape), ( - "cuBLAS GEMM bias tensor has incorrect shape, " - f"expected ({tuple(rhs_non_contracting_shape)}, ) but found {bias.shape}." - ) - assert bias.dtype == out_dtype, ( - "cuBLAS GEMM bias tensor has incorrect data type, " - f"expected {out_dtype} but found {bias.dtype}." - ) - # WAR: allocate dbias regardless of fuse_bias so that the sharding propagation works as we - # change the fuse_bias value in the sharded_impl - dbias_shape = bias.shape if grad else (0,) - bias_grad = jax.core.ShapedArray(shape=dbias_shape, dtype=bias.dtype) - - # Validate pre-GeLU - pre_gelu_shape = (0,) - pre_gelu_dtype = out_dtype - if fuse_gelu: - pre_gelu_shape = out_shape - if grad: - pre_gelu_ndim = len(pre_gelu_shape) - assert gelu_input.ndim == pre_gelu_shape and all( - gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim) - ), ( - "cuBLAS GEMM pre-GeLU tensor has incorrect shape, " - f"expected {pre_gelu_shape} but found {gelu_input.shape}." + # Validate bias when present (bias.size > 0 means fuse bias) + if bias.size > 0: + if bias.shape != tuple(rhs_non_contracting_shape): + raise ValueError( + "cuBLAS GEMM bias tensor has incorrect shape, " + f"expected ({tuple(rhs_non_contracting_shape)}, ) but found {bias.shape}." ) - assert gelu_input.dtype == out_dtype, ( - "cuBLAS GEMM pre-GeLU tensor has incorrect data type, " - f"expected {pre_gelu_dtype} but found {gelu_input.dtype}." + if bias.dtype != out_dtype: + raise ValueError( + "cuBLAS GEMM bias tensor has incorrect data type, " + f"expected {out_dtype} but found {bias.dtype}." ) - pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) - assert alpha.size == 1 and alpha.dtype == jnp.float32 - assert beta.size == 1 and beta.dtype == jnp.float32 + + if alpha.size != 1 or alpha.dtype != jnp.float32: + raise ValueError( + f"Expected alpha to be a single float32 scalar, but got alpha.size={alpha.size}," + f" alpha.dtype={alpha.dtype}" + ) + if beta.size != 1 or beta.dtype != jnp.float32: + raise ValueError( + f"Expected beta to be a single float32 scalar, but got beta.size={beta.size}," + f" beta.dtype={beta.dtype}" + ) # Declare cuBLAS workspace workspace_size = get_cublas_workspace_size_bytes() @@ -563,12 +630,12 @@ def _dims_are_consecutive(dims): workspace_size += 256 workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - return output, bias_grad, pre_gelu_out, workspace + return output, workspace @staticmethod def outer_abstract(*args, **kwargs): - outputs = GemmPrimitive.abstract(*args, **kwargs) - return outputs[:-1] # discard workspace array + output, _ = GemmPrimitive.abstract(*args, **kwargs) + return (output,) @staticmethod def lowering( @@ -578,15 +645,11 @@ def lowering( rhs, rhs_scale_inv, bias, - gelu_input, alpha, beta, out_dtype, contracting_dims, scaling_mode, - fuse_bias, - fuse_gelu, - grad, use_split_accumulator, transpose_batch_sequence, sequence_dim, @@ -601,53 +664,18 @@ def lowering( (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) ) - lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed) - lhs_contracting_size = ( - reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) - if lhs_transposed - else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) - ) - assert_cublas_requirements( - scaling_mode, - lhs_contracting_size, - "LHS", - ) - rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) - rhs_contracting_size = ( - reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) - if rhs_transposed - else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) - ) - assert_cublas_requirements( - scaling_mode, - rhs_contracting_size, - "RHS", - ) - - args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) + args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, alpha, beta) kwargs = { "scaling_mode": int(scaling_mode.value), + "collective_op": int(collective_op.value), "lhs_axis_boundary": get_lhs_axis_boundary(lhs_cdims, lhs_transposed), "rhs_axis_boundary": get_rhs_axis_boundary(rhs_cdims, rhs_transposed), "lhs_transposed": lhs_transposed, "rhs_transposed": rhs_transposed, - "fuse_bias": fuse_bias, - "fuse_gelu": fuse_gelu, - "grad": grad, "use_split_accumulator": use_split_accumulator, - "collective_op": int(collective_op.value), } - operand_output_aliases = {} - if grad: - operand_output_aliases.update({4: 1}) # bias <-> bias_grad - if fuse_gelu and grad: - operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out - - return jax.ffi.ffi_lowering( - GemmPrimitive.name, - operand_output_aliases=operand_output_aliases, - )(ctx, *args, **kwargs) + return jax.ffi.ffi_lowering(GemmPrimitive.name)(ctx, *args, config=kwargs) @staticmethod def impl( @@ -656,15 +684,11 @@ def impl( rhs, rhs_scale_inv, bias, - gelu_input, alpha, beta, out_dtype, contracting_dims, scaling_mode, - fuse_bias, - fuse_gelu, - grad, use_split_accumulator, transpose_batch_sequence, sequence_dim, @@ -679,93 +703,88 @@ def impl( lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims) rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1 - lhs_scale_inv = apply_padding_to_scale_inv( - lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis - ) - rhs_scale_inv = apply_padding_to_scale_inv( - rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis - ) + if not collective_op.is_none and not is_outer: + # MXFP8 + Collective AG/RS: both sides of flatten_axis must be multiples of 128. + # No padding is needed in this case + lhs_first, lhs_last = math.prod(lhs.shape[:lhs_flatten_axis]), math.prod( + lhs.shape[lhs_flatten_axis:] + ) + assert lhs_first % 128 == 0 and lhs_last % 128 == 0, ( + "MXFP8 + Collective AG/RS requires LHS dimensions before and after the flatten" + f" axis to be multiples of 128. Got lhs.shape={lhs.shape}," + f" lhs_flatten_axis={lhs_flatten_axis}" + ) + rhs_first, rhs_last = math.prod(rhs.shape[:rhs_flatten_axis]), math.prod( + rhs.shape[rhs_flatten_axis:] + ) + assert rhs_first % 128 == 0 and rhs_last % 128 == 0, ( + "MXFP8 + Collective AG/RS requires LHS dimensions before and after the flatten" + f" axis to be multiples of 128. Got rhs.shape={rhs.shape}," + f" rhs_flatten_axis={rhs_flatten_axis}" + ) + # The scale needs to be in good shape for reordering + assert lhs_scale_inv.shape[sequence_dim] % tpsp_axis_size() == 0, ( + "MXFP8 + Collective AG/RS requires RHS scale inv sequence dimension to be" + f" multiples of tpsp_axis_size. Got lhs_scale_inv.shape={lhs_scale_inv.shape}," + f" tpsp_axis_size={tpsp_axis_size()}, sequence_dim={sequence_dim}" + ) + else: + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, + scaling_mode, + lhs.shape, + lhs_transposed, + lhs_flatten_axis, + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis + ) + # Only perform JAX-based swizzle for MXFP8, NVFP4 swizzle will go though nvte kernel if scaling_mode.is_mxfp8_scaling and not is_hip_extension(): lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) + # Determine if we need to reorder the tensor so that the input/output are in the correct layout for the collective operation + need_reorder = not transpose_batch_sequence and not is_outer and not collective_op.is_none + # Alter lhs blocks so that CGEMM RS outputs correctly + if need_reorder and collective_op.is_reduce_scatter and lhs.shape[0] != 1: + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + lhs = _reorder_tpsp_leading(lhs, lhs.shape) + if ( - collective_op.is_reduce_scatter - and not transpose_batch_sequence - and not is_outer - and not lhs.shape[0] == 1 + need_reorder + and (collective_op.is_reduce_scatter or collective_op.is_all_gather) + and lhs_scale_inv.shape[0] != 1 + and scaling_mode.is_1d_block_scaling() ): assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" - original_shape = lhs.shape - assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( - f"Original_shape[0]={original_shape[0]} is not divisible by" - f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" - ) - assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( - f"Original_shape[1]={original_shape[1]} is not divisible by" - f" tpsp_axis_size()={tpsp_axis_size()}" - ) - reshaped = lhs.reshape( - dp_or_fsdp_axis_size(), - int(original_shape[0] / dp_or_fsdp_axis_size()), - tpsp_axis_size(), - int(original_shape[1] / tpsp_axis_size()), - *original_shape[2:], - ) - reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) - lhs = reordered.reshape(original_shape) + lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape) - (output, bias_grad, pre_gelu_out, _) = GemmPrimitive.inner_primitive.bind( + (output, _) = GemmPrimitive.inner_primitive.bind( lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, - gelu_input, alpha, beta, out_dtype=out_dtype, contracting_dims=contracting_dims, scaling_mode=scaling_mode, - fuse_bias=fuse_bias, - fuse_gelu=fuse_gelu, - grad=grad, use_split_accumulator=use_split_accumulator, - collective_op=collective_op, transpose_batch_sequence=transpose_batch_sequence, sequence_dim=sequence_dim, is_outer=is_outer, + collective_op=collective_op, ) # Alter output blocks for CGEMM AG - if ( - collective_op.is_all_gather - and not transpose_batch_sequence - and not is_outer - and not output.shape[0] == 1 - ): + if need_reorder and collective_op.is_all_gather and output.shape[0] != 1: assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" - original_shape = output.shape - assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( - f"Original_shape[0]={original_shape[0]} is not divisible by" - f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" - ) - assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( - f"Original_shape[1]={original_shape[1]} is not divisible by" - f" tpsp_axis_size()={tpsp_axis_size()}" - ) - reshaped = output.reshape( - tpsp_axis_size(), - dp_or_fsdp_axis_size(), - int(original_shape[0] / dp_or_fsdp_axis_size()), - int(original_shape[1] / tpsp_axis_size()), - *original_shape[2:], - ) - reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim)) - output = reordered.reshape(original_shape) + output = _reorder_dp_leading(output, output.shape) - return [output, bias_grad, pre_gelu_out] + return (output,) @staticmethod def outer_impl( @@ -774,15 +793,11 @@ def outer_impl( rhs, rhs_scale_inv, bias, - gelu_input, alpha, beta, out_dtype, contracting_dims, scaling_mode, - fuse_bias, - fuse_gelu, - grad, use_split_accumulator, transpose_batch_sequence, sequence_dim, @@ -795,15 +810,11 @@ def outer_impl( rhs, rhs_scale_inv, bias, - gelu_input, alpha, beta, out_dtype, contracting_dims, scaling_mode, - fuse_bias, - fuse_gelu, - grad, use_split_accumulator, transpose_batch_sequence, sequence_dim, @@ -818,9 +829,6 @@ def batcher( out_dtype, contracting_dims, scaling_mode, - fuse_bias, - fuse_gelu, - grad, use_split_accumulator, collective_op, transpose_batch_sequence, @@ -828,39 +836,30 @@ def batcher( is_outer, ): del transpose_batch_sequence, sequence_dim, is_outer - assert GemmPrimitive.outer_primitive is not None + if GemmPrimitive.outer_primitive is None: + raise RuntimeError("GemmPrimitive.outer_primitive has not been registered") lhs_bdims, _, rhs_bdims, *_ = batch_dims # Batched GEMM is not supported - assert ( - lhs_bdims is None and rhs_bdims is None - ), f"(Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims})" + if not (lhs_bdims is None and rhs_bdims is None): + raise RuntimeError( + f"Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}" + ) out_bdims = (None,) - # Bias gradient is never batched - bias_bdims = (None,) - - # Pre-GeLU output, if exists, is batched like GEMM output - pre_gelu_bdims = (None,) - if fuse_gelu and not grad: - pre_gelu_bdims = out_bdims - return ( GemmPrimitive.outer_primitive.bind( *batched_args, out_dtype=out_dtype, contracting_dims=contracting_dims, scaling_mode=scaling_mode, - fuse_bias=fuse_bias, - fuse_gelu=fuse_gelu, - grad=grad, use_split_accumulator=use_split_accumulator, collective_op=collective_op, transpose_batch_sequence=transpose_batch_sequence, sequence_dim=sequence_dim, is_outer=is_outer, ), - (out_bdims, bias_bdims, pre_gelu_bdims), + (out_bdims,), ) @staticmethod @@ -869,6 +868,7 @@ def _parse_operand_output_specs( contracting_dims, transpose_batch_sequence, collective_op, + scaling_mode, ): lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) @@ -900,7 +900,8 @@ def _parse_operand_output_specs( for l in lhs_cspecs: for r in rhs_cspecs: if l is not None and l == r: - assert reduce_spec is None, "Multiple reduce dimension is detected!" + if reduce_spec is not None: + raise RuntimeError("Multiple reduce dimension is detected!") reduce_spec = l sequence_dim = None @@ -916,18 +917,20 @@ def _parse_operand_output_specs( " Please check your sharding configuration." ) from exc sequence_dim = tpsp_idx - assert (sequence_dim == 1) ^ transpose_batch_sequence, ( - "CollectiveGEMM supports only (sequence_dim=1 and transpose_batch_sequence=False)" - " or (sequence_dim=0 and transpose_batch_sequence=True). Received:" - f" sequence_dim={sequence_dim}," - f" transpose_batch_sequence={transpose_batch_sequence}." - ) + if not (sequence_dim == 1) ^ transpose_batch_sequence: + raise ValueError( + "CollectiveGEMM supports only (sequence_dim=1 and" + " transpose_batch_sequence=False) or (sequence_dim=0 and" + f" transpose_batch_sequence=True). Received: sequence_dim={sequence_dim}," + f" transpose_batch_sequence={transpose_batch_sequence}." + ) elif collective_op.is_reduce_scatter: - assert reduce_spec == gsr.tpsp_resource, ( - "Only CollectiveGemm RS with the Reduction over the TPSP axis is supported! Got" - f" reduce_spec={reduce_spec}, tpsp_resource={gsr.tpsp_resource}" - ) + if reduce_spec != gsr.tpsp_resource: + raise ValueError( + "Only CollectiveGemm RS with the Reduction over the TPSP axis is supported! Got" + f" reduce_spec={reduce_spec}, tpsp_resource={gsr.tpsp_resource}" + ) sequence_dim = int(not transpose_batch_sequence) if reduce_spec is not None: @@ -955,7 +958,15 @@ def _parse_operand_output_specs( # Non-contracting dims of RHS always needs to be gathered along the FSDP axis rhs_non_cspecs = tuple( - None if spec is not None and spec == gsr.fsdp_resource else spec + ( + None + if spec is not None + and ( + spec == gsr.fsdp_resource + or (isinstance(spec, tuple) and gsr.fsdp_resource in spec) + ) + else spec + ) for spec in rhs_non_cspecs ) @@ -972,14 +983,18 @@ def _parse_operand_output_specs( # Only do AG Sequence dim if not Overlap RS if collective_op.is_all_gather: - assert sequence_dim <= len( - lhs_non_cspecs - ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}" + if sequence_dim > len(lhs_non_cspecs): + raise ValueError( + f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs:" + f" {lhs_non_cspecs}" + ) out_specs = out_specs[:sequence_dim] + (None,) + out_specs[sequence_dim + 1 :] elif collective_op.is_reduce_scatter: - assert sequence_dim <= len( - lhs_non_cspecs - ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}" + if sequence_dim > len(lhs_non_cspecs): + raise ValueError( + f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs:" + f" {lhs_non_cspecs}" + ) out_specs = ( out_specs[:sequence_dim] + (gsr.tpsp_resource,) + out_specs[sequence_dim + 1 :] ) @@ -994,16 +1009,29 @@ def _parse_operand_output_specs( (lhs_non_cspecs, rhs_non_cspecs), ) - # Bias and Pre-GeLU sharding is based on GEMM output before any scatter - bias_specs = tuple(list(rhs_non_cspecs).copy()) - gelu_specs = tuple(list(out_specs).copy()) + # Bias sharding is based on GEMM output before any scatter + bias_specs = rhs_non_cspecs if arg_infos[4].size > 0 else (None,) # bias is operand index 4 + + # Scale shardings are based on the scaling_mode and collective_op + lhs_scale_specs = rhs_scale_specs = (None,) + if scaling_mode.is_1d_block_scaling(): + rhs_scale_specs = rhs_specs + # Set the seq spec to None to trigger AG the scales as TE/Common CGEMM does not handle + # scale collecting yet + if collective_op.is_all_gather: + lhs_scale_specs = tuple( + None if i == sequence_dim else s for i, s in enumerate(lhs_specs) + ) + else: + lhs_scale_specs = lhs_specs if not collective_op.is_none: - assert sequence_dim >= 0, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + if sequence_dim < 0: + raise ValueError(f"Invalid sequence_dim. Got sequence_dim={sequence_dim}") return ( - (lhs_specs, rhs_specs, bias_specs, gelu_specs), - (out_specs, bias_specs, gelu_specs), + (lhs_specs, lhs_scale_specs, rhs_specs, rhs_scale_specs, bias_specs), + out_specs, reduce_spec, sequence_dim, ) @@ -1013,9 +1041,6 @@ def infer_sharding_from_operands( out_dtype, contracting_dims, scaling_mode, - fuse_bias, - fuse_gelu, - grad, use_split_accumulator, transpose_batch_sequence, sequence_dim, @@ -1027,40 +1052,28 @@ def infer_sharding_from_operands( ): del ( out_dtype, - scaling_mode, use_split_accumulator, result_infos, is_outer, sequence_dim, ) - (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( - GemmPrimitive._parse_operand_output_specs( - arg_infos, contracting_dims, transpose_batch_sequence, collective_op - ) + (_, out_specs, *_) = GemmPrimitive._parse_operand_output_specs( + arg_infos, + contracting_dims, + transpose_batch_sequence, + collective_op, + scaling_mode, ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) - # Discard dbias gradient spec if there is no bias and grad fusion - if not (fuse_bias and grad): - dbias_specs = (None,) - dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs)) - - # Discard pre-GeLU output spec if there is no GeLU fusion - if not fuse_gelu: - pre_gelu_specs = (None,) - pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)) - - return [out_sharding, dbias_sharding, pre_gelu_sharding] + return (out_sharding,) @staticmethod def partition( out_dtype, contracting_dims, scaling_mode, - fuse_bias, - fuse_gelu, - grad, use_split_accumulator, transpose_batch_sequence, sequence_dim, @@ -1073,8 +1086,8 @@ def partition( del result_infos, is_outer, sequence_dim ( - (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), - (out_specs, dbias_specs, pre_gelu_specs), + (lhs_specs, lhs_scale_specs, rhs_specs, rhs_scale_specs, bias_input_specs), + out_specs, reduce_spec, inferred_sequence_dim, ) = GemmPrimitive._parse_operand_output_specs( @@ -1082,63 +1095,48 @@ def partition( contracting_dims, transpose_batch_sequence, collective_op, + scaling_mode, ) # Block scale inverses match their operands, but tensor scale inverses are unsharded. none_sharding = NamedSharding(mesh, PartitionSpec(None)) lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) + lhs_scale_sharding = NamedSharding(mesh, PartitionSpec(*lhs_scale_specs)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) + rhs_scale_sharding = NamedSharding(mesh, PartitionSpec(*rhs_scale_specs)) + arg_shardings = ( lhs_sharding, - lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + lhs_scale_sharding, rhs_sharding, - rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + rhs_scale_sharding, ) - # Discard bias input spec if there is no bias fusion - if not fuse_bias: - bias_input_specs = (None,) + # Bias arg_shardings += (NamedSharding(mesh, PartitionSpec(*bias_input_specs)),) - # Discard pre-GeLU input spec if there is no GeLU fusion - if not fuse_gelu: - gelu_input_specs = (None,) - arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),) - # Alpha, beta arg_shardings += (none_sharding, none_sharding) # Assemble output shardings - out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] - - # Discard bias gradient spec if there is no bias and grad fusion - if not (fuse_bias and grad): - dbias_specs = (None,) - out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs))) + out_sharding = (NamedSharding(mesh, PartitionSpec(*out_specs)),) - # Discard pre-GeLU output spec if there is no GeLU fusion - if not fuse_gelu: - pre_gelu_specs = (None,) - out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) - - def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta): + def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, alpha, beta): # We should not fuse bias in the output reduction case - sharded_fuse_bias = fuse_bias and reduce_spec is None - outputs = GemmPrimitive.impl( + has_bias = bias.size > 0 + fuse_bias = has_bias and reduce_spec is None + bias_for_impl = bias if fuse_bias else jnp.empty(0, dtype=bias.dtype) + (output,) = GemmPrimitive.impl( lhs, lhs_scale_inv, rhs, rhs_scale_inv, - bias, - gelu_input, + bias_for_impl, alpha, beta, out_dtype=out_dtype, contracting_dims=contracting_dims, scaling_mode=scaling_mode, - fuse_bias=sharded_fuse_bias, - fuse_gelu=fuse_gelu, - grad=grad, use_split_accumulator=use_split_accumulator, transpose_batch_sequence=transpose_batch_sequence, sequence_dim=inferred_sequence_dim, @@ -1149,27 +1147,24 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alph if reduce_spec is not None: if not collective_op.is_reduce_scatter: if is_all_reduce_in_float32(): # For unittest only - outputs[0] = jax.lax.psum( - outputs[0].astype(jnp.float32), reduce_spec - ).astype(out_dtype) + output = jax.lax.psum(output.astype(jnp.float32), reduce_spec).astype( + out_dtype + ) else: - outputs[0] = jax.lax.psum(outputs[0], reduce_spec) + output = jax.lax.psum(output, reduce_spec) - if fuse_bias: # TODO(Phuong): rename fuse_bias to has_bias - outputs[0] += bias + if has_bias: + output += bias - return outputs + return (output,) - return mesh, _sharded_impl, out_shardings, arg_shardings + return mesh, _sharded_impl, out_sharding, arg_shardings @staticmethod def shardy_sharding_rule( out_dtype, contracting_dims, scaling_mode, - fuse_bias, - fuse_gelu, - grad, use_split_accumulator, transpose_batch_sequence, sequence_dim, @@ -1183,9 +1178,16 @@ def shardy_sharding_rule( del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer if not collective_op.is_none: - raise NotImplementedError( - "CollectiveGEMM with Shardy propagation is not supported yet! Please turn off" - " Shardy by exporting env var JAX_USE_SHARDY_PARTITIONER=false" + warnings.warn( + "CollectiveGEMM with Shardy propagation may produce an incorrect sharding pattern" + " for the output.\n To resolve this, apply a sharding constraint on the output" + " using one of the following options:\n" + " - TE `dense` vjp: set `output_axes`.\n" + " - TE `layernorm_mlp` vjp: set `dot_2_input_axes`.\n" + " - TE `transformer_engine.jax.cpp_extensions.gemm`: apply" + " `jax.lax.with_sharding_constraint` on the output.\n" + " - TE via MaxText: no action needed.", + UserWarning, ) prefix = "Gemm_" @@ -1222,11 +1224,10 @@ def _generate_operand_rules(name, ndim, cdims): lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims) rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) out_spec = (*lhs_non_cspec, *rhs_non_cspec) - bias_spec = rhs_non_cspec if fuse_bias else ("…4",) - gelu_spec = out_spec if fuse_gelu else ("…5",) - alpha_spec = ("_6",) - beta_spec = ("_7",) - dbias_spec = bias_spec if grad else ("…8") + bias_aval = operand_types[4] + bias_spec = rhs_non_cspec if math.prod(bias_aval.shape) > 0 else ("…4",) + alpha_spec = ("_5",) + beta_spec = ("_6",) return SdyShardingRule( operand_mappings=( @@ -1235,56 +1236,30 @@ def _generate_operand_rules(name, ndim, cdims): rhs_specs, rhs_scale_specs, bias_spec, - gelu_spec, alpha_spec, beta_spec, ), - result_mappings=( - out_spec, - dbias_spec, - gelu_spec, - ), + result_mappings=(out_spec,), ) register_primitive(GemmPrimitive) -def gemm_uses_jax_dot() -> bool: - """Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot.""" - return not GemmPrimitive.enabled() - - +# TODO(Phuong): move this function down after GroupedGemmPrimitive after initial review. Keep it +# here for now to minimize line changes. def _te_gemm( lhs: Union[jax.Array, ScaledTensor], rhs: Union[jax.Array, ScaledTensor], bias: jax.Array = None, - gelu_input: jax.Array = None, lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), - fuse_bias: bool = False, - fuse_gelu: bool = False, - grad: bool = False, - use_split_accumulator: bool = None, + use_split_accumulator: bool = False, transpose_batch_sequence: bool = False, collective_op: CollectiveOp = CollectiveOp.NONE, ) -> Tuple[jax.Array, ...]: - if grad or fuse_gelu: - warnings.warn( - "GEMM + fused grad or fused gelu is not well tested and will be deprecated in the" - " future", - DeprecationWarning, - ) - - if use_split_accumulator is None: - # TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also - # use context of the GEMM type so we can decide between fprop, dgrad, and wgrad - use_split_accumulator = get_quantize_config_with_recipe( - get_global_quantize_recipe() - ).FP8_2X_ACC_FPROP - # Prepare non-quantized GEMM operands lhs_data = lhs rhs_data = rhs @@ -1301,10 +1276,11 @@ def _te_gemm( lhs_amax = rhs_amax = None # Extract GEMM custom op inputs from quantized operands if isinstance(lhs_q, ScaledTensor): - assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, ( - "cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid " - "`Quantizer` object to quantize the RHS operand." - ) + if not isinstance(rhs_q, ScaledTensor) and rhs_quantizer is None: + raise ValueError( + "cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid " + "`Quantizer` object to quantize the RHS operand." + ) if isinstance(lhs_q, ScaledTensor2x): # Choose the quantization of the contracting dimension(s) lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() @@ -1316,21 +1292,23 @@ def _te_gemm( lhs_amax = lhs_q.amax if isinstance(rhs_q, ScaledTensor): - assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( - "cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid " - "`Quantizer` object to quantize the LHS operand." - ) + if not isinstance(lhs_q, ScaledTensor) and lhs_quantizer is None: + raise ValueError( + "cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid " + "`Quantizer` object to quantize the LHS operand." + ) if isinstance(rhs_q, ScaledTensor2x): # Choose the quantization of the contracting dimension(s) rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() - assert ( + if not ( rhs_q.scaling_mode == lhs_q.scaling_mode or rhs_q.scaling_mode.is_nvfp4_scaling and lhs_q.scaling_mode.is_nvfp4_scaling - ), ( - "cuBLAS GEMM quantized operands have mismatched scaling types, " - f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." - ) + ): + raise ValueError( + "cuBLAS GEMM quantized operands have mismatched scaling types, " + f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." + ) rhs_data = rhs_q.data rhs_scale_inv = rhs_q.scale_inv if rhs_q.data_layout == "T": @@ -1340,39 +1318,40 @@ def _te_gemm( alpha = jnp.ones((1,), jnp.float32) beta = jnp.zeros((1,), jnp.float32) if scaling_mode.is_nvfp4_scaling: - assert lhs_amax is not None and rhs_amax is not None + if lhs_amax is None or rhs_amax is None: + raise ValueError("NVFP4 scaling requires non-None amax for both LHS and RHS operands") lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs_amax) rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv - # Dummy empties for bias and gelu + if not collective_op.is_none: + assert not scaling_mode.is_nvfp4_scaling, ( + f"Collective GEMM is not yet supported with {scaling_mode} quantization. Only" + " DELAYED_TENSOR_SCALING, CURRENT_TENSOR_SCALING, and MXFP8_1D_SCALING are supported." + ) + out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype - if bias is None or not (fuse_bias and not grad): + if bias is None: bias = jnp.empty(0, dtype=out_dtype) - if gelu_input is None or not (fuse_gelu and grad): - gelu_input = jnp.empty(0, dtype=out_dtype) - return GemmPrimitive.outer_primitive.bind( + (output,) = GemmPrimitive.outer_primitive.bind( lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, - gelu_input, alpha, beta, out_dtype=out_dtype, contracting_dims=(lhs_cdims, rhs_cdims), scaling_mode=scaling_mode, - fuse_bias=fuse_bias, - fuse_gelu=fuse_gelu, - grad=grad, use_split_accumulator=use_split_accumulator, transpose_batch_sequence=transpose_batch_sequence, sequence_dim=-1, # Dummy value and will be set in the primitive is_outer=True, collective_op=collective_op, ) + return output class GroupedGemmCopySizesPrimitive(BasePrimitive): @@ -1421,7 +1400,10 @@ def impl( group_sizes, num_gemms, ): - assert GroupedGemmCopySizesPrimitive.inner_primitive is not None + if GroupedGemmCopySizesPrimitive.inner_primitive is None: + raise RuntimeError( + "GroupedGemmCopySizesPrimitive.inner_primitive has not been registered" + ) out = GroupedGemmCopySizesPrimitive.inner_primitive.bind( group_sizes, num_gemms=num_gemms, @@ -1434,12 +1416,15 @@ def impl( class GroupedGemmPrimitive(BasePrimitive): """ - Primitive for grouped GEMM + Primitive for grouped GEMM using nvte_multi_tensor_gemm (supports all scaling modes) or nvte_grouped_gemm (supporting BF16). """ + # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, group_offset, unused_placeholder name = "te_grouped_gemm_ffi" + # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, alpha, beta + name_graph_safe = "te_grouped_gemm_v2_ffi" multiple_results = True - impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) inner_primitive = None outer_primitive = None @@ -1451,8 +1436,7 @@ def abstract( rhs_scale_inv_aval, bias_aval, group_sizes_aval, - group_offset_aval, - *, + *additional_args, # group_offset_aval, unused_placeholder OR alpha_aval, beta_aval M, N, K, @@ -1463,6 +1447,7 @@ def abstract( has_bias, is_grouped_dense_wgrad, use_async_d2h_group_sizes, + use_v2_ffi, ): """ Grouped GEMM operation. @@ -1474,7 +1459,11 @@ def abstract( rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array bias: Bias matrix of shape (G, N) group_sizes: 1D array containing the sizes of each group - group_offset: 1D array containing offsets for each group (not yet implemented) + additional_args: Either + * group_offsets: 1D array containing offsets for each group (not yet implemented) + OR + * alpha: 1D array of shape (G,) containing alpha values for each group + * beta: 1D array of shape (G,) containing beta values for each group M: Number of rows in the output matrix N: Number of columns in the output matrix K: Number of columns in the left-hand side matrix @@ -1489,10 +1478,66 @@ def abstract( Returns: A jnp.ndarray containing the result of the grouped GEMM operation """ - del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval + del lhs_data_aval, rhs_data_aval, bias_aval del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes + + num_groups = group_sizes_aval.size + + cublas_workspace_aval = jax.core.ShapedArray( + shape=( + GroupedGemmPrimitive._compute_cublas_workspace_size( + scaling_mode, lhs_scale_inv_aval, rhs_scale_inv_aval, use_v2_ffi + ), + ), + dtype=jnp.uint8, + ) + + out_shape = (M, N) + if is_grouped_dense_wgrad: + out_shape = (num_groups, M, N) + out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + + if use_v2_ffi: + setup_workspace_aval = jax.core.ShapedArray( + shape=(get_grouped_gemm_setup_workspace_size(num_groups),), dtype=jnp.uint8 + ) + # Temporary buffer for int32 -> int64 conversion of group_sizes on device. + int64_workspace_size = num_groups * jnp.dtype(jnp.int64).itemsize + int64_workspace_aval = jax.core.ShapedArray( + shape=(int64_workspace_size,), dtype=jnp.uint8 + ) + + if len(additional_args) != 2: + raise ValueError( + "Expected additional_args to contain alpha, beta for the graph-safe grouped" + f" GEMM primitive, but got {len(additional_args)} arguments." + ) + alpha_aval, beta_aval = additional_args + if alpha_aval.shape != (num_groups,): + raise ValueError(f"Expected alpha shape {(num_groups,)}, got {alpha_aval.shape}") + if alpha_aval.dtype != jnp.float32: + raise ValueError(f"Expected alpha dtype float32, got {alpha_aval.dtype}") + if beta_aval.shape != (num_groups,): + raise ValueError(f"Expected beta shape {(num_groups,)}, got {beta_aval.shape}") + if beta_aval.dtype != jnp.float32: + raise ValueError(f"Expected beta dtype float32, got {beta_aval.dtype}") + + return (out_aval, cublas_workspace_aval, setup_workspace_aval, int64_workspace_aval) + + return (out_aval, cublas_workspace_aval) + + @staticmethod + def _compute_cublas_workspace_size( + scaling_mode: ScalingMode, + lhs_scale_inv_aval, + rhs_scale_inv_aval, + use_v2_ffi: bool, + ): + """Compute the required cuBLAS workspace size based on the scaling mode and alignment requirements.""" + stream_count = 1 if use_v2_ffi else num_cublas_streams + # TODO(Phuong): move some shape checks from Cpp to here - workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams + workspace_size = get_cublas_workspace_size_bytes() * stream_count workspace_alignment_padding = 256 tensor_scaling_sinv_aligment = 16 mxfp8_scaling_sinv_alignment_padding = 256 @@ -1511,18 +1556,12 @@ def abstract( # We also pad scale_inv swizzle buffers size for 256 bytes alignment. workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding - workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - - out_shape = (M, N) - if is_grouped_dense_wgrad: - out_shape = (group_sizes_aval.size, M, N) - out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) - return (out_aval, workspace_aval) + return workspace_size @staticmethod def outer_abstract(*args, **kwargs): - (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) - return (out_aval,) + (out, *_) = GroupedGemmPrimitive.abstract(*args, **kwargs) + return (out,) @staticmethod def lowering( @@ -1538,9 +1577,24 @@ def lowering( has_bias, is_grouped_dense_wgrad, use_async_d2h_group_sizes, + use_v2_ffi, ): del out_dtype - return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( + if use_v2_ffi: + ffi_name = GroupedGemmPrimitive.name_graph_safe + return jax.ffi.ffi_lowering(ffi_name)( + ctx, + *args, + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, + ) + ffi_name = GroupedGemmPrimitive.name + return jax.ffi.ffi_lowering(ffi_name)( ctx, *args, M=M, @@ -1562,7 +1616,8 @@ def impl( rhs_scale_inv, bias, group_sizes, - group_offset, + additional_arg_0, # group_offset (non-graph-safe) OR alpha (graph-safe) + additional_arg_1, # unused placeholder (non-graph-safe) OR beta (graph-safe) M, N, K, @@ -1573,16 +1628,22 @@ def impl( has_bias, is_grouped_dense_wgrad, use_async_d2h_group_sizes, + use_v2_ffi, ): - assert GroupedGemmPrimitive.inner_primitive is not None - (out, _) = GroupedGemmPrimitive.inner_primitive.bind( + if GroupedGemmPrimitive.inner_primitive is None: + raise RuntimeError("GroupedGemmPrimitive.inner_primitive has not been registered") + if use_v2_ffi: + additional_args = (additional_arg_0, additional_arg_1) + else: + additional_args = (additional_arg_0,) + (out, *_) = GroupedGemmPrimitive.inner_primitive.bind( lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, - group_offset, + *additional_args, M=M, N=N, K=K, @@ -1593,6 +1654,7 @@ def impl( has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, + use_v2_ffi=use_v2_ffi, ) return (out,) @@ -1657,30 +1719,37 @@ def _jax_scaled_matmul( """ JAX GEMM for MXFP8 via scaled_matmul """ - assert rhs.scaling_mode in ( + if rhs.scaling_mode not in ( ScalingMode.MXFP8_1D_SCALING, ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING, - ), f"rhs does not have MXFP8 or NVFP4 scaling mode, got rhs.scaling_mode={rhs.scaling_mode}" + ): + raise ValueError( + "rhs does not have MXFP8 or NVFP4 scaling mode, got" + f" rhs.scaling_mode={rhs.scaling_mode}" + ) (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums expected_lhs_is_colwise = lhs_contract[-1] != lhs.data.ndim - 1 expected_rhs_is_colwise = rhs_contract[-1] != rhs.data.ndim - 1 - assert lhs.is_colwise is expected_lhs_is_colwise, ( - f"LHS with unexpected quantize dimension.\nExpect is_colwise={expected_lhs_is_colwise}, got" - f" {lhs.is_colwise}" - ) - assert rhs.is_colwise is expected_rhs_is_colwise, ( - f"RHS with unexpected quantize dimension.\nExpect is_colwise={expected_rhs_is_colwise}, got" - f" {rhs.is_colwise}" - ) + if lhs.is_colwise is not expected_lhs_is_colwise: + raise ValueError( + f"LHS with unexpected quantize dimension.\nExpect is_colwise={expected_lhs_is_colwise}," + f" got {lhs.is_colwise}" + ) + if rhs.is_colwise is not expected_rhs_is_colwise: + raise ValueError( + f"RHS with unexpected quantize dimension.\nExpect is_colwise={expected_rhs_is_colwise}," + f" got {rhs.is_colwise}" + ) if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: out_dtype = lhs.dq_dtype - assert ( - lhs.data_layout == "N" and rhs.data_layout == "N" - ), f"Got lhs.data_layout={lhs.data_layout}, rhs.data_layout={rhs.data_layout}" + if not (lhs.data_layout == "N" and rhs.data_layout == "N"): + raise ValueError( + f"Got lhs.data_layout={lhs.data_layout}, rhs.data_layout={rhs.data_layout}" + ) else: if lhs.data_layout == "T": lhs_contract = transpose_dims( @@ -1712,7 +1781,8 @@ def _jax_scaled_matmul( lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=out_dtype ) if lhs.scaling_mode.is_nvfp4_scaling: - assert lhs.amax is not None and rhs.amax is not None + if lhs.amax is None or rhs.amax is None: + raise ValueError("NVFP4 scaling requires non-None amax for both LHS and RHS operands") lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs.amax) rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs.amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv @@ -1736,6 +1806,7 @@ def _jax_gemm( contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, + use_split_accumulator: bool = False, ) -> jnp.ndarray: """ FP8 GEMM via JAX @@ -1744,15 +1815,10 @@ def _jax_gemm( def _jax_gemm_impl(lhs, rhs): if lhs.scaling_mode.is_tensor_scaling(): - assert ( - rhs.scaling_mode == lhs.scaling_mode - ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" - - # TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also - # use context of the GEMM type so we can decide between fprop, dgrad, and wgrad - use_split_accumulator = get_quantize_config_with_recipe( - get_global_quantize_recipe() - ).FP8_2X_ACC_FPROP + if rhs.scaling_mode != lhs.scaling_mode: + raise ValueError( + f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" + ) precision = ( jax.lax.Precision.HIGHEST if use_split_accumulator else jax.lax.Precision.DEFAULT @@ -1783,6 +1849,7 @@ def _jax_gemm_impl(lhs, rhs): def gemm( lhs: Union[jnp.ndarray, AbstractBaseTensor], rhs: Union[jnp.ndarray, AbstractBaseTensor], + bias: jnp.ndarray = None, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, @@ -1798,30 +1865,15 @@ def gemm( Left-hand side operand in the matrix multiplication. rhs: Union[jax.Array, ScaledTensor] Right-hand side operand in the matrix multiplication. + bias: jax.Array, default = None + Optional additive bias term. When provided (non-empty), bias is added to the result of the Matrix Multiplication operation. + This bias addition is fused when using the TE's custom call to cuBLAS GEMM. + contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, )) + Tuple of sequences representing the contracting dimensions of the operands. lhs_quantizer: Quantizer, default = None Object for down-casting the LHS operand for quantized GEMM. rhs_quantizer: Quantizer, default = None Object for down-casting the RHS operand for quantized GEMM. - contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, )) - Tuple of sequences representing the contracting dimensions of the operands. - bias: jax.Array, default = None - Optional additive bias term, required for forward GEMM with bias fusion. Only supported - with TE's custom call to cuBLAS GEMM. - gelu_input: jax.Array, default = None - Pre-GeLU output from forward GEMM, required for backward/grad GEMM with dGeLU fusion. Only - supported with TE's custom call to cuBLAS GEMM. - fuse_bias: bool, default = False - Enable bias addition in forward GEMM or bias gradient in backward GEMM. Only supported with - TE's custom call to cuBLAS GEMM. - fuse_gelu: bool, default = False - Enable GeLU activation in forward GEMM or GeLU gradient in backward GEMM. Only supported - with TE's custom call to cuBLAS GEMM. - grad: bool, default = False - Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with - TE's custom call to cuBLAS GEMM. - use_split_accumulator: bool, default = True - Enable promoting some intermediate sums to higher precision when accumulating the result in - the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. transpose_batch_sequence: bool, default = False Transpose the batch and sequence dimensions of the input tensor. collective_op: CollectiveOp, default = CollectiveOp.NONE @@ -1830,18 +1882,7 @@ def gemm( Returns ------- jax.Array: - Result of the operation. For TE's custom call to cuBLAS GEMM, this result can include the - GeLU application when `fuse_gelu=True` and `grad=False`, the GeLU gradient contribution - when `fuse_gelu=True` and `grad=True`, and the additive bias when `fuse_bias=True` and - `grad=False`. - Optional[jax.Array]: - Bias gradient when `fuse_bias=True` and `grad=True`. Only supported with TE's custom call - to cuBLAS GEMM. - Optional[jax.Array]: - Pre-GeLU GEMM output when `fuse_gelu=True` and `grad=False`. This is required as an input - to `_te_gemm()` with `fuse_gelu=True` and `grad=True` in the backward pass in order to - compute the GeLU contribution to the gradient. Only supported with TE's custom call to - cuBLAS GEMM. + Result of the operation lhs * rhs + bias. """ if isinstance(lhs, NoScaleTensor): lhs = lhs.data @@ -1855,45 +1896,34 @@ def gemm( lhs_quantizer = quantizer_set.x rhs_quantizer = quantizer_set.kernel + # This option enable promoting some intermediate sums to higher precision when accumulating the result in + # the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. + use_split_accumulator = _get_high_precision_accumulation_from_env() + # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled - # TODO(Phuong): fuse_bias -> has_bias and has_bias = bias is not None - fuse_bias = kwargs.get("fuse_bias", False) - fuse_gelu = kwargs.get("fuse_gelu", False) if not GemmPrimitive.enabled(): - assert kwargs.get("bias", None) is None and not fuse_gelu, ( - "TE GEMM was invoked with bias fusion options that are not supported by the " - "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " - "GEMM primitive is disabled." - ) - assert kwargs.get("gelu_input", None) is None and not fuse_bias, ( - "TE GEMM was invoked with GeLU fusion options that are not supported by the " - "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " - "GEMM primitive is disabled." + if not collective_op.is_none: + raise RuntimeError("JAX GEMM does not support collective GEMM") + output = _jax_gemm( + lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer, use_split_accumulator ) - assert collective_op.is_none, "JAX GEMM does not support collective GEMM" - return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) + if bias is not None: + output += bias # Unfused + return output - outputs = _te_gemm( + output = _te_gemm( lhs, rhs, + bias, lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, contracting_dims=contracting_dims, + use_split_accumulator=use_split_accumulator, transpose_batch_sequence=transpose_batch_sequence, collective_op=collective_op, - **kwargs, ) - # Discard empty outputs - grad = kwargs.get("grad", False) - clean_outputs = outputs[0] # first output is the final result and is never empty - if (fuse_bias and grad) or (fuse_gelu and not grad): - clean_outputs = (outputs[0],) - if fuse_bias and grad: # only return bias gradient if it exists - clean_outputs += (outputs[1],) - if fuse_gelu and not grad: # only return pre-GeLU output if it exists - clean_outputs += (outputs[2],) - return clean_outputs + return output def grouped_gemm_copy_group_sizes( @@ -1914,6 +1944,23 @@ def grouped_gemm_copy_group_sizes( return out +def _can_use_v2_grouped_gemm( + scaling_mode: ScalingMode, + dtype: jnp.dtype, + has_bias: bool, +) -> bool: + """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters.""" + # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy + # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay + # feature-compatible with the main branch. + # Bias can be supported in a kernel or in pure-JAX in the future. + + if not _v2_grouped_gemm_available: + return False + + return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias + + def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], @@ -1948,14 +1995,15 @@ def grouped_gemm( lhs: [M, K] or [K, N] rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] """ - # TODO(Phuong): implement the group_offset - group_offset = group_offset or jnp.zeros((1,), jnp.int32) # TODO(Phuong): implement the precision del precision if isinstance(lhs, jnp.ndarray): - assert isinstance(rhs, jnp.ndarray) + if not isinstance(rhs, jnp.ndarray): + raise TypeError( + f"Expected rhs to be jnp.ndarray when lhs is jnp.ndarray, but got type={type(rhs)}" + ) out_dtype = lhs.dtype lhs_shape = lhs.shape rhs_shape = rhs.shape @@ -1964,7 +2012,11 @@ def grouped_gemm( lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) scaling_mode = ScalingMode.NO_SCALING elif isinstance(lhs, GroupedScaledTensor1x): - assert isinstance(rhs, GroupedScaledTensor1x) + if not isinstance(rhs, GroupedScaledTensor1x): + raise TypeError( + "Expected rhs to be GroupedScaledTensor1x when lhs is GroupedScaledTensor1x, but" + f" got type={type(rhs)}" + ) out_dtype = lhs.dq_dtype lhs_shape = lhs.original_shape rhs_shape = rhs.original_shape @@ -1972,7 +2024,11 @@ def grouped_gemm( rhs_data = rhs.data lhs_scale_inv = lhs.scale_inv rhs_scale_inv = rhs.scale_inv - assert lhs.scaling_mode == rhs.scaling_mode + if lhs.scaling_mode != rhs.scaling_mode: + raise ValueError( + f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," + f" rhs.scaling_mode={rhs.scaling_mode}" + ) scaling_mode = lhs.scaling_mode else: raise TypeError("Unsupported lhs type object!") @@ -2009,8 +2065,16 @@ def grouped_gemm( and not isinstance(rhs, ScaledTensor) and quantizer_set != noop_quantizer_set ): - assert isinstance(quantizer_set.x, GroupedQuantizer) - assert type(quantizer_set.x) is type(quantizer_set.kernel) + if not isinstance(quantizer_set.x, GroupedQuantizer): + raise TypeError( + "Expected quantizer_set.x to be GroupedQuantizer, but got" + f" type={type(quantizer_set.x)}" + ) + if type(quantizer_set.x) is not type(quantizer_set.kernel): + raise TypeError( + "Expected quantizer_set.x and quantizer_set.kernel to have the same type, but got" + f" {type(quantizer_set.x)} and {type(quantizer_set.kernel)}" + ) scaling_mode = quantizer_set.x.scaling_mode if ( quantizer_set.x.scaling_mode.is_tensor_scaling() @@ -2037,9 +2101,8 @@ def grouped_gemm( lhs_shape = lhs_q.original_shape rhs_shape = rhs_q.original_shape - assert not ( - lhs_data.dtype == jnp_float8_e5m2_type and rhs_data.dtype == jnp_float8_e5m2_type - ), "FP8 GEMM does not support E5M2 * E5M2" + if lhs_data.dtype == jnp_float8_e5m2_type and rhs_data.dtype == jnp_float8_e5m2_type: + raise ValueError("FP8 GEMM does not support E5M2 * E5M2") # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs # thus additional transpose is required @@ -2052,12 +2115,10 @@ def grouped_gemm( rhs_layout_is_T = rhs_q.data_layout == "T" # we can't apply _shape_normalization on the grouped input # thus we need to ensure that lhs is in N and rhs is in T - assert ( - lhs_is_trans == lhs_layout_is_T - ), "lhs input must be transposed before calling grouped_gemm" - assert ( - not rhs_is_trans == rhs_layout_is_T - ), "rhs input must be transposed before calling grouped_gemm" + if lhs_is_trans != lhs_layout_is_T: + raise RuntimeError("lhs input must be transposed before calling grouped_gemm") + if (not rhs_is_trans) != rhs_layout_is_T: + raise RuntimeError("rhs input must be transposed before calling grouped_gemm") lhs_is_trans = False rhs_is_trans = True lhs_ndim = len(lhs_shape) @@ -2076,21 +2137,46 @@ def grouped_gemm( # Calling GroupedGEMM Custom Call K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) - assert K_lhs == K_rhs + if K_lhs != K_rhs: + raise ValueError( + f"Mismatched contracting dimensions: K_lhs={K_lhs}, K_rhs={K_rhs} (from" + f" lhs_shape={lhs_shape}, rhs_shape={rhs_shape})" + ) M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G if is_grouped_dense_wgrad: N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) else: - assert group_sizes.size == rhs_shape[0] - - assert group_offset.size == 1 + if group_sizes.size != rhs_shape[0]: + raise ValueError( + "Expected group_sizes.size == rhs_shape[0], but got" + f" group_sizes.size={group_sizes.size}, rhs_shape[0]={rhs_shape[0]}" + ) has_bias = bias is not None - assert not has_bias or bias.shape == (group_sizes.size, N) + if has_bias and bias.shape != (group_sizes.size, N): + raise ValueError( + f"Expected bias.shape=({group_sizes.size}, {N}), but got bias.shape={bias.shape}" + ) bias = jnp.empty((), jnp.float32) if bias is None else bias + if group_offset is not None: + raise RuntimeError( + "group_offset is not supported yet and is instead computed" + " internally assuming contiguous grouping. Any padding is included in the group_sizes" + " and padded with zeros to not affect the result of the MoE block." + ) + + use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) + if use_v2_ffi: + num_gemms = group_sizes.shape[0] + additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha + additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta + else: + additional_arg_0 = jnp.zeros((1,), jnp.int32) # group_offset + additional_arg_1 = jnp.zeros((0,), jnp.int32) # unused placeholder + (out,) = GroupedGemmPrimitive.outer_primitive.bind( lhs_data, lhs_scale_inv, @@ -2098,7 +2184,8 @@ def grouped_gemm( rhs_scale_inv, bias, group_sizes, - group_offset, + additional_arg_0, + additional_arg_1, M=M, N=N, K=K_lhs, @@ -2109,5 +2196,6 @@ def grouped_gemm( has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, + use_v2_ffi=use_v2_ffi, ) return out diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 5ec0c4b4e..4b32a002f 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -136,9 +136,17 @@ def abstract( ) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) - assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert scale_aval is None or scale_aval.dtype == jnp.float32 - assert amax_aval is None or amax_aval.dtype == jnp.float32 + assert x_dtype in [ + jnp.float32, + jnp.float16, + jnp.bfloat16, + ], f"Unsupported x_dtype={x_dtype}, expected one of [float32, float16, bfloat16]" + assert ( + scale_aval is None or scale_aval.dtype == jnp.float32 + ), f"Expected scale_aval.dtype=float32, but got scale_aval.dtype={scale_aval.dtype}" + assert ( + amax_aval is None or amax_aval.dtype == jnp.float32 + ), f"Expected amax_aval.dtype=float32, but got amax_aval.dtype={amax_aval.dtype}" assert ( scaling_mode != ScalingMode.MXFP8_1D_SCALING.value @@ -163,7 +171,10 @@ def abstract( mu_rsigama_dtype = jnp.float32 if norm_type == NVTE_Norm_Type.LayerNorm: - assert gamma_aval.size == beta_aval.size + assert gamma_aval.size == beta_aval.size, ( + "Expected gamma_aval.size == beta_aval.size, but got" + f" gamma_aval.size={gamma_aval.size}, beta_aval.size={beta_aval.size}" + ) assert gamma_aval.dtype == beta_aval.dtype, ( f"gamma and beta should have the same dtype, but got {gamma_aval.dtype} and " f"{beta_aval.dtype}" @@ -269,18 +280,35 @@ def lowering( del out_dtype, scale_dtype, is_outer, amax_scope, transpose_batch_sequence x_aval, scale_aval, amax_aval, gamma_aval, beta_aval = ctx.avals_in - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert scale_aval is None or scale_aval.dtype == jnp.float32 - assert amax_aval is None or amax_aval.dtype == jnp.float32 + assert x_aval.dtype in [ + jnp.float32, + jnp.float16, + jnp.bfloat16, + ], f"Unsupported x_aval.dtype={x_aval.dtype}, expected one of [float32, float16, bfloat16]" + assert ( + scale_aval is None or scale_aval.dtype == jnp.float32 + ), f"Expected scale_aval.dtype=float32, but got scale_aval.dtype={scale_aval.dtype}" + assert ( + amax_aval is None or amax_aval.dtype == jnp.float32 + ), f"Expected amax_aval.dtype=float32, but got amax_aval.dtype={amax_aval.dtype}" g_type = ir.RankedTensorType(gamma.type) g_shape = g_type.shape if norm_type == NVTE_Norm_Type.LayerNorm: - assert gamma_aval.dtype == beta_aval.dtype + assert gamma_aval.dtype == beta_aval.dtype, ( + "Expected gamma and beta to have the same dtype, but got" + f" gamma_aval.dtype={gamma_aval.dtype}, beta_aval.dtype={beta_aval.dtype}" + ) b_type = ir.RankedTensorType(beta.type) b_shape = b_type.shape - assert g_type == b_type - assert g_shape == b_shape + assert g_type == b_type, ( + f"Expected gamma and beta to have the same IR type, but got gamma_type={g_type}," + f" beta_type={b_type}" + ) + assert g_shape == b_shape, ( + f"Expected gamma and beta to have the same shape, but got gamma_shape={g_shape}," + f" beta_shape={b_shape}" + ) sm_margin = get_forward_sm_margin() return ffi.ffi_lowering( @@ -325,7 +353,9 @@ def impl( to describe implementation """ del is_outer - assert NormFwdPrimitive.inner_primitive is not None + assert ( + NormFwdPrimitive.inner_primitive is not None + ), "NormFwdPrimitive.inner_primitive has not been registered" ( out, colwise_out, @@ -399,7 +429,9 @@ def batcher( to describe batch rules for vmap """ check_valid_batch_dims(batch_dims) - assert NormFwdPrimitive.outer_primitive is not None + assert ( + NormFwdPrimitive.outer_primitive is not None + ), "NormFwdPrimitive.outer_primitive has not been registered" x, scale, amax, gamma, beta = batched_args x_bdim, scale_bdim, _, _, _ = batch_dims @@ -716,13 +748,26 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_ w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype) - assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype - assert dz_aval.shape == x_aval.shape + assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype, ( + f"Expected dz_aval.dtype={w_dtype} (matching gamma dtype), but got" + f" dz_aval.dtype={dtypes.canonicalize_dtype(dz_aval.dtype)}" + ) + assert dz_aval.shape == x_aval.shape, ( + f"Expected dz_aval.shape == x_aval.shape, but got dz_aval.shape={dz_aval.shape}," + f" x_aval.shape={x_aval.shape}" + ) if norm_type == NVTE_Norm_Type.LayerNorm: mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype) - assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1] - assert mu_dtype == rsigma_dtype == jnp.float32 + assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1], ( + "Expected mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1], but got" + f" mu_aval.shape={mu_aval.shape}, rsigma_aval.shape={rsigma_aval.shape}," + f" x_aval.shape[:-1]={x_aval.shape[:-1]}" + ) + assert mu_dtype == rsigma_dtype == jnp.float32, ( + f"Expected mu_dtype == rsigma_dtype == float32, but got mu_dtype={mu_dtype}," + f" rsigma_dtype={rsigma_dtype}" + ) dx_aval = dz_aval dgamma_aval = dbeta_aval = gamma_aval @@ -766,8 +811,14 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma): g_shape = g_type.shape b_type = ir.RankedTensorType(gamma.type) b_shape = b_type.shape - assert g_type == b_type - assert g_shape == b_shape + assert g_type == b_type, ( + f"Expected gamma and beta to have the same IR type, but got gamma_type={g_type}," + f" beta_type={b_type}" + ) + assert g_shape == b_shape, ( + f"Expected gamma and beta to have the same shape, but got gamma_shape={g_shape}," + f" beta_shape={b_shape}" + ) sm_margin = get_backward_sm_margin() return ffi.ffi_lowering(NormBwdPrimitive.name)( @@ -784,7 +835,9 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma): @staticmethod def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma): - assert NormBwdPrimitive.inner_primitive is not None + assert ( + NormBwdPrimitive.inner_primitive is not None + ), "NormBwdPrimitive.inner_primitive has not been registered" dx, dgamma, dbeta, _ = NormBwdPrimitive.inner_primitive.bind( dz, x, mu, rsigma, gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma ) @@ -793,7 +846,9 @@ def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma): @staticmethod def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma): check_valid_batch_dims(batch_dims) - assert NormBwdPrimitive.outer_primitive is not None + assert ( + NormBwdPrimitive.outer_primitive is not None + ), "NormBwdPrimitive.outer_primitive has not been registered" dz, x, mu, rsigma, gamma = batched_args _, x_bdim, _, _, gamma_bdim = batch_dims diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 7720cba3b..2613d5b8f 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -99,7 +99,9 @@ def abstract( dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] out_shape = x_aval.shape - assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert ( + scale_aval is None or scale_aval.dtype == jnp.float32 + ), f"scale must be float32 but received {scale_aval}" if stochastic_rounding: assert ScalingMode( scaling_mode @@ -1215,7 +1217,7 @@ def grouped_quantize( assert n_groups == len( quantizer.quantizers ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}" - scale = jnp.empty((n_groups,), jnp.float32) + scale = jnp.ones((n_groups,), jnp.float32) if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: for i, quantizer_i in enumerate(quantizer.quantizers): diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py new file mode 100644 index 000000000..f2affacda --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -0,0 +1,704 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for fused MoE router""" +from enum import IntEnum + +import jax.numpy as jnp +from jax import dtypes, ffi +from jax.sharding import NamedSharding, PartitionSpec +from transformer_engine_jax import JAXX_Score_Function + +from .base import BasePrimitive, register_primitive +from .misc import get_padded_spec + +__all__ = [ + "ScoreFunction", + "fused_topk_with_score_function_fwd", + "fused_topk_with_score_function_bwd", + "fused_moe_aux_loss_fwd", + "fused_moe_aux_loss_bwd", +] + + +class ScoreFunction(IntEnum): + """Score function enum for fused MoE router kernels, synced with C++ JAXX_Score_Function.""" + + SIGMOID = int(JAXX_Score_Function.SIGMOID) + SOFTMAX = int(JAXX_Score_Function.SOFTMAX) + + +# =========================================== ================================== +# Fused Top-K with Score Function - Forward +# ============================================================================= + + +class FusedTopkWithScoreFunctionFwdPrimitive(BasePrimitive): + """ + Fused Top-K with Score Function Forward Primitive. + Computes score_function(logits) -> top-k -> probs, routing_map. + When compute_aux_scores=1, instead computes clean scores for aux loss. + """ + + name = "te_fused_topk_with_score_function_forward_ffi" + multiple_results = True + impl_static_args = ( + 2, + 3, + 4, + 5, + 6, + 7, + 8, + ) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, compute_aux_scores + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + logits_aval, + expert_bias_aval, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ): + """Abstract evaluation: describe output shapes and dtypes.""" + del expert_bias_aval, topk, use_pre_softmax, num_groups, group_topk + del scaling_factor, score_function, compute_aux_scores + i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype) + i_shape = logits_aval.shape + probs_aval = logits_aval.update(shape=i_shape, dtype=i_dtype) + routing_map_aval = logits_aval.update(shape=i_shape, dtype=jnp.bool_) + # The CUDA kernel always uses float32 (CompType) for intermediate + # computations (softmax/sigmoid values saved for backward). + intermediate_aval = logits_aval.update(shape=i_shape, dtype=jnp.float32) + return probs_aval, routing_map_aval, intermediate_aval + + @staticmethod + def lowering( + ctx, + logits, + expert_bias, + *, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ): + return ffi.ffi_lowering(FusedTopkWithScoreFunctionFwdPrimitive.name)( + ctx, + logits, + expert_bias, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=compute_aux_scores, + ) + + @staticmethod + def impl( + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ): + if FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive is None: + raise RuntimeError( + "FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive has not been registered" + ) + return FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive.bind( + logits, + expert_bias, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=compute_aux_scores, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ): + if FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive is None: + raise RuntimeError( + "FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive has not been registered" + ) + logits, expert_bias = batched_args + logits_bdim, _ = batch_dims + return ( + FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive.bind( + logits, + expert_bias, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=compute_aux_scores, + ), + (logits_bdim, logits_bdim, logits_bdim), + ) + + @staticmethod + def partition( + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + mesh, + arg_infos, + result_infos, + ): + del result_infos + logits_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + routing_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + intermediate_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + out_shardings = [out_sharding, routing_sharding, intermediate_sharding] + arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding) + + def sharded_impl(logits, expert_bias): + return FusedTopkWithScoreFunctionFwdPrimitive.impl( + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return ( + "num_tokens num_experts, bias_dim -> num_tokens num_experts, num_tokens num_experts," + " num_tokens num_experts" + ) + + +register_primitive(FusedTopkWithScoreFunctionFwdPrimitive) + + +# ============================================================================= +# Fused Top-K with Score Function - Backward +# ============================================================================= + + +class FusedTopkWithScoreFunctionBwdPrimitive(BasePrimitive): + """ + Fused Top-K with Score Function Backward Primitive. + When compute_aux_scores=1, runs the score-for-aux-loss backward instead. + """ + + name = "te_fused_topk_with_score_function_backward_ffi" + multiple_results = False + impl_static_args = ( + 3, + 4, + 5, + 6, + 7, + ) # topk, use_pre_softmax, scaling_factor, score_function, compute_aux_scores + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + routing_map_aval, + intermediate_aval, + grad_probs_aval, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + ): + del topk, use_pre_softmax, scaling_factor, score_function + del compute_aux_scores, routing_map_aval + return intermediate_aval.update( + shape=intermediate_aval.shape, + dtype=dtypes.canonicalize_dtype(grad_probs_aval.dtype), + ) + + @staticmethod + def lowering( + ctx, + routing_map, + intermediate, + grad_probs, + *, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + ): + return ffi.ffi_lowering(FusedTopkWithScoreFunctionBwdPrimitive.name)( + ctx, + routing_map, + intermediate, + grad_probs, + topk=topk, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=compute_aux_scores, + ) + + @staticmethod + def impl( + routing_map, + intermediate, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + ): + if FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive is None: + raise RuntimeError( + "FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive has not been registered" + ) + return FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive.bind( + routing_map, + intermediate, + grad_probs, + topk=topk, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=compute_aux_scores, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + ): + if FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive is None: + raise RuntimeError( + "FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive has not been registered" + ) + routing_map, intermediate, grad_probs = batched_args + _, _, grad_probs_bdim = batch_dims + return ( + FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive.bind( + routing_map, + intermediate, + grad_probs, + topk=topk, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=compute_aux_scores, + ), + grad_probs_bdim, + ) + + @staticmethod + def partition( + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + mesh, + arg_infos, + result_infos, + ): + del result_infos + grad_spec = get_padded_spec(arg_infos[2]) + out_sharding = NamedSharding(mesh, PartitionSpec(*grad_spec)) + arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding, arg_infos[2].sharding) + + def sharded_impl(routing_map, intermediate, grad_probs): + return FusedTopkWithScoreFunctionBwdPrimitive.impl( + routing_map, + intermediate, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return ( + "num_tokens num_experts, num_tokens num_experts, num_tokens num_experts -> num_tokens" + " num_experts" + ) + + +register_primitive(FusedTopkWithScoreFunctionBwdPrimitive) + + +# ============================================================================= +# Fused MoE Aux Loss - Forward +# ============================================================================= + + +class FusedMoEAuxLossFwdPrimitive(BasePrimitive): + """ + Fused MoE Aux Loss Forward Primitive. + """ + + name = "te_fused_moe_aux_loss_forward_ffi" + multiple_results = True + impl_static_args = (2, 3) # topk, coeff + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(probs_aval, tokens_per_expert_aval, topk, coeff): + del topk, coeff, tokens_per_expert_aval + i_dtype = dtypes.canonicalize_dtype(probs_aval.dtype) + aux_loss_aval = probs_aval.update(shape=(), dtype=i_dtype) + const_buf_aval = probs_aval.update(shape=(1,), dtype=jnp.float32) + return aux_loss_aval, const_buf_aval + + @staticmethod + def lowering(ctx, probs, tokens_per_expert, *, topk, coeff): + return ffi.ffi_lowering(FusedMoEAuxLossFwdPrimitive.name)( + ctx, + probs, + tokens_per_expert, + topk=topk, + coeff=coeff, + ) + + @staticmethod + def impl(probs, tokens_per_expert, topk, coeff): + if FusedMoEAuxLossFwdPrimitive.inner_primitive is None: + raise RuntimeError( + "FusedMoEAuxLossFwdPrimitive.inner_primitive has not been registered" + ) + return FusedMoEAuxLossFwdPrimitive.inner_primitive.bind( + probs, + tokens_per_expert, + topk=topk, + coeff=coeff, + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, topk, coeff): + if FusedMoEAuxLossFwdPrimitive.outer_primitive is None: + raise RuntimeError( + "FusedMoEAuxLossFwdPrimitive.outer_primitive has not been registered" + ) + probs, tokens_per_expert = batched_args + probs_bdim, _ = batch_dims + return ( + FusedMoEAuxLossFwdPrimitive.outer_primitive.bind( + probs, + tokens_per_expert, + topk=topk, + coeff=coeff, + ), + (probs_bdim, probs_bdim), + ) + + @staticmethod + def partition(topk, coeff, mesh, arg_infos, result_infos): + del result_infos + aux_loss_sharding = NamedSharding(mesh, PartitionSpec()) + const_buf_sharding = NamedSharding(mesh, PartitionSpec(None)) + out_shardings = [aux_loss_sharding, const_buf_sharding] + arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding) + + def sharded_impl(probs, tokens_per_expert): + return FusedMoEAuxLossFwdPrimitive.impl( + probs, + tokens_per_expert, + topk, + coeff, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return "num_tokens num_experts, num_experts -> , const_buf_one" + + +register_primitive(FusedMoEAuxLossFwdPrimitive) + + +# ============================================================================= +# Fused MoE Aux Loss - Backward +# ============================================================================= + + +class FusedMoEAuxLossBwdPrimitive(BasePrimitive): + """ + Fused MoE Aux Loss Backward Primitive. + """ + + name = "te_fused_moe_aux_loss_backward_ffi" + multiple_results = False + impl_static_args = (3,) # num_tokens + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(const_buf_aval, tokens_per_expert_aval, grad_aux_loss_aval, num_tokens): + del const_buf_aval + num_experts = tokens_per_expert_aval.shape[0] + out_dtype = dtypes.canonicalize_dtype(grad_aux_loss_aval.dtype) + return grad_aux_loss_aval.update( + shape=(num_tokens, num_experts), + dtype=out_dtype, + ) + + @staticmethod + def lowering(ctx, const_buf, tokens_per_expert, grad_aux_loss, *, num_tokens): + del num_tokens + return ffi.ffi_lowering(FusedMoEAuxLossBwdPrimitive.name)( + ctx, + const_buf, + tokens_per_expert, + grad_aux_loss, + ) + + @staticmethod + def impl(const_buf, tokens_per_expert, grad_aux_loss, num_tokens): + if FusedMoEAuxLossBwdPrimitive.inner_primitive is None: + raise RuntimeError( + "FusedMoEAuxLossBwdPrimitive.inner_primitive has not been registered" + ) + return FusedMoEAuxLossBwdPrimitive.inner_primitive.bind( + const_buf, + tokens_per_expert, + grad_aux_loss, + num_tokens=num_tokens, + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, num_tokens): + if FusedMoEAuxLossBwdPrimitive.outer_primitive is None: + raise RuntimeError( + "FusedMoEAuxLossBwdPrimitive.outer_primitive has not been registered" + ) + const_buf, tokens_per_expert, grad_aux_loss = batched_args + _, _, grad_bdim = batch_dims + return ( + FusedMoEAuxLossBwdPrimitive.outer_primitive.bind( + const_buf, + tokens_per_expert, + grad_aux_loss, + num_tokens=num_tokens, + ), + grad_bdim, + ) + + @staticmethod + def partition( + num_tokens, + mesh, + arg_infos, + result_infos, + ): + del result_infos + out_sharding = NamedSharding(mesh, PartitionSpec(None, None)) + arg_shardings = ( + arg_infos[0].sharding, + arg_infos[1].sharding, + arg_infos[2].sharding, + ) + + def sharded_impl(const_buf, tokens_per_expert, grad_aux_loss): + return FusedMoEAuxLossBwdPrimitive.impl( + const_buf, + tokens_per_expert, + grad_aux_loss, + num_tokens, + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + # num_tokens only appears in the output (not in any input) because the + # backward reconstructs the full [num_tokens, num_experts] grad_probs from + # scalar inputs. Shardy will leave num_tokens unsharded, which matches the + # replicated PartitionSpec(None, None) in partition(). + return "const_buf_one, num_experts, grad_one -> i num_experts" + + +register_primitive(FusedMoEAuxLossBwdPrimitive) + + +# ============================================================================= +# Public API functions +# ============================================================================= + + +def fused_topk_with_score_function_fwd( + logits: jnp.ndarray, + topk: int, + use_pre_softmax: bool, + num_groups: int, + group_topk: int, + scaling_factor: float, + score_function, + expert_bias: jnp.ndarray, + compute_aux_scores: bool = False, +): + """ + Fused top-k with score function forward pass. + + When compute_aux_scores=True, runs the clean score-for-aux-loss kernel + instead of the full top-k kernel (expert_bias, use_pre_softmax, num_groups, + group_topk, and scaling_factor are ignored). + + Parameters + ---------- + logits : jnp.ndarray + [num_tokens, num_experts] logits from gating GEMM. + topk : int + Number of top experts to select. + use_pre_softmax : bool + If True, apply softmax before top-k. + num_groups : int + Number of groups for grouped top-k (1 to disable). + group_topk : int + Top-k at group level (1 to disable). + scaling_factor : float + Scaling factor for output probs. + score_function : ScoreFunction + ScoreFunction.SOFTMAX or ScoreFunction.SIGMOID. + expert_bias : jnp.ndarray + Expert bias (only used with sigmoid). Pass empty array if unused. + compute_aux_scores : bool + If True, compute clean scores for aux loss instead of full top-k. + + Returns + ------- + probs_or_scores, routing_map, saved_scores + """ + return FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive.bind( + logits, + expert_bias, + topk=int(topk), + use_pre_softmax=int(use_pre_softmax), + num_groups=int(num_groups), + group_topk=int(group_topk), + scaling_factor=float(scaling_factor), + score_function=int(score_function), + compute_aux_scores=int(compute_aux_scores), + ) + + +def fused_topk_with_score_function_bwd( + routing_map: jnp.ndarray, + saved_scores: jnp.ndarray, + grad_probs: jnp.ndarray, + topk: int, + use_pre_softmax: bool, + scaling_factor: float, + score_function, + compute_aux_scores: bool = False, +): + """ + Fused top-k with score function backward pass. + + When compute_aux_scores=True, routing_map is ignored and the + score-for-aux-loss backward kernel is used instead. + """ + return FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive.bind( + routing_map, + saved_scores, + grad_probs, + topk=int(topk), + use_pre_softmax=int(use_pre_softmax), + scaling_factor=float(scaling_factor), + score_function=int(score_function), + compute_aux_scores=int(compute_aux_scores), + ) + + +def fused_moe_aux_loss_fwd( + probs: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + topk: int, + coeff: float, +): + """ + Fused MoE aux loss forward pass. + + Returns + ------- + aux_loss, const_buf + """ + return FusedMoEAuxLossFwdPrimitive.outer_primitive.bind( + probs, + tokens_per_expert, + topk=int(topk), + coeff=float(coeff), + ) + + +def fused_moe_aux_loss_bwd( + const_buf: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + grad_aux_loss: jnp.ndarray, + num_tokens: int, +): + """ + Fused MoE aux loss backward pass. + """ + return FusedMoEAuxLossBwdPrimitive.outer_primitive.bind( + const_buf, + tokens_per_expert, + grad_aux_loss, + num_tokens=int(num_tokens), + ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 8c2798c68..47ca0fd15 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -51,6 +51,16 @@ struct ActivationConfig { ClampedSwigluConfig clamped_swiglu; }; +struct GemmConfig { + JAXX_Scaling_Mode scaling_mode; + JAXX_Collective_Op collective_op; + int64_t lhs_axis_boundary; + int64_t rhs_axis_boundary; + bool lhs_transposed; + bool rhs_transposed; + bool use_split_accumulator; +}; + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -119,7 +129,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right); + int64_t window_size_right, bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, @@ -127,7 +137,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, - int64_t window_size_right); + int64_t window_size_right, bool bottom_right_diagonal); pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, @@ -135,21 +145,31 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, - int64_t window_size_left, int64_t window_size_right); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal); // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmV2Handler); +#ifndef USE_ROCM XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmInitV2Handler); +#endif // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmV2Handler); #ifndef USE_ROCM // Amax XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler); +#endif +// Inspect +XLA_FFI_DECLARE_HANDLER_SYMBOL(InspectHandler); + +#ifndef USE_ROCM // Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); @@ -157,6 +177,12 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); #endif +// Router +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); + } // namespace jax } // namespace transformer_engine @@ -168,8 +194,19 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( transformer_engine::jax::ActivationConfig, ::xla::ffi::StructMember("clamped_swiglu")); +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::GemmConfig, + ::xla::ffi::StructMember("scaling_mode"), + ::xla::ffi::StructMember("collective_op"), + ::xla::ffi::StructMember("lhs_axis_boundary"), + ::xla::ffi::StructMember("rhs_axis_boundary"), + ::xla::ffi::StructMember("lhs_transposed"), + ::xla::ffi::StructMember("rhs_transposed"), + ::xla::ffi::StructMember("use_split_accumulator")); + // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Score_Function); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Quantize_Layout); diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 6c5a97634..ce5828d6f 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -109,6 +109,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal case NVTE_Activation_Type::GEGLU: nvte_geglu(input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::GLU: + nvte_glu(input_tensor.data(), output_tensor.data(), stream); + break; case NVTE_Activation_Type::SILU: nvte_silu(input_tensor.data(), output_tensor.data(), stream); break; @@ -427,6 +430,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, case NVTE_Activation_Type::GEGLU: nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::GLU: + nvte_dglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; case NVTE_Activation_Type::SWIGLU: nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index ad6f13949..39e37c1c7 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -8,8 +8,6 @@ #ifndef USE_ROCM #include -#include - #include "../extensions.h" #include "transformer_engine/cast.h" #include "transformer_engine/hadamard_transform.h" diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 41347a85e..02efd1b38 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -18,12 +18,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right) { + int64_t window_size_right, bool deterministic) { auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, deterministic); return backend; } @@ -159,7 +159,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, - int64_t window_size_right) { + int64_t window_size_right, bool bottom_right_diagonal) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; @@ -207,7 +207,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); + window_size_left, window_size_right, bottom_right_diagonal, query_workspace_tensor.data(), + nullptr); } nvte_tensor_pack_destroy(&aux_output_tensors); @@ -255,7 +256,7 @@ static void FusedAttnForwardImpl( size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, - int64_t window_size_left, int64_t window_size_right) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) { FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ @@ -284,7 +285,7 @@ static void FusedAttnForwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, deterministic); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -346,7 +347,7 @@ static void FusedAttnForwardImpl( k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + window_size_left, window_size_right, bottom_right_diagonal, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -364,6 +365,7 @@ static void FusedAttnForwardImpl( size_t max_segments_per_seq = get_attr_value(attrs, "max_segments_per_seq"); \ auto window_size_left = get_attr_value(attrs, "window_size_left"); \ auto window_size_right = get_attr_value(attrs, "window_size_right"); \ + bool bottom_right_diagonal = get_attr_value(attrs, "bottom_right_diagonal"); \ float scaling_factor = get_attr_value(attrs, "scaling_factor"); \ float dropout_probability = get_attr_value(attrs, "dropout_probability"); \ NVTE_Bias_Type bias_type = \ @@ -402,7 +404,7 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, softmax_type, qkv_layout, dtype, wkspace_dtype, - is_training, deterministic, window_size_left, window_size_right); + is_training, deterministic, window_size_left, window_size_right, bottom_right_diagonal); return ffi_with_cuda_error_check(); } @@ -433,7 +435,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, - int64_t window_size_left, int64_t window_size_right) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); @@ -485,17 +487,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); - nvte_fused_attn_bwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, deterministic, false, + query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_input_tensors); @@ -514,7 +517,7 @@ static void FusedAttnBackwardImpl( size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, - int64_t window_size_left, int64_t window_size_right) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) { FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ @@ -540,7 +543,7 @@ static void FusedAttnBackwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, deterministic); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); @@ -611,16 +614,17 @@ static void FusedAttnBackwardImpl( } } - nvte_fused_attn_bwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), - dsoftmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream); + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), dsoftmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, + kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, false, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_input_tensors); } @@ -649,7 +653,7 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, softmax_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left, - window_size_right); + window_size_right, bottom_right_diagonal); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp index 9f50826a7..fa1362f61 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -141,8 +141,8 @@ void CommunicatorHandler::init(int num_total_devices, int num_devices_per_proces // Bootstrap UB via creating a dummy CommOverlapP2PBase object std::vector buffer_shape{1, 1}; - auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32, - JAXX_Collective_Op::ALL_GATHER); + [[maybe_unused]] auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor( + buffer_shape, DType::kFloat32, JAXX_Collective_Op::ALL_GATHER); } void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 40121049a..2d9a13278 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -72,6 +72,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( std::vector scale_shape = {1}; auto is_nvfp4 = is_nvfp4_scaling(scaling_mode); auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING || is_nvfp4) { // Block scaling also needs to be collapsed to match 2D data scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), @@ -117,46 +118,75 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( } #ifndef USE_ROCM -Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, - Buffer_Type rhs_scale_inv, Buffer_Type bias, - Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta, - Result_Type output, Result_Type bias_grad, - Result_Type pre_gelu_out, Result_Type workspace, - JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, - int64_t rhs_axis_boundary, bool lhs_transposed, - bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, - bool use_split_accumulator, JAXX_Collective_Op collective_op) { +Error_Type GemmInitV2FFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type alpha, + Buffer_Type beta, Result_Type output, Result_Type workspace, + GemmConfig config) { nvte_cublas_handle_init(); // Init UB buffer - if (collective_op != JAXX_Collective_Op::NONE) { + if (config.collective_op != JAXX_Collective_Op::NONE) { auto &comm_handler = CommunicatorHandler::get(); std::vector lhs_shape = { - product(lhs.dimensions(), 0, lhs_axis_boundary), - product(lhs.dimensions(), lhs_axis_boundary, lhs.dimensions().size())}; + product(lhs.dimensions(), 0, config.lhs_axis_boundary), + product(lhs.dimensions(), config.lhs_axis_boundary, lhs.dimensions().size())}; std::vector rhs_shape = { - product(rhs.dimensions(), 0, rhs_axis_boundary), - product(rhs.dimensions(), rhs_axis_boundary, rhs.dimensions().size())}; + product(rhs.dimensions(), 0, config.rhs_axis_boundary), + product(rhs.dimensions(), config.rhs_axis_boundary, rhs.dimensions().size())}; - std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], - (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; + std::vector out_shape = {(config.lhs_transposed) ? lhs_shape[1] : lhs_shape[0], + (config.rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; std::vector buffer_shape{0, 0}; DType buffer_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); - if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + if (config.collective_op == JAXX_Collective_Op::ALL_GATHER) { buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size; buffer_shape[1] = lhs_shape[1]; buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); - } else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + } else if (config.collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { buffer_shape[0] = out_shape[0]; buffer_shape[1] = out_shape[1]; } - auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, buffer_dtype, - collective_op); + [[maybe_unused]] auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor( + buffer_shape, buffer_dtype, config.collective_op); } return ffi_with_cuda_error_check(); } +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmInitV2Handler, GemmInitV2FFI, + FFI::Bind() + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // alpha + .Arg() // beta + .Ret() // output + .Ret() // workspace + .Attr("config"), + FFI_CudaGraph_Traits); + +Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type bias_grad, + Result_Type pre_gelu_out, Result_Type workspace, + JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + int64_t rhs_axis_boundary, bool lhs_transposed, + bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, + bool use_split_accumulator, JAXX_Collective_Op collective_op) { + static std::once_flag gemm_init_warned; + std::call_once(gemm_init_warned, []() { + std::cerr << "[CollectiveGemmInitFFI] Deprecation: This API is deprecated and will be removed " + "in September 2026. Use GemmInitV2FFI instead." + << std::endl; + }); + return GemmInitV2FFI(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, alpha, beta, output, workspace, + GemmConfig{scaling_mode, collective_op, lhs_axis_boundary, rhs_axis_boundary, + lhs_transposed, rhs_transposed, use_split_accumulator}); +} + XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, FFI::Bind() .Arg() // lhs @@ -180,22 +210,21 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, .Attr("fuse_gelu") .Attr("grad") .Attr("use_split_accumulator") - .Attr("collective_op")); + .Attr("collective_op"), + FFI_CudaGraph_Traits); #endif //#ifndef USE_ROCM -Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, - Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, - Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type bias_grad, - Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, - int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, - bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, - bool use_split_accumulator, JAXX_Collective_Op collective_op) { +Error_Type GemmV2FFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, + Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type workspace, + GemmConfig config) { // cuBLAS workspace + 256 alignment enforcement (+ swizzle scales) uint8_t *lhs_swizzle_scale_ptr = nullptr, *rhs_swizzle_scale_ptr = nullptr; auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); size_t workspace_size = static_cast(workspace->element_count()) - 256; - if (is_nvfp4_scaling(scaling_mode)) { + + if (is_nvfp4_scaling(config.scaling_mode)) { auto lhs_scale_size = product(lhs_scale_inv.dimensions()); auto rhs_scale_size = product(rhs_scale_inv.dimensions()); workspace_size = workspace_size - lhs_scale_size - rhs_scale_size; @@ -207,60 +236,42 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) - bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || - (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); - bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; - bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; - - auto [lhs_, lhs_shape] = - xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, lhs_swizzle_scale_ptr, - scaling_mode, lhs_axis_boundary, make_lhs_rowwise); - auto [rhs_, rhs_shape] = - xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, rhs_swizzle_scale_ptr, - scaling_mode, rhs_axis_boundary, make_rhs_rowwise); - - std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], - (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; + bool always_rowwise = + (config.scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + (is_tensor_scaling(config.scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); + bool make_lhs_rowwise = (always_rowwise) ? true : !config.lhs_transposed; + bool make_rhs_rowwise = (always_rowwise) ? true : config.rhs_transposed; + + auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( + stream, lhs, lhs_scale_inv, lhs_swizzle_scale_ptr, config.scaling_mode, + config.lhs_axis_boundary, make_lhs_rowwise); + auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( + stream, rhs, rhs_scale_inv, rhs_swizzle_scale_ptr, config.scaling_mode, + config.rhs_axis_boundary, make_rhs_rowwise); + + std::vector out_shape = {(config.lhs_transposed) ? lhs_shape[1] : lhs_shape[0], + (config.rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); // Bias input to forward pass or bias gradient output from backward pass void *bias_ptr = nullptr; size_t bias_size = 0; DType bias_dtype = out_dtype; + auto fuse_bias = bias.element_count() > 0; if (fuse_bias) { - if (grad) { - NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(), - "Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad"); - } bias_ptr = bias.untyped_data(); bias_size = product(bias.dimensions()); bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); } auto bias_ = TensorWrapper(bias_ptr, std::vector{bias_size}, bias_dtype); - // Pre-GeLU output from forward pass or input to backward pass - void *pre_gelu_ptr = nullptr; - std::vector pre_gelu_shape = {0}; - DType pre_gelu_dtype = out_dtype; - if (gelu_input.element_count() > 0) { - if (grad) { - NVTE_CHECK(pre_gelu_out->untyped_data() == gelu_input.untyped_data(), - "Missing operand-output aliasing in GemmPrimitive: gelu_input <-> pre_gelu_out"); - } - pre_gelu_ptr = pre_gelu_out->untyped_data(); - pre_gelu_shape = {product(pre_gelu_out->dimensions(), 0, pre_gelu_out->dimensions().size() - 1), - static_cast(pre_gelu_out->dimensions().back())}; - pre_gelu_dtype = convert_ffi_datatype_to_te_dtype(pre_gelu_out->element_type()); - } - auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype); - auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); float one = 1.; float zero = 0.; // alpha, beta float *alpha_ptr = &one, *beta_ptr = &zero; - if (is_nvfp4_scaling(scaling_mode)) { + if (is_nvfp4_scaling(config.scaling_mode)) { NVTE_CHECK(alpha.element_count() == 1 && convert_ffi_datatype_to_te_dtype(alpha.element_type()) == DType::kFloat32); alpha_ptr = reinterpret_cast(alpha.untyped_data()); @@ -270,16 +281,12 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i } // Construct GEMM config - transformer_engine::MatmulConfigWrapper config; - config.set_use_split_accumulator(use_split_accumulator); - config.set_sm_count(num_math_sm); - if (fuse_bias) config.set_bias_tensor(bias_.data()); - if (fuse_gelu) { - config.set_with_gelu_epilogue(true); - config.set_epilogue_aux_tensor(pre_gelu_.data()); - } + transformer_engine::MatmulConfigWrapper matmul_config; + matmul_config.set_use_split_accumulator(config.use_split_accumulator); + matmul_config.set_sm_count(num_math_sm); + if (fuse_bias) matmul_config.set_bias_tensor(bias_.data()); - if (collective_op == JAXX_Collective_Op::NONE) { + if (config.collective_op == JAXX_Collective_Op::NONE) { auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); NVTE_CHECK(out_.numel() == output->element_count(), "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ", @@ -289,9 +296,10 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ", out_shape[1]=", out_shape[1]); // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order - nvte_cublas_gemm_v2(rhs_transposed /*transa*/, lhs_transposed /*transb*/, alpha_ptr, - rhs_.data() /*A*/, lhs_.data() /*B*/, beta_ptr, out_.data() /*C*/, - out_.data() /*D*/, workspace_.data(), config, stream); + nvte_cublas_gemm_v2(config.rhs_transposed /*transa*/, config.lhs_transposed /*transb*/, + alpha_ptr, rhs_.data() /*A*/, lhs_.data() /*B*/, beta_ptr, + out_.data() /*C*/, out_.data() /*D*/, workspace_.data(), matmul_config, + stream); } else { #ifdef USE_ROCM NVTE_ERROR("ROCm TE JAX does not support comm-comp overlap yet."); @@ -299,12 +307,12 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i std::vector buffer_shape{0, 0}; DType buffer_dtype = out_dtype; auto &comm_handler = CommunicatorHandler::get(); - if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + if (config.collective_op == JAXX_Collective_Op::ALL_GATHER) { buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size; buffer_shape[1] = lhs_shape[1]; out_shape[0] = out_shape[0] * comm_handler.tp_size; buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); - } else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + } else if (config.collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { buffer_shape[0] = out_shape[0]; buffer_shape[1] = out_shape[1]; out_shape[0] = out_shape[0] / comm_handler.tp_size; @@ -312,8 +320,9 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size, ", out_shape[1]=", out_shape[1]); auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor( - buffer_shape, buffer_dtype, collective_op); - if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + buffer_shape, buffer_dtype, config.collective_op); + auto pre_gelu_ = TensorWrapper(nullptr, std::vector{0}, DType::kByte); + if (config.collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype); // Prepare the auxiliary buffer for the reduce-scattered GEMM output auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); @@ -323,11 +332,11 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i " elements ", to_string_like(output->dimensions())); // Launch GEMM+RS - executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, ubuf_out_, bias_, - pre_gelu_, workspace_, grad, false, use_split_accumulator, out_, - stream); + executor->split_overlap_rs(rhs_, config.rhs_transposed, lhs_, config.lhs_transposed, + ubuf_out_, bias_, pre_gelu_, workspace_, false /*grad*/, + false /*accumulate*/, config.use_split_accumulator, out_, stream); - } else if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + } else if (config.collective_op == JAXX_Collective_Op::ALL_GATHER) { auto aux_out_ = TensorWrapper(nullptr, std::vector{0}, out_dtype); // Empty auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); @@ -338,8 +347,9 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // Copy the distributed LHS operand into the local chunk of the communication buffer executor->copy_into_buffer(stream, lhs_, true, make_lhs_rowwise); // Launch AG+GEMM - executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, - workspace_, grad, false, use_split_accumulator, aux_out_, stream); + executor->split_overlap_ag(rhs_, config.rhs_transposed, lhs_, config.lhs_transposed, out_, + bias_, pre_gelu_, workspace_, false /*grad*/, false /*accumulate*/, + config.use_split_accumulator, aux_out_, stream); } #endif //#ifdef USE_ROCM } @@ -347,6 +357,56 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i return ffi_with_cuda_error_check(); } +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmV2Handler, GemmV2FFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // alpha + .Arg() // beta + .Ret() // output + .Ret() // workspace + .Attr("config"), + GemmFFI_CudaGraph_Traits); + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type bias_grad, + Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, + int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, + bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, + bool use_split_accumulator, JAXX_Collective_Op collective_op) { + static std::once_flag once_fuse_bias; + static std::once_flag once_fuse_gelu_grad; + static std::once_flag once_api; + if (fuse_bias) { + std::call_once(once_fuse_bias, [] { + std::cerr << "[GemmFFI] Deprecation: fuse_bias is deprecated; bias fusion is inferred from " + "non-empty bias. This parameter will be removed in future release." + << std::endl; + }); + } + if (fuse_gelu || grad) { + std::call_once(once_fuse_gelu_grad, [] { + std::cerr << "[GemmFFI] Deprecation: fuse_gelu and grad are deprecated. These options are " + "ignored as there is no support for them in the current implementation. " + << std::endl; + }); + } + std::call_once(once_api, [] { + std::cerr << "[GemmFFI] Deprecation: This API is deprecated in Sep 2026. Use GemmV2FFI instead." + << std::endl; + }); + + return GemmV2FFI(stream, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, alpha, beta, output, + workspace, + GemmConfig{scaling_mode, collective_op, lhs_axis_boundary, rhs_axis_boundary, + lhs_transposed, rhs_transposed, use_split_accumulator}); +} + XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, FFI::Bind() .Ctx() // stream @@ -433,6 +493,387 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGro .Ret() // dummy_output .Attr("num_gemms")); +class JAXX_GroupedTensorWrapper { + public: + JAXX_GroupedTensorWrapper() = delete; + JAXX_GroupedTensorWrapper(JAXX_Scaling_Mode scaling_mode, size_t num_tensors, + NVTEShape const &dataShape); + JAXX_GroupedTensorWrapper(JAXX_GroupedTensorWrapper const &) = delete; + JAXX_GroupedTensorWrapper &operator=(JAXX_GroupedTensorWrapper const &) = delete; + JAXX_GroupedTensorWrapper(JAXX_GroupedTensorWrapper &&other) noexcept + : m_data_shape(other.m_data_shape), + m_grouped_tensor(other.m_grouped_tensor), + m_data_tensor(other.m_data_tensor), + m_scale_inv_tensor(other.m_scale_inv_tensor), + m_sizes_tensor(other.m_sizes_tensor), + m_offsets_tensor(other.m_offsets_tensor) { + other.m_grouped_tensor = nullptr; + } + JAXX_GroupedTensorWrapper &operator=(JAXX_GroupedTensorWrapper &&) = delete; + ~JAXX_GroupedTensorWrapper(); + + void set_rowwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, + NVTEGroupedTensorParam group_sizes_param_name); + // Set only group sizes (no offsets); the setup kernel will compute offsets from sizes. + void set_group_sizes_only(const int64_t *sizes_ptr, size_t num_tensors, + NVTEGroupedTensorParam group_sizes_param_name); + + operator NVTEGroupedTensor() const { return m_grouped_tensor; } + NVTEGroupedTensor const &get_grouped_tensor() const; + + private: + NVTEShape m_data_shape{}; + NVTEGroupedTensor m_grouped_tensor{}; + + // Internal tensors. These need to be kept alive as long as the grouped tensor is alive. + NVTEBasicTensor m_data_tensor{}; + NVTEBasicTensor m_scale_inv_tensor{}; + + NVTEBasicTensor m_sizes_tensor{}; + NVTEBasicTensor m_offsets_tensor{}; +}; + +JAXX_GroupedTensorWrapper::JAXX_GroupedTensorWrapper(JAXX_Scaling_Mode scaling_mode, + size_t num_tensors, + NVTEShape const &dataShape) { + m_data_shape = dataShape; + m_grouped_tensor = + nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, dataShape); +} + +JAXX_GroupedTensorWrapper::~JAXX_GroupedTensorWrapper() { + if (m_grouped_tensor != nullptr) { + nvte_destroy_grouped_tensor(m_grouped_tensor); + } +} + +void JAXX_GroupedTensorWrapper::set_rowwise(Buffer_Type const &data, + std::optional const &scale_inv) { + NVTEDType data_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())); + m_data_tensor = + NVTEBasicTensor{reinterpret_cast(data.untyped_data()), data_dtype, m_data_shape}; + + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedRowwiseData, &m_data_tensor, + sizeof(m_data_tensor)); + + if (scale_inv.has_value()) { + NVTEDType scale_inv_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())); + NVTEShape logical_scale_shape{}; + if (scale_inv->dimensions().size() == 1) { + logical_scale_shape.ndim = 1; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + } else if (scale_inv->dimensions().size() == 2) { + logical_scale_shape.ndim = 2; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + logical_scale_shape.data[1] = scale_inv->dimensions()[1]; + } else { + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", + scale_inv->dimensions().size()); + } + m_scale_inv_tensor = NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), + scale_inv_dtype, logical_scale_shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedRowwiseScaleInv, + &m_scale_inv_tensor, sizeof(m_scale_inv_tensor)); + } +} + +void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const &group_sizes, + Buffer_Type const &group_offsets, + NVTEGroupedTensorParam group_sizes_param_name) { + NVTEDType sizes_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(group_sizes.element_type())); + NVTEDType offsets_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(group_offsets.element_type())); + + NVTE_CHECK(sizes_dtype == NVTEDType::kNVTEInt64, "group_sizes must be of type int64."); + NVTE_CHECK(offsets_dtype == NVTEDType::kNVTEInt64, "group_offsets must be of type int64."); + + size_t num_tensors = group_sizes.dimensions()[0]; + NVTE_CHECK(group_sizes.dimensions().size() == 1, + "group_sizes must be a 1D tensor with length equal to the number of tensors."); + NVTE_CHECK(group_offsets.dimensions().size() == 1, + "group_offsets must be a 1D tensor with length equal to the number of tensors."); + NVTE_CHECK(group_offsets.dimensions()[0] == num_tensors, + "group_sizes and group_offsets must have the same number of elements."); + + NVTEShape shape{}; + shape.ndim = 1; + shape.data[0] = num_tensors; + + m_sizes_tensor = NVTEBasicTensor{reinterpret_cast(group_sizes.untyped_data()), + NVTEDType::kNVTEInt64, shape}; + m_offsets_tensor = NVTEBasicTensor{reinterpret_cast(group_offsets.untyped_data()), + NVTEDType::kNVTEInt64, shape}; + + nvte_set_grouped_tensor_param(m_grouped_tensor, group_sizes_param_name, &m_sizes_tensor, + sizeof(m_sizes_tensor)); + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedTensorOffsets, &m_offsets_tensor, + sizeof(m_offsets_tensor)); +} + +void JAXX_GroupedTensorWrapper::set_group_sizes_only( + const int64_t *sizes_ptr, size_t num_tensors, NVTEGroupedTensorParam group_sizes_param_name) { + NVTEShape shape{}; + shape.ndim = 1; + shape.data[0] = num_tensors; + m_sizes_tensor = NVTEBasicTensor{reinterpret_cast(const_cast(sizes_ptr)), + NVTEDType::kNVTEInt64, shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, group_sizes_param_name, &m_sizes_tensor, + sizeof(m_sizes_tensor)); + // Intentionally no offset tensor: offsets will be computed by the setup kernel. +} + +NVTEGroupedTensor const &JAXX_GroupedTensorWrapper::get_grouped_tensor() const { + return m_grouped_tensor; +} + +JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, + std::optional scale_inv, + JAXX_Scaling_Mode scaling_mode, size_t num_tensors, + NVTEShape const &dataShape) { + JAXX_GroupedTensorWrapper grouped_tensor_wrapper(scaling_mode, num_tensors, dataShape); + if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING) { + scale_inv = std::nullopt; + } + grouped_tensor_wrapper.set_rowwise(data, scale_inv); + + return std::move(grouped_tensor_wrapper); +} + +// This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. +Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, + Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, + Buffer_Type group_sizes, Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type cublas_workspace, + Result_Type setup_workspace, Result_Type int64_workspace, size_t m, + size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, + JAXX_Scaling_Mode scaling_mode, bool is_grouped_dense_wgrad) { + // Notes on matrix layouts and transpose: + // Jax uses row-major data_layout, on entering this function, each input matrix pair: + // A: row-major [m, k] for N - [k, m] for T + // B: row-major [k, n] for N - [n, k] for T + // on exiting this function, JAX expect: + // C: row-major with size [m, n]. + // cuBLAS uses column-major data_layout, in this view, each input matrix pair: + // A: column-major with size [k, m] for T - [m, k] for N + // B: column-major with size [n, k] for T - [k, n] for N + // + // If we call cuBLAS GEMM for A * B, the output will be: + // C: column-major with size [m, n] --> row-major with size [n, m]. + // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. + + // Inputs + auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); + auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); + auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); + auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); + auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); + auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); + auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); + auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); + bool has_bias = product(bias.dimensions()) > 0; + auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + + NVTE_CHECK(group_sizes.dimensions().size() == 1); + size_t num_gemms = group_sizes.dimensions()[0]; + + // Convert int32 group_sizes to int64 into the dedicated output buffer. + NVTE_CHECK(group_sizes.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(group_sizes.untyped_data()), + int64_sizes_ptr, num_gemms, stream); + + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, + "Only non-quantized grouped GEMM is supported in current implementation."); + + // It is weird that TE/Common GEMM only use colwise for MXFP8 + const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); + const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; + const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; + const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; + + // Outputs + auto out_ptr = reinterpret_cast(output->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + auto setup_workspace_ptr = reinterpret_cast(setup_workspace->untyped_data()); + // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned + auto cublas_workspace_ptr = reinterpret_cast(cublas_workspace->untyped_data()); + cublas_workspace_ptr = move_ptr_to_next_256B_aligned(cublas_workspace_ptr); + auto workspace_total_size = product(cublas_workspace->dimensions()); + + auto lhs_sinv_size = product(lhs_sinv.dimensions()); + auto rhs_sinv_size = product(rhs_sinv.dimensions()); + const size_t workspace_alignment_padding = 256; + const size_t tensor_scaling_sinv_aligment = 16; + const size_t mxfp8_scaling_sinv_alignment_padding = 256; + auto workspace_size = workspace_total_size - workspace_alignment_padding; + if (is_mxfp8_scaling) { + // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. + workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); + } else if (is_tensor_scaling) { + // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned + // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. + workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); + } + auto swizzled_lhs_sinv_ptr = cublas_workspace_ptr + workspace_size; + swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); + auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; + swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); + auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned + auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; + + size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); + size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); + size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); + size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); + size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); + size_t out_dtype_bytes = te_dtype_bytes(out_dtype); + + NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); + NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, + "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); + + size_t expected_lhs_size = m * k; + size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t actual_lhs_size = product(lhs_data.dimensions()); + size_t actual_rhs_size = product(rhs_data.dimensions()); + size_t actual_out_size = product(output->dimensions()); + NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", + expected_lhs_size, ", got ", actual_lhs_size); + if (!is_grouped_dense_wgrad) { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, + "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, + " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, + " * ", n, " = ", expected_out_size, ", got ", actual_out_size); + } else { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, + " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, + "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, + " = ", expected_out_size, ", got ", actual_out_size); + } + + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + bool grad = false; + bool accumulate = false; + bool use_split_accumulator = false; + auto bias_shape = std::vector{has_bias ? n : 0}; + const int arch = cuda::sm_arch(); + + if (arch < 100 && is_fp8_gemm) { + NVTE_CHECK(!lhs_is_trans && rhs_is_trans, + "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", + "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); + } + + TensorWrapper workspace_setup(setup_workspace_ptr, + std::vector{product(setup_workspace->dimensions())}, + DType::kByte); + TensorWrapper workspace_cublas(cublas_workspace_ptr, std::vector{workspace_size}, + DType::kByte); + + TensorWrapper alpha_tensor(static_cast(alpha.untyped_data()), + std::vector{num_gemms}, + convert_ffi_datatype_to_te_dtype(alpha.element_type())); + TensorWrapper beta_tensor(static_cast(beta.untyped_data()), + std::vector{num_gemms}, + convert_ffi_datatype_to_te_dtype(beta.element_type())); + + if (is_grouped_dense_wgrad) { + NVTE_CHECK(lhs_is_trans && !rhs_is_trans, + "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); + + //// RHS + NVTEShape rhsShape{.data = {k, n}, .ndim = 2}; + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); + rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + + //// LHS + NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; + lhs_is_trans = true; + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + + //// OUTPUT + NVTEShape outShape{.data = {num_gemms * m, n}, .ndim = 2}; + auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, + num_gemms, outShape); + + nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, + alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), + workspace_cublas.data(), + nullptr, // config (use defaults) + stream); + + return ffi_with_cuda_error_check(); + } + + // Nominal case for FWD or DGRAD + + //// RHS + NVTEShape rhsShape{.data = {num_gemms * k, n}, .ndim = 2}; + if (rhs_is_trans) { + rhsShape.data[0] = num_gemms * n; + rhsShape.data[1] = k; + } + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); + + //// LHS + NVTEShape lhsShape{.data = {m, k}, .ndim = 2}; + if (lhs_is_trans) { + std::swap(lhsShape.data[0], lhsShape.data[1]); + } + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, + lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); + + //// OUTPUT + NVTEShape outShape{.data = {m, n}, .ndim = 2}; + auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, + num_gemms, outShape); + out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + + nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, + alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), + workspace_cublas.data(), + nullptr, // config (use defaults) + stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs_data + .Arg() // lhs_sinv + .Arg() // rhs_data + .Arg() // rhs_sinv + .Arg() // bias + .Arg() // group_sizes (int32) + .Arg() // alpha + .Arg() // beta + .Ret() // output + .Ret() // cublas_workspace + .Ret() // setup_workspace + .Ret() // int64_workspace + .Attr("M") + .Attr("N") + .Attr("K") + .Attr("lhs_is_trans") + .Attr("rhs_is_trans") + .Attr("scaling_mode") + .Attr("is_grouped_dense_wgrad"), + GemmFFI_CudaGraph_Traits); + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, @@ -804,7 +1245,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type } #endif -// Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM + // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM size_t num_zero_outs = zero_out_dptr_list.size(); for (int i = 0; i < num_zero_outs; i++) { int stream_id = i % num_streams; diff --git a/transformer_engine/jax/csrc/extensions/inspect.cpp b/transformer_engine/jax/csrc/extensions/inspect.cpp new file mode 100644 index 000000000..9012cd054 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/inspect.cpp @@ -0,0 +1,99 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#include + +#include +#include + +#include "../extensions.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf, + Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf, + Result_Type output_buf) { + NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); + NVTE_CHECK(output_buf->untyped_data() != nullptr, + "Output must be provided for inspect operation"); + NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), + "Input and output must point to the same buffer for inspect operation"); + + std::vector input_data(input_buf.size_bytes()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), + input_buf.size_bytes(), cudaMemcpyDeviceToHost, stream)); + + float min_val{}, max_val{}, mean_val{}, std_val{}; + NVTE_CHECK_CUDA(cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float), + cudaMemcpyDeviceToHost, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float), + cudaMemcpyDeviceToHost, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float), + cudaMemcpyDeviceToHost, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float), + cudaMemcpyDeviceToHost, stream)); + + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + int device; + NVTE_CHECK_CUDA(cudaGetDevice(&device)); + + // Write the tensor data to a file as a binary blob + std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; + std::ofstream file(filename, std::ios::binary); + NVTE_CHECK(file.is_open(), "Failed to create file: ", filename); + file.write(reinterpret_cast(input_data.data()), input_data.size()); + file.close(); + + // Write out a metadata file + std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json"; + std::ofstream meta_file(meta_filename); + NVTE_CHECK(meta_file.is_open(), "Failed to create file: ", meta_filename); + meta_file << "{"; + meta_file << "\"shape\": ["; + for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { + meta_file << input_buf.dimensions()[i]; + if (i < input_buf.dimensions().size() - 1) { + meta_file << ", "; + } + } + meta_file << "], "; + meta_file << "\"dtype\": " << static_cast(input_buf.element_type()); + meta_file << ", \"min\": " << min_val; + meta_file << ", \"max\": " << max_val; + meta_file << ", \"mean\": " << mean_val; + meta_file << ", \"std\": " << std_val; + meta_file << "}"; + meta_file.close(); + + // Log the tensor metadata to the console + printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str()); + for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { + printf("%zu", static_cast(input_buf.dimensions()[i])); + if (i < input_buf.dimensions().size() - 1) { + printf(", "); + } + } + printf("], dtype: %d", static_cast(input_buf.element_type())); + printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // min + .Arg() // max + .Arg() // mean + .Arg() // std + .Ret() // output +); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index c9516dba5..acd90f5af 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -128,6 +128,11 @@ void hash_combine(int64_t &seed, const T &v, Rest... rest) { (hash_combine(seed, rest), ...); } +enum class JAXX_Score_Function : int64_t { + SIGMOID = 0, + SOFTMAX = 1, +}; + enum class JAXX_Collective_Op : int64_t { NONE = 0, ALL_GATHER = 1, diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 0c56bb088..c7b3e4678 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -11,6 +11,7 @@ #include "cgemm_helper.h" #endif //#ifndef USE_ROCM #include "common/util/cuda_runtime.h" +#include "transformer_engine/gemm.h" namespace transformer_engine { namespace jax { @@ -73,6 +74,10 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler), pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); + dict["te_gemm_v2_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(GemmInitV2Handler), + pybind11::arg("execute") = EncapsulateFFI(GemmV2Handler)); + // Grouped GEMM dict["te_grouped_gemm_d2h_group_sizes_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), @@ -80,6 +85,10 @@ pybind11::dict Registrations() { dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); + dict["te_grouped_gemm_v2_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GroupedGemmV2Handler)); + // Amax dict["te_rht_amax_ffi"] = pybind11::dict( pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), @@ -93,9 +102,26 @@ pybind11::dict Registrations() { dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler); dict["te_fused_attn_backward_ffi"] = EncapsulateFFI(FusedAttnBackwardHandler); + // GEMM dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); + dict["te_gemm_v2_ffi"] = EncapsulateFFI(GemmV2Handler); + + // Grouped GEMM + dict["te_grouped_gemm_d2h_group_sizes_ffi"] = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler); dict["te_grouped_gemm_ffi"] = EncapsulateFFI(GroupedGemmHandler); + dict["te_grouped_gemm_v2_ffi"] = EncapsulateFFI(GroupedGemmV2Handler); #endif + dict["te_inspect_ffi"] = + pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler)); + + // Router + dict["te_fused_topk_with_score_function_forward_ffi"] = + EncapsulateFFI(FusedTopkWithScoreFunctionForwardHandler); + dict["te_fused_topk_with_score_function_backward_ffi"] = + EncapsulateFFI(FusedTopkWithScoreFunctionBackwardHandler); + dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); + dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); + return dict; } @@ -122,6 +148,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { #ifndef USE_ROCM m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); + m.def("get_grouped_gemm_setup_workspace_size", &nvte_get_grouped_gemm_setup_workspace_size); #endif pybind11::enum_(m, "DType", pybind11::module_local()) @@ -171,6 +198,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) .value("GELU", NVTE_Activation_Type::GELU) .value("GEGLU", NVTE_Activation_Type::GEGLU) + .value("GLU", NVTE_Activation_Type::GLU) .value("SILU", NVTE_Activation_Type::SILU) .value("SWIGLU", NVTE_Activation_Type::SWIGLU) .value("RELU", NVTE_Activation_Type::RELU) @@ -215,6 +243,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("ROWWISE_COLWISE", JAXX_Quantize_Layout::ROWWISE_COLWISE) .export_values(); + pybind11::enum_(m, "JAXX_Score_Function", pybind11::module_local()) + .value("SIGMOID", JAXX_Score_Function::SIGMOID) + .value("SOFTMAX", JAXX_Score_Function::SOFTMAX) + .export_values(); + pybind11::enum_(m, "JAXX_Collective_Op", pybind11::module_local()) .value("NONE", JAXX_Collective_Op::NONE) .value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER) diff --git a/transformer_engine/jax/csrc/extensions/router.cpp b/transformer_engine/jax/csrc/extensions/router.cpp new file mode 100644 index 000000000..c81671f10 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/router.cpp @@ -0,0 +1,252 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "../extensions.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +// ============================================================================ +// Fused Top-K with Score Function - Forward +// ============================================================================ + +Error_Type FusedTopkWithScoreFunctionForwardFFI( + cudaStream_t stream, + Buffer_Type logits_buf, // [num_tokens, num_experts] + Buffer_Type expert_bias_buf, // [num_experts] or empty + Result_Type probs_buf, // [num_tokens, num_experts] (or scores when compute_aux_scores) + Result_Type routing_map_buf, // [num_tokens, num_experts] + Result_Type intermediate_buf, // [num_tokens, num_experts] + int64_t topk, int64_t use_pre_softmax, int64_t num_groups, int64_t group_topk, + double scaling_factor, JAXX_Score_Function score_function, int64_t compute_aux_scores) { + auto dtype = convert_ffi_datatype_to_te_dtype(logits_buf.element_type()); + auto dims = logits_buf.dimensions(); + auto num_tokens = static_cast(product(dims, 0, dims.size() - 1)); + auto num_experts = static_cast(dims[dims.size() - 1]); + + auto *logits = logits_buf.untyped_data(); + auto *expert_bias = expert_bias_buf.untyped_data(); + auto *probs = probs_buf->untyped_data(); + auto *routing_map = routing_map_buf->untyped_data(); + auto *intermediate = intermediate_buf->untyped_data(); + + auto flat_shape = + std::vector{static_cast(num_tokens), static_cast(num_experts)}; + auto logits_tensor = TensorWrapper(logits, flat_shape, dtype); + auto probs_tensor = TensorWrapper(probs, flat_shape, dtype); + auto routing_map_tensor = TensorWrapper(routing_map, flat_shape, DType::kByte); + // intermediate is always float32 (CompType) regardless of logits dtype. + auto intermediate_dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf->element_type()); + NVTE_CHECK( + intermediate_dtype == DType::kFloat32, + "intermediate_output must be float32 (CompType); got dtype ", + static_cast(intermediate_dtype), + ". Check FusedTopkWithScoreFunctionFwdPrimitive.abstract in cpp_extensions/router.py."); + auto intermediate_tensor = TensorWrapper(intermediate, flat_shape, DType::kFloat32); + + if (compute_aux_scores) { + nvte_fused_score_for_moe_aux_loss_forward( + logits_tensor.data(), num_tokens, num_experts, static_cast(topk), + static_cast(score_function), probs_tensor.data(), routing_map_tensor.data(), + intermediate_tensor.data(), stream); + } else { + auto bias_dims = expert_bias_buf.dimensions(); + auto expert_bias_tensor = + (bias_dims.size() > 0 && bias_dims[0] > 0) + ? TensorWrapper(expert_bias, std::vector{static_cast(bias_dims[0])}, + convert_ffi_datatype_to_te_dtype(expert_bias_buf.element_type())) + : TensorWrapper(); + + nvte_fused_topk_with_score_function_forward( + logits_tensor.data(), num_tokens, num_experts, static_cast(topk), + static_cast(use_pre_softmax), static_cast(num_groups), + static_cast(group_topk), static_cast(scaling_factor), + static_cast(score_function), expert_bias_tensor.data(), probs_tensor.data(), + routing_map_tensor.data(), intermediate_tensor.data(), stream); + } + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionForwardHandler, + FusedTopkWithScoreFunctionForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // logits + .Arg() // expert_bias + .Ret() // probs (or scores) + .Ret() // routing_map + .Ret() // intermediate_output + .Attr("topk") + .Attr("use_pre_softmax") + .Attr("num_groups") + .Attr("group_topk") + .Attr("scaling_factor") + .Attr("score_function") + .Attr("compute_aux_scores"), + FFI_CudaGraph_Traits); + +// ============================================================================ +// Fused Top-K with Score Function - Backward +// ============================================================================ + +Error_Type FusedTopkWithScoreFunctionBackwardFFI( + cudaStream_t stream, + Buffer_Type routing_map_buf, // [num_tokens, num_experts] (unused when compute_aux_scores) + Buffer_Type intermediate_buf, // [num_tokens, num_experts] + Buffer_Type grad_probs_buf, // [num_tokens, num_experts] (grad_scores when compute_aux_scores) + Result_Type grad_logits_buf, // [num_tokens, num_experts] + int64_t topk, int64_t use_pre_softmax, double scaling_factor, + JAXX_Score_Function score_function, int64_t compute_aux_scores) { + // intermediate is always float32 (CompType) regardless of logits dtype. + auto intermediate_dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()); + NVTE_CHECK( + intermediate_dtype == DType::kFloat32, + "intermediate_output must be float32 (CompType); got dtype ", + static_cast(intermediate_dtype), + ". Check FusedTopkWithScoreFunctionFwdPrimitive.abstract in cpp_extensions/router.py."); + auto grad_dtype = convert_ffi_datatype_to_te_dtype(grad_probs_buf.element_type()); + auto dims = intermediate_buf.dimensions(); + auto num_tokens = static_cast(product(dims, 0, dims.size() - 1)); + auto num_experts = static_cast(dims[dims.size() - 1]); + + auto flat_shape = + std::vector{static_cast(num_tokens), static_cast(num_experts)}; + + auto intermediate_tensor = + TensorWrapper(intermediate_buf.untyped_data(), flat_shape, DType::kFloat32); + auto grad_probs_tensor = TensorWrapper(grad_probs_buf.untyped_data(), flat_shape, grad_dtype); + auto grad_logits_tensor = TensorWrapper(grad_logits_buf->untyped_data(), flat_shape, grad_dtype); + + if (compute_aux_scores) { + nvte_fused_score_for_moe_aux_loss_backward(intermediate_tensor.data(), grad_probs_tensor.data(), + num_tokens, num_experts, static_cast(topk), + static_cast(score_function), + grad_logits_tensor.data(), stream); + } else { + auto routing_map_tensor = + TensorWrapper(routing_map_buf.untyped_data(), flat_shape, DType::kByte); + + nvte_fused_topk_with_score_function_backward( + routing_map_tensor.data(), intermediate_tensor.data(), grad_probs_tensor.data(), num_tokens, + num_experts, static_cast(topk), static_cast(use_pre_softmax), + static_cast(scaling_factor), static_cast(score_function), + grad_logits_tensor.data(), stream); + } + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler, + FusedTopkWithScoreFunctionBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // routing_map + .Arg() // intermediate_output + .Arg() // grad_probs + .Ret() // grad_logits + .Attr("topk") + .Attr("use_pre_softmax") + .Attr("scaling_factor") + .Attr("score_function") + .Attr("compute_aux_scores"), + FFI_CudaGraph_Traits); + +// ============================================================================ +// Fused MoE Aux Loss - Forward +// ============================================================================ + +Error_Type FusedMoEAuxLossForwardFFI(cudaStream_t stream, + Buffer_Type probs_buf, // [num_tokens, num_experts] + Buffer_Type tokens_per_expert_buf, // [num_experts] + Result_Type aux_loss_buf, // scalar + Result_Type const_buf, // scalar + int64_t topk, double coeff) { + auto dtype = convert_ffi_datatype_to_te_dtype(probs_buf.element_type()); + auto probs_dims = probs_buf.dimensions(); + auto num_tokens = static_cast(probs_dims[0]); + auto num_experts = static_cast(probs_dims[1]); + + auto probs_shape = + std::vector{static_cast(num_tokens), static_cast(num_experts)}; + auto tpe_dtype = convert_ffi_datatype_to_te_dtype(tokens_per_expert_buf.element_type()); + auto tpe_shape = std::vector{static_cast(num_experts)}; + auto scalar_shape = std::vector{1}; + + auto probs_tensor = TensorWrapper(probs_buf.untyped_data(), probs_shape, dtype); + auto tpe_tensor = TensorWrapper(tokens_per_expert_buf.untyped_data(), tpe_shape, tpe_dtype); + auto aux_loss_tensor = TensorWrapper(aux_loss_buf->untyped_data(), scalar_shape, dtype); + auto const_buf_tensor = TensorWrapper(const_buf->untyped_data(), scalar_shape, DType::kFloat32); + + nvte_fused_moe_aux_loss_forward(probs_tensor.data(), tpe_tensor.data(), num_tokens, num_experts, + num_tokens, num_experts, static_cast(topk), + static_cast(coeff), aux_loss_tensor.data(), + const_buf_tensor.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler, FusedMoEAuxLossForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // probs + .Arg() // tokens_per_expert + .Ret() // aux_loss + .Ret() // const_buf + .Attr("topk") + .Attr("coeff"), + FFI_CudaGraph_Traits); + +// ============================================================================ +// Fused MoE Aux Loss - Backward +// ============================================================================ + +Error_Type FusedMoEAuxLossBackwardFFI(cudaStream_t stream, + Buffer_Type const_buf_in, // scalar float32 + Buffer_Type tokens_per_expert_buf, // [num_experts] + Buffer_Type grad_aux_loss_buf, // scalar + Result_Type grad_probs_buf) { // [num_tokens, num_experts] + auto grad_dtype = convert_ffi_datatype_to_te_dtype(grad_aux_loss_buf.element_type()); + auto tpe_dtype = convert_ffi_datatype_to_te_dtype(tokens_per_expert_buf.element_type()); + + auto grad_probs_dims = grad_probs_buf->dimensions(); + auto num_tokens = static_cast(grad_probs_dims[0]); + auto num_experts = static_cast(grad_probs_dims[1]); + + auto scalar_shape = std::vector{1}; + auto tpe_dims = tokens_per_expert_buf.dimensions(); + auto tpe_shape = std::vector{static_cast(tpe_dims[0])}; + auto grad_probs_shape = + std::vector{static_cast(num_tokens), static_cast(num_experts)}; + + auto const_buf_tensor = TensorWrapper(const_buf_in.untyped_data(), scalar_shape, DType::kFloat32); + auto tpe_tensor = TensorWrapper(tokens_per_expert_buf.untyped_data(), tpe_shape, tpe_dtype); + auto grad_aux_loss_tensor = + TensorWrapper(grad_aux_loss_buf.untyped_data(), scalar_shape, grad_dtype); + auto grad_probs_tensor = + TensorWrapper(grad_probs_buf->untyped_data(), grad_probs_shape, grad_dtype); + + nvte_fused_moe_aux_loss_backward(const_buf_tensor.data(), tpe_tensor.data(), num_tokens, + num_experts, grad_aux_loss_tensor.data(), + grad_probs_tensor.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler, FusedMoEAuxLossBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // const_buf + .Arg() // tokens_per_expert + .Arg() // grad_aux_loss + .Ret(), // grad_probs + FFI_CudaGraph_Traits); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/debug/__init__.py b/transformer_engine/jax/debug/__init__.py new file mode 100644 index 000000000..7fcf194d7 --- /dev/null +++ b/transformer_engine/jax/debug/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""EXPERIMENTAL debugging utilities for Transformer Engine JAX. + +This API is experimental and may change or be removed without deprecation in future releases. +""" + +__all__ = [ + "experimental", +] diff --git a/transformer_engine/jax/debug/experimental/__init__.py b/transformer_engine/jax/debug/experimental/__init__.py new file mode 100644 index 000000000..44a484766 --- /dev/null +++ b/transformer_engine/jax/debug/experimental/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""EXPERIMENTAL debugging utilities for Transformer Engine JAX. + +This API is experimental and may change or be removed without deprecation in future releases. +""" + +from .inspect import inspect_array, load_array_dump + +__all__ = [ + "inspect_array", + "load_array_dump", +] diff --git a/transformer_engine/jax/debug/experimental/inspect.py b/transformer_engine/jax/debug/experimental/inspect.py new file mode 100644 index 000000000..9ce46426c --- /dev/null +++ b/transformer_engine/jax/debug/experimental/inspect.py @@ -0,0 +1,174 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Experimental JAX array inspection utilities.""" + +from functools import partial + +import jax +import jax.numpy as jnp +from jax import ffi + +from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive + +__all__ = ["inspect_array", "load_array_dump"] + + +class InspectPrimitive(BasePrimitive): + """ + No-op used for inspect array values. + """ + + name = "te_inspect_ffi" + multiple_results = False + impl_static_args = () + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + x_min_aval, + x_max_aval, + x_mean_aval, + x_std_aval, + ): + """ + inspect abstract + """ + assert ( + x_min_aval.shape == () and x_min_aval.dtype == jnp.float32 + ), "x_min must be a scalar with dtype float32" + assert ( + x_max_aval.shape == () and x_max_aval.dtype == jnp.float32 + ), "x_max must be a scalar with dtype float32" + assert ( + x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32 + ), "x_mean must be a scalar with dtype float32" + assert ( + x_std_aval.shape == () and x_std_aval.dtype == jnp.float32 + ), "x_std must be a scalar with dtype float32" + return x_aval + + @staticmethod + def lowering( + ctx, + x, + x_min, + x_max, + x_mean, + x_std, + ): + """ + inspect lowering rules + """ + + return ffi.ffi_lowering( + InspectPrimitive.name, + operand_output_aliases={0: 0}, # donate input buffer to output buffer + )( + ctx, + x, + x_min, + x_max, + x_mean, + x_std, + ) + + @staticmethod + def impl( + x, + x_min, + x_max, + x_mean, + x_std, + ): + """ + inspect implementation + """ + assert InspectPrimitive.inner_primitive is not None + (x) = InspectPrimitive.inner_primitive.bind( + x, + x_min, + x_max, + x_mean, + x_std, + ) + return x + + +register_primitive(InspectPrimitive) + + +def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray: + assert InspectPrimitive.outer_primitive is not None, ( + "InspectPrimitive FFI is not registered. Please ensure the C++ extension is properly built" + " and registered." + ) + return InspectPrimitive.outer_primitive.bind( + x, + jnp.min(x).astype(jnp.float32), + jnp.max(x).astype(jnp.float32), + jnp.mean(x.astype(jnp.float32)), + jnp.std(x.astype(jnp.float32)), + ) + + +@partial(jax.custom_vjp, nondiff_argnums=()) +def _inspect( + x, +): + """ """ + output, _ = _inspect_fwd_rule( + x, + ) + return output + + +def _inspect_fwd_rule( + x, +): + """""" + ctx = () + x = _inspect_array_inner(x) + return x, ctx + + +def _inspect_bwd_rule( + ctx, + grad, +): + """""" + del ctx + return (grad,) + + +_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule) + + +def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: + """Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics. + + Args: + x (jnp.ndarray): The JAX array to inspect. + name (str): The name of the array for identification in the output. + """ + del name # Name is currently unused, but can be included in the future for more informative output + return _inspect(x) + + +def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray: + """Utility function to load a JAX array from a dumped binary file. + + Args: + filename (str): The path to the binary file containing the array data. + shape (tuple): The shape of the array to be loaded. + dtype (jnp.dtype): The data type of the array to be loaded. + + Returns: + jnp.ndarray: The loaded JAX array. + """ + with open(filename, "rb") as f: + data = f.read() + array = jnp.frombuffer(data, dtype=dtype).reshape(shape) + return array diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 23d91f7db..fe02e61fc 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -94,6 +94,13 @@ def dense( if transpose_batch_sequence: warnings.warn("transpose_batch_sequence is not well tested, use with caution!") + if collective_op_set != tex.noop_collective_op_set and not output_axes: + warnings.warn( + "Collective GEMM with Shardy propagation may produce an incorrect sharding pattern" + " for the output. Set `output_axes` to apply the correct sharding constraint.", + UserWarning, + ) + if quantizer_set == noop_quantizer_set: input_dtype = x.dtype kernel = kernel.astype(input_dtype) @@ -210,30 +217,25 @@ def _dense_fwd_rule( casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # GEMM NN - use_bias = bias is not None output = tex.gemm( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), + bias=bias, contracting_dims=(x_contracting_dims, k_contracting_dims), transpose_batch_sequence=transpose_batch_sequence, - bias=bias if not tex.gemm_uses_jax_dot() else None, - fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, collective_op=collective_op_set.forward, ) output = with_sharding_constraint_by_logical_axes(output, output_axes) - if use_bias and tex.gemm_uses_jax_dot(): - bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape - output += jnp.reshape(bias, bias_new_shape) - + has_bias = bias is not None ctx = ( casted_x.get_tensor(usage=TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x), casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel), x.shape, kernel.shape, - use_bias, quantizer_set, flatten_axis_k, + has_bias, ) return output, ctx @@ -258,9 +260,9 @@ def _dense_bwd_rule( casted_kernel_rhs, x_shape, kernel_shape, - use_bias, quantizer_set, flatten_axis_k, + has_bias, ) = ctx grad = with_sharding_constraint_by_logical_axes(grad, output_axes) @@ -270,7 +272,7 @@ def _dense_bwd_rule( casted_grad, dbias = tex.quantize_dbias( grad, - is_dbias=use_bias, + is_dbias=has_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, amax_scope=AmaxScope.TPSP, diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index dd7d2a47b..92a968f06 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -4,7 +4,11 @@ """Transformer Engine bindings for JAX""" from .module import DenseGeneral, LayerNorm from .module import LayerNormDenseGeneral, LayerNormMLP -from .module import wrap_function_in_te_state_module, make_dot_general_cls +from .module import ( + wrap_function_in_te_state_module, + make_dot_general_cls, + make_grouped_dense_cls, +) from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -16,6 +20,7 @@ "LayerNormMLP", "wrap_function_in_te_state_module", "make_dot_general_cls", + "make_grouped_dense_cls", "extend_logical_axis_rules", "DotProductAttention", "MultiHeadAttention", diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 3d82d8f0b..31ce6e72e 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -17,7 +17,7 @@ from jax.ad_checkpoint import checkpoint_name -from ..dense import dense +from ..dense import dense, grouped_dense from ..layernorm import canonicalize_norm_type from ..layernorm import layernorm @@ -377,6 +377,7 @@ def generate_quantizer_set( variable_collection: str = None, quantization_checkpoint_name: Optional[str] = None, fp8_recipe=None, + n_groups: int = None, ): """ Generate a set of FP8 meta for a GEMM. @@ -409,6 +410,7 @@ def generate_quantizer_set( fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set, checkpoint_name=quantization_checkpoint_name, + n_groups=n_groups, ) return quantizer_set @@ -1379,12 +1381,13 @@ def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] class TEWrapper(te.flax.module.TransformerEngineBase): """Wrapper Flax module for TransformerEngine quantization support.""" - def generate_quantizer_set(self, postfix: str = ""): + def generate_quantizer_set(self, postfix: str = "", n_groups: int = None): OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" return super().generate_quantizer_set( postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, fp8_recipe=quantization_recipe, + n_groups=n_groups, ) @nn.compact @@ -1438,3 +1441,27 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): ) return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") + + +def make_grouped_dense_cls(quantization_recipe): + """Creates a grouped dense (grouped GEMM) instance for use with TE state module.""" + if quantization_recipe is not None: + raise ValueError("Ragged dot grouped GEMM does not support quantization yet") + + def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): + del kwargs # Unused + num_groups = group_sizes.shape[0] + quantizer_set = generate_quantizer_set(n_groups=num_groups) + + out = grouped_dense( + x, + kernel, + group_sizes=group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + ) + return out + + return wrap_function_in_te_state_module( + te_grouped_dot_general, quantization_recipe, "ragged_dot" + )() diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index ad5a60e4c..513677e4a 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -182,7 +182,9 @@ def __call__( is_gqa = h_q != h_kv if is_gqa: - assert (h_q % h_kv == 0) and (h_q >= h_kv) + assert (h_q % h_kv == 0) and ( + h_q >= h_kv + ), f"num_query_heads ({h_q}) must be divisible by and >= num_kv_heads ({h_kv})" group_size = h_q // h_kv grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) @@ -428,7 +430,9 @@ def __call__( if self.transpose_batch_sequence: x = x.transpose([1, 0, 2, 3]) - assert x.dtype == query.dtype + assert ( + x.dtype == query.dtype + ), f"output dtype {x.dtype} does not match query dtype {query.dtype}" return x @@ -713,9 +717,13 @@ def __call__( del self.attn_bias_type, self.attn_mask_type, self.qkv_layout if attn_bias_type == AttnBiasType.NO_BIAS: - assert bias is None + assert ( + bias is None + ), f"bias must be None when attn_bias_type is NO_BIAS, but got bias={bias}" else: - assert bias is not None + assert ( + bias is not None + ), f"bias must not be None when attn_bias_type is {attn_bias_type}" bias = bias.astype(input_dtype) self._assert_dtypes(query, key, value, qkv_layout) @@ -823,11 +831,13 @@ def __call__( key, value = jnp.split(key, [1], axis=-3) key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value]) else: - assert qkv_layout.is_separate() + assert ( + qkv_layout.is_separate() + ), f"Expected separate qkv_layout, but got {qkv_layout}" assert sequence_descriptor is None or isinstance( sequence_descriptor, (jnp.ndarray, np.ndarray) - ) + ), f"sequence_descriptor must be None or ndarray, but got {type(sequence_descriptor)}" x = _UnfusedDotProductAttention( attention_dropout=self.attention_dropout, @@ -994,7 +1004,7 @@ def _canonicalize_lora_scope(scope): SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP, - ] + ], f"Unsupported LoRA scope: {scope}" lora_scope = LoRAScope() @@ -1307,8 +1317,10 @@ def query_init(*args): return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0) def qkv_init(key, shape, dtype): - assert len(shape) == 3 - assert shape[-2] == 3 + assert ( + len(shape) == 3 + ), f"qkv_init expects 3D shape, but got {len(shape)}D shape {shape}" + assert shape[-2] == 3, f"qkv_init expects shape[-2] == 3, but got shape={shape}" q_key, k_key, v_key = jax_random.split(key, num=3) @@ -1323,8 +1335,8 @@ def qkv_init(key, shape, dtype): return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype) def kv_init(key, shape, dtype): - assert len(shape) == 3 - assert shape[-2] == 2 + assert len(shape) == 3, f"kv_init expects 3D shape, but got {len(shape)}D shape {shape}" + assert shape[-2] == 2, f"kv_init expects shape[-2] == 2, but got shape={shape}" k_key, v_key = jax_random.split(key) @@ -1415,7 +1427,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): )(inputs_q) if is_self_attn: - assert ln_out is not None + assert ln_out is not None, "ln_out must not be None for self-attention" inputs_kv = ln_out kv_proj = DenseGeneral( @@ -1475,7 +1487,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): )(inputs_q) if is_self_attn: - assert ln_out is not None + assert ln_out is not None, "ln_out must not be None for self-attention" inputs_kv = ln_out query = query.astype(input_dtype) @@ -1494,7 +1506,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): elif qkv_layout == QKVLayout.BSHD_BS2HD: key, value = jnp.split(kv_proj, [1], axis=-2) else: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + assert ( + qkv_layout == QKVLayout.BSHD_BSHD_BSHD + ), f"Expected QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact) query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) @@ -1520,7 +1534,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) if decode: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + assert ( + qkv_layout == QKVLayout.BSHD_BSHD_BSHD + ), f"decode mode requires QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" is_initialized = self.has_variable("cache", "cached_key") cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) @@ -1588,7 +1604,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint) dpa_args = [query, kv_proj, None] else: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + assert ( + qkv_layout == QKVLayout.BSHD_BSHD_BSHD + ), f"Expected QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) @@ -2101,7 +2119,9 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): l = inputs.shape[sequence_dim] attn_bias = rel_emb(l, l, False) - assert inputs.ndim == 3 + assert ( + inputs.ndim == 3 + ), f"inputs must be 3D (batch, sequence, hidden), but got {inputs.ndim}D" # Make name be the exactly same as T5X, since names would affect # RNGKey during init and apply. Myabe no need in the feature. @@ -2151,10 +2171,15 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode) def hidden_dropout(x, deterministic): - assert isinstance(self.hidden_dropout_dims, Sequence) + assert isinstance( + self.hidden_dropout_dims, Sequence + ), f"hidden_dropout_dims must be a Sequence, but got {type(self.hidden_dropout_dims)}" x_shape_len = len(x.shape) for dims in self.hidden_dropout_dims: - assert -x_shape_len <= dims < x_shape_len + assert -x_shape_len <= dims < x_shape_len, ( + f"hidden_dropout_dims value {dims} is out of range " + f"[{-x_shape_len}, {x_shape_len}) for input with {x_shape_len} dimensions" + ) return nn.Dropout( rate=self.hidden_dropout, @@ -2179,7 +2204,9 @@ def hidden_dropout(x, deterministic): )(x, deterministic=deterministic) if self.apply_residual_connection_post_layernorm: - assert ln_out is not None + assert ( + ln_out is not None + ), "ln_out must not be None when apply_residual_connection_post_layernorm is True" residual = ln_out x = x + residual @@ -2239,7 +2266,9 @@ def hidden_dropout(x, deterministic): y = hidden_dropout(y, deterministic) if self.apply_residual_connection_post_layernorm: - assert ln_out is not None + assert ( + ln_out is not None + ), "ln_out must not be None when apply_residual_connection_post_layernorm is True" residual = ln_out mlp_input = y + residual @@ -2284,7 +2313,9 @@ def hidden_dropout(x, deterministic): )(mlp_input, deterministic=deterministic) if self.apply_residual_connection_post_layernorm: - assert ln_out is not None + assert ( + ln_out is not None + ), "ln_out must not be None when apply_residual_connection_post_layernorm is True" residual = ln_out z = with_sharding_constraint_by_logical_axes( diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index 3f3f3802d..0f173a89e 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -31,7 +31,11 @@ def canonicalize_norm_type(x): Canonicalized normalization type string """ canonicalized = x.lower().strip().replace("-", "").replace("_", "") - assert canonicalized in ["layernorm", "rmsnorm"] + if canonicalized not in ["layernorm", "rmsnorm"]: + raise ValueError( + f"Unsupported normalization type '{x}' (canonicalized: '{canonicalized}'). " + "Valid options are: 'layernorm', 'rmsnorm'." + ) return canonicalized diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 8c21496ff..63e6daf9d 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -220,20 +220,15 @@ def _layernorm_dense_fwd_rule( # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out...) - use_bias = bias is not None output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), transpose_batch_sequence=transpose_batch_sequence, - bias=bias if not tex.gemm_uses_jax_dot() else None, - fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + bias=bias, ) - if use_bias and tex.gemm_uses_jax_dot(): - bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape - output += jnp.reshape(bias, bias_new_shape) - + has_bias = bias is not None ctx = ( casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x), casted_kernel.get_tensor(TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel), @@ -246,7 +241,7 @@ def _layernorm_dense_fwd_rule( beta, x_contracting_dims, k_contracting_dims, - use_bias, + has_bias, quantizer_set, flatten_axis, ) @@ -289,14 +284,14 @@ def _layernorm_dense_bwd_rule( beta, x_contracting_dims_in_fwd, k_contracting_dims_in_fwd, - use_bias, + has_bias, quantizer_set, flatten_axis, ) = ctx casted_grad, dbias = tex.quantize_dbias( grad, - is_dbias=use_bias, + is_dbias=has_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad, amax_scope=AmaxScope.TPSP, diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index f76ee54b8..4c324c208 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -15,6 +15,7 @@ from typing import List, Tuple, Sequence, Union, Callable from functools import partial +import warnings import jax import jax.numpy as jnp @@ -275,6 +276,13 @@ def _layernorm_mlp_fwd_rule( assert not collective_op_set_1.forward.is_reduce_scatter assert not collective_op_set_2.forward.is_all_gather + if collective_op_set_1 != tex.noop_collective_op_set and not dot_2_input_axes: + warnings.warn( + "Collective GEMM with Shardy propagation may produce an incorrect sharding pattern" + " for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.", + UserWarning, + ) + # x should be in shape of (batch..., hidden) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) # Kernel_2 should be in shape of (intermediate, hidden_in) @@ -287,8 +295,8 @@ def _layernorm_mlp_fwd_rule( assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] - use_bias_1 = bias_1 is not None - use_bias_2 = bias_2 is not None + has_bias_1 = bias_1 is not None + has_bias_2 = bias_2 is not None x = with_sharding_constraint_by_logical_axes(x, norm_input_axes) @@ -320,16 +328,10 @@ def _layernorm_mlp_fwd_rule( casted_kernel_1.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), transpose_batch_sequence=transpose_batch_sequence, - bias=bias_1 if not tex.gemm_uses_jax_dot() else None, - fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, + bias=bias_1, collective_op=collective_op_set_1.forward, ) - if use_bias_1 and tex.gemm_uses_jax_dot(): - bias_1_shape = bias_1.shape - bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape - dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) - # This sharding constraint is needed to correct the Shardy sharding propagation if dot_2_input_axes is not None: dot_1_output_axes = ( @@ -369,16 +371,10 @@ def _layernorm_mlp_fwd_rule( casted_kernel_2.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), transpose_batch_sequence=transpose_batch_sequence, - bias=bias_2 if not tex.gemm_uses_jax_dot() else None, - fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, + bias=bias_2, collective_op=collective_op_set_2.forward, ) - if use_bias_2 and tex.gemm_uses_jax_dot(): - bias_2_shape = bias_2.shape - bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape - dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) - # sharding of outputs should be the same as dot_1's input dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) @@ -398,8 +394,8 @@ def _layernorm_mlp_fwd_rule( k_contracting_dims, kernel_1.shape, kernel_2.shape, - use_bias_1, - use_bias_2, + has_bias_1, + has_bias_2, quantizer_sets, ) @@ -453,8 +449,8 @@ def _layernorm_mlp_bwd_rule( k_contracting_dims_in_fwd, kernel_1_shape, kernel_2_shape, - use_bias_1, - use_bias_2, + has_bias_1, + has_bias_2, quantizer_sets, ) = ctx @@ -469,7 +465,7 @@ def _layernorm_mlp_bwd_rule( casted_grad, dbias_2 = tex.quantize_dbias( grad, - is_dbias=use_bias_2, + is_dbias=has_bias_2, quantizer=ffn1_quantizer_set.dgrad, amax_scope=AmaxScope.TPSP, transpose_batch_sequence=transpose_batch_sequence, @@ -514,7 +510,7 @@ def _layernorm_mlp_bwd_rule( dgrad_2, dot_1_output, activation_type=activation_type, - is_dbias=use_bias_1, + is_dbias=has_bias_1, quantizer=ffn2_quantizer_set.dgrad, act_params=( tex.activation.ActivationParams.create(activation_type, **activation_params) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 405d5f766..6a0a3229d 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -52,7 +52,7 @@ def token_dispatch( Optional[jnp.ndarray], jnp.ndarray, Optional[jnp.ndarray], - Optional[jnp.ndarray], + jnp.ndarray, ]: """ Dispatch tokens to experts based on routing map. @@ -101,9 +101,11 @@ def token_dispatch( pad_offsets : Optional[jnp.ndarray] Per-expert cumulative padding offsets of shape [num_experts] when using padding, None otherwise. Pass this to token_combine when unpadding is needed. - target_tokens_per_expert : Optional[jnp.ndarray] - Aligned token counts per expert of shape [num_experts] when using padding, - None otherwise. + tokens_per_expert : jnp.ndarray + Token counts per expert of shape [num_experts]: + - Without padding: actual token counts (sum of routing_map columns) + - With padding: aligned token counts (ceil(actual / align_size) * align_size) + This gives the effective number of tokens per expert in the output buffer. Note ---- @@ -151,10 +153,10 @@ def _token_dispatch( Optional[jnp.ndarray], jnp.ndarray, Optional[jnp.ndarray], - Optional[jnp.ndarray], + jnp.ndarray, ]: """Internal token_dispatch with custom VJP.""" - (output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = ( + (output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert), _ = ( _token_dispatch_fwd_rule( inp, routing_map, @@ -165,7 +167,7 @@ def _token_dispatch( use_padding, ) ) - return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert + return output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert def _token_dispatch_fwd_rule( @@ -182,7 +184,7 @@ def _token_dispatch_fwd_rule( Optional[jnp.ndarray], jnp.ndarray, Optional[jnp.ndarray], - Optional[jnp.ndarray], + jnp.ndarray, ], Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool], ]: @@ -212,11 +214,11 @@ def _token_dispatch_fwd_rule( with_probs = probs is not None - if use_padding: - # Compute tokens_per_expert internally from routing_map - # This can be a traced value since output shape uses worst_case_out_tokens - tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32) + # Compute tokens_per_expert from routing_map (actual counts) + # This is well-optimized by XLA as a simple column-wise reduction + tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32) + if use_padding: # Calculate aligned token counts per expert target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype( jnp.int32 @@ -242,10 +244,12 @@ def _token_dispatch_fwd_rule( hidden_size, align_size=align_size, ) + + # Return aligned counts when using padding + out_tokens_per_expert = target_tokens_per_expert else: # No padding pad_offsets = None - target_tokens_per_expert = None output, permuted_probs = permute_with_mask_map( inp, @@ -257,14 +261,20 @@ def _token_dispatch_fwd_rule( hidden_size, ) + # Return actual counts when not using padding + out_tokens_per_expert = tokens_per_expert + # Return (primals, residuals) + # out_tokens_per_expert is: + # - target_tokens_per_expert (aligned) when using padding + # - tokens_per_expert (actual) when not using padding residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs) return ( output, permuted_probs, row_id_map, pad_offsets, - target_tokens_per_expert, + out_tokens_per_expert, ), residuals @@ -571,7 +581,7 @@ def sort_chunks_by_index( return _sort_chunks_by_index(inp, split_sizes, sorted_indices) -@partial(jax.custom_vjp, nondiff_argnums=(1, 2)) +@jax.custom_vjp def _sort_chunks_by_index( inp: jnp.ndarray, split_sizes: jnp.ndarray, @@ -586,7 +596,7 @@ def _sort_chunks_by_index_fwd_rule( inp: jnp.ndarray, split_sizes: jnp.ndarray, sorted_indices: jnp.ndarray, -) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]: +) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int]]: """Forward pass rule for sort_chunks_by_index.""" # Validate input dimensions assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D" @@ -608,18 +618,17 @@ def _sort_chunks_by_index_fwd_rule( ) # Return (primals, residuals) - residuals = (row_id_map, num_tokens, hidden_size) + # Include split_sizes and sorted_indices in residuals since we removed nondiff_argnums + residuals = (row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size) return (output, row_id_map), residuals def _sort_chunks_by_index_bwd_rule( - _split_sizes: jnp.ndarray, - _sorted_indices: jnp.ndarray, - residuals: Tuple[jnp.ndarray, int, int], + residuals: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int], g: Tuple[jnp.ndarray, jnp.ndarray], -) -> Tuple[jnp.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Backward pass rule for sort_chunks_by_index.""" - row_id_map, num_tokens, hidden_size = residuals + row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size = residuals output_grad, _ = g # Backward: reverse the sort @@ -632,7 +641,12 @@ def _sort_chunks_by_index_bwd_rule( is_forward=False, ) - return (inp_grad,) + # Return gradients for all inputs: (inp, split_sizes, sorted_indices) + # split_sizes and sorted_indices are integer arrays, so their gradients are zeros + split_sizes_grad = jnp.zeros_like(split_sizes, dtype=split_sizes.dtype) + sorted_indices_grad = jnp.zeros_like(sorted_indices, dtype=sorted_indices.dtype) + + return (inp_grad, split_sizes_grad, sorted_indices_grad) _sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 366c31726..fa481f795 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -16,11 +16,9 @@ from enum import Enum import hashlib from typing import Optional, Tuple, Dict, Union, Sequence, Type, List -from functools import reduce, lru_cache +from functools import reduce import operator -from importlib.metadata import version as get_pkg_version import warnings -from packaging.version import Version as PkgVersion import jax import jax.numpy as jnp @@ -49,6 +47,7 @@ get_all_mesh_axes, with_sharding_constraint, ) +from transformer_engine.jax.version_utils import jax_version_meet_requirement from .metadata import QuantizeMeta from .scaling_modes import ScalingMode @@ -61,6 +60,8 @@ "fp8_autocast", "is_fp8_available", "is_scaling_mode_supported", + "is_quantize_recipe_supported", + "get_quantization_recipe", "get_supported_scaling_modes", "get_supported_quantization_recipes", "update_collections", @@ -77,16 +78,6 @@ NVTE_FP8_COLLECTION_NAME = "fp8_metas" -@lru_cache(maxsize=None) -def _jax_version_meet_requirement(version: str): - """ - Helper function checking if required JAX version is available - """ - jax_version = PkgVersion(get_pkg_version("jax")) - jax_version_required = PkgVersion(version) - return jax_version >= jax_version_required - - def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: """Check if delayed scaling FP8 is supported on the given GPU architecture. @@ -129,7 +120,7 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: return False, "CublasLt version 12.8.0 or higher required for MXFP8 execution." if get_cuda_version() < 12080: return False, "Cuda version 12.8 or higher required for MXFP8 execution." - if not _jax_version_meet_requirement("0.5.3"): + if not jax_version_meet_requirement("0.5.3"): return False, "Jax version 0.5.3 or higher required for MXFP8 execution." return True, "" @@ -144,7 +135,7 @@ def _check_fp4_support(gpu_arch) -> Tuple[bool, str]: return False, "CublasLt version 12.8.0 or higher required for NVFP4 execution." if get_cuda_version() < 12080: return False, "Cuda version 12.8 or higher required for NVFP4 execution." - if not _jax_version_meet_requirement("0.5.3"): + if not jax_version_meet_requirement("0.5.3"): return False, "Jax version 0.5.3 or higher required for NVFP4 execution." return True, "" @@ -193,6 +184,54 @@ def is_scaling_mode_supported( return _is_scaling_mode_supported[scaling_mode], _reason_for_no_scaling_mode[scaling_mode] +_RECIPE_NAME_TO_RECIPE = { + "DelayedScaling": DelayedScaling, + "Float8CurrentScaling": Float8CurrentScaling, + "MXFP8BlockScaling": MXFP8BlockScaling, + "NVFP4BlockScaling": NVFP4BlockScaling, +} + + +def get_quantization_recipe(name: str) -> Recipe: + """Return a recipe object from a recipe name string. + + Args: + name: Recipe name. One of "DelayedScaling", "Float8CurrentScaling", + "MXFP8BlockScaling", or "NVFP4BlockScaling". + + Returns: + A new instance of the corresponding recipe class. + + Raises: + ValueError: If ``name`` does not match any known recipe. + """ + recipe_cls = _RECIPE_NAME_TO_RECIPE.get(name) + if recipe_cls is None: + valid = list(_RECIPE_NAME_TO_RECIPE) + raise ValueError(f"Invalid quantization recipe '{name}'. Valid options: {valid}") + return recipe_cls() + + +def is_quantize_recipe_supported(recipe_name: str) -> Tuple[bool, str]: + """Check if the given quantization recipe (by name) is supported on the current GPU. + + Args: + recipe_name: Name of the recipe, e.g. "DelayedScaling", "Float8CurrentScaling", + "MXFP8BlockScaling", "NVFP4BlockScaling". + + Returns: + A tuple of (supported: bool, reason: str). + """ + recipe = get_quantization_recipe(recipe_name) + config = get_quantize_config_with_recipe(recipe) + for tensor_source in TensorSource: + scaling_mode = config.get_scaling_mode(tensor_source) + is_supported, reason = is_scaling_mode_supported(scaling_mode) + if not is_supported: + return is_supported, reason + return True, None + + def is_fp8_available( scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, gpu_id=None, @@ -294,9 +333,6 @@ class BaseQuantizeConfig(ABC): COLLECTION_NAME: Name of the collection for quantization metadata FWD_DTYPE: Forward pass data type BWD_DTYPE: Backward pass data type - FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass - FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients - FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients INFERENCE_MODE: Whether to enable optimization for inference AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling AMAX_COMPUTE_ALGO: Algorithm for AMAX computation @@ -307,9 +343,6 @@ class BaseQuantizeConfig(ABC): COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME FWD_DTYPE: DType = None BWD_DTYPE: DType = None - FP8_2X_ACC_FPROP: bool = False - FP8_2X_ACC_DGRAD: bool = False - FP8_2X_ACC_WGRAD: bool = False INFERENCE_MODE: bool = False # DelayedScaling @@ -455,9 +488,6 @@ def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: } self.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo] - self.FP8_2X_ACC_DGRAD = True - self.FP8_2X_ACC_WGRAD = True - def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.DELAYED_TENSOR_SCALING @@ -956,6 +986,7 @@ def apply_padding_to_scale_inv( unpadded_scale_shape = scaling_mode.get_scale_shape( data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) + assert scale_inv.shape == unpadded_scale_shape, ( f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " f"{scale_inv.shape}." diff --git a/transformer_engine/jax/router.py b/transformer_engine/jax/router.py new file mode 100644 index 000000000..65f2e8a7f --- /dev/null +++ b/transformer_engine/jax/router.py @@ -0,0 +1,318 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused MoE Router API for JAX. + +This module provides high-level fused router operations for Mixture of Experts (MoE) +models with proper automatic differentiation support. These wrap the CUDA kernels in +transformer_engine/common/fused_router/. + +Functions: + fused_topk_with_score_function: + Fused score_function + top-k selection. Supports softmax/sigmoid, + grouped top-k, expert bias, and scaling factor. When compute_aux_scores=True, + switches to the clean score-for-aux-loss kernel (no bias/groups/scaling, + dense output). + + fused_moe_aux_loss: + Compute the MoE auxiliary load-balancing loss scalar. +""" + +from functools import partial +from typing import Optional, Tuple, Union + +import jax +import jax.numpy as jnp + +from transformer_engine.jax.cpp_extensions.router import ( + ScoreFunction, + fused_topk_with_score_function_fwd, + fused_topk_with_score_function_bwd, + fused_moe_aux_loss_fwd, + fused_moe_aux_loss_bwd, +) + +__all__ = [ + "ScoreFunction", + "fused_topk_with_score_function", + "fused_moe_aux_loss", +] + + +def _validate_score_function(score_function: Union[str, ScoreFunction]) -> ScoreFunction: + """Validate and convert score_function to a ScoreFunction enum.""" + if isinstance(score_function, ScoreFunction): + return score_function + try: + return ScoreFunction[score_function.upper()] + except (KeyError, AttributeError): + raise ValueError( + "score_function must be 'softmax', 'sigmoid', or a ScoreFunction enum, " + f"got {score_function!r}" + ) from None + + +# ============================================================================= +# Fused Top-K with Score Function +# ============================================================================= + + +def fused_topk_with_score_function( + logits: jnp.ndarray, + topk: int, + use_pre_softmax: bool = False, + num_groups: int = -1, + group_topk: int = -1, + scaling_factor: float = 1.0, + score_function: Union[str, ScoreFunction] = ScoreFunction.SOFTMAX, + expert_bias: Optional[jnp.ndarray] = None, + compute_aux_scores: bool = False, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Fused top-k with score function router. + + When compute_aux_scores=False (default), runs the main routing kernel: + score_function(logits) -> [optional bias] -> top-k -> [optional post-softmax] -> scale. + Returns sparse probs (only top-k positions nonzero) and routing_map. + + When compute_aux_scores=True, runs the score-for-aux-loss kernel instead: + score_function(logits) -> top-k (clean, no bias/groups/scaling). + Returns dense scores (all expert positions) and routing_map. + The expert_bias, use_pre_softmax, num_groups, group_topk, and scaling_factor + parameters are ignored in this mode. + + Parameters + ---------- + logits : jnp.ndarray + Logits from the gating GEMM, shape [num_tokens, num_experts]. + topk : int + Number of top experts to select per token. + use_pre_softmax : bool + If True, apply softmax before top-k (only for softmax score function). Else, apply post top-k. + Ignored when compute_aux_scores=True. + num_groups : int + Number of groups for grouped top-k. <= 0 disables grouping (default). + Ignored when compute_aux_scores=True. + group_topk : int + Top-k at group level. <= 0 disables group-level selection (default). + Ignored when compute_aux_scores=True. + scaling_factor : float + Scaling factor applied to output probs. + Ignored when compute_aux_scores=True. + score_function : Union[str, ScoreFunction] + Score function: "softmax" / "sigmoid" or ScoreFunction.SOFTMAX / ScoreFunction.SIGMOID. + expert_bias : Optional[jnp.ndarray] + Expert bias, shape [num_experts]. Only used with sigmoid. + Ignored when compute_aux_scores=True. + compute_aux_scores : bool + If True, use the clean score-for-aux-loss kernel. Returns dense scores + over all experts instead of sparse probs. + + Returns + ------- + probs_or_scores : jnp.ndarray + When compute_aux_scores=False: Sparse probability tensor, shape [num_tokens, num_experts]. + Non-zero only at selected expert positions. + When compute_aux_scores=True: Dense score tensor, shape [num_tokens, num_experts]. + All expert positions contain scores. + routing_map : jnp.ndarray + Boolean mask, shape [num_tokens, num_experts]. + True at selected expert positions. + """ + if not isinstance(scaling_factor, (int, float)): + raise TypeError( + f"scaling_factor must be a Python float or int, not {type(scaling_factor).__name__}. " + "If you used jnp.sqrt() or similar, use math.sqrt() instead." + ) + + score_function = _validate_score_function(score_function) + + if compute_aux_scores: + expert_bias = jnp.empty((0,), dtype=logits.dtype) + use_pre_softmax = False + num_groups = -1 + group_topk = -1 + scaling_factor = 1.0 + else: + if expert_bias is not None and score_function != ScoreFunction.SIGMOID: + raise ValueError( + "expert_bias is only supported with score_function='sigmoid'. " + f"Got score_function='{score_function.name}'." + ) + if expert_bias is None: + expert_bias = jnp.empty((0,), dtype=logits.dtype) + + probs_or_scores, routing_map = _fused_topk_with_score_function( + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ) + + return probs_or_scores, routing_map + + +@partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5, 6, 7, 8)) +def _fused_topk_with_score_function( + logits: jnp.ndarray, + expert_bias: jnp.ndarray, + topk: int, + use_pre_softmax: bool, + num_groups: int, + group_topk: int, + scaling_factor: float, + score_function: ScoreFunction, + compute_aux_scores: bool, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + (probs, routing_map), _ = _fused_topk_with_score_function_fwd( + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ) + return probs, routing_map + + +def _fused_topk_with_score_function_fwd( + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, +): + probs, routing_map, saved_scores = fused_topk_with_score_function_fwd( + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + compute_aux_scores, + ) + residuals = (routing_map, saved_scores) + return (probs, routing_map), residuals + + +def _fused_topk_with_score_function_bwd( + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + residuals, + g, +): + del num_groups, group_topk + routing_map, saved_scores = residuals + grad_probs, _ = g + + grad_logits = fused_topk_with_score_function_bwd( + routing_map, + saved_scores, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + ) + return grad_logits, None + + +_fused_topk_with_score_function.defvjp( + _fused_topk_with_score_function_fwd, + _fused_topk_with_score_function_bwd, +) + + +# ============================================================================= +# Fused MoE Aux Loss +# ============================================================================= + + +def fused_moe_aux_loss( + probs: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + topk: int, + coeff: float, +) -> jnp.ndarray: + """ + Compute the MoE auxiliary load-balancing loss. + + loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens_per_expert[i]) + + where T = probs.shape[0] (num_tokens) and E = probs.shape[1] (num_experts). + + Parameters + ---------- + probs : jnp.ndarray + Probability/score tensor, shape [num_tokens, num_experts]. + tokens_per_expert : jnp.ndarray + Token counts per expert, shape [num_experts]. Integer tensor. + topk : int + Top-k value. + coeff : float + Loss coefficient. + + Returns + ------- + aux_loss : jnp.ndarray + Scalar loss value. + """ + return _fused_moe_aux_loss(probs, tokens_per_expert, topk, coeff) + + +@partial(jax.custom_vjp, nondiff_argnums=(2, 3)) +def _fused_moe_aux_loss( + probs: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + topk: int, + coeff: float, +) -> jnp.ndarray: + aux_loss, _ = _fused_moe_aux_loss_fwd(probs, tokens_per_expert, topk, coeff) + return aux_loss + + +def _fused_moe_aux_loss_fwd(probs, tokens_per_expert, topk, coeff): + aux_loss, const_buf = fused_moe_aux_loss_fwd(probs, tokens_per_expert, topk, coeff) + residuals = (const_buf, tokens_per_expert, probs.shape[0]) + return aux_loss, residuals + + +def _fused_moe_aux_loss_bwd(topk, coeff, residuals, g): + del topk, coeff + const_buf, tokens_per_expert, num_tokens = residuals + grad_aux_loss = g.reshape(1) + + grad_probs = fused_moe_aux_loss_bwd( + const_buf, + tokens_per_expert, + grad_aux_loss, + num_tokens, + ) + return grad_probs, None + + +_fused_moe_aux_loss.defvjp( + _fused_moe_aux_loss_fwd, + _fused_moe_aux_loss_bwd, +) diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index d9708fde9..150a5fbf1 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -54,6 +54,9 @@ def lowering(ctx, x, **kwargs): from transformer_engine.jax.triton_extensions import get_triton_info info = get_triton_info() print(f"Using Triton {info['version']} from {info['source']}") + + # Check if JAX version supports Triton (without importing triton_extensions) + from transformer_engine.jax.version_utils import is_triton_extension_supported """ from .utils import * diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index bd8bd8ff1..98c54e52b 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -65,8 +65,6 @@ class RowIdMapPass1Primitive(BasePrimitive): @staticmethod def abstract(routing_map_aval, *, num_tokens, num_experts, block_size): """Shape/dtype inference for pass 1.""" - del block_size # Only affects grid, not output shape - assert routing_map_aval.shape == ( num_tokens, num_experts, @@ -75,7 +73,7 @@ def abstract(routing_map_aval, *, num_tokens, num_experts, block_size): row_id_map_shape = (num_tokens, num_experts * 2 + 1) workspace_shape = ( num_experts, - triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE), + triton.cdiv(num_tokens, block_size), ) return ( @@ -134,9 +132,10 @@ def infer_sharding_from_operands( desc="RowIdMapPass1.row_id_map_sharding", ) # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so it must be sharded on the same axis as tokens workspace_sharding = NamedSharding( mesh, - PartitionSpec(None, None), + PartitionSpec(None, routing_map_spec[0]), desc="RowIdMapPass1.workspace_sharding", ) return [row_id_map_sharding, workspace_sharding] @@ -156,9 +155,11 @@ def partition(num_tokens, num_experts, block_size, mesh, arg_infos, result_infos PartitionSpec(routing_map_spec[0], None), desc="RowIdMapPass1.row_id_map_sharding", ) + # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so it must be sharded on the same axis as tokens workspace_sharding = NamedSharding( mesh, - PartitionSpec(None, None), + PartitionSpec(None, routing_map_spec[0]), desc="RowIdMapPass1.workspace_sharding", ) out_shardings = [row_id_map_sharding, workspace_sharding] @@ -186,7 +187,8 @@ def shardy_sharding_rule(num_tokens, num_experts, block_size, mesh, value_types, # Note: row_id_cols != experts since it's num_experts * 2 + 1 row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols") # workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) - workspace_spec = (f"{prefix}_experts", f"{prefix}_ws_blocks") + # Second dim depends on num_tokens, so use same factor to ensure same sharding + workspace_spec = (f"{prefix}_experts", f"{prefix}_tokens") return SdyShardingRule((input_spec,), (row_id_map_spec, workspace_spec)) @@ -208,10 +210,9 @@ class RowIdMapPass2Primitive(BasePrimitive): def abstract(row_id_map_aval, workspace_aval, *, num_tokens, num_experts, block_size): """Shape/dtype inference for pass 2 (in-place operation).""" del row_id_map_aval, workspace_aval - del block_size row_id_map_shape = (num_tokens, num_experts * 2 + 1) - workspace_shape = (num_experts, triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE)) + workspace_shape = (num_experts, triton.cdiv(num_tokens, block_size)) return ( jax.core.ShapedArray(row_id_map_shape, jnp.int32), @@ -270,9 +271,11 @@ def infer_sharding_from_operands( PartitionSpec(*row_id_map_spec), desc="RowIdMapPass2.row_id_map_sharding", ) + # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so it must be sharded on the same axis as tokens workspace_sharding = NamedSharding( mesh, - PartitionSpec(None, None), + PartitionSpec(None, row_id_map_spec[0]), desc="RowIdMapPass2.workspace_sharding", ) return [row_id_map_sharding, workspace_sharding] @@ -292,9 +295,11 @@ def partition(num_tokens, num_experts, block_size, mesh, arg_infos, result_infos PartitionSpec(*row_id_map_spec), desc="RowIdMapPass2.row_id_map_sharding", ) + # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so it must be sharded on the same axis as tokens workspace_sharding = NamedSharding( mesh, - PartitionSpec(None, None), + PartitionSpec(None, row_id_map_spec[0]), desc="RowIdMapPass2.workspace_sharding", ) out_shardings = [row_id_map_sharding, workspace_sharding] @@ -317,7 +322,9 @@ def shardy_sharding_rule(num_tokens, num_experts, block_size, mesh, value_types, del num_tokens, num_experts, block_size, mesh, value_types, result_types prefix = "RowIdMapPass2" row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_cols") - workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_ws_blocks") + # workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so use same factor to ensure same sharding + workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_tokens") return SdyShardingRule((row_id_map_spec, workspace_spec), (row_id_map_spec, workspace_spec)) @@ -1659,10 +1666,19 @@ class SortChunksByMapPrimitive(BasePrimitive): @staticmethod def abstract( - inp_aval, row_id_map_aval, probs_aval, *, num_tokens, hidden_size, is_forward, with_probs + inp_aval, + row_id_map_aval, + probs_aval, + output_buf_aval=None, # Pre-allocated output buffer (inner primitive only) + *, + num_tokens, + hidden_size, + is_forward, + with_probs, ): """Shape/dtype inference.""" del row_id_map_aval, is_forward + del output_buf_aval # Used for input_output_aliases only output_aval = jax.core.ShapedArray((num_tokens, hidden_size), inp_aval.dtype) @@ -1677,10 +1693,14 @@ def abstract( def impl(inp, row_id_map, probs, num_tokens, hidden_size, is_forward, with_probs): """Forward to inner primitive.""" assert SortChunksByMapPrimitive.inner_primitive is not None + + output_buf = jnp.empty((num_tokens, hidden_size), dtype=inp.dtype) + return SortChunksByMapPrimitive.inner_primitive.bind( inp, row_id_map, probs, + output_buf, num_tokens=num_tokens, hidden_size=hidden_size, is_forward=is_forward, @@ -1688,7 +1708,9 @@ def impl(inp, row_id_map, probs, num_tokens, hidden_size, is_forward, with_probs ) @staticmethod - def lowering(ctx, inp, row_id_map, probs, *, num_tokens, hidden_size, is_forward, with_probs): + def lowering( + ctx, inp, row_id_map, probs, output_buf, *, num_tokens, hidden_size, is_forward, with_probs + ): """MLIR lowering using triton_call_lowering.""" # Compute strides inp_stride_token = hidden_size @@ -1702,13 +1724,22 @@ def lowering(ctx, inp, row_id_map, probs, *, num_tokens, hidden_size, is_forward block_size = _get_min_block_size(_sort_chunks_by_map_kernel) grid = (num_tokens, triton.cdiv(hidden_size, block_size)) + # Declare input_output_aliases so XLA knows output slot 0 is claimed by + # input 3 (output_buf). This prevents XLA from implicitly aliasing any + # other input (like output_grad in backward) to the output buffer. + # Input indices: 0=inp, 1=row_id_map, 2=probs, 3=output_buf + # Output indices: 0=output, 1=permuted_probs + input_output_aliases = {3: 0} + return triton_call_lowering( ctx, _sort_chunks_by_map_kernel, inp, row_id_map, probs, + output_buf, grid=grid, + input_output_aliases=input_output_aliases, constexprs={ "stride_input_token": inp_stride_token, "stride_input_hidden": inp_stride_hidden, diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 6ea4092cb..28e3f08e1 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -36,10 +36,17 @@ from typing import Any, Callable, Mapping import zlib +from packaging import version + from jax import core import jax import jax.numpy as jnp +from ..version_utils import ( + TRITON_EXTENSION_MIN_JAX_VERSION, + is_triton_extension_supported, +) + # Placeholder package version on PyPI that should never be used _PYTORCH_TRITON_PLACEHOLDER_VERSION = "0.0.1" @@ -148,6 +155,21 @@ def _check_triton_compatibility(): # Perform compatibility check and get triton info _TRITON_VERSION, _IS_PYTORCH_TRITON = _check_triton_compatibility() +# Enforce minimum JAX version before importing gpu_triton. The segfault on old +# jaxlib occurs at Triton kernel dispatch time, not at import time, so gpu_triton +# itself is safe to import on older jaxlib. The guard is placed here (before the +# import) as a belt-and-suspenders measure so that if the import behaviour ever +# changes, we still fail fast with a clear error rather than a cryptic crash. +if not is_triton_extension_supported(): + raise RuntimeError( + f"JAX >= {TRITON_EXTENSION_MIN_JAX_VERSION} required for " + "transformer_engine.jax.triton_extensions. " + "Triton kernel dispatch segfaults with older jaxlib. " + f"Current jax version: {jax.__version__}. " + "Please upgrade: pip install --upgrade jax jaxlib. " + "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." + ) + try: from jax._src.lib import gpu_triton from triton.compiler import compiler as tc @@ -274,13 +296,16 @@ def compile_triton( return _TRITON_KERNEL_CACHE[cache_key] # Compile kernel + cuda_option_kwargs = {} + if version.parse(_TRITON_VERSION) < version.parse("3.6.0"): + cuda_option_kwargs["cluster_dims"] = (1, 1, 1) options = cb.CUDAOptions( num_warps=num_warps, num_stages=num_stages, num_ctas=num_ctas, - cluster_dims=(1, 1, 1), debug=False, enable_fp_fusion=enable_fp_fusion, + **cuda_option_kwargs, ) # Mark constants as constexpr in signature @@ -303,8 +328,6 @@ def compile_triton( # Create kernel object for JAX # From jax/jaxlib/gpu/triton_kernels.cc: - from packaging import version - if version.parse(jax.__version__) >= version.parse("0.8.2"): kernel = gpu_triton.TritonKernel( compiled.name, # arg0: kernel_name (str) diff --git a/transformer_engine/jax/version_utils.py b/transformer_engine/jax/version_utils.py new file mode 100644 index 000000000..04b7ff879 --- /dev/null +++ b/transformer_engine/jax/version_utils.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" +JAX version helpers. + +Provides version checks for JAX that can be used across TE JAX (quantize, triton +extensions, etc.) without pulling in feature-specific code. +""" + +from functools import lru_cache +from importlib.metadata import version as get_pkg_version + +from packaging.version import Version as PkgVersion + + +@lru_cache(maxsize=None) +def jax_version_meet_requirement(version: str): + """Return True if the installed JAX version is >= the required version.""" + jax_version = PkgVersion(get_pkg_version("jax")) + jax_version_required = PkgVersion(version) + return jax_version >= jax_version_required + + +# Minimum JAX version required for Triton kernel dispatch (jaxlib < 0.8.0 segfaults). +TRITON_EXTENSION_MIN_JAX_VERSION = "0.8.0" + + +def is_triton_extension_supported() -> bool: + """Return True if the current JAX version supports Triton kernel dispatch. + + JAX/jaxlib >= 0.8.0 is required. Older versions segfault when dispatching + Triton kernels. Use this to skip tests or gate features without importing + triton_extensions (which would raise immediately on old jax). + """ + return jax_version_meet_requirement(TRITON_EXTENSION_MIN_JAX_VERSION) + + +__all__ = [ + "jax_version_meet_requirement", + "is_triton_extension_supported", + "TRITON_EXTENSION_MIN_JAX_VERSION", +] diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 5e1eb6954..cd18ca75a 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -54,7 +54,11 @@ from transformer_engine.pytorch.graph import make_graphed_callables from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import CudaRNGStatesTracker -from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context +from transformer_engine.pytorch.cpu_offload import ( + get_cpu_offload_context, + mark_not_offload, + ManualOffloadSynchronizer, +) from transformer_engine.pytorch import ops from transformer_engine.pytorch import optimizers from transformer_engine.pytorch.export import onnx_export diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 2d6f1da7e..21058520b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -170,6 +170,11 @@ class FP8EmulationFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): # pylint: disable=missing-function-docstring + if is_in_onnx_export_mode(): + return FP8EmulationFunc.onnx_forward( + tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout + ) + if quantizer_name == "QKV_quantizer": query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] @@ -208,6 +213,47 @@ def backward(ctx, grad1, grad2, grad3): tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None + @staticmethod + def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout=None): + """ + ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations. + """ + # pylint: disable=unused-argument + is_qkv_quantizer = quantizer_name == "QKV_quantizer" + assert isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ), "ONNX FP8 emulation path supports only Float8 quantizers." + + if is_qkv_quantizer: + # Flatten + concatenate + quantize + split. Equivalent to combine_and_quantize Case 3. + orig_dtype = tensor1.dtype + shapes = [tensor1.shape, tensor2.shape, tensor3.shape] + numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()] + + # Flatten and concatenate + combined = torch.cat( + [tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0 + ) + + # Quantize + dequantize combined tensor using quantizer's ONNX methods + combined_fp8 = quantizer.onnx_quantize(combined) + out = quantizer.onnx_dequantize(combined_fp8).to(orig_dtype) + + # Split back + out1 = out[: numels[0]].reshape(shapes[0]) + out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1]) + out3 = out[numels[0] + numels[1] :].reshape(shapes[2]) + + return out1, out2, out3 + if quantizer_name in ["S_quantizer", "O_quantizer"]: + # Emulate FP8 on single tensor using quantizer's ONNX methods + orig_dtype = tensor1.dtype + t_fp8 = quantizer.onnx_quantize(tensor1) + out = quantizer.onnx_dequantize(t_fp8).to(orig_dtype) + return out, tensor2, tensor3 + # Pass-through + return tensor1, tensor2, tensor3 + class UnfusedDotProductAttention(torch.nn.Module): """Parallel attention w/o QKV and Proj Gemms @@ -255,6 +301,10 @@ def mask_func(x, y): bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None ) + def fast_setattr(self, name: str, value: Any) -> None: + """Fast attribute set for non-parameter fields.""" + self.__dict__[name] = value + def forward( self, _alibi_cache: Dict[str, Any], @@ -269,6 +319,7 @@ def forward( attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, @@ -354,6 +405,11 @@ def forward( attention_mask=attention_mask, window_size=window_size, attention_type=self.attention_type, + bottom_right_alignment=( + attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None + else bottom_right_diagonal + ), ) ) @@ -457,7 +513,11 @@ def forward( actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_alignment=( + attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None + else bottom_right_diagonal + ), ) matmul_result = torch.baddbmm( matmul_result, @@ -1118,6 +1178,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, fused_attention_backend, use_FAv2_bwd, @@ -1221,6 +1282,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, softmax_offset, cuda_graph=is_graph_capturing(), @@ -1298,6 +1360,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, softmax_offset, return_max_logit, @@ -1385,6 +1448,7 @@ def forward( ctx.attn_mask_type = attn_mask_type ctx.softmax_type = softmax_type ctx.window_size = window_size + ctx.bottom_right_diagonal = bottom_right_diagonal ctx.fused_attention_backend = ( fused_attention_backend if (IS_HIP_EXTENSION or ctx.fp8) else FusedAttnBackend["F16_arbitrary_seqlen"] ) @@ -1535,6 +1599,7 @@ def backward(ctx, d_out, *_args): ctx.attn_mask_type, ctx.softmax_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), ) @@ -1600,6 +1665,7 @@ def backward(ctx, d_out, *_args): ctx.attn_mask_type, ctx.softmax_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), ) @@ -1639,6 +1705,7 @@ def backward(ctx, d_out, *_args): None, None, None, + None, d_softmax_offset, None, None, @@ -1738,6 +1805,7 @@ def forward( attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -1945,6 +2013,7 @@ def forward( attn_mask_type, self.softmax_type, window_size, + bottom_right_diagonal, None, # rng_gen fused_attention_backend, use_FAv2_bwd, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index f8fad6993..6c3f3ea10 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -843,13 +843,24 @@ def cp_p2p_fwd_fused_attn( q_part = q_part.contiguous() if attn_bias is not None: idx = (rank - step) % cp_size - attn_bias_inputs = torch.cat( - ( - attn_bias_[..., 1, :, idx, :], - attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() + # For bias shape 111s, only the s_kv dim is split, i.e. [b, h, sq, 2*cp, sk//(2*cp)]) + if attn_bias.shape[-3] == 1: + attn_bias_inputs = torch.cat( + ( + attn_bias_[..., :, idx, :], + attn_bias_[..., :, (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + # For bias shapes 1hss, 11ss, bhss, b1ss, the s_kv and s_q dims are split, i.e. [b, h, 2, sq//2, 2*cp, sk//(2*cp)]) + else: + attn_bias_inputs = torch.cat( + ( + attn_bias_[..., 1, :, idx, :], + attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() max_seqlen_q_ = max_seqlen_q // 2 max_seqlen_kv_ = max_seqlen_kv cu_seqlens_q_ = cu_seqlens_q_per_step @@ -929,9 +940,9 @@ def cp_p2p_fwd_flash_attn( elif section == "upper-triangle": max_seqlen_q_ = max_seqlen_q // 2 if section in ["lower-triangle", "upper-triangle"]: - if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: + elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 @@ -1181,9 +1192,9 @@ def cp_p2p_bwd_flash_attn( ): """Per-tile backward call of CP P2P with FlashAttention backend""" dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] - if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: + elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 if not use_flash_attn_3: @@ -1193,9 +1204,9 @@ def cp_p2p_bwd_flash_attn( softmax_lse__ = softmax_lse causal_ = False if section == "diagonal": - if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, 0) - elif fa_utils.v2_7_0_plus: + elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = 0 causal_ = True @@ -1217,6 +1228,10 @@ def cp_p2p_bwd_flash_attn( dk=dk, dv=dv, ) + if use_flash_attn_3: + fa_backward_kwargs["is_causal"] = causal_ + else: + fa_backward_kwargs["causal"] = causal_ flash_attn_bwd( dout_part, q_part, @@ -1225,7 +1240,6 @@ def cp_p2p_bwd_flash_attn( out_part, softmax_lse__, *fa_backward_args_thd, - causal=causal_, **fa_backward_kwargs, ) @@ -1445,20 +1459,33 @@ def forward( attn_bias_ = None if attn_bias is not None: assert len(attn_bias.shape) == 4, ( - "Only support bias shape of [b, h, sq, sk] for forward, " - "and [1, h, sq, sk] for backward!" - ) - assert ( - attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 - ), "Sequence length does not meet divisible requirements!" - # [b, h, sq, sk] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] - attn_bias_ = attn_bias.view( - *attn_bias.shape[:-2], - 2, - attn_bias.shape[-2] // 2, - 2 * cp_size, - attn_bias.shape[-1] // (2 * cp_size), + "Only support bias shape of [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv], [b,h,sq,skv]," + " [1,1,1,skv] for forward, and [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv]," + " [b,h,sq,skv] for backward!" ) + # For all bias shapes except 111s, sq must be divisible by 2 and skv must be divisible by 2*cp_size + # For bias shape 111s, only skv must be divisible by 2*cp_size + if attn_bias.shape[-2] != 1: + assert ( + attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 + ), "Sequence length does not meet divisible requirements!" + # [b, h, sq, sk] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] + attn_bias_ = attn_bias.view( + *attn_bias.shape[:-2], + 2, + attn_bias.shape[-2] // 2, + 2 * cp_size, + attn_bias.shape[-1] // (2 * cp_size), + ) + else: + assert ( + attn_bias.shape[-1] % (2 * cp_size) == 0 + ), "Sequence length does not meet divisible requirements!" + # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] + attn_bias_ = attn_bias.view( + *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) + ) + # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] attn_bias = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) @@ -1487,7 +1514,8 @@ def forward( flash_attn_fwd = ( _flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment ) - fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = 0 if causal else -1 else: if qkv_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( @@ -2079,10 +2107,13 @@ def backward(ctx, dout, *_args): attn_dbias = torch.zeros( *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device ) - # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] - attn_dbias_ = attn_dbias.view( - *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:] - ) + # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] only when sq > 1 (i.e. all supported bias shapes except 111s) + if attn_dbias.shape[-3] > 1: + attn_dbias_ = attn_dbias.view( + *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:] + ) + else: + attn_dbias_ = None else: attn_dbias = None attn_dbias_ = None @@ -2510,8 +2541,8 @@ def backward(ctx, dout, *_args): elif i >= (cp_size - rank - 1): # [b, h, sq, sk//(2*cp)] attn_dbias[..., idx, :].copy_(dbias_) - else: - # [b, h, sq//2, sk//cp] -> [b, h, sq//2, 2, sk//(2*cp)] + elif attn_dbias_ is not None: + # upper-triangle: [b, h, sq//2, sk//cp] -> [b, h, sq//2, 2, sk//(2*cp)] dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :]) attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) @@ -2961,9 +2992,9 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv_, ) - if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size"] = window_size_per_step[i] - elif fa_utils.v2_7_0_plus: + elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( @@ -3182,13 +3213,15 @@ def backward(ctx, dout, *_args): ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = window_size_per_step[i] - elif fa_utils.v2_7_0_plus: + elif ctx.use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] + if ctx.use_flash_attn_3: + fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type + else: + fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type flash_attn_bwd( dout_, q_, @@ -3197,7 +3230,6 @@ def backward(ctx, dout, *_args): out_, softmax_lse_per_step[i], *fa_backward_args_thd, - causal="causal" in ctx.attn_mask_type, **fa_backward_kwargs, ) @@ -3337,7 +3369,8 @@ def forward( ) flash_attn_fwd = _flash_attn_fwd_v3 - fa_forward_kwargs["window_size"] = window_size + fa_forward_kwargs["window_size_left"] = window_size[0] + fa_forward_kwargs["window_size_right"] = window_size[1] else: if qkv_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( @@ -3714,7 +3747,8 @@ def backward(ctx, dout, *_args): flash_attn_bwd = ( _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment ) - fa_backward_kwargs["window_size"] = ctx.window_size + fa_backward_kwargs["window_size_left"] = ctx.window_size[0] + fa_backward_kwargs["window_size_right"] = ctx.window_size[1] fa_backward_kwargs["deterministic"] = ctx.deterministic else: if qkv_format == "thd": @@ -3797,6 +3831,10 @@ def backward(ctx, dout, *_args): ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state + fa_backward_kwargs["causal"] = causal + else: + fa_backward_kwargs["is_causal"] = causal + flash_attn_bwd( dout, q, @@ -3805,7 +3843,6 @@ def backward(ctx, dout, *_args): out, softmax_lse, *fa_backward_args_thd, - causal=causal, **fa_backward_kwargs, ) @@ -4029,28 +4066,30 @@ def attn_forward_func_with_cp( assert not sliding_window_attn or cp_comm_type in [ "a2a", "all_gather", - ], "Context parallelism does not support sliding window attention with {cp_comm_type=}!" + ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] assert not enable_mla or cp_comm_type in [ "p2p", "a2a+p2p", - ], "Context parallelism does not support MLA with {cp_comm_type=}!" + ], f"Context parallelism does not support MLA with {cp_comm_type=}!" if fp8 and fp8_meta is not None: if fp8_meta["recipe"].fp8_dpa: assert ( softmax_type == "vanilla" - ), "Context parallelism does not support {softmax_type=} with FP8 attention!" + ), f"Context parallelism does not support {softmax_type=} with FP8 attention!" assert ( softmax_type == "vanilla" or use_fused_attention - ), "Context parallelism only supports {softmax_type=} with FusedAttention backend!" + ), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!" assert ( softmax_type == "vanilla" or cp_comm_type == "a2a" - ), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" - assert ( - softmax_type == "vanilla" or qkv_format != "thd" - ), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" + ), f"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" + if get_cudnn_version() < (9, 18, 0): + assert softmax_type == "vanilla" or qkv_format != "thd", ( + f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with" + " qkv_format = 'thd'!" + ) args = [ is_training, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index fa4fb9a48..672a23c25 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -232,6 +232,11 @@ class DotProductAttention(TransformerEngineBaseModule): map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can be overridden by :attr:`window_size` in ``forward`` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. attention_type : str, default = "self" type of attention, either ``"self"`` and ``"cross"``. layer_number : int, default = None @@ -328,6 +333,7 @@ def __init__( qkv_format: str = "sbhd", attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, sequence_parallel: bool = False, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, @@ -356,6 +362,7 @@ def __init__( attn_mask_type = "padding_causal" self.attn_mask_type = attn_mask_type self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + self.bottom_right_diagonal = bottom_right_diagonal if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -438,7 +445,7 @@ def __init__( if self.softmax_type == "learnable": self.register_parameter( "softmax_offset", - Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")), + Parameter(torch.zeros(self.num_attention_heads // self.tp_size, device="cuda")), get_rng_state_tracker=get_rng_state_tracker, ) @@ -682,9 +689,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # assume attention uses the same fp8_group as GEMMs fp8_group = FP8GlobalStateManager.get_fp8_group() - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters()) + self.fast_setattr("fp8", FP8GlobalStateManager.is_fp8_enabled()) + self.fast_setattr("fp8_calibration", FP8GlobalStateManager.is_fp8_calibration()) fp8_enabled = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration if self.fp8_parameters or fp8_enabled: @@ -709,7 +716,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ) else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_setattr("fp8_initialized", False) return if self.fp8_parameters and not self.fp8_initialized: @@ -727,7 +734,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Allocate scales and amaxes self.init_fp8_meta_tensors(fp8_recipes) - self.fp8_initialized = True + self.fast_setattr("fp8_initialized", True) self.fp8_meta["recipe"] = fp8_recipe_dpa if fp8_recipe != fp8_recipe_dpa: @@ -817,6 +824,7 @@ def forward( max_seqlen_kv: int = None, attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, checkpoint_core_attention: bool = False, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -969,6 +977,16 @@ def forward( causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention. + bottom_right_diagonal: Optional[bool], default = None + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. + Note: This parameter will be automatically overridden based on the + `attn_mask_type` - it will be forced to `False` for 'causal' and + 'padding_causal' mask types, and forced to `True` for mask types + containing 'bottom_right' (e.g., 'causal_bottom_right', + 'padding_causal_bottom_right'), regardless of the explicitly passed value. checkpoint_core_attention : bool, default = False If true, forward activations for attention are recomputed during the backward pass in order to save memory that would @@ -1006,7 +1024,7 @@ def forward( cases. It is ignored for other backends and when context parallelism is enabled. """ - with self.prepare_forward( + with self.prepare_forward_ctx( query_layer, num_gemms=3, allow_non_contiguous=True, @@ -1087,6 +1105,15 @@ def forward( if window_size is None: window_size = self.window_size window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True # checks for qkv_format if qkv_format is None: @@ -1150,11 +1177,14 @@ def forward( assert "padding" in attn_mask_type, "KV caching requires padding mask!" if attn_mask_type == "padding_causal": attn_mask_type = attn_mask_type + "_bottom_right" + # since attention mask is changed, set `bottom_right_diagonal` to True + bottom_right_diagonal = True - self.attention_type = "cross" - self.flash_attention.attention_type = self.attention_type - self.fused_attention.attention_type = self.attention_type - self.unfused_attention.attention_type = self.attention_type + if self.attention_type != "cross": + self.fast_setattr("attention_type", "cross") + self.flash_attention.attention_type = self.attention_type + self.fused_attention.attention_type = self.attention_type + self.unfused_attention.attention_type = self.attention_type query_layer, key_layer, value_layer = [ x.contiguous() if not x.is_contiguous() else x @@ -1262,7 +1292,6 @@ def forward( if self.layer_number == 1: _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True - bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],) if core_attention_bias_type == "alibi": assert ( core_attention_bias is None @@ -1271,7 +1300,7 @@ def forward( _alibi_cache["_num_heads"] != query_layer.shape[-2] or _alibi_cache["_max_seqlen_q"] != max_seqlen_q or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv - or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment + or _alibi_cache["_bottom_right_alignment"] != bottom_right_diagonal or _alibi_cache["_alibi_slopes"] is None ): _alibi_cache["_alibi_slopes_require_update"] = True @@ -1295,11 +1324,14 @@ def forward( ): core_attention_bias_shape = "b1ss" elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1: - core_attention_bias_shape = "11ss" + if core_attention_bias.shape[2] == 1: + core_attention_bias_shape = "111s" + else: + core_attention_bias_shape = "11ss" else: assert ( False - ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" + ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss, 111s} shapes" # check if there is padding between sequences when qkv_format='thd' if pad_between_seqs is None: @@ -1328,6 +1360,7 @@ def forward( head_dim_v=head_dim_v, attn_mask_type=attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, core_attention_bias_type=core_attention_bias_type, core_attention_bias_shape=core_attention_bias_shape, @@ -1451,9 +1484,7 @@ def forward( if use_fused_attention: fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias = core_attention_bias - if core_attention_bias_type == "alibi" and ( - alibi_slopes is not None or max_seqlen_q != max_seqlen_kv - ): + if core_attention_bias_type == "alibi" and (alibi_slopes is not None): fu_core_attention_bias_type = "post_scale_bias" _, fu_core_attention_bias = dpa_utils.get_alibi( _alibi_cache, @@ -1462,7 +1493,7 @@ def forward( max_seqlen_kv, alibi_slopes=alibi_slopes, bias_dtype=query_layer.dtype, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_alignment=bottom_right_diagonal, ) if checkpoint_core_attention: return self._checkpointed_attention_forward( @@ -1480,6 +1511,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, @@ -1510,6 +1542,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, @@ -1528,7 +1561,9 @@ def forward( ) if use_unfused_attention: - allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + allow_emulation = ( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() + ) if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, @@ -1544,6 +1579,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, @@ -1567,6 +1603,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 9b7147106..3ee89a176 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -203,6 +203,9 @@ class AttentionParams: `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} window_size : Tuple[int, int], default = None Sliding window attention size. + bottom_right_diagonal: bool, default = `None` + Whether to align sliding window and ALiBi diagonal to the bottom right corner + of the softmax matrix. alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. core_attention_bias_type : str, default = no_bias @@ -252,6 +255,7 @@ class AttentionParams: head_dim_v: int = 64 attn_mask_type: str = "no_mask" window_size: Union[Tuple[int, int], None] = None + bottom_right_diagonal: bool = True alibi_slopes_shape: Union[torch.Size, List, None] = None core_attention_bias_type: str = "no_bias" core_attention_bias_shape: str = "1hss" @@ -330,6 +334,7 @@ def get_attention_backend( head_dim_v = attention_params.head_dim_v attn_mask_type = attention_params.attn_mask_type window_size = attention_params.window_size + bottom_right_diagonal = attention_params.bottom_right_diagonal alibi_slopes_shape = attention_params.alibi_slopes_shape core_attention_bias_type = attention_params.core_attention_bias_type core_attention_bias_shape = attention_params.core_attention_bias_shape @@ -479,7 +484,9 @@ def get_attention_backend( logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False if use_unfused_attention: - allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + allow_emulation = ( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() + ) if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False @@ -723,22 +730,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_unfused_attention = False if qkv_format == "thd": - logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type - ) - use_fused_attention = False - logger.debug( - "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", - softmax_type, - ) - use_unfused_attention = False + if not IS_HIP_EXTENSION and cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" + " version < 9.18", + softmax_type, + ) + use_fused_attention = False if context_parallel: - logger.debug( - "Disabling UnfusedDotProductAttention for context parallelism with softmax_type" - " = %s", - softmax_type, - ) - use_unfused_attention = False if cp_comm_type != "a2a": logger.debug( "Disabling FusedAttention for context parallelism with softmax_type = %s and" @@ -874,39 +873,43 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt # backend | window_size | diagonal alignment # --------------------------------------------------------------------------------- # FlashAttention | (-1, -1) or (>=0, >=0) | bottom right - # FusedAttention | (-1, 0) or (>=0, 0) | top left - # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both; + # FusedAttention | (-1, 0) or (>=0, >=0) | top left, bottom right + # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | top left, bottom right # | | converts window_size to an 'arbitrary' mask if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) - else: - if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention" - " for FP8" - ) - use_fused_attention = False - elif (not IS_HIP_EXTENSION) and (window_size[1] != 0 or attention_dropout != 0.0): - logger.debug( - "Disabling FusedAttention as it only supports sliding window attention " - "with (left, 0) and no dropout" - ) - use_fused_attention = False - elif max_seqlen_q > max_seqlen_kv: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with s_q > s_kv for cross-attention" - ) - use_fused_attention = False - if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if not FlashAttentionUtils.is_installed: - FlashAttentionUtils.version_required = PkgVersion("2.3") - elif not FlashAttentionUtils.v2_3_plus: - logger.debug( - "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" - ) - use_flash_attention_2 = False + if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention for FP8" + ) + use_fused_attention = False + elif not IS_HIP_EXTENSION and attention_dropout != 0.0: + logger.debug( + "Disabling FusedAttention as it only supports sliding window attention " + "without dropout" + ) + use_fused_attention = False + elif max_seqlen_q > max_seqlen_kv: + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention " + "with s_q > s_kv for cross-attention" + ) + use_fused_attention = False + if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.3") + elif not FlashAttentionUtils.v2_3_plus: + logger.debug( + "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" + ) + use_flash_attention_2 = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it only supports sliding window with bottom right" + " diagonal alignment for cross-attention" + ) + use_flash_attention = False # Filter: Attention bias # backend | bias types | ALiBi diagonal alignment @@ -928,6 +931,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt elif not FlashAttentionUtils.v2_4_plus: logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") use_flash_attention_2 = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it only supports ALiBi with bottom right diagonal" + " alignment for cross-attention" + ) + use_flash_attention = False if ( core_attention_bias_type not in ["no_bias", "alibi"] @@ -945,13 +954,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if ( use_fused_attention and core_attention_bias_type == "alibi" - and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv) + and (alibi_slopes_shape is not None) ): fu_core_attention_bias_type = "post_scale_bias" fu_core_attention_bias_requires_grad = False - if alibi_slopes_shape is None: - fu_core_attention_bias_shape = "1hss" - elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: + + if len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: fu_core_attention_bias_shape = "1hss" elif ( len(alibi_slopes_shape) == 2 @@ -960,19 +968,23 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ): fu_core_attention_bias_shape = "bhss" - # rocm ck backend support all 4 bias shapes (11ss, 1hss, b1ss, and bhss) - if ( - not IS_HIP_EXTENSION and + # rocm ck backend support 4 bias shapes (11ss, 1hss, b1ss, and bhss) + if IS_HIP_EXTENSION: + if use_fused_attention and fu_core_attention_bias_shape == "111s": + logger.debug("Disabling FusedAttention as ROCm backends do not support 111s") + use_fused_attention = False + elif ( use_fused_attention and fu_core_attention_bias_type == "post_scale_bias" and fu_core_attention_bias_shape != "1hss" ): - if fu_core_attention_bias_requires_grad: - # remove this line when cuDNN adds bwd support for - # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] - logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape") + # dbias calculation is not supported for 111s as of cuDNN 9.18. So, use fused attention backend only if bias does not require grad. + if fu_core_attention_bias_requires_grad and fu_core_attention_bias_shape == "111s": + logger.warning( + "Disabling FusedAttention as dbias calculation is not supported for 111s" + ) use_fused_attention = False - else: + elif not fu_core_attention_bias_requires_grad: # max512 backend will only support [1, h, s, s] os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" @@ -1003,6 +1015,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt window_size[1], return_max_logit, cuda_graph, + deterministic, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: logger.debug("Disabling FusedAttention as no backend supports the provided input") @@ -1057,8 +1070,24 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_2 = False if use_fused_attention and deterministic and (not IS_HIP_EXTENSION): - if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: - logger.debug("Disabling FusedAttention for determinism reasons with FP8") + if softmax_type != "vanilla": + logger.debug( + "Disabling FusedAttention for determinism reasons with softmax_type = %s. " + "Sink attention (off-by-one and learnable softmax) requires " + "NVTE_ALLOW_NONDETERMINISTIC_ALGO=1", + softmax_type, + ) + use_fused_attention = False + fused_attention_backend = None + if ( + fused_attention_backend == FusedAttnBackend["FP8"] + and is_training + and (device_compute_capability < (9, 0) or cudnn_version < (9, 19, 0)) + ): + logger.debug( + "Disabling FusedAttention for determinism reasons with FP8 on arch < sm90 or cuDNN" + " < 9.19.0" + ) use_fused_attention = False fused_attention_backend = None if ( @@ -1073,10 +1102,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") use_fused_attention = False fused_attention_backend = None - if is_training and device_compute_capability >= (10, 0): - logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") - use_fused_attention = False - fused_attention_backend = None # TODO: remove the filtering after ck team tells us how to enable more deterministic bwd kernels if use_fused_attention and deterministic and IS_HIP_EXTENSION: if ( diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index f875fd1e0..d95d327c7 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -5,10 +5,9 @@ """Multi-head Attention.""" import os import collections -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch -from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -32,6 +31,7 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb +from transformer_engine.pytorch.attention.dot_product_attention import utils as dpa_utils from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled @@ -93,6 +93,11 @@ class MultiheadAttention(torch.nn.Module): map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can be overridden by :attr:`window_size` in :meth:`forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. num_gqa_groups : int, default = None number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -248,6 +253,7 @@ def __init__( layer_number: Optional[int] = None, attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, num_gqa_groups: Optional[int] = None, @@ -286,6 +292,7 @@ def __init__( self.qkv_format = qkv_format self.attn_mask_type = attn_mask_type self.window_size = window_size + self.bottom_right_diagonal = bottom_right_diagonal self.layer_number = 1 if layer_number is None else layer_number self.input_layernorm = input_layernorm self.attention_type = attention_type @@ -335,6 +342,7 @@ def __init__( self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.name = name + TransformerEngineBaseModule._validate_name(self) common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, @@ -347,7 +355,7 @@ def __init__( } self.q_norm, self.k_norm = self._create_qk_norm_modules( - qk_norm_type, qk_norm_eps, device, seq_length, micro_batch_size + qk_norm_type, qk_norm_eps, device, seq_length, micro_batch_size, params_dtype ) qkv_parallel_mode = "column" if set_parallel_mode else None @@ -470,6 +478,10 @@ def __init__( **common_gemm_kwargs, ) + def fast_setattr(self, name: str, value: Any) -> None: + """Fast attribute set for non-parameter fields.""" + self.__dict__[name] = value + def _create_qk_norm_modules( self, qk_norm_type: Optional[str], @@ -477,6 +489,7 @@ def _create_qk_norm_modules( device: Union[torch.device, str], seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, ) -> Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]: """ Create query and key normalization modules based on the specified normalization type. @@ -493,6 +506,8 @@ def _create_qk_norm_modules( Sequence length for L2Normalization optimization micro_batch_size : Optional[int], default = None Micro batch size for L2Normalization optimization + params_dtype : Optional[torch.dtype], default = None + Data type for the normalization modules Returns ------- @@ -516,11 +531,13 @@ def _create_qk_norm_modules( normalized_shape=self.hidden_size_per_attention_head, eps=qk_norm_eps, device=device, + params_dtype=params_dtype, ) k_norm = RMSNorm( normalized_shape=self.hidden_size_per_attention_head, eps=qk_norm_eps, device=device, + params_dtype=params_dtype, ) return q_norm, k_norm @@ -529,11 +546,13 @@ def _create_qk_norm_modules( normalized_shape=self.hidden_size_per_attention_head, eps=qk_norm_eps, device=device, + params_dtype=params_dtype, ) k_norm = LayerNorm( normalized_shape=self.hidden_size_per_attention_head, eps=qk_norm_eps, device=device, + params_dtype=params_dtype, ) return q_norm, k_norm @@ -621,6 +640,7 @@ def forward( encoder_output: Optional[torch.Tensor] = None, attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, @@ -667,6 +687,11 @@ def forward( aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None sliding window size for local attention. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. encoder_output : Optional[torch.Tensor], default = None Output of the encoder block to be fed into the decoder block if using ``layer_type="decoder"``. @@ -731,6 +756,17 @@ def forward( if window_size is None: window_size = self.window_size + window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True + if "padding" in attn_mask_type and attention_mask is not None: for mask in attention_mask: assert mask.dtype == torch.bool, "Attention mask must be in boolean type!" @@ -739,9 +775,6 @@ def forward( core_attention_bias_type in AttnBiasTypes ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" - if TEDebugState.debug_enabled: - TransformerEngineBaseModule._validate_name(self) - # ================================================= # Pre-allocate memory for key-value cache for inference # ================================================= @@ -1004,6 +1037,7 @@ def forward( attention_mask=attention_mask, attn_mask_type=attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, checkpoint_core_attention=checkpoint_core_attention, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 120e63c04..5f4a95bb8 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -5,6 +5,7 @@ # See LICENSE for license information. """Enums for e2e transformer""" +from types import SimpleNamespace import torch import torch.distributed import transformer_engine_torch as tex @@ -51,6 +52,25 @@ def __missing__(self, key): tex.DType.kBFloat16: torch.bfloat16, }) +# Cache enum -> int conversions to avoid repeated PyObject lookups. +FP8FwdTensorIdx = SimpleNamespace( + GEMM1_INPUT=int(tex.FP8FwdTensors.GEMM1_INPUT), + GEMM1_WEIGHT=int(tex.FP8FwdTensors.GEMM1_WEIGHT), + GEMM1_OUTPUT=int(tex.FP8FwdTensors.GEMM1_OUTPUT), + GEMM2_INPUT=int(tex.FP8FwdTensors.GEMM2_INPUT), + GEMM2_WEIGHT=int(tex.FP8FwdTensors.GEMM2_WEIGHT), + GEMM2_OUTPUT=int(tex.FP8FwdTensors.GEMM2_OUTPUT), + GEMM3_OUTPUT=int(tex.FP8FwdTensors.GEMM3_OUTPUT), +) +FP8BwdTensorIdx = SimpleNamespace( + GRAD_INPUT1=int(tex.FP8BwdTensors.GRAD_INPUT1), + GRAD_INPUT2=int(tex.FP8BwdTensors.GRAD_INPUT2), + GRAD_INPUT3=int(tex.FP8BwdTensors.GRAD_INPUT3), + GRAD_OUTPUT1=int(tex.FP8BwdTensors.GRAD_OUTPUT1), + GRAD_OUTPUT2=int(tex.FP8BwdTensors.GRAD_OUTPUT2), + GRAD_OUTPUT3=int(tex.FP8BwdTensors.GRAD_OUTPUT3), +) + AttnMaskTypes = ( "no_mask", "padding", diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 788a9d7ef..daef41a3e 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -19,6 +19,7 @@ NVTE_Fused_Attn_Backend, ) from ..quantized_tensor import Quantizer +from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx __all__ = [ @@ -113,12 +114,13 @@ BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 -META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT -META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 -META_O = tex.FP8FwdTensors.GEMM2_INPUT -META_DO = tex.FP8BwdTensors.GRAD_INPUT2 -META_S = tex.FP8FwdTensors.GEMM3_OUTPUT -META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +META_QKV = FP8FwdTensorIdx.GEMM1_OUTPUT +META_DQKV = FP8BwdTensorIdx.GRAD_OUTPUT1 +META_O = FP8FwdTensorIdx.GEMM2_INPUT +META_DO = FP8BwdTensorIdx.GRAD_INPUT2 +META_S = FP8FwdTensorIdx.GEMM3_OUTPUT +META_DP = FP8BwdTensorIdx.GRAD_INPUT3 + def fused_attn_fwd( is_training: bool, @@ -146,6 +148,7 @@ def fused_attn_fwd( attn_mask_type: str = "padding", softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = None, rng_gen: torch.Generator = None, softmax_offset: torch.Tensor = None, return_max_logit: bool = False, @@ -221,6 +224,9 @@ def fused_attn_fwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = None + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. rng_gen : torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen @@ -267,46 +273,64 @@ def fused_attn_fwd( if IS_HIP_EXTENSION: assert not return_max_logit, "ROCm does not support return_max_logit yet." + if bottom_right_diagonal is None: + bottom_right_diagonal = attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + } + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) if attn_bias_type not in ["no_bias", "alibi"]: - assert ( - attn_bias is not None - ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." - assert attn_bias.dtype == q.dtype, "attn_bias tensor must be in the same dtype as q and kv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + if attn_bias is None: + raise ValueError( + f"attn_bias tensor cannot be None when attn_bias_type={attn_bias_type!r}." + ) + if attn_bias.dtype != q.dtype: + raise ValueError( + "attn_bias tensor must have the same dtype as q and kv: " + f"attn_bias.dtype={attn_bias.dtype} but q.dtype={q.dtype}." + ) + + if fused_attention_backend == FusedAttnBackend["No_Backend"]: + raise ValueError( + "Fused attention does not support this input combination:" + f" qkv_layout={qkv_layout!r}, attn_bias_type={attn_bias_type!r}," + f" attn_mask_type={attn_mask_type!r}, q.shape={list(q.shape)}," + f" q.dtype={q.dtype}, backend={fused_attention_backend}." + ) if IS_HIP_EXTENSION: # Both CK/aiter and aotriton follow the flash-attn rng design rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS + # BF16/FP16 fused attention API from fmha_v1 apex + elif fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: + rng_elts_per_thread = ( + max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 + ) // BACKEND_F16m512_FP8_THREADS_PER_CTA + # BF16/FP16 fused attention API from fmha_v2 + elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: + rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS + # FP8 fused attention API from fmha_v2 + elif fused_attention_backend == FusedAttnBackend["FP8"]: + rng_elts_per_thread = ( + max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 + ) // BACKEND_F16m512_FP8_THREADS_PER_CTA + + if s_quantizer is None: + raise ValueError( + "s_quantizer is required for FP8 fused attention forward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) + if o_quantizer is None: + raise ValueError( + "o_quantizer is required for FP8 fused attention forward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) else: - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - # BF16/FP16 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - # FP8 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["FP8"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention." - assert ( - o_quantizer is not None - ), "o_quantizer is required as an input for FP8 fused attention." - else: - raise ValueError(f"Unsupported backend {fused_attention_backend}") + raise ValueError(f"Unsupported backend {fused_attention_backend}") # execute kernel @@ -322,6 +346,7 @@ def fused_attn_fwd( AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], window_size, + bottom_right_diagonal, cu_seqlens_q, cu_seqlens_kv, q, @@ -386,6 +411,7 @@ def fused_attn_bwd( attn_mask_type: str = "padding", softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = None, deterministic: bool = False, cuda_graph: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -458,6 +484,9 @@ def fused_attn_bwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = None + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. deterministic : bool, default = False whether to execute the backward pass with deterministic behaviours. cuda_graph : bool, default = False @@ -478,33 +507,54 @@ def fused_attn_bwd( gradient tensor of softmax offset of shape [1, h_q, 1, 1]. See softmax_type in DotProductAttention for details. """ + if bottom_right_diagonal is None: + bottom_right_diagonal = attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + } + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - if not IS_HIP_EXTENSION: - if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - assert ( - len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." - - if fused_attention_backend == FusedAttnBackend["FP8"]: - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention backward." - assert ( - dp_quantizer is not None - ), "dp_quantizer is required as an input for FP8 fused attention backward." - assert ( - dqkv_dtype is not None - ), "dqkv_dtype is required as an input for FP8 fused attention backward." - assert ( - len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." + if fused_attention_backend == FusedAttnBackend["No_Backend"]: + raise ValueError( + "Fused attention backward does not support this input combination:" + f" qkv_layout={qkv_layout!r}, attn_bias_type={attn_bias_type!r}," + f" attn_mask_type={attn_mask_type!r}, q.shape={list(q.shape)}," + f" q.dtype={q.dtype}, backend={fused_attention_backend}." + ) + + if not IS_HIP_EXTENSION and fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: + if len(aux_ctx_tensors) < 1: + raise ValueError( + "aux_ctx_tensors must contain rng_state as its last element," + f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" + f" for backend={fused_attention_backend}." + ) + + if not IS_HIP_EXTENSION and fused_attention_backend == FusedAttnBackend["FP8"]: + if s_quantizer is None: + raise ValueError( + "s_quantizer is required for FP8 fused attention backward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) + if dp_quantizer is None: + raise ValueError( + "dp_quantizer is required for FP8 fused attention backward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) + if dqkv_dtype is None: + raise ValueError( + "dqkv_dtype is required for FP8 fused attention backward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) + if len(aux_ctx_tensors) != 3: + raise ValueError( + "aux_ctx_tensors must be [M, ZInv, rng_state] for FP8 fused attention," + f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" + f" (backend={fused_attention_backend})." + ) output_tensors = tex.fused_attn_bwd( max_seqlen_q, @@ -517,6 +567,7 @@ def fused_attn_bwd( AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], window_size, + bottom_right_diagonal, deterministic, cu_seqlens_q, cu_seqlens_kv, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 35fae5ac1..939a898b9 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -7,6 +7,7 @@ """Python interface for GEMM extensions""" from typing import Iterable, Optional, Tuple, Union, List +import ctypes import os import functools import torch @@ -27,6 +28,7 @@ __all__ = [ "general_gemm", "general_grouped_gemm", + "general_grouped_gemm_for_grouped_tensor", ] @@ -78,28 +80,6 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 -def get_tensor_device(tensor: torch.Tensor) -> int: - """ - Returns tensor device as an integer. - - This method is used because checking instances of - QuantizedTensor or Storage incurs more CPU overhead. - The order of attributes checked is important to also - minimize overhead. - """ - if hasattr(tensor, "device"): - return tensor.device.index - if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: - return tensor._rowwise_data.device.index - if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: - return tensor._columnwise_data.device.index - if hasattr(tensor, "_data") and tensor._data is not None: - return tensor._data.device.index - if hasattr(tensor, "_transpose") and tensor._transpose is not None: - return tensor._transpose.device.index - return torch.cuda.current_device() - - def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -128,7 +108,7 @@ def general_gemm( alpha = validate_gemm_scale(alpha, True) beta = validate_gemm_scale(beta, accumulate) - workspace = get_cublas_workspace(get_tensor_device(A), ub is not None, False) + workspace = get_cublas_workspace(A.device.index, ub is not None, False) if ub_type is not None: assert ub is not None, ( @@ -246,7 +226,7 @@ def general_grouped_gemm( out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype sm_count = get_sm_count() - workspaces = get_cublas_workspace(get_tensor_device(A[0]), False, True) + workspaces = get_cublas_workspace(A[0].device.index, False, True) if grad and use_bias: grad_bias = [ @@ -317,3 +297,113 @@ def general_grouped_gemm( ) return out, bias, gelu_input + + +@functools.lru_cache(maxsize=None) +def get_grouped_gemm_setup_workspace_size(num_tensors: int) -> int: + """Return workspace size for grouped GEMM pointer setup. + Must match GroupedGemmSetupWorkspace::required_setup_size in cublaslt_grouped_gemm.cu. + """ + ptr_bytes = ctypes.sizeof(ctypes.c_void_p) + int_bytes = ctypes.sizeof(ctypes.c_int) + ptr_size = num_tensors * ptr_bytes + int_size = num_tensors * int_bytes + k_ptr_alignment = 16 + # Each pointer array is placed at a 16-byte-aligned offset (matching kPtrAlignment in C++). + # aligned_ptr_size = round_up(num_tensors * ptr_bytes, 16) + aligned_ptr_size = ((ptr_size + k_ptr_alignment - 1) // k_ptr_alignment) * k_ptr_alignment + size = 8 * aligned_ptr_size + 6 * int_size + alignment = 256 + return ((size + alignment - 1) // alignment) * alignment + + +def general_grouped_gemm_for_grouped_tensor( + A, + B, + out, + *, + layout: str = "TN", + accumulate: bool = False, + use_split_accumulator: bool = False, + bias=None, + grad: bool = False, + alpha: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Grouped GEMM using GroupedTensor inputs. + + This uses nvte_grouped_gemm and supports different per-matrix shapes. + + The caller must ensure that GroupedTensor metadata is already compatible with the + underlying GEMM implementation (e.g., aligned offsets and output metadata layout). + """ + assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." + if grad: + raise NotImplementedError("grad is not supported for grouped_tensor GEMM yet.") + transa = layout[0] == "T" + transb = layout[1] == "T" + is_discrete_out = isinstance(out, list) + is_discrete_in = isinstance(A, list) + if is_discrete_in and is_discrete_out: + raise ValueError("Both A and out are discrete. This is not supported yet.") + + if is_discrete_out: + # wgrad case. + grouped_gemm_impl = tex.te_general_grouped_gemm_for_discrete_out + elif is_discrete_in: + # Use-case: forward pass with list of weights. + grouped_gemm_impl = tex.te_general_grouped_gemm_for_discrete_in + else: + # Use-case: Single Grouped Parameter for Weight/ Weight Grads. + grouped_gemm_impl = tex.te_general_grouped_gemm_for_grouped_tensor + + if is_discrete_out and bias is not None: + raise ValueError( + "Bias is not supported when out is a list (discrete_out mode) yet. " + "Apply bias manually after the GEMM." + ) + + num_tensors = B.num_tensors + rowwise = B.rowwise_data + device = rowwise.device if rowwise is not None else B.columnwise_data.device + + if alpha is None: + alpha = torch.ones(num_tensors, dtype=torch.float32, device=device) + if beta is None: + if accumulate: + beta = torch.ones(num_tensors, dtype=torch.float32, device=device) + else: + beta = torch.zeros(num_tensors, dtype=torch.float32, device=device) + + if not alpha.is_cuda or not beta.is_cuda: + raise ValueError("alpha and beta must be CUDA tensors.") + + workspace_setup = torch.empty( + get_grouped_gemm_setup_workspace_size(num_tensors), + dtype=torch.uint8, + device=device, + ) + workspace_cublas = torch.empty( + get_cublas_workspace_size_bytes(), + dtype=torch.uint8, + device=device, + ) + + sm_count = get_sm_count() + sm_count = sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))) + + return grouped_gemm_impl( + A, + transa, + B, + transb, + out, + bias, + alpha, + beta, + workspace_setup, + workspace_cublas, + use_split_accumulator, + sm_count, + ) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 05219b7b1..d0b314a64 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -124,7 +124,11 @@ def tensor_group_process_after_reload(tensor_group: TensorGroup): """ Call for a tensor group, just after reload logic. """ - assert tensor_group.aux is not None + if tensor_group.aux is None: + raise RuntimeError( + "TensorGroup.aux must be set before post-reload processing, " + f"but got aux=None for tensor_group with {len(tensor_group.tensor_list)} tensors" + ) tensor_group = TensorGroupProcessor._restore_tensor_duplicates(tensor_group) tensor_group = TensorGroupProcessor._switch_to_views(tensor_group) return tensor_group @@ -158,9 +162,8 @@ def _check_if_offload_base_tensor(tensor: torch.Tensor) -> bool: if _check_if_offload_base_tensor(tensor): aux["views"].append((tensor.shape, tensor.stride(), tensor.storage_offset())) tensor = tensor._base - assert ( - tensor is not None - ), "Cannot offload base tensor, if the tensor is not a view." + if tensor is None: + raise RuntimeError("Cannot offload base tensor, if the tensor is not a view.") tensor_group.tensor_list[tensor_id] = tensor else: aux["views"].append(None) @@ -247,9 +250,10 @@ def __init__( self.state = "not_offloaded" def _validate_state(self, func_name: str, allowed_states: list[str]): - assert ( - self.state in allowed_states - ), f"Invalid state: {self.state} for {func_name}, must be one of {allowed_states}" + if self.state not in allowed_states: + raise RuntimeError( + f"Invalid state: {self.state} for {func_name}, must be one of {allowed_states}" + ) def start_offload(self): """ @@ -271,7 +275,12 @@ def start_offload(self): ) for tensor_id, tensor in enumerate(self.fwd_gpu_tensor_group.tensor_list): - assert tensor.is_contiguous() + if not tensor.is_contiguous(): + raise ValueError( + f"Tensor at index {tensor_id} must be contiguous for CPU offloading, " + f"but got non-contiguous tensor with shape={tensor.shape}, " + f"stride={tensor.stride()}, dtype={tensor.dtype}" + ) # Wait for the moment the tensor is ready to be offloaded. self.offload_stream.wait_event(self.fwd_gpu_tensor_group.events[tensor_id]) # type: ignore[arg-type] @@ -284,12 +293,13 @@ def start_offload(self): self.cpu_tensor_group.tensor_list.append(offloaded_tensor) else: offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id] - assert offloaded_tensor.shape == tensor.shape, ( - "CPU buffer shape does not match the offloaded tensor shape:" - f" {offloaded_tensor.shape} != {tensor.shape} " - "Make sure that tensor shapes do not change between" - " iterations if retain_pinned_cpu_buffers is True." - ) + if offloaded_tensor.shape != tensor.shape: + raise ValueError( + "CPU buffer shape does not match the offloaded tensor shape:" + f" {offloaded_tensor.shape} != {tensor.shape} " + "Make sure that tensor shapes do not change between" + " iterations if retain_pinned_cpu_buffers is True." + ) offloaded_tensor.copy_(tensor, non_blocking=True) # aux is a dictionary that contains auxiliary data like information which tensors were deduplicated, @@ -420,7 +430,11 @@ def pop_tensor( return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id] # 4. the layer was offloaded - assert self.state == "reload_started" + if self.state != "reload_started": + raise RuntimeError( + "Expected state='reload_started' when popping an offloaded tensor, " + f"but got state='{self.state}' for tensor={tensor_or_tensor_id}" + ) # wait for the tensor to be reloaded torch.cuda.current_stream().wait_event( self.bwd_gpu_tensor_group.events[tensor_or_tensor_id] @@ -685,7 +699,7 @@ def get_cpu_offload_context( offload_stream: Optional[torch.cuda.Stream] = None, ): """ - CPU Offloading feature for seqeuences of layers. Can be used for arbitrary layers, not necessarily + CPU Offloading feature for sequences of layers. Can be used for arbitrary layers, not necessarily for these provided by the TE. Usage: @@ -710,7 +724,7 @@ def get_cpu_offload_context( Number of layers in the model that will be used under this context. offload_activations : bool, default = True Deprecated. - offload_weights : bool, default = True + offload_weights : bool, default = False Deprecated. double_buffering : bool, default = False Deprecated. @@ -769,14 +783,14 @@ def get_cpu_offload_context( out[i] = sync_function(out[i]) manual_controller.start_offload_layer(i) - offload_stream.synchronize() + # Release GPU memory - each call inserts a GPU-side wait_event on the compute stream for i in range(num_layers): manual_controller.release_activation_forward_gpu_memory(i) + # Start reloading - backward will wait for each tensor's reload via wait_event for i in range(num_layers - 1, -1, -1): manual_controller.start_reload_layer(i) - offload_stream.synchronize() for i in range(num_layers): out[i].sum().backward() @@ -824,18 +838,19 @@ def get_cpu_offload_context( raise RuntimeError("CPU offload is not supported in debug mode.") if not manual_synchronization: - assert ( - num_layers <= model_layers - 1 - ), "Cannot offload all layers without manual synchronization - last layer is not offloaded." + if num_layers > model_layers - 1: + raise ValueError( + "Cannot offload all layers without manual synchronization - last layer is not" + f" offloaded. Got num_layers={num_layers}, model_layers={model_layers}." + ) if num_layers == model_layers - 1: warnings.warn( "Offloading num_layers == model_layers - 1 is not recommended, it prevents" " overlapping of computation and offload/reload." ) - assert ( - offload_stream is None or manual_synchronization - ), "offload_stream can be provided only if manual_synchronization is True" + if offload_stream is not None and not manual_synchronization: + raise ValueError("offload_stream can be provided only if manual_synchronization is True") if manual_synchronization: offload_synchronizer = ManualOffloadSynchronizer( @@ -858,9 +873,10 @@ def __init__(self): self.inside_context = False def __enter__(self): - assert ( - self.inside_context is False - ), "Offloading context was entered without synchronization function being called." + if self.inside_context: + raise RuntimeError( + "Offloading context was entered without synchronization function being called." + ) self.inside_context = True self._hooks_ctx = saved_tensors_hooks( offload_synchronizer.push_tensor, offload_synchronizer.pop_tensor @@ -882,12 +898,23 @@ def synchronization_function(self, tensor): """ This function is used to catch the backward pass of the model. """ - assert tensor.requires_grad is True - assert self.current_layer is not None + if not tensor.requires_grad: + raise ValueError( + "Tensor passed to synchronization_function must require grad to " + "register backward hooks, but got requires_grad=False for tensor " + f"with shape={tensor.shape}, dtype={tensor.dtype}" + ) + if self.current_layer is None: + raise RuntimeError( + "synchronization_function called but no layer has been set via __enter__. " + f"inside_context={self.inside_context}, " + f"offload_synchronizer num_layers={self.offload_synchronizer.num_layers}" + ) cur_layer = self.current_layer - assert ( - self.inside_context is False - ), "Synchronization function was called without offloading context being entered." + if self.inside_context: + raise RuntimeError( + "Synchronization function was called without offloading context being entered." + ) def hook(_): # offload_synchronizer.finish_part_of_bwd needs diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 9cbba112b..e958fc9a4 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -278,7 +278,8 @@ at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType } else if (size == 1) { return at::empty({static_cast(shape.data[0])}, at::CUDA(GetATenDType(type))); } - NVTE_CHECK(false, "Should never reach here! func: allocateSpace"); + NVTE_ERROR("Unsupported tensor allocation: ndim=", size, ", init_to_zeros=", init_to_zeros, + ". Only 1D and 2D tensors are supported."); } at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype) { diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 02b135f4f..c46059377 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -111,6 +111,12 @@ class Quantizer { virtual std::pair create_tensor(const std::vector& shape, DType dtype) const = 0; + /*! @brief Construct a grouped tensor with uninitialized data */ + virtual std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const = 0; + /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor * * The PyTorch tensor's attributes are modified to match the @@ -146,6 +152,11 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; @@ -172,6 +183,11 @@ class Float8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, std::optional data, @@ -204,6 +220,11 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. * * The amax is zeroed out. Most TE kernels that output amax expect @@ -261,6 +282,11 @@ class Float8BlockQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, @@ -282,6 +308,11 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, @@ -317,6 +348,11 @@ class NVFP4Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer * * The amax is zeroed out. Most TE kernels that output amax expect diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index df8b548cc..fd0013f3e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -29,23 +29,22 @@ namespace transformer_engine::pytorch { **************************************************************************************************/ std::tuple fused_topk_with_score_function_fwd( - at::Tensor logits, int topk, bool use_pre_softmax, c10::optional num_groups, - c10::optional group_topk, c10::optional scaling_factor, std::string score_function, - c10::optional expert_bias); + at::Tensor logits, int topk, bool use_pre_softmax, std::optional num_groups, + std::optional group_topk, std::optional scaling_factor, std::string score_function, + std::optional expert_bias); -at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts, - at::Tensor routing_map, - at::Tensor intermediate_output, at::Tensor grad_probs, - int topk, bool use_pre_softmax, - c10::optional scaling_factor, - std::string score_function); +void fused_topk_with_score_function_bwd(int num_tokens, int num_experts, at::Tensor routing_map, + at::Tensor intermediate_output, at::Tensor grad_probs, + at::Tensor grad_logits, int topk, bool use_pre_softmax, + std::optional scaling_factor, + std::string score_function); std::tuple fused_score_for_moe_aux_loss_fwd( at::Tensor logits, int topk, std::string score_function); -at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, - at::Tensor intermediate_output, at::Tensor grad_probs, - int topk, std::string score_function); +void fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, + at::Tensor intermediate_output, at::Tensor grad_probs, + at::Tensor grad_logits, int topk, std::string score_function); std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, at::Tensor tokens_per_expert, @@ -83,15 +82,16 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph); + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, + const std::vector window_size, bool bottom_right_diagonal, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, @@ -101,10 +101,10 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, + bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -151,6 +151,25 @@ std::optional> te_general_grouped_gemm( std::vector pre_gelu_out, bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +py::object te_general_grouped_gemm_for_grouped_tensor( + py::handle A, bool transa, py::handle B, bool transb, py::handle D, py::object bias, + at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, at::Tensor workspace_cublas, + bool use_split_accumulator, int math_sm_count); + +py::object te_general_grouped_gemm_for_discrete_in(py::handle A, bool transa, py::handle B, + bool transb, py::handle D, py::object bias, + at::Tensor alpha, at::Tensor beta, + at::Tensor workspace_setup, + at::Tensor workspace_cublas, + bool use_split_accumulator, int math_sm_count); + +py::object te_general_grouped_gemm_for_discrete_out(py::handle A, bool transa, py::handle B, + bool transb, py::handle D, py::object bias, + at::Tensor alpha, at::Tensor beta, + at::Tensor workspace_setup, + at::Tensor workspace_cublas, + bool use_split_accumulator, int math_sm_count); + /*************************************************************************************************** * Transpose **************************************************************************************************/ @@ -158,12 +177,50 @@ std::optional> te_general_grouped_gemm( at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output = std::nullopt); +at::Tensor nvfp4_data_transpose(at::Tensor input, std::optional output = std::nullopt); + +void nvfp4_2d_scale_transpose(at::Tensor input, at::Tensor output, int64_t M_tiles, + int64_t K_tiles); + +void nvfp4_2d_multi_tensor_transpose(std::vector rowwise_data_list, + std::vector columnwise_data_list, + std::vector rowwise_scale_inv_list, + std::vector columnwise_scale_inv_list, + std::vector M_list, std::vector K_list); + +void nvfp4_multi_tensor_compute_partial_amax( + std::vector master_weight_list, std::vector partial_amax_list, + std::vector global_amax_list, std::vector h_list, + std::vector w_list, std::vector start_offset_list, int64_t block_len); + +void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, int64_t tile_rows, + int64_t tile_cols, int64_t rows_padded, int64_t block_len); + +void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, at::Tensor global_amax); + +void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, at::Tensor per_block_scale, + at::Tensor target_scale, at::Tensor target_amax, int64_t tile_rows, + int64_t tile_cols, int64_t rows_padded, int64_t block_len); + +void nvfp4_multi_tensor_fused_scale( + std::vector block_amax_list, std::vector global_amax_list, + std::vector per_block_scale_list, std::vector target_scale_list, + std::vector target_amax_list, std::vector tile_rows_list, + std::vector tile_cols_list, std::vector rows_padded_list, int64_t block_len); + +void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale); + at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = std::nullopt); /*************************************************************************************************** * Activations **************************************************************************************************/ +/* GLU (sigmoid gate) */ +py::object glu(const at::Tensor &input, py::handle quantizer); + +py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + /* GELU and variants*/ py::object gelu(const at::Tensor &input, py::handle quantizer); @@ -251,6 +308,9 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob py::object dequantize(const py::handle &input, DType otype); +py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, + std::optional first_dims); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); @@ -338,6 +398,19 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const size_t h, size_t w, size_t start_offset, size_t block_len, const DType out_dtype); +void nvfp4_2d_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, size_t w, + size_t start_offset, size_t block_len); + +void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tensor &scale, + const at::Tensor &global_scale, size_t h, size_t w, size_t start_offset, + size_t block_len); + +void nvfp4_multi_tensor_2d_partial_cast(std::vector inp_list, + std::vector out_list, + std::vector scale_list, + std::vector global_scale_list, + std::vector h_list, std::vector w_list, + std::vector start_offset_list, int64_t block_len); void mxfp8_scaling_compute_partial_amax(const at::Tensor &input, at::Tensor amax_rowwise, at::Tensor amax_colwise, int rows, int cols, size_t start_offset); @@ -382,9 +455,11 @@ at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tenso #ifndef USE_ROCM size_t get_cublasLt_version(); size_t get_cudnn_version(); +#else +void placeholder(); #endif -void placeholder(); +at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim); /*************************************************************************************************** * Support THD format for Context Parallel @@ -417,6 +492,10 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, float scale); +void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor is_infinite, + std::vector> tensor_lists, + at::Tensor scale); + std::tuple multi_tensor_l2norm_cuda( int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::optional per_tensor_python); @@ -570,6 +649,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve ~CommOverlap() {} + using transformer_engine::CommOverlapCore::copy_into_buffer; void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); at::Tensor get_buffer(bool local_chunk = false, @@ -591,6 +671,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm ~CommOverlapP2P() {} + using transformer_engine::CommOverlapP2PBase::copy_into_buffer; void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); at::Tensor get_buffer(bool local_chunk = false, diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 7bf9b35a0..df6906c18 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -260,6 +260,14 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua return dactivation_helper(grad, input, quantizer); } +py::object glu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); +} + +py::object dglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + py::object geglu(const at::Tensor& input, py::handle quantizer) { return activation_helper(input, quantizer, 2); } diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index b455e0375..bf62db8c3 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -45,12 +45,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, deterministic); return fused_attention_backend; } @@ -100,9 +100,10 @@ std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, + const std::vector window_size, bool bottom_right_diagonal, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, @@ -235,7 +236,7 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -295,7 +296,7 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -310,10 +311,10 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, + bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -532,14 +533,14 @@ std::vector fused_attn_bwd( // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), - te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), - te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd( + te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -549,14 +550,14 @@ std::vector fused_attn_bwd( // execute kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), - te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), - te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd( + te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5a6a98442..985af18d3 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -82,6 +83,163 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob return output_py; } +#ifndef USE_ROCM +namespace { + +// helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy) +void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor, + GroupedTensorWrapper &grouped_output_tensor, + NVFP4Quantizer *nvfp4_quantizer_cpp, cudaStream_t stream) { + size_t num_tensors = grouped_input_tensor.num_tensors(); + + // assert the 2D scaling case, since 2D scaling grouped quant kernel is not ready yet + NVTE_CHECK(!nvfp4_quantizer_cpp->with_2d_quantization, + "2D scaling grouped quant kernel is not ready yet"); + + auto quant_config_cpp = QuantizationConfigWrapper(); + + // stochastic rounding + bool need_stochastic_rounding = nvfp4_quantizer_cpp->stochastic_rounding; + auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + at::Tensor rng_states_tensor; // Declare tensor outside, do not allocate yet + TensorWrapper te_rng_state; + + if (need_stochastic_rounding) { + // in fused kernel, one rng state will be used by the grouped kernel to generate random + // number for different tensors in the group, so we only need to allocate one rng state + const size_t rng_elts_per_thread = 1024 * num_tensors; + rng_states_tensor = torch::empty({2}, opts); + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); + philox_unpack(philox_args, static_cast(rng_states_tensor.data_ptr())); + + te_rng_state = makeTransformerEngineTensor(rng_states_tensor); + quant_config_cpp.set_rng_state(te_rng_state.data()); + quant_config_cpp.set_stochastic_rounding(true); + } + + // fast math + const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); + if (use_fast_math) { + quant_config_cpp.set_use_fast_math(true); + } + + // so far, only the RHT path has grouped kernel support + // grouped kernels for non-RHT path will be added later + + if (nvfp4_quantizer_cpp->with_rht) { + // post-RHT amax or not + if (nvfp4_quantizer_cpp->with_post_rht_amax) { + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_hadamard_transform_amax_graph_safe( + grouped_input_tensor.data(), grouped_output_tensor.data(), 0, + nvfp4_quantizer_cpp->rht_matrix_random_sign_mask_t, stream); + }); + } else { + NVTE_ERROR("graph safe grouped quant kernel for non-RHT path is not ready yet"); + } + + // RHT cast fusion + auto tile_scheduler_workspace_torch = + at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32)); + auto nvte_tile_scheduler_workspace = + makeTransformerEngineTensor(tile_scheduler_workspace_torch); + + auto rht_matrix_nvte = makeTransformerEngineTensor(nvfp4_quantizer_cpp->rht_matrix); + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_hadamard_transform_cast_fusion_graph_safe( + grouped_input_tensor.data(), grouped_output_tensor.data(), rht_matrix_nvte.data(), + quant_config_cpp, nvte_tile_scheduler_workspace.data(), stream); + }); + + } else { + NVTE_ERROR("graph safe grouped quant kernel for non-RHT path is not ready yet"); + } +} + +} // namespace +#endif // USE_ROCM + +// NOTE: Only supports varying first dim. +py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, + std::optional first_dims) { + using namespace transformer_engine::pytorch::detail; + init_extension(); + + NVTE_CHECK(tensor.dim() == 2, "Tensor must be 2D"); + + std::vector logical_shape; + for (const auto &d : tensor.sizes()) { + logical_shape.push_back(d); + } + const auto logical_first_dim = logical_shape[0]; + const auto logical_last_dim = logical_shape[1]; + + bool empty_input_buffer = logical_first_dim == 0 || logical_last_dim == 0; + + auto quantizer_cpp = convert_quantizer(quantizer); + + // Create input GroupedTensor. + auto grouped_input_tensor = GroupedTensorWrapper(num_tensors, logical_shape); + grouped_input_tensor.set_rowwise_data( + tensor.data_ptr(), GetTransformerEngineDType(tensor.scalar_type()), getTensorShape(tensor)); + + // Create output GroupedTensor. + auto [grouped_output_tensor_cpp, grouped_output_py] = quantizer_cpp->create_grouped_tensor( + num_tensors, logical_shape, GetTransformerEngineDType(tensor.scalar_type()), + py::reinterpret_borrow(quantizer), first_dims, logical_first_dim, + logical_last_dim); + + // dispatch to scaling methods + enum class GroupedQuantizationMode { + MXFP8_GROUPED_QUANTIZE, + NVFP4_GROUPED_QUANTIZE, + INVALID_FOR_GROUPED_QUANTIZE + }; + GroupedQuantizationMode grouped_quantization_mode = + GroupedQuantizationMode::INVALID_FOR_GROUPED_QUANTIZE; + if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + grouped_quantization_mode = GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE; + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + grouped_quantization_mode = GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE; + } + + if (empty_input_buffer) { + // early return for empty input buffer + // just return the output tensor as is + // no need to quantize + return py::reinterpret_borrow(grouped_output_py); + } + + switch (grouped_quantization_mode) { + case GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE: { +#ifdef USE_ROCM + NVTE_ERROR("NVFP4 grouped quantization is not supported on ROCm platform."); +#else + // NVFP4 grouped quantization + NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + group_quantize_nvfp4_impl(grouped_input_tensor, grouped_output_tensor_cpp, + nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream()); +#endif // USE_ROCM + break; + } + case GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE: { + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_quantize(grouped_input_tensor.data(), grouped_output_tensor_cpp.data(), + at::cuda::getCurrentCUDAStream()); + }); + break; + } + case GroupedQuantizationMode::INVALID_FOR_GROUPED_QUANTIZE: + default: + NVTE_ERROR("group_quantize: only support NVFP4 or MXFP8 quantizer."); + break; + } + + return py::reinterpret_borrow(grouped_output_py); +} + py::object dequantize(const py::handle &input, transformer_engine::DType otype) { init_extension(); @@ -1211,9 +1369,19 @@ std::vector split_quantize(const at::Tensor &tensor, for (auto &quantizer : quantizer_cpp_list) { nvfp4_quantizers.push_back(static_cast(quantizer.get())); } - bool contiguous_data_and_scale; + bool contiguous_data_and_scale = false; std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); + if (!input_shape.empty() && input_shape.back() % 128 != 0) { + static std::once_flag once_unfused_nvfp4_fallback_warning; + std::call_once(once_unfused_nvfp4_fallback_warning, []() { + NVTE_WARN( + "Unfused NVFP4 quantization fallback is triggered because the input tensor inner " + "dimension is not a multiple of 128, disabling NVFP4 grouped kernel fusion. " + "NVFP4 might bring performance regressions for this input tensor shape."); + }); + quantization_method = QuantizationMethod::UNFUSED; + } if (!contiguous_data_and_scale) { // Avoid fused quantize kernel if data is not contiguous quantization_method = QuantizationMethod::UNFUSED; diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 941b88e36..44c642202 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -80,6 +80,46 @@ bool checkGemmShape(const std::vector& expected, const NVTEShape& actual return true; } +struct GroupedGemmConfig { + TensorWrapper te_alpha; + TensorWrapper te_beta; + TensorWrapper te_workspace_setup; + TensorWrapper te_workspace_cublas; + std::optional matmul_config; +}; + +GroupedGemmConfig prepare_grouped_gemm_config(at::Tensor alpha, at::Tensor beta, + at::Tensor workspace_setup, + at::Tensor workspace_cublas, size_t num_tensors, + int math_sm_count, bool use_split_accumulator) { + NVTE_CHECK(alpha.numel() == static_cast(num_tensors), + "Grouped GEMM expects alpha to have num_tensors elements."); + NVTE_CHECK(beta.numel() == static_cast(num_tensors), + "Grouped GEMM expects beta to have num_tensors elements."); + + GroupedGemmConfig grouped_gemm_config{ + makeTransformerEngineTensor(alpha), + makeTransformerEngineTensor(beta), + makeTransformerEngineTensor(workspace_setup.data_ptr(), + std::vector{static_cast(workspace_setup.numel())}, + DType::kByte), + makeTransformerEngineTensor( + workspace_cublas.data_ptr(), + std::vector{static_cast(workspace_cublas.numel())}, DType::kByte), + std::nullopt, + }; + + if (math_sm_count > 0 || use_split_accumulator) { + grouped_gemm_config.matmul_config.emplace(); + if (math_sm_count > 0) { + grouped_gemm_config.matmul_config->set_sm_count(math_sm_count); + } + grouped_gemm_config.matmul_config->set_use_split_accumulator(use_split_accumulator); + } + + return grouped_gemm_config; +} + } // namespace detail std::pair createOutputTensor(const std::vector& shape, @@ -598,4 +638,180 @@ std::optional> te_general_grouped_gemm( return bias; } +py::object te_general_grouped_gemm_for_grouped_tensor( + py::handle A, bool transa, py::handle B, bool transb, py::handle D, py::object bias, + at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, at::Tensor workspace_cublas, + bool use_split_accumulator, int math_sm_count) { + using namespace transformer_engine::pytorch::detail; + + init_extension(); + + // Ensure that cublasLt handle is created on the correct device, + // overriding torch.cuda.set_device calls from user side. + // Assumes all tensors passed are on the same device. + at::cuda::CUDAGuard device_guard(workspace_cublas.device()); + + auto grouped_A = GroupedTensorFromPyTorchGroupedTensor(A); + auto grouped_B = GroupedTensorFromPyTorchGroupedTensor(B); + auto grouped_D = GroupedTensorFromPyTorchGroupedTensor(D); + + const size_t num_tensors = grouped_A.num_tensors(); + NVTE_CHECK(num_tensors > 0, "Grouped GEMM requires non-empty inputs."); + NVTE_CHECK(grouped_B.num_tensors() == num_tensors, + "Grouped GEMM requires A and B to have the same num_tensors."); + NVTE_CHECK(grouped_D.num_tensors() == num_tensors, + "Grouped GEMM requires D to have the same num_tensors as inputs."); + + auto gemm_config = prepare_grouped_gemm_config(alpha, beta, workspace_setup, workspace_cublas, + num_tensors, math_sm_count, use_split_accumulator); + + [[maybe_unused]] auto swizzled_scales_A = maybe_swizzle_grouped_tensor_for_gemm(grouped_A); + [[maybe_unused]] auto swizzled_scales_B = maybe_swizzle_grouped_tensor_for_gemm(grouped_B); + + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm(grouped_A.data(), transa, grouped_B.data(), transb, grouped_D.data(), + grouped_D.data(), gemm_config.te_alpha.data(), gemm_config.te_beta.data(), + gemm_config.te_workspace_setup.data(), gemm_config.te_workspace_cublas.data(), + gemm_config.matmul_config.has_value() + ? static_cast(*gemm_config.matmul_config) + : nullptr, + at::cuda::getCurrentCUDAStream()); + }); + + if (!bias.is_none()) { + auto grouped_bias = GroupedTensorFromPyTorchGroupedTensor(bias); + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_bias_add(grouped_D.data(), grouped_bias.data(), + at::cuda::getCurrentCUDAStream()); + }); + } + + return py::reinterpret_borrow(D); +} + +py::object te_general_grouped_gemm_for_discrete_in(py::handle A, bool transa, py::handle B, + bool transb, py::handle D, py::object bias, + at::Tensor alpha, at::Tensor beta, + at::Tensor workspace_setup, + at::Tensor workspace_cublas, + bool use_split_accumulator, int math_sm_count) { + using namespace transformer_engine::pytorch::detail; + + init_extension(); + + // Ensure that cublasLt handle is created on the correct device, + // overriding torch.cuda.set_device calls from user side. + // Assumes all tensors passed are on the same device. + at::cuda::CUDAGuard device_guard(workspace_cublas.device()); + + auto grouped_B = GroupedTensorFromPyTorchGroupedTensor(B); + auto grouped_D = GroupedTensorFromPyTorchGroupedTensor(D); + + const auto A_list = py::cast>(A); + const size_t num_tensors = grouped_B.num_tensors(); + NVTE_CHECK(num_tensors > 0, "Grouped GEMM requires non-empty inputs."); + NVTE_CHECK(A_list.size() == num_tensors, + "Grouped GEMM requires A_list to have num_tensors elements."); + NVTE_CHECK(grouped_D.num_tensors() == num_tensors, + "Grouped GEMM requires D to have the same num_tensors as inputs."); + + auto gemm_config = prepare_grouped_gemm_config(alpha, beta, workspace_setup, workspace_cublas, + num_tensors, math_sm_count, use_split_accumulator); + + std::vector te_A_wrappers; + std::vector te_A_vector; + te_A_wrappers.reserve(num_tensors); + te_A_vector.reserve(num_tensors); + const auto none = py::none(); + for (const auto& tensor : A_list) { + te_A_wrappers.emplace_back(makeTransformerEngineTensor(tensor, none)); + te_A_vector.emplace_back(te_A_wrappers.back().data()); + } + + std::vector> swizzled_scale_inverses_list; + swizzled_scale_inverses_list.emplace_back( + multi_tensor_swizzle_scales_for_gemm(te_A_wrappers, transa, !transa)); + + [[maybe_unused]] auto swizzled_scales_B = maybe_swizzle_grouped_tensor_for_gemm(grouped_B); + + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm_with_discrete_inputA( + te_A_vector.data(), num_tensors, transa, grouped_B.data(), transb, grouped_D.data(), + grouped_D.data(), gemm_config.te_alpha.data(), gemm_config.te_beta.data(), + gemm_config.te_workspace_setup.data(), gemm_config.te_workspace_cublas.data(), + gemm_config.matmul_config.has_value() + ? static_cast(*gemm_config.matmul_config) + : nullptr, + at::cuda::getCurrentCUDAStream()); + }); + + if (!bias.is_none()) { + auto grouped_bias = GroupedTensorFromPyTorchGroupedTensor(bias); + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_bias_add(grouped_D.data(), grouped_bias.data(), + at::cuda::getCurrentCUDAStream()); + }); + } + + return py::reinterpret_borrow(D); +} + +py::object te_general_grouped_gemm_for_discrete_out(py::handle A, bool transa, py::handle B, + bool transb, py::handle D, py::object bias, + at::Tensor alpha, at::Tensor beta, + at::Tensor workspace_setup, + at::Tensor workspace_cublas, + bool use_split_accumulator, int math_sm_count) { + using namespace transformer_engine::pytorch::detail; + + init_extension(); + + // Ensure that cublasLt handle is created on the correct device, + // overriding torch.cuda.set_device calls from user side. + // Assumes all tensors passed are on the same device. + at::cuda::CUDAGuard device_guard(workspace_cublas.device()); + + NVTE_CHECK(bias.is_none(), "Bias is not supported for discrete output grouped GEMM."); + + auto grouped_A = GroupedTensorFromPyTorchGroupedTensor(A); + auto grouped_B = GroupedTensorFromPyTorchGroupedTensor(B); + + const auto D_list = py::cast>(D); + const size_t num_tensors = grouped_A.num_tensors(); + NVTE_CHECK(num_tensors > 0, "Grouped GEMM requires non-empty inputs."); + NVTE_CHECK(grouped_B.num_tensors() == num_tensors, + "Grouped GEMM requires A and B to have the same num_tensors."); + NVTE_CHECK(D_list.size() == num_tensors, + "Grouped GEMM requires D_list to have num_tensors elements."); + + auto gemm_config = prepare_grouped_gemm_config(alpha, beta, workspace_setup, workspace_cublas, + num_tensors, math_sm_count, use_split_accumulator); + + std::vector te_D_wrappers; + std::vector te_D_vector; + te_D_wrappers.reserve(num_tensors); + te_D_vector.reserve(num_tensors); + const auto none = py::none(); + for (const auto& tensor : D_list) { + te_D_wrappers.emplace_back(makeTransformerEngineTensor(tensor, none)); + te_D_vector.emplace_back(te_D_wrappers.back().data()); + } + + [[maybe_unused]] auto swizzled_scales_A = maybe_swizzle_grouped_tensor_for_gemm(grouped_A); + [[maybe_unused]] auto swizzled_scales_B = maybe_swizzle_grouped_tensor_for_gemm(grouped_B); + + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm_with_discrete_out( + grouped_A.data(), transa, grouped_B.data(), transb, te_D_vector.data(), num_tensors, + te_D_vector.data(), num_tensors, gemm_config.te_alpha.data(), gemm_config.te_beta.data(), + gemm_config.te_workspace_setup.data(), gemm_config.te_workspace_cublas.data(), + gemm_config.matmul_config.has_value() + ? static_cast(*gemm_config.matmul_config) + : nullptr, + at::cuda::getCurrentCUDAStream()); + }); + + return py::reinterpret_borrow(D); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index 4f1c22fc6..2949f79cb 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -9,10 +9,31 @@ #include "../extensions.h" namespace transformer_engine::pytorch { + #ifndef USE_ROCM size_t get_cublasLt_version() { return cublasLtGetVersion(); } size_t get_cudnn_version() { return cudnnGetVersion(); } -#endif +#else void placeholder() {} +#endif + +at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim) { + NVTE_CHECK(first_dims.is_cuda(), "first_dims must be on CUDA."); + NVTE_CHECK(first_dims.scalar_type() == at::kLong, "first_dims must have dtype int64."); + NVTE_CHECK(first_dims.dim() == 1, "first_dims must be a 1D tensor."); + NVTE_CHECK(logical_last_dim > 0, "logical_last_dim must be greater than 0."); + + auto first_dims_contiguous = first_dims.contiguous(); + const auto num_tensors = static_cast(first_dims_contiguous.numel()); + auto output = at::empty({static_cast(num_tensors) + 1}, + first_dims_contiguous.options().dtype(at::kLong)); + + nvte_splits_to_offsets(static_cast(first_dims_contiguous.data_ptr()), + static_cast(output.data_ptr()), num_tensors, logical_last_dim, + at::cuda::getCurrentCUDAStream()); + + return output; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp index 4bb83bfee..687eb34f3 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp @@ -8,14 +8,26 @@ namespace transformer_engine::pytorch { -void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, +void multi_tensor_scale_cuda(int chunk_size, at::Tensor is_infinite, std::vector> tensor_lists, float scale) { - auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); + auto is_infinite_cu = makeTransformerEngineTensor(is_infinite); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, - num_tensors, scale, at::cuda::getCurrentCUDAStream()); + nvte_multi_tensor_scale_cuda(chunk_size, is_infinite_cu.data(), tensor_lists_ptr.data(), + num_lists, num_tensors, scale, at::cuda::getCurrentCUDAStream()); +} + +void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor is_infinite, + std::vector> tensor_lists, + at::Tensor scale) { + auto is_infinite_cu = makeTransformerEngineTensor(is_infinite); + auto scale_cu = makeTransformerEngineTensor(scale); + auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = + makeTransformerEngineTensorList(tensor_lists); + nvte_multi_tensor_scale_tensor_cuda(chunk_size, is_infinite_cu.data(), tensor_lists_ptr.data(), + num_lists, num_tensors, scale_cu.data(), + at::cuda::getCurrentCUDAStream()); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp new file mode 100644 index 000000000..685250d13 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp @@ -0,0 +1,156 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../extensions.h" + +namespace transformer_engine::pytorch { + +void nvfp4_2d_compute_partial_amax(const at::Tensor& tensor, at::Tensor amax, size_t h, size_t w, + size_t start_offset, size_t block_len) { + TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); + TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor"); + TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor"); + TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float || + tensor.scalar_type() == at::ScalarType::BFloat16, + "tensor must be a float or bfloat16 tensor"); + + const TensorWrapper tensor_cu = makeTransformerEngineTensor(tensor.contiguous()); + TensorWrapper amax_cu = makeTransformerEngineTensor(amax); + + nvte_nvfp4_2d_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w, amax.stride(0), + amax.stride(1), start_offset, block_len, + at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_2d_partial_cast(const at::Tensor& inp, py::handle out, const at::Tensor& scale, + const at::Tensor& global_scale, size_t h, size_t w, size_t start_offset, + size_t block_len) { + TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); + TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); + TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); + TORCH_CHECK(global_scale.numel() == 1, "global_scale must be a scalar tensor"); + TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, + "global_scale must be a float tensor"); + TORCH_CHECK( + inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16, + "input must be a float or bfloat16 tensor"); + + const TensorWrapper inp_cu = makeTransformerEngineTensor(inp.contiguous()); + const TensorWrapper out_cu = makeTransformerEngineTensor(out, py::none()); + const TensorWrapper scale_cu = makeTransformerEngineTensor(scale); + const TensorWrapper global_scale_cu = makeTransformerEngineTensor(global_scale); + + nvte_nvfp4_2d_partial_cast(inp_cu.data(), out_cu.data(), scale_cu.data(), global_scale_cu.data(), + h, w, scale.stride(0), scale.stride(1), start_offset, block_len, + at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_multi_tensor_2d_partial_cast(std::vector inp_list, + std::vector out_list, + std::vector scale_list, + std::vector global_scale_list, + std::vector h_list, std::vector w_list, + std::vector start_offset_list, int64_t block_len) { + TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); + + const size_t num_tensors = inp_list.size(); + TORCH_CHECK(out_list.size() == num_tensors, "out_list size mismatch"); + TORCH_CHECK(scale_list.size() == num_tensors, "scale_list size mismatch"); + TORCH_CHECK(global_scale_list.size() == num_tensors, "global_scale_list size mismatch"); + TORCH_CHECK(h_list.size() == num_tensors, "h_list size mismatch"); + TORCH_CHECK(w_list.size() == num_tensors, "w_list size mismatch"); + TORCH_CHECK(start_offset_list.size() == num_tensors, "start_offset_list size mismatch"); + + if (num_tensors == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + for (size_t i = 0; i < num_tensors; ++i) { + const auto& inp = inp_list[i]; + const auto& out = out_list[i]; + const auto& scale = scale_list[i]; + const auto& global_scale = global_scale_list[i]; + const size_t h = static_cast(h_list[i]); + const size_t w = static_cast(w_list[i]); + const size_t start_offset = static_cast(start_offset_list[i]); + + TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); + TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); + TORCH_CHECK(global_scale.numel() == 1, "global_scale must be a scalar tensor"); + TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, + "global_scale must be a float tensor"); + TORCH_CHECK( + inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16, + "input must be a float or bfloat16 tensor"); + + const TensorWrapper inp_cu = makeTransformerEngineTensor(inp.contiguous()); + const TensorWrapper out_cu = makeTransformerEngineTensor(out); + const TensorWrapper scale_cu = makeTransformerEngineTensor(scale); + const TensorWrapper global_scale_cu = makeTransformerEngineTensor(global_scale); + + nvte_nvfp4_2d_partial_cast(inp_cu.data(), out_cu.data(), scale_cu.data(), + global_scale_cu.data(), h, w, scale.stride(0), scale.stride(1), + start_offset, static_cast(block_len), stream); + } +} + +void nvfp4_multi_tensor_compute_partial_amax( + std::vector master_weight_list, std::vector partial_amax_list, + std::vector global_amax_list, std::vector h_list, + std::vector w_list, std::vector start_offset_list, int64_t block_len) { + TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); + + const size_t num_tensors = master_weight_list.size(); + TORCH_CHECK(partial_amax_list.size() == num_tensors, "partial_amax_list size mismatch"); + TORCH_CHECK(global_amax_list.size() == num_tensors, "global_amax_list size mismatch"); + TORCH_CHECK(h_list.size() == num_tensors, "h_list size mismatch"); + TORCH_CHECK(w_list.size() == num_tensors, "w_list size mismatch"); + TORCH_CHECK(start_offset_list.size() == num_tensors, "start_offset_list size mismatch"); + + if (num_tensors == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + for (size_t i = 0; i < num_tensors; ++i) { + const auto& master_weight = master_weight_list[i]; + auto& partial_amax = partial_amax_list[i]; + auto& global_amax = global_amax_list[i]; + const size_t h = static_cast(h_list[i]); + const size_t w = static_cast(w_list[i]); + const size_t start_offset = static_cast(start_offset_list[i]); + + TORCH_CHECK(partial_amax.dim() == 2, "partial_amax must be a 2D tensor"); + TORCH_CHECK(partial_amax.scalar_type() == at::ScalarType::Float, + "partial_amax must be a float tensor"); + TORCH_CHECK(master_weight.scalar_type() == at::ScalarType::Float || + master_weight.scalar_type() == at::ScalarType::BFloat16, + "master_weight must be a float or bfloat16 tensor"); + TORCH_CHECK(global_amax.scalar_type() == at::ScalarType::Float, + "global_amax must be a float tensor"); + TORCH_CHECK(global_amax.numel() == 1, "global_amax must have exactly one element"); + + // Compute partial amax (per-block amax) + const TensorWrapper tensor_cu = makeTransformerEngineTensor(master_weight.contiguous()); + TensorWrapper amax_cu = makeTransformerEngineTensor(partial_amax); + + nvte_nvfp4_2d_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w, + partial_amax.stride(0), partial_amax.stride(1), start_offset, + static_cast(block_len), stream); + + // Compute global amax + auto* global_amax_ptr = global_amax.data_ptr(); + TensorWrapper fake_te_output( + /*dptr=*/nullptr, tensor_cu.shape(), DType::kFloat32, global_amax_ptr); + + nvte_compute_amax(tensor_cu.data(), fake_te_output.data(), stream); + } +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c462e9236..27c2c562e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -37,9 +37,11 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; +PyTypeObject *GroupedTensorPythonClass = nullptr; +PyTypeObject *GroupedTensorStoragePythonClass = nullptr; +std::once_flag extension_init_flag; void init_float8_extension() { - if (Float8TensorPythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); Float8QuantizerClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); @@ -56,7 +58,6 @@ void init_float8_extension() { } void init_mxfp8_extension() { - if (MXFP8TensorPythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor"); MXFP8QuantizerClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer")); @@ -71,7 +72,6 @@ void init_mxfp8_extension() { } void init_float8blockwise_extension() { - if (Float8BlockwiseQTensorStoragePythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); auto fp8_base_module = py::module_::import( @@ -92,7 +92,6 @@ void init_float8blockwise_extension() { } void init_nvfp4_extensions() { - if (NVFP4TensorPythonClass) return; auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor"); NVFP4QuantizerClass = reinterpret_cast( PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer")); @@ -106,11 +105,30 @@ void init_nvfp4_extensions() { "Internal error: could not initialize pyTorch NVFP4 extension."); } +void init_grouped_tensor_extension() { + if (GroupedTensorPythonClass && GroupedTensorStoragePythonClass) return; + auto grouped_tensor_module = + py::module_::import("transformer_engine.pytorch.tensor.grouped_tensor"); + GroupedTensorPythonClass = reinterpret_cast( + PyObject_GetAttrString(grouped_tensor_module.ptr(), "GroupedTensor")); + auto grouped_tensor_storage_module = + py::module_::import("transformer_engine.pytorch.tensor.storage.grouped_tensor_storage"); + GroupedTensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(grouped_tensor_storage_module.ptr(), "GroupedTensorStorage")); + NVTE_CHECK(GroupedTensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch grouped tensor extension."); + NVTE_CHECK(GroupedTensorStoragePythonClass != nullptr, + "Internal error: could not initialize pyTorch grouped tensor extension."); +} + void init_extension() { - init_float8_extension(); - init_mxfp8_extension(); - init_float8blockwise_extension(); - init_nvfp4_extensions(); + std::call_once(extension_init_flag, []() { + init_float8_extension(); + init_mxfp8_extension(); + init_float8blockwise_extension(); + init_nvfp4_extensions(); + init_grouped_tensor_extension(); + }); } } // namespace transformer_engine::pytorch @@ -123,7 +141,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output") = py::none(), py::arg("noop") = py::none()); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); - + m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"), + py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", @@ -134,6 +153,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false, py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt); + /* GLU (sigmoid gate) */ + m.def("glu", transformer_engine::pytorch::glu, "GLU activation", py::arg("input"), + py::arg("quantizer")); /* GELU and variants*/ m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), py::arg("quantizer")); @@ -160,6 +182,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); + /* Backward of GLU */ + m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); /* Backward of GELU and variants */ m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); @@ -253,9 +278,56 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); + m.def("te_general_grouped_gemm_for_grouped_tensor", + &transformer_engine::pytorch::te_general_grouped_gemm_for_grouped_tensor, + "Grouped GEMM for GroupedTensor"); + m.def("te_general_grouped_gemm_for_discrete_in", + &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_in, + "Grouped GEMM for discrete A input list"); + m.def("te_general_grouped_gemm_for_discrete_out", + &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out, + "Grouped GEMM for discrete output list"); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); + m.def("nvfp4_data_transpose", &transformer_engine::pytorch::nvfp4_data_transpose, + "Transpose NVFP4 packed data with nibble repacking", py::arg("input"), py::kw_only(), + py::arg("out"), py::call_guard()); + m.def( + "nvfp4_2d_scale_transpose", &transformer_engine::pytorch::nvfp4_2d_scale_transpose, + "Transpose NVFP4 tile-level scales (E4M3 stored as uint8) from rowwise to columnwise format", + py::arg("input"), py::arg("output"), py::arg("M_tiles"), py::arg("K_tiles"), + py::call_guard()); + m.def("nvfp4_expand_scale_to_fp8", &transformer_engine::pytorch::nvfp4_expand_scale_to_fp8, + "Expand tile-level scales to row-level scales and convert to FP8 E4M3", py::arg("input"), + py::arg("output"), py::arg("tile_rows"), py::arg("tile_cols"), py::arg("rows_padded"), + py::arg("block_len"), py::call_guard()); + m.def("nvfp4_compute_per_block_scale", + &transformer_engine::pytorch::nvfp4_compute_per_block_scale, + "Compute per-block decode scale from block amax and global amax", py::arg("block_amax"), + py::arg("scale"), py::arg("global_amax"), py::call_guard()); + m.def("nvfp4_compute_global_scale", &transformer_engine::pytorch::nvfp4_compute_global_scale, + "Compute global encode scale from global amax", py::arg("global_amax"), + py::arg("global_scale"), py::call_guard()); + m.def("nvfp4_fused_scale", &transformer_engine::pytorch::nvfp4_fused_scale, + "Fused kernel: compute per-block decode scale, copy global amax, expand to row-level FP8", + py::arg("block_amax"), py::arg("global_amax"), py::arg("per_block_scale"), + py::arg("target_scale"), py::arg("target_amax"), py::arg("tile_rows"), py::arg("tile_cols"), + py::arg("rows_padded"), py::arg("block_len"), py::call_guard()); + m.def("nvfp4_multi_tensor_fused_scale", + &transformer_engine::pytorch::nvfp4_multi_tensor_fused_scale, + "Batched fused scale: compute per-block decode scale, copy global amax, expand to FP8 for " + "multiple tensors", + py::arg("block_amax_list"), py::arg("global_amax_list"), py::arg("per_block_scale_list"), + py::arg("target_scale_list"), py::arg("target_amax_list"), py::arg("tile_rows_list"), + py::arg("tile_cols_list"), py::arg("rows_padded_list"), py::arg("block_len"), + py::call_guard()); + m.def("nvfp4_2d_multi_tensor_transpose", + &transformer_engine::pytorch::nvfp4_2d_multi_tensor_transpose, + "Batched NVFP4 columnwise creation: transpose data and scales for multiple tensors", + py::arg("rowwise_data_list"), py::arg("columnwise_data_list"), + py::arg("rowwise_scale_inv_list"), py::arg("columnwise_scale_inv_list"), py::arg("M_list"), + py::arg("K_list"), py::call_guard()); m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims, "Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"), py::call_guard()); @@ -278,6 +350,29 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"), py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), py::arg("out_dtype"), py::call_guard()); + // NVFP4 2D + m.def("nvfp4_2d_compute_partial_amax", + &transformer_engine::pytorch::nvfp4_2d_compute_partial_amax, + "Compute partial amax from master weights for NVFP4 2D", py::arg("tensor"), py::arg("amax"), + py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len") = 16, + py::call_guard()); + m.def("nvfp4_multi_tensor_compute_partial_amax", + &transformer_engine::pytorch::nvfp4_multi_tensor_compute_partial_amax, + "Batched compute partial and global amax from master weights for NVFP4 2D", + py::arg("master_weight_list"), py::arg("partial_amax_list"), py::arg("global_amax_list"), + py::arg("h_list"), py::arg("w_list"), py::arg("start_offset_list"), + py::arg("block_len") = 16, py::call_guard()); + m.def("nvfp4_2d_partial_cast", &transformer_engine::pytorch::nvfp4_2d_partial_cast, + "Partial cast from master weights for NVFP4 2D", py::arg("inp"), py::arg("out"), + py::arg("scale"), py::arg("global_scale"), py::arg("h"), py::arg("w"), + py::arg("start_offset"), py::arg("block_len") = 16, + py::call_guard()); + m.def("nvfp4_multi_tensor_2d_partial_cast", + &transformer_engine::pytorch::nvfp4_multi_tensor_2d_partial_cast, + "Batched partial cast from master weights for NVFP4 2D", py::arg("inp_list"), + py::arg("out_list"), py::arg("scale_list"), py::arg("global_scale_list"), py::arg("h_list"), + py::arg("w_list"), py::arg("start_offset_list"), py::arg("block_len") = 16, + py::call_guard()); m.def("mxfp8_scaling_compute_partial_amax", &transformer_engine::pytorch::mxfp8_scaling_compute_partial_amax, "Compute partial amax from master weights for fp8 mxfp8 scaling", py::arg("input"), @@ -327,19 +422,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::fused_topk_with_score_function_fwd, py::arg("logits"), py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"), py::arg("group_topk"), py::arg("scaling_factor"), py::arg("score_function"), py::arg("expert_bias"), - "Fused topk softmax fwd"); + "Fused topk with score function fwd"); m.def("fused_topk_with_score_function_bwd", &transformer_engine::pytorch::fused_topk_with_score_function_bwd, py::arg("num_tokens"), py::arg("num_experts"), py::arg("routing_map"), py::arg("intermediate_output"), - py::arg("grad_probs"), py::arg("topk"), py::arg("use_pre_softmax"), - py::arg("scaling_factor"), py::arg("score_function"), "Fused topk softmax bwd"); + py::arg("grad_probs"), py::arg("grad_logits"), py::arg("topk"), py::arg("use_pre_softmax"), + py::arg("scaling_factor"), py::arg("score_function"), "Fused topk with score function bwd"); m.def("fused_score_for_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_score_for_moe_aux_loss_fwd, py::arg("logits"), - py::arg("topk"), py::arg("score_function"), "Fused topk softmax fwd"); + py::arg("topk"), py::arg("score_function"), "Fused aux loss with score function fwd"); m.def("fused_score_for_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_score_for_moe_aux_loss_bwd, py::arg("num_tokens"), py::arg("num_experts"), py::arg("intermediate_output"), py::arg("grad_scores"), - py::arg("topk"), py::arg("score_function"), "Fused topk softmax bwd"); + py::arg("grad_logits"), py::arg("topk"), py::arg("score_function"), + "Fused aux loss with score function bwd"); m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd, py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"), py::arg("num_experts"), py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"), @@ -362,6 +458,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version", py::call_guard()); #endif + m.def("splits_to_offsets", &transformer_engine::pytorch::splits_to_offsets, + "Compute grouped tensor offsets from split sizes", py::arg("first_dims"), + py::arg("logical_last_dim"), py::call_guard()); m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams", py::call_guard()); @@ -429,6 +528,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_scale", &transformer_engine::pytorch::multi_tensor_scale_cuda, "Fused overflow check + scale for a list of contiguous tensors", py::call_guard()); + m.def("multi_tensor_scale_tensor", &transformer_engine::pytorch::multi_tensor_scale_tensor_cuda, + "Fused overflow check + scale for a list of contiguous tensors with scale passed as tensor", + py::call_guard()); m.def("multi_tensor_l2norm", &transformer_engine::pytorch::multi_tensor_l2norm_cuda, "Computes L2 norm for a list of contiguous tensors", py::call_guard()); @@ -476,7 +578,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard(), py::arg("allgather_communicator"), py::arg("send_stream"), py::arg("recv_stream")); #else - m.def("bulk_overlap_ag_with_external_gemm", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); + m.def("bulk_overlap_ag_with_external_gemm", &transformer_engine::pytorch::placeholder, + "Dummy function for python side annotations"); #endif // Data structures @@ -521,8 +624,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) - .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), - py::arg("local_chunk") = false) + .def("copy_into_buffer", + static_cast( + &CommOverlap::copy_into_buffer), + py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) .def("get_communication_stream", &CommOverlap::get_communication_stream); @@ -539,8 +644,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) - .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), - py::arg("local_chunk") = false) + .def("copy_into_buffer", + static_cast( + &CommOverlapP2P::copy_into_buffer), + py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) .def("get_communication_stream", &CommOverlapP2P::get_communication_stream); diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 2ae0d648a..94625c0f1 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -9,12 +9,13 @@ namespace transformer_engine::pytorch { -static std::map score_function_map = {{"sigmoid", 0}, {"softmax", 1}}; +static std::map score_function_map = { + {"sigmoid", 0}, {"softmax", 1}, {"sqrtsoftplus", 2}}; std::tuple fused_topk_with_score_function_fwd( - at::Tensor logits, int topk, bool use_pre_softmax, c10::optional num_groups, - c10::optional group_topk, c10::optional scaling_factor, std::string score_function, - c10::optional expert_bias) { + at::Tensor logits, int topk, bool use_pre_softmax, std::optional num_groups, + std::optional group_topk, std::optional scaling_factor, std::string score_function, + std::optional expert_bias) { int num_tokens = logits.size(0); int num_experts = logits.size(1); // Check if the input is valid @@ -22,13 +23,16 @@ std::tuple fused_topk_with_score_function_fw "num_tokens and num_experts must be greater than 0"); // Expert bias only happens at the sigmoid case if (expert_bias.has_value()) { - TORCH_CHECK(score_function == "sigmoid", - "score_function must be sigmoid when expert_bias is not None"); + TORCH_CHECK(score_function == "sigmoid" || score_function == "sqrtsoftplus", + "score_function must be sigmoid or sqrtsoftplus when expert_bias is not None"); + TORCH_CHECK(expert_bias.value().scalar_type() == at::kFloat, + "expert_bias must be a float32 tensor"); } // Check if the score function is valid - TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid", - "score_function must be softmax or sigmoid for router fusion"); - if (score_function == "sigmoid") { + TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid" || + score_function == "sqrtsoftplus", + "score_function must be softmax, sigmoid or sqrtsoftplus for router fusion"); + if (score_function == "sigmoid" || score_function == "sqrtsoftplus") { use_pre_softmax = false; // Pre-softmax only happens at the softmax case } @@ -44,7 +48,7 @@ std::tuple fused_topk_with_score_function_fw at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA)); // Intermediate output is used to store the output of the softmax/sigmoid function at::Tensor intermediate_output = - at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA)); + at::empty({num_tokens, num_experts}, at::dtype(at::kFloat).device(at::kCUDA)); auto logits_cu = makeTransformerEngineTensor(logits); auto probs_cu = makeTransformerEngineTensor(probs); @@ -64,18 +68,14 @@ std::tuple fused_topk_with_score_function_fw return std::make_tuple(probs, routing_map, intermediate_output); } -at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts, - at::Tensor routing_map, - at::Tensor intermediate_output, at::Tensor grad_probs, - int topk, bool use_pre_softmax, - c10::optional scaling_factor, - std::string score_function) { +void fused_topk_with_score_function_bwd(int num_tokens, int num_experts, at::Tensor routing_map, + at::Tensor intermediate_output, at::Tensor grad_probs, + at::Tensor grad_logits, int topk, bool use_pre_softmax, + std::optional scaling_factor, + std::string score_function) { // Get the value of the parameters auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; auto score_function_value = score_function_map[score_function]; - // Init the output tensor - at::Tensor grad_logits = at::empty( - {num_tokens, num_experts}, at::dtype(intermediate_output.scalar_type()).device(at::kCUDA)); auto routing_map_cu = makeTransformerEngineTensor(routing_map); auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output); @@ -86,8 +86,6 @@ at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts, routing_map_cu.data(), intermediate_output_cu.data(), grad_probs_cu.data(), num_tokens, num_experts, topk, use_pre_softmax, scaling_factor_value, score_function_value, grad_logits_cu.data(), at::cuda::getCurrentCUDAStream()); - - return grad_logits; } std::tuple fused_score_for_moe_aux_loss_fwd( @@ -99,17 +97,17 @@ std::tuple fused_score_for_moe_aux_loss_fwd( "num_tokens and num_experts must be greater than 0"); TORCH_CHECK(topk > 0, "topk must be greater than 0"); // Check if the score function is valid - TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid", - "score_function must be softmax or sigmoid for router fusion"); + TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid" || + score_function == "sqrtsoftplus", + "score_function must be softmax, sigmoid or sqrtsoftplus for router fusion"); int score_function_value = score_function_map[score_function]; // Construct the output tensor - at::Tensor scores = - at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA)); + at::Tensor scores = at::empty({num_tokens, num_experts}, at::dtype(at::kFloat).device(at::kCUDA)); at::Tensor routing_map = at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA)); at::Tensor intermediate_output = - at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA)); + at::empty({num_tokens, num_experts}, at::dtype(at::kFloat).device(at::kCUDA)); auto logits_cu = makeTransformerEngineTensor(logits); auto scores_cu = makeTransformerEngineTensor(scores); @@ -123,14 +121,12 @@ std::tuple fused_score_for_moe_aux_loss_fwd( return std::make_tuple(scores, routing_map, intermediate_output); } -at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, - at::Tensor intermediate_output, at::Tensor grad_scores, - int topk, std::string score_function) { +void fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, + at::Tensor intermediate_output, at::Tensor grad_scores, + at::Tensor grad_logits, int topk, + std::string score_function) { // Get the value of the parameters int score_function_value = score_function_map[score_function]; - // Init the output tensor - at::Tensor grad_logits = at::empty( - {num_tokens, num_experts}, at::dtype(intermediate_output.scalar_type()).device(at::kCUDA)); auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output); auto grad_scores_cu = makeTransformerEngineTensor(grad_scores); @@ -139,8 +135,6 @@ at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, nvte_fused_score_for_moe_aux_loss_backward( intermediate_output_cu.data(), grad_scores_cu.data(), num_tokens, num_experts, topk, score_function_value, grad_logits_cu.data(), at::cuda::getCurrentCUDAStream()); - - return grad_logits; } std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index 4ad57bbf1..bd5524b56 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -37,6 +37,13 @@ void reset_tensor_data(transformer_engine::TensorWrapper &tensor, bool rowwise, } } +bool is_empty_grouped_tensor_param(const NVTEBasicTensor &t) { + if (t.data_ptr == nullptr) { + return true; + } + return t.shape.ndim == 1 && t.shape.data[0] == 0; +} + } // namespace std::tuple, std::optional> swizzle_scales_for_gemm( @@ -335,6 +342,83 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp } #endif // !USE_ROCM +std::optional maybe_swizzle_grouped_tensor_for_gemm( + GroupedTensorWrapper &input) { + if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { + return std::nullopt; + } + if (input.get_with_gemm_swizzled_scales()) { + return std::nullopt; + } + + const auto row_scales = input.get_rowwise_scale_inv(); + const auto col_scales = input.get_columnwise_scale_inv(); + const bool has_rowwise_scales = !is_empty_grouped_tensor_param(row_scales); + const bool has_columnwise_scales = !is_empty_grouped_tensor_param(col_scales); + if (!has_rowwise_scales && !has_columnwise_scales) { + return std::nullopt; + } + const auto first_dims = input.get_first_dims(); + const auto last_dims = input.get_last_dims(); + if (first_dims.data_ptr != nullptr || last_dims.data_ptr != nullptr) { + NVTE_ERROR( + "Grouped GEMM swizzle requires uniform shapes for now (first_dims/last_dims must be " + "absent)."); + } + + std::optional rowwise_scales_pyt; + std::optional columnwise_scales_pyt; + GroupedTensorWrapper output(input.num_tensors(), input.logical_shape(), input.scaling_mode()); + + const auto rowwise_data = input.get_rowwise_data(); + if (rowwise_data.data_ptr != nullptr) { + output.set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + } + const auto columnwise_data = input.get_columnwise_data(); + if (columnwise_data.data_ptr != nullptr) { + output.set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); + } + const auto tensor_offsets = input.get_tensor_offsets(); + if (tensor_offsets.data_ptr != nullptr) { + output.set_tensor_offsets(tensor_offsets.data_ptr, static_cast(tensor_offsets.dtype), + tensor_offsets.shape); + } + + if (has_rowwise_scales) { + const auto scales_dtype = static_cast(row_scales.dtype); + rowwise_scales_pyt = allocateSpace(row_scales.shape, scales_dtype, false); + void *output_scales_dptr = getDataPtr(*rowwise_scales_pyt); + output.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, row_scales.shape); + } + if (has_columnwise_scales) { + const auto scales_dtype = static_cast(col_scales.dtype); + columnwise_scales_pyt = allocateSpace(col_scales.shape, scales_dtype, false); + void *output_scales_dptr = getDataPtr(*columnwise_scales_pyt); + output.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, col_scales.shape); + } + + output.set_with_gemm_swizzled_scales(true); + NVTE_SCOPED_GIL_RELEASE({ + nvte_swizzle_grouped_scaling_factors(input.data(), output.data(), + at::cuda::getCurrentCUDAStream()); + }); + + if (has_rowwise_scales) { + const auto scales_dtype = static_cast(row_scales.dtype); + input.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, row_scales.shape); + } + if (has_columnwise_scales) { + const auto scales_dtype = static_cast(col_scales.dtype); + input.set_columnwise_scale_inv(getDataPtr(*columnwise_scales_pyt), scales_dtype, + col_scales.shape); + } + input.set_with_gemm_swizzled_scales(true); + + return SwizzledGroupedScales{std::move(rowwise_scales_pyt), std::move(columnwise_scales_pyt)}; +} + void inplace_swizzle_scale_for_gemm(py::handle &tensor) { // Convert Python tensor to C++ tensor auto tensor_nvte = makeTransformerEngineTensor(tensor, py::none()); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 477d7c87e..aaa27a104 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -5,6 +5,8 @@ ************************************************************************/ #include +#include +#include #include #include @@ -52,11 +54,218 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output) { + init_extension(); + + // Input is packed FP4: logical [M, K] stored as [M, K/2] bytes + // Output is packed FP4: logical [K, M] stored as [K, M/2] bytes + const auto shape = getTensorShape(input); + NVTE_CHECK(shape.size() == 2, "NVFP4 transpose expects 2D input (packed storage)."); + + const size_t M = shape[0]; + const size_t K_packed = shape[1]; + const size_t K = K_packed * 2; // logical K + const size_t M_packed = M / 2; + + NVTE_CHECK(M % 2 == 0, "NVFP4 transpose requires M (", M, ") to be even."); + + // Output shape: [K, M/2] + std::vector output_shape = {static_cast(K), static_cast(M_packed)}; + + // Output tensor + at::Tensor out; + if (output.has_value()) { + out = *output; + NVTE_CHECK( + static_cast(out.size(0)) == K && static_cast(out.size(1)) == M_packed, + "Output shape mismatch for NVFP4 transpose."); + } else { + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + out = at::empty(output_shape, opts); + } + + // Return immediately if tensor is empty + if (M == 0 || K == 0) { + return out; + } + + // Call the NVFP4 transpose kernel + auto input_cu = + makeTransformerEngineTensor(input.data_ptr(), std::vector{M, K_packed}, DType::kByte); + auto output_cu = + makeTransformerEngineTensor(out.data_ptr(), std::vector{K, M_packed}, DType::kByte); + nvte_nvfp4_data_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return out; +} + +void nvfp4_2d_scale_transpose(at::Tensor input, at::Tensor output, int64_t M_tiles, + int64_t K_tiles) { + init_extension(); + + // Input: rowwise_scale_inv [M_padded, K_tiles], uint8 (E4M3 stored as bytes) + // Output: columnwise_scale_inv [K_padded, M_tiles], uint8 (E4M3 stored as bytes) + const auto in_shape = getTensorShape(input); + const auto out_shape = getTensorShape(output); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 scale transpose expects 2D input."); + NVTE_CHECK(out_shape.size() == 2, "NVFP4 scale transpose expects 2D output."); + NVTE_CHECK(input.scalar_type() == at::kByte, "NVFP4 scale transpose input must be uint8 (E4M3)."); + NVTE_CHECK(output.scalar_type() == at::kByte, + "NVFP4 scale transpose output must be uint8 (E4M3)."); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), std::vector{in_shape[0], in_shape[1]}, DType::kByte); + auto output_cu = makeTransformerEngineTensor( + output.data_ptr(), std::vector{out_shape[0], out_shape[1]}, DType::kByte); + + nvte_nvfp4_scale_transpose(input_cu.data(), output_cu.data(), static_cast(M_tiles), + static_cast(K_tiles), at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, int64_t tile_rows, + int64_t tile_cols, int64_t rows_padded, int64_t block_len) { + init_extension(); + + // Input: per_block_decode_scale [tile_rows, tile_cols], float32 + // Output: target_scale [rows_padded, tile_cols], uint8 (E4M3) + const auto in_shape = getTensorShape(input); + const auto out_shape = getTensorShape(output); + NVTE_CHECK(in_shape.size() == 2, "NVFP4 expand scale expects 2D input."); + NVTE_CHECK(out_shape.size() == 2, "NVFP4 expand scale expects 2D output."); + NVTE_CHECK(input.scalar_type() == at::kFloat, "NVFP4 expand scale input must be float32."); + NVTE_CHECK(output.scalar_type() == at::kByte, "NVFP4 expand scale output must be uint8 (E4M3)."); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), std::vector{in_shape[0], in_shape[1]}, DType::kFloat32); + auto output_cu = makeTransformerEngineTensor( + output.data_ptr(), std::vector{out_shape[0], out_shape[1]}, DType::kByte); + + nvte_nvfp4_expand_scale_to_fp8(input_cu.data(), output_cu.data(), static_cast(tile_rows), + static_cast(tile_cols), static_cast(rows_padded), + static_cast(block_len), at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, + at::Tensor global_amax) { + init_extension(); + + // block_amax and scale: [tile_rows, tile_cols], float32 + // global_amax: single element tensor, float32 (avoids D2H transfer) + NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); + NVTE_CHECK(scale.scalar_type() == at::kFloat, "Scale must be float32."); + NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + + auto block_amax_cu = makeTransformerEngineTensor(block_amax); + auto scale_cu = makeTransformerEngineTensor(scale); + auto global_amax_cu = makeTransformerEngineTensor(global_amax); + + nvte_nvfp4_compute_per_block_scale(block_amax_cu.data(), scale_cu.data(), global_amax_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, at::Tensor per_block_scale, + at::Tensor target_scale, at::Tensor target_amax, int64_t tile_rows, + int64_t tile_cols, int64_t rows_padded, int64_t block_len) { + init_extension(); + + // block_amax: [tile_rows, tile_cols], float32 + // global_amax: [1], float32 + // per_block_scale: [tile_rows, tile_cols], float32 (for partial_cast) + // target_scale: [rows_padded, tile_cols], uint8 (E4M3) + // target_amax: [1], float32 + NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); + NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); + NVTE_CHECK(per_block_scale.scalar_type() == at::kFloat, "Per-block scale must be float32."); + NVTE_CHECK(target_scale.scalar_type() == at::kByte, "Target scale must be uint8 (E4M3)."); + NVTE_CHECK(target_amax.scalar_type() == at::kFloat, "Target amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + NVTE_CHECK(target_amax.numel() == 1, "Target amax must be a single element tensor."); + + auto block_amax_cu = makeTransformerEngineTensor(block_amax); + auto global_amax_cu = makeTransformerEngineTensor(global_amax); + auto per_block_scale_cu = makeTransformerEngineTensor(per_block_scale); + auto target_scale_cu = makeTransformerEngineTensor(target_scale); + auto target_amax_cu = makeTransformerEngineTensor(target_amax); + + nvte_nvfp4_fused_scale(block_amax_cu.data(), global_amax_cu.data(), per_block_scale_cu.data(), + target_scale_cu.data(), target_amax_cu.data(), + static_cast(tile_rows), static_cast(tile_cols), + static_cast(rows_padded), static_cast(block_len), + at::cuda::getCurrentCUDAStream()); +} + +void nvfp4_multi_tensor_fused_scale( + std::vector block_amax_list, std::vector global_amax_list, + std::vector per_block_scale_list, std::vector target_scale_list, + std::vector target_amax_list, std::vector tile_rows_list, + std::vector tile_cols_list, std::vector rows_padded_list, int64_t block_len) { + init_extension(); + + const size_t num_tensors = block_amax_list.size(); + NVTE_CHECK(global_amax_list.size() == num_tensors, "global_amax_list size mismatch"); + NVTE_CHECK(per_block_scale_list.size() == num_tensors, "per_block_scale_list size mismatch"); + NVTE_CHECK(target_scale_list.size() == num_tensors, "target_scale_list size mismatch"); + NVTE_CHECK(target_amax_list.size() == num_tensors, "target_amax_list size mismatch"); + NVTE_CHECK(tile_rows_list.size() == num_tensors, "tile_rows_list size mismatch"); + NVTE_CHECK(tile_cols_list.size() == num_tensors, "tile_cols_list size mismatch"); + NVTE_CHECK(rows_padded_list.size() == num_tensors, "rows_padded_list size mismatch"); + + if (num_tensors == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + for (size_t i = 0; i < num_tensors; ++i) { + const auto& block_amax = block_amax_list[i]; + const auto& global_amax = global_amax_list[i]; + auto& per_block_scale = per_block_scale_list[i]; + auto& target_scale = target_scale_list[i]; + auto& target_amax = target_amax_list[i]; + const size_t tile_rows = static_cast(tile_rows_list[i]); + const size_t tile_cols = static_cast(tile_cols_list[i]); + const size_t rows_padded = static_cast(rows_padded_list[i]); + + NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); + NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); + NVTE_CHECK(per_block_scale.scalar_type() == at::kFloat, "Per-block scale must be float32."); + NVTE_CHECK(target_scale.scalar_type() == at::kByte, "Target scale must be uint8 (E4M3)."); + NVTE_CHECK(target_amax.scalar_type() == at::kFloat, "Target amax must be float32."); + NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); + NVTE_CHECK(target_amax.numel() == 1, "Target amax must be a single element tensor."); + + auto block_amax_cu = makeTransformerEngineTensor(block_amax); + auto global_amax_cu = makeTransformerEngineTensor(global_amax); + auto per_block_scale_cu = makeTransformerEngineTensor(per_block_scale); + auto target_scale_cu = makeTransformerEngineTensor(target_scale); + auto target_amax_cu = makeTransformerEngineTensor(target_amax); + + nvte_nvfp4_fused_scale(block_amax_cu.data(), global_amax_cu.data(), per_block_scale_cu.data(), + target_scale_cu.data(), target_amax_cu.data(), tile_rows, tile_cols, + rows_padded, static_cast(block_len), stream); + } +} + +void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale) { + init_extension(); + + // global_amax and global_scale: [num_params], float32 + NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); + NVTE_CHECK(global_scale.scalar_type() == at::kFloat, "Global scale must be float32."); + + auto global_amax_cu = makeTransformerEngineTensor(global_amax); + auto global_scale_cu = makeTransformerEngineTensor(global_scale); + + nvte_nvfp4_compute_global_scale(global_amax_cu.data(), global_scale_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { init_extension(); // Make sure input is contiguous - const auto &input = tensor.contiguous(); + const auto& input = tensor.contiguous(); // Allocate output tensor if needed if (!out) { @@ -77,5 +286,70 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { return std::move(*out); } +void nvfp4_2d_multi_tensor_transpose(std::vector rowwise_data_list, + std::vector columnwise_data_list, + std::vector rowwise_scale_inv_list, + std::vector columnwise_scale_inv_list, + std::vector M_list, std::vector K_list) { + init_extension(); + + const size_t num_tensors = rowwise_data_list.size(); + NVTE_CHECK(columnwise_data_list.size() == num_tensors, "Tensor list size mismatch"); + NVTE_CHECK(rowwise_scale_inv_list.size() == num_tensors, "Tensor list size mismatch"); + NVTE_CHECK(columnwise_scale_inv_list.size() == num_tensors, "Tensor list size mismatch"); + NVTE_CHECK(M_list.size() == num_tensors, "M_list size mismatch"); + NVTE_CHECK(K_list.size() == num_tensors, "K_list size mismatch"); + + if (num_tensors == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + // Process each tensor - the main benefit is reduced Python overhead + // by doing the iteration in C++ rather than Python + constexpr size_t TILE_SIZE = 16; + + for (size_t i = 0; i < num_tensors; ++i) { + const auto& rowwise_data = rowwise_data_list[i]; + auto& columnwise_data = columnwise_data_list[i]; + const auto& rowwise_scale_inv = rowwise_scale_inv_list[i]; + auto& columnwise_scale_inv = columnwise_scale_inv_list[i]; + const int64_t M = M_list[i]; + const int64_t K = K_list[i]; + + // Transpose data: [M, K/2] -> [K, M/2] + const auto data_shape = getTensorShape(rowwise_data); + NVTE_CHECK(data_shape.size() == 2, "NVFP4 data must be 2D."); + const size_t M_packed = static_cast(M) / 2; + const size_t K_packed = data_shape[1]; + + auto input_cu = makeTransformerEngineTensor( + rowwise_data.data_ptr(), std::vector{static_cast(M), K_packed}, + DType::kByte); + auto output_cu = makeTransformerEngineTensor( + columnwise_data.data_ptr(), std::vector{static_cast(K), M_packed}, + DType::kByte); + nvte_nvfp4_data_transpose(input_cu.data(), output_cu.data(), stream); + + // Transpose scales + const size_t M_tiles = (static_cast(M) + TILE_SIZE - 1) / TILE_SIZE; + const size_t K_tiles = (static_cast(K) + TILE_SIZE - 1) / TILE_SIZE; + + const auto scale_in_shape = getTensorShape(rowwise_scale_inv); + const auto scale_out_shape = getTensorShape(columnwise_scale_inv); + + auto scale_input_cu = makeTransformerEngineTensor( + rowwise_scale_inv.data_ptr(), std::vector{scale_in_shape[0], scale_in_shape[1]}, + DType::kByte); + auto scale_output_cu = makeTransformerEngineTensor( + columnwise_scale_inv.data_ptr(), + std::vector{scale_out_shape[0], scale_out_shape[1]}, DType::kByte); + + nvte_nvfp4_scale_transpose(scale_input_cu.data(), scale_output_cu.data(), M_tiles, K_tiles, + stream); + } +} + } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index d5fd4a4fe..e4575002a 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -45,6 +45,8 @@ extern PyTypeObject *Float8BlockwiseQuantizerClass; extern PyTypeObject *NVFP4TensorPythonClass; extern PyTypeObject *NVFP4TensorStoragePythonClass; extern PyTypeObject *NVFP4QuantizerClass; +extern PyTypeObject *GroupedTensorPythonClass; +extern PyTypeObject *GroupedTensorStoragePythonClass; void init_extension(); @@ -97,6 +99,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer); +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index d6a9f7e1d..1efba2de6 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -33,6 +33,23 @@ std::vector make_transpose_shape(const std::vector& shape) { return ret; } +/*! @brief Calculate stride from shape for contiguous tensors */ +template +std::vector stride_from_shape(const std::vector& shape) { + std::vector stride; + if (shape.empty()) { + return stride; + } + std::vector rstride; + rstride.reserve(shape.size()); + rstride.push_back(static_cast(1)); + for (size_t i = shape.size(); i > 1; --i) { + rstride.push_back(rstride.back() * shape[i - 1]); + } + stride.assign(rstride.rbegin(), rstride.rend()); + return stride; +} + /*! @brief Convert shape for FP4 data by dividing the last dimension by 2 */ template std::vector convert_shape_for_fp4(const std::vector& shape) { @@ -44,6 +61,44 @@ std::vector convert_shape_for_fp4(const std::vector& shape) { return ret; } +std::optional build_grouped_tensor_offsets(const size_t num_tensors, + const std::optional& first_dims, + const size_t logical_last_dim) { + if (!first_dims.has_value()) { + return std::nullopt; + } + + const auto& first_dims_tensor = first_dims.value(); + NVTE_CHECK(first_dims_tensor.is_cuda(), "first_dims must be on CUDA."); + NVTE_CHECK(first_dims_tensor.scalar_type() == at::kLong, "first_dims must have dtype int64."); + NVTE_CHECK(static_cast(first_dims_tensor.numel()) == num_tensors, + "first_dims must have length ", num_tensors, "."); + + const int64_t logical_last_dim_i64 = static_cast(logical_last_dim); + const auto first_dims_contiguous = first_dims_tensor.contiguous(); + auto tensor_offsets = + at::empty({static_cast(num_tensors) + 1}, first_dims_contiguous.options()); + NVTE_SCOPED_GIL_RELEASE({ + nvte_splits_to_offsets(static_cast(first_dims_contiguous.data_ptr()), + static_cast(tensor_offsets.data_ptr()), num_tensors, + logical_last_dim_i64, at::cuda::getCurrentCUDAStream()); + }); + return tensor_offsets; +} + +at::TensorOptions grouped_tensor_data_options(const DType dtype) { + return at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); +} + +py::object maybe_tensor_to_py(const std::optional& tensor) { + return tensor ? py::cast(*tensor) : py::none(); +} + +py::handle grouped_tensor_python_class(const bool internal) { + PyTypeObject* cls = internal ? GroupedTensorStoragePythonClass : GroupedTensorPythonClass; + return py::handle(reinterpret_cast(cls)); +} + } // namespace constexpr size_t NVFP4_BLOCK_SIZE = 16; @@ -90,6 +145,76 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } +std::pair NoneQuantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + std::optional rowwise_data; + std::optional columnwise_data; + const bool with_rowwise_data = rowwise_usage; + const bool with_columnwise_data = columnwise_usage; + if (with_rowwise_data) { + rowwise_data = at::empty({total_elements}, grouped_tensor_data_options(dtype)); + } + if (with_columnwise_data) { + columnwise_data = at::empty({total_elements}, grouped_tensor_data_options(dtype)); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (with_rowwise_data) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), dtype, getTensorShape(*rowwise_data)); + } + if (with_columnwise_data) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), dtype, + getTensorShape(*columnwise_data)); + } + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); + py::dict kwargs; + py::tuple args(0); + const std::vector grouped_shape = {static_cast(logical_first_dim), + static_cast(logical_last_dim)}; + const std::vector grouped_stride = stride_from_shape(grouped_shape); + kwargs["shape"] = py::cast(grouped_shape); + kwargs["stride"] = py::cast(grouped_stride); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["num_tensors"] = py::cast(num_tensors); + kwargs["quantizer"] = quantizer; + kwargs["data"] = maybe_tensor_to_py(rowwise_data); + kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); + kwargs["scale_inv"] = py::none(); + kwargs["columnwise_scale_inv"] = py::none(); + kwargs["amax"] = py::none(); + kwargs["columnwise_amax"] = py::none(); + kwargs["scale"] = py::none(); + kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); + kwargs["last_dims"] = py::none(); + kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); + kwargs["with_gemm_swizzled_scales"] = py::cast(false); + PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); + py::object out_py = py::reinterpret_steal(result); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair NoneQuantizer::convert_and_update_tensor( py::object tensor) const { auto tensor_pyt = tensor.cast(); @@ -125,9 +250,9 @@ std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Initialize data tensor - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data && !data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -138,7 +263,7 @@ std::pair Float8Quantizer::create_tensor( py::object data_py = with_data ? py::cast(*data) : py::none(); // Initialize transpose tensor - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose && !transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -147,26 +272,59 @@ std::pair Float8Quantizer::create_tensor( transpose.reset(); } py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); - // Initialize scale-inverse tensor if (!scale_inv) { scale_inv = at::reciprocal(scale); } - + py::object scale_inv_py = py::cast(*scale_inv); + at::Device device = + with_data ? data->device() + : (with_transpose ? transpose->device() + : at::Device(torch::kCUDA, c10::cuda::current_device())); // Construct Python FP8 tensor py::object out_py; if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + py::tuple args(0); + kwargs["data"] = data_py; + kwargs["fp8_scale_inv"] = scale_inv_py; + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["data_transpose"] = transpose_py; + kwargs["quantizer"] = this->quantizer; + kwargs["fake_dtype"] = GetATenDType(dtype); + + PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); - out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + const auto stride_int64 = stride_from_shape(shape_int64); + + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + py::tuple args(0); + kwargs["shape"] = py::cast(shape_int64); + kwargs["stride"] = py::cast(stride_int64); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["data"] = data_py; + kwargs["fp8_scale_inv"] = scale_inv_py; + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["data_transpose"] = transpose_py; + kwargs["quantizer"] = this->quantizer; + kwargs["device"] = py::cast(device); + PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorPythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ FP8 tensor @@ -186,13 +344,95 @@ std::pair Float8Quantizer::create_tensor( return {std::move(out_cpp), std::move(out_py)}; } +std::pair Float8Quantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + at::Tensor amax = at::empty({static_cast(num_tensors)}, float_opts); + + if (rowwise_usage) { + rowwise_data = at::empty({total_elements}, uint8_opts); + rowwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); + } + if (columnwise_usage) { + columnwise_data = at::empty({total_elements}, uint8_opts); + columnwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*columnwise_scale_inv)); + } + out_cpp.set_amax(amax.data_ptr(), DType::kFloat32, getTensorShape(amax)); + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); + py::dict kwargs; + py::tuple args(0); + const std::vector grouped_shape = {static_cast(logical_first_dim), + static_cast(logical_last_dim)}; + const std::vector grouped_stride = stride_from_shape(grouped_shape); + kwargs["shape"] = py::cast(grouped_shape); + kwargs["stride"] = py::cast(grouped_stride); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["num_tensors"] = py::cast(num_tensors); + kwargs["quantizer"] = quantizer; + kwargs["data"] = maybe_tensor_to_py(rowwise_data); + kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); + kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); + kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); + kwargs["amax"] = amax; + kwargs["columnwise_amax"] = py::none(); + kwargs["scale"] = py::none(); + kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); + kwargs["last_dims"] = py::none(); + kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); + kwargs["with_gemm_swizzled_scales"] = py::cast(false); + PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); + py::object out_py = py::reinterpret_steal(result); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); // Extract buffers from Python tensor @@ -332,7 +572,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize data tensor at::Tensor data_tensor; - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -341,13 +582,12 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize transpose tensor at::Tensor transpose_tensor; - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); transpose_tensor = at::empty(transpose_shape, opts); } - // Initialize scale-inverse tensor at::Tensor scale_inv_tensor; { @@ -355,23 +595,56 @@ std::pair Float8CurrentScalingQuantizer::create_tenso const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); scale_inv_tensor = at::empty(scale_inv_shape, opts); } - + at::Device device = + with_data ? data_tensor.device() + : (with_transpose ? transpose_tensor.device() + : at::Device(torch::kCUDA, c10::cuda::current_device())); // Construct Python FP8 tensor py::object out_py; + py::object scale_inv_py = py::cast(scale_inv_tensor); py::object data_py = with_data ? py::cast(data_tensor) : py::none(); py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + kwargs["data"] = data_py; + kwargs["fp8_scale_inv"] = scale_inv_py; + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["data_transpose"] = transpose_py; + kwargs["quantizer"] = this->quantizer; + kwargs["fake_dtype"] = GetATenDType(dtype); + + py::tuple args(0); + PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); - out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + const auto stride_int64 = stride_from_shape(shape_int64); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + kwargs["shape"] = py::cast(shape_int64); + kwargs["stride"] = py::cast(stride_int64); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["data"] = data_py; + kwargs["fp8_scale_inv"] = scale_inv_py; + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["data_transpose"] = transpose_py; + kwargs["quantizer"] = this->quantizer; + kwargs["device"] = py::cast(device); + py::tuple args(0); + PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorPythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ FP8 tensor @@ -392,6 +665,90 @@ std::pair Float8CurrentScalingQuantizer::create_tenso return {std::move(out_cpp), std::move(out_py)}; } +std::pair Float8CurrentScalingQuantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + at::Tensor scale = at::empty({static_cast(num_tensors)}, float_opts); + at::Tensor amax = at::empty({static_cast(num_tensors)}, float_opts); + + if (rowwise_usage) { + rowwise_data = at::empty({total_elements}, uint8_opts); + rowwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); + } + if (columnwise_usage) { + columnwise_data = at::empty({total_elements}, uint8_opts); + columnwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*columnwise_scale_inv)); + } + out_cpp.set_scale(scale.data_ptr(), DType::kFloat32, getTensorShape(scale)); + out_cpp.set_amax(amax.data_ptr(), DType::kFloat32, getTensorShape(amax)); + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); + py::dict kwargs; + py::tuple args(0); + const std::vector grouped_shape = {static_cast(logical_first_dim), + static_cast(logical_last_dim)}; + const std::vector grouped_stride = stride_from_shape(grouped_shape); + kwargs["shape"] = py::cast(grouped_shape); + kwargs["stride"] = py::cast(grouped_stride); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["num_tensors"] = py::cast(num_tensors); + kwargs["quantizer"] = quantizer; + kwargs["data"] = maybe_tensor_to_py(rowwise_data); + kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); + kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); + kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); + kwargs["amax"] = amax; + kwargs["columnwise_amax"] = py::none(); + kwargs["scale"] = scale; + kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); + kwargs["last_dims"] = py::none(); + kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); + kwargs["with_gemm_swizzled_scales"] = py::cast(false); + PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); + py::object out_py = py::reinterpret_steal(result); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, DType dtype, @@ -410,10 +767,10 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8CurrentScalingQuantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); // Extract buffers from Python tensor @@ -628,26 +985,141 @@ std::pair Float8BlockQuantizer::create_tensor( py::object ret; if (internal) { - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass)); - ret = Float8BlockwiseQTensorClass( - "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, - "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, - "is_2D_scaled"_a = (block_scaling_dim == 2)); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + kwargs["rowwise_data"] = py::cast(data_rowwise); + kwargs["columnwise_data"] = py::cast(data_colwise); + kwargs["rowwise_scale_inv"] = py::cast(scale_inv_rowwise); + kwargs["columnwise_scale_inv"] = py::cast(scale_inv_colwise); + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["quantizer"] = this->quantizer; + kwargs["is_2D_scaled"] = py::cast(block_scaling_dim == 2); + kwargs["fake_dtype"] = GetATenDType(dtype); + + py::tuple args(0); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensorStorage instance"); + ret = py::reinterpret_steal(result); } else { - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorPythonClass)); - ret = Float8BlockwiseQTensorClass( - "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, - "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, - "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2)); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + const auto stride_int64 = stride_from_shape(torch_shape); + kwargs["shape"] = py::cast(torch_shape); + kwargs["stride"] = py::cast(stride_int64); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["rowwise_data"] = py::cast(data_rowwise); + kwargs["columnwise_data"] = py::cast(data_colwise); + kwargs["rowwise_scale_inv"] = py::cast(scale_inv_rowwise); + kwargs["columnwise_scale_inv"] = py::cast(scale_inv_colwise); + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["quantizer"] = this->quantizer; + kwargs["is_2D_scaled"] = py::cast(block_scaling_dim == 2); + + py::tuple args(0); + PyObject* result = PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorPythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensor instance"); + ret = py::reinterpret_steal(result); } return {std::move(tensor), std::move(ret)}; } +std::pair Float8BlockQuantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; + + if (rowwise_usage) { + rowwise_data = at::empty({total_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, false); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + rowwise_scale_inv = at::empty({total_scale_elements}, float_opts); + } + + if (columnwise_usage) { + columnwise_data = at::empty({total_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, true); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + columnwise_scale_inv = at::empty({total_scale_elements}, float_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*columnwise_scale_inv)); + } + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); + py::dict kwargs; + py::tuple args(0); + const std::vector grouped_shape = {static_cast(logical_first_dim), + static_cast(logical_last_dim)}; + const std::vector grouped_stride = stride_from_shape(grouped_shape); + kwargs["shape"] = py::cast(grouped_shape); + kwargs["stride"] = py::cast(grouped_stride); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["num_tensors"] = py::cast(num_tensors); + kwargs["quantizer"] = quantizer; + kwargs["data"] = maybe_tensor_to_py(rowwise_data); + kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); + kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); + kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); + kwargs["amax"] = py::none(); + kwargs["columnwise_amax"] = py::none(); + kwargs["scale"] = py::none(); + kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); + kwargs["last_dims"] = py::none(); + kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); + kwargs["with_gemm_swizzled_scales"] = py::cast(false); + PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); + py::object out_py = py::reinterpret_steal(result); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair Float8BlockQuantizer::convert_and_update_tensor( py::object tensor) const { const DType dtype = tensor.attr("_fp8_dtype").cast(); @@ -934,18 +1406,50 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve // Construct Python MXFP8 tensor py::object out_py; if (internal) { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); - out_py = MXFP8TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py, - columnwise_scale_inv_py, this->dtype, this->quantizer, - with_gemm_swizzled_scales); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + py::tuple args(0); + kwargs["rowwise_data"] = rowwise_data_py; + kwargs["columnwise_data"] = columnwise_data_py; + kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; + kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["quantizer"] = this->quantizer; + kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["fake_dtype"] = GetATenDType(dtype); + + PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorStoragePythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create MXFP8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - out_py = MXFP8TensorClass( - "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "with_gemm_swizzled_scales"_a = with_gemm_swizzled_scales); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + const auto stride_int64 = stride_from_shape(shape_int64); + kwargs["shape"] = py::cast(shape_int64); + kwargs["stride"] = py::cast(stride_int64); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["rowwise_data"] = rowwise_data_py; + kwargs["columnwise_data"] = columnwise_data_py; + kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; + kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["quantizer"] = this->quantizer; + kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + + py::tuple args(0); + PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorPythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create MXFP8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ MXFP8 tensor @@ -966,6 +1470,93 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +std::pair MXFP8Quantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; + + if (rowwise_usage) { + rowwise_data = at::empty({total_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, false); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + rowwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); + } + + if (columnwise_usage) { + columnwise_data = at::empty({total_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, true); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + columnwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*columnwise_scale_inv)); + } + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); + + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); + py::dict kwargs; + py::tuple args(0); + const std::vector grouped_shape = {static_cast(logical_first_dim), + static_cast(logical_last_dim)}; + const std::vector grouped_stride = stride_from_shape(grouped_shape); + kwargs["shape"] = py::cast(grouped_shape); + kwargs["stride"] = py::cast(grouped_stride); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["num_tensors"] = py::cast(num_tensors); + kwargs["quantizer"] = quantizer; + kwargs["data"] = maybe_tensor_to_py(rowwise_data); + kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); + kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); + kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); + kwargs["amax"] = py::none(); + kwargs["columnwise_amax"] = py::none(); + kwargs["scale"] = py::none(); + kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); + kwargs["last_dims"] = py::none(); + kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); + kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; + PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); + py::object out_py = py::reinterpret_steal(result); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair MXFP8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); @@ -1230,19 +1821,54 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // Construct Python NVFP4 tensor py::object out_py; if (internal) { - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); - out_py = NVFP4TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py, - columnwise_scale_inv_py, amax_rowwise_py, amax_columnwise_py, - this->dtype, this->quantizer, with_gemm_swizzled_scales); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + kwargs["rowwise_data"] = rowwise_data_py; + kwargs["columnwise_data"] = columnwise_data_py; + kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; + kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; + kwargs["amax_rowwise"] = amax_rowwise_py; + kwargs["amax_columnwise"] = amax_columnwise_py; + kwargs["fp4_dtype"] = py::cast(this->dtype); + kwargs["quantizer"] = this->quantizer; + kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["fake_dtype"] = GetATenDType(dtype); + + py::tuple args(0); + + PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorStoragePythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create NVFP4TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorPythonClass)); - out_py = NVFP4TensorClass( - "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, - "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "with_gemm_swizzled_scales"_a = with_gemm_swizzled_scales); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + const auto stride_int64 = stride_from_shape(shape_int64); + kwargs["shape"] = py::cast(shape_int64); + kwargs["stride"] = py::cast(stride_int64); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["rowwise_data"] = rowwise_data_py; + kwargs["columnwise_data"] = columnwise_data_py; + kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; + kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; + kwargs["amax_rowwise"] = amax_rowwise_py; + kwargs["amax_columnwise"] = amax_columnwise_py; + kwargs["fp4_dtype"] = py::cast(this->dtype); + kwargs["quantizer"] = this->quantizer; + kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + py::tuple args(0); + PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create NVFP4Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ tensor @@ -1271,6 +1897,104 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +std::pair NVFP4Quantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + NVTE_CHECK(total_elements % 2 == 0, "NVFP4 data size must be divisible by 2."); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + std::optional rowwise_amax; + std::optional columnwise_amax; + const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; + + const int64_t total_data_elements = total_elements / 2; + + if (rowwise_usage) { + rowwise_data = at::empty({total_data_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, false); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + rowwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); + rowwise_amax = at::empty({static_cast(num_tensors)}, float_opts); + } + + if (columnwise_usage) { + columnwise_data = at::empty({total_data_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, true); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + columnwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); + columnwise_amax = at::empty({static_cast(num_tensors)}, float_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*rowwise_scale_inv)); + out_cpp.set_amax(rowwise_amax->data_ptr(), DType::kFloat32, getTensorShape(*rowwise_amax)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*columnwise_scale_inv)); + out_cpp.set_columnwise_amax(columnwise_amax->data_ptr(), DType::kFloat32, + getTensorShape(*columnwise_amax)); + } + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); + + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); + py::dict kwargs; + py::tuple args(0); + const std::vector grouped_shape = {static_cast(logical_first_dim), + static_cast(logical_last_dim)}; + const std::vector grouped_stride = stride_from_shape(grouped_shape); + kwargs["shape"] = py::cast(grouped_shape); + kwargs["stride"] = py::cast(grouped_stride); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["num_tensors"] = py::cast(num_tensors); + kwargs["quantizer"] = quantizer; + kwargs["data"] = maybe_tensor_to_py(rowwise_data); + kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); + kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); + kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); + kwargs["amax"] = maybe_tensor_to_py(rowwise_amax); + kwargs["columnwise_amax"] = maybe_tensor_to_py(columnwise_amax); + kwargs["scale"] = py::none(); + kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); + kwargs["last_dims"] = py::none(); + kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); + kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; + PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); + py::object out_py = py::reinterpret_steal(result); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair NVFP4Quantizer::create_unquantized_tensor_with_amax( TensorWrapper& quantized_tensor, DType dtype) { // Construct tensor @@ -1506,7 +2230,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // Compute amax. if (this->with_rht) { if (input.dtype() != DType::kBFloat16) { - NVTE_CHECK(false, "RHT is only supported for bfloat16 input"); + NVTE_ERROR("RHT is only supported for bfloat16 input, got dtype enum value ", + static_cast(input.dtype())); } if (this->with_post_rht_amax) { // We need: @@ -1518,7 +2243,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou }); } else { // raise error since it's not supported yet - NVTE_CHECK(false, "Pre-RHT amax is not supported yet"); + NVTE_ERROR( + "Pre-RHT amax is not supported yet. " + "Use with_post_rht_amax=true instead."); } } else { // Without RHT if (compute_amax) { diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 3f998bb66..e9c6ca882 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -170,6 +170,127 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +NVTEScalingMode ScalingModeFromQuantizer(py::handle quantizer) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return NVTE_MXFP8_1D_SCALING; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return NVTE_NVFP4_1D_SCALING; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + const int block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); + return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D; + } + return NVTE_DELAYED_TENSOR_SCALING; +} + +DType GetTransformerEngineDTypeForScaleInv(py::handle quantizer, at::Tensor scale_inv) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return DType::kFloat8E8M0; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + return DType::kFloat32; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return DType::kFloat8E4M3; + } + return GetTransformerEngineDType(scale_inv.scalar_type()); +} + +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { + // Returns a GroupedTensorWrapper from a PyTorch GroupedTensor. + const auto num_tensors = tensor.attr("num_tensors").cast(); + const auto logical_shape = tensor.attr("logical_shape").cast>(); + py::handle quantizer = py::none(); + DType quantizer_dtype = DType::kNumTypes; + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + if (!tensor.attr("quantizer").is_none()) { + quantizer = tensor.attr("quantizer"); + if (!quantizer.is_none()) { + scaling_mode = ScalingModeFromQuantizer(quantizer); + quantizer_dtype = quantizer.attr("dtype").cast(); + } + } + auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); + + // Rowwise data + if (!tensor.attr("rowwise_data").is_none()) { + const auto &data = tensor.attr("rowwise_data").cast(); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_rowwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); + } + + // Columnwise data + if (!tensor.attr("columnwise_data").is_none()) { + const auto &data = tensor.attr("columnwise_data").cast(); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_columnwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); + } + + // Scale + if (!tensor.attr("scale").is_none()) { + const auto &scale = tensor.attr("scale").cast(); + ret.set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + } + + // Amax + if (!tensor.attr("amax").is_none()) { + const auto &amax = tensor.attr("amax").cast(); + ret.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + if (!tensor.attr("columnwise_amax").is_none()) { + const auto &amax = tensor.attr("columnwise_amax").cast(); + ret.set_columnwise_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + + // Scale inverse + if (!tensor.attr("scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("scale_inv").cast(); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + if (!tensor.attr("columnwise_scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("columnwise_scale_inv").cast(); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + + // Shape metadata + if (!tensor.attr("first_dims").is_none()) { + const auto &first_dims = tensor.attr("first_dims").cast(); + ret.set_first_dims(first_dims.data_ptr(), GetTransformerEngineDType(first_dims.scalar_type()), + getTensorShape(first_dims)); + } + if (!tensor.attr("last_dims").is_none()) { + const auto &last_dims = tensor.attr("last_dims").cast(); + ret.set_last_dims(last_dims.data_ptr(), GetTransformerEngineDType(last_dims.scalar_type()), + getTensorShape(last_dims)); + } + if (!tensor.attr("tensor_offsets").is_none()) { + const auto &tensor_offsets = tensor.attr("tensor_offsets").cast(); + ret.set_tensor_offsets(tensor_offsets.data_ptr(), + GetTransformerEngineDType(tensor_offsets.scalar_type()), + getTensorShape(tensor_offsets)); + } + + bool with_gemm_swizzled = false; + if (py::hasattr(tensor, "_with_gemm_swizzled_scales")) { + with_gemm_swizzled = tensor.attr("_with_gemm_swizzled_scales").cast(); + } + ret.set_with_gemm_swizzled_scales(with_gemm_swizzled); + + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 6588aa6c5..587ec289a 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -1,6 +1,4 @@ /************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -9,8 +7,6 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ -#ifndef USE_ROCM - #include #include @@ -37,6 +33,16 @@ std::optional multi_tensor_swizzle_scales_for_gemm(std::vector, std::optional>; + +/*! \brief Swizzle grouped tensor scales for GEMM if needed. + * Currently only works for MXFP8 1D scaling with uniform shapes. + * + * The returned swizzled scales should be kept alive during the GEMM. + */ +std::optional maybe_swizzle_grouped_tensor_for_gemm( + GroupedTensorWrapper& input); + /*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. * * If rowwise==false, the columnwise data will be reinterpreted as @@ -54,6 +60,4 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(TensorWrapper& input, bool roww } // namespace pytorch } // namespace transformer_engine -#endif //!USE_ROCM - #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/custom_recipes/gemm.py b/transformer_engine/pytorch/custom_recipes/gemm.py index 8f853ff09..3d1e1cc43 100644 --- a/transformer_engine/pytorch/custom_recipes/gemm.py +++ b/transformer_engine/pytorch/custom_recipes/gemm.py @@ -32,7 +32,8 @@ def custom_gemm( grad: bool = False, ) -> Iterable[Optional[torch.Tensor]]: """Dispatch GEMM to quantizer's qgemm method.""" - assert is_custom(A) and is_custom(B), "A and B must be custom tensors" + if not (is_custom(A) and is_custom(B)): + raise TypeError("A and B must be custom tensors") A, B = B, A @@ -68,11 +69,16 @@ def custom_gemm( if gemm_type == GEMMType.FPROP: qx, sx = A.data, A.scale qw, sw = B.data, B.scale - assert qx is not None - assert sx is not None - assert qw is not None - assert sw is not None - assert A.original_shape is not None + if qx is None: + raise ValueError("FPROP GEMM: quantized activation data (A.data) is None") + if sx is None: + raise ValueError("FPROP GEMM: activation scale (A.scale) is None") + if qw is None: + raise ValueError("FPROP GEMM: quantized weight data (B.data) is None") + if sw is None: + raise ValueError("FPROP GEMM: weight scale (B.scale) is None") + if A.original_shape is None: + raise ValueError("FPROP GEMM: A.original_shape is None, cannot determine output shape") # Call quantizer's qgemm method result = quantizer.qgemm( @@ -95,10 +101,14 @@ def custom_gemm( elif gemm_type == GEMMType.DGRAD: qdy, sdy = A.data, A.scale qw_t, sw_t = B.data_t, B.scale_t - assert qdy is not None - assert sdy is not None - assert qw_t is not None - assert sw_t is not None + if qdy is None: + raise ValueError("DGRAD GEMM: quantized gradient data (A.data) is None") + if sdy is None: + raise ValueError("DGRAD GEMM: gradient scale (A.scale) is None") + if qw_t is None: + raise ValueError("DGRAD GEMM: transposed quantized weight data (B.data_t) is None") + if sw_t is None: + raise ValueError("DGRAD GEMM: transposed weight scale (B.scale_t) is None") result = quantizer.qgemm( qdy, @@ -115,10 +125,14 @@ def custom_gemm( elif gemm_type == GEMMType.WGRAD: qdy_t, sdy_t = A.data_t, A.scale_t qx_t, sx_t = B.data_t, B.scale_t - assert qdy_t is not None - assert sdy_t is not None - assert qx_t is not None - assert sx_t is not None + if qdy_t is None: + raise ValueError("WGRAD GEMM: transposed quantized gradient data (A.data_t) is None") + if sdy_t is None: + raise ValueError("WGRAD GEMM: transposed gradient scale (A.scale_t) is None") + if qx_t is None: + raise ValueError("WGRAD GEMM: transposed quantized activation data (B.data_t) is None") + if sx_t is None: + raise ValueError("WGRAD GEMM: transposed activation scale (B.scale_t) is None") result = quantizer.qgemm( qdy_t, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py index 5bdc537e4..8580cf4a3 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py @@ -218,6 +218,12 @@ def __init__( self.with_amax_reduction = False self.amax_reduction_group = None + def __getstate__(self): + """Exclude unpicklable process group from serialized state.""" + state = self.__dict__.copy() + state["amax_reduction_group"] = None + return state + @property def custom(self) -> bool: """Flag to indicate this quantizer is custom.""" diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index d00d0c8b9..f42183ec0 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -169,7 +169,8 @@ def high_precision_gemm_ref( y_shape = (mat1.size(0), mat2.size(1)) if bias is not None: - assert not accumulate, "Bias is not supported with accumulation" + if accumulate: + raise ValueError("Bias is not supported with accumulation") bias = bias.to(out_dtype) # With bias case if out_dtype == torch.float32: @@ -325,7 +326,8 @@ def size(self, *args, **kwargs): # pylint: disable=unused-argument the second dimension by half. This method returns the logical shape that users expect, not the internal packed storage shape. """ - assert self.original_shape is not None + if self.original_shape is None: + raise RuntimeError("NVFP4TensorRef.size() called but original_shape has not been set") return torch.Size(self.original_shape) @@ -374,7 +376,8 @@ def _build_hadamard_matrix( Uses Sylvester construction to avoid SciPy dependency. """ - assert (size & (size - 1)) == 0, "Hadamard size must be a power of two" + if (size & (size - 1)) != 0: + raise ValueError(f"Hadamard size must be a power of two, got {size}") h = torch.ones((1, 1), device=device, dtype=torch.float32) while h.shape[0] < size: h = torch.cat( @@ -402,9 +405,10 @@ def _apply_rht(self, x: torch.Tensor) -> torch.Tensor: # RHT dimension equals the quantization tile length (NVFP4 uses 16) rht_dim = self.quant_tile_shape[1] - assert ( - x.shape[-1] % rht_dim == 0 - ), f"Inner dimension {x.shape[-1]} must be divisible by hadamard dimension {rht_dim}" + if x.shape[-1] % rht_dim != 0: + raise ValueError( + f"Inner dimension {x.shape[-1]} must be divisible by hadamard dimension {rht_dim}" + ) # Build H and scale H = self._build_hadamard_matrix(rht_dim, x.device, x.dtype, self.with_random_sign_mask) @@ -446,7 +450,11 @@ def _quantize_blockwise_reference( eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.ndim == 2 + if x.ndim != 2: + raise ValueError( + f"_quantize_blockwise_reference expects a 2D tensor, got {x.ndim}D with shape" + f" {x.shape}" + ) using_2d_quantization = tile_len_x == 16 and tile_len_y == 16 m, n = x.shape # Compute vec_max based on the original x (before reshape) @@ -525,7 +533,11 @@ def _pad_tensor( tensor: torch.Tensor, row_divisor: Optional[int], col_divisor: Optional[int] ) -> torch.Tensor: - assert tensor.dim() == 2, "only supports 2D tensors" + if tensor.dim() != 2: + raise ValueError( + f"_pad_tensor only supports 2D tensors, got {tensor.dim()}D tensor with shape" + f" {tensor.shape}" + ) M, N = tensor.shape padding_needed_rows = 0 padding_needed_cols = 0 @@ -553,7 +565,11 @@ def _pad_tensor( @staticmethod def _rm_pad_tensor(tensor: torch.Tensor, original_size: tuple[int, ...]) -> torch.Tensor: - assert tensor.dim() == 2, "only supports 2D tensors" + if tensor.dim() != 2: + raise ValueError( + f"_rm_pad_tensor only supports 2D tensors, got {tensor.dim()}D tensor with shape" + f" {tensor.shape}" + ) M, N = original_size out = tensor[:M, :N].contiguous() return out @@ -584,19 +600,20 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ - sx_t: scale tensor for qx_t (if columnwise_usage), None otherwise - global_amax_row, global_amax_col: global amax tensors """ + global_amax_col = None if self.pow_2_scales: - assert self.quant_tile_shape == ( - 1, - 32, - ), "MXFP4 only supports 1x32 tile shape." + if self.quant_tile_shape != (1, 32): + raise ValueError( + f"MXFP4 only supports 1x32 tile shape, got {self.quant_tile_shape}" + ) # TODO(etsykunov): Fix bug where global_amax_row and # global_amax_col are not defined # global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32) else: - assert self.quant_tile_shape in ( - (1, 16), - (16, 16), - ), "NVFP4 only supports 1x16 or 16x16 tile shape." + if self.quant_tile_shape not in ((1, 16), (16, 16)): + raise ValueError( + f"NVFP4 only supports 1x16 or 16x16 tile shape, got {self.quant_tile_shape}" + ) # Prepare inputs once so we can reuse for both amax and quantization # Row-input will always be the original input. row_input = tensor @@ -670,7 +687,11 @@ def quantize( **kwargs, # pylint: disable=unused-argument ) -> NVFP4TensorRef: # sanity checks - assert tensor.dtype in utils.HIGH_PRECISION_FLOAT_DTYPES, "Unsupported input dtype." + if tensor.dtype not in utils.HIGH_PRECISION_FLOAT_DTYPES: + raise TypeError( + f"Unsupported input dtype {tensor.dtype}, expected one of" + f" {utils.HIGH_PRECISION_FLOAT_DTYPES}" + ) # Make it work with 3D tensors original_shape = tensor.shape @@ -766,7 +787,10 @@ def is_data_t_transposed_in_memory(self) -> bool: TODO(etsykunov): Confirm docstring is correct. """ - raise NotImplementedError("Not implemented yet") + raise NotImplementedError( + "NVFP4QuantizerRef.is_data_t_transposed_in_memory is not implemented for FP4" + " quantization" + ) def qgemm( self, @@ -784,7 +808,8 @@ def qgemm( qresult_w: QuantizedTensorStorage | None = None, ) -> torch.Tensor: """Python implementation of microblock FP4 GEMM.""" - assert bias is None, "Bias is implemented for FP4 GEMM." + if bias is not None: + raise ValueError("Bias is not supported in NVFP4QuantizerRef.qgemm") high_precision_x = cast_from_fp4x2(qx, out_dtype) high_precision_w = cast_from_fp4x2(qw, out_dtype) @@ -814,11 +839,22 @@ def qgemm( else: - assert qresult_x is not None - assert qresult_w is not None - - assert qresult_x.global_amax_row is not None - assert qresult_w.global_amax_col is not None + if qresult_x is None: + raise ValueError( + "qresult_x is required for non-pow_2_scales NVFP4 GEMM (needed for global_amax)" + ) + if qresult_w is None: + raise ValueError( + "qresult_w is required for non-pow_2_scales NVFP4 GEMM (needed for global_amax)" + ) + if qresult_x.global_amax_row is None: + raise ValueError( + "qresult_x.global_amax_row must be set for non-pow_2_scales NVFP4 GEMM" + ) + if qresult_w.global_amax_col is None: + raise ValueError( + "qresult_w.global_amax_col must be set for non-pow_2_scales NVFP4 GEMM" + ) sx = sx.to(torch.float32) sw = sw.to(torch.float32) @@ -833,23 +869,27 @@ def qgemm( M, K = high_precision_x.shape N, K_w = high_precision_w.shape - assert K == K_w, "K dimension mismatch between qx and qw" - - assert K % 32 == 0, "K dimension must be divisible by 32" - assert N % 8 == 0, "N dimension must be divisible by 8" + if K != K_w: + raise ValueError( + f"K dimension mismatch between qx and qw: qx has K={K}, qw has K={K_w}" + ) + if K % 32 != 0: + raise ValueError(f"K dimension must be divisible by 32, got K={K}") + if N % 8 != 0: + raise ValueError(f"N dimension must be divisible by 8, got N={N}") block_length = 32 if self.pow_2_scales else 16 grid_k = K // block_length - assert sx.shape == ( - M, - K // block_length, - ), f"sx shape mismatch: expected ({M}, {K//block_length}), got {sx.shape}" - assert sw.shape == ( - N, - K // block_length, - ), f"sw shape mismatch: expected ({N}, {K//block_length}), got {sw.shape}" + if sx.shape != (M, K // block_length): + raise ValueError( + f"sx shape mismatch: expected ({M}, {K // block_length}), got {sx.shape}" + ) + if sw.shape != (N, K // block_length): + raise ValueError( + f"sw shape mismatch: expected ({N}, {K // block_length}), got {sw.shape}" + ) y = torch.zeros(M, N, dtype=torch.float32, device=qx.device) @@ -878,10 +918,12 @@ def qgemm( # accumulation happens at epilogue in float32 if accumulate: - assert out is not None, "Output tensor must be provided for accumulation." + if out is None: + raise ValueError("Output tensor must be provided for accumulation.") y += out.to(torch.float32) else: - assert out is None, "Output tensor should be None when accumulate is False." + if out is not None: + raise ValueError("Output tensor should be None when accumulate is False.") y = y.to(out_dtype) return y diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index c881bbe08..9501fadf0 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -154,7 +154,11 @@ def set_tensor_model_parallel_attributes( ) -> None: """set attributes needed for TP""" for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - assert not hasattr(tensor, attribute) + if hasattr(tensor, attribute): + raise RuntimeError( + f"Tensor already has attribute '{attribute}' set. Cannot set " + "tensor model parallel attributes on a tensor that already has them." + ) # Set the attributes. setattr(tensor, "tensor_model_parallel", is_parallel) setattr(tensor, "partition_dim", dim) @@ -172,7 +176,11 @@ def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int: @lru_cache def get_distributed_rank(group: Optional[dist_group_type] = None) -> int: """Return my rank for the distributed group.""" - assert torch.distributed.is_initialized(), "torch.distributed is not initialized." + if not torch.distributed.is_initialized(): + raise RuntimeError( + "torch.distributed is not initialized. Call torch.distributed.init_process_group() " + "before calling get_distributed_rank()." + ) return torch.distributed.get_rank(group=group) @@ -731,8 +739,8 @@ def checkpoint( if isinstance(function, TransformerEngineBaseModule): # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need # to scatter/gather activations that we will recompute anyway. - setattr(function, "fsdp_wrapped", False) - setattr(function, "fsdp_group", None) + function.fast_setattr("fsdp_wrapped", False) + function.fast_setattr("fsdp_group", None) # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments # and execute TE's own checkpointing @@ -745,7 +753,12 @@ def checkpoint( # If saved activations need to be distributed but there is no process group, # default to the world group. if distribute_saved_activations: - assert torch.distributed.is_initialized(), "torch.distributed is not initialized." + if not torch.distributed.is_initialized(): + raise RuntimeError( + "torch.distributed is not initialized. Call " + "torch.distributed.init_process_group() before using " + "distribute_saved_activations=True." + ) tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group return _CheckpointFunction.apply( @@ -921,9 +934,12 @@ def reduce_scatter_along_first_dim( return inp, None dim_size = list(inp.size()) - assert ( - dim_size[0] % world_size == 0 - ), "First dimension of the tensor should be divisible by tensor parallel size" + if dim_size[0] % world_size != 0: + raise ValueError( + "First dimension of the tensor should be divisible by tensor parallel size, " + f"but got dim_size[0]={dim_size[0]} and world_size={world_size} " + f"(remainder={dim_size[0] % world_size})." + ) dim_size[0] = dim_size[0] // world_size @@ -988,7 +1004,11 @@ def _all_gather_fp8( # Note: We cannot directly all-gather the transposed FP8 tensor, # so temporarily modify quantizer to avoid creating FP8 transpose. if not isinstance(inp, Float8TensorStorage): - assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) + if not isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): + raise TypeError( + "Expected quantizer to be Float8Quantizer or Float8CurrentScalingQuantizer " + f"when input is not Float8TensorStorage, but got {type(quantizer).__name__}." + ) # we cannot directly gather the transposed fp8 tensor # so we need to disable columnwise usage for the quantizer # and then set it back to the original value after quantizing @@ -1083,7 +1103,7 @@ def _start_all_gather_fp8_blockwise( device = inp._columnwise_data.device else: raise ValueError("Got Float8BlockwiseQTensorStorage input tensor without any data") - dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. + dtype = inp._dtype else: raise ValueError( "Invalid type for input tensor (expected torch.Tensor or" @@ -1104,6 +1124,9 @@ def _start_all_gather_fp8_blockwise( # Fall back to high-precision all-gather if FP8 is not supported if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: + warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") + if isinstance(inp, QuantizedTensorStorage): + inp = inp.dequantize(dtype=dtype) # Dequantize if needed out = torch.empty(out_shape, dtype=dtype, device=device) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) out = quantizer(out) @@ -1119,7 +1142,7 @@ def _start_all_gather_fp8_blockwise( "Input and quantizer do not have matching usages. " "Dequantizing and requantizing to Float8BlockwiseQTensor." ) - inp = quantizer(inp.dequantize()) + inp = quantizer(inp.dequantize(dtype=dtype)) # Construct Float8BlockwiseQTensor output tensor out = quantizer.make_empty(out_shape, dtype=dtype, device=device) @@ -1235,10 +1258,18 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int): """ shape = tensor.shape - assert len(shape) >= 2, "Wrong number of dimensions for fixing interleave." + if len(shape) < 2: + raise ValueError( + f"Wrong number of dimensions for fixing interleave: got {len(shape)}, " + f"expected at least 2 (shape={shape})." + ) first_dim = shape[0] flattened_trailing = math.prod(shape[1:]) - assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave." + if first_dim % world_size != 0: + raise ValueError( + f"Wrong dimensions for fixing interleave: first_dim={first_dim} is not divisible " + f"by world_size={world_size} (remainder={first_dim % world_size})." + ) tensor = tensor.reshape(world_size, first_dim // world_size, flattened_trailing) tensor = tex.swap_first_dims(tensor, out=None) return tensor.reshape(first_dim // world_size, flattened_trailing * world_size) @@ -1321,14 +1352,18 @@ def _all_gather_nvfp4( if inp._columnwise_data is not None: in_shape_t = inp._columnwise_data.size() device = inp._columnwise_data.device - dtype = torch.bfloat16 + dtype = inp._dtype else: raise ValueError( "Invalid type for input tensor (expected torch.Tensor or NVFP4TensorStorage, " f"found {inp.__class__.__name__})" ) - assert in_shape is not None or in_shape_t is not None, "No data found." + if in_shape is None and in_shape_t is None: + raise ValueError( + "No data found: both in_shape and in_shape_t are None. " + "Input tensor must have rowwise or columnwise data." + ) world_size = get_distributed_world_size(process_group) @@ -1342,6 +1377,9 @@ def _all_gather_nvfp4( and quantizer is not None and not quantizer.is_quantizable(inp) ): + warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") + if isinstance(inp, QuantizedTensorStorage): + inp = inp.dequantize(dtype=dtype) # Dequantize if needed out = torch.empty( out_shape, dtype=dtype, @@ -1362,7 +1400,7 @@ def _all_gather_nvfp4( "Input and quantizer do not have matching usages. " "Dequantizing and requantizing to NVFP4." ) - inp = quantizer(inp.dequantize()) + inp = quantizer(inp.dequantize(dtype=dtype)) # Construct NVFP4 output tensor out = quantizer.make_empty(out_shape, dtype=dtype, device=device) @@ -1378,7 +1416,11 @@ def _all_gather_nvfp4( if quantizer.rowwise_usage: # Remove padding from NVFP4 scale-inverses - assert in_shape is not None, "Shape not found." + if in_shape is None: + raise RuntimeError( + "Shape not found: in_shape is None but rowwise_usage is True. " + "Input tensor must have rowwise data for NVFP4 rowwise gathering." + ) in_scale_inv = inp._rowwise_scale_inv out_scale_inv = out._rowwise_scale_inv flattened_in_shape0 = math.prod(in_shape[:-1]) @@ -1490,7 +1532,7 @@ def _all_gather_mxfp8( device = inp._columnwise_data.device else: raise ValueError("Got MXFP8 input tensor without any data") - dtype = torch.bfloat16 # Guess high-precision dtype. + dtype = inp._dtype else: raise ValueError( "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorStorage, " @@ -1509,6 +1551,9 @@ def _all_gather_mxfp8( and quantizer is not None and not quantizer.is_quantizable(inp) ): + warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") + if isinstance(inp, QuantizedTensorStorage): + inp = inp.dequantize(dtype=dtype) # Dequantize if needed out = torch.empty( out_shape, dtype=dtype, @@ -1529,7 +1574,7 @@ def _all_gather_mxfp8( "Input and quantizer do not have matching usages. " "Dequantizing and requantizing to MXFP8." ) - inp = quantizer(inp.dequantize()) + inp = quantizer(inp.dequantize(dtype=dtype)) # Construct MXFP8 output tensor out = quantizer.make_empty(out_shape, dtype=dtype, device=device) @@ -1676,7 +1721,10 @@ def gather_along_first_dim( # MXFP8 case if isinstance(inp, MXFP8TensorStorage) or isinstance(quantizer, MXFP8Quantizer): - assert isinstance(quantizer, MXFP8Quantizer) + if not isinstance(quantizer, MXFP8Quantizer): + raise TypeError( + f"Expected MXFP8Quantizer for MXFP8 all-gather, but got {type(quantizer).__name__}." + ) return _all_gather_mxfp8( inp, process_group, @@ -1687,7 +1735,10 @@ def gather_along_first_dim( # NVFP4 case if isinstance(inp, NVFP4TensorStorage) or isinstance(quantizer, NVFP4Quantizer): - assert isinstance(quantizer, NVFP4Quantizer) + if not isinstance(quantizer, NVFP4Quantizer): + raise TypeError( + f"Expected NVFP4Quantizer for NVFP4 all-gather, but got {type(quantizer).__name__}." + ) return _all_gather_nvfp4( inp, process_group, @@ -1830,8 +1881,15 @@ def symmetric_all_reduce( - The second element is the async work handle if async_op=True, otherwise None. """ - assert async_op is False, "Async symmetric ops no supported yet" - assert HAS_TORCH_SYMMETRIC, "Could not import symetric memory from torch" + if async_op: + raise RuntimeError( + f"Async symmetric ops are not supported yet, but async_op={async_op!r} was passed." + ) + if not HAS_TORCH_SYMMETRIC: + raise RuntimeError( + "Could not import symmetric memory from torch. " + "Please ensure torch.distributed._symmetric_memory is available." + ) if get_distributed_world_size(tp_group) == 1: return inp, None @@ -1964,10 +2022,19 @@ def _fsdp_gather_tensors( *tensors: torch.Tensor, ): if fsdp_group is not None: - assert len(shapes) == len(tensors), "Number of tensors and tensor shapes must be equal." + if len(shapes) != len(tensors): + raise ValueError( + "Number of tensors and tensor shapes must be equal, " + f"but got {len(shapes)} shapes and {len(tensors)} tensors." + ) for s, t in zip(shapes, tensors): if isinstance(t, torch.Tensor): - assert s is not None, "Internal TE error." + if s is None: + raise RuntimeError( + "Internal TE error: shape is None for a non-None tensor in " + "post_optimizer_step_fwd_amax_reduction. " + f"Tensor type: {type(t).__name__}, tensor shape: {t.shape}." + ) targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t] for target in targets: safely_set_viewless_tensor_data( @@ -2015,29 +2082,37 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: fsdp_root : torch.nn.Module FSDP-wrapped root module that may contain FSDP-wrapped TE modules. """ - assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped." + if not isinstance(fsdp_root, FSDP): + raise TypeError(f"Root module must be FSDP-wrapped, but got {type(fsdp_root).__name__}.") # If the root module is a TE module, inject FSDP information into it if _is_te_module(fsdp_root.module): if hasattr(fsdp_root, "primary_weights_in_fp8"): - assert not fsdp_root.primary_weights_in_fp8, ( - "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " - "Please initialize your model without the te.quantized_model_init(...) context." - ) + if fsdp_root.primary_weights_in_fp8: + raise RuntimeError( + "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " + "Please initialize your model without the te.quantized_model_init(...) context." + ) root_state = _get_module_fsdp_state(fsdp_root) - assert root_state is not None, "Root module does not have a valid _FSDPState." - setattr(fsdp_root.module, "fsdp_group", root_state.process_group) + if root_state is None: + raise RuntimeError( + f"Root module ({type(fsdp_root.module).__name__}) does not have a valid " + "_FSDPState. Ensure the module is properly wrapped with FSDP." + ) + fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group) # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root) for state, fsdp_module in zip(fsdp_states, fsdp_modules): if _is_te_module(fsdp_module.module): if hasattr(fsdp_module.module, "primary_weights_in_fp8"): - assert not fsdp_module.module.primary_weights_in_fp8, ( - "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " - "Please initialize your model without the te.quantized_model_init(...) context." - ) - setattr(fsdp_module.module, "fsdp_group", state.process_group) + if fsdp_module.module.primary_weights_in_fp8: + raise RuntimeError( + f"TE module '{type(fsdp_module.module).__name__}' with primary weights " + "in FP8 cannot be FSDP-wrapped. Please initialize your model without " + "the te.quantized_model_init(...) context." + ) + fsdp_module.module.fast_setattr("fsdp_group", state.process_group) class FullyShardedDataParallel(FSDP): diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 15dfb1c1d..b7ed68a57 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -110,6 +110,8 @@ def _make_graphed_callables( pool: Optional[Tuple[int, ...]] = None, retain_graph_in_backward: bool = False, _reuse_graph_input_output_buffers: bool = False, + pre_warmup_hook: Optional[Callable] = None, + post_warmup_hook: Optional[Callable] = None, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -139,7 +141,7 @@ def _make_graphed_callables( # Check training/inference is_training = all(c.training for c in callables) if not is_training and any(c.training for c in callables): - assert False, ( + raise RuntimeError( "make_graphed_callables only supports when modules are all in training or all in" " inference mode." ) @@ -148,8 +150,16 @@ def _make_graphed_callables( _order_without_wgrad = None delay_wgrad_compute = False if _order is None: - assert len(sample_args) == len(callables) - assert len(sample_kwargs) == len(callables) + if len(sample_args) != len(callables): + raise ValueError( + "Expected sample_args to have the same length as callables, " + f"but got {len(sample_args)} sample_args for {len(callables)} callables" + ) + if len(sample_kwargs) != len(callables): + raise ValueError( + "Expected sample_kwargs to have the same length as callables, " + f"but got {len(sample_kwargs)} sample_kwargs for {len(callables)} callables" + ) else: # Custom logic for interleaved pipeline parallelism # Note: This is tightly coupled with the Megatron-core @@ -173,48 +183,62 @@ def _make_graphed_callables( _order_without_wgrad.append(c_id) num_model_chunks = max(_order_without_wgrad) num_microbatches = len(_order_without_wgrad) // num_model_chunks // 2 - assert num_model_chunks * num_microbatches * 2 == len(_order_without_wgrad) + if num_model_chunks * num_microbatches * 2 != len(_order_without_wgrad): + raise ValueError( + f"Pipeline-parallel order dimension mismatch: num_model_chunks ({num_model_chunks})" + f" * num_microbatches ({num_microbatches}) * 2 =" + f" {num_model_chunks * num_microbatches * 2}, but len(_order_without_wgrad) =" + f" {len(_order_without_wgrad)}" + ) # When delay_wgrad_compute is enabled, each layer is treated as a model chunk, which # allows for fine-grained graph capture order. if delay_wgrad_compute: - assert ( - _num_layers_per_chunk is not None - ), "'_num_layers_per_chunk' must be provided when delay_wgrad_compute is True." + if _num_layers_per_chunk is None: + raise ValueError( + "'_num_layers_per_chunk' must be provided when delay_wgrad_compute is True." + ) for num_layers in _num_layers_per_chunk: - assert ( - num_layers == 1 - ), "Each model chunk must have only one layer when delay_wgrad_compute is True." + if num_layers != 1: + raise ValueError( + "Each model chunk must have only one layer when delay_wgrad_compute is" + f" True, but got {num_layers} layers." + ) # Determine number of layers in each model chunk. if _num_layers_per_chunk is None: - assert len(sample_args) * 2 >= len(_order_without_wgrad) and ( - len(sample_args) * 2 % len(_order_without_wgrad) == 0 - ), ( - f"{len(sample_args)} * 2 >= {len(_order_without_wgrad)} and {len(sample_args)} * 2" - f" % {len(_order_without_wgrad)} == 0" - ) + if not ( + len(sample_args) * 2 >= len(_order_without_wgrad) + and (len(sample_args) * 2 % len(_order_without_wgrad) == 0) + ): + raise ValueError( + f"{len(sample_args)} * 2 >= {len(_order_without_wgrad)} and" + f" {len(sample_args)} * 2 % {len(_order_without_wgrad)} == 0" + ) num_layers = len(sample_args) // num_model_chunks // num_microbatches _num_layers_per_chunk = [num_layers] * num_model_chunks else: - assert ( + if not ( isinstance(_num_layers_per_chunk, int) or len(_num_layers_per_chunk) == num_model_chunks - ), ( - "If _num_layers_per_chunk is provided, it must be an integer or a list of" - f" {num_model_chunks} integers, but got {_num_layers_per_chunk}." - ) + ): + raise ValueError( + "If _num_layers_per_chunk is provided, it must be an integer or a list of" + f" {num_model_chunks} integers, but got {_num_layers_per_chunk}." + ) if isinstance(_num_layers_per_chunk, int): _num_layers_per_chunk = [_num_layers_per_chunk] * num_model_chunks total_num_layers = sum(_num_layers_per_chunk) - assert len(callables) == total_num_layers, ( - f"Callables should have ({total_num_layers}) " - + f"entries when order input is provided but got {len(callables)}." - ) - assert len(sample_args) == total_num_layers * num_microbatches, ( - f"Expected {total_num_layers * num_microbatches} " - + f"args tuple, but got {len(sample_args)}." - ) + if len(callables) != total_num_layers: + raise ValueError( + f"Callables should have ({total_num_layers}) " + + f"entries when order input is provided but got {len(callables)}." + ) + if len(sample_args) != total_num_layers * num_microbatches: + raise ValueError( + f"Expected {total_num_layers * num_microbatches} " + + f"args tuple, but got {len(sample_args)}." + ) # Calculate the starting index of each chunk in callables for future use. _prefix_num_layers = [0] @@ -222,19 +246,26 @@ def _make_graphed_callables( num_layers = _num_layers_per_chunk[m_chunk] _prefix_num_layers.append(_prefix_num_layers[-1] + num_layers) - assert len(sample_kwargs) == len(sample_args) + if len(sample_kwargs) != len(sample_args): + raise ValueError( + "Pipeline-parallel schedule requires sample_kwargs and sample_args to have " + f"the same length, but got {len(sample_kwargs)} sample_kwargs " + f"for {len(sample_args)} sample_args" + ) # Check reuse graph conditions and reorganize sample_args and sample_kwargs. # Note: When capturing a graph, we hold onto the args and kwargs so we have static buffers # when the graph is replayed. If two model chunk microbatches have no overlap between their # forward and backward, then we can reduce memory usage by reusing the same static buffers. if _reuse_graph_input_output_buffers: - assert ( - _order is not None - ), "`_order` must be provided when `_reuse_graph_input_output_buffers` is True." - assert ( - is_training - ), "`_reuse_graph_input_output_buffers` is only available in training mode." + if _order is None: + raise ValueError( + "`_order` must be provided when `_reuse_graph_input_output_buffers` is True." + ) + if not is_training: + raise RuntimeError( + "`_reuse_graph_input_output_buffers` is only available in training mode." + ) if isinstance(sample_args, tuple): sample_args = list(sample_args) if isinstance(sample_kwargs, tuple): @@ -300,20 +331,22 @@ def _make_graphed_callables( # Check callables for c in callables: if isinstance(c, torch.nn.Module): - assert ( + if not ( len(c._backward_hooks) == 0 and len(c._forward_hooks) == 0 and len(c._forward_pre_hooks) == 0 - ), ( - "Modules must not have hooks registered at the time they are passed. " - + "However, registering hooks on modules after passing them " - + "through make_graphed_callables is allowed." - ) - assert all(b.requires_grad is False for b in c.buffers()), ( - "In any :class:`~torch.nn.Module` passed to " - + ":func:`~make_graphed_callables`, only parameters may be trainable. " - + "All buffers must have ``requires_grad=False``." - ) + ): + raise RuntimeError( + "Modules must not have hooks registered at the time they are passed. " + + "However, registering hooks on modules after passing them " + + "through make_graphed_callables is allowed." + ) + if not all(b.requires_grad is False for b in c.buffers()): + raise RuntimeError( + "In any :class:`~torch.nn.Module` passed to " + + ":func:`~make_graphed_callables`, only parameters may be trainable. " + + "All buffers must have ``requires_grad=False``." + ) # Flatten callable arguments per_callable_kwargs_keys = [list(kwargs.keys()) for kwargs in sample_kwargs] @@ -322,10 +355,11 @@ def _make_graphed_callables( flatten_arg, _ = _tree_flatten(args) flatten_kwarg, _ = _tree_flatten([kwargs[key] for key in kwargs_keys]) flatten_sample_args.append(tuple(flatten_arg + flatten_kwarg)) - assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( - "In the beta API, sample_args " - + "for each callable must contain only Tensors. Other types are not allowed." - ) + if not all(isinstance(arg, torch.Tensor) for arg in flatten_arg): + raise TypeError( + "In the beta API, sample_args " + + "for each callable must contain only Tensors. Other types are not allowed." + ) # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly # passes to forward (ie, its sample_args) AND the module's parameter attributes. @@ -354,7 +388,12 @@ def _make_graphed_callables( ) else () ) - assert len(per_callable_module_params) == len(flatten_sample_args) + if len(per_callable_module_params) != len(flatten_sample_args): + raise ValueError( + "Pipeline-parallel dimension mismatch: " + f"per_callable_module_params has {len(per_callable_module_params)} entries, " + f"but flatten_sample_args has {len(flatten_sample_args)} entries" + ) per_callable_static_input_surfaces = [ flatten_sample_args[i] + per_callable_module_params[i] for i in range(len(flatten_sample_args)) @@ -400,12 +439,12 @@ def _make_graphed_callables( warmup_func_idx.append(func_idx) warmup_func.append(func) fwd_idx[m_chunk] += 1 - assert len(warmup_func) == len( - sample_args - ), f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}." - assert len(warmup_func_idx) == len( - set(warmup_func_idx) - ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." + if len(warmup_func) != len(sample_args): + raise ValueError(f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}.") + if len(warmup_func_idx) != len(set(warmup_func_idx)): + raise RuntimeError( + f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." + ) # Filter the TE modules that cudagraph can access. visited_te_modules = {} @@ -429,9 +468,10 @@ def hook_fn( modules.add(module) # If forward is called on a te.ops.Sequential it is not called on its constituent ops elif isinstance(module, Sequential): - assert ( - module._module_groups is not None - ), "Should have been initialized by warmup" + if module._module_groups is None: + raise RuntimeError( + "module._module_groups should have been initialized by warmup" + ) for module_group in module._module_groups: if isinstance(module_group, OperationFuser): for basic_op in module_group._basic_ops: @@ -442,6 +482,8 @@ def hook_fn( else: visited_te_modules[func_idx].update(modules) + if pre_warmup_hook is not None: + pre_warmup_hook() for warmup_iter in range(num_warmup_iters): hooks = [] for module in func.modules(): @@ -453,11 +495,12 @@ def hook_fn( if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs): + outputs_requiring_grad = tuple( + o for o in outputs if o is not None and o.requires_grad + ) torch.autograd.backward( - tuple(o for o in outputs if o.requires_grad), - grad_tensors=tuple( - torch.empty_like(o) for o in outputs if o.requires_grad - ), + outputs_requiring_grad, + grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad), ) grad_inputs = tuple(input.grad for input in inputs) @@ -477,20 +520,22 @@ def hook_fn( grad_inputs[grad_inputs_idx] is None and grad_inputs_idx < num_required_grad_sample_args ): - assert allow_unused_input, ( - "The input tensor requires grad, but the grad is None after" - " backward pass." - ) + if not allow_unused_input: + raise RuntimeError( + "The input tensor requires grad, but the grad is None after" + " backward pass." + ) elif ( grad_inputs[grad_inputs_idx] is not None and grad_inputs_idx >= num_required_grad_sample_args ): module_params_with_grad.append(static_input_surface[inputs_idx]) if len(module_params_with_grad) != len(per_callable_module_params[func_idx]): - assert warmup_iter == 0, ( - "no-grad params should only be used as inputs in the first warmup" - " iteration" - ) + if warmup_iter != 0: + raise RuntimeError( + "no-grad params should only be used as inputs in the first warmup" + f" iteration, but found in iteration {warmup_iter}" + ) per_callable_module_params[func_idx] = tuple(module_params_with_grad) static_input_surface = flatten_sample_args[func_idx] + tuple( module_params_with_grad @@ -508,11 +553,8 @@ def hook_fn( else: grad_inputs = None del outputs, grad_inputs - # The following code is added specifically for MCore's special requirements, - # aimed at preventing warmup from altering the control flow. - for module in func.modules(): - if hasattr(module, "is_first_microbatch"): - module.is_first_microbatch = True + if post_warmup_hook is not None: + post_warmup_hook() torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -531,7 +573,10 @@ def hook_fn( previous_chunk_last_callable_bwd_idx = None for i, c_id in enumerate(_order): if c_id > 0: - assert isinstance(c_id, int), "Forward order value must be an integer." + if not isinstance(c_id, int): + raise TypeError( + f"Forward order value must be an integer, but got {type(c_id).__name__}." + ) # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] m_chunk = c_id - 1 for l_no in range(_num_layers_per_chunk[m_chunk]): @@ -583,23 +628,27 @@ def hook_fn( break if wgrad_validation_list[i] is None: wgrad_validation_list[i] = False - assert wgrad_validation_list[i], ( - f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number " - f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}." - ) + if not wgrad_validation_list[i]: + raise RuntimeError( + f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number " + f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}." + ) elif ceil(c_id) != c_id: per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk] - assert is_training, "Only training mode supports backward_dw." + if not is_training: + raise RuntimeError("Only training mode supports backward_dw.") # If no one module needs the backward_dw, the bwd_dw_graph will be empty. # So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate # the specific order of backward_dw. - assert ceil(c_id) - c_id == 0.5, ( - "The order diff of wgrad and dgrad must be 0.5, " - f"get {ceil(c_id) - c_id}." - ) - assert need_bwd_dw_graph[ - per_callable_bwd_idx - ], "No module needs wgrad computation but get float in order" + if ceil(c_id) - c_id != 0.5: + raise ValueError( + "The order diff of wgrad and dgrad must be 0.5, " + f"get {ceil(c_id) - c_id}." + ) + if not need_bwd_dw_graph[per_callable_bwd_idx]: + raise RuntimeError( + "No module needs wgrad computation but get float in order" + ) bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx] with _graph_context_wrapper(bwd_dw_graph, pool=mempool): for module in visited_te_modules[per_callable_bwd_idx]: @@ -618,19 +667,22 @@ def hook_fn( # Note for _reuse_graph_input_output_buffers: grad output is only used # within backward, so we can reuse the same static buffers every time. static_grad_outputs_keys = tuple( - (o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad + (o.shape, o.dtype, o.layout) + for o in static_outputs + if o is not None and o.requires_grad ) if static_grad_outputs_keys in static_grad_outputs_dict: static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys] else: static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None + torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs ) static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs else: static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None + for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -638,7 +690,9 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o.requires_grad), + tuple( + o for o in static_outputs if o is not None and o.requires_grad + ), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -721,7 +775,8 @@ def hook_fn( ): # For now, assumes all static_outputs require grad static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None + for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -729,7 +784,7 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o.requires_grad), + tuple(o for o in static_outputs if o is not None and o.requires_grad), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -777,14 +832,15 @@ class Graphed(torch.autograd.Function): """Autograd function for graph replay.""" @staticmethod - def forward(ctx, skip_fp8_weight_update, *inputs): + def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *inputs): # pylint: disable=missing-function-docstring # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) - + ctx.cuda_graph_stream = cuda_graph_stream + ctx.cuda_graph_event = cuda_graph_event # Copy values from new tensors into static tensors for i in range(len_user_args): if ( @@ -794,9 +850,22 @@ def forward(ctx, skip_fp8_weight_update, *inputs): static_input_surface[i].copy_(inputs[i]) # Replay forward graph - fwd_graph.replay() - assert isinstance(static_outputs, tuple) - return tuple(o.detach() for o in static_outputs) + if cuda_graph_stream != torch.cuda.current_stream(): + cuda_graph_stream.wait_stream(torch.cuda.current_stream()) + with cuda_graph_stream: + fwd_graph.replay() + if cuda_graph_event is not None: + torch.cuda.current_stream().wait_event(cuda_graph_event) + else: + torch.cuda.current_stream().wait_stream(cuda_graph_stream) + else: + fwd_graph.replay() + if not isinstance(static_outputs, tuple): + raise TypeError( + "Expected static_outputs to be a tuple, but got" + f" {type(static_outputs).__name__}" + ) + return tuple(o.detach() if o is not None else o for o in static_outputs) @staticmethod @torch.autograd.function.once_differentiable @@ -804,14 +873,28 @@ def backward(ctx, *grads): # pylint: disable=missing-function-docstring # Replay backward graph - assert len(grads) == len(static_grad_outputs) + if len(grads) != len(static_grad_outputs): + raise ValueError( + "Backward graph grad dimension mismatch: " + f"received {len(grads)} grads, " + f"but expected {len(static_grad_outputs)} static_grad_outputs" + ) for g, grad in zip(static_grad_outputs, grads): if g is not None: # don't copy if autograd gods have been kind and the # incoming grad is already in the right place if g.data_ptr() != grad.data_ptr(): g.copy_(grad) - bwd_graph.replay() + if ctx.cuda_graph_stream != torch.cuda.current_stream(): + ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) + with ctx.cuda_graph_stream: + bwd_graph.replay() + if ctx.cuda_graph_event is not None: + torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) + else: + torch.cuda.current_stream().wait_stream(ctx.cuda_graph_stream) + else: + bwd_graph.replay() # Update FP8 scale factors if needed if ctx.is_first_module: @@ -820,8 +903,12 @@ def backward(ctx, *grads): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True) # Input args that didn't require grad expect a None gradient. - assert isinstance(static_grad_inputs, tuple) - return (None,) + tuple( + if not isinstance(static_grad_inputs, tuple): + raise TypeError( + "Expected static_grad_inputs to be a tuple, but got" + f" {type(static_grad_inputs).__name__}" + ) + return (None, None, None) + tuple( b.detach() if b is not None else b for b in static_grad_inputs ) @@ -830,12 +917,33 @@ def functionalized(*user_args, **user_kwargs): # Decide whether to update FP8 weights skip_fp8_weight_update = None if cache_quantized_params: - assert "is_first_microbatch" in user_kwargs and isinstance( + if "is_first_microbatch" not in user_kwargs or not isinstance( user_kwargs["is_first_microbatch"], bool - ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." + ): + raise ValueError( + "`is_first_microbatch` boolean kwarg must be provided for FP8 weight" + " caching." + ) skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] + # The cuda_graph_stream and cuda_graph_event are used in the TE CUDA graph replay. + # When replaying the graph in the cuda graph stream, the graph replay could overlap + # with the work on main stream. + # When cuda_graph_event is given, it should be an external event recorded + # in the cuda graph and is used to sync-back to the main stream. + # If cuda_graph_event is not given, it will be None and the graph replay will block + # the main stream until it is finished. + if "cuda_graph_stream" in user_kwargs: + cuda_graph_stream = user_kwargs["cuda_graph_stream"] + user_kwargs.pop("cuda_graph_stream") + else: + cuda_graph_stream = torch.cuda.current_stream() + if "cuda_graph_event" in user_kwargs: + cuda_graph_event = user_kwargs["cuda_graph_event"] + user_kwargs.pop("cuda_graph_event") + else: + cuda_graph_event = None # Check that required kwargs are provided for key in kwargs_keys: if key not in user_kwargs: @@ -851,18 +959,30 @@ def functionalized(*user_args, **user_kwargs): flatten_user_args, _ = _tree_flatten(user_args) flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys]) func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params - out = Graphed.apply(skip_fp8_weight_update, *func_args) + out = Graphed.apply( + skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args + ) return _tree_unflatten(out, output_unflatten_spec) return functionalized def make_graphed_attribute_functions(graph_idx): + # Get te modules for current graph + te_modules = visited_te_modules.get(graph_idx, set()) # Attach backward_dw as an attribute to the graphed callable. def backward_dw(): if need_bwd_dw_graph.get(graph_idx, False): bwd_dw_graphs[graph_idx].replay() + # Trigger the grad accumulation hook for wgrad graphs. + for module in te_modules: + if ( + isinstance(module, TransformerEngineBaseModule) + and module.need_backward_dw() + ): + module._trigger_wgrad_accumulation_and_reduce_hooks() + # Attach reset as an attribute to the graphed callable. def reset(): fwd_graphs[graph_idx].reset() @@ -1027,6 +1147,8 @@ def make_graphed_callables( pool: Optional[Tuple[int, ...]] = None, retain_graph_in_backward: bool = False, _reuse_graph_input_output_buffers: bool = False, + pre_warmup_hook: Optional[Callable] = None, + post_warmup_hook: Optional[Callable] = None, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -1065,6 +1187,10 @@ def make_graphed_callables( graphs. Only supported with Mcore interleaved pipeline parallelism, i.e. when `_order` is provided. All callables in `modules` are assumed to have inputs and outputs with the same dtype and shape. + pre_warmup_hook: callable, default = None + A hook function that will be called before the warmup iterations. + post_warmup_hook: callable, default = None + A hook function that will be called after the warmup iterations. Quantization parameters ----------------------- @@ -1179,12 +1305,16 @@ def make_graphed_callables( modules = (modules,) if not isinstance(enabled, tuple): - assert isinstance(enabled, bool), "enabled must be a bool or a tuple of bools" + if not isinstance(enabled, bool): + raise TypeError( + f"enabled must be a bool or a tuple of bools, but got {type(enabled).__name__}" + ) enabled = (enabled,) * len(modules) else: - assert len(enabled) == len( - modules - ), f"enabled length ({len(enabled)}) must match modules length ({len(modules)})" + if len(enabled) != len(modules): + raise ValueError( + f"enabled length ({len(enabled)}) must match modules length ({len(modules)})" + ) if any(enabled) and recipe is None: recipe = get_default_fp8_recipe() elif not any(enabled): @@ -1220,7 +1350,8 @@ def call_func(self, *args, **kwargs): forward_funcs = [] for module in modules: - assert isinstance(module, torch.nn.Module), f"Graphing for {type(module)} is not supported." + if not isinstance(module, torch.nn.Module): + raise TypeError(f"Graphing for {type(module)} is not supported.") wrap_autocast(module) forward_funcs.append(module) @@ -1251,6 +1382,8 @@ def call_func(self, *args, **kwargs): pool=pool, retain_graph_in_backward=retain_graph_in_backward, _reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers, + pre_warmup_hook=pre_warmup_hook, + post_warmup_hook=post_warmup_hook, ) # Ensures warmup does not affect numerics for ops such as dropout. diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index e8cef56bd..4d52d9b92 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -49,17 +49,35 @@ def wrapper(*args, **kwargs): # Decorator to disable Torch Dynamo # See: https://github.com/NVIDIA/TransformerEngine/issues/308 -no_torch_dynamo = lambda recursive=True: lambda func: func if torch.__version__ >= "2": import torch._dynamo - if torch.__version__ >= "2.1": - no_torch_dynamo = lambda recursive=True: lambda f: ( - f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive) - ) - else: - # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True - no_torch_dynamo = lambda recursive=True: torch._dynamo.disable + def no_torch_dynamo(recursive=True): + """Decorator to disable Torch Dynamo, except during ONNX export.""" + + def decorator(f): + # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True + disabled_f = ( + torch._dynamo.disable(f, recursive=recursive) + if torch.__version__ >= "2.1" + else torch._dynamo.disable(f) + ) + + @wraps(f) + def wrapper(*args, **kwargs): + if is_in_onnx_export_mode(): + return f(*args, **kwargs) + return disabled_f(*args, **kwargs) + + return wrapper + + return decorator + +else: + # Fallback for PyTorch < 2.0: no-op decorator + def no_torch_dynamo(recursive=True): # pylint: disable=unused-argument + """No-op decorator for PyTorch < 2.0.""" + return lambda func: func def set_jit_fusion_options() -> None: diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 4058833c9..168fec063 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -91,6 +91,8 @@ def forward( # Check first tensor if not tensors: raise ValueError("Attempted to concatenate 0 tensors") + + # Check concat dim num_dims = tensors[0].dim() if not -num_dims <= dim < num_dims: raise ValueError( @@ -123,11 +125,24 @@ def forward( ctx.dim = dim ctx.split_ranges = split_ranges - # Out-of-place concatenation if needed + # Tensor properties from first tensor dtype = tensors[0].dtype device = tensors[0].device strides = tensors[0].stride() data_ptr_stride = strides[dim] * tensors[0].element_size() + + # Out-of-place concatenation when view tensors have different storage + # Note: This works around an edge case with the split_quantize + # function, which might allocate a buffer and construct + # subviews. However, in order to reduce CPU overheads, these + # views are configured manually outside of PyTorch. PyTorch + # doesn't know these views share the same memory, and it + # blocks us from reconstructing the full tensor because it + # thinks we are accessing out-of-bounds memory. + if tensors[0].untyped_storage().nbytes() < out_shape[dim] * data_ptr_stride: + return torch.cat(tensors, dim=dim) + + # Out-of-place concatenation if tensor properties do not match data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * data_ptr_stride numel = tensors[0].numel() for tensor in tensors[1:]: @@ -149,13 +164,7 @@ def forward( return out # No-op concatenation - out = tensors[0].new() - out.set_( - tensors[0].untyped_storage(), - tensors[0].storage_offset(), - out_shape, - strides, - ) + out = tensors[0].as_strided(out_shape, strides) out.requires_grad = any(tensor.requires_grad for tensor in tensors) return out diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2d8563729..f789b164d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -12,9 +12,8 @@ import warnings from enum import Enum from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Union from contextlib import contextmanager -import logging from types import MethodType from itertools import chain @@ -24,7 +23,6 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine_torch as tex -from transformer_engine.common.recipe import Recipe from ._common import _ParameterInitMeta, noop_cat from ..quantization import ( @@ -43,6 +41,7 @@ _fsdp_gather_tensors, ) from ..constants import dist_group_type +from ..cpp_extensions.gemm import _NUM_MAX_UB_STREAMS from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer @@ -52,7 +51,14 @@ from ..triton_kernels.cast import te_quantize_triton from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage -from ..utils import get_device_compute_capability, is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype +from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage +from ..utils import ( + is_non_tn_fp8_gemm_supported, + torch_get_autocast_gpu_dtype, + get_nvtx_range_context, + nvtx_range_push, + nvtx_range_pop, +) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState @@ -64,11 +70,8 @@ _2X_ACC_FPROP = False _2X_ACC_DGRAD = True _2X_ACC_WGRAD = True -_multi_stream_cublas_workspace = [] _dummy_wgrads = {} -_cublas_workspace = None _ub_communicators = None -_NUM_MAX_UB_STREAMS = 3 _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None layers_atomic_ring_exchange = [] @@ -82,44 +85,10 @@ class UserBufferQuantizationMode(Enum): FP8 = "fp8" -def get_cublas_workspace_size_bytes() -> None: - """Return workspace size needed for current architecture""" - if IS_HIP_EXTENSION: - """Return 64 MiB for gfx50x, 32 MiB for all other architectures.""" - if get_device_compute_capability() == (9, 5): - return 67_108_864 - return 33_554_432 - """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: - # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales - return 32 * 1024 * 1024 + 1024 - return 4_194_304 - - -def get_workspace() -> torch.Tensor: - """Returns workspace for cublas.""" - global _cublas_workspace - if _cublas_workspace is None: - _cublas_workspace = torch.empty( - get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda" - ) - return _cublas_workspace - - -def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: - """Returns workspace for multi-stream cublas.""" - global _multi_stream_cublas_workspace - if not _multi_stream_cublas_workspace: - for _ in range(tex.get_num_cublas_streams()): - _multi_stream_cublas_workspace.append( - torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda") - ) - return _multi_stream_cublas_workspace - - def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: """Returns a dummy tensor of given shape.""" - assert len(shape) == 2 + if len(shape) != 2: + raise ValueError(f"Expected 2D shape, got {len(shape)}D: {shape}") global _dummy_wgrads if (shape[0], shape[1], dtype) not in _dummy_wgrads: _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( @@ -144,61 +113,62 @@ def initialize_ub( ) -> None: r""" Initialize the Userbuffers communicator for overlapping tensor-parallel communications with - GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules. + GEMM compute in ``te.Linear``, ``te.LayerNormLinear`` and ``te.LayerNormMLP`` modules. Parameters ---------- shape : list shape of the communication buffer, typically set to be the same as the global shape of - the input tensor to a te.TransformerLayer forward pass, with the sequence and batch - dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)` + the input tensor to a ``te.TransformerLayer`` forward pass, with the sequence and batch + dimensions collapsed together -- i.e.: ``(sequence_length * batch_size, hidden_size)`` tp_size : int number of GPUs in the tensor-parallel process group use_fp8 : bool = False allocate the communication buffer for FP8 GEMM inputs/outputs. - DEPRECATED: Please use `quantization_modes` instead. + DEPRECATED: Please use ``quantization_modes`` instead. quantization_modes : List[UserBufferQuantizationMode] = None if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list. - falls back to the legacy `use_fp8` parameter if `None` is provided. + falls back to the legacy ``use_fp8`` parameter if ``None`` is provided. dtype : torch.dtype = torch.bfloat16 - non-FP8 data type of the communication buffer when `use_fp8 = False` - ub_cfgs: dict = None - Configuration dictionary with the structure - ``` - { - : { - "method": <"ring_exchange" or "pipeline">, - "is_reduce_scatter": bool, - "num_sm": int, - "cga_size": int, - "set_sm_margin": bool, - "num_splits": int, - "aggregate": bool, - "atomic_gemm": bool, - "use_ce": bool, - "fp8_buf": bool, - } - } - ``` - for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", + non-FP8 data type of the communication buffer when ``use_fp8 = False`` + ub_cfgs : dict = None + Configuration dictionary with the structure:: + + { + : { + "method": <"ring_exchange" or "pipeline">, + "is_reduce_scatter": bool, + "num_sm": int, + "cga_size": int, + "set_sm_margin": bool, + "num_splits": int, + "aggregate": bool, + "atomic_gemm": bool, + "use_ce": bool, + "fp8_buf": bool, + } + } + + for ``te.TransformerLayer`` GEMM layers in ``["qkv_fprop", "qkv_dgrad", "qkv_wgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", - "fc2_fprop", "fc2_wgrad"]`. - a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes` + "fc2_fprop", "fc2_wgrad"]``. + a list may be provided to specify different overlap configurations for different the quantization settings in ``quantization_modes`` bootstrap_backend : str = None - `torch.distributed` communication backend for the all-gather, broadcast and + ``torch.distributed`` communication backend for the all-gather, broadcast and barrier collectives during Userbuffers initialization. Not all backends are valid for every cluster configuration and distributed launch method even if they are available in PyTorch. When left unset, the initialization prefers to use the MPI backend, falling back first on Gloo and then NCCL if MPI is - not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this + not available. Setting ``NVTE_UB_WITH_MPI=1`` when building TE overrides this option and always initializes Userbuffers with direct MPI calls in C++, - which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time. + which also requires ``MPI_HOME=/path/to/mpi/root`` to be set at compile time. """ if not tex.device_supports_multicast(): - assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( - "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " - + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." - ) + if not bool(int(os.getenv("UB_SKIPMC", "0"))): + raise RuntimeError( + "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap " + "with CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." + ) if not quantization_modes: warnings.warn( @@ -210,34 +180,48 @@ def initialize_ub( UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE ] else: - assert isinstance(quantization_modes, list), "quantization_modes must be a list" - assert all( - isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes - ), "quantization_modes must be a list of UserBufferQuantizationMode" + if not isinstance(quantization_modes, list): + raise TypeError( + f"quantization_modes must be a list, got {type(quantization_modes).__name__}" + ) + invalid_modes = [ + mode for mode in quantization_modes if not isinstance(mode, UserBufferQuantizationMode) + ] + if invalid_modes: + raise TypeError( + "quantization_modes must be a list of UserBufferQuantizationMode, " + f"got invalid entries: {invalid_modes}" + ) if isinstance(ub_cfgs, dict) or ub_cfgs is None: ub_cfgs = [ub_cfgs] * len(quantization_modes) else: - assert len(ub_cfgs) == len( - quantization_modes - ), "Number of ub_cfgs settings must match number of quantization configurations" + if len(ub_cfgs) != len(quantization_modes): + raise ValueError( + f"Number of ub_cfgs settings ({len(ub_cfgs)}) must match number of " + f"quantization configurations ({len(quantization_modes)})" + ) global _ub_communicators - assert _ub_communicators is None, "UB communicators are already initialized." + if _ub_communicators is not None: + raise RuntimeError("UB communicators are already initialized.") _ub_communicators = {} if tex.ubuf_built_with_mpi(): # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force # an MPI_Init() here by creating a new MPI process group... - assert torch.distributed.is_mpi_available() + if not torch.distributed.is_mpi_available(): + raise RuntimeError( + "MPI backend is not available in torch.distributed but is required " + "when Userbuffers is built with MPI support" + ) _ = torch.distributed.new_group(backend="mpi") helper = tex.CommOverlapHelper() else: # Bootstrapping with torch.distributed API, so check backend and construct # intra/inter-node process groups... - assert ( - torch.distributed.is_initialized() - ), "torch.distributed must be initialized before Userbuffers" + if not torch.distributed.is_initialized(): + raise RuntimeError("torch.distributed must be initialized before using Userbuffers") if bootstrap_backend is None: bootstrap_backend = "nccl" if torch.distributed.is_mpi_available(): @@ -245,15 +229,16 @@ def initialize_ub( elif torch.distributed.is_gloo_available(): bootstrap_backend = "gloo" else: - assert bootstrap_backend in [ - "gloo", - "mpi", - "nccl", - ], "Invalid torch.distributed backend for bootstrapping Userbuffers!" - assert torch.distributed.is_backend_available(bootstrap_backend), ( - f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " - f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." - ) + if bootstrap_backend not in ["gloo", "mpi", "nccl"]: + raise ValueError( + f"Invalid torch.distributed backend '{bootstrap_backend}' for bootstrapping " + "Userbuffers. Must be one of: 'gloo', 'mpi', 'nccl'" + ) + if not torch.distributed.is_backend_available(bootstrap_backend): + raise RuntimeError( + f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " + f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." + ) world_group = torch.distributed.new_group(backend=bootstrap_backend) world_rank = torch.distributed.get_rank(world_group) @@ -393,9 +378,11 @@ def add_ub( warnings.warn( "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." ) - assert ( - quantization_mode == UserBufferQuantizationMode.FP8 - ), "Atomic GEMM overlap supported only for FP8 GEMM." + if quantization_mode != UserBufferQuantizationMode.FP8: + raise ValueError( + "Atomic GEMM overlap supported only for FP8 GEMM, " + f"got quantization_mode={quantization_mode}" + ) if method in ("bulk", "external"): warnings.warn( f"At {name}, atoimic GEMM not is supported for a bulk overlap." @@ -420,20 +407,24 @@ def add_ub( "for functionality." ) if name in layers_atomic_ring_exchange: - assert atomic_gemm and method == "ring_exchange", assert_message + if not (atomic_gemm and method == "ring_exchange"): + raise ValueError(assert_message) else: if atomic_gemm and method == "ring_exchange": - assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message + if rs_ag_pairs[name] not in layers_atomic_ring_exchange: + raise ValueError(assert_message) if name in external_gemm_to_overlap: - assert method == "external", ( - f"At {name}, `external` overlap method is specified, but the selected method is" - f" {method}" - ) - assert external_gemm_to_overlap[name] in methods["ring_exchange"], ( - f"At {name}, `external` overlap method is specified, but the external gemm" - f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" - ) + if method != "external": + raise ValueError( + f"At {name}, `external` overlap method is specified, but the selected method " + f"is {method}" + ) + if external_gemm_to_overlap[name] not in methods["ring_exchange"]: + raise ValueError( + f"At {name}, `external` overlap method is specified, but the external gemm " + f"{external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" + ) buffer_dtype = ( torch.uint8 @@ -485,7 +476,12 @@ def add_ub( and user_ub_cfg[name]["method"] != "bulk" ): wgrad_name = name.replace("dgrad", "wgrad") - assert wgrad_name not in user_ub_cfg + if wgrad_name in user_ub_cfg: + raise ValueError( + f"Cannot specify user UB config for '{wgrad_name}' when its " + f"corresponding dgrad '{name}' uses a non-bulk overlap method " + f"('{user_ub_cfg[name]['method']}')" + ) if wgrad_name in layers_reduce_scatter_overlap: layers_reduce_scatter_overlap.remove(wgrad_name) if name in layers_all_gather_overlap: @@ -500,10 +496,11 @@ def add_ub( if IS_HIP_EXTENSION and user_ub_cfg is not None: for name, cfg in user_ub_cfg.items(): - assert cfg.get("method") != "bulk", ( - f"Bulk overlap method for '{name}' is not supported on HIP/ROCm. " - "Please use 'ring_exchange' method instead." - ) + if cfg.get("method") == "bulk": + raise NotImplementedError( + f"Bulk overlap method for '{name}' is not supported on HIP/ROCm. " + "Please use 'ring_exchange' method instead." + ) for name in chain.from_iterable(methods.values()): ub_cfg = get_default_config(name) @@ -522,8 +519,10 @@ def get_ub(name: str, use_fp8: bool): # So favour simplicity until the correct design becomes clear. # This is mainly an internal API so we don't need to worry about future changes key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE) - assert _ub_communicators is not None, "UB manager is not initialized." - assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered." + if _ub_communicators is None: + raise RuntimeError("UB manager is not initialized.") + if key not in _ub_communicators: + raise KeyError(f"UB for {name} with use_fp8={use_fp8} is not registered.") return _ub_communicators[key] @@ -593,6 +592,7 @@ def fill_userbuffers_buffer_for_all_gather( data=global_tensor_data, fp8_scale_inv=local_tensor._scale_inv, fp8_dtype=local_tensor._fp8_dtype, + fake_dtype=local_tensor._dtype, quantizer=quantizer, ) return global_tensor, local_tensor @@ -632,6 +632,8 @@ def fill_userbuffers_buffer_for_all_gather( "Userbuffers requires MXFP8 tensor dims that are divisible by 128, " f"but got MXFP8 tensor with shape={tuple(local_shape)}" ) + if local_tensor._with_gemm_swizzled_scales: + raise ValueError("Userbuffers assumes MXFP8 tensors have unswizzled scales") local_scale_inv = ( local_tensor._rowwise_scale_inv if with_rowwise_data @@ -664,6 +666,8 @@ def fill_userbuffers_buffer_for_all_gather( columnwise_scale_inv=columnwise_scale_inv, fp8_dtype=local_tensor._fp8_dtype, quantizer=quantizer, + with_gemm_swizzled_scales=False, + fake_dtype=local_tensor._dtype, ) return global_tensor, local_tensor @@ -674,10 +678,11 @@ def fill_userbuffers_buffer_for_all_gather( class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" - def __init__(self) -> None: + def __init__(self, name: Optional[str] = None) -> None: super().__init__() - assert torch.cuda.is_available(), "TransformerEngine needs CUDA." - self.name = None + if not torch.cuda.is_available(): + raise RuntimeError("TransformerEngine needs CUDA.") + self.name = name self.next_iter_when_debug_should_be_run = 0 self.fp8_initialized = False self.fp8 = False @@ -704,26 +709,22 @@ def __init__(self) -> None: if not TEDebugState.debug_enabled: TEDebugState.initialize() + self._validate_name() - # Names of attributes that can be set quickly (see __setattr__ - # method) - _fast_setattr_names: Set[str] = { - "activation_dtype", - "fp8", - "fp8_initialized", - "fp8_calibration", - "fp8_parameters", - } - - def __setattr__(self, name: str, value: Any) -> None: - if name in TransformerEngineBaseModule._fast_setattr_names: - # torch.nn.Module has a custom __setattr__ that handles - # modules, parameters, and buffers. This is unnecessary - # overhead when setting plain attrs. - self.__dict__[name] = value - else: - # Default case - super().__setattr__(name, value) + def fast_setattr(self, name: str, value: Any) -> None: + """ + Fast version of the Module's set attribute function. + Should be used for regular attributes, but not properties nor parameters/buffers. + """ + self.__dict__[name] = value + + def module_setattr(self, name: str, value: Any) -> None: + """ + Regular version of the Module's set attribute function. + Should be used only when the fast version cannot be used - for the properties, + parameters and buffers. + """ + super().__setattr__(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ @@ -768,9 +769,12 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> ] for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): if buffer_key in FP8GlobalStateManager.global_amax_buffer: - assert ( - buffer_key in FP8GlobalStateManager.global_amax_history_buffer - ), "TE internal error during amax history change." + if buffer_key not in FP8GlobalStateManager.global_amax_history_buffer: + raise RuntimeError( + "TE internal error during amax history change: " + f"buffer_key '{buffer_key}' found in global_amax_buffer " + "but missing from global_amax_history_buffer" + ) FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[ meta_key ].amax_history[0] @@ -819,10 +823,11 @@ def _update_weight_quantizers(self) -> None: """Update the quantizers for the weight tensors.""" weight_tensors = self._get_weight_tensors() weight_quantizers = self._get_weight_quantizers() - assert len(weight_tensors) == len(weight_quantizers), ( - f"Number of weight tensors ({len(weight_tensors)}) and quantizers " - f"({len(weight_quantizers)}) must match" - ) + if len(weight_tensors) != len(weight_quantizers): + raise ValueError( + f"Number of weight tensors ({len(weight_tensors)}) and quantizers " + f"({len(weight_quantizers)}) must match" + ) for weight, quantizer in zip(weight_tensors, weight_quantizers): if quantizer is not None and isinstance(weight, QuantizedTensorStorage): weight.update_quantizer(quantizer) @@ -844,7 +849,7 @@ def init_fp8_meta_tensors(self, recipe: Recipe) -> None: self.set_meta_tensor(True, recipe) self.set_meta_tensor(False, recipe) - self.fp8_meta_tensors_initialized = True + self.fast_setattr("fp8_meta_tensors_initialized", True) def get_fp8_meta_tensors(self) -> None: """Get scales and amaxes.""" @@ -870,7 +875,11 @@ def reset(key): torch.zeros_like(self.fp8_meta[key].amax_history) ) else: - assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." + if key not in fp8_meta_tensors: + raise KeyError( + f"Cannot reset fp8 tensors: key '{key}' not found in fp8_meta_tensors. " + f"Available keys: {list(fp8_meta_tensors.keys())}" + ) self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1]) @@ -1001,22 +1010,22 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): - self.activation_dtype = torch_get_autocast_gpu_dtype() + self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype()) return - + dtype = inp.dtype # All checks after this have already been performed once, thus skip - if self.activation_dtype == inp.dtype: + if self.activation_dtype == dtype: return - dtype = inp.dtype if not self.allow_different_data_and_param_types: for name, param in self.named_parameters(): if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) - self.activation_dtype = dtype + if dtype != param.dtype: + raise TypeError( + "Data types for parameters must match when outside of autocasted " + f"region. Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) + self.fast_setattr("activation_dtype", dtype) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ @@ -1025,11 +1034,11 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N Parameters ---------- - tp_group : ProcessGroup, default = `None` + tp_group : ProcessGroup, default = None tensor parallel process group. """ - self.tp_group = tp_group - self.tp_group_initialized = True + self.fast_setattr("tp_group", tp_group) + self.fast_setattr("tp_group_initialized", True) def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: """returns the FP8 weights.""" @@ -1045,53 +1054,56 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" - _original_recipe = self.fp8_meta.get("recipe", None) + meta = self.fp8_meta - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() - fp8_enabled = self.fp8 or self.fp8_calibration - self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration + fp8 = FP8GlobalStateManager.is_fp8_enabled() + fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_setattr("fp8_parameters", fp8_parameters) + self.fast_setattr("fp8", fp8) + self.fast_setattr("fp8_calibration", fp8_calibration) + fp8_enabled = fp8 or fp8_calibration + meta["fp8_checkpoint"] = fp8_enabled - if IS_HIP_EXTENSION and not FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 and hasattr(self, 'use_fsdp2') and self.use_fsdp2: - FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 = True + _original_recipe = None - if self.fp8_parameters or fp8_enabled: - if ( - self.fp8_initialized - and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] - ): + if IS_HIP_EXTENSION and not FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 and hasattr(self, 'use_fsdp2') and self.use_fsdp2: + FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 = True + + if fp8_parameters or fp8_enabled: + _original_recipe = meta.get("recipe", None) + if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe: # FP8 init has already been run and recipe is the same, don't do anything. return - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_setattr("fp8_initialized", False) return - if self.fp8_parameters and not self.fp8_initialized: - self.fp8_meta["num_gemms"] = num_gemms - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + if fp8_parameters and not self.fp8_initialized: + meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(meta["recipe"]) if fp8_enabled: # Set FP8 and other FP8 metadata - self.fp8_meta["num_gemms"] = num_gemms - self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + meta["num_gemms"] = num_gemms + meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Set FP8_MAX per tensor according to recipe - if hasattr(self.fp8_meta["recipe"], "fp8_format"): - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + if hasattr(meta["recipe"], "fp8_format"): + meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.max_fwd + meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) - self.fp8_initialized = True + self.init_fp8_meta_tensors(meta["recipe"]) + self.fast_setattr("fp8_initialized", True) - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - if self.fp8_meta["recipe"].mxfp8(): - self.keep_fp8_weight_transpose_cache = True + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + if meta["recipe"].mxfp8(): + self.keep_fp8_weight_transpose_cache = True - _current_recipe = self.fp8_meta["recipe"] + _current_recipe = meta["recipe"] if _original_recipe is not None and not ( issubclass(_current_recipe.__class__, _original_recipe.__class__) or issubclass(_original_recipe.__class__, _current_recipe.__class__) @@ -1104,55 +1116,87 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Clear cached workspaces as they were created with the old recipe/quantizer type self._fp8_workspaces.clear() - @contextmanager def prepare_forward( self, inp: torch.Tensor, num_gemms: int = 1, allow_non_contiguous: bool = False, allow_different_data_and_param_types: bool = False, - ) -> Generator[torch.Tensor, None, None]: - """Checks and prep for FWD. - The context manager is needed because there isn't a way for a module to know - if it's the last FP8 module in the forward autocast. It is useful - to setup the forward aggregated amax reduction for every module - just in case. The autocast exit will pick up the most recent one. - """ - self.allow_different_data_and_param_types = allow_different_data_and_param_types - self.forwarded_at_least_once = True + ) -> torch.Tensor: + """Checks and prepares for FWD execution.""" + self.fast_setattr( + "allow_different_data_and_param_types", allow_different_data_and_param_types + ) + self.fast_setattr("forwarded_at_least_once", True) + # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): + delayed_scaling_recipe = self.fp8_meta["recipe"].delayed() FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) else: - assert inp.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise RuntimeError( + f"TransformerEngine needs CUDA. Got input on device: {inp.device}" + ) if self.tp_size > 1: - assert self.tp_group_initialized, "TP group not initialized." + if not self.tp_group_initialized: + raise RuntimeError( + "Tensor parallel group not initialized. Call " + "set_tensor_parallel_group() before forward pass when tp_size > 1." + ) self.set_activation_dtype(inp) self.init_fp8_metadata(num_gemms=num_gemms) self._check_weight_tensor_recipe_correspondence() - if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): - assert self.fp8_meta["recipe"].reduce_amax, ( - "Amax reduction across tensor parallel group is " - "necessary when using sequence parallelism with FP8." - ) + delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() + if delayed_scaling_recipe: + if self.sequence_parallel: + if not self.fp8_meta["recipe"].reduce_amax: + raise ValueError( + "Amax reduction across tensor parallel group is " + "necessary when using sequence parallelism with FP8." + ) - if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) + if not FP8GlobalStateManager.fp8_graph_capturing(): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) - # Activation recomputation is used and this is the first forward phase. - if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): - FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) + # Activation recomputation is used and this is the first forward phase. + if self.training and is_fp8_activation_recompute_enabled(): + FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) - with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): - if not allow_non_contiguous and not inp.is_contiguous(): - inp = inp.contiguous() - yield inp + nvtx_range_push(self.__class__.__name__ + " forward") + if not allow_non_contiguous and not inp.is_contiguous(): + inp = inp.contiguous() + return inp - if self.fp8 and in_fp8_activation_recompute_phase(): + def end_forward(self): + """ + Required to be called at the end of the forward function to properly handle + DelayedScaling metadata handling and the NVTX ranges. + """ + delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() + if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) + nvtx_range_pop() + + @contextmanager + def prepare_forward_ctx( + self, + inp: torch.Tensor, + num_gemms: int = 1, + allow_non_contiguous: bool = False, + allow_different_data_and_param_types: bool = False, + ) -> Generator[torch.Tensor, None, None]: + """Checks and prepares for FWD execution.""" + inp = self.prepare_forward( + inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types + ) + try: + yield inp + finally: + self.end_forward() def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled @@ -1240,18 +1284,7 @@ def grad_output_preprocess( # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None if ctx.debug: grad_output_ = quantizer(grad_output) - if ( - isinstance( - grad_output_.get_tensor(True), - ( - QuantizedTensor, - Float8TensorStorage, - MXFP8TensorStorage, - Float8BlockwiseQTensorStorage, - ), - ) - and ctx.use_bias - ): + if ctx.use_bias: grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias = None @@ -1409,9 +1442,9 @@ def clear(self): # Update the parameter based on its type if not is_dtensor: - setattr(self, name, param) + self.module_setattr(name, param) else: - setattr(self, name, dtensor_param) + self.module_setattr(name, dtensor_param) @abstractmethod def forward(self): @@ -1442,7 +1475,7 @@ def get_weight_workspace( workspace is being constructed or updated. cache_name: str, optional Key for caching. - update_workspace: bool, default = `True` + update_workspace: bool, default = True Update workspace with values from `tensor`. skip_update_flag: torch.Tensor, optional GPU flag to skip updating the workspace. Take precedence @@ -1464,6 +1497,10 @@ def get_weight_workspace( rowwise_usage=update_rowwise_usage, columnwise_usage=update_columnwise_usage, ) + + if isinstance(quantizer, DebugQuantizer): + tensor = quantizer.wrap_quantized_tensor(tensor) + return tensor # Try getting workspace from cache @@ -1486,6 +1523,11 @@ def get_weight_workspace( reset_cache = True elif quantizer.columnwise_usage and out._columnwise_data is None: reset_cache = True + elif isinstance(out, NVFP4TensorStorage): + if quantizer.rowwise_usage and out._rowwise_data is None: + reset_cache = True + elif quantizer.columnwise_usage and out._columnwise_data is None: + reset_cache = True if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): reset_cache = True if reset_cache: @@ -1590,7 +1632,7 @@ def backward_dw(self): """ if not self.need_backward_dw(): return - with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): + with get_nvtx_range_context(f"_{self.__class__.__name__}_wgrad"): (wgrad, bgrad), _ = self.wgrad_store.pop() if not self.fuse_wgrad_accumulation: weight_tensor = noop_cat(self._get_weight_tensors()) @@ -1601,8 +1643,14 @@ def backward_dw(self): bias_tensor.grad = bgrad.to(bias_tensor.dtype) del wgrad del bgrad - for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: - wgrad_accumulation_and_reduce_hook() + self._trigger_wgrad_accumulation_and_reduce_hooks() + + def _trigger_wgrad_accumulation_and_reduce_hooks(self): + """ + Trigger the wgrad accumulation and reduce hooks. + """ + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() def is_debug_iter(self) -> bool: """ @@ -1611,7 +1659,6 @@ def is_debug_iter(self) -> bool: debug = TEDebugState.debug_enabled if not debug: return False - self._validate_name() # If layer is run first time in new iteration, # we need to check if the debug should be enabled for this layer - @@ -1625,13 +1672,19 @@ def is_debug_iter(self) -> bool: debug = False else: debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run - self.debug_last_iteration = TEDebugState.get_iteration() - self.debug_enabled_in_this_iteration = debug + self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration()) + self.fast_setattr("debug_enabled_in_this_iteration", debug) else: # If this is the same iteration as previous invocation of the module, # we use the debug value from the first invocation in the iteration. debug = self.debug_enabled_in_this_iteration + self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration()) + + if self.wgrad_store is not None: + if debug and self.wgrad_store.delay_wgrad_compute(): + raise RuntimeError("Delayed wgrad compute is not supported in debug mode.") + return debug def no_debug_features_active(self, quantizers): @@ -1642,34 +1695,25 @@ def no_debug_features_active(self, quantizers): # Sometimes features inform that they will not be enabled for particular layer # for multiple next iterations. - self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers) + self.fast_setattr( + "next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers) + ) if not run_current: return True - if self.primary_weights_in_fp8: - raise RuntimeError("FP8 weights are not supported in debug mode.") return False def _validate_name(self): """ Validate name passed to the module. - This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. - If no name is assigned, it creates a default name with layer count as the variable. + It creates a default name with layer count as the variable + which may be changed by the user of the module. """ if self.name is not None: return - assert TEDebugState.debug_enabled - import nvdlfw_inspect.api as debug_api - - if self.name is None: - debug_api.log_message( - "Names are not provided to debug modules. ", - "Creating and using generic names. Pass names to debug modules for better" - " insight. ", - level=logging.WARNING, - ) - self.name = f"Layer_{TEDebugState.get_layer_count()}" + + self.name = f"Layer_{TEDebugState.get_layer_count()}" def _check_weight_tensor_recipe_correspondence(self) -> None: """ @@ -1687,6 +1731,8 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: """ if not self.fp8 and not self.fp8_calibration: return + if not self.primary_weights_in_fp8: + return if not hasattr(self, "weight_names") or not self.weight_names: return diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 8c6aa8bde..a47912b4c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -15,6 +15,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from .base import ( get_dummy_wgrad, TransformerEngineBaseModule, @@ -132,8 +133,9 @@ def forward( and not in_fp8_activation_recompute_phase() ) # No need to set the quantizer states if weight is already quantized - if weight_quantizers[0] is not None and not isinstance( - weights[0], QuantizedTensorStorage + # for debug mode we create quantizer every iteration, thus we need to set the quantizer states + if weight_quantizers[0] is not None and ( + not isinstance(weights[0], QuantizedTensorStorage) or debug ): for weight_quantizer in weight_quantizers: weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) @@ -158,7 +160,10 @@ def forward( # tensors (like scales), but bulk allocation shares storage across all tensors, # so if scales can't be offloaded, nothing in the group can be offloaded. inputmats = tex.split_quantize( - inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading + inp_view, + m_splits, + input_quantizers, + disable_bulk_allocation=cpu_offloading, ) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -389,7 +394,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], for i in range(ctx.num_gemms): grad_biases[i] = grad_output_mats[i].sum(dim=0) grad_output = DebugQuantizer.multi_tensor_quantize( - grad_output_view, ctx.grad_output_quantizers, ctx.m_splits, ctx.activation_dtype + grad_output_view, + ctx.grad_output_quantizers, + ctx.m_splits, + ctx.activation_dtype, ) else: # Only split grad output. Grad bias is fused with @@ -479,7 +487,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.input_quantizers[0] is not None: for input_quantizer in ctx.input_quantizers: if isinstance( - input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + input_quantizer, + (Float8Quantizer, Float8CurrentScalingQuantizer), ): input_quantizer.set_usage(rowwise=True, columnwise=True) else: @@ -489,7 +498,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( - inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype + inp_view, + ctx.input_quantizers, + ctx.m_splits, + ctx.activation_dtype, ) else: if not ctx.use_grouped_gemm_triton: @@ -636,6 +648,10 @@ class GroupedLinear(TransformerEngineBaseModule): cast tensor. In some scenarios, the input tensor is used by multiple modules, and saving the original input tensor may reduce the memory usage. Cannot work with FP8 DelayedScaling recipe. + single_grouped_parameter : bool, default = False + If set to ``True``, grouped weights are stored as a single grouped parameter + instead of one parameter per GEMM. + EXPERIMENTAL and subject to change. Notes ----- @@ -666,11 +682,12 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, save_original_input: bool = False, + single_grouped_parameter: bool = False, name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) - params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype + self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_gemms = num_gemms self.in_features = in_features self.out_features = out_features @@ -682,16 +699,22 @@ def __init__( self.ub_overlap_ag = ub_overlap_ag self.ub_name = ub_name self.save_original_input = save_original_input - assert ( - not ub_overlap_rs and not ub_overlap_ag - ), "GroupedLinear doesn't support Userbuffer overlap." + self.single_grouped_parameter = single_grouped_parameter + if ub_overlap_rs or ub_overlap_ag: + raise ValueError("GroupedLinear doesn't support Userbuffer overlap.") + self.init_method = init_method self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute) - self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} + self._offsets = { + "input": 0, + "weight": 1, + "output": 2, + "grad_output": 0, + "grad_input": 1, + } self._num_fp8_tensors_per_gemm = { "fwd": 3, "bwd": 2, @@ -713,9 +736,11 @@ def __init__( ) self.parallel_mode = parallel_mode - assert ( - self.parallel_mode in GemmParallelModes - ), f"parallel_mode {parallel_mode} not supported" + if self.parallel_mode not in GemmParallelModes: + raise ValueError( + f"parallel_mode {parallel_mode!r} not supported." + f" Supported modes: {GemmParallelModes}" + ) if self.parallel_mode == "column": self.out_features = divide(self.out_features, self.tp_size) @@ -733,7 +758,7 @@ def __init__( self.out_features, self.in_features, device=device, - dtype=params_dtype, + dtype=self.params_dtype, ), ), init_fn=init_method, @@ -749,13 +774,13 @@ def __init__( torch.empty( self.out_features, device=device, - dtype=params_dtype, + dtype=self.params_dtype, ), ), init_fn=init_method_constant(0.0), ) else: - bias = torch.Tensor().to(dtype=params_dtype, device=device) + bias = torch.Tensor().to(dtype=self.params_dtype, device=device) setattr(self, f"bias{i}", bias) if self.primary_weights_in_fp8: @@ -779,18 +804,89 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) + def make_grouped_weights(self, defer_init=False) -> None: + """ + Convert parameters into a GroupedTensor and re-register them as parameters. + """ + + if defer_init: + return + + weight_quantizers = self._get_weight_quantizers() + recipe = ( + weight_quantizers[0]._get_compatible_recipe() + if weight_quantizers and weight_quantizers[0] is not None + else None + ) + if recipe is not None and (recipe.delayed() or recipe.float8_current_scaling()): + self.set_tensor_parallel_attributes(defer_init=defer_init) + return + + weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + + # Create the weight storage. + grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=self.num_gemms, + shapes=[(self.out_features, self.in_features)] * self.num_gemms, + quantizer=weight_quantizers[0], + dtype=self.params_dtype, + device=weights[0].device, + ) + + # Copy existing params into storage. + with torch.no_grad(): + for i in range(self.num_gemms): + if self.primary_weights_in_fp8: + grouped_weights.quantized_tensors[i].copy_from_storage(weights[i]) + else: + grouped_weights.quantized_tensors[i].copy_(weights[i]) + + # Re-register as a single grouped weight parameter. + # Re-register as a single grouped weight parameter. + if not ( + isinstance(grouped_weights, torch.Tensor) + and (weight_quantizers[0] is None or not weight_quantizers[0].internal) + ): + raise RuntimeError("Found internal quantizer with `single_grouped_parameter=True`.") + self.register_parameter( + "weight", + torch.nn.Parameter(grouped_weights), + init_fn=self.init_method, + get_rng_state_tracker=self.get_rng_state_tracker, + fp8_meta_index=self._offsets["weight"], + ) + for i in range(self.num_gemms): + self.register_parameter(f"weight{i}", None) + + self.set_tensor_parallel_attributes(defer_init=defer_init) + def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) + # Grouped tensor weights is an opt-in feature. + if self.single_grouped_parameter: + self.make_grouped_weights(defer_init=defer_init) + + def set_tensor_parallel_attributes(self, defer_init=False) -> None: + """Set attributes needed for TP""" if not defer_init: # Set parallelism attributes for linear weights - for i in range(self.num_gemms): + grouped_weight = getattr(self, "weight", None) + if grouped_weight is not None: set_tensor_model_parallel_attributes( - tensor=getattr(self, f"weight{i}"), + tensor=grouped_weight, is_parallel=True, dim=1 if self.parallel_mode == "row" else 0, stride=1, ) + else: + for i in range(self.num_gemms): + set_tensor_model_parallel_attributes( + tensor=getattr(self, f"weight{i}"), + is_parallel=True, + dim=1 if self.parallel_mode == "row" else 0, + stride=1, + ) # Set parallelism attributes for linear biases if self.use_bias: @@ -837,14 +933,18 @@ def forward( """ debug = self.is_debug_iter() - assert not isinstance( - inp, QuantizedTensorStorage - ), "GroupedLinear doesn't support input tensor in FP8." - assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." + if isinstance(inp, QuantizedTensorStorage): + raise TypeError("GroupedLinear doesn't support input tensor in FP8.") + if len(m_splits) != self.num_gemms: + raise ValueError( + f"Number of splits ({len(m_splits)}) should match number of" + f" GEMMs ({self.num_gemms})." + ) is_grad_enabled = torch.is_grad_enabled() - with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: + inp = self.prepare_forward(inp, num_gemms=self.num_gemms) + try: weight_tensors = self._get_weight_tensors() bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] @@ -855,9 +955,6 @@ def forward( debug = False quantizers = self._get_quantizers() - if isinstance(weight_tensors, QuantizedTensorStorage): - raise RuntimeError("FP8 weights are not supported in debug mode.") - ( input_quantizers, weight_quantizers, @@ -900,6 +997,9 @@ def forward( ) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + finally: + self.end_forward() + if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out @@ -914,7 +1014,7 @@ def backward_dw(self): with get_nvtx_range_context("_GroupedLinear_wgrad"): (_, grad_biases_, _), tensor_list = self.wgrad_store.pop() wgrad_list = tensor_list[2] - weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + weight_params = self._get_weight_tensors() bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fuse_wgrad_accumulation: for i in range(self.num_gemms): @@ -926,16 +1026,16 @@ def backward_dw(self): del grad_biases_ del wgrad_list del tensor_list - for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: - wgrad_accumulation_and_reduce_hook() + self._trigger_wgrad_accumulation_and_reduce_hooks() def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: """Customize quantizers based on current scaling recipe + linear.""" - assert not self.tp_size > 1, ( - "GroupedLinear doesn't support TP > 1 with Float8 current scaling. " - "Because the TP communication is handled outside of this module." - ) + if self.tp_size > 1: + raise ValueError( + "GroupedLinear doesn't support TP > 1 with Float8 current scaling. " + "Because the TP communication is handled outside of this module." + ) if fwd: for i in range(self.num_gemms): @@ -965,7 +1065,14 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" - weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + grouped_weight = getattr(self, "weight", None) + if grouped_weight is not None: + weight_tensors = grouped_weight.quantized_tensors + if weight_tensors is None: + # TODO(ksivaman): Remove this after GEMM integration. + weight_tensors = grouped_weight.split_into_quantized_tensors() + else: + weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors): warnings.warn( "You are using quantized weights without quantized compute. " @@ -979,7 +1086,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8 and not self.fp8_calibration: + if not self.fp8 and not self.fp8_calibration and not self.primary_weights_in_fp8: return [None] * self.num_gemms weight_quantizers = [ self.quantizers["scaling_fwd"][ @@ -988,7 +1095,7 @@ def _get_weight_quantizers(self) -> List[Quantizer]: for i in range(self.num_gemms) ] for i in range(self.num_gemms): - weight_quantizers[i].internal = True + weight_quantizers[i].internal = not self.primary_weights_in_fp8 return weight_quantizers def _get_quantizers(self): @@ -1033,12 +1140,13 @@ def _get_quantizers(self): def _get_debug_quantizers(self): original_quantizers = self._get_quantizers() - assert TEDebugState.debug_enabled + if not TEDebugState.debug_enabled: + raise RuntimeError("TEDebugState.debug_enabled must be True to get debug quantizers") names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] return tuple( [ - DebugQuantizer(self.name + f".gemm_{q_id}", name, q, self.tp_group) + DebugQuantizer(self.name + f".gemm_{q_id}", name, q, self.tp_group, self.tp_size) for q_id, q in enumerate(qs) ] for name, qs in zip(names, original_quantizers) diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index d4f0a78ba..54fad8d1b 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -4,7 +4,7 @@ """LayerNorm API""" import warnings -from typing import Iterable, Optional, Union +from typing import Any, Iterable, Optional, Union import torch @@ -102,6 +102,10 @@ def __init__( **kwargs, ) + def fast_setattr(self, name: str, value: Any) -> None: + """Fast attribute set for non-parameter fields.""" + self.__dict__[name] = value + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7347fc138..d7241f2a6 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -33,7 +33,6 @@ from ..quantization import FP8GlobalStateManager from ..utils import ( assert_dim_for_fp8_exec, - assert_dim_for_all_gather, cast_if_needed, clear_tensor_data, divide, @@ -56,7 +55,7 @@ _fsdp_scatter_tensors, _fsdp_gather_tensors, ) -from ..constants import GemmParallelModes, dist_group_type +from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ._common import apply_normalization, noop_cat, WeightGradStore @@ -169,7 +168,6 @@ def forward( inputmat = inp if fp8: assert_dim_for_fp8_exec(inputmat, weight) - assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer) # Cast for native AMP nvtx_range_push(f"{nvtx_label}.norm_input_cast") @@ -308,7 +306,8 @@ def forward( # Configure quantizer # If weight is already quantized, no need to set quantizer states - if is_weight_param_quantized: + # for debug mode we create quantizer every iteration, thus we need to set the quantizer states + if is_weight_param_quantized and not debug: weight_quantizer = weight._quantizer elif weight_quantizer is not None: weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) @@ -1207,11 +1206,11 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, - name: str = None, + name: Optional[str] = None, keep_fp8_weight_transpose_cache: bool = True, use_fsdp2: bool = False ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1230,7 +1229,7 @@ def __init__( self.symmetric_ar_type = symmetric_ar_type self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) - self.name = name + self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True self.use_fsdp2 = use_fsdp2 if IS_HIP_EXTENSION else False @@ -1247,6 +1246,10 @@ def __init__( assert ( self.parallel_mode in GemmParallelModes ), f"parallel_mode {parallel_mode} not supported" + if self.parallel_mode == "row": + raise NotImplementedError( + "Normalization does not support tensor-parallel distribution." + ) if self.parallel_mode == "column": self.out_features = divide(self.out_features, self.tp_size) @@ -1410,7 +1413,7 @@ def __init__( torch.nn.Parameter(weight_tensor[split_start:split_end]), init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_meta_index=FP8FwdTensorIdx.GEMM1_WEIGHT, ) # Construct bias parameters if needed @@ -1561,10 +1564,11 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + inp = self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer - ) as inp: + ) + try: # Get concatenated weight and bias tensors weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() @@ -1645,6 +1649,9 @@ def forward( non_tensor_args, ) + finally: + self.end_forward() + if self.return_layernorm_output: out, ln_out = out @@ -1666,20 +1673,20 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): grad_weight_quantizer = None grad_output_quantizer = None output_quantizer = None - input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT] input_quantizer.internal = True if not (self.parallel_mode == "column" and self.sequence_parallel): input_quantizer.optimize_for_gemm = True (weight_quantizer,) = self._get_weight_quantizers() if fp8_output: - output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + output_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT] if is_grad_enabled: - grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1] grad_output_quantizer.internal = True if not (self.parallel_mode == "row" and self.sequence_parallel): grad_output_quantizer.optimize_for_gemm = True if fp8_grad: - grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] return ( input_quantizer, @@ -1697,7 +1704,7 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] return tuple( - DebugQuantizer(self.name, name, q, self.tp_group) + DebugQuantizer(self.name, name, q, self.tp_group, self.tp_size) for name, q in zip(names, original_quantizers) ) @@ -1776,43 +1783,43 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe if fwd: # set configs about amax epsilon and power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon # also set weight quantizer with same amax_epsilon & power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT + FP8FwdTensorIdx.GEMM1_WEIGHT ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT + FP8FwdTensorIdx.GEMM1_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon # parallel related if self.sequence_parallel and self.parallel_mode == "column": # set input_quantizer with amax reduction TP group self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].with_amax_reduction = True self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: # set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here) self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon # parallel related if self.sequence_parallel and self.parallel_mode == "row": # customize grad_output_quantizer with amax reduction TP group self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: @@ -1822,19 +1829,19 @@ def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: if self.sequence_parallel and self.parallel_mode == "column": # set input_quantizer with amax reduction TP group self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].with_amax_reduction = True self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: if self.sequence_parallel and self.parallel_mode == "row": # customize grad_output_quantizer with amax reduction TP group self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: @@ -1858,7 +1865,7 @@ def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration: return [None] - weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT] weight_quantizer.internal = True if IS_HIP_EXTENSION: weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 49667b633..53f1de86d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -43,7 +43,6 @@ init_method_constant, cast_if_needed, assert_dim_for_fp8_exec, - assert_dim_for_all_gather, clear_tensor_data, requires_grad, needs_quantized_gemm, @@ -62,7 +61,7 @@ _get_cuda_rng_state, _set_cuda_rng_state, ) -from ..constants import dist_group_type +from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..tensor.float8_tensor import ( @@ -104,6 +103,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): return { "gelu": (tex.gelu, tex.dgelu, None), "geglu": (tex.geglu, tex.dgeglu, None), + "glu": (tex.glu, tex.dglu, None), "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), "relu": (tex.relu, tex.drelu, None), @@ -120,6 +120,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): return { "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), "geglu": (tex.geglu, tex.dgeglu, None), + "glu": (tex.glu, tex.dglu, None), "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), "relu": (tex.relu, tex.drelu, tex.dbias_drelu), @@ -142,6 +143,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): return { "gelu": (tex.gelu, tex.dgelu, None), "geglu": (tex.geglu, tex.dgeglu, None), + "glu": (tex.glu, tex.dglu, None), "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), "relu": (tex.relu, tex.drelu, None), @@ -227,8 +229,8 @@ def _forward( ub_overlap_ag, ub_overlap_rs, ub_overlap_rs_dgrad, - ub_bulk_wgrad, - ub_bulk_dgrad, + ub_bulk_wgrad, #ROCm: there is a but in upstream - order of dgrad and wgrad here bug here + ub_bulk_dgrad, #does not match order in LayerNormMLP::forward. Fix it there gemm_gelu_fusion, fsdp_group, module, @@ -341,7 +343,6 @@ def _forward( inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) - assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer) activation_func = _act_func( activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None @@ -484,12 +485,13 @@ def _forward( # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch # No need to set the quantizer states if weights are already quantized - if isinstance(fc1_weight, QuantizedTensorStorage): + # for debug mode we create quantizer every iteration, thus we need to set the quantizer states + if isinstance(fc1_weight, QuantizedTensorStorage) and not debug: fc1_weight_quantizer = fc1_weight._quantizer elif fc1_weight_quantizer is not None: fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) - if isinstance(fc2_weight, QuantizedTensorStorage): + if isinstance(fc2_weight, QuantizedTensorStorage) and not debug: fc2_weight_quantizer = fc2_weight._quantizer elif fc2_weight_quantizer is not None: fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) @@ -555,7 +557,7 @@ def _forward( gemm_gelu_fusion = False if debug: gemm_gelu_fusion = False - + if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache: assert fc1_weight_final._transpose is None or fc1_weight_final._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled." @@ -1400,7 +1402,6 @@ def fc2_wgrad_gemm( # Overlap FC1 DGRAD reduce-scatter with WGRAD compute ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) ub_type_fc1_wgrad = tex.CommOverlapType.RS - # -------------------------------------------------- # FC1 DGRAD @@ -1676,7 +1677,7 @@ def fc1_wgrad_gemm( if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) if ctx.autocast_fp8_reduction_skipped: - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True) + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True) # FIX THIS # Scatter Fp8 tranposed-weight buffers @@ -1717,7 +1718,7 @@ class LayerNormMLP(TransformerEngineBaseModule): type of normalization applied. activation : str, default = 'gelu' activation function used. - Options: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, + Options: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``. activation_params : dict, default = None Additional parameters for the activation function. @@ -1853,7 +1854,7 @@ def __init__( zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, - name: str = None, + name: Optional[str] = None, ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, @@ -1864,7 +1865,7 @@ def __init__( keep_fp8_weight_transpose_cache: bool = True, use_fsdp2: bool = False, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.fuse_wgrad_accumulation = fuse_wgrad_accumulation @@ -1897,7 +1898,6 @@ def __init__( for use_fp8 in [False, True] ) ) - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) @@ -1955,7 +1955,15 @@ def __init__( self.layer_norm_bias = None # FC1 init - if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]: + if self.activation in [ + "geglu", + "glu", + "qgeglu", + "reglu", + "sreglu", + "swiglu", + "clamped_swiglu", + ]: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition @@ -1968,7 +1976,7 @@ def __init__( fc1_weight, init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_meta_index=FP8FwdTensorIdx.GEMM1_WEIGHT, ) if self.use_bias: @@ -1988,7 +1996,7 @@ def __init__( fc2_weight, init_fn=output_layer_init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, + fp8_meta_index=FP8FwdTensorIdx.GEMM2_WEIGHT, ) if self.use_bias: @@ -2117,8 +2125,9 @@ def forward( if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True - with self.prepare_forward(inp, num_gemms=2) as inp: + inp = self.prepare_forward(inp, num_gemms=2) + try: quantizers = ( self._get_quantizers(fp8_output, is_grad_enabled) if not debug @@ -2158,7 +2167,7 @@ def forward( # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode if ( not IS_HIP_EXTENSION and self.bias_gelu_nvfusion and not use_reentrant_activation_recompute() ): - self.bias_gelu_nvfusion = False + self.fast_setattr("bias_gelu_nvfusion", False) if is_grad_enabled: fwd_fn = _LayerNormMLP.apply @@ -2206,8 +2215,8 @@ def forward( self.ub_overlap_ag, self.ub_overlap_rs, self.ub_overlap_rs_dgrad, - self.ub_bulk_wgrad, - self.ub_bulk_dgrad, + self.ub_bulk_wgrad, #ROCm: there is a bug in upstream with dgrad and wgrad + self.ub_bulk_dgrad, #order not matching _LayerNormMLP::_forward self.gemm_gelu_fusion and not debug, self.fsdp_group, self, @@ -2230,6 +2239,9 @@ def forward( non_tensor_args, ) + finally: + self.end_forward() + if self.return_layernorm_output: out, ln_out = out @@ -2259,11 +2271,11 @@ def _get_quantizers(self, fp8_output, is_grad_enabled): ) = [None] * 10 fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers() if self.fp8 or self.fp8_calibration: - fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + fc1_input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT] fc1_input_quantizer.internal = True if not self.sequence_parallel: fc1_input_quantizer.optimize_for_gemm = True - fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] + fc2_input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_INPUT] fc2_input_quantizer.set_usage( rowwise=True, columnwise=isinstance( @@ -2274,18 +2286,16 @@ def _get_quantizers(self, fp8_output, is_grad_enabled): fc2_input_quantizer.internal = True fc2_input_quantizer.optimize_for_gemm = True if fp8_output: - fc2_output_quantizer = self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM2_OUTPUT - ] + fc2_output_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_OUTPUT] if is_grad_enabled: fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ] fc2_grad_output_quantizer.internal = True if not self.sequence_parallel: fc2_grad_output_quantizer.optimize_for_gemm = True fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ] fc1_grad_output_quantizer.internal = True fc1_grad_output_quantizer.optimize_for_gemm = True @@ -2378,6 +2388,7 @@ def _clamped_swiglu(x, limit, alpha): activation_map = { "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + "glu": lambda x: torch.sigmoid(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh") * x.chunk(2, -1)[1], @@ -2432,6 +2443,7 @@ def make_debug(prefix, offset): label, None if label in ("dgrad", "wgrad") else base_quantizers[i + offset], self.tp_group, + self.tp_size, ) for i, label in enumerate(labels) ] @@ -2446,63 +2458,63 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe if fwd: # fc1_input_quantizer: set configs about amax epsilon and power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon # fc2_input_quantizer self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM2_INPUT + FP8FwdTensorIdx.GEMM2_INPUT ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM2_INPUT + FP8FwdTensorIdx.GEMM2_INPUT ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon # fc1_weight_quantizer: also set numerical configs about weight self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT + FP8FwdTensorIdx.GEMM1_WEIGHT ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT + FP8FwdTensorIdx.GEMM1_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon # fc2_weight_quantizer self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM2_WEIGHT + FP8FwdTensorIdx.GEMM2_WEIGHT ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM2_WEIGHT + FP8FwdTensorIdx.GEMM2_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon # parallel related if self.sequence_parallel and self.set_parallel_mode: # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].with_amax_reduction = True self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: # fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon # fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon if self.sequence_parallel and self.set_parallel_mode: # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ].amax_reduction_group = self.tp_group def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: @@ -2512,19 +2524,19 @@ def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: if self.sequence_parallel and self.set_parallel_mode: # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].with_amax_reduction = True self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: if self.sequence_parallel and self.set_parallel_mode: # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ].amax_reduction_group = self.tp_group def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: @@ -2535,11 +2547,11 @@ def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration: return [None, None] - fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + fc1_weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT] fc1_weight_quantizer.internal = True if IS_HIP_EXTENSION: fc1_weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) - fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] + fc2_weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True if IS_HIP_EXTENSION: fc2_weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) @@ -2580,5 +2592,4 @@ def backward_dw(self): del fc2_wgrad del fc1_wgrad del fc1_bias_grad - for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: - wgrad_accumulation_and_reduce_hook() + self._trigger_wgrad_accumulation_and_reduce_hooks() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 01d07d91a..111a0210b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -36,7 +36,6 @@ requires_grad, needs_quantized_gemm, assert_dim_for_fp8_exec, - assert_dim_for_all_gather, nvtx_range_pop, nvtx_range_push, get_nvtx_range_context, @@ -56,7 +55,7 @@ from ..cpp_extensions import ( general_gemm, ) -from ..constants import GemmParallelModes, dist_group_type +from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..quantized_tensor import ( @@ -180,7 +179,6 @@ def forward( own_quantized_input = False if fp8: assert_dim_for_fp8_exec(inputmat, weight) - assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) if save_original_input: assert not isinstance( input_quantizer, Float8Quantizer @@ -258,7 +256,8 @@ def forward( if fp8 or debug: # Configure quantizer # No need to set the quantizer states if weight is already quantized - if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): + # for debug mode we create quantizer every iteration, thus we need to set the quantizer states + if weight_quantizer is not None and (not isinstance(weight, QuantizedTensor) or debug): columnwise_usage = is_grad_enabled and inp.requires_grad and keep_fp8_weight_transpose_cache if not columnwise_usage and keep_fp8_weight_transpose_cache: columnwise_usage = ( @@ -441,8 +440,8 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight - if cpu_offloading: mark_not_offload(weight, weightmat, bias) + # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, @@ -1136,7 +1135,7 @@ def __init__( keep_fp8_weight_transpose_cache: bool = True, use_fsdp2: bool = False ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1149,7 +1148,6 @@ def __init__( self.rng_tracker_name = rng_tracker_name self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True @@ -1312,7 +1310,7 @@ def __init__( torch.nn.Parameter(weight_tensor[split_start:split_end]), init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_meta_index=FP8FwdTensorIdx.GEMM1_WEIGHT, ) # Construct bias parameters if needed @@ -1435,11 +1433,8 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( - inp, - allow_non_contiguous=isinstance(inp, QuantizedTensor), - ) as inp: - + inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) + try: weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() quantizers = ( @@ -1512,6 +1507,8 @@ def forward( bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, non_tensor_args, ) + finally: + self.end_forward() if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) @@ -1526,20 +1523,20 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): grad_weight_quantizer = None grad_output_quantizer = None output_quantizer = None - input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT] input_quantizer.internal = True if not (self.parallel_mode == "column" and self.sequence_parallel): input_quantizer.optimize_for_gemm = True (weight_quantizer,) = self._get_weight_quantizers() if fp8_output: - output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + output_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT] if is_grad_enabled: - grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1] grad_output_quantizer.internal = True if not (self.parallel_mode == "row" and self.sequence_parallel): grad_output_quantizer.optimize_for_gemm = True if fp8_grad: - grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] return ( input_quantizer, weight_quantizer, @@ -1556,7 +1553,7 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] return tuple( - DebugQuantizer(self.name, name, q, self.tp_group) + DebugQuantizer(self.name, name, q, self.tp_group, self.tp_size) for name, q in zip(names, original_quantizers) ) @@ -1644,43 +1641,43 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe if fwd: # set configs about amax epsilon and power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon # also set weight quantizer with same amax_epsilon & power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT + FP8FwdTensorIdx.GEMM1_WEIGHT ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT + FP8FwdTensorIdx.GEMM1_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon # paralle related if self.sequence_parallel and self.parallel_mode == "column": # customize input_quantizer with amax reduction TP group self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].with_amax_reduction = True self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: # set grad_output_quantizer with amax epsilon and power_2_scale self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon # parallel related if self.sequence_parallel and self.parallel_mode == "row": # customize grad_output_quantizer with amax reduction TP group self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: @@ -1690,26 +1687,26 @@ def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: if self.sequence_parallel and self.parallel_mode == "column": # customize input_quantizer with amax reduction TP group self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].with_amax_reduction = True self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: if self.sequence_parallel and self.parallel_mode == "row": # customize grad_output_quantizer with amax reduction TP group self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration: return [None] - weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT] weight_quantizer.internal = True if IS_HIP_EXTENSION: weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index ace4be31d..f8d5aade5 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -4,7 +4,7 @@ """RMSNorm API""" import warnings -from typing import Iterable, Optional, Union +from typing import Any, Iterable, Optional, Union import torch @@ -106,6 +106,10 @@ def __init__( **kwargs, ) + def fast_setattr(self, name: str, value: Any) -> None: + """Fast attribute set for non-parameter fields.""" + self.__dict__[name] = value + def reset_rms_norm_parameters(self) -> None: """Deprecated""" warnings.warn( diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index 2b270ea3d..99f51a9c7 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -8,7 +8,9 @@ """ -from transformer_engine.pytorch.ops.basic import * -from transformer_engine.pytorch.ops.linear import Linear -from transformer_engine.pytorch.ops.op import FusibleOperation -from transformer_engine.pytorch.ops.sequential import Sequential +from .basic import * +from .fuser import register_backward_fusion, register_forward_fusion +from .linear import Linear +from .op import BasicOperation, FusedOperation, FusibleOperation +from .sequential import Sequential +from . import fused diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 665ffe359..e0a3f4101 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -7,6 +7,7 @@ from .activation import ( GELU, GEGLU, + GLU, QGELU, QGEGLU, ReLU, @@ -14,8 +15,6 @@ SReLU, SReGLU, SiLU, - SwiGLU, - ClampedSwiGLU, ) from .add_extra_input import AddExtraInput from .all_gather import AllGather @@ -24,6 +23,7 @@ from .bias import Bias from .constant_scale import ConstantScale from .dropout import Dropout +from .grouped_linear import GroupedLinear from .identity import Identity from .l2normalization import L2Normalization from .layer_norm import LayerNorm @@ -32,3 +32,4 @@ from .reduce_scatter import ReduceScatter from .reshape import Reshape from .rmsnorm import RMSNorm +from .swiglu import ClampedSwiGLU, ScaledSwiGLU, SwiGLU diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 9d54e12db..13cb519c1 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -20,6 +20,7 @@ __all__ = [ "GELU", "GEGLU", + "GLU", "QGELU", "QGEGLU", "ReLU", @@ -27,8 +28,6 @@ "SReLU", "SReGLU", "SiLU", - "SwiGLU", - "ClampedSwiGLU", ] @@ -153,7 +152,7 @@ class GELU(_ActivationOperation): \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) - See `Gaussian Error Linear Units (GELUs)`__. + See `Gaussian Error Linear Units (GELUs) `__. """ @@ -164,6 +163,38 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dgelu(*args, **kwargs) +class GLU(_ActivationOperation): + r"""Gated Linear Unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GLU}(a,b) = \sigma(a) * b + + where :math:`\sigma` is the sigmoid function. + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `Language Modeling with Gated Convolutional Networks `__ + and `GLU Variants Improve Transformer `__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.glu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dglu(*args, **kwargs) + + class GEGLU(_ActivationOperation): r"""Gaussian Error Gated Linear Unit @@ -188,7 +219,7 @@ class GEGLU(_ActivationOperation): the first half of the input tensor, while PyTorch applies it to the second half. - See `GLU Variants Improve Transformer`__. + See `GLU Variants Improve Transformer `__. """ @@ -202,8 +233,8 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: class QGELU(_ActivationOperation): r"""Quick Gaussian Error Linear Unit - Quick GELU from `HuggingFace`__ - and `paper`__. + Quick GELU from `HuggingFace `__ + and `paper `__. .. math:: @@ -285,7 +316,7 @@ class ReGLU(_ActivationOperation): the first half of the input tensor, while PyTorch applies it to the second half. - See `GLU Variants Improve Transformer`__. + See `GLU Variants Improve Transformer `__. """ @@ -303,7 +334,7 @@ class SReLU(_ActivationOperation): \text{SReLU}(x) = \max(x^2,0) - See `Primer: Searching for Efficient Transformers for Language Modeling`__. + See `Primer: Searching for Efficient Transformers for Language Modeling `__. """ @@ -355,76 +386,3 @@ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dsilu(*args, **kwargs) - - -class SwiGLU(_ActivationOperation): - r"""Swish gated linear unit - - The input tensor is split into chunks :math:`a` and :math:`b` - along the last dimension and the following is computed: - - .. math:: - - \text{GEGLU}(a,b) = \text{SiLU}(a) * b - - where - - .. math:: - - \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} - - .. warning:: - - Transformer Engine's gated activations and PyTorch's GLU - activation follow opposite conventions for :math:`a` and - :math:`b`. Transformer Engine applies the gating function to - the first half of the input tensor, while PyTorch applies it to - the second half. - - The Sigmoid Linear Unit (SiLU) gating function is also known as - the swish function. See - `GLU Variants Improve Transformer`__ - and `Gaussian Error Linear Units (GELUs)`__. - - """ - - def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.swiglu(*args, **kwargs) - - def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.dswiglu(*args, **kwargs) - - -class ClampedSwiGLU(_ActivationOperation): - r"""GPT-OSS - Implementation based on `GPT-OSS`__. - - This activation has two differences compared to the original SwiGLU - 1. Both gate and pre-activations are clipped based on parameter limit. - 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. - - .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt - from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. - - Parameters - ---------- - limit : float - The clamp limit. - alpha : float - The scaling factor for the sigmoid function used in the activation. - cache_quantized_input : bool, default = False - Quantize input tensor when caching for use in the backward pass. - """ - - def __init__( - self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False - ): - super().__init__(cache_quantized_input=cache_quantized_input) - self.limit = limit - self.alpha = alpha - - def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) - - def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) diff --git a/transformer_engine/pytorch/ops/basic/add_extra_input.py b/transformer_engine/pytorch/ops/basic/add_extra_input.py index 47f2b6e24..fc3ca9cad 100644 --- a/transformer_engine/pytorch/ops/basic/add_extra_input.py +++ b/transformer_engine/pytorch/ops/basic/add_extra_input.py @@ -30,7 +30,7 @@ class AddExtraInput(BasicOperation): feature and most users are discouraged from it. In-place operations break some autograd assumptions and they can result in subtle, esoteric bugs. - Compare to `MakeExtraOutput`, which does a similar operation in + Compare to ``MakeExtraOutput``, which does a similar operation in the backward pass. """ diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index e640f3ffb..48376a297 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -48,8 +48,8 @@ def _wait_async(handle: Optional[Any]) -> None: class BasicLinear(BasicOperation): """Apply linear transformation: :math:`y = x A^T` - This is a drop-in replacement for `torch.nn.Linear` with - `bias=False`. + This is a drop-in replacement for ``torch.nn.Linear`` with + ``bias=False``. Parameters ---------- @@ -61,27 +61,27 @@ class BasicLinear(BasicOperation): Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - tensor_parallel_mode : {`None`, "column", "row"}, default = `None` + tensor_parallel_mode : {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group : torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel : bool, default = `False` + sequence_parallel : bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) rng_state_tracker_function : callable - Function that returns `CudaRNGStatesTracker`, which is used + Function that returns ``CudaRNGStatesTracker``, which is used for model-parallel weight initialization - accumulate_into_main_grad : bool, default = `False` + accumulate_into_main_grad : bool, default = False Whether to directly accumulate weight gradients into the - weight's `main_grad` attribute instead of relying on PyTorch - autograd. The weight's `main_grad` must be set externally and - there is no guarantee that `grad` will be set or be - meaningful. This is primarily intented to integrate with + weight's ``main_grad`` attribute instead of relying on PyTorch + autograd. The weight's ``main_grad`` must be set externally + and there is no guarantee that ``grad`` will be set or be + meaningful. This is primarily intended to integrate with Megatron-LM. This argument along with weight tensor having - attribute 'overwrite_main_grad' set to True will overwrite - `main_grad` instead of accumulating. + attribute ``overwrite_main_grad`` set to ``True`` will + overwrite ``main_grad`` instead of accumulating. userbuffers_options, dict, optional Options for overlapping tensor-parallel communication with compute using Userbuffers. This feature is highly @@ -184,7 +184,7 @@ def _canonicalize_tensor_parallelism( Parameters ---------- - mode: {`None`, "column", "row"} + mode: {None, "column", "row"} Mode for tensor parallelism process_group: torch.distributed.ProcessGroup Process group for tensor parallelism @@ -200,7 +200,7 @@ def _canonicalize_tensor_parallelism( Returns ------- - mode: {`None`, "column", "row"} + mode: {None, "column", "row"} Mode for tensor parallelism process_group: torch.distributed.ProcessGroup Process group for tensor parallelism @@ -446,18 +446,18 @@ def _functional_forward( Output tensor beta: float, optional Scaling factor applied to original value of out when accumulating into it - accumulate_into_out: bool, default = `False` + accumulate_into_out: bool, default = False Add result to output tensor instead of overwriting - tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + tensor_parallel_mode: {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel: bool, default = `False` + sequence_parallel: bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_quantized_compute: bool, default = `False` + with_quantized_compute: bool, default = False Whether to perform compute with quantized data. input_quantizer: Quantizer, optional Builder class for quantized input tensor. @@ -465,10 +465,10 @@ def _functional_forward( Builder class for quantized weight tensor. output_quantizer: Quantizer, optional Builder class for quantized output tensor. - input_requires_grad: bool, default = `True` + input_requires_grad: bool, default = True Whether the loss gradient w.r.t. the input tensor is required in the backward pass. - weight_requires_grad: bool, default = `True` + weight_requires_grad: bool, default = True Whether the loss gradient w.r.t. the weight tensor is required in the backward pass. @@ -477,11 +477,11 @@ def _functional_forward( torch.Tensor Output tensor torch.Tensor, optional - Input tensor, ready for use in backward pass. `None` is + Input tensor, ready for use in backward pass. ``None`` is returned if loss gradient w.r.t. the weight tensor is not required. torch.Tensor, optional - Weight tensor, ready for use in backward pass. `None` is + Weight tensor, ready for use in backward pass. ``None`` is returned if loss gradient w.r.t. the input tensor is not required. @@ -682,24 +682,24 @@ def _functional_backward( Loss gradient w.r.t. weight tensor grad_weight_beta: float, optional Scaling factor applied to original value of grad_weight when accumulating into it - accumulate_into_grad_weight: bool, default = `False` + accumulate_into_grad_weight: bool, default = False Add result to weight grad instead of overwriting grad_input: torch.Tensor, optional Loss gradient w.r.t. input tensor grad_input_beta: float, optional Scaling factor applied to original value of grad_input when accumulating into it - accumulate_into_grad_input: bool, default = `False` + accumulate_into_grad_input: bool, default = False Add result to input grad instead of overwriting - tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + tensor_parallel_mode: {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel: bool, default = `False` + sequence_parallel: bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_quantized_compute: bool, default = `False` + with_quantized_compute: bool, default = False Whether to perform compute with quantized data. input_quantizer: Quantizer, optional Builder class for quantized input tensor. diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index 8b6025108..d580f8486 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -18,7 +18,7 @@ class Bias(BasicOperation): """Apply additive bias - This is equivalent to the additive bias in `torch.nn.Linear`. + This is equivalent to the additive bias in ``torch.nn.Linear``. Parameters ---------- @@ -28,7 +28,7 @@ class Bias(BasicOperation): Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - tensor_parallel : bool, default = `False` + tensor_parallel : bool, default = False Whether to distribute input tensor and bias tensors along inner dimension tensor_parallel_group : torch.distributed.ProcessGroup, default = world group diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py new file mode 100644 index 000000000..b44e77b0c --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -0,0 +1,702 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for grouped linear layer.""" + +from __future__ import annotations +from collections.abc import Callable, Iterable, Sequence +import contextlib +import math +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...cpp_extensions import general_grouped_gemm +from ...distributed import CudaRNGStatesTracker +from ...module.base import ( + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, + get_dummy_wgrad, +) +from ...quantization import FP8GlobalStateManager, Recipe +from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, + round_up_to_nearest_multiple, +) +from .._common import is_quantized_tensor, maybe_dequantize +from ..op import BasicOperation, OperationContext + + +class GroupedLinear(BasicOperation): + r"""Apply multiple linear transformations: :math:``y_i = x_i W_i^T + b_i`` + + This feature is experimental and subject to change. + + This is equivalent to splitting the input tensor along its first + dimension, applying a separate ``torch.nn.Linear`` to each split, + and concatenating along the first dimension. + + Parameters + ---------- + num_groups : int + Number of linear transformations. + in_features : int + Inner dimension of input tensor. + out_features : int + Inner dimension of output tensor. + bias : bool, default = ``True`` + Apply additive bias. + device : torch.device, default = default CUDA device + Tensor device. + dtype : torch.dtype, default = default dtype + Tensor datatype. + rng_state_tracker_function : callable + Function that returns ``CudaRNGStatesTracker``, which is used + for model-parallel weight initialization. + accumulate_into_main_grad : bool, default = ``False`` + Whether to directly accumulate weight gradients into the + weight's ``main_grad`` attribute instead of relying on PyTorch + autograd. The weight's ``main_grad`` must be set externally + and there is no guarantee that `grad` will be set or be + meaningful. This is primarily intended to integrate with + Megatron-LM. This argument along with weight tensor having + attribute ``overwrite_main_grad`` set to True will overwrite + ``main_grad`` instead of accumulating. + + """ + + # Operation expects input split sizes + num_extra_inputs: int = 1 + + def __init__( + self, + num_groups: int, + in_features: int, + out_features: int, + *, + bias: bool = True, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, + rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, + accumulate_into_main_grad: bool = False, + ) -> None: + super().__init__() + + # Weight tensor dimensions + self.num_groups: int = num_groups + self.in_features: int = in_features + self.out_features: int = out_features + if self.num_groups <= 0: + raise ValueError(f"Invalid number of groups ({self.num_groups})") + if self.in_features <= 0: + raise ValueError(f"Invalid input size ({self.in_features})") + if self.out_features <= 0: + raise ValueError(f"Invalid output size ({self.out_features})") + + # Weight tensor attributes + device = canonicalize_device(device) + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Initialize recipe state if needed for natively quantized weight + self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters() + if self._with_quantized_weight: + self.reset_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe()) + + # RNG state tracker + self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] + self._rng_state_tracker_function = rng_state_tracker_function + + # Register weights + self.weight0: torch.nn.Parameter + for group_idx in range(self.num_groups): + weight_tensor = torch.empty( + self.out_features, + self.in_features, + device="meta", + dtype=dtype, + ) + self.register_parameter( + f"weight{group_idx}", + torch.nn.Parameter(weight_tensor), + ) + + # Register biases + self.bias0: Optional[torch.nn.Parameter] + for group_idx in range(self.num_groups): + bias_tensor = None + if bias: + bias_tensor = torch.empty( + self.out_features, + device="meta", + dtype=dtype, + ) + bias_tensor = torch.nn.Parameter(bias_tensor) + self.register_parameter(f"bias{group_idx}", bias_tensor) + + # Initialize weights if needed + if device.type != "meta": + self.reset_parameters() + + # Whether to accumulate weight gradient into main_grad + self._accumulate_into_main_grad: bool = accumulate_into_main_grad + + def num_quantizers(self, mode: str) -> int: + if mode == "forward": + return 2 * self.num_groups + if mode == "backward": + return self.num_groups + return 0 + + @property + def has_bias(self) -> bool: + """Whether an additive bias is being applied""" + return self.bias0 is not None + + def reset_parameters(self) -> None: + """Initialize parameter buffers and values""" + + # Parameter device + device = self.weight0.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize weight values + # Note: Allocate a single buffer in order to support grouped + # GEMM kernels that expect a single weight buffer. + packed_weights = torch.empty( + self.num_groups, + self.out_features, + self.in_features, + dtype=self.weight0.dtype, + device=device, + ) + weights = [packed_weights[idx] for idx in range(self.num_groups)] + for weight in weights: + init_context = contextlib.nullcontext() + if self._rng_state_tracker_function is not None: + init_context = self._rng_state_tracker_function().fork() + with init_context: + torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + + # Quantize weights if needed + if self._with_quantized_weight: + + # Configure quantizers + quantizers = [ + self.get_quantizer("forward", 2 * idx + 1) for idx in range(self.num_groups) + ] + with_rowwise_usage = True + with_columnwise_usage = torch.is_grad_enabled() + for quantizer in quantizers: + if quantizer is None: + raise RuntimeError( + "Tried to quantize weight with deferred initialization " + "due to meta device, but no quantizer was available. " + "This is most likely because the weight was initialized " + "within quantized_model_init, but the forward pass was not " + "performed within autocast." + ) + quantizer.set_usage( + rowwise=with_rowwise_usage, + columnwise=with_columnwise_usage, + ) + quantizer.internal = False + + # Quantize weights + weights = self._quantize_weights(weights, quantizers) + + # Register weights + for group_idx, weight in enumerate(weights): + if not isinstance(weight, torch.nn.Parameter): + weight = torch.nn.Parameter(weight) + setattr(self, f"weight{group_idx}", weight) + + # Initialize biases if needed + if self.bias0 is not None: + packed_biases = torch.zeros( + self.num_groups, + self.out_features, + dtype=self.bias0.dtype, + device=device, + ) + for group_idx in range(self.num_groups): + bias = torch.nn.Parameter(packed_biases[group_idx]) + setattr(self, f"bias{group_idx}", bias) + + def _quantize_weights( + self, + weights: Sequence[torch.Tensor], + quantizers: Sequence[Quantizer], + ) -> Sequence[torch.Tensor]: + """Construct quantized weight tensors.""" + + # Manually construct MXFP8 weights + if isinstance(quantizers[0], MXFP8Quantizer): + return self._quantize_weights_mxfp8(weights, quantizers) + + # Use quantizers to construct quantized weights + with torch.no_grad(): + return [quantizer(weight) for quantizer, weight in zip(quantizers, weights)] + + def _quantize_weights_mxfp8( + self, + weights: Sequence[torch.Tensor], + quantizers: Sequence[Quantizer], + ) -> Sequence[MXFP8Tensor]: + """Construct MXFP8 weight tensors. + + Instead of allocating separate buffers for each weight tensor, + this function constructs large buffers and assigns subviews to + each tensor. This is intended to support grouped GEMM kernels + that expect packed buffers. + + """ + + # Tensor dimensions + num_groups = len(weights) + out_features, in_features = weights[0].size() + packed_shape = (num_groups, out_features, in_features) + unpacked_shape = (out_features, in_features) + + # Tensor attributes + device = weights[0].device + dtype = weights[0].dtype + requires_grad = torch.is_grad_enabled() + with_rowwise_usage = quantizers[0].rowwise_usage + with_columnwise_usage = quantizers[0].columnwise_usage + + # Construct packed buffers + rowwise_data = [None] * num_groups + rowwise_scales = [None] * num_groups + columnwise_data = [None] * num_groups + columnwise_scales = [None] * num_groups + if with_rowwise_usage: + scale_shape = ( + num_groups, + round_up_to_nearest_multiple(out_features, 128), + round_up_to_nearest_multiple(in_features // 32, 4), + ) + packed_data = torch.empty(packed_shape, dtype=torch.uint8, device=device) + packed_scales = torch.empty(scale_shape, dtype=torch.uint8, device=device) + rowwise_data = [packed_data[idx] for idx in range(num_groups)] + rowwise_scales = [packed_scales[idx] for idx in range(num_groups)] + if with_columnwise_usage: + scale_shape = ( + num_groups, + round_up_to_nearest_multiple(out_features // 32, 4), + round_up_to_nearest_multiple(in_features, 128), + ) + packed_data = torch.empty(packed_shape, dtype=torch.uint8, device=device) + packed_scales = torch.empty(scale_shape, dtype=torch.uint8, device=device) + columnwise_data = [packed_data[idx] for idx in range(num_groups)] + columnwise_scales = [packed_scales[idx] for idx in range(num_groups)] + + # Construct MXFP8 tensors and cast to MXFP8 + out = [] + with torch.no_grad(): + for group_idx in range(num_groups): + weight = MXFP8Tensor( + shape=unpacked_shape, + dtype=dtype, + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise_data=rowwise_data[group_idx], + rowwise_scale_inv=rowwise_scales[group_idx], + columnwise_data=columnwise_data[group_idx], + columnwise_scale_inv=columnwise_scales[group_idx], + quantizer=quantizers[group_idx], + requires_grad=requires_grad, + with_gemm_swizzled_scales=False, + ) + weight.copy_(weights[group_idx]) + out.append(weight) + + return out + + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() + + # Initialize params if needed + if any(param.device.type == "meta" for param in self.parameters()): + self.reset_parameters() + + # Check that weights are consistent + dtype = self.weight0.dtype + device = self.weight0.device + weight_requires_grad = self.weight0.requires_grad + weight_tensor_type = type(self.weight0.data) + for group_idx in range(self.num_groups): + weight = getattr(self, f"weight{group_idx}") + if weight.dtype != dtype: + raise RuntimeError( + f"Weight {group_idx} has invalid dtype (expected {dtype}, got {weight.dtype})." + ) + if not devices_match(weight.device, device): + raise RuntimeError( + f"Weight {group_idx} has invalid device " + f"(expected {device}, got {weight.device})." + ) + if weight.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Weight {group_idx} has requires_grad={weight.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck + raise RuntimeError( + f"Weight {group_idx} has invalid tensor type " + f"(expected {weight_tensor_type.__name__}, " + f"got {type(weight.data).__name__})." + ) + + # Check that biases are consistent + for group_idx in range(self.num_groups): + bias = getattr(self, f"bias{group_idx}") + if self.has_bias: + if bias is None: + raise RuntimeError(f"Expected biases, but bias {group_idx} is uninitialized") + if bias.dtype != dtype: + raise RuntimeError( + f"Bias {group_idx} has invalid dtype (expected {dtype}, got {bias.dtype})." + ) + if not devices_match(bias.device, device): + raise RuntimeError( + f"Bias {group_idx} has invalid device " + f"(expected {device}, got {bias.device})." + ) + if bias.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Bias {group_idx} has requires_grad={bias.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + else: + if bias is not None: + raise RuntimeError(f"Expected no biases, but bias {group_idx} is initialized") + + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + super().pre_fuser_forward(requires_grad=requires_grad) + if FP8GlobalStateManager.is_fp8_enabled(): + # Assume weights have consistent grad requirement + weight_requires_grad = requires_grad and self.weight0.requires_grad + + # Configure quantizer usages + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + for group_idx in range(self.num_groups): + input_quantizer = self.get_quantizer("forward", 2 * group_idx) + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizer = self.get_quantizer("backward", group_idx) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: + super().reset_recipe_state(recipe=recipe) + + for group_idx in range(self.num_groups): + # Input/grad output quantizers use internal tensors + input_quantizer = self.get_quantizer("forward", 2 * group_idx) + grad_output_quantizer = self.get_quantizer("backward", group_idx) + if input_quantizer is not None: + input_quantizer.internal = True + if grad_output_quantizer is not None: + grad_output_quantizer.internal = True + + # Handle weight quantizer + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + if weight_quantizer is None: + pass + elif is_quantized_tensor(getattr(self, f"weight{group_idx}", None)): + # Make sure weight param has correct quantizer + weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) + weight_quantizer.internal = False + getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy()) + else: + # Use internal tensors if quantized weights will not be + # exposed externally + weight_quantizer.internal = ( + not FP8GlobalStateManager.with_fp8_parameters() + and not getattr(self, "_with_quantized_weight", False) + ) + + # Recipe-specific configuration + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + if recipe is not None: + if recipe.float8_current_scaling(): + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon + grad_output_quantizer.force_pow_2_scales = ( + recipe.fp8_quant_bwd_grad.power_2_scale + ) + grad_output_quantizer.amax_epsilon_scales = ( + recipe.fp8_quant_bwd_grad.amax_epsilon + ) + + def op_forward(self, *args, **kwargs): + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs): + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + num_groups = self.num_groups + has_bias = self.has_bias + device = self.weight0.device + + # Check which grads are required + ctx = basic_op_ctxs[0] + input_requires_grad = ctx.requires_grad + weight_requires_grad = ctx.requires_grad and self.weight0.requires_grad + + # Quantizers + input_quantizers = [None] * num_groups + weight_quantizers = [None] * num_groups + grad_output_quantizers = [None] * num_groups + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + for group_idx in range(num_groups): + input_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx) + weight_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizers[group_idx] = self.get_quantizer("backward", group_idx) + + # Get autocast dtype if needed + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = self.weight0.dtype + + # Extract split sizes from extra input + split_sizes = basic_op_extra_inputs[0][0] + split_sizes_int = [int(s) for s in split_sizes.tolist()] + if len(split_sizes_int) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {len(split_sizes_int)}.") + + # Extract params + weights = [getattr(self, f"weight{idx}") for idx in range(num_groups)] + bs = None + if has_bias: + bs = [maybe_dequantize(getattr(self, f"bias{idx}"), dtype) for idx in range(num_groups)] + + # Convert weight dtype if needed + ws = [] + for w, quantizer in zip(weights, weight_quantizers): + if not with_quantized_compute: + w = maybe_dequantize(w, dtype) + elif with_quantized_compute and not is_quantized_tensor(w): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + w = quantizer(w) + ws.append(w) + + # Split input tensor and convert dtypes if needed + x = maybe_dequantize(input_, dtype) + xs = None + if with_quantized_compute: + for quantizer in input_quantizers: + quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + xs = tex.split_quantize(x, split_sizes_int, input_quantizers) + else: + xs = torch.split(x, split_sizes_int) + + # Allocate output tensor + in_shape = list(input_.size()) + out_shape = in_shape[:-1] + [self.out_features] + out = torch.empty(out_shape, dtype=dtype, device=device) + + # Perform GEMMs + general_grouped_gemm( + ws, + xs, + [out], + [None] * num_groups, # quantization_params + dtype, + m_splits=split_sizes_int, + bias=bs, + use_bias=has_bias, + use_split_accumulator=_2X_ACC_FPROP, + single_output=True, + ) + + # Prepare weight tensors for backward pass + if not input_requires_grad: + ws = [None] * num_groups + elif with_quantized_compute: + for w, weight_param in zip(ws, weights): + if w is not weight_param: + w.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Prepare input tensor for backward pass + if not weight_requires_grad: + xs = [None] * num_groups + elif with_quantized_compute: + for x in xs: + x.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Save state for backward pass + if ctx.requires_grad: + ctx.save_for_backward(split_sizes, *xs, *ws) + ctx.with_quantized_compute = with_quantized_compute + ctx.input_quantizers = input_quantizers + ctx.weight_quantizers = weight_quantizers + ctx.grad_output_quantizers = grad_output_quantizers + ctx.grad_input_quantizers = None + ctx.dtype = dtype + ctx.input_requires_grad = input_requires_grad + ctx.weight_requires_grad = weight_requires_grad + + return out, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + num_groups = self.num_groups + has_bias = self.has_bias + device = self.weight0.device + + # Saved tensors from forward pass + ctx = basic_op_ctxs[0] + saved_tensors = ctx.saved_tensors + split_sizes, saved_tensors = saved_tensors[0], saved_tensors[1:] + xs, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:] + ws, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:] + + # Split grad output tensor and convert dtypes if needed + split_sizes_int = [int(s) for s in split_sizes.tolist()] + dy = maybe_dequantize(grad_output, ctx.dtype) + dys = None + grad_biases = [None] * num_groups + if ctx.with_quantized_compute: + for quantizer in ctx.grad_output_quantizers: + quantizer.set_usage( + rowwise=ctx.input_requires_grad, + columnwise=ctx.weight_requires_grad, + ) + dys = tex.split_quantize(dy, split_sizes_int, ctx.grad_output_quantizers) + if has_bias: + grad_biases = [ + dy.reshape(-1, dy.size(-1)).sum(dim=0) + for dy in torch.split(grad_output, split_sizes_int) + ] + else: + dys = torch.split(dy, split_sizes_int) + if has_bias: + grad_biases = [dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys] + + # Initialize grad weight buffers + accumulate_into_main_grad = self._accumulate_into_main_grad + grad_weights = [None] * num_groups + if ctx.weight_requires_grad: + if accumulate_into_main_grad: + # Megatron-LM wgrad fusion + # Note: Get grad tensors from params so we can + # accumulate directly into it. + for group_idx in range(num_groups): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + grad_weights[group_idx] = weight_param.main_grad + accumulate_into_main_grad = not getattr(self.weight0, "overwrite_main_grad", False) + else: + weight_shape = ws[0].size() + for group_idx in range(num_groups): + grad_weights[group_idx] = torch.empty( + weight_shape, + dtype=ctx.dtype, + device=device, + ) + else: + accumulate_into_main_grad = False + + # Perform dgrad GEMMs + grad_input = None + if ctx.input_requires_grad: + out_shape = list(grad_output.size()) + in_shape = out_shape[:-1] + [self.in_features] + grad_input = torch.empty( + in_shape, + dtype=ctx.dtype, + device=device, + ) + general_grouped_gemm( + ws, + dys, + [grad_input], + [None] * num_groups, # quantization_params + ctx.dtype, + layout="NN", + m_splits=split_sizes_int, + use_split_accumulator=_2X_ACC_DGRAD, + single_output=True, + ) + + # Perform wgrad GEMMs + if ctx.weight_requires_grad: + general_grouped_gemm( + xs, + dys, + grad_weights, + [None] * num_groups, # quantization_params + ctx.dtype, + layout="NT", + m_splits=split_sizes_int, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_into_main_grad, + ) + + # Clear input tensors if possible + clear_tensor_data(*xs) + + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. + if accumulate_into_main_grad: + grad_weights = [None] * num_groups + for group_idx in range(num_groups): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weights[group_idx] = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + + grad_params = grad_weights + grad_biases if has_bias else grad_weights + return grad_input, [grad_params], [(None,)] diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 340c2b895..b15dd3660 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -36,7 +36,7 @@ class LayerNorm(BasicOperation): r"""Layer Normalization Applies Layer Normalization over a mini-batch of inputs as described in - the paper `Layer Normalization `__ + the paper `Layer Normalization `__ . .. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta @@ -56,9 +56,9 @@ class LayerNorm(BasicOperation): Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - zero_centered_gamma : bool, default = 'False' - If `True`, the :math:`\gamma` parameter is initialized to zero - and the calculation changes to + zero_centered_gamma : bool, default = False + If ``True``, the :math:`\gamma` parameter is initialized to + zero and the calculation changes to .. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta diff --git a/transformer_engine/pytorch/ops/basic/make_extra_output.py b/transformer_engine/pytorch/ops/basic/make_extra_output.py index 61caaaf65..0d9c87026 100644 --- a/transformer_engine/pytorch/ops/basic/make_extra_output.py +++ b/transformer_engine/pytorch/ops/basic/make_extra_output.py @@ -35,7 +35,7 @@ class MakeExtraOutput(BasicOperation): operations break some autograd assumptions and they can result in subtle, esoteric bugs. - Compare to `AddExtraInput`, which does a similar operation in the + Compare to ``AddExtraInput``, which does a similar operation in the backward pass. """ diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index d126b554b..fa3efc380 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -18,14 +18,14 @@ class Quantize(BasicOperation): """Quantize tensor data - Uses recipe from `autocast` context. When called outside - of an `autocast` context, this is an identity operation. + Uses recipe from ``autocast`` context. When called outside + of an ``autocast`` context, this is an identity operation. Parameters ---------- - forward : bool, default = `True` + forward : bool, default = True Perform quantization in forward pass - backward : bool, default = `False` + backward : bool, default = False Perform quantization in backward pass """ diff --git a/transformer_engine/pytorch/ops/basic/reshape.py b/transformer_engine/pytorch/ops/basic/reshape.py index f8ae86fec..4a171c294 100644 --- a/transformer_engine/pytorch/ops/basic/reshape.py +++ b/transformer_engine/pytorch/ops/basic/reshape.py @@ -20,7 +20,7 @@ class Reshape(BasicOperation): """Reshape tensor - See `torch.reshape`. + See ``torch.reshape``. Parameters ---------- diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index a9bbeab8c..0491ab914 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -40,7 +40,7 @@ class RMSNorm(BasicOperation): Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in the paper - `Root Mean Square Layer Normalization `__ + `Root Mean Square Layer Normalization `__ . .. math:: y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma @@ -58,8 +58,8 @@ class RMSNorm(BasicOperation): Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - zero_centered_gamma : bool, default = 'False' - If `True`, the :math:`\gamma` parameter is initialized to zero + zero_centered_gamma : bool, default = False + If ``True``, the :math:`\gamma` parameter is initialized to zero and the calculation changes to .. math:: diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py new file mode 100644 index 000000000..b4427df41 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -0,0 +1,502 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for SwiGLU and variants.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...tensor import Float8CurrentScalingQuantizer, Quantizer +from ...utils import clear_tensor_data +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize + +__all__ = ["SwiGLU", "ClampedSwiGLU", "ScaledSwiGLU"] + + +class SwiGLU(BasicOperation): + r"""Swish gated linear unit + + The input tensor is split into chunks :math:``a`` and :math:``b`` + along the last dimension and the following is computed: + + .. math:: + + \text{SwiGLU}(a,b) = \text{SiLU}(a) * b + + where + + .. math:: + + \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:``a`` and + :math:``b``. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + The Sigmoid Linear Unit (SiLU) gating function is also known as + the swish function. See + `GLU Variants Improve Transformer `__. + + Parameters + ---------- + cache_quantized_input : bool, default = False + Quantize input tensor when caching for use in the backward + pass. This will typically reduce memory usage but require + extra compute and increase numerical error. This feature is + highly experimental. + glu_interleave_size : int, optional + When set, the GLU activations will use a block interleaved + format. Instead of interpreting the input tensor as a + concatenation of gates and linear units (e.g. + :math:``[a_1, a_2, a_3, a_4, b_1, b_2, b_3, b_4]`` + in the above notation), it will be interpreted + as alternating blocks of gates and linear units (e.g. + :math:``[a_1, a_2, b_1, b_2, a_3, a_4, b_3, b_4]`` + when the interleave size is 2). This data format is highly + experiental and is primarily intended to support some advanced + fused kernels. + + """ + + def __init__( + self, + *, + cache_quantized_input: bool = False, + glu_interleave_size: Optional[int] = None, + ): + super().__init__() + self.cache_quantized_input: bool = cache_quantized_input + self.glu_interleave_size: Optional[int] = glu_interleave_size + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + input_ = maybe_dequantize(input_.contiguous(), dtype) + + # Remove interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Launch kernel + out = tex.swiglu(swiglu_in, next_op_input_quantizer) + + # Quantize input to FP8 before caching if needed + if self.cache_quantized_input: + input_quantizer = Float8CurrentScalingQuantizer( + tex.DType.kFloat8E4M3, + input_.device, + ) + input_quantizer.set_usage(rowwise=True, columnwise=False) + input_ = input_quantizer(input_) + + # Save state for backward pass + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(input_) + ctx.save_for_backward(input_) + ctx.dtype = dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer + + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (input_,) = ctx.saved_tensors + + # Make sure tensors have correct dtypes + x = maybe_dequantize(input_.contiguous(), ctx.dtype) + dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype) + + # Remove interleaving if needed + swiglu_in = x + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Quantizer for grad input + quantizer = ctx.prev_op_grad_output_quantizer + if self.glu_interleave_size is not None: + quantizer = None + + # Launch kernel + grad_swiglu_in = tex.dswiglu(dy, swiglu_in, quantizer) + + # Apply interleaving if needed + dx = grad_swiglu_in + if self.glu_interleave_size is not None: + shape = dx.size() + dx = dx.reshape( + -1, + 2, + shape[-1] // (2 * self.glu_interleave_size), + self.glu_interleave_size, + ) + dx = dx.transpose(1, 2).contiguous() + dx = dx.view(shape) + + # Clear input tensor if possible + clear_tensor_data(input_) + + return dx, () + + +class ClampedSwiGLU(BasicOperation): + r"""GPT-OSS + Implementation based on `GPT-OSS `__. + + This activation has two differences compared to the original SwiGLU + 1. Both gate and pre-activations are clipped based on parameter limit. + 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. + + .. warning:: + + The input tensor is chunked along the last dimension to get + gates/pre-activations which is different from GPT OSS + implementation where the gates/pre-activations are assumed to + be interleaved in the input tensor. + + Parameters + ---------- + limit : float + The clamp limit. + alpha : float + The scaling factor for the sigmoid function used in the activation. + cache_quantized_input : bool, default = ``False`` + Quantize input tensor when caching for use in the backward pass. + glu_interleave_size : int, optional + When set, the GLU activations will use an experimental block + interleaved format. See the corresponding option in the SwiGLU + operation for more details. + + """ + + def __init__( + self, + *, + limit: float = 7.0, + alpha: float = 1.702, + cache_quantized_input: bool = False, + glu_interleave_size: Optional[int] = None, + ): + super().__init__() + self.limit: float = limit + self.alpha: float = alpha + self.cache_quantized_input: bool = cache_quantized_input + self.glu_interleave_size: Optional[int] = glu_interleave_size + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + x = maybe_dequantize(input_.contiguous(), dtype) + + # Remove interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Launch kernel + out = tex.clamped_swiglu( + swiglu_in, + next_op_input_quantizer, + limit=self.limit, + alpha=self.alpha, + ) + + # Quantize input to FP8 before caching if needed + if self.cache_quantized_input: + input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) + input_quantizer.set_usage(rowwise=True, columnwise=False) + x = input_quantizer(x) + + # Save state for backward pass + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x) + ctx.save_for_backward(x) + ctx.dtype = dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer + + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (input_,) = ctx.saved_tensors + + # Make sure tensors have correct dtypes + x = maybe_dequantize(input_.contiguous(), ctx.dtype) + dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype) + + # Remove interleaving if needed + swiglu_in = x + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Quantizer for grad input + quantizer = ctx.prev_op_grad_output_quantizer + if self.glu_interleave_size is not None: + quantizer = None + + # Launch kernel + grad_swiglu_in = tex.clamped_dswiglu( + dy, + swiglu_in, + quantizer, + limit=self.limit, + alpha=self.alpha, + ) + + # Apply interleaving if needed + dx = grad_swiglu_in + if self.glu_interleave_size is not None: + shape = dx.size() + dx = dx.reshape( + -1, + 2, + shape[-1] // (2 * self.glu_interleave_size), + self.glu_interleave_size, + ) + dx = dx.transpose(1, 2).contiguous() + dx = dx.view(shape) + + # Clear input tensor if possible + clear_tensor_data(input_) + + return dx, () + + +class ScaledSwiGLU(BasicOperation): + r"""SwiGLU with post-scaling. + + If the SwiGLU output has shape ``(d_1, ..., d_n)``, it is + multiplied with an extra input tensor of shape + ``(d_1, ..., d_{n-1})``. + + Parameters + ---------- + glu_interleave_size : int, optional + When set, the GLU activations will use an experimental block + interleaved format. See the corresponding option in the SwiGLU + operation for more details. + + """ + + # Operation expects scales + num_extra_inputs: int = 1 + + def __init__(self, glu_interleave_size: Optional[int] = None): + super().__init__() + self.glu_interleave_size: Optional[int] = glu_interleave_size + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + extra_input = basic_op_extra_inputs[0][0] + + # Determine compute dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + elif isinstance(input_, torch.Tensor): + dtype = input_.dtype + else: + dtype = extra_input.dtype + + # Make sure inputs are in correct dtype + input_ = maybe_dequantize(input_, dtype) + scales = maybe_dequantize(extra_input, dtype) + + # Remove gate interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Compute scaled SwiGLU + swiglu_out = tex.swiglu(swiglu_in, None) + out = swiglu_out * scales.unsqueeze(-1) + + # Save state for backward pass + ctx = basic_op_ctxs[0] + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(input_) + ctx.input_requires_grad = True + ctx.extra_input_requires_grad = extra_input.requires_grad + ctx.dtype = dtype + ctx.save_for_backward( + input_, + scales if ctx.input_requires_grad else None, + ) + + return out, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + ctx = basic_op_ctxs[0] + input_, scales = ctx.saved_tensors + input_ = maybe_dequantize(input_, ctx.dtype) + if scales is not None: + scales = maybe_dequantize(scales, ctx.dtype) + grad_output = maybe_dequantize(grad_output, ctx.dtype) + + # Remove gate interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Compute input grad + grad_input = None + if ctx.input_requires_grad: + grad_swiglu_out = grad_output * scales.unsqueeze(-1) + grad_swiglu_in = tex.dswiglu(grad_swiglu_out, swiglu_in, None) + grad_input = grad_swiglu_in + if self.glu_interleave_size is not None: + shape = grad_input.size() + grad_input = grad_input.reshape( + -1, + 2, + shape[-1] // (2 * self.glu_interleave_size), + self.glu_interleave_size, + ) + grad_input = grad_input.transpose(1, 2).contiguous() + grad_input = grad_input.view(shape) + + # Compute scales grad by recomputing SwiGLU + grad_extra_input = None + if ctx.extra_input_requires_grad: + swiglu_out = tex.swiglu(swiglu_in, None) + grad_extra_input = torch.linalg.vecdot(swiglu_out, grad_output) + + # Clear input tensor if possible + clear_tensor_data(ctx.saved_tensors[0]) # input_ + + return grad_input, [()], [(grad_extra_input,)] diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index e1a51197d..19608894e 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -4,40 +4,27 @@ """Compound tensor operation supported by the operation fuser.""" -from .backward_activation_bias import ( - BackwardActivationBias, - fuse_backward_activation_bias, -) -from .backward_add_rmsnorm import ( - BackwardAddRMSNorm, - fuse_backward_add_rmsnorm, -) -from .backward_linear_add import ( - BackwardLinearAdd, - fuse_backward_linear_add, -) -from .backward_linear_scale import ( - BackwardLinearScale, - fuse_backward_linear_scale, -) -from .forward_linear_bias_activation import ( - ForwardLinearBiasActivation, - fuse_forward_linear_bias_activation, -) -from .forward_linear_bias_add import ( - ForwardLinearBiasAdd, - fuse_forward_linear_bias_add, -) -from .forward_linear_scale_add import ( - ForwardLinearScaleAdd, - fuse_forward_linear_scale_add, -) +from ..fuser import register_backward_fusion, register_forward_fusion +from .backward_activation_bias import BackwardActivationBias +from .backward_add_rmsnorm import BackwardAddRMSNorm +from .backward_linear_add import BackwardLinearAdd +from .backward_linear_scale import BackwardLinearScale +from .forward_linear_bias_activation import ForwardLinearBiasActivation +from .forward_linear_bias_add import ForwardLinearBiasAdd +from .forward_linear_scale_add import ForwardLinearScaleAdd +from .userbuffers_backward_linear import UserbuffersBackwardLinear +from .userbuffers_forward_linear import UserbuffersForwardLinear -from .userbuffers_backward_linear import ( - UserbuffersBackwardLinear, - fuse_userbuffers_backward_linear, -) -from .userbuffers_forward_linear import ( - UserbuffersForwardLinear, - fuse_userbuffers_forward_linear, -) + +# Register forward fusions +register_forward_fusion(UserbuffersForwardLinear.fuse_forward_ops) +register_forward_fusion(ForwardLinearBiasAdd.fuse_forward_ops) +register_forward_fusion(ForwardLinearBiasActivation.fuse_forward_ops) +register_forward_fusion(ForwardLinearScaleAdd.fuse_forward_ops) + +# Register backward fusions +register_backward_fusion(UserbuffersBackwardLinear.fuse_backward_ops) +register_backward_fusion(BackwardLinearAdd.fuse_backward_ops) +register_backward_fusion(BackwardLinearScale.fuse_backward_ops) +register_backward_fusion(BackwardActivationBias.fuse_backward_ops) +register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops) diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index d5b9ce0e9..4ab082d32 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -53,8 +53,8 @@ def fuser_backward( ]: # Get basic operation contexts - activation_op_ctx = basic_op_ctxs[0] - bias_op_ctx = basic_op_ctxs[1] + bias_op_ctx = basic_op_ctxs[0] + activation_op_ctx = basic_op_ctxs[1] # Saved tensors from forward pass (act_input,) = activation_op_ctx.saved_tensors @@ -79,68 +79,59 @@ def fuser_backward( # Clear activation input tensor clear_tensor_data(act_input) - return dx, [(), (db,)], [(), ()] + return dx, [(db,), ()], [(), ()] - -def fuse_backward_activation_bias( - ops: list[tuple[FusibleOperation, list[int]]], - recipe: Optional[Recipe], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward dact + dbias + quantize - - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. - recipe : Recipe, optional - Used quantization recipe - - Returns - ------- - ops : list of tuples - Updated backward pass operations - - """ - - # Check if recipe supports bias activation fusion - if recipe is None: - return ops - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 3: + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + recipe : Recipe, optional + Quantization recipe. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Check if recipe supports bias activation fusion + if recipe is None: + return ops + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + if ( + isinstance(window[2], _fusible_activations) + and isinstance(window[1], Bias) + and window[0].get_grad_output_quantizer() is not None + ): + # Construct fused op if window matches pattern + op = BackwardActivationBias(bias=window[1], activation=window[2]) + window = [window[0], op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-2]) + window = window[-2:] + + # Adjust window to expected size + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops out.extend(window) - - # Check if first op is a supported activation - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, _fusible_activations): - continue - - # Check if second op is bias - op, _ = ops[0] - if not isinstance(op, Bias): - continue - - # Check if third op has a grad input quantizer - op, _ = ops[1] - if not op.num_quantizers("backward") > 0: - continue - - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = BackwardActivationBias( - activation=window[0][0], - bias=window[1][0], - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py index 186619caa..a3c81e60c 100644 --- a/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py +++ b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py @@ -42,7 +42,7 @@ def fuser_backward( # Get basic operations rmsnorm_op = self.basic_ops[1] - rmsnorm_op_ctx = basic_op_ctxs[0] + rmsnorm_op_ctx = basic_op_ctxs[1] # Saved tensors from forward pass x, rstdevs = rmsnorm_op_ctx.saved_tensors @@ -53,7 +53,7 @@ def fuser_backward( # Check input tensors dtype = rmsnorm_op_ctx.dtype - extra_grad = basic_op_grad_extra_outputs[1][0] + extra_grad = basic_op_grad_extra_outputs[0][0] dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size()) w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,)) add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size()) @@ -77,57 +77,51 @@ def fuser_backward( grad_input = dx.view(grad_output.size()) grad_weight = dw.view(weight_dims) - return grad_input, [(grad_weight,), ()], [(), ()] - - -def fuse_backward_add_rmsnorm( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward RMNorm + add - - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated backward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + return grad_input, [(), (grad_weight,)], [(), ()] + + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:2], ops[2:] + while len(window) == 2: + if ( + isinstance(window[0], MakeExtraOutput) + and isinstance(window[1], RMSNorm) + and not window[0]._in_place + ): + # Construct fused op if window matches pattern + op = BackwardAddRMSNorm(add=window[0], rmsnorm=window[1]) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-1]) + window = window[-1:] + + # Adjust window to expected size + out.extend(window[:-2]) + window = window[-2:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, RMSNorm): - continue - - # Check if second op is "make extra output" - op, _ = ops[0] - if not isinstance(op, MakeExtraOutput): - continue - if op._in_place: - continue - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = BackwardAddRMSNorm( - rmsnorm=window[0][0], - add=window[1][0], - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 5e7339db8..c06e212e8 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -45,7 +45,7 @@ def fuser_backward( # Get basic operations linear_op = self.basic_ops[1] - linear_op_ctx = basic_op_ctxs[0] + linear_op_ctx = basic_op_ctxs[1] # Saved tensors from forward pass (x_local, w) = linear_op_ctx.saved_tensors @@ -71,7 +71,7 @@ def fuser_backward( accumulate_into_main_grad = False # Linear backward pass - grad_input = basic_op_grad_extra_outputs[1][0] + grad_input = basic_op_grad_extra_outputs[0][0] grad_input, grad_weight = BasicLinear._functional_backward( grad_output=grad_output, input=x_local, @@ -109,61 +109,60 @@ def fuser_backward( zero=getattr(weight_param, "zero_out_wgrad", False), ) - return grad_input, [(grad_weight,), ()], [(), ()] - - -def fuse_backward_linear_add( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward dgrad GEMM + add - - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated backward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + return grad_input, [(), (grad_weight,)], [(), ()] + + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:2], ops[2:] + while len(window) == 2: + + # Check if window matches pattern + matches_pattern = True + if not (isinstance(window[0], MakeExtraOutput) and isinstance(window[1], BasicLinear)): + matches_pattern = False + elif not window[0]._in_place: + # Fused op accumulates grad input in-place + matches_pattern = False + elif window[1].tensor_parallel_mode == "column": + # Column tensor-parallelism requires communication + # after the dgrad GEMM + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = BackwardLinearAdd(backward_add=window[0], linear=window[1]) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-1]) + window = window[-1:] + + # Adjust window to expected size + out.extend(window[:-2]) + window = window[-2:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, BasicLinear): - continue - if op.tensor_parallel_mode == "column": - # Row tensor-parallelism requires communication after the - # GEMM - continue - - # Check if second op is "make extra output" - op, _ = ops[0] - if not isinstance(op, MakeExtraOutput): - continue - if not op._in_place: - continue - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = BackwardLinearAdd( - linear=window[0][0], - backward_add=window[1][0], - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index f7f59e65c..709073e6f 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -45,7 +45,7 @@ def fuser_backward( # Get basic operations linear_op = self.basic_ops[0] - linear_op_ctx = basic_op_ctxs[1] + linear_op_ctx = basic_op_ctxs[0] scale_op = self.basic_ops[1] # Saved tensors from forward pass @@ -109,58 +109,57 @@ def fuser_backward( zero=getattr(weight_param, "zero_out_wgrad", False), ) - return grad_input, [(), (grad_weight,)], [(), ()] - - -def fuse_backward_linear_scale( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward dgrad GEMM + constant scale - - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated backward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + return grad_input, [(grad_weight,), ()], [(), ()] + + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:2], ops[2:] + while len(window) == 2: + + # Check if window matches pattern + matches_pattern = True + if not (isinstance(window[0], BasicLinear) and isinstance(window[1], ConstantScale)): + matches_pattern = False + elif window[0].tensor_parallel_mode == "column": + # Column tensor-parallelism requires communication + # after the dgrad GEMM + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = BackwardLinearScale(linear=window[0], scale=window[1]) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-1]) + window = window[-1:] + + # Adjust window to expected size + out.extend(window[:-2]) + window = window[-2:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops out.extend(window) - - # Check if first op is constant scale - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, ConstantScale): - continue - - # Check if second op is linear - op, _ = ops[0] - if not isinstance(op, BasicLinear): - continue - if op.tensor_parallel_mode == "column": - # Column tensor-parallelism requires communication after the dgrad GEMM - continue - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = BackwardLinearScale( - scale=window[0][0], - linear=window[1][0], - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 1c5edfcfc..dfc11a19e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -134,62 +134,63 @@ def fuser_forward( return output, [() for _ in range(len(self.basic_ops))] - -def fuse_forward_linear_bias_activation( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fuse forward GEMM + bias + activation - - Parameters - ---------- - ops : list of tuples - Forward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated forward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:2], ops[2:] + while len(window) == 2: + + # Check if window matches pattern + matches_pattern = True + if not (isinstance(window[0], BasicLinear) and isinstance(window[1], Bias)): + matches_pattern = False + elif window[0].tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after + # the GEMM + matches_pattern = False + elif window[0].weight.dtype not in (torch.float16, torch.bfloat16): + # cuBLAS only supports fused GEMM+bias+activation with + # FP16 and BF16 output + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = ForwardLinearBiasActivation( + linear=window[0], + bias=window[1], + activation=None, + ) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-1]) + window = window[-1:] + + # Adjust window to expected size + out.extend(window[:-2]) + window = window[-2:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op1, _ = window[0] - if not isinstance(op1, BasicLinear): - continue - if op1.tensor_parallel_mode == "row": - # Row tensor-parallelism requires communication after the - # GEMM - continue - if op1.weight.dtype not in (torch.float16, torch.bfloat16): - # cuBLAS only supports fused GEMM+bias+activation with - # FP16 and BF16 output - continue - - # Check if second op is bias - op2, _ = ops[0] - if not isinstance(op2, Bias): - continue - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = ForwardLinearBiasActivation( - linear=window[0][0], - bias=window[1][0], - activation=None, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 4efb33e03..2dfc0566b 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -131,72 +131,63 @@ def fuser_forward( return output, [() for _ in range(len(self.basic_ops))] + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + + # Shift window + out.extend(window) + window = [ops[0]] + ops = ops[1:] -def fuse_forward_linear_bias_add( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fuse forward GEMM + bias + add - - Parameters - ---------- - ops : list of tuples - Forward pass operations and the indices of the corresponding - basic operations. + # Check if first op is linear + if not isinstance(window[0], BasicLinear): + continue + if window[0].tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after + # the GEMM + continue + linear = window[0] - Returns - ------- - ops : list of tuples - Updated forward pass operations + # Check if next op is bias + bias = None + if ops and isinstance(ops[0], Bias): + window.append(ops[0]) + ops = ops[1:] + bias = window[-1] + + # Check if next op is in-place add extra input + if ops and isinstance(ops[0], AddExtraInput) and ops[0]._in_place: + window.append(ops[0]) + ops = ops[1:] + add = window[-1] + else: + continue - """ + # Replace window with fused op + op = ForwardLinearBiasAdd(linear=linear, bias=bias, add=add) + window = [op] - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, BasicLinear): - continue - if op.tensor_parallel_mode == "row": - # Row tensor-parallelism requires communication after the - # GEMM - continue - linear = op - op, _ = ops[0] - - # Check if next op is bias - bias = None - if isinstance(op, Bias): - bias = op - window.extend(ops[:1]) - ops = ops[1:] - if len(ops) == 0: - continue - op, _ = ops[0] - - # Check if next op is in-place add extra input - if not isinstance(op, AddExtraInput): - continue - if not op._in_place: - continue - add = op - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = ForwardLinearBiasAdd( - linear=linear, - bias=bias, - add=add, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 25b40f76e..ae4bdd4b1 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -110,70 +110,66 @@ def fuser_forward( return output, [() for _ in range(len(self.basic_ops))] - -def fuse_forward_linear_scale_add( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fuse forward GEMM + scale + add - - Parameters - ---------- - ops : list of tuples - Forward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated forward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 3: + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + + # Check if window matches pattern + matches_pattern = True + if not ( + isinstance(window[0], BasicLinear) + and isinstance(window[1], ConstantScale) + and isinstance(window[2], AddExtraInput) + ): + matches_pattern = False + elif window[0].tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after + # the GEMM + matches_pattern = False + elif not window[2]._in_place: + # Fused op accumulates output in-place + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = ForwardLinearScaleAdd( + linear=window[0], + scale=window[1], + add=window[2], + ) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-2]) + window = window[-2:] + + # Adjust window to expected size + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, BasicLinear): - continue - if op.tensor_parallel_mode == "row": - # Row tensor-parallelism requires communication after the - # GEMM - continue - linear = op - op, _ = ops[0] - - # Check if next op is constant scale - if not isinstance(op, ConstantScale): - continue - scale = op - window.extend(ops[:1]) - ops = ops[1:] - op, _ = ops[0] - - # Check if next op is in-place add extra input - if not isinstance(op, AddExtraInput): - continue - if not op._in_place: - continue - add = op - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = ForwardLinearScaleAdd( - linear=linear, - scale=scale, - add=add, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 6c889ba04..fbaf69d75 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -125,18 +125,18 @@ def _functional_backward( Tensor datatype grad_weight: torch.Tensor, optional Loss gradient w.r.t. weight tensor - accumulate_into_grad_weight: bool, default = `False` + accumulate_into_grad_weight: bool, default = False Add result to weight grad instead of overwriting - tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + tensor_parallel_mode: {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel: bool, default = `False` + sequence_parallel: bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_quantized_compute: bool, default = `False` + with_quantized_compute: bool, default = False Whether to perform compute with quantized data. input_quantizer: Quantizer, optional Builder class for quantized input tensor. @@ -503,7 +503,7 @@ def fuser_backward( # Get basic operations idx = self._op_idxs["linear"] linear_op = self.basic_ops[idx] - linear_op_ctx = basic_op_ctxs[-1] + linear_op_ctx = basic_op_ctxs[0] bias_op = None if self._op_idxs["bias"] is not None: idx = self._op_idxs["bias"] @@ -578,99 +578,84 @@ def fuser_backward( grad_params[self._op_idxs["linear"]] = (grad_weight,) if bias_op is not None: grad_params[self._op_idxs["bias"]] = (grad_bias,) - grad_params.reverse() grad_extra_inputs = [() for _ in range(len(self.basic_ops))] return grad_input, grad_params, grad_extra_inputs + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. -def fuse_userbuffers_backward_linear( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Substitute linear operations with Userbuffers implementation + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + recipe : Recipe, optional + Quantization recipe. - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations - Returns - ------- - ops : list of tuples - Updated backward pass operations + """ - """ + # Return immediately if environment is not distributed + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: + return ops - # Return immediately if environment is not distributed - if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: - return ops - - # Sliding window in list of ops - window = [] - - def peek_next_op() -> Optional[FusibleOperation]: - """Get next op in list of ops""" - nonlocal ops - if not ops: - return None - return ops[-1][0] - - def pop_next_op() -> FusibleOperation: - """Remove next op from list of ops and add to sliding window""" - nonlocal ops, window - window.insert(0, ops[-1]) - ops = ops[:-1] - return window[0][0] - - # Scan through ops in reverse order, fusing if possible - out_reversed = [] - while ops: - out_reversed.extend(reversed(window)) - window.clear() - - # Check if next op is linear - next_op = pop_next_op() - if not isinstance(next_op, BasicLinear): - continue - linear = next_op - if linear._userbuffers_options is None: - continue - - # Check if next op is bias - bias = None - if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): - bias = pop_next_op() - - # Check if next op is reduce-scatter - reduce_scatter = None - if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): - reduce_scatter = pop_next_op() - - # Check for invalid combinations - if reduce_scatter is None: - if linear.tensor_parallel_mode is None: - continue - if linear.tensor_parallel_size == 1: - continue - if linear.tensor_parallel_mode == "row" and bias is not None: - continue - else: - if linear.tensor_parallel_mode is not None: + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + + # Shift window + out.extend(window) + window, ops = ops[:1], ops[1:] + + # Check if first op is linear + if not isinstance(window[0], BasicLinear): continue - if reduce_scatter.process_group_size == 1: + linear = window[0] + if linear._userbuffers_options is None: continue - # Replace window with fused op - op = UserbuffersBackwardLinear( - linear=linear, - bias=bias, - reduce_scatter=reduce_scatter, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out_reversed.extend(reversed(window)) - out = out_reversed - out.reverse() - return out + # Check if next op is bias + bias = None + if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias): + bias, ops = ops[0], ops[1:] + window.append(bias) + + # Check if next op is reduce-scatter + reduce_scatter = None + if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter): + reduce_scatter, ops = ops[0], ops[1:] + window.append(reduce_scatter) + + # Check for invalid combinations + if reduce_scatter is None: + if linear.tensor_parallel_mode is None: + continue + if linear.tensor_parallel_size == 1: + continue + if linear.tensor_parallel_mode == "row" and bias is not None: + continue + else: + if linear.tensor_parallel_mode is not None: + continue + if reduce_scatter.process_group_size == 1: + continue + + # Replace window with fused op + op = UserbuffersBackwardLinear( + linear=linear, + bias=bias, + reduce_scatter=reduce_scatter, + ) + window = [op] + + # Return list of ops + out.extend(window) + return out diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index fe04aa1e0..0d3e1d041 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -115,16 +115,16 @@ def _functional_forward( Tensor device dtype: torch.dtype Tensor datatype - tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + tensor_parallel_mode: {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group: torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel: bool, default = `False` + sequence_parallel: bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_quantized_compute: bool, default = `False` + with_quantized_compute: bool, default = False Whether to perform compute with quantized data. input_quantizer: Quantizer, optional Builder class for quantized input tensor. @@ -132,10 +132,10 @@ def _functional_forward( Builder class for quantized weight tensor. output_quantizer: Quantizer, optional Builder class for quantized output tensor. - input_requires_grad: bool, default = `True` + input_requires_grad: bool, default = True Whether the loss gradient w.r.t. the input tensor is required in the backward pass. - weight_requires_grad: bool, default = `True` + weight_requires_grad: bool, default = True Whether the loss gradient w.r.t. the weight tensor is required in the backward pass. ub_comm_name: str @@ -369,93 +369,79 @@ def fuser_forward( return output, [() for _ in range(len(self.basic_ops))] + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. -def fuse_userbuffers_forward_linear( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Substitute linear operations with Userbuffers implementation - - Parameters - ---------- - ops : list of tuples - Forward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated forward pass operations + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. - """ + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations - # Return immediately if environment is not distributed - if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: - return ops - - # Sliding window in list of ops - window = [] - - def peek_next_op() -> Optional[FusibleOperation]: - """Get next op in list of ops""" - nonlocal ops - if not ops: - return None - return ops[0][0] - - def pop_next_op() -> FusibleOperation: - """Remove next op from list of ops and add to sliding window""" - nonlocal ops, window - window.append(ops[0]) - ops = ops[1:] - return window[-1][0] - - # Scan through ops, fusing if possible - out = [] - while ops: - out.extend(window) - window.clear() + """ - # Check if next op is linear - next_op = pop_next_op() - if not isinstance(next_op, BasicLinear): - continue - linear = next_op - if linear._userbuffers_options is None: - continue + # Return immediately if environment is not distributed + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: + return ops - # Check if next op is bias - bias = None - if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): - bias = pop_next_op() + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: - # Check if next op is reduce-scatter - reduce_scatter = None - if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): - reduce_scatter = pop_next_op() + # Shift window + out.extend(window) + window, ops = ops[:1], ops[1:] - # Check for invalid combinations - if reduce_scatter is None: - if linear.tensor_parallel_mode is None: - continue - if linear.tensor_parallel_size == 1: - continue - if linear.tensor_parallel_mode == "row" and bias is not None: - continue - else: - if linear.tensor_parallel_mode is not None: + # Check if first op is linear + if not isinstance(window[0], BasicLinear): continue - if reduce_scatter.process_group_size == 1: + linear = window[0] + if linear._userbuffers_options is None: continue - # Replace window with fused op - op = UserbuffersForwardLinear( - linear=linear, - bias=bias, - reduce_scatter=reduce_scatter, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] + # Check if next op is bias + bias = None + if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias): + bias, ops = ops[0], ops[1:] + window.append(bias) + + # Check if next op is reduce-scatter + reduce_scatter = None + if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter): + reduce_scatter, ops = ops[0], ops[1:] + window.append(reduce_scatter) + + # Check for invalid combinations + if reduce_scatter is None: + if linear.tensor_parallel_mode is None: + continue + if linear.tensor_parallel_size == 1: + continue + if linear.tensor_parallel_mode == "row" and bias is not None: + continue + else: + if linear.tensor_parallel_mode is not None: + continue + if reduce_scatter.process_group_size == 1: + continue + + # Replace window with fused op + op = UserbuffersForwardLinear( + linear=linear, + bias=bias, + reduce_scatter=reduce_scatter, + ) + window = [op] - # Return list of ops - out.extend(window) - return out + # Return list of ops + out.extend(window) + return out diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 86b279759..80386db2d 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -1,5 +1,3 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -7,37 +5,20 @@ """Manager class for a pipeline of fusible operations.""" from __future__ import annotations -from collections.abc import Callable, Iterable -from typing import Any, Optional +from collections.abc import Callable, Iterable, Sequence import itertools +from typing import Any, Optional, TypeAlias import torch -from transformer_engine.pytorch.quantization import FP8GlobalStateManager, Recipe, DelayedScaling -from transformer_engine.pytorch.ops.op import ( +from ..quantization import FP8GlobalStateManager, Recipe, DelayedScaling +from ..quantized_tensor import prepare_for_saving, restore_from_saved +from .op import ( BasicOperation, FusibleOperation, + FusedOperation, OperationContext, ) -from torch.utils.cpp_extension import IS_HIP_EXTENSION -from transformer_engine.pytorch.ops.fused import ( - fuse_backward_activation_bias, - fuse_backward_add_rmsnorm, - fuse_backward_linear_add, - fuse_backward_linear_scale, - fuse_forward_linear_bias_activation, - fuse_forward_linear_bias_add, - fuse_forward_linear_scale_add, -) -if not IS_HIP_EXTENSION: - from transformer_engine.pytorch.ops.fused import ( - fuse_userbuffers_backward_linear, - fuse_userbuffers_forward_linear, - ) -from transformer_engine.pytorch.quantized_tensor import ( - prepare_for_saving, - restore_from_saved, -) def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: @@ -50,7 +31,7 @@ def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: def _is_graph_capturing() -> bool: - """Whether function is called within `make_graphed_callables` + """Whether function is called within ``make_graphed_callables`` Avoid circular import with lazy import. @@ -63,6 +44,12 @@ def _is_graph_capturing() -> bool: return _is_graph_capturing_function() +# Type alias for a function that may perform operation fusion +OperationFusionFunction: TypeAlias = ( + "Callable[tuple[list[FusibleOperation], ...], list[FusibleOperation]]" +) + + class _OperationFuserAutogradFunction(torch.autograd.Function): """Autograd function for a pipeline of operations @@ -226,6 +213,7 @@ def backward( # Restore saved tensors saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors) + func_ctx.tensor_objects = None # Unflatten list of saved tensors for ctx in basic_op_ctxs: @@ -247,7 +235,7 @@ def backward( dx = grad_output grad_params = [None for _ in range(len(basic_ops))] grad_extra_inputs = [None for _ in range(len(basic_ops))] - for op, basic_op_idxs in backward_ops: + for op, basic_op_idxs in reversed(backward_ops): # Stop if no more gradients are required if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs): @@ -321,6 +309,10 @@ class OperationFuser: """ + # Functions to perform operation fusion + forward_fusion_functions: list[OperationFusionFunction] = [] + backward_fusion_functions: list[OperationFusionFunction] = [] + def __init__( self, ops: list[FusibleOperation], @@ -340,7 +332,7 @@ def __init__( self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops) self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs) - # Ops for forward and backward pass, will be populated in fuse_ops + # Ops for forward and backward pass, will be populated in maybe_fuse_ops self._forward_ops: list[tuple[FusibleOperation, list[int]]] self._backward_ops: list[tuple[FusibleOperation, list[int]]] @@ -355,33 +347,48 @@ def __init__( self._flat_basic_op_params = sum(self._basic_op_params, []) @classmethod - def _fuse_forward_ops( - cls, - ops: list[tuple[FusibleOperation, list[int]]], - recipe: Optional[Recipe], # pylint: disable=unused-argument - ) -> list[tuple[FusibleOperation, list[int]]]: - """Attempt to fuse operations in forward pass""" - if not IS_HIP_EXTENSION: - ops = fuse_userbuffers_forward_linear(ops) - ops = fuse_forward_linear_bias_add(ops) - ops = fuse_forward_linear_bias_activation(ops) - ops = fuse_forward_linear_scale_add(ops) - return ops - - @classmethod - def _fuse_backward_ops( + def _fuse_ops( cls, - ops: list[tuple[FusibleOperation, list[int]]], + basic_ops: Sequence[BasicOperation], + fusion_funcs: Iterable[OperationFusionFunction], recipe: Optional[Recipe], ) -> list[tuple[FusibleOperation, list[int]]]: - """Attempt to fuse operations in backward pass""" - if not IS_HIP_EXTENSION: - ops = fuse_userbuffers_backward_linear(ops) - ops = fuse_backward_linear_add(ops) - ops = fuse_backward_linear_scale(ops) - ops = fuse_backward_activation_bias(ops, recipe) - ops = fuse_backward_add_rmsnorm(ops) - return ops + """Apply operation fusions""" + + # Apply op fusions + fused_ops = list(basic_ops) + for func in fusion_funcs: + fused_ops = func(fused_ops, recipe=recipe) + + def raise_mismatch_error() -> None: + """Throw error indicating invalid op fusion""" + raise RuntimeError( + "Found mismatch after fusing operations " + f"(basic_ops={[o.__class__.__name__ for o in basic_ops]}, " + f"fused_ops={[o.__class__.__name__ for o in fused_ops]})" + ) + + # Determine basic op indices corresponding to each op + out = [] + idx = 0 + for op in fused_ops: + if isinstance(op, FusedOperation): + idxs = [] + for basic_op in op.basic_ops: + if basic_op is not basic_ops[idx]: + raise_mismatch_error() + idxs.append(idx) + idx += 1 + out.append((op, idxs)) + else: + if op is not basic_ops[idx]: + raise_mismatch_error() + out.append((op, [idx])) + idx += 1 + if idx != len(basic_ops): + raise_mismatch_error() + + return out def maybe_fuse_ops( self, @@ -432,12 +439,16 @@ def maybe_fuse_ops( op.pre_first_fuser_forward() # Prepare basic op lists for fusions - forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)] - backward_ops = list(reversed(forward_ops[first_op_requiring_backward:])) - - # Fuse ops - self._forward_ops = self._fuse_forward_ops(forward_ops, recipe) - self._backward_ops = self._fuse_backward_ops(backward_ops, recipe) + self._forward_ops = OperationFuser._fuse_ops( + self._basic_ops, + OperationFuser.forward_fusion_functions, + recipe=recipe, + ) + self._backward_ops = OperationFuser._fuse_ops( + self._basic_ops, + OperationFuser.backward_fusion_functions, + recipe=recipe, + ) # Save current fusion params self.recipe_type, self.first_op_requiring_backward = fusion_params @@ -499,3 +510,59 @@ def __call__( *extra_inputs, ) return forward_func(*args) + + +def register_forward_fusion( + op_fusion_func: OperationFusionFunction, + prepend: bool = False, +) -> None: + """Register function to perform operation fusion for forward pass. + + The fusion function should have the following signature: + + .. code-block:: python + + func(ops, *, recipe) -> updated ops + + Parameters + ---------- + op_fusion_func: function + Function that takes a list of operations and may substitute + them with fused operations. + prepend: bool, default = ``False`` + Whether the operation fuser should apply this fusion function + first. The default is to apply it last. + + """ + if prepend: + OperationFuser.forward_fusion_functions.insert(0, op_fusion_func) + else: + OperationFuser.forward_fusion_functions.append(op_fusion_func) + + +def register_backward_fusion( + op_fusion_func: OperationFusionFunction, + prepend: bool = False, +) -> None: + """Register function to perform operation fusion for backward pass. + + The fusion function should have the following signature: + + .. code-block:: python + + func(ops, *, recipe) -> updated ops + + Parameters + ---------- + op_fusion_func: function + Function that takes a list of operations and may substitute + them with fused operations. + prepend: bool, default = ``False`` + Whether the operation fuser should apply this fusion function + first. The default is to apply it last. + + """ + if prepend: + OperationFuser.backward_fusion_functions.insert(0, op_fusion_func) + else: + OperationFuser.backward_fusion_functions.append(op_fusion_func) diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index d5829b0c5..c6ca4786b 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -23,7 +23,7 @@ class Linear(FusedOperation): """Apply linear transformation: :math:`y = x A^T + b` - This is a drop-in replacement for `torch.nn.Linear`. + This is a drop-in replacement for ``torch.nn.Linear``. Parameters ---------- @@ -31,17 +31,17 @@ class Linear(FusedOperation): Inner dimension of input tensor out_features : int Inner dimension of output tensor - bias : bool, default = `True` + bias : bool, default = True Apply additive bias device : torch.device, default = default CUDA device Tensor device dtype : torch.dtype, default = default dtype Tensor datatype - tensor_parallel_mode : {`None`, "column", "row"}, default = `None` + tensor_parallel_mode : {None, "column", "row"}, default = None Mode for tensor parallelism tensor_parallel_group : torch.distributed.ProcessGroup, default = world group Process group for tensor parallelism - sequence_parallel : bool, default = `False` + sequence_parallel : bool, default = False Whether to apply sequence parallelism together with tensor parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing @@ -49,12 +49,12 @@ class Linear(FusedOperation): rng_state_tracker_function : callable Function that returns CudaRNGStatesTracker, which is used for model-parallel weight initialization - accumulate_into_main_grad : bool, default = `False` + accumulate_into_main_grad : bool, default = False Whether to directly accumulate weight gradients into the - weight's `main_grad` attribute instead of relying on PyTorch - autograd. The weight's `main_grad` must be set externally and - there is no guarantee that `grad` will be set or be - meaningful. This is primarily intented to integrate with + weight's ``main_grad`` attribute instead of relying on PyTorch + autograd. The weight's ``main_grad`` must be set externally and + there is no guarantee that ``grad`` will be set or be + meaningful. This is primarily intended to integrate with Megatron-LM. """ diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 47286dfce..54b3f0011 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -94,7 +94,7 @@ def fuser_forward( several of this function's arguments are lists of arguments to forward functions of corresponding basic ops. - Called by `OperationFuser`. + Called by ``OperationFuser``. Parameters ---------- @@ -141,7 +141,7 @@ def fuser_backward( several of this function's arguments are lists of arguments to backward functions of corresponding basic ops. - Called by `OperationFuser`. + Called by ``OperationFuser``. Parameters ---------- diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index a0db3cd2d..592ddae23 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -15,10 +15,10 @@ class Sequential(torch.nn.Module): - """Sequential container for fusible operations + """Sequential container for fusible operations. - This is a drop-in replacement for `torch.nn.Sequential`, with - support for fusing `FusibleOperation`s. + This is a drop-in replacement for ``torch.nn.Sequential`` with + support for fusing ``FusibleOperation`` s. Parameters ---------- diff --git a/transformer_engine/pytorch/optimizers/__init__.py b/transformer_engine/pytorch/optimizers/__init__.py index 792eab094..7220f1924 100644 --- a/transformer_engine/pytorch/optimizers/__init__.py +++ b/transformer_engine/pytorch/optimizers/__init__.py @@ -5,6 +5,7 @@ """Fused optimizers and multi-tensor kernels.""" from transformer_engine_torch import ( multi_tensor_scale, + multi_tensor_scale_tensor, multi_tensor_l2norm, multi_tensor_unscale_l2norm, multi_tensor_adam, diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index b2b78f3eb..95efb88cb 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -144,19 +144,24 @@ def __init__( if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]: raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg_sq.") - # Currently, capturable mode only supports fp32 master weights and optimizer states. - # The reason is, if the master weights or optimizer states are not in fp32 dtype, - # they will be copied to temporary fp32 buffers first. These fp32 buffers are then - # used as inputs for the kernel. Consequently, the pointer for earch `.step()` differs, - # making CUDA Graph inapplicable in this scenario. + # Capturable mode requires fp32 master weights, and optimizer states (exp_avg/exp_avg_sq) + # must both be fp32 or both be bf16. This is because master weights in non-fp32 dtypes + # or optimizer states in non-fp32/bf16 dtypes require copying to temporary fp32 buffers + # before kernel execution, causing different pointers on each `.step()` call and making + # CUDA Graph inapplicable. if capturable and master_weights and master_weight_dtype != torch.float32: raise RuntimeError("Capturable mode only supports fp32 master weights.") - if capturable and exp_avg_dtype != torch.float32: - raise RuntimeError("Capturable mode only supports fp32 exp_avg.") - if capturable and exp_avg_sq_dtype != torch.float32: - raise RuntimeError("Capturable mode only supports fp32 exp_avg_sq") - if capturable and store_param_remainders: - raise RuntimeError("Capturable mode doesn't support storing param remainders") + if capturable: + valid_moment_dtypes = ( + exp_avg_dtype == exp_avg_sq_dtype == torch.float32 + or exp_avg_dtype == exp_avg_sq_dtype == torch.bfloat16 + ) + if not valid_moment_dtypes: + raise RuntimeError( + "Capturable mode requires exp_avg_dtype and exp_avg_sq_dtype to be " + "both torch.float32 or both torch.bfloat16, but got " + f"exp_avg_dtype={exp_avg_dtype} and exp_avg_sq_dtype={exp_avg_sq_dtype}." + ) # If the optimizer is capturable then LR should be a tensor (on GPU) lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr @@ -211,6 +216,11 @@ def __init__( self.store_param_remainders = ( store_param_remainders and master_weights and master_weight_dtype == torch.float32 ) + if self.capturable and self.store_param_remainders: + raise RuntimeError("Capturable mode doesn't support storing param remainders") + # If the exp_avg and exp_avg_sq dtypes are bfloat16, we can fuse the unscaling/scaling + # operations into the fused Adam kernel. + self.fuse_unscale = self.exp_avg_dtype == self.exp_avg_sq_dtype == torch.bfloat16 # Deprecated options self.set_grad_none = set_grad_none @@ -272,10 +282,9 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): dtype = self.name_to_dtype_map[state_name] if dtype == torch.uint8: assert isinstance(scaled_state, Float8Tensor) - assert len(scaled_state._quantizer.scale) == 1, ( - "Only scaling with one scaling factor per tensor is supported by the" - " FusedAdam." - ) + assert ( + len(scaled_state._quantizer.scale) == 1 + ), "Only scaling with one scaling factor per tensor is supported by the FusedAdam." else: assert scaled_state.dtype == dtype @@ -297,13 +306,22 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): unscaled_state.mul_(rscale) scaled_state.copy_(unscaled_state) - def get_unscaled_state(self, param, state_name): + def get_unscaled_state( + self, param: torch.nn.Parameter, state_name: str, skip_unscale: bool = False + ) -> torch.Tensor: """Return the unscaled state corresponding to the input `param` and `state_name`. Arguments: param (torch.nn.Parameter): One of parameters in this optimizer. state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', and 'master_param`. + skip_unscale (optional, bool): Whether to skip the unscaling operation. + Should only be True if 'self.fuse_unscale' is True. Default is False. + + Returns: + torch.Tensor: The unscaled state. Note that if the state is in BF16, the returned + tensor is still in BF16 because it doesn't require to be "unscaled", otherwise it + will be unscaled to FP32. """ state = self.state[param] dtype = self.name_to_dtype_map[state_name] @@ -325,7 +343,10 @@ def get_unscaled_state(self, param, state_name): unscaled = state[state_name] elif dtype == torch.bfloat16: assert state[state_name].dtype == torch.bfloat16 - unscaled = state[state_name].float() + if skip_unscale: + unscaled = state[state_name] + else: + unscaled = state[state_name].float() else: raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/bf16/fp32.") return unscaled @@ -377,9 +398,25 @@ def _initialize_state( store_param_remainders (bool): Store only trailing remainder bits. """ dtype = self.name_to_dtype_map[state_name] + # Extract local tensor from DTensor (e.g. from FSDP2) to avoid + # QuantizedTensor.__torch_dispatch__ ignoring the dtype kwarg in + # torch.empty_like, and to ensure optimizer states are plain tensors. + local_param = param._local_tensor if isinstance(param, DTensor) else param # Handle QuantizedTensor by dequantizing first - param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param - if store_param_remainders: + param_for_empty = ( + local_param.dequantize() if isinstance(local_param, QuantizedTensor) else local_param + ) + #ROCm: create plain `torch.Tensor` instead of `FSDPAGTensor` subclasses + if IS_HIP_EXTENSION: + if store_param_remainders: + data = torch.zeros( + param_for_empty.shape, dtype=torch.int16, device=param_for_empty.device + ) + else: + data = torch.empty( + param_for_empty.shape, dtype=dtype, device=param_for_empty.device + ) + elif store_param_remainders: data = torch.zeros_like(param_for_empty, dtype=torch.int16) else: data = torch.empty_like(param_for_empty, dtype=dtype) @@ -423,7 +460,14 @@ def initialize_state(self, param, store_param_remainders): store_param_remainders=store_param_remainders, ) if not store_param_remainders: - self.set_scaled_state(param, "master_param", param.clone().detach().float()) + # Extract local tensor from DTensor and dequantize QuantizedTensor + # to get a plain float32 copy for the master weight. + local_param = param._local_tensor if isinstance(param, DTensor) else param + if isinstance(local_param, QuantizedTensor): + master = local_param.dequantize(dtype=torch.float32).clone().detach() + else: + master = local_param.clone().detach().float() + self.set_scaled_state(param, "master_param", master) def state_dict(self): """Override the state_dict() of pytorch. Before returning the state_dict, cast all @@ -541,6 +585,7 @@ def step(self, closure=None, grad_scaler=None): has_fp16 = False has_bf16 = False + quantized_params_to_update = [] for p in group["params"]: state = self.state[p] @@ -569,12 +614,27 @@ def step(self, closure=None, grad_scaler=None): unscaled_state[name] = self.state[p][name] assert unscaled_state[name].dtype == torch.int16 else: - unscaled = self.get_unscaled_state(p, name) + unscaled = self.get_unscaled_state( + p, name, skip_unscale=self.fuse_unscale + ) unscaled_state[name] = unscaled if self.name_to_dtype_map[name] != torch.float32: unscaled_lists[name].append(unscaled) scaled_lists[name].append(state[name]) state_scales[name].append(self._scales[p][name]) + # ROCm: extract local tensor data from DTensor (FSDP2) to ensure + # consistent shapes with optimizer states (which are local tensors). + local_p = p._local_tensor if IS_HIP_EXTENSION and isinstance(p, DTensor) else p + local_g = ( p_grad._local_tensor if IS_HIP_EXTENSION and + isinstance(p_grad, DTensor) else p_grad ) + + local_p_data = p._local_tensor.data if isinstance(p, DTensor) else p.data + local_g_data = ( + p_grad._local_tensor.data + if isinstance(p_grad, DTensor) + else p_grad.data + ) + if isinstance(p, Float8Tensor) or ( isinstance(p, DTensor) and isinstance(p._local_tensor, Float8Tensor) ): @@ -587,21 +647,44 @@ def step(self, closure=None, grad_scaler=None): scale_invs.append(scale_inv) if self.master_weights: p_main_of_fp8_model.append(unscaled_state["master_param"].data) - g_of_fp8_model.append(p_grad.data) + g_of_fp8_model.append(local_g.data) m_of_fp8_model.append(unscaled_state["exp_avg"]) v_of_fp8_model.append(unscaled_state["exp_avg_sq"]) + elif isinstance(p, QuantizedTensor) or ( + isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) + ): + # Block-scaling quantized params (MXFP8Tensor, Float8BlockwiseQTensor, + # NVFP4Tensor). Operate on FP32 master weights, requantize back after + # Adam update. + # Note: a fused Adam+requantize kernel (like multi_tensor_adam_fp8 + # for Float8Tensor) would avoid the FP32 round-trip here. + if not self.master_weights: + local_p = p._local_tensor if isinstance(p, DTensor) else p + raise RuntimeError( + "FusedAdam without master_weights does not support " + f"{type(local_p).__name__} parameters. Use master_weights=True." + ) + # Route to the FP32 master-weight path: Adam updates the FP32 master, + # then we write back to the quantized param after kernels run. + # Gradients may be BF16/FP16 from the backward pass — cast to FP32 + # to match the FP32 Adam kernel expectations. + p_f32_model.append(unscaled_state["master_param"].data) + g_of_f32_model.append(local_g.data.float()) + m_of_f32_model.append(unscaled_state["exp_avg"]) + v_of_f32_model.append(unscaled_state["exp_avg_sq"]) + quantized_params_to_update.append((p, unscaled_state["master_param"])) elif p.dtype in [torch.float16, torch.bfloat16]: has_fp16 = has_fp16 or p.dtype == torch.float16 has_bf16 = has_bf16 or p.dtype == torch.bfloat16 - p_f16_model.append(p.data) + p_f16_model.append(local_p.data) if self.master_weights: p_main_of_f16_model.append(unscaled_state["master_param"].data) - g_of_f16_model.append(p_grad.data) + g_of_f16_model.append(local_g.data) m_of_f16_model.append(unscaled_state["exp_avg"]) v_of_f16_model.append(unscaled_state["exp_avg_sq"]) elif p.dtype == torch.float32: - p_f32_model.append(p.data) - g_of_f32_model.append(p_grad.data) + p_f32_model.append(local_p.data) + g_of_f32_model.append(local_g.data) m_of_f32_model.append(unscaled_state["exp_avg"]) v_of_f32_model.append(unscaled_state["exp_avg_sq"]) else: @@ -614,6 +697,13 @@ def step(self, closure=None, grad_scaler=None): "FusedAdam does not support FP8 model weights with capturable=True." ) + if self.capturable and len(quantized_params_to_update) > 0: + raise RuntimeError( + "FusedAdam does not support block-scaling quantized weights " + "with capturable=True. The post-step quantize_() writeback " + "cannot be captured in a CUDA graph." + ) + if has_fp16 and has_bf16: if self.store_param_remainders: raise RuntimeError( @@ -750,8 +840,17 @@ def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=N tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model] apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) + # Write updated FP32 master weights back to quantized parameters + for qt_param, master_w in quantized_params_to_update: + local_p = qt_param._local_tensor if isinstance(qt_param, DTensor) else qt_param + local_p.quantize_(master_w.data) + # Scaling for name in ["exp_avg", "exp_avg_sq", "master_param"]: + if self.fuse_unscale and name in ["exp_avg", "exp_avg_sq"]: + # When fused_unscale is True, the scaling is fused into the Adam kernel. + # The momentums are updated inplace, so we don't need to scale here. + continue if len(unscaled_lists[name]) > 0: for unscaled, scaled, scale in zip( unscaled_lists[name], scaled_lists[name], state_scales[name] diff --git a/transformer_engine/pytorch/optimizers/fused_sgd.py b/transformer_engine/pytorch/optimizers/fused_sgd.py index 08e465e95..d7ab3fe9f 100644 --- a/transformer_engine/pytorch/optimizers/fused_sgd.py +++ b/transformer_engine/pytorch/optimizers/fused_sgd.py @@ -123,7 +123,7 @@ def __init__( self.set_grad_none = set_grad_none if self.set_grad_none is not None: warnings.warn( - "set_grad_none kwarg in FusedAdam constructor is deprecated. " + "set_grad_none kwarg in FusedSGD constructor is deprecated. " "Use set_to_none kwarg in zero_grad instead.", DeprecationWarning, ) @@ -147,7 +147,7 @@ def zero_grad(self, set_to_none: Optional[bool] = None) -> None: if set_to_none is not None and set_to_none != self.set_grad_none: raise ValueError( f"Called zero_grad with set_to_none={set_to_none}, " - f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}" + f"but FusedSGD was initialized with set_grad_none={self.set_grad_none}" ) set_to_none = self.set_grad_none if set_to_none is None: diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 5beeed126..ca59a0ebf 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -42,10 +42,16 @@ def forward( return inp, torch.tensor([], device=inp.device) # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert index.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") + if not index.is_cuda: + raise ValueError(f"index must be a CUDA tensor, but got tensor on {index.device}.") # Shape check - assert inp.size(0) == index.size(0), "Permute not possible" + if inp.size(0) != index.size(0): + raise ValueError( + f"Permute not possible: inp.size(0) ({inp.size(0)}) must match " + f"index.size(0) ({index.size(0)})." + ) # Data type check dtype = TE_DType[inp.dtype] @@ -119,7 +125,8 @@ def forward( # None probs check if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." + if not probs.is_cuda: + raise ValueError(f"probs must be a CUDA tensor, but got tensor on {probs.device}.") if probs.dtype != torch.float32: warnings.warn( @@ -136,8 +143,12 @@ def forward( probs = torch.empty(0) # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") + if not row_id_map.is_cuda: + raise ValueError( + f"row_id_map must be a CUDA tensor, but got tensor on {row_id_map.device}." + ) # Data type check dtype = TE_DType[inp.dtype] @@ -198,19 +209,30 @@ def forward( ctx.probs = probs return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device) - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert routing_map.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") + if not routing_map.is_cuda: + raise ValueError( + f"routing_map must be a CUDA tensor, but got tensor on {routing_map.device}." + ) if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." + if not probs.is_cuda: + raise ValueError(f"probs must be a CUDA tensor, but got tensor on {probs.device}.") if pad_offsets is not None: - assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." + if not pad_offsets.is_cuda: + raise ValueError( + f"pad_offsets must be a CUDA tensor, but got tensor on {pad_offsets.device}." + ) - assert inp.size(0) == routing_map.size(0), "Permute not possible" + if inp.size(0) != routing_map.size(0): + raise ValueError( + f"Permute not possible: inp.size(0) ({inp.size(0)}) must match " + f"routing_map.size(0) ({routing_map.size(0)})." + ) num_tokens, hidden_size = inp.size() num_experts = routing_map.size(1) - assert ( - num_out_tokens is not None - ), "num_out_tokens must be provided to the fused permute function." + if num_out_tokens is None: + raise ValueError("num_out_tokens must be provided to the fused permute function.") row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) @@ -226,13 +248,25 @@ def forward( if blockwise_recipe: fp8_scale = inp._rowwise_scale_inv.T.contiguous() scale_hidden_dim = fp8_scale.shape[1] - assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + if num_tokens != fp8_scale.shape[0]: + raise ValueError( + f"Scale and input shape mismatch: num_tokens ({num_tokens}) != " + f"fp8_scale.shape[0] ({fp8_scale.shape[0]}). " + f"Input shape: ({num_tokens}, {hidden_size}), " + f"scale shape: {tuple(fp8_scale.shape)}." + ) inp = inp._rowwise_data # mxfp8 scaling elif mxfp8_recipe: fp8_scale = inp._rowwise_scale_inv.contiguous() scale_hidden_dim = fp8_scale.shape[1] - assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + if num_tokens != fp8_scale.shape[0]: + raise ValueError( + f"Scale and input shape mismatch: num_tokens ({num_tokens}) != " + f"fp8_scale.shape[0] ({fp8_scale.shape[0]}). " + f"Input shape: ({num_tokens}, {hidden_size}), " + f"scale shape: {tuple(fp8_scale.shape)}." + ) inp = inp._rowwise_data # per-tensor scaling elif per_tensor_recipe: @@ -318,9 +352,11 @@ def backward( probs_grad = None if ctx.needs_input_grad[0]: row_id_map, pad_offsets = ctx.saved_tensors - assert not isinstance( - permuted_act_grad, QuantizedTensor - ), "The backward of moe_permute does not support FP8." + if isinstance(permuted_act_grad, QuantizedTensor): + raise TypeError( + "The backward of moe_permute does not support FP8, but got " + f"QuantizedTensor of type {type(permuted_act_grad).__name__}." + ) act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( permuted_act_grad, row_id_map, @@ -360,17 +396,30 @@ def forward( with_probs = merging_probs is not None if with_probs: - assert merging_probs.is_cuda, "TransformerEngine needs CUDA." + if not merging_probs.is_cuda: + raise ValueError( + "merging_probs must be a CUDA tensor, but got tensor on " + f"{merging_probs.device}." + ) # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") + if not row_id_map.is_cuda: + raise ValueError( + f"row_id_map must be a CUDA tensor, but got tensor on {row_id_map.device}." + ) if pad_offsets is not None: - assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." + if not pad_offsets.is_cuda: + raise ValueError( + f"pad_offsets must be a CUDA tensor, but got tensor on {pad_offsets.device}." + ) - assert not isinstance( - inp, QuantizedTensor - ), "The forward of moe_unpermute does not support FP8." + if isinstance(inp, QuantizedTensor): + raise TypeError( + "The forward of moe_unpermute does not support FP8, but got " + f"QuantizedTensor of type {type(inp).__name__}." + ) unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( inp, row_id_map, @@ -427,13 +476,23 @@ def backward(ctx, unpermuted_act_grad): fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() unpermuted_act_grad = unpermuted_act_grad._rowwise_data scale_hidden_dim = fp8_scale.shape[1] - assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + if ctx.num_tokens != fp8_scale.shape[0]: + raise ValueError( + f"Scale and input shape mismatch: num_tokens ({ctx.num_tokens}) != " + f"fp8_scale.shape[0] ({fp8_scale.shape[0]}). " + f"Scale shape: {tuple(fp8_scale.shape)}." + ) # mxfp8 scaling elif mxfp8_recipe: fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous() unpermuted_act_grad = unpermuted_act_grad._rowwise_data scale_hidden_dim = fp8_scale.shape[1] - assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + if ctx.num_tokens != fp8_scale.shape[0]: + raise ValueError( + f"Scale and input shape mismatch: num_tokens ({ctx.num_tokens}) != " + f"fp8_scale.shape[0] ({fp8_scale.shape[0]}). " + f"Scale shape: {tuple(fp8_scale.shape)}." + ) else: raise ValueError("Unsupported FP8 recipe") else: @@ -441,10 +500,13 @@ def backward(ctx, unpermuted_act_grad): fp8_dtype = None fp8_scale = None + permuted_scale = None if ctx.with_probs: - assert ( - not fp8 - ), "The backward of moe_unpermute with merging probs does not support FP8." + if fp8: + raise TypeError( + "The backward of moe_unpermute with merging probs does not support FP8, " + f"but got FP8 gradient with dtype {fp8_dtype}." + ) act_grad, probs_grad = ( triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( unpermuted_act_grad, @@ -619,10 +681,12 @@ def moe_permute_and_pad_with_probs( align_size : int the alignment size for the input tensor. """ - assert ( - tokens_per_expert is not None - ), "tokens_per_expert must be provided to the fused permute padding function." - assert align_size > 0, f"align_size must be positive, got {align_size}" + if tokens_per_expert is None: + raise ValueError( + "tokens_per_expert must be provided to the fused permute padding function." + ) + if align_size <= 0: + raise ValueError(f"align_size must be positive, got {align_size}.") # Ensure tokens_per_expert is on the same device as input to avoid device transfers if tokens_per_expert.device != inp.device: @@ -713,15 +777,27 @@ def forward( if not inp.numel(): return inp, probs - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert split_sizes.is_cuda, "TransformerEngine needs CUDA." - assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") + if not split_sizes.is_cuda: + raise ValueError( + f"split_sizes must be a CUDA tensor, but got tensor on {split_sizes.device}." + ) + if not sorted_idxs.is_cuda: + raise ValueError( + f"sorted_idxs must be a CUDA tensor, but got tensor on {sorted_idxs.device}." + ) if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." + if not probs.is_cuda: + raise ValueError(f"probs must be a CUDA tensor, but got tensor on {probs.device}.") num_tokens, hidden_size = inp.shape num_splits = split_sizes.size(0) - assert num_splits == sorted_idxs.size(0) + if num_splits != sorted_idxs.size(0): + raise ValueError( + f"split_sizes.size(0) ({num_splits}) must match " + f"sorted_idxs.size(0) ({sorted_idxs.size(0)})." + ) fp8 = isinstance(inp, Float8Tensor) if fp8: diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 25c2ff7f3..8a0d955f7 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -118,7 +118,8 @@ def check_recipe_support(recipe: Recipe) -> None: recipe_supported, unsupported_reason = check_fp8_block_scaling_support() elif isinstance(recipe, MXFP8BlockScaling): recipe_supported, unsupported_reason = check_mxfp8_support() - assert recipe_supported, unsupported_reason + if not recipe_supported: + raise RuntimeError(unsupported_reason) def get_default_fp8_recipe() -> Recipe: diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 5c1d7290e..f9a567eb7 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -41,6 +41,7 @@ class QuantizedTensorStorage: XTensor should only implement the functionality needed to behave like regular torch.Tensor (like __torch_dispatch__).""" + _dtype: torch.dtype _quantizer: Optional[Quantizer] def update_usage( @@ -73,7 +74,9 @@ def get_usages(self) -> Dict[str, bool]: f"{self.__class__.__name__} class does not implement get_usages function" ) - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: + def prepare_for_saving( + self, + ) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: """Prepare the tensor base for saving for backward""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement prepare_for_saving function" @@ -119,11 +122,18 @@ def update_quantizer(self, quantizer: Quantizer): warnings.warn("Quantizer is being updated, this may affect model behavior") self._quantizer = quantizer + def copy_from_storage(self, src: QuantizedTensorStorage) -> None: + """Copy data from another QuantizedTensorStorage.""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement copy_from_storage function" + ) + def prepare_for_saving( *tensors: Union[torch.Tensor, QuantizedTensorStorage], ) -> Tuple[ - list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorStorage]] + list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], + list[Optional[QuantizedTensorStorage]], ]: """Prepare tensors for saving. Needed because save_for_backward accepts only torch.Tensor/torch.nn.Parameter types, while we want to be able to save @@ -148,7 +158,10 @@ def restore_from_saved( return_saved_tensors: bool = False, ) -> ( list[Optional[torch.Tensor | QuantizedTensorStorage]] - | tuple[list[Optional[torch.Tensor | QuantizedTensorStorage]], list[Optional[torch.Tensor]]] + | tuple[ + list[Optional[torch.Tensor | QuantizedTensorStorage]], + list[Optional[torch.Tensor]], + ] ): """Recombine the tensor data and metadata during backward pass.""" tensor_objects = [] @@ -358,11 +371,23 @@ def __new__( shape: Iterable[int], dtype: torch.dtype, *, + fake_dtype: Optional[torch.dtype] = None, requires_grad: bool = False, device: Optional[torch.device] = None, + stride: Optional[Iterable[int]] = None, ): - # We are assuming only contiguous tensors - stride = _stride_from_shape(shape) + if fake_dtype is not None and fake_dtype != dtype: + raise ValueError(f"fake_dtype ({fake_dtype}) does not match dtype ({dtype})") + # For stride, We are assuming only contiguous tensors + # Calculate stride from shape if not provided. When creating this object from + # C++ code, we provide the stride computed from shape in C++ to avoid the + # PyobjectVectorCall overhead of calling _stride_from_shape from C++ to Python. + stride = _stride_from_shape(shape) if stride is None else stride + if IS_HIP_EXTENSION and device == torch.device("cuda"): + # Without passing explicit device index to _make_wrapper_subclass tests fail with + # RuntimeError at autograd: 0 <= device.index() && + # device.index() < static_cast(device_ready_queues_.size()) + device = torch.device("cuda", torch.cuda.current_device()) instance = torch.Tensor._make_wrapper_subclass( cls, shape, @@ -373,9 +398,75 @@ def __new__( requires_grad=requires_grad, device=torch.cuda.current_device() if device is None else device, ) - + instance._requires_grad = requires_grad + instance._dtype = dtype return instance + @property + def dtype(self) -> torch.dtype: + """ + Return the high precision data type of the tensor + Attribute access of custom tensors goes through an + expensive Pyobject lookup. Since dtype for a tensor is never + change after creation, we cache it in a member variable and return + """ + # Lazy initialization for tensors created via alternate paths + if not hasattr(self, "_dtype"): + # pylint: disable=unnecessary-dunder-call + self._dtype = torch._C.TensorBase.dtype.__get__(self, type(self)) + return self._dtype + + @dtype.setter + def dtype(self, value: torch.dtype) -> None: + """Set dtype property""" + self._dtype = value + + @property + def requires_grad(self) -> bool: + """ + Return whether or not the tensor requires gradient. + Attribute access of custom tensors goes through an + expensive Pyobject lookup. Since requires_grad is set during + initialization and may be updated, we cache it in a member variable. + """ + # Fallback to parent if not cached yet + if not hasattr(self, "_requires_grad"): + # pylint: disable=unnecessary-dunder-call + self._requires_grad = torch._C.TensorBase.requires_grad.__get__(self, type(self)) + return self._requires_grad + + @requires_grad.setter + def requires_grad(self, value: bool) -> None: + """Set requires_grad property so that autograd engine is aware of the change""" + # Update the cached value and call parent class method to ensure autograd engine is aware + self.requires_grad_(value) + + def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: + """Cache requires_grad property and call parent class method""" + # pylint: disable=missing-function-docstring + # Update the cached value + self._requires_grad = requires_grad + # Call parent class method to ensure autograd engine is aware + super().requires_grad_(requires_grad) + return self + + def _get_data(self) -> torch.Tensor: + """Get tensor data property""" + return super().data + + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + Updates the underlying tensor data and syncs the dtype cache. + """ + # Update the parent class's data descriptor + # pylint: disable=unnecessary-dunder-call + super(QuantizedTensor, type(self)).data.__set__(self, tensor) + # Update the dtype cache + self._dtype = tensor.dtype + + # Create the data property with getter and setter + data = property(_get_data, _set_data) + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" raise NotImplementedError( @@ -406,7 +497,7 @@ def clear(self): ) def __repr__(self, *, tensor_contents=None) -> str: - return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" + return f"{self.__class__.__name__}(data={self.dequantize()})" def float(self) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -469,7 +560,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst.quantize_(src) else: if isinstance(src, QuantizedTensor): - src = src.dequantize() + dtype = dst.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + dtype = torch.float32 + src = src.dequantize(dtype=dtype) dst.copy_(src) return None @@ -477,6 +571,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.view.default: raise NotImplementedError("{cls.__name__} class does not support tensor views") + # New empty op (used by DCP async staging to create CPU copies) + if func == torch.ops.aten.new_empty.default: + tensor = args[0] + size = args[1] + dtype = kwargs.get("dtype", tensor.dtype) + device = kwargs.get("device", tensor.device) + pin_memory = kwargs.get("pin_memory", False) + if tensor._quantizer is None: + raise RuntimeError( + f"{type(tensor).__name__} does not have a quantizer; " + "cannot create new_empty QuantizedTensor" + ) + out = tensor._quantizer.make_empty( + shape=torch.Size(size), + dtype=dtype, + device=device, + requires_grad=tensor.requires_grad, + pin_memory=pin_memory, + ) + return out + # Empty like op if func == torch.ops.aten.empty_like.default: tensor = args[0] @@ -509,7 +624,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): def maybe_unwrap(arg): if isinstance(arg, QuantizedTensor): - return arg.dequantize(dtype=arg.dtype) + return arg.dequantize() return arg def maybe_update_inplace(arg, new_arg, schema_arg): @@ -592,6 +707,7 @@ def make_like( shape = shape if shape is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype kwargs = tensor.get_metadata() + kwargs["fake_dtype"] = dtype return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs) def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: diff --git a/transformer_engine/pytorch/router.py b/transformer_engine/pytorch/router.py index 52d1d9d6c..b56b1cd5e 100644 --- a/transformer_engine/pytorch/router.py +++ b/transformer_engine/pytorch/router.py @@ -3,7 +3,18 @@ # See LICENSE for license information. """ Fused functions used in the MoE router + +Precision Notes: +- FP64 is currently not supported. +- Inputs are casted into FP32 when loading from global memory. +- All the math/calculations/accumulations are in FP32 in the kernels. +- "scores" is always in FP32 (match the MCore implementation). +- "intermediate_output" is always in FP32 for better backward precision. +- Only cast to low-precision when necessary and the casting only happens in writing to + global memory. For example, the gradient is required to have the same dtype as the input. """ +from typing import Optional + import torch import transformer_engine_torch as tex @@ -11,7 +22,7 @@ class FusedTopkScoreFunction(torch.autograd.Function): """ Fused Topk with Score Function router. - Currently, only support softmax and sigmoid. + Currently, support "softmax", "sigmoid" and "sqrtsoftplus". """ @staticmethod @@ -20,11 +31,11 @@ def forward( logits: torch.Tensor, topk: int, use_pre_softmax: bool, - num_groups: int, - group_topk: int, - scaling_factor: float, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: Optional[float], score_function: str, - expert_bias: torch.Tensor, + expert_bias: Optional[torch.Tensor], ): # pylint: disable=missing-function-docstring # Save the shape of the logits @@ -52,6 +63,7 @@ def forward( ctx.topk = topk ctx.scaling_factor = scaling_factor ctx.score_function = score_function + ctx.logits_dtype = logits.dtype return probs, routing_map @staticmethod @@ -62,12 +74,16 @@ def backward(ctx, grad_probs, _): tensor_shape = grad_probs.shape # Adjust the shape of the grad_probs to 2D shape grad_probs = grad_probs.contiguous().view(-1, tensor_shape[-1]) - grad_logits = tex.fused_topk_with_score_function_bwd( + grad_logits = torch.empty( + (ctx.num_tokens, ctx.num_experts), dtype=ctx.logits_dtype, device=grad_probs.device + ) + tex.fused_topk_with_score_function_bwd( ctx.num_tokens, ctx.num_experts, routing_map, intermediate_output, grad_probs, + grad_logits, ctx.topk, ctx.use_pre_softmax, ctx.scaling_factor, @@ -82,37 +98,37 @@ def fused_topk_with_score_function( logits: torch.Tensor, topk: int, use_pre_softmax: bool, - num_groups: int, - group_topk: int, - scaling_factor: float, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: Optional[float], score_function: str, - expert_bias: torch.Tensor, + expert_bias: Optional[torch.Tensor], ): """ Fused topk with score function router. Parameters ---------- - logits : torch.Tensor + logits : torch.Tensor in fp32/bf16/fp16 topk : int use_pre_softmax : bool - if enabled, the computation order: softmax -> topk - num_groups : int + if enabled, the computation order: softmax -> topk. + num_groups : int, optional used in the group topk - group_topk : int + group_topk : int, optional used in the group topk - scaling_factor : float + scaling_factor : float, optional score_function : str - currently only support softmax and sigmoid - expert_bias : torch.Tensor - could be used in the sigmoid + currently support "softmax", "sigmoid" and "sqrtsoftplus". + expert_bias : torch.Tensor, optional + could be used with the sigmoid/sqrtsoftplus score functions. Returns ------- - probs : torch.Tensor - routing_map : torch.Tensor + probs : torch.Tensor in the same dtype as the "logits". + routing_map : torch.Tensor in bool. """ if logits.dtype == torch.float64: - raise ValueError("Current TE does not support float64 router type") + raise ValueError("Current TE does not support float64 router type.") return FusedTopkScoreFunction.apply( logits, topk, @@ -154,6 +170,7 @@ def forward( ctx.score_function = score_function ctx.num_tokens = num_tokens ctx.num_experts = num_experts + ctx.logits_dtype = logits.dtype return routing_map, scores @staticmethod @@ -164,11 +181,15 @@ def backward(ctx, _, grad_scores): tensor_shape = grad_scores.shape # Adjust the shape of the grad_scores to 2D shape grad_scores = grad_scores.contiguous().view(-1, tensor_shape[-1]) - grad_logits = tex.fused_score_for_moe_aux_loss_bwd( + grad_logits = torch.empty( + (ctx.num_tokens, ctx.num_experts), dtype=ctx.logits_dtype, device=grad_scores.device + ) + tex.fused_score_for_moe_aux_loss_bwd( num_tokens=ctx.num_tokens, num_experts=ctx.num_experts, intermediate_output=intermediate_output, grad_scores=grad_scores, + grad_logits=grad_logits, topk=ctx.topk, score_function=ctx.score_function, ) @@ -186,15 +207,15 @@ def fused_compute_score_for_moe_aux_loss( Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function. Parameters ---------- - logits : torch.Tensor + logits : torch.Tensor in fp32/bf16/fp16 topk : int score_function : str - currently only support softmax and sigmoid + currently support "softmax", "sigmoid" and "sqrtsoftplus". Returns ------- - routing_map : torch.Tensor - scores : torch.Tensor + routing_map : torch.Tensor in bool + scores : torch.Tensor in fp32 """ return FusedComputeScoresForMoEAuxLoss.apply(logits, topk, score_function) @@ -253,23 +274,24 @@ def fused_moe_aux_loss( num_experts: int, topk: int, coeff: float, -): +) -> torch.Tensor: """ Fused MoE aux loss. Parameters ---------- - probs : torch.Tensor - tokens_per_expert : torch.Tensor - the number of tokens per expert + probs : torch.Tensor in fp32/bf16/fp16 + tokens_per_expert : torch.Tensor in int32/int64/fp32/bf16 + the number of tokens per expert. total_num_tokens : int - the total number of tokens, involved in the aux loss calculation + the total number of tokens used in the aux loss calculation. num_experts : int topk : int coeff : float - the coefficient of the aux loss + the coefficient of the aux loss. Returns ------- - aux_loss : torch.scalar + aux_loss : torch.Tensor. + A scalar tensor in the same dtype as the "probs". """ return FusedAuxLoss.apply(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff) diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index cb199d24b..566805670 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -17,10 +17,12 @@ from .storage.mxfp8_tensor_storage import MXFP8TensorStorage from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from .storage.nvfp4_tensor_storage import NVFP4TensorStorage +from .storage.grouped_tensor_storage import GroupedTensorStorage from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer +from .grouped_tensor import GroupedTensor from .utils import cast_master_weights_to_fp8, replace_raw_data __all__ = [ @@ -35,11 +37,13 @@ "MXFP8TensorStorage", "Float8BlockwiseQTensorStorage", "NVFP4TensorStorage", + "GroupedTensorStorage", "QuantizedTensor", "Float8Tensor", "MXFP8Tensor", "Float8BlockwiseQTensor", "NVFP4Tensor", + "GroupedTensor", "prepare_for_saving", "restore_from_saved", ] @@ -89,5 +93,7 @@ def get_all_tensor_types(): Float8BlockwiseQTensorStorage, NVFP4Tensor, NVFP4TensorStorage, + GroupedTensor, + GroupedTensorStorage, ] return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 180cc6f25..ca79ce1ce 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -8,6 +8,7 @@ from __future__ import annotations from collections.abc import Iterable import math +import warnings from typing import Any, Optional, Tuple, Union import torch @@ -337,7 +338,7 @@ def __repr__(self, *, tensor_contents=None): return ( f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," f" is_2D_scaled={self._is_2D_scaled}," - f" data={self.dequantize(dtype=self.dtype)})" + f" data={self.dequantize()})" ) def quantize_( @@ -440,6 +441,30 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) return Float8BlockwiseQTensor.make_like(tensor) + # as_strided op — applied by FSDP2 on the unsharded param. + # When shape and strides match (no-op), return self to preserve the quantized type. + # If shape differs (e.g. padding needed), fall through to dequantize. + if func == aten.as_strided.default: + tensor = args[0] + shape = args[1] + strides = args[2] + if ( + len(shape) == len(strides) == 2 + and tuple(strides) == (shape[-1], 1) + and tuple(shape) == tuple(tensor.size()) + ): + return Float8BlockwiseQTensor.make_like(tensor) + + # slice op — applied by FSDP2 when shards need unpadding. + # When the slice is a no-op (covers entire dimension), return self. + if func == aten.slice.Tensor: + tensor = args[0] + dim = args[1] + start = args[2] + length = args[3] + if start == 0 and length == tensor.size(dim): + return Float8BlockwiseQTensor.make_like(tensor) + # record stream op if func == torch.ops.aten.record_stream.default: qt, stream = args @@ -579,6 +604,164 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): # Cast to FP8 when setting Float8BlockwiseQTensor.data data = property(_get_data, _set_data) + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.shape + if self._columnwise_data is not None: + return self._columnwise_data.shape + raise RuntimeError("Float8BlockwiseQTensor has no data!") + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._rowwise_data is not None: + return self._rowwise_data.is_cuda + if self._columnwise_data is not None: + return self._columnwise_data.is_cuda + raise RuntimeError("Float8BlockwiseQTensor has no data!") + + def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy): + """Called by FSDP2 before all-gather of weights for forward and backward passes. + + Args: + mesh: DeviceMesh used by FSDP2 to shard the weights. + orig_size: Original size of the weight tensor. + contiguous_orig_stride: Original stride of the weight tensor. + module: FSDP-wrapped module containing this tensor. + mp_policy: Mixed precision policy used by FSDP2. + + Returns: + sharded_tensors: Tuple of tensors to be all-gathered. + metadata: Metadata needed for reconstructing the tensor after all-gather. + """ + # pylint: disable=unused-argument + from transformer_engine.pytorch.distributed import _get_module_fsdp_state + + if not self._is_2D_scaled: + raise NotImplementedError( + "FSDP2 is only supported for Float8BlockwiseQTensors with 2D block scaling " + "(block_scaling_dim=2). 1D block scaling is not supported because the scale " + "layout has M in dim1, which is incompatible with FSDP2 dim0 all-gather." + ) + + block_len = self._quantizer.block_len # 128 + + # Prepare rowwise tensors — for 2D scaling, M is in dim0 of both data and scale_inv, + # so they naturally align with FSDP2's dim0 all-gather. No unpadding needed. + rowwise_data = self._rowwise_data + rowwise_scale_inv = self._rowwise_scale_inv + + # Prepare columnwise tensors — columnwise data is transposed (K, M) and + # columnwise scale_inv is (ceil(K/128), round_up(ceil(M/128), 4)). + # M is in dim1 for both, so we must transpose to put M in dim0 for all-gather. + columnwise_data = self._columnwise_data + columnwise_scale_inv = self._columnwise_scale_inv + + if columnwise_data is not None: + # Transpose (K, shard_M) -> (shard_M, K) so M is in dim0 + columnwise_data = columnwise_data.t().contiguous() + + if columnwise_scale_inv is not None: + # Original shape: (ceil(K/128), round_up(ceil(shard_M/128), 4)) + # Strip padding from dim1 (the M-block dimension), transpose, then all-gather + shard_M = math.prod(self.shape[:-1]) + m_blocks = (shard_M + block_len - 1) // block_len # ceil(shard_M/128) + columnwise_scale_inv = columnwise_scale_inv[:, :m_blocks] # unpad dim1 + columnwise_scale_inv = columnwise_scale_inv.t().contiguous() # (m_blocks, k_blocks) + + # Always send both rowwise and columnwise data. + # Unlike MXFP8 (where both forms share the same shape), Float8Blockwise has + # differently-shaped rowwise (M, K) and columnwise (K, M) data. The GEMM kernel + # needs both forms available to perform forward and backward operations, so we + # cannot optimize by sending only one usage based on forward/backward pass. + rowwise_usage = True + sharded_tensors = (rowwise_data, rowwise_scale_inv) + columnwise_usage = self._quantizer.columnwise_usage + if columnwise_usage: + sharded_tensors += (columnwise_data, columnwise_scale_inv) + + metadata = (self._fp8_dtype, self._is_2D_scaled, rowwise_usage, columnwise_usage) + return sharded_tensors, metadata + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[Float8BlockwiseQTensor] = None, + ): + """Called by FSDP2 after all-gather of weights for forward and backward passes. + + Args: + all_gather_outputs: All-gathered tensors from fsdp_pre_all_gather. + metadata: Metadata from fsdp_pre_all_gather. + param_dtype: High-precision dtype of the tensor. + out: Existing tensor to update in-place (None on first iteration). + + Returns: + Tuple of (Float8BlockwiseQTensor, all_gather_outputs). + """ + fp8_dtype, is_2D_scaled, rowwise_usage, columnwise_usage = metadata + + # Extract rowwise tensors from all-gather outputs + rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] if rowwise_usage else (None, None) + + # Extract columnwise tensors — they were transposed in pre_all_gather, + # so we need to transpose them back. + columnwise_data, columnwise_scale_inv = ( + all_gather_outputs[-2:] if columnwise_usage else (None, None) + ) + + if columnwise_data is not None: + # All-gathered shape is (full_M, K), transpose back to (K, full_M) + columnwise_data = columnwise_data.t().contiguous() + + if columnwise_scale_inv is not None: + # All-gathered shape is (full_m_blocks, k_blocks), + # transpose back to (k_blocks, full_m_blocks) + columnwise_scale_inv = columnwise_scale_inv.t().contiguous() + # Repad dim1 (M-block dimension) to multiple of 4 for GEMM alignment + current_m_blocks = columnwise_scale_inv.shape[1] + pad_amount = (4 - current_m_blocks % 4) % 4 + if pad_amount > 0: + columnwise_scale_inv = torch.nn.functional.pad( + columnwise_scale_inv, (0, pad_amount) + ) + + # Determine the logical shape from the all-gathered data + if rowwise_data is not None: + data_shape = rowwise_data.shape + else: + # columnwise_data is (K, full_M), logical shape is (full_M, K) + data_shape = (columnwise_data.shape[1], columnwise_data.shape[0]) + + if out is not None: + # Update existing tensor in-place (subsequent iterations) + out._rowwise_data = rowwise_data + out._rowwise_scale_inv = rowwise_scale_inv + out._columnwise_data = columnwise_data + out._columnwise_scale_inv = columnwise_scale_inv + else: + # Construct new tensor (first iteration). + # Float8BlockwiseQTensor constructor copies the quantizer, + # so the sharded tensor's quantizer remains independent. + out = Float8BlockwiseQTensor( + shape=data_shape, + dtype=param_dtype, + fp8_dtype=fp8_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=self._quantizer, + is_2D_scaled=is_2D_scaled, + ) + out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) + return out, all_gather_outputs + class _ViewFunc(torch.autograd.Function): """View function @@ -616,19 +799,27 @@ def forward( if tensor._is_2D_scaled: # For the case of 2D scaled tensor, the last 2 dimensions should not change if shape[-1] != ctx.shape[-1] or shape[-2] != ctx.shape[-2]: - raise RuntimeError( + warnings.warn( "2D scaled Float8BlockwiseQTensor does not support view " "the last 2 dimensions " - f"(attempted to view dims={tuple(tensor.shape)} to {tuple(shape)})" + f"(attempted to view dims={tuple(tensor.shape)} to {tuple(shape)}). " + "If you are using this for FSDP2 without compiled_autograd_enabled, " + "then ignore this warning since this view is not going to be used anywhere.", + stacklevel=2, ) + return tensor.dequantize().view(*shape) else: # For the case of 1D scaled tensor, the last dimension should not change if shape[-1] != ctx.shape[-1]: - raise RuntimeError( + warnings.warn( "1D scaled Float8BlockwiseQTensor does not support view " "the last dimension " - f"(attempted to view dims={tuple(tensor.shape)} to {tuple(shape)})" + f"(attempted to view dims={tuple(tensor.shape)} to {tuple(shape)}). " + "If you are using this for FSDP2 without compiled_autograd_enabled, " + "then ignore this warning since this view is not going to be used anywhere.", + stacklevel=2, ) + return tensor.dequantize().view(*shape) if list(shape) == list(tensor.shape): return tensor @@ -723,19 +914,27 @@ def forward( if tensor._is_2D_scaled: # For the case of 2D scaled tensor, the last 2 dimensions should not change if shape[-1] != ctx.shape[-1] or shape[-2] != ctx.shape[-2]: - raise RuntimeError( + warnings.warn( "2D scaled Float8BlockwiseQTensor does not support reshaping " "the last 2 dimensions " - f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)}). " + "If you are using this for FSDP2 without compiled_autograd_enabled, " + "then ignore this warning since this view is not going to be used anywhere.", + stacklevel=2, ) + return tensor.dequantize().reshape(*shape) else: # For the case of 1D scaled tensor, the last dimension should not change if shape[-1] != ctx.shape[-1]: - raise RuntimeError( + warnings.warn( "1D scaled Float8BlockwiseQTensor does not support reshaping " "the last dimension " - f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)}). " + "If you are using this for FSDP2 without compiled_autograd_enabled, " + "then ignore this warning since this view is not going to be used anywhere.", + stacklevel=2, ) + return tensor.dequantize().reshape(*shape) if list(shape) == list(tensor.shape): return tensor diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 8ea81d912..c2dc4d7a3 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -14,7 +14,11 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + Recipe, +) from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func from ..quantized_tensor import QuantizedTensor, Quantizer @@ -167,12 +171,17 @@ def make_empty( requires_grad=requires_grad, data_transpose=data_transpose, quantizer=self, + device=device, ) def calibrate(self, tensor: torch.Tensor) -> None: amin, amax = tensor.aminmax() self.amax.copy_(torch.max(-amin, amax)) + def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for Float8 1D blockwise quantization.""" + return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1]) + def create_tensor_from_data( self, data: torch.Tensor, @@ -193,6 +202,7 @@ def create_tensor_from_data( data=data, fp8_scale_inv=1 / self.scale, fp8_dtype=self.dtype, + fake_dtype=fake_dtype, requires_grad=requires_grad, data_transpose=None, quantizer=self, @@ -295,6 +305,12 @@ def __init__( self.force_pow_2_scales = force_pow_2_scales self.amax_epsilon = amax_epsilon + def __getstate__(self): + """Exclude unpicklable process group from serialized state.""" + state = self.__dict__.copy() + state["amax_reduction_group"] = None + return state + def copy(self) -> Float8CurrentScalingQuantizer: """Create shallow copy""" @@ -394,6 +410,7 @@ def make_empty( requires_grad=requires_grad, data_transpose=data_transpose, quantizer=self, + device=device, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -423,6 +440,7 @@ def create_tensor_from_data( data=data, fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), fp8_dtype=self.dtype, + fake_dtype=fake_dtype, requires_grad=requires_grad, data_transpose=None, quantizer=self, @@ -438,6 +456,10 @@ def create_tensor_from_data( quantizer=self, ) + def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for Float8 1D blockwise quantization.""" + return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1]) + def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: """Function using primitives with ONNX defined translations.""" if tensor.dtype != torch.float32: @@ -510,7 +532,7 @@ def __repr__(self, *, tensor_contents=None): "Float8Tensor(" f"fp8_dtype={self._fp8_dtype}, " f"scale_inv={self._scale_inv.item()}, " - f"data={self.dequantize(dtype=self.dtype)}" + f"data={self.dequantize()}" ")" ) @@ -799,7 +821,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): kwargs, ) return Float8Tensor.make_like( - tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape + tensor, + data=func_out, + data_transpose=func_transposed_out, + shape=func_out.shape, ) if func == torch.ops.aten.detach.default: @@ -941,6 +966,34 @@ def fsdp_post_all_gather( ) return out, all_gather_outputs + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._data is not None: + return self._data.shape + if self._transpose is not None: + transpose_shape = self._transpose.shape + return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],)) + raise RuntimeError("Both data and transpose are None") + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._data is not None: + return self._data.is_cuda + if self._transpose is not None: + return self._transpose.is_cuda + raise RuntimeError("Both data and transpose are None") + + @property + def is_cpu(self): + """Return whether the tensor is on CPU.""" + if self._data is not None: + return self._data.is_cpu + if self._transpose is not None: + return self._transpose.is_cpu + raise RuntimeError("Both data and transpose are None") + @classmethod def _make_in_reduce_ex( cls, @@ -965,7 +1018,16 @@ def _make_in_reduce_ex( ) def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects""" + """Custom pickling to remove references to FP8 metadata objects + + CPU Float8Tensors are serialized as dequantized plain tensors + for compatibility with torch.load(weights_only=True), which is + used by DCP async save staging. + """ + data_is_cpu = self._data is not None and self._data.is_cpu + transpose_is_cpu = self._transpose is not None and self._transpose.is_cpu + if data_is_cpu or transpose_is_cpu: + return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol) return ( Float8Tensor._make_in_reduce_ex, (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py new file mode 100644 index 000000000..2fce9a38e --- /dev/null +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -0,0 +1,351 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Grouped tensor class for handling collections of tensors with different shapes""" +from __future__ import annotations + +from typing import List, Optional, Tuple + +import torch +from torch.utils._pytree import tree_map + +from ..quantized_tensor import QuantizedTensorStorage, Quantizer +from .storage.grouped_tensor_storage import GroupedTensorStorage + + +def _stride_from_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Calculate contiguous stride from shape.""" + if len(shape) == 0: + return () + stride = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + stride[i] = stride[i + 1] * shape[i + 1] + return tuple(stride) + + +class _GroupedIdentityFunc(torch.autograd.Function): + """Identity autograd function used to create a dummy grad_fn node.""" + + @staticmethod + def forward(ctx, tensor: "GroupedTensor") -> "GroupedTensor": + # pylint: disable=missing-function-docstring + ctx.input_dtype = tensor.dtype + return tensor.detach() + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # pylint: disable=missing-function-docstring + grad_input = grad_output + if grad_input.dtype != ctx.input_dtype: + grad_input = grad_input.to(ctx.input_dtype) + return grad_input + + +# For now, conservatively ban 'most' shape manipulating ops. +BANNED_SHAPE_OPS = { + torch.ops.aten.reshape.default, + torch.ops.aten._reshape_alias.default, + torch.ops.aten.flatten.using_ints, + torch.ops.aten.unflatten.int, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.transpose.int, + torch.ops.aten.permute.default, + torch.ops.aten.movedim.int, + torch.ops.aten.t.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.narrow.default, + torch.ops.aten.select.int, + torch.ops.aten.split.Tensor, + torch.ops.aten.chunk.default, + torch.ops.aten.cat.default, + torch.ops.aten.stack.default, +} + + +class GroupedTensor(GroupedTensorStorage, torch.Tensor): + """Tensor wrapper class for grouped tensor storage.""" + + def __new__( + cls, + shape: Tuple[int, int], + dtype: torch.dtype, + *, + num_tensors: int, + shapes: Optional[List[Tuple[int, int]]] = None, + quantizer: Optional[Quantizer] = None, + data: Optional[torch.Tensor] = None, + columnwise_data: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + columnwise_scale_inv: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + columnwise_amax: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + first_dims: Optional[torch.Tensor] = None, + last_dims: Optional[torch.Tensor] = None, + tensor_offsets: Optional[torch.Tensor] = None, + offsets: Optional[List[int]] = None, + scale_inv_offsets: Optional[List[int]] = None, + columnwise_scale_inv_offsets: Optional[List[int]] = None, + requires_grad: bool = False, + stride: Optional[List[int]] = None, + with_gemm_swizzled_scales: bool = False, + ): + if ( + shapes is not None + and len(shapes) == num_tensors + and num_tensors > 0 + and all(shapes[0] == s for s in shapes) + ): + wrapper_shape = (num_tensors, shapes[0][0], shapes[0][1]) + else: + wrapper_shape = shape + + device = None + for maybe_tensor in ( + data, + columnwise_data, + scale_inv, + columnwise_scale_inv, + amax, + columnwise_amax, + scale, + first_dims, + last_dims, + tensor_offsets, + ): + if maybe_tensor is not None: + device = maybe_tensor.device + break + if device is None: + device = torch.device("cuda") + + # Match QuantizedTensor __new__: accept externally-computed stride to + # avoid Python-side stride computation overhead for C++ construction. + strides = _stride_from_shape(tuple(wrapper_shape)) if stride is None else tuple(stride) + instance = torch.Tensor._make_wrapper_subclass( + cls, + wrapper_shape, + strides=strides, + storage_offset=0, + dtype=dtype, + layout=torch.strided, + requires_grad=requires_grad, + device=device, + ) + GroupedTensorStorage._initialize_storage_fields( + instance=instance, + shape=shape, + dtype=dtype, + num_tensors=num_tensors, + shapes=shapes, + quantizer=quantizer, + data=data, + columnwise_data=columnwise_data, + scale_inv=scale_inv, + columnwise_scale_inv=columnwise_scale_inv, + amax=amax, + columnwise_amax=columnwise_amax, + scale=scale, + first_dims=first_dims, + last_dims=last_dims, + tensor_offsets=tensor_offsets, + offsets=offsets, + scale_inv_offsets=scale_inv_offsets, + columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + ) + return instance + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + """Dispatch by dequantizing grouped members, then requantizing writes.""" + if kwargs is None: + kwargs = {} + + def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> None: + """Shallow-copy grouped-storage metadata onto wrapper outputs.""" + dst.num_tensors = src.num_tensors + dst.quantizer = src.quantizer + dst.tensor_shapes = src.tensor_shapes + dst.fake_dtype = src.fake_dtype + dst.rowwise_data = src.rowwise_data + dst.columnwise_data = src.columnwise_data + dst.scale_inv = src.scale_inv + dst.columnwise_scale_inv = src.columnwise_scale_inv + dst.amax = src.amax + dst.columnwise_amax = src.columnwise_amax + dst.scale = src.scale + dst.first_dims = src.first_dims + dst.last_dims = src.last_dims + dst.tensor_offsets = src.tensor_offsets + dst.offsets = src.offsets + dst.scale_inv_offsets = src.scale_inv_offsets + dst.columnwise_scale_inv_offsets = src.columnwise_scale_inv_offsets + dst.logical_shape = src.logical_shape + dst.quantized_tensors = src.quantized_tensors + + def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: + """Create a wrapper of the same type and tensor metadata as src.""" + out = torch.Tensor._make_wrapper_subclass( + type(src), + tuple(src.shape), + strides=tuple(src.stride()), + storage_offset=src.storage_offset(), + dtype=src.dtype, + layout=src.layout, + requires_grad=requires_grad, + device=src.device, + ) + copy_grouped_storage_metadata(out, src) + return out + + # Parameter construction calls detach()/alias-like paths. + if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default): + src = args[0] + if not isinstance(src, GroupedTensor): + raise TypeError(f"Expected GroupedTensor, got {type(src).__name__}") + if func == torch.ops.aten.detach.default: + return make_wrapper_like(src, requires_grad=False) + return make_wrapper_like(src, requires_grad=src.requires_grad) + + # Parameter construction may invoke aten.expand on tensor subclasses. + # Handle this explicitly so grouped parameters can be created safely. + if func == torch.ops.aten.expand.default: + src = args[0] + if not isinstance(src, GroupedTensor): + raise TypeError(f"Expected GroupedTensor, got {type(src).__name__}") + expanded_shape = tuple(args[1]) + src_shape = tuple(src.shape) + if len(expanded_shape) == len(src_shape): + normalized_shape = tuple( + src_shape[i] if dim == -1 else dim for i, dim in enumerate(expanded_shape) + ) + if normalized_shape == src_shape: + return make_wrapper_like(src, requires_grad=src.requires_grad) + return super().__torch_dispatch__(func, types, args, kwargs) + + # DDP and mcore use expand_as(self) to build a dummy autograd node and + # access gradient accumulators during parameter hook registration. + if func == torch.ops.aten.expand_as.default: + src = args[0] + other = args[1] + if not isinstance(src, GroupedTensor): + raise TypeError(f"Expected GroupedTensor, got {type(src).__name__}") + if other is src: + return _GroupedIdentityFunc.apply(src) + if tuple(other.shape) == tuple(src.shape): + return make_wrapper_like(src, requires_grad=src.requires_grad) + return super().__torch_dispatch__(func, types, args, kwargs) + + # Distributed optimizer flattens detached parameters via + # model_param.detach().view(-1). Support this path explicitly by + # returning a flat view of grouped backing storage. + if func in (torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default): + src = args[0] + if not isinstance(src, GroupedTensor): + raise TypeError(f"Expected GroupedTensor, got {type(src).__name__}") + target_shape = tuple(args[1]) + if target_shape in ((-1,), (src.numel(),)): + if src.rowwise_data is not None: + return src.rowwise_data.view(-1) + raise RuntimeError( + f"{cls.__name__} view(-1) requires rowwise_data to be initialized" + ) + raise RuntimeError( + f"{cls.__name__} only supports view(-1) for distributed optimizer flattening" + ) + + # Don't allow reshape/view etc. + if func in BANNED_SHAPE_OPS: + raise RuntimeError(f"{cls.__name__} forbids shape-manipulation op: {func} ") + + def grouped_to_stacked_tensor(grouped: GroupedTensor) -> torch.Tensor: + if not grouped.all_same_shape(): + raise NotImplementedError( + "GroupedTensor __torch_dispatch__ currently supports only uniform member shapes" + ) + grouped_members = grouped.quantized_tensors + if grouped_members is None: + grouped_members = grouped.split_into_quantized_tensors() + dequantized_members = [ + ( + member.dequantize(dtype=grouped.get_dtype()) + if isinstance(member, QuantizedTensorStorage) + else member + ) + for member in grouped_members + ] + return torch.stack(dequantized_members, dim=0) + + def maybe_unwrap(arg): + if isinstance(arg, GroupedTensor): + return grouped_to_stacked_tensor(arg) + return arg + + def update_grouped_tensor_inplace(grouped: GroupedTensor, updated: torch.Tensor): + if not grouped.all_same_shape(): + raise NotImplementedError( + "GroupedTensor __torch_dispatch__ currently supports only uniform member shapes" + ) + updated_members = list(updated.unbind(dim=0)) + if grouped.quantizer is None: + grouped_members = grouped.quantized_tensors + if grouped_members is None: + grouped_members = grouped.split_into_quantized_tensors() + for dst, src in zip(grouped_members, updated_members): + dst.copy_(src) + else: + grouped.quantize(updated_members) + + def maybe_update_inplace(arg, new_arg, schema_arg): + if ( + isinstance(arg, GroupedTensor) + and isinstance(new_arg, torch.Tensor) + and hasattr(schema_arg, "alias_info") + and hasattr(schema_arg.alias_info, "is_write") + and schema_arg.alias_info.is_write + ): + update_grouped_tensor_inplace(arg, new_arg) + elif isinstance(arg, list) and isinstance(new_arg, list): + for a, na in zip(arg, new_arg): + maybe_update_inplace(a, na, schema_arg) + + # In-place op: dequantize members, perform op, write back into grouped storage. + if func._schema.is_mutable: + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + schema_args = func._schema.arguments + args_len = len(args) + super().__torch_dispatch__(func, types, new_args, new_kwargs) + for arg, new_arg, schema_arg in zip(args, new_args, schema_args): + maybe_update_inplace(arg, new_arg, schema_arg) + for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): + if kwarg != new_kwarg or kwarg != schema_arg.name: + raise RuntimeError( + f"Name of kwarg should match schema, got kwarg={kwarg!r}," + f" new_kwarg={new_kwarg!r}, schema_arg.name={schema_arg.name!r}" + ) + maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) + return None + + # Default op: operate on dequantized stacked tensors. + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + return super().__torch_dispatch__(func, types, new_args, new_kwargs) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + # Do not force GroupedTensor on outputs. + return torch._C._disabled_torch_function_impl(func, types, args, kwargs) + + def expand_as(self, other: torch.Tensor) -> torch.Tensor: + # pylint: disable=missing-function-docstring + # Needed during parameter creation/hook registration paths. + if other is self: + return _GroupedIdentityFunc.apply(self) + return super().expand_as(other) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index bd3d93e9f..63a460276 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -203,6 +203,49 @@ def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass + def get_scale_shape( + self, + shape: Iterable[int], + columnwise: bool, + ) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For MXFP8 1D blockwise quantization, blocksize is 32 + Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. + CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + if columnwise: + # Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]] + # with padding to multiples of [4, 128] + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + ) + # Rowwise: scale_inv shape is [prod(shape[:-1]), shape[-1] // BLOCK_SIZE] + # with padding to multiples of [128, 4] + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + ) + + def get_columnwise_shape(self, rowwise_data_shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for MXFP8 1D blockwise quantization.""" + return rowwise_data_shape + def create_tensor_from_data( self, data: torch.Tensor, @@ -292,7 +335,7 @@ def __new__( ) def __repr__(self, *, tensor_contents=None): - return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" + return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize()})" def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ @@ -743,7 +786,7 @@ def fsdp_post_all_gather( columnwise_scale_inv=columnwise_scale_inv, fp8_dtype=fp8_dtype, dtype=param_dtype, - shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, + shape=(rowwise_data.shape if rowwise_data is not None else columnwise_data.shape), quantizer=self._quantizer, with_gemm_swizzled_scales=False, ) @@ -838,6 +881,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data self._quantizer = tensor._quantizer.copy() @@ -857,6 +901,33 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Cast to FP8 when setting MXFP8Tensor.data data = property(_get_data, _set_data) + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.device + if self._columnwise_data is not None: + return self._columnwise_data.device + raise RuntimeError("MXFP8Tensor has no data!") + + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.shape + if self._columnwise_data is not None: + return self._columnwise_data.shape + raise RuntimeError("MXFP8Tensor has no data!") + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._rowwise_data is not None: + return self._rowwise_data.is_cuda + if self._columnwise_data is not None: + return self._columnwise_data.is_cuda + raise RuntimeError("MXFP8Tensor has no data!") + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 101cf78a8..8ed1b4682 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -6,6 +6,7 @@ from __future__ import annotations from collections.abc import Iterable import math +import warnings from typing import Dict, Optional, Tuple, Union import functools @@ -157,6 +158,12 @@ def __init__( ) self.rht_matrix = get_rht_matrix(with_random_sign_mask, torch.cuda.current_device()) + def __getstate__(self): + """Exclude unpicklable process group from serialized state.""" + state = self.__dict__.copy() + state["amax_reduction_group"] = None + return state + def update_quantized( self, src: torch.Tensor, @@ -341,7 +348,10 @@ def make_empty( ) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_inv = torch.empty( - columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + columnwise_scale_shape, + dtype=torch.uint8, + device=device, + pin_memory=pin_memory, ) amax_columnwise = torch.zeros( 1, dtype=torch.float32, device=device, pin_memory=pin_memory @@ -440,7 +450,7 @@ def __new__( return instance def __repr__(self, *, tensor_contents=None): - return f"NVFP4Tensor, data={self.dequantize(dtype=self.dtype)})" + return f"NVFP4Tensor, data={self.dequantize()})" def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ @@ -697,6 +707,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data self._quantizer = tensor._quantizer @@ -716,6 +727,35 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Cast to FP8 when setting NVFP4Tensor.data data = property(_get_data, _set_data) + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.device + if self._columnwise_data is not None: + return self._columnwise_data.device + raise RuntimeError("NVFP4Tensor has no data!") + + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + byte_shape = self._rowwise_data.shape + return torch.Size(byte_shape[:-1] + (byte_shape[-1] * 2,)) + if self._columnwise_data is not None: + byte_shape = self._columnwise_data.shape + return torch.Size(byte_shape[1:-1] + (byte_shape[-1] * 2, byte_shape[0])) + raise RuntimeError("NVFP4Tensor has no data!") + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._rowwise_data is not None: + return self._rowwise_data.is_cuda + if self._columnwise_data is not None: + return self._columnwise_data.is_cuda + raise RuntimeError("NVFP4Tensor has no data!") + class _ViewFunc(torch.autograd.Function): """View function @@ -752,10 +792,14 @@ def forward( shape[i] = d_inferred break if shape[-1] != cur_shape[-1]: - raise RuntimeError( + warnings.warn( "NVFP4Tensor does not support reshaping inner dimension " - f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)}). " + "If you are using this for FSDP2 without compiled_autograd_enabled, " + "then ignore this warning since this view is not going to be used anywhere.", + stacklevel=2, ) + return tensor.dequantize().view(*shape) # Reshape data new_rowwise_data = None @@ -874,10 +918,14 @@ def forward( shape[i] = d_inferred break if shape[-1] != cur_shape[-1]: - raise RuntimeError( + warnings.warn( "NVFP4Tensor does not support reshaping inner dimension " - f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)}). " + "If you are using this for FSDP2 without compiled_autograd_enabled, " + "then ignore this warning since this view is not going to be used anywhere.", + stacklevel=2, ) + return tensor.dequantize().reshape(*shape) # Reshape data new_rowwise_data = None diff --git a/transformer_engine/pytorch/tensor/storage/__init__.py b/transformer_engine/pytorch/tensor/storage/__init__.py index d7a271920..44a77d975 100644 --- a/transformer_engine/pytorch/tensor/storage/__init__.py +++ b/transformer_engine/pytorch/tensor/storage/__init__.py @@ -7,3 +7,4 @@ from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401 from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401 from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401 +from .grouped_tensor_storage import GroupedTensorStorage # noqa: F401 diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 278d7dc03..52e292125 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -46,12 +46,14 @@ def __new__( quantizer: Quantizer, is_2D_scaled: bool, *args, + fake_dtype: Optional[torch.dtype] = None, **kwargs, ): if cls is Float8BlockwiseQTensorStorage: instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 else: - instance = super().__new__(cls, *args, **kwargs) + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data instance._quantizer = quantizer.copy() if quantizer is not None else None @@ -73,6 +75,24 @@ def clear(self): if t is not None: t.data = _empty_tensor() + def copy_from_storage(self, src: QuantizedTensorStorage) -> None: + """Copy data buffers from another Float8BlockwiseQTensorStorage.""" + if not isinstance(src, Float8BlockwiseQTensorStorage): + raise TypeError("copy_from_storage expects Float8BlockwiseQTensorStorage") + if self._fp8_dtype != src._fp8_dtype: + raise RuntimeError("FP8 dtype mismatch in copy_from_storage") + if self._is_2D_scaled != src._is_2D_scaled: + raise RuntimeError("Scale layout mismatch in copy_from_storage") + + def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): + if dst is not None and src_tensor is not None: + dst.copy_(src_tensor) + + _copy_optional(self._rowwise_data, src._rowwise_data) + _copy_optional(self._columnwise_data, src._columnwise_data) + _copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv) + _copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv) + def get_metadata(self) -> Dict[str, Any]: """Get this tensor's metadata.""" return { @@ -83,6 +103,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "quantizer": self._quantizer, "is_2D_scaled": self._is_2D_scaled, + "fake_dtype": self._dtype, } def prepare_for_saving( @@ -131,7 +152,9 @@ def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch. permute_dims.append(0) return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous() - def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def _dequantize_vectorwise(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if dtype is None: + dtype = self._dtype block_len = 128 q_M, q_K = 1, 1 @@ -193,10 +216,12 @@ def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch return self._transpose_dq_columnwise_output(result) return result - def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ Construct plain PyTorch tensor from Float8BlockwiseQTensor """ + if dtype is None: + dtype = self._dtype block_len = 128 if not self._is_2D_scaled: return self._dequantize_vectorwise(dtype=dtype) @@ -272,6 +297,15 @@ def size(self, *args, **kwargs): reordered.append(dims[0]) return torch.Size(reordered) + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.device + if self._columnwise_data is not None: + return self._columnwise_data.device + raise RuntimeError("Float8BlockwiseQTensorStorage has no data!") + def _create_columnwise(self): """ Update columnwise data and columnwise scale inv. Can only be used when using 2D scaling. diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index adf3ce8ae..de7f8f58e 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -14,7 +14,7 @@ from ...quantized_tensor import QuantizedTensorStorage, Quantizer -from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ...constants import TE_DType as torch_to_transformer_engine_dtype, TE_DType_To_Torch from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor @@ -35,6 +35,13 @@ def forward( if tensor._data is not None: if tensor._data.numel() == 0: return torch.empty_like(tensor._data, dtype=dtype) + if tensor._data.is_cpu: + # CPU fallback: reinterpret uint8 as FP8, cast to target dtype, scale + fp8_torch_dtype = TE_DType_To_Torch[tensor._fp8_dtype] + return ( + tensor._data.view(fp8_torch_dtype).float() + * tensor._scale_inv.to(tensor._data.device) + ).to(dtype) # Cast from FP8 return tex.dequantize(tensor, te_dtype) @@ -75,14 +82,16 @@ def __new__( data: Optional[torch.Tensor], fp8_scale_inv: torch.Tensor, fp8_dtype: TE_DType, + fake_dtype: Optional[torch.dtype] = None, data_transpose: Optional[torch.Tensor] = None, quantizer: Optional[Quantizer] = None, **kwargs, ): if cls is Float8TensorStorage: instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 else: - instance = super().__new__(cls, *args, **kwargs) + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) instance._data = data instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype @@ -104,6 +113,24 @@ def clear(self): t.data = _empty_tensor() self._transpose_invalid = True + def copy_from_storage(self, src: QuantizedTensorStorage) -> None: + """Copy data buffers from another Float8TensorStorage.""" + if not isinstance(src, Float8TensorStorage): + raise TypeError("copy_from_storage expects Float8TensorStorage") + if self._fp8_dtype != src._fp8_dtype: + raise RuntimeError("FP8 dtype mismatch in copy_from_storage") + + def _copy_optional( + dst: Optional[torch.Tensor], + src_tensor: Optional[torch.Tensor], + ): + if dst is not None and src_tensor is not None: + dst.copy_(src_tensor) + + _copy_optional(self._data, src._data) + _copy_optional(self._transpose, src._transpose) + _copy_optional(self._scale_inv, src._scale_inv) + def get_metadata(self) -> Dict[str, Any]: """Get this tensor's metadata.""" return { @@ -112,6 +139,12 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "data_transpose": self._transpose, "quantizer": self._quantizer, + "device": ( + self._data.device + if self._data is not None + else (self._transpose.device if self._transpose is not None else None) + ), + "fake_dtype": self._dtype, } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: @@ -141,8 +174,10 @@ def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = Tr return self._transpose raise ValueError("No data to get, both rowwise_data and columnwise_data are False") - def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" + if dtype is None: + dtype = self._dtype return _FromFloat8Func.forward(None, self, dtype) def size(self, *args, **kwargs): @@ -152,6 +187,15 @@ def size(self, *args, **kwargs): size = self._transpose.size(*args, **kwargs) return torch.Size([size[-1], math.prod(size[:-1])]) + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._data is not None: + return self._data.device + if self._transpose is not None: + return self._transpose.device + raise RuntimeError("Float8TensorStorage has no data!") + def view(self, shape: torch.Size): # pylint: disable=missing-function-docstring out_data = self._data.view(shape) @@ -165,6 +209,7 @@ def view(self, shape: torch.Size): data=out_data, fp8_scale_inv=self._scale_inv, fp8_dtype=self._fp8_dtype, + fake_dtype=self._dtype, data_transpose=out_transpose, quantizer=self._quantizer, ) diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py new file mode 100644 index 000000000..68097259c --- /dev/null +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -0,0 +1,994 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Grouped tensor storage class for handling collections of tensors with different shapes""" +from __future__ import annotations +from typing import Optional, Tuple, List, Union +import math + +import torch +from ...quantized_tensor import QuantizedTensorStorage, Quantizer + +from ..mxfp8_tensor import MXFP8Tensor +from ..nvfp4_tensor import NVFP4Tensor +from ..float8_tensor import Float8Tensor +from ..float8_blockwise_tensor import Float8BlockwiseQTensor +from .float8_tensor_storage import Float8TensorStorage +from .mxfp8_tensor_storage import MXFP8TensorStorage +from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from .nvfp4_tensor_storage import NVFP4TensorStorage + + +class GroupedTensorStorage: + """ + EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. + + Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode. + + Shape Representation: + - logical_shape: 2D shape representing the conceptual layout, i.e. the shape when member tensors + are flattened to 2D and stacked together (REQUIRED) + + When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N) + + When varying_first_dim(): [~sum_of_first_dims, N] where N is common + + When varying_last_dim(): [M, ~sum_of_last_dims] where M is common + + When varying_both_dims(): [1, total_elements] (fully flattened) + + - first_dims and last_dims are OPTIONAL (None if dimension is uniform) + + None first_dims: all tensors have the same first dimension + + None last_dims: all tensors have the same last dimension + + Both None: all tensors have identical shapes + + Both set: each tensor has unique shape (first_dims[i], last_dims[i]) + + Data Layout: + - ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.) + - logical_shape provides the conceptual 2D interpretation + - All data is stored on device in contiguous layout + + Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode. + """ + + @staticmethod + def _initialize_storage_fields( + instance: "GroupedTensorStorage", + shape: Tuple[int, int], + dtype: torch.dtype, + num_tensors: int, + shapes: Optional[List[Tuple[int, int]]] = None, + quantizer: Optional[Quantizer] = None, + data: Optional[torch.Tensor] = None, + columnwise_data: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + columnwise_scale_inv: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + columnwise_amax: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + first_dims: Optional[torch.Tensor] = None, + last_dims: Optional[torch.Tensor] = None, + tensor_offsets: Optional[torch.Tensor] = None, + offsets: Optional[List[int]] = None, + scale_inv_offsets: Optional[List[int]] = None, + columnwise_scale_inv_offsets: Optional[List[int]] = None, + requires_grad: bool = False, + stride: Optional[List[int]] = None, + with_gemm_swizzled_scales: bool = False, + ) -> None: + """ + Initialize a GroupedTensor. + + Args: + shape: 2D tuple representing conceptual shape + dtype: Data type of the grouped tensor + num_tensors: Number of tensors in the group + shapes: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer used for all tensors in the group + data: Row-wise data buffer (1D flattened) + columnwise_data: Column-wise data buffer (1D flattened) + scale_inv: Row-wise scale inverse buffer + columnwise_scale_inv: Column-wise scale inverse buffer + amax: Row-wise amax buffer + columnwise_amax: Column-wise amax buffer + scale: Scale buffer (for FP8-DS only) + first_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform) + offsets: Vector of integer offsets for each tensor. + """ + # `requires_grad` and `stride` are accepted for API symmetry with + # GroupedTensor.__new__ but are not relevant for storage-only + # initialization; they are intentionally ignored here. + del requires_grad + del stride + + instance.num_tensors = num_tensors + instance.quantizer = quantizer + instance.tensor_shapes = shapes + instance.fake_dtype = dtype + + # Data buffers + instance.rowwise_data = data + instance.columnwise_data = columnwise_data + instance.scale_inv = scale_inv + instance.columnwise_scale_inv = columnwise_scale_inv + instance.amax = amax + instance.columnwise_amax = columnwise_amax + instance.scale = scale + + # For convenient indexing for python GroupedTensor API. + instance.scale_inv_offsets = scale_inv_offsets + instance.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets + + # Shape information (OPTIONAL - None if dimension is uniform across all tensors) + # first_dims[i] = first dimension of tensor i (None if all tensors have same first dim) + # last_dims[i] = last dimension of tensor i (None if all tensors have same last dim) + instance.first_dims = ( + first_dims # Device pointer to int64_t array of length num_tensors (or None) + ) + instance.last_dims = ( + last_dims # Device pointer to int64_t array of length num_tensors (or None) + ) + + # Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape()) + # tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1) + # Usage: tensor_i_ptr = data.data_ptr() + tensor_offsets[i] * element_size + # If None and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions) + instance.tensor_offsets = ( + tensor_offsets # Device pointer to int64_t array of length num_tensors (or None) + ) + instance.offsets = offsets # Vector of integer offsets for each tensor. + + # Logical shape: conceptual 2D shape of the grouped data (REQUIRED) + # Represents how the 1D flattened data should be interpreted as 2D + # Always 2D with positive dimensions + instance.logical_shape = shape + + # Hold a reference to the quantized tensors that occupy same storage as the GroupedTensor. + # Used as a convenience. + instance.quantized_tensors = None + instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales + + def __new__( + cls, + shape: Tuple[int, int], + dtype: torch.dtype, + *, + num_tensors: int, + shapes: Optional[List[Tuple[int, int]]] = None, + quantizer: Optional[Quantizer] = None, + data: Optional[torch.Tensor] = None, + columnwise_data: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + columnwise_scale_inv: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + columnwise_amax: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + first_dims: Optional[torch.Tensor] = None, + last_dims: Optional[torch.Tensor] = None, + tensor_offsets: Optional[torch.Tensor] = None, + offsets: Optional[List[int]] = None, + scale_inv_offsets: Optional[List[int]] = None, + columnwise_scale_inv_offsets: Optional[List[int]] = None, + requires_grad: bool = False, + stride: Optional[List[int]] = None, + with_gemm_swizzled_scales: bool = False, + ): + instance = object.__new__(cls) + cls._initialize_storage_fields( + instance=instance, + shape=shape, + dtype=dtype, + num_tensors=num_tensors, + shapes=shapes, + quantizer=quantizer, + data=data, + columnwise_data=columnwise_data, + scale_inv=scale_inv, + columnwise_scale_inv=columnwise_scale_inv, + amax=amax, + columnwise_amax=columnwise_amax, + scale=scale, + first_dims=first_dims, + last_dims=last_dims, + tensor_offsets=tensor_offsets, + offsets=offsets, + scale_inv_offsets=scale_inv_offsets, + columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, + requires_grad=requires_grad, + stride=stride, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + ) + return instance + + def has_data(self) -> bool: + """ + Check if the tensor has row-wise data. + + Returns: + True if data buffer is initialized, False otherwise + """ + return self.rowwise_data is not None + + def has_columnwise_data(self) -> bool: + """ + Check if the tensor has column-wise data. + + Returns: + True if columnwise_data buffer is initialized, False otherwise + """ + return self.columnwise_data is not None + + def all_same_first_dim(self) -> bool: + """ + Check if all tensors in the group have the same first dimension. + + Returns: + True if first dimension is uniform across all tensors + """ + return self.first_dims is None + + def all_same_last_dim(self) -> bool: + """ + Check if all tensors in the group have the same last dimension. + + Returns: + True if last dimension is uniform across all tensors + """ + return self.last_dims is None + + def all_same_shape(self) -> bool: + """ + Check if all tensors in the group have identical shapes. + + Returns: + True if all tensors have the same shape + """ + return self.first_dims is None and self.last_dims is None + + def varying_both_dims(self) -> bool: + """ + Check if both dimensions vary across tensors. + + Returns: + True if both first and last dimensions vary + """ + return self.first_dims is not None and self.last_dims is not None + + def get_common_first_dim(self) -> int: + """ + Get the common first dimension when all tensors share it. + + Returns: + The common first dimension + + Raises: + RuntimeError: If first dimension varies across tensors or logical_shape is not 2D + """ + if not self.all_same_first_dim(): + raise RuntimeError("First dim varies across tensors") + if len(self.logical_shape) != 2: + raise RuntimeError("Logical shape must be 2D") + + if self.all_same_shape(): + # When both dims are uniform: logical_shape = [num_tensors * M, N] + return self.logical_shape[0] // self.num_tensors + # When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims] + return self.logical_shape[0] + + def get_common_last_dim(self) -> int: + """ + Get the common last dimension when all tensors share it. + + Returns: + The common last dimension + + Raises: + RuntimeError: If last dimension varies across tensors or logical_shape is not 2D + """ + if not self.all_same_last_dim(): + raise RuntimeError("Last dim varies across tensors") + if len(self.logical_shape) != 2: + raise RuntimeError("Logical shape must be 2D") + + # For both uniform and varying first dim cases: logical_shape[1] is the common last dim + return self.logical_shape[1] + + def get_dtype(self) -> torch.dtype: + """ + Get the high precision data type of the tensor. + + Returns: + The high precision dtype of the data buffer + """ + + return self.fake_dtype + + def clear(self) -> None: + """ + Reset tensor data and clear all buffers. + """ + self.rowwise_data = None + self.columnwise_data = None + self.scale_inv = None + self.columnwise_scale_inv = None + self.amax = None + self.columnwise_amax = None + self.scale = None + self.first_dims = None + self.last_dims = None + self.tensor_offsets = None + self.logical_shape = (0, 0) + self.num_tensors = 0 + self.quantizer = None + self.quantized_tensors = None + self.offsets = None + self.scale_inv_offsets = None + self.columnwise_scale_inv_offsets = None + self.tensor_shapes = [] + self.fake_dtype = torch.float32 + + def __repr__(self) -> str: + """String representation of the GroupedTensorStorage.""" + return ( + f"GroupedTensorStorage(num_tensors={self.num_tensors}, " + f"shapes={self.tensor_shapes}, " + f"logical_shape={self.logical_shape}, " + f"quantizer={self.quantizer}, " + f"dtype={self.get_dtype()})" + ) + + @staticmethod + def make_grouped_tensor_with_shapes( + num_tensors: int, + shapes: List[Tuple[int, int]], + quantizer: Optional[Quantizer] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> GroupedTensorStorage: + """ + Create a GroupedTensor for storing multiple weight tensors of the same shape. + + Args: + num_tensors: Number of tensors + shapes: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer used for all tensors + device: Device to allocate tensors on, defaults to current cuda device + dtype: Data type of the tensor (for high precision case) + + Returns: + A GroupedTensor. + """ + + # First dim + first_dim_list = [s[0] for s in shapes] + uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list) + logical_first_dim = sum(first_dim_list) + if uniform_first_dim: + first_dims = None + else: + first_dims = torch.tensor([s[0] for s in shapes], dtype=torch.int64, device=device) + + # Last dim + last_dim_list = [s[1] for s in shapes] + logical_last_dim = last_dim_list[0] + assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform" + + return GroupedTensorStorage.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=first_dims, + last_dims=None, + logical_first_dim=logical_first_dim, + logical_last_dim=logical_last_dim, + quantizer=quantizer, + device=device, + dtype=dtype, + ) + + @staticmethod + def make_grouped_tensor( + num_tensors: int, + first_dims: Optional[torch.Tensor], + last_dims: Optional[torch.Tensor], + logical_first_dim: int, + logical_last_dim: int, + quantizer: Optional[Quantizer] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> GroupedTensorStorage: + """ + Create a GroupedTensor for storing multiple weight tensors of the same shape. + + Args: + num_tensors: Number of tensors + first_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + logical_first_dim: Logical first dimension + logical_last_dim: Logical last dimension + quantizer: Quantizer used for all tensors. Used to figure out recipe + and what to allocate. + device: Device to allocate tensors on, defaults to current cuda device + dtype: Data type of the tensor (for high precision case) + + Returns: + A GroupedTensor. + """ + + # Set device + if device is None: + device = torch.cuda.current_device() + + # Shape patterns and validation. + all_same_first = first_dims is None + all_same_last = last_dims is None + + assert all_same_last, "Last dim must be uniform for GroupedTensor" + assert logical_first_dim > 0, "Logical first dim must be positive for GroupedTensor" + assert logical_last_dim > 0, "Logical last dim must be positive for GroupedTensor" + + # assert ( + # logical_first_dim % 128 == 0 + # ), "Logical first dim must be divisible by 128" + # assert logical_last_dim % 128 == 0, "Logical last dim must be divisible by 128" + + # Calculate tensor offsets (cumulative element offsets) + tensor_offsets = None + offsets = None + shape = [] + if not all_same_first: + # Need explicit offsets for non-uniform shapes + # Offsets are based on number of elements and not pointers. + # Kernels need to calculate precise pointers based on size of elements. + + # TODO(ksivaman): Single kernel + remove the host offset calculation. + tensor_offsets = torch.cat( + [ + torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), + torch.cumsum(first_dims * logical_last_dim, dim=0), + ] + ) + offsets = tensor_offsets.tolist() + first_dims_list = first_dims.tolist() + for i in range(num_tensors): + shape.append((first_dims_list[i], logical_last_dim)) + else: + offsets = [ + i * logical_first_dim * logical_last_dim // num_tensors + for i in range(num_tensors + 1) + ] + for i in range(num_tensors): + shape.append((logical_first_dim // num_tensors, logical_last_dim)) + + # Calculate logical shape based + logical_shape = (logical_first_dim, logical_last_dim) + + no_quantization = quantizer is None + + rowwise_usage = quantizer.rowwise_usage if not no_quantization else True + columnwise_usage = quantizer.columnwise_usage if not no_quantization else False + + # Calculate total elements across all tensors + total_elements = logical_first_dim * logical_last_dim + + data = None + columnwise_data = None + scale_inv = None + columnwise_scale_inv = None + amax = None + columnwise_amax = None + scale = None + scale_inv_offsets = None + columnwise_scale_inv_offsets = None + if no_quantization: + assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=dtype, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=dtype, device=device) + elif quantizer._get_compatible_recipe().mxfp8(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse buffer for MXFP8 - complex shape based on block scaling + # For grouped tensors, we need to calculate scale_inv size for all tensors + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_elements = math.prod(scale_inv_shape) + total_scale_elements += scale_elements + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse buffer + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + columnwise_scale_elements = math.prod(scale_inv_shape) + total_columnwise_scale_elements += columnwise_scale_elements + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.uint8, device=device + ) + elif quantizer._get_compatible_recipe().delayed(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - one per tensor + scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors + scale_inv_offsets = list(range(num_tensors + 1)) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse - one per tensor + columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors + columnwise_scale_inv_offsets = list(range(num_tensors + 1)) + + # Amax buffer for delayed scaling - one per tensor + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + elif quantizer._get_compatible_recipe().nvfp4(): + + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte) + data = torch.empty((total_elements) // 2, dtype=torch.uint8, device=device) + # Scale inverse buffer for NVFP4 - complex shape based on block scaling + # For simplicity, calculate total scale elements needed + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + total_scale_elements += math.prod(scale_inv_shape) + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) + # Amax buffer - one per tensor + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8, FP4 packed) + columnwise_data = torch.empty( + (total_elements) // 2, dtype=torch.uint8, device=device + ) + # Columnwise scale inverse buffer + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) + total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.uint8, device=device + ) + # Columnwise amax buffer - one per tensor + columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + elif quantizer._get_compatible_recipe().float8_block_scaling(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - size depends on block configuration + # For simplicity, calculate total scale elements needed + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + total_scale_elements += math.prod(scale_inv_shape) + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.float32, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) + total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.float32, device=device + ) + elif quantizer._get_compatible_recipe().float8_current_scaling(): + # Current scaling - per-tensor scaling computed on the fly + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - one per tensor + scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors + scale_inv_offsets = list(range(num_tensors + 1)) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse - one per tensor + columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors + columnwise_scale_inv_offsets = list(range(num_tensors + 1)) + + # Scale and amax buffers for current scaling - one per tensor + scale = torch.empty(num_tensors, dtype=torch.float32, device=device) + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + else: + raise ValueError(f"Unsupported quantizer for GroupedTensor: {quantizer}") + + # Construct wrapper vs storage based on quantizer.internal. + # If quantizer is None (high precision path), default to wrapper class. + # TODO(ksivaman): Properly handle high precision path. + internal = False if quantizer is None else quantizer.internal + if internal: + grouped_tensor_class = GroupedTensorStorage + else: + from ..grouped_tensor import GroupedTensor + + grouped_tensor_class = GroupedTensor + + grouped_tensor = grouped_tensor_class( + logical_shape, + dtype, + num_tensors=num_tensors, + shapes=shape, + quantizer=quantizer, + data=data, + columnwise_data=columnwise_data, + scale_inv=scale_inv, + columnwise_scale_inv=columnwise_scale_inv, + amax=amax, + columnwise_amax=columnwise_amax, + scale=scale, + first_dims=first_dims, + last_dims=last_dims, + tensor_offsets=tensor_offsets, + offsets=offsets, + scale_inv_offsets=scale_inv_offsets, + columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, + with_gemm_swizzled_scales=( + quantizer.optimize_for_gemm if quantizer is not None else False + ), + ) + + grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() + return grouped_tensor + + def split_into_quantized_tensors( + self, + ) -> List[Union[QuantizedTensorStorage, torch.Tensor]]: + """ + Split the GroupedTensor into a list of `num_tensors` + quantized tensors based on the quantizer. No additional memory allocation is performed, + so the tensors returned are the same as the ones used to create the GroupedTensor. + + If quantizer is None, returns normal torch tensors. + If quantizer.internal is True, returns QuantizedTensorStorage. + Otherwise, returns QuantizedTensor. + + This API is NOT graph safe, but can be used for testing & debugging. + + TODO(ksivaman): Block cases where any dims are varying. This is needed only + to expose the weights as separate parameters. + """ + + result = [] + + no_quantization = self.quantizer is None + + # if self.tensor_shapes is None, then trigger D2H copy and get the shape (not graph safe) + if self.tensor_shapes is None: + first_dims_list = ( + [self.logical_shape[0]] * self.num_tensors + if self.first_dims is None + else self.first_dims.tolist() + ) + last_dims_list = ( + [self.logical_shape[1]] * self.num_tensors + if self.last_dims is None + else self.last_dims.tolist() + ) + shape_list = [] + for i in range(self.num_tensors): + shape_list.append((first_dims_list[i], last_dims_list[i])) + self.tensor_shapes = shape_list + + # edge case: handle the case where tensor_offsets is given but offsets is not set + if self.offsets is None and self.tensor_offsets is not None: + self.offsets = self.tensor_offsets.tolist() + + # Case 1: No quantization - return regular torch tensors + if no_quantization: + for i in range(self.num_tensors): + # Get tensor shape + tensor_shape = self.tensor_shapes[i] + + # Get tensor data slice + if self.offsets is not None: + start_offset = self.offsets[i] + numel = tensor_shape[0] * tensor_shape[1] + end_offset = start_offset + numel + + if self.has_data(): + tensor_data = self.rowwise_data[start_offset:end_offset].view(tensor_shape) + result.append(tensor_data) + elif self.has_columnwise_data(): + tensor_data = self.columnwise_data[start_offset:end_offset].view( + tensor_shape + ) + result.append(tensor_data) + else: + raise RuntimeError("GroupedTensor has no data to split") + else: + # All same shape case + numel = tensor_shape[0] * tensor_shape[1] + start_offset = i * numel + end_offset = start_offset + numel + + if self.has_data(): + tensor_data = self.rowwise_data[start_offset:end_offset].view(tensor_shape) + result.append(tensor_data) + elif self.has_columnwise_data(): + tensor_data = self.columnwise_data[start_offset:end_offset].view( + tensor_shape + ) + result.append(tensor_data) + else: + raise RuntimeError("GroupedTensor has no data to split") + + return result + + # Case 2: Quantized tensors + recipe = self.quantizer._get_compatible_recipe() + + # populate scale_inv_offsets from the tensor offsets + if self.scale_inv is not None and self.scale_inv_offsets is None: + if recipe.nvfp4(): + self.scale_inv_offsets = self.tensor_offsets // 16 + if recipe.mxfp8(): + self.scale_inv_offsets = self.tensor_offsets // 32 + if self.columnwise_scale_inv is not None and self.columnwise_scale_inv_offsets is None: + if recipe.nvfp4(): + self.columnwise_scale_inv_offsets = self.tensor_offsets // 16 + if recipe.mxfp8(): + self.columnwise_scale_inv_offsets = self.tensor_offsets // 32 + + for i in range(self.num_tensors): + quantizer = self.quantizer + # Get tensor shape + tensor_shape = self.tensor_shapes[i] + numel = tensor_shape[0] * tensor_shape[1] + + # Get data offsets + if self.offsets is not None: + data_start = self.offsets[i] + data_end = data_start + numel + else: + # All same shape + data_start = i * numel + data_end = data_start + numel + + # Special shape handling for NVFP4. + nvfp4 = quantizer._get_compatible_recipe().nvfp4() + if nvfp4: + data_start = data_start // 2 + data_end = data_end // 2 + + # Extract rowwise and columnwise data + rowwise_data = None + columnwise_data = None + + if self.has_data(): + if nvfp4: + rowwise_tensor_shape = quantizer.convert_shape_for_fp4(tensor_shape) + else: + rowwise_tensor_shape = tensor_shape + rowwise_data = self.rowwise_data[data_start:data_end].view(rowwise_tensor_shape) + + if self.has_columnwise_data(): + columnwise_tensor_shape = quantizer.get_columnwise_shape(tensor_shape) + if nvfp4: + columnwise_tensor_shape = quantizer.convert_shape_for_fp4( + columnwise_tensor_shape + ) + columnwise_data = self.columnwise_data[data_start:data_end].view( + columnwise_tensor_shape + ) + + # MXFP8 format + if recipe.mxfp8(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + # for paged stashing, scale_inv should depend on the split offsets + scale_end = self.scale_inv_offsets[i + 1] + + # Calculate expected scale shape for MXFP8 + scale_shape = quantizer.get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + # for paged stashing, columnwise_scale_inv should depend on the split offsets + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + + cscale_shape = quantizer.get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + if quantizer.internal: + mxfp8_tensor_class = MXFP8TensorStorage + else: + mxfp8_tensor_class = MXFP8Tensor + tensor = mxfp8_tensor_class( + shape=tensor_shape, + dtype=self.fake_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=quantizer.dtype, + quantizer=quantizer, + with_gemm_swizzled_scales=quantizer.optimize_for_gemm, + ) + result.append(tensor) + + # Delayed scaling or current scaling (both use Float8TensorStorage) + elif recipe.delayed() or recipe.float8_current_scaling(): + # Scale inverse - one per tensor + scale_inv = None + if self.scale_inv is not None: + scale_inv = self.scale_inv[i : i + 1] + + if quantizer.internal: + float8_tensor_class = Float8TensorStorage + else: + float8_tensor_class = Float8Tensor + + tensor = float8_tensor_class( + shape=tensor_shape, + dtype=self.fake_dtype, + data=rowwise_data, + fp8_scale_inv=scale_inv, + fp8_dtype=quantizer.dtype, + quantizer=quantizer, + data_transpose=columnwise_data, + ) + result.append(tensor) + + # Float8 block scaling + elif recipe.float8_block_scaling(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + # for paged stashing, scale_inv should depend on the split offsets + scale_end = self.scale_inv_offsets[i + 1] + + # Get scale shape from quantizer + scale_shape = quantizer.get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + # for paged stashing, columnwise_scale_inv should depend on the split offsets + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + + # Get columnwise scale shape from quantizer + cscale_shape = quantizer.get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + # Compute is_2D_scaled and data_format from quantizer attributes + is_2D_scaled = quantizer.block_scaling_dim == 2 + + if quantizer.internal: + float8_blockwise_q_tensor_class = Float8BlockwiseQTensorStorage + else: + float8_blockwise_q_tensor_class = Float8BlockwiseQTensor + + tensor = float8_blockwise_q_tensor_class( + shape=tensor_shape, + dtype=self.fake_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=quantizer.dtype, + quantizer=quantizer, + is_2D_scaled=is_2D_scaled, + ) + result.append(tensor) + + # NVFP4 format + elif recipe.nvfp4(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + amax_rowwise = None + amax_columnwise = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + # for paged stashing, scale_inv should depend on the split offsets + scale_end = self.scale_inv_offsets[i + 1] + + # Get scale shape from quantizer + scale_shape = quantizer.get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + # for paged stashing, columnwise_scale_inv should depend on the split offsets + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + + # Get columnwise scale shape from quantizer + cscale_shape = quantizer.get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + # Extract amax - one per tensor + if self.amax is not None: + amax_rowwise = self.amax[i : i + 1] + + if self.columnwise_amax is not None: + amax_columnwise = self.columnwise_amax[i : i + 1] + + if quantizer.internal: + nvfp4_tensor_class = NVFP4TensorStorage + else: + nvfp4_tensor_class = NVFP4Tensor + + tensor = nvfp4_tensor_class( + shape=tensor_shape, + dtype=self.fake_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + fp4_dtype=quantizer.dtype, + quantizer=quantizer, + with_gemm_swizzled_scales=quantizer.optimize_for_gemm, + ) + result.append(tensor) + + else: + raise ValueError(f"Unsupported quantization recipe: {recipe}") + + return result + + def quantize( + self, + tensors: List[torch.Tensor], + noop_flag: Optional[torch.Tensor] = None, + ) -> Tuple[QuantizedTensorStorage, ...]: + """ + Quantize the GroupedTensor inplace. + """ + + quantized_tensors = self.split_into_quantized_tensors() + for i in range(self.num_tensors): + self.quantizer.update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag) + return quantized_tensors diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 3f9aaa4d0..19b2ded03 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -41,9 +41,9 @@ def forward( return te_dequantize_triton(tensor, dtype) # Make sure FP8 data is in expected format - if tensor._rowwise_data is not None: + if tensor._rowwise_data is not None or tensor._columnwise_data is not None: return tex.dequantize(tensor, dtype) - raise NotImplementedError("Casting back from the transpose not implemented yet!") + raise ValueError("Cannot dequantize MXFP8 tensor with no data") @staticmethod def backward( @@ -92,12 +92,14 @@ def __new__( quantizer: Optional[Quantizer], with_gemm_swizzled_scales: bool, *args, + fake_dtype: Optional[torch.dtype] = None, **kwargs, ): if cls is MXFP8TensorStorage: instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 else: - instance = super().__new__(cls, *args, **kwargs) + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data instance._rowwise_scale_inv = rowwise_scale_inv @@ -119,6 +121,24 @@ def clear(self): if t is not None: t.data = _empty_tensor() + def copy_from_storage(self, src: QuantizedTensorStorage) -> None: + """Copy data buffers from another MXFP8TensorStorage.""" + if not isinstance(src, MXFP8TensorStorage): + raise TypeError("copy_from_storage expects MXFP8TensorStorage") + if self._fp8_dtype != src._fp8_dtype: + raise RuntimeError("FP8 dtype mismatch in copy_from_storage") + if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales: + raise RuntimeError("Scale layout mismatch in copy_from_storage") + + def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): + if dst is not None and src_tensor is not None: + dst.copy_(src_tensor) + + _copy_optional(self._rowwise_data, src._rowwise_data) + _copy_optional(self._columnwise_data, src._columnwise_data) + _copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv) + _copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv) + def get_metadata(self) -> Dict[str, Any]: """Get this tensor's metadata.""" return { @@ -129,6 +149,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, + "fake_dtype": self._dtype, } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]: @@ -165,8 +186,10 @@ def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = Tr return self._columnwise_data raise ValueError("No data to get, both rowwise_data and columnwise_data are False") - def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" + if dtype is None: + dtype = self._dtype return _FromMXFP8Func.forward(None, self, dtype) def size(self, *args, **kwargs): @@ -175,6 +198,15 @@ def size(self, *args, **kwargs): return self._rowwise_data.size(*args, **kwargs) return self._columnwise_data.size(*args, **kwargs) + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.device + if self._columnwise_data is not None: + return self._columnwise_data.device + raise RuntimeError("MXFP8TensorStorage has no data!") + def view(self, shape: torch.Size): # pylint: disable=missing-function-docstring @@ -219,6 +251,7 @@ def view(self, shape: torch.Size): fp8_dtype=self._fp8_dtype, quantizer=self._quantizer, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + fake_dtype=self._dtype, ) def __repr__(self): diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index b064d711c..fb163c903 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -106,10 +106,14 @@ def __new__( quantizer: Optional[Quantizer], with_gemm_swizzled_scales: bool, *args, + fake_dtype: Optional[torch.dtype] = None, **kwargs, ): - - instance = super().__new__(cls, *args, **kwargs) + if cls is NVFP4TensorStorage: + instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 + else: + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data @@ -136,6 +140,26 @@ def clear(self): if t is not None: t.data = _empty_tensor() + def copy_from_storage(self, src: QuantizedTensorStorage) -> None: + """Copy data buffers from another NVFP4TensorStorage.""" + if not isinstance(src, NVFP4TensorStorage): + raise TypeError("copy_from_storage expects NVFP4TensorStorage") + if self._fp4_dtype != src._fp4_dtype: + raise RuntimeError("FP4 dtype mismatch in copy_from_storage") + if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales: + raise RuntimeError("Scale layout mismatch in copy_from_storage") + + def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): + if dst is not None and src_tensor is not None: + dst.copy_(src_tensor) + + _copy_optional(self._rowwise_data, src._rowwise_data) + _copy_optional(self._columnwise_data, src._columnwise_data) + _copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv) + _copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv) + _copy_optional(self._amax_rowwise, src._amax_rowwise) + _copy_optional(self._amax_columnwise, src._amax_columnwise) + def get_metadata(self) -> Dict[str, Any]: """Get this tensor's metadata.""" return { @@ -148,6 +172,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp4_dtype": self._fp4_dtype, "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, + "fake_dtype": self._dtype, } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]: @@ -184,8 +209,10 @@ def get_data_tensors(self): """Get this Tensor's data.""" return self._rowwise_data, self._columnwise_data - def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" + if dtype is None: + dtype = self._dtype return _FromNVFP4Func.forward(None, self, dtype) def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: @@ -208,6 +235,15 @@ def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: return torch.Size(shape) return shape[dim] + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.device + if self._columnwise_data is not None: + return self._columnwise_data.device + raise RuntimeError("NVFP4TensorStorage has no data!") + def view(self, shape: torch.Size): # pylint: disable=missing-function-docstring @@ -266,6 +302,7 @@ def view(self, shape: torch.Size): quantizer=self._quantizer, fp4_dtype=self._fp4_dtype, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + fake_dtype=self._dtype, ) def __repr__(self): @@ -296,6 +333,20 @@ def update_usage( if columnwise_usage is None: columnwise_usage = self._columnwise_data is not None + # If both rowwise and columnwise are requested, create columnwise from rowwise if needed + if rowwise_usage and columnwise_usage: + if ( + self._rowwise_data is None + or self._rowwise_scale_inv is None + or self._amax_rowwise is None + ): + raise RuntimeError( + "Cannot update to rowwise and columnwise usage because rowwise data is None." + ) + if self._columnwise_data is None or self._columnwise_scale_inv is None: + self._create_columnwise() + return + # Update row-scaled data if rowwise_usage: if self._rowwise_data is None: @@ -336,3 +387,61 @@ def update_usage( self._columnwise_data = None self._columnwise_scale_inv = None self._amax_columnwise = None + + def _create_columnwise(self): + """ + Update columnwise data and columnwise scale inv. Can only be used when using 2D scaling. + """ + if self._quantizer is None or not self._quantizer.with_2d_quantization: + raise RuntimeError("Cannot create columnwise data without 2D quantization enabled.") + rowwise_data = self._rowwise_data + if not rowwise_data.is_contiguous(): + rowwise_data = rowwise_data.contiguous() + # NVFP4 requires a specialized transpose that handles nibble repacking + self._columnwise_data = tex.nvfp4_data_transpose(rowwise_data, out=self._columnwise_data) + if self._columnwise_scale_inv is None: + if self._quantizer is None: + raise RuntimeError("Cannot create columnwise scale inverse: quantizer is None.") + # Use logical shape (self.size()), not packed byte shape (rowwise_data.shape) + # NVFP4 packs 2 elements per byte, so rowwise_data.shape[-1] is K/2 + logical_shape = self.size() + columnwise_scale_inv_shape = self._quantizer.get_scale_shape(logical_shape, True) + self._columnwise_scale_inv = torch.empty( + columnwise_scale_inv_shape, + dtype=self._rowwise_scale_inv.dtype, + device=self._rowwise_scale_inv.device, + ) + if len(self._rowwise_scale_inv.shape) != 2: + raise ValueError( + "Expected rowwise_scale_inv to be 2D, but got" + f" {len(self._rowwise_scale_inv.shape)}D with shape" + f" {self._rowwise_scale_inv.shape}." + ) + if len(self._columnwise_scale_inv.shape) != 2: + raise ValueError( + "Expected columnwise_scale_inv to be 2D, but got" + f" {len(self._columnwise_scale_inv.shape)}D with shape" + f" {self._columnwise_scale_inv.shape}." + ) + + # rowwise_scale_inv has shape [M_padded, K_tiles] where each tile's scale + # is repeated 16 times (once per row in the 16x16 tile). + # columnwise_scale_inv has shape [K_padded, M_tiles] where scales are + # repeated 16 times per tile row. + TILE_SIZE = 16 + logical_shape = self.size() + M, K = logical_shape[0], logical_shape[-1] + M_tiles = (M + TILE_SIZE - 1) // TILE_SIZE + K_tiles = (K + TILE_SIZE - 1) // TILE_SIZE + + tex.nvfp4_2d_scale_transpose( + self._rowwise_scale_inv, + self._columnwise_scale_inv, + M_tiles, + K_tiles, + ) + + # Also set columnwise amax (same as rowwise since it's just transposed data) + if self._amax_columnwise is None: + self._amax_columnwise = torch.empty_like(self._amax_rowwise) + self._amax_columnwise.copy_(self._amax_rowwise) diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 3c72d7861..e3e4c8bda 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -4,7 +4,7 @@ # # See LICENSE for license information. -"""Helper functions for using fp8 tensors as weights""" +"""Helper functions for using fp8/nvfp4 tensors as weights""" from typing import Optional, Union, List import torch @@ -18,10 +18,12 @@ from ..quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer +from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from ..optimizers.multi_tensor_apply import multi_tensor_applier from ..utils import is_non_tn_fp8_gemm_supported, is_fp8_fnuz +from ..constants import NVFP4_BLOCK_SCALING_SIZE def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): @@ -37,23 +39,41 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): """ if isinstance(tensor, Float8Tensor): old_raw_data = tensor._data - assert old_raw_data.dtype == new_raw_data.dtype, "The data types of raw data don't match" + if old_raw_data.dtype != new_raw_data.dtype: + raise ValueError( + "The data types of raw data don't match: " + f"old dtype={old_raw_data.dtype}, new dtype={new_raw_data.dtype}" + ) new_raw_data.detach().copy_(old_raw_data) tensor._data = new_raw_data del old_raw_data elif isinstance(tensor, Float8BlockwiseQTensor): old_raw_data = tensor._rowwise_data - assert old_raw_data.dtype == new_raw_data.dtype, "The data types of raw data don't match" + if old_raw_data.dtype != new_raw_data.dtype: + raise ValueError( + "The data types of raw data don't match: " + f"old dtype={old_raw_data.dtype}, new dtype={new_raw_data.dtype}" + ) new_raw_data.detach().copy_(old_raw_data) tensor._rowwise_data = new_raw_data del old_raw_data + elif isinstance(tensor, NVFP4Tensor): + old_rowwise = tensor._rowwise_data + if old_rowwise.dtype != new_raw_data.dtype: + raise ValueError( + f"The data types of raw data don't match: {old_rowwise.dtype} vs" + f" {new_raw_data.dtype}" + ) + new_raw_data.detach().copy_(old_rowwise) + tensor._rowwise_data = new_raw_data + del old_rowwise elif isinstance(tensor, MXFP8Tensor): raise NotImplementedError("replace_raw_data for MXFP8Tensor is not supported yet") else: raise ValueError(f"replace_raw_data for {type(tensor)} is not supported yet") -def cast_master_weights_to_fp8( +def quantize_master_weights( model_weights, master_weights, start_offsets, @@ -61,15 +81,15 @@ def cast_master_weights_to_fp8( fsdp_shard_model_weights=None, manual_post_all_gather_processing=False, ): - r"""Helper function to cast master weights to FP8 primary weights. + r"""Helper function to cast master weights to quantized (FP8/NVFP4) primary weights. This is intended for use with ZeRO/FSDP. Each rank has a shard of the master weights (possibly empty) and a full copy of the model - weights. + weights. Supports FP8 (delayed, current, blockwise, MXFP8) and NVFP4 quantization. Parameters ---------- - model_weights : list of FP8 weights. + model_weights : list of quantized weights (FP8 or NVFP4). master_weights : list of master weights. Typically they are FP32 weights. start_offsets : list of integers, the starting index of the master weight in the model weight. master_weight may be smaller than model_weight because it could be distributed @@ -92,6 +112,7 @@ def cast_master_weights_to_fp8( current_scaling_params = [] blockwise_scaling_params = [] mxfp8_scaling_params = [] + nvfp4_params = [] if fsdp_shard_model_weights is None: use_fsdp_shard_model_weights = False @@ -99,6 +120,46 @@ def cast_master_weights_to_fp8( else: use_fsdp_shard_model_weights = True + # Batch convert master_weights to model dtype for NVFP4 (single kernel instead of N kernels) + # Check if there are any NVFP4 weights + has_nvfp4 = any( + isinstance(w._get_quantizer(), NVFP4Quantizer) + for w in model_weights + if hasattr(w, "_get_quantizer") + ) + if has_nvfp4 and len(model_weights) > 0: + # Find target dtype from first NVFP4 weight + target_dtype = None + for w in model_weights: + if hasattr(w, "_get_quantizer") and isinstance(w._get_quantizer(), NVFP4Quantizer): + target_dtype = w.dtype + break + + if target_dtype is not None: + # Collect non-None master_weights and their indices + non_none_indices = [] + non_none_weights = [] + sizes = [] + for i, mw in enumerate(master_weights): + if mw is not None: + non_none_indices.append(i) + non_none_weights.append(mw.view(-1)) + sizes.append(mw.numel()) + + if len(non_none_weights) > 0 and non_none_weights[0].dtype != target_dtype: + # Concatenate, convert once, then split + concatenated = torch.cat(non_none_weights) + converted = concatenated.to(target_dtype) + split_weights = torch.split(converted, sizes) + + # Rebuild master_weights list with converted tensors + converted_master_weights = list(master_weights) + for idx, split_w, orig_mw in zip( + non_none_indices, split_weights, [master_weights[i] for i in non_none_indices] + ): + converted_master_weights[idx] = split_w.view(orig_mw.shape) + master_weights = converted_master_weights + for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip( model_weights, master_weights, start_offsets, fsdp_shard_model_weights ): @@ -117,34 +178,42 @@ def cast_master_weights_to_fp8( if hasattr(model_weight, "clear_high_precision_init_val"): model_weight.clear_high_precision_init_val() - if master_weight is not None: - # When not using fp8_primary_weights, the master_weight (fp32) is first cast to - # bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when - # fp8_primary_weights is enabled, we still keep this logic to keep numerical - # consistency. So here we cast the master_weight to model_weight.dtype. - master_weight = master_weight.to(model_weight.dtype) - quantizer = model_weight._get_quantizer() - if isinstance(quantizer, Float8Quantizer): - delayed_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - elif isinstance(quantizer, Float8CurrentScalingQuantizer): - current_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - elif isinstance(quantizer, Float8BlockQuantizer): - blockwise_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - elif isinstance(quantizer, MXFP8Quantizer): - mxfp8_scaling_params.append( + + if isinstance(quantizer, NVFP4Quantizer): + # NVFP4: master_weight dtype conversion already done above + nvfp4_params.append( (model_weight, master_weight, start_offset, fsdp_shard_model_weight) ) else: - raise ValueError( - f"cast_master_weights_to_fp8 for {type(quantizer)} is not supported yet" - ) + # FP8: convert master_weight to model dtype + if master_weight is not None: + # When not using fp8_primary_weights, the master_weight (fp32) is first cast to + # bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when + # fp8_primary_weights is enabled, we still keep this logic to keep numerical + # consistency. So here we cast the master_weight to model_weight.dtype. + master_weight = master_weight.to(model_weight.dtype) + + if isinstance(quantizer, Float8Quantizer): + delayed_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, Float8CurrentScalingQuantizer): + current_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, Float8BlockQuantizer): + blockwise_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, MXFP8Quantizer): + mxfp8_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + else: + raise ValueError( + f"quantize_master_weights for {type(quantizer)} is not supported yet" + ) extra_args = [group, use_fsdp_shard_model_weights, manual_post_all_gather_processing] if len(delayed_scaling_params) > 0: @@ -155,6 +224,32 @@ def cast_master_weights_to_fp8( _cast_master_weights_to_fp8_blockwise_scaling(blockwise_scaling_params, *extra_args) if len(mxfp8_scaling_params) > 0: _cast_master_weights_to_fp8_mxfp8_scaling(mxfp8_scaling_params, *extra_args) + if len(nvfp4_params) > 0: + _cast_master_weights_to_nvfp4_2d(nvfp4_params, *extra_args) + + +def cast_master_weights_to_fp8( + model_weights, + master_weights, + start_offsets, + group, + fsdp_shard_model_weights=None, + manual_post_all_gather_processing=False, +): + r"""Helper function to cast master weights to FP8 primary weights. + + .. deprecated:: + Use :func:`quantize_master_weights` instead. + + """ + quantize_master_weights( + model_weights, + master_weights, + start_offsets, + group, + fsdp_shard_model_weights, + manual_post_all_gather_processing, + ) def _cast_master_weights_to_fp8_delayed_scaling( @@ -195,10 +290,16 @@ def _cast_master_weights_to_fp8_delayed_scaling( continue # If master weight is not None, start_offset must be a valid value. - assert start_offset is not None - assert start_offset >= 0 + if start_offset is None: + raise ValueError("start_offset must not be None when master_weight is provided") + if start_offset < 0: + raise ValueError(f"start_offset must be non-negative, got {start_offset}") end_offset = start_offset + master_weight.numel() - assert end_offset <= model_weight.numel() + if end_offset > model_weight.numel(): + raise ValueError( + f"end_offset ({end_offset}) exceeds model_weight numel ({model_weight.numel()}), " + f"start_offset={start_offset}, master_weight numel={master_weight.numel()}" + ) # master_weight may be smaller than model_weight because it could be distributed across # multiple ranks. So we need to create a dummy weight using the raw data from model_weight. @@ -282,9 +383,21 @@ def _cast_master_weights_to_fp8_current_scaling( # Make sure all the model weights have the same numerical options. quantizer = model_weight._get_quantizer() - assert quantizer.dtype == fp8_dtype - assert quantizer.force_pow_2_scales == force_pow_2_scales - assert quantizer.amax_epsilon == amax_epsilon + if quantizer.dtype != fp8_dtype: + raise ValueError( + "All model weights must have the same fp8 dtype, " + f"expected {fp8_dtype} but got {quantizer.dtype}" + ) + if quantizer.force_pow_2_scales != force_pow_2_scales: + raise ValueError( + "All model weights must have the same force_pow_2_scales, " + f"expected {force_pow_2_scales} but got {quantizer.force_pow_2_scales}" + ) + if quantizer.amax_epsilon != amax_epsilon: + raise ValueError( + "All model weights must have the same amax_epsilon, " + f"expected {amax_epsilon} but got {quantizer.amax_epsilon}" + ) scales.append(quantizer.scale.view(1)) scale_invs.append(model_weight._scale_inv.view(1)) @@ -398,19 +511,47 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # Make sure all the model weights have the same numerical options. quantizer = model_weight._get_quantizer() - assert block_len == quantizer.block_len - assert fp8_dtype == quantizer.dtype - assert force_pow_2_scales == quantizer.force_pow_2_scales - assert amax_epsilon == quantizer.amax_epsilon + if block_len != quantizer.block_len: + raise ValueError( + "All model weights must have the same block_len, " + f"expected {block_len} but got {quantizer.block_len}" + ) + if fp8_dtype != quantizer.dtype: + raise ValueError( + "All model weights must have the same fp8 dtype, " + f"expected {fp8_dtype} but got {quantizer.dtype}" + ) + if force_pow_2_scales != quantizer.force_pow_2_scales: + raise ValueError( + "All model weights must have the same force_pow_2_scales, " + f"expected {force_pow_2_scales} but got {quantizer.force_pow_2_scales}" + ) + if amax_epsilon != quantizer.amax_epsilon: + raise ValueError( + "All model weights must have the same amax_epsilon, " + f"expected {amax_epsilon} but got {quantizer.amax_epsilon}" + ) scale_shape = quantizer.get_scale_shape(model_weight.shape, False) amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) scale = torch.empty(scale_shape, dtype=torch.float32, device=device) scale_inv = model_weight._rowwise_scale_inv - assert len(scale_shape) == 2 - assert len(scale_inv.shape) == 2 - assert scale_inv.shape[0] == scale_shape[0] - assert scale_inv.shape[1] == scale_shape[1] + if len(scale_shape) != 2: + raise ValueError(f"scale_shape must be 2D, got {len(scale_shape)}D shape {scale_shape}") + if len(scale_inv.shape) != 2: + raise ValueError( + f"scale_inv must be 2D, got {len(scale_inv.shape)}D shape {scale_inv.shape}" + ) + if scale_inv.shape[0] != scale_shape[0]: + raise ValueError( + f"scale_inv dim 0 mismatch: scale_inv.shape={scale_inv.shape}," + f" scale_shape={scale_shape}" + ) + if scale_inv.shape[1] != scale_shape[1]: + raise ValueError( + f"scale_inv dim 1 mismatch: scale_inv.shape={scale_inv.shape}," + f" scale_shape={scale_shape}" + ) amaxes.append(amax) scales.append(scale) @@ -418,7 +559,11 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # Compute amax of the master weight and store it in packed_amaxes. if master_weight is not None: - assert len(model_weight.shape) == 2 + if len(model_weight.shape) != 2: + raise ValueError( + "model_weight must be 2D for blockwise scaling, " + f"got {len(model_weight.shape)}D shape {model_weight.shape}" + ) h, w = model_weight.shape tex.fp8_block_scaling_compute_partial_amax( master_weight, amax, h, w, start_offset, block_len @@ -469,13 +614,233 @@ def _cast_master_weights_to_fp8_blockwise_scaling( end_offset = start_offset + master_weight.numel() if not use_fsdp_shard_model_weights: model_weight_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset] - assert len(model_weight.shape) == 2 + if len(model_weight.shape) != 2: + raise ValueError( + "model_weight must be 2D for blockwise scaling partial cast, " + f"got {len(model_weight.shape)}D shape {model_weight.shape}" + ) h, w = model_weight.shape tex.fp8_block_scaling_partial_cast( master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype ) +def _cast_master_weights_to_nvfp4_2d( + params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False +): + r"""Helper function to cast master weights to NVFP4 2D quantized weights. + + Parameters + ---------- + params : List of tuple, each tuple contains a model weight, a master weight, and an offset + indicating the starting index of the master weight in the model weight. + group : The distributed group to do amax reduction. Typically it's the data parallel + group. + use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded. + """ + + device = params[0][0].device + block_len = NVFP4_BLOCK_SCALING_SIZE + + cu_amax_sizes = [0] + tile_shapes: List[tuple[int, int]] = [] + tile_widths: List[int] = [] + scale_targets: List[torch.Tensor] = [] + amax_targets: List[Optional[torch.Tensor]] = [] + for model_weight, _, _, _ in params: + quantizer = model_weight._get_quantizer() + if not isinstance(quantizer, NVFP4Quantizer): + raise TypeError(f"Expected NVFP4Quantizer, got {type(quantizer).__name__}") + if not quantizer.with_2d_quantization: + raise ValueError("NVFP4 2D quantization must be enabled.") + if len(model_weight.shape) != 2: + raise ValueError(f"Expected 2D model weight, got {len(model_weight.shape)}D") + h, w = model_weight.shape + tile_h = (h + block_len - 1) // block_len + tile_w = (w + block_len - 1) // block_len + tile_shapes.append((tile_h, tile_w)) + tile_widths.append(tile_w) + scale_targets.append(model_weight._rowwise_scale_inv) + amax_targets.append(model_weight._amax_rowwise) + num_amaxes = tile_h * tile_w + cu_amax_sizes.append(cu_amax_sizes[-1] + num_amaxes) + + packed_amaxes = torch.zeros(cu_amax_sizes[-1], dtype=torch.float32, device=device) + packed_scales = torch.zeros(cu_amax_sizes[-1], dtype=torch.float32, device=device) + + amaxes: List[torch.Tensor] = [] + scales: List[torch.Tensor] = [] + global_amaxes = torch.zeros(len(params), dtype=torch.float32, device=device) + global_amax_views: List[torch.Tensor] = [global_amaxes[i : i + 1] for i in range(len(params))] + + # Collect tensors for batched multi-tensor amax computation + master_weight_list: List[torch.Tensor] = [] + partial_amax_list: List[torch.Tensor] = [] + global_amax_list: List[torch.Tensor] = [] + h_list: List[int] = [] + w_list: List[int] = [] + start_offset_list: List[int] = [] + + for i, (model_weight, master_weight, start_offset, _) in enumerate(params): + scale_shape = tile_shapes[i] + amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) + scale = packed_scales[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) + global_amax_view = global_amax_views[i] + + if model_weight._rowwise_scale_inv is None: + raise RuntimeError("model_weight._rowwise_scale_inv must not be None") + + amaxes.append(amax) + scales.append(scale) + + if master_weight is not None and master_weight.numel() > 0: + if len(model_weight.shape) != 2: + raise ValueError(f"Expected 2D model weight, got {len(model_weight.shape)}D") + h, w = model_weight.shape + # Collect for batched processing + master_weight_list.append(master_weight) + partial_amax_list.append(amax) + global_amax_list.append(global_amax_view) + h_list.append(h) + w_list.append(w) + start_offset_list.append(start_offset) + + # Batched multi-tensor call for partial and global amax computation + if master_weight_list: + tex.nvfp4_multi_tensor_compute_partial_amax( + master_weight_list, + partial_amax_list, + global_amax_list, + h_list, + w_list, + start_offset_list, + block_len, + ) + + if packed_amaxes.numel() > 0: + torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + + if global_amaxes.numel() > 0: + torch.distributed.all_reduce(global_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + + # Use GPU kernel to compute global encode scales from global amaxes + # This replaces multiple Python tensor operations with a single kernel + global_scale_tensor = torch.empty_like(global_amaxes) + + tex.nvfp4_compute_global_scale(global_amaxes, global_scale_tensor) + global_scale_views = [global_scale_tensor[i : i + 1] for i in range(len(params))] + + # Collect tensors for batched fused scale kernel + fused_scale_block_amax_list: List[torch.Tensor] = [] + fused_scale_global_amax_list: List[torch.Tensor] = [] + fused_scale_per_block_scale_list: List[torch.Tensor] = [] + fused_scale_target_scale_list: List[torch.Tensor] = [] + fused_scale_target_amax_list: List[torch.Tensor] = [] + fused_scale_tile_rows_list: List[int] = [] + fused_scale_tile_cols_list: List[int] = [] + fused_scale_rows_padded_list: List[int] = [] + + # Collect tensors for batched partial cast kernel + partial_cast_inp_list: List[torch.Tensor] = [] + partial_cast_out_list: List[torch.Tensor] = [] + partial_cast_scale_list: List[torch.Tensor] = [] + partial_cast_global_scale_list: List[torch.Tensor] = [] + partial_cast_h_list: List[int] = [] + partial_cast_w_list: List[int] = [] + partial_cast_start_offset_list: List[int] = [] + + # First pass: collect all tensors and update usage + zipped_meta = zip( + tile_shapes, + tile_widths, + scale_targets, + amax_targets, + params, + amaxes, + scales, + global_scale_views, + ) + for idx, ( + tile_shape, + tile_col_cnt, + target_scale, + target_amax, + (model_weight, master_weight, start_offset, model_weight_fragment), + block_amax, + per_block_decode_scale, + global_scale, + ) in enumerate(zipped_meta): + + if not manual_post_all_gather_processing: + # Reset transpose cache for all model weights. + # We cannot create transpose cache here because users (like megatron) may want to + # overlap the all-gather of model weights and forward process, so the model weight is + # not updated currently. + model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) + + tile_rows = tile_shape[0] + rows_padded = target_scale.shape[0] + global_amax_view = global_amaxes[idx : idx + 1] + + # Collect for fused scale kernel (only if target_amax is not None) + if target_amax is not None: + fused_scale_block_amax_list.append(block_amax) + fused_scale_global_amax_list.append(global_amax_view) + fused_scale_per_block_scale_list.append(per_block_decode_scale) + fused_scale_target_scale_list.append(target_scale) + fused_scale_target_amax_list.append(target_amax) + fused_scale_tile_rows_list.append(tile_rows) + fused_scale_tile_cols_list.append(tile_col_cnt) + fused_scale_rows_padded_list.append(rows_padded) + + # Collect for partial cast kernel (only for layers owned by this rank) + if master_weight is not None and master_weight.numel() > 0: + end_offset = start_offset + master_weight.numel() + if not use_fsdp_shard_model_weights: + rowwise_bytes = model_weight._rowwise_data.view(-1) + byte_start = start_offset // 2 + byte_end = (end_offset + 1) // 2 + model_weight_fragment = rowwise_bytes[byte_start:byte_end] + if len(model_weight.shape) != 2: + raise ValueError(f"Expected 2D model weight, got {len(model_weight.shape)}D") + h, w = model_weight.shape + + partial_cast_inp_list.append(master_weight) + partial_cast_out_list.append(model_weight_fragment) + partial_cast_scale_list.append(per_block_decode_scale) + partial_cast_global_scale_list.append(global_scale) + partial_cast_h_list.append(h) + partial_cast_w_list.append(w) + partial_cast_start_offset_list.append(start_offset) + + # Batched multi-tensor call for fused scale + if fused_scale_block_amax_list: + tex.nvfp4_multi_tensor_fused_scale( + fused_scale_block_amax_list, + fused_scale_global_amax_list, + fused_scale_per_block_scale_list, + fused_scale_target_scale_list, + fused_scale_target_amax_list, + fused_scale_tile_rows_list, + fused_scale_tile_cols_list, + fused_scale_rows_padded_list, + block_len, + ) + + # Batched multi-tensor call for partial cast + if partial_cast_inp_list: + tex.nvfp4_multi_tensor_2d_partial_cast( + partial_cast_inp_list, + partial_cast_out_list, + partial_cast_scale_list, + partial_cast_global_scale_list, + partial_cast_h_list, + partial_cast_w_list, + partial_cast_start_offset_list, + block_len, + ) + + def _cast_master_weights_to_fp8_mxfp8_scaling( params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False ): # pylint: disable=unused-argument @@ -502,9 +867,15 @@ def _cast_master_weights_to_fp8_mxfp8_scaling( cu_colwise_amax_sizes = [0] for model_weight, _, _, _ in params: rowwise_shape = model_weight._rowwise_scale_inv.shape - assert len(rowwise_shape) == 2 + if len(rowwise_shape) != 2: + raise ValueError( + f"rowwise_scale_inv must be 2D, got {len(rowwise_shape)}D shape {rowwise_shape}" + ) colwise_shape = model_weight._columnwise_scale_inv.shape - assert len(colwise_shape) == 2 + if len(colwise_shape) != 2: + raise ValueError( + f"columnwise_scale_inv must be 2D, got {len(colwise_shape)}D shape {colwise_shape}" + ) cu_rowwise_amax_sizes.append( cu_rowwise_amax_sizes[-1] + rowwise_shape[0] * rowwise_shape[1] ) @@ -543,7 +914,11 @@ def _cast_master_weights_to_fp8_mxfp8_scaling( # Compute amax of the master weight and store it in packed_amaxes. if master_weight is not None: - assert len(model_weight.shape) == 2 + if len(model_weight.shape) != 2: + raise ValueError( + "model_weight must be 2D for MXFP8 scaling, " + f"got {len(model_weight.shape)}D shape {model_weight.shape}" + ) h, w = model_weight.shape tex.mxfp8_scaling_compute_partial_amax( master_weight, amax_rowwise, amax_colwise, h, w, start_offset @@ -587,7 +962,11 @@ def _cast_master_weights_to_fp8_mxfp8_scaling( else: rowwise_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset] colwise_fragment = model_weight._columnwise_data.reshape(-1)[start_offset:end_offset] - assert len(model_weight.shape) == 2 + if len(model_weight.shape) != 2: + raise ValueError( + "model_weight must be 2D for MXFP8 scaling partial cast, " + f"got {len(model_weight.shape)}D shape {model_weight.shape}" + ) h, w = model_weight.shape tex.mxfp8_scaling_partial_cast( master_weight, @@ -607,9 +986,15 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten - Float8Tensor: may need to create a transposed view to match backend GEMM. - Float8BlockwiseQTensor: create column-wise storage. - Plain pytorch tensor: noop. + + For NVFP4 tensors, uses batched multi-tensor processing to reduce CPU overhead. """ if not isinstance(model_weights, list): model_weights = [model_weights] + + # Collect NVFP4 tensors for batched processing + nvfp4_tensors = [] + for model_weight in model_weights: if isinstance(model_weight, Float8Tensor): # Delayed scaling and per-tensor current scaling: if backend does not support @@ -620,12 +1005,91 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten elif isinstance(model_weight, Float8BlockwiseQTensor): # Blockwise scaling: create column-wise storage. model_weight._create_columnwise() + elif isinstance(model_weight, NVFP4Tensor): + # Collect for batched processing + nvfp4_tensors.append(model_weight) elif isinstance(model_weight, MXFP8Tensor): # MXFP8 scaling: no need to do anything. pass elif isinstance(model_weight, QuantizedTensor): raise ValueError(f"post_processing for {type(model_weight)} is not supported") + # Batch process all NVFP4 tensors with multi-tensor approach + if nvfp4_tensors: + _nvfp4_2d_multi_tensor_transpose(nvfp4_tensors) + + +def _nvfp4_2d_multi_tensor_transpose(nvfp4_tensors: List[NVFP4Tensor]): + """ + Batched columnwise creation for multiple NVFP4 tensors. + Reduces CPU overhead by collecting all tensor metadata and dispatching to C++. + """ + # Prepare tensor lists for batched C++ call + rowwise_data_list = [] + columnwise_data_list = [] + rowwise_scale_inv_list = [] + columnwise_scale_inv_list = [] + M_list = [] + K_list = [] + + for tensor in nvfp4_tensors: + rowwise_data = tensor._rowwise_data + if not rowwise_data.is_contiguous(): + rowwise_data = rowwise_data.contiguous() + tensor._rowwise_data = rowwise_data + + logical_shape = tensor.size() + M, K = logical_shape[0], logical_shape[-1] + + # Allocate columnwise_data if needed + if tensor._columnwise_data is None: + # Output shape: [K, M/2] packed bytes + columnwise_data = torch.empty( + (K, M // 2), + dtype=torch.uint8, + device=rowwise_data.device, + ) + tensor._columnwise_data = columnwise_data + else: + columnwise_data = tensor._columnwise_data + + # Allocate columnwise_scale_inv if needed + if tensor._columnwise_scale_inv is None: + if tensor._quantizer is None: + raise RuntimeError("tensor._quantizer must not be None") + columnwise_scale_inv_shape = tensor._quantizer.get_scale_shape(logical_shape, True) + columnwise_scale_inv = torch.empty( + columnwise_scale_inv_shape, + dtype=tensor._rowwise_scale_inv.dtype, + device=tensor._rowwise_scale_inv.device, + ) + tensor._columnwise_scale_inv = columnwise_scale_inv + else: + columnwise_scale_inv = tensor._columnwise_scale_inv + + rowwise_data_list.append(rowwise_data) + columnwise_data_list.append(columnwise_data) + rowwise_scale_inv_list.append(tensor._rowwise_scale_inv) + columnwise_scale_inv_list.append(columnwise_scale_inv) + M_list.append(M) + K_list.append(K) + + # Copy amax if needed + if tensor._amax_columnwise is None and tensor._amax_rowwise is not None: + tensor._amax_columnwise = tensor._amax_rowwise.clone() + elif tensor._amax_rowwise is not None: + tensor._amax_columnwise.copy_(tensor._amax_rowwise) + + # Dispatch to C++ multi-tensor kernel + tex.nvfp4_2d_multi_tensor_transpose( + rowwise_data_list, + columnwise_data_list, + rowwise_scale_inv_list, + columnwise_scale_inv_list, + M_list, + K_list, + ) + def is_custom(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool: """Check if an object is custom. diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 4f131c3c0..6573ea6a5 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -8,13 +8,13 @@ import os import warnings from contextlib import nullcontext -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch.torch_version import torch_version from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm -from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.jit import ( @@ -37,9 +37,7 @@ from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.module.base import TransformerEngineBaseModule - -from torch.utils.cpp_extension import IS_HIP_EXTENSION - +import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") @@ -153,11 +151,21 @@ class TransformerLayer(torch.nn.Module): distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`. Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by :attr:`window_size` in :meth:`forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, default = "no_mask" type of attention mask passed into softmax operation for decoder. enc_dec_window_size : Optional[Tuple[int, int]], default = None sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. zero_centered_gamma : bool, default = False if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to @@ -179,7 +187,7 @@ class TransformerLayer(torch.nn.Module): if set to ``False``, the transformer layer will not learn any additive biases. activation : str, default = 'gelu' Type of activation used in MLP block. - Options are: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, + Options are: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``. activation_params : Optional[dict], default = None Additional parameters for the activation function. @@ -306,7 +314,9 @@ def __init__( kv_channels: Optional[int] = None, self_attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, enc_dec_attn_mask_type: str = "no_mask", + enc_dec_bottom_right_diagonal: Optional[bool] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, @@ -348,8 +358,10 @@ def __init__( self.self_attn_mask_type = self_attn_mask_type self.window_size = window_size + self.bottom_right_diagonal = bottom_right_diagonal self.enc_dec_attn_mask_type = enc_dec_attn_mask_type self.enc_dec_window_size = enc_dec_window_size + self.enc_dec_bottom_right_diagonal = enc_dec_bottom_right_diagonal params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad @@ -364,23 +376,35 @@ def __init__( self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm if parallel_attention_mlp: - assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'" - assert not self.apply_residual_connection_post_layernorm, ( - "parallel_attention and apply_residual_connection_post_layernorm " - "not supported simultaneously." - ) - assert ( - not self.output_layernorm - ), "parallel_attention and output_layernorm not supported simultaneously" + if self.layer_type != "encoder": + raise ValueError( + "parallel_attention requires layer_type='encoder', " + f"but got layer_type={self.layer_type!r}" + ) + if self.apply_residual_connection_post_layernorm: + raise ValueError( + "parallel_attention and apply_residual_connection_post_layernorm " + "are not supported simultaneously." + ) + if self.output_layernorm: + raise ValueError( + "parallel_attention and output_layernorm are not supported simultaneously." + ) self.parallel_attention_mlp = parallel_attention_mlp - assert layer_type in LayerTypes, f"layer_type {layer_type} not supported" + if layer_type not in LayerTypes: + raise ValueError( + f"layer_type {layer_type!r} is not supported. " + f"Supported types are: {', '.join(repr(t) for t in LayerTypes)}" + ) if not fuse_qkv_params: - assert ( - not fuse_wgrad_accumulation - ), "Gradient accumulation fusion requires single QKV parameter." + if fuse_wgrad_accumulation: + raise ValueError( + "Gradient accumulation fusion (fuse_wgrad_accumulation=True) " + "requires fuse_qkv_params=True, but fuse_qkv_params is False." + ) if not fuse_qkv_params: qkv_weight_interleaved = False @@ -402,6 +426,7 @@ def __init__( self.softmax_type = softmax_type self.name = name + TransformerEngineBaseModule._validate_name(self) attention_args = ( hidden_size, @@ -450,7 +475,7 @@ def __init__( qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, - name=name + ".self_attention" if name is not None else None, + name=self.name + ".self_attention" if self.name is not None else None, ) if layer_type == "decoder": @@ -467,7 +492,7 @@ def __init__( qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, - name=name + ".inter_attention" if name is not None else None, + name=self.name + ".inter_attention" if self.name is not None else None, ) # LayerNorm -> activation(Linear + Bias) -> Linear @@ -503,7 +528,7 @@ def __init__( activation_params=activation_params, normalization=normalization, device=device, - name=name + ".layernorm_mlp" if name is not None else None, + name=self.name + ".layernorm_mlp" if self.name is not None else None, ) self.hidden_dropout = hidden_dropout @@ -535,6 +560,10 @@ def __init__( device=device, ) + def fast_setattr(self, name: str, value: Any) -> None: + """Fast attribute set for non-parameter fields.""" + self.__dict__[name] = value + def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ Set the tensor parallel group for the given @@ -610,10 +639,12 @@ def forward( attention_mask: Optional[torch.Tensor] = None, self_attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, enc_dec_attn_mask_type: Optional[str] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None, + enc_dec_bottom_right_diagonal: Optional[bool] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, @@ -658,6 +689,11 @@ def forward( causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention in encoder. + bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. encoder_output : Optional[torch.Tensor], default = None Output of the encoder block to be fed into the decoder block if using :attr:`layer_type` = ``"decoder"``. @@ -674,6 +710,11 @@ def forward( Type of attention mask passed into softmax operation for decoder. enc_dec_window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -740,40 +781,87 @@ def forward( self_attn_mask_type = self.self_attn_mask_type if window_size is None: window_size = self.window_size + window_size = dpa_utils.check_set_window_size(self_attn_mask_type, window_size) + if enc_dec_attn_mask_type is None: enc_dec_attn_mask_type = self.enc_dec_attn_mask_type if enc_dec_window_size is None: enc_dec_window_size = self.enc_dec_window_size + enc_dec_window_size = dpa_utils.check_set_window_size( + enc_dec_attn_mask_type, enc_dec_window_size + ) - assert ( - self_attn_mask_type in AttnMaskTypes - ), f"self_attn_mask_type {self_attn_mask_type} not supported" - assert ( - enc_dec_attn_mask_type in AttnMaskTypes - ), f"enc_dec_attn_mask_type {enc_dec_attn_mask_type} not supported" + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if self_attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or self_attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True + + if enc_dec_bottom_right_diagonal is None: + enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal + if enc_dec_attn_mask_type in {"causal", "padding_causal"}: + enc_dec_bottom_right_diagonal = False + if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + enc_dec_bottom_right_diagonal = True + + if self_attn_mask_type not in AttnMaskTypes: + raise ValueError( + f"self_attn_mask_type {self_attn_mask_type!r} is not supported. " + f"Supported types are: {', '.join(repr(t) for t in AttnMaskTypes)}" + ) + if enc_dec_attn_mask_type not in AttnMaskTypes: + raise ValueError( + f"enc_dec_attn_mask_type {enc_dec_attn_mask_type!r} is not supported. " + f"Supported types are: {', '.join(repr(t) for t in AttnMaskTypes)}" + ) hidden_states = hidden_states.contiguous() if self.sequence_parallel and self.seq_length is not None: - assert ( - hidden_states.shape[0] == self.seq_length // self.tp_size - ), "Sequence dimension must be split across TP group when using sequence parallel." + if hidden_states.shape[0] != self.seq_length // self.tp_size: + raise ValueError( + "Sequence dimension must be split across TP group when using " + "sequence parallel. Expected hidden_states.shape[0] to be " + f"{self.seq_length // self.tp_size} " + f"(seq_length={self.seq_length} // tp_size={self.tp_size}), " + f"but got {hidden_states.shape[0]}." + ) if ( "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary" ) and attention_mask is not None: - assert all( - attention_mask[i].dtype == torch.bool for i in range(len(attention_mask)) - ), "Attention mask must be a boolean tensor or a list/tuple of two boolean tensors" + if not all(attention_mask[i].dtype == torch.bool for i in range(len(attention_mask))): + non_bool_dtypes = [ + (i, attention_mask[i].dtype) + for i in range(len(attention_mask)) + if attention_mask[i].dtype != torch.bool + ] + raise TypeError( + "Attention mask must be a boolean tensor or a list/tuple of boolean " + f"tensors, but found non-bool dtypes at indices: {non_bool_dtypes}" + ) if ( "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary" ) and enc_dec_attn_mask is not None: - assert all( + if not all( enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) - ), "Encoder-decoder attention mask must be boolean tensor(s)" - - if TEDebugState.debug_enabled: - TransformerEngineBaseModule._validate_name(self) + ): + non_bool_dtypes = [ + (i, enc_dec_attn_mask[i].dtype) + for i in range(len(enc_dec_attn_mask)) + if enc_dec_attn_mask[i].dtype != torch.bool + ] + raise TypeError( + "Encoder-decoder attention mask must be boolean tensor(s), " + f"but found non-bool dtypes at indices: {non_bool_dtypes}" + ) # For AMP if torch.is_autocast_enabled(): @@ -785,6 +873,7 @@ def forward( attention_mask=attention_mask, attn_mask_type=self_attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, @@ -820,6 +909,7 @@ def forward( attention_mask=enc_dec_attn_mask, attn_mask_type=enc_dec_attn_mask_type, window_size=enc_dec_window_size, + bottom_right_diagonal=enc_dec_bottom_right_diagonal, encoder_output=encoder_output, inference_params=inference_params, is_first_microbatch=is_first_microbatch, diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 0e7ad7018..2c08f0063 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -54,7 +54,7 @@ def cross_entropy_forward( n_non_ignore = torch.zeros(1, dtype=torch.int64, device=_input.device) # ensure _input and target are contiguous in the last dimension - if _input.stride(-1) != 1: + if _input.stride(-1) != 1 or _input.stride(-2) != _input.shape[-1]: _input = _input.contiguous() if target.stride(-1) != 1: target = target.contiguous() diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 6b5de9ab0..4902bc686 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -427,6 +427,7 @@ def sort_chunks_by_map( inp, row_id_map, probs, + output, # no use in Pytorch side, serves as WAR for JAX side inp.stride(0), inp.stride(1), output.stride(0), diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 7131d45e6..2d3ec389a 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -15,7 +15,6 @@ import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION -from .quantized_tensor import Quantizer from .torch_version import torch_version from ..debug.pytorch.debug_quantization import DebugQuantizedTensor @@ -152,7 +151,8 @@ def compare_tensors(a: torch.Tensor, b: torch.Tensor) -> None: def ensure_divisibility(numerator: int, denominator: int) -> None: """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}" + if numerator % denominator != 0: + raise ValueError(f"{numerator} is not divisible by {denominator}") def divide(numerator: int, denominator: int) -> int: @@ -247,6 +247,7 @@ def forward( fp8_dtype=mixed_x_layer._fp8_dtype, data=x.squeeze(split_dim) if squeeze else x, shape=x.squeeze(split_dim).shape if squeeze else x.shape, + fake_dtype=mixed_x_layer._dtype, quantizer=mixed_x_layer._quantizer, ) for x in torch.split( @@ -276,13 +277,16 @@ def forward( @staticmethod def backward(ctx, *grad_outputs): # pylint: disable=missing-function-docstring - assert len(grad_outputs) > 0, "No gradients received for backprop!" + if len(grad_outputs) == 0: + raise RuntimeError("No gradients received for backprop!") if isinstance(ctx.split_size_or_sections, (list, tuple)): split_sizes = ctx.split_size_or_sections - assert len(grad_outputs) == len( - split_sizes - ), "Unequal number of gradients vs split sections for backprop!" + if len(grad_outputs) != len(split_sizes): + raise RuntimeError( + f"Unequal number of gradients ({len(grad_outputs)}) vs " + f"split sections ({len(split_sizes)}) for backprop!" + ) if isinstance(ctx.split_size_or_sections, int): split_sizes = [ctx.split_size_or_sections] * len(grad_outputs) dims = len(grad_outputs[0].shape) @@ -376,7 +380,8 @@ def validate_rng_states_func(get_rng_tracker: Callable) -> None: """Checks if passed in param function has everything required for tensor/model and sequence parallel. """ - assert callable(get_rng_tracker), "get_rng_tracker is not a valid function" + if not callable(get_rng_tracker): + raise TypeError(f"get_rng_tracker must be callable, got {type(get_rng_tracker).__name__}") rng_tracker = None try: @@ -384,15 +389,13 @@ def validate_rng_states_func(get_rng_tracker: Callable) -> None: except Exception as e: raise RuntimeError("Cannot call get_rng_tracker function") from e - assert hasattr(rng_tracker, "get_states") and callable( - rng_tracker.get_states - ), "rng_tracker object does not have valid method get_states" - assert hasattr(rng_tracker, "set_states") and callable( - rng_tracker.set_states - ), "rng_tracker object does not have valid method set_states" - assert hasattr(rng_tracker, "fork") and callable( - rng_tracker.fork - ), "rng_tracker object does not have valid method fork" + for method_name in ("get_states", "set_states", "fork"): + if not hasattr(rng_tracker, method_name) or not callable(getattr(rng_tracker, method_name)): + raise TypeError( + f"rng_tracker object ({type(rng_tracker).__name__}) does not have " + f"a valid callable method '{method_name}'. " + "Required methods: get_states, set_states, fork." + ) validate_ctx_manager(rng_tracker.fork) @@ -403,11 +406,12 @@ def assert_viewless_tensor(tensor: torch.Tensor, extra_msg: Optional[str] = None return [assert_viewless_tensor(t) for t in tensor] if not isinstance(tensor, torch.Tensor): return tensor - assert tensor._base is None, ( - "Ensure tensor._base is None before setting tensor.data or storing " - "tensor to memory buffer. Otherwise, a memory leak will occur (and " - f"likely accumulate over iterations). {extra_msg}" - ) + if tensor._base is not None: + raise ValueError( + "Ensure tensor._base is None before setting tensor.data or storing " + "tensor to memory buffer. Otherwise, a memory leak will occur (and " + f"likely accumulate over iterations). {extra_msg}" + ) return tensor @@ -445,11 +449,13 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM.""" for tensor in tensors: - assert math.prod(tensor.shape[:-1]) % 8 == 0 and tensor.shape[-1] % 16 == 0, ( - "FP8 execution requires the product of all dimensions except the last to be divisible" - " by 8 and the last dimension to be divisible by 16, but got tensor with" - f" dims={list(tensor.size())}" - ) + if math.prod(tensor.shape[:-1]) % 8 != 0 or tensor.shape[-1] % 16 != 0: + raise ValueError( + "FP8 execution requires the product of all dimensions except the last to be" + " divisible by 8 and the last dimension to be divisible by 16, but got tensor" + f" with dims={list(tensor.size())} (product of leading dims =" + f" {math.prod(tensor.shape[:-1])}, last dim = {tensor.shape[-1]})" + ) if IS_HIP_EXTENSION: @functools.lru_cache(maxsize=None) @@ -457,7 +463,7 @@ def is_mi200(): """check whether this machine is mi200/210/250""" import re return (re.search('AMD Instinct MI2.0', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) - + @functools.lru_cache(maxsize=None) def is_mi308(): """check whether this machine is mi308""" @@ -471,27 +477,15 @@ def is_fp8_fnuz(): get_torch_float8_e4m3_type = lambda: torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn get_torch_float8_e5m2_type = lambda: torch.float8_e5m2fnuz if is_fp8_fnuz() else torch.float8_e5m2 -def assert_dim_for_all_gather( - tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer -) -> None: - """Assert that tensor dimensions are supported for all-gather""" - if with_all_gather: - assert quantizer.is_quantizable(tensor), ( - "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ - ) - -def is_bf16_compatible() -> None: +def is_bf16_compatible() -> bool: if IS_HIP_EXTENSION: # only MI200 and newer machines support bf16 - if get_device_compute_capability() in [(9, 4), (9, 5)] or is_mi200(): - return True - else: - return False - else: - """Replaces torch.cuda.is_bf16_compatible() with an explicit - check on device compute capability to enforce sm_80 or higher. - """ - return torch.cuda.get_device_capability()[0] >= 8 + return get_device_compute_capability() in [(9, 4), (9, 5)] or is_mi200() + """Replaces torch.cuda.is_bf16_compatible() with an explicit + check on device compute capability to enforce sm_80 or higher. + """ + return torch.cuda.get_device_capability()[0] >= 8 + def is_bf16_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: """ @@ -522,7 +516,10 @@ def is_non_tn_fp8_gemm_supported() -> bool: non-TN layouts for FP8 GEMMs. """ # TODO: release until rocm support non-TN fp8 gemms - return (not IS_HIP_EXTENSION) and (torch.cuda.get_device_capability() >= (10, 0)) + if IS_HIP_EXTENSION: + return False + device_capability = torch.cuda.get_device_capability() + return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0) @functools.lru_cache(maxsize=None) @@ -784,7 +781,9 @@ def __cuda_array_interface__(self): def torch_dtype_to_np_typestr(self): """Convert PyTorch dtype to numpy typestr.""" ret = _torch_dtype_to_np_typestr_dict.get(self.dtype) - assert ret is not None, f"Unsupported dtype: {self.dtype}" + if ret is None: + supported = ", ".join(str(d) for d in _torch_dtype_to_np_typestr_dict) + raise TypeError(f"Unsupported dtype: {self.dtype}. Supported dtypes: {supported}") return ret @@ -823,4 +822,7 @@ def convert_to_torch_tensor(tensor: Union[_WeakRefTensor, torch.Tensor]) -> torc return x if x is None: return None - raise TypeError(f"Invalid type {type(x)} to make weak ref") + raise TypeError( + f"Invalid type {type(x).__name__} to make weak ref. " + "Valid types are: torch.Tensor, tuple, list, dict, int, float, bool, and None." + )