cocktailpeanut commited on
Commit
c721d8e
1 Parent(s): 14a73d8
Files changed (2) hide show
  1. app.py +13 -7
  2. requirements.txt +4 -4
app.py CHANGED
@@ -2,10 +2,10 @@
2
  # -*- coding: utf-8 -*-
3
  import os
4
 
5
- print("Installing correct gradio version...")
6
- os.system("pip uninstall -y gradio")
7
- os.system("pip install gradio==3.50.0")
8
- print("Installing Finished!")
9
 
10
  ##!/usr/bin/python3
11
  # -*- coding: utf-8 -*-
@@ -19,7 +19,13 @@ import torch
19
  from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
20
  import random
21
 
22
- mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to("cuda")
 
 
 
 
 
 
23
  mobile_sam.eval()
24
  mobile_predictor = SamPredictor(mobile_sam)
25
  colors = [(255, 0, 0), (0, 255, 0)]
@@ -107,7 +113,7 @@ def process(input_image,
107
  init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB")
108
  mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB")
109
 
110
- generator = torch.Generator("cuda").manual_seed(random.randint(0,2147483647) if randomize_seed else seed)
111
 
112
  image = pipe(
113
  [prompt]*2,
@@ -337,4 +343,4 @@ with block:
337
  run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
338
 
339
 
340
- block.launch()
 
2
  # -*- coding: utf-8 -*-
3
  import os
4
 
5
+ #print("Installing correct gradio version...")
6
+ #os.system("pip uninstall -y gradio")
7
+ #os.system("pip install gradio==3.50.0")
8
+ #print("Installing Finished!")
9
 
10
  ##!/usr/bin/python3
11
  # -*- coding: utf-8 -*-
 
19
  from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
20
  import random
21
 
22
+ if torch.backends.mps.is_available():
23
+ DEVICE = "mps"
24
+ elif torch.cuda.is_available():
25
+ DEVICE = "cuda"
26
+ else:
27
+ DEVICE = "cpu"
28
+ mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to(DEVICE)
29
  mobile_sam.eval()
30
  mobile_predictor = SamPredictor(mobile_sam)
31
  colors = [(255, 0, 0), (0, 255, 0)]
 
113
  init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB")
114
  mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB")
115
 
116
+ generator = torch.Generator(DEVICE).manual_seed(random.randint(0,2147483647) if randomize_seed else seed)
117
 
118
  image = pipe(
119
  [prompt]*2,
 
343
  run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
344
 
345
 
346
+ block.launch()
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- torch
2
- torchvision
3
- torchaudio
4
  transformers>=4.25.1
5
  gradio==3.50.0
6
  ftfy
@@ -16,4 +16,4 @@ torchmetrics
16
  open-clip-torch
17
  clip
18
  segment_anything
19
- git+https://github.com/TencentARC/BrushNet.git
 
1
+ #torch
2
+ #torchvision
3
+ #torchaudio
4
  transformers>=4.25.1
5
  gradio==3.50.0
6
  ftfy
 
16
  open-clip-torch
17
  clip
18
  segment_anything
19
+ git+https://github.com/TencentARC/BrushNet.git