Skip to content
Snippets Groups Projects
Commit 8f44b986 authored by zhaojichen's avatar zhaojichen Committed by caojian05
Browse files

fix dcgan distributed training scripts

parent 38f7c153
No related branches found
No related tags found
No related merge requests found
......@@ -16,6 +16,7 @@
rec format to jpg
"""
import os
import argparse
from skimage import io
import mxnet as mx
from mxnet import recordio
......
......@@ -33,8 +33,12 @@ def create_dataset_imagenet(dataset_path, num_parallel_workers=None):
Returns:
dataset
"""
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers)
device_num, rank_id = _get_rank_info()
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers,
num_shards=device_num, shard_id=rank_id)
assert dcgan_imagenet_cfg.image_height == dcgan_imagenet_cfg.image_width, "image_height not equal image_width"
image_size = dcgan_imagenet_cfg.image_height
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment