Skip to content
Snippets Groups Projects
Commit b35be0f0 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1695 fix some bug in ProtoNet & Pnasnet

Merge pull request !1695 from zhaoting/master
parents a1bf5a8e 7f137ea0
No related branches found
No related tags found
No related merge requests found
......@@ -29,7 +29,7 @@ get_real_path(){
export DATA_PATH=$(get_real_path $1) # dataset path
export TRAIN_CLASS=$2 # train class, propose 20
export EPOCHS=$3 # num of epochs
export DEVICE_NUM=$4 # device_num
export RANK_SIZE=$4 # device_num
if [ ! -d $DATA_PATH ]
then
......@@ -40,7 +40,7 @@ fi
rm -rf distribute_output
mkdir distribute_output
mpirun --allow-run-as-root -n $DEVICE_NUM --output-filename log_output --merge-stderr-to-stdout \
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python ../train.py --dataset_root=$DATA_PATH \
--device_target="GPU" \
--classes_per_it_tr=$TRAIN_CLASS \
......
......@@ -24,6 +24,7 @@ from mindspore import dataset as ds
import mindspore.context as context
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode
from mindspore.common import set_seed
from src.EvalCallBack import EvalCallBack
from src.protonet import WithLossCell
from src.PrototypicalLoss import PrototypicalLoss
......@@ -33,7 +34,7 @@ from model_init import init_dataloader
local_data_url = './cache/data'
local_train_url = './cache/out'
set_seed(1)
def train(opt, tr_dataloader, net, loss_fn, eval_loss_fn, optim, path, rank_id, val_dataloader=None):
'''
......@@ -68,17 +69,15 @@ def main():
global local_train_url
options = get_parser().parse_args()
device_num = int(os.environ.get("DEVICE_NUM", 1))
if options.device_target == "GPU":
rank_id = get_rank()
device_num = int(os.environ.get("RANK_SIZE", "1"))
rank_id = 0
if options.run_offline:
if device_num > 1:
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
rank_id = get_rank()
if options.device_target == "Ascend":
context.set_context(device_id=options.device_id)
......@@ -105,6 +104,7 @@ def main():
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
rank_id = get_rank()
local_data_url = os.path.join(local_data_url, str(device_id))
local_train_url = os.path.join(local_train_url, str(device_id))
......
......@@ -39,8 +39,6 @@ cutout_length: 56
# Dataset config
train_batch_size: 32
val_batch_size: 125
# True for GPU, False for Ascend
drop_remainder: False
#learning rate config
lr_init: 0.32
......
......@@ -72,10 +72,9 @@ if __name__ == '__main__':
else:
dataset_val_path = os.path.join(dataset_path, 'val')
drop_remainder = config.drop_remainder
dataset = create_dataset(dataset_val_path, do_train=False, rank=device_id,
group_size=1, batch_size=config.val_batch_size,
drop_remainder=drop_remainder, shuffle=False)
drop_remainder=False, shuffle=False)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
......
......@@ -118,12 +118,10 @@ def train():
else:
dataset_train_path = os.path.join(dataset_path, 'train')
drop_remainder = config.drop_remainder
train_dataset = create_dataset(dataset_train_path, True, config.rank, config.group_size,
num_parallel_workers=config.work_nums,
batch_size=config.train_batch_size,
drop_remainder=drop_remainder, shuffle=True,
drop_remainder=True, shuffle=True,
cutout=config.cutout, cutout_length=config.cutout_length)
train_batches_per_epoch = train_dataset.get_dataset_size()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment