diff --git a/research/cv/PSPNet/README.md b/research/cv/PSPNet/README.md index 19db7dfeb7b2ba77d8a47cca86e6862dc3983d29..9f866fb248e1e5eb3727208b43f5ffa883c57884 100644 --- a/research/cv/PSPNet/README.md +++ b/research/cv/PSPNet/README.md @@ -35,12 +35,12 @@ The pyramid pooling module fuses features under four different pyramid scales.Fo # [Dataset](#Content) -- [PASCAL VOC 2012 and SBD Dataset Website](http://home.bharathh.info/pubs/codes/SBD/download.html) +- [Semantic Boundaries Dataset](http://home.bharathh.info/pubs/codes/SBD/download.html) - It contains 11,357 finely annotated images split into training and testing sets with 8,498 and 2,857 images respectively. - - The path formats in voc_train_lst.txt and voc_val_lst.txt are different, you can run create_train_lst.py to generate train_lst.txt in data dir for VOC2012. As follow锛� + - The path formats in train.txt and val.txt are partial. And the mat file in the cls needs to be converted to image. You can run preprocess_dataset.py to convert the mat file and generate train_list.txt and val_list.txt. As follow锛� ```python - python src/dataset/create_train_lst.py --data_dir [DATA_DIR] + python src/dataset/preprocess_dataset.py --data_dir [DATA_DIR] ``` - [ADE20K Dataset Website](http://groups.csail.mit.edu/vision/datasets/ADE20K/) diff --git a/research/cv/PSPNet/src/dataset/create_train_lst.py b/research/cv/PSPNet/src/dataset/create_train_lst.py deleted file mode 100644 index cda77b897224dd5bbe35f5bafc494a5791d64c8f..0000000000000000000000000000000000000000 --- a/research/cv/PSPNet/src/dataset/create_train_lst.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""generate train_lst.txt""" -import os -import argparse - - -def _parser_args(): - parser = argparse.ArgumentParser('dataset list generator') - parser.add_argument("--data_dir", type=str, default='', help="VOC2012 data dir") - return parser.parse_args() - - -def _get_data_list(data_list_file): - with open(data_list_file, 'r') as f: - return f.readlines() - - -def main(): - args = _parser_args() - data_dir = args.data_dir - voc_train_lst_txt = os.path.join(data_dir, 'voc_train_lst.txt') - train_lst_txt = os.path.join(data_dir, 'train_lst.txt') - - voc_train_data_lst = _get_data_list(voc_train_lst_txt) - with open(train_lst_txt, 'w') as f: - for line in voc_train_data_lst: - img_, anno_ = (os.path.join('VOCdevkit/VOC2012', i.strip()) for i in line.split()) - f.write(f'{img_} {anno_}\n') - print('generating voc train list success.') - - -if __name__ == "__main__": - main() diff --git a/research/cv/PSPNet/src/dataset/preprocess_dataset.py b/research/cv/PSPNet/src/dataset/preprocess_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..132db42f7338d7660a2123f3d3f64aa8c21fd6ee --- /dev/null +++ b/research/cv/PSPNet/src/dataset/preprocess_dataset.py @@ -0,0 +1,85 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""preprocess dataset""" +import os +import argparse +from PIL import Image +from scipy.io import loadmat + + +def _parser_args(): + parser = argparse.ArgumentParser('dataset list generator') + parser.add_argument("--data_dir", type=str, default='', help="VOC2012 data dir") + return parser.parse_args() + + +def _get_data_list(data_list_file): + with open(data_list_file, 'r') as f: + return f.readlines() + + +def _mat_to_arr(mat_path): + data = loadmat(mat_path)['GTcls'] + arr = data[0, 0][1] + return arr + + +def main(): + args = _parser_args() + data_dir = args.data_dir + cls_path = os.path.join(data_dir, 'cls') + cls_png_path = os.path.join(data_dir, 'cls_png') + if not os.path.exists(cls_png_path): + os.mkdir(cls_png_path) + mat_list = os.listdir(cls_path) + print('Start generating png.') + print("It takes a little time. Don't quit!") + i = 0 + for mat in mat_list: + mat_path = os.path.join(cls_path, mat) + arr = _mat_to_arr(mat_path) + png_path = os.path.join(cls_png_path, mat.replace('mat', 'png')) + ann_im = Image.fromarray(arr) + ann_im.save(png_path) + i += 1 + print(f"Generate {i} png to data_dir/cls_png.") + + train_txt = os.path.join(data_dir, 'train.txt') + train_list_txt = os.path.join(data_dir, 'train_list.txt') + val_txt = os.path.join(data_dir, 'val.txt') + val_list_txt = os.path.join(data_dir, 'val_list.txt') + + train_data_lst = _get_data_list(train_txt) + with open(train_list_txt, 'w') as f: + for line in train_data_lst: + line = line.strip() + img_ = os.path.join('img', line + '.jpg') + anno_ = os.path.join('cls_png', line + '.png') + f.write(f'{img_} {anno_}\n') + print('Generate train_list to data_dir.') + + val_data_lst = _get_data_list(val_txt) + with open(val_list_txt, 'w') as f: + for line in val_data_lst: + line = line.strip() + img_ = os.path.join('img', line + '.jpg') + anno_ = os.path.join('cls_png', line + '.png') + f.write(f'{img_} {anno_}\n') + print('Generate train_list to data_dir.') + print('Finish.') + + +if __name__ == "__main__": + main()