From 4759c551862de1b7a30ee7915baa5be5acb50a69 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 3 Sep 2024 10:03:50 +0000 Subject: [PATCH] Converters ready --- examples/xglm/convert_ntmoe2hf.py | 17 +++++++++++++--- examples/xglm/tests/test_moe.py | 34 +++++++++++++++++-------------- src/nanotron/models/moe.py | 25 ++++++----------------- 3 files changed, 39 insertions(+), 37 deletions(-) diff --git a/examples/xglm/convert_ntmoe2hf.py b/examples/xglm/convert_ntmoe2hf.py index d8801caa..7acbb264 100644 --- a/examples/xglm/convert_ntmoe2hf.py +++ b/examples/xglm/convert_ntmoe2hf.py @@ -29,7 +29,7 @@ def convert_config(config: GPT3MoEConfig) -> XGLMmoeConfig: - assert config.moe_num_experts > 1, f"Why are you using a 1-expert moe? lol" + #assert config.moe_num_experts > 1, f"Why are you using a 1-expert moe? lol" if config.embd_pdrop != config.resid_pdrop: warnings.warn( f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " @@ -80,14 +80,25 @@ def convert_gate(gate_hf: BasicGate, gate_nt: LearnedRouter): def convert_ff(ff_hf: XGLMSparseMoeBlock, ff_nt: dMoE): convert_gate(ff_hf.gate, ff_nt.gate) int_size = ff_nt.config.intermediate_size + if len(ff_hf.experts) == 1: + assert ff_nt.experts.mlp.w1.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size) + assert ff_nt.experts.mlp.w2.module.weight.shape == (ff_nt.config.hidden_size, int_size*len(ff_hf.experts)) + else: + assert ff_nt.experts.mlp.w1.module.weight.T.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size) + assert ff_nt.experts.mlp.w2.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size) + for i, expert_hf in enumerate(ff_hf.experts): # TODO: fc1, fc2 has bias i0 = i*int_size i1 = (i + 1)*int_size with torch.no_grad(): - expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight.T[i0:i1, :].clone()) + if len(ff_hf.experts) == 1: + expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight[i0:i1, :].clone()) + expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[:, i0:i1].clone()) + else: + expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight.T[i0:i1, :].clone()) + expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[i0:i1, :].T.clone()) expert_hf.fc1.bias.data.zero_() - expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[i0:i1, :].T.clone()) expert_hf.fc2.bias.data.zero_() def convert_decoder(block_hf: XGLMDecoderLayer, block_nt: GPT3MoEBlock): diff --git a/examples/xglm/tests/test_moe.py b/examples/xglm/tests/test_moe.py index 7be847c7..7dda7ca4 100644 --- a/examples/xglm/tests/test_moe.py +++ b/examples/xglm/tests/test_moe.py @@ -24,8 +24,8 @@ #TEST_SEQUENCE_LENGTH = MAX_SEQUENCE_LENGTH BATCH_SIZE = 4 HIDDEN_SIZE = 1024 -DTYPE = torch.bfloat16 -#DTYPE = torch.float32 +#DTYPE = torch.bfloat16 +DTYPE = torch.float32 TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" CONFIG = GPT3MoEConfig( @@ -47,13 +47,13 @@ use_spda=DTYPE is not torch.bfloat16, # vvv moe vvv is_moe=True, - moe_num_experts=4, - num_experts_per_tok=4, + moe_num_experts=8, + num_experts_per_tok=2, moe_loss_weight=0.01, moe_z_loss_weight=0.0, moe_glu=False, ) -#PARALLEL_CONFIG = ParallelismArgs(dp=1, pp=1, tp=1, expert_parallel_size=1) #CONFIG.moe_num_experts) +PARALLEL_CONFIG = ParallelismArgs(dp=1, pp=1, tp=1, expert_parallel_size=1) #CONFIG.moe_num_experts) @pytest.fixture @@ -93,11 +93,15 @@ def test_nt2hf_gate(hidden_states: torch.Tensor): init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_gate)(hidden_states=hidden_states) -def _test_nt2hf_ff(parallel_context: ParallelContext, hidden_states: torch.Tensor): +def _test_nt2hf_ff(parallel_context: ParallelContext, hidden_states: torch.Tensor, + num_experts: int, num_experts_per_tok: int): hidden_states = hidden_states.cuda() - config_hf = convert_config(CONFIG) - ff_nt = dMoE(CONFIG, parallel_context, None).cuda().to(DTYPE) + config = {**vars(CONFIG)} + config.update({"moe_num_experts": num_experts, "num_experts_per_tok": num_experts_per_tok}) + config = GPT3MoEConfig(**config) + config_hf = convert_config(config) + ff_nt = dMoE(config, parallel_context, PARALLEL_CONFIG).cuda().to(DTYPE) ff_hf = XGLMSparseMoeBlock(config_hf).cuda().to(DTYPE) convert_ff(ff_hf, ff_nt) @@ -106,9 +110,12 @@ def _test_nt2hf_ff(parallel_context: ParallelContext, hidden_states: torch.Tenso out_hf = out_hf.permute(1, 0, 2) assert out_nt.size() == out_hf.size() - almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.02) - #torch.testing.assert_close(out_nt, out_hf) + almost_close(out_nt, out_hf, max_far=0.05, far_atol=0.003) + +@pytest.mark.parametrize("num_experts,num_experts_per_tok", [(1, 1), (2, 1), (4, 1), (4, 2), (8, 1), (8, 2), (8, 4)]) +def test_nt2hf_ff(hidden_states: torch.Tensor, num_experts: int, num_experts_per_tok: int): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_ff)(hidden_states=hidden_states, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok) def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): @@ -166,13 +173,10 @@ def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor return out_nt.cpu(), out_hf.cpu() -def test_nt2hf_ff(hidden_states: torch.Tensor): - init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_ff)(hidden_states=hidden_states) - - def _test_nt2hf_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): out_nt, out_hf = _test_nt2hf_model(parallel_context, input_ids, input_mask) - almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.02) + almost_close(out_nt, out_hf, max_far=0.01, far_atol=2.0) # We allow for less than 1% errors, but some of these are very large! + #torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) def test_nt2hf_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index 01986e3f..98add57d 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -162,8 +162,7 @@ def forward(self, hidden_states: torch.Tensor): router_logits, expert_weights, top_experts = self.gate(x) # Compute the experts. - #return self.experts(x, router_logits, expert_weights, top_experts) - x, lbl_loss, z_loss = self.experts(x, router_logits, expert_weights, top_experts) #REMOVE + x, lbl_loss, z_loss = self.experts(x, router_logits, expert_weights, top_experts) return { "hidden_states": x.reshape(batch_size, sequence_length, -1), "load_balancing_loss": lbl_loss, @@ -301,15 +300,12 @@ def forward_once(self, x, expert_weights, top_experts): # TODO: sparse ) = self.indices_and_padded_bins(top_experts) # Route the tokens for MoE computation. - #x_pre = x.clone() x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.num_experts_per_tok) - #print("forward_once a", x.shape) with torch.no_grad(): topo = self.topology(x, padded_bins) - x = self.mlp(x, topo) #REMOVE - #return x_pre, self.mlp(x, topo) + x = self.mlp(x, topo) # Un-route the data for the MoE output. x = ops.padded_scatter( @@ -426,11 +422,7 @@ def forward(self, x, router_logits, expert_weights, top_experts): top_experts: tensor of shape [sequence_length * batch_size, num_experts_per_tok] """ # Compute the experts. - x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) #REMOVE - #return router_logits - #print("nano b", expert_weights) - #return expert_weights.bfloat16() - #return self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) + x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) if self.training: lbl_loss = load_balancing_loss(router_logits, tokens_per_expert, self.config) z_loss = router_z_loss(router_logits, self.config) @@ -603,14 +595,9 @@ def __init__( def forward(self, x, topo): self.w1.scale_gradients(), self.w2.scale_gradients() - x = self.sdd(x.contiguous(), self.w1.module.weight, topo) # REMOVE - #x1 = self.sdd(x.contiguous(), self.w1.module.weight, topo) - activation_fn_out = act_fn(x, self.act) # REMOVE - #print(x.shape, activation_fn_out.shape, self.w2.module.weight.shape) - #activation_fn_out = act_fn(x1, self.act) - return self.dsd(activation_fn_out, self.w2.module.weight) #REMOVE - #x2 = self.dsd(activation_fn_out, self.w2.module.weight) - #return x, x1, x2, topo, self.w1.module.weight, self.w2.module.weight + x = self.sdd(x.contiguous(), self.w1.module.weight, topo) + activation_fn_out = act_fn(x, self.act) + return self.dsd(activation_fn_out, self.w2.module.weight) class MLP(nn.Module):