diff --git a/official/nlp/gpt/eval.py b/official/nlp/gpt/eval.py
index a6426cf4c41487e64798fc5c5245a376322ec151..a66971c327f89667642f7f177ed53623d8682c8e 100644
--- a/official/nlp/gpt/eval.py
+++ b/official/nlp/gpt/eval.py
@@ -23,10 +23,12 @@ import numpy as np
 from mindspore import context
 import mindspore.common.dtype as mstype
 from mindspore.common.tensor import Tensor
+from mindspore.nn.transformer.loss import CrossEntropyLoss
+from mindspore.nn.transformer.transformer import TransformerOpParallelConfig
 from mindspore.train.serialization import load_checkpoint, load_param_into_net
 from src.inference import generate
 from src.dataset import create_dataset
-from src.gpt import GPT, EvalNet, GPTWithLoss, CrossEntropyLoss
+from src.gpt import GPT, EvalNet, GPTWithLoss
 from src.utils import GPTConfig
 
 context.set_context(mode=context.GRAPH_MODE)
@@ -51,8 +53,9 @@ def get_ppl(model, dataset):
     for data in dataset:
         data = data[0].asnumpy()
         input_ids = data
+        input_mask = (data != 0).astype(np.float32)
 
-        logits = model(Tensor(input_ids, mstype.int32)).asnumpy()
+        logits = model(Tensor(input_ids, mstype.int32), Tensor(input_mask, mstype.float32)).asnumpy()
         PPL.append(logits * len(data))
         tokens += len(data)
 
@@ -74,12 +77,10 @@ def get_acc(model, dataset):
             input_mask[i][idx-1] = 0
             data[i][idx-1] = 0
 
-        length = np.sum(data != 50256, 1)
-        input_ids = data
-        logits = model(Tensor(input_ids, mstype.int32)).asnumpy()
+        logits = model(Tensor(data, mstype.int32), Tensor(input_mask, mstype.float32)).asnumpy()
         logits = logits.reshape(len(length), -1)
 
-        predicted_label = np.zeros(length.shape)
+        predicted_label = np.zeros(len(length))
         for i, idx in enumerate(length):
             predicted_label[i] = logits[i][idx-2]
 
@@ -109,7 +110,7 @@ def run_eval():
         raise ValueError("{} is not supported now".format(metrics))
 
 
