Skip to content

Commit

Permalink
Prevent duplicate torch_dtype kwargs (#115)
Browse files Browse the repository at this point in the history
Throws an error currently if you specify a `torch_dtype` in the `model_config`.

```pycon
TypeError: transformers.models.auto.auto_factory._BaseAutoModelClass.from_pretrained() got multiple values for keyword argument 'torch_dtype'
```
  • Loading branch information
yasyf authored Mar 24, 2024
1 parent 5828ac3 commit 7a74721
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions llmlingua/prompt_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def load_model(
if "cuda" in device_map or "cpu" in device_map:
model = MODEL_CLASS.from_pretrained(
model_name,
torch_dtype=model_config.get(
torch_dtype=model_config.pop(
"torch_dtype", "auto" if device_map == "cuda" else torch.float32
),
device_map=device_map,
Expand All @@ -150,7 +150,7 @@ def load_model(
model = MODEL_CLASS.from_pretrained(
model_name,
device_map=device_map,
torch_dtype=model_config.get("torch_dtype", "auto"),
torch_dtype=model_config.pop("torch_dtype", "auto"),
pad_token_id=tokenizer.pad_token_id,
**model_config,
)
Expand Down

0 comments on commit 7a74721

Please sign in to comment.