diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py
index 676d89f1..7c8e2028 100644
--- a/minigpt4/conversation/conversation.py
+++ b/minigpt4/conversation/conversation.py
@@ -85,123 +85,4 @@ def dict(self):
# "system_img": self.system_img,
"roles": self.roles,
"messages": self.messages,
- "offset": self.offset,
- "sep": self.sep,
- "sep2": self.sep2,
- "conv_id": self.conv_id,
- }
-
-
-class StoppingCriteriaSub(StoppingCriteria):
-
- def __init__(self, stops=[], encounters=1):
- super().__init__()
- self.stops = stops
-
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
- for stop in self.stops:
- if torch.all((stop == input_ids[0][-len(stop):])).item():
- return True
-
- return False
-
-
-CONV_VISION = Conversation(
- system="Give the following image: ImageContent. "
- "You will be able to see the image once I provide it to you. Please answer my questions.",
- roles=("Human", "Assistant"),
- messages=[],
- offset=2,
- sep_style=SeparatorStyle.SINGLE,
- sep="###",
-)
-
-
-
-class Chat:
- def __init__(self, model, vis_processor, device='cuda:0'):
- self.device = device
- self.model = model
- self.vis_processor = vis_processor
- stop_words_ids = [torch.tensor([835]).to(self.device),
- torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
- self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
-
- def ask(self, text, conv):
- if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
- and conv.messages[-1][1][-6:] == '': # last message is image.
- conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
- else:
- conv.append_message(conv.roles[0], text)
-
- def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
- repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
- conv.append_message(conv.roles[1], None)
- embs = self.get_context_emb(conv, img_list)
-
- current_max_len = embs.shape[1] + max_new_tokens
- if current_max_len - max_length > 0:
- print('Warning: The number of tokens in current conversation exceeds the max length. '
- 'The model will not see the contexts outside the range.')
- begin_idx = max(0, current_max_len - max_length)
-
- embs = embs[:, begin_idx:]
-
- outputs = self.model.llama_model.generate(
- inputs_embeds=embs,
- max_new_tokens=max_new_tokens,
- stopping_criteria=self.stopping_criteria,
- num_beams=num_beams,
- do_sample=True,
- min_length=min_length,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- length_penalty=length_penalty,
- temperature=temperature,
- )
- output_token = outputs[0]
- if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it
- output_token = output_token[1:]
- if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it
- output_token = output_token[1:]
- output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
- output_text = output_text.split('###')[0] # remove the stop sign '###'
- output_text = output_text.split('Assistant:')[-1].strip()
- conv.messages[-1][1] = output_text
- return output_text, output_token.cpu().numpy()
-
- def upload_img(self, image, conv, img_list):
- if isinstance(image, str): # is a image path
- raw_image = Image.open(image).convert('RGB')
- image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
- elif isinstance(image, Image.Image):
- raw_image = image
- image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
- elif isinstance(image, torch.Tensor):
- if len(image.shape) == 3:
- image = image.unsqueeze(0)
- image = image.to(self.device)
-
- image_emb, _ = self.model.encode_img(image)
- img_list.append(image_emb)
- conv.append_message(conv.roles[0], "")
- msg = "Received."
- # self.conv.append_message(self.conv.roles[1], msg)
- return msg
-
- def get_context_emb(self, conv, img_list):
- prompt = conv.get_prompt()
- prompt_segs = prompt.split('')
- assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
- seg_tokens = [
- self.model.llama_tokenizer(
- seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
- # only add bos to the first seg
- for i, seg in enumerate(prompt_segs)
- ]
- seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
- mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
- mixed_embs = torch.cat(mixed_embs, dim=1)
- return mixed_embs
-
-
+ "offset": self.offset