61 Star 651 Fork 253

PaddlePaddle / PaddleDetection

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
bifpn.py 10.75 KB
一键复制 编辑 原始数据 按行查看 历史
wangxinxin08 提交于 2021-12-20 17:36 . refine sync bn (#4361)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Constant
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import ConvNormLayer
from ..shape_spec import ShapeSpec
__all__ = ['BiFPN']
class SeparableConvLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels=None,
kernel_size=3,
norm_type='bn',
norm_groups=32,
act='swish'):
super(SeparableConvLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn', 'gn', None]
assert act in ['swish', 'relu', None]
self.in_channels = in_channels
if out_channels is None:
self.out_channels = self.in_channels
self.norm_type = norm_type
self.norm_groups = norm_groups
self.depthwise_conv = nn.Conv2D(
in_channels,
in_channels,
kernel_size,
padding=kernel_size // 2,
groups=in_channels,
bias_attr=False)
self.pointwise_conv = nn.Conv2D(in_channels, self.out_channels, 1)
# norm type
if self.norm_type in ['bn', 'sync_bn']:
self.norm = nn.BatchNorm2D(self.out_channels)
elif self.norm_type == 'gn':
self.norm = nn.GroupNorm(
num_groups=self.norm_groups, num_channels=self.out_channels)
# activation
if act == 'swish':
self.act = nn.Swish()
elif act == 'relu':
self.act = nn.ReLU()
def forward(self, x):
if self.act is not None:
x = self.act(x)
out = self.depthwise_conv(x)
out = self.pointwise_conv(out)
if self.norm_type is not None:
out = self.norm(out)
return out
class BiFPNCell(nn.Layer):
def __init__(self,
channels=256,
num_levels=5,
eps=1e-5,
use_weighted_fusion=True,
kernel_size=3,
norm_type='bn',
norm_groups=32,
act='swish'):
super(BiFPNCell, self).__init__()
self.channels = channels
self.num_levels = num_levels
self.eps = eps
self.use_weighted_fusion = use_weighted_fusion
# up
self.conv_up = nn.LayerList([
SeparableConvLayer(
self.channels,
kernel_size=kernel_size,
norm_type=norm_type,
norm_groups=norm_groups,
act=act) for _ in range(self.num_levels - 1)
])
# down
self.conv_down = nn.LayerList([
SeparableConvLayer(
self.channels,
kernel_size=kernel_size,
norm_type=norm_type,
norm_groups=norm_groups,
act=act) for _ in range(self.num_levels - 1)
])
if self.use_weighted_fusion:
self.up_weights = self.create_parameter(
shape=[self.num_levels - 1, 2],
attr=ParamAttr(initializer=Constant(1.)))
self.down_weights = self.create_parameter(
shape=[self.num_levels - 1, 3],
attr=ParamAttr(initializer=Constant(1.)))
def _feature_fusion_cell(self,
conv_layer,
lateral_feat,
sampling_feat,
route_feat=None,
weights=None):
if self.use_weighted_fusion:
weights = F.relu(weights)
weights = weights / (weights.sum() + self.eps)
if route_feat is not None:
out_feat = weights[0] * lateral_feat + \
weights[1] * sampling_feat + \
weights[2] * route_feat
else:
out_feat = weights[0] * lateral_feat + \
weights[1] * sampling_feat
else:
if route_feat is not None:
out_feat = lateral_feat + sampling_feat + route_feat
else:
out_feat = lateral_feat + sampling_feat
out_feat = conv_layer(out_feat)
return out_feat
def forward(self, feats):
# feats: [P3 - P7]
lateral_feats = []
# up
up_feature = feats[-1]
for i, feature in enumerate(feats[::-1]):
if i == 0:
lateral_feats.append(feature)
else:
shape = paddle.shape(feature)
up_feature = F.interpolate(
up_feature, size=[shape[2], shape[3]])
lateral_feature = self._feature_fusion_cell(
self.conv_up[i - 1],
feature,
up_feature,
weights=self.up_weights[i - 1]
if self.use_weighted_fusion else None)
lateral_feats.append(lateral_feature)
up_feature = lateral_feature
out_feats = []
# down
down_feature = lateral_feats[-1]
for i, (lateral_feature,
route_feature) in enumerate(zip(lateral_feats[::-1], feats)):
if i == 0:
out_feats.append(lateral_feature)
else:
down_feature = F.max_pool2d(down_feature, 3, 2, 1)
if i == len(feats) - 1:
route_feature = None
weights = self.down_weights[
i - 1][:2] if self.use_weighted_fusion else None
else:
weights = self.down_weights[
i - 1] if self.use_weighted_fusion else None
out_feature = self._feature_fusion_cell(
self.conv_down[i - 1],
lateral_feature,
down_feature,
route_feature,
weights=weights)
out_feats.append(out_feature)
down_feature = out_feature
return out_feats
@register
@serializable
class BiFPN(nn.Layer):
"""
Bidirectional Feature Pyramid Network, see https://arxiv.org/abs/1911.09070
Args:
in_channels (list[int]): input channels of each level which can be
derived from the output shape of backbone by from_config.
out_channel (int): output channel of each level.
num_extra_levels (int): the number of extra stages added to the last level.
default: 2
fpn_strides (List): The stride of each level.
num_stacks (int): the number of stacks for BiFPN, default: 1.
use_weighted_fusion (bool): use weighted feature fusion in BiFPN, default: True.
norm_type (string|None): the normalization type in BiFPN module. If
norm_type is None, norm will not be used after conv and if
norm_type is string, bn, gn, sync_bn are available. default: bn.
norm_groups (int): if you use gn, set this param.
act (string|None): the activation function of BiFPN.
"""
def __init__(self,
in_channels=(512, 1024, 2048),
out_channel=256,
num_extra_levels=2,
fpn_strides=[8, 16, 32, 64, 128],
num_stacks=1,
use_weighted_fusion=True,
norm_type='bn',
norm_groups=32,
act='swish'):
super(BiFPN, self).__init__()
assert num_stacks > 0, "The number of stacks of BiFPN is at least 1."
assert norm_type in ['bn', 'sync_bn', 'gn', None]
assert act in ['swish', 'relu', None]
assert num_extra_levels >= 0, \
"The `num_extra_levels` must be non negative(>=0)."
self.in_channels = in_channels
self.out_channel = out_channel
self.num_extra_levels = num_extra_levels
self.num_stacks = num_stacks
self.use_weighted_fusion = use_weighted_fusion
self.norm_type = norm_type
self.norm_groups = norm_groups
self.act = act
self.num_levels = len(self.in_channels) + self.num_extra_levels
if len(fpn_strides) != self.num_levels:
for i in range(self.num_extra_levels):
fpn_strides += [fpn_strides[-1] * 2]
self.fpn_strides = fpn_strides
self.lateral_convs = nn.LayerList()
for in_c in in_channels:
self.lateral_convs.append(
ConvNormLayer(in_c, self.out_channel, 1, 1))
if self.num_extra_levels > 0:
self.extra_convs = nn.LayerList()
for i in range(self.num_extra_levels):
if i == 0:
self.extra_convs.append(
ConvNormLayer(self.in_channels[-1], self.out_channel, 3,
2))
else:
self.extra_convs.append(nn.MaxPool2D(3, 2, 1))
self.bifpn_cells = nn.LayerList()
for i in range(self.num_stacks):
self.bifpn_cells.append(
BiFPNCell(
self.out_channel,
self.num_levels,
use_weighted_fusion=self.use_weighted_fusion,
norm_type=self.norm_type,
norm_groups=self.norm_groups,
act=self.act))
@classmethod
def from_config(cls, cfg, input_shape):
return {
'in_channels': [i.channels for i in input_shape],
'fpn_strides': [i.stride for i in input_shape]
}
@property
def out_shape(self):
return [
ShapeSpec(
channels=self.out_channel, stride=s) for s in self.fpn_strides
]
def forward(self, feats):
assert len(feats) == len(self.in_channels)
fpn_feats = []
for conv_layer, feature in zip(self.lateral_convs, feats):
fpn_feats.append(conv_layer(feature))
if self.num_extra_levels > 0:
feat = feats[-1]
for conv_layer in self.extra_convs:
feat = conv_layer(feat)
fpn_feats.append(feat)
for bifpn_cell in self.bifpn_cells:
fpn_feats = bifpn_cell(fpn_feats)
return fpn_feats
Python
1
https://gitee.com/paddlepaddle/PaddleDetection.git
git@gitee.com:paddlepaddle/PaddleDetection.git
paddlepaddle
PaddleDetection
PaddleDetection
release/2.5

搜索帮助