diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index 8f4595bd6cee..481dd9a0e187 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -41,6 +41,7 @@ def __init__( model_transform=None, model_accelerator=None, trust_remote_code=False, + default_dtype=torch.bfloat16, ): super().__init__() self.save_hyperparameters() @@ -53,6 +54,7 @@ def __init__( self.model_transform = model_transform self.model_accelerator = model_accelerator self.trust_remote_code = trust_remote_code + self.default_dtype = default_dtype @property def tokenizer(self): @@ -79,7 +81,10 @@ def configure_model(self): from transformers import AutoConfig config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code) - self.model = AutoModelForCausalLM.from_config(config, trust_remote_code=self.trust_remote_code) + dtype = getattr(config, 'torch_dtype', self.default_dtype) + self.model = AutoModelForCausalLM.from_config( + config, torch_dtype=dtype, trust_remote_code=self.trust_remote_code + ) if self.model_accelerator is not None: self.model_accelerator(self.model)