Skip to content

Commit

Permalink
manage default (greedy) gen_kwargs in vllm (#1341)
Browse files Browse the repository at this point in the history
* manage default (greedy) gen_kwargs in vllm better

* mirror HF `do_sample`

* just need to set temp=0 for greedy
  • Loading branch information
baberabb authored Jan 23, 2024
1 parent 969b48b commit 081deb8
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions lm_eval/models/vllm_causallms.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,8 @@ def _model_generate(
stop: Optional[List[str]] = None,
**kwargs,
):
if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if generate:
# hf defaults
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
kwargs["spaces_between_special_tokens"] = kwargs.get(
"spaces_between_special_tokens", False
)
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
sampling_params = SamplingParams(
Expand Down Expand Up @@ -438,3 +432,16 @@ def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
break

return continuation_logprobs, is_greedy

@staticmethod
def modify_gen_kwargs(kwargs: dict) -> dict:
# sampling_params
do_sample = kwargs.pop("do_sample", False)
if do_sample is not True:
kwargs["temperature"] = 0.0
# hf defaults
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
kwargs["spaces_between_special_tokens"] = kwargs.get(
"spaces_between_special_tokens", False
)
return kwargs

0 comments on commit 081deb8

Please sign in to comment.