Skip to content

pt模型转rknn模型过大 #515

@lrm2017

Description

@lrm2017

直接将pt模型转为rknn模型,导出的模型巨大300MB,与转为onnx再导出rknn模型(2MB)相差很大,是代码哪里配置有问题么
`#!/usr/bin/env python3
"""
XFeat PyTorch模型直接转换为RKNN
避免GridSample问题,支持灵活配置
"""

import sys
import os
import cv2
import numpy as np
import torch
import glob
from rknn.api import RKNN

def create_xfeat_dataset(input_shape, output_path="dataset_xfeat_pt.txt", num_samples=50):
"""为XFeat创建数据集"""
print(f"创建XFeat数据集: {output_path}")

# 解析输入形状
if len(input_shape) == 4:
    batch, channels, height, width = input_shape
    is_nchw = True
else:
    print(f"不支持的输入形状: {input_shape}")
    return None

print(f"输入格式: NCHW")
print(f"目标尺寸: {width}x{height}")

# 查找datasets目录中的图像
assets_dir = "datasets"
if not os.path.exists(assets_dir):
    print(f"❌ datasets目录不存在: {assets_dir}")
    return None

# 查找图像文件(包括子目录)
image_files = []
for ext in ['.jpg', '.jpeg', '.png', '.bmp']:
    image_files.extend(glob.glob(os.path.join(assets_dir, f"**/*{ext}"), recursive=True))
    image_files.extend(glob.glob(os.path.join(assets_dir, f"**/*{ext.upper()}"), recursive=True))

if not image_files:
    print(f"❌ 在 {assets_dir} 目录中未找到图像文件")
    return None

print(f"找到 {len(image_files)} 个图像文件")

# 创建临时目录
temp_dir = "temp_xfeat_pt_data"
os.makedirs(temp_dir, exist_ok=True)

# 限制样本数量
num_samples = min(num_samples, len(image_files))

with open(output_path, 'w') as f:
    for i in range(num_samples):
        image_file = image_files[i % len(image_files)]
        print(f"处理图像: {os.path.basename(image_file)}")
        
        # 读取并处理图像
        img = cv2.imread(image_file)
        if img is None:
            continue
            
        # 转换为RGB并调整尺寸
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_resized = cv2.resize(img_rgb, (width, height))
        
        # 归一化到[0,1]
        image = img_resized.astype(np.float32) / 255.0
        
        # 转换为NCHW格式: [H, W, C] -> [C, H, W]
        image = np.transpose(image, (2, 0, 1))
        image = np.expand_dims(image, axis=0)  # [1, C, H, W]
        
        # 保存为numpy数组
        image_path = os.path.join(temp_dir, f"image_{i:03d}.npy")
        np.save(image_path, image)
        f.write(f"{image_path}\n")

print(f"数据集创建完成,包含 {num_samples} 个样本")
return output_path

def convert_xfeat_pt_to_rknn(pt_path, input_shape, platform='rk3568', output_path=None,
dense=False, top_k=512,
optimization_level=1, num_samples=50):
"""转换XFeat PyTorch模型为RKNN"""
if output_path is None:
model_name = f"xfeat_{'dense' if dense else 'sparse'}{top_k}{input_shape[2]}x{input_shape[3]}"
output_path = f"rknn/{model_name}.rknn"

print(f"转换XFeat模型: {pt_path}")
print(f"输入形状: {input_shape}")
print(f"目标平台: {platform}")
print(f"输出路径: {output_path}")
print(f"Dense模式: {dense}")
print(f"关键点数量: {top_k}")

# 创建数据集
dataset_path = create_xfeat_dataset(input_shape, num_samples=num_samples)
if not dataset_path:
    return False

try:
    # 创建RKNN对象
    rknn = RKNN(verbose=True)
    
    # 配置模型 - 添加完整的量化配置
    print('--> Config model')
    rknn.config(
        mean_values=[[127.5, 127.5, 127.5]], 
        std_values=[[128.0, 128.0, 128.0]], 
        quant_img_RGB2BGR=False,
        quantized_algorithm='normal',
        quantized_dtype='asymmetric_quantized-8',  # 8位量化,大幅减小模型
        quantized_method='channel',                # 通道级量化
        target_platform=platform,
        optimization_level=optimization_level,
        model_pruning=False,
    )
    print('done')
    
    # 加载PyTorch模型 - 参考官方test.py格式
    print('--> Loading model')
    
    # 尝试多种加载方法
    load_success = False
    
    # 方法1:标准格式
    try:
        ret = rknn.load_pytorch(
            model=pt_path, 
            input_size_list=[[input_shape[0], input_shape[1], input_shape[2], input_shape[3]]]  # 格式: [[C, H, W]]
        )
        if ret == 0:
            load_success = True
            print('✅ 标准格式加载成功')
        else:
            print(f'❌ 标准格式加载失败,错误码: {ret}')
    except Exception as e:
        print(f'❌ 标准格式加载异常: {e}')
    
    # 方法2:尝试不同的输入格式
    if not load_success:
        try:
            print('尝试不同的输入格式...')
            ret = rknn.load_pytorch(
                model=pt_path, 
                input_size_list=[[input_shape[2], input_shape[3], input_shape[1]]]  # 格式: [[H, W, C]]
            )
            if ret == 0:
                load_success = True
                print('✅ HWC格式加载成功')
            else:
                print(f'❌ HWC格式加载失败,错误码: {ret}')
        except Exception as e:
            print(f'❌ HWC格式加载异常: {e}')
    
    # 方法3:尝试不指定输入尺寸
    if not load_success:
        try:
            print('尝试不指定输入尺寸...')
            ret = rknn.load_pytorch(model=pt_path)
            if ret == 0:
                load_success = True
                print('✅ 无尺寸限制加载成功')
            else:
                print(f'❌ 无尺寸限制加载失败,错误码: {ret}')
        except Exception as e:
            print(f'❌ 无尺寸限制加载异常: {e}')
    
    if not load_success:
        print('❌ 所有加载方法都失败了')
        return False
    
    print('done')
    
    # 构建模型
    print('--> Building model')
    
    # 尝试不同的构建选项
    build_success = False
    
    # 方法1:使用量化 + 优化配置
    try:
        ret = rknn.build(
            do_quantization=True, 
            dataset=dataset_path,
            rknn_batch_size=1,  # 减小批处理大小
        )
        if ret == 0:
            build_success = True
            print('✅ 量化构建成功')
        else:
            print(f'❌ 量化构建失败,错误码: {ret}')
    except Exception as e:
        print(f'❌ 量化构建异常: {e}')
    
    # 方法2:不使用量化
    if not build_success:
        try:
            print('尝试不使用量化构建...')
            ret = rknn.build(do_quantization=False)
            if ret == 0:
                build_success = True
                print('✅ 非量化构建成功')
            else:
                print(f'❌ 非量化构建失败,错误码: {ret}')
        except Exception as e:
            print(f'❌ 非量化构建异常: {e}')
    
    # 方法3:使用更少的样本
    if not build_success:
        try:
            print('尝试使用更少的样本构建...')
            # 创建更小的数据集
            small_dataset = create_xfeat_dataset(input_shape, num_samples=10)
            if small_dataset:
                ret = rknn.build(do_quantization=True, dataset=small_dataset)
                if ret == 0:
                    build_success = True
                    print('✅ 小样本构建成功')
                else:
                    print(f'❌ 小样本构建失败,错误码: {ret}')
        except Exception as e:
            print(f'❌ 小样本构建异常: {e}')
    
    if not build_success:
        print('❌ 所有构建方法都失败了')
        return False
    
    print('done')
    
    # 导出RKNN模型
    print('--> Export rknn model')
    ret = rknn.export_rknn(output_path)
    if ret != 0:
        print('Export rknn model failed!')
        return False
    print('done')
    
    # 释放资源
    rknn.release()
    
    # 清理临时文件
    if os.path.exists("temp_xfeat_pt_data"):
        import shutil
        shutil.rmtree("temp_xfeat_pt_data")
    
    print(f"✅ 转换成功: {output_path}")
    return True
    
except Exception as e:
    print(f"❌ 转换失败: {e}")
    return False

def export_xfeat_pt_to_rknn(xfeat_path, output_folder="rknn",
input_shape=(1, 3, 480, 640),
dynamic=False,
dense=False,
top_k=256,
platform='rk3568',
optimization_level=1,
num_samples=50):
"""导出XFeat PyTorch模型为RKNN格式"""
print("=" * 50)
print("XFeat PyTorch to RKNN 转换器")
print("=" * 50)
print(f"模型路径: {xfeat_path}")
print(f"输出目录: {output_folder}")
print(f"输入形状: {input_shape}")
print(f"动态输入: {dynamic}")
print(f"Dense模式: {dense}")
print(f"关键点数量: {top_k}")
print(f"目标平台: {platform}")
print(f"优化级别: {optimization_level}")
print("=" * 50)

# 创建输出目录
os.makedirs(output_folder, exist_ok=True)

# 生成输出文件名
model_name = f"xfeat_{'dense' if dense else 'sparse'}_{top_k}_{input_shape[2]}x{input_shape[3]}"
output_path = os.path.join(output_folder, f"{model_name}.rknn")

# 转换模型
success = convert_xfeat_pt_to_rknn(
    pt_path=xfeat_path,
    input_shape=input_shape,
    platform=platform,
    output_path=output_path,
    dense=dense,
    top_k=top_k,
    optimization_level=optimization_level,
    num_samples=num_samples
)

if success:
    print("\n🎉 XFeat转换完成!")
    print(f"RKNN模型已保存到: {output_path}")
    print("现在可以使用转换后的RKNN模型进行推理了。")
else:
    print("\n❌ 转换失败!")
    print("可能的原因:")
    print("1. PyTorch模型包含RK3568不支持的操作")
    print("2. 模型结构过于复杂")
    print("3. 输入形状不匹配")
    print("4. 内存不足")

return success

def main():
"""主函数 - 支持命令行参数"""
if len(sys.argv) < 2:
print("Usage: python3 {} xfeat_pt_model_path [options]".format(sys.argv[0]))
print("\nOptions:")
print(" --input_shape WIDTH,HEIGHT 输入尺寸 (默认: 640,480)")
print(" --platform PLATFORM 目标平台 (默认: rk3568)")
print(" --dense 使用dense模式")
print(" --top_k NUM 关键点数量 (默认: 256)")
print(" --layers NUM LightGlue层数 (默认: 2)")
print(" --optimization NUM 优化级别 (默认: 1)")
print(" --samples NUM 数据集样本数 (默认: 50)")
print("\nExample:")
print(" python3 xfeat_pt_to_rknn.py weights/xfeat.pt --input_shape 1024,1536 --top_k 512")
print(" python3 xfeat_pt_to_rknn.py weights/xfeat.pt --dense --platform rk3588")
exit(1)

pt_path = sys.argv[1]

# 解析命令行参数
input_shape = (1, 3, 480, 640)  # 默认
platform = 'rk3568'
dense = False
top_k = 256
optimization_level = 1
num_samples = 50

i = 2
while i < len(sys.argv):
    arg = sys.argv[i]
    if arg == '--input_shape' and i + 1 < len(sys.argv):
        width, height = map(int, sys.argv[i + 1].split(','))
        input_shape = (1, 3, height, width)
        i += 2
    elif arg == '--platform' and i + 1 < len(sys.argv):
        platform = sys.argv[i + 1]
        i += 2
    elif arg == '--dense':
        dense = True
        i += 1
    elif arg == '--top_k' and i + 1 < len(sys.argv):
        top_k = int(sys.argv[i + 1])
        i += 2
    elif arg == '--optimization' and i + 1 < len(sys.argv):
        optimization_level = int(sys.argv[i + 1])
        i += 2
    elif arg == '--samples' and i + 1 < len(sys.argv):
        num_samples = int(sys.argv[i + 1])
        i += 2
    else:
        print(f"未知参数: {arg}")
        exit(1)

if not os.path.exists(pt_path):
    print(f"❌ PyTorch文件不存在: {pt_path}")
    exit(1)

# 执行转换
success = export_xfeat_pt_to_rknn(
    xfeat_path=pt_path,
    input_shape=input_shape,
    dynamic=False,
    dense=dense,
    top_k=top_k,
    platform=platform,
    optimization_level=optimization_level,
    num_samples=num_samples
)

exit(0 if success else 1)

if name == 'main':
# 如果直接运行脚本,使用默认参数
if len(sys.argv) == 1:
# 示例:使用默认参数转换XFeat模型
export_xfeat_pt_to_rknn(
xfeat_path="weights/xfeat_dummy.pt",
output_folder="rknn",
input_shape=(1, 3, 480, 640), # N C H W
dynamic=False, # 固定输入
dense=False, # 使用稀疏特征,避免复杂操作
top_k=256, # 大幅减少关键点数
platform='rk3568',
optimization_level=1,
num_samples=50
)
else:
# 使用命令行参数
main()
`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions