Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel]: update qwen_model_auto #9421

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 76 additions & 95 deletions paddlenlp/transformers/qwen/modeling_3D_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@

import paddle
import paddle.distributed as dist
import paddle.distributed.fleet.meta_parallel as mpu
import paddle.nn.functional as F
from paddle import Tensor, nn
from paddle import nn
from paddle.distributed import fleet
from paddle.distributed.fleet.utils import recompute
from paddle.utils import try_import
Expand Down Expand Up @@ -55,6 +54,15 @@
except:
fused_rotary_position_embedding = None

try:
from paddle.incubate.nn.functional import swiglu
except ImportError:

def swiglu(x, y=None):
if y is None:
x, y = paddle.chunk(x, chunks=2, axis=-1)
return F.silu(x) * y


def get_mesh(pp_idx=0):
mesh = fleet.auto.get_mesh()
Expand All @@ -63,31 +71,6 @@ def get_mesh(pp_idx=0):
return mesh


def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
is_fleet_init = True
tensor_parallel_degree = 1
try:
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
tensor_parallel_degree = hcg.get_model_parallel_world_size()
except:
is_fleet_init = False

if is_fleet_init and tensor_parallel_degree > 1 and y.is_distributed:
# if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group)
logits = paddle.matmul(input_parallel, y, transpose_y=False)

if tensor_parallel_output:
return logits

return paddle.distributed.collective._c_concat(logits, group=model_parallel_group)

else:
logits = paddle.matmul(x, y, transpose_y=False)
return logits


def get_triangle_upper_mask(x, mask=None):
if mask is not None:
return mask
Expand Down Expand Up @@ -148,6 +131,13 @@ def __init__(self, config, ipp=None):
global attention_cnt
self.attention_cnt = attention_cnt
attention_cnt += 1
self.c_attn.weight = dist.shard_tensor(
self.c_attn.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)]
)
self.c_attn.bias = dist.shard_tensor(self.c_attn.bias, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)])
self.c_proj.weight = dist.shard_tensor(
self.c_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)]
)

def _attn(self, query, key, value, attention_mask=None):
# Support the flash attention and normal attention
Expand Down Expand Up @@ -230,12 +220,15 @@ def forward(
# # [bz, sql, hid] ==> [bz, sql, 3*hid]
mixed_x_layer = self.c_attn(hidden_states)
# [bz, sql, 3*hid] ==> [bz, sql, hid]
target_shape = [0, 0, self.num_heads, 3 * self.head_dim]

mixed_x_layer = paddle.reshape_(mixed_x_layer, target_shape)
query, key, value = paddle.split(mixed_x_layer, num_or_sections=3, axis=-1)

# [bz, sql, hid] ==> [bz, sql, nh, hdim]
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
# query = self._split_heads(query, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)

kv_seq_len = hidden_states.shape[1]
if layer_past:
Expand Down Expand Up @@ -310,18 +303,28 @@ class QWenMLPAuto(nn.Layer):
def __init__(self, config, ipp=None):
super().__init__()
ff_dim_in = config.intermediate_size // 2
self.fuse_attention_ffn = config.fuse_attention_ffn
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias_attr=not config.no_bias)
self.ipp = ipp
self.w1.weight = dist.shard_tensor(self.w1.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)])
self.w2.weight = dist.shard_tensor(self.w2.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)])
self.c_proj.weight = dist.shard_tensor(
self.c_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)]
)

def forward(self, hidden_states):
# up
a1 = self.w1(hidden_states)
# gate
a2 = self.w2(hidden_states)
intermediate_parallel = a1 * F.silu(a2)
# # up
# a1 = self.w1(hidden_states)
# # gate
# a2 = self.w2(hidden_states)
# intermediate_parallel = a1 * F.silu(a2)
# down
if self.fuse_attention_ffn:
intermediate_parallel = swiglu(self.gate_up_fused_proj(hidden_states))
else:
intermediate_parallel = swiglu(self.w2(hidden_states), self.w1(hidden_states))
output = self.c_proj(intermediate_parallel)
return output

