Skip to content
Snippets Groups Projects
Unverified Commit d66e4224 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2797 [哈尔滨工业大学][高校贡献][Mindspore][PAGENet]

Merge pull request !2797 from xiaoxiaoyaoa1207/demo2
parents d00dc7bd 6d79038e
No related branches found
No related tags found
No related merge requests found
## 目录
- [目录](#目录)
- [pagenet描述](#pagenet描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [数据集配置](#数据集配置)
- [环境要求](#环境要求)
- [脚本说明](#脚本说明)
- [代码文件说明](#代码文件说明)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [导出过程](#导出过程)
- [模型描述](#模型描述)
- [评估性能](#评估性能)
- [推理性能](#推理性能)
- [modelzoo主页](#modelzoo主页)
## pagenet描述
PAGE-Net是通过监督学习解决显著性目标检测问题,它由提取特征的骨干网络模块,金字塔注意力模块和显著性边缘检测模块三部分构成。作者通过融合不同分辨率的显著性信息使得到的特征有更大的感受野和更好的表达能力,同时显著性边缘检测模块获得的边缘信息也能更加精确的分割显著性物体的边缘部分,从而使检测的结果更加精确。与其他19个工作在6个数据集上通过3种评价指标进行评估表明,PAGE-Net有着更加优异的性能和有竞争力的结果。
[PAGE-Net的tensorflow-keras源码](https://github.com/wenguanwang/PAGE-Net),由论文作者提供。具体包含运行文件、模型文件,此外还有数据集,预训练模型的获取途径。
[论文](https://www.researchgate.net/publication/332751907_Salient_Object_Detection_With_Pyramid_Attention_and_Salient_Edges):Wang W , Zhao S , Shen J , et al. Salient Object Detection With Pyramid Attention and Salient Edges[C]// CVPR19. 2019.
## 模型架构
PAGE-Net网络由三个部分组成,提取特征的CNN模块,金字塔注意力模块和边缘检测模块。预处理后的输入图片通过降采样输出特征信息,与此同时,对每一层的特征通过金字塔注意力模块生成更好表达力的特征,然后将边缘信息与不同深度提取出来的多尺度特征进行融合,最终输出了一张融合后的显著性检测图像。
## 数据集
数据集统一放在一个目录
### 数据集配置
数据集目录修改在config.py中,训练集变量为train_dataset_imgs,train_dataset_gts,train_dataset_edges,
测试集路径请自行修改
测试集若要使用自己的数据集,请添加数据集路径,并在train.py中添加新增的数据集
- 训练集:
[THUS10K数据集]([MSRA10K Salient Object Database – 程明明个人主页 (mmcheng.net)](https://mmcheng.net/msra10k/)) , 342MB,共有10000张带有标签的图像
- 测试集:
[ECSSD数据集](https://gitee.com/link?target=http%3A%2F%2Fwww.cse.cuhk.edu.hk%2Fleojia%2Fprojects%2Fhsaliency%2Fdata%2FECSSD%2Fimages.zip%EF%BC%8Chttp%3A%2F%2Fwww.cse.cuhk.edu.hk%2Fleojia%2Fprojects%2Fhsaliency%2Fdata%2FECSSD%2Fground_truth_mask.zip),67.2MB,共1000张
[DUTS-OMRON数据集](https://gitee.com/link?target=http%3A%2F%2Fsaliencydetection.net%2Fdut-omron%2F),113MB,共5163张
[HKU-IS数据集](https://gitee.com/link?target=https%3A%2F%2Fi.cs.hku.hk%2F~gbli%2Fdeep_saliency.html),899MB,共4447张
[SOD数据集](https://gitee.com/link?target=https%3A%2F%2Fwww.elderlab.yorku.ca%2F%3Fsmd_process_download%3D1%26download_id%3D8285),19.7MB,共1000张
[DUTS-TE数据集](https://gitee.com/link?target=http%3A%2F%2Fsaliencydetection.net%2Fduts%2Fdownload%2FDUTS-TE.zip),132MB,共5019张
## 环境要求
- 硬件(CPU/GPU)
- 如需查看详情,请参见如下资源:
[MindSpore教程](https://gitee.com/link?target=https%3A%2F%2Fwww.mindspore.cn%2Ftutorials%2Fzh-CN%2Fmaster%2Findex.html)
[MindSpore Python API](https://gitee.com/link?target=https%3A%2F%2Fwww.mindspore.cn%2Fdocs%2Fapi%2Fzh-CN%2Fmaster%2Findex.html)
- 需要的包
Mindspore-GPU 1.5.0
## 脚本说明
### 代码文件说明
```markdown
├── model_zoo
├── PAGENet
├── dataset
│ ├──train_dataset #训练集
│ ├──test_dataset #测试集
├──README.md # README文件
├── config.py # 参数配置脚本文件
├── scripts
│ ├──run_standalone_train_gpu.sh # 单卡训练脚本文件
│ ├──run_distribute_train_gpu.sh # 多卡训练脚本文件
│ ├──run_eval.sh # 评估脚本文件
├── src
│ ├──mind_dataloader.py # 加载数据集并进行预处理
│ ├──pagenet.py # pageNet的网络结构
│ ├──train_loss.py # 损失定义
├── train.py # 训练脚本
├── eval.py # 评估脚本
├── export.py # 模型导出脚本
```
### 脚本参数
```markdown
device_target: "GPU" # 运行设备 ["CPU", "GPU"]
batch_size: 20 # 训练批次大小
n_ave_grad: 10 # 梯度累积step数
epoch_size: 100 # 总计训练epoch数
image_height: 224 # 输入到模型的图像高度
image_width: 224 # 输入到模型的图像宽度
train_path: "./data/DUTS-TR/" # 训练数据集的路径
test_path: "./data" # 测试数据集的根目录
vgg: "/home/EGnet/EGnet/model/vgg16.ckpt" # vgg预训练模型的路径
resnet: "/home/EGnet/EGnet/model/resnet50.ckpt" # resnet预训练模型的路径
model: "EGNet/run-nnet/models/final_vgg_bone.ckpt" # 测试时使用的checkpoint文件
```
## 训练过程
```bash
### 训练
cd scripts
bash run_standalone_train_gpu.sh
### 分布式训练
bash run_distribute_train_gpu.sh
## 评估过程
bash run_eval.sh [CKPT_FILE] #CKPT_FILE 为权重文件名,请将权重文件放在当前目录下
## 导出过程
python export.py
```
## 模型描述
### 评估性能
THUS10K上的PAGE-Net(GPU)
| 参数 | GPU(单卡) | GPU(8卡) |
| ------------- | -------------------------------- | ------------------------------- |
| 模型 | PAGE-Net | PAGE-Net |
| 上传日期 | 2022.6.20 | 2022.6.20 |
| Mindspore版本 | 1.5.0 | 1.5.0 |
| 数据集 | THUS10K | THUS10K |
| 训练参数 | epoch=100,steps=1000,batch_size=10|epoch=200,steps=125,batch_size=10|
| 损失函数 | MSE&BCE | MSE&BCE |
| 优化器 | Adam | Adam |
| 速度 | 52s/step | 87s/step |
| 总时长 | 7h15m0s | 3h28m0s |
| 微调检查点 | 390M(.ckpt文件) | 390M(.ckpt文件) |
### 推理性能
显著性目标检测数据集上的PAGE-Net(GPU)
| 参数 | GPU(单卡) | GPU(8卡) |
| ------------- | --------------------- | ---------------------|
| 模型 | PAGE-Net | PAGE-Net |
| 上传日期 | 2022.6.20 | 2022.6.20 |
| Mindspore版本 | 1.5.0 | 1.5.0 |
| 数据集 | SOD, 1000张图像 | SOD, 1000张图像 |
| 评估指标 | F-score:0.974 | F-score:0.974 |
| 数据集 | ECCSD, 1000张图像 | ECCSD, 1000张图像 |
| 评估指标 | F-score: 0.845 | F-score:0.845 |
| 数据集 | DUTS-OMRON, 5163张图像| DUTS-OMRON, 5163张图像|
| 评估指标 | F-score: 0.80 | F-score: 0.80 |
| 数据集 | HKU-IS, 4447张图像 | HKU-IS, 4447张图像 |
| 评估指标 | F-score: 0.842 | F-score: 0.842 |
| 数据集 | DUTS-TE, 5019张图像 | DUTS-TE, 5019张图像 |
| 评估指标 | F-score: 0.778 | F-score: 0.778 |
## modelzoo主页
请浏览官网[主页](https://gitee.com/mindspore/models)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import context
train_img_path = "./dataset/train_dataset/images"
train_gt_path = "./dataset/train_dataset/labels"
train_edge_path = "./dataset/train_dataset/edges"
DUT_OMRON_img_path = "./dataset/test_dataset/DUT-OMRON/DUT-OMRON-image"
DUT_OMRON_gt_path = "./dataset/test_dataset/DUT-OMRON/DUT-OMRON-mask"
DUTS_TE_img_path = "./dataset/test_dataset/DUTS-TE/DUTS-TE-Image"
DUTS_TE_gt_path = "./dataset/test_dataset/DUTS-TE/DUTS-TE-Mask"
ECCSD_img_path = "./dataset/test_dataset/ECCSD/ECCSD-image"
ECCSD_gt_path = "./dataset/test_dataset/ECCSD/ECCSD-mask"
HKU_IS_img_path = "./dataset/test_dataset/HKU-IS/HKU-IS-image"
HKU_IS_gt_path = "./dataset/test_dataset/HKU-IS/HKU-IS-mask"
SOD_img_path = "./dataset/test_dataset/SOD/SOD-image"
SOD_gt_path = "./dataset/test_dataset/SOD/SOD-mask"
batch_size = 10
train_size = 224
device_target = 'GPU'
LR = 2e-5
WD = 0.0005
EPOCH = 100
MODE = context.GRAPH_MODE
ckpt_file = "PAGENET.ckpt"
file_name = 'pagenet'
file_format = 'MINDIR'
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import time
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import context
import config
from config import MODE, device_target, train_size
from src.pagenet import MindsporeModel
from src.mind_dataloader_final import get_test_loader
def main(test_img_path, test_gt_path, ckpt_file):
# context_set
context.set_context(mode=MODE,
device_target=device_target,
reserve_class_name_in_scope=False)
# dataset
test_loader = get_test_loader(test_img_path, test_gt_path, batchsize=1, testsize=train_size)
data_iterator = test_loader.create_tuple_iterator()
# step
total_test_step = 0
test_data_size = test_loader.get_dataset_size()
# loss&eval
loss = nn.Loss()
mae = nn.MAE()
F_score = nn.F1()
# model
model = MindsporeModel()
ckpt_file_name = ckpt_file
ms.load_checkpoint(ckpt_file_name, net=model)
model.set_train(False)
mae.clear()
loss.clear()
start = time.time()
for imgs, targets in data_iterator:
targets1 = targets.astype(int)
outputs = model(imgs)
pre_mask = outputs[9]
pre_mask = pre_mask.flatten()
targets1 = targets1.flatten()
pre_mask1 = pre_mask.asnumpy().tolist()
F_pre = np.array([[1 - i, i] for i in pre_mask1])
F_score.update(F_pre, targets1)
mae.update(pre_mask, targets1)
total_test_step = total_test_step + 1
if total_test_step % 100 == 0:
print("evaling:{}/{}".format(total_test_step, test_data_size))
end = time.time()
total = end - start
print("total time is {}h".format(total / 3600))
print("step time is {}s".format(total / (test_data_size)))
mae_result = mae.eval()
F_score_result = F_score.eval()
print("mae: ", mae_result)
print("F-score: ", (F_score_result[0] + F_score_result[1]) / 2)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='manual to this script')
parser.add_argument('-s', '--test_set', type=str)
parser.add_argument('-c', '--ckpt', type=str)
args = parser.parse_args()
if args.test_set == 'DUT-OMRON':
img_path = config.DUT_OMRON_img_path
gt_path = config.DUT_OMRON_gt_path
elif args.test_set == 'DUTS-TE':
img_path = config.DUTS_TE_img_path
gt_path = config.DUTS_TE_gt_path
elif args.test_set == 'ECCSD':
img_path = config.ECCSD_img_path
gt_path = config.ECCSD_gt_path
elif args.test_set == 'HKU-IS':
img_path = config.HKU_IS_img_path
gt_path = config.HKU_IS_gt_path
elif args.test_set == 'SOD':
img_path = config.SOD_img_path
gt_path = config.SOD_gt_path
else:
print("dataset is not exist")
ckpt = args.ckpt
main(img_path, gt_path, ckpt)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, mindir models#################
python export.py
"""
import numpy as np
import mindspore as ms
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.pagenet import MindsporeModel
import config
def run_export():
"""
run export operation
"""
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
net = MindsporeModel()
if config.ckpt_file is not None:
print("config.ckpt_file is None.")
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.ones([config.batch_size, 3, 224, 224]), ms.float32)
export(net, input_arr, file_name=config.file_name, file_format=config.file_format)
if __name__ == "__main__":
run_export()
numpy
PIL
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Get absolute path
echo "Usage: bash run_distribute_gpu.sh"
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
# Get current script path
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/..
mpirun --allow-run-as-root -n 8 python train.py --train_mode 'distribute' &> distribute.log 2>&1 &
echo "The train log is at ../distribute.log."
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
if [ $# != 1 ]
then
echo "Usage: bash run_eval.sh [CKPT_PATH]"
exit 1
fi
cd ..
python eval.py -s 'DUT-OMRON' -c $1 &> test_O.log 2>&1 &
python eval.py -s 'DUTS-TE' -c $1 &> test_T.log 2>&1 &
python eval.py -s 'ECCSD' -c $1 &> test_E.log 2>&1 &
python eval.py -s 'HKU-IS' -c $1 &> test_H.log 2>&1 &
python eval.py -s 'SOD' -c $1 &> test_S.log 2>&1 &
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Usage: bash run_standalone_train_gpu.sh "
# Get absolute path
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
# Get current script path
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/..
python train.py --train_mode 'single' &> standalone_train.log 2>&1 &
echo "The train log is at ../standalone_train.log."
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.transforms as transforms
import mindspore.dataset.vision as vision
import mindspore.dataset.vision.c_transforms as C
from PIL import Image
# transform img to tensor and resize to 224x224x3
class TrainData:
"""
dataloader for pageNet
"""
def __init__(self, image_root, gt_root, edge_root, img_size, augmentations):
self.img_size = img_size
self.augmentations = augmentations
self.images = [image_root + "/" + f for f in os.listdir(image_root) if
f.endswith('.jpg') or f.endswith('.png')]
self.gts = [gt_root + "/" + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')]
self.edges = [edge_root + "/" + f for f in os.listdir(edge_root) if f.endswith('.jpg') or f.endswith('.png')]
self.images = sorted(self.images)
self.gts = sorted(self.gts)
self.edges = sorted(self.edges)
self.size = len(self.images)
print('no augmentation')
self.img_transform = transforms.c_transforms.Compose([
C.Resize((self.img_size, self.img_size)),
vision.py_transforms.ToTensor()])
self.gt_transform = transforms.c_transforms.Compose([
C.Resize((self.img_size, self.img_size)),
vision.py_transforms.ToTensor()])
self.edge_transform = transforms.c_transforms.Compose([
C.Resize((self.img_size, self.img_size)),
vision.py_transforms.ToTensor()])
def __getitem__(self, index):
img = Image.open(self.images[index], 'r').convert('RGB')
gt = Image.open(self.gts[index], 'r').convert('1')
edge = Image.open(self.edges[index], 'r').convert('1')
if self.img_transform is not None:
img = np.array(img, dtype=np.float32)
img -= np.array((104.00699, 116.66877, 122.67892))
img = self.img_transform(img)
img = img * 255
if self.gt_transform is not None:
gt = self.gt_transform(gt)
gt = gt * 255
if self.edge_transform is not None:
edge = self.edge_transform(edge)
edge = edge * 255
return img, gt, edge
def __len__(self):
return self.size
class TestData:
"""
dataloader for pageNet
"""
def __init__(self, image_root, gt_root, img_size, augmentations):
self.img_size = img_size
self.augmentations = augmentations
self.images = [image_root + "/" + f for f in os.listdir(image_root) if
f.endswith('.jpg') or f.endswith('.png')]
self.gts = [gt_root + "/" + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')]
self.images = sorted(self.images)
self.gts = sorted(self.gts)
self.size = len(self.images)
self.img_transform = transforms.c_transforms.Compose([
C.Resize((self.img_size, self.img_size)),
vision.py_transforms.ToTensor()])
self.gt_transform = transforms.c_transforms.Compose([
C.Resize((self.img_size, self.img_size)),
vision.py_transforms.ToTensor(),
])
def __getitem__(self, index):
img = Image.open(self.images[index], 'r').convert('RGB')
gt = Image.open(self.gts[index], 'r').convert('1')
if self.img_transform is not None:
img = np.array(img, dtype=np.float32)
img -= np.array((104.00699, 116.66877, 122.67892))
img = self.img_transform(img)
img = img * 255
if self.gt_transform is not None:
gt = self.gt_transform(gt)
gt = gt * 255
return img, gt
def __len__(self):
return self.size
def get_train_loader(image_root, gt_root, edge_root, batchsize, trainsize, device_num=1, rank_id=0, shuffle=True,
num_parallel_workers=1, augmentation=False):
dataset_generator = TrainData(image_root, gt_root, edge_root, trainsize, augmentation)
dataset = ds.GeneratorDataset(dataset_generator, ["imgs", "gts", "edges"], shuffle=shuffle,
num_parallel_workers=num_parallel_workers, num_shards=device_num, shard_id=rank_id)
data_loader = dataset.batch(batch_size=batchsize)
return data_loader
def get_test_loader(image_root, gt_root, batchsize, testsize, augmentation=False):
dataset_generator = TestData(image_root, gt_root, testsize, augmentation)
dataset = ds.GeneratorDataset(dataset_generator, ["imgs", "gts"])
data_loader = dataset.batch(batch_size=batchsize)
return data_loader
This diff is collapsed.
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SalEdgeLoss define"""
import mindspore as ms
from mindspore import nn
from mindspore import Parameter
class total_loss(nn.Cell):
def __init__(self):
super(total_loss, self).__init__()
self.loss_fn1 = nn.MSELoss()
self.loss_fn2 = nn.BCELoss(reduction="mean")
self.zero = ms.Tensor(0, dtype=ms.float32)
# for log
self.sal_loss = Parameter(default_input=0.0, requires_grad=False)
self.edge_loss = Parameter(default_input=0.0, requires_grad=False)
self.total_loss = Parameter(default_input=0.0, requires_grad=False)
def construct(self, pres, gts, edges):
loss_edg_5 = self.loss_fn1(pres[1], edges)
loss_sal_5 = self.loss_fn2(pres[2], gts)
loss_5 = loss_sal_5 + loss_edg_5
loss_4 = self.loss_fn1(pres[3], edges) + self.loss_fn2(pres[4], gts)
loss_3 = self.loss_fn1(pres[5], edges) + self.loss_fn2(pres[6], gts)
loss_2 = self.loss_fn1(pres[7], edges) + self.loss_fn2(pres[8], gts)
loss_1 = self.loss_fn1(pres[10], edges) + self.loss_fn2(pres[9], gts)
loss = loss_1 + loss_2 + loss_3 + loss_4 + loss_5
return loss
class WithLossCell(nn.Cell):
"""
loss cell
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self.backbone = backbone
self.loss_fn = loss_fn
def construct(self, data, gts, edges):
"""
compute loss
"""
pres = self.backbone(data)
return self.loss_fn(pres, gts, edges)
@property
def backbone_network(self):
return self.backbone
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import time
import argparse
import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore.communication.management import init, get_rank, get_group_size
from config import MODE, device_target, train_size, train_img_path, train_edge_path, train_gt_path, batch_size, EPOCH, \
LR, WD
from src.pagenet import MindsporeModel
from src.train_loss import total_loss, WithLossCell
from src.mind_dataloader_final import get_train_loader
def main(train_mode='single'):
context.set_context(mode=MODE,
device_target=device_target,
reserve_class_name_in_scope=False)
if train_mode == 'single':
# env set
# dataset
train_loader = get_train_loader(train_img_path, train_gt_path, train_edge_path, batchsize=batch_size,
trainsize=train_size)
train_data_size = train_loader.get_dataset_size()
# epoch
epoch = EPOCH
else:
init()
rank_id = get_rank()
device_num = get_group_size()
context.set_auto_parallel_context(device_num=device_num, gradients_mean=True,
parallel_mode=context.ParallelMode.DATA_PARALLEL)
# dataset
train_loader = get_train_loader(train_img_path, train_gt_path, train_edge_path, device_num=device_num,
rank_id=rank_id, num_parallel_workers=8, batchsize=batch_size,
trainsize=train_size)
train_data_size = train_loader.get_dataset_size()
# epoch
epoch = EPOCH * 2
# setup train_parameters
model = MindsporeModel()
# loss function
loss_fn = total_loss()
# learning_rate and optimizer
optimizer = nn.Adam(model.trainable_params(), learning_rate=LR, weight_decay=WD)
# train model
net_with_loss = WithLossCell(model, loss_fn)
train_network = nn.TrainOneStepCell(net_with_loss, optimizer)
train_network.set_train()
data_iterator = train_loader.create_tuple_iterator(num_epochs=epoch)
start = time.time()
for i in range(epoch):
total_train_step = 0
for imgs, gts, edges in data_iterator:
loss = train_network(imgs, gts, edges)
total_train_step = total_train_step + 1
if total_train_step % 10 == 0:
print("epoch: {}, step: {}/{}, loss: {}".format(i, total_train_step, train_data_size, loss))
if train_mode == 'single':
mindspore.save_checkpoint(train_network, "PAGENET" + '.ckpt')
print("PAGENET.ckpt" + " have saved!")
else:
mindspore.save_checkpoint(train_network, "PAGENET" + str(get_rank()) + '.ckpt')
print("PAGENET" + str(get_rank()) + '.ckpt' + " have saved!")
end = time.time()
total = end - start
print("total time is {}h".format(total / 3600))
print("step time is {}s".format(total / (train_data_size * epoch)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='manual to this script')
parser.add_argument('-m', '--train_mode', type=str)
args = parser.parse_args()
main(args.train_mode)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment