ZhengPeng7 commited on
Commit
d967d62
1 Parent(s): 1352148

Upgrade the weights loading method to avoid duplicated loading.

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -35,13 +35,6 @@ class ImagePreprocessor():
35
  return image
36
 
37
 
38
- from transformers import AutoModelForImageSegmentation
39
- weights_path = 'BiRefNet'
40
- birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', weights_path)), trust_remote_code=True)
41
- birefnet.to(device)
42
- birefnet.eval()
43
- birefnet.weights_path = weights_path
44
-
45
  usage_to_weights_file = {
46
  'General': 'BiRefNet',
47
  'Portrait': 'BiRefNet-portrait',
@@ -51,6 +44,13 @@ usage_to_weights_file = {
51
  'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs'
52
  }
53
 
 
 
 
 
 
 
 
54
 
55
  @spaces.GPU
56
  def predict(image, resolution, weights_file):
 
35
  return image
36
 
37
 
 
 
 
 
 
 
 
38
  usage_to_weights_file = {
39
  'General': 'BiRefNet',
40
  'Portrait': 'BiRefNet-portrait',
 
44
  'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs'
45
  }
46
 
47
+ from transformers import AutoModelForImageSegmentation
48
+ weights_path = 'General'
49
+ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file[weights_path])), trust_remote_code=True)
50
+ birefnet.to(device)
51
+ birefnet.eval()
52
+ birefnet.weights_path = weights_path
53
+
54
 
55
  @spaces.GPU
56
  def predict(image, resolution, weights_file):