Expand All @@ -330,9 +333,10 @@ class QWenBlockAuto(nn.Layer):
def __init__(self, config, ipp=None, idx=None):
super().__init__()
self.config = config
self.ln_1 = QWenRMSNormAuto(config)
self.ipp = ipp
self.ln_1 = QWenRMSNormAuto(config, self.ipp)
self.attn = QWenAttentionAuto(config, ipp)
self.ln_2 = QWenRMSNormAuto(config)
self.ln_2 = QWenRMSNormAuto(config, self.ipp)
self.mlp = QWenMLPAuto(config, ipp)
self.ipp = ipp
self.idx = idx
Expand All @@ -349,7 +353,11 @@ def forward(
output_attentions=False,
):
layernorm_output = self.ln_1(hidden_states)

attention_mask = (
dist.reshard(attention_mask, get_mesh(self.ipp), [dist.Shard(0), dist.Replicate()])
if attention_mask is not None
else attention_mask
)
attn_outputs = self.attn(
layernorm_output,
layer_past=layer_past,
Expand Down Expand Up @@ -386,9 +394,6 @@ class QWenPretrainedModelAuto(PretrainedModel):
config_class = QWenConfig
base_model_prefix = "qwen"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

@classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True):

Expand Down Expand Up @@ -497,35 +502,6 @@ def _get_name_mappings(cls, config: QWenConfig) -> List[StateDictNameMapping]:
init_name_mappings(mappings)
return [StateDictNameMapping(*mapping) for mapping in mappings]

def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(
module,
(
nn.Linear,
nn.Embedding,
mpu.ColumnParallelLinear,
mpu.RowParallelLinear,
mpu.VocabParallelEmbedding,
QWenLMHeadAuto,
),
):
module.weight.set_value(
paddle.tensor.normal(mean=0.0, std=self.config.initializer_range, shape=module.weight.shape)
)
if getattr(module, "bias", None) is not None:
module.weight.set_value(paddle.zeros(shape=module.weight.shape, dtype=paddle.get_default_dtype()))

for name, p in module.named_parameters():
if name == "c_proj.weight":
p.set_value(
paddle.tensor.normal(
mean=0.0,
std=self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers),
shape=p.shape,
)
)


class QWenModelAuto(QWenPretrainedModelAuto):
def __init__(self, config):
Expand All @@ -538,29 +514,32 @@ def __init__(self, config):
self.recompute_granularity = config.recompute_granularity

self.wte = nn.Embedding(self.vocab_size, self.embed_dim)

self.wte.weight = dist.shard_tensor(self.wte.weight, get_mesh(), [dist.Replicate(), dist.Shard(0)])
self.drop = nn.Dropout(config.emb_dropout_prob)

def get_layer_ipp(layer_index):
mesh = fleet.auto.get_mesh()
if "pp" not in mesh.dim_names:
return None
else:
pp_degree = mesh.get_dim_size("pp")
layer_per_stage = math.ceil(config.num_hidden_layers / pp_degree)
return layer_index // layer_per_stage

self.h = nn.LayerList(
[
QWenBlockAuto(
config,
get_layer_ipp(i),
self.get_layer_ipp(i),
i,
)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = QWenRMSNormAuto(config)
self.ln_f = QWenRMSNormAuto(config, self.get_last_layer_ipp())

def get_layer_ipp(self, layer_index):
mesh = fleet.auto.get_mesh()
if "pp" not in mesh.dim_names:
return None
else:
pp_degree = mesh.get_dim_size("pp")
layer_per_stage = math.ceil(self.config.num_hidden_layers / pp_degree)
return layer_index // layer_per_stage

def get_last_layer_ipp(self):
return self.get_layer_ipp(self.config.num_hidden_layers - 1)

def get_input_embeddings(self):
return self.wte
Expand Down Expand Up @@ -668,7 +647,8 @@ def forward(

encoder_attention_mask = None
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
with paddle.amp.auto_cast(False):
inputs_embeds = self.wte(input_ids)

hidden_states = inputs_embeds

Expand All @@ -681,7 +661,7 @@ def forward(
neg_inf = paddle.full_like(attention_mask, paddle.finfo(paddle.bfloat16).min, dtype=paddle.bfloat16)
# dtype 4D mask
attention_mask = paddle.where(attention_mask, zero, neg_inf)

attention_mask = dist.shard_tensor(attention_mask, get_mesh(), [dist.Replicate(), dist.Replicate()])
hidden_states = self.drop(hidden_states)
hidden_states = dist.reshard(hidden_states, get_mesh(), [dist.Shard(0), dist.Replicate()])
output_shape = input_shape + [
Expand Down Expand Up @@ -718,7 +698,7 @@ def forward(
attention_mask = dist.reshard(
attention_mask,
get_mesh(block.ipp),
[dist.Shard(0), dist.Replicate()],
[dist.Replicate(), dist.Replicate()],
)
if self.enable_recompute and self.training and has_gradient and self.recompute_granularity == "full":
outputs = self.recompute_training(
Expand Down Expand Up @@ -774,15 +754,16 @@ def forward(


class QWenLMHeadAuto(nn.Layer):
def __init__(self, config: QWenConfig):
def __init__(self, config: QWenConfig, ipp=None):
super(QWenLMHeadAuto, self).__init__()
self.config = config
vocab_size = config.vocab_size

self.ipp = ipp
self.weight = self.create_parameter(
shape=[config.hidden_size, vocab_size],
dtype=paddle.get_default_dtype(),
)
self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)])

def forward(self, hidden_states, tensor_parallel_output=None):
if tensor_parallel_output is None:
Expand Down Expand Up @@ -835,7 +816,7 @@ class QWenForCausalLM3DAuto(QWenPretrainedModelAuto):
def __init__(self, config):
super().__init__(config)
self.qwen = QWenModelAuto(config)
self.lm_head = QWenLMHeadAuto(config)
self.lm_head = QWenLMHeadAuto(config, self.qwen.get_last_layer_ipp())

def forward(
self,
Expand Down Expand Up @@ -898,7 +879,7 @@ def update_cos_sin_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
self._ntk_alpha_cached = ntk_alpha
seq = paddle.arange(self._seq_len_cached)
with paddle.amp.auto_cast(enable=False):
freqs = paddle.outer(seq.astype(self.inv_freq.dtype), self.inv_freq)
freqs = paddle.outer(seq.astype(paddle.float32), self.inv_freq.astype(paddle.float32))
emb = paddle.concat([freqs, freqs], axis=-1)
self.cos_cached = emb.cos()[None, :, None, :]
self.sin_cached = emb.sin()[None, :, None, :]
Expand Down Expand Up @@ -940,7 +921,7 @@ def rms_norm_fused(x_in, w, eps):


class QWenRMSNormAuto(nn.Layer):
def __init__(self, config):
def __init__(self, config, ipp):
super().__init__()
self.config = config
self.eps = config.layer_norm_epsilon
Expand All @@ -949,14 +930,14 @@ def __init__(self, config):
dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.Constant(1.0),
)
self.weight = dist.shard_tensor(self.weight, get_mesh(ipp), [dist.Replicate(), dist.Replicate()])

def _norm(self, x):
return x * paddle.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
if self.config.use_fused_rms_norm:
return rms_norm_fused(x, self.weight, self.eps)
with paddle.amp.auto_cast(False):
variance = x.astype("float32").pow(2).mean(-1, keepdim=True)
output = paddle.rsqrt(variance + self.eps) * x

if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
output = paddle.cast(output, self.weight.dtype)
output = self._norm(x.astype(paddle.float32)).astype(x.dtype)
return output * self.weight
Loading
Loading