ZhengPeng7 commited on
Commit
b59df1c
1 Parent(s): bf97864

Add the weights option.

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  from PIL import Image
6
  import torch
7
  from torchvision import transforms
 
8
  import gradio as gr
9
  import spaces
10
  from gradio_imageslider import ImageSlider
@@ -34,9 +35,9 @@ class ImagePreprocessor():
34
  return image
35
 
36
 
37
-
38
  from transformers import AutoModelForImageSegmentation
39
- birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
 
40
  birefnet.to(device)
41
  birefnet.eval()
42
 
@@ -44,7 +45,12 @@ birefnet.eval()
44
  # def predict(image_1, image_2):
45
  # images = [image_1, image_2]
46
  @spaces.GPU
47
- def predict(image, resolution):
 
 
 
 
 
48
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
49
  # Image is a RGB numpy array.
50
  resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
@@ -84,7 +90,11 @@ examples[-1][1] = '512x512'
84
 
85
  demo = gr.Interface(
86
  fn=predict,
87
- inputs=['image', gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `512x512`. Higher resolutions can be much slower for inference.", label="Resolution")],
 
 
 
 
88
  outputs=ImageSlider(),
89
  examples=examples,
90
  title='Online demo for `Bilateral Reference for High-Resolution Dichotomous Image Segmentation`',
 
5
  from PIL import Image
6
  import torch
7
  from torchvision import transforms
8
+ from transformers import AutoModelForImageSegmentation
9
  import gradio as gr
10
  import spaces
11
  from gradio_imageslider import ImageSlider
 
35
  return image
36
 
37
 
 
38
  from transformers import AutoModelForImageSegmentation
39
+ model_path = 'zhengpeng7/BiRefNet'
40
+ birefnet = AutoModelForImageSegmentation.from_pretrained(model_path, trust_remote_code=True)
41
  birefnet.to(device)
42
  birefnet.eval()
43
 
 
45
  # def predict(image_1, image_2):
46
  # images = [image_1, image_2]
47
  @spaces.GPU
48
+ def predict(image, resolution, weights_file):
49
+ # Load BiRefNet with chosen weights
50
+ birefnet = AutoModelForImageSegmentation.from_pretrained(weights_file, trust_remote_code=True)
51
+ birefnet.to(device)
52
+ birefnet.eval()
53
+
54
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
55
  # Image is a RGB numpy array.
56
  resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
 
90
 
91
  demo = gr.Interface(
92
  fn=predict,
93
+ inputs=[
94
+ 'image',
95
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `512x512`. Higher resolutions can be much slower for inference.", label="Resolution"),
96
+ gr.Checkbox(['zhengpeng7/BiRefNet', 'zhengpeng7/BiRefNet-portrait'], label="Models", info="Choose the weights you want.")
97
+ ],
98
  outputs=ImageSlider(),
99
  examples=examples,
100
  title='Online demo for `Bilateral Reference for High-Resolution Dichotomous Image Segmentation`',