Skip to content

Latest commit

 

History

History
144 lines (117 loc) · 5.5 KB

README.md

File metadata and controls

144 lines (117 loc) · 5.5 KB

OCR_MLLM_TOY : A Multimodal Large Language Model for OCR

Our model supports image caption, VQA, especially performs well on OCR-related images. Our model supports Chinese Simplified, Chinese Traditional, and English languages. 🤩🤩🤩 Please give me a star if you find it interesting and useful! [🤗Space]

Overview


Diagram of OCR_MLLM_TOY Model.

  • [1] OCR image encoder is adopted from an end to end OCR recognition model, here we adopted the pretrain weight from vary
  • [2] We adopted VIT image encoder weight from QwenVL

Release

  • [2024/03/08] 🔥 We released the OCR_MLLM_TOY pretrain weight.
  • [2024/03/07] 🔥 We released the training and evaluation code.

Usage and License Notices: The data, and code is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Qwen, Vary

Contents

Gradio Demo

To run our gradio demo, you need to get the checkpoints from huggingface and put them in "./checkpoints/qwen14b-finetune_all/checkpoint-8300". Then run the following commands.

python -m ./ocr_mllm_gradio/my_gradio_web_server.py --host 0.0.0.0 --port 10000
commonx2.mp4
ticketsx2.mp4

OCR_MLLM_TOY Demo

Install

  1. Clone this repository
git clone https://github.com/SuXuping/OCR_MLLM_TOY.git
  1. Install Package
conda create -n OCR_MLLM_TOY python=3.10 -y
conda activate OCR_MLLM_TOY
pip install --upgrade pip
pip install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/
  1. Install additional packages for training cases
pip install ninja
pip install flash-attn --no-build-isolation
pip install deepspeed

Train

OCR_MLLM_TOY is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can:

  • [1] reduce the per_device_train_batch_size and increase the gradient_accumulation_steps accordingly.
  • [2] use lora during training.
  • [3] use LLM(7B) instead.

0、Prepare data

Prepare your data in this format:

[
    {
        "id": "image_id",
        "image": "image path",
        "conversations": [
            {
                "from": "human",
                "value": "Examine the image closely and share its details<image>\n"
            },
            {
                "from": "gpt",
                "value": "The image shows a man fishing on a lawn next to a river with a bridge in the background. Trees can be seen on the other side of the river, and the sky is cloudy."
            }
        ]
    }
]

1、Set Hyperparameters

We use a similar set of hyperparameters as LLaVA in pretraining and SFT. more details refer to pretrain and SFT

Hyperparameter Global Batch Size Learning rate Epochs Max length Weight decay deepspeed
pretrain 1024 1e-3 1 1024 0 zero2
SFT 64 2e-5 1 2048 0 zero3

2、Prepare training weights

Before you start train you own MLLM, yuu need prepare some weights:

  • [1] prepare your base LLM and put the weights in "./ocr_mllm_toy/pretrain_weight/qwen_pretrain"
  • [2] prepare vit image encoder 448 and put the weights in "./ocr_mllm_toy/pretrain_weight/qwen_vit_448"
  • [3] prepare ocr image encoder 1024 and put the weights in "./ocr_mllm_toy/pretrain_weight/vary_pretrain"

3、Run train scripts

The pretrain scripts are provided

sh ./scripts/pretrain_qwen14b.sh

The SFT scripts are provided

sh ./scripts/finetune_lora_qwen14b.sh

Attention: you need change some parameters in ".sh" file according to your device and data.

Evaluation

We have evaluated OCR_MLLM_TOY on many benchmarks including TextVQA\mm_bench\mm_bench_cn\mm_vet\MME. In some benchmarks our model can achieve similar results as LLaVA Next-34B.


Evaluation results.


Evaluation results of OCR_MLLM_TOY14B and LLaVA13B model on mm-vet benchmark.

Comparing to LLaVA13B on mm-vet benchmark, our model achieve better results in OCR related dimensions.

Please see this doc for the details.

Inference

To run our inference, you need get our weights from here huggingface, and put them in "./checkpoints/qwen14b-finetune_all/checkpoint-8300" . We provide interactive api for multimodal infence with stream outputs.

python cli.py

We also provide LLM only infence and multimodal infence in "infence.py".

python inference.py

Acknowledgement

  • LLaVA: the codebase we built upon.
  • vary: the ocr image encoder pretrain weight we built upon.