Skip to content

Commit

Permalink
Fix loglikelihood in llava_hf
Browse files Browse the repository at this point in the history
  • Loading branch information
brian.li committed Nov 2, 2024
1 parent fc9dfdf commit 5720bd1
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lmms_eval/models/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:

formatted_contexts = [prompt]
formatted_continuation = [prompt_and_continuation]
model_inputs = self._image_processor(text=formatted_continuation, images=visuals).to(self._device, self.model.dtype)
model_inputs = self._image_processor(text=formatted_continuation, images=visuals, return_tensors="pt").to(self._device, self.model.dtype)
labels = model_inputs["input_ids"].clone()
contxt_id = self._image_processor(text=formatted_contexts, return_tensors="pt")["input_ids"]
labels[: len(contxt_id)] = -100
labels[:, : contxt_id.shape[1]] = -100

if self.accelerator.is_main_process and doc_id % 100 == 0:
eval_logger.debug(f"Prompt for doc ID {doc_id}:\n\n{formatted_contexts[0]}\n")
Expand Down

0 comments on commit 5720bd1

Please sign in to comment.