Skip to content
Snippets Groups Projects
Commit 1c9c8038 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1085 Fix Inference Error

Merge pull request !1085 from huangxinjing/fix_inference_error
parents d1e13cd7 c65ea229
No related branches found
No related tags found
No related merge requests found
......@@ -533,7 +533,7 @@ Please check the official [homepage](https://gitee.com/mindspore/models).
# [Requirements](#contents)
- mindspore 1.2.1 or higher version
- mindspore 1.5.0 or higher version
- jieba 0.42.1
- sentencepiece 0.1.94
- transformers >= 4.7.0
......@@ -553,4 +553,10 @@ A: It's because the feature column name in `dataset.py` is not consistent with t
Q: `ERROR: device_num must be the power of 2`.
A: The number of the cards must be the power of 2 if we use the parallel training. For example, if we want to train
the 2.6B model, the number of cards should be 2, 4, 8, 16 and so on.
\ No newline at end of file
the 2.6B model, the number of cards should be 2, 4, 8, 16 and so on.
Q: How to modify the the network's hyperparameter?
A: The pre-defined hyperparameter of the network is in the function `set_parse` of `src/pangu_alpha_config.py`. The
parameter determines the layer numbers, hidden size and so on. The data parallel number is determined in `train.py` by
`device_num / model_parallel`.
\ No newline at end of file
......@@ -299,8 +299,9 @@ class PanguAlpha_Model(Cell):
r"""forward pass of the model"""
embed, word_table = self.embedding(input_ids, input_position, init_reset, batch_valid_length)
hidden_state = P.Cast()(embed, self.dtype)
hidden_state = self.reshape_to_2d(hidden_state)
# encoder_mask = self.create_encoder_mask(encoder_masks)
# the input of the incremental prediction is 3d
if self._phase != 'predict':
hidden_state = self.reshape_to_2d(hidden_state)
if self.blocks is not None:
for i in range(self.num_layers - 1):
hidden_state, _ = self.blocks[i](hidden_state, encoder_masks, init_reset, batch_valid_length)
......@@ -311,6 +312,7 @@ class PanguAlpha_Model(Cell):
encoder_masks, init_reset, batch_valid_length)
encoder_output = self.layernorm(encoder_output)
else:
hidden_state = self.reshape_to_2d(hidden_state)
encoder_output = self.layernorm(hidden_state)
encoder_output = P.Cast()(encoder_output, self.dtype)
top_query_hidden_states, _ = self.top_query_embedding(input_position)
......@@ -468,14 +470,19 @@ class EvalNet(nn.Cell):
self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),))
self.get_attention_mask = AttentionMask(seq_length)
self.expand = P.ExpandDims().shard(((1, 1, 1),))
# used for incremental prediction
self.all_ones_attention_mask = Tensor(np.ones((1, 1, seq_length)), mstype.float32)
def construct(self, input_ids, current_index, init_reset=True, batch_valid_length=None):
"""evaluation net"""
input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32)
bs, seq_length = F.shape(input_ids)
attention_mask = self.get_attention_mask(input_mask)
input_position = F.tuple_to_array(F.make_range(seq_length))
input_position = P.Tile()(input_position, (bs, 1))
if self.is_first_iteration:
attention_mask = self.get_attention_mask(input_mask)
else:
attention_mask = P.Tile()(self.all_ones_attention_mask, (bs, 1, 1))
logits = self.backbone(input_ids, input_position, attention_mask,
init_reset, batch_valid_length)
index = current_index.view(1,)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment