diff --git a/official/cv/centerface/README.md b/official/cv/centerface/README.md index 9286d286e54d813ff0e580c393ae38dceeb1872a..5aee17eb82fffcc6b2486b0630f75ca7e50e1565 100644 --- a/official/cv/centerface/README.md +++ b/official/cv/centerface/README.md @@ -194,7 +194,9 @@ step6: eval # python setup.py install; # cd -; #cd ../../scripts; -bash eval_all.sh [ground_truth_path] +bash eval_all.sh [GROUND_TRUTH_PATH] +or +bash eval_all.sh [GROUND_TRUTH_PATH] [FILTER_EASY](optional) [FILTER_MEDIUM](optional) [FILTER_HARD](optional) [FILTER_SUM](optional) ``` - Running on [ModelArts](https://support.huaweicloud.com/modelarts/) @@ -651,7 +653,9 @@ cd ../../../scripts; ```python # you need to change the parameter in eval_all.sh # default eval the ckpt saved in ./scripts/output/centerface/[89-140] - bash eval_all.sh [ground_truth_path] + bash eval_all.sh [GROUND_TRUTH_PATH] + or + bash eval_all.sh [GROUND_TRUTH_PATH] [FILTER_EASY](optional) [FILTER_MEDIUM](optional) [FILTER_HARD](optional) [FILTER_SUM](optional) ``` 3. test+eval diff --git a/official/cv/centerface/scripts/eval_all.sh b/official/cv/centerface/scripts/eval_all.sh index 4db448818c4b21afdd7902c97c0d0de5ad804089..4acb8b0988d1e8119cd59e1cc695eada06afc2c8 100644 --- a/official/cv/centerface/scripts/eval_all.sh +++ b/official/cv/centerface/scripts/eval_all.sh @@ -14,6 +14,26 @@ # limitations under the License. # ============================================================================ +if [ $# != 1 ] && [ $# != 5 ] +then + echo "Usage: bash eval_all.sh [GROUND_TRUTH_PATH]" + echo " bash eval_all.sh [GROUND_TRUTH_PATH] [FILTER_EASY](optional) [FILTER_MEDIUM](optional) [FILTER_HARD](optional) [FILTER_SUM](optional)" + exit 1 +fi + +FILTER_EASY=0 +FILTER_MEDIUM=0 +FILTER_HARD=0 +FILTER_SUM=0 + +if [ $# = 5 ] +then + FILTER_EASY=$2 + FILTER_MEDIUM=$3 + FILTER_HARD=$4 + FILTER_SUM=$5 +fi + root=$PWD save_path=$root/output/centerface/ if [ ! -d $save_path ] @@ -44,9 +64,8 @@ do done wait -line_number=`awk 'BEGIN{hard=0;nr=0}{if($0~"Hard"){if($4>hard){hard=$4;nr=NR}}}END{print nr}' log_eval_all.txt` -start_line_number=`expr $line_number - 3` -end_line_number=`expr $line_number + 1` - -echo "The best result " >> log_eval_all.txt -sed -n "$start_line_number, $end_line_number p" log_eval_all.txt >> log_eval_all.txt +python ../src/find_best_checkpoint.py --result_file_path=./log_eval_all.txt \ +--filter_easy=$FILTER_EASY \ +--filter_medium=$FILTER_MEDIUM \ +--filter_hard=$FILTER_HARD \ +--filter_sum=$FILTER_SUM >> log_eval_all.txt 2>&1 & diff --git a/official/cv/centerface/src/find_best_checkpoint.py b/official/cv/centerface/src/find_best_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..47c863a06d05031f4bc02f9ceae5de7cf5c38d12 --- /dev/null +++ b/official/cv/centerface/src/find_best_checkpoint.py @@ -0,0 +1,99 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""find best checkpoint""" + +import os +import argparse + +def find_ckpt(file_path, arg): + with open(file_path) as f: + str_result_list = f.readlines() + + easy_list, medium_list, hard_list, path_list = [], [], [], [] + easy_index, medium_index, hard_index, sum_index = -1, -1, -1, -1 + easy_max, medium_max, hard_max, sum_max = -1, -1, -1, -1 + index = -1 + for i in range(len(str_result_list)): + if str_result_list[i].startswith("==================== Results"): + index += 1 + path_list.append(str_result_list[i][:-1]) + easy_ap = float(str_result_list[i + 1][15:-1]) + medium_ap = float(str_result_list[i + 2][15:-1]) + hard_ap = float(str_result_list[i + 3][15:-1]) + sum_ap = easy_ap + medium_ap + hard_ap + easy_list.append(easy_ap) + medium_list.append(medium_ap) + hard_list.append(hard_ap) + if easy_ap >= arg.filter_easy \ + and medium_ap >= arg.filter_medium \ + and hard_ap >= arg.filter_hard \ + and sum_ap >= arg.filter_sum: + if easy_ap >= easy_max: + easy_max = easy_ap + easy_index = index + if medium_ap >= medium_max: + medium_max = medium_ap + medium_index = index + if hard_ap >= hard_max: + hard_max = hard_ap + hard_index = index + if sum_ap >= sum_max: + sum_max = sum_ap + sum_index = index + + if easy_index == -1: + print("\nCannot find a checkpoint that meets the filter requirements.") + else: + print("\nThe best easy result:", flush=True) + print(path_list[easy_index], flush=True) + print("Easy Val AP: {}".format(easy_list[easy_index]), flush=True) + print("Medium Val AP: {}".format(medium_list[easy_index]), flush=True) + print("Hard Val AP: {}".format(hard_list[easy_index]), flush=True) + print("=================================================", flush=True) + + print("\nThe best medium result:", flush=True) + print(path_list[medium_index], flush=True) + print("Easy Val AP: {}".format(easy_list[medium_index]), flush=True) + print("Medium Val AP: {}".format(medium_list[medium_index]), flush=True) + print("Hard Val AP: {}".format(hard_list[medium_index]), flush=True) + print("=================================================", flush=True) + + print("\nThe best hard result:", flush=True) + print(path_list[hard_index], flush=True) + print("Easy Val AP: {}".format(easy_list[hard_index]), flush=True) + print("Medium Val AP: {}".format(medium_list[hard_index]), flush=True) + print("Hard Val AP: {}".format(hard_list[hard_index]), flush=True) + print("=================================================", flush=True) + + print("\nThe best sum result:", flush=True) + print(path_list[sum_index], flush=True) + print("Easy Val AP: {}".format(easy_list[sum_index]), flush=True) + print("Medium Val AP: {}".format(medium_list[sum_index]), flush=True) + print("Hard Val AP: {}".format(hard_list[sum_index]), flush=True) + print("=================================================", flush=True) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--result_file_path', default='', help='result file') + parser.add_argument('--filter_easy', default=0.0, type=float, help='filter easy') + parser.add_argument('--filter_medium', default=0.0, type=float, help='filter medium') + parser.add_argument('--filter_hard', default=0.0, type=float, help='filter hard') + parser.add_argument('--filter_sum', default=0.0, type=float, help='filter sum') + args = parser.parse_args() + + if os.path.isfile(args.result_file_path): + find_ckpt(args.result_file_path, args) + else: + raise FileNotFoundError("{} not found.".format(args.result_file_path))