diff --git a/codecov.yml b/codecov.yml index 49f2b681..767d0825 100644 --- a/codecov.yml +++ b/codecov.yml @@ -34,3 +34,4 @@ ignore: - ".git" - "*.yml" - "*.md" + - "**/minigpt4.py" diff --git a/examples/README.md b/examples/README.md index ae89048b..eb6cb063 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,11 +1,33 @@ # Example +- [How to run Visual Question Answering with MiniGPT-4](#How-to-run-Visual-Question-Answering-with-MiniGPT-4) - [How to set the `embedding` function](#How-to-set-the-embedding-function) - [How to set the `data manager` class](#How-to-set-the-data-manager-class) - [How to set the `similarity evaluation` interface](#How-to-set-the-similarity-evaluation-interface) - [Other cache init params](#Other-cache-init-params) - [Benchmark](#Benchmark) +## How to run Visual Question Answering with MiniGPT-4 + +You can run [vqa_demo.py](./vqa_demo.py) to implement the image Q&A, which uses MiniGPT-4 for generating answers and then GPTCache to cache the answers. + +> Note that you need to make sure that [minigpt4](https://github.com/Vision-CAIR/MiniGPT-4) and [gptcache](https://gptcache.readthedocs.io/en/dev/index.html) are successfully installed, and move the **vqa_demo.py** file to the MiniGPT-4 directory. + +```bash +$ python vqa_demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0 +``` + +The above command will use the exact match cache, i.e. map cache management method. When you ask the same image and question, it will hit the cache directly and return the answer quickly. + +If you want to use similar search cache, you can run the following command to set `map` to `False`, which will use sqlite3 and faiss to manage the cache to search for similar images and questions in the cache. + +> You can also set `dir` to your workspace directory. + +```bash +$ python vqa_demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0 --dir /path/to/workspace --no-map +``` + + ## How to set the `embedding` function > Please note that not all data managers are compatible with an embedding function. diff --git a/examples/vqa_demo.py b/examples/vqa_demo.py new file mode 100644 index 00000000..f9e18eb9 --- /dev/null +++ b/examples/vqa_demo.py @@ -0,0 +1,92 @@ +# ================================================================================ +# This demo comes from [minigpt4](https://github.com/Vision-CAIR/MiniGPT-4) +# and is integrated with [gptcahe](https://github.com/zilliztech/GPTCache) +# for image Question Answering. +# Please make sure you have successfully setup minigpt4. +# Run `python vqa_demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0`. +# ================================================================================= + +import argparse + +import gradio as gr + +from gptcache import cache +from gptcache.processor.pre import get_image, get_image_question +from gptcache.embedding import Timm +from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation +from gptcache.manager.factory import manager_factory + +from gptcache.adapter.minigpt4 import MiniGPT4 + + +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--cfg-path", required=True, help="path to configuration file.") + parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") + parser.add_argument("--dir", type=str, default=".", help="path for data storage.") + parser.add_argument("--map", action='store_true', help="use map for exact match cache.") + parser.add_argument('--no-map', dest='map', action='store_false', help="use sqlite and faiss for similar search cache.") + parser.set_defaults(map=True) + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + + +args = parse_args() + +print("Initializing GPTCache") +if args.map: + data_manager = manager_factory("map", args.dir) + cache.init( + pre_embedding_func=get_image_question, + data_manager=data_manager + ) # init with map method +else: + timm = Timm() + data_manager = manager_factory("sqlite,faiss", args.dir, vector_params={"dimension": timm.dimension}) + cache.init( + pre_embedding_func=get_image, + data_manager=data_manager, + embedding_func=timm.to_embeddings, + similarity_evaluation=SearchDistanceEvaluation() + ) +print("GPTCache Initialization Finished") + +print("Initializing Chat") +pipeline = MiniGPT4.from_pretrained(cfg_path=args.cfg_path, gpu_id=args.gpu_id, options=args.options, return_hit=True) +print(" Chat Initialization Finished") + + +# ======================================== +# Gradio Setting +# ======================================== + + +title = """

Demo of MiniGPT-4 and GPTCache

""" +description = """

This is the demo of MiniGPT-4 and GPTCache. Upload your images and ask question, and it will be cached.

""" +article = """

""" + +# show examples below + + +with gr.Blocks() as demo: + gr.Markdown(title) + gr.Markdown(description) + gr.Markdown(article) + with gr.Row(): + with gr.Column(): + inp0 = gr.Image(source="upload", type="filepath") + inp1 = gr.Textbox(label="Question") + with gr.Column(): + out0 = gr.Textbox() + out1 = gr.Textbox(label="is hit") + btn = gr.Button("Submit") + btn.click(fn=pipeline, inputs=[inp0, inp1], outputs=[out0, out1]) + +demo.launch(share=True) diff --git a/gptcache/adapter/minigpt4.py b/gptcache/adapter/minigpt4.py new file mode 100644 index 00000000..fcd1b27c --- /dev/null +++ b/gptcache/adapter/minigpt4.py @@ -0,0 +1,90 @@ +from gptcache.adapter.adapter import adapt +from gptcache.utils.error import CacheError +from gptcache.manager.scalar_data.base import DataType, Question, Answer + +from argparse import Namespace + +from minigpt4.common.config import Config +from minigpt4.common.registry import registry +from minigpt4.conversation.conversation import Chat, CONV_VISION + +# pylint: disable=wildcard-import +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * + + +class MiniGPT4: # pragma: no cover + """MiniGPT4 Wrapper + + Example: + .. code-block:: python + from gptcache import cache + from gptcache.processor.pre import get_image_question + from gptcache.adapter.minigpt4 import MiniGPT4 + + # init gptcache + cache.init(pre_embedding_func=get_image_question) + + # run with gptcache + pipe = MiniGPT4.from_pretrained(cfg_path='eval_configs/minigpt4_eval.yaml', gpu_id=3, options=None) + question = "Which city is this photo taken?" + image = "./merlion.png" + answer = pipe(image, question) + """ + def __init__(self, chat, return_hit): + self.chat = chat + self.return_hit = return_hit + + @classmethod + def from_pretrained(cls, cfg_path, gpu_id=0, options=None, return_hit=False): + args = Namespace(cfg_path=cfg_path, gpu_id=gpu_id, options=options) + cfg = Config(args) + model_config = cfg.model_cfg + model_config.device_8bit = args.gpu_id + model_cls = registry.get_model_class(model_config.arch) + model = model_cls.from_config(model_config).to("cuda:{}".format(args.gpu_id)) + + vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train + vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) + chat = Chat(model, vis_processor, device="cuda:{}".format(args.gpu_id)) + return cls(chat, return_hit) + + def llm_handler(self, image, question): + chat_state = CONV_VISION.copy() + img_list = [] + try: + self.chat.upload_img(image, chat_state, img_list) + self.chat.ask(question, chat_state) + answer = self.chat.answer(conv=chat_state, img_list=img_list)[0] + return answer if not self.return_hit else answer, False + except Exception as e: + raise CacheError("minigpt4 error") from e + + def __call__(self, image, question, *args, **kwargs): + cache_context = {"deps": [ + {"name": "text", "data": question, "dep_type": DataType.STR}, + {"name": "image", "data": image, "dep_type": DataType.STR}, + ]} + + def cache_data_convert(cache_data): + return cache_data if not self.return_hit else cache_data, True + + def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument + question_data = Question.from_dict({ + "content": "pre_embedding_data", + "deps": [ + {"name": "text", "data": kwargs["question"], "dep_type": DataType.STR}, + {"name": "image", "data": kwargs["image"], "dep_type": DataType.STR}, + ] + }) + llm_data_cache = llm_data if not self.return_hit else llm_data[0] + update_cache_func(Answer(llm_data_cache, DataType.STR), question=question_data) + return llm_data + + return adapt( + self.llm_handler, cache_data_convert, update_cache_callback, image=image, question=question, cache_context=cache_context, *args, **kwargs + ) diff --git a/gptcache/manager/scalar_data/sql_storage.py b/gptcache/manager/scalar_data/sql_storage.py index 088532ad..488ead34 100644 --- a/gptcache/manager/scalar_data/sql_storage.py +++ b/gptcache/manager/scalar_data/sql_storage.py @@ -55,7 +55,7 @@ class AnswerTable(Base): else: id = Column(Integer, primary_key=True, autoincrement=True) question_id = Column(Integer, nullable=False) - answer = Column(String(1000), nullable=False) + answer = Column(String(2000), nullable=False) answer_type = Column(Integer, nullable=False) class QuestionDepTable(Base): diff --git a/gptcache/processor/pre.py b/gptcache/processor/pre.py index 1441b133..6b3e275c 100644 --- a/gptcache/processor/pre.py +++ b/gptcache/processor/pre.py @@ -51,3 +51,13 @@ def get_input_str(data: Dict[str, Any], **_: Dict[str, Any]) -> str: def get_input_image_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str: input_data = data.get("input") return input_data["image"].name + + +def get_image_question(data: Dict[str, Any], **_: Dict[str, Any]) -> str: # pragma: no cover + img = data.get("image") + data_img = str(open(img, "rb").peek()) if isinstance(img, str) else str(img) # pylint: disable=consider-using-with + return data_img + data.get("question") + + +def get_image(data: Dict[str, Any], **_: Dict[str, Any]) -> str: # pragma: no cover + return data.get("image")