Skip to content
Snippets Groups Projects
Unverified Commit a8fc71d3 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3241 modify FaceRecognition and maskrcnn network scripts

Merge pull request !3241 from anzhengqi/modify-networks
parents 3c3e1175 3a8d05a1
Branches
No related tags found
No related merge requests found
Showing
with 14 additions and 23 deletions
...@@ -55,7 +55,7 @@ class FeatPyramidNeck(nn.Cell): ...@@ -55,7 +55,7 @@ class FeatPyramidNeck(nn.Cell):
Tuple, with tensors of same channel size. Tuple, with tensors of same channel size.
Examples: Examples:
neck = FeatPyramidNeck([100,200,300], 50, 4) neck = FeatPyramidNeck([100,200,300], 50, 4, config.feature_shapes)
input_data = (normal(0,0.1,(1,c,1280//(4*2**i), 768//(4*2**i)), input_data = (normal(0,0.1,(1,c,1280//(4*2**i), 768//(4*2**i)),
dtype=np.float32) \ dtype=np.float32) \
for i, c in enumerate(config.fpn_in_channels)) for i, c in enumerate(config.fpn_in_channels))
...@@ -65,7 +65,8 @@ class FeatPyramidNeck(nn.Cell): ...@@ -65,7 +65,8 @@ class FeatPyramidNeck(nn.Cell):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
num_outs): num_outs,
feature_shapes):
super(FeatPyramidNeck, self).__init__() super(FeatPyramidNeck, self).__init__()
if context.get_context("device_target") == "Ascend": if context.get_context("device_target") == "Ascend":
...@@ -91,9 +92,9 @@ class FeatPyramidNeck(nn.Cell): ...@@ -91,9 +92,9 @@ class FeatPyramidNeck(nn.Cell):
self.fpn_convs_.append(fpn_conv) self.fpn_convs_.append(fpn_conv)
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_)
self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_)
self.interpolate1 = P.ResizeBilinear((48, 80)) self.interpolate1 = P.ResizeBilinear(feature_shapes[2])
self.interpolate2 = P.ResizeBilinear((96, 160)) self.interpolate2 = P.ResizeBilinear(feature_shapes[1])
self.interpolate3 = P.ResizeBilinear((192, 320)) self.interpolate3 = P.ResizeBilinear(feature_shapes[0])
self.cast = P.Cast() self.cast = P.Cast()
self.maxpool = P.MaxPool(kernel_size=1, strides=2, pad_mode="same") self.maxpool = P.MaxPool(kernel_size=1, strides=2, pad_mode="same")
......
...@@ -96,7 +96,8 @@ class Mask_Rcnn_Resnet50(nn.Cell): ...@@ -96,7 +96,8 @@ class Mask_Rcnn_Resnet50(nn.Cell):
# Fpn # Fpn
self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels, self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels,
config.fpn_out_channels, config.fpn_out_channels,
config.fpn_num_outs) config.fpn_num_outs,
config.feature_shapes)
# Rpn and rpn loss # Rpn and rpn loss
self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8)) self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8))
......
# Copyright 2020-2021 Huawei Technologies Co., Ltd # Copyright 2020-2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -64,15 +64,14 @@ class Proposal(nn.Cell): ...@@ -64,15 +64,14 @@ class Proposal(nn.Cell):
self.target_means = target_means self.target_means = target_means
self.target_stds = target_stds self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls self.use_sigmoid_cls = use_sigmoid_cls
self.reshape_shape = (-1, 1)
if self.use_sigmoid_cls: if self.use_sigmoid_cls:
self.cls_out_channels = num_classes - 1 self.cls_out_channels = num_classes - 1
self.activation = P.Sigmoid() self.activation = P.Sigmoid()
self.reshape_shape = (-1, 1)
else: else:
self.cls_out_channels = num_classes self.cls_out_channels = num_classes
self.activation = P.Softmax(axis=1) self.activation = P.Softmax(axis=1)
self.reshape_shape = (-1, 2)
if self.cls_out_channels <= 0: if self.cls_out_channels <= 0:
raise ValueError('num_classes={} is too small'.format(num_classes)) raise ValueError('num_classes={} is too small'.format(num_classes))
......
...@@ -27,7 +27,6 @@ backbone: "r100" ...@@ -27,7 +27,6 @@ backbone: "r100"
use_se: 1 use_se: 1
emb_size: 512 emb_size: 512
act_type: "relu" act_type: "relu"
fp16: 1
pre_bn: 1 pre_bn: 1
inference: 0 inference: 0
use_drop: 1 use_drop: 1
......
...@@ -27,7 +27,6 @@ backbone: "r100" ...@@ -27,7 +27,6 @@ backbone: "r100"
use_se: 1 use_se: 1
emb_size: 512 emb_size: 512
act_type: "relu" act_type: "relu"
fp16: 1
pre_bn: 1 pre_bn: 1
inference: 0 inference: 0
use_drop: 1 use_drop: 1
......
...@@ -27,7 +27,6 @@ backbone: "r100" ...@@ -27,7 +27,6 @@ backbone: "r100"
use_se: 0 use_se: 0
emb_size: 256 emb_size: 256
act_type: "relu" act_type: "relu"
fp16: 1
pre_bn: 0 pre_bn: 0
inference: 0 inference: 0
use_drop: 1 use_drop: 1
......
...@@ -27,7 +27,6 @@ backbone: "r100" ...@@ -27,7 +27,6 @@ backbone: "r100"
use_se: 0 use_se: 0
emb_size: 256 emb_size: 256
act_type: "relu" act_type: "relu"
fp16: 1
pre_bn: 0 pre_bn: 0
inference: 0 inference: 0
use_drop: 1 use_drop: 1
......
...@@ -27,7 +27,6 @@ backbone: "r100" ...@@ -27,7 +27,6 @@ backbone: "r100"
use_se: 1 use_se: 1
emb_size: 512 emb_size: 512
act_type: "relu" act_type: "relu"
fp16: 1
pre_bn: 1 pre_bn: 1
inference: 0 inference: 0
use_drop: 1 use_drop: 1
......
...@@ -29,7 +29,6 @@ backbone: "r100" ...@@ -29,7 +29,6 @@ backbone: "r100"
use_se: 0 use_se: 0
emb_size: 256 emb_size: 256
act_type: "relu" act_type: "relu"
fp16: 1
pre_bn: 0 pre_bn: 0
inference: 1 inference: 1
use_drop: 0 use_drop: 0
......
...@@ -29,7 +29,6 @@ backbone: "r100" ...@@ -29,7 +29,6 @@ backbone: "r100"
use_se: 0 use_se: 0
emb_size: 256 emb_size: 256
act_type: "relu" act_type: "relu"
fp16: 1
pre_bn: 0 pre_bn: 0
inference: 1 inference: 1
use_drop: 0 use_drop: 0
......
...@@ -260,12 +260,11 @@ def run_train(): ...@@ -260,12 +260,11 @@ def run_train():
network_1 = DistributedHelper(_backbone, margin_fc_1) network_1 = DistributedHelper(_backbone, margin_fc_1)
config.logger.info('DistributedHelper----out----') config.logger.info('DistributedHelper----out----')
config.logger.info('network fp16----in----') config.logger.info('network fp16----in----')
if config.fp16 == 1:
network_1.add_flags_recursive(fp16=True) network_1.add_flags_recursive(fp16=True)
config.logger.info('network fp16----out----') config.logger.info('network fp16----out----')
criterion_1 = get_loss(config) criterion_1 = get_loss(config)
if config.fp16 == 1 and config.model_parallel == 0: if config.model_parallel == 0:
criterion_1.add_flags_recursive(fp32=True) criterion_1.add_flags_recursive(fp32=True)
network_1 = load_pretrain(config, network_1) network_1 = load_pretrain(config, network_1)
......
...@@ -148,7 +148,6 @@ class GatherFeatureByInd(nn.Cell): ...@@ -148,7 +148,6 @@ class GatherFeatureByInd(nn.Cell):
self.reshape = ops.Reshape() self.reshape = ops.Reshape()
self.enable_cpu_gatherd = enable_cpu_gatherd self.enable_cpu_gatherd = enable_cpu_gatherd
if self.enable_cpu_gatherd: if self.enable_cpu_gatherd:
self.value = Tensor(2, mstype.int32)
self.gather_nd = ops.GatherD() self.gather_nd = ops.GatherD()
self.expand_dims = ops.ExpandDims() self.expand_dims = ops.ExpandDims()
else: else:
...@@ -165,7 +164,7 @@ class GatherFeatureByInd(nn.Cell): ...@@ -165,7 +164,7 @@ class GatherFeatureByInd(nn.Cell):
# (b, J, K, N) # (b, J, K, N)
index = self.expand_dims(ind, -1) index = self.expand_dims(ind, -1)
index = self.tile(index, (1, 1, 1, N)) index = self.tile(index, (1, 1, 1, N))
feat = self.gather_nd(feat, self.value, index) feat = self.gather_nd(feat, 2, index)
else: else:
ind = self.reshape(ind, (-1, 1)) ind = self.reshape(ind, (-1, 1))
ind_b = nn.Range(0, b * J, 1)() ind_b = nn.Range(0, b * J, 1)()
......
...@@ -134,7 +134,6 @@ class GatherFeature(nn.Cell): ...@@ -134,7 +134,6 @@ class GatherFeature(nn.Cell):
self.reshape = ops.Reshape() self.reshape = ops.Reshape()
self.enable_cpu_gather = enable_cpu_gather self.enable_cpu_gather = enable_cpu_gather
if self.enable_cpu_gather: if self.enable_cpu_gather:
self.value = Tensor(1, mstype.int32)
self.gather_nd = ops.GatherD() self.gather_nd = ops.GatherD()
self.expand_dims = ops.ExpandDims() self.expand_dims = ops.ExpandDims()
else: else:
...@@ -147,7 +146,7 @@ class GatherFeature(nn.Cell): ...@@ -147,7 +146,7 @@ class GatherFeature(nn.Cell):
# (b, N, c) # (b, N, c)
index = self.expand_dims(ind, -1) index = self.expand_dims(ind, -1)
index = self.tile(index, (1, 1, c)) index = self.tile(index, (1, 1, c))
feat = self.gather_nd(feat, self.value, index) feat = self.gather_nd(feat, 1, index)
else: else:
# (b, N)->(b*N, 1) # (b, N)->(b*N, 1)
b, N = self.shape(ind) b, N = self.shape(ind)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment