Skip to content
Snippets Groups Projects
Commit c0e33496 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1677 fix seq2seq2 issue & add shufflenetv2 standalone train script

Merge pull request !1677 from JichenZhao/master
parents 17cf987c 8c3d18a6
No related branches found
No related tags found
No related merge requests found
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp -r ../modelarts ./train
cp -r ../src ./train
cp -r ../infer ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py \
--dataset_path=$DATASET_PATH \
--is_distributed=False \
--platform=Ascend > log.txt 2>&1 &
cd ..
......@@ -39,6 +39,11 @@
训练集:5,822,653张图片,85742个类
```python
#将rec数据格式转换成jpg
python src/rec2jpg_dataset.py --include rec/dataset/path --output output/path
```
# 环境要求
- 硬件:昇腾处理器(Ascend)
......@@ -59,7 +64,7 @@
sh scripts/run_distribute_train.sh rank_size /path/dataset
# 单机训练运行示例
sh scripts/run_standalone_train.sh /path/dataset
sh scripts/run_standalone_train.sh /path/dataset device_id
# 运行评估示例
sh scripts/run_eval.sh /path/evalset /path/ckpt
......@@ -90,6 +95,7 @@ sh scripts/run_eval.sh /path/evalset /path/ckpt
├── loss.py //损失函数
├── dataset.py // 创建数据集
├── iresnet.py // ResNet架构
├── rec2jpg_dataset.py // 将rec数据格式转换成jpg
├── val.py // 测试脚本
├── train.py // 训练脚本
├── export.py
......
......@@ -17,13 +17,13 @@
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run.sh DATA_PATH"
echo "For example: bash run.sh path/MS1M"
echo "For example: bash run.sh path/MS1M DEVICE_ID"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
# shellcheck disable=SC2034
DATA_PATH=$1
export DEVICE_ID=$2
python train.py \
--data_url $DATA_PATH \
--device_num 1 \
......
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
rec format to jpg
"""
import os
import argparse
import cv2
import mxnet as mx
def main(input_args):
"""
trans rec format to jpg
:param args: inputs arguments
:return:
"""
include_datasets = input_args.include.split(',')
rec_list = []
for ds in include_datasets:
path_imgrec = os.path.join(ds, 'train.rec')
path_imgidx = os.path.join(ds, 'train.idx')
imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
rec_list.append(imgrec)
if not os.path.exists(input_args.output):
os.makedirs(input_args.output)
imgid = 0
for ds_id in range(len(rec_list)):
imgrec = rec_list[ds_id]
s = imgrec.read_idx(0)
header, _ = mx.recordio.unpack(s)
assert header.flag > 0
seq_identity = range(int(header.label[0]), int(header.label[1]))
for identity in seq_identity:
s = imgrec.read_idx(identity)
header, _ = mx.recordio.unpack(s)
for _idx in range(int(header.label[0]), int(header.label[1])):
s = imgrec.read_idx(_idx)
_header, _img = mx.recordio.unpack(s)
label = int(_header.label[0])
class_path = os.path.join(args.output, "%d" % label)
if not os.path.exists(class_path):
os.makedirs(class_path)
_img = mx.image.imdecode(_img).asnumpy()[:, :, ::-1] # to bgr
image_path = os.path.join(class_path, "%d_%d.jpg" % (label, imgid))
cv2.imwrite(image_path, _img)
imgid += 1
if imgid % 10000 == 0:
print(imgid)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='do dataset merge')
# general
parser.add_argument('--include', default='', type=str, help='')
parser.add_argument('--output', default='', type=str, help='')
args = parser.parse_args()
main(args)
......@@ -206,7 +206,7 @@ bash wmt14_en_fr.sh
```bash
# grep "accuracy:"
BLEU scores is :12.9
BLEU scores is :12.1
```
# 模型描述
......@@ -243,7 +243,7 @@ bash wmt14_en_fr.sh
| 数据集 | WMT14 |
| batch_size | 128 |
| 输出 | BLEU |
| 准确性 | 8卡: BLEU=12.9 |
| 准确性 | 8卡: BLEU=12.1 |
# 随机情况说明
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment