Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,25 @@ FROM nvcr.io/nvidia/pytorch:24.07-py3
EXPOSE 28000

ENV TZ=Asia/Shanghai
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && apt update && apt install python3-tk -y
ENV MAX_JOBS=4
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && \
apt update && apt install python3-tk ninja-build -y && \
rm -rf /var/lib/apt/lists/*

RUN mkdir /app

WORKDIR /app
RUN git clone https://github.com/wochenlong/lora-scripts-next.git lora-scripts

WORKDIR /app/lora-scripts
RUN pip install xformers==0.0.27.post2 --no-deps && pip install -r requirements.txt
RUN pip install xformers==0.0.27.post2 --no-deps && \
pip install -r requirements.txt && \
pip install flash-attn==2.7.4.post1 --no-build-isolation && \
python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary; print('Flash Attention 2 OK')"

WORKDIR /app/lora-scripts/scripts
RUN pip install -r requirements.txt

WORKDIR /app/lora-scripts

CMD ["python", "gui.py", "--listen"]
CMD ["python", "gui.py", "--listen"]
19 changes: 14 additions & 5 deletions Dockerfile-for-Mainland-China
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ FROM nvcr.io/nvidia/pytorch:24.07-py3
EXPOSE 28000

ENV TZ=Asia/Shanghai
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && apt update && apt install python3-tk -y
ENV MAX_JOBS=4
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && \
apt update && apt install python3-tk ninja-build -y && \
rm -rf /var/lib/apt/lists/*

RUN mkdir /app

Expand All @@ -17,7 +20,10 @@ RUN pip config set global.index-url 'https://pypi.tuna.tsinghua.edu.cn/simple' &
pip config set install.trusted-host 'pypi.tuna.tsinghua.edu.cn'

# 初次安装依赖
RUN pip install xformers==0.0.27.post2 --no-deps && pip install -r requirements.txt
RUN pip install xformers==0.0.27.post2 --no-deps && \
pip install -r requirements.txt && \
pip install flash-attn==2.7.4.post1 --no-build-isolation && \
python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary; print('Flash Attention 2 OK')"

# 更新 训练程序 stable 版本依赖
WORKDIR /app/lora-scripts/scripts/stable
Expand All @@ -33,7 +39,10 @@ WORKDIR /app/lora-scripts
# ref
# - https://soulteary.com/2024/01/07/fix-opencv-dependency-errors-opencv-fixer.html
# - https://blog.csdn.net/qq_50195602/article/details/124188467
RUN pip install opencv-fixer==0.2.5 && python -c "from opencv_fixer import AutoFix; AutoFix()" \
pip install opencv-python-headless && apt install ffmpeg libsm6 libxext6 libgl1 -y
RUN pip install opencv-fixer==0.2.5 && \
python -c "from opencv_fixer import AutoFix; AutoFix()" && \
pip install opencv-python-headless && \
apt update && apt install ffmpeg libsm6 libxext6 libgl1 -y && \
rm -rf /var/lib/apt/lists/*

CMD ["python", "gui.py", "--listen"]
CMD ["python", "gui.py", "--listen"]
18 changes: 8 additions & 10 deletions README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,16 @@

> 512 分辨率约节省 2~3 GB;降低 `network_dim`(如 8)也能少量减少显存。

#### 整合包暂不支持 Flash Attention 2(说明)

**当前 Windows 整合包(`SD-Trainer-v*.7z`)不会安装 Flash Attention 2,训练使用 xformers 或 PyTorch SDPA。** 这与「装不上」无关,而是便携包运行方式下的**刻意取舍**。
#### 整合包 Flash Attention 2 支持

| 点 | 说明 |
|----|------|
| **flash-attn 依赖 triton** | 预编译的 `flash-attn` wheel 能装进环境,但运行时大量算子仍通过 `flash_attn.ops.triton` 调用 **Triton** 生成的 CUDA kernel。 |
| **嵌入式 Python 跑不好 triton** | 整合包使用 Python Embeddable(`python_embeded\`),缺少完整编译链;`triton` / `triton-windows` 常在首次 JIT 时失败,导致启动或训练崩溃。 |
| **不能只卸 triton、保留 flash-attn** | 若只安装 `flash-attn` 而不装 `triton`,import 时会报 `No module named 'triton'`;`transformers` 等库探测到已安装的 `flash_attn` 也可能仍尝试走 flash 路径。 |
| **整合包实际策略** | 首次安装跳过 flash-attn;若用户手动 `pip install` 了不完整的组合,启动时会自动卸载 flash-attn / triton,并设置 `TRANSFORMERS_ATTN_IMPLEMENTATION=sdpa`。 |
| **嵌入式 Python + triton** | 整合包使用 Python Embeddable(`python_embeded\`),因此固定安装 `triton-windows<3.4` 与匹配 PyTorch 2.7/CUDA 12.8 的 Flash Attention 2 wheel。 |
| **先自检再启用** | 首次安装和 `install_flash_attn.bat` 都会验证 `import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary`。 |
| **失败回退** | 若自检失败,启动时会自动卸载不完整的 flash-attn / triton 组合,训练回退到 **xformers** 或 **PyTorch SDPA**。 |

**需要 Flash Attention 2 时:** 请使用下方「[从源码安装](#从源码安装)」并按 **[Flash Attention 2(源码 / venv 用户)](#flash-attention-2源码--venv-用户)** 配置;整合包在 embed Python 支持成熟前**暂不承诺** flash-attn 加速
自检通过后,Anima / SD3 LoRA 会自动选择 `attn_mode=flash`;失败时日志会说明原因并继续训练

### 从源码安装

Expand Down Expand Up @@ -108,7 +106,7 @@ python gui.py --browser edge

#### Flash Attention 2(源码 / venv 用户)

**整合包用户请看上节,不要对 `python_embeded` 手动安装 flash-attn。**
整合包用户可通过首次启动或 `install_flash_attn.bat` 使用同一套固定组合;源码用户也可按下方命令手动安装。

本节适用于:`git clone` 后使用 **`venv`**(或 `python\` 目录下的完整 Python),且已安装 **PyTorch 2.7.0 + CUDA 12.8** 的源码用户。

Expand Down Expand Up @@ -191,15 +189,15 @@ python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary i
| wheel 安装成功但训练仍用 xformers | 运行上方验证命令;若失败说明 triton 与 flash-attn 未配对,勿只保留 flash-attn |
| `pip install flash-attn` 编译很久或失败 | Windows 请改用 **预编译 wheel**(上表 URL),不要在本机编译 |
| PyTorch 版本不是 2.7+cu128 | wheel 与 CUDA 标签不匹配,请对齐 `install.ps1` 中的 torch 版本后再装 flash-attn |
| 在整合包 `python_embeded` 里安装 | **不支持**,请改用源码 + venv |
| 整合包自检失败 | 更新显卡驱动或依赖后重新运行 `install_flash_attn.bat`;损坏组合会在启动时自动清理 |

---

## 功能亮点

- **多模型支持** — SD 1.5 / SDXL / Flux / **Anima** 全部开箱即用
- **Anima LoRA 训练** — 侧边栏一键进入,支持 LoRA / LoKr(LyCORIS)/ **T-LoRA**
- **Attention 加速** — 自动选择后端:源码/venv 环境优先 Flash Attention 2(Windows 预编译 wheel);**整合包**使用 xformers / PyTorch SDPA([暂不支持 flash-attn](#整合包暂不支持-flash-attention-2说明))
- **Attention 加速** — 自动选择后端:源码/venv、整合包、Docker 均在自检通过时优先 Flash Attention 2,失败时回退到 xformers / PyTorch SDPA
- **T-LoRA** — 基于扩散时间步的动态 Rank LoRA,正交初始化,防止过拟合([论文](https://github.com/ControlGenAI/T-LoRA))
- **训练监控页** — 随 GUI 自动启动,展示 TensorBoard 同源 Loss / LR 曲线、关键训练参数速查、实时进度、终端日志同步和预览图
- **TensorBoard 内置** — 侧边栏直接查看,无需额外操作
Expand Down
18 changes: 8 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,16 @@ Measured on RTX 4090, batch=1, bf16, standard LoRA (dim=16):

> 512 resolution saves roughly 2–3 GB; lowering `network_dim` (e.g. to 8) also helps marginally.

#### Portable package: Flash Attention 2 not supported (for now)

The **Windows portable package** (`SD-Trainer-v*.7z`) **does not install Flash Attention 2**; training uses **xformers** or **PyTorch SDPA**. This is intentional, not a failed install.
#### Portable package: Flash Attention 2

| Point | Why |
|-------|-----|
| **flash-attn needs triton** | Prebuilt `flash-attn` wheels install, but many kernels still run via **Triton** (`flash_attn.ops.triton`). |
| **Embedded Python + triton** | The portable bundle uses Python Embeddable (`python_embeded\`) without a full toolchain; `triton` / `triton-windows` often fail at JIT compile time. |
| **Cannot keep flash-attn without triton** | Flash-attn-only installs hit `No module named 'triton'`; `transformers` may still probe `flash_attn` if the package is present. |
| **What we do** | Skip flash-attn on first setup; on launch, remove broken flash-attn/triton pairs and set `TRANSFORMERS_ATTN_IMPLEMENTATION=sdpa`. |
| **Embedded Python + triton** | The portable bundle uses Python Embeddable (`python_embeded\`), so the package pins `triton-windows<3.4` and the matching Flash Attention 2 wheel. |
| **Self-check first** | First setup and `install_flash_attn.bat` verify `import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary`. |
| **Fallback** | If the self-check fails, startup removes the broken flash-attn/triton pair and training falls back to **xformers** or **PyTorch SDPA**. |

For Flash Attention 2, use **[install from source](#install-from-source)** and follow **[Flash Attention 2 (source / venv)](#flash-attention-2-source--venv-installs)**. Portable flash-attn support may come later when embed Python + triton is reliable.
When the self-check passes, Anima / SD3 LoRA auto-selects `attn_mode=flash`; otherwise logs explain the fallback.

### Install from Source

Expand Down Expand Up @@ -96,7 +94,7 @@ python gui.py --browser edge

#### Flash Attention 2 (source / venv installs)

**Portable users:** see the section above — do not `pip install flash-attn` into `python_embeded`.
Portable users get the same pinned stack through first-run setup or `install_flash_attn.bat`; source users can also install it manually.

This section is for **`git clone` + `venv`** (or a full Python under `python\`), with **PyTorch 2.7.0 + CUDA 12.8** installed.

Expand Down Expand Up @@ -177,15 +175,15 @@ Then run `python gui.py` and start **Anima LoRA** training — logs should show
| Wheel installs but training uses xformers | Run the verify command above; flash-attn without working triton is ignored |
| Long compile or build errors | On Windows use the **prebuilt wheel** URLs, not `pip install flash-attn` from source |
| PyTorch not 2.7+cu128 | Align torch with `install.ps1` before installing flash-attn |
| Installed into portable `python_embeded` | **Unsupported** — use source + venv instead |
| Portable self-check fails | Rerun `install_flash_attn.bat` after updating GPU drivers/dependencies; broken stacks are removed automatically |

---

## Features

- **Multi-model** — SD 1.5 / SDXL / Flux / **Anima** all work out of the box
- **Anima LoRA training** — One-click sidebar entry, supports LoRA / LoKr (LyCORIS) / **T-LoRA**
- **Attention backends** — Source/venv: Flash Attention 2 when available (Windows prebuilt wheels). **Portable package:** xformers / PyTorch SDPA only ([flash-attn not supported yet](#portable-package-flash-attention-2-not-supported-for-now))
- **Attention backends** — Source/venv, portable package, and Docker all prefer Flash Attention 2 when the stack self-checks OK, then fall back to xformers / PyTorch SDPA
- **T-LoRA** — Timestep-Dependent LoRA with dynamic rank and orthogonal init ([paper](https://github.com/ControlGenAI/T-LoRA))
- **Train Monitor** — Auto-opens with GUI, TensorBoard-backed Loss / LR scalar cards, key training parameter checks, real-time progress, terminal log echo, and preview samples
- **Built-in TensorBoard** — Accessible from the sidebar, no extra setup
Expand Down
30 changes: 28 additions & 2 deletions build-scripts/build_portable.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ $updateDepsBat = "@echo off`r`nchcp 65001 >nul 2>&1`r`ncd /d `"%~dp0..`"`r`n"
$updateDepsBat += "echo Updating Python dependencies...`r`n"
$updateDepsBat += "`"python_embeded\python.exe`" -s -m pip install --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cu128`r`n"
$updateDepsBat += "`"python_embeded\python.exe`" -s -m pip install --upgrade -r `"SD-Trainer\requirements.txt`"`r`n"
$updateDepsBat += "`"python_embeded\python.exe`" -s -m pip install `"triton-windows<3.4`" --no-warn-script-location`r`n"
$updateDepsBat += "`"python_embeded\python.exe`" -s -m pip install https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-win_amd64.whl --no-warn-script-location`r`n"
$updateDepsBat += "`"python_embeded\python.exe`" -s -c `"import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary; print('Flash Attention 2 OK')`"`r`n"
$updateDepsBat += "echo Done.`r`npause`r`n"
[System.IO.File]::WriteAllText(
(Join-Path $updateDir "update_dependencies.bat"),
Expand Down Expand Up @@ -382,6 +385,28 @@ $xformersBat += "echo.`r`necho Done! You can now use attn_mode = xformers.`r`ne
)
Write-Host " Created install_xformers.bat"

# install_flash_attn.bat — one-click Flash Attention 2 installer for portable users
$flashAttnBat = "@echo off`r`nchcp 65001 >nul 2>&1`r`ntitle Install Flash Attention 2`r`ncd /d `"%~dp0`"`r`n"
$flashAttnBat += "set `"PYTHON_EXE=%~dp0python_embeded\python.exe`"`r`n"
$flashAttnBat += "if not exist `"%PYTHON_EXE%`" (`r`n"
$flashAttnBat += " echo [ERROR] python_embeded\python.exe not found!`r`n"
$flashAttnBat += " pause`r`n exit /b 1`r`n)`r`n"
$flashAttnBat += "echo.`r`necho Installing Flash Attention 2 for Torch 2.7.0 + CUDA 12.8 ...`r`necho.`r`n"
$flashAttnBat += "`"%PYTHON_EXE%`" -s -m pip install `"triton-windows<3.4`" --no-warn-script-location`r`n"
$flashAttnBat += "if errorlevel 1 (`r`n echo [ERROR] triton-windows installation failed.`r`n pause`r`n exit /b 1`r`n)`r`n"
$flashAttnBat += "`"%PYTHON_EXE%`" -s -m pip install https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-win_amd64.whl --no-warn-script-location`r`n"
$flashAttnBat += "if errorlevel 1 (`r`n echo [ERROR] flash-attn wheel installation failed.`r`n pause`r`n exit /b 1`r`n)`r`n"
$flashAttnBat += "echo.`r`necho Verifying...`r`n"
$flashAttnBat += "`"%PYTHON_EXE%`" -s -c `"import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary; print(' Flash Attention 2 OK')`"`r`n"
$flashAttnBat += "if errorlevel 1 (`r`n echo [ERROR] Flash Attention 2 self-check failed. Training will fall back to xformers or PyTorch SDPA.`r`n `"%PYTHON_EXE%`" -s -m pip uninstall flash-attn flash_attn triton-windows triton -y >nul 2>&1`r`n pause`r`n exit /b 1`r`n)`r`n"
$flashAttnBat += "echo.`r`necho Done! Anima training can now use attn_mode = flash.`r`necho.`r`npause`r`n"
[System.IO.File]::WriteAllText(
(Join-Path $portableDir "install_flash_attn.bat"),
$flashAttnBat,
(New-Object System.Text.UTF8Encoding $false)
)
Write-Host " Created install_flash_attn.bat"

# Root-level utility bat files
$templateDir = Join-Path $PSScriptRoot "templates"
foreach ($bat in @("Update-SD-Trainer.bat", "Download-Anima-Model.bat")) {
Expand Down Expand Up @@ -428,8 +453,9 @@ $readme += "xformers (recommended):`r`n"
$readme += " If xformers is missing, double-click install_xformers.bat to install.`r`n"
$readme += " xformers provides faster attention than PyTorch SDPA on most GPUs.`r`n`r`n"
$readme += "Flash Attention 2:`r`n"
$readme += " This portable package does NOT use flash-attn (uses xformers / PyTorch SDPA).`r`n"
$readme += " Do not pip install flash-attn into python_embeded. See README in SD-Trainer/.`r`n"
$readme += " First launch installs the pinned flash-attn + triton-windows stack when available.`r`n"
$readme += " If the self-check fails, startup removes the broken stack and falls back to xformers / PyTorch SDPA.`r`n"
$readme += " You can rerun install_flash_attn.bat after updating GPU drivers or dependencies.`r`n"
[System.IO.File]::WriteAllText(
(Join-Path $portableDir "README.txt"),
$readme,
Expand Down
8 changes: 5 additions & 3 deletions install-cn.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,21 @@ python -m pip install --upgrade -r requirements.txt
Check "训练依赖库安装失败。"


Write-Output "Installing Flash Attention 2 (prebuilt wheel)..."
Write-Output "Installing Flash Attention 2 stack (triton-windows + prebuilt wheel)..."
$pyver = python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')" 2>$null
python -m pip install "triton-windows<3.4"
if ($pyver -match "^cp3(10|11|12)$") {
$whl = "flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-$pyver-$pyver-win_amd64.whl"
$url = "https://hf-mirror.com/lldacing/flash-attention-windows-wheel/resolve/main/$whl"
python -m pip install $url 2>$null
} else {
python -m pip install flash-attn --no-build-isolation 2>$null
}
python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary" 2>$null
if ($LASTEXITCODE -eq 0) {
Write-Output "Flash Attention 2 installed"
Write-Output "Flash Attention 2 installed and verified"
} else {
Write-Output "Flash Attention 2 install failed (non-fatal)"
Write-Output "Flash Attention 2 install/self-check failed (non-fatal)"
}

Write-Output "安装完成"
Expand Down
14 changes: 11 additions & 3 deletions install.bash
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,16 @@ cd "$script_dir" || exit
pip install --upgrade -r requirements.txt

echo "Installing Flash Attention 2 (optional, for training acceleration)..."
pip install flash-attn --no-build-isolation 2>/dev/null && \
echo "Flash Attention 2 installed successfully" || \
echo "Flash Attention 2 install failed (non-fatal, will use PyTorch SDPA)"
export MAX_JOBS="${MAX_JOBS:-4}"
if pip install "flash-attn==2.7.4.post1" --no-build-isolation; then
if python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary" 2>/dev/null; then
echo "Flash Attention 2 installed and verified successfully"
else
echo "Flash Attention 2 self-check failed (non-fatal, will use xformers or PyTorch SDPA)"
pip uninstall flash-attn flash_attn triton -y >/dev/null 2>&1 || true
fi
else
echo "Flash Attention 2 install failed (non-fatal, will use xformers or PyTorch SDPA)"
fi

echo "Install completed"
8 changes: 5 additions & 3 deletions install.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,21 @@ pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --extra-index-url https
pip install -U -I --no-deps xformers==0.0.30 --extra-index-url https://download.pytorch.org/whl/cu128
pip install --upgrade -r requirements.txt

Write-Output "Installing Flash Attention 2 (prebuilt wheel)..."
Write-Output "Installing Flash Attention 2 stack (triton-windows + prebuilt wheel)..."
$pyver = python -c "import sys; print(f'cp{sys.version_info.major}{sys.version_info.minor}')" 2>$null
pip install "triton-windows<3.4"
if ($pyver -match "^cp3(10|11|12)$") {
$whl = "flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-$pyver-$pyver-win_amd64.whl"
$url = "https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/$whl"
pip install $url 2>$null
} else {
pip install flash-attn --no-build-isolation 2>$null
}
python -c "import triton; import flash_attn; from flash_attn.ops.triton.rotary import apply_rotary" 2>$null
if ($LASTEXITCODE -eq 0) {
Write-Output "Flash Attention 2 installed successfully"
Write-Output "Flash Attention 2 installed and verified successfully"
} else {
Write-Output "Flash Attention 2 install failed (non-fatal, will use PyTorch SDPA instead)"
Write-Output "Flash Attention 2 install/self-check failed (non-fatal, will use xformers or PyTorch SDPA instead)"
}

Write-Output "Install completed"
Expand Down
Loading