diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index 6912428e..5c208d90 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -170,14 +170,8 @@ def _model_generate( stop: Optional[List[str]] = None, **kwargs, ): - if "do_sample" in kwargs.keys(): - kwargs.pop("do_sample") if generate: - # hf defaults - kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False) - kwargs["spaces_between_special_tokens"] = kwargs.get( - "spaces_between_special_tokens", False - ) + kwargs = self.modify_gen_kwargs(kwargs) sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs) else: sampling_params = SamplingParams( @@ -438,3 +432,16 @@ def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]: break return continuation_logprobs, is_greedy + + @staticmethod + def modify_gen_kwargs(kwargs: dict) -> dict: + # sampling_params + do_sample = kwargs.pop("do_sample", False) + if do_sample is not True: + kwargs["temperature"] = 0.0 + # hf defaults + kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False) + kwargs["spaces_between_special_tokens"] = kwargs.get( + "spaces_between_special_tokens", False + ) + return kwargs