Skip to content
Snippets Groups Projects
Commit 588f3a52 authored by chenhaozhe's avatar chenhaozhe
Browse files

Use ReverseV2 instead of ReverseSequence in CRNN for higher throughput

parent 241b94cf
No related branches found
No related tags found
No related merge requests found
......@@ -231,10 +231,8 @@ class CRNNV1(nn.Cell):
self.transpose = P.Transpose()
self.squeeze = P.Squeeze(axis=0)
self.vgg = VGG()
self.reverse_seq1 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq2 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq3 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq4 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq1 = P.ReverseV2(axis=[1])
self.reverse_seq2 = P.ReverseV2(axis=[1])
self.seq_length = Tensor(np.ones((self.batch_size), np.int32) * config.num_step, mstype.int32)
self.concat1 = P.Concat(axis=2)
self.rnn_dropout = nn.Dropout(0.9)
......@@ -245,10 +243,10 @@ class CRNNV1(nn.Cell):
x = self.reshape(x, (self.batch_size, self.input_size, -1))
x = self.transpose(x, (2, 0, 1))
bw_x = self.reverse_seq1(x, self.seq_length)
bw_x = self.reverse_seq1(x)
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
y1_bw, _, _, _, _, _, _, _ = self.rnn1_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
y1_bw = self.reverse_seq2(y1_bw, self.seq_length)
y1_bw = self.reverse_seq2(y1_bw)
y1_out = self.concat1((y1, y1_bw))
if self.use_dropout:
y1_out = self.rnn_dropout(y1_out)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment