forked from zilliztech/GPTCache
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vqa_demo.py
92 lines (74 loc) · 3.45 KB
/
vqa_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# ================================================================================
# This demo comes from [minigpt4](https://github.com/Vision-CAIR/MiniGPT-4)
# and is integrated with [gptcahe](https://github.com/zilliztech/GPTCache)
# for image Question Answering.
# Please make sure you have successfully setup minigpt4.
# Run `python vqa_demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0`.
# =================================================================================
import argparse
import gradio as gr
from gptcache import cache
from gptcache.processor.pre import get_image, get_image_question
from gptcache.embedding import Timm
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
from gptcache.manager.factory import manager_factory
from gptcache.adapter.minigpt4 import MiniGPT4
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument("--dir", type=str, default=".", help="path for data storage.")
parser.add_argument("--map", action='store_true', help="use map for exact match cache.")
parser.add_argument('--no-map', dest='map', action='store_false', help="use sqlite and faiss for similar search cache.")
parser.set_defaults(map=True)
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
args = parse_args()
print("Initializing GPTCache")
if args.map:
data_manager = manager_factory("map", args.dir)
cache.init(
pre_embedding_func=get_image_question,
data_manager=data_manager
) # init with map method
else:
timm = Timm()
data_manager = manager_factory("sqlite,faiss", args.dir, vector_params={"dimension": timm.dimension})
cache.init(
pre_embedding_func=get_image,
data_manager=data_manager,
embedding_func=timm.to_embeddings,
similarity_evaluation=SearchDistanceEvaluation()
)
print("GPTCache Initialization Finished")
print("Initializing Chat")
pipeline = MiniGPT4.from_pretrained(cfg_path=args.cfg_path, gpu_id=args.gpu_id, options=args.options, return_hit=True)
print(" Chat Initialization Finished")
# ========================================
# Gradio Setting
# ========================================
title = """<h1 align="center">Demo of MiniGPT-4 and GPTCache</h1>"""
description = """<h3>This is the demo of MiniGPT-4 and GPTCache. Upload your images and ask question, and it will be cached.</h3>"""
article = """<p><a href="https://github.com/zilliztech/GPTCache"><img src="https://img.shields.io/badge/Github-Code-blue"></a></p>"""
# show examples below
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(article)
with gr.Row():
with gr.Column():
inp0 = gr.Image(source="upload", type="filepath")
inp1 = gr.Textbox(label="Question")
with gr.Column():
out0 = gr.Textbox()
out1 = gr.Textbox(label="is hit")
btn = gr.Button("Submit")
btn.click(fn=pipeline, inputs=[inp0, inp1], outputs=[out0, out1])
demo.launch(share=True)