Skip to content
Snippets Groups Projects
Commit 17888472 authored by anzhengqi's avatar anzhengqi
Browse files

modify efficientnet scirpts

parent cfde9bcb
No related branches found
No related tags found
No related merge requests found
......@@ -41,3 +41,10 @@ save_checkpoint: True
keep_checkpoint_max: 10
loss_scale: 1024
resume_start_epoch: 0
width: 224
height: 224
ckpt_file: ""
file_name: "efficientnet"
file_format: "MINDIR" # ["AIR", "ONNX", "MINDIR"]
device_target: "GPU" # ["GPU", "CPU"]
\ No newline at end of file
......@@ -40,3 +40,10 @@ save_checkpoint: True
keep_checkpoint_max: 10
loss_scale: 1024
resume_start_epoch: 0
width: 224
height: 224
ckpt_file: ""
file_name: "efficientnet"
file_format: "MINDIR" # ["AIR", "ONNX", "MINDIR"]
device_target: "GPU" # ["GPU", "CPU"]
\ No newline at end of file
......@@ -38,4 +38,11 @@ bn_tf: False
save_checkpoint: True
keep_checkpoint_max: 10
loss_scale: 1024
resume_start_epoch: 0
\ No newline at end of file
resume_start_epoch: 0
width: 224
height: 224
ckpt_file: ""
file_name: "efficientnet"
file_format: "MINDIR" # ["AIR", "ONNX", "MINDIR"]
device_target: "GPU" # ["GPU", "CPU"]
\ No newline at end of file
......@@ -13,46 +13,29 @@
# limitations under the License.
# ============================================================================
"""export file"""
import argparse
import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.efficientnet import efficientnet_b0
from src.config import dataset_config
from src.config import config
parser = argparse.ArgumentParser(description="efficientnet export")
parser.add_argument("--width", type=int, default=224, help="input width")
parser.add_argument("--height", type=int, default=224, help="input height")
parser.add_argument('--dataset', type=str, default='ImageNet', choices=['ImageNet', 'CIFAR10'],
help='ImageNet or CIFAR10')
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="efficientnet", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"],
default="MINDIR", help="file format")
parser.add_argument("--device_target", type=str, choices=["GPU", "CPU"], default="GPU",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if __name__ == "__main__":
if args.device_target not in ("GPU", "CPU"):
if config.device_target not in ("GPU", "CPU"):
raise ValueError("Only supported CPU and GPU now.")
dataset_type = args.dataset.lower()
cfg = dataset_config[dataset_type].cfg
net = efficientnet_b0(num_classes=cfg.num_classes,
drop_rate=cfg.drop,
drop_connect_rate=cfg.drop_connect,
global_pool=cfg.gp,
bn_tf=cfg.bn_tf,
net = efficientnet_b0(num_classes=config.num_classes,
drop_rate=config.drop,
drop_connect_rate=config.drop_connect,
global_pool=config.gp,
bn_tf=config.bn_tf,
)
ckpt = load_checkpoint(args.ckpt_file)
ckpt = load_checkpoint(config.ckpt_file)
load_param_into_net(net, ckpt)
net.set_train(False)
image = Tensor(np.ones([cfg.batch_size, 3, args.height, args.width], np.float32))
export(net, image, file_name=args.file_name, file_format=args.file_format)
image = Tensor(np.ones([config.batch_size, 3, config.height, config.width], np.float32))
export(net, image, file_name=config.file_name, file_format=config.file_format)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment