From db768c43701c45c151c2771af71578a8a8b9215a Mon Sep 17 00:00:00 2001 From: lvyufeng <lvyufeng@cqu.edu.cn> Date: Sat, 25 Dec 2021 10:58:19 +0800 Subject: [PATCH] fix crnn 310 and warpctc 910 errors --- official/cv/crnn/ascend310_infer/src/main.cc | 4 ++-- official/cv/warpctc/export.py | 2 +- official/cv/warpctc/src/warpctc.py | 4 ++-- official/cv/warpctc/train.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/official/cv/crnn/ascend310_infer/src/main.cc b/official/cv/crnn/ascend310_infer/src/main.cc index 69d6d6349..14cf72995 100644 --- a/official/cv/crnn/ascend310_infer/src/main.cc +++ b/official/cv/crnn/ascend310_infer/src/main.cc @@ -71,8 +71,8 @@ int main(int argc, char **argv) { auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>(); ascend310_info->SetDeviceID(FLAGS_device_id); ascend310_info->SetInsertOpConfigPath({FLAGS_aipp_path}); - ascend310->SetPrecisionMode("allow_fp32_to_fp16"); - ascend310->SetOpSelectImplMode("high_precision"); + ascend310_info->SetPrecisionMode("allow_fp32_to_fp16"); + ascend310_info->SetOpSelectImplMode("high_precision"); context->MutableDeviceInfo().push_back(ascend310_info); Graph graph; diff --git a/official/cv/warpctc/export.py b/official/cv/warpctc/export.py index 62cac25d9..94e40c1e0 100644 --- a/official/cv/warpctc/export.py +++ b/official/cv/warpctc/export.py @@ -45,7 +45,7 @@ def run_export(): batch_size = config.batch_size hidden_size = config.hidden_size image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float32)) - net = StackedRNN(input_size=input_size, hidden_size=hidden_size) + net = StackedRNN(input_size=input_size, hidden_size=hidden_size, batch_size=batch_size) param_dict = load_checkpoint(config.ckpt_file) load_param_into_net(net, param_dict) diff --git a/official/cv/warpctc/src/warpctc.py b/official/cv/warpctc/src/warpctc.py index 4068ec3c9..8bc2efe63 100644 --- a/official/cv/warpctc/src/warpctc.py +++ b/official/cv/warpctc/src/warpctc.py @@ -34,7 +34,7 @@ class StackedRNN(nn.Cell): num_layer(int): the number of layer of LSTM. """ - def __init__(self, input_size, hidden_size=512, num_layer=2): + def __init__(self, input_size, hidden_size=512, num_layer=2, batch_size=64): super(StackedRNN, self).__init__() self.batch_size = batch_size self.input_size = input_size @@ -42,7 +42,7 @@ class StackedRNN(nn.Cell): self.reshape = P.Reshape() self.cast = P.Cast() - self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2) + self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layer) self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32) self.fc_bias = np.random.random(self.num_classes).astype(np.float32) diff --git a/official/cv/warpctc/train.py b/official/cv/warpctc/train.py index acb33cb67..314532447 100644 --- a/official/cv/warpctc/train.py +++ b/official/cv/warpctc/train.py @@ -136,7 +136,7 @@ def train(): loss = CTCLoss(max_sequence_length=config.captcha_width, max_label_length=max_captcha_digits, batch_size=config.batch_size) - net = StackedRNN(input_size=input_size, hidden_size=config.hidden_size) + net = StackedRNN(input_size=input_size, hidden_size=config.hidden_size, batch_size=config.batch_size) opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum) net = WithLossCell(net, loss) -- GitLab