File size: 15,263 Bytes
e571ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eec823a
 
 
 
 
e571ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7444ebf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
"""
This file is used for deploying hugging face demo:
https://huggingface.co/spaces/
"""

import sys
sys.path.append('StableSR')
import os
import cv2
import torch
import torch.nn.functional as F
import gradio as gr
import torchvision
from torchvision.transforms.functional import normalize
from ldm.util import instantiate_from_config
from torch import autocast
import PIL
import numpy as np
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from omegaconf import OmegaConf
from PIL import Image
import copy
from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_normalization
from scripts.util_image import ImageSpliterTh
from basicsr.utils.download_util import load_file_from_url
from einops import rearrange, repeat

# os.system("pip freeze")

pretrain_model_url = {
	'stablesr_512': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_000117.ckpt',
	'stablesr_768': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_768v_000139.ckpt',
	'CFW': 'https://huggingface.co/Iceclear/StableSR/resolve/main/vqgan_cfw_00011.ckpt',
}
# download weights
if not os.path.exists('./stablesr_000117.ckpt'):
	load_file_from_url(url=pretrain_model_url['stablesr_512'], model_dir='./', progress=True, file_name=None)
if not os.path.exists('./stablesr_768v_000139.ckpt'):
	load_file_from_url(url=pretrain_model_url['stablesr_768'], model_dir='./', progress=True, file_name=None)
if not os.path.exists('./vqgan_cfw_00011.ckpt'):
	load_file_from_url(url=pretrain_model_url['CFW'], model_dir='./', progress=True, file_name=None)

# download images
torch.hub.download_url_to_file(
	'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/Lincoln.png',
	'01.png')
torch.hub.download_url_to_file(
	'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/oldphoto6.png',
	'02.png')
torch.hub.download_url_to_file(
	'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/comic2.png',
	'03.png')
torch.hub.download_url_to_file(
	'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/OST_120.png',
	'04.png')
torch.hub.download_url_to_file(
	'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet65/comic3.png',
	'05.png')

def load_img(path):
	image = Image.open(path).convert("RGB")
	w, h = image.size
	w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
	image = image.resize((w, h), resample=PIL.Image.LANCZOS)
	image = np.array(image).astype(np.float32) / 255.0
	image = image[None].transpose(0, 3, 1, 2)
	image = torch.from_numpy(image)
	return 2.*image - 1.

def space_timesteps(num_timesteps, section_counts):
	"""
	Create a list of timesteps to use from an original diffusion process,
	given the number of timesteps we want to take from equally-sized portions
	of the original process.
	For example, if there's 300 timesteps and the section counts are [10,15,20]
	then the first 100 timesteps are strided to be 10 timesteps, the second 100
	are strided to be 15 timesteps, and the final 100 are strided to be 20.
	If the stride is a string starting with "ddim", then the fixed striding
	from the DDIM paper is used, and only one section is allowed.
	:param num_timesteps: the number of diffusion steps in the original
							process to divide up.
	:param section_counts: either a list of numbers, or a string containing
							 comma-separated numbers, indicating the step count
							 per section. As a special case, use "ddimN" where N
							 is a number of steps to use the striding from the
							 DDIM paper.
	:return: a set of diffusion steps from the original process to use.
	"""
	if isinstance(section_counts, str):
		if section_counts.startswith("ddim"):
			desired_count = int(section_counts[len("ddim"):])
			for i in range(1, num_timesteps):
				if len(range(0, num_timesteps, i)) == desired_count:
					return set(range(0, num_timesteps, i))
			raise ValueError(
				f"cannot create exactly {num_timesteps} steps with an integer stride"
			)
		section_counts = [int(x) for x in section_counts.split(",")]   #[250,]
	size_per = num_timesteps // len(section_counts)
	extra = num_timesteps % len(section_counts)
	start_idx = 0
	all_steps = []
	for i, section_count in enumerate(section_counts):
		size = size_per + (1 if i < extra else 0)
		if size < section_count:
			raise ValueError(
				f"cannot divide section of {size} steps into {section_count}"
			)
		if section_count <= 1:
			frac_stride = 1
		else:
			frac_stride = (size - 1) / (section_count - 1)
		cur_idx = 0.0
		taken_steps = []
		for _ in range(section_count):
			taken_steps.append(start_idx + round(cur_idx))
			cur_idx += frac_stride
		all_steps += taken_steps
		start_idx += size
	return set(all_steps)

def chunk(it, size):
	it = iter(it)
	return iter(lambda: tuple(islice(it, size)), ())

