Skip to content
Snippets Groups Projects
Commit 517e3346 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1596 fix crnn 310 and warpctc 910 errors

Merge pull request !1596 from 吕昱峰(Nate.River)/master
parents 046f0d5f db768c43
No related branches found
No related tags found
No related merge requests found
......@@ -71,8 +71,8 @@ int main(int argc, char **argv) {
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310_info->SetDeviceID(FLAGS_device_id);
ascend310_info->SetInsertOpConfigPath({FLAGS_aipp_path});
ascend310->SetPrecisionMode("allow_fp32_to_fp16");
ascend310->SetOpSelectImplMode("high_precision");
ascend310_info->SetPrecisionMode("allow_fp32_to_fp16");
ascend310_info->SetOpSelectImplMode("high_precision");
context->MutableDeviceInfo().push_back(ascend310_info);
Graph graph;
......
......@@ -45,7 +45,7 @@ def run_export():
batch_size = config.batch_size
hidden_size = config.hidden_size
image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float32))
net = StackedRNN(input_size=input_size, hidden_size=hidden_size)
net = StackedRNN(input_size=input_size, hidden_size=hidden_size, batch_size=batch_size)
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(net, param_dict)
......
......@@ -34,7 +34,7 @@ class StackedRNN(nn.Cell):
num_layer(int): the number of layer of LSTM.
"""
def __init__(self, input_size, hidden_size=512, num_layer=2):
def __init__(self, input_size, hidden_size=512, num_layer=2, batch_size=64):
super(StackedRNN, self).__init__()
self.batch_size = batch_size
self.input_size = input_size
......@@ -42,7 +42,7 @@ class StackedRNN(nn.Cell):
self.reshape = P.Reshape()
self.cast = P.Cast()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2)
self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layer)
self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32)
self.fc_bias = np.random.random(self.num_classes).astype(np.float32)
......
......@@ -136,7 +136,7 @@ def train():
loss = CTCLoss(max_sequence_length=config.captcha_width,
max_label_length=max_captcha_digits,
batch_size=config.batch_size)
net = StackedRNN(input_size=input_size, hidden_size=config.hidden_size)
net = StackedRNN(input_size=input_size, hidden_size=config.hidden_size, batch_size=config.batch_size)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum)
net = WithLossCell(net, loss)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment