From db768c43701c45c151c2771af71578a8a8b9215a Mon Sep 17 00:00:00 2001
From: lvyufeng <lvyufeng@cqu.edu.cn>
Date: Sat, 25 Dec 2021 10:58:19 +0800
Subject: [PATCH] fix crnn 310 and warpctc 910 errors

---
 official/cv/crnn/ascend310_infer/src/main.cc | 4 ++--
 official/cv/warpctc/export.py                | 2 +-
 official/cv/warpctc/src/warpctc.py           | 4 ++--
 official/cv/warpctc/train.py                 | 2 +-
 4 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/official/cv/crnn/ascend310_infer/src/main.cc b/official/cv/crnn/ascend310_infer/src/main.cc
index 69d6d6349..14cf72995 100644
--- a/official/cv/crnn/ascend310_infer/src/main.cc
+++ b/official/cv/crnn/ascend310_infer/src/main.cc
@@ -71,8 +71,8 @@ int main(int argc, char **argv) {
   auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
   ascend310_info->SetDeviceID(FLAGS_device_id);
   ascend310_info->SetInsertOpConfigPath({FLAGS_aipp_path});
-  ascend310->SetPrecisionMode("allow_fp32_to_fp16");
-  ascend310->SetOpSelectImplMode("high_precision");
+  ascend310_info->SetPrecisionMode("allow_fp32_to_fp16");
+  ascend310_info->SetOpSelectImplMode("high_precision");
   context->MutableDeviceInfo().push_back(ascend310_info);
 
   Graph graph;
diff --git a/official/cv/warpctc/export.py b/official/cv/warpctc/export.py
index 62cac25d9..94e40c1e0 100644
--- a/official/cv/warpctc/export.py
+++ b/official/cv/warpctc/export.py
@@ -45,7 +45,7 @@ def run_export():
     batch_size = config.batch_size
     hidden_size = config.hidden_size
     image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float32))
-    net = StackedRNN(input_size=input_size, hidden_size=hidden_size)
+    net = StackedRNN(input_size=input_size, hidden_size=hidden_size, batch_size=batch_size)
 
     param_dict = load_checkpoint(config.ckpt_file)
     load_param_into_net(net, param_dict)
diff --git a/official/cv/warpctc/src/warpctc.py b/official/cv/warpctc/src/warpctc.py
index 4068ec3c9..8bc2efe63 100644
--- a/official/cv/warpctc/src/warpctc.py
+++ b/official/cv/warpctc/src/warpctc.py
@@ -34,7 +34,7 @@ class StackedRNN(nn.Cell):
         num_layer(int): the number of layer of LSTM.
      """
 
-    def __init__(self, input_size, hidden_size=512, num_layer=2):
+    def __init__(self, input_size, hidden_size=512, num_layer=2, batch_size=64):
         super(StackedRNN, self).__init__()
         self.batch_size = batch_size
         self.input_size = input_size
@@ -42,7 +42,7 @@ class StackedRNN(nn.Cell):
         self.reshape = P.Reshape()
         self.cast = P.Cast()
 
-        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2)
+        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layer)
 
         self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32)
         self.fc_bias = np.random.random(self.num_classes).astype(np.float32)
diff --git a/official/cv/warpctc/train.py b/official/cv/warpctc/train.py
index acb33cb67..314532447 100644
--- a/official/cv/warpctc/train.py
+++ b/official/cv/warpctc/train.py
@@ -136,7 +136,7 @@ def train():
     loss = CTCLoss(max_sequence_length=config.captcha_width,
                    max_label_length=max_captcha_digits,
                    batch_size=config.batch_size)
-    net = StackedRNN(input_size=input_size, hidden_size=config.hidden_size)
+    net = StackedRNN(input_size=input_size, hidden_size=config.hidden_size, batch_size=config.batch_size)
     opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum)
 
     net = WithLossCell(net, loss)
-- 
GitLab