def load_model_from_config(config, ckpt, verbose=False):
	print(f"Loading model from {ckpt}")
	pl_sd = torch.load(ckpt, map_location="cpu")
	if "global_step" in pl_sd:
		print(f"Global Step: {pl_sd['global_step']}")
	sd = pl_sd["state_dict"]
	model = instantiate_from_config(config.model)
	m, u = model.load_state_dict(sd, strict=False)
	if len(m) > 0 and verbose:
		print("missing keys:")
		print(m)
	if len(u) > 0 and verbose:
		print("unexpected keys:")
		print(u)

	model.cuda()
	model.eval()
	return model

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda")
vqgan_config = OmegaConf.load("StableSR/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml")
vq_model = load_model_from_config(vqgan_config, './vqgan_cfw_00011.ckpt')
vq_model = vq_model.to(device)

os.makedirs('output', exist_ok=True)

def inference(image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type):
	"""Run a single prediction on the model"""
	precision_scope = autocast
	vq_model.decoder.fusion_w = dec_w
	seed_everything(seed)

	if model_type == '512':
		config = OmegaConf.load("StableSR/configs/stableSRNew/v2-finetune_text_T_512.yaml")
		model = load_model_from_config(config, "./stablesr_000117.ckpt")
		min_size = 512
	else:
		config = OmegaConf.load("StableSR/configs/stableSRNew/v2-finetune_text_T_768v.yaml")
		model = load_model_from_config(config, "./stablesr_768v_000139.ckpt")
		min_size = 768

	model = model.to(device)
	model.configs = config
	model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000,
							linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3)
	model.num_timesteps = 1000

	sqrt_alphas_cumprod = copy.deepcopy(model.sqrt_alphas_cumprod)
	sqrt_one_minus_alphas_cumprod = copy.deepcopy(model.sqrt_one_minus_alphas_cumprod)

	use_timesteps = set(space_timesteps(1000, [ddpm_steps]))
	last_alpha_cumprod = 1.0
	new_betas = []
	timestep_map = []
	for i, alpha_cumprod in enumerate(model.alphas_cumprod):
		if i in use_timesteps:
			new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
			last_alpha_cumprod = alpha_cumprod
			timestep_map.append(i)
	new_betas = [beta.data.cpu().numpy() for beta in new_betas]
	model.register_schedule(given_betas=np.array(new_betas), timesteps=len(new_betas))
	model.num_timesteps = 1000
	model.ori_timesteps = list(use_timesteps)
	model.ori_timesteps.sort()
	model = model.to(device)

	try: # global try
			with torch.no_grad():
				with precision_scope("cuda"):
					with model.ema_scope():
						init_image = load_img(image)
						init_image = F.interpolate(
									init_image,
									size=(int(init_image.size(-2)*upscale),
											int(init_image.size(-1)*upscale)),
									mode='bicubic',
									)

						if init_image.size(-1) < min_size or init_image.size(-2) < min_size:
							ori_size = init_image.size()
							rescale = min_size * 1.0 / min(init_image.size(-2), init_image.size(-1))
							new_h = max(int(ori_size[-2]*rescale), min_size)
							new_w = max(int(ori_size[-1]*rescale), min_size)
							init_template = F.interpolate(
										init_image,
										size=(new_h, new_w),
										mode='bicubic',
										)
						else:
							init_template = init_image
							rescale = 1
						init_template = init_template.clamp(-1, 1)
						assert init_template.size(-1) >= min_size
						assert init_template.size(-2) >= min_size

						init_template = init_template.type(torch.float16).to(device)

						if init_template.size(-1) <= 1280 or init_template.size(-2) <= 1280:
							init_latent_generator, enc_fea_lq = vq_model.encode(init_template)
							init_latent = model.get_first_stage_encoding(init_latent_generator)
							text_init = ['']*init_template.size(0)
							semantic_c = model.cond_stage_model(text_init)

							noise = torch.randn_like(init_latent)

							t = repeat(torch.tensor([999]), '1 -> b', b=init_image.size(0))
							t = t.to(device).long()
							x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)

							if init_template.size(-1)<= min_size and init_template.size(-2) <= min_size:
								samples, _ = model.sample(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True)
							else:
								samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=init_template.size(0))
							x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
							if colorfix_type == 'adain':
								x_samples = adaptive_instance_normalization(x_samples, init_template)
							elif colorfix_type == 'wavelet':
								x_samples = wavelet_reconstruction(x_samples, init_template)
							x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
						else:
							im_spliter = ImageSpliterTh(init_template, 1280, 1000, sf=1)
							for im_lq_pch, index_infos in im_spliter:
								init_latent = model.get_first_stage_encoding(model.encode_first_stage(im_lq_pch))  # move to latent space
								text_init = ['']*init_latent.size(0)
								semantic_c = model.cond_stage_model(text_init)
								noise = torch.randn_like(init_latent)
								# If you would like to start from the intermediate steps, you can add noise to LR to the specific steps.
								t = repeat(torch.tensor([999]), '1 -> b', b=init_template.size(0))
								t = t.to(device).long()
								x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
								# x_T = noise
								samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=im_lq_pch.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=im_lq_pch.size(0))
								_, enc_fea_lq = vq_model.encode(im_lq_pch)
								x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
								if colorfix_type == 'adain':
									x_samples = adaptive_instance_normalization(x_samples, im_lq_pch)
								elif colorfix_type == 'wavelet':
									x_samples = wavelet_reconstruction(x_samples, im_lq_pch)
								im_spliter.update(x_samples, index_infos)
							x_samples = im_spliter.gather()
							x_samples = torch.clamp((x_samples+1.0)/2.0, min=0.0, max=1.0)

			if rescale > 1:
				x_samples = F.interpolate(
							x_samples,
							size=(int(init_image.size(-2)),
									int(init_image.size(-1))),
							mode='bicubic',
							)
				x_samples = x_samples.clamp(0, 1)
			x_sample = 255. * rearrange(x_samples[0].cpu().numpy(), 'c h w -> h w c')
			restored_img = x_sample.astype(np.uint8)
			Image.fromarray(x_sample.astype(np.uint8)).save(f'output/out.png')

			return restored_img, f'output/out.png'
	except Exception as error:
		print('Global exception', error)
		return None, None


