diff --git a/official/cv/retinaface_resnet50/README.md b/official/cv/retinaface_resnet50/README.md index d4564b4f2c88121595a333c5a40cecd7060e9c18..28c18d01fe1197d812cbb5dc7953e675eef58131 100644 --- a/official/cv/retinaface_resnet50/README.md +++ b/official/cv/retinaface_resnet50/README.md @@ -79,11 +79,19 @@ After installing MindSpore via the official website and download the dataset, yo # run distributed training example bash scripts/run_distribute_gpu_train.sh 4 0,1,2,3 + # run exporting ONNX model example + python export.py + # run evaluation example export CUDA_VISIBLE_DEVICES=0 python eval.py > eval.log 2>&1 & OR bash run_standalone_gpu_eval.sh 0 + + # run evaluation of ONNX model example + python eval_onnx.py + OR + bash scripts/run_onnx_eval.sh ``` # [Script Description](#contents) @@ -98,6 +106,7 @@ After installing MindSpore via the official website and download the dataset, yo 鈹溾攢鈹€ scripts 鈹� 鈹溾攢鈹€run_distribute_gpu_train.sh // shell script for distributed on GPU 鈹� 鈹溾攢鈹€run_standalone_gpu_eval.sh // shell script for evaluation on GPU + 鈹� 鈹溾攢鈹€run_onnx_eval.sh // shell script for ONNX model evaluation on GPU or CPU 鈹溾攢鈹€ src 鈹� 鈹溾攢鈹€dataset.py // creating dataset 鈹� 鈹溾攢鈹€network.py // retinaface architecture @@ -112,6 +121,8 @@ After installing MindSpore via the official website and download the dataset, yo 鈹� 鈹溾攢鈹€ground_truth // eval label 鈹溾攢鈹€ train.py // training script 鈹溾攢鈹€ eval.py // evaluation script + 鈹溾攢鈹€ export.py // ONNX model exporting script + 鈹溾攢鈹€ eval_onnx.py // ONNX model evaluation script ``` ## [Script Parameters](#contents) @@ -160,6 +171,10 @@ Parameters for both training and evaluation can be set in config.py 'val_save_result': False, # Whether save the resultss 'val_predict_save_folder': './widerface_result', # Result save path 'val_gt_dir': './data/ground_truth/', # Path of val set ground_truth + # onnx + 'ckpt_model': '../ckpt/retinaface.ckpt', # path of ckpt file to be exported + 'onnx_model': '../ckpt/retinaface.onnx', # path of the ONNX model to be evaluated + 'device': 'CPU', # device type: CPU or GPU ``` ## [Training Process](#contents) @@ -189,6 +204,16 @@ Parameters for both training and evaluation can be set in config.py After training, you'll get some checkpoint files under the folder `./checkpoint/ckpt_0/` by default. +## ONNX EXPORTING + +- **preparation**锛歁odify the parameter `device` in the `src/config.py` file to select the type of device: `CPU` or `GPU`; then modify the parameter `ckpt_model` to specify the path of ckpt file that is uesd to export to onnx model. + +- **run script**锛歊un the following command to export the `ONNX` model and it will be saved in the current directory. + + ```SHELL + python export.py + ``` + ## [Evaluation Process](#contents) ### Evaluation @@ -226,6 +251,32 @@ Parameters for both training and evaluation can be set in config.py Hard Val AP : 0.8900 ``` +## Evaluation for ONNX MODEL + +- **preparation**锛歁odify the following parameters in the `src/config.py` file according to the actual situation: + + - `device`锛歵ype of device `CPU` 鎴� `GPU`锛� + - `onnx_model`锛歱ath of onnx model锛� + - `val_dataset_folder`锛歱ath of validation dataset; + - `val_gt_dir`锛歱ath of validation dataset ground_truth label. + +- **run script**锛歊un the following command to evaluate the `ONNX` model: + + ```bash + export CUDA_VISIBLE_DEVICES="$1" + python eval_onnx.py + or + bash scripts/run_onnx_eval.sh 0 + ``` + +The evaluation results can be viewed in the window锛� + + ```python + Easy Val AP : 0.9390 + Medium Val AP : 0.9306 + Hard Val AP : 0.8886 + ``` + # [Model Description](#contents) ## [Performance](#contents) diff --git a/official/cv/retinaface_resnet50/README_CN.md b/official/cv/retinaface_resnet50/README_CN.md index cf4335ba2cf6ffcc394f1e0cc5739c34ae0f2e71..37bc2230672ff2e1ceb3bea108dfd3b418dbc5ce 100644 --- a/official/cv/retinaface_resnet50/README_CN.md +++ b/official/cv/retinaface_resnet50/README_CN.md @@ -84,11 +84,19 @@ RetinaFace浣跨敤ResNet50楠ㄥ共鎻愬彇鍥惧儚鐗瑰緛杩涜妫€娴嬨€備粠ModelZoo鑾峰彇 # 鍒嗗竷寮忚缁冪ず渚� bash scripts/run_distribute_gpu_train.sh 3 0,1,2 + # ONNX妯″瀷瀵煎嚭绀轰緥 + python export.py + # 璇勪及绀轰緥 export CUDA_VISIBLE_DEVICES=0 python eval.py > eval.log 2>&1 & OR bash run_standalone_gpu_eval.sh 0 + + # ONNX妯″瀷璇勪及绀轰緥 + python eval_onnx.py + OR + bash scripts/run_onnx_eval.sh ``` # 鑴氭湰璇存槑 @@ -103,6 +111,7 @@ RetinaFace浣跨敤ResNet50楠ㄥ共鎻愬彇鍥惧儚鐗瑰緛杩涜妫€娴嬨€備粠ModelZoo鑾峰彇 鈹溾攢鈹€ scripts 鈹� 鈹溾攢鈹€run_distribute_gpu_train.sh // GPU鍒嗗竷寮弒hell鑴氭湰 鈹� 鈹溾攢鈹€run_standalone_gpu_eval.sh // GPU璇勪及shell鑴氭湰 + 鈹� 鈹溾攢鈹€run_onnx_eval.sh // ONNX妯″瀷璇勪及shell鑴氭湰 鈹溾攢鈹€ src 鈹� 鈹溾攢鈹€dataset.py // 鍒涘缓鏁版嵁闆� 鈹� 鈹溾攢鈹€network.py // RetinaFace鏋舵瀯 @@ -116,6 +125,8 @@ RetinaFace浣跨敤ResNet50楠ㄥ共鎻愬彇鍥惧儚鐗瑰緛杩涜妫€娴嬨€備粠ModelZoo鑾峰彇 鈹� 鈹溾攢鈹€ground_truth // 璇勪及鏍囩 鈹溾攢鈹€ train.py // 璁粌鑴氭湰 鈹溾攢鈹€ eval.py // 璇勪及鑴氭湰 + 鈹溾攢鈹€ export.py // ONNX妯″瀷瀵煎嚭鑴氭湰 + 鈹溾攢鈹€ eval_onnx.py // ONNX妯″瀷璇勪及鑴氭湰 ``` ## 鑴氭湰鍙傛暟 @@ -169,6 +180,10 @@ RetinaFace浣跨敤ResNet50楠ㄥ共鎻愬彇鍥惧儚鐗瑰緛杩涜妫€娴嬨€備粠ModelZoo鑾峰彇 'val_save_result': False, # 鏄惁淇濆瓨缁撴灉 'val_predict_save_folder': './widerface_result', # 缁撴灉淇濆瓨璺緞 'val_gt_dir': './data/ground_truth/', # 楠岃瘉闆唃round_truth璺緞 + # onnx + 'ckpt_model': '../ckpt/retinaface.ckpt', # 寰呰浆鎴怬NNX妯″瀷鐨刢kpt鏂囦欢鐨勮矾寰� + 'onnx_model': '../ckpt/retinaface.onnx', # 璇勪及鎵€浣跨敤鐨凮NNX妯″瀷鐨勮矾寰� + 'device': 'CPU', # 璇勪及ONNX妯″瀷鍜屽鍑篛NNX妯″瀷鏃剁殑璁惧绫诲瀷锛欳PU or GPU ``` ## 璁粌杩囩▼ @@ -198,6 +213,16 @@ RetinaFace浣跨敤ResNet50楠ㄥ共鎻愬彇鍥惧儚鐗瑰緛杩涜妫€娴嬨€備粠ModelZoo鑾峰彇 璁粌缁撴潫鍚庯紝鍙湪榛樿鏂囦欢澶筦./checkpoint/ckpt_0/`涓壘鍒版鏌ョ偣鏂囦欢銆� +## 瀵煎嚭ONNX妯″瀷 + +- **鍑嗗宸ヤ綔**锛氫慨鏀� `src/config.py` 鏂囦欢涓殑 `device` 鍙傛暟锛岄€夋嫨璁惧鐨勭被鍨嬶細`CPU` 鎴� `GPU`锛涚劧鍚庡啀淇敼 `ckpt_model` 鍙傛暟锛岀敤浜庣粰瀹氬鍑� `ONNX` 妯″瀷鎵€浣跨敤鐨� `CKPT` 鏂囦欢鐨勮矾寰勩€� + +- **杩愯妯″瀷瀵煎嚭鑴氭湰**锛氳繍琛屼互涓嬪懡浠ゅ嵆鍙鍑� `ONNX` 妯″瀷锛宍ONNX` 妯″瀷淇濆瓨鍦ㄥ綋鍓嶇洰褰曚笅銆� + + ```bash + python export.py + ``` + ## 璇勪及杩囩▼ ### 璇勪及 @@ -235,6 +260,32 @@ RetinaFace浣跨敤ResNet50楠ㄥ共鎻愬彇鍥惧儚鐗瑰緛杩涜妫€娴嬨€備粠ModelZoo鑾峰彇 Hard Val AP : 0.8904 ``` +## ONNX妯″瀷璇勪及杩囩▼ + +- **鍑嗗宸ヤ綔**锛氭牴鎹疄闄呮儏鍐典慨鏀� `src/config.py` 鏂囦欢涓互涓嬪弬鏁帮細 + + - `device` 鍙傛暟锛氳澶囩被鍨� `CPU` 鎴� `GPU`锛� + - `onnx_model` 鍙傛暟锛氳瘎浼版墍浣跨敤鐨� `ONNX` 妯″瀷鐨勮矾寰勶紱 + - `val_dataset_folder` 鍙傛暟锛氶獙璇侀泦鍥剧墖鎵€鍦ㄦ牴鐩綍锛� + - `val_gt_dir` 鍙傛暟锛歚ground truth` 鏍囩鎵€鍦ㄦ牴鐩綍銆� + +- **杩愯ONNX鎺ㄧ悊鑴氭湰**锛氳繍琛屼互涓嬪懡浠ゅ嵆鍙 `ONNX` 妯″瀷杩涜璇勪及锛� + + ```bash + export CUDA_VISIBLE_DEVICES="$1" + python eval_onnx.py + or + bash scripts/run_onnx_eval.sh 0 + ``` + +鍙湪绐楀彛涓煡鐪嬭瘎浼扮粨鏋滐細 + + ```python +Easy Val AP : 0.9390 +Medium Val AP : 0.9306 +Hard Val AP : 0.8886 + ``` + # 妯″瀷鎻忚堪 ## 鎬ц兘 diff --git a/official/cv/retinaface_resnet50/eval_onnx.py b/official/cv/retinaface_resnet50/eval_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..e69244609508d0cfaf850ee2f005ed4643618790 --- /dev/null +++ b/official/cv/retinaface_resnet50/eval_onnx.py @@ -0,0 +1,407 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Eval Retinaface_resnet50 on ONNXRUNTIME""" +from __future__ import print_function +import os +import time +import datetime +import numpy as np +import cv2 +from src.config import cfg_res50 +from src.utils import decode_bbox, prior_box +import onnxruntime + + +class Timer(): + def __init__(self): + self.start_time = 0. + self.diff = 0. + + def start(self): + self.start_time = time.time() + + def end(self): + self.diff = time.time() - self.start_time + +class DetectionEngine: + def __init__(self, cfg): + self.results = {} + self.nms_thresh = cfg['val_nms_threshold'] + self.conf_thresh = cfg['val_confidence_threshold'] + self.iou_thresh = cfg['val_iou_threshold'] + self.var = cfg['variance'] + self.save_prefix = cfg['val_predict_save_folder'] + self.gt_dir = cfg['val_gt_dir'] + + def _iou(self, a, b): + A = a.shape[0] + B = b.shape[0] + max_xy = np.minimum( + np.broadcast_to(np.expand_dims(a[:, 2:4], 1), [A, B, 2]), + np.broadcast_to(np.expand_dims(b[:, 2:4], 0), [A, B, 2])) + min_xy = np.maximum( + np.broadcast_to(np.expand_dims(a[:, 0:2], 1), [A, B, 2]), + np.broadcast_to(np.expand_dims(b[:, 0:2], 0), [A, B, 2])) + inter = np.maximum((max_xy - min_xy + 1), np.zeros_like(max_xy - min_xy)) + inter = inter[:, :, 0] * inter[:, :, 1] + + area_a = np.broadcast_to( + np.expand_dims( + (a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), 1), + np.shape(inter)) + area_b = np.broadcast_to( + np.expand_dims( + (b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), 0), + np.shape(inter)) + union = area_a + area_b - inter + return inter / union + + def _nms(self, boxes, threshold=0.5): + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + scores = boxes[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + reserved_boxes = [] + while order.size > 0: + i = order[0] + reserved_boxes.append(i) + max_x1 = np.maximum(x1[i], x1[order[1:]]) + max_y1 = np.maximum(y1[i], y1[order[1:]]) + min_x2 = np.minimum(x2[i], x2[order[1:]]) + min_y2 = np.minimum(y2[i], y2[order[1:]]) + + intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1) + intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1) + intersect_area = intersect_w * intersect_h + ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area) + + indices = np.where(ovr <= threshold)[0] + order = order[indices + 1] + + return reserved_boxes + + def write_result(self): + # save result to file. + import json + t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S') + try: + if not os.path.isdir(self.save_prefix): + os.makedirs(self.save_prefix) + + self.file_path = self.save_prefix + '/predict' + t + '.json' + f = open(self.file_path, 'w') + json.dump(self.results, f) + except IOError as e: + raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e))) + else: + f.close() + return self.file_path + + def detect(self, boxes, confs, resize, scale, image_path, priors): + if boxes.shape[0] == 0: + # add to result + event_name, img_name = image_path.split('/') + self.results[event_name][img_name[:-4]] = {'img_path': image_path, + 'bboxes': []} + return + + boxes = decode_bbox(np.squeeze(boxes, 0), priors, self.var) + boxes = boxes * scale / resize + + scores = np.squeeze(confs, 0)[:, 1] + # ignore low scores + inds = np.where(scores > self.conf_thresh)[0] + boxes = boxes[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1] + boxes = boxes[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = self._nms(dets, self.nms_thresh) + dets = dets[keep, :] + + dets[:, 2:4] = (dets[:, 2:4].astype(np.int32) - dets[:, 0:2].astype(np.int32)).astype(np.float64) # int + dets[:, 0:4] = dets[:, 0:4].astype(np.int32).astype(np.float64) # int + + + # add to result + event_name, img_name = image_path.split('/') + if event_name not in self.results.keys(): + self.results[event_name] = {} + self.results[event_name][img_name[:-4]] = {'img_path': image_path, + 'bboxes': dets[:, :5].astype(np.float64).tolist()} + + def _get_gt_boxes(self): + from scipy.io import loadmat + gt = loadmat(os.path.join(self.gt_dir, 'wider_face_val.mat')) + hard = loadmat(os.path.join(self.gt_dir, 'wider_hard_val.mat')) + medium = loadmat(os.path.join(self.gt_dir, 'wider_medium_val.mat')) + easy = loadmat(os.path.join(self.gt_dir, 'wider_easy_val.mat')) + + faceboxes = gt['face_bbx_list'] + events = gt['event_list'] + files = gt['file_list'] + + hard_gt_list = hard['gt_list'] + medium_gt_list = medium['gt_list'] + easy_gt_list = easy['gt_list'] + + return faceboxes, events, files, hard_gt_list, medium_gt_list, easy_gt_list + + def _norm_pre_score(self): + max_score = 0 + min_score = 1 + + for event in self.results: + for name in self.results[event].keys(): + bbox = np.array(self.results[event][name]['bboxes']).astype(np.float64) + if bbox.shape[0] <= 0: + continue + max_score = max(max_score, np.max(bbox[:, -1])) + min_score = min(min_score, np.min(bbox[:, -1])) + + length = max_score - min_score + for event in self.results: + for name in self.results[event].keys(): + bbox = np.array(self.results[event][name]['bboxes']).astype(np.float64) + if bbox.shape[0] <= 0: + continue + bbox[:, -1] -= min_score + bbox[:, -1] /= length + self.results[event][name]['bboxes'] = bbox.tolist() + + def _image_eval(self, predict, gt, keep, iou_thresh, section_num): + + _predict = predict.copy() + _gt = gt.copy() + + image_p_right = np.zeros(_predict.shape[0]) + image_gt_right = np.zeros(_gt.shape[0]) + proposal = np.ones(_predict.shape[0]) + + # x1y1wh -> x1y1x2y2 + _predict[:, 2:4] = _predict[:, 0:2] + _predict[:, 2:4] + _gt[:, 2:4] = _gt[:, 0:2] + _gt[:, 2:4] + + ious = self._iou(_predict[:, 0:4], _gt[:, 0:4]) + for i in range(_predict.shape[0]): + gt_ious = ious[i, :] + max_iou, max_index = gt_ious.max(), gt_ious.argmax() + if max_iou >= iou_thresh: + if keep[max_index] == 0: + image_gt_right[max_index] = -1 + proposal[i] = -1 + elif image_gt_right[max_index] == 0: + image_gt_right[max_index] = 1 + + right_index = np.where(image_gt_right == 1)[0] + image_p_right[i] = len(right_index) + + + + image_pr = np.zeros((section_num, 2), dtype=np.float64) + for section in range(section_num): + _thresh = 1 - (section + 1)/section_num + over_score_index = np.where(predict[:, 4] >= _thresh)[0] + if over_score_index.shape[0] <= 0: + image_pr[section, 0] = 0 + image_pr[section, 1] = 0 + else: + index = over_score_index[-1] + p_num = len(np.where(proposal[0:(index+1)] == 1)[0]) + image_pr[section, 0] = p_num + image_pr[section, 1] = image_p_right[index] + + return image_pr + + + def get_eval_result(self): + self._norm_pre_score() + facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list = self._get_gt_boxes() + section_num = 1000 + sets = ['easy', 'medium', 'hard'] + set_gts = [easy_gt_list, medium_gt_list, hard_gt_list] + ap_key_dict = {0: "Easy Val AP : ", 1: "Medium Val AP : ", 2: "Hard Val AP : ",} + ap_dict = {} + for _set in range(len(sets)): + gt_list = set_gts[_set] + count_gt = 0 + pr_curve = np.zeros((section_num, 2), dtype=np.float64) + for i, _ in enumerate(event_list): + event = str(event_list[i][0][0]) + image_list = file_list[i][0] + event_predict_dict = self.results[event] + event_gt_index_list = gt_list[i][0] + event_gt_box_list = facebox_list[i][0] + + for j, _ in enumerate(image_list): + predict = np.array(event_predict_dict[str(image_list[j][0][0])]['bboxes']).astype(np.float64) + gt_boxes = event_gt_box_list[j][0].astype('float') + keep_index = event_gt_index_list[j][0] + count_gt += len(keep_index) + + if gt_boxes.shape[0] <= 0 or predict.shape[0] <= 0: + continue + keep = np.zeros(gt_boxes.shape[0]) + if keep_index.shape[0] > 0: + keep[keep_index-1] = 1 + + image_pr = self._image_eval(predict, gt_boxes, keep, + iou_thresh=self.iou_thresh, + section_num=section_num) + pr_curve += image_pr + + precision = pr_curve[:, 1] / pr_curve[:, 0] + recall = pr_curve[:, 1] / count_gt + + precision = np.concatenate((np.array([0.]), precision, np.array([0.]))) + recall = np.concatenate((np.array([0.]), recall, np.array([1.]))) + for i in range(precision.shape[0]-1, 0, -1): + precision[i-1] = np.maximum(precision[i-1], precision[i]) + index = np.where(recall[1:] != recall[:-1])[0] + ap = np.sum((recall[index + 1] - recall[index]) * precision[index + 1]) + + + print(ap_key_dict[_set] + '{:.4f}'.format(ap)) + + return ap_dict + + +def val(): + cfg = cfg_res50 + + # testing dataset + testset_folder = cfg['val_dataset_folder'] + testset_label_path = cfg['val_dataset_folder'] + "label.txt" + with open(testset_label_path, 'r') as f: + _test_dataset = f.readlines() + test_dataset = [] + for im_path in _test_dataset: + if im_path.startswith('# '): + test_dataset.append(im_path[2:-1]) # delete '# ...\n' + + num_images = len(test_dataset) + + timers = {'forward_time': Timer(), 'misc': Timer()} + + if cfg['val_origin_size']: + h_max, w_max = 0, 0 + for img_name in test_dataset: + image_path = os.path.join(testset_folder, 'images', img_name) + _img = cv2.imread(image_path, cv2.IMREAD_COLOR) + if _img.shape[0] > h_max: + h_max = _img.shape[0] + if _img.shape[1] > w_max: + w_max = _img.shape[1] + + h_max = (int(h_max / 32) + 1) * 32 + w_max = (int(w_max / 32) + 1) * 32 + + priors = prior_box(image_sizes=(h_max, w_max), + min_sizes=[[16, 32], [64, 128], [256, 512]], + steps=[8, 16, 32], + clip=False) + else: + target_size = 1600 + max_size = 2176 + priors = prior_box(image_sizes=(max_size, max_size), + min_sizes=[[16, 32], [64, 128], [256, 512]], + steps=[8, 16, 32], + clip=False) + + # init detection engine + detection = DetectionEngine(cfg) + + # testing begin + print('Predict box starting') + for i, img_name in enumerate(test_dataset): + image_path = os.path.join(testset_folder, 'images/', img_name) + + img_raw = cv2.imread(image_path, cv2.IMREAD_COLOR) + + img = np.float32(img_raw) + + # testing scale + if cfg['val_origin_size']: + resize = 1 + assert img.shape[0] <= h_max and img.shape[1] <= w_max + image_t = np.empty((h_max, w_max, 3), dtype=img.dtype) + image_t[:, :] = (104.0, 117.0, 123.0) + image_t[0:img.shape[0], 0:img.shape[1]] = img + img = image_t + else: + im_size_min = np.min(img.shape[0:2]) + im_size_max = np.max(img.shape[0:2]) + resize = float(target_size) / float(im_size_min) + # prevent bigger axis from being more than max_size: + if np.round(resize * im_size_max) > max_size: + resize = float(max_size) / float(im_size_max) + + img = cv2.resize(img, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + + assert img.shape[0] <= max_size and img.shape[1] <= max_size + image_t = np.empty((max_size, max_size, 3), dtype=img.dtype) + image_t[:, :] = (104.0, 117.0, 123.0) + image_t[0:img.shape[0], 0:img.shape[1]] = img + img = image_t + + scale = np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]], dtype=img.dtype) + img -= (104, 117, 123) + img = img.transpose(2, 0, 1) + img = np.expand_dims(img, 0) + + timers['forward_time'].start() + if cfg['device'] == 'GPU': + providers = ['CUDAExecutionProvider'] + elif cfg['device'] == 'CPU': + providers = ['CPUExecutionProvider'] + else: + raise ValueError( + f'Please Select CPU or GPU' + ) + session = onnxruntime.InferenceSession(cfg['onnx_model'], providers=providers) + inputs = {session.get_inputs()[0].name: img} + boxes, confs, _ = session.run(None, inputs) + timers['forward_time'].end() + timers['misc'].start() + detection.detect(boxes, confs, resize, scale, img_name, priors) + timers['misc'].end() + + print('im_detect: {:d}/{:d} forward_pass_time: {:.4f}s misc: {:.4f}s'.format(i + 1, num_images, + timers['forward_time'].diff, + timers['misc'].diff)) + print('Predict box done.') + print('Eval starting') + + if cfg['val_save_result']: + # Save the predict result if you want. + predict_result_path = detection.write_result() + print('predict result path is {}'.format(predict_result_path)) + + detection.get_eval_result() + print('Eval done.') + + +if __name__ == '__main__': + val() diff --git a/official/cv/retinaface_resnet50/export.py b/official/cv/retinaface_resnet50/export.py new file mode 100644 index 0000000000000000000000000000000000000000..9176e06ef4fe5ec1dcaa564bdc1506f11017bb01 --- /dev/null +++ b/official/cv/retinaface_resnet50/export.py @@ -0,0 +1,49 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""EXPORT ONNX MODEL WITH CKPT MODEL BASED ON MINDSPORE""" +from __future__ import print_function +import numpy as np +import mindspore as ms +from mindspore import Tensor, export +from src.network import RetinaFace, resnet50 +from src.config import cfg_res50 + + +def export_ONNX_model(): + cfg = cfg_res50 + + ms.set_context(mode=ms.GRAPH_MODE, device_target=cfg['device'], save_graphs=True) + + # build network + backbone = resnet50(1001) + network = RetinaFace(phase='predict', backbone=backbone) + backbone.set_train(False) + network.set_train(False) + + # load checkpoint into network + param_dict = ms.load_checkpoint(cfg['ckpt_model']) + network.init_parameters_data() + ms.load_param_into_net(network, param_dict) + + # build input data + input_data = Tensor(np.ones([1, 3, 2176, 2176]).astype(np.float32)) + + # export onnx model + print("Begin to Export ONNX Model...") + export(network, input_data, file_name='retinaface', file_format='ONNX') + print("Export ONNX Model Successfully!") + +if __name__ == '__main__': + export_ONNX_model() diff --git a/official/cv/retinaface_resnet50/requirements.txt b/official/cv/retinaface_resnet50/requirements.txt index e7590640de37f69a782424616f0d5ecc836b634a..e9ed1cae243b1cabf1cfc6e15a6ae55fd1c45cbb 100644 --- a/official/cv/retinaface_resnet50/requirements.txt +++ b/official/cv/retinaface_resnet50/requirements.txt @@ -1,3 +1,4 @@ numpy opencv-python scipy +onnxruntime-gpu diff --git a/official/cv/retinaface_resnet50/scripts/run_onnx_eval.sh b/official/cv/retinaface_resnet50/scripts/run_onnx_eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..0ce748981b7cdb36388b6e01405e6b5170c0e504 --- /dev/null +++ b/official/cv/retinaface_resnet50/scripts/run_onnx_eval.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Before running the script, you should modify 4 params of src/config.py file" +echo "including val_dataset_folder, val_gt_dir, onnx_model and device." +echo "And then, you can run the script as: bash run_onnx_eval.sh 0" +echo "==============================================================================================================" + +export CUDA_VISIBLE_DEVICES="$1" +python eval_onnx.py > eval.log 2>&1 & diff --git a/official/cv/retinaface_resnet50/src/config.py b/official/cv/retinaface_resnet50/src/config.py index 473124ccf11eed0cb9fe28861478af77c42c54f7..7765b3eb5fccfe3df930cb85b7e3c1d4dabb813d 100644 --- a/official/cv/retinaface_resnet50/src/config.py +++ b/official/cv/retinaface_resnet50/src/config.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -65,4 +65,8 @@ cfg_res50 = { 'val_predict_save_folder': './widerface_result', 'val_gt_dir': './data/ground_truth/', + # onnx + 'ckpt_model': '../ckpt/retinaface.ckpt', + 'onnx_model': './retinaface.onnx', + 'device': 'GPU', }