fynn3003 commited on
Commit
c1b0a00
1 Parent(s): ed1cb12
Files changed (1) hide show
  1. app.py +17 -23
app.py CHANGED
@@ -1,8 +1,8 @@
1
-
2
  from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
3
  import torch
4
  from PIL import Image
5
  import gradio as gr
 
6
 
7
  model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
8
  feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
@@ -11,30 +11,24 @@ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning"
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model.to(device)
13
 
14
-
15
-
16
  max_length = 16
17
  num_beams = 4
18
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
19
- def predict_step(image_paths):
20
- images = []
21
- for image_path in image_paths:
22
- i_image = Image.open(image_path)
23
- if i_image.mode != "RGB":
24
- i_image = i_image.convert(mode="RGB")
25
-
26
- images.append(i_image)
27
 
28
- pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
29
- pixel_values = pixel_values.to(device)
30
-
31
- output_ids = model.generate(pixel_values, **gen_kwargs)
32
-
33
- preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
34
- preds = [pred.strip() for pred in preds]
35
- print(preds)
36
- return preds
37
-
38
-
39
- predict_step(['cat.jpg'])
 
 
 
 
40
 
 
 
1
  from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
2
  import torch
3
  from PIL import Image
4
  import gradio as gr
5
+ import numpy as np
6
 
7
  model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
8
  feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model.to(device)
13
 
 
 
14
  max_length = 16
15
  num_beams = 4
16
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
 
 
 
 
 
 
 
 
17
 
18
+ def predict_step(image):
19
+ i_image = Image.fromarray(np.uint8(image))
20
+ if i_image.mode != "RGB":
21
+ i_image = i_image.convert(mode="RGB")
22
+ pixel_values = feature_extractor(images=i_image, return_tensors="pt").pixel_values
23
+ pixel_values = pixel_values.to(device)
24
+ output_ids = model.generate(pixel_values, **gen_kwargs)
25
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
26
+ preds = [pred.strip() for pred in preds]
27
+ return preds
28
+
29
+ iface = gr.Interface(fn=predict_step,
30
+ inputs=gr.inputs.Image(shape=(224, 224)),
31
+ outputs=gr.outputs.Textbox(label="Generated Caption"))
32
+
33
+ iface.launch(share=True)
34