Iceclear commited on
Commit
e571ea9
β€’
1 Parent(s): a68feeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +363 -363
app.py CHANGED
@@ -1,363 +1,363 @@
1
- """
2
- This file is used for deploying hugging face demo:
3
- https://huggingface.co/spaces/
4
- """
5
-
6
- import sys
7
- sys.path.append('StableSR')
8
- import os
9
- import cv2
10
- import torch
11
- import torch.nn.functional as F
12
- import gradio as gr
13
- import torchvision
14
- from torchvision.transforms.functional import normalize
15
- from ldm.util import instantiate_from_config
16
- from torch import autocast
17
- import PIL
18
- import numpy as np
19
- from pytorch_lightning import seed_everything
20
- from contextlib import nullcontext
21
- from omegaconf import OmegaConf
22
- from PIL import Image
23
- import copy
24
- from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_normalization
25
- from scripts.util_image import ImageSpliterTh
26
- from basicsr.utils.download_util import load_file_from_url
27
- from einops import rearrange, repeat
28
-
29
- # os.system("pip freeze")
30
-
31
- pretrain_model_url = {
32
- 'stablesr_512': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_000117.ckpt',
33
- 'stablesr_768': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_768v_000139.ckpt',
34
- 'CFW': 'https://huggingface.co/Iceclear/StableSR/resolve/main/vqgan_cfw_00011.ckpt',
35
- }
36
- # download weights
37
- if not os.path.exists('./stablesr_000117.ckpt'):
38
- load_file_from_url(url=pretrain_model_url['stablesr_512'], model_dir='./', progress=True, file_name=None)
39
- if not os.path.exists('./stablesr_768v_000139.ckpt'):
40
- load_file_from_url(url=pretrain_model_url['stablesr_768'], model_dir='./', progress=True, file_name=None)
41
- if not os.path.exists('./vqgan_cfw_00011.ckpt'):
42
- load_file_from_url(url=pretrain_model_url['CFW'], model_dir='./', progress=True, file_name=None)
43
-
44
- # download images
45
- torch.hub.download_url_to_file(
46
- 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/Lincoln.png',
47
- '01.png')
48
- torch.hub.download_url_to_file(
49
- 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/oldphoto6.png',
50
- '02.png')
51
- torch.hub.download_url_to_file(
52
- 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/comic2.png',
53
- '03.png')
54
- torch.hub.download_url_to_file(
55
- 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/OST_120.png',
56
- '04.png')
57
- torch.hub.download_url_to_file(
58
- 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet65/comic3.png',
59
- '05.png')
60
-
61
- def load_img(path):
62
- image = Image.open(path).convert("RGB")
63
- w, h = image.size
64
- w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
65
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
66
- image = np.array(image).astype(np.float32) / 255.0
67
- image = image[None].transpose(0, 3, 1, 2)
68
- image = torch.from_numpy(image)
69
- return 2.*image - 1.
70
-
71
- def space_timesteps(num_timesteps, section_counts):
72
- """
73
- Create a list of timesteps to use from an original diffusion process,
74
- given the number of timesteps we want to take from equally-sized portions
75
- of the original process.
76
- For example, if there's 300 timesteps and the section counts are [10,15,20]
77
- then the first 100 timesteps are strided to be 10 timesteps, the second 100
78
- are strided to be 15 timesteps, and the final 100 are strided to be 20.
79
- If the stride is a string starting with "ddim", then the fixed striding
80
- from the DDIM paper is used, and only one section is allowed.
81
- :param num_timesteps: the number of diffusion steps in the original
82
- process to divide up.
83
- :param section_counts: either a list of numbers, or a string containing
84
- comma-separated numbers, indicating the step count
85
- per section. As a special case, use "ddimN" where N
86
- is a number of steps to use the striding from the
87
- DDIM paper.
88
- :return: a set of diffusion steps from the original process to use.
89
- """
90
- if isinstance(section_counts, str):
91
- if section_counts.startswith("ddim"):
92
- desired_count = int(section_counts[len("ddim"):])
93
- for i in range(1, num_timesteps):
94
- if len(range(0, num_timesteps, i)) == desired_count:
95
- return set(range(0, num_timesteps, i))
96
- raise ValueError(
97
- f"cannot create exactly {num_timesteps} steps with an integer stride"
98
- )
99
- section_counts = [int(x) for x in section_counts.split(",")] #[250,]
100
- size_per = num_timesteps // len(section_counts)
101
- extra = num_timesteps % len(section_counts)
102
- start_idx = 0
103
- all_steps = []
104
- for i, section_count in enumerate(section_counts):
105
- size = size_per + (1 if i < extra else 0)
106
- if size < section_count:
107
- raise ValueError(
108
- f"cannot divide section of {size} steps into {section_count}"
109
- )
110
- if section_count <= 1:
111
- frac_stride = 1
112
- else:
113
- frac_stride = (size - 1) / (section_count - 1)
114
- cur_idx = 0.0
115
- taken_steps = []
116
- for _ in range(section_count):
117
- taken_steps.append(start_idx + round(cur_idx))
118
- cur_idx += frac_stride
119
- all_steps += taken_steps
120
- start_idx += size
121
- return set(all_steps)
122
-
123
- def chunk(it, size):
124
- it = iter(it)
125
- return iter(lambda: tuple(islice(it, size)), ())
126
-
127
- def load_model_from_config(config, ckpt, verbose=False):
128
- print(f"Loading model from {ckpt}")
129
- pl_sd = torch.load(ckpt, map_location="cpu")
130
- if "global_step" in pl_sd:
131
- print(f"Global Step: {pl_sd['global_step']}")
132
- sd = pl_sd["state_dict"]
133
- model = instantiate_from_config(config.model)
134
- m, u = model.load_state_dict(sd, strict=False)
135
- if len(m) > 0 and verbose:
136
- print("missing keys:")
137
- print(m)
138
- if len(u) > 0 and verbose:
139
- print("unexpected keys:")
140
- print(u)
141
-
142
- model.cuda()
143
- model.eval()
144
- return model
145
-
146
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
147
- device = torch.device("cuda")
148
- vqgan_config = OmegaConf.load("./configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml")
149
- vq_model = load_model_from_config(vqgan_config, './vqgan_cfw_00011.ckpt')
150
- vq_model = vq_model.to(device)
151
-
152
- os.makedirs('output', exist_ok=True)
153
-
154
- def inference(image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type):
155
- """Run a single prediction on the model"""
156
- precision_scope = autocast
157
- vq_model.decoder.fusion_w = dec_w
158
- seed_everything(seed)
159
-
160
- if model_type == '512':
161
- config = OmegaConf.load("./configs/stableSRNew/v2-finetune_text_T_512.yaml")
162
- model = load_model_from_config(config, "./stablesr_000117.ckpt")
163
- min_size = 512
164
- else:
165
- config = OmegaConf.load("./configs/stableSRNew/v2-finetune_text_T_768v.yaml")
166
- model = load_model_from_config(config, "./stablesr_768v_000139.ckpt")
167
- min_size = 768
168
-
169
- model = model.to(device)
170
- model.configs = config
171
- model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000,
172
- linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3)
173
- model.num_timesteps = 1000
174
-
175
- sqrt_alphas_cumprod = copy.deepcopy(model.sqrt_alphas_cumprod)
176
- sqrt_one_minus_alphas_cumprod = copy.deepcopy(model.sqrt_one_minus_alphas_cumprod)
177
-
178
- use_timesteps = set(space_timesteps(1000, [ddpm_steps]))
179
- last_alpha_cumprod = 1.0
180
- new_betas = []
181
- timestep_map = []
182
- for i, alpha_cumprod in enumerate(model.alphas_cumprod):
183
- if i in use_timesteps:
184
- new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
185
- last_alpha_cumprod = alpha_cumprod
186
- timestep_map.append(i)
187
- new_betas = [beta.data.cpu().numpy() for beta in new_betas]
188
- model.register_schedule(given_betas=np.array(new_betas), timesteps=len(new_betas))
189
- model.num_timesteps = 1000
190
- model.ori_timesteps = list(use_timesteps)
191
- model.ori_timesteps.sort()
192
- model = model.to(device)
193
-
194
- try: # global try
195
- with torch.no_grad():
196
- with precision_scope("cuda"):
197
- with model.ema_scope():
198
- init_image = load_img(image)
199
- init_image = F.interpolate(
200
- init_image,
201
- size=(int(init_image.size(-2)*upscale),
202
- int(init_image.size(-1)*upscale)),
203
- mode='bicubic',
204
- )
205
-
206
- if init_image.size(-1) < min_size or init_image.size(-2) < min_size:
207
- ori_size = init_image.size()
208
- rescale = min_size * 1.0 / min(init_image.size(-2), init_image.size(-1))
209
- new_h = max(int(ori_size[-2]*rescale), min_size)
210
- new_w = max(int(ori_size[-1]*rescale), min_size)
211
- init_template = F.interpolate(
212
- init_image,
213
- size=(new_h, new_w),
214
- mode='bicubic',
215
- )
216
- else:
217
- init_template = init_image
218
- rescale = 1
219
- init_template = init_template.clamp(-1, 1)
220
- assert init_template.size(-1) >= min_size
221
- assert init_template.size(-2) >= min_size
222
-
223
- init_template = init_template.type(torch.float16).to(device)
224
-
225
- if init_template.size(-1) <= 1280 or init_template.size(-2) <= 1280:
226
- init_latent_generator, enc_fea_lq = vq_model.encode(init_template)
227
- init_latent = model.get_first_stage_encoding(init_latent_generator)
228
- text_init = ['']*init_template.size(0)
229
- semantic_c = model.cond_stage_model(text_init)
230
-
231
- noise = torch.randn_like(init_latent)
232
-
233
- t = repeat(torch.tensor([999]), '1 -> b', b=init_image.size(0))
234
- t = t.to(device).long()
235
- x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
236
-
237
- if init_template.size(-1)<= min_size and init_template.size(-2) <= min_size:
238
- samples, _ = model.sample(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True)
239
- else:
240
- samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=init_template.size(0))
241
- x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
242
- if colorfix_type == 'adain':
243
- x_samples = adaptive_instance_normalization(x_samples, init_template)
244
- elif colorfix_type == 'wavelet':
245
- x_samples = wavelet_reconstruction(x_samples, init_template)
246
- x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
247
- else:
248
- im_spliter = ImageSpliterTh(init_template, 1280, 1000, sf=1)
249
- for im_lq_pch, index_infos in im_spliter:
250
- init_latent = model.get_first_stage_encoding(model.encode_first_stage(im_lq_pch)) # move to latent space
251
- text_init = ['']*init_latent.size(0)
252
- semantic_c = model.cond_stage_model(text_init)
253
- noise = torch.randn_like(init_latent)
254
- # If you would like to start from the intermediate steps, you can add noise to LR to the specific steps.
255
- t = repeat(torch.tensor([999]), '1 -> b', b=init_template.size(0))
256
- t = t.to(device).long()
257
- x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
258
- # x_T = noise
259
- samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=im_lq_pch.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=im_lq_pch.size(0))
260
- _, enc_fea_lq = vq_model.encode(im_lq_pch)
261
- x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
262
- if colorfix_type == 'adain':
263
- x_samples = adaptive_instance_normalization(x_samples, im_lq_pch)
264
- elif colorfix_type == 'wavelet':
265
- x_samples = wavelet_reconstruction(x_samples, im_lq_pch)
266
- im_spliter.update(x_samples, index_infos)
267
- x_samples = im_spliter.gather()
268
- x_samples = torch.clamp((x_samples+1.0)/2.0, min=0.0, max=1.0)
269
-
270
- if rescale > 1:
271
- x_samples = F.interpolate(
272
- x_samples,
273
- size=(int(init_image.size(-2)),
274
- int(init_image.size(-1))),
275
- mode='bicubic',
276
- )
277
- x_samples = x_samples.clamp(0, 1)
278
- x_sample = 255. * rearrange(x_samples[0].cpu().numpy(), 'c h w -> h w c')
279
- restored_img = x_sample.astype(np.uint8)
280
- Image.fromarray(x_sample.astype(np.uint8)).save(f'output/out.png')
281
-
282
- return restored_img, f'output/out.png'
283
- except Exception as error:
284
- print('Global exception', error)
285
- return None, None
286
-
287
-
288
- title = "Exploiting Diffusion Prior for Real-World Image Super-Resolution"
289
- description = r"""<center><img src='https://user-images.githubusercontent.com/22350795/236680126-0b1cdd62-d6fc-4620-b998-75ed6c31bf6f.png' style='height:40px' alt='StableSR logo'></center>
290
- <b>Official Gradio demo</b> for <a href='https://github.com/IceClear/StableSR' target='_blank'><b>Exploiting Diffusion Prior for Real-World Image Super-Resolution</b></a>.<br>
291
- πŸ”₯ StableSR is a general image super-resolution algorithm for real-world and AIGC images.<br>
292
- """
293
- article = r"""
294
- If StableSR is helpful, please help to ⭐ the <a href='https://github.com/IceClear/StableSR' target='_blank'>Github Repo</a>. Thanks!
295
- [![GitHub Stars](https://img.shields.io/github/stars/IceClear/StableSR?style=social)](https://github.com/IceClear/StableSR)
296
-
297
- ---
298
-
299
- πŸ“ **Citation**
300
-
301
- If our work is useful for your research, please consider citing:
302
- ```bibtex
303
- @inproceedings{wang2023exploiting,
304
- author = {Wang, Jianyi and Yue, Zongsheng and Zhou, Shangchen and Chan, Kelvin CK and Loy, Chen Change},
305
- title = {Exploiting Diffusion Prior for Real-World Image Super-Resolution},
306
- booktitle = {arXiv preprint arXiv:2305.07015},
307
- year = {2023}
308
- }
309
- ```
310
-
311
- πŸ“‹ **License**
312
-
313
- This project is licensed under <a rel="license" href="https://github.com/IceClear/StableSR/blob/main/LICENSE.txt">S-Lab License 1.0</a>.
314
- Redistribution and use for non-commercial purposes should follow this license.
315
-
316
- πŸ“§ **Contact**
317
-
318
- If you have any questions, please feel free to reach me out at <b>iceclearwjy@gmail.com</b>.
319
-
320
- <div>
321
- πŸ€— Find Me:
322
- <a href="https://twitter.com/Iceclearwjy"><img style="margin-top:0.5em; margin-bottom:0.5em" src="https://img.shields.io/twitter/follow/Iceclearwjy?label=%40Iceclearwjy&style=social" alt="Twitter Follow"></a>
323
- <a href="https://github.com/IceClear"><img style="margin-top:0.5em; margin-bottom:2em" src="https://img.shields.io/github/followers/IceClear?style=social" alt="Github Follow"></a>
324
- </div>
325
-
326
- <center><img src='https://visitor-badge.laobi.icu/badge?page_id=IceClear/StableSR' alt='visitors'></center>
327
- """
328
-
329
- demo = gr.Interface(
330
- inference, [
331
- gr.inputs.Image(type="filepath", label="Input"),
332
- gr.inputs.Number(default=1, label="Rescaling_Factor (Large images require huge time)"),
333
- gr.Slider(0, 1, value=0.5, step=0.01, label='CFW_Fidelity (0 for better quality, 1 for better identity)'),
334
- gr.inputs.Number(default=42, label="Seeds"),
335
- gr.Dropdown(
336
- choices=["512", "768v"],
337
- value="512",
338
- label="Model",
339
- ),
340
- gr.Slider(10, 1000, value=200, step=1, label='Sampling timesteps for DDPM'),
341
- gr.Dropdown(
342
- choices=["none", "adain", "wavelet"],
343
- value="adain",
344
- label="Color_Correction",
345
- ),
346
- ], [
347
- gr.outputs.Image(type="numpy", label="Output"),
348
- gr.outputs.File(label="Download the output")
349
- ],
350
- title=title,
351
- description=description,
352
- article=article,
353
- examples=[
354
- ['./01.png', 4, 0.5, 42, "512", 200, "adain"],
355
- ['./02.png', 4, 0.5, 42, "512", 200, "adain"],
356
- ['./03.png', 4, 0.5, 42, "512", 200, "adain"],
357
- ['./04.png', 4, 0.5, 42, "512", 200, "adain"],
358
- ['./05.png', 4, 0.5, 42, "512", 200, "adain"]
359
- ]
360
- )
361
-
362
- demo.queue(concurrency_count=1)
363
- demo.launch(share=True)
 
1
+ """
2
+ This file is used for deploying hugging face demo:
3
+ https://huggingface.co/spaces/
4
+ """
5
+
6
+ import sys
7
+ sys.path.append('StableSR')
8
+ import os
9
+ import cv2
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import gradio as gr
13
+ import torchvision
14
+ from torchvision.transforms.functional import normalize
15
+ from ldm.util import instantiate_from_config
16
+ from torch import autocast
17
+ import PIL
18
+ import numpy as np
19
+ from pytorch_lightning import seed_everything
20
+ from contextlib import nullcontext
21
+ from omegaconf import OmegaConf
22
+ from PIL import Image
23
+ import copy
24
+ from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_normalization
25
+ from scripts.util_image import ImageSpliterTh
26
+ from basicsr.utils.download_util import load_file_from_url
27
+ from einops import rearrange, repeat
28
+
29
+ # os.system("pip freeze")
30
+
31
+ pretrain_model_url = {
32
+ 'stablesr_512': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_000117.ckpt',
33
+ 'stablesr_768': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_768v_000139.ckpt',
34
+ 'CFW': 'https://huggingface.co/Iceclear/StableSR/resolve/main/vqgan_cfw_00011.ckpt',
35
+ }
36
+ # download weights
37
+ if not os.path.exists('./stablesr_000117.ckpt'):
38
+ load_file_from_url(url=pretrain_model_url['stablesr_512'], model_dir='./', progress=True, file_name=None)
39
+ if not os.path.exists('./stablesr_768v_000139.ckpt'):
40
+ load_file_from_url(url=pretrain_model_url['stablesr_768'], model_dir='./', progress=True, file_name=None)
41
+ if not os.path.exists('./vqgan_cfw_00011.ckpt'):
42
+ load_file_from_url(url=pretrain_model_url['CFW'], model_dir='./', progress=True, file_name=None)
43
+
44
+ # download images
45
+ torch.hub.download_url_to_file(
46
+ 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/Lincoln.png',
47
+ '01.png')
48
+ torch.hub.download_url_to_file(
49
+ 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/oldphoto6.png',
50
+ '02.png')
51
+ torch.hub.download_url_to_file(
52
+ 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/comic2.png',
53
+ '03.png')
54
+ torch.hub.download_url_to_file(
55
+ 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/OST_120.png',
56
+ '04.png')
57
+ torch.hub.download_url_to_file(
58
+ 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet65/comic3.png',
59
+ '05.png')
60
+
61
+ def load_img(path):
62
+ image = Image.open(path).convert("RGB")
63
+ w, h = image.size
64
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
65
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
66
+ image = np.array(image).astype(np.float32) / 255.0
67
+ image = image[None].transpose(0, 3, 1, 2)
68
+ image = torch.from_numpy(image)
69
+ return 2.*image - 1.
70
+
71
+ def space_timesteps(num_timesteps, section_counts):
72
+ """
73
+ Create a list of timesteps to use from an original diffusion process,
74
+ given the number of timesteps we want to take from equally-sized portions
75
+ of the original process.
76
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
77
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
78
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
79
+ If the stride is a string starting with "ddim", then the fixed striding
80
+ from the DDIM paper is used, and only one section is allowed.
81
+ :param num_timesteps: the number of diffusion steps in the original
82
+ process to divide up.
83
+ :param section_counts: either a list of numbers, or a string containing
84
+ comma-separated numbers, indicating the step count
85
+ per section. As a special case, use "ddimN" where N
86
+ is a number of steps to use the striding from the
87
+ DDIM paper.
88
+ :return: a set of diffusion steps from the original process to use.
89
+ """
90
+ if isinstance(section_counts, str):
91
+ if section_counts.startswith("ddim"):
92
+ desired_count = int(section_counts[len("ddim"):])
93
+ for i in range(1, num_timesteps):
94
+ if len(range(0, num_timesteps, i)) == desired_count:
95
+ return set(range(0, num_timesteps, i))
96
+ raise ValueError(
97
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
98
+ )
99
+ section_counts = [int(x) for x in section_counts.split(",")] #[250,]
100
+ size_per = num_timesteps // len(section_counts)
101
+ extra = num_timesteps % len(section_counts)
102
+ start_idx = 0
103
+ all_steps = []
104
+ for i, section_count in enumerate(section_counts):
105
+ size = size_per + (1 if i < extra else 0)
106
+ if size < section_count:
107
+ raise ValueError(
108
+ f"cannot divide section of {size} steps into {section_count}"
109
+ )
110
+ if section_count <= 1:
111
+ frac_stride = 1
112
+ else:
113
+ frac_stride = (size - 1) / (section_count - 1)
114
+ cur_idx = 0.0
115
+ taken_steps = []
116
+ for _ in range(section_count):
117
+ taken_steps.append(start_idx + round(cur_idx))
118
+ cur_idx += frac_stride
119
+ all_steps += taken_steps
120
+ start_idx += size
121
+ return set(all_steps)
122
+
123
+ def chunk(it, size):
124
+ it = iter(it)
125
+ return iter(lambda: tuple(islice(it, size)), ())
126
+
127
+ def load_model_from_config(config, ckpt, verbose=False):
128
+ print(f"Loading model from {ckpt}")
129
+ pl_sd = torch.load(ckpt, map_location="cpu")
130
+ if "global_step" in pl_sd:
131
+ print(f"Global Step: {pl_sd['global_step']}")
132
+ sd = pl_sd["state_dict"]
133
+ model = instantiate_from_config(config.model)
134
+ m, u = model.load_state_dict(sd, strict=False)
135
+ if len(m) > 0 and verbose:
136
+ print("missing keys:")
137
+ print(m)
138
+ if len(u) > 0 and verbose:
139
+ print("unexpected keys:")
140
+ print(u)
141
+
142
+ model.cuda()
143
+ model.eval()
144
+ return model
145
+
146
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
147
+ device = torch.device("cuda")
148
+ vqgan_config = OmegaConf.load("StableSR/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml")
149
+ vq_model = load_model_from_config(vqgan_config, './vqgan_cfw_00011.ckpt')
150
+ vq_model = vq_model.to(device)
151
+
152
+ os.makedirs('output', exist_ok=True)
153
+
154
+ def inference(image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type):
155
+ """Run a single prediction on the model"""
156
+ precision_scope = autocast
157
+ vq_model.decoder.fusion_w = dec_w
158
+ seed_everything(seed)
159
+
160
+ if model_type == '512':
161
+ config = OmegaConf.load("StableSR/configs/stableSRNew/v2-finetune_text_T_512.yaml")
162
+ model = load_model_from_config(config, "./stablesr_000117.ckpt")
163
+ min_size = 512
164
+ else:
165
+ config = OmegaConf.load("StableSR/configs/stableSRNew/v2-finetune_text_T_768v.yaml")
166
+ model = load_model_from_config(config, "./stablesr_768v_000139.ckpt")
167
+ min_size = 768
168
+
169
+ model = model.to(device)
170
+ model.configs = config
171
+ model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000,
172
+ linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3)
173
+ model.num_timesteps = 1000
174
+
175
+ sqrt_alphas_cumprod = copy.deepcopy(model.sqrt_alphas_cumprod)
176
+ sqrt_one_minus_alphas_cumprod = copy.deepcopy(model.sqrt_one_minus_alphas_cumprod)
177
+
178
+ use_timesteps = set(space_timesteps(1000, [ddpm_steps]))
179
+ last_alpha_cumprod = 1.0
180
+ new_betas = []
181
+ timestep_map = []
182
+ for i, alpha_cumprod in enumerate(model.alphas_cumprod):
183
+ if i in use_timesteps:
184
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
185
+ last_alpha_cumprod = alpha_cumprod
186
+ timestep_map.append(i)
187
+ new_betas = [beta.data.cpu().numpy() for beta in new_betas]
188
+ model.register_schedule(given_betas=np.array(new_betas), timesteps=len(new_betas))
189
+ model.num_timesteps = 1000
190
+ model.ori_timesteps = list(use_timesteps)
191
+ model.ori_timesteps.sort()
192
+ model = model.to(device)
193
+
194
+ try: # global try
195
+ with torch.no_grad():
196
+ with precision_scope("cuda"):
197
+ with model.ema_scope():
198
+ init_image = load_img(image)
199
+ init_image = F.interpolate(
200
+ init_image,
201
+ size=(int(init_image.size(-2)*upscale),
202
+ int(init_image.size(-1)*upscale)),
203
+ mode='bicubic',
204
+ )
205
+
206
+ if init_image.size(-1) < min_size or init_image.size(-2) < min_size:
207
+ ori_size = init_image.size()
208
+ rescale = min_size * 1.0 / min(init_image.size(-2), init_image.size(-1))
209
+ new_h = max(int(ori_size[-2]*rescale), min_size)
210
+ new_w = max(int(ori_size[-1]*rescale), min_size)
211
+ init_template = F.interpolate(
212
+ init_image,
213
+ size=(new_h, new_w),
214
+ mode='bicubic',
215
+ )
216
+ else:
217
+ init_template = init_image
218
+ rescale = 1
219
+ init_template = init_template.clamp(-1, 1)
220
+ assert init_template.size(-1) >= min_size
221
+ assert init_template.size(-2) >= min_size
222
+
223
+ init_template = init_template.type(torch.float16).to(device)
224
+
225
+ if init_template.size(-1) <= 1280 or init_template.size(-2) <= 1280:
226
+ init_latent_generator, enc_fea_lq = vq_model.encode(init_template)
227
+ init_latent = model.get_first_stage_encoding(init_latent_generator)
228
+ text_init = ['']*init_template.size(0)
229
+ semantic_c = model.cond_stage_model(text_init)
230
+
231
+ noise = torch.randn_like(init_latent)
232
+
233
+ t = repeat(torch.tensor([999]), '1 -> b', b=init_image.size(0))
234
+ t = t.to(device).long()
235
+ x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
236
+
237
+ if init_template.size(-1)<= min_size and init_template.size(-2) <= min_size:
238
+ samples, _ = model.sample(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True)
239
+ else:
240
+ samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=init_template.size(0))
241
+ x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
242
+ if colorfix_type == 'adain':
243
+ x_samples = adaptive_instance_normalization(x_samples, init_template)
244
+ elif colorfix_type == 'wavelet':
245
+ x_samples = wavelet_reconstruction(x_samples, init_template)
246
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
247
+ else:
248
+ im_spliter = ImageSpliterTh(init_template, 1280, 1000, sf=1)
249
+ for im_lq_pch, index_infos in im_spliter:
250
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(im_lq_pch)) # move to latent space
251
+ text_init = ['']*init_latent.size(0)
252
+ semantic_c = model.cond_stage_model(text_init)
253
+ noise = torch.randn_like(init_latent)
254
+ # If you would like to start from the intermediate steps, you can add noise to LR to the specific steps.
255
+ t = repeat(torch.tensor([999]), '1 -> b', b=init_template.size(0))
256
+ t = t.to(device).long()
257
+ x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
258
+ # x_T = noise
259
+ samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=im_lq_pch.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=im_lq_pch.size(0))
260
+ _, enc_fea_lq = vq_model.encode(im_lq_pch)
261
+ x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
262
+ if colorfix_type == 'adain':
263
+ x_samples = adaptive_instance_normalization(x_samples, im_lq_pch)
264
+ elif colorfix_type == 'wavelet':
265
+ x_samples = wavelet_reconstruction(x_samples, im_lq_pch)
266
+ im_spliter.update(x_samples, index_infos)
267
+ x_samples = im_spliter.gather()
268
+ x_samples = torch.clamp((x_samples+1.0)/2.0, min=0.0, max=1.0)
269
+
270
+ if rescale > 1:
271
+ x_samples = F.interpolate(
272
+ x_samples,
273
+ size=(int(init_image.size(-2)),
274
+ int(init_image.size(-1))),
275
+ mode='bicubic',
276
+ )
277
+ x_samples = x_samples.clamp(0, 1)
278
+ x_sample = 255. * rearrange(x_samples[0].cpu().numpy(), 'c h w -> h w c')
279
+ restored_img = x_sample.astype(np.uint8)
280
+ Image.fromarray(x_sample.astype(np.uint8)).save(f'output/out.png')
281
+
282
+ return restored_img, f'output/out.png'
283
+ except Exception as error:
284
+ print('Global exception', error)
285
+ return None, None
286
+
287
+
288
+ title = "Exploiting Diffusion Prior for Real-World Image Super-Resolution"
289
+ description = r"""<center><img src='https://user-images.githubusercontent.com/22350795/236680126-0b1cdd62-d6fc-4620-b998-75ed6c31bf6f.png' style='height:40px' alt='StableSR logo'></center>
290
+ <b>Official Gradio demo</b> for <a href='https://github.com/IceClear/StableSR' target='_blank'><b>Exploiting Diffusion Prior for Real-World Image Super-Resolution</b></a>.<br>
291
+ πŸ”₯ StableSR is a general image super-resolution algorithm for real-world and AIGC images.<br>
292
+ """
293
+ article = r"""
294
+ If StableSR is helpful, please help to ⭐ the <a href='https://github.com/IceClear/StableSR' target='_blank'>Github Repo</a>. Thanks!
295
+ [![GitHub Stars](https://img.shields.io/github/stars/IceClear/StableSR?style=social)](https://github.com/IceClear/StableSR)
296
+
297
+ ---
298
+
299
+ πŸ“ **Citation**
300
+
301
+ If our work is useful for your research, please consider citing:
302
+ ```bibtex
303
+ @inproceedings{wang2023exploiting,
304
+ author = {Wang, Jianyi and Yue, Zongsheng and Zhou, Shangchen and Chan, Kelvin CK and Loy, Chen Change},
305
+ title = {Exploiting Diffusion Prior for Real-World Image Super-Resolution},
306
+ booktitle = {arXiv preprint arXiv:2305.07015},
307
+ year = {2023}
308
+ }
309
+ ```
310
+
311
+ πŸ“‹ **License**
312
+
313
+ This project is licensed under <a rel="license" href="https://github.com/IceClear/StableSR/blob/main/LICENSE.txt">S-Lab License 1.0</a>.
314
+ Redistribution and use for non-commercial purposes should follow this license.
315
+
316
+ πŸ“§ **Contact**
317
+
318
+ If you have any questions, please feel free to reach me out at <b>iceclearwjy@gmail.com</b>.
319
+
320
+ <div>
321
+ πŸ€— Find Me:
322
+ <a href="https://twitter.com/Iceclearwjy"><img style="margin-top:0.5em; margin-bottom:0.5em" src="https://img.shields.io/twitter/follow/Iceclearwjy?label=%40Iceclearwjy&style=social" alt="Twitter Follow"></a>
323
+ <a href="https://github.com/IceClear"><img style="margin-top:0.5em; margin-bottom:2em" src="https://img.shields.io/github/followers/IceClear?style=social" alt="Github Follow"></a>
324
+ </div>
325
+
326
+ <center><img src='https://visitor-badge.laobi.icu/badge?page_id=IceClear/StableSR' alt='visitors'></center>
327
+ """
328
+
329
+ demo = gr.Interface(
330
+ inference, [
331
+ gr.inputs.Image(type="filepath", label="Input"),
332
+ gr.inputs.Number(default=1, label="Rescaling_Factor (Large images require huge time)"),
333
+ gr.Slider(0, 1, value=0.5, step=0.01, label='CFW_Fidelity (0 for better quality, 1 for better identity)'),
334
+ gr.inputs.Number(default=42, label="Seeds"),
335
+ gr.Dropdown(
336
+ choices=["512", "768v"],
337
+ value="512",
338
+ label="Model",
339
+ ),
340
+ gr.Slider(10, 1000, value=200, step=1, label='Sampling timesteps for DDPM'),
341
+ gr.Dropdown(
342
+ choices=["none", "adain", "wavelet"],
343
+ value="adain",
344
+ label="Color_Correction",
345
+ ),
346
+ ], [
347
+ gr.outputs.Image(type="numpy", label="Output"),
348
+ gr.outputs.File(label="Download the output")
349
+ ],
350
+ title=title,
351
+ description=description,
352
+ article=article,
353
+ examples=[
354
+ ['./01.png', 4, 0.5, 42, "512", 200, "adain"],
355
+ ['./02.png', 4, 0.5, 42, "512", 200, "adain"],
356
+ ['./03.png', 4, 0.5, 42, "512", 200, "adain"],
357
+ ['./04.png', 4, 0.5, 42, "512", 200, "adain"],
358
+ ['./05.png', 4, 0.5, 42, "512", 200, "adain"]
359
+ ]
360
+ )
361
+
362
+ demo.queue(concurrency_count=1)
363
+ demo.launch(share=True)