Skip to content
Snippets Groups Projects
Unverified Commit 427f8e10 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2294 modify centerface eval_all process

Merge pull request !2294 from zhanghuiyao/modify_centerface
parents fcce1675 a26c3d85
No related branches found
No related tags found
No related merge requests found
......@@ -194,7 +194,9 @@ step6: eval
# python setup.py install;
# cd -; #cd ../../scripts;
bash eval_all.sh [ground_truth_path]
bash eval_all.sh [GROUND_TRUTH_PATH]
or
bash eval_all.sh [GROUND_TRUTH_PATH] [FILTER_EASY](optional) [FILTER_MEDIUM](optional) [FILTER_HARD](optional) [FILTER_SUM](optional)
```
- Running on [ModelArts](https://support.huaweicloud.com/modelarts/)
......@@ -651,7 +653,9 @@ cd ../../../scripts;
```python
# you need to change the parameter in eval_all.sh
# default eval the ckpt saved in ./scripts/output/centerface/[89-140]
bash eval_all.sh [ground_truth_path]
bash eval_all.sh [GROUND_TRUTH_PATH]
or
bash eval_all.sh [GROUND_TRUTH_PATH] [FILTER_EASY](optional) [FILTER_MEDIUM](optional) [FILTER_HARD](optional) [FILTER_SUM](optional)
```
3. test+eval
......
......@@ -14,6 +14,26 @@
# limitations under the License.
# ============================================================================
if [ $# != 1 ] && [ $# != 5 ]
then
echo "Usage: bash eval_all.sh [GROUND_TRUTH_PATH]"
echo " bash eval_all.sh [GROUND_TRUTH_PATH] [FILTER_EASY](optional) [FILTER_MEDIUM](optional) [FILTER_HARD](optional) [FILTER_SUM](optional)"
exit 1
fi
FILTER_EASY=0
FILTER_MEDIUM=0
FILTER_HARD=0
FILTER_SUM=0
if [ $# = 5 ]
then
FILTER_EASY=$2
FILTER_MEDIUM=$3
FILTER_HARD=$4
FILTER_SUM=$5
fi
root=$PWD
save_path=$root/output/centerface/
if [ ! -d $save_path ]
......@@ -44,9 +64,8 @@ do
done
wait
line_number=`awk 'BEGIN{hard=0;nr=0}{if($0~"Hard"){if($4>hard){hard=$4;nr=NR}}}END{print nr}' log_eval_all.txt`
start_line_number=`expr $line_number - 3`
end_line_number=`expr $line_number + 1`
echo "The best result " >> log_eval_all.txt
sed -n "$start_line_number, $end_line_number p" log_eval_all.txt >> log_eval_all.txt
python ../src/find_best_checkpoint.py --result_file_path=./log_eval_all.txt \
--filter_easy=$FILTER_EASY \
--filter_medium=$FILTER_MEDIUM \
--filter_hard=$FILTER_HARD \
--filter_sum=$FILTER_SUM >> log_eval_all.txt 2>&1 &
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""find best checkpoint"""
import os
import argparse
def find_ckpt(file_path, arg):
with open(file_path) as f:
str_result_list = f.readlines()
easy_list, medium_list, hard_list, path_list = [], [], [], []
easy_index, medium_index, hard_index, sum_index = -1, -1, -1, -1
easy_max, medium_max, hard_max, sum_max = -1, -1, -1, -1
index = -1
for i in range(len(str_result_list)):
if str_result_list[i].startswith("==================== Results"):
index += 1
path_list.append(str_result_list[i][:-1])
easy_ap = float(str_result_list[i + 1][15:-1])
medium_ap = float(str_result_list[i + 2][15:-1])
hard_ap = float(str_result_list[i + 3][15:-1])
sum_ap = easy_ap + medium_ap + hard_ap
easy_list.append(easy_ap)
medium_list.append(medium_ap)
hard_list.append(hard_ap)
if easy_ap >= arg.filter_easy \
and medium_ap >= arg.filter_medium \
and hard_ap >= arg.filter_hard \
and sum_ap >= arg.filter_sum:
if easy_ap >= easy_max:
easy_max = easy_ap
easy_index = index
if medium_ap >= medium_max:
medium_max = medium_ap
medium_index = index
if hard_ap >= hard_max:
hard_max = hard_ap
hard_index = index
if sum_ap >= sum_max:
sum_max = sum_ap
sum_index = index
if easy_index == -1:
print("\nCannot find a checkpoint that meets the filter requirements.")
else:
print("\nThe best easy result:", flush=True)
print(path_list[easy_index], flush=True)
print("Easy Val AP: {}".format(easy_list[easy_index]), flush=True)
print("Medium Val AP: {}".format(medium_list[easy_index]), flush=True)
print("Hard Val AP: {}".format(hard_list[easy_index]), flush=True)
print("=================================================", flush=True)
print("\nThe best medium result:", flush=True)
print(path_list[medium_index], flush=True)
print("Easy Val AP: {}".format(easy_list[medium_index]), flush=True)
print("Medium Val AP: {}".format(medium_list[medium_index]), flush=True)
print("Hard Val AP: {}".format(hard_list[medium_index]), flush=True)
print("=================================================", flush=True)
print("\nThe best hard result:", flush=True)
print(path_list[hard_index], flush=True)
print("Easy Val AP: {}".format(easy_list[hard_index]), flush=True)
print("Medium Val AP: {}".format(medium_list[hard_index]), flush=True)
print("Hard Val AP: {}".format(hard_list[hard_index]), flush=True)
print("=================================================", flush=True)
print("\nThe best sum result:", flush=True)
print(path_list[sum_index], flush=True)
print("Easy Val AP: {}".format(easy_list[sum_index]), flush=True)
print("Medium Val AP: {}".format(medium_list[sum_index]), flush=True)
print("Hard Val AP: {}".format(hard_list[sum_index]), flush=True)
print("=================================================", flush=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--result_file_path', default='', help='result file')
parser.add_argument('--filter_easy', default=0.0, type=float, help='filter easy')
parser.add_argument('--filter_medium', default=0.0, type=float, help='filter medium')
parser.add_argument('--filter_hard', default=0.0, type=float, help='filter hard')
parser.add_argument('--filter_sum', default=0.0, type=float, help='filter sum')
args = parser.parse_args()
if os.path.isfile(args.result_file_path):
find_ckpt(args.result_file_path, args)
else:
raise FileNotFoundError("{} not found.".format(args.result_file_path))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment