zamal commited on
Commit
7ac4196
1 Parent(s): dd212a2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoProcessor,
5
+ GenerationConfig,
6
+ BitsAndBytesConfig,
7
+ )
8
+ from PIL import Image
9
+ import torch
10
+
11
+ # Configuration for 4-bit quantization and GPU offloading
12
+ bnb_config = BitsAndBytesConfig(
13
+ load_in_4bit=True,
14
+ )
15
+
16
+ # Model repository
17
+ repo_name = "cyan2k/molmo-7B-O-bnb-4bit"
18
+
19
+ # Load the processor and model
20
+ processor = AutoProcessor.from_pretrained(repo_name, trust_remote_code=True)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ repo_name,
23
+ torch_dtype=torch.float16,
24
+ device_map="auto",
25
+ trust_remote_code=True,
26
+ quantization_config=bnb_config,
27
+ )
28
+
29
+ # Ensure model is on GPU
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ model.to(device)
32
+
33
+ def describe_images(images):
34
+ descriptions = []
35
+ for image in images:
36
+ if isinstance(image, str):
37
+ image = Image.open(image)
38
+ # Process the image
39
+ inputs = processor.process(
40
+ images=[image],
41
+ text="Describe this image in great detail.",
42
+ )
43
+ # Move inputs to the same device as the model
44
+ inputs = {k: v.to(device) for k, v in inputs.items()}
45
+ # Generate output
46
+ with torch.no_grad():
47
+ output = model.generate_from_batch(
48
+ inputs,
49
+ GenerationConfig(max_new_tokens=200, stop_strings=["<|endoftext|>"]),
50
+ tokenizer=processor.tokenizer,
51
+ )
52
+ # Decode generated tokens to text
53
+ generated_tokens = output[0, inputs["input_ids"].size(1):]
54
+ generated_text = processor.tokenizer.decode(
55
+ generated_tokens, skip_special_tokens=True
56
+ )
57
+ descriptions.append(generated_text.strip())
58
+ return "\n\n".join(descriptions)
59
+
60
+ # Gradio interface
61
+ with gr.Blocks() as demo:
62
+ gr.Markdown("<h3><center>Image Description Generator</center></h3>")
63
+ with gr.Row():
64
+ image_input = gr.File(
65
+ file_types=["image"], label="Upload Image(s)", multiple=True
66
+ )
67
+ generate_button = gr.Button("Generate Descriptions")
68
+ output_text = gr.Textbox(label="Descriptions", lines=15)
69
+
70
+ generate_button.click(describe_images, inputs=image_input, outputs=output_text)
71
+
72
+ demo.launch()