-    config = GPTConfig(batch_size=16,
+    config = GPTConfig(batch_size=1,
                        seq_length=1024,
                        vocab_size=50257,
                        embedding_size=1024,
@@ -129,8 +130,9 @@ def run_eval():
     elif metrics == "acc":
         gpt_eval = EvalNet(gpt, generate=False)
     else:
-        loss = CrossEntropyLoss(config)
-        gpt_eval = GPTWithLoss(gpt, loss)
+        parallel_config = TransformerOpParallelConfig()
+        loss = CrossEntropyLoss(parallel_config.dp_mp_config)
+        gpt_eval = GPTWithLoss(gpt, loss, eos_token=0)
 
     gpt_eval.set_train(False)
     load_param_into_net(gpt_eval, ckpt_dict)
diff --git a/official/nlp/gpt/src/gpt.py b/official/nlp/gpt/src/gpt.py
index acd02a3cc26ce16235c9c1f5e92f4690cf276607..f3747ad21d103d033483025376c35ddb48b4a17b 100644
--- a/official/nlp/gpt/src/gpt.py
+++ b/official/nlp/gpt/src/gpt.py
@@ -15,168 +15,17 @@
 
 """GPT model"""
 
-import math
 import numpy as np
 import mindspore.nn as nn
-from mindspore.common.tensor import Tensor
 from mindspore.common.parameter import Parameter
 import mindspore.common.dtype as mstype
-from mindspore.common.initializer import TruncatedNormal, initializer, Normal
+from mindspore.common.initializer import TruncatedNormal, initializer
 from mindspore.ops import operations as P
 from mindspore.ops import functional as F
+from mindspore.nn.transformer.layers import _LayerNorm
+from mindspore.nn.transformer.transformer import AttentionMask, TransformerEncoder
 
 
-class LayerNorm(nn.Cell):
-    """
-    Layer Normalization
-
-    Args:
-        normalized_shape: the corresponding shape of the normalized axes
-        eps: epsilon, a small number avoiding zero division
-
-    Inputs:
-        x: input tensor
-
-    Returns:
-        rescaled_output: Tensor, returned tensor after layernorm
-    """
-    def __init__(self, normalized_shape, eps=1e-5):
-        super(LayerNorm, self).__init__()
-        self.gamma = Parameter(initializer('ones', normalized_shape))
-        self.beta = Parameter(initializer('zeros', normalized_shape))
-        self.mean = P.ReduceMean(keep_dims=True)
-        self.eps = eps
-
-    def construct(self, x):
-        mean = self.mean(x, -1)
-        variance = self.mean(F.square(x - mean), -1)
-        output = (x - mean) / F.sqrt(variance + self.eps)
-        rescaled_output = output * self.gamma + self.beta
-        return rescaled_output
-
-class Softmax(nn.Cell):
-    """
-    softmax realization
-
-    Args:
-        axis: the axis to be applied softmax
-
-    Inputs:
-        x: input tensor
-
-    Returns:
-        output: Tensor, returned tensor after softmax
-    """
-    def __init__(self, axis=-1):
-        super(Softmax, self).__init__()
-        self.max = P.ArgMaxWithValue(axis=axis, keep_dims=True)
-        self.sum = P.ReduceSum(keep_dims=True)
-        self.axis = axis
-
-    def construct(self, x):
-        _, max_value = self.max(x)
-        exp_x = F.tensor_pow(np.e, x - max_value)
-        sum_x = self.sum(exp_x, self.axis)
-        output = exp_x / sum_x
-        return output
-
-
-class Mapping(nn.Cell):
-    """
-    A mapping function with a 3d input
-
-    Args:
-        input_size: the size of the last dimension of the input tensor
-        output_size: the desired size of the last dimension of the output tensor
-        dtype: the compute datatype
-        scale: the scale factor for initialization
-
-    Inputs:
-        x: the 3d input
-
-    Returns:
-        output: Tensor, a 3d tensor after projection
-    """
-    def __init__(self, input_size, output_size, dtype, scale=1.0):
-        super(Mapping, self).__init__()
-        self.output_size = output_size
-        self.input_size = input_size
-        self.weight = Parameter(initializer(Normal(sigma=0.02*scale), [input_size, output_size]))
-        self.bias = Parameter(initializer("zeros", [output_size,]))
-        self.dtype = dtype
-        self.cast = P.Cast()
-
-    def construct(self, x):
-        out_shape = P.Shape()(x)[:-1] + (self.output_size,)
-        x = P.Reshape()(x, (-1, self.input_size))
-        x = nn.MatMul()(x, self.cast(self.weight, self.dtype)) + self.cast(self.bias, self.dtype)
-        output = P.Reshape()(x, out_shape)
-        return output
-
-
-
-class Output(nn.Cell):
-    """
-    The output mapping module for each layer
-
-    Args:
-        config(GPTConfig): the config of network
-        scale: scale factor for initialization
-
-    Inputs:
-        x: output of the self-attention module
-
-    Returns:
-        output: Tensor, the output of this layer after mapping
-    """
-    def __init__(self, config, scale=1.0):
-        super(Output, self).__init__()
-        input_size = config.embedding_size
-        output_size = config.embedding_size*config.expand_ratio
-        self.mapping = Mapping(input_size, output_size, config.compute_dtype)
-        self.projection = Mapping(output_size, input_size, config.compute_dtype, scale)
-        self.activation = nn.GELU()
-        self.dropout = nn.Dropout(1-config.dropout_rate)
-
-    def construct(self, x):
-        hidden = self.activation(self.mapping(x))
-        output = self.projection(hidden)
-        output = self.dropout(output)
-        return output
-
-class AttentionMask(nn.Cell):
-    """
-    Get the attention matrix for self-attention module
-
-    Args:
-        config(GPTConfig): the config of network
-
-    Inputs:
-        input_mask: the mask indicating whether each position is a valid input
-
-    Returns:
-        attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
-    """
-    def __init__(self, config):
-        super(AttentionMask, self).__init__()
-        self.reshape = P.Reshape()
-        self.mul = P.BatchMatMul()
-        ones = np.ones(shape=(config.seq_length, config.seq_length))
-        self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
-        self.multiply = P.Mul()
-
-
-    def construct(self, input_mask):
-        input_shape = P.Shape()(input_mask)
-        shape_right = (input_shape[0], 1, input_shape[1])
-        shape_left = input_shape + (1,)
-        mask_left = self.reshape(input_mask, shape_left)
-        mask_right = self.reshape(input_mask, shape_right)
-        attention_mask = self.mul(mask_left, mask_right)
-        lower_traiangle = P.ExpandDims()(self.lower_triangle_mask, 0)
-        attention_mask = self.multiply(attention_mask, lower_traiangle)  #bs seq_length seq_length
-        return attention_mask
-
 class EmbeddingLookup(nn.Cell):
     """
     The embedding lookup table for vocabulary
@@ -203,188 +52,6 @@ class EmbeddingLookup(nn.Cell):
         return output, self.embedding_table
 
 
-class Attention(nn.Cell):
-    """
-    Self-Attention module for each layer
-
-    Args:
-        config(GPTConfig): the config of network
-        scale: scale factor for initialization
-        layer_idx: current layer index
-    """
-    def __init__(self, config, scale=1.0, layer_idx=None):
-        super(Attention, self).__init__()
-        self.get_attention_mask = AttentionMask(config)
-        self.projection = Mapping(config.embedding_size, config.embedding_size, config.compute_dtype, scale)
-        self.split = P.Split(axis=-1, output_num=3)
-        self.transpose = P.Transpose()
-        self.reshape = P.Reshape()
-        self.n_head = config.num_heads
-        self.size_per_head = config.embedding_size // self.n_head
-        self.concat_k = P.Concat(axis=3)
-        self.concat_v = P.Concat(axis=2)
-        self.multiply_data = Tensor([-10000.0,], dtype=mstype.float32)
-        self.batch_matmul = P.BatchMatMul()
-        self.scale = scale
-        if self.scale:
-            self.scale_factor = Tensor(math.sqrt(self.size_per_head))
-        if layer_idx is not None:
-            self.coeff = math.sqrt(layer_idx * math.sqrt(self.size_per_head))
-            self.coeff = Tensor(self.coeff)
-        self.use_past = config.use_past
-        self.dropout = nn.Dropout(1-config.dropout_rate)
-        self.prob_dropout = nn.Dropout(1-config.dropout_rate)
-
-        self.dense1 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
-        self.dense2 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
-        self.dense3 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
-
-    def construct(self, x, attention_mask, layer_past=None):
-        """
-        self-attention
-
-        Inputs:
-            x: output of previous layer
-            attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
-            layer_past: the previous feature map
-
-        Returns:
-            output: Tensor, the output logit of this layer
-            layer_present: Tensor, the feature map of current layer
-        """
-
-        original_shape = F.shape(x)
-        x = F.reshape(x, (-1, original_shape[-1]))
-        query = self.dense1(x)
-        key = self.dense2(x)
-        value = self.dense3(x)
-        query = self.transpose(F.reshape(query, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3))
-        key = self.transpose(F.reshape(key, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 3, 1))
-        value = self.transpose(F.reshape(value, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3))
-        if self.use_past:
-            past_value = layer_past[1]
-            past_key = self.transpose(layer_past[0], (0, 1, 3, 2))
-            key = self.concat_k((past_key, key))
-            value = self.concat_v(past_value, value)
-        layer_present = P.Stack()([self.transpose(key, (0, 1, 3, 2)), value])
-        attention = self._attn(query, key, value, attention_mask)
-        attention_merge = self.merge_heads(attention)
-        output = self.projection(attention_merge)
-        output = self.dropout(output)
-        return output, layer_present
-
-    def split_heads(self, x, transpose):
-        """
-        split 3d tensor to 4d and switch certain axes
-
-        Inputs:
-            x: input tensor
-            transpose: tuple, the transpose sequence
-
-        Returns:
-            x_transpose: the 4d output
-        """
-        x_size = P.Shape()(x)
-        new_x_shape = x_size[:-1] + (self.n_head, self.size_per_head)
-        x = self.reshape(x, new_x_shape)
-        x_transpose = self.transpose(x, transpose)
-        return x_transpose
-
-    def merge_heads(self, x):
-        """
-        convert a 4d input to a 3d output
-
-        Inputs:
-            x: input tensor
-
-        Returns:
-            x_merge: the 3d output
-        """
-        x = self.transpose(x, (0, 2, 1, 3)) #bs, seq_length, head, size_per_head
-        x_shape = P.Shape()(x)
-        new_shape = x_shape[:-2] + (x_shape[-2]*x_shape[-1],)
-        x_merge = self.reshape(x, new_shape)
-        return x_merge
-
-    def _attn(self, query, key, value, attention_mask):
-        """
-        Get the weighted score along the seq_length
-
-        Inputs:
-            query: the query matrix
-            key: the key matrix
-            value: the value matrix
-            attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
-
-        Returns:
-            weighted_values: Tensor, the weighted sum scores
-        """
-        if not self.scale:
-            query = query / F.cast(self.coeff, F.dtype(query))
-            key = key / F.cast(self.coeff, F.dtype(key))
-
-        score = self.batch_matmul(query, key)
-        if self.scale:
-            score = score / P.Cast()(self.scale_factor, P.DType()(score))
-
-        ori_dtype = P.DType()(score)
-        score = P.Cast()(score, mstype.float32)
-        multiplu_out = P.Sub()(P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
-                               P.Cast()(attention_mask, P.DType()(score)))
-
-        adder = P.Mul()(multiplu_out, self.multiply_data)
-        attention_scores = adder + score
-
-        attention_scores = P.Cast()(attention_scores, ori_dtype)
-        attention_probs = Softmax()(attention_scores)
-
-        attention_probs = self.prob_dropout(attention_probs)
-        weighted_values = self.batch_matmul(attention_probs, value)
-        return weighted_values
-
-class Block(nn.Cell):
-    """
-    The basic block of GPT network
-
-    Args:
-        config(GPTConfig): the config of network
-        layer_idx: current layer index
-
-    Inputs:
-        x: the output of previous layer(input_ids for the first layer)
-        attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
-        layer_past: the previous feature map
-
-    Returns:
-        output: Tensor, the output logit of this layer
-        layer_present: Tensor, the feature map of current layer
-    """
-    def __init__(self, config, layer_idx):
-        super(Block, self).__init__()
-        scale = 1.0
-        self.layernorm1 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
-        self.attention = Attention(config, scale, layer_idx)
-        self.layernorm2 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
-        self.output = Output(config, scale)
-        self.post_layernorm_residual = config.post_layernorm_residual
-
-    def construct(self, x, attention_mask, layer_past=None):
-        """basic block of each layer"""
-        input_x = self.layernorm1(x)
-        attention, layer_present = self.attention(input_x, attention_mask, layer_past)
-        if self.post_layernorm_residual:
-            x = input_x + attention
-        else:
-            x = x + attention
-
-        output_x = self.layernorm2(x)
-        mlp_logit = self.output(output_x)
-        if self.post_layernorm_residual:
-            output = output_x + mlp_logit
-        else:
-            output = x + mlp_logit
-        return output, layer_present
-
 class GPT_Model(nn.Cell):
     """
     The backbone of GPT network
@@ -404,14 +71,18 @@ class GPT_Model(nn.Cell):
     """
     def __init__(self, config):
         super(GPT_Model, self).__init__()
-        self.get_attention_mask = AttentionMask(config)
+        self.get_attention_mask = AttentionMask(seq_length=config.seq_length)
         self.word_embedding = EmbeddingLookup(config)
         self.position_embedding = nn.Embedding(config.seq_length, config.embedding_size,
                                                embedding_table=TruncatedNormal(0.02))
         self.blocks = nn.CellList()
-        for i in range(config.num_layers):
-            self.blocks.append(Block(config, i+1))
-        self.layernorm = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
+        self.encoder = TransformerEncoder(batch_size=config.batch_size,
+                                          num_layers=config.num_layers,
+                                          hidden_size=config.embedding_size,
+                                          ffn_hidden_size=config.embedding_size * 4,
+                                          seq_length=config.seq_length,
+                                          num_heads=config.num_heads,)
+        self.layernorm = _LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
         self.use_past = config.use_past
         self.past = tuple([None]*config.num_layers)
         self.num_layers = config.num_layers
@@ -433,13 +104,8 @@ class GPT_Model(nn.Cell):
 
         hidden_states = P.Cast()(hidden_states, mstype.float16)
         attention_mask = self.get_attention_mask(input_mask)
-        attention_mask = P.ExpandDims()(attention_mask, 1)
-
-        present_layer = ()
-        for i in range(self.num_layers):
-            hidden_states, present = self.blocks[i](hidden_states, attention_mask, layer_past)
-            present_layer = present_layer + (present,)
 
+        hidden_states, present_layer = self.encoder(hidden_states, attention_mask)
         output_state = self.layernorm(hidden_states)
         return output_state, present_layer, embedding_table
 
@@ -495,42 +161,6 @@ class GPT(nn.Cell):
         logits = self.head(output_states, embedding_table)
         return logits
 
-class CrossEntropyLoss(nn.Cell):
-    """
-    Calculate the cross entropy loss
-
-    Args:
-        config(GPTConfig): the config of the network
-
-    Inputs:
-        logits: the output logits of the backbone
-        label: the ground truth label of the sample
-        input_mask: the mask indicating whether each position is a valid input
-
-    Returns:
-        loss: Tensor, the corrsponding cross entropy loss
-    """
-    def __init__(self, config):
-        super(CrossEntropyLoss, self).__init__()
-        self.log_softmax = nn.LogSoftmax(axis=-1)
-        self.mean = P.ReduceMean()
-        self.sum = P.ReduceSum()
-        self.onehot = P.OneHot()
-        self.on_value = Tensor(1.0, mstype.float32)
-        self.off_value = Tensor(0.0, mstype.float32)
-        self.vocab_size = config.vocab_size
-
-    def construct(self, logits, label, input_mask):
-        logits = self.log_softmax(P.Cast()(logits, mstype.float32))
-        label = P.Reshape()(label, (-1,))
-        one_hot_label = self.onehot(label, self.vocab_size, self.on_value, self.off_value)
-        loss_sum = P.Neg()(self.sum(logits*one_hot_label, (-1,)))
-        input_mask = P.Reshape()(input_mask, (-1,))
-        numerator = self.sum(loss_sum*input_mask)
-        denominator = self.sum(input_mask) + P.Cast()(F.tuple_to_array((1e-5,)), mstype.float32)
-        loss = numerator / denominator
-        return loss
-
 class GPTWithLoss(nn.Cell):
     """
     GPT training loss
@@ -558,6 +188,8 @@ class GPTWithLoss(nn.Cell):
         input_mask = F.cast(F.not_equal(tokens, self.eos_token), mstype.float32)
         logits = self.network(tokens, input_mask, past)
         labels = input_ids[:, 1:]
+        labels = P.Reshape()(labels, (-1,))
+        input_mask = P.Reshape()(input_mask, (-1,))
         output = self.loss(logits, labels, input_mask)
         return output
 
@@ -580,10 +212,11 @@ class EvalNet(nn.Cell):
         self.backbone = backbone
         self.argmax = P.Argmax()
         self.generate = generate
+        self.cast = P.Cast()
 
-    def construct(self, input_ids):
+    def construct(self, input_ids, input_mask):
         """evaluation net"""
-        input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32)
+        input_mask = self.cast(input_mask, mstype.float32)
         logits = self.backbone(input_ids, input_mask)
         outputs = None
         if self.generate:
diff --git a/official/nlp/gpt/src/inference.py b/official/nlp/gpt/src/inference.py
index f08d9fbc57ed500fef72329a760724af590ea7e4..7ada70cc7c96f691ca76c7f6ed948ebb51eb1473 100644
--- a/official/nlp/gpt/src/inference.py
+++ b/official/nlp/gpt/src/inference.py
@@ -42,7 +42,8 @@ def generate(model, origin_inputs, seq_length, end_token=50256):
     print("input_ids is ", input_ids)
     while valid_length < seq_length:
         inputs = Tensor(input_ids, mstype.int32)
-        logits = model(inputs).asnumpy()
+        inputs_mask = Tensor((input_ids != 0), mstype.float32)
+        logits = model(inputs, inputs_mask).asnumpy()
         logits = logits.reshape(bs, seq_length, -1)
         probs = logits[0, valid_length-1, :]
         p_args = probs.argsort()[::-1][:TOPK]
diff --git a/official/nlp/gpt/train.py b/official/nlp/gpt/train.py
index 17d62f7426dc958e1c501074e9ef6bfa2c017128..530caf94b038e35966dd03b966de186cb22ca03b 100644
--- a/official/nlp/gpt/train.py
+++ b/official/nlp/gpt/train.py
@@ -27,10 +27,12 @@ from mindspore.context import ParallelMode
 import mindspore.nn as nn
 from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
 from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
+from mindspore.nn.transformer.loss import CrossEntropyLoss
+from mindspore.nn.transformer.transformer import TransformerOpParallelConfig
 import mindspore.common.dtype as mstype
 from mindspore.common import set_seed
 from src.dataset import create_dataset
-from src.gpt import GPT, GPTWithLoss, CrossEntropyLoss
+from src.gpt import GPT, GPTWithLoss
 from src.gpt_wrapcell import GPTTrainOneStepWithLossScaleCell
 from src.utils import GPTConfig, LearningRate
 
@@ -49,6 +51,7 @@ def run_train():
     parser.add_argument("--start_lr", type=float, default="5e-5", help="Start learning rate, default is 5e-5.")
     parser.add_argument("--end_lr", type=float, default="1e-10", help="End learning rate, default is 1e-10.")
     parser.add_argument("--sink_size", type=int, default=100, help="Sink size for every iteration, default is 100")
+    parser.add_argument("--model_parallel_num", type=int, default=8, help="Num of model parallel, default is 8")
 
 
     args_opt = parser.parse_args()
@@ -80,7 +83,11 @@ def run_train():
                        compute_dtype=mstype.float16,
                        use_past=False)
     gpt = GPT(config)
-    loss = CrossEntropyLoss(config)
+    model_parallel_num = args_opt.model_parallel_num
+    data_parallel_num = int(device_num / model_parallel_num)
+    parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
+                                                  model_parallel=model_parallel_num)
+    loss = CrossEntropyLoss(parallel_config.dp_mp_config)
     gpt_with_loss = GPTWithLoss(gpt, loss)
 
     ds = create_dataset(config.batch_size, data_path=args_opt.data_path, device_num=device_num, rank=rank)