erndgn commited on
Commit
94e7301
1 Parent(s): 14deb22

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +2 -0
  2. README.md +14 -12
  3. app.py +144 -141
  4. baklava.png +3 -0
  5. bee.jpg +3 -0
  6. conversation.py +209 -0
  7. requirements.txt +2 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ baklava.png filter=lfs diff=lfs merge=lfs -text
37
+ bee.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,14 @@
1
- ---
2
- title: CosmosLLaVA
3
- emoji: 📉
4
- colorFrom: indigo
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.41.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
+ ---
2
+ title: Try CosmosLLaVA
3
+ emoji: 🔥
4
+ colorFrom: yellow
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.28.3
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: The best open source Turkish vision model
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,141 +1,144 @@
1
- import spaces
2
-
3
- import time
4
- from threading import Thread
5
-
6
- import gradio as gr
7
- import torch
8
- from PIL import Image
9
- from transformers import AutoProcessor
10
- from llava.constants import (
11
- IMAGE_TOKEN_INDEX,
12
- DEFAULT_IMAGE_TOKEN,
13
- DEFAULT_IM_START_TOKEN,
14
- DEFAULT_IM_END_TOKEN,
15
- IMAGE_PLACEHOLDER,
16
- )
17
- from llava.model.builder import load_pretrained_model
18
- from llava.utils import disable_torch_init
19
- from llava.mm_utils import (
20
- process_images,
21
- tokenizer_image_token,
22
- get_model_name_from_path,
23
- )
24
- from io import BytesIO
25
- import requests
26
- import os
27
- from conversation import Conversation, SeparatorStyle
28
-
29
- model_id = "ytu-ce-cosmos/Turkish-LLaVA-v0.1"
30
-
31
- disable_torch_init()
32
- model_name = get_model_name_from_path(model_id)
33
- tokenizer, model, image_processor, context_len = load_pretrained_model(
34
- model_id, None, model_name
35
- )
36
-
37
- def load_image(image_file):
38
- if image_file.startswith("http") or image_file.startswith("https"):
39
- response = requests.get(image_file)
40
- image = Image.open(BytesIO(response.content)).convert("RGB")
41
- elif os.path.exists(image_file):
42
- image = Image.open(image_file).convert("RGB")
43
- else:
44
- raise FileNotFoundError(f"Image file {image_file} not found.")
45
- return image
46
-
47
- def infer_single_image(model_id, image_file, prompt):
48
- image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
49
- if IMAGE_PLACEHOLDER in prompt:
50
- if model.config.mm_use_im_start_end:
51
- prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
52
- else:
53
- prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
54
- else:
55
- if model.config.mm_use_im_start_end:
56
- prompt = image_token_se + "\n" + prompt
57
- else:
58
- prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
59
-
60
- conv = Conversation(
61
- system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nSen bir yapay zeka asistanısın. Kullanıcı sana bir görev verecek. Amacın görevi olabildiğince sadık bir şekilde tamamlamak. Görevi yerine getirirken adım adım düşün ve adımlarını gerekçelendir.""",
62
- roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
63
- version="llama3",
64
- messages=[],
65
- offset=0,
66
- sep_style=SeparatorStyle.MPT,
67
- sep="<|eot_id|>",
68
- )
69
- conv.append_message(conv.roles[0], prompt)
70
- conv.append_message(conv.roles[1], None)
71
- full_prompt = conv.get_prompt()
72
-
73
- print("full prompt: ", full_prompt)
74
-
75
- image = load_image(image_file)
76
- image_tensor = process_images(
77
- [image],
78
- image_processor,
79
- model.config
80
- ).to(model.device, dtype=torch.float16)
81
-
82
- input_ids = (
83
- tokenizer_image_token(full_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
84
- .unsqueeze(0)
85
- .cuda()
86
- )
87
-
88
- with torch.inference_mode():
89
- output_ids = model.generate(
90
- input_ids,
91
- images=image_tensor,
92
- image_sizes=[image.size],
93
- do_sample=False,
94
- max_new_tokens=512,
95
- use_cache=True,
96
- )
97
-
98
- output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
99
- return output
100
-
101
- @spaces.GPU
102
- def bot_streaming(message, history):
103
- print(message)
104
- if message["files"]:
105
- if type(message["files"][-1]) == dict:
106
- image = message["files"][-1]["path"]
107
- else:
108
- image = message["files"][-1]
109
- else:
110
- for hist in history:
111
- if type(hist[0]) == tuple:
112
- image = hist[0][0]
113
- try:
114
- if image is None:
115
- gr.Error("You need to upload an image for LLaVA to work.")
116
- except NameError:
117
- gr.Error("You need to upload an image for LLaVA to work.")
118
-
119
- prompt = message['text']
120
-
121
- result = infer_single_image(model_id, image, prompt)
122
- yield result
123
-
124
- chatbot = gr.Chatbot(scale=1)
125
- chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
126
-
127
- with gr.Blocks(fill_height=True) as demo:
128
- gr.ChatInterface(
129
- fn=bot_streaming,
130
- title="LLaVA Llama-3-8B",
131
- examples=[{"text": "Çiçeğin üzerinde ne var?", "files": ["./bee.jpg"]},
132
- {"text": "Bu tatlı nasıl yapılır?", "files": ["./baklava.png"]}],
133
- description="",
134
- stop_btn="Stop Generation",
135
- multimodal=True,
136
- textbox=chat_input,
137
- chatbot=chatbot,
138
- )
139
-
140
- demo.queue(api_open=False)
141
- demo.launch(show_api=False, share=False)
 
 
 
 
1
+ import spaces
2
+
3
+ import time
4
+ from threading import Thread
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import AutoProcessor
10
+ from llava.constants import (
11
+ IMAGE_TOKEN_INDEX,
12
+ DEFAULT_IMAGE_TOKEN,
13
+ DEFAULT_IM_START_TOKEN,
14
+ DEFAULT_IM_END_TOKEN,
15
+ IMAGE_PLACEHOLDER,
16
+ )
17
+ from llava.model.builder import load_pretrained_model
18
+ from llava.utils import disable_torch_init
19
+ from llava.mm_utils import (
20
+ process_images,
21
+ tokenizer_image_token,
22
+ get_model_name_from_path,
23
+ )
24
+ from io import BytesIO
25
+ import requests
26
+ import os
27
+ from conversation import Conversation, SeparatorStyle
28
+
29
+ model_id = "ytu-ce-cosmos/Turkish-LLaVA-v0.1"
30
+
31
+ disable_torch_init()
32
+ model_name = get_model_name_from_path(model_id)
33
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
34
+ model_id, None, model_name
35
+ )
36
+
37
+ def load_image(image_file):
38
+ if image_file.startswith("http") or image_file.startswith("https"):
39
+ response = requests.get(image_file)
40
+ image = Image.open(BytesIO(response.content)).convert("RGB")
41
+ elif os.path.exists(image_file):
42
+ image = Image.open(image_file).convert("RGB")
43
+ else:
44
+ raise FileNotFoundError(f"Görüntü dosyası {image_file} bulunamadı.")
45
+ return image
46
+
47
+ def infer_single_image(model_id, image_file, prompt):
48
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
49
+ if IMAGE_PLACEHOLDER in prompt:
50
+ if model.config.mm_use_im_start_end:
51
+ prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
52
+ else:
53
+ prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
54
+ else:
55
+ if model.config.mm_use_im_start_end:
56
+ prompt = image_token_se + "\n" + prompt
57
+ else:
58
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
59
+
60
+ conv = Conversation(
61
+ system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nSen bir yapay zeka asistanısın. Kullanıcı sana bir görev verecek. Amacın görevi olabildiğince sadık bir şekilde tamamlamak. Görevi yerine getirirken adım adım düşün ve adımlarını gerekçelendir.""",
62
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
63
+ version="llama3",
64
+ messages=[],
65
+ offset=0,
66
+ sep_style=SeparatorStyle.MPT,
67
+ sep="<|eot_id|>",
68
+ )
69
+ conv.append_message(conv.roles[0], prompt)
70
+ conv.append_message(conv.roles[1], None)
71
+ full_prompt = conv.get_prompt()
72
+
73
+ print("full prompt: ", full_prompt)
74
+
75
+ image = load_image(image_file)
76
+ image_tensor = process_images(
77
+ [image],
78
+ image_processor,
79
+ model.config
80
+ ).to(model.device, dtype=torch.float16)
81
+
82
+ input_ids = (
83
+ tokenizer_image_token(full_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
84
+ .unsqueeze(0)
85
+ .cuda()
86
+ )
87
+
88
+ with torch.inference_mode():
89
+ output_ids = model.generate(
90
+ input_ids,
91
+ images=image_tensor,
92
+ image_sizes=[image.size],
93
+ do_sample=False,
94
+ max_new_tokens=512,
95
+ use_cache=True,
96
+ )
97
+
98
+ output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
99
+ return output
100
+
101
+ @spaces.GPU
102
+ def bot_streaming(message, history):
103
+ print(message)
104
+ if message["files"]:
105
+ if type(message["files"][-1]) == dict:
106
+ image = message["files"][-1]["path"]
107
+ else:
108
+ image = message["files"][-1]
109
+ else:
110
+ for hist in history:
111
+ if type(hist[0]) == tuple:
112
+ image = hist[0][0]
113
+ try:
114
+ if image is None:
115
+ gr.Error("LLaVA'nın çalışması için bir resim yüklemeniz gerekir.")
116
+ except NameError:
117
+ gr.Error("LLaVA'nın çalışması için bir resim yüklemeniz gerekir.")
118
+
119
+ prompt = message['text']
120
+
121
+ result = infer_single_image(model_id, image, prompt)
122
+
123
+ print(result)
124
+
125
+ yield result
126
+
127
+ chatbot = gr.Chatbot(scale=1)
128
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Mesaj girin veya dosya yükleyin...", show_label=False)
129
+
130
+ with gr.Blocks(fill_height=True) as demo:
131
+ gr.ChatInterface(
132
+ fn=bot_streaming,
133
+ title="LLaVA Llama-3-8B",
134
+ examples=[{"text": "Çiçeğin üzerinde ne var?", "files": ["./bee.jpg"]},
135
+ {"text": "Bu tatlı nasıl yapılır?", "files": ["./baklava.png"]}],
136
+ description="",
137
+ stop_btn="Stop Generation",
138
+ multimodal=True,
139
+ textbox=chat_input,
140
+ chatbot=chatbot,
141
+ )
142
+
143
+ demo.queue(api_open=False)
144
+ demo.launch(show_api=False, share=False)
baklava.png ADDED

Git LFS Details

  • SHA256: 7839e93dd753e5356176bf70d38c43bc56355099d8891ead7aaa342029369268
  • Pointer size: 132 Bytes
  • Size of remote file: 2.04 MB
bee.jpg ADDED

Git LFS Details

  • SHA256: 8b21ba78250f852ca5990063866b1ace6432521d0251bde7f8de783b22c99a6d
  • Pointer size: 132 Bytes
  • Size of remote file: 5.37 MB
conversation.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+ PLAIN = auto()
15
+ LLAMA_2 = auto()
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class Conversation:
20
+ """A class that keeps all conversation history."""
21
+ system: str
22
+ roles: List[str]
23
+ messages: List[List[str]]
24
+ offset: int
25
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
26
+ sep: str = "###"
27
+ sep2: str = None
28
+ version: str = "Unknown"
29
+
30
+ skip_next: bool = False
31
+
32
+ def get_prompt(self):
33
+ messages = self.messages
34
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
35
+ messages = self.messages.copy()
36
+ init_role, init_msg = messages[0].copy()
37
+ init_msg = init_msg[0].replace("<image>", "").strip()
38
+ if 'mmtag' in self.version:
39
+ messages[0] = (init_role, init_msg)
40
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
41
+ messages.insert(1, (self.roles[1], "Received."))
42
+ else:
43
+ messages[0] = (init_role, "<image>\n" + init_msg)
44
+
45
+ if self.sep_style == SeparatorStyle.SINGLE:
46
+ ret = self.system + self.sep
47
+ for role, message in messages:
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ elif self.sep_style == SeparatorStyle.TWO:
55
+ seps = [self.sep, self.sep2]
56
+ ret = self.system + seps[0]
57
+ for i, (role, message) in enumerate(messages):
58
+ if message:
59
+ if type(message) is tuple:
60
+ message, _, _ = message
61
+ ret += role + ": " + message + seps[i % 2]
62
+ else:
63
+ ret += role + ":"
64
+ elif self.sep_style == SeparatorStyle.MPT:
65
+ ret = self.system + self.sep
66
+ for role, message in messages:
67
+ if message:
68
+ if type(message) is tuple:
69
+ message, _, _ = message
70
+ ret += role + message + self.sep
71
+ else:
72
+ ret += role
73
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
74
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
75
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
76
+ ret = ""
77
+
78
+ for i, (role, message) in enumerate(messages):
79
+ if i == 0:
80
+ assert message, "first message should not be none"
81
+ assert role == self.roles[0], "first message should come from user"
82
+ if message:
83
+ if type(message) is tuple:
84
+ message, _, _ = message
85
+ if i == 0: message = wrap_sys(self.system) + message
86
+ if i % 2 == 0:
87
+ message = wrap_inst(message)
88
+ ret += self.sep + message
89
+ else:
90
+ ret += " " + message + " " + self.sep2
91
+ else:
92
+ ret += ""
93
+ ret = ret.lstrip(self.sep)
94
+ elif self.sep_style == SeparatorStyle.PLAIN:
95
+ seps = [self.sep, self.sep2]
96
+ ret = self.system
97
+ for i, (role, message) in enumerate(messages):
98
+ if message:
99
+ if type(message) is tuple:
100
+ message, _, _ = message
101
+ ret += message + seps[i % 2]
102
+ else:
103
+ ret += ""
104
+ else:
105
+ raise ValueError(f"Invalid style: {self.sep_style}")
106
+
107
+ return ret
108
+
109
+ def append_message(self, role, message):
110
+ self.messages.append([role, message])
111
+
112
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
113
+ if image_process_mode == "Pad":
114
+ def expand2square(pil_img, background_color=(122, 116, 104)):
115
+ width, height = pil_img.size
116
+ if width == height:
117
+ return pil_img
118
+ elif width > height:
119
+ result = Image.new(pil_img.mode, (width, width), background_color)
120
+ result.paste(pil_img, (0, (width - height) // 2))
121
+ return result
122
+ else:
123
+ result = Image.new(pil_img.mode, (height, height), background_color)
124
+ result.paste(pil_img, ((height - width) // 2, 0))
125
+ return result
126
+
127
+ image = expand2square(image)
128
+ elif image_process_mode in ["Default", "Crop"]:
129
+ pass
130
+ elif image_process_mode == "Resize":
131
+ image = image.resize((336, 336))
132
+ else:
133
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
134
+ if max(image.size) > max_len:
135
+ max_hw, min_hw = max(image.size), min(image.size)
136
+ aspect_ratio = max_hw / min_hw
137
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
138
+ longest_edge = int(shortest_edge * aspect_ratio)
139
+ W, H = image.size
140
+ if H > W:
141
+ H, W = longest_edge, shortest_edge
142
+ else:
143
+ H, W = shortest_edge, longest_edge
144
+ image = image.resize((W, H))
145
+ if return_pil:
146
+ return image
147
+ else:
148
+ buffered = BytesIO()
149
+ image.save(buffered, format=image_format)
150
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
151
+ return img_b64_str
152
+
153
+ def get_images(self, return_pil=False):
154
+ images = []
155
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
156
+ if i % 2 == 0:
157
+ if type(msg) is tuple:
158
+ msg, image, image_process_mode = msg
159
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
160
+ images.append(image)
161
+ return images
162
+
163
+ def to_gradio_chatbot(self):
164
+ ret = []
165
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
166
+ if i % 2 == 0:
167
+ if type(msg) is tuple:
168
+ msg, image, image_process_mode = msg
169
+ img_b64_str = self.process_image(
170
+ image, "Default", return_pil=False,
171
+ image_format='JPEG')
172
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
173
+ msg = img_str + msg.replace('<image>', '').strip()
174
+ ret.append([msg, None])
175
+ else:
176
+ ret.append([msg, None])
177
+ else:
178
+ ret[-1][-1] = msg
179
+ return ret
180
+
181
+ def copy(self):
182
+ return Conversation(
183
+ system=self.system,
184
+ roles=self.roles,
185
+ messages=[[x, y] for x, y in self.messages],
186
+ offset=self.offset,
187
+ sep_style=self.sep_style,
188
+ sep=self.sep,
189
+ sep2=self.sep2,
190
+ version=self.version)
191
+
192
+ def dict(self):
193
+ if len(self.get_images()) > 0:
194
+ return {
195
+ "system": self.system,
196
+ "roles": self.roles,
197
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
198
+ "offset": self.offset,
199
+ "sep": self.sep,
200
+ "sep2": self.sep2,
201
+ }
202
+ return {
203
+ "system": self.system,
204
+ "roles": self.roles,
205
+ "messages": self.messages,
206
+ "offset": self.offset,
207
+ "sep": self.sep,
208
+ "sep2": self.sep2,
209
+ }
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ llava-torch
2
+ spaces