Skip to content
Snippets Groups Projects
Commit df8f9341 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!40 Use ReverseV2 instead ReverseSequence in CRNN

Merge pull request !40 from chenhaozhe/Use-ReverseV2-in-CRNN
parents 5175ab29 588f3a52
No related branches found
No related tags found
No related merge requests found
......@@ -231,10 +231,8 @@ class CRNNV1(nn.Cell):
self.transpose = P.Transpose()
self.squeeze = P.Squeeze(axis=0)
self.vgg = VGG()
self.reverse_seq1 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq2 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq3 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq4 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq1 = P.ReverseV2(axis=[1])
self.reverse_seq2 = P.ReverseV2(axis=[1])
self.seq_length = Tensor(np.ones((self.batch_size), np.int32) * config.num_step, mstype.int32)
self.concat1 = P.Concat(axis=2)
self.rnn_dropout = nn.Dropout(0.9)
......@@ -245,10 +243,10 @@ class CRNNV1(nn.Cell):
x = self.reshape(x, (self.batch_size, self.input_size, -1))
x = self.transpose(x, (2, 0, 1))
bw_x = self.reverse_seq1(x, self.seq_length)
bw_x = self.reverse_seq1(x)
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
y1_bw, _, _, _, _, _, _, _ = self.rnn1_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
y1_bw = self.reverse_seq2(y1_bw, self.seq_length)
y1_bw = self.reverse_seq2(y1_bw)
y1_out = self.concat1((y1, y1_bw))
if self.use_dropout:
y1_out = self.rnn_dropout(y1_out)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment