Skip to content

Commit

Permalink
Converters ready
Browse files Browse the repository at this point in the history
  • Loading branch information
AleHD committed Sep 3, 2024
1 parent 930fe81 commit 4759c55
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 37 deletions.
17 changes: 14 additions & 3 deletions examples/xglm/convert_ntmoe2hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


def convert_config(config: GPT3MoEConfig) -> XGLMmoeConfig:
assert config.moe_num_experts > 1, f"Why are you using a 1-expert moe? lol"
#assert config.moe_num_experts > 1, f"Why are you using a 1-expert moe? lol"
if config.embd_pdrop != config.resid_pdrop:
warnings.warn(
f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with "
Expand Down Expand Up @@ -80,14 +80,25 @@ def convert_gate(gate_hf: BasicGate, gate_nt: LearnedRouter):
def convert_ff(ff_hf: XGLMSparseMoeBlock, ff_nt: dMoE):
convert_gate(ff_hf.gate, ff_nt.gate)
int_size = ff_nt.config.intermediate_size
if len(ff_hf.experts) == 1:
assert ff_nt.experts.mlp.w1.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size)
assert ff_nt.experts.mlp.w2.module.weight.shape == (ff_nt.config.hidden_size, int_size*len(ff_hf.experts))
else:
assert ff_nt.experts.mlp.w1.module.weight.T.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size)
assert ff_nt.experts.mlp.w2.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size)

for i, expert_hf in enumerate(ff_hf.experts):
# TODO: fc1, fc2 has bias
i0 = i*int_size
i1 = (i + 1)*int_size
with torch.no_grad():
expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight.T[i0:i1, :].clone())
if len(ff_hf.experts) == 1:
expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight[i0:i1, :].clone())
expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[:, i0:i1].clone())
else:
expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight.T[i0:i1, :].clone())
expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[i0:i1, :].T.clone())
expert_hf.fc1.bias.data.zero_()
expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[i0:i1, :].T.clone())
expert_hf.fc2.bias.data.zero_()

def convert_decoder(block_hf: XGLMDecoderLayer, block_nt: GPT3MoEBlock):
Expand Down
34 changes: 19 additions & 15 deletions examples/xglm/tests/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
#TEST_SEQUENCE_LENGTH = MAX_SEQUENCE_LENGTH
BATCH_SIZE = 4
HIDDEN_SIZE = 1024
DTYPE = torch.bfloat16
#DTYPE = torch.float32
#DTYPE = torch.bfloat16
DTYPE = torch.float32
TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:"

CONFIG = GPT3MoEConfig(
Expand All @@ -47,13 +47,13 @@
use_spda=DTYPE is not torch.bfloat16,
# vvv moe vvv
is_moe=True,
moe_num_experts=4,
num_experts_per_tok=4,
moe_num_experts=8,
num_experts_per_tok=2,
moe_loss_weight=0.01,
moe_z_loss_weight=0.0,
moe_glu=False,
)
#PARALLEL_CONFIG = ParallelismArgs(dp=1, pp=1, tp=1, expert_parallel_size=1) #CONFIG.moe_num_experts)
PARALLEL_CONFIG = ParallelismArgs(dp=1, pp=1, tp=1, expert_parallel_size=1) #CONFIG.moe_num_experts)


@pytest.fixture
Expand Down Expand Up @@ -93,11 +93,15 @@ def test_nt2hf_gate(hidden_states: torch.Tensor):
init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_gate)(hidden_states=hidden_states)


def _test_nt2hf_ff(parallel_context: ParallelContext, hidden_states: torch.Tensor):
def _test_nt2hf_ff(parallel_context: ParallelContext, hidden_states: torch.Tensor,
num_experts: int, num_experts_per_tok: int):
hidden_states = hidden_states.cuda()

config_hf = convert_config(CONFIG)
ff_nt = dMoE(CONFIG, parallel_context, None).cuda().to(DTYPE)
config = {**vars(CONFIG)}
config.update({"moe_num_experts": num_experts, "num_experts_per_tok": num_experts_per_tok})
config = GPT3MoEConfig(**config)
config_hf = convert_config(config)
ff_nt = dMoE(config, parallel_context, PARALLEL_CONFIG).cuda().to(DTYPE)
ff_hf = XGLMSparseMoeBlock(config_hf).cuda().to(DTYPE)
convert_ff(ff_hf, ff_nt)

