Skip to content
Snippets Groups Projects
Commit 1593978a authored by zhaoxusheng's avatar zhaoxusheng
Browse files

10.20

parent 38058771
No related branches found
No related tags found
No related merge requests found
......@@ -52,7 +52,7 @@ Siam-RPN提出了一种基于RPN的孪生网络结构。由孪生子网络和RPN
## 混合精度
采用[混合精度](https://www.mindspore.cn/tutorials/experts/zh-CN/master/others/mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
采用[混合精度](https://www.mindspore.cn/tutorials/zh-CN/master/advanced/mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。
# 环境要求
......@@ -117,6 +117,7 @@ Siam-RPN提出了一种基于RPN的孪生网络结构。由孪生子网络和RPN
| |──run_distribute_train_gpu.sh // 本地GPU多卡训练脚本
| |──run_infer_310.sh //310推理评估脚本
| |──run_gpu.sh //GPU单卡训练脚本
| |──run_evalonnx.sh //onnx推理数据集脚本
├── src
│ ├── data_loader.py // 数据集加载处理脚本
│ ├── net.py // siamRPN架构
......@@ -146,8 +147,9 @@ Siam-RPN提出了一种基于RPN的孪生网络结构。由孪生子网络和RPN
│ ├── ······
│ └── list.txt
├── train.py // 训练脚本
├── eval.py // 评估脚本
├── export_mindir.py // 将checkpoint文件导出到air/mindir
├── evalonnx.py // onnx评估脚本
├── eval.py // ckpt评估脚本
├── export.py // 将checkpoint文件导出到onnx或mindir
```
## 脚本参数
......@@ -331,6 +333,30 @@ cat acc.log
在train.py中,我们设置了随机种子。
# ONNX推理
```bash
# 生成onnx文件
python export.py --ckpt_file=/path/ckpt/siamRPN-xx_xxxx.ckpt
```
# onnx推理,根据vot2015数据集和vot2016数据集分别选择对应的onnx推理代码
# 例如选择vot2015数据集推理
```bash
python evalonnx.py --checkpoint_path=/path/siamrpn.onnx
或者 sh run_evalonnx.sh [dataset_path] [model_path] [filename]
#vot2016同理
```
-结果保存在filename当中,例如
```bash
{"all_videos": {"accuracy": 0.5890433443656077, "robustness": 0.3868562106735027, "eao": 0.30735406482761557}}
}
```
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/models)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval vot"""
import argparse
import os
import json
import sys
import time
import numpy as np
import onnxruntime as ort
import mindspore as ms
from mindspore import context, ops
from mindspore import Tensor
from src import evaluation as eval_
from src.config import config
from src.util import get_exemplar_image, get_instance_image, box_transform_inv
from src.generate_anchors import generate_anchors
from tqdm import tqdm
import cv2
sys.path.append(os.getcwd())
def create_session(checkpoint_path, target_device):
if target_device == 'GPU':
providers = ['CUDAExecutionProvider']
elif target_device == 'CPU':
providers = ['CPUExecutionProvider']
else:
raise ValueError(
f'Unsupported target device {target_device}, '
f'Expected one of: "CPU", "GPU"'
)
session = ort.InferenceSession(checkpoint_path, providers=providers)
return session
def change(r):
return np.maximum(r, 1. / r)
def sz(w, h):
pad = (w + h) * 0.5
sz2 = (w + pad) * (h + pad)
return np.sqrt(sz2)
def sz_wh(wh):
pad = (wh[0] + wh[1]) * 0.5
sz2 = (wh[0] + pad) * (wh[1] + pad)
return np.sqrt(sz2)
def get_axis_aligned_bbox(region):
""" convert region to (cx, cy, w, h) that represent by axis aligned box
"""
nv = len(region)
region = np.array(region)
if nv == 8:
x1 = min(region[0::2])
x2 = max(region[0::2])
y1 = min(region[1::2])
y2 = max(region[1::2])
A1 = np.linalg.norm(region[0:2] - region[2:4]) * \
np.linalg.norm(region[2:4] - region[4:6])
A2 = (x2 - x1) * (y2 - y1)
s = np.sqrt(A1 / A2)
w = s * (x2 - x1) + 1
h = s * (y2 - y1) + 1
x = x1
y = y1
else:
x = region[0]
y = region[1]
w = region[2]
h = region[3]
return x, y, w, h
def reshapeimg1(img):
img = Tensor(img, ms.float32)
img = img.transpose((2, 0, 1))
img = img.asnumpy()
img = img.reshape(1, 3, 127, 127)
return img
def reshapeimg2(img):
img = Tensor(img, ms.float32)
img = img.transpose((2, 0, 1))
img = img.asnumpy()
img = img.reshape(1, 3, 255, 255)
return img
def calculate(bbox):
gbox = np.array(bbox)
gbox = list((gbox[0] - gbox[2] / 2 + 1 / 2, gbox[1] - gbox[3] / 2 + 1 / 2,
gbox[0] + gbox[2] / 2 - 1 / 2, gbox[1] + gbox[3] / 2 - 1 / 2))
return gbox
def show(accuracy, video_paths, robustness, eao):
print('accuracy is ', accuracy / float(len(video_paths)))
print('robustness is ', robustness)
print('eao is ', eao)
def predscore(pred_score):
pred_score = Tensor(pred_score)
softmax = ops.Softmax(axis=2)
pred_score = softmax(pred_score)[0, :, 1]
pred_score = pred_score.asnumpy()
return pred_score
def resshow(target, pos, frame, origin_target_sz, lr, target_sz):
res_x = np.clip(target[0] + pos[0], 0, frame.shape[1])
res_y = np.clip(target[1] + pos[1], 0, frame.shape[0])
res_w = np.clip(target_sz[0] * (1 - lr) + target[2] * lr,
config.min_scale * origin_target_sz[0],
config.max_scale * origin_target_sz[0])
res_h = np.clip(target_sz[1] * (1 - lr) + target[3] * lr,
config.min_scale * origin_target_sz[1],
config.max_scale * origin_target_sz[1])
return res_x, res_y, res_w, res_h
def bboxshow(bbox, frame):
bbox = (
np.clip(bbox[0], 0, frame.shape[1]).astype(np.float64),
np.clip(bbox[1], 0, frame.shape[0]).astype(np.float64),
np.clip(bbox[2], 10, frame.shape[1]).astype(np.float64),
np.clip(bbox[3], 10, frame.shape[0]).astype(np.float64))
return bbox
def result1show(acc, num_failures, frames, duration):
result1 = {}
result1['acc'] = acc
result1['num_failures'] = num_failures
result1['fps'] = round(len(frames) / duration, 3)
return result1
def test(model_path, data_path, save_name):
session = create_session(model_path, "GPU")
inname = [input.name for input in session.get_inputs()]
outname = [output.name for output in session.get_outputs()]
direct_file = os.path.join(data_path, 'list.txt')
with open(direct_file, 'r') as f:
direct_lines = f.readlines()
video_names = np.sort([x.split('\n')[0] for x in direct_lines])
video_paths = [os.path.join(data_path, x) for x in video_names]
results = {}
accuracy = 0
all_overlaps = []
all_failures = []
gt_lenth = []
for video_path in tqdm(video_paths, total=len(video_paths)):
groundtruth_path = os.path.join(video_path, 'groundtruth.txt')
with open(groundtruth_path, 'r') as f:
boxes = f.readlines()
if ',' in boxes[0]:
boxes = [list(map(float, box.split(','))) for box in boxes]
else:
boxes = [list(map(int, box.split())) for box in boxes]
gt = boxes.copy()
gt[:][2] = gt[:][0] + gt[:][2]
gt[:][3] = gt[:][1] + gt[:][3]
frames = [os.path.join(video_path, 'color', x) for x in np.sort(os.listdir(os.path.join(video_path, 'color')))]
frames = [x for x in frames if '.jpg' in x]
tic = time.perf_counter()
template_idx = 0
valid_scope = 2 * config.valid_scope + 1
anchors = generate_anchors(config.total_stride, config.anchor_base_size, config.anchor_scales,
config.anchor_ratios,
valid_scope)
window = np.tile(np.outer(np.hanning(config.score_size), np.hanning(config.score_size))[None, :],
[config.anchor_num, 1, 1]).flatten()
res = []
for idx, frame in tqdm(enumerate(frames), total=len(frames)):
frame = cv2.imdecode(np.fromfile(frame, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
h, w = frame.shape[0], frame.shape[1]
if idx == template_idx:
bbox = get_axis_aligned_bbox(boxes[idx])
pos = np.array(
[bbox[0] + bbox[2] / 2 - 1 / 2, bbox[1] + bbox[3] / 2 - 1 / 2]) # center x, center y, zero based
target_sz = np.array([bbox[2], bbox[3]])
bbox = np.array([bbox[0] + bbox[2] / 2 - 1 / 2, bbox[1] + bbox[3] / 2 - 1 / 2, bbox[2], bbox[3]])
origin_target_sz = np.array([bbox[2], bbox[3]])
img_mean = np.mean(frame, axis=(0, 1))
exemplar_img, _, _ = get_exemplar_image(frame, bbox,
config.exemplar_size, config.context_amount, img_mean)
exemplar_img = reshapeimg1(exemplar_img)
res.append([1])
elif idx < template_idx:
res.append([0])
else:
instance_img_np, _, _, scale_x = get_instance_image(frame, bbox, config.exemplar_size,
config.instance_size,
config.context_amount, img_mean)
instance_img_np = reshapeimg2(instance_img_np)
pred_score, pred_regress = session.run(outname, {inname[0]: exemplar_img, inname[1]: instance_img_np})
pred_score = predscore(pred_score)
delta = pred_regress[0]
box_pred = box_transform_inv(anchors, delta)
s_c = change(sz(box_pred[:, 2], box_pred[:, 3]) / (sz_wh(target_sz * scale_x))) # scale penalty
r_c = change(
(target_sz[0] / target_sz[1]) / (box_pred[:, 2] / box_pred[:, 3])) # ratio penalty
penalty = np.exp(-(r_c * s_c - 1.) * config.penalty_k)
pscore = penalty * pred_score
pscore = pscore * (1 - config.window_influence) + window * config.window_influence
best_pscore_id = np.argmax(pscore)
target = box_pred[best_pscore_id, :] / scale_x
lr = penalty[best_pscore_id] * pred_score[best_pscore_id] * config.lr_box
res_x, res_y, res_w, res_h = resshow(target, pos, frame, origin_target_sz, lr, target_sz)
pos = np.array([res_x, res_y])
target_sz = np.array([res_w, res_h])
bbox = np.array([res_x, res_y, res_w, res_h])
bbox = bboxshow(bbox, frame)
gbox = calculate(bbox)
if eval_.judge_failures(gbox, boxes[idx], 0):
res.append([2])
template_idx = min(idx + 5, len(frames) - 1)
else:
res.append(gbox)
duration = time.perf_counter() - tic
acc, overlaps, failures, num_failures = eval_.calculate_accuracy_failures(res, gt, [w, h])
accuracy += acc
result1 = result1show(acc, num_failures, frames, duration)
results[video_path.split('/')[-1]] = result1
all_overlaps.append(overlaps)
all_failures.append(failures)
gt_lenth.append(len(frames))
all_length = sum([len(x) for x in all_overlaps])
robustness = sum([len(x) for x in all_failures]) / all_length * 100
eao = eval_.calculate_eao("VOT2015", all_failures, all_overlaps, gt_lenth)
result1 = {}
result1['accuracy'] = accuracy / float(len(video_paths))
result1['robustness'] = robustness
result1['eao'] = eao
results['all_videos'] = result1
show(accuracy, video_paths, robustness, eao)
json.dump(results, open(save_name, 'w'))
def parse_args():
'''parse_args'''
parser = argparse.ArgumentParser(description='Mindspore SiameseRPN Infering')
parser.add_argument('--device_target', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
parser.add_argument('--device_id', type=int, default=0, help='DEVICE_ID')
# choose dataset_path vot2016 or vot vot2015
parser.add_argument('--dataset_path', type=str, default='vot2015', help='Dataset path')
parser.add_argument('--checkpoint_path', type=str, default='siamrpn.onnx', help='checkpoint of siamRPN')
parser.add_argument('--filename', type=str, default='onnx2015', help='save result file')
args_opt = parser.parse_args()
return args_opt
if __name__ == '__main__':
args = parse_args()
if args.device_target == 'GPU':
device_id = args.device_id
context.set_context(device_id=device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
model_file_path = args.checkpoint_path
data_file_path = args.dataset_path
save_file_name = args.filename
test(model_path=model_file_path, data_path=data_file_path, save_name=save_file_name)
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,19 +16,17 @@
import argparse
import numpy as np
import mindspore
from mindspore import context, Tensor, export
from mindspore.train.serialization import load_checkpoint
from src.net import SiameseRPN
def siamrpn_export():
""" export function """
context.set_context(
mode=context.GRAPH_MODE,
device_target="Ascend",
device_target="GPU",
save_graphs=False,
device_id=args.device_id)
net = SiameseRPN(groups=1, is_310infer=True)
......@@ -37,6 +35,7 @@ def siamrpn_export():
input_data1 = Tensor(np.zeros([1, 3, 127, 127]), mindspore.float32)
input_data2 = Tensor(np.zeros([1, 3, 255, 255]), mindspore.float32)
input_data = [input_data1, input_data2]
# choose file_format= mindir or onnx
export(net, *input_data, file_name='siamrpn', file_format="MINDIR")
......
......@@ -2,4 +2,5 @@ lmdb
fire
opencv-python
tqdm
Shapely
\ No newline at end of file
Shapely
onnxruntime-gpu
\ No newline at end of file
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: bash run_eval_onnx.sh [DATA_ROOT] [DATA_LST] [FILE_NAME]"
exit 1
fi
export DATA_NAME=$1
export MODEL_PATH=$2
export FILENAME=$3
python evalonnx.py --dataset_path=$DATA_NAME --checkpoint_path=$MODEL_PATH --filename=$FILENAME &> evalonnx2015.log &
......@@ -63,7 +63,7 @@ class Config:
gray_ratio = 0.25
score_size = int((instance_size - exemplar_size) / 8 + 1)
penalty_k = 0.22
window_influence = 0.40
window_influence = 0.20
lr_box = 0.30
min_scale = 0.1
max_scale = 10
......
......@@ -74,8 +74,8 @@ class SiameseRPN(nn.Cell):
self.op_split_input = ops.Split(axis=1, output_num=self.groups)
self.op_split_krenal = ops.Split(axis=0, output_num=self.groups)
self.op_concat = ops.Concat(axis=1)
self.conv2d_cout = ops.Conv2D(out_channel=10, kernel_size=4)
self.conv2d_rout = ops.Conv2D(out_channel=20, kernel_size=4)
self.conv2d_cout = ops.Conv2D(out_channel=10, kernel_size=(4, 4))
self.conv2d_rout = ops.Conv2D(out_channel=20, kernel_size=(4, 4))
self.regress_adjust = nn.Conv2d(4 * self.k, 4 * self.k, 1, pad_mode='valid', has_bias=True)
self.reshape = ops.Reshape()
self.transpose = ops.Transpose()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment