diff --git a/official/cv/vgg16/mindspore_hub_conf.py b/official/cv/vgg16/mindspore_hub_conf.py index 588fede1d065dbc70127286e2a3cb214859f8f7b..c9037fbf277e44ba1825bd35f48c1bd5bf457b31 100644 --- a/official/cv/vgg16/mindspore_hub_conf.py +++ b/official/cv/vgg16/mindspore_hub_conf.py @@ -14,7 +14,7 @@ # ============================================================================ """hub config.""" from src.vgg import vgg16 as VGG16 -from model_utils.moxing_adapter import config +from model_utils.config import get_config_static def vgg16(*args, **kwargs): return VGG16(*args, **kwargs) @@ -22,5 +22,12 @@ def vgg16(*args, **kwargs): def create_network(name, *args, **kwargs): if name == "vgg16": - return vgg16(args=config, *args, **kwargs) + num_classes = kwargs.get("num_classes", 10) + if "num_classes" in kwargs: + del kwargs["num_classes"] + if num_classes == 10: + config = get_config_static(config_path="../cifar10_config.yaml") + elif num_classes == 1000: + config = get_config_static(config_path="../imagenet2012_config.yaml") + return vgg16(num_classes=num_classes, args=config, *args, **kwargs) raise NotImplementedError(f"{name} is not implemented in the repo")