From 09e9bbcc92b6039365644cba05c0c2d8956c32c7 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Thu, 28 Nov 2024 01:03:58 -0800 Subject: [PATCH] fix dtype when init HF model from config (#11420) * fix dtype when init HF model from config Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa --------- Signed-off-by: Alexandros Koumparoulis Signed-off-by: akoumpa Co-authored-by: akoumpa --- .../llm/gpt/model/hf_auto_model_for_causal_lm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index 8f4595bd6cee..481dd9a0e187 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -41,6 +41,7 @@ def __init__( model_transform=None, model_accelerator=None, trust_remote_code=False, + default_dtype=torch.bfloat16, ): super().__init__() self.save_hyperparameters() @@ -53,6 +54,7 @@ def __init__( self.model_transform = model_transform self.model_accelerator = model_accelerator self.trust_remote_code = trust_remote_code + self.default_dtype = default_dtype @property def tokenizer(self): @@ -79,7 +81,10 @@ def configure_model(self): from transformers import AutoConfig config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code) - self.model = AutoModelForCausalLM.from_config(config, trust_remote_code=self.trust_remote_code) + dtype = getattr(config, 'torch_dtype', self.default_dtype) + self.model = AutoModelForCausalLM.from_config( + config, torch_dtype=dtype, trust_remote_code=self.trust_remote_code + ) if self.model_accelerator is not None: self.model_accelerator(self.model)