ChartGemma / app.py
Henry96's picture
Update app.py
59b685c verified
raw
history blame contribute delete
No virus
2.12 kB
import gradio as gr
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import requests
from PIL import Image
import torch
# 下载示例图片
torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/74801584018932.png', 'chart_example_1.png')
torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_2.png')
# 加载模型和处理器
model = PaliGemmaForConditionalGeneration.from_pretrained("ahmed-masry/chartgemma")
processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
def predict(image, input_text):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
image = image.convert("RGB")
inputs = processor(text=input_text, images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
prompt_length = inputs['input_ids'].shape[1]
# 生成文本
generate_ids = model.generate(**inputs, max_new_tokens=512)
output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return output_text
examples = [
["chart_example_1.png", "Describe the trend of the mortality rates for children before age 5"],
["chart_example_2.png", "What is the share of respondents who prefer Facebook Messenger in the 30-59 age group?"]
]
title = "ChartGemma 模型的互动式 Gradio 演示"
with gr.Blocks(css="theme.css") as demo:
gr.Markdown(f"# {title}")
with gr.Row():
with gr.Column():
image = gr.Image(type="pil", label="图表图像")
input_prompt = gr.Textbox(label="输入")
with gr.Column():
model_output = gr.Textbox(label="输出")
gr.Examples(examples=examples, inputs=[image, input_prompt])
submit_button = gr.Button("运行")
submit_button.click(predict, inputs=[image, input_prompt], outputs=model_output)
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)