CosmosLLaVA / app.py
erndgn's picture
Upload 2 files
a701996 verified
raw
history blame contribute delete
No virus
4.88 kB
import spaces
import time
from threading import Thread
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
IMAGE_PLACEHOLDER,
)
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
process_images,
tokenizer_image_token,
get_model_name_from_path,
)
from io import BytesIO
import requests
import os
from conversation import Conversation, SeparatorStyle
model_id = "ytu-ce-cosmos/Turkish-LLaVA-v0.1"
disable_torch_init()
model_name = get_model_name_from_path(model_id)
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_id, None, model_name
)
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
elif os.path.exists(image_file):
image = Image.open(image_file).convert("RGB")
else:
raise FileNotFoundError(f"Görüntü dosyası {image_file} bulunamadı.")
return image
def infer_single_image(model_id, image_file, prompt):
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in prompt:
if model.config.mm_use_im_start_end:
prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
else:
prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
else:
if model.config.mm_use_im_start_end:
prompt = image_token_se + "\n" + prompt
else:
prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
conv = Conversation(
system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nSen bir yapay zeka asistanısın. Kullanıcı sana bir görev verecek. Amacın görevi olabildiğince sadık bir şekilde tamamlamak. Görevi yerine getirirken adım adım düşün ve adımlarını gerekçelendir.""",
roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
version="llama3",
messages=[],
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|eot_id|>",
)
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
full_prompt = conv.get_prompt()
print("full prompt: ", full_prompt)
image = load_image(image_file)
image_tensor = process_images(
[image],
image_processor,
model.config
).to(model.device, dtype=torch.float16)
input_ids = (
tokenizer_image_token(full_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
image_sizes=[image.size],
do_sample=False,
max_new_tokens=512,
use_cache=True,
)
output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return output
@spaces.GPU
def bot_streaming(message, history):
print(message)
if message["files"]:
if type(message["files"][-1]) == dict:
image = message["files"][-1]["path"]
else:
image = message["files"][-1]
else:
for hist in history:
if type(hist[0]) == tuple:
image = hist[0][0]
try:
if image is None:
gr.Error("LLaVA'nın çalışması için bir resim yüklemeniz gerekir.")
except NameError:
gr.Error("LLaVA'nın çalışması için bir resim yüklemeniz gerekir.")
prompt = message['text']
result = infer_single_image(model_id, image, prompt)
print(result)
yield result
chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Mesaj girin veya dosya yükleyin...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="Cosmos LLaVA",
examples=[{"text": "Bu kitabın adı ne?", "files": ["./book.jpg"]},
{"text": "Çiçeğin üzerinde ne var?", "files": ["./bee.jpg"]},
{"text": "Bu tatlı nasıl yapılır?", "files": ["./baklava.png"]}],
description="",
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)