ClaireOzzz commited on
Commit
1c2e38e
1 Parent(s): 847b4f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +397 -0
app.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import login, HfFileSystem, HfApi, ModelCard
3
+ import os
4
+ import spaces
5
+ import random
6
+ import torch
7
+
8
+ is_shared_ui = False
9
+
10
+ hf_token = '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'
11
+ login(token=hf_token)
12
+
13
+ fs = HfFileSystem(token=hf_token)
14
+ api = HfApi()
15
+
16
+ device="cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
19
+ from diffusers.utils import load_image
20
+ from PIL import Image
21
+ import torch
22
+ import numpy as np
23
+ import cv2
24
+
25
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
26
+
27
+ controlnet = ControlNetModel.from_pretrained(
28
+ "diffusers/controlnet-canny-sdxl-1.0",
29
+ torch_dtype=torch.float16
30
+ )
31
+
32
+ def check_use_custom_or_no(value):
33
+ if value is True:
34
+ return gr.update(visible=True)
35
+ else:
36
+ return gr.update(visible=False)
37
+
38
+ def get_files(file_paths):
39
+ last_files = {} # Dictionary to store the last file for each path
40
+
41
+ for file_path in file_paths:
42
+ # Split the file path into directory and file components
43
+ directory, file_name = file_path.rsplit('/', 1)
44
+
45
+ # Update the last file for the current path
46
+ last_files[directory] = file_name
47
+
48
+ # Extract the last files from the dictionary
49
+ result = list(last_files.values())
50
+
51
+ return result
52
+
53
+ def load_model(model_name):
54
+
55
+ if model_name == "":
56
+ gr.Warning("If you want to use a private model, you need to duplicate this space on your personal account.")
57
+ raise gr.Error("You forgot to define Model ID.")
58
+
59
+ # Get instance_prompt a.k.a trigger word
60
+ card = ModelCard.load(model_name)
61
+ repo_data = card.data.to_dict()
62
+ instance_prompt = repo_data.get("instance_prompt")
63
+
64
+ if instance_prompt is not None:
65
+ print(f"Trigger word: {instance_prompt}")
66
+ else:
67
+ instance_prompt = "no trigger word needed"
68
+ print(f"Trigger word: no trigger word needed")
69
+
70
+ # List all ".safetensors" files in repo
71
+ sfts_available_files = fs.glob(f"{model_name}/*safetensors")
72
+ sfts_available_files = get_files(sfts_available_files)
73
+
74
+ if sfts_available_files == []:
75
+ sfts_available_files = ["NO SAFETENSORS FILE"]
76
+
77
+ print(f"Safetensors available: {sfts_available_files}")
78
+
79
+ return model_name, "Model Ready", gr.update(choices=sfts_available_files, value=sfts_available_files[0], visible=True), gr.update(value=instance_prompt, visible=True)
80
+
81
+ def custom_model_changed(model_name, previous_model):
82
+ if model_name == "" and previous_model == "" :
83
+ status_message = ""
84
+ elif model_name != previous_model:
85
+ status_message = "model changed, please reload before any new run"
86
+ else:
87
+ status_message = "model ready"
88
+ return status_message
89
+
90
+ def resize_image(input_path, output_path, target_height):
91
+ # Open the input image
92
+ img = Image.open(input_path)
93
+
94
+ # Calculate the aspect ratio of the original image
95
+ original_width, original_height = img.size
96
+ original_aspect_ratio = original_width / original_height
97
+
98
+ # Calculate the new width while maintaining the aspect ratio and the target height
99
+ new_width = int(target_height * original_aspect_ratio)
100
+
101
+ # Resize the image while maintaining the aspect ratio and fixing the height
102
+ img = img.resize((new_width, target_height), Image.LANCZOS)
103
+
104
+ # Save the resized image
105
+ img.save(output_path)
106
+
107
+ return output_path
108
+
109
+ @spaces.GPU
110
+ def infer(use_custom_model, model_name, weight_name, custom_lora_weight, image_in, prompt, negative_prompt, preprocessor, controlnet_conditioning_scale, guidance_scale, inf_steps, seed, progress=gr.Progress(track_tqdm=True)):
111
+
112
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
113
+ "stabilityai/stable-diffusion-xl-base-1.0",
114
+ controlnet=controlnet,
115
+ vae=vae,
116
+ torch_dtype=torch.float16,
117
+ variant="fp16",
118
+ use_safetensors=True
119
+ )
120
+
121
+ pipe.to(device)
122
+
123
+ prompt = prompt
124
+ negative_prompt = negative_prompt
125
+
126
+ if seed < 0 :
127
+ seed = random.randint(0, 423538377342)
128
+
129
+ generator = torch.Generator(device=device).manual_seed(seed)
130
+
131
+ if image_in == None:
132
+ raise gr.Error("You forgot to upload a source image.")
133
+
134
+ image_in = resize_image(image_in, "resized_input.jpg", 1024)
135
+
136
+ if preprocessor == "canny":
137
+
138
+ image = load_image(image_in)
139
+
140
+ image = np.array(image)
141
+ image = cv2.Canny(image, 100, 200)
142
+ image = image[:, :, None]
143
+ image = np.concatenate([image, image, image], axis=2)
144
+ image = Image.fromarray(image)
145
+
146
+ if use_custom_model:
147
+
148
+ if model_name == "":
149
+ raise gr.Error("you forgot to set a custom model name.")
150
+
151
+ custom_model = model_name
152
+
153
+ # This is where you load your trained weights
154
+ if weight_name == "NO SAFETENSORS FILE":
155
+ pipe.load_lora_weights(
156
+ custom_model,
157
+ low_cpu_mem_usage = True,
158
+ use_auth_token = True
159
+ )
160
+
161
+ else:
162
+ pipe.load_lora_weights(
163
+ custom_model,
164
+ weight_name = weight_name,
165
+ low_cpu_mem_usage = True,
166
+ use_auth_token = True
167
+ )
168
+
169
+ lora_scale=custom_lora_weight
170
+
171
+ images = pipe(
172
+ prompt,
173
+ negative_prompt=negative_prompt,
174
+ image=image,
175
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
176
+ guidance_scale = float(guidance_scale),
177
+ num_inference_steps=inf_steps,
178
+ generator=generator,
179
+ cross_attention_kwargs={"scale": lora_scale}
180
+ ).images
181
+ else:
182
+ images = pipe(
183
+ prompt,
184
+ negative_prompt=negative_prompt,
185
+ image=image,
186
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
187
+ guidance_scale = float(guidance_scale),
188
+ num_inference_steps=inf_steps,
189
+ generator=generator,
190
+ ).images
191
+
192
+ images[0].save(f"result.png")
193
+
194
+ return f"result.png", seed
195
+
196
+ css="""
197
+ #col-container{
198
+ margin: 0 auto;
199
+ max-width: 720px;
200
+ text-align: left;
201
+ }
202
+ div#warning-duplicate {
203
+ background-color: #ebf5ff;
204
+ padding: 0 10px 5px;
205
+ margin: 20px 0;
206
+ }
207
+ div#warning-duplicate > .gr-prose > h2, div#warning-duplicate > .gr-prose > p {
208
+ color: #0f4592!important;
209
+ }
210
+ div#warning-duplicate strong {
211
+ color: #0f4592;
212
+ }
213
+ p.actions {
214
+ display: flex;
215
+ align-items: center;
216
+ margin: 20px 0;
217
+ }
218
+ div#warning-duplicate .actions a {
219
+ display: inline-block;
220
+ margin-right: 10px;
221
+ }
222
+ button#load_model_btn{
223
+ height: 46px;
224
+ }
225
+ #status_info{
226
+ font-size: 0.9em;
227
+ }
228
+ """
229
+
230
+ theme = gr.themes.Soft(
231
+ primary_hue="teal",
232
+ secondary_hue="gray",
233
+ ).set(
234
+ body_text_color_dark='*neutral_800',
235
+ background_fill_primary_dark='*neutral_50',
236
+ background_fill_secondary_dark='*neutral_50',
237
+ border_color_accent_dark='*primary_300',
238
+ border_color_primary_dark='*neutral_200',
239
+ color_accent_soft_dark='*neutral_50',
240
+ link_text_color_dark='*secondary_600',
241
+ link_text_color_active_dark='*secondary_600',
242
+ link_text_color_hover_dark='*secondary_700',
243
+ link_text_color_visited_dark='*secondary_500',
244
+ code_background_fill_dark='*neutral_100',
245
+ shadow_spread_dark='6px',
246
+ block_background_fill_dark='white',
247
+ block_label_background_fill_dark='*primary_100',
248
+ block_label_text_color_dark='*primary_500',
249
+ block_title_text_color_dark='*primary_500',
250
+ checkbox_background_color_dark='*background_fill_primary',
251
+ checkbox_background_color_selected_dark='*primary_600',
252
+ checkbox_border_color_dark='*neutral_100',
253
+ checkbox_border_color_focus_dark='*primary_500',
254
+ checkbox_border_color_hover_dark='*neutral_300',
255
+ checkbox_border_color_selected_dark='*primary_600',
256
+ checkbox_label_background_fill_selected_dark='*primary_500',
257
+ checkbox_label_text_color_selected_dark='white',
258
+ error_background_fill_dark='#fef2f2',
259
+ error_border_color_dark='#b91c1c',
260
+ error_text_color_dark='#b91c1c',
261
+ error_icon_color_dark='#b91c1c',
262
+ input_background_fill_dark='white',
263
+ input_background_fill_focus_dark='*secondary_500',
264
+ input_border_color_dark='*neutral_50',
265
+ input_border_color_focus_dark='*secondary_300',
266
+ input_placeholder_color_dark='*neutral_400',
267
+ slider_color_dark='*primary_500',
268
+ stat_background_fill_dark='*primary_300',
269
+ table_border_color_dark='*neutral_300',
270
+ table_even_background_fill_dark='white',
271
+ table_odd_background_fill_dark='*neutral_50',
272
+ button_primary_background_fill_dark='*primary_500',
273
+ button_primary_background_fill_hover_dark='*primary_400',
274
+ button_primary_border_color_dark='*primary_00',
275
+ button_secondary_background_fill_dark='whiite',
276
+ button_secondary_background_fill_hover_dark='*neutral_100',
277
+ button_secondary_border_color_dark='*neutral_200',
278
+ button_secondary_text_color_dark='*neutral_800'
279
+ )
280
+
281
+ with gr.Blocks(theme=theme, css=css) as demo:
282
+ with gr.Column(elem_id="col-container"):
283
+
284
+ gr.HTML("""
285
+ <h2 style="text-align: center;">SD-XL Control LoRas</h2>
286
+ <p style="text-align: center;">Use StableDiffusion XL with <a href="https://huggingface.co/collections/diffusers/sdxl-controlnets-64f9c35846f3f06f5abe351f">Diffusers' SDXL ControlNets</a></p>
287
+ """)
288
+
289
+ use_custom_model = gr.Checkbox(label="Use a custom pre-trained LoRa model ? (optional)", visible = False, value=False, info="To use a private model, you'll need to duplicate the space with your own access token.")
290
+
291
+ with gr.Blocks(visible=False) as custom_model_box:
292
+ with gr.Row():
293
+ with gr.Column():
294
+ if not is_shared_ui:
295
+ your_username = api.whoami()["name"]
296
+ my_models = api.list_models(author=your_username, filter=["diffusers", "stable-diffusion-xl", 'lora'])
297
+ model_names = [item.modelId for item in my_models]
298
+
299
+ if not is_shared_ui:
300
+ custom_model = gr.Dropdown(
301
+ label = "Your custom model ID",
302
+ info="You can pick one of your private models",
303
+ choices = model_names,
304
+ allow_custom_value = True
305
+ #placeholder = "username/model_id"
306
+ )
307
+ else:
308
+ custom_model = gr.Textbox(
309
+ label="Your custom model ID",
310
+ placeholder="your_username/your_trained_model_name",
311
+ info="Make sure your model is set to PUBLIC"
312
+ )
313
+
314
+ weight_name = gr.Dropdown(
315
+ label="Safetensors file",
316
+ #value="pytorch_lora_weights.safetensors",
317
+ info="specify which one if model has several .safetensors files",
318
+ allow_custom_value=True,
319
+ visible = False
320
+ )
321
+ with gr.Column():
322
+ with gr.Group():
323
+ # load_model_btn = gr.Button("Load my model", elem_id="load_model_btn")
324
+ previous_model = gr.Textbox(
325
+ visible = False
326
+ )
327
+ model_status = gr.Textbox(
328
+ label = "model status",
329
+ show_label = False,
330
+ elem_id = "status_info"
331
+ )
332
+ trigger_word = gr.Textbox(label="Trigger word", interactive=False, visible=False)
333
+
334
+ image_in = gr.Image(sources="upload", type="filepath")
335
+
336
+ with gr.Row():
337
+
338
+ with gr.Column():
339
+ with gr.Group():
340
+ prompt = gr.Textbox(label="Prompt")
341
+ negative_prompt = gr.Textbox(label="Negative prompt", value="extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured")
342
+ with gr.Group():
343
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=7.5)
344
+ inf_steps = gr.Slider(label="Inference Steps", minimum="25", maximum="50", step=1, value=25)
345
+ custom_lora_weight = gr.Slider(label="Custom model weights", minimum=0.1, maximum=0.9, step=0.1, value=0.9)
346
+
347
+ with gr.Column():
348
+ with gr.Group():
349
+ preprocessor = gr.Dropdown(label="Preprocessor", choices=["canny"], value="canny", interactive=False, info="For the moment, only canny is available")
350
+ controlnet_conditioning_scale = gr.Slider(label="Controlnet conditioning Scale", minimum=0.1, maximum=0.9, step=0.01, value=0.5)
351
+ with gr.Group():
352
+ seed = gr.Slider(
353
+ label="Seed",
354
+ info = "-1 denotes a random seed",
355
+ minimum=-1,
356
+ maximum=423538377342,
357
+ step=1,
358
+ value=-1
359
+ )
360
+ last_used_seed = gr.Number(
361
+ label = "Last used seed",
362
+ info = "the seed used in the last generation",
363
+ )
364
+
365
+ load_model_btn = gr.Button("Load my model", elem_id="load_model_btn")
366
+ submit_btn = gr.Button("Submit")
367
+
368
+ result = gr.Image(label="Result")
369
+
370
+ use_custom_model.change(
371
+ fn = check_use_custom_or_no,
372
+ inputs =[use_custom_model],
373
+ outputs = [custom_model_box],
374
+ queue = False
375
+ )
376
+ custom_model.blur(
377
+ fn=custom_model_changed,
378
+ inputs = [custom_model, previous_model],
379
+ outputs = [model_status],
380
+ queue = False
381
+ )
382
+ load_model_btn.click(
383
+ fn = load_model,
384
+ inputs=[custom_model],
385
+ outputs = [previous_model, model_status, weight_name, trigger_word],
386
+ queue = False
387
+ )
388
+ submit_btn.click(
389
+ fn = infer,
390
+ inputs = [use_custom_model,custom_model, weight_name, custom_lora_weight, image_in, prompt, negative_prompt, preprocessor, controlnet_conditioning_scale, guidance_scale, inf_steps, seed],
391
+ outputs = [result, last_used_seed]
392
+ )
393
+
394
+ # return demo
395
+
396
+
397
+ demo.queue(max_size=12).launch(share=True)