zhiweili commited on
Commit
fd86234
1 Parent(s): a4a060a

add app_base

Browse files
Files changed (4) hide show
  1. app.py +1 -1
  2. app_base.py +143 -0
  3. inversion_run_base.py +222 -0
  4. segment_utils.py +11 -1
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
 
3
- from app_gfp import create_demo as create_demo_face
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
 
1
  import gradio as gr
2
 
3
+ from app_base import create_demo as create_demo_face
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
app_base.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import torch
5
+ import os
6
+ import numpy as np
7
+ import cv2
8
+
9
+ from PIL import Image
10
+ from inversion_run_base import run as base_run
11
+ from segment_utils import(
12
+ segment_image,
13
+ restore_result,
14
+ )
15
+ from gfpgan.utils import GFPGANer
16
+
17
+
18
+ DEFAULT_SRC_PROMPT = "a woman, photo"
19
+ DEFAULT_EDIT_PROMPT = "a beautiful woman, photo, hollywood style face, 8k, high quality"
20
+
21
+ DEFAULT_CATEGORY = "face"
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ os.system("pip freeze")
26
+ if not os.path.exists('GFPGANv1.4.pth'):
27
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
28
+
29
+ face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2)
30
+
31
+ os.makedirs('output', exist_ok=True)
32
+
33
+ @spaces.GPU(duration=15)
34
+ def image_to_image(
35
+ input_image: Image,
36
+ input_image_prompt: str,
37
+ edit_prompt: str,
38
+ seed: int,
39
+ w1: float,
40
+ num_steps: int,
41
+ start_step: int,
42
+ guidance_scale: float,
43
+ generate_size: int,
44
+ adapter_weights: float,
45
+ ):
46
+ w2 = 1.0
47
+ run_task_time = 0
48
+ time_cost_str = ''
49
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
50
+ run_model = base_run
51
+ res_image = run_model(
52
+ input_image,
53
+ input_image_prompt,
54
+ edit_prompt,
55
+ generate_size,
56
+ seed,
57
+ w1,
58
+ w2,
59
+ num_steps,
60
+ start_step,
61
+ guidance_scale,
62
+ adapter_weights,
63
+ )
64
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
65
+ enhanced_image = enhance(res_image)
66
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
67
+
68
+ return enhanced_image, res_image, time_cost_str
69
+
70
+ def get_time_cost(run_task_time, time_cost_str):
71
+ now_time = int(time.time()*1000)
72
+ if run_task_time == 0:
73
+ time_cost_str = 'start'
74
+ else:
75
+ if time_cost_str != '':
76
+ time_cost_str += f'-->'
77
+ time_cost_str += f'{now_time - run_task_time}'
78
+ run_task_time = now_time
79
+ return run_task_time, time_cost_str
80
+
81
+ def enhance(
82
+ pil_image: Image,
83
+ ):
84
+ img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
85
+
86
+ h, w = img.shape[0:2]
87
+ if h < 300:
88
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
89
+
90
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=True, paste_back=True)
91
+ pil_output = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
92
+
93
+ return pil_output
94
+
95
+ def create_demo() -> gr.Blocks:
96
+ with gr.Blocks() as demo:
97
+ croper = gr.State()
98
+ with gr.Row():
99
+ with gr.Column():
100
+ input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
101
+ edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
102
+ category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
103
+ with gr.Column():
104
+ num_steps = gr.Slider(minimum=1, maximum=100, value=5, step=1, label="Num Steps")
105
+ start_step = gr.Slider(minimum=1, maximum=100, value=1, step=1, label="Start Step")
106
+ guidance_scale = gr.Slider(minimum=0, maximum=20, value=2.5, step=0.5, label="Guidance Scale")
107
+ with gr.Accordion("Advanced Options", open=False):
108
+ generate_size = gr.Number(label="Generate Size", value=512)
109
+ mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
110
+ mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
111
+ with gr.Column():
112
+ seed = gr.Number(label="Seed", value=8)
113
+ w1 = gr.Number(label="W1", value=1.5)
114
+ adapter_weights = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Adapter Weights", visible=False)
115
+ g_btn = gr.Button("Edit Image")
116
+
117
+ with gr.Row():
118
+ with gr.Column():
119
+ input_image = gr.Image(label="Input Image", type="pil")
120
+ with gr.Column():
121
+ restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
122
+ download_path = gr.File(label="Download the output image", interactive=False)
123
+ with gr.Column():
124
+ origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
125
+ enhanced_image = gr.Image(label="Enhanced Image", type="pil", interactive=False)
126
+ generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
127
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
128
+
129
+ g_btn.click(
130
+ fn=segment_image,
131
+ inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
132
+ outputs=[origin_area_image, croper],
133
+ ).success(
134
+ fn=image_to_image,
135
+ inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size, adapter_weights],
136
+ outputs=[enhanced_image, generated_image, generated_cost],
137
+ ).success(
138
+ fn=restore_result,
139
+ inputs=[croper, category, enhanced_image],
140
+ outputs=[restored_image, download_path],
141
+ )
142
+
143
+ return demo
inversion_run_base.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers import (
4
+ DDPMScheduler,
5
+ StableDiffusionXLImg2ImgPipeline,
6
+ )
7
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents
8
+ from PIL import Image
9
+ from inversion_utils import get_ddpm_inversion_scheduler, create_xts
10
+ from config import get_config, get_num_steps_actual
11
+ from functools import partial
12
+ from compel import Compel, ReturnedEmbeddingsType
13
+
14
+ class Object(object):
15
+ pass
16
+
17
+ args = Object()
18
+ args.images_paths = None
19
+ args.images_folder = None
20
+ args.force_use_cpu = False
21
+ args.folder_name = 'test_measure_time'
22
+ args.config_from_file = 'run_configs/noise_shift_guidance_1_5.yaml'
23
+ args.save_intermediate_results = False
24
+ args.batch_size = None
25
+ args.skip_p_to_p = True
26
+ args.only_p_to_p = False
27
+ args.fp16 = False
28
+ args.prompts_file = 'dataset_measure_time/dataset.json'
29
+ args.images_in_prompts_file = None
30
+ args.seed = 986
31
+ args.time_measure_n = 1
32
+
33
+
34
+ assert (
35
+ args.batch_size is None or args.save_intermediate_results is False
36
+ ), "save_intermediate_results is not implemented for batch_size > 1"
37
+
38
+ generator = None
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+
41
+ # BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
42
+ BASE_MODEL = "stabilityai/sdxl-turbo"
43
+
44
+
45
+ pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
46
+ BASE_MODEL,
47
+ torch_dtype=torch.float16,
48
+ variant="fp16",
49
+ use_safetensors=True,
50
+ )
51
+ pipeline = pipeline.to(device)
52
+
53
+ pipeline.scheduler = DDPMScheduler.from_pretrained(
54
+ BASE_MODEL,
55
+ subfolder="scheduler",
56
+ )
57
+
58
+ # pipeline.load_lora_weights("checkpoints/lora", weight_name="zy_AmateurStyle_v2.safetensors", adapter_name="zy_AmateurStyle_v2")
59
+
60
+ config = get_config(args)
61
+
62
+ compel_proc = Compel(
63
+ tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] ,
64
+ text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
65
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
66
+ requires_pooled=[False, True]
67
+ )
68
+
69
+ def run(
70
+ input_image:Image,
71
+ src_prompt:str,
72
+ tgt_prompt:str,
73
+ generate_size:int,
74
+ seed:int,
75
+ w1:float,
76
+ w2:float,
77
+ num_steps:int,
78
+ start_step:int,
79
+ guidance_scale:float,
80
+ adapter_weights:float,
81
+ ):
82
+ # pipeline.set_adapters(["zy_AmateurStyle_v2"], adapter_weights=[adapter_weights])
83
+ generator = torch.Generator().manual_seed(seed)
84
+
85
+ config.num_steps_inversion = num_steps
86
+ config.step_start = start_step
87
+ num_steps_actual = get_num_steps_actual(config)
88
+
89
+
90
+ num_steps_inversion = config.num_steps_inversion
91
+ denoising_start = (num_steps_inversion - num_steps_actual) / num_steps_inversion
92
+ print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} denoising_start: {denoising_start}")
93
+
94
+ timesteps, num_inference_steps = retrieve_timesteps(
95
+ pipeline.scheduler, num_steps_inversion, device, None
96
+ )
97
+ timesteps, num_inference_steps = pipeline.get_timesteps(
98
+ num_inference_steps=num_inference_steps,
99
+ denoising_start=denoising_start,
100
+ strength=0,
101
+ device=device,
102
+ )
103
+ timesteps = timesteps.type(torch.int64)
104
+
105
+ timesteps = [torch.tensor(t) for t in timesteps.tolist()]
106
+ timesteps_len = len(timesteps)
107
+ config.step_start = start_step + num_steps_actual - timesteps_len
108
+ num_steps_actual = timesteps_len
109
+ config.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
110
+ print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} step_start: {config.step_start}")
111
+ print(f"-------->timesteps len: {len(timesteps)} max_norm_zs len: {len(config.max_norm_zs)}")
112
+ pipeline.__call__ = partial(
113
+ pipeline.__call__,
114
+ num_inference_steps=num_steps_inversion,
115
+ guidance_scale=guidance_scale,
116
+ generator=generator,
117
+ denoising_start=denoising_start,
118
+ strength=0,
119
+ )
120
+
121
+ x_0_image = input_image
122
+ x_0 = encode_image(x_0_image, pipeline)
123
+ x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
124
+ x_ts = [xt.to(dtype=torch.float16) for xt in x_ts]
125
+ latents = [x_ts[0]]
126
+ x_ts_c_hat = [None]
127
+ config.ws1 = [w1] * num_steps_actual
128
+ config.ws2 = [w2] * num_steps_actual
129
+ pipeline.scheduler = get_ddpm_inversion_scheduler(
130
+ pipeline.scheduler,
131
+ config.step_function,
132
+ config,
133
+ timesteps,
134
+ config.save_timesteps,
135
+ latents,
136
+ x_ts,
137
+ x_ts_c_hat,
138
+ args.save_intermediate_results,
139
+ pipeline,
140
+ x_0,
141
+ v1s_images := [],
142
+ v2s_images := [],
143
+ deltas_images := [],
144
+ v1_x0s := [],
145
+ v2_x0s := [],
146
+ deltas_x0s := [],
147
+ "res12",
148
+ image_name="im_name",
149
+ time_measure_n=args.time_measure_n,
150
+ )
151
+ latent = latents[0].expand(3, -1, -1, -1)
152
+ prompt = [src_prompt, src_prompt, tgt_prompt]
153
+ conditioning, pooled = compel_proc(prompt)
154
+ image = pipeline.__call__(
155
+ image=latent,
156
+ prompt_embeds=conditioning,
157
+ pooled_prompt_embeds=pooled,
158
+ eta=1,
159
+ ).images
160
+ return image[2]
161
+
162
+ def encode_image(image, pipe):
163
+ image = pipe.image_processor.preprocess(image)
164
+ originDtype = pipe.dtype
165
+ image = image.to(device=device, dtype=originDtype)
166
+
167
+ if pipe.vae.config.force_upcast:
168
+ image = image.float()
169
+ pipe.vae.to(dtype=torch.float32)
170
+
171
+ if isinstance(generator, list):
172
+ init_latents = [
173
+ retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i])
174
+ for i in range(1)
175
+ ]
176
+ init_latents = torch.cat(init_latents, dim=0)
177
+ else:
178
+ init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator)
179
+
180
+ if pipe.vae.config.force_upcast:
181
+ pipe.vae.to(originDtype)
182
+
183
+ init_latents = init_latents.to(originDtype)
184
+ init_latents = pipe.vae.config.scaling_factor * init_latents
185
+
186
+ return init_latents.to(dtype=torch.float16)
187
+
188
+ def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=None):
189
+ # get the original timestep using init_timestep
190
+ if denoising_start is None:
191
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
192
+ t_start = max(num_inference_steps - init_timestep, 0)
193
+ else:
194
+ t_start = 0
195
+
196
+ timesteps = pipe.scheduler.timesteps[t_start * pipe.scheduler.order :]
197
+
198
+ # Strength is irrelevant if we directly request a timestep to start at;
199
+ # that is, strength is determined by the denoising_start instead.
200
+ if denoising_start is not None:
201
+ discrete_timestep_cutoff = int(
202
+ round(
203
+ pipe.scheduler.config.num_train_timesteps
204
+ - (denoising_start * pipe.scheduler.config.num_train_timesteps)
205
+ )
206
+ )
207
+
208
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
209
+ if pipe.scheduler.order == 2 and num_inference_steps % 2 == 0:
210
+ # if the scheduler is a 2nd order scheduler we might have to do +1
211
+ # because `num_inference_steps` might be even given that every timestep
212
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
213
+ # mean that we cut the timesteps in the middle of the denoising step
214
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
215
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
216
+ num_inference_steps = num_inference_steps + 1
217
+
218
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
219
+ timesteps = timesteps[-num_inference_steps:]
220
+ return timesteps, num_inference_steps
221
+
222
+ return timesteps, num_inference_steps - t_start
segment_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import numpy as np
2
  import mediapipe as mp
 
3
 
4
  from PIL import Image
5
  from mediapipe.tasks import python
@@ -22,7 +23,16 @@ def restore_result(croper, category, generated_image):
22
  restored_image = croper.input_image.copy()
23
  restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image)
24
 
25
- return restored_image
 
 
 
 
 
 
 
 
 
26
 
27
  def segment_image(input_image, category, generate_size, mask_expansion, mask_dilation):
28
  mask_size = int(generate_size)
 
1
  import numpy as np
2
  import mediapipe as mp
3
+ import uuid
4
 
5
  from PIL import Image
6
  from mediapipe.tasks import python
 
23
  restored_image = croper.input_image.copy()
24
  restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image)
25
 
26
+ extension = 'png'
27
+ if restored_image.mode == 'RGBA':
28
+ extension = 'png'
29
+ else:
30
+ extension = 'jpg'
31
+
32
+ path = f"output/{uuid.uuid4()}.{extension}"
33
+ restored_image.save(path)
34
+
35
+ return restored_image, path
36
 
37
  def segment_image(input_image, category, generate_size, mask_expansion, mask_dilation):
38
  mask_size = int(generate_size)