Expand All @@ -106,9 +110,12 @@ def _test_nt2hf_ff(parallel_context: ParallelContext, hidden_states: torch.Tenso
out_hf = out_hf.permute(1, 0, 2)

assert out_nt.size() == out_hf.size()
almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.02)
#torch.testing.assert_close(out_nt, out_hf)
almost_close(out_nt, out_hf, max_far=0.05, far_atol=0.003)


@pytest.mark.parametrize("num_experts,num_experts_per_tok", [(1, 1), (2, 1), (4, 1), (4, 2), (8, 1), (8, 2), (8, 4)])
def test_nt2hf_ff(hidden_states: torch.Tensor, num_experts: int, num_experts_per_tok: int):
init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_ff)(hidden_states=hidden_states, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok)


def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor):
Expand Down Expand Up @@ -166,13 +173,10 @@ def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor
return out_nt.cpu(), out_hf.cpu()


def test_nt2hf_ff(hidden_states: torch.Tensor):
init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_ff)(hidden_states=hidden_states)


def _test_nt2hf_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor):
out_nt, out_hf = _test_nt2hf_model(parallel_context, input_ids, input_mask)
almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.02)
almost_close(out_nt, out_hf, max_far=0.01, far_atol=2.0) # We allow for less than 1% errors, but some of these are very large!
#torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16())


def test_nt2hf_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor):
Expand Down
25 changes: 6 additions & 19 deletions src/nanotron/models/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ def forward(self, hidden_states: torch.Tensor):
router_logits, expert_weights, top_experts = self.gate(x)

# Compute the experts.
#return self.experts(x, router_logits, expert_weights, top_experts)
x, lbl_loss, z_loss = self.experts(x, router_logits, expert_weights, top_experts) #REMOVE
x, lbl_loss, z_loss = self.experts(x, router_logits, expert_weights, top_experts)
return {
"hidden_states": x.reshape(batch_size, sequence_length, -1),
"load_balancing_loss": lbl_loss,
Expand Down Expand Up @@ -301,15 +300,12 @@ def forward_once(self, x, expert_weights, top_experts): # TODO: sparse
) = self.indices_and_padded_bins(top_experts)

# Route the tokens for MoE computation.
#x_pre = x.clone()
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.num_experts_per_tok)
#print("forward_once a", x.shape)

with torch.no_grad():
topo = self.topology(x, padded_bins)

x = self.mlp(x, topo) #REMOVE
#return x_pre, self.mlp(x, topo)
x = self.mlp(x, topo)

# Un-route the data for the MoE output.
x = ops.padded_scatter(
Expand Down Expand Up @@ -426,11 +422,7 @@ def forward(self, x, router_logits, expert_weights, top_experts):
top_experts: tensor of shape [sequence_length * batch_size, num_experts_per_tok]
"""
# Compute the experts.
x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) #REMOVE
#return router_logits
#print("nano b", expert_weights)
#return expert_weights.bfloat16()
#return self.forward_fn(x, expert_weights.flatten(), top_experts.flatten())
x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten())
if self.training:
lbl_loss = load_balancing_loss(router_logits, tokens_per_expert, self.config)
z_loss = router_z_loss(router_logits, self.config)
Expand Down Expand Up @@ -603,14 +595,9 @@ def __init__(

def forward(self, x, topo):
self.w1.scale_gradients(), self.w2.scale_gradients()
x = self.sdd(x.contiguous(), self.w1.module.weight, topo) # REMOVE
#x1 = self.sdd(x.contiguous(), self.w1.module.weight, topo)
activation_fn_out = act_fn(x, self.act) # REMOVE
#print(x.shape, activation_fn_out.shape, self.w2.module.weight.shape)
#activation_fn_out = act_fn(x1, self.act)
return self.dsd(activation_fn_out, self.w2.module.weight) #REMOVE
#x2 = self.dsd(activation_fn_out, self.w2.module.weight)
#return x, x1, x2, topo, self.w1.module.weight, self.w2.module.weight
x = self.sdd(x.contiguous(), self.w1.module.weight, topo)
activation_fn_out = act_fn(x, self.act)
return self.dsd(activation_fn_out, self.w2.module.weight)


class MLP(nn.Module):
Expand Down

0 comments on commit 4759c55

Please sign in to comment.