diff --git a/src/autotrain/trainers/clm/__main__.py b/src/autotrain/trainers/clm/__main__.py index 1b59f219e1..177142e690 100644 --- a/src/autotrain/trainers/clm/__main__.py +++ b/src/autotrain/trainers/clm/__main__.py @@ -114,22 +114,23 @@ def train(config): bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=False, ) + config.fp16 = True elif config.use_int8: bnb_config = BitsAndBytesConfig(load_in_8bit=config.use_int8) + config.fp16 = True else: - bnb_config = BitsAndBytesConfig() + bnb_config = None model = AutoModelForCausalLM.from_pretrained( config.model, config=model_config, token=config.token, quantization_config=bnb_config, - torch_dtype=torch.float16, + torch_dtype=torch.float16 if config.fp16 else torch.float32, device_map={"": Accelerator().process_index} if torch.cuda.is_available() else None, trust_remote_code=True, use_flash_attention_2=config.use_flash_attention_2, ) - else: model = AutoModelForCausalLM.from_pretrained( config.model,