diff --git a/official/nlp/bert/src/finetune_eval_model.py b/official/nlp/bert/src/finetune_eval_model.py index 09c524101bf008d3bd10c6dfea15b0c2b7c3bce0..1dd21ae760fd38102fc85a5c4cadb9b694ed56a5 100644 --- a/official/nlp/bert/src/finetune_eval_model.py +++ b/official/nlp/bert/src/finetune_eval_model.py @@ -16,17 +16,13 @@ ''' Bert finetune and evaluation model script. ''' - -import numpy as np - import mindspore.nn as nn from mindspore.common.initializer import TruncatedNormal from mindspore.ops import operations as P -from mindspore import context, Tensor +from mindspore import context from .bert_model import BertModel - class BertCLSModel(nn.Cell): """ This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3), @@ -142,8 +138,8 @@ class BertNERModel(nn.Cell): batch_size = input_ids.shape[0] data_type = self.dtype hidden_size = self.lstm_hidden_size - h0 = Tensor(np.zeros((2, batch_size, hidden_size)), data_type) - c0 = Tensor(np.zeros((2, batch_size, hidden_size)), data_type) + h0 = P.Zeros()((2, batch_size, hidden_size), data_type) + c0 = P.Zeros()((2, batch_size, hidden_size), data_type) seq, _ = self.lstm(seq, (h0, c0)) seq = self.reshape(seq, self.shape) logits = self.dense_1(seq)