Skip to content
Snippets Groups Projects
Unverified Commit 3821e5f9 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2734 fix BertNERModel

Merge pull request !2734 from zhaoting/bert
parents dcb1ff3e 92539220
No related branches found
No related tags found
No related merge requests found
......@@ -16,17 +16,13 @@
'''
Bert finetune and evaluation model script.
'''
import numpy as np
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P
from mindspore import context, Tensor
from mindspore import context
from .bert_model import BertModel
class BertCLSModel(nn.Cell):
"""
This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
......@@ -142,8 +138,8 @@ class BertNERModel(nn.Cell):
batch_size = input_ids.shape[0]
data_type = self.dtype
hidden_size = self.lstm_hidden_size
h0 = Tensor(np.zeros((2, batch_size, hidden_size)), data_type)
c0 = Tensor(np.zeros((2, batch_size, hidden_size)), data_type)
h0 = P.Zeros()((2, batch_size, hidden_size), data_type)
c0 = P.Zeros()((2, batch_size, hidden_size), data_type)
seq, _ = self.lstm(seq, (h0, c0))
seq = self.reshape(seq, self.shape)
logits = self.dense_1(seq)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment