本项目是本科毕设“基于昇腾 AI 架构的高效化无人机射频信号识别”的训练端代码实现。仓库围绕 2D .npy 数据集构建了六阶段主线:
base_model:基座模型训练、验证、测试与 UMAP/混淆矩阵可视化pruning:基于torch-pruning的 iterative structured pruning + 微调qat:基于 Torch 原生 FX graph mode 的保守单路径 QATonnx:导出pruning_fp16或qat_convertONNX,并用 ONNX Runtime 评估amct:消费qat_convertONNX,生成 Ascend 侧可继续下游处理的 deploy/fakequant ONNXatc:消费pruning_fp16或amct_deployONNX,编译固定 batch=1 的.om
工程主线为:
base_model checkpoint
-> pruning checkpoint
-> QAT prepare checkpoint
-> ONNX 导出 / ORT 评估
-> AMCT 转换
-> ATC 编译
-> 后续部署 / 推理验证
base_model:已实现,作为训练主线上游稳定使用pruning:已实现,支持多轮剪枝、微调、拓扑导出与实验摘要qat:已实现,支持按剪枝拓扑严格恢复并导出 prepare 后 QAT checkpointonnx:已实现,支持pruning_fp16/qat_convert双分支导出、动态 batch、ORT 精度评估amct:已实现并接入主线;运行该阶段前需要额外准备仓库附带的amct_onnxwheel 与算子包atc:已实现并接入主线;真实编译与运行验证依赖 Ascend 宿主机环境
说明:
- “已实现 / 已接入主线”表示代码、入口脚本和产物契约已经落地。
- 文档统一以代码实现与标准产物契约为准,不把“已实现”误写成“已充分验证”。
- 5 种 2D ResNet 架构:
resnet6_2d、resnet10_2d、resnet14_2d、resnet18_2d、resnet34_2d base_model / pruning支持fp16或fp32数据加载qat固定纯fp32- 稳定的数据集切分与
output/splits/manifest 复用 - 基座、剪枝、QAT 三类结构化 checkpoint
- 剪枝后完整拓扑导出:
channel_cfg + architecture_signature - 所有消费
.pthcheckpoint 的步骤统一强校验architecture_signature - ONNX / deploy ONNX 阶段统一通过同目录 summary 读取并校验上游
architecture_signature与来源信息 - Torch 原生 FX graph mode QAT
pruning_fp16/qat_convert双分支 ONNX 导出- QAT ONNX rewrite + validate,用于 CANN/AMCT/ATC 兼容约束
- AMCT deploy / fakequant ONNX 转换
- pruning FP16 / AMCT deploy ONNX 的 ATC 编译
- TensorBoard、混淆矩阵、UMAP 等辅助分析能力
项目只把以下项目视为“用户独立手动安装的前置项”:
git:用于克隆仓库pixi:负责系统工具链与运行时环境uv:负责 Python 依赖同步与uv run ...direnv(可选):用于自动激活项目公共环境层
以下内容不应再作为“手动环境依赖”单独列出,因为它们由项目命令自动安装:
pixi install:- Python 3.12 运行时
- GCC / G++ / Make / CMake 等工具链
cuda-runtime、cudnnascend-cann-toolkit、ascend-cann-310b-ops
uv sync:torchonnxonnxruntime-gputorch-pruningtensorboardmatplotlibumap-learn- 以及
pyproject.toml中声明的其余 Python 包
- 若要使用 CUDA 加速训练,宿主机需要可用的 NVIDIA GPU 与驱动。
- 若要进行真实的 Ascend 编译或部署验证,宿主机需要对应的 Ascend 设备/驱动环境。
- 上述宿主机硬件/驱动要求是条件说明,不属于仓库自身的可自动安装依赖。
-
克隆项目
git clone git@github.com:wh-wang132/ResNet.git cd ResNet -
安装 Pixi 环境
pixi install
-
同步 Python 依赖
uv sync
-
启用
direnv(推荐)direnv allow
当前项目根目录的
.envrc提供仓库级公共变量:REPO_ROOTPYTHONPATH=$REPO_ROOT/src
说明:
direnv为推荐方案;若不使用direnv自动激活,也必须手动提供与.envrc等价的环境变量- 所有脚本统一通过
.envrc提供的REPO_ROOT识别仓库根
-
若需要运行 AMCT 阶段,额外准备仓库附带的 AMCT 组件
amct_onnx/amct_onnx-0.23.2-py3-none-linux_x86_64.whlamct_onnx/amct_onnx_op.tar.gz说明:- 上述 wheel 与算子包已经随仓库提供
- 它们不在
pyproject.toml或pixi.toml中声明 - 因此它们不是
uv sync/pixi install自动安装的公共依赖,而是 AMCT 阶段专用的手动准备项 - 若宿主机已预装系统全局 CUDA,AMCT 在安装
amct_onnx/amct_onnx_op.tar.gz或运行阶段可能优先搜索到系统 CUDA,从而引发 CUDA 版本不匹配错误;此场景建议在 Docker 环境中部署和运行 AMCT 相关流程
-
准备数据集
- 将
.npy数据集放入Data/ - 目录结构见 数据准备指南
- 将
项目采用“公共层 + 阶段增量层”的环境结构:
| 阶段 | 公共层 | 阶段增量层 |
|---|---|---|
base_model |
.envrc |
source scripts/load_base_model_env.sh |
pruning |
.envrc |
无 |
qat |
.envrc |
无 |
onnx |
.envrc |
source scripts/load_onnx_env.sh |
amct |
.envrc |
source scripts/load_amct_env.sh |
atc |
.envrc |
source scripts/load_atc_env.sh |
说明:
.envrc是唯一公共入口,只负责仓库级变量。scripts/load_*_env.sh只负责补齐阶段增量环境。- 阶段脚本默认在已加载
.envrc的 shell 中运行。
# 完整训练 + 测试
uv run src/base_model_main.py --epochs 20 --model resnet6_2d
# 仅训练
uv run src/base_model_main.py --epochs 20 --Test False
# 仅测试 + UMAP
uv run src/base_model_main.py --Train False --UMAP True# 最小剪枝命令
uv run src/pruning_main.py --model resnet6_2d
# 指定总剪枝率与轮数
uv run src/pruning_main.py \
--model resnet18_2d \
--pruning_ratio 0.30 \
--pruning_steps 5 \
--global_pruning True \
--finetune_epochs 10
# 不做微调,只保存最终剪枝结果
uv run src/pruning_main.py \
--model resnet14_2d \
--finetune_epochs 0 \
--evaluate_test False# 最小 QAT 命令
uv run src/qat_main.py \
--pruning_checkpoint output/pruning/resnet14_2d/ratio0.60_steps8_global_ft10_bs64/best_pruned_model.pth
# 指定保守 QAT 微调参数
uv run src/qat_main.py \
--pruning_checkpoint output/pruning/resnet34_2d/ratio0.80_steps8_global_ft10_bs64/best_pruned_model.pth \
--qat_epochs 10 \
--lr 1e-5 \
--batch_size 64# pruning checkpoint -> FP16 ONNX
uv run src/onnx_main.py \
--branch pruning_fp16 \
--checkpoint output/pruning/resnet10_2d/ratio0.40_steps5_global_ft10_bs64/best_pruned_model.pth \
--eval_batch_size 64
# QAT checkpoint -> convert 后量化 ONNX
uv run src/onnx_main.py \
--branch qat_convert \
--checkpoint output/qat/resnet10_2d/from_ratio0.40_steps5_global_ft10_bs64/best_qat_prepare_model.pth \
--eval_batch_size 64说明:
- ONNX 导出当前统一使用动态 batch。
onnx_summary.json.example_input_shape中的batch=1只表示导出样例输入。--eval_batch_size只影响 Torch / ORT 精度评估,不影响导出图结构。
运行前请先确认仓库附带的 amct_onnx wheel 与算子包已经按目标环境自行安装或部署。
. scripts/load_amct_env.sh
uv run src/amct_main.py \
--onnx_model output/onnx/qat_convert/resnet6_2d/from_ratio0.60_steps8_global_ft10_bs64/model_quant.onnx. scripts/load_atc_env.sh
# pruning_fp16 ONNX -> ATC
pixi run python src/atc_main.py \
--branch pruning_fp16 \
--onnx_model output/onnx/pruning_fp16/resnet10_2d/from_ratio0.40_steps5_global_ft10_bs64/model_fp16.onnx
# AMCT deploy ONNX -> ATC
pixi run python src/atc_main.py \
--branch amct_deploy \
--onnx_model output/amct/resnet6_2d/from_ratio0.60_steps8_global_ft10_bs64/deploy_model.onnx默认:
soc_version=Ascend310B4input_format=NCHWinput_shape默认从上游摘要中的输入接口派生,并将 batch 固定为1- 若用户显式传入
--input_shape,其输入名与各维度必须与自动派生结果完全一致,否则直接报错
剪枝入口不会手动接收基座 checkpoint 路径,而是自动扫描:
output/base_model/<model>/<experiment_dir>/best_model.pth
程序会遍历 output/base_model/<model>/ 下各实验目录,读取每个目录 best_val_acc_info.txt 的最后一条有效记录,并按 val acc 优先、val loss 次优选择最佳实验权重进入剪枝链路。
autorun/ 目录提供 6 份顺序执行脚本:
- autorun/autorun_base_model.sh
- autorun/autorun_pruning.sh
- autorun/autorun_qat.sh
- autorun/autorun_onnx.sh
- autorun/autorun_amct.sh
- autorun/autorun_atc.sh
这些脚本整体仍服务于顺序批处理执行;其中 onnx / amct / atc autorun 已包含 shell 函数、mktemp、trap、find 遍历与临时文件清理等基础控制逻辑,适合直接在服务器终端监视运行并安全清理中间状态。
运行前提:
- shell 必须已自动或手动激活项目根目录的
.envrc autorun脚本统一直接使用.envrc提供的REPO_ROOT作为仓库根
脚本行为概览:
autorun/autorun_base_model.sh- 批量训练 5 个基座模型
- 搜索维度:模型与
batch_size - 固定显式传入:
--full_load True
autorun/autorun_pruning.sh- 批量运行 pruning 实验
- 搜索维度:模型、
pruning_ratio、pruning_steps - 固定显式传入:
--full_load True
autorun/autorun_qat.sh- 批量消费 pruning checkpoint
- 固定显式传入:
--full_load True
autorun/autorun_onnx.sh- 遍历
output/pruning/**/best_pruned_model.pth - 遍历
output/qat/**/best_qat_prepare_model.pth - 默认参数与 ONNX CLI 保持一致:
evaluate_test=True、eval_batch_size=64
- 遍历
autorun/autorun_amct.sh- 遍历
output/onnx/qat_convert/**/model_quant.onnx
- 遍历
autorun/autorun_atc.sh- 遍历
output/onnx/pruning_fp16/**/model_fp16.onnx - 遍历
output/amct/**/deploy_model.onnx - 默认参数与 ATC CLI 保持一致:
soc_version=Ascend310B4
- 遍历
base_model.dataset.data_set_split() 会优先读取:
output/splits/
中的数据集划分 manifest。若 manifest 不存在或与当前配置不匹配,则会重新划分并重新落盘。
manifest 保存:
- 训练 / 验证 / 测试的相对路径清单
class_names与class_to_idx- 划分比例与随机种子
- 相对路径
data_dir=Data