From c2ad0a11083efb34c2132431d40804528f0421d2 Mon Sep 17 00:00:00 2001 From: caifubi <caifubi1@huawei.com> Date: Fri, 2 Sep 2022 14:24:07 +0800 Subject: [PATCH] remove create_dataset_pynative --- official/cv/resnet/src/dataset.py | 71 ------------------------------- official/cv/resnet/train.py | 5 +-- 2 files changed, 1 insertion(+), 75 deletions(-) diff --git a/official/cv/resnet/src/dataset.py b/official/cv/resnet/src/dataset.py index 9b340844c..614c0eb14 100644 --- a/official/cv/resnet/src/dataset.py +++ b/official/cv/resnet/src/dataset.py @@ -158,77 +158,6 @@ def create_dataset2(dataset_path, do_train, batch_size=32, train_image_size=224, return data_set -def create_dataset_pynative(dataset_path, do_train, batch_size=32, train_image_size=224, - eval_image_size=224, target="Ascend", distribute=False, enable_cache=False, - cache_session_id=None): - """ - create a train or eval imagenet2012 dataset for resnet50 benchmark - - Args: - dataset_path(string): the path of dataset. - do_train(bool): whether dataset is used for train or eval. - repeat_num(int): the repeat times of dataset. Default: 1 - batch_size(int): the batch size of dataset. Default: 32 - target(str): the device target. Default: Ascend - distribute(bool): data for distribute or not. Default: False - enable_cache(bool): whether tensor caching service is used for eval. Default: False - cache_session_id(int): If enable_cache, cache session_id need to be provided. Default: None - - Returns: - dataset - """ - device_num, rank_id = _get_rank_info(distribute) - - if device_num == 1: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=get_num_parallel_workers(8), shuffle=True) - else: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=get_num_parallel_workers(2), shuffle=True, - num_shards=device_num, shard_id=rank_id) - - # Computed from random subset of ImageNet training images - mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] - std = [0.229 * 255, 0.224 * 255, 0.225 * 255] - - # define map operations - if do_train: - trans = [ - ds.vision.RandomCropDecodeResize(train_image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), - ds.vision.RandomHorizontalFlip(prob=0.5), - ds.vision.Normalize(mean=mean, std=std), - ds.vision.HWC2CHW() - ] - else: - trans = [ - ds.vision.Decode(), - ds.vision.Resize(256), - ds.vision.CenterCrop(eval_image_size), - ds.vision.Normalize(mean=mean, std=std), - ds.vision.HWC2CHW() - ] - - type_cast_op = ds.transforms.transforms.TypeCast(ms.int32) - - data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=4) - # only enable cache for eval - if do_train: - enable_cache = False - if enable_cache: - if not cache_session_id: - raise ValueError("A cache session_id must be provided to use cache.") - eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0) - data_set = data_set.map(operations=type_cast_op, input_columns="label", - num_parallel_workers=get_num_parallel_workers(2), - cache=eval_cache) - else: - data_set = data_set.map(operations=type_cast_op, input_columns="label", - num_parallel_workers=get_num_parallel_workers(2)) - - # apply batch operations - data_set = data_set.batch(batch_size, drop_remainder=True) - - return data_set - - def create_dataset3(dataset_path, do_train, batch_size=32, train_image_size=224, eval_image_size=224, target="Ascend", distribute=False, enable_cache=False, cache_session_id=None): """ diff --git a/official/cv/resnet/train.py b/official/cv/resnet/train.py index 9eb95e41f..c62f62897 100644 --- a/official/cv/resnet/train.py +++ b/official/cv/resnet/train.py @@ -83,10 +83,7 @@ if config.net_name in ("resnet18", "resnet34", "resnet50", "resnet152"): if config.dataset == "cifar10": from src.dataset import create_dataset1 as create_dataset else: - if config.mode_name == "GRAPH": - from src.dataset import create_dataset2 as create_dataset - else: - from src.dataset import create_dataset_pynative as create_dataset + from src.dataset import create_dataset2 as create_dataset elif config.net_name == "resnet101": from src.resnet import resnet101 as resnet from src.dataset import create_dataset3 as create_dataset -- GitLab