Skip to content
Snippets Groups Projects
Unverified Commit e5827ca2 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3578 remove create_dataset_pynative

Merge pull request !3578 from caifubi/master
parents dc15eb1e c2ad0a11
No related branches found
No related tags found
No related merge requests found
......@@ -158,77 +158,6 @@ def create_dataset2(dataset_path, do_train, batch_size=32, train_image_size=224,
return data_set
def create_dataset_pynative(dataset_path, do_train, batch_size=32, train_image_size=224,
eval_image_size=224, target="Ascend", distribute=False, enable_cache=False,
cache_session_id=None):
"""
create a train or eval imagenet2012 dataset for resnet50 benchmark
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
distribute(bool): data for distribute or not. Default: False
enable_cache(bool): whether tensor caching service is used for eval. Default: False
cache_session_id(int): If enable_cache, cache session_id need to be provided. Default: None
Returns:
dataset
"""
device_num, rank_id = _get_rank_info(distribute)
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=get_num_parallel_workers(8), shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=get_num_parallel_workers(2), shuffle=True,
num_shards=device_num, shard_id=rank_id)
# Computed from random subset of ImageNet training images
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if do_train:
trans = [
ds.vision.RandomCropDecodeResize(train_image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
ds.vision.RandomHorizontalFlip(prob=0.5),
ds.vision.Normalize(mean=mean, std=std),
ds.vision.HWC2CHW()
]
else:
trans = [
ds.vision.Decode(),
ds.vision.Resize(256),
ds.vision.CenterCrop(eval_image_size),
ds.vision.Normalize(mean=mean, std=std),
ds.vision.HWC2CHW()
]
type_cast_op = ds.transforms.transforms.TypeCast(ms.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=4)
# only enable cache for eval
if do_train:
enable_cache = False
if enable_cache:
if not cache_session_id:
raise ValueError("A cache session_id must be provided to use cache.")
eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0)
data_set = data_set.map(operations=type_cast_op, input_columns="label",
num_parallel_workers=get_num_parallel_workers(2),
cache=eval_cache)
else:
data_set = data_set.map(operations=type_cast_op, input_columns="label",
num_parallel_workers=get_num_parallel_workers(2))
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
return data_set
def create_dataset3(dataset_path, do_train, batch_size=32, train_image_size=224, eval_image_size=224,
target="Ascend", distribute=False, enable_cache=False, cache_session_id=None):
"""
......
......@@ -83,10 +83,7 @@ if config.net_name in ("resnet18", "resnet34", "resnet50", "resnet152"):
if config.dataset == "cifar10":
from src.dataset import create_dataset1 as create_dataset
else:
if config.mode_name == "GRAPH":
from src.dataset import create_dataset2 as create_dataset
else:
from src.dataset import create_dataset_pynative as create_dataset
from src.dataset import create_dataset2 as create_dataset
elif config.net_name == "resnet101":
from src.resnet import resnet101 as resnet
from src.dataset import create_dataset3 as create_dataset
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment