tabtoyou commited on
Commit
2769331
โ€ข
1 Parent(s): ae922ce

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +139 -1
README.md CHANGED
@@ -2,4 +2,142 @@
2
  license: apache-2.0
3
  ---
4
  ### Korean Otter
5
- [Otter](https://huggingface.co/luodian/OTTER-9B-LA-InContext) ๋ชจ๋ธ์„ [KoLLaVA-Instruct-150K](https://huggingface.co/datasets/tabtoyou/KoLLaVA-Instruct-150k) ์ค‘ Complex resoning์— ํ•ด๋‹นํ•˜๋Š” 77k ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํ•™์Šตํ–ˆ์Šต๋‹ˆ๋‹ค. Otter ์ด๋ฏธ์ง€ [๋ฐ๋ชจ](https://github.com/Luodian/Otter)์—์„œ ํ•œ๊ตญ์–ด ์งˆ๋ฌธ์„ ์–ด๋Š์ •๋„ ์ดํ•ดํ•ด ์˜์–ด๋กœ ๋‹ต๋ณ€ํ•˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•˜๊ณ , ํ•ด๋‹น ๋ชจ๋ธ์„ ๊ทธ๋Œ€๋กœ ๊ฐ€์ ธ์™€ ํ•œ๊ตญ์–ด ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํ•™์Šต์ด ๋˜๋Š”์ง€ ํ…Œ์ŠคํŠธํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. GPU memory ํ•œ๊ณ„๋กœ Otter์˜ LLM ๋ถ€๋ถ„์—์„œ ํŠน์ • ๋ ˆ์ด์–ด ์ด์ƒ(>25)๋งŒ 1epoch ํ•™์Šตํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ๋‹ต๋ณ€์˜ ํ’ˆ์งˆ์ด ์ข‹์ง€ ์•Š์ง€๋งŒ, ๋” ๋งŽ์€ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ epoch์„ ๋Š˜๋ ค ํ•™์Šตํ•œ๋‹ค๋ฉด ๋” ์ข‹์€ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์„ ๊ฒƒ์œผ๋กœ ๋ณด์ž…๋‹ˆ๋‹ค. inference๋Š” [Otter](https://huggingface.co/luodian/OTTER-9B-LA-InContext)์˜ ์ฝ”๋“œ๋ฅผ ์ฐธ๊ณ ํ•ด ์ง„ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: apache-2.0
3
  ---
4
  ### Korean Otter
5
+ [Otter](https://huggingface.co/luodian/OTTER-9B-LA-InContext) ๋ชจ๋ธ์„ [KoLLaVA-Instruct-150K](https://huggingface.co/datasets/tabtoyou/KoLLaVA-Instruct-150k) ์ค‘ Complex resoning์— ํ•ด๋‹นํ•˜๋Š” 77k ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํ•™์Šตํ–ˆ์Šต๋‹ˆ๋‹ค. Otter ์ด๋ฏธ์ง€ [๋ฐ๋ชจ](https://github.com/Luodian/Otter)์—์„œ ํ•œ๊ตญ์–ด ์งˆ๋ฌธ์„ ์–ด๋Š์ •๋„ ์ดํ•ดํ•ด ์˜์–ด๋กœ ๋‹ต๋ณ€ํ•˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•˜๊ณ , ํ•ด๋‹น ๋ชจ๋ธ์„ ๊ทธ๋Œ€๋กœ ๊ฐ€์ ธ์™€ ํ•œ๊ตญ์–ด ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํ•™์Šต์ด ๋˜๋Š”์ง€ ํ…Œ์ŠคํŠธํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. GPU memory ํ•œ๊ณ„๋กœ Otter์˜ LLM ๋ถ€๋ถ„์—์„œ ํŠน์ • ๋ ˆ์ด์–ด ์ด์ƒ(>25)๋งŒ 1epoch ํ•™์Šตํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ๋‹ต๋ณ€์˜ ํ’ˆ์งˆ์ด ์ข‹์ง€ ์•Š์ง€๋งŒ, ๋” ๋งŽ์€ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ epoch์„ ๋Š˜๋ ค ํ•™์Šตํ•œ๋‹ค๋ฉด ๋” ์ข‹์€ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์„ ๊ฒƒ์œผ๋กœ ๋ณด์ž…๋‹ˆ๋‹ค.
6
+
7
+
8
+ ``` python
9
+ import mimetypes
10
+ import os
11
+ from io import BytesIO
12
+ from typing import Union
13
+ import cv2
14
+ import requests
15
+ import torch
16
+ import transformers
17
+ from PIL import Image
18
+ from torchvision.transforms import Compose, Resize, ToTensor
19
+ from tqdm import tqdm
20
+ import sys
21
+
22
+ from otter.modeling_otter import OtterForConditionalGeneration
23
+
24
+
25
+ # Disable warnings
26
+ requests.packages.urllib3.disable_warnings()
27
+
28
+ # ------------------- Utility Functions -------------------
29
+
30
+
31
+ def get_content_type(file_path):
32
+ content_type, _ = mimetypes.guess_type(file_path)
33
+ return content_type
34
+
35
+
36
+ # ------------------- Image and Video Handling Functions -------------------
37
+
38
+ def get_image(url: str) -> Union[Image.Image, list]:
39
+ if "://" not in url: # Local file
40
+ content_type = get_content_type(url)
41
+ else: # Remote URL
42
+ content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")
43
+
44
+ if "image" in content_type:
45
+ if "://" not in url: # Local file
46
+ return Image.open(url)
47
+ else: # Remote URL
48
+ return Image.open(requests.get(url, stream=True, verify=False).raw)
49
+ else:
50
+ raise ValueError("Invalid content type. Expected image or video.")
51
+
52
+
53
+ # ------------------- OTTER Prompt and Response Functions -------------------
54
+
55
+
56
+ def get_formatted_prompt(prompt: str, in_context_prompts: list = []) -> str:
57
+ in_context_string = ""
58
+ for in_context_prompt, in_context_answer in in_context_prompts:
59
+ in_context_string += f"<image>User: {in_context_prompt} GPT:<answer> {in_context_answer}<|endofchunk|>"
60
+ return f"{in_context_string}<image>User: {prompt} GPT:<answer>"
61
+
62
+
63
+ def get_response(image_list, prompt: str, model=None, image_processor=None, in_context_prompts: list = []) -> str:
64
+ input_data = image_list
65
+
66
+ if isinstance(input_data, Image.Image):
67
+ vision_x = image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
68
+ elif isinstance(input_data, list): # list of video frames
69
+ vision_x = image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
70
+ else:
71
+ raise ValueError("Invalid input data. Expected PIL Image or list of video frames.")
72
+
73
+ lang_x = model.text_tokenizer(
74
+ [
75
+ get_formatted_prompt(prompt, in_context_prompts),
76
+ ],
77
+ return_tensors="pt",
78
+ )
79
+ bad_words_id = tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
80
+ generated_text = model.generate(
81
+ vision_x=vision_x.to(model.device),
82
+ lang_x=lang_x["input_ids"].to(model.device),
83
+ attention_mask=lang_x["attention_mask"].to(model.device),
84
+ max_new_tokens=512,
85
+ num_beams=3,
86
+ no_repeat_ngram_size=3,
87
+ bad_words_ids=bad_words_id,
88
+ )
89
+ parsed_output = (
90
+ model.text_tokenizer.decode(generated_text[0])
91
+ .split("<answer>")[-1]
92
+ .lstrip()
93
+ .rstrip()
94
+ .split("<|endofchunk|>")[0]
95
+ .lstrip()
96
+ .rstrip()
97
+ .lstrip('"')
98
+ .rstrip('"')
99
+ )
100
+ return parsed_output
101
+
102
+
103
+ # ------------------- Main Function -------------------
104
+
105
+ if __name__ == "__main__":
106
+ model = OtterForConditionalGeneration.from_pretrained("tabtoyou/Ko-Otter-9B-LACR-v0", device_map="auto")
107
+ model.text_tokenizer.padding_side = "left"
108
+ tokenizer = model.text_tokenizer
109
+ image_processor = transformers.CLIPImageProcessor()
110
+ model.eval()
111
+
112
+ while True:
113
+ urls = [
114
+ "https://images.cocodataset.org/train2017/000000339543.jpg",
115
+ "https://images.cocodataset.org/train2017/000000140285.jpg",
116
+ ]
117
+
118
+ encoded_frames_list = []
119
+ for url in urls:
120
+ frames = get_image(url)
121
+ encoded_frames_list.append(frames)
122
+
123
+ in_context_prompts = []
124
+ in_context_examples = [
125
+ "์ด๋ฏธ์ง€์— ๋Œ€ํ•ด ๋ฌ˜์‚ฌํ•ด์ฃผ์„ธ์š”::ํ•œ ๊ฐ€์กฑ์ด ์„ค์‚ฐ ์•ž์—์„œ ์‚ฌ์ง„์„ ์ฐ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.",
126
+ ]
127
+ for in_context_input in in_context_examples:
128
+ in_context_prompt, in_context_answer = in_context_input.split("::")
129
+ in_context_prompts.append((in_context_prompt.strip(), in_context_answer.strip()))
130
+
131
+ # prompts_input = input("Enter the prompts separated by commas (or type 'quit' to exit): ")
132
+ prompts_input = "์ด๋ฏธ์ง€์— ๋Œ€ํ•ด ๋ฌ˜์‚ฌํ•ด์ฃผ์„ธ์š”"
133
+
134
+ prompts = [prompt.strip() for prompt in prompts_input.split(",")]
135
+
136
+ for prompt in prompts:
137
+ print(f"\nPrompt: {prompt}")
138
+ response = get_response(encoded_frames_list, prompt, model, image_processor, in_context_prompts)
139
+ print(f"Response: {response}")
140
+
141
+ if prompts_input.lower() == "quit":
142
+ break
143
+ ```