Skip to content
Snippets Groups Projects
Commit a2db0506 authored by dinglinhe's avatar dinglinhe
Browse files

Update mindspore_hub_conf in vgg16 at branch master

parent 73400843
No related branches found
No related tags found
......@@ -14,7 +14,7 @@
# ============================================================================
"""hub config."""
from src.vgg import vgg16 as VGG16
from model_utils.moxing_adapter import config
from model_utils.config import get_config_static
def vgg16(*args, **kwargs):
return VGG16(*args, **kwargs)
......@@ -22,5 +22,12 @@ def vgg16(*args, **kwargs):
def create_network(name, *args, **kwargs):
if name == "vgg16":
return vgg16(args=config, *args, **kwargs)
num_classes = kwargs.get("num_classes", 10)
if "num_classes" in kwargs:
del kwargs["num_classes"]
if num_classes == 10:
config = get_config_static(config_path="../cifar10_config.yaml")
elif num_classes == 1000:
config = get_config_static(config_path="../imagenet2012_config.yaml")
return vgg16(num_classes=num_classes, args=config, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment