Skip to content

Commit

Permalink
Merge pull request #873 from google:lizhiyu/moe_ep_drop_int8
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 672790103
  • Loading branch information
maxtext authors committed Sep 10, 2024
2 parents 66684a2 + d0f3ade commit 48d1107
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
54 changes: 33 additions & 21 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,19 +289,19 @@ class MoeBlock(nn.Module):
dtype: DType = jnp.float32
quant: Optional[Quant] = None

# The first axes is expert
wi_kernel_axes = ('exp', 'embed_no_exp', 'mlp')
wo_kernel_axes = ('exp', 'mlp', 'embed_no_exp')

def generate_kernels(self, num_experts, emb_dim, mlp_dim):

kernel_in_axis = np.arange(1)
kernel_out_axis = np.arange(1, 2)
kernel_init = nd_dense_init(1.0, 'fan_in', 'truncated_normal')

# The first axes is expert
kernel_axes = ('exp', 'embed_no_exp', 'mlp')
wo_kernel_axes = ('exp', 'mlp', 'embed_no_exp')

w0_kernel = self.param(
'wi_0',
nn.with_logical_partitioning(kernel_init, kernel_axes),
nn.with_logical_partitioning(kernel_init, self.wi_kernel_axes),
(num_experts, emb_dim, mlp_dim),
self.weight_dtype,
kernel_in_axis,
Expand All @@ -310,7 +310,7 @@ def generate_kernels(self, num_experts, emb_dim, mlp_dim):
w0_kernel = jnp.asarray(w0_kernel, self.dtype)
w1_kernel = self.param(
'wi_1',
nn.with_logical_partitioning(kernel_init, kernel_axes),
nn.with_logical_partitioning(kernel_init, self.wi_kernel_axes),
(num_experts, emb_dim, mlp_dim),
self.weight_dtype,
kernel_in_axis,
Expand All @@ -319,7 +319,7 @@ def generate_kernels(self, num_experts, emb_dim, mlp_dim):
w1_kernel = jnp.asarray(w1_kernel, self.dtype)
wo_kernel = self.param(
'wo',
nn.with_logical_partitioning(kernel_init, wo_kernel_axes),
nn.with_logical_partitioning(kernel_init, self.wo_kernel_axes),
(num_experts, mlp_dim, emb_dim),
self.weight_dtype,
kernel_in_axis,
Expand Down Expand Up @@ -470,6 +470,13 @@ def load_balance_loss(self, top_k_indices, logits):
loss = jnp.mean(density * density_prob) * (self.num_experts ** 2) * self.config.load_balance_loss_weight
return loss

def get_einsum(self, rhs_mesh_axes: Tuple[Optional[str], ...] = ()):
if self.quant:
einsum_op = self.quant.einsum(rhs_mesh_axes)
else:
einsum_op = jnp.einsum
return einsum_op

def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel):
gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_length", "activation_embed"))
softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype)
Expand All @@ -479,43 +486,48 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel):
if self.config.capacity_factor > 0:
# token dropping if needed
dispatch_mask, combine_mask = self.generate_masks(top_k_indices, softmax_probs)
dispatch_mask = nn.with_logical_constraint(dispatch_mask, ("activation_batch", "activation_length", None, None))
combine_mask = nn.with_logical_constraint(combine_mask, ("activation_batch", "activation_length", None, None))
mask_axes = ("activation_batch", "activation_length", None, None)
dispatch_mask = nn.with_logical_constraint(dispatch_mask, mask_axes)
combine_mask = nn.with_logical_constraint(combine_mask, mask_axes)
loss = self.load_balance_loss(top_k_indices, softmax_probs)
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))
with jax.named_scope("dispatch"):
dispatch = jnp.einsum("BSM,BSEC -> BECM", inputs, dispatch_mask)
dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> BECM", inputs, dispatch_mask)
dispatch = nn.with_logical_constraint(dispatch, ("activation_batch_no_exp", "activation_exp", None, "activation_embed"))
with jax.named_scope("wi_0"):
w0_kernel = nn.with_logical_constraint(w0_kernel, ("exp", None, None))
layer_w0 = jnp.einsum("BECM,EMH -> BECH", dispatch, w0_kernel)
w0_kernel_axes = ("exp", None, None)
w0_kernel = nn.with_logical_constraint(w0_kernel, w0_kernel_axes)
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)("BECM,EMH -> BECH", dispatch, w0_kernel)
layer_w0 = nn.with_logical_constraint(layer_w0, ("activation_batch_no_exp", "activation_exp", None, "activation_mlp"))
with jax.named_scope("wi_1"):
w1_kernel = nn.with_logical_constraint(w1_kernel, ("exp", None, None))
layer_w1 = jnp.einsum("BECM,EMH -> BECH", dispatch, w1_kernel)
w1_kernel_axes = ("exp", None, None)
w1_kernel = nn.with_logical_constraint(w1_kernel, w1_kernel_axes)
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)("BECM,EMH -> BECH", dispatch, w1_kernel)
layer_w1 = nn.with_logical_constraint(layer_w1, ("activation_batch_no_exp", "activation_exp", None, "activation_mlp"))
layer_w0_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0)
layer_multiply = jnp.multiply(layer_w0_act, layer_w1)
with jax.named_scope("wo"):
wo_kernel = nn.with_logical_constraint(wo_kernel, ("exp", None, None))
intermediate_layer = jnp.einsum("BECH,EHM -> BECM", layer_multiply, wo_kernel)
wo_kernel_axes = ("exp", None, None)
wo_kernel = nn.with_logical_constraint(wo_kernel, wo_kernel_axes)
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)("BECH,EHM -> BECM", layer_multiply, wo_kernel)
intermediate_layer = nn.with_logical_constraint(intermediate_layer, ("activation_batch_no_exp", "activation_exp", None, "activation_embed"))
with jax.named_scope("combine"):
output = jnp.einsum("BECM,BSEC -> BSM", intermediate_layer, combine_mask)
output = self.get_einsum(rhs_mesh_axes=mask_axes)("BECM,BSEC -> BSM", intermediate_layer, combine_mask)
return output, loss
else:
weights = self.reshape_and_update_weights(top_k_weights, top_k_indices)
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))
with jax.named_scope("wi_0"):
layer_w0 = jnp.einsum("BSM,EMH -> BSEH", inputs, w0_kernel)
layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)("BSM,EMH -> BSEH", inputs, w0_kernel)
with jax.named_scope("wi_1"):
layer_w1 = jnp.einsum("BSM,EMH -> BSEH", inputs, w1_kernel)
layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)("BSM,EMH -> BSEH", inputs, w1_kernel)
layer_w0_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0)
layer_multiply = jnp.multiply(layer_w0_act, layer_w1)
with jax.named_scope("wo"):
intermediate_layer = jnp.einsum("BSEH,EHM -> BSEM", layer_multiply, wo_kernel)
intermediate_layer = self.get_einsum(rhs_mesh_axes=self.wo_kernel_axes)("BSEH,EHM -> BSEM", layer_multiply, wo_kernel)
with jax.named_scope("w_sum"):
output = jnp.einsum("BSEM,BSE -> BSM", intermediate_layer, weights)
weights_axis = ("activation_batch", "activation_length", "activation_exp")
output = self.get_einsum(rhs_mesh_axes=weights_axis)("BSEM,BSE -> BSM", intermediate_layer, weights)
return output, None

@nn.compact
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):

def einsum(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns einsum configured with aqt params."""
rhs_axis_metadata_wrapper=self._get_rhs_axis_metadata_wrapper(
rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper(
mesh_axes)
aqt_einsum = functools.partial(
aqt_flax.AqtEinsum(
Expand Down

0 comments on commit 48d1107

Please sign in to comment.