Skip to content
Snippets Groups Projects
Unverified Commit ede0ba8f authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2040 [FasterRCNN]fix import error

Merge pull request !2040 from zhouneng/fasterrcnn_issue_fix
parents edec3082 fd022b93
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@ from collections import defaultdict
import numpy as np
from pycocotools.coco import COCO
import mindspore as ms
from mindspore.common import set_seed, Parameter
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset, parse_json_annos_from_txt
......
......@@ -15,10 +15,10 @@
"""FasterRcnn training network wrapper."""
import time
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore import ParameterTuple
from mindspore import ParameterTuple, Tensor
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
......@@ -134,7 +134,7 @@ class TrainOneStepCell(nn.Cell):
self.optimizer = optimizer
self.grad = ops.GradOperation(get_by_list=True,
sens_param=True)
self.sens = ms.numpy.ones((1,) * sens).astype(np.float32)
self.sens = Tensor([sens,], mstype.float32)
self.reduce_flag = reduce_flag
if reduce_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment