Skip to content

biubushy/mamba-selective-scan

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mamba-selective-scan

Mamba 官方仓库提取的 Selective Scan 高性能 CUDA 实现。

快速安装

pip install git+https://github.com/biubushy/mamba-selective-scan.git

来源说明

本包仅做提取工作,所有 CUDA 算子代码原封不动地来自 Mamba 官方仓库的指定提交:

本包未对原始算法进行任何修改,仅将 Selective Scan 这一子模块从完整的 Mamba 项目中独立提取出来,使其可作为轻量级依赖单独安装。

原始算法作者: Albert Gu, Tri Dao

环境要求

  • Python >= 3.10
  • PyTorch >= 2.0
  • CUDA >= 11.6 (NVIDIA) 或 ROCm >= 6.0 (AMD)

API

selective_scan_fn

CUDA 加速的 Selective Scan 前向 + 反向(通过 torch.autograd 自动支持)。

from mamba_selective_scan import selective_scan_fn

out = selective_scan_fn(
    u,              # (batch, dim, seqlen)
    delta,          # (batch, dim, seqlen)
    A,              # (dim, dstate)
    B,              # (dim, dstate) 或 (batch, ngroups, dstate, seqlen)
    C,              # (dim, dstate) 或 (batch, ngroups, dstate, seqlen)
    D=None,         # (dim,)
    z=None,         # (batch, dim, seqlen) — 门控
    delta_bias=None,        # (dim,)
    delta_softplus=False,
    return_last_state=False,
)

selective_scan_ref

纯 Python/PyTorch 参考实现,不依赖 CUDA 扩展,可用于调试和正确性验证。函数签名与 selective_scan_fn 一致。

from mamba_selective_scan import selective_scan_ref

许可证

本包遵循 Apache License 2.0,与 Mamba 官方仓库保持一致。

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors