Skip to content
Snippets Groups Projects
Commit 4a55557c authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1061 add new model meta-baseline

Merge pull request !1061 from xuncliu/master
parents 208cd92f a17fb2de
No related branches found
No related tags found
No related merge requests found
Showing
with 1622 additions and 0 deletions
# Few-Shot Meta-Baseline
# Contents
- [Few-Shot Meta-Baseline](#few-shot-meta-baseline)
- [Datasets](#datasets)
- [Environment](#environment)
- [Script and sample code](#script-and-sample-code)
- [Script parameters](#script-parameters)
- [Quick start](#quick-start)
- [Main Results](#main-results)
- [5-way accuracy (%) on *miniImageNet*](#5-way-accuracy-----on--miniimagenet-)
- [Running the code](#running-the-code)
- [1. Training Classifier-Baseline](#1-training-classifier-baseline)
- [2. Training Meta-Baseline](#2-training-meta-baseline)
- [3. Test](#3-test)
- [Performance](#Performance)
- [Citation](#citation)
# [Few-Shot Meta-Baseline](#Contents)
Mindspore implementation for ***Meta-Baseline: Exploring Simple Meta-Learning for Few-Shot Learning***.
Original Pytorch implementation can be seen
Meta-Baseline [here](https://github.com/cyvius96/few-shot-meta-baseline).
<img src="https://user-images.githubusercontent.com/10364424/76388735-bfb02580-63a4-11ea-8540-4021961a4fbe.png" width="600">
## [Datasets](#Contents)
- [miniImageNet](https://drive.google.com/file/d/1fJAK5WZTjerW7EWHHQAR9pRJVNg1T1Y7/view?usp=sharing) (
courtesy of [Spyros Gidaris](https://github.com/gidariss/FewShotWithoutForgetting))
Download the datasets and link the folders into `dataset/` with names `mini-imagenet`.
Note `imagenet` refers to ILSVRC-2012 1K dataset with two directories `train` and `val` with class
folders.
- Directory structure of the dataset:
```markdown
.dataset(root_path)
├── mini-imagenet
├── miniImageNet_category_split_val.pickle
├── miniImageNet_category_split_train_phase_val.pickle
├── miniImageNet_category_split_train_phase_train.pickle
├── miniImageNet_category_split_train_phase_test.pickle
├── miniImageNet_category_split_test.pickle
```
## [Environment](#Contents)
- Hardware (Ascend/GPU)
- Prepare hardware environment with Ascend or GPU.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
## [Script and sample code](#Contents)
```markdown
. meta-baseline
├── .dataset(root_path) # dataset see section of dataset
├── scripts
│ ├──run_eval.sh # script to eval
│ ├──run_standalone_train_classifier.sh # script to train_classifier
│ ├──run_standalone_train_meta.sh # script to train_meta
├── src
│ ├──data
│ ├──InerSamplers.py # sampler
│ ├──mini_Imagenet.py # mini_Imagenet
│ ├──model
│ ├──classifier.py # train_classifier
│ ├──meta_baseline.py # train meta_baseline
│ ├──meta_eval.py # evaluation
│ ├──resnet12.py # backbone
│ ├──util
│ ├──_init_.py # util
├── eval.py # evaluation script
├── export.py # export
├── README.md # descriptions about meta-baseline
├── train_classifier.py # train_classifier script
└── train_meta.py # train_meta script
```
## [Script parameters](#Contents)
Parameters for both train_classifier and train_meta can be set in the follow:
- Parameters:
```text
# base setting
"root_path": "../dataset", # dataset root path
"device_target": "GPU", # device GPU or Ascend
"run_offline": False, # run on line or offline
"dataset": "mini-imagenet", # dataset mini_imagenet
"ep_per_batch": 4, # nums of batch episode
"max_epoch": 25, # epoch
"lr": 0.1, # lr
"n_classes": 64, # base classes 64
"batch_size": 128, # batchsize
"weight_decay": 5.e-4, # weight_decay
"num_ways": 5, # way 5
"num_shots": 1, # shot 1 or 5
```
## [Quick start](#Contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
- run on Ascend or GPU
```bash
# standalone training classifier
bash scripts/run_standalone_train_classifier.sh [GPU] [./dataset]
# standalone training meta
bash scripts/run_standalone_train_meta.sh [./save/classifier_mini-imagenet_resnet12/max-va.ckpt]
[GPU] [./dataset] [1 or 5]
# standalone evaluation
bash scripts/run_eval.sh [./save/classifier_mini-imagenet_resnet12/max-va.ckpt] [./dataset]
[GPU] [1 or 5]
## [Main Results](#Contents)
*The models on *miniImageNet* use ResNet-12 as backbone, the channels in each block are **
64-128-256-512**, the backbone does **NOT** introduce any additional trick (e.g. DropBlock or wider
channel in some recent work).*
### 5-way accuracy (%) on *miniImageNet*
|method |1-shot |5-shot|
|-----------------------------------------------|------ |------|
|[Baseline++](https://arxiv.org/abs/1904.04232) |51.87 |75.68 |
|[MetaOptNet](https://arxiv.org/abs/1904.03758) |62.64 |78.63 |
|Classifier-Baseline |58.91|77.76|
|Meta-Baseline |63.17|79.26|
|Classifier-Baseline* |60.83|78.12|
|Meta-Baseline* |62.37|78.28|
## [Running the code](#Contents)
### [1. Training Classifier-Baseline](#Contents)
``` python
python train_classifier.py --root_path ./dataset/ --device_id 0 --device_target GPU --run_offline True
```
```text
...
epoch 16, 1-shot, val acc 0.6024
epoch 16, 5-shot, val acc 0.7720
2.0m 53.8m/1.4h
train loss 0.3114, train acc 0.9227
epoch 17, 1-shot, val acc 0.6006
epoch 17, 5-shot, val acc 0.7745
2.0m 55.9m/1.4h
...
```
note:After each training epoch is completed, we have done inferences, so there is no need to
execute eval.py separately to view the results.
### [2. Training Meta-Baseline](#Contents)
``` python
python train_meta.py --num_shots 1 --load_encoder (dir) --root_path ./dataset/ --device_id 0 --device_target GPU --run_offline True
```
load_encoder is saved checkpoint of classifier-baseline.
The loss value and acc will be achieved as follows:
```text
...
epoch 5, train 0.3933|0.8947, val 1.0961|0.6113, 2.7m 13.5m/40.6m (@-1)
epoch 6, train 0.3977|0.8903, val 1.0951|0.6103, 2.7m 16.3m/40.7m (@-1)
epoch 7, train 0.3882|0.8931, val 1.0818|0.6219, 2.7m 19.0m/40.7m (@-1)
epoch 8, train 0.3752|0.8989, val 1.0839|0.6075, 2.7m 21.7m/40.7m (@-1)
epoch 9, train 0.3764|0.8967, val 1.0724|0.6116, 2.7m 24.5m/40.8m (@-1)
...
```
note:After each training epoch is completed, we have done inferences, so there is no need to
execute eval.py separately to view the results.
### [3. Test](#Contents)
``` python
python eval.py --load_encoder (dir) --num_shots 1 --root_path ./dataset/ --device_target GPU
```
## [Performance](#Contents)
### Training Performance
| Parameters | Ascend 910 | GPU(RTX Titan) |
| -------------------------- | ------------------------------------------------------------ | ----------------------------------------------|
| Model Version | Meta-Baseline | Meta-Baseline |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NVIDIA RTX Titan-24G |
| uploaded Date | 12/28/2021 (month/day/year) | 12/28/2021 (month/day/year) |
| MindSpore Version | 1.3.0, 1.5.0 | 1.3.0, 1.5.0 |
| Dataset | mini-imagenet | mini-imagenet |
| Training Parameters | Epochs=20, steps per epoch=300, batch_size = 4 lr=0.001 | epoch=20, steps per epoch=300 batch_size = 4 lr=0.001|
| Optimizer | SGD | SGD |
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
| outputs | probability | probability |
| Speed | 2.7 ms/step(1pcs, PyNative Mode) | 2.7ms/step(1pcs, PyNative Mode) |
| Total time | about 50min |about 50min |
| Parameters (M) | 8.7M | 8.7M |
| Checkpoint for Fine tuning | 31.4M (.ckpt file) | 31.4M (.ckpt file) |
| Scripts | [link](https://gitee.com/mindspore/models/research/cv/meta-baseline) ||
### Inference Performance
| Parameters | Ascend | GPU(RTX Titan) |
| ----------------- | ----------------------------------------------------------- | ----------------------------------------------------------- |
| Model Version | Meta-Baseline | Meta-Baseline |
| Resource | Ascend 910; OS Euler2.8 | NVIDIA RTX Titan-24G |
| Uploaded Date | 12/28/2021 (month/day/year) | 12/28/2021 (month/day/year) |
| MindSpore Version | 1.5.0, 1.3.0 | 1.5.0, 1.3.0 |
| Dataset | mini-imagenet | mini-imagenet |
| batch_size | 4 | 4 |
| outputs | probability | probability |
| Accuracy | See the table | |
## [Citation](#Contents)
``` text
@misc{chen2020new,
title={A New Meta-Baseline for Few-Shot Learning},
author={Yinbo Chen and Xiaolong Wang and Zhuang Liu and Huijuan Xu and Trevor Darrell},
year={2020},
eprint={2003.04390},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
pretrain_eval
"""
import os
import argparse
import mindspore.dataset as ds
from mindspore import context
from mindspore import load_checkpoint, load_param_into_net
import numpy as np
import src.util as util
from src.model.meta_eval import MetaEval
from src.model.classifier import Classifier
from src.data.IterSamplers import CategoriesSampler
from src.data.mini_Imagenet import MiniImageNet
from tqdm import tqdm
class Eval:
"""
Eval meta-baseline and EGNN
"""
def __init__(self):
pass
def set_exp_name(self):
"""
:return: experience setting
"""
exp_name = 'D-{}'.format(args.dataset)
exp_name += '_backbone-{}'.format(args.backbone)
exp_name += '_N-{}_K-{}'.format(args.num_ways, args.num_shots)
exp_name += '_L-{}_B-{}'.format(args.num_layers, args.meta_batch_size)
return exp_name
def pretrain_eval(self):
"""
:return: meta-baseline eval
"""
param_dict = load_checkpoint(args.load_encoder)
net = Classifier(64)
load_param_into_net(net, param_dict)
n_way = 5
n_query = 15
n_shots = [args.num_shots]
eval_net = MetaEval()
root_path = os.path.join(args.root_path, args.dataset)
testset = MiniImageNet(root_path, 'test')
fs_loaders = []
for n_shot in n_shots:
test_sampler = CategoriesSampler(testset.data, testset.label, n_way, n_shot + n_query,
200,
args.ep_per_batch)
test_loader = ds.GeneratorDataset(test_sampler, ['data'], shuffle=True)
fs_loaders.append(test_loader)
aves_keys = ['tl', 'ta', 'vl', 'va']
for n_shot in n_shots:
aves_keys += ['fsa-' + str(n_shot)]
aves = {k: util.Averager() for k in aves_keys}
print("few-shot eval start")
net.set_train(mode=False)
for i, n_shot in enumerate(n_shots):
np.random.seed(0)
for data in tqdm(fs_loaders[i].create_dict_iterator(), desc='test', leave=False):
x_shot, x_query = data['data'][:, :, :n_shot], data['data'][:, :, n_shot:]
img_shape = x_query.shape[-3:]
x_query = x_query.view(args.ep_per_batch, -1,
*img_shape) # bs*(way*n_query)*3*84*84
label = util.make_nk_label(n_way, n_query, args.ep_per_batch) # bs*(way*n_query)
acc_val, _ = eval_net.eval(x_shot, x_query, label, net.encoder)
aves['fsa-' + str(n_shot)].add(acc_val.asnumpy())
for k, v in aves.items():
aves[k] = v.item()
for n_shot in n_shots:
key = 'fsa-' + str(n_shot)
print("epoch {}, {}-shot, val acc {:.4f}".format(str(1), n_shot, aves[key]))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# parser.add_argument('--config', default='configs/train_classifier_mini.yaml') root_path
parser.add_argument('--name', default=None)
parser.add_argument('--root_path', default='./dataset/')
parser.add_argument('--tag', default=None)
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU', 'CPU'])
parser.add_argument('--dataset', default='mini-imagenet')
parser.add_argument('--backbone', default='convnet', choices=['convnet', 'resnet12'])
parser.add_argument('--load_encoder',
default='./save/epoch-max.ckpt')
parser.add_argument('--resume', type=str, default="False")
parser.add_argument('--ep_per_batch', type=int, default=4)
parser.add_argument('--max_epoch', type=int, default=3)
parser.add_argument('--visualize_datasets', default=True)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=1e-6)
parser.add_argument('--device_id', type=int, default=0)
parser.add_argument('--num_gpu', type=int, default=1)
parser.add_argument('--classifier', default='linear-classifier')
parser.add_argument('--n_classes', type=int, default=64)
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('--meta_batch_size', type=int, default=40)
parser.add_argument('--save_epoch', type=int, default=200)
parser.add_argument('--eval_fs_epoch', type=int, default=3)
parser.add_argument('--num_ways', type=int, default=5)
parser.add_argument('--num_shots', type=int, default=1)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--emb_size', type=int, default=128)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--data_url', default=None, help='Location of data.')
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
parser.add_argument('--run_offline', type=str, default="False", help='run in offline')
args = parser.parse_args()
eval_model = Eval()
eval_model.set_exp_name()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
if args.device_target == 'GPU' or args.device_target == 'Ascend':
context.set_context(device_id=args.device_id)
if args.run_offline == "True":
import moxing as mox
mox.file.copy_parallel(src_url=args.data_url, dst_url=args.root_path)
else:
raise ValueError("Unsupported platform.")
eval_model.pretrain_eval()
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import argparse
import numpy as np
from mindspore import dtype as mstype
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.model.classifier import Classifier
parser = argparse.ArgumentParser(description='meta-baseline')
parser.add_argument('--device_id', type=int, default=0, help='Device id.')
parser.add_argument("--batch_size", type=int, default=128, help="batch size")
parser.add_argument('--n_classes', type=int, default=64)
parser.add_argument('--ckpt_file', type=str, required=True, help='Checkpoint file path.')
parser.add_argument('--file_name', type=str, default='meta_baseline', help='Output file name.')
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='MINDIR',
help='file format')
parser.add_argument('--device_target', type=str, choices=['Ascend', 'CPU', 'GPU'], default='Ascend',
help='Device target')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
network = Classifier(n_classes=args.n_classes)
assert args.ckpt_file is not None, "args.ckpt_file is None."
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(network, param_dict)
img = Tensor(np.ones([args.batch_size, 3, 84, 84]), mstype.float32)
export(network, img, file_name=args.file_name, file_format=args.file_format)
#!/usr/bin/env bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=========================================================================================="
echo "Please run the script as: "
echo "bash eval.sh [load_encoder] [ROOT_PATH] [DEVICE] [num_shots]"
echo "========================================================================================="
export LOAD_ENCODER=$1
ROOT_PATH=$2
DEVICE=$3
NUM_SHOT=$4
python ./eval.py \
--load_encoder $LOAD_ENCODER \
--root_path $ROOT_PATH \
--device_target $DEVICE \
--run_offline "True" \
--num_shots $NUM_SHOT
\ No newline at end of file
#!/usr/bin/env bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=========================================================================================="
echo "Please run the script as: "
echo "bash run_train_classifier.sh [device_target] [ROOT_PATH]"
echo "========================================================================================="
export DEVICE=$1
ROOT_PATH=$2
python ./train_classifier.py \
--run_offline "True" \
--device_target $DEVICE \
--root_path $ROOT_PATH > log.txt 2>&1 &
#!/usr/bin/env bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=========================================================================================="
echo "Please run the script as: "
echo "bash run_train_meta.sh [load_encoder/(checkpoint path)] [device_target] [ROOT_PATH]
[num_shots/1 or 5]"
echo "========================================================================================="
export LOAD_ENCODER=$1
DEVICE=$2
ROOT_PATH=$3
NUM_SHOT=$4
python ./train_meta.py \
--load_encoder $LOAD_ENCODER \
--run_offline "True" \
--device_target $DEVICE \
--root_path $ROOT_PATH \
--num_shots $NUM_SHOT > log_meta.txt 2>&1 &
\ No newline at end of file
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
CategoriesSampler
"""
import mindspore.dataset as ds
import numpy as np
class CategoriesSampler(ds.Sampler):
"""
CategoriesSampler
"""
def __init__(self, data, label, n_cls, n_per, iterations, ep_per_batch=1):
super(CategoriesSampler, self).__init__()
self.__iterations = iterations
self.n_cls = n_cls # way
self.n_per = n_per # shot = support_shot + query_shot
self.ep_per_batch = ep_per_batch # 4
self.__iter = 0
label = np.array(label)
self.data = data
self.label = label
self.catlocs = []
for c in range(max(label) + 1):
self.catlocs.append(np.argwhere(label == c).reshape(-1))
def __next__(self):
if self.__iter >= self.__iterations:
raise StopIteration
batch = []
for _ in range(self.ep_per_batch):
episode = []
classes = np.random.choice(len(self.catlocs), self.n_cls,
replace=False)
for c in classes:
l = np.random.choice(self.catlocs[c], self.n_per,
replace=False)
episode.append(self.data[l])
episode = np.stack(episode)
batch.append(episode)
batch = np.stack(batch) # bs * n_cls * n_per
self.__iter += 1
return (batch,)
def __iter__(self):
self.__iter = 0
return self
def __len__(self):
return self.__iterations
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MiniImageNet
"""
import os
import pickle
import numpy as np
import mindspore.dataset.vision.py_transforms as py_transforms
from mindspore.dataset.transforms.py_transforms import Compose
from PIL import Image
class MiniImageNet:
"""
MiniImageNet
"""
def __init__(self, root_path, split='train'):
self.split = split
split_tag = split
if split == 'train':
split_tag = 'train_phase_train'
split_file = 'miniImageNet_category_split_{}.pickle'.format(split_tag)
with open(os.path.join(root_path, split_file), 'rb') as f:
pack = pickle.load(f, encoding='latin1')
data = pack['data']
label = pack['labels']
min_label = min(label)
print("min_label", min_label)
label = [x - min_label for x in label]
image_size = 84
normalize = py_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if split == 'train':
self.transforms = Compose([
decode,
py_transforms.RandomCrop(image_size, padding=4),
py_transforms.ToTensor(),
normalize
])
else:
self.transforms = Compose([
decode,
py_transforms.Resize(image_size),
py_transforms.ToTensor(),
normalize
])
data = [self.transforms(x)[0] for x in data]
self.len = len(data)
self.data = np.array(data)
self.label = np.array(label)
self.n_classes = max(self.label) + 1
def __len__(self):
return self.len
def __getitem__(self, i):
return self.data[i], self.label[i]
def decode(img):
"""
:param img:
:return:
"""
return Image.fromarray(img)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
ClassifierWithLossCell
"""
import mindspore as ms
from mindspore import ops, Parameter, Tensor
from mindspore import nn
from src.model.resnet12 import resnet12
class Classifier(nn.Cell):
"""
Classifier
"""
def __init__(self, n_classes):
super(Classifier, self).__init__()
self.encoder = resnet12()
in_dim = self.encoder.emb_size
self.classifier = nn.Dense(in_channels=in_dim, out_channels=n_classes, has_bias=False)
def construct(self, x):
"""
:param x: data
:return: logits
"""
x = self.encoder(x)
x = self.classifier(x)
return x
class ClassifierWithLossCell(nn.Cell):
"""
ClassifierWithLossCell
"""
def __init__(self, net):
super(ClassifierWithLossCell, self).__init__()
self.net = net
self.loss = nn.SoftmaxCrossEntropyWithLogits(reduction='mean', sparse=True)
self.acc = Parameter(Tensor(0.0, ms.float32), requires_grad=False)
def construct(self, img, labels):
"""
:param img: data
:param labels: label
:return: loss cost
"""
logits = self.net(img)
ret = ops.Argmax()(logits) == labels
acc = ops.ReduceMean()(ret.astype(ms.float32))
self.acc = acc
return self.loss(logits, labels)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MetaBaseline
"""
import mindspore as ms
from mindspore import nn, Parameter, ops, Tensor
from src.model.resnet12 import resnet12
class MetaBaseline(nn.Cell):
"""
MetaBaseline
"""
def __init__(self, method='cos', temp=5.0, temp_learnable=True):
super(MetaBaseline, self).__init__()
self.encoder = resnet12()
self.method = method
if temp_learnable:
self.temp = Parameter(Tensor([temp], ms.float32), requires_grad=True)
else:
self.temp = [temp]
def construct(self, x_shot, x_query):
"""
:param x_shot:
:param x_query:
:return: logit
"""
shot_shape = x_shot.shape[:-3]
query_shape = x_query.shape[:-3]
img_shape = x_shot.shape[-3:]
x_shot = x_shot.view(-1, *img_shape)
x_query = x_query.view(-1, *img_shape)
x_tot = self.encoder(ops.Concat(0)([x_shot, x_query]))
x_shot, x_query = x_tot[:len(x_shot)], x_tot[-len(x_query):]
x_shot = x_shot.view(*shot_shape, -1)
x_query = x_query.view(*query_shape, -1)
########## cross-class bias ############
bs = x_shot.shape[0]
fs = x_shot.shape[-1]
bias = x_shot.view(bs, -1, fs).mean(1) - x_query.mean(1)
x_query = x_query + ops.ExpandDims()(bias, 1)
x_shot = x_shot.mean(axis=-2)
x_shot = ops.L2Normalize(axis=-1)(x_shot)
x_query = ops.L2Normalize(axis=-1)(x_query)
logit = ops.BatchMatMul()(x_query, x_shot.transpose(0, 2, 1))
return logit * self.temp[0]
class MetaBaselineWithLossCell(nn.Cell):
"""
MetaBaselineWithLossCell
"""
def __init__(self, net):
super(MetaBaselineWithLossCell, self).__init__()
self.net = net
self.loss = nn.SoftmaxCrossEntropyWithLogits(reduction='mean', sparse=True)
self.acc = Parameter(Tensor(0.0, ms.float32), requires_grad=False)
def construct(self, x_shot, x_query, labels):
"""
:param x_shot:
:param x_query:
:param labels:
:return: loss
"""
logits = self.net(x_shot, x_query)
ret = ops.Argmax()(logits) == labels
acc = ret.astype(ms.float32).mean()
self.acc = acc
n_way = logits.shape[-1]
return self.loss(logits.view(-1, n_way), labels.astype(ms.int32).view(-1))
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MetaEval
"""
import mindspore as ms
from mindspore import nn, ops, Tensor
class MetaEval:
"""
MetaEval
"""
def __init__(self, method='cos', temp=5.):
super(MetaEval, self).__init__()
# self.encoder = resnet12()
self.method = method
self.temp = temp
self.loss = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
def eval(self, x_shot, x_query, labels, encoder):
"""
:param x_shot:
:param x_query:
:param labels:
:param encoder:
:return: acc loss
"""
shot_shape = x_shot.shape[:-3]
query_shape = x_query.shape[:-3]
img_shape = x_shot.shape[-3:]
x_shot = x_shot.view(-1, *img_shape)
x_query = x_query.view(-1, *img_shape)
x_tot = encoder(ops.Concat(0)([x_shot, x_query]))
x_shot, x_query = x_tot[:len(x_shot)], x_tot[-len(x_query):]
x_shot = x_shot.view(*shot_shape, -1)
x_query = x_query.view(*query_shape, -1)
########## cross-class bias ############
bs = x_shot.shape[0]
fs = x_shot.shape[-1]
bias = x_shot.view(bs, -1, fs).mean(1) - x_query.mean(1)
x_query = x_query + ops.ExpandDims()(bias, 1)
x_shot = x_shot.mean(axis=-2)
x_shot = ops.L2Normalize(axis=-1)(x_shot)
x_query = ops.L2Normalize(axis=-1)(x_query)
logits = ops.BatchMatMul()(x_query, x_shot.transpose(0, 2, 1))
logits = logits * self.temp
ret = ops.Argmax()(logits) == labels.astype(ms.int32)
acc = ret.astype(ms.float32).mean()
n_way = logits.shape[-1]
loss = self.loss(logits.view(-1, n_way),
ops.OneHot()(labels.astype(ms.int32).view(-1), logits.shape[-1],
Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)))
return acc, loss
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
ResNet12
"""
from mindspore import nn
from mindspore.common.initializer import HeNormal, Constant
from mindspore.common import initializer as init
def weight_variable_conv():
"""
:return: HeNormal
"""
return HeNormal(mode='fan_out', nonlinearity='leaky_relu')
def weight_variable_bn(value):
"""
:param value: constant value
:return: Constant
"""
return Constant(value)
def conv3x3(in_planes, out_planes):
"""
:param in_planes: in_planes
:param out_planes: out_planes
:return: conv3x3
"""
return nn.Conv2d(in_planes, out_planes, 3, padding=1, pad_mode='pad', has_bias=False)
def conv1x1(in_planes, out_planes):
"""
:param in_planes: in_planes
:param out_planes: out_planes
:return: conv1x1
"""
return nn.Conv2d(in_planes, out_planes, 1, has_bias=False)
def norm_layer(planes):
"""
:param planes: planes
:return: BatchNorm2d
"""
return nn.BatchNorm2d(planes, momentum=0.1)
class Block(nn.Cell):
"""
Block
"""
def __init__(self, inplanes, planes, downsample):
super(Block, self).__init__()
self.relu = nn.LeakyReLU(0.1)
self.conv1 = conv3x3(inplanes, planes)
self.bn1 = norm_layer(planes)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.conv3 = conv3x3(planes, planes)
self.bn3 = norm_layer(planes)
self.downsample = downsample
self.meanpool = nn.AvgPool2d(2, 2)
def construct(self, x):
"""
:param x: feat
:return: block feat
"""
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
identity = self.downsample(x)
out += identity
out = self.relu(out)
out = self.meanpool(out)
return out
class ResNet12(nn.Cell):
"""
ResNet12
"""
def __init__(self, channels):
super(ResNet12, self).__init__()
self.inplanes = 3
self.layer1 = self._make_layer(channels[0])
self.layer2 = self._make_layer(channels[1])
self.layer3 = self._make_layer(channels[2])
self.layer4 = self._make_layer(channels[3])
self.emb_size = channels[3]
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(init.initializer(
HeNormal(negative_slope=0, mode='fan_out', nonlinearity='leaky_relu'),
cell.weight.shape, cell.weight.dtype))
def _make_layer(self, planes):
downsample = nn.SequentialCell(
[conv1x1(self.inplanes, planes),
norm_layer(planes)]
)
block = Block(self.inplanes, planes, downsample)
self.inplanes = planes
return block
def construct(self, x):
"""
:param x: data
:return: feat
"""
x = self.layer1(x) # 40*40
x = self.layer2(x) # 20*20
x = self.layer3(x) # 10*10
x = self.layer4(x) # 5*5
x = x.view(x.shape[0], x.shape[1], -1).mean(axis=2)
return x
def resnet12():
"""
:return: resnet12
"""
return ResNet12([64, 128, 256, 512])
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
util
"""
import os
import shutil
import time
import numpy as np
from mindspore import Tensor
from mindspore import dtype as ms
from mindspore import ops
_log_path = None
def set_log_path(path):
"""
:param path:
:return:
"""
global _log_path
_log_path = path
class Averager:
"""
Averager
"""
def __init__(self):
self.n = 0.0
self.v = 0.0
def add(self, v, n=1.0):
"""
:param v:
:param n:
:return:
"""
self.v = (self.v * self.n + v * n) / (self.n + n)
self.n += n
def item(self):
"""
:return: v
"""
return self.v
class Timer:
"""
Timer
"""
def __init__(self):
self.v = time.time()
def s(self):
"""
:return:None
"""
self.v = time.time()
def t(self):
"""
:return: time
"""
return time.time() - self.v
def set_gpu(gpu):
"""
:param gpu: gpu
:return: None
"""
print('set gpu:', gpu)
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
def ensure_path(path):
"""
ensure_path
:param path:
:param remove:
:return:
"""
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)
else:
os.makedirs(path)
def time_str(t):
"""
time to str
:param t:
:return: time
"""
if t >= 3600:
return '{:.1f}h'.format(t / 3600)
if t >= 60:
return '{:.1f}m'.format(t / 60)
return '{:.1f}s'.format(t)
def make_nk_label(n, k, batch_size):
"""
get label
:param n:
:param k:
:param batch_size:
:return: label
"""
label = ops.BroadcastTo((k, n))(Tensor(np.arange(n), ms.float32))
label = ops.Transpose()(label, (1, 0))
label = ops.Reshape()(label, (-1,))
label = ops.Tile()(label, (batch_size, 1))
return label
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
train_classifier
"""
import os
import argparse
import numpy as np
import mindspore.dataset as ds
from mindspore import nn
from mindspore import save_checkpoint
from mindspore.nn import piecewise_constant_lr
from mindspore import context
from mindspore import ParameterTuple
from mindspore import ops
import src.util as util
from src.data.IterSamplers import CategoriesSampler
from src.data.mini_Imagenet import MiniImageNet
from src.model.classifier import Classifier, ClassifierWithLossCell
from src.model.meta_eval import MetaEval
from tqdm import tqdm
class TrainOneStepCell(nn.Cell):
"""
TrainOneStepCell
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=True)
self.network = network
self.optimizer = optimizer
self.weights = ParameterTuple(network.trainable_params())
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
def set_sens(self, value):
"""
:param value:
:return:
"""
self.sens = value
def construct(self, data, label):
"""
:param data:
:param label:
:return:
"""
weights = self.weights
loss = self.network(data, label)
sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(data, label, sens)
return ops.depend(loss, self.optimizer(grads))
def main():
"""
train
:return:
"""
util.ensure_path(save_path)
# train set
root_path = os.path.join(args.root_path, args.dataset)
n_way, n_query, n_shots = 5, 15, [1, 5]
if args.dataset == 'mini-imagenet':
trainset = MiniImageNet(root_path, 'train')
testset = MiniImageNet(root_path, 'test')
else:
print('not found error')
trainloader = ds.GeneratorDataset(trainset, ['data', 'label'], shuffle=True).batch(
args.batch_size)
# test set
fs_loaders = []
for n_shot in n_shots:
fs_loader = CategoriesSampler(testset.data, testset.label, n_way, n_shot + n_query,
args.ep_batch, args.ep_per_batch)
fs_loaders.append(ds.GeneratorDataset(fs_loader, ['data'], shuffle=True))
# define network
net = Classifier(args.n_classes)
net_with_loss = ClassifierWithLossCell(net)
# define opt
train_batch = trainset.len // args.batch_size
multiStepLR = piecewise_constant_lr([(args.max_epoch - 10) * train_batch,
args.max_epoch * train_batch], [args.lr, args.lr * 0.1])
net_opt = nn.SGD(params=net.trainable_params(), learning_rate=multiStepLR,
weight_decay=args.weight_decay, momentum=0.9, nesterov=True)
train_cell = TrainOneStepCell(net_with_loss, net_opt)
eval_loss_fn = MetaEval()
timer_used = util.Timer()
timer_epoch = util.Timer()
max_va = 0.
for epoch in range(1, args.max_epoch + 1):
timer_epoch.s()
aves_keys = ['tl', 'ta']
for n_shot in n_shots:
aves_keys += ['fsa-' + str(n_shot)]
aves = {k: util.Averager() for k in aves_keys}
# pre train
net.set_train(mode=True)
for data in tqdm(trainloader.create_dict_iterator(), desc='train', leave=False):
loss = train_cell(data['data'], data['label'])
acc = net_with_loss.acc
aves['tl'].add(loss.asnumpy())
aves['ta'].add(acc.asnumpy())
# few-shot eval
if epoch == args.max_epoch or epoch % args.save_epoch == 0:
net.set_train(mode=False)
for i, n_shot in enumerate(n_shots):
np.random.seed(0)
for data in tqdm(fs_loaders[i].create_dict_iterator(), desc='test-' + str(n_shot),
leave=False):
x_shot, x_query = data['data'][:, :, :n_shot], data['data'][:, :, n_shot:]
img_shape = x_query.shape[-3:]
x_query = x_query.view(args.ep_per_batch, -1, *img_shape)
label = util.make_nk_label(n_way, n_query, args.ep_per_batch)
acc_val, _ = eval_loss_fn.eval(x_shot, x_query, label, net.encoder)
aves['fsa-' + str(n_shot)].add(acc_val.asnumpy())
# post
for k, v in aves.items():
aves[k] = v.item()
t_epoch = util.time_str(timer_epoch.t())
t_used = util.time_str(timer_used.t())
t_estimate = util.time_str(timer_used.t() / epoch * args.max_epoch)
print("epoch {},train loss {:.4f}, train acc {:.4f}".format(str(epoch), aves['tl'],
aves['ta']))
if epoch == args.max_epoch or epoch % args.save_epoch == 0:
for n_shot in n_shots:
key = 'fsa-' + str(n_shot)
print("epoch {}, {}-shot, val acc {:.4f}".format(str(epoch), n_shot, aves[key]))
if epoch <= args.max_epoch:
print("{} {}/{}".format(t_epoch, t_used, t_estimate))
else:
print("{}".format(t_epoch))
path = os.path.join(save_path, 'epoch-{}.ckpt'.format(epoch))
if epoch >= 15 and epoch % args.save_epoch == 0:
save_checkpoint(net, path)
if aves['fsa-' + str(5)] > max_va:
max_va = aves['fsa-' + str(5)]
save_checkpoint(net, os.path.join(save_path, 'max-va.ckpt'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default=None)
parser.add_argument('--root_path', default='./dataset/')
parser.add_argument('--tag', default=None)
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU', 'CPU'])
parser.add_argument('--dataset', default='mini-imagenet')
parser.add_argument('--encoder', default='resnet12')
parser.add_argument('--load_encoder',
default='./save/classifier3_mini-imagenet_resnet12/epoch-70.ckpt')
parser.add_argument('--ep_per_batch', type=int, default=4)
parser.add_argument('--ep_batch', type=int, default=200)
parser.add_argument('--max_epoch', type=int, default=25)
parser.add_argument('--visualize_datasets', default=True)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--weight_decay', type=float, default=5.e-4)
parser.add_argument('--device_id', type=int, default=0)
parser.add_argument('--num_gpu', type=int, default=1)
parser.add_argument('--classifier', default='linear-classifier')
parser.add_argument('--n_classes', type=int, default=64)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--save_epoch', type=int, default=1)
parser.add_argument('--eval_fs_epoch', type=int, default=3)
parser.add_argument('--data_url', default=None, help='Location of data.')
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
parser.add_argument('--run_offline', type=str, default="False", help='run in offline')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
if args.device_target == 'GPU' or args.device_target == 'Ascend':
context.set_context(device_id=args.device_id)
if args.run_offline == "True":
print("run_online--")
import moxing as mox
mox.file.copy_parallel(src_url=args.data_url, dst_url=args.root_path)
else:
raise ValueError("Unsupported platform.")
svname = 'classifier_{}'.format(args.dataset)
svname += '_' + args.encoder
save_path = os.path.join('./save/', svname)
main()
if args.run_offline == "True":
md_save_path = os.path.join(args.train_url, save_path)
mox.file.make_dirs(md_save_path)
mox.file.copy_parallel(src_url=save_path, dst_url=md_save_path)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
train_meta
"""
import argparse
import os
import numpy as np
from mindspore import nn, ops, ParameterTuple, load_checkpoint, load_param_into_net, save_checkpoint
from mindspore.dataset import context, ds
from tqdm import tqdm
import src.util as util
from src.data.IterSamplers import CategoriesSampler
from src.data.mini_Imagenet import MiniImageNet
from src.model.classifier import Classifier
from src.model.meta_baseline import MetaBaseline, MetaBaselineWithLossCell
class TrainOneStepCell(nn.Cell):
"""
TrainOneStepCell
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=True)
self.network = network
self.optimizer = optimizer
self.weights = ParameterTuple(network.trainable_params())
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
def set_sens(self, value):
"""
:param value:
:return:
"""
self.sens = value
def construct(self, x_shot, x_query, label):
"""
:param x_shot:
:param x_query:
:param label:
:return:
"""
weights = self.weights
loss = self.network(x_shot, x_query, label)
sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(x_shot, x_query, label, sens)
return ops.depend(loss, self.optimizer(grads))
def main():
"""
:return:
"""
svname = args.name
if svname is None: # svname = classifier_mini-imagenet_resnet12
svname = 'train-meta_{}'.format(args.dataset)
svname += '_ways_' + str(args.num_ways)
svname += '_shots_' + str(args.num_shots)
svname += '_' + args.encoder
# loader dataset
save_path = os.path.join('./save/', svname)
util.ensure_path(save_path)
root_path = os.path.join(args.root_path, args.dataset)
if args.dataset == 'mini-imagenet':
trainset = MiniImageNet(root_path, 'train')
testset = MiniImageNet(root_path, 'test')
n_query = 15
train_sampler = CategoriesSampler(trainset.data, trainset.label, args.num_ways,
args.num_shots + n_query, args.ep_batch, args.ep_per_batch)
train_loader = ds.GeneratorDataset(train_sampler, ['data'], shuffle=True)
test_sampler = CategoriesSampler(testset.data, testset.label, args.num_ways,
args.num_shots + n_query, args.ep_batch, args.ep_per_batch)
test_loader = ds.GeneratorDataset(test_sampler, ['data'], shuffle=True)
classifier = Classifier(args.n_classes)
param_dict = load_checkpoint(args.load_encoder)
load_param_into_net(classifier, param_dict)
net = MetaBaseline()
load_param_into_net(net.encoder, classifier.encoder.parameters_dict())
net_with_loss = MetaBaselineWithLossCell(net)
net_opt = nn.SGD(params=net.trainable_params(), learning_rate=args.lr,
weight_decay=args.weight_decay, momentum=0.9)
train_cell = TrainOneStepCell(net_with_loss, net_opt)
# eval model
max_va = 0.
timer_used = util.Timer()
timer_epoch = util.Timer()
trlog = dict()
aves_keys = ['tl', 'ta', 'tvl', 'tva', 'vl', 'va']
for k in aves_keys:
trlog[k] = []
for epoch in range(1, args.max_epoch):
timer_epoch.s()
aves = {k: util.Averager() for k in aves_keys}
net.set_train(True)
np.random.seed(epoch)
# train
for data in tqdm(train_loader.create_dict_iterator(), desc='train', leave=False):
x_shot, x_query = data['data'][:, :, :args.num_shots], data['data'][:, :,
args.num_shots:]
img_shape = x_query.shape[-3:]
x_query = x_query.view(args.ep_per_batch, -1, *img_shape) # bs*(way*n_query)*3*84*84
label = util.make_nk_label(args.num_ways, n_query,
args.ep_per_batch) # bs*(way*n_query)
loss = train_cell(x_shot, x_query, label)
aves['tl'].add(loss.asnumpy())
aves['ta'].add(net_with_loss.acc.asnumpy())
# test
net.set_train(False)
# train_cell.set_train(False)
for data in tqdm(test_loader.create_dict_iterator(), desc='test', leave=False):
x_shot, x_query = data['data'][:, :, :args.num_shots], data['data'][:, :,
args.num_shots:]
img_shape = x_query.shape[-3:]
x_query = x_query.view(args.ep_per_batch, -1, *img_shape) # bs*(way*n_query)*3*84*84
label = util.make_nk_label(args.num_ways, n_query,
args.ep_per_batch) # bs*(way*n_query)
loss_val = net_with_loss(x_shot, x_query, label)
aves['vl'].add(loss_val.asnumpy())
aves['va'].add(net_with_loss.acc.asnumpy())
_sig = int(-1)
for k, v in aves.items():
aves[k] = v.item()
trlog[k].append([aves[k]])
t_epoch = util.time_str(timer_epoch.t())
t_used = util.time_str(timer_used.t())
t_estimate = util.time_str(timer_used.t() / epoch * args.max_epoch)
print('epoch {}, train {:.4f}|{:.4f}, '
'val {:.4f}|{:.4f}, {} {}/{} (@{})'.format(
epoch, aves['tl'], aves['ta'], aves['vl'], aves['va'],
t_epoch, t_used, t_estimate, _sig))
if epoch % args.save_epoch == 0:
path = os.path.join(save_path, 'epoch-{}.ckpt'.format(epoch))
save_checkpoint(net, path)
if max_va < aves['va']:
path = os.path.join(save_path, 'max-va.ckpt')
save_checkpoint(net, path)
if args.run_offline == "True":
md_save_path = os.path.join(args.train_url, save_path)
mox.file.make_dirs(md_save_path)
mox.file.copy_parallel(src_url=save_path, dst_url=md_save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default=None)
parser.add_argument('--root_path', default='./dataset/')
parser.add_argument('--tag', default=None)
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU', 'CPU'])
parser.add_argument('--dataset', default='mini-imagenet')
parser.add_argument('--encoder', default='resnet12')
parser.add_argument('--load_encoder',
default='./save/classifier5_mini-imagenet_resnet12/epoch-17.ckpt')
parser.add_argument('--ep_per_batch', type=int, default=4)
parser.add_argument('--ep_batch', type=int, default=200)
parser.add_argument('--max_epoch', type=int, default=20)
parser.add_argument('--visualize_datasets', default=True)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--weight_decay', type=float, default=5.e-4)
parser.add_argument('--device_id', type=int, default=0)
parser.add_argument('--num_gpu', type=int, default=1)
parser.add_argument('--classifier', default='linear-classifier')
parser.add_argument('--n_classes', type=int, default=64)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--save_epoch', type=int, default=2)
parser.add_argument('--eval_fs_epoch', type=int, default=4)
parser.add_argument('--optimizer', default='sgd')
parser.add_argument('--milestones', default=[90])
parser.add_argument('--num_ways', type=int, default=5)
parser.add_argument('--num_shots', type=int, default=1) #
parser.add_argument('--data_url', default=None, help='Location of data.')
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
parser.add_argument('--run_offline', type=str, default=False, help='run in offline')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
if args.device_target == 'GPU' or args.device_target == 'Ascend':
context.set_context(device_id=args.device_id)
if args.run_offline == "True":
print("run_online--")
import moxing as mox
mox.file.copy_parallel(src_url=args.data_url, dst_url=args.root_path)
else:
raise ValueError("Unsupported platform.")
main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment