From 2e1777197b9705e4454fa8e6ec3d96f9914808ef Mon Sep 17 00:00:00 2001 From: niangao233 <2532475357@qq.com> Date: Thu, 21 May 2026 14:52:39 +0800 Subject: [PATCH] feat: support flash attention 2 across install paths --- Dockerfile | 12 +++++-- Dockerfile-for-Mainland-China | 19 +++++++--- README-zh.md | 18 +++++----- README.md | 18 +++++----- build-scripts/build_portable.ps1 | 30 ++++++++++++++-- install-cn.ps1 | 8 +++-- install.bash | 14 ++++++-- install.ps1 | 8 +++-- mikazuki/app/api.py | 9 ++--- mikazuki/portable_utils.py | 61 +++++++++++++++++++++++--------- run_gui_source.ps1 | 20 +++++++---- setup_environment.py | 60 +++++++++++++------------------ tests/test_portable_utils.py | 60 +++++++++++++++++++++++++++++++ 13 files changed, 236 insertions(+), 101 deletions(-) create mode 100644 tests/test_portable_utils.py diff --git a/Dockerfile b/Dockerfile index 15217fa7..80c0f889 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,10 @@ FROM nvcr.io/nvidia/pytorch:24.07-py3 EXPOSE 28000 ENV TZ=Asia/Shanghai -RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && apt update && apt install python3-tk -y +ENV MAX_JOBS=4 +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && \ + apt update && apt install python3-tk ninja-build -y && \ + rm -rf /var/lib/apt/lists/* RUN mkdir /app @@ -11,11 +14,14 @@ WORKDIR /app RUN git clone https://github.com/wochenlong/lora-scripts-next.git lora-scripts WORKDIR /app/lora-scripts -RUN pip install xformers==0.0.27.post2 --no-deps && pip install -r requirements.txt +RUN pip install xformers==0.0.27.post2 --no-deps && \ + pip install -r requirements.txt && \ + pip install flash-attn==2.7.4.post1 --no-build-isolation && \ + python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary; print('Flash Attention 2 OK')" WORKDIR /app/lora-scripts/scripts RUN pip install -r requirements.txt WORKDIR /app/lora-scripts -CMD ["python", "gui.py", "--listen"] \ No newline at end of file +CMD ["python", "gui.py", "--listen"] diff --git a/Dockerfile-for-Mainland-China b/Dockerfile-for-Mainland-China index 427d923c..6b958bd7 100644 --- a/Dockerfile-for-Mainland-China +++ b/Dockerfile-for-Mainland-China @@ -3,7 +3,10 @@ FROM nvcr.io/nvidia/pytorch:24.07-py3 EXPOSE 28000 ENV TZ=Asia/Shanghai -RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && apt update && apt install python3-tk -y +ENV MAX_JOBS=4 +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && \ + apt update && apt install python3-tk ninja-build -y && \ + rm -rf /var/lib/apt/lists/* RUN mkdir /app @@ -17,7 +20,10 @@ RUN pip config set global.index-url 'https://pypi.tuna.tsinghua.edu.cn/simple' & pip config set install.trusted-host 'pypi.tuna.tsinghua.edu.cn' # 初次安装依赖 -RUN pip install xformers==0.0.27.post2 --no-deps && pip install -r requirements.txt +RUN pip install xformers==0.0.27.post2 --no-deps && \ + pip install -r requirements.txt && \ + pip install flash-attn==2.7.4.post1 --no-build-isolation && \ + python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary; print('Flash Attention 2 OK')" # 更新 训练程序 stable 版本依赖 WORKDIR /app/lora-scripts/scripts/stable @@ -33,7 +39,10 @@ WORKDIR /app/lora-scripts # ref # - https://soulteary.com/2024/01/07/fix-opencv-dependency-errors-opencv-fixer.html # - https://blog.csdn.net/qq_50195602/article/details/124188467 -RUN pip install opencv-fixer==0.2.5 && python -c "from opencv_fixer import AutoFix; AutoFix()" \ - pip install opencv-python-headless && apt install ffmpeg libsm6 libxext6 libgl1 -y +RUN pip install opencv-fixer==0.2.5 && \ + python -c "from opencv_fixer import AutoFix; AutoFix()" && \ + pip install opencv-python-headless && \ + apt update && apt install ffmpeg libsm6 libxext6 libgl1 -y && \ + rm -rf /var/lib/apt/lists/* -CMD ["python", "gui.py", "--listen"] \ No newline at end of file +CMD ["python", "gui.py", "--listen"] diff --git a/README-zh.md b/README-zh.md index a1cb8b30..53319655 100644 --- a/README-zh.md +++ b/README-zh.md @@ -56,18 +56,16 @@ > 512 分辨率约节省 2~3 GB;降低 `network_dim`(如 8)也能少量减少显存。 -#### 整合包暂不支持 Flash Attention 2(说明) - -**当前 Windows 整合包(`SD-Trainer-v*.7z`)不会安装 Flash Attention 2,训练使用 xformers 或 PyTorch SDPA。** 这与「装不上」无关,而是便携包运行方式下的**刻意取舍**。 +#### 整合包 Flash Attention 2 支持 | 点 | 说明 | |----|------| | **flash-attn 依赖 triton** | 预编译的 `flash-attn` wheel 能装进环境,但运行时大量算子仍通过 `flash_attn.ops.triton` 调用 **Triton** 生成的 CUDA kernel。 | -| **嵌入式 Python 跑不好 triton** | 整合包使用 Python Embeddable(`python_embeded\`),缺少完整编译链;`triton` / `triton-windows` 常在首次 JIT 时失败,导致启动或训练崩溃。 | -| **不能只卸 triton、保留 flash-attn** | 若只安装 `flash-attn` 而不装 `triton`,import 时会报 `No module named 'triton'`;`transformers` 等库探测到已安装的 `flash_attn` 也可能仍尝试走 flash 路径。 | -| **整合包实际策略** | 首次安装跳过 flash-attn;若用户手动 `pip install` 了不完整的组合,启动时会自动卸载 flash-attn / triton,并设置 `TRANSFORMERS_ATTN_IMPLEMENTATION=sdpa`。 | +| **嵌入式 Python + triton** | 整合包使用 Python Embeddable(`python_embeded\`),因此固定安装 `triton-windows<3.4` 与匹配 PyTorch 2.7/CUDA 12.8 的 Flash Attention 2 wheel。 | +| **先自检再启用** | 首次安装和 `install_flash_attn.bat` 都会验证 `import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary`。 | +| **失败回退** | 若自检失败,启动时会自动卸载不完整的 flash-attn / triton 组合,训练回退到 **xformers** 或 **PyTorch SDPA**。 | -**需要 Flash Attention 2 时:** 请使用下方「[从源码安装](#从源码安装)」并按 **[Flash Attention 2(源码 / venv 用户)](#flash-attention-2源码--venv-用户)** 配置;整合包在 embed Python 支持成熟前**暂不承诺** flash-attn 加速。 +自检通过后,Anima / SD3 LoRA 会自动选择 `attn_mode=flash`;失败时日志会说明原因并继续训练。 ### 从源码安装 @@ -108,7 +106,7 @@ python gui.py --browser edge #### Flash Attention 2(源码 / venv 用户) -**整合包用户请看上节,不要对 `python_embeded` 手动安装 flash-attn。** +整合包用户可通过首次启动或 `install_flash_attn.bat` 使用同一套固定组合;源码用户也可按下方命令手动安装。 本节适用于:`git clone` 后使用 **`venv`**(或 `python\` 目录下的完整 Python),且已安装 **PyTorch 2.7.0 + CUDA 12.8** 的源码用户。 @@ -191,7 +189,7 @@ python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary i | wheel 安装成功但训练仍用 xformers | 运行上方验证命令;若失败说明 triton 与 flash-attn 未配对,勿只保留 flash-attn | | `pip install flash-attn` 编译很久或失败 | Windows 请改用 **预编译 wheel**(上表 URL),不要在本机编译 | | PyTorch 版本不是 2.7+cu128 | wheel 与 CUDA 标签不匹配,请对齐 `install.ps1` 中的 torch 版本后再装 flash-attn | -| 在整合包 `python_embeded` 里安装 | **不支持**,请改用源码 + venv | +| 整合包自检失败 | 更新显卡驱动或依赖后重新运行 `install_flash_attn.bat`;损坏组合会在启动时自动清理 | --- @@ -199,7 +197,7 @@ python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary i - **多模型支持** — SD 1.5 / SDXL / Flux / **Anima** 全部开箱即用 - **Anima LoRA 训练** — 侧边栏一键进入,支持 LoRA / LoKr(LyCORIS)/ **T-LoRA** -- **Attention 加速** — 自动选择后端:源码/venv 环境优先 Flash Attention 2(Windows 预编译 wheel);**整合包**使用 xformers / PyTorch SDPA([暂不支持 flash-attn](#整合包暂不支持-flash-attention-2说明)) +- **Attention 加速** — 自动选择后端:源码/venv、整合包、Docker 均在自检通过时优先 Flash Attention 2,失败时回退到 xformers / PyTorch SDPA - **T-LoRA** — 基于扩散时间步的动态 Rank LoRA,正交初始化,防止过拟合([论文](https://github.com/ControlGenAI/T-LoRA)) - **训练监控页** — 随 GUI 自动启动,展示 TensorBoard 同源 Loss / LR 曲线、关键训练参数速查、实时进度、终端日志同步和预览图 - **TensorBoard 内置** — 侧边栏直接查看,无需额外操作 diff --git a/README.md b/README.md index ae4608d3..854a8542 100644 --- a/README.md +++ b/README.md @@ -56,18 +56,16 @@ Measured on RTX 4090, batch=1, bf16, standard LoRA (dim=16): > 512 resolution saves roughly 2–3 GB; lowering `network_dim` (e.g. to 8) also helps marginally. -#### Portable package: Flash Attention 2 not supported (for now) - -The **Windows portable package** (`SD-Trainer-v*.7z`) **does not install Flash Attention 2**; training uses **xformers** or **PyTorch SDPA**. This is intentional, not a failed install. +#### Portable package: Flash Attention 2 | Point | Why | |-------|-----| | **flash-attn needs triton** | Prebuilt `flash-attn` wheels install, but many kernels still run via **Triton** (`flash_attn.ops.triton`). | -| **Embedded Python + triton** | The portable bundle uses Python Embeddable (`python_embeded\`) without a full toolchain; `triton` / `triton-windows` often fail at JIT compile time. | -| **Cannot keep flash-attn without triton** | Flash-attn-only installs hit `No module named 'triton'`; `transformers` may still probe `flash_attn` if the package is present. | -| **What we do** | Skip flash-attn on first setup; on launch, remove broken flash-attn/triton pairs and set `TRANSFORMERS_ATTN_IMPLEMENTATION=sdpa`. | +| **Embedded Python + triton** | The portable bundle uses Python Embeddable (`python_embeded\`), so the package pins `triton-windows<3.4` and the matching Flash Attention 2 wheel. | +| **Self-check first** | First setup and `install_flash_attn.bat` verify `import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary`. | +| **Fallback** | If the self-check fails, startup removes the broken flash-attn/triton pair and training falls back to **xformers** or **PyTorch SDPA**. | -For Flash Attention 2, use **[install from source](#install-from-source)** and follow **[Flash Attention 2 (source / venv)](#flash-attention-2-source--venv-installs)**. Portable flash-attn support may come later when embed Python + triton is reliable. +When the self-check passes, Anima / SD3 LoRA auto-selects `attn_mode=flash`; otherwise logs explain the fallback. ### Install from Source @@ -96,7 +94,7 @@ python gui.py --browser edge #### Flash Attention 2 (source / venv installs) -**Portable users:** see the section above — do not `pip install flash-attn` into `python_embeded`. +Portable users get the same pinned stack through first-run setup or `install_flash_attn.bat`; source users can also install it manually. This section is for **`git clone` + `venv`** (or a full Python under `python\`), with **PyTorch 2.7.0 + CUDA 12.8** installed. @@ -177,7 +175,7 @@ Then run `python gui.py` and start **Anima LoRA** training — logs should show | Wheel installs but training uses xformers | Run the verify command above; flash-attn without working triton is ignored | | Long compile or build errors | On Windows use the **prebuilt wheel** URLs, not `pip install flash-attn` from source | | PyTorch not 2.7+cu128 | Align torch with `install.ps1` before installing flash-attn | -| Installed into portable `python_embeded` | **Unsupported** — use source + venv instead | +| Portable self-check fails | Rerun `install_flash_attn.bat` after updating GPU drivers/dependencies; broken stacks are removed automatically | --- @@ -185,7 +183,7 @@ Then run `python gui.py` and start **Anima LoRA** training — logs should show - **Multi-model** — SD 1.5 / SDXL / Flux / **Anima** all work out of the box - **Anima LoRA training** — One-click sidebar entry, supports LoRA / LoKr (LyCORIS) / **T-LoRA** -- **Attention backends** — Source/venv: Flash Attention 2 when available (Windows prebuilt wheels). **Portable package:** xformers / PyTorch SDPA only ([flash-attn not supported yet](#portable-package-flash-attention-2-not-supported-for-now)) +- **Attention backends** — Source/venv, portable package, and Docker all prefer Flash Attention 2 when the stack self-checks OK, then fall back to xformers / PyTorch SDPA - **T-LoRA** — Timestep-Dependent LoRA with dynamic rank and orthogonal init ([paper](https://github.com/ControlGenAI/T-LoRA)) - **Train Monitor** — Auto-opens with GUI, TensorBoard-backed Loss / LR scalar cards, key training parameter checks, real-time progress, terminal log echo, and preview samples - **Built-in TensorBoard** — Accessible from the sidebar, no extra setup diff --git a/build-scripts/build_portable.ps1 b/build-scripts/build_portable.ps1 index 7a286b55..21c8b10c 100644 --- a/build-scripts/build_portable.ps1 +++ b/build-scripts/build_portable.ps1 @@ -355,6 +355,9 @@ $updateDepsBat = "@echo off`r`nchcp 65001 >nul 2>&1`r`ncd /d `"%~dp0..`"`r`n" $updateDepsBat += "echo Updating Python dependencies...`r`n" $updateDepsBat += "`"python_embeded\python.exe`" -s -m pip install --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cu128`r`n" $updateDepsBat += "`"python_embeded\python.exe`" -s -m pip install --upgrade -r `"SD-Trainer\requirements.txt`"`r`n" +$updateDepsBat += "`"python_embeded\python.exe`" -s -m pip install `"triton-windows<3.4`" --no-warn-script-location`r`n" +$updateDepsBat += "`"python_embeded\python.exe`" -s -m pip install https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-win_amd64.whl --no-warn-script-location`r`n" +$updateDepsBat += "`"python_embeded\python.exe`" -s -c `"import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary; print('Flash Attention 2 OK')`"`r`n" $updateDepsBat += "echo Done.`r`npause`r`n" [System.IO.File]::WriteAllText( (Join-Path $updateDir "update_dependencies.bat"), @@ -382,6 +385,28 @@ $xformersBat += "echo.`r`necho Done! You can now use attn_mode = xformers.`r`ne ) Write-Host " Created install_xformers.bat" +# install_flash_attn.bat — one-click Flash Attention 2 installer for portable users +$flashAttnBat = "@echo off`r`nchcp 65001 >nul 2>&1`r`ntitle Install Flash Attention 2`r`ncd /d `"%~dp0`"`r`n" +$flashAttnBat += "set `"PYTHON_EXE=%~dp0python_embeded\python.exe`"`r`n" +$flashAttnBat += "if not exist `"%PYTHON_EXE%`" (`r`n" +$flashAttnBat += " echo [ERROR] python_embeded\python.exe not found!`r`n" +$flashAttnBat += " pause`r`n exit /b 1`r`n)`r`n" +$flashAttnBat += "echo.`r`necho Installing Flash Attention 2 for Torch 2.7.0 + CUDA 12.8 ...`r`necho.`r`n" +$flashAttnBat += "`"%PYTHON_EXE%`" -s -m pip install `"triton-windows<3.4`" --no-warn-script-location`r`n" +$flashAttnBat += "if errorlevel 1 (`r`n echo [ERROR] triton-windows installation failed.`r`n pause`r`n exit /b 1`r`n)`r`n" +$flashAttnBat += "`"%PYTHON_EXE%`" -s -m pip install https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-win_amd64.whl --no-warn-script-location`r`n" +$flashAttnBat += "if errorlevel 1 (`r`n echo [ERROR] flash-attn wheel installation failed.`r`n pause`r`n exit /b 1`r`n)`r`n" +$flashAttnBat += "echo.`r`necho Verifying...`r`n" +$flashAttnBat += "`"%PYTHON_EXE%`" -s -c `"import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary; print(' Flash Attention 2 OK')`"`r`n" +$flashAttnBat += "if errorlevel 1 (`r`n echo [ERROR] Flash Attention 2 self-check failed. Training will fall back to xformers or PyTorch SDPA.`r`n `"%PYTHON_EXE%`" -s -m pip uninstall flash-attn flash_attn triton-windows triton -y >nul 2>&1`r`n pause`r`n exit /b 1`r`n)`r`n" +$flashAttnBat += "echo.`r`necho Done! Anima training can now use attn_mode = flash.`r`necho.`r`npause`r`n" +[System.IO.File]::WriteAllText( + (Join-Path $portableDir "install_flash_attn.bat"), + $flashAttnBat, + (New-Object System.Text.UTF8Encoding $false) +) +Write-Host " Created install_flash_attn.bat" + # Root-level utility bat files $templateDir = Join-Path $PSScriptRoot "templates" foreach ($bat in @("Update-SD-Trainer.bat", "Download-Anima-Model.bat")) { @@ -428,8 +453,9 @@ $readme += "xformers (recommended):`r`n" $readme += " If xformers is missing, double-click install_xformers.bat to install.`r`n" $readme += " xformers provides faster attention than PyTorch SDPA on most GPUs.`r`n`r`n" $readme += "Flash Attention 2:`r`n" -$readme += " This portable package does NOT use flash-attn (uses xformers / PyTorch SDPA).`r`n" -$readme += " Do not pip install flash-attn into python_embeded. See README in SD-Trainer/.`r`n" +$readme += " First launch installs the pinned flash-attn + triton-windows stack when available.`r`n" +$readme += " If the self-check fails, startup removes the broken stack and falls back to xformers / PyTorch SDPA.`r`n" +$readme += " You can rerun install_flash_attn.bat after updating GPU drivers or dependencies.`r`n" [System.IO.File]::WriteAllText( (Join-Path $portableDir "README.txt"), $readme, diff --git a/install-cn.ps1 b/install-cn.ps1 index 893c6066..1e7ed400 100644 --- a/install-cn.ps1 +++ b/install-cn.ps1 @@ -59,8 +59,9 @@ python -m pip install --upgrade -r requirements.txt Check "训练依赖库安装失败。" -Write-Output "Installing Flash Attention 2 (prebuilt wheel)..." +Write-Output "Installing Flash Attention 2 stack (triton-windows + prebuilt wheel)..." $pyver = python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')" 2>$null +python -m pip install "triton-windows<3.4" if ($pyver -match "^cp3(10|11|12)$") { $whl = "flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-$pyver-$pyver-win_amd64.whl" $url = "https://hf-mirror.com/lldacing/flash-attention-windows-wheel/resolve/main/$whl" @@ -68,10 +69,11 @@ if ($pyver -match "^cp3(10|11|12)$") { } else { python -m pip install flash-attn --no-build-isolation 2>$null } +python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary" 2>$null if ($LASTEXITCODE -eq 0) { - Write-Output "Flash Attention 2 installed" + Write-Output "Flash Attention 2 installed and verified" } else { - Write-Output "Flash Attention 2 install failed (non-fatal)" + Write-Output "Flash Attention 2 install/self-check failed (non-fatal)" } Write-Output "安装完成" diff --git a/install.bash b/install.bash index 948f2b4e..5459c6ca 100644 --- a/install.bash +++ b/install.bash @@ -73,8 +73,16 @@ cd "$script_dir" || exit pip install --upgrade -r requirements.txt echo "Installing Flash Attention 2 (optional, for training acceleration)..." -pip install flash-attn --no-build-isolation 2>/dev/null && \ - echo "Flash Attention 2 installed successfully" || \ - echo "Flash Attention 2 install failed (non-fatal, will use PyTorch SDPA)" +export MAX_JOBS="${MAX_JOBS:-4}" +if pip install "flash-attn==2.7.4.post1" --no-build-isolation; then + if python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary" 2>/dev/null; then + echo "Flash Attention 2 installed and verified successfully" + else + echo "Flash Attention 2 self-check failed (non-fatal, will use xformers or PyTorch SDPA)" + pip uninstall flash-attn flash_attn triton -y >/dev/null 2>&1 || true + fi +else + echo "Flash Attention 2 install failed (non-fatal, will use xformers or PyTorch SDPA)" +fi echo "Install completed" diff --git a/install.ps1 b/install.ps1 index 474f43c2..53ad5c42 100644 --- a/install.ps1 +++ b/install.ps1 @@ -22,8 +22,9 @@ pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --extra-index-url https pip install -U -I --no-deps xformers==0.0.30 --extra-index-url https://download.pytorch.org/whl/cu128 pip install --upgrade -r requirements.txt -Write-Output "Installing Flash Attention 2 (prebuilt wheel)..." +Write-Output "Installing Flash Attention 2 stack (triton-windows + prebuilt wheel)..." $pyver = python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')" 2>$null +pip install "triton-windows<3.4" if ($pyver -match "^cp3(10|11|12)$") { $whl = "flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-$pyver-$pyver-win_amd64.whl" $url = "https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/$whl" @@ -31,10 +32,11 @@ if ($pyver -match "^cp3(10|11|12)$") { } else { pip install flash-attn --no-build-isolation 2>$null } +python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary" 2>$null if ($LASTEXITCODE -eq 0) { - Write-Output "Flash Attention 2 installed successfully" + Write-Output "Flash Attention 2 installed and verified successfully" } else { - Write-Output "Flash Attention 2 install failed (non-fatal, will use PyTorch SDPA instead)" + Write-Output "Flash Attention 2 install/self-check failed (non-fatal, will use xformers or PyTorch SDPA instead)" } Write-Output "Install completed" diff --git a/mikazuki/app/api.py b/mikazuki/app/api.py index 98bfd05a..578a7256 100644 --- a/mikazuki/app/api.py +++ b/mikazuki/app/api.py @@ -28,7 +28,7 @@ from mikazuki.train_log_hub import hub as train_log_hub from mikazuki.utils import train_utils from mikazuki.utils.devices import printable_devices -from mikazuki.portable_utils import flash_attn_stack_usable, is_embedded_python +from mikazuki.portable_utils import flash_attn_probe, flash_attn_stack_usable from mikazuki.utils.tk_window import (open_directory_selector, open_file_selector, tkinter_available) @@ -252,7 +252,7 @@ def is_preview_enabled(config: dict) -> bool: def _detect_best_attn_mode() -> str: """Auto-detect the best available attention backend for Anima training.""" - if not is_embedded_python() and flash_attn_stack_usable(): + if flash_attn_stack_usable(): return "flash" try: import xformers # noqa: F401 @@ -296,12 +296,13 @@ def apply_anima_training_defaults(config: dict, model_train_type: str): f"falling back to '{best}'" ) elif requested_attn == "flash": - if is_embedded_python() or not flash_attn_stack_usable(): + usable, reason = flash_attn_probe() + if not usable: best = _detect_best_attn_mode() config["attn_mode"] = best log.warning( f"attn_mode='flash' requested but flash-attn is not available, " - f"falling back to '{best}'" + f"falling back to '{best}' ({reason})" ) diff --git a/mikazuki/portable_utils.py b/mikazuki/portable_utils.py index 91963d22..b239fbec 100644 --- a/mikazuki/portable_utils.py +++ b/mikazuki/portable_utils.py @@ -1,11 +1,23 @@ # -*- coding: utf-8 -*- -"""Helpers for Windows portable (embedded) Python — flash-attn needs triton, which does not work reliably here.""" +"""Helpers for portable Python and optional Flash Attention 2 support.""" from __future__ import annotations +import importlib.util import subprocess import sys -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Tuple + + +FLASH_ATTN_WHEEL_VERSION = "2.7.4.post1" +FLASH_ATTN_CUDA_TAG = "cu128" +FLASH_ATTN_TORCH_VERSION = "2.7.0" +TRITON_WINDOWS_SPEC = "triton-windows<3.4" +FLASH_ATTN_WHEEL_BASE = "flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE" +FLASH_ATTN_WHEEL_HOSTS = { + "global": "https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main", + "china": "https://hf-mirror.com/lldacing/flash-attention-windows-wheel/resolve/main", +} def is_embedded_python(executable: Optional[str] = None) -> bool: @@ -13,34 +25,52 @@ def is_embedded_python(executable: Optional[str] = None) -> bool: return "python_embeded" in exe or "python_embedded" in exe -def flash_attn_stack_usable() -> bool: - """True only when flash-attn and its triton ops import cleanly (not true on embedded Python).""" +def flash_attn_wheel_name(python_tag: Optional[str] = None) -> str: + tag = python_tag or f"cp{sys.version_info.major}{sys.version_info.minor}" + return f"{FLASH_ATTN_WHEEL_BASE}-{tag}-{tag}-win_amd64.whl" + + +def flash_attn_wheel_url(region: str = "global", python_tag: Optional[str] = None) -> str: + host = FLASH_ATTN_WHEEL_HOSTS.get(region, FLASH_ATTN_WHEEL_HOSTS["global"]) + return f"{host}/{flash_attn_wheel_name(python_tag)}" + + +def flash_attn_probe() -> Tuple[bool, str]: + """Return whether the flash-attn + Triton runtime can import cleanly.""" try: import triton # noqa: F401 import flash_attn # noqa: F401 from flash_attn.ops.triton.rotary import apply_rotary # noqa: F401 - return True - except Exception: - return False + return True, "flash-attn stack import OK" + except Exception as exc: # noqa: BLE001 - probe must never break startup + return False, f"{exc.__class__.__name__}: {exc}" + + +def flash_attn_stack_usable() -> bool: + """True only when flash-attn and its Triton ops import cleanly.""" + usable, _reason = flash_attn_probe() + return usable def sanitize_embedded_deps(log: Optional[Callable[[str], None]] = None) -> None: - """Remove flash-attn / triton from embedded Python if the stack cannot run.""" + """Remove flash-attn / triton from embedded Python only when the stack cannot run.""" if not is_embedded_python(): return - import importlib.util - has_flash = importlib.util.find_spec("flash_attn") is not None has_triton = importlib.util.find_spec("triton") is not None if not has_flash and not has_triton: return - if has_flash and flash_attn_stack_usable(): + + usable, reason = flash_attn_probe() + if has_flash and usable: + if log: + log("Portable package: flash-attn/Triton self-check passed; keeping Flash Attention 2 enabled.") return msg = ( "Portable package: removing incompatible flash-attn/triton " - "(training will use xformers or PyTorch SDPA)." + f"(self-check failed: {reason}; training will use xformers or PyTorch SDPA)." ) if log: log(msg) @@ -69,7 +99,6 @@ def train_env_overrides() -> Dict[str, str]: """Environment for training subprocesses on embedded Python.""" if not is_embedded_python(): return {} - return { - "TRANSFORMERS_ATTN_IMPLEMENTATION": "sdpa", - "XFORMERS_FORCE_DISABLE_TRITON": "1", - } + if flash_attn_stack_usable(): + return {} + return {"TRANSFORMERS_ATTN_IMPLEMENTATION": "sdpa"} diff --git a/run_gui_source.ps1 b/run_gui_source.ps1 index 2f7a5232..2f946673 100644 --- a/run_gui_source.ps1 +++ b/run_gui_source.ps1 @@ -18,24 +18,30 @@ else { Write-Host -ForegroundColor Blue "No virtual environment found, using system python..." } -# Auto-install flash-attn only when triton stack works (source/venv install; not portable embedded). +# Auto-install flash-attn when the pinned triton stack is missing or broken. python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary" 2>$null if ($LASTEXITCODE -ne 0) { python -c "import triton" 2>$null if ($LASTEXITCODE -ne 0) { Write-Host -ForegroundColor Cyan "Installing triton-windows..." - pip install "triton-windows<3.4" 2>$null + pip install "triton-windows<3.4" } python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary" 2>$null if ($LASTEXITCODE -ne 0) { Write-Host -ForegroundColor Cyan "Installing Flash Attention 2 (prebuilt wheel)..." - $whl = "flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-win_amd64.whl" - $url = "https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/$whl" - pip install $url 2>$null + $pyver = python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')" 2>$null + if ($pyver -match "^cp3(10|11|12)$") { + $whl = "flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-$pyver-$pyver-win_amd64.whl" + $url = "https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/$whl" + pip install $url 2>$null + } else { + pip install flash-attn --no-build-isolation 2>$null + } + python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary" 2>$null if ($LASTEXITCODE -eq 0) { - Write-Host -ForegroundColor Green "Flash Attention 2 installed successfully" + Write-Host -ForegroundColor Green "Flash Attention 2 installed and verified successfully" } else { - Write-Host -ForegroundColor Yellow "Flash Attention 2 install failed (non-fatal, using PyTorch SDPA)" + Write-Host -ForegroundColor Yellow "Flash Attention 2 install/self-check failed (non-fatal, using xformers or PyTorch SDPA)" } } } diff --git a/setup_environment.py b/setup_environment.py index c41f5158..2ee39798 100644 --- a/setup_environment.py +++ b/setup_environment.py @@ -14,6 +14,12 @@ import time import urllib.request +from mikazuki.portable_utils import ( + TRITON_WINDOWS_SPEC, + flash_attn_probe, + flash_attn_wheel_url, +) + # ──────────────────── Configuration ──────────────────── TORCH_VERSION = "2.7.0" @@ -178,7 +184,7 @@ def install_torch(region): def _filter_requirements(req_file): - """Read requirements.txt, filtering out packages incompatible with embedded Python.""" + """Read requirements.txt, filtering packages handled by the controlled FA2 setup step.""" skip_packages = {"triton-windows", "triton"} filtered_path = req_file + ".filtered" with open(req_file, "r", encoding="utf-8") as f: @@ -212,38 +218,24 @@ def install_requirements(region): return ok -_FLASH_ATTN_WHEEL = ( - "flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE" - "-cp310-cp310-win_amd64.whl" -) -_FLASH_ATTN_WHEEL_URLS = { - "global": ( - "https://huggingface.co/lldacing/flash-attention-windows-wheel" - f"/resolve/main/{_FLASH_ATTN_WHEEL}" - ), - "china": ( - "https://hf-mirror.com/lldacing/flash-attention-windows-wheel" - f"/resolve/main/{_FLASH_ATTN_WHEEL}" - ), -} - - -def _is_embedded_portable(): - exe = _python_exe().replace("\\", "/").lower() - return "python_embeded" in exe or "python_embedded" in exe - - def install_flash_attn(region): - """Install flash-attn from prebuilt wheel. Non-fatal on failure. + """Install and verify the pinned flash-attn + Triton stack. Non-fatal on failure.""" + if sys.platform == "win32" and not _run_pip( + ["install", TRITON_WINDOWS_SPEC, "--no-warn-script-location"] + ): + return False - Skipped on embedded portable Python: prebuilt flash-attn still imports triton - kernels at runtime, but triton cannot compile on embedded Python (no dev headers). - """ - if _is_embedded_portable(): + url = flash_attn_wheel_url(region) + if not _run_pip(["install", url, "--no-warn-script-location"]): return False - url = _FLASH_ATTN_WHEEL_URLS.get(region, _FLASH_ATTN_WHEEL_URLS["global"]) - args = ["install", url, "--no-warn-script-location"] - return _run_pip(args) + + usable, reason = flash_attn_probe() + if usable: + return True + + print(f" >>> Flash Attention 2 self-check failed: {reason}") + _run_pip(["uninstall", "flash-attn", "flash_attn", "triton-windows", "triton", "-y"]) + return False def write_mirror_env(region): @@ -338,16 +330,14 @@ def main(): return 1 _ok("训练组件安装完成") - # 5 — flash-attn (optional acceleration; skipped on embedded portable) + # 5 — flash-attn (optional acceleration; falls back cleanly on failure) _separator() _step(5, "安装 Flash Attention 2 训练加速 (可选)...") print() - if _is_embedded_portable(): - print(" >>> 便携包跳过 Flash Attention 2(使用 xformers / PyTorch SDPA,避免 triton 依赖)") - elif install_flash_attn(region): + if install_flash_attn(region): _ok("Flash Attention 2 安装成功,训练将自动启用加速") else: - print(" >>> Flash Attention 2 安装失败(不影响训练,将使用 PyTorch SDPA)") + print(" >>> Flash Attention 2 安装/自检失败(不影响训练,将使用 xformers / PyTorch SDPA)") # Verify _separator() diff --git a/tests/test_portable_utils.py b/tests/test_portable_utils.py new file mode 100644 index 00000000..c59b584e --- /dev/null +++ b/tests/test_portable_utils.py @@ -0,0 +1,60 @@ +import unittest +from unittest import mock + +from mikazuki import portable_utils + + +class PortableFlashAttentionTests(unittest.TestCase): + def test_flash_attn_wheel_url_uses_python_tag_and_region(self): + url = portable_utils.flash_attn_wheel_url("china", "cp310") + + self.assertIn("hf-mirror.com", url) + self.assertTrue(url.endswith("-cp310-cp310-win_amd64.whl")) + self.assertIn("cu128torch2.7.0", url) + + def test_sanitize_embedded_deps_keeps_usable_stack(self): + with ( + mock.patch.object(portable_utils, "is_embedded_python", return_value=True), + mock.patch.object(portable_utils.importlib.util, "find_spec", return_value=object()), + mock.patch.object(portable_utils, "flash_attn_probe", return_value=(True, "ok")), + mock.patch.object(portable_utils.subprocess, "run") as run, + ): + portable_utils.sanitize_embedded_deps() + + run.assert_not_called() + + def test_sanitize_embedded_deps_removes_broken_stack(self): + with ( + mock.patch.object(portable_utils, "is_embedded_python", return_value=True), + mock.patch.object(portable_utils.importlib.util, "find_spec", return_value=object()), + mock.patch.object(portable_utils, "flash_attn_probe", return_value=(False, "missing triton")), + mock.patch.object(portable_utils.subprocess, "run") as run, + ): + portable_utils.sanitize_embedded_deps() + + run.assert_called_once() + args = run.call_args.args[0] + self.assertIn("uninstall", args) + self.assertIn("flash-attn", args) + self.assertIn("triton-windows", args) + + def test_train_env_overrides_do_not_disable_working_flash_stack(self): + with ( + mock.patch.object(portable_utils, "is_embedded_python", return_value=True), + mock.patch.object(portable_utils, "flash_attn_stack_usable", return_value=True), + ): + self.assertEqual(portable_utils.train_env_overrides(), {}) + + def test_train_env_overrides_fallback_to_sdpa_when_broken(self): + with ( + mock.patch.object(portable_utils, "is_embedded_python", return_value=True), + mock.patch.object(portable_utils, "flash_attn_stack_usable", return_value=False), + ): + self.assertEqual( + portable_utils.train_env_overrides(), + {"TRANSFORMERS_ATTN_IMPLEMENTATION": "sdpa"}, + ) + + +if __name__ == "__main__": + unittest.main()