diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py index a26a569..9625065 100644 --- a/llmlingua/prompt_compressor.py +++ b/llmlingua/prompt_compressor.py @@ -138,7 +138,7 @@ def load_model( if "cuda" in device_map or "cpu" in device_map: model = MODEL_CLASS.from_pretrained( model_name, - torch_dtype=model_config.get( + torch_dtype=model_config.pop( "torch_dtype", "auto" if device_map == "cuda" else torch.float32 ), device_map=device_map, @@ -150,7 +150,7 @@ def load_model( model = MODEL_CLASS.from_pretrained( model_name, device_map=device_map, - torch_dtype=model_config.get("torch_dtype", "auto"), + torch_dtype=model_config.pop("torch_dtype", "auto"), pad_token_id=tokenizer.pad_token_id, **model_config, )