diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 676d89f1..7c8e2028 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -85,123 +85,4 @@ def dict(self): # "system_img": self.system_img, "roles": self.roles, "messages": self.messages, - "offset": self.offset, - "sep": self.sep, - "sep2": self.sep2, - "conv_id": self.conv_id, - } - - -class StoppingCriteriaSub(StoppingCriteria): - - def __init__(self, stops=[], encounters=1): - super().__init__() - self.stops = stops - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): - for stop in self.stops: - if torch.all((stop == input_ids[0][-len(stop):])).item(): - return True - - return False - - -CONV_VISION = Conversation( - system="Give the following image: ImageContent. " - "You will be able to see the image once I provide it to you. Please answer my questions.", - roles=("Human", "Assistant"), - messages=[], - offset=2, - sep_style=SeparatorStyle.SINGLE, - sep="###", -) - - - -class Chat: - def __init__(self, model, vis_processor, device='cuda:0'): - self.device = device - self.model = model - self.vis_processor = vis_processor - stop_words_ids = [torch.tensor([835]).to(self.device), - torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. - self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) - - def ask(self, text, conv): - if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ - and conv.messages[-1][1][-6:] == '': # last message is image. - conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) - else: - conv.append_message(conv.roles[0], text) - - def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, - repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): - conv.append_message(conv.roles[1], None) - embs = self.get_context_emb(conv, img_list) - - current_max_len = embs.shape[1] + max_new_tokens - if current_max_len - max_length > 0: - print('Warning: The number of tokens in current conversation exceeds the max length. ' - 'The model will not see the contexts outside the range.') - begin_idx = max(0, current_max_len - max_length) - - embs = embs[:, begin_idx:] - - outputs = self.model.llama_model.generate( - inputs_embeds=embs, - max_new_tokens=max_new_tokens, - stopping_criteria=self.stopping_criteria, - num_beams=num_beams, - do_sample=True, - min_length=min_length, - top_p=top_p, - repetition_penalty=repetition_penalty, - length_penalty=length_penalty, - temperature=temperature, - ) - output_token = outputs[0] - if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it - output_token = output_token[1:] - if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it - output_token = output_token[1:] - output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) - output_text = output_text.split('###')[0] # remove the stop sign '###' - output_text = output_text.split('Assistant:')[-1].strip() - conv.messages[-1][1] = output_text - return output_text, output_token.cpu().numpy() - - def upload_img(self, image, conv, img_list): - if isinstance(image, str): # is a image path - raw_image = Image.open(image).convert('RGB') - image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) - elif isinstance(image, Image.Image): - raw_image = image - image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) - elif isinstance(image, torch.Tensor): - if len(image.shape) == 3: - image = image.unsqueeze(0) - image = image.to(self.device) - - image_emb, _ = self.model.encode_img(image) - img_list.append(image_emb) - conv.append_message(conv.roles[0], "") - msg = "Received." - # self.conv.append_message(self.conv.roles[1], msg) - return msg - - def get_context_emb(self, conv, img_list): - prompt = conv.get_prompt() - prompt_segs = prompt.split('') - assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." - seg_tokens = [ - self.model.llama_tokenizer( - seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids - # only add bos to the first seg - for i, seg in enumerate(prompt_segs) - ] - seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] - mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] - mixed_embs = torch.cat(mixed_embs, dim=1) - return mixed_embs - - + "offset": self.offset