Skip to content
19 changes: 16 additions & 3 deletions hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
from ...text_encoder import TextEncoder
from ...modules import HYVideoDiffusionTransformer

try:
import torch_musa
except Exception:
torch_musa = None

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

EXAMPLE_DOC_STRING = """"""
Expand Down Expand Up @@ -835,7 +840,15 @@ def __call__(
else:
batch_size = prompt_embeds.shape[0]

device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
if torch_musa is not None:
device_type = "musa"
else:
device_type = "cuda"

if dist.is_initialized():
device = torch.device(f"{device_type}:{dist.get_rank()}")
else:
device = self._execution_device

# 3. Encode input prompt
lora_scale = (
Expand Down Expand Up @@ -986,7 +999,7 @@ def __call__(

# predict the noise residual
with torch.autocast(
device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
device_type=device_type, dtype=target_dtype, enabled=autocast_enabled
):
noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
latent_model_input, # [2, 16, 33, 24, 42]
Expand Down Expand Up @@ -1069,7 +1082,7 @@ def __call__(
latents = latents / self.vae.config.scaling_factor

with torch.autocast(
device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
device_type=device_type, dtype=vae_dtype, enabled=vae_autocast_enabled
):
if enable_tiling:
self.vae.enable_tiling()
Expand Down
47 changes: 36 additions & 11 deletions hyvideo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline

try:
import torch_musa
except Exception:
torch_musa = None

try:
import xfuser
from xfuser.core.distributed import (
Expand Down Expand Up @@ -80,7 +85,11 @@ def new_forward(
from xfuser.core.long_ctx_attention import xFuserLongContextAttention

for block in transformer.double_blocks + transformer.single_blocks:
block.hybrid_seq_parallel_attn = xFuserLongContextAttention()
if torch_musa is not None:
from yunchang.kernels import AttnType
block.hybrid_seq_parallel_attn = xFuserLongContextAttention(attn_type=AttnType.TORCH)
else:
block.hybrid_seq_parallel_attn = xFuserLongContextAttention()

output = original_forward(
x,
Expand Down Expand Up @@ -130,13 +139,15 @@ def __init__(
self.use_cpu_offload = use_cpu_offload

self.args = args
self.device = (
device
if device is not None
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
if device is not None:
self.device = device
else:
if torch.cuda.is_available():
self.device = "cuda"
elif torch_musa is not None:
self.device = "musa"
else:
self.device = "cpu"
self.logger = logger
self.parallel_args = parallel_args

Expand All @@ -161,7 +172,12 @@ def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs):
assert args.use_cpu_offload is False, \
"Cannot enable use_cpu_offload in the distributed environment."

dist.init_process_group("nccl")
if torch_musa is not None:
dist_ccl = "mccl"
else:
dist_ccl = "nccl"

dist.init_process_group(dist_ccl)

assert dist.get_world_size() == args.ring_degree * args.ulysses_degree, \
"number of GPUs should be equal to ring_degree * ulysses_degree."
Expand All @@ -173,10 +189,19 @@ def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs):
ring_degree=args.ring_degree,
ulysses_degree=args.ulysses_degree,
)
device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}")
if torch_musa is not None:
device = "musa"
else:
device = "cuda"
device = torch.device(f"{device}:{os.environ['LOCAL_RANK']}")
else:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch_musa is not None:
device = torch.device("musa")
else:
device = torch.device("cpu")

parallel_args = {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree}

Expand Down
20 changes: 18 additions & 2 deletions hyvideo/modules/attenion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
flash_attn_varlen_func = None
_flash_attn_forward = None

try:
import torch_musa
except Exception:
torch_musa = None


MEMORY_LAYOUT = {
"flash": (
Expand Down Expand Up @@ -45,7 +50,7 @@ def get_cu_seqlens(text_mask, img_len):
text_len = text_mask.sum(dim=1)
max_len = text_mask.shape[1] + img_len

cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=text_mask.device)

for i in range(batch_size):
s = text_len[i] + img_len
Expand Down Expand Up @@ -93,6 +98,8 @@ def attention(
Returns:
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
"""
if torch_musa is not None and mode == "flash":
mode = "torch"
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
q = pre_attn_layout(q)
k = pre_attn_layout(k)
Expand Down Expand Up @@ -178,7 +185,16 @@ def parallel_attention(
joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
joint_strategy="rear",
)
if flash_attn.__version__ >= '2.7.0':
if torch_musa is not None:
attn2 = F.scaled_dot_product_attention(
q[:,cu_seqlens_q[1]:],
k[:,cu_seqlens_kv[1]:],
v[:,cu_seqlens_kv[1]:],
attn_mask=None,
dropout_p=0,
is_causal=False
)
elif flash_attn.__version__ >= '2.7.0':
attn2, *_ = _flash_attn_forward(
q[:,cu_seqlens_q[1]:],
k[:,cu_seqlens_kv[1]:],
Expand Down