title = "Exploiting Diffusion Prior for Real-World Image Super-Resolution"
description = r"""<center><img src='https://user-images.githubusercontent.com/22350795/236680126-0b1cdd62-d6fc-4620-b998-75ed6c31bf6f.png' style='height:40px' alt='StableSR logo'></center>
<b>Official Gradio demo</b> for <a href='https://github.com/IceClear/StableSR' target='_blank'><b>Exploiting Diffusion Prior for Real-World Image Super-Resolution</b></a>.<br>
πŸ”₯ StableSR is a general image super-resolution algorithm for real-world and AIGC images.<br>
"""
article = r"""
If StableSR is helpful, please help to ⭐ the <a href='https://github.com/IceClear/StableSR' target='_blank'>Github Repo</a>. Thanks!
[![GitHub Stars](https://img.shields.io/github/stars/IceClear/StableSR?style=social)](https://github.com/IceClear/StableSR)

---

πŸ“ **Citation**

If our work is useful for your research, please consider citing:
```bibtex
@article{wang2024exploiting,
  author = {Wang, Jianyi and Yue, Zongsheng and Zhou, Shangchen and Chan, Kelvin C.K. and Loy, Chen Change},
  title = {Exploiting Diffusion Prior for Real-World Image Super-Resolution},
  article = {International Journal of Computer Vision},
  year = {2024}
}
```

πŸ“‹ **License**

This project is licensed under <a rel="license" href="https://github.com/IceClear/StableSR/blob/main/LICENSE.txt">S-Lab License 1.0</a>.
Redistribution and use for non-commercial purposes should follow this license.

πŸ“§ **Contact**

If you have any questions, please feel free to reach me out at <b>iceclearwjy@gmail.com</b>.

<div>
	πŸ€— Find Me:
	<a href="https://twitter.com/Iceclearwjy"><img style="margin-top:0.5em; margin-bottom:0.5em" src="https://img.shields.io/twitter/follow/Iceclearwjy?label=%40Iceclearwjy&style=social" alt="Twitter Follow"></a>
	<a href="https://github.com/IceClear"><img style="margin-top:0.5em; margin-bottom:2em" src="https://img.shields.io/github/followers/IceClear?style=social" alt="Github Follow"></a>
</div>

<center><img src='https://visitor-badge.laobi.icu/badge?page_id=IceClear/StableSR' alt='visitors'></center>
"""

demo = gr.Interface(
	inference, [
		gr.inputs.Image(type="filepath", label="Input"),
		gr.inputs.Number(default=1, label="Rescaling_Factor (Large images require huge time)"),
		gr.Slider(0, 1, value=0.5, step=0.01, label='CFW_Fidelity (0 for better quality, 1 for better identity)'),
		gr.inputs.Number(default=42, label="Seeds"),
		gr.Dropdown(
			choices=["512", "768v"],
			value="512",
			label="Model",
			),
		gr.Slider(10, 1000, value=200, step=1, label='Sampling timesteps for DDPM'),
		gr.Dropdown(
			choices=["none", "adain", "wavelet"],
			value="adain",
			label="Color_Correction",
			),
	], [
		gr.outputs.Image(type="numpy", label="Output"),
		gr.outputs.File(label="Download the output")
	],
	title=title,
	description=description,
	article=article,
	examples=[
		['./01.png', 4, 0.5, 42, "512", 200, "adain"],
		['./02.png', 4, 0.5, 42, "512", 200, "adain"],
		['./03.png', 4, 0.5, 42, "512", 200, "adain"],
		['./04.png', 4, 0.5, 42, "512", 200, "adain"],
		['./05.png', 4, 0.5, 42, "512", 200, "adain"]
		]
	)

demo.queue(concurrency_count=1)
demo.launch()