Skip to content

Commit

Permalink
Fix parallel_embedding (#10975) (#10996)
Browse files Browse the repository at this point in the history
Co-authored-by: meatybobby <[email protected]>
  • Loading branch information
ko3n1g and meatybobby authored Oct 22, 2024
1 parent 8fc188b commit 931cfbf
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions nemo/export/trt_llm/converter/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,7 @@ def model_to_trtllm_ckpt(

if mapping.is_first_pp_rank():
embedding_weight = (
np.ascontiguousarray(
split(weights_dict["transformer.vocab_embedding.weight"], mapping.tp_size, mapping.tp_rank)
)
split(weights_dict["transformer.vocab_embedding.weight"], mapping.tp_size, mapping.tp_rank)
if use_parallel_embedding
else weights_dict["transformer.vocab_embedding.weight"]
)
Expand All @@ -272,9 +270,7 @@ def model_to_trtllm_ckpt(
pos_embedding_weight = weights_dict.get("transformer.position_embedding.weight")
if pos_embedding_weight is not None:
if use_parallel_embedding:
pos_embedding_weight = np.ascontiguousarray(
split(pos_embedding_weight, mapping.tp_size, mapping.tp_rank)
)
pos_embedding_weight = split(pos_embedding_weight, mapping.tp_size, mapping.tp_rank)
weights_dict_local["transformer.position_embedding.weight"] = pos_embedding_weight

if mapping.is_last_pp_rank():
Expand Down

0 comments on commit 931cfbf

Please sign in to comment.