Skip to content
Snippets Groups Projects
Commit c536479e authored by zhaoting's avatar zhaoting
Browse files

fix pagenet

parent f25a150a
No related branches found
No related tags found
No related merge requests found
......@@ -37,7 +37,7 @@ PAGE-Net网络由三个部分组成,提取特征的CNN模块,金字塔注意
### 数据集配置
数据集目录修改在config.py中,训练集变量为train_dataset_imgs,train_dataset_gts,train_dataset_edges,
数据集目录修改在config.py中,训练集变量为train_dataset_imgs,train_dataset_gts,train_dataset_edges, vgg_init
测试集路径请自行修改
测试集若要使用自己的数据集,请添加数据集路径,并在train.py中添加新增的数据集
......@@ -84,11 +84,10 @@ PAGE-Net网络由三个部分组成,提取特征的CNN模块,金字塔注意
├── default_config_ascend.yaml # 参数配置脚本文件(ascned)
├── default_config_gpu.yaml # 参数配置脚本文件(gpu)
├── scripts
│ ├── run_standalone_train_gpu.sh # 单卡训练脚本文件(gpu)
│ ├── run_standalone_train.sh # 单卡训练脚本文件(ascend)
│ ├── run_standalone_train.sh # 单卡训练脚本文件(ascend & gpu)
│ ├── run_distribute_train_gpu.sh # 多卡训练脚本文件(gpu)
│ ├── run_distribute_train.sh # 多卡训练脚本文件(ascend)
│ ├── run_eval.sh # 评估脚本文件
│ ├── run_eval.sh # 评估脚本文件(ascend & gpu)
├── src
| ├── model_utils
| | ├── config.py
......@@ -124,17 +123,15 @@ model: "output/PAGENET.ckpt" # 测试时使用的checkpoint文件
### 训练
```markdown
cd scripts
bash run_standalone_train_gpu.sh [CONFIG_PATH] #运行gpu单卡训练
bash run_standalone_train.sh [CONFIG_PATH] #运行ascend单卡训练,config路径默认为ascend
```shell
bash scripts/run_standalone_train.sh [DEVICE_ID] [CONFIG_PATH] #运行单卡训练,config路径默认为ascend
```
### 分布式训练
```markdown
bash run_distribute_train_gpu.sh [CONFIG_PATH] #运行gpu分布式训练
bash run_distribute_train.sh 8 rank_table_8pcs.json [CONFIG_PATH] #运行ascend分布式训练,config路径默认为ascend
```shell
bash scripts/run_distribute_gpu.sh [DEVICE_NUM] [CONFIG_PATH] #运行gpu分布式训练
bash scripts/run_distribute_train.sh [DEVICE_NUM] [RANK_TABLE_FILE] [CONFIG_PATH] #运行ascend分布式训练,config路径默认为ascend
```
### 云上训练
......@@ -152,13 +149,13 @@ bash run_distribute_train.sh 8 rank_table_8pcs.json [CONFIG_PATH] #运行asce
## 评估过程
```markdown
bash run_eval.sh [CONFIG_PATH] #运行推理
bash scripts/eval.sh [DEVICE_ID] [CONFIG_PATH] #运行推理
```
## 导出过程
```markdown
python export.py #导出mindir,模型文件路径为config中的ckpt_file
```shell
python export.py --config_path=[CONFIG_PATH] #导出mindir,模型文件路径为config中的ckpt_file
```
## 模型描述
......
......@@ -11,7 +11,7 @@ test_task: "DUT-OMRON"
test_img_path: "./dataset/test_dataset/DUT-OMRON/DUT-OMRON-image"
test_gt_path: "./dataset/test_dataset/DUT-OMRON/DUT-OMRON-mask"
vgg_init: ""
vgg_init: "vgg16_20M.ckpt"
batch_size: 10
train_size: 224
......
......@@ -41,7 +41,7 @@ def main(test_img_path, test_gt_path, ckpt_file):
mae = nn.MAE()
F_score = nn.F1()
# model
model = MindsporeModel()
model = MindsporeModel(config)
ckpt_file_name = ckpt_file
ms.load_checkpoint(ckpt_file_name, net=model)
......
......@@ -32,10 +32,10 @@ def run_export():
"""
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
net = MindsporeModel()
net = MindsporeModel(config)
if not os.path.exists(config.ckpt_file):
print("config.ckpt_file is None.")
raise ValueError("config.ckpt_file is None.")
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(net, param_dict)
......
......@@ -14,16 +14,35 @@
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: bash scripts/run_distribute_train.sh [DEVICE_NUM] [RANK_TABLE_FILE] [CONFIG_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
BASE_DIR=$(cd "$(dirname "$0")" || exit; pwd)
RANK_SIZE=$1
RANK_TABLE_FILE=$2
echo $RANK_TABLE_FILE
RANK_TABLE_FILE=$(get_real_path $2)
CONFIG_PATH=$(get_real_path $3)
if [ ! -f ${RANK_TABLE_FILE} ]; then
echo "${RANK_TABLE_FILE} file not exists"
echo "rank table ${RANK_TABLE_FILE} file not exists"
exit
fi
if [ ! -f ${CONFIG_PATH} ]; then
echo "config path ${CONFIG_PATH} file not exists"
exit
fi
export RANK_TABLE_FILE=${RANK_TABLE_FILE}
export RANK_SIZE=${RANK_SIZE}
rank_start=0
......@@ -36,11 +55,13 @@ do
rm -rf device$DEVICE_ID
mkdir device$DEVICE_ID
ln -s $BASE_DIR/../data ./device$DEVICE_ID
cp -r $BASE_DIR/../src ./device$DEVICE_ID
cp $BASE_DIR/../*.py ./device$DEVICE_ID
cp $BASE_DIR/../*.ckpt ./device$DEVICE_ID
cp $BASE_DIR/../*.yaml ./device$DEVICE_ID
cd ./device$DEVICE_ID
python -u ./train.py --train_mode 'distribute' config_path $3 > train.log 2>&1 &
python -u ./train.py \
--train_mode="distribute" \
--device_target="Ascend" \
--config_path=$CONFIG_PATH > train.log 2>&1 &
cd ../
done
......@@ -15,9 +15,9 @@
# ============================================================================
# Get absolute path
if [ $# != 1 ]
if [ $# != 2 ]
then
echo "Usage: bash run_distribute_gpu.sh [CONFIG_PATH]"
echo "Usage: bash scripts/run_distribute_gpu.sh [DEVICE_NUM] [CONFIG_PATH]"
exit 1
fi
......@@ -31,9 +31,18 @@ get_real_path(){
# Get current script path
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
CONFIG_PATH=$(get_real_path $2)
if [ ! -f ${CONFIG_PATH} ]; then
echo "config path ${CONFIG_PATH} file not exists"
exit
fi
cd $BASE_PATH/..
mpirun --allow-run-as-root -n 8 python train.py --train_mode 'distribute' --config_path $1 &> distribute.log 2>&1 &
mpirun --allow-run-as-root -n $1 python train.py \
--train_mode="distribute" \
--device_target="GPU" \
--config_path=$CONFIG_PATH &> distribute.log 2>&1 &
echo "The train log is at ../distribute.log."
......@@ -13,11 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
if [ $# != 1 ]
if [ $# != 2 ]
then
echo "Usage: bash run_eval.sh [CONFIG_PATH]"
echo "Usage: bash scripts/eval.sh [DEVICE_ID] [CONFIG_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
cd ..
python eval.py --config_path $1 &> test.log 2>&1 &
export DEVICE_ID=$1
export CUDA_VISIBLE_DEVICES=$1
export RANK_ID=0
CONFIG_PATH=$(get_real_path $2)
if [ ! -f ${CONFIG_PATH} ]; then
echo "config path ${CONFIG_PATH} file not exists"
exit
fi
# Get current script path
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/..
python eval.py --config_path=$CONFIG_PATH &> test.log 2>&1 &
echo "The eval log is at $BASE_PATH/../test.log."
......@@ -13,7 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
421
if [ $# != 2 ]
then
echo "Usage: bash scripts/run_standalone_train.sh [DEVICE_ID] [CONFIG_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
......@@ -21,13 +27,21 @@ get_real_path(){
echo "$(realpath -m $PWD/$1)"
fi
}
export DEVICE_ID=0
export RANK_ID=0
export DEVICE_ID=$1
export CUDA_VISIBLE_DEVICES=$1
export RANK_ID=0
CONFIG_PATH=$(get_real_path $2)
if [ ! -f ${CONFIG_PATH} ]; then
echo "config path ${CONFIG_PATH} file not exists"
exit
fi
# Get current script path
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/..
python train.py --config_path $1 &> standalone_train.log 2>&1 &
python train.py --config_path=$CONFIG_PATH &> standalone_train.log 2>&1 &
echo "The train log is at ../standalone_train.log."
\ No newline at end of file
echo "The train log is at $BASE_PATH/../standalone_train.log."
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Usage: bash run_standalone_train_gpu.sh "
# Get absolute path
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
# Get current script path
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/..
python train.py --train_mode 'single' --config_path $1 &> standalone_train.log 2>&1 &
echo "The train log is at ../standalone_train.log."
......@@ -129,6 +129,7 @@ def get_config():
"../../default_config_ascend.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
print(f"config path is {path_args.config_path}")
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
......
......@@ -25,6 +25,8 @@ from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.mytrainonestep import CustomTrainOneStepCell
ms.set_seed(1)
@moxing_wrapper()
def main():
context.set_context(mode=config.MODE,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment