Skip to content
Snippets Groups Projects
Commit f7019cb6 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1280 modify gpt model for modelzoo

Merge pull request !1280 from lilei/modify_gpt_model
parents 136bbbd7 33bd6de8
No related branches found
No related tags found
No related merge requests found
......@@ -23,10 +23,12 @@ import numpy as np
from mindspore import context
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.nn.transformer.loss import CrossEntropyLoss
from mindspore.nn.transformer.transformer import TransformerOpParallelConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.inference import generate
from src.dataset import create_dataset
from src.gpt import GPT, EvalNet, GPTWithLoss, CrossEntropyLoss
from src.gpt import GPT, EvalNet, GPTWithLoss
from src.utils import GPTConfig
context.set_context(mode=context.GRAPH_MODE)
......@@ -51,8 +53,9 @@ def get_ppl(model, dataset):
for data in dataset:
data = data[0].asnumpy()
input_ids = data
input_mask = (data != 0).astype(np.float32)
logits = model(Tensor(input_ids, mstype.int32)).asnumpy()
logits = model(Tensor(input_ids, mstype.int32), Tensor(input_mask, mstype.float32)).asnumpy()
PPL.append(logits * len(data))
tokens += len(data)
......@@ -74,12 +77,10 @@ def get_acc(model, dataset):
input_mask[i][idx-1] = 0
data[i][idx-1] = 0
length = np.sum(data != 50256, 1)
input_ids = data
logits = model(Tensor(input_ids, mstype.int32)).asnumpy()
logits = model(Tensor(data, mstype.int32), Tensor(input_mask, mstype.float32)).asnumpy()
logits = logits.reshape(len(length), -1)
predicted_label = np.zeros(length.shape)
predicted_label = np.zeros(len(length))
for i, idx in enumerate(length):
predicted_label[i] = logits[i][idx-2]
......@@ -109,7 +110,7 @@ def run_eval():
raise ValueError("{} is not supported now".format(metrics))
config = GPTConfig(batch_size=16,
config = GPTConfig(batch_size=1,
seq_length=1024,
vocab_size=50257,
embedding_size=1024,
......@@ -129,8 +130,9 @@ def run_eval():
elif metrics == "acc":
gpt_eval = EvalNet(gpt, generate=False)
else:
loss = CrossEntropyLoss(config)
gpt_eval = GPTWithLoss(gpt, loss)
parallel_config = TransformerOpParallelConfig()
loss = CrossEntropyLoss(parallel_config.dp_mp_config)
gpt_eval = GPTWithLoss(gpt, loss, eos_token=0)
gpt_eval.set_train(False)
load_param_into_net(gpt_eval, ckpt_dict)
......
......@@ -15,168 +15,17 @@
"""GPT model"""
import math
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
import mindspore.common.dtype as mstype
from mindspore.common.initializer import TruncatedNormal, initializer, Normal
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn.transformer.layers import _LayerNorm
from mindspore.nn.transformer.transformer import AttentionMask, TransformerEncoder
class LayerNorm(nn.Cell):
"""
Layer Normalization
Args:
normalized_shape: the corresponding shape of the normalized axes
eps: epsilon, a small number avoiding zero division
Inputs:
x: input tensor
Returns:
rescaled_output: Tensor, returned tensor after layernorm
"""
def __init__(self, normalized_shape, eps=1e-5):
super(LayerNorm, self).__init__()
self.gamma = Parameter(initializer('ones', normalized_shape))
self.beta = Parameter(initializer('zeros', normalized_shape))
self.mean = P.ReduceMean(keep_dims=True)
self.eps = eps
def construct(self, x):
mean = self.mean(x, -1)
variance = self.mean(F.square(x - mean), -1)
output = (x - mean) / F.sqrt(variance + self.eps)
rescaled_output = output * self.gamma + self.beta
return rescaled_output
class Softmax(nn.Cell):
"""
softmax realization
Args:
axis: the axis to be applied softmax
Inputs:
x: input tensor
Returns:
output: Tensor, returned tensor after softmax
"""
def __init__(self, axis=-1):
super(Softmax, self).__init__()
self.max = P.ArgMaxWithValue(axis=axis, keep_dims=True)
self.sum = P.ReduceSum(keep_dims=True)
self.axis = axis
def construct(self, x):
_, max_value = self.max(x)
exp_x = F.tensor_pow(np.e, x - max_value)
sum_x = self.sum(exp_x, self.axis)
output = exp_x / sum_x
return output
class Mapping(nn.Cell):
"""
A mapping function with a 3d input
Args:
input_size: the size of the last dimension of the input tensor
output_size: the desired size of the last dimension of the output tensor
dtype: the compute datatype
scale: the scale factor for initialization
Inputs:
x: the 3d input
Returns:
output: Tensor, a 3d tensor after projection
"""
def __init__(self, input_size, output_size, dtype, scale=1.0):
super(Mapping, self).__init__()
self.output_size = output_size
self.input_size = input_size
self.weight = Parameter(initializer(Normal(sigma=0.02*scale), [input_size, output_size]))
self.bias = Parameter(initializer("zeros", [output_size,]))
self.dtype = dtype
self.cast = P.Cast()
def construct(self, x):
out_shape = P.Shape()(x)[:-1] + (self.output_size,)
x = P.Reshape()(x, (-1, self.input_size))
x = nn.MatMul()(x, self.cast(self.weight, self.dtype)) + self.cast(self.bias, self.dtype)
output = P.Reshape()(x, out_shape)
return output
class Output(nn.Cell):
"""
The output mapping module for each layer
Args:
config(GPTConfig): the config of network
scale: scale factor for initialization
Inputs:
x: output of the self-attention module
Returns:
output: Tensor, the output of this layer after mapping
"""
def __init__(self, config, scale=1.0):
super(Output, self).__init__()
input_size = config.embedding_size
output_size = config.embedding_size*config.expand_ratio
self.mapping = Mapping(input_size, output_size, config.compute_dtype)
self.projection = Mapping(output_size, input_size, config.compute_dtype, scale)
self.activation = nn.GELU()
self.dropout = nn.Dropout(1-config.dropout_rate)
def construct(self, x):
hidden = self.activation(self.mapping(x))
output = self.projection(hidden)
output = self.dropout(output)
return output
class AttentionMask(nn.Cell):
"""
Get the attention matrix for self-attention module
Args:
config(GPTConfig): the config of network
Inputs:
input_mask: the mask indicating whether each position is a valid input
Returns:
attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
"""
def __init__(self, config):
super(AttentionMask, self).__init__()
self.reshape = P.Reshape()
self.mul = P.BatchMatMul()
ones = np.ones(shape=(config.seq_length, config.seq_length))
self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
self.multiply = P.Mul()
def construct(self, input_mask):
input_shape = P.Shape()(input_mask)
shape_right = (input_shape[0], 1, input_shape[1])
shape_left = input_shape + (1,)
mask_left = self.reshape(input_mask, shape_left)
mask_right = self.reshape(input_mask, shape_right)
attention_mask = self.mul(mask_left, mask_right)
lower_traiangle = P.ExpandDims()(self.lower_triangle_mask, 0)
attention_mask = self.multiply(attention_mask, lower_traiangle) #bs seq_length seq_length
return attention_mask
class EmbeddingLookup(nn.Cell):
"""
The embedding lookup table for vocabulary
......@@ -203,188 +52,6 @@ class EmbeddingLookup(nn.Cell):
return output, self.embedding_table
class Attention(nn.Cell):
"""
Self-Attention module for each layer
Args:
config(GPTConfig): the config of network
scale: scale factor for initialization
layer_idx: current layer index
"""
def __init__(self, config, scale=1.0, layer_idx=None):
super(Attention, self).__init__()
self.get_attention_mask = AttentionMask(config)
self.projection = Mapping(config.embedding_size, config.embedding_size, config.compute_dtype, scale)
self.split = P.Split(axis=-1, output_num=3)
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.n_head = config.num_heads
self.size_per_head = config.embedding_size // self.n_head
self.concat_k = P.Concat(axis=3)
self.concat_v = P.Concat(axis=2)
self.multiply_data = Tensor([-10000.0,], dtype=mstype.float32)
self.batch_matmul = P.BatchMatMul()
self.scale = scale
if self.scale:
self.scale_factor = Tensor(math.sqrt(self.size_per_head))
if layer_idx is not None:
self.coeff = math.sqrt(layer_idx * math.sqrt(self.size_per_head))
self.coeff = Tensor(self.coeff)
self.use_past = config.use_past
self.dropout = nn.Dropout(1-config.dropout_rate)
self.prob_dropout = nn.Dropout(1-config.dropout_rate)
self.dense1 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
self.dense2 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
self.dense3 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
def construct(self, x, attention_mask, layer_past=None):
"""
self-attention
Inputs:
x: output of previous layer
attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
layer_past: the previous feature map
Returns:
output: Tensor, the output logit of this layer
layer_present: Tensor, the feature map of current layer
"""
original_shape = F.shape(x)
x = F.reshape(x, (-1, original_shape[-1]))
query = self.dense1(x)
key = self.dense2(x)
value = self.dense3(x)
query = self.transpose(F.reshape(query, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3))
key = self.transpose(F.reshape(key, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 3, 1))
value = self.transpose(F.reshape(value, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3))
if self.use_past:
past_value = layer_past[1]
past_key = self.transpose(layer_past[0], (0, 1, 3, 2))
key = self.concat_k((past_key, key))
value = self.concat_v(past_value, value)
layer_present = P.Stack()([self.transpose(key, (0, 1, 3, 2)), value])
attention = self._attn(query, key, value, attention_mask)
attention_merge = self.merge_heads(attention)
output = self.projection(attention_merge)
output = self.dropout(output)
return output, layer_present
def split_heads(self, x, transpose):
"""
split 3d tensor to 4d and switch certain axes
Inputs:
x: input tensor
transpose: tuple, the transpose sequence
Returns:
x_transpose: the 4d output
"""
x_size = P.Shape()(x)
new_x_shape = x_size[:-1] + (self.n_head, self.size_per_head)
x = self.reshape(x, new_x_shape)
x_transpose = self.transpose(x, transpose)
return x_transpose
def merge_heads(self, x):
"""
convert a 4d input to a 3d output
Inputs:
x: input tensor
Returns:
x_merge: the 3d output
"""
x = self.transpose(x, (0, 2, 1, 3)) #bs, seq_length, head, size_per_head
x_shape = P.Shape()(x)
new_shape = x_shape[:-2] + (x_shape[-2]*x_shape[-1],)
x_merge = self.reshape(x, new_shape)
return x_merge
def _attn(self, query, key, value, attention_mask):
"""
Get the weighted score along the seq_length
Inputs:
query: the query matrix
key: the key matrix
value: the value matrix
attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
Returns:
weighted_values: Tensor, the weighted sum scores
"""
if not self.scale:
query = query / F.cast(self.coeff, F.dtype(query))
key = key / F.cast(self.coeff, F.dtype(key))
score = self.batch_matmul(query, key)
if self.scale:
score = score / P.Cast()(self.scale_factor, P.DType()(score))
ori_dtype = P.DType()(score)
score = P.Cast()(score, mstype.float32)
multiplu_out = P.Sub()(P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
P.Cast()(attention_mask, P.DType()(score)))
adder = P.Mul()(multiplu_out, self.multiply_data)
attention_scores = adder + score
attention_scores = P.Cast()(attention_scores, ori_dtype)
attention_probs = Softmax()(attention_scores)
attention_probs = self.prob_dropout(attention_probs)
weighted_values = self.batch_matmul(attention_probs, value)
return weighted_values
class Block(nn.Cell):
"""
The basic block of GPT network
Args:
config(GPTConfig): the config of network
layer_idx: current layer index
Inputs:
x: the output of previous layer(input_ids for the first layer)
attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
layer_past: the previous feature map
Returns:
output: Tensor, the output logit of this layer
layer_present: Tensor, the feature map of current layer
"""
def __init__(self, config, layer_idx):
super(Block, self).__init__()
scale = 1.0
self.layernorm1 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
self.attention = Attention(config, scale, layer_idx)
self.layernorm2 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
self.output = Output(config, scale)
self.post_layernorm_residual = config.post_layernorm_residual
def construct(self, x, attention_mask, layer_past=None):
"""basic block of each layer"""
input_x = self.layernorm1(x)
attention, layer_present = self.attention(input_x, attention_mask, layer_past)
if self.post_layernorm_residual:
x = input_x + attention
else:
x = x + attention
output_x = self.layernorm2(x)
mlp_logit = self.output(output_x)
if self.post_layernorm_residual:
output = output_x + mlp_logit
else:
output = x + mlp_logit
return output, layer_present
class GPT_Model(nn.Cell):
"""
The backbone of GPT network
......@@ -404,14 +71,18 @@ class GPT_Model(nn.Cell):
"""
def __init__(self, config):
super(GPT_Model, self).__init__()
self.get_attention_mask = AttentionMask(config)
self.get_attention_mask = AttentionMask(seq_length=config.seq_length)
self.word_embedding = EmbeddingLookup(config)
self.position_embedding = nn.Embedding(config.seq_length, config.embedding_size,
embedding_table=TruncatedNormal(0.02))
self.blocks = nn.CellList()
for i in range(config.num_layers):
self.blocks.append(Block(config, i+1))
self.layernorm = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
self.encoder = TransformerEncoder(batch_size=config.batch_size,
num_layers=config.num_layers,
hidden_size=config.embedding_size,
ffn_hidden_size=config.embedding_size * 4,
seq_length=config.seq_length,
num_heads=config.num_heads,)
self.layernorm = _LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
self.use_past = config.use_past
self.past = tuple([None]*config.num_layers)
self.num_layers = config.num_layers
......@@ -433,13 +104,8 @@ class GPT_Model(nn.Cell):
hidden_states = P.Cast()(hidden_states, mstype.float16)
attention_mask = self.get_attention_mask(input_mask)
attention_mask = P.ExpandDims()(attention_mask, 1)
present_layer = ()
for i in range(self.num_layers):
hidden_states, present = self.blocks[i](hidden_states, attention_mask, layer_past)
present_layer = present_layer + (present,)
hidden_states, present_layer = self.encoder(hidden_states, attention_mask)
output_state = self.layernorm(hidden_states)
return output_state, present_layer, embedding_table
......@@ -495,42 +161,6 @@ class GPT(nn.Cell):
logits = self.head(output_states, embedding_table)
return logits
class CrossEntropyLoss(nn.Cell):
"""
Calculate the cross entropy loss
Args:
config(GPTConfig): the config of the network
Inputs:
logits: the output logits of the backbone
label: the ground truth label of the sample
input_mask: the mask indicating whether each position is a valid input
Returns:
loss: Tensor, the corrsponding cross entropy loss
"""
def __init__(self, config):
super(CrossEntropyLoss, self).__init__()
self.log_softmax = nn.LogSoftmax(axis=-1)
self.mean = P.ReduceMean()
self.sum = P.ReduceSum()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.vocab_size = config.vocab_size
def construct(self, logits, label, input_mask):
logits = self.log_softmax(P.Cast()(logits, mstype.float32))
label = P.Reshape()(label, (-1,))
one_hot_label = self.onehot(label, self.vocab_size, self.on_value, self.off_value)
loss_sum = P.Neg()(self.sum(logits*one_hot_label, (-1,)))
input_mask = P.Reshape()(input_mask, (-1,))
numerator = self.sum(loss_sum*input_mask)
denominator = self.sum(input_mask) + P.Cast()(F.tuple_to_array((1e-5,)), mstype.float32)
loss = numerator / denominator
return loss
class GPTWithLoss(nn.Cell):
"""
GPT training loss
......@@ -558,6 +188,8 @@ class GPTWithLoss(nn.Cell):
input_mask = F.cast(F.not_equal(tokens, self.eos_token), mstype.float32)
logits = self.network(tokens, input_mask, past)
labels = input_ids[:, 1:]
labels = P.Reshape()(labels, (-1,))
input_mask = P.Reshape()(input_mask, (-1,))
output = self.loss(logits, labels, input_mask)
return output
......@@ -580,10 +212,11 @@ class EvalNet(nn.Cell):
self.backbone = backbone
self.argmax = P.Argmax()
self.generate = generate
self.cast = P.Cast()
def construct(self, input_ids):
def construct(self, input_ids, input_mask):
"""evaluation net"""
input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32)
input_mask = self.cast(input_mask, mstype.float32)
logits = self.backbone(input_ids, input_mask)
outputs = None
if self.generate:
......
......@@ -42,7 +42,8 @@ def generate(model, origin_inputs, seq_length, end_token=50256):
print("input_ids is ", input_ids)
while valid_length < seq_length:
inputs = Tensor(input_ids, mstype.int32)
logits = model(inputs).asnumpy()
inputs_mask = Tensor((input_ids != 0), mstype.float32)
logits = model(inputs, inputs_mask).asnumpy()
logits = logits.reshape(bs, seq_length, -1)
probs = logits[0, valid_length-1, :]
p_args = probs.argsort()[::-1][:TOPK]
......
......@@ -27,10 +27,12 @@ from mindspore.context import ParallelMode
import mindspore.nn as nn
from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.transformer.loss import CrossEntropyLoss
from mindspore.nn.transformer.transformer import TransformerOpParallelConfig
import mindspore.common.dtype as mstype
from mindspore.common import set_seed
from src.dataset import create_dataset
from src.gpt import GPT, GPTWithLoss, CrossEntropyLoss
from src.gpt import GPT, GPTWithLoss
from src.gpt_wrapcell import GPTTrainOneStepWithLossScaleCell
from src.utils import GPTConfig, LearningRate
......@@ -49,6 +51,7 @@ def run_train():
parser.add_argument("--start_lr", type=float, default="5e-5", help="Start learning rate, default is 5e-5.")
parser.add_argument("--end_lr", type=float, default="1e-10", help="End learning rate, default is 1e-10.")
parser.add_argument("--sink_size", type=int, default=100, help="Sink size for every iteration, default is 100")
parser.add_argument("--model_parallel_num", type=int, default=8, help="Num of model parallel, default is 8")
args_opt = parser.parse_args()
......@@ -80,7 +83,11 @@ def run_train():
compute_dtype=mstype.float16,
use_past=False)
gpt = GPT(config)
loss = CrossEntropyLoss(config)
model_parallel_num = args_opt.model_parallel_num
data_parallel_num = int(device_num / model_parallel_num)
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
model_parallel=model_parallel_num)
loss = CrossEntropyLoss(parallel_config.dp_mp_config)
gpt_with_loss = GPTWithLoss(gpt, loss)
ds = create_dataset(config.batch_size, data_path=args_opt.data_path, device_num=device_num, rank=rank)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment