diff --git a/research/nlp/textrcnn/src/textrcnn.py b/research/nlp/textrcnn/src/textrcnn.py index 6fb032778973461bf180614cae05dbd02ba44467..32d71ce99e6fad4f0911dd1d9f7ef1893ea50421 100644 --- a/research/nlp/textrcnn/src/textrcnn.py +++ b/research/nlp/textrcnn/src/textrcnn.py @@ -48,7 +48,7 @@ class textrcnn(nn.Cell): if cell == "lstm": if self.gpu_flag: self.lstm = nn.LSTM(self.embed_size, self.num_hiddens) - self.lstm.to_float(mstype.float16) + self.lstm.to_float(mstype.float32) else: self.lstm = P.DynamicRNN(forget_bias=0.0) self.w1_fw = Parameter( @@ -216,5 +216,5 @@ class textrcnn(nn.Cell): output_dense = self.tanh(output_dense) # sl*bs, num_hidden output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden output = self.reduce_max(output, 0) # bs, num_hidden - outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes + outputs = self.cast(self.mydense(output), mstype.float32) # bs, num_classes return outputs diff --git a/research/nlp/textrcnn/train.py b/research/nlp/textrcnn/train.py index c2113aa6736f996361dafba82eed6457153d7027..987b92d87c9f405435dd3c191e7bdb3235093253 100644 --- a/research/nlp/textrcnn/train.py +++ b/research/nlp/textrcnn/train.py @@ -84,8 +84,7 @@ def run_train(): loss_cb = LossMonitor() time_cb = TimeMonitor() if cfg.cell == "lstm" and cfg.device_target == "GPU": - model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()}, amp_level="O3", - loss_scale_manager=loss_scale) + model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()}, loss_scale_manager=loss_scale) else: model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()}, amp_level="O3")