Skip to content
Snippets Groups Projects
Commit 66d18ad5 authored by wangchangheng's avatar wangchangheng
Browse files

fix mem of nets

parent 150b71a1
No related branches found
No related tags found
No related merge requests found
......@@ -18,11 +18,9 @@ import math
import re
from copy import deepcopy
import mindspore as ms
from mindspore import nn
from mindspore.common.initializer import (Normal, One, Uniform, Zero)
from mindspore.ops import operations as P
from mindspore.ops.composite import clip_by_value
relu = P.ReLU()
sigmoid = P.Sigmoid()
......@@ -345,14 +343,6 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
return arch_args
def hard_swish(x):
x = P.Cast()(x, ms.float32)
y = x + 3.0
y = clip_by_value(y, 0.0, 6.0)
y = y / 6.0
return x * y
class BlockBuilder(nn.Cell):
def __init__(self, builder_in_channels, builder_block_args, channel_multiplier=1.0, channel_divisor=8,
channel_min=None, pad_type='', act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False,
......@@ -702,7 +692,7 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
channel_multiplier=channel_multiplier,
num_features=num_features,
bn_args=_resolve_bn_args(kwargs),
act_fn=hard_swish,
act_fn=nn.HSwish(),
**kwargs
)
return model
......
......@@ -125,9 +125,6 @@ def train_net():
ms.set_auto_parallel_context(all_reduce_fusion_config=config.all_reduce_fusion_config)
rank = get_rank()
# Set mempool block size in PYNATIVE_MODE for improving memory utilization, which will not take effect in GRAPH_MODE
ms.set_context(mempool_block_size="31GB")
mindrecord_file = create_mindrecord(config.dataset, "ssd.mindrecord", True)
if config.only_create_dataset:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment