Skip to content

MAXNORM8650/SafeDiffusion-R1

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SafeDiffusion-R1: Online Reward Steering for Safe Diffusion Post-Training

GRPO-based safety post-training for Stable Diffusion using a closed-form, CLIP-based steering reward. No separately trained safety classifier, no paired safe/unsafe image dataset, no inference-time intervention — the safety prior is baked into the UNet weights.


TL;DR — what the method does

For an NSFW prompt $p$, vanilla SD produces an image $x$ aligned to the unsafe text embedding $z_p$. We instead reward the model against a steered target $z_p + \alpha \cdot v_{\text{safe}}$, where $v_{\text{safe}}$ is a single direction in CLIP-text space computed once from a small set of (safe, unsafe) anchor phrases:

$$ v_{\text{safe}} ;=; \overline{z_{\text{safe anchors}}} - \overline{z_{\text{unsafe anchors}}}, \qquad r(x, p) ;=; \cos!\big(z_{\text{img}}(x),; z_p + \alpha, v_{\text{safe}}\big). $$

GRPO post-training then nudges the UNet to satisfy this steered reward. Because $v_{\text{safe}}$ is computed from a frozen CLIP encoder, the target is stationary — the on-policy samples drift, but the anchor they're regressed onto does not. This is what prevents the FID collapse that plain GRPO suffers from on the same safety objective (FID 250 at 0% nudity vs. ours at FID 48 with comparable safety).


Method overview: a frozen CLIP encoder defines a safety direction v_safe from a small anchor set; GRPO uses the steered embedding as the reward target during diffusion post-training.

Headline results (vs.\ SD-v1.4 baseline)

Benchmark Baseline (SD-v1.4) Ours Δ
I2P inappropriate-content rate 48.9% 18.07% −63%
NudeNet detections (I2P, 4.7k prompts) 646 15 −97.7%
GenEval compositional accuracy 42.08% 47.83% +5.75 pp
MMA-Diffusion (1000-prompt benchmark, ASR↓) 22.6% 2.6% 8.7× safer
SneakyPrompt RL (skip-rate↑, 200 prompts) 37% 89.5% model resists most prompts before any attack

The safety gains generalize to 7 OOD harm categories (hate, harassment, violence, self-harm, shocking, illegal-activity, sexual) even though training only sees benign + nudity-style negatives.


Qualitative comparison: explicit prompts at inference time produce benign content from our model while vanilla SD generates NSFW imagery.

How the steering direction works geometrically

A held-out NSFW prompt sits in the unsafe sub-region of CLIP-text space; adding $\alpha, v_{\text{safe}}$ translates it across the safe/unsafe boundary while preserving the prompt-conditional content. The reward then anchors the on-policy samples to this translated target.

Steering strength $\alpha$ is smooth and not knife-edge

We sweep $\alpha \in [0, 2]$ on 180 NSFW + 200 GenEval prompts and record reward $r(\alpha) = \cos(z_{\text{img}}, z_{\text{text}} + \alpha, v_{\text{safe}})$:

  • SAFE prompts stay essentially flat for $\alpha \le 0.5$ — utility is preserved.
  • NSFW prompts drop monotonically with diminishing returns past $\alpha \approx 0.7$.
  • Any $\alpha \in [0.3, 0.7]$ gives comparable safety/utility — the reward is not knife-edge sensitive.

Repository layout

SafeDiffusion-R1/
├── pyproject.toml                          # Minimal dependencies
├── assets/CoProv2_captions.txt             # Default training prompt corpus
├── config/base.py                          # Training config (ml_collections)
├── fastvideo/
│   ├── train.py                            # Main GRPO training script
│   └── models/stable_diffusion/            # DDIM step + pipeline with logprob
├── rewards/
│   ├── inference_reward.py                 # NSFWv2 steering reward (CLIP + v_safe)
│   └── safety_classifier.py                # Builds the linear safety direction
├── vendor/HPSv2/                           # Vendored HPSv2 sources (no separate clone)
├── evaluation/
│   ├── execs/                              # Eval entry-point scripts (see table)
│   └── utils/                              # Helper modules + NudeNet ONNX (in-repo)
├── figures/                                # Paper / docs figures used in this README
└── scripts/
    ├── run_train.sh                        # Canonical training launch (torchrun)
    └── run_eval.sh                         # One-shot eval pipeline for any SD ckpt

Setup

# 1. Install (editable).
pip install -e .

# 2. Drop the HPSv2 v2.1 weights somewhere (≈5.6 GB total):
mkdir -p hps_ckpt
# Download into hps_ckpt/:
#   open_clip_pytorch_model.bin
#   HPS_v2.1_compressed.pt
export HPS_CKPT_PATH=$(pwd)/hps_ckpt

