diff --git a/research/cv/wgan_gp/README_CN.md b/research/cv/wgan_gp/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..2396ebd717179e96bdf79faae21da595a2c90bee --- /dev/null +++ b/research/cv/wgan_gp/README_CN.md @@ -0,0 +1,184 @@ +# 鐩綍 + +<!-- TOC --> + +- [鐩綍](#鐩綍) +- [WGAN-GP鎻忚堪](#wgan-gp鎻忚堪) +- [妯″瀷鏋舵瀯](#妯″瀷鏋舵瀯) +- [鏁版嵁闆哴(#鏁版嵁闆�) +- [鐜瑕佹眰](#鐜瑕佹眰) +- [蹇€熷叆闂╙(#蹇€熷叆闂�) +- [鑴氭湰璇存槑](#鑴氭湰璇存槑) + - [鑴氭湰鍙婃牱渚嬩唬鐮乚(#鑴氭湰鍙婃牱渚嬩唬鐮�) + - [鑴氭湰鍙傛暟](#鑴氭湰鍙傛暟) + - [璁粌杩囩▼](#璁粌杩囩▼) + - [鍗曟満璁粌](#鍗曟満璁粌) +- [妯″瀷鎻忚堪](#妯″瀷鎻忚堪) + - [鎬ц兘](#鎬ц兘) + - [璁粌鎬ц兘](#璁粌鎬ц兘) +- [闅忔満鎯呭喌璇存槑](#闅忔満鎯呭喌璇存槑) +- [ModelZoo涓婚〉](#modelzoo涓婚〉) + +<!-- /TOC --> + +# WGAN-GP鎻忚堪 + +WGAN-GP(Wasserstein GAN-Gradient Penalty)鏄竴绉嶅寘鍚獶CGAN缁撴瀯鍒ゅ埆鍣ㄤ笌鐢熸垚鍣ㄧ殑鐢熸垚瀵规姉缃戠粶锛屽畠鍦╓GAN鍩虹涓婄敤姊害鎯╃綒鏇夸唬浜嗘搴﹀壀瑁侊紝鍦ㄦ崯澶卞嚱鏁板紩鍏ヤ簡鍒ゅ埆鍣ㄨ緭鍑虹浉瀵硅緭鍏ョ殑浜岄樁瀵兼暟锛屼綔涓鸿鑼冨垽鍒櫒鎹熷け妯$殑鍑芥暟锛岃В鍐充簡WGAN闅忔満涓嶆敹鏁涗笌鐢熸垚鏍锋湰璐ㄩ噺宸殑闂銆� + +[璁烘枃](https://arxiv.org/pdf/1704.00028v3.pdf)锛欼mproved Training of Wasserstein GANs + +# 妯″瀷鏋舵瀯 + +WGAN-GP缃戠粶鍖呭惈涓ら儴鍒嗭紝鐢熸垚鍣ㄧ綉缁滃拰鍒ゅ埆鍣ㄧ綉缁溿€傚垽鍒櫒缃戠粶閲囩敤鍗风НDCGAN鐨勬灦鏋勶紝鍗冲灞備簩缁村嵎绉浉杩炪€傜敓鎴愬櫒缃戠粶閲囩敤鍗风НDCGAN鐢熸垚鍣ㄧ粨鏋勩€傝緭鍏ユ暟鎹寘鎷湡瀹炲浘鐗囨暟鎹拰鍣0鏁版嵁锛屾暟鎹泦Cifar10鐨勭湡瀹炲浘鐗噐esize鍒�32*32锛屽櫔澹版暟鎹殢鏈虹敓鎴愩€� + +# 鏁版嵁闆� + +[CIFAR-10](<http://www.cs.toronto.edu/~kriz/cifar.html>) + +- 鏁版嵁闆嗗ぇ灏忥細175M, 60000寮�10鍒嗙被褰╄壊鍥惧儚 + - 璁粌闆嗭細146M锛屽叡50000寮犲浘鍍忋€� + - 娉細瀵逛簬鐢熸垚瀵规姉缃戠粶锛屾帹鐞嗛儴鍒嗘槸浼犲叆鍣0鏁版嵁鐢熸垚鍥剧墖锛屾晠鏃犻渶浣跨敤娴嬭瘯闆嗘暟鎹€� +- 鏁版嵁鏍煎紡锛氫簩杩涘埗鏂囦欢 + +# 鐜瑕佹眰 + +- 纭欢锛圓scend锛� + - 浣跨敤Ascend鏉ユ惌寤虹‖浠剁幆澧冦€� +- 妗嗘灦 + - [MindSpore](https://www.mindspore.cn/install) +- 濡傞渶鏌ョ湅璇︽儏锛岃鍙傝濡備笅璧勬簮锛� + - [MindSpore鏁欑▼](https://www.mindspore.cn/tutorials/zh-CN/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/docs/zh-CN/master/index.html) + +# 蹇€熷叆闂� + +閫氳繃瀹樻柟缃戠珯瀹夎MindSpore鍚庯紝鎮ㄥ彲浠ユ寜鐓у涓嬫楠よ繘琛岃缁冨拰璇勪及锛� + +- Ascend澶勭悊鍣ㄧ幆澧冭繍琛� + + ```python + # 杩愯鍗曟満璁粌绀轰緥锛� + bash run_train.sh [DATAROOT] [DEVICE_ID] + + + # 杩愯璇勪及绀轰緥 + bash run_eval.sh [DEVICE_ID] [CONFIG_PATH] [CKPT_FILE_PATH] [OUTPUT_DIR] [NIMAGES] + ``` + +# 鑴氭湰璇存槑 + +## 鑴氭湰鍙婃牱渚嬩唬鐮� + +```bash +鈹溾攢鈹€ model_zoo + 鈹溾攢鈹€ README.md // 鎵€鏈夋ā鍨嬬浉鍏宠鏄� + 鈹溾攢鈹€ WGAN-GP + 鈹溾攢鈹€ README.md // WGAN-GP鐩稿叧璇存槑 + 鈹溾攢鈹€ scripts + 鈹� 鈹溾攢鈹€ run_train.sh // 鍗曟満鍒癆scend澶勭悊鍣ㄧ殑shell鑴氭湰 + 鈹� 鈹溾攢鈹€ run_eval.sh // Ascend璇勪及鐨剆hell鑴氭湰 + 鈹溾攢鈹€ src + 鈹� 鈹溾攢鈹€ dataset.py // 鍒涘缓鏁版嵁闆嗗強鏁版嵁棰勫鐞� + 鈹� 鈹溾攢鈹€ model.py // WGAN-GP鐢熸垚鍣ㄤ笌鍒ゅ埆鍣ㄥ畾涔� + 鈹� 鈹溾攢鈹€ args.py // 鍙傛暟閰嶇疆鏂囦欢 + 鈹� 鈹溾攢鈹€ cell.py // 妯″瀷鍗曟璁粌鏂囦欢 + 鈹溾攢鈹€ train.py // 璁粌鑴氭湰 + 鈹溾攢鈹€ eval.py // 璇勪及鑴氭湰 +``` + +## 鑴氭湰鍙傛暟 + +鍦╝rgs.py涓彲浠ュ悓鏃堕厤缃缁冨弬鏁般€佽瘎浼板弬鏁板強妯″瀷瀵煎嚭鍙傛暟銆� + + ```python + # common_config + 'device_target': 'Ascend', # 杩愯璁惧 + 'device_id': 0, # 鐢ㄤ簬璁粌鎴栬瘎浼版暟鎹泦鐨勮澶嘔D + + # train_config + 'dataroot': None, # 鏁版嵁闆嗚矾寰勶紝蹇呴』杈撳叆锛屼笉鑳戒负绌� + 'workers': 8, # 鏁版嵁鍔犺浇绾跨▼鏁� + 'batchSize': 64, # 鎵瑰鐞嗗ぇ灏� + 'imageSize': 32, # 鍥剧墖灏哄澶у皬 + 'DIM': 128, # GAN缃戠粶闅愯棌灞傚ぇ灏� + 'niter': 1200, # 缃戠粶璁粌鐨別poch鏁� + 'save_iterations': 1000, # 淇濆瓨妯″瀷鏂囦欢鐨勭敓鎴愬櫒杩唬娆℃暟 + 'lrD': 0.0001, # 鍒ゅ埆鍣ㄥ垵濮嬪涔犵巼 + 'lrG': 0.0001, # 鐢熸垚鍣ㄥ垵濮嬪涔犵巼 + 'beta1': 0.5, # Adam浼樺寲鍣╞eta1鍙傛暟 + 'beta2': 0.9, # Adam浼樺寲鍣╞eta2鍙傛暟 + 'netG': '', # 鎭㈠璁粌鐨勭敓鎴愬櫒鐨刢kpt鏂囦欢璺緞 + 'netD': '', # 鎭㈠璁粌鐨勫垽鍒櫒鐨刢kpt鏂囦欢璺緞 + 'Diters': 5, # 姣忚缁冧竴娆$敓鎴愬櫒闇€瑕佽缁冨垽鍒櫒鐨勬鏁� + 'experiment': None, # 淇濆瓨妯″瀷鍜岀敓鎴愬浘鐗囩殑璺緞锛岃嫢涓嶆寚瀹氾紝鍒欎娇鐢ㄩ粯璁よ矾寰� + + # eval_config + 'ckpt_file': None, # 璁粌鏃朵繚瀛樼殑鐢熸垚鍣ㄧ殑鏉冮噸鏂囦欢.ckpt鐨勮矾寰勶紝蹇呴』鎸囧畾 + 'output_dir': None, # 鐢熸垚鍥剧墖鐨勮緭鍑鸿矾寰勶紝蹇呴』鎸囧畾 + ``` + +鏇村閰嶇疆缁嗚妭璇峰弬鑰冭剼鏈琡args.py`銆� + +## 璁粌杩囩▼ + +### 鍗曟満璁粌 + +- Ascend澶勭悊鍣ㄧ幆澧冭繍琛� + + ```bash + bash run_train.sh [DATAROOT] [DEVICE_ID] + ``` + + 涓婅堪python鍛戒护灏嗗湪鍚庡彴杩愯锛屾偍鍙互閫氳繃train.log鏂囦欢鏌ョ湅缁撴灉銆� + + 璁粌缁撴潫鍚庯紝鎮ㄥ彲鍦ㄥ瓨鍌ㄧ殑鏂囦欢澶癸紙榛樿鏄�./samples锛変笅鎵惧埌鐢熸垚鐨勫浘鐗囥€佹鏌ョ偣鏂囦欢鍜�.json鏂囦欢銆傞噰鐢ㄤ互涓嬫柟寮忓緱鍒版崯澶卞€硷細 + + ```bash + [0/1200][230/937][23] Loss_D: -379.555344 Loss_G: -33.761238 + [0/1200][235/937][24] Loss_D: -214.557617 Loss_G: -23.762344 + ... + ``` + +## 鎺ㄧ悊杩囩▼ + +### 鎺ㄧ悊 + +- 鍦ˋscend鐜涓嬭瘎浼� + + 鍦ㄨ繍琛屼互涓嬪懡浠や箣鍓嶏紝璇锋鏌ョ敤浜庢帹鐞嗙殑妫€鏌ョ偣鍜宩son鏂囦欢璺緞锛屽苟璁剧疆杈撳嚭鍥剧墖鐨勮矾寰勩€� + + ```bash + bash run_eval.sh [DEVICE_ID] [CKPT_FILE_PATH] [OUTPUT_DIR] + ``` + + 涓婅堪python鍛戒护灏嗗湪鍚庡彴杩愯锛屾偍鍙互閫氳繃eval/eval.log鏂囦欢鏌ョ湅鏃ュ織淇℃伅锛屽湪杈撳嚭鍥剧墖鐨勮矾寰勪笅鏌ョ湅鐢熸垚鐨勫浘鐗囥€� + +# 妯″瀷鎻忚堪 + +## 鎬ц兘 + +### 璁粌鎬ц兘 + +| 鍙傛暟 | Ascend | +| ------------------------- | ----------------------------------------------------- | +| 璧勬簮 | Ascend 910 锛汣PU 2.60GHz锛�192鏍革紱鍐呭瓨锛�755G | +| 涓婁紶鏃ユ湡 | 2022-08-01 | +| MindSpore鐗堟湰 | 1.8.0 | +| 鏁版嵁闆� | CIFAR-10 | +| 璁粌鍙傛暟 | max_epoch=1200, batch_size=64, lr_init=0.0001 | +| 浼樺寲鍣� | Adam | +| 鎹熷け鍑芥暟 | 鑷畾涔夋崯澶卞嚱鏁� | +| 杈撳嚭 | 鐢熸垚鐨勫浘鐗� | +| 閫熷害 | 鍗曞崱锛�0.06绉�/姝� | + +鐢熸垚鍥剧墖鏁堟灉濡備笅锛� + + + +# 闅忔満鎯呭喌璇存槑 + +鍦╰rain.py涓紝鎴戜滑璁剧疆浜嗛殢鏈虹瀛愩€� + +# ModelZoo涓婚〉 + + 璇锋祻瑙堝畼缃慬涓婚〉](https://gitee.com/mindspore/models)銆� diff --git a/research/cv/wgan_gp/eval.py b/research/cv/wgan_gp/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..9c1a53dab7e0c9cc4721dc07a1ebd2acdcbd072f --- /dev/null +++ b/research/cv/wgan_gp/eval.py @@ -0,0 +1,71 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore as ms +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore import context +import numpy as np +from PIL import Image + +from src.model import DcganG +from src.args import get_args + +def save_image(img, img_path, IMAGE_SIZE): + """save image""" + mul = ops.Mul() + add = ops.Add() + if isinstance(img, Tensor): + img = mul(img, 255 * 0.5) + img = add(img, 255 * 0.5) + + img = img.asnumpy().astype(np.uint8).transpose((0, 2, 3, 1)) + + elif not isinstance(img, np.ndarray): + raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img))) + + IMAGE_ROW = 8 # Row num + IMAGE_COLUMN = 8 # Column num + PADDING = 2 # Interval of small pictures + to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE + PADDING * (IMAGE_COLUMN + 1), + IMAGE_ROW * IMAGE_SIZE + PADDING * (IMAGE_ROW + 1))) # create a new picture + # cycle + ii = 0 + for y in range(1, IMAGE_ROW + 1): + for x in range(1, IMAGE_COLUMN + 1): + from_image = Image.fromarray(img[ii]) + to_image.paste(from_image, ((x - 1) * IMAGE_SIZE + PADDING * x, (y - 1) * IMAGE_SIZE + PADDING * y)) + ii = ii + 1 + + to_image.save(img_path) # save + +if __name__ == "__main__": + + args_opt = get_args() + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) + context.set_context(device_id=args_opt.device_id) + + netG = DcganG(args_opt.DIM) + + # load weights + load_param_into_net(netG, load_checkpoint(args_opt.ckpt_file)) + + fixed_noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, 128]), dtype=ms.float32) + + fake = netG(fixed_noise) + save_image(fake, '{0}/generated_samples.png'.format(args_opt.output_dir), args_opt.imageSize) + + print("Generate images success!") diff --git a/research/cv/wgan_gp/imgs/fake_samples_200000.png b/research/cv/wgan_gp/imgs/fake_samples_200000.png new file mode 100644 index 0000000000000000000000000000000000000000..7bd6f304c38021c7a4e39f279f4bc0e2b9a7fe04 Binary files /dev/null and b/research/cv/wgan_gp/imgs/fake_samples_200000.png differ diff --git a/research/cv/wgan_gp/requirements.txt b/research/cv/wgan_gp/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..adf4746ea1793617fe8ca72a5622d958893d310a --- /dev/null +++ b/research/cv/wgan_gp/requirements.txt @@ -0,0 +1,2 @@ +Pillow +onnxruntime-gpu \ No newline at end of file diff --git a/research/cv/wgan_gp/scripts/run_eval.sh b/research/cv/wgan_gp/scripts/run_eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..a34da980cc18bdb69ac90b6b2f6002b988772329 --- /dev/null +++ b/research/cv/wgan_gp/scripts/run_eval.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash bash run_eval.sh device_id ckpt_file output_dir" +echo "For example: bash run_eval.sh DEVICE_ID CKPT_FILE_PATH OUTPUT_DIR" +echo "It is better to use the absolute path." +echo "==============================================================================================================" + +EXEC_PATH=$(pwd) +echo "$EXEC_PATH" +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +cd ../ +rm -rf eval +mkdir eval +cd ./eval +mkdir src +cd ../ +cp ./*.py ./eval +cp ./src/*.py ./eval/src +cd ./eval + +env > env0.log + +echo "train begin." +python eval.py --device_id $1 --ckpt_file $2 --output_dir $3 > ./eval.log 2>&1 & + +if [ $? -eq 0 ];then + echo "eval success" +else + echo "eval failed" + exit 2 +fi +echo "finish" +cd ../ diff --git a/research/cv/wgan_gp/scripts/run_train.sh b/research/cv/wgan_gp/scripts/run_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..02d0203080d3501d8075701fe555768952df4179 --- /dev/null +++ b/research/cv/wgan_gp/scripts/run_train.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash bash run_train.sh dataroot device_id" +echo "For example: bash run_train.sh /home/cifar10/cifar-10-batches-bin/ 3" +echo "It is better to use the absolute path." +echo "==============================================================================================================" + +EXEC_PATH=$(pwd) +echo "$EXEC_PATH" +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +cd ../ +rm -rf train +mkdir train +cd ./train +mkdir src +cd ../ +cp ./*.py ./train +cp ./src/*.py ./train/src +cd ./train + +env > env0.log + +echo "train begin." +python train.py --dataroot $1 --device_id $2 > ./train.log 2>&1 & + +if [ $? -eq 0 ];then + echo "training success" +else + echo "training failed" + exit 2 +fi +echo "finish" +cd ../ diff --git a/research/cv/wgan_gp/src/args.py b/research/cv/wgan_gp/src/args.py new file mode 100644 index 0000000000000000000000000000000000000000..7a934994b765b632146faecf13842038b64641ae --- /dev/null +++ b/research/cv/wgan_gp/src/args.py @@ -0,0 +1,45 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""get args""" +import argparse + +def get_args(): + """Define the common options that are used in training.""" + parser = argparse.ArgumentParser(description='WGAN-GP') + parser.add_argument('--device_target', default='Ascend', help='enables npu') + parser.add_argument('--device_id', type=int, default=0) + + parser.add_argument('--dataroot', default=None, help='path to dataset') + + parser.add_argument('--workers', type=int, help='number of data loading workers', default=8) + parser.add_argument('--batchSize', type=int, default=64, help='input batch size') + parser.add_argument('--imageSize', type=int, default=32, help='the height/width of the input image to network') + parser.add_argument('--DIM', type=int, default=128, help='dimension of input samples') + parser.add_argument('--niter', type=int, default=1200, help='number of epochs to train for') + parser.add_argument('--save_iterations', type=int, default=1000, help='num of gen iterations to save model') + parser.add_argument('--lrD', type=float, default=0.0001, help='learning rate for Critic, default=0.0001') + parser.add_argument('--lrG', type=float, default=0.0001, help='learning rate for Generator, default=0.0001') + parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') + parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for adam. default=0.9') + parser.add_argument('--netG', default='', help="path to netG (to continue training)") + parser.add_argument('--netD', default='', help="path to netD (to continue training)") + parser.add_argument('--Diters', type=int, default=5, help='number of D iters per each G iter') + parser.add_argument('--experiment', default="samples", help='Where to store samples and models') + parser.add_argument('--ckpt_file', default=None, help='path to pretrained ckpt model file') + parser.add_argument('--output_dir', default=None, help='output path of generated images') + + args_opt = parser.parse_args() + return args_opt diff --git a/research/cv/wgan_gp/src/cell.py b/research/cv/wgan_gp/src/cell.py new file mode 100644 index 0000000000000000000000000000000000000000..09b39299afb2577dcd29b0b5fbc6bf40bce7cf6a --- /dev/null +++ b/research/cv/wgan_gp/src/cell.py @@ -0,0 +1,67 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore import ops, nn +import mindspore.numpy as mnp + +class GenWithLossCell(nn.Cell): + """Generator with loss(wrapped)""" + + def __init__(self, netG, netD): + super(GenWithLossCell, self).__init__() + self.netG = netG + self.netD = netD + + def construct(self, noise): + + fake = self.netG(noise) + errG = self.netD(fake) + return -errG + + +class DisWithLossCell(nn.Cell): + """ Discriminator with loss(wrapped) """ + + def __init__(self, netG, netD): + super(DisWithLossCell, self).__init__() + self.netG = netG + self.netD = netD + self.gradop = ops.GradOperation() + self.LAMBDA = 100 + self.uniform = ops.UniformReal() + + def compute_gradient_penalty(self, real_samples, fake_samples): + """Calculates the gradient penalty loss for WGAN GP""" + + # Get random interpolation between real and fake samples + alpha = self.uniform((real_samples.shape[0], 1, 1, 1)) + interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)) + + grad_fn = self.gradop(self.netD) + gradients = grad_fn(interpolates) + gradients = gradients.view(gradients.shape[0], -1) + gradient_penalty = ops.reduce_mean(((mnp.norm(gradients, 2, axis=1) - 1) ** 2)) + return gradient_penalty + + def construct(self, real, noise): + + errD_real = self.netD(real) + fake = self.netG(noise) + fake = ops.stop_gradient(fake) + errD_fake = self.netD(fake) + + gradient_penalty = self.compute_gradient_penalty(real, fake) + + return errD_fake - errD_real + gradient_penalty * self.LAMBDA diff --git a/research/cv/wgan_gp/src/dataset.py b/research/cv/wgan_gp/src/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b555cd8d4946bed66fc44a59fb532ac333260e78 --- /dev/null +++ b/research/cv/wgan_gp/src/dataset.py @@ -0,0 +1,39 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore as ms +import mindspore.dataset as ds +import mindspore.dataset.vision as C +import mindspore.dataset.transforms as C2 + +def create_dataset(dataroot, batchSize, imageSize, repeat_num=1, workers=8, target='Ascend'): + + # define map operations + resize_op = C.Resize(imageSize) + normalize_op = C.Normalize(mean=(0.5*255, 0.5*255, 0.5*255), std=(0.5*255, 0.5*255, 0.5*255)) + hwc2chw_op = C.HWC2CHW() + + data_set = ds.Cifar10Dataset(dataroot, num_parallel_workers=workers, shuffle=True) + transform = [resize_op, normalize_op, hwc2chw_op] + + type_cast_op = C2.TypeCast(ms.int32) + + data_set = data_set.map(input_columns='image', operations=transform, num_parallel_workers=workers) + data_set = data_set.map(input_columns='label', operations=type_cast_op, num_parallel_workers=workers) + + data_set = data_set.batch(batchSize, drop_remainder=True) + data_set = data_set.repeat(repeat_num) + + return data_set diff --git a/research/cv/wgan_gp/src/model.py b/research/cv/wgan_gp/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9c203ac34d99312b612b36154558ad3e61f476cf --- /dev/null +++ b/research/cv/wgan_gp/src/model.py @@ -0,0 +1,100 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.nn as nn +import mindspore.ops as ops + +class DcgannobnD(nn.Cell): + """ DCGAN Descriminator with no Batchnorm layer """ + def __init__(self, DIM): + super(DcgannobnD, self).__init__() + + self.DIM = DIM + KERNEL_SIZE = 5 + STRIDE = 2 + + main = nn.SequentialCell() + main.append(nn.Conv2d(3, self.DIM, KERNEL_SIZE, STRIDE, 'same')) + main.append(nn.LeakyReLU(0.2)) + + main.append(nn.Conv2d(self.DIM, self.DIM*2, KERNEL_SIZE, STRIDE, 'same')) + main.append(nn.LeakyReLU(0.2)) + + main.append(nn.Conv2d(self.DIM*2, self.DIM*4, KERNEL_SIZE, STRIDE, 'same')) + main.append(nn.LeakyReLU(0.2)) + self.main = main + self.linear = nn.Dense(4*4*4*self.DIM, 1) + + def construct(self, input1): + + output = self.main(input1) + output = output.view(-1, 4*4*4*self.DIM) + output = self.linear(output) + output = ops.reduce_mean(output) + return output + +class DcganG(nn.Cell): + + def __init__(self, DIM): + super(DcganG, self).__init__() + + self.DIM = DIM + KERNEL_SIZE = 5 + STRIDE = 2 + + self.linear = nn.Dense(self.DIM, 4*4*4*self.DIM) + self.bn = nn.BatchNorm2d(4*4*4*self.DIM) + self.relu = nn.ReLU() + + main = nn.SequentialCell() + main.append(nn.Conv2dTranspose( + self.DIM*4, + self.DIM*2, + KERNEL_SIZE, + stride=STRIDE, + weight_init='normal', + pad_mode='same')) + main.append(nn.BatchNorm2d(self.DIM*2)) + main.append(nn.ReLU()) + + main.append(nn.Conv2dTranspose( + self.DIM*2, + self.DIM, + KERNEL_SIZE, + stride=STRIDE, + weight_init='normal', + pad_mode='same')) + main.append(nn.BatchNorm2d(self.DIM)) + main.append(nn.ReLU()) + + main.append(nn.Conv2dTranspose( + self.DIM, + 3, + KERNEL_SIZE, + stride=STRIDE, + weight_init='normal', + pad_mode='same')) + main.append(nn.Tanh()) + self.main = main + + def construct(self, input1): + + output = self.linear(input1) + output = output.view(64, 4*4*4*self.DIM, 1, 1) + output = self.bn(output) + output = self.relu(output) + output = output.view(64, 4*self.DIM, 4, 4) + output = self.main(output) + return output diff --git a/research/cv/wgan_gp/train.py b/research/cv/wgan_gp/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d7eb97c465dfe4a0a891f943b7665ea8c717a296 --- /dev/null +++ b/research/cv/wgan_gp/train.py @@ -0,0 +1,191 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import time +import mindspore as ms +from mindspore import Tensor +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common import initializer as init +import mindspore.common.dtype as mstype +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net, save_checkpoint +from PIL import Image +import numpy as np + +from src.dataset import create_dataset +from src.model import DcganG, DcgannobnD +from src.cell import GenWithLossCell, DisWithLossCell +from src.args import get_args + +if __name__ == '__main__': + + t_begin = time.time() + args_opt = get_args() + + if args_opt.experiment is None: + args_opt.experiment = 'samples' + os.system('rm -rf {0}'.format(args_opt.experiment)) + os.system('mkdir {0}'.format(args_opt.experiment)) + + context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=int(args_opt.device_id)) + ms.set_seed(0) + dataset = create_dataset(args_opt.dataroot, args_opt.batchSize, args_opt.imageSize, 1, + args_opt.workers, args_opt.device_target) + + def init_weight(net): + for _, cell in net.cells_and_names(): + if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)): + cell.weight.set_data(init.initializer(init.Normal(0.02), cell.weight.shape)) + elif isinstance(cell, nn.BatchNorm2d): + cell.gamma.set_data(init.initializer(Tensor(np.random.normal(1, 0.02, cell.gamma.shape), \ + mstype.float32), cell.gamma.shape)) + cell.beta.set_data(init.initializer('zeros', cell.beta.shape)) + elif isinstance(cell, nn.Dense): + cell.weight.set_data(init.initializer(init.Normal(0.02), cell.weight.shape)) + + def save_image(img, img_path, IMAGE_SIZE): + """save image""" + mul = ops.Mul() + add = ops.Add() + if isinstance(img, Tensor): + img = mul(img, 255 * 0.5) + img = add(img, 255 * 0.5) + + img = img.asnumpy().astype(np.uint8).transpose((0, 2, 3, 1)) + + elif not isinstance(img, np.ndarray): + raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img))) + + IMAGE_ROW = 8 # Row num + IMAGE_COLUMN = 8 # Column num + PADDING = 2 # Interval of small pictures + to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE + PADDING * (IMAGE_COLUMN + 1), + IMAGE_ROW * IMAGE_SIZE + PADDING * (IMAGE_ROW + 1))) # create a new picture + # cycle + ii = 0 + for y in range(1, IMAGE_ROW + 1): + for x in range(1, IMAGE_COLUMN + 1): + from_image = Image.fromarray(img[ii]) + to_image.paste(from_image, ((x - 1) * IMAGE_SIZE + PADDING * x, (y - 1) * IMAGE_SIZE + PADDING * y)) + ii = ii + 1 + + to_image.save(img_path) # save + + + # define net---------------------------------------------------------------------------------------------- + # Generator + netG = DcganG(args_opt.DIM) + + init_weight(netG) + + if args_opt.netG != '': # load checkpoint if needed + load_param_into_net(netG, load_checkpoint(args_opt.netG)) + + netD = DcgannobnD(args_opt.DIM) + init_weight(netD) + + if args_opt.netD != '': + load_param_into_net(netD, load_checkpoint(args_opt.netD)) + + fixed_noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, 128]), dtype=ms.float32) + + # # setup optimizer + optimizerD = nn.Adam( + netD.trainable_params(), + learning_rate=args_opt.lrD, + beta1=args_opt.beta1, + beta2=args_opt.beta2) + optimizerG = nn.Adam( + netG.trainable_params(), + learning_rate=args_opt.lrG, + beta1=args_opt.beta1, + beta2=args_opt.beta2) + + netG_train = nn.TrainOneStepCell(GenWithLossCell(netG, netD), optimizerG) + netD_train = nn.TrainOneStepCell(DisWithLossCell(netG, netD), optimizerD) + + netG_train.set_train() + netD_train.set_train() + + gen_iterations = 0 + + t0 = time.time() + # Train + for epoch in range(args_opt.niter): + data_iter = dataset.create_dict_iterator() + length = dataset.get_dataset_size() + i = 0 + while i < length: + ########################### + # (1) Update D network + ########################### + for p in netD.trainable_params(): # reset requires_grad + p.requires_grad = True # they are set to False below in netG update + + # train the discriminator Diters times + Diters = args_opt.Diters + + j = 0 + while j < Diters and i < length: + j += 1 + + data = data_iter.__next__() + i += 1 + + # train with real and fake + real = data['image'] + noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, 128]), dtype=ms.float32) + loss_D = netD_train(real, noise) + + print('epoch %d loss_D: %.4f ' % (epoch, float(loss_D))) + + # ########################## + # (2) Update G network + # ########################## + for p in netD.trainable_params(): + p.requires_grad = False # to avoid computation + + noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, 128]), dtype=ms.float32) + + loss_G = netG_train(noise) + gen_iterations += 1 + + t1 = time.time() + print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f' + % (epoch, args_opt.niter, i, length, gen_iterations, + loss_D.asnumpy(), loss_G.asnumpy())) + print('step_cost: %.4f seconds' % (float(t1 - t0))) + t0 = t1 + + if gen_iterations % args_opt.save_iterations == 0: + + fake = netG(fixed_noise) + save_image( + real, + '{0}/real_samples.png'.format(args_opt.experiment), + args_opt.imageSize) + save_image( + fake, + '{0}/fake_samples_{1}.png'.format(args_opt.experiment, gen_iterations), + args_opt.imageSize) + + save_checkpoint(netD, '{0}/debug_netD_giter_{1}.ckpt'.format(args_opt.experiment, gen_iterations)) + save_checkpoint(netG, '{0}/debug_netG_giter_{1}.ckpt'.format(args_opt.experiment, gen_iterations)) + + t_end = time.time() + print('total_cost: %.4f seconds' % (float(t_end - t_begin))) + print("Train success!")