REILX's picture
Update README.md
6d6b671 verified
|
raw
history blame
No virus
4.21 kB
metadata
license: apache-2.0
datasets:
  - REILX/text-description-of-the-meme
language:
  - zh
tags:
  - llava
  - Qwen2
  - txtimage-to-txt
  - lora

模型 llava-Qwen2-7B-Instruct-Chinese-CLIP 增强中文文字识别能力和表情包内涵识别能力,达到gpt4o、claude-3.5-sonnet的识别水平!

  1. 模型结构:
    llava-Qwen2-7B-Instruct-Chinese-CLIP = Qwen/Qwen2-7B-Instruct + multi_modal_projector + OFA-Sys/chinese-clip-vit-large-patch14-336px

  2. 微调模块

  • vision_tower和language_model的q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj模块进行lora训练
  • mmp层全量训练
  1. 微调参数
  • lora_r=32,lora_alpha=64,num_train_epochs=5,per_device_train_batch_size=1,gradient_accumulation_steps=8,high_lr=1e-3,low_lr=2e-5,model_max_length=2048.
  • 设备:8*A800
  • 训练时长:5小时12分钟
  1. 数据集
    使用gemini-1.5-pro, gemini-1.5-flash, yi-vision, gpt4o,claude-3.5-sonnet模型描述emo-visual-data和ChineseBQB数据集。
    文本描述信息通过text-description-of-the-meme 下载
    图像可通过emo-visual-data, ChineseBQB下载
    图片数据总量1.8G,约10835张中文表情包图片。文字总量42Mb,约24332个图像文本对描述信息。

  2. 效果展示
    以下测试结果显示模型能识别图像中的文字信息,且能正确识别表情包想要表达的内涵。对比REILX/llava-1.5-7b-hf-meme-lora模型中也测试了原始llava-1.5-7b-hf模型的输出,模型无法正确识别图像中的文本信息。


    以下三张图为gpt4o的识别效果
  3. 代码
    推理代码

from transformers import LlavaForConditionalGeneration, AutoProcessor
import torch
from PIL import Image

raw_model_name_or_path = "/保存的完整模型路径"
model = LlavaForConditionalGeneration.from_pretrained(raw_model_name_or_path, device_map="cuda:0", torch_dtype=torch.bfloat16)
processor = AutoProcessor.from_pretrained(raw_model_name_or_path)
model.eval()

def build_model_input(model, processor):
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "<image>\n 使用中文描述图片中的信息"}
    ]
    prompt = processor.tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image = Image.open("01.PNG")
    inputs = processor(text=prompt, images=image, return_tensors="pt", return_token_type_ids=False)
    
    for tk in inputs.keys():
        inputs[tk] = inputs[tk].to(model.device)
    generate_ids = model.generate(**inputs, max_new_tokens=200)
    
    generate_ids = [
        oid[len(iids):] for oid, iids in zip(generate_ids, inputs.input_ids)
    ]
    gen_text = processor.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
    return gen_text
build_model_input(model, processor)