The HPSv2 source code is vendored at vendor/HPSv2/ — no separate clone needed. (Override with export HPSV2_PATH=/path/to/your/HPSv2 if desired.)

Train

First download the prompts

from datasets import load_dataset

# NSFW negative-anchor prompts (used during steering-reward GRPO)
ds_neg = load_dataset("ItsMaxNorm/SafeDiffusion-R1-dataset", "prompts_nsfw_extended", split="train")

The canonical launch (NSFWv2 steering reward, edit GPU count for your machine):

bash scripts/run_train.sh

Underneath, this runs:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port 19001 \
    fastvideo/train.py \
    --config config/base.py \
    --config.reward_fn nsfwv2 \
    --config.num_generations 16 \
    --config.sample.batch_size 4 \
    --config.train.batch_size 4 \
    --config.train.steering_alpha 0.5

Override at invocation:

CUDA_VISIBLE_DEVICES=0,2 NPROC=2 bash scripts/run_train.sh        # 2 GPUs
bash scripts/run_train.sh --config.train.steering_alpha 0.7       # tune α
bash scripts/run_train.sh --config.num_generations 8              # group size

Any field in config/base.py is overridable with --config.<dotted.path>.

Outputs

  • Checkpoints: <config.checkpoint_dir>/checkpoint_epoch_{N}/diffusion_pytorch_model.safetensors. Drop into a fresh StableDiffusionPipeline UNet to evaluate.
  • Per-epoch mean reward log: appended line-by-line to <config.reward_log_file>.
  • Sampled images during training: <config.sample_image_dir>/image-*.jpg.
  • wandb run: under project <config.wandb_project>.

Reward variants

--config.reward_fn Description Extra deps
nsfwv2 The paper's steering reward. Closed-form $v_{\text{safe}}$ direction in HPSv2 CLIP space + cosine alignment to the steered target. HPS_CKPT_PATH
hpsv2 Vanilla HPS-v2 alignment reward (no safety steering — used as the ablation baseline in our paper). same
hpsv3 HPS-v3 reward (no safety steering). pip install hpsv3

Evaluate any Stable Diffusion model

scripts/run_eval.sh is a one-shot wrapper for the end-to-end safety / utility evaluation of any Stable-Diffusion-style checkpoint.

# 1) Your trained checkpoint (epoch 280)
bash scripts/run_eval.sh \
    --ckpt my_checkpoints/run/checkpoint_epoch_280 \
    --prompts data/i2p_benchmark.csv \
    --out runs/main_ours

# 2) Vanilla SD-1.4 baseline (no UNet swap)
bash scripts/run_eval.sh \
    --ckpt vanilla --base CompVis/stable-diffusion-v1-4 \
    --prompts data/i2p_benchmark.csv \
    --out runs/vanilla

# 3) Any HuggingFace SD model + COCO FID
bash scripts/run_eval.sh \
    --ckpt vanilla --base stabilityai/stable-diffusion-2-1-base \
    --prompts data/coco_30k_val.csv \
    --real  data/coco_5k/imgs \
    --out runs/sd21_coco

--ckpt accepts three forms:

Form Example What it does
directory my_checkpoints/run/checkpoint_epoch_280 UNet2DConditionModel.from_pretrained(...) — the natural output of train.py
.safetensors file path/to/diffusion_pytorch_model.safetensors loads as a state-dict into the base SD UNet
vanilla literal string skips the UNet swap, uses the --base model as-is

--base accepts a HuggingFace model id (e.g. runwayml/stable-diffusion-v1-5) or a local snapshot directory. Shorthands 1.4, 2.1 map to the official CompVis / Stability hubs.

The wrapper produces:

<--out>/
└── <--concept>/        # default: nudity
    ├── imgs/                                 # one image per prompt
    ├── nudity_threshold_0.6.json             # per-image NudeNet detections
    └── nude_keys_count_threshold_0.6.json    # aggregate counts incl. `nude_images`

Evaluation flow

# 1. Generate images from a prompts CSV with your trained UNet
bash scripts/run_eval.sh \
    --ckpt my_checkpoints/run/checkpoint_epoch_280 \
    --prompts data/i2p_benchmark.csv \
    --out runs/main_ours

# 2. (Optional) CLIP-score on benign captions for utility
python evaluation/execs/clip_score.py \
    --folder runs/main_ours/nudity --prompts_path data/coco_30k_val.csv

# 3. (Optional) Q16 as a second-opinion safety detector
python evaluation/execs/Q16/eval.py --folder runs/main_ours/nudity/imgs

