Skip to content
Snippets Groups Projects
Commit c536479e authored by zhaoting's avatar zhaoting
Browse files

fix pagenet

parent f25a150a
Branches
No related tags found
No related merge requests found
......@@ -37,7 +37,7 @@ PAGE-Net网络由三个部分组成,提取特征的CNN模块,金字塔注意
### 数据集配置
数据集目录修改在config.py中,训练集变量为train_dataset_imgs,train_dataset_gts,train_dataset_edges,
数据集目录修改在config.py中,训练集变量为train_dataset_imgs,train_dataset_gts,train_dataset_edges, vgg_init
测试集路径请自行修改
测试集若要使用自己的数据集,请添加数据集路径,并在train.py中添加新增的数据集
......@@ -84,11 +84,10 @@ PAGE-Net网络由三个部分组成,提取特征的CNN模块,金字塔注意
├── default_config_ascend.yaml # 参数配置脚本文件(ascned)
├── default_config_gpu.yaml # 参数配置脚本文件(gpu)
├── scripts
│ ├── run_standalone_train_gpu.sh # 单卡训练脚本文件(gpu)
│ ├── run_standalone_train.sh # 单卡训练脚本文件(ascend)
│ ├── run_standalone_train.sh # 单卡训练脚本文件(ascend & gpu)
│ ├── run_distribute_train_gpu.sh # 多卡训练脚本文件(gpu)
│ ├── run_distribute_train.sh # 多卡训练脚本文件(ascend)
│ ├── run_eval.sh # 评估脚本文件
│ ├── run_eval.sh # 评估脚本文件(ascend & gpu)
├── src
| ├── model_utils
| | ├── config.py
......@@ -124,17 +123,15 @@ model: "output/PAGENET.ckpt" # 测试时使用的checkpoint文件
### 训练
```markdown
cd scripts
bash run_standalone_train_gpu.sh [CONFIG_PATH] #运行gpu单卡训练
bash run_standalone_train.sh [CONFIG_PATH] #运行ascend单卡训练,config路径默认为ascend
```shell
bash scripts/run_standalone_train.sh [DEVICE_ID] [CONFIG_PATH] #运行单卡训练,config路径默认为ascend
```
### 分布式训练
```markdown
bash run_distribute_train_gpu.sh [CONFIG_PATH] #运行gpu分布式训练
bash run_distribute_train.sh 8 rank_table_8pcs.json [CONFIG_PATH] #运行ascend分布式训练,config路径默认为ascend
```shell
bash scripts/run_distribute_gpu.sh [DEVICE_NUM] [CONFIG_PATH] #运行gpu分布式训练
bash scripts/run_distribute_train.sh [DEVICE_NUM] [RANK_TABLE_FILE] [CONFIG_PATH] #运行ascend分布式训练,config路径默认为ascend
```
### 云上训练
......@@ -152,13 +149,13 @@ bash run_distribute_train.sh 8 rank_table_8pcs.json [CONFIG_PATH] #运行asce
## 评估过程
```markdown
bash run_eval.sh [CONFIG_PATH] #运行推理
bash scripts/eval.sh [DEVICE_ID] [CONFIG_PATH] #运行推理
```
## 导出过程
```markdown
python export.py #导出mindir,模型文件路径为config中的ckpt_file
```shell
python export.py --config_path=[CONFIG_PATH] #导出mindir,模型文件路径为config中的ckpt_file
```
## 模型描述
......
......@@ -11,7 +11,7 @@ test_task: "DUT-OMRON"
test_img_path: "./dataset/test_dataset/DUT-OMRON/DUT-OMRON-image"
test_gt_path: "./dataset/test_dataset/DUT-OMRON/DUT-OMRON-mask"
vgg_init: ""
vgg_init: "vgg16_20M.ckpt"
batch_size: 10
train_size: 224
......
......@@ -41,7 +41,7 @@ def main(test_img_path, test_gt_path, ckpt_file):
mae = nn.MAE()
F_score = nn.F1()
# model
model = MindsporeModel()
model = MindsporeModel(config)
ckpt_file_name = ckpt_file
ms.load_checkpoint(ckpt_file_name, net=model)
......
......@@ -32,10 +32,10 @@ def run_export():
"""
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
net = MindsporeModel()
net = MindsporeModel(config)
if not os.path.exists(config.ckpt_file):
print("config.ckpt_file is None.")
raise ValueError("config.ckpt_file is None.")
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(net, param_dict)
......
......@@ -14,16 +14,35 @@
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: bash scripts/run_distribute_train.sh [DEVICE_NUM] [RANK_TABLE_FILE] [CONFIG_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
BASE_DIR=$(cd "$(dirname "$0")" || exit; pwd)
RANK_SIZE=$1
RANK_TABLE_FILE=$2
echo $RANK_TABLE_FILE
RANK_TABLE_FILE=$(get_real_path $2)
CONFIG_PATH=$(get_real_path $3)
if [ ! -f ${RANK_TABLE_FILE} ]; then
echo "${RANK_TABLE_FILE} file not exists"
echo "rank table ${RANK_TABLE_FILE} file not exists"
exit
fi
if [ ! -f ${CONFIG_PATH} ]; then
echo "config path ${CONFIG_PATH} file not exists"
exit
fi
export RANK_TABLE_FILE=${RANK_TABLE_FILE}
export RANK_SIZE=${RANK_SIZE}
rank_start=0
......@@ -36,11 +55,13 @@ do
rm -rf device$DEVICE_ID
mkdir device$DEVICE_ID
ln -s $BASE_DIR/../data ./device$DEVICE_ID
cp -r $BASE_DIR/../src ./device$DEVICE_ID
cp $BASE_DIR/../*.py ./device$DEVICE_ID
cp $BASE_DIR/../*.ckpt ./device$DEVICE_ID
cp $BASE_DIR/../*.yaml ./device$DEVICE_ID
cd ./device$DEVICE_ID
python -u ./train.py --train_mode 'distribute' config_path $3 > train.log 2>&1 &
python -u ./train.py \
--train_mode="distribute" \
--device_target="Ascend" \
--config_path=$CONFIG_PATH > train.log 2>&1 &
cd ../
done
......@@ -15,9 +15,9 @@
# ============================================================================
# Get absolute path
if [ $# != 1 ]
if [ $# != 2 ]
then
echo "Usage: bash run_distribute_gpu.sh [CONFIG_PATH]"
echo "Usage: bash scripts/run_distribute_gpu.sh [DEVICE_NUM] [CONFIG_PATH]"
exit 1
fi
......@@ -31,9 +31,18 @@ get_real_path(){
# Get current script path
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
CONFIG_PATH=$(get_real_path $2)
if [ ! -f ${CONFIG_PATH} ]; then
echo "config path ${CONFIG_PATH} file not exists"
exit
fi
cd $BASE_PATH/..
mpirun --allow-run-as-root -n 8 python train.py --train_mode 'distribute' --config_path $1 &> distribute.log 2>&1 &
mpirun --allow-run-as-root -n $1 python train.py \
--train_mode="distribute" \
--device_target="GPU" \
--config_path=$CONFIG_PATH &> distribute.log 2>&1 &
echo "The train log is at ../distribute.log."
......@@ -13,11 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
if [ $# != 1 ]
if [ $# != 2 ]
then
echo "Usage: bash run_eval.sh [CONFIG_PATH]"
echo "Usage: bash scripts/eval.sh [DEVICE_ID] [CONFIG_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
export DEVICE_ID=$1
export CUDA_VISIBLE_DEVICES=$1
export RANK_ID=0
CONFIG_PATH=$(get_real_path $2)
if [ ! -f ${CONFIG_PATH} ]; then
echo "config path ${CONFIG_PATH} file not exists"
exit
fi
# Get current script path
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd ..
python eval.py --config_path $1 &> test.log 2>&1 &
cd $BASE_PATH/..
python eval.py --config_path=$CONFIG_PATH &> test.log 2>&1 &
echo "The eval log is at $BASE_PATH/../test.log."
......@@ -13,7 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
421
if [ $# != 2 ]
then
echo "Usage: bash scripts/run_standalone_train.sh [DEVICE_ID] [CONFIG_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
......@@ -21,13 +27,21 @@ get_real_path(){
echo "$(realpath -m $PWD/$1)"
fi
}
export DEVICE_ID=0
export DEVICE_ID=$1
export CUDA_VISIBLE_DEVICES=$1
export RANK_ID=0
CONFIG_PATH=$(get_real_path $2)
if [ ! -f ${CONFIG_PATH} ]; then
echo "config path ${CONFIG_PATH} file not exists"
exit
fi
# Get current script path
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/..
python train.py --config_path $1 &> standalone_train.log 2>&1 &
python train.py --config_path=$CONFIG_PATH &> standalone_train.log 2>&1 &
echo "The train log is at ../standalone_train.log."
\ No newline at end of file
echo "The train log is at $BASE_PATH/../standalone_train.log."
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Usage: bash run_standalone_train_gpu.sh "
# Get absolute path
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
# Get current script path
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/..
python train.py --train_mode 'single' --config_path $1 &> standalone_train.log 2>&1 &
echo "The train log is at ../standalone_train.log."
......@@ -129,6 +129,7 @@ def get_config():
"../../default_config_ascend.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
print(f"config path is {path_args.config_path}")
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
......
......@@ -25,6 +25,8 @@ from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.mytrainonestep import CustomTrainOneStepCell
ms.set_seed(1)
@moxing_wrapper()
def main():
context.set_context(mode=config.MODE,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment