该场景主要是为了使用进化策略对CNN网络结构进行自动剪枝压缩,剪枝的搜索空间为需要剪枝的卷积层的输出通道数。维护一个包含N个个体的种群P,每个个体对应一个压缩的网络模型。通过交叉变异产生一个同样大小N的种群P’,每个压缩的网络模型通过进行training/valid,在验证集的Accuracy、FLOPs、参数量、Mac值、带宽等用户指定的指标,作为优化目标,来对个体进行排序选择,更新维持的种群P。
ResNet-20结构如下图所示,主要包括第一个卷积核和三个顺序stage,每个stage由三个ResBlock构成,当第一个ResBlock的shortcut连接两个通道数不一致的卷积层时,shortcut由1x1的上采样卷积核构成。
以ResNet-20网络为例,详细阐述自动剪枝压缩的过程:
编码:
1.1. 编码表示对网络的卷积层的输出通道进行01编码,0表示剪掉该通道,1表示保留该通道。首先需要对网络结构进行分析,确定需要编码的卷积层;
1.2. 对于上图中的蓝色块表示的存在shortcut连接的卷积层,当不存在1x1上采样操作时,需保证这两个卷积层的输出通道数一致,为了编码的方便,我们将这两个卷积层的输出通道数的编码也保持一致。我们以ch_final表示该部分编码,共3个;
1.3. 对于上图中的橙色块表示的不存在shortcut连接的卷积层,可以任意剪枝,但需要保证不要全部剪掉。我们以ch_middle表示该部分编码,共9个。
搜索:
2.1. 根据编码确定的【搜索空间】,通过交叉、变异等【进化操作】从种群P生成N个压缩模型的编码;
评估:
3.1. 根据【进化操作】生成的N个编码,完成压缩模型的构建;
3.2 执行【评估过程】,产生用户定义的所有评估结果;
优化:
4.1. 调用【进化算法】,对种群P进行更新;
接下来重复【搜索】->【评估】->【优化】过程,完成整个进化自动剪枝流程,搜出Pareto前沿。搜索完剪枝模型之后,我们会对Pareto前沿的剪枝模型进行训练,得到剪枝后的模型的最终表现。
EA Pruning适合网络的通道剪枝,分为两个阶段:搜索剪枝网络阶段+剪枝网络训练阶段。 本方法用于原始大模型进行剪枝压缩,可以用于各种场景。目前给的example是图像分类场景。
我们推荐在华为云的训练环境中运行PruningEA。其入口程序为
vega/examples/prune_ea/main.py
在配置文件中进行参数配置,包括搜索剪枝模型、训练剪枝模型2个过程,配置文件为:
vega/examples/prune_ea/prune.yml
配置文件在main.py
中直接传入给pipeline,两个过程会依次进行,搜索过程会搜出Pareto前沿,然后训练过程会把前沿的模型训到底,得到最终的表现。
prune.yml
的主要配置信息如下:
nas:
search_algorithm: # 进化算法配置信息
type: PruneEA
codec: PruneCodec
policy:
length: 464 # 搜索的总通道数
num_generation: 31 # 进化代数
num_individual: 32 # 每一代的个体数
random_models: 64 # 随机初始化的模型数
search_space: # 搜索空间配置信息
type: SearchSpace
modules: ['backbone']
backbone:
name: 'PruneResNet'
base_chn: [16,16,16,32,32,32,64,64,64]
base_chn_node: [16,16,32,64]
num_classes: 10
目前可以支持的剪枝网络为为ResNet系列分类网络如ResNet20,ResNet32,ResNet56等
search_space: # ResNet20搜索空间
type: SearchSpace
modules: ['backbone']
backbone:
name: 'PruneResNet'
base_chn: [16,16,16,32,32,32,64,64,64]
base_chn_node: [16,16,32,64]
num_classes: 10 # 分类数
目前可支持的网络如下:
search space | search algorithm | |
---|---|---|
ResNet-20 | base_chn_node: [16,16,32,64] base_chn: [[16]*3, [32]*3,[64]*3] num_blocks: [3,3,3] |
length:464 |
ResNet-32 | base_chn_node: [16,16,32,64] base_chn: [[16]*5, [32]*5,[64]*5] num_blocks: [5,5,5] |
length:688 |
ResNet-56 | base_chn_node: [16,16,32,64] base_chn: [[16]*9, [32]*9,[64]*9] num_blocks: [9,9,9] |
length: 1136 |
我们使用NSGA-III多目标优化进化算法进行pareto front的搜索。算法详情参考原论文[1]。
[1] Deb, Kalyanmoy, and Himanshu Jain. "An evolutionary many-objective optimization algorithm using reference-point-based nondominated sorting approach, part I: solving problems with box constraints." IEEE Transactions on Evolutionary Computation 18.4 (2013): 577-601.
输出2个文件:
我们在example里提供了在CIFAR-10数据集上自动剪枝ResNet20网络的配置。实验结果如下:
Pareto front上选3个不同剪枝比例的模型重训400epoch的结果。
model | FLOPs | TOP1 Acc | r_c |
---|---|---|---|
Baseline | 40.8M | 91.45 | 1x |
IND1 | 30.11M | 91.22 | 0.74x |
IND2 | 19.14M | 90.9 | 0.46x |
IND3 | 6.14M | 87.61 | 0.15x |
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。