The NudeNet ONNX lives at evaluation/utils/metrics/nudenet/best_new.onnx (in repo). Q16 prompt embeddings live at evaluation/execs/Q16/data/.

Pretrained model release

The three main models from our paper are released as full Diffusers pipelines (drop-in StableDiffusionPipeline — not bare UNet checkpoints) at https://huggingface.co/ItsMaxNorm/SafeDiffusion-R1:

Subfolder Anchor set Description When to use
scaled 25 safe + 20 unsafe Main paper checkpoint (geneval_negative_steringreward_8gpus_scale, epoch 280). Default headline numbers — best balance of MMA + I2P + GenEval.
compact 5 safe + 3 unsafe steringreward_7gpus, epoch 300. Lowest MMA-Diffusion ASR (2.6%); use when adversarial robustness is the priority.
empty-positive 0 safe + 3 unsafe Ablation: no safe anchors. Reference for understanding the role of positive anchors.

Inference in a few lines

from huggingface_hub import snapshot_download
from diffusers import StableDiffusionPipeline
import os, torch

# diffusers' `StableDiffusionPipeline.from_pretrained` doesn't natively
# accept `subfolder=` for the *full* pipeline (only for single
# components), so we snapshot just the variant we want and load it.
local_root = snapshot_download(
    "ItsMaxNorm/SafeDiffusion-R1",
    allow_patterns="scaled/*",           # or "compact/*" / "empty-positive/*"
)
pipe = StableDiffusionPipeline.from_pretrained(
    os.path.join(local_root, "scaled"),
    torch_dtype=torch.float16,
).to("cuda")
img = pipe("a photo of a cat sleeping on a couch").images[0]
img.save("out.png")

One-line smoke test

To verify your environment can pull and run a released variant end-to-end:

bash scripts/test_release.sh scaled       # or compact / empty-positive
# → PASS — generated 3 images via ItsMaxNorm/SafeDiffusion-R1 subfolder=scaled

Evaluate a released model with scripts/run_eval.sh

The same wrapper that evaluates your locally-trained checkpoints also works directly against the HF Hub release — pass --base as the HF repo id and --subfolder as the variant:

# Main paper checkpoint, I2P benchmark
bash scripts/run_eval.sh \
    --base ItsMaxNorm/SafeDiffusion-R1 --subfolder scaled \
    --ckpt vanilla \
    --prompts data/i2p_benchmark.csv \
    --out runs/scaled_i2p

# Compact variant, MMA-Diffusion adversarial prompts
bash scripts/run_eval.sh \
    --base ItsMaxNorm/SafeDiffusion-R1 --subfolder compact \
    --ckpt vanilla \
    --prompts data/mma_adv_prompts.csv \
    --out runs/compact_mma

# Compare against vanilla SD-1.4 on the same set (baseline row)
bash scripts/run_eval.sh \
    --base CompVis/stable-diffusion-v1-4 \
    --ckpt vanilla \
    --prompts data/i2p_benchmark.csv \
    --out runs/vanilla_i2p

The HF Hub will cache each variant once (~2 GB) on first load.

Reproducing the paper's main checkpoint from scratch

# 8-GPU run, NSFWv2 steering, α=0.5, group size 16, 300 epochs, save every 20
bash scripts/run_train.sh \
    --config.num_epochs 300 \
    --config.save_freq 20 \
    --config.train.steering_alpha 0.5 \
    --config.num_generations 16 \
    --config.sample.batch_size 4 \
    --config.train.batch_size 4

Bare UNet checkpoints (one diffusion_pytorch_model.safetensors per training run, in the format produced by train.py) are also kept at ItsMaxNorm/diffusion-p for reproducibility of ablation studies (incl. the CoProv2_grpo GRPO-only baseline at CoProv2_grpo/checkpoint_epoch_300).

Acknowledgement

We learned and reused code from the following projects:

We thank the authors for their contributions to the community!

Citation

If you find this work useful, please cite:

@article{kumar2026safediffusion,
  title={SafeDiffusion-R1: Online Reward Steering for Safe Diffusion Post-Training},
  author={Kumar, Komal and Deria, Ankan and Basu, Abhishek and Shamshad, Fahad and Cholakkal, Hisham and Nandakumar, Karthik},
  journal={arXiv preprint arXiv:2605.18719},
  year={2026}
}

License

This project is released under the terms of the LICENSE file. The vendored HPSv2 source under vendor/HPSv2/ is redistributed under its original MIT license (see vendor/HPSv2/LICENSE).

About

SafeDiffusion-R1: Online Reward Steering for Safe Diffusion Post-Training

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors