ZhengPeng7 commited on
Commit
f406b44
1 Parent(s): afa3ff7

Add option for resolution in text box to do worse but faster prediction.

Browse files
app.py CHANGED
@@ -16,9 +16,9 @@ device = config.device
16
 
17
 
18
  class ImagePreprocessor():
19
- def __init__(self) -> None:
20
  self.transform_image = transforms.Compose([
21
- transforms.Resize((1024, 1024)),
22
  transforms.ToTensor(),
23
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
24
  ])
@@ -42,13 +42,14 @@ model.eval()
42
 
43
  # def predict(image_1, image_2):
44
  # images = [image_1, image_2]
45
- def predict(image):
46
  images = [image]
47
  image_shapes = [image.shape[:2] for image in images]
48
  images = [Image.fromarray(image) for image in images]
49
 
 
 
50
  images_proc = []
51
- image_preprocessor = ImagePreprocessor()
52
  for image in images:
53
  images_proc.append(image_preprocessor.proc(image))
54
  images_proc = torch.cat([image_proc.unsqueeze(0) for image_proc in images_proc])
@@ -73,6 +74,7 @@ examples = [[_] for _ in glob('materials/examples/*')][:]
73
  N = 1
74
  ipt = [gr.Image() for _ in range(N)]
75
  opt = [gr.Image() for _ in range(N)]
 
76
  demo = gr.Interface(
77
  fn=predict,
78
  inputs=ipt,
@@ -80,6 +82,6 @@ demo = gr.Interface(
80
  examples=examples,
81
  title='Online demo for `Bilateral Reference for High-Resolution Dichotomous Image Segmentation`',
82
  description=('Upload a picture, our model will give you the binary maps of the highly accurate segmentation of the salient objects in it. :)'
83
- '\n')
84
  )
85
  demo.launch(debug=True)
 
16
 
17
 
18
  class ImagePreprocessor():
19
+ def __init__(self, resolution=(1024, 1024)) -> None:
20
  self.transform_image = transforms.Compose([
21
+ transforms.Resize(resolution),
22
  transforms.ToTensor(),
23
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
24
  ])
 
42
 
43
  # def predict(image_1, image_2):
44
  # images = [image_1, image_2]
45
+ def predict(image, resolution='1024x1024'):
46
  images = [image]
47
  image_shapes = [image.shape[:2] for image in images]
48
  images = [Image.fromarray(image) for image in images]
49
 
50
+ resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
51
+ image_preprocessor = ImagePreprocessor(resolution=resolution)
52
  images_proc = []
 
53
  for image in images:
54
  images_proc.append(image_preprocessor.proc(image))
55
  images_proc = torch.cat([image_proc.unsqueeze(0) for image_proc in images_proc])
 
74
  N = 1
75
  ipt = [gr.Image() for _ in range(N)]
76
  opt = [gr.Image() for _ in range(N)]
77
+ ipt += [gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `512x512`. Higher resolutions can be much slower for inference.", label="Resolution")]
78
  demo = gr.Interface(
79
  fn=predict,
80
  inputs=ipt,
 
82
  examples=examples,
83
  title='Online demo for `Bilateral Reference for High-Resolution Dichotomous Image Segmentation`',
84
  description=('Upload a picture, our model will give you the binary maps of the highly accurate segmentation of the salient objects in it. :)'
85
+ '\nThe resolution used in our training was `1024x1024`, which is too much burden for the huggingface free spaces like this one (cost ~500s). Please set resolution as more than `768x768` for images with many texture details to obtain good results!')
86
  )
87
  demo.launch(debug=True)
materials/examples/1024x1024-1#Accessories#1#Bag#3713356643_ff7bdcdbf6_o.jpg ADDED

Git LFS Details

  • SHA256: 47854347f888b8691e9e9b66a2401f7c88866f6ca8ac59f0886989b86ea5dcbc
  • Pointer size: 131 Bytes
  • Size of remote file: 643 kB