Skip to content
Snippets Groups Projects
Commit 0fba1a96 authored by meng-linghui123's avatar meng-linghui123
Browse files

add enet model to master

parent 99d57726
No related branches found
No related tags found
No related merge requests found
Showing
with 2466 additions and 0 deletions
# E-NET
## 目录
<!-- TOC -->
- [E-NET](#E-NET)
- [目录](#目录)
- [E-NET描述](#E-NET描述)
- [概述](#概述)
- [论文](#论文)
- [精度](#精度)
- [环境](#环境)
- [数据集](#数据集)
- [脚本说明](#脚本说明)
- [训练与验证](#训练与验证)
- [单卡训练](#单卡训练)
- [多卡训练](#多卡训练)
- [验证单个ckpt](#验证单个ckpt)
- [模型描述](#模型描述)
- [310推理](#310推理)
<!-- /TOC -->
## E-NET描述
## 概述
E-NET主要用于图像分割领域,是一种端到端的分割方法。语义分割的落地(应用于嵌入式设备如手机、可穿戴设备等低功耗移动设备)是一个很重要的问题。基于VGG架构的语义分割模型需要大量的浮点运算,导致运行时间长,从而降低了时效性。 相比之下,ENet网络推理速度快,浮点计算量少,参数少,且有相似的精度。
使用mindpsore复现E-NET[[论文]](https://arxiv.org/abs/1606.02147)
这个项目迁移于ENet的Pytorch实现[[HERE]](https://github.com/davidtvs/PyTorch-ENet)
### 论文
[论文地址](https://arxiv.org/abs/1606.02147):A. Paszke, A. Chaurasia, S. Kim, and E. Culurciello."ENet: A deep neural network architecture for real-time semantic segmentation."
### 精度
| (Val IOU) | enet_pytorch | enet_mindspore |
| -------------- | ------------ | -------------- |
| **512 x 1024** | **59.5** | **62.1** |
其中各个类的IOU的具体计算方法来自[erfnet的pytorch实现](https://github.com/Eromera/erfnet_pytorch)
## 环境
Ascend
## 数据集
[**The Cityscapes dataset**](https://www.cityscapes-dataset.com/):
在官网直接下载的标签文件, 像素被分为30多类, 在训练时我们需要将其归纳到20类, 所以对其需要进行处理. 为了方便可以直接下载已经处理好的数据.
链接:[[HERE]](https://pan.baidu.com/s/1jH9GUDX4grcEoDNLsWPKGw). 提取码:aChQ.
下载后可以得到以下目录:
```sh
└── cityscapes
├── gtFine .................................. ground truth
└── leftImg8bit ............................. 训练集&测试集&验证集
```
键入
```bash
python src/build_mrdata.py \
--dataset_path /path/to/cityscapes/ \
--subset train \
--output_name train.mindrecord
```
脚本会在/path/to/cityscapes/数据集根目录下,找到训练集,在output_name指出的路径下生成mindrecord文件,然后在项目根目录下新建data文件夹,
再将生成的mindrecord文件移动到项目根目录下的data文件夹下,来让脚本中的相对路径能够定位
## 脚本说明
```bash
|
├── ascend310_infer
│ ├── inc
│ │ └── utils.h // utils头文件
│ └── src
│ ├── CMakeLists.txt // cmakelist
│ ├── main.cc // 推理代码
│ ├── build.sh // 运行脚本
│ └── utils.cc // utils实现
├── scripts
│ ├── run_distribute_train.sh // 多卡训练脚本
│ └── run_standalone_train.sh // 单卡训练脚本
├── src
│ ├── build_mrdata.py // 生成mindrecord数据集
│ ├── config.py // 配置参数脚本
│ ├── dataset.py // 数据集脚本
│ ├── iou_eval.py // metric计算脚本
│ ├── criterion.py // 损失函数脚本
│ ├── model.py // 模型脚本
│ └── util.py // 工具函数脚本
├── README_CN.md // 描述文件
├── eval.py // 测试脚本
├── export.py // MINDIR模型导出脚本
└── train.py // 训练脚本
```
## 训练与验证
训练之前需要生成mindrecord数据文件并放到项目根目录的data文件夹下,然后启动脚本。
### 单卡训练
如果你要使用单卡进行训练,进入项目根目录,键入
```bash
nohup bash scripts/run_standalone_train.sh /home/name/cityscapes 0 &
```
其中/home/name/cityscapes指数据集的位置,其后的0指定device_id.
运行该脚本会完成对模型的训练和评估两个阶段。
其中训练阶段分三步,前两步用于训练Enet模型的编码器部分,第三步会训练完整的Enet网络。
训练过程中在项目根目录下会生成log_single_device文件夹,其中log_stage*.txt即为程序log文件,键入
```bash
tail -f log_single_device/log_stage*.txt
```
显示训练状态。
评估阶段会在验证集上计算log_single_device文件夹下所有权重的精度,并同位置生成后缀metrics.txt文件,显示结果
### 多卡训练
例如,你要使用4卡进行训练,进入项目根目录,键入
```py
nohup bash scripts/run_distribute_train.sh /home/name/cityscapes 4 0,1,2,3 /home/name/rank_table_4pcs.json &
```
其中/home/name/cityscapes指数据集的位置,其后的4指rank_size, 再后的0,1,2,3制定了设备的编号, /home/name/rank_table_4pcs.json指并行训练配置文件的位置。其他数目的设备并行训练也类似。
在项目根目录下会生成log_multi_device文件夹,./log_multi_device/log0/log*.txt即为多卡日志文件,键入
```bash
tail -f log_multi_device/log0/log*.txt
```
显示训练状态。
### 验证单个ckpt
键入
```py
python eval.py \
--data_path /path/cityscapes \
--run_distribute false \
--encode false \
--model_root_path /path/ENet/ENet.ckpt \
--device_id 1
```
data_path为数据集根目录,model_root_path为ckpt文件路径。
验证完毕后,会在ckpt文件同目录下后缀metrics.txt文件记录结果。
```txt
model path ./ENet-100_496.ckpt
mean_iou 0.6219186616013426
mean_loss 0.3161865407142856
iou_class [0.96626199 0.75290523 0.87924483 0.43634233 0.44190292 0.50485979
0.50586298 0.60316052 0.89555818 0.56628902 0.92109006 0.66907491
0.4730712 0.89284724 0.45698707 0.62259347 0.32161359 0.29706163
0.6097276 ]
```
## 模型描述
### 性能
#### 训练性能
##### Cityscapes上训练E-Net
| 参数 | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | E-Net | |
| 资源 | Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 Euler2.8 |
| 上传日期 | 2021-10-09 | 2021-07-05 |
| MindSpore版本 | 1.2.0 |
| 数据集 | Cityscapes |
| 训练参数 | epoch=250, steps=496, batch_size = 6, lr=5e-4 |
| 优化器 | Adam |
| 损失函数 | 带权重的Softmax交叉熵 |
| 输出 | 语义分割图 |
| 损失 | 0.17356214 |
| 速度 | 单卡:882毫秒/步; |
| 总时长 | 单卡:30h; |
| 参数(M) | 0.34 |
| 微调检查点 | 4.40M (.ckpt文件) |
| 推理模型 | 9.97M(.air文件) | |
#### 评估性能
##### Cityscapes上评估E-Net
| 参数 | Ascend |
| ------------------- | --------------------------- |
| 模型版本 | E-Net |
| 资源 | Ascend 910;系统 Euler2.8 |
| 上传日期 | 2021-10-09 |
| MindSpore 版本 | 1.2.0 |
| 数据集 | Cityscapes, 500张图像 |
| batch_size | 6 |
| 输出 | 语义分割图 |
| 准确性 | 单卡: 62.19%; |
## 310推理
需要导出训练好的ckpt文件, 得到能在310上直接推理的mindir模型文件:
```sh
python export.py --model_path /path/to/net.ckpt
```
会在当前目录下得到enet.mindir文件。
```sh
bash scripts/run_infer_310.sh /path/to/enet.mindir /path/to/images /path/to/result /path/to/label 0
```
其中/path/to/images指验证集的图片, 由于原始数据集的路径cityscapes/leftImg8bit/val/的图片根据拍摄的城市进行了分类, 需要先将其归到一个文件夹下才能供推理。
例如
```sh
cp /path/to/cityscapes/leftImg8bit/val/frankfurt/* /path/to/images/
cp /path/to/cityscapes/leftImg8bit/val/lindau/* /path/to/images/
cp /path/to/cityscapes/leftImg8bit/val/munster/* /path/to/images/
```
验证集的ground truth, 同理也要归到/path/to/labels/下. 其余的参数/path/to/enet.mindir指mindir文件的路径, /path/to/result推理结果的输出路径(也需要提前生成该文件夹), 0指的是device_id
最终推理结果会输出在/res/result/文件夹下, 当前目录下会生成metric.txt, 其中包含精度.
\ No newline at end of file
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/stat.h>
#include <sys/time.h>
#include <dirent.h>
#include <algorithm>
#include <fstream>
#include <iostream>
#include <string>
#include <sstream>
#include <vector>
#include "include/api/context.h"
#include "include/api/model.h"
#include "include/api/serialization.h"
#include "include/dataset/execute.h"
#include "include/dataset/vision.h"
namespace ms = mindspore;
namespace ds = mindspore::dataset;
int WriteResult(const std::string& imageFile, const std::vector<ms::MSTensor> &outputs, const std::string &res_path);
std::vector<std::string> GetAllFiles(std::string_view dir_name);
DIR *OpenDir(std::string_view dir_name);
std::string RealPath(std::string_view path);
ms::MSTensor ReadFile(const std::string &file);
cmake_minimum_required(VERSION 3.14.1)
project(ENet)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT}/../)
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(enet main.cc utils.cc)
target_link_libraries(enet ${MS_LIB} ${MD_LIB})
#! /bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
mkdir build
cd build
cmake .. -DMINDSPORE_PATH="`pip3 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make
cd ..
\ No newline at end of file
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/stat.h>
#include <sys/time.h>
#include <dirent.h>
#include <algorithm>
#include <fstream>
#include <iostream>
#include <string>
#include <sstream>
#include "include/api/context.h"
#include "include/api/model.h"
#include "include/api/serialization.h"
#include "include/dataset/execute.h"
#include "include/dataset/vision.h"
#include "inc/utils.h"
namespace ms = mindspore;
namespace ds = mindspore::dataset;
int main(int argc, char **argv) {
if (argc != 5) {
std::cout << "example: ./enet /path/to/model /path/to/image device_id " << std::endl;
return -1;
}
std::cout << "model_path:" << argv[1] << std::endl;
std::cout << "image_path:" << argv[2] << std::endl;
std::cout << "result_path" << argv[3] << std::endl;
std::cout << "device_id:"<< argv[4] << std::endl;
int device_id = argv[4][0] - '0';
// set context
auto context = std::make_shared<ms::Context>();
auto ascend310_info = std::make_shared<ms::Ascend310DeviceInfo>();
ascend310_info->SetDeviceID(device_id);
context->MutableDeviceInfo().push_back(ascend310_info);
// define model
ms::Graph graph;
ms::Status ret = ms::Serialization::Load(argv[1], ms::ModelType::kMindIR, &graph);
if (ret != ms::kSuccess) {
std::cout << "Load model failed." << std::endl;
return 1;
}
ms::Model enet;
// build model
ret = enet.Build(ms::GraphCell(graph), context);
if (ret != ms::kSuccess) {
std::cout << "Build model failed." << std::endl;
return 1;
}
// get model info
std::vector<ms::MSTensor> model_inputs = enet.GetInputs();
if (model_inputs.empty()) {
std::cout << "Invalid model, inputs is empty." << std::endl;
return 1;
}
// define transforms
std::make_shared<ds::TensorTransform> decode(new ds::vision::Decode());
std::make_shared<ds::TensorTransform> resize(new ds::vision::Resize({512, 1024}));
std::make_shared<ds::TensorTransform> normalize(new ds::vision::Normalize({0, 0, 0},
{255, 255, 255}));
std::make_shared<ds::TensorTransform> hwc2chw(new ds::vision::HWC2CHW());
// define preprocessor
ds::Execute preprocessor({decode, resize, normalize, hwc2chw});
std::map<double, double> costTime_map;
std::vector<std::string> images = GetAllFiles(argv[2]);
for (const auto &image_file : images) {
struct timeval start = {0};
struct timeval end = {0};
double startTime_ms;
double endTime_ms;
// prepare input
std::vector<ms::MSTensor> outputs;
std::vector<ms::MSTensor> inputs;
// read image file and preprocess
auto image = ReadFile(image_file);
ret = preprocessor(image, &image);
if (ret != ms::kSuccess) {
std::cout << "Image preprocess failed." << std::endl;
return 1;
}
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
image.Data().get(), image.DataSize());
// infer
gettimeofday(&start, NULL);
ret = enet.Predict(inputs, &outputs);
gettimeofday(&end, NULL);
if (ret != ms::kSuccess) {
std::cout << "Predict model failed." << std::endl;
return 1;
}
// print infer result
std::cout << "Image: " << image_file << std::endl;
WriteResult(image_file, outputs, argv[3]);
startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTime_ms, endTime_ms));
}
double average = 0.0;
int infer_cnt = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
infer_cnt++;
}
average = average/infer_cnt;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << infer_cnt << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << infer_cnt << std::endl;
std::string file_name = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream file_stream(file_name.c_str(), std::ios::trunc);
file_stream << timeCost.str();
file_stream.close();
costTime_map.clear();
return 0;
}
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/utils.h"
int WriteResult(const std::string& imageFile, const std::vector<ms::MSTensor> &outputs, const std::string &res_path) {
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = res_path + "/" + fileName;
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
std::vector<std::string> GetAllFiles(std::string_view dir_name) {
struct dirent *filename;
DIR *dir = OpenDir(dir_name);
if (dir == nullptr) {
return {};
}
/* read all the files in the dir ~ */
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string d_name = std::string(filename->d_name);
// get rid of "." and ".."
if (d_name == "." || d_name == ".." || filename->d_type != DT_REG)
continue;
res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
return res;
}
DIR *OpenDir(std::string_view dir_name) {
// check the parameter !
if (dir_name.empty()) {
std::cout << " dir_name is null ! " << std::endl;
return nullptr;
}
std::string real_path = RealPath(dir_name);
// check if dir_name is a valid dir
struct stat s;
lstat(real_path.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dir_name is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(real_path.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dir_name << std::endl;
return nullptr;
}
return dir;
}
std::string RealPath(std::string_view path) {
char real_path_mem[PATH_MAX] = {0};
char *real_path_ret = realpath(path.data(), real_path_mem);
if (real_path_ret == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
return std::string(real_path_mem);
}
ms::MSTensor ReadFile(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return ms::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return ms::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return ms::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
ms::MSTensor buffer(file, ms::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval Enet"""
import math
import os
from argparse import ArgumentParser
import numpy as np
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.criterion import SoftmaxCrossEntropyLoss
from src.dataset import getCityScapesDataLoader_GeneratorDataset
from src.iou_eval import iouEval
from src.model import Encoder_pred, Enet
from src.util import getBool, getCityLossWeight, seed_seed
def IOU(network_trained, dataloader, num_class, enc):
"""compute IOU"""
ioueval = iouEval(num_class)
loss = SoftmaxCrossEntropyLoss(num_class, getCityLossWeight(enc))
loss_list = []
network_trained.set_train(False)
for index, (images, labels) in enumerate(dataloader):
preds = network_trained(images)
l = loss(preds, labels)
loss_list.append(float(str(l)))
print("step {}/{}: loss: ".format(index+1, dataloader.get_dataset_size()), l)
preds = preds.asnumpy().argmax(axis=1).astype(np.int32)
labels = labels.asnumpy().astype(np.int32)
ioueval.addBatch(preds, labels)
mean_iou, iou_class = ioueval.getIoU()
mean_iou = mean_iou.item()
mean_loss = sum(loss_list) / len(loss_list)
return mean_iou, mean_loss, iou_class
def evalNetwork(network, eval_dataloader, ckptPath, encode, num_class=20):
"""load model,eval and save result"""
if ckptPath is None:
print("no model checkpoint!")
elif not os.path.exists(ckptPath):
print("not exist {}".format(ckptPath))
else:
print("load model checkpoint {}!".format(ckptPath))
param_dict = load_checkpoint(ckptPath)
load_param_into_net(network, param_dict)
mean_iou, mean_loss, iou_class = IOU(network, eval_dataloader, num_class, encode)
with open(ckptPath + ".metric.txt", "w") as file_:
print("model path", ckptPath, file=file_)
print("mean_iou", mean_iou, file=file_)
print("mean_loss", mean_loss, file=file_)
print("iou_class", iou_class, file=file_)
def listCKPTPath(model_root_path, enc):
"""get all the ckpt path in model_root_path"""
paths = []
names = os.listdir(model_root_path)
for name in names:
if name.endswith(".ckpt") and name+".metric.txt" not in names:
if enc and name.startswith("Encoder"):
ckpt_path = os.path.join(model_root_path, name)
paths.append(ckpt_path)
elif not enc and name.startswith("ENet"):
ckpt_path = os.path.join(model_root_path, name)
paths.append(ckpt_path)
return paths
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--data_path', type=str)
parser.add_argument('--run_distribute', type=str)
parser.add_argument('--encode', type=str)
parser.add_argument('--model_root_path', type=str)
parser.add_argument('--device_id', type=int)
config = parser.parse_args()
model_root_path_ = config.model_root_path
encode_ = getBool(config.encode)
device_id = config.device_id
CityScapesRoot = config.data_path
run_distribute = getBool(config.run_distribute)
seed_seed()
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
context.set_context(device_id=device_id)
context.set_context(save_graphs=False)
eval_dataloader_ = getCityScapesDataLoader_GeneratorDataset(CityScapesRoot, "val", 6, \
encode_, 512, False, False)
weight_init = "XavierUniform"
if encode_:
network_ = Encoder_pred(num_class=20, weight_init=weight_init, train=False)
else:
network_ = Enet(num_classes=20, init_conv=weight_init, train=False)
if not run_distribute:
if os.path.isdir(model_root_path_):
paths_ = listCKPTPath(model_root_path_, encode_)
for path in paths_:
evalNetwork(network_, eval_dataloader_, path, encode_)
else:
evalNetwork(network_, eval_dataloader_, model_root_path_, encode_)
else:
rank_id = int(os.environ["RANK_ID"])
rank_size = int(os.environ["RANK_SIZE"])
ckpt_files_path = listCKPTPath(model_root_path_, encode_)
n = math.ceil(len(ckpt_files_path) / rank_size)
ckpt_files_path = ckpt_files_path[rank_id*n : rank_id*n + n]
for path in ckpt_files_path:
evalNetwork(network_, eval_dataloader_, path, encode_)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export"""
from argparse import ArgumentParser
import numpy as np
from mindspore import Tensor, context, load_checkpoint, export
from src.model import Enet
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--model_path', type=str)
parser.add_argument('--device_id', type=int, default=0)
config = parser.parse_args()
net = Enet(20, "XavierUniform", train=False)
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
context.set_context(device_id=config.device_id)
load_checkpoint(config.model_path, net=net)
net.set_train(False)
input_data = Tensor(np.zeros([1, 3, 512, 1024]).astype(np.float32))
export(net, input_data, file_name="Enet.mindir", file_format="MINDIR")
#! /bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ]
then
echo "Usage: bash scripts/run_distribute_train.sh /path/to/cityscapes DEVICE_ID RANK_TABLE_FILE"
echo "Example: bash scripts/run_distribute_train.sh /home/name/cityscapes 4 0,1,2,3 /home/name/rank_table_4pcs.json"
exit 1
fi
if [ ! -d $1 ]
then
echo "error: DATASET_PATH=$1 is not a directory"
exit 1
fi
if [ ! -f $4 ]
then
echo "error: RANK_TABLE_FILE=$4 is not a file"
exit 1
fi
echo "CityScapes dataset path: $1"
echo "RANK_SIZE: $2"
echo "DEVICE_ID: $3"
echo "RANK_TABLE_FILE: $4"
ps -aux | grep "python -u ../../train.py" | awk '{print $2}' | xargs kill -9
export HCCL_CONNECT_TIMEOUT=600
export RANK_SIZE=$2
cityscapes_path=$1
IFS="," read -r -a devices <<< "$3";
export RANK_TABLE_FILE=$4
mkdir ./log_multi_device
cd ./log_multi_device
# 1.train
for((i=0;i<RANK_SIZE;i++))
do
{
mkdir ./log$i
cd ./log$i
export RANK_ID=$i
export DEVICE_ID=${devices[i]}
echo "start training stage1 for rank $i, device $DEVICE_ID"
python -u ../../train.py \
--lr 1e-3 \
--repeat 2 \
--run_distribute true \
--save_path './' \
--mindrecord_train_data "../../data/train.mindrecord" \
--stage 1 \
--ckpt_path "" \
> log_stage1.txt 2>&1
cd ../
} &
done
wait
# 2.train
for((i=0;i<RANK_SIZE;i++))
do
{
cd ./log$i
export RANK_ID=$i
export DEVICE_ID=${devices[i]}
echo "start training stage2 for rank $i, device $DEVICE_ID"
python -u ../../train.py \
--lr 1e-3 \
--repeat 2 \
--run_distribute true \
--save_path './' \
--mindrecord_train_data "../../data/train.mindrecord" \
--stage 2 \
--ckpt_path "../log0/Encoder-65_496.ckpt" \
> log_stage2.txt 2>&1
cd ../
} &
done
wait
# 3.train
for((i=0;i<RANK_SIZE;i++))
do
{
cd ./log$i
export RANK_ID=$i
export DEVICE_ID=${devices[i]}
echo "start training stage3 for rank $i, device $DEVICE_ID"
python -u ../../train.py \
--lr 1e-3 \
--repeat 2 \
--run_distribute true \
--save_path './' \
--mindrecord_train_data "../../data/train.mindrecord" \
--stage 3 \
--ckpt_path "../log0/Encoder_1-85_496.ckpt" \
> log_stage3.txt 2>&1
cd ../
} &
done
wait
# eval
cd ./log0
echo "start eval , device ${devices[0]}"
python -u ../../eval.py \
--data_path ${cityscapes_path} \
--run_distribute true \
--encode false \
--model_root_path './' \
--device_id ${devices[0]} \
> log_eval.txt 2>&1 &
#! /bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 5 ]
then
echo "Usage: bash scripts/run_infer_310.sh MINDIR IMGS RES LABEL DEVICE_ID"
echo "Example: bash scripts/run_infer_310.sh /path/to/net.mindir /path/to/images /path/to/result /path/to/label 0"
exit 1
fi
if [ ! -f $1 ]
then
echo "error: mindir_path=$1 is not a file"
exit 1
fi
if [ ! -d $2 ]
then
echo "error: images_path=$2 is not a directory"
exit 1
fi
if [ ! -d $3 ]
then
echo "error: result_path=$3 is not a directory"
exit 1
fi
if [ ! -d $4 ]
then
echo "error: label_path=$4 is not a directory"
exit 1
fi
echo "model mindir: $1"
echo "images path: $2"
echo "result path: $3"
echo "label path: $4"
echo "device id: $5"
cd ascend310_infer/src
bash build.sh
./build/enet $1 $2 $3 $5
cd ../..
python src/eval310.py --res_path $3 --label_path $4
#! /bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: bash scripts/run.sh /path/to/cityscapes DEVICE_ID"
echo "Example: bash scripts/run.sh /home/name/cityscapes 0"
exit 1
fi
if [ ! -d $1 ]
then
echo "error: DATASET_PATH=$1 is not a directory"
exit 1
fi
echo "CityScapes dataset path: $1"
echo "DEVICE_ID: $2"
ps -aux | grep "python -u ../train.py" | awk '{print $2}' | xargs kill -9
mkdir ./log_single_device
cd ./log_single_device
cityscapes_path=$1
export RANK_SIZE=1
export DEVICE_ID=$2
python -u ../train.py \
--lr 5e-4 \
--repeat 1 \
--run_distribute false \
--save_path './' \
--mindrecord_train_data "../data/train.mindrecord" \
--stage 1 \
--ckpt_path "" \
> log_stage1.txt 2>&1
python -u ../train.py \
--lr 5e-4 \
--repeat 1 \
--run_distribute false \
--save_path './' \
--mindrecord_train_data "../data/train.mindrecord" \
--stage 2 \
--ckpt_path "./Encoder-65_496.ckpt" \
> log_stage2.txt 2>&1
python -u ../train.py \
--lr 5e-4 \
--repeat 1 \
--run_distribute false \
--save_path './' \
--mindrecord_train_data "../data/train.mindrecord" \
--stage 3 \
--ckpt_path "./Encoder_1-85_496.ckpt" \
> log_stage3.txt 2>&1
python -u ../eval.py \
--data_path ${cityscapes_path} \
--run_distribute false \
--encode false \
--model_root_path './' \
--device_id ${DEVICE_ID} \
> log_eval.txt 2>&1 &
\ No newline at end of file
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""build .mindrecord data file"""
from argparse import ArgumentParser
from tqdm import tqdm
from mindspore.mindrecord import FileWriter
from dataset import cityscapes_datapath
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--dataset_path', type=str)
parser.add_argument('--subset', type=str)
parser.add_argument('--output_name', type=str)
config = parser.parse_args()
output_name = config.output_name
subset = config.subset
dataset_path = config.dataset_path
if not subset in ("train", "val"):
raise RuntimeError('subset should be "train" or "val"')
dataPathLoader = cityscapes_datapath(dataset_path, subset)
writer = FileWriter(file_name=output_name)
seg_schema = {"file_name": {"type": "string"},
"label": {"type": "bytes"},
"data": {"type": "bytes"}}
writer.add_schema(seg_schema, "seg_schema")
data_list = []
cnt = 0
for img_path, label_path in tqdm(dataPathLoader):
sample_ = {"file_name": img_path.split('/')[-1]}
with open(img_path, 'rb') as f:
sample_['data'] = f.read()
with open(label_path, 'rb') as f:
sample_['label'] = f.read()
data_list.append(sample_)
cnt += 1
if cnt % 100 == 0:
writer.write_raw_data(data_list)
print('number of samples written:', cnt)
data_list = []
if data_list:
writer.write_raw_data(data_list)
writer.commit()
print('number of samples written:', cnt)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""argument config"""
import os
from argparse import ArgumentParser
from mindspore import context
from mindspore.common.initializer import XavierUniform
from src.util import getBool, seed_seed, getLR
parser = ArgumentParser()
parser.add_argument('--lr', type=float)
parser.add_argument('--run_distribute', type=str)
parser.add_argument('--save_path', type=str)
parser.add_argument('--repeat', type=int)
parser.add_argument('--mindrecord_train_data', type=str)
parser.add_argument('--stage', type=int)
parser.add_argument('--ckpt_path', type=str)
parser.add_argument('--num_class', type=int, default=20)
config = parser.parse_args()
max_lr = config.lr
run_distribute = getBool(config.run_distribute)
global_size = int(os.environ["RANK_SIZE"])
repeat = config.repeat
stage = config.stage
ckpt_path = config.ckpt_path
save_path = config.save_path
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
context.set_context(device_id=int(os.environ["DEVICE_ID"]))
context.set_context(save_graphs=False)
seed_seed(2) # init random seed
weight_init = XavierUniform() # weight init
ms_train_data = config.mindrecord_train_data
num_class = config.num_class
# train config
class TrainConfig_1:
"""encoder training stage 1"""
def __init__(self):
self.subset = "train"
self.num_class = 20
self.train_img_size = 512
self.epoch_num_save = 10
self.epoch = 65
self.encode = True
self.attach_decoder = False
self.lr = getLR(max_lr, 0, 150, 496, \
run_distribute=run_distribute, global_size=global_size, repeat=repeat)
class TrainConfig_2:
"""encoder training stage 2"""
def __init__(self):
self.subset = "train"
self.num_class = 20
self.train_img_size = 512
self.epoch_num_save = 10
self.epoch = 85
self.encode = True
self.attach_decoder = False
self.lr = getLR(max_lr, 65, 150, 496, \
run_distribute=run_distribute, global_size=global_size, repeat=repeat)
class TrainConfig_3:
"""Enet training stage 3"""
def __init__(self):
self.subset = "train"
self.num_class = 20
self.train_img_size = 512
self.epoch_num_save = 10
self.epoch = 100
self.encode = False
self.attach_decoder = True
self.lr = getLR(max_lr, 0, 100, 496, \
run_distribute=run_distribute, global_size=global_size, repeat=repeat)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""criterion function"""
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore import numpy as mnp
from mindspore import ops as mops
from mindspore.ops import operations as P
from mindspore import Tensor
class SoftmaxCrossEntropyLoss(nn.Cell):
"""SoftmaxCrossEntropyLoss"""
def __init__(self, num_cls, weight):
super(SoftmaxCrossEntropyLoss, self).__init__()
self.one_hot = P.OneHot(axis=-1)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.cast = P.Cast()
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.not_equal = P.NotEqual()
self.num_cls = num_cls
self.mul = P.Mul()
self.sum = P.ReduceSum(False)
self.div = P.RealDiv()
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.unsqueeze = mops.ExpandDims()
self.get_size = mops.Size()
self.exp = mops.Exp()
self.pow = mops.Pow()
self.weight = weight
def construct(self, pred, labels):
"""construct"""
labels = self.cast(labels, mstype.int32)
labels = self.reshape(labels, (-1,))
pred = self.transpose(pred, (0, 2, 3, 1))
pred = self.reshape(pred, (-1, self.num_cls))
one_hot_labels = self.one_hot(labels, self.num_cls, self.on_value, self.off_value)
pred = self.cast(pred, mstype.float32)
num = self.get_size(labels)
if self.weight is not None:
weight = mnp.copy(self.weight)
weight = self.cast(weight, mstype.float32)
weight = self.unsqueeze(weight, 0)
expand = mops.BroadcastTo(pred.shape)
weight = expand(weight)
weight_masked = weight[mnp.arange(num), labels]
loss = self.ce(pred, one_hot_labels)
loss = self.div(self.sum(loss * weight_masked), self.sum(weight_masked))
else:
loss = self.ce(pred, one_hot_labels)
loss = self.div(self.sum(loss), num)
return loss
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dataset"""
import os
import random
from io import BytesIO
import numpy as np
from PIL import Image, ImageFilter
import mindspore.dataset as ds
EXTENSIONS = ['.jpg', '.png']
class MyGaussianBlur(ImageFilter.Filter):
"""GaussianBlur"""
def __init__(self, radius=2, bounds=None):
self.radius = radius
self.bounds = bounds
def filter(self, image):
if self.bounds:
clips = image.crop(self.bounds).gaussian_blur(self.radius)
image.paste(clips, self.bounds)
return image
return image.gaussian_blur(self.radius)
def resize(img, height, interpolation):
h, w = img.size
width = int(height * h / w)
img_new = img.resize((width, height), interpolation)
return img_new
def load_image(file):
return Image.open(file)
def is_image(filename):
return any(filename.endswith(ext) for ext in EXTENSIONS)
def is_label(filename):
return filename.endswith("_labelTrainIds.png")
def image_path(root, basename, extension):
return os.path.join(root, f'{basename}{extension}')
def image_path_city(root, name):
return os.path.join(root, f'{name}')
def image_basename(filename):
return os.path.basename(os.path.splitext(filename)[0])
class MyCoTransform:
"""Transform"""
def __init__(self, stage, enc, augment, height, if_from_mindrecord=False):
self.enc = enc
self.augment = augment
self.height = height
self.if_from_mindrecord = if_from_mindrecord
if not stage in (1, 2, 3):
raise RuntimeError("stage should be 1, 2, 3")
self.stage = stage
if self.stage == 1:
self.ratio = 1.2
else:
self.ratio = 1.3
def process_one(self, image, target, height):
"""data enhance"""
if self.augment:
# GaussianBlur
image = image.filter(MyGaussianBlur(radius=random.random()))
if random.random() > 0.5: # random crop
if self.stage == 1:
ratio = self.ratio
else:
ratio = random.random() * (self.ratio - 1) + 1
w = int(2048 / ratio)
h = int(1024 / ratio)
x = int(random.random()*(2048-w))
y = int(random.random()*(1024-h))
box = (x, y, x+w, y+h)
image = image.crop(box)
target = target.crop(box)
image = resize(image, height, Image.BILINEAR)
target = resize(target, height, Image.NEAREST)
# Random hflip
if random.random() < 0.5:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
target = target.transpose(Image.FLIP_LEFT_RIGHT)
else:
image = resize(image, height, Image.BILINEAR)
target = resize(target, height, Image.NEAREST)
image = np.array(image).astype(np.float32) / 255
image = image.transpose(2, 0, 1)
target = resize(target, int(height/8), Image.NEAREST) if self.enc else target
target = np.array(target).astype(np.uint32)
target[target == 255] = 19
return image, target
def process_one_infer(self, image, height):
image = resize(image, height, Image.BILINEAR)
image = np.array(image).astype(np.float32) / 255
image = image.transpose(2, 0, 1)
return image
def __call__(self, image, target=None):
if self.if_from_mindrecord:
image = Image.open(BytesIO(image))
target = Image.open(BytesIO(target))
if target is None:
image = self.process_one_infer(image, self.height)
return image
image, target = self.process_one(image, target, self.height)
return image, target
class cityscapes:
"""cityscapes"""
def __init__(self, root, subset, enc, aug, height):
self.images_root = os.path.join(root, 'leftImg8bit/')
self.labels_root = os.path.join(root, 'gtFine/')
self.images_root += subset
self.labels_root += subset
self.filenames = [os.path.join(dp, f) for dp, dn, fn in
os.walk(os.path.expanduser(self.images_root)) for f in fn if is_image(f)]
self.filenames.sort()
self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in
os.walk(os.path.expanduser(self.labels_root)) for f in fn if is_label(f)]
self.filenamesGt.sort()
self.transform = MyCoTransform(1, enc, aug, height)
def __getitem__(self, index):
filename = self.filenames[index]
filenameGt = self.filenamesGt[index]
with open(image_path_city(self.images_root, filename), 'rb') as f:
image = load_image(f).convert('RGB')
with open(image_path_city(self.labels_root, filenameGt), 'rb') as f:
label = load_image(f).convert('P')
image, label = self.transform(image, label)
return image, label
def __len__(self):
return len(self.filenames)
class cityscapes_datapath:
"""get cityscapes data path"""
def __init__(self, root, subset):
self.images_root = os.path.join(root, 'leftImg8bit/')
self.labels_root = os.path.join(root, 'gtFine/')
self.images_root += subset
self.labels_root += subset
self.filenames = [os.path.join(dp, f) for dp, dn, fn in
os.walk(os.path.expanduser(self.images_root)) for f in fn if is_image(f)]
self.filenames.sort()
self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in
os.walk(os.path.expanduser(self.labels_root)) for f in fn if is_label(f)]
self.filenamesGt.sort()
def __getitem__(self, index):
filename = self.filenames[index]
filenameGt = self.filenamesGt[index]
return filename, filenameGt
def __len__(self):
return len(self.filenames)
def getCityScapesDataLoader_GeneratorDataset(CityScapesRoot, subset, batch_size,
enc, height, shuffle, aug, rank_id=0, global_size=1, repeat=1):
"""CityScapesGeneratorDataset"""
dataset = cityscapes(CityScapesRoot, subset, enc, aug, height)
dataloader = ds.GeneratorDataset(dataset, column_names=["images", "labels"],
num_parallel_workers=8, shuffle=shuffle, shard_id=rank_id,
num_shards=global_size, python_multiprocessing=True)
if shuffle:
dataloader = dataloader.shuffle(batch_size*10)
dataloader = dataloader.batch(batch_size, drop_remainder=False)
if repeat > 1:
dataloader = dataloader.repeat(repeat)
return dataloader
def getCityScapesDataLoader_mindrecordDataset(stage, data_path, batch_size, enc, height,
shuffle, aug, rank_id=0, global_size=1, repeat=1):
"""CityScapesmindrecordDataset"""
dataloader = ds.MindDataset(data_path, columns_list=["data", "label"],
num_parallel_workers=8, shuffle=shuffle, shard_id=rank_id, num_shards=global_size)
transform = MyCoTransform(stage, enc, aug, height, if_from_mindrecord=True)
dataloader = dataloader.map(operations=transform,
input_columns=["data", "label"], output_columns=["data", "label"],
num_parallel_workers=8, python_multiprocessing=True)
if shuffle:
dataloader = dataloader.shuffle(batch_size*10)
dataloader = dataloader.batch(batch_size, drop_remainder=False)
if repeat > 1:
dataloader = dataloader.repeat(repeat)
return dataloader
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval for 310 infer"""
import os
from argparse import ArgumentParser
import numpy as np
from PIL import Image
def resize(img, height, interpolation):
h, w = img.size
width = int(height * h / w)
img_new = img.resize((width, height), interpolation)
return img_new
def load_image(file):
return Image.open(file)
def is_label(filename):
return filename.endswith("_labelTrainIds.png")
def convert_to_one_hot(y, C):
return np.transpose(np.eye(C)[y], (0, 3, 1, 2)).astype(np.float32)
class iouEval:
"""compute iou"""
def __init__(self, nClasses, ignoreIndex=19):
self.nClasses = nClasses
self.ignoreIndex = ignoreIndex if nClasses > ignoreIndex else -1
self.reset()
def reset(self):
classes = self.nClasses if self.ignoreIndex == -1 else self.nClasses-1
self.tp = np.zeros(classes)
self.fp = np.zeros(classes)
self.fn = np.zeros(classes)
def addBatch(self, x, y):
"""add a batch and compute its iou"""
x_onehot = convert_to_one_hot(x, self.nClasses)
y_onehot = convert_to_one_hot(y, self.nClasses)
if self.ignoreIndex != -1:
ignores = np.expand_dims(y_onehot[:, self.ignoreIndex], axis=1)
x_onehot = x_onehot[:, :self.ignoreIndex]
y_onehot = y_onehot[:, :self.ignoreIndex]
else:
ignores = 0
tpmult = x_onehot * y_onehot
tp = np.sum(
np.sum(np.sum(tpmult, axis=0, keepdims=True),
axis=2, keepdims=True),
axis=3, keepdims=True
).squeeze()
fpmult = x_onehot * (1-y_onehot-ignores)
fp = np.sum(np.sum(np.sum(fpmult, axis=0,
keepdims=True), axis=2, keepdims=True), axis=3, keepdims=True).squeeze()
fnmult = (1-x_onehot) * (y_onehot)
fn = np.sum(np.sum(np.sum(fnmult, axis=0,
keepdims=True), axis=2, keepdims=True), axis=3, keepdims=True).squeeze()
self.tp += tp
self.fp += fp
self.fn += fn
def getIoU(self):
num = self.tp
den = self.tp + self.fp + self.fn + 1e-15
iou = num / den
return np.mean(iou), iou
class cityscapes_datapath:
"""cityscapes datapath iter"""
def __init__(self, labels_path):
self.labels_path = labels_path
self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in
os.walk(os.path.expanduser(self.labels_path)) for f in fn if is_label(f)]
self.filenamesGt.sort()
def __getitem__(self, index):
filenameGt = self.filenamesGt[index]
return filenameGt
def __len__(self):
return len(self.filenamesGt)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--res_path', type=str)
parser.add_argument('--label_path', type=str)
config = parser.parse_args()
res_path = config.res_path
label_path = config.label_path
gt = {}
for i in list(cityscapes_datapath(label_path)):
gt[i.split("/")[-1].rstrip("_gtFine_labelTrainIds.png")] = i
metrics = iouEval(nClasses=20)
for i, bin_name in enumerate(os.listdir(res_path)):
print(i)
file_name_sof = os.path.join(res_path, bin_name)
key = bin_name.split("_leftImg8bit_0.bin")[0]
with open(gt[key], 'rb') as f:
target = load_image(f).convert('P')
target = resize(target, 512, Image.NEAREST)
target = np.array(target).astype(np.uint32)
target[target == 255] = 19
target = target.reshape((1, 512, 1024))
softmax_out = np.fromfile(
file_name_sof, np.float32).reshape((1, 20, 512, 1024))
preds = softmax_out.argmax(axis=1).astype(np.int32)
labels = target.astype(np.int32)
metrics.addBatch(preds, labels)
mean_iou, iou_class = metrics.getIoU()
mean_iou = mean_iou.item()
with open("metric.txt", "w") as metric_file:
print("mean_iou: ", mean_iou, file=metric_file)
print("iou_class: ", iou_class, file=metric_file)
print("mean_iou: ", mean_iou)
print("iou_class: ", iou_class)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""compute iou"""
import numpy as np
def convert_to_one_hot(y, C):
return np.transpose(np.eye(C)[y], (0, 3, 1, 2)).astype(np.float32)
class iouEval:
"""compute iou"""
def __init__(self, nClasses, ignoreIndex=19):
self.nClasses = nClasses
self.ignoreIndex = ignoreIndex if nClasses > ignoreIndex else -1
self.reset()
def reset(self):
classes = self.nClasses if self.ignoreIndex == -1 else self.nClasses-1
self.tp = np.zeros(classes)
self.fp = np.zeros(classes)
self.fn = np.zeros(classes)
def addBatch(self, x, y):
"""add a batch and compute its iou"""
x_onehot = convert_to_one_hot(x, self.nClasses)
y_onehot = convert_to_one_hot(y, self.nClasses)
if self.ignoreIndex != -1:
ignores = np.expand_dims(y_onehot[:, self.ignoreIndex], axis=1)
x_onehot = x_onehot[:, :self.ignoreIndex]
y_onehot = y_onehot[:, :self.ignoreIndex]
else:
ignores = 0
tpmult = x_onehot * y_onehot
tp = np.sum(
np.sum(np.sum(tpmult, axis=0, keepdims=True), axis=2, keepdims=True),
axis=3, keepdims=True
).squeeze()
fpmult = x_onehot * (1-y_onehot-ignores)
fp = np.sum(np.sum(np.sum(fpmult, axis=0, \
keepdims=True), axis=2, keepdims=True), axis=3, keepdims=True).squeeze()
fnmult = (1-x_onehot) * (y_onehot)
fn = np.sum(np.sum(np.sum(fnmult, axis=0, \
keepdims=True), axis=2, keepdims=True), axis=3, keepdims=True).squeeze()
self.tp += tp
self.fp += fp
self.fn += fn
def getIoU(self):
num = self.tp
den = self.tp + self.fp + self.fn + 1e-15
iou = num / den
return np.mean(iou), iou
This diff is collapsed.
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""util"""
import os
import random
import numpy as np
import mindspore
from mindspore import Tensor
from mindspore.train.callback import Callback
class LossMonitor_mine(Callback):
"""LossMonitor"""
def __init__(self, per_print_times, learning_rate):
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.loss_list = []
self.learning_rate = learning_rate
def epoch_begin(self, run_context):
"""epoch begin"""
cb_params = run_context.original_args()
print("epoch:%d lr: %s" % (cb_params.cur_epoch_num, \
self.learning_rate[cb_params.cur_step_num]))
def step_end(self, run_context):
"""step end"""
cb_params = run_context.original_args()
loss = cb_params.net_outputs
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cb_params.cur_epoch_num, cur_step_in_epoch))
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
self.loss_list.append(loss)
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, \
cur_step_in_epoch, loss))
print("average loss is %s" % (np.mean(self.loss_list)))
print()
def epoch_end(self, run_context):
"""epoch end"""
self.loss_list = []
def getBool(string):
"""from str to bool"""
if string == "true":
b = True
elif string == "false":
b = False
else:
raise RuntimeError('string should be "true" or "false"')
return b
def seed_seed(seed=2):
"""set random seed"""
seed = int(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
mindspore.set_seed(seed)
def getLR(maxLR, start_epoch, epoch_num, epoch_step_num, run_distribute=False, \
global_size=1, repeat=1):
"""generate learning rate"""
if run_distribute:
epoch_step_num = int(repeat * epoch_step_num / global_size)
LR = np.arange(epoch_num).repeat(epoch_step_num, 0) + 1
LR = np.power(1 - ((LR - 1) / epoch_num), 0.9) * maxLR
LR = LR[start_epoch * epoch_step_num:]
LR_1 = np.zeros(epoch_step_num)
LR_1[:] = LR[-1]
LR = np.concatenate((LR, LR_1), axis=0)
return Tensor(LR, dtype=mindspore.float32)
def getCityLossWeight(encode):
"""class weight for balance"""
# calculate weights by processing dataset histogram
# create a loder to run all images and calculate histogram of labels,
# then create weight array using class balancing
weight = Tensor(np.zeros(20, dtype=np.float32))
if encode:
weight[0] = 2.3653597831726
weight[1] = 4.4237880706787
weight[2] = 2.9691488742828
weight[3] = 5.3442072868347
weight[4] = 5.2983593940735
weight[5] = 5.2275490760803
weight[6] = 5.4394111633301
weight[7] = 5.3659925460815
weight[8] = 3.4170460700989
weight[9] = 5.2414722442627
weight[10] = 4.7376127243042
weight[11] = 5.2286224365234
weight[12] = 5.455126285553
weight[13] = 4.3019247055054
weight[14] = 5.4264230728149
weight[15] = 5.4331531524658
weight[16] = 5.433765411377
weight[17] = 5.4631009101868
weight[18] = 5.3947434425354
else:
weight[0] = 2.8149201869965
weight[1] = 6.9850029945374
weight[2] = 3.7890393733978
weight[3] = 9.9428062438965
weight[4] = 9.7702074050903
weight[5] = 9.5110931396484
weight[6] = 10.311357498169
weight[7] = 10.026463508606
weight[8] = 4.6323022842407
weight[9] = 9.5608062744141
weight[10] = 7.8698215484619
weight[11] = 9.5168733596802
weight[12] = 10.373730659485
weight[13] = 6.6616044044495
weight[14] = 10.260489463806
weight[15] = 10.287888526917
weight[16] = 10.289801597595
weight[17] = 10.405355453491
weight[18] = 10.138095855713
weight[19] = 0
return weight
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train enet"""
import os
from mindspore import Model, context, nn
from mindspore.communication.management import get_group_size, get_rank, init
from mindspore.context import ParallelMode
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.serialization import _update_param, load_checkpoint
from src.config import (TrainConfig_1, TrainConfig_2, TrainConfig_3,
ckpt_path, ms_train_data, num_class, repeat, run_distribute, save_path, stage, weight_init)
from src.criterion import SoftmaxCrossEntropyLoss
from src.dataset import getCityScapesDataLoader_mindrecordDataset
from src.model import Encoder_pred, Enet
from src.util import getCityLossWeight
def attach(enet, encoder_pretrain):
"""move the params in encoder to enet"""
print("attach decoder.")
encoder_trained_par = encoder_pretrain.parameters_dict()
enet_par = enet.parameters_dict()
for name, param_old in encoder_trained_par.items():
if name.startswith("encoder"):
_update_param(enet_par[name], param_old)
def train(ckpt_path_, trainConfig_, rank_id, rank_size, stage_):
"""train enet"""
print("stage:", stage_)
save_prefix = "Encoder" if trainConfig_.encode else "ENet"
if trainConfig_.epoch == 0:
raise RuntimeError("epoch num cannot be zero")
if trainConfig_.encode:
network = Encoder_pred(num_class, weight_init)
else:
network = Enet(num_class, weight_init)
if not os.path.exists(ckpt_path_):
print("load no ckpt file.")
else:
load_checkpoint(ckpt_file_name=ckpt_path_, net=network)
print("load ckpt file:", ckpt_path_)
# attach decoder
if trainConfig_.attach_decoder:
network_enet = Enet(num_class, weight_init)
attach(network_enet, network)
network = network_enet
dataloader = getCityScapesDataLoader_mindrecordDataset(stage_, ms_train_data, 6, \
trainConfig_.encode, trainConfig_.train_img_size, shuffle=True, aug=True, \
rank_id=rank_id, global_size=rank_size, repeat=repeat)
opt = nn.Adam(network.trainable_params(), trainConfig_.lr, \
weight_decay=1e-4, eps=1e-08)
loss = SoftmaxCrossEntropyLoss(num_class, getCityLossWeight(trainConfig_.encode))
loss_scale_manager = DynamicLossScaleManager()
wrapper = Model(network, loss, opt, loss_scale_manager=loss_scale_manager, \
keep_batchnorm_fp32=True)
time_cb = TimeMonitor()
loss_cb = LossMonitor()
if rank_id == 0:
config_ck = CheckpointConfig(save_checkpoint_steps= \
trainConfig_.epoch_num_save * dataloader.get_dataset_size(), \
keep_checkpoint_max=9999)
saveModel_cb = ModelCheckpoint(prefix=save_prefix, directory= \
save_path, config=config_ck)
call_backs = [saveModel_cb, time_cb, loss_cb]
else:
call_backs = [time_cb, loss_cb]
print("============== Starting {} Training ==============".format(save_prefix))
wrapper.train(trainConfig_.epoch, dataloader, callbacks=call_backs, dataset_sink_mode=True)
return network
if __name__ == "__main__":
rank_id_ = 0
rank_size_ = 1
if run_distribute:
context.set_auto_parallel_context(parameter_broadcast=True)
context.set_auto_parallel_context(parallel_mode=\
ParallelMode.DATA_PARALLEL, gradients_mean=False)
init()
rank_id_ = get_rank()
rank_size_ = get_group_size()
trainConfig = {
1: TrainConfig_1(),
2: TrainConfig_2(),
3: TrainConfig_3()
}
network_ = train(ckpt_path, trainConfig[stage], rank_id=rank_id_,
rank_size=rank_size_, stage_=stage)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment