ZhengPeng7 commited on
Commit
bfe6e38
1 Parent(s): 7ccb658

Upgrade the weights loading method to avoid duplicated loading.

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -53,13 +53,12 @@ birefnet.eval()
53
  @spaces.GPU
54
  def predict(image, resolution, weights_file):
55
  global birefnet
56
- if weights_file != 'General':
57
- # Load BiRefNet with chosen weights
58
- _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else 'BiRefNet'))
59
- print('Change weights to:', _weights_file)
60
- birefnet = birefnet.from_pretrained(_weights_file)
61
- birefnet.to(device)
62
- birefnet.eval()
63
 
64
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
65
  # Image is a RGB numpy array.
 
53
  @spaces.GPU
54
  def predict(image, resolution, weights_file):
55
  global birefnet
56
+ # Load BiRefNet with chosen weights
57
+ _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else 'BiRefNet'))
58
+ print('Change weights to:', _weights_file)
59
+ birefnet = birefnet.from_pretrained(_weights_file)
60
+ birefnet.to(device)
61
+ birefnet.eval()
 
62
 
63
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
64
  # Image is a RGB numpy array.