From c2ad0a11083efb34c2132431d40804528f0421d2 Mon Sep 17 00:00:00 2001
From: caifubi <caifubi1@huawei.com>
Date: Fri, 2 Sep 2022 14:24:07 +0800
Subject: [PATCH] remove create_dataset_pynative

---
 official/cv/resnet/src/dataset.py | 71 -------------------------------
 official/cv/resnet/train.py       |  5 +--
 2 files changed, 1 insertion(+), 75 deletions(-)

diff --git a/official/cv/resnet/src/dataset.py b/official/cv/resnet/src/dataset.py
index 9b340844c..614c0eb14 100644
--- a/official/cv/resnet/src/dataset.py
+++ b/official/cv/resnet/src/dataset.py
@@ -158,77 +158,6 @@ def create_dataset2(dataset_path, do_train, batch_size=32, train_image_size=224,
     return data_set
 
 
-def create_dataset_pynative(dataset_path, do_train, batch_size=32, train_image_size=224,
-                            eval_image_size=224, target="Ascend", distribute=False, enable_cache=False,
-                            cache_session_id=None):
-    """
-    create a train or eval imagenet2012 dataset for resnet50 benchmark
-
-    Args:
-        dataset_path(string): the path of dataset.
-        do_train(bool): whether dataset is used for train or eval.
-        repeat_num(int): the repeat times of dataset. Default: 1
-        batch_size(int): the batch size of dataset. Default: 32
-        target(str): the device target. Default: Ascend
-        distribute(bool): data for distribute or not. Default: False
-        enable_cache(bool): whether tensor caching service is used for eval. Default: False
-        cache_session_id(int): If enable_cache, cache session_id need to be provided. Default: None
-
-    Returns:
-        dataset
-    """
-    device_num, rank_id = _get_rank_info(distribute)
-
-    if device_num == 1:
-        data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=get_num_parallel_workers(8), shuffle=True)
-    else:
-        data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=get_num_parallel_workers(2), shuffle=True,
-                                         num_shards=device_num, shard_id=rank_id)
-
-    # Computed from random subset of ImageNet training images
-    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
-    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
-
-    # define map operations
-    if do_train:
-        trans = [
-            ds.vision.RandomCropDecodeResize(train_image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
-            ds.vision.RandomHorizontalFlip(prob=0.5),
-            ds.vision.Normalize(mean=mean, std=std),
-            ds.vision.HWC2CHW()
-        ]
-    else:
-        trans = [
-            ds.vision.Decode(),
-            ds.vision.Resize(256),
-            ds.vision.CenterCrop(eval_image_size),
-            ds.vision.Normalize(mean=mean, std=std),
-            ds.vision.HWC2CHW()
-        ]
-
-    type_cast_op = ds.transforms.transforms.TypeCast(ms.int32)
-
-    data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=4)
-    # only enable cache for eval
-    if do_train:
-        enable_cache = False
-    if enable_cache:
-        if not cache_session_id:
-            raise ValueError("A cache session_id must be provided to use cache.")
-        eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0)
-        data_set = data_set.map(operations=type_cast_op, input_columns="label",
-                                num_parallel_workers=get_num_parallel_workers(2),
-                                cache=eval_cache)
-    else:
-        data_set = data_set.map(operations=type_cast_op, input_columns="label",
-                                num_parallel_workers=get_num_parallel_workers(2))
-
-    # apply batch operations
-    data_set = data_set.batch(batch_size, drop_remainder=True)
-
-    return data_set
-
-
 def create_dataset3(dataset_path, do_train, batch_size=32, train_image_size=224, eval_image_size=224,
                     target="Ascend", distribute=False, enable_cache=False, cache_session_id=None):
     """
diff --git a/official/cv/resnet/train.py b/official/cv/resnet/train.py
index 9eb95e41f..c62f62897 100644
--- a/official/cv/resnet/train.py
+++ b/official/cv/resnet/train.py
@@ -83,10 +83,7 @@ if config.net_name in ("resnet18", "resnet34", "resnet50", "resnet152"):
     if config.dataset == "cifar10":
         from src.dataset import create_dataset1 as create_dataset
     else:
-        if config.mode_name == "GRAPH":
-            from src.dataset import create_dataset2 as create_dataset
-        else:
-            from src.dataset import create_dataset_pynative as create_dataset
+        from src.dataset import create_dataset2 as create_dataset
 elif config.net_name == "resnet101":
     from src.resnet import resnet101 as resnet
     from src.dataset import create_dataset3 as create_dataset
-- 
GitLab