ameerazam08 commited on
Commit
84bf924
1 Parent(s): 80e89cb

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/main_figure.jpg filter=lfs diff=lfs merge=lfs -text
37
+ figures/sample_bunny_2K.png filter=lfs diff=lfs merge=lfs -text
38
+ figures/sample_icecream_4K.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 yhyun225
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
figures/.DS_Store ADDED
Binary file (6.15 kB). View file
 
figures/main_figure.jpg ADDED

Git LFS Details

  • SHA256: e71cdc429d085e933edf2f0e6b28e8d45a8ef16695e177fa74665487482f2bd6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
figures/sample_bunny_2K.png ADDED

Git LFS Details

  • SHA256: a9c69510165b65c7d5bd98ee720b5f3c8e49a197ae69c02bbdde94afc26d6e8a
  • Pointer size: 132 Bytes
  • Size of remote file: 4.64 MB
figures/sample_icecream_4K.png ADDED

Git LFS Details

  • SHA256: 792333c8fe94d1a0019c7e2eab4de6fe2c0bb87f00c2a34fca18c90e6d7d8dd3
  • Pointer size: 133 Bytes
  • Size of remote file: 17.5 MB
pipeline_diffusehigh_sdxl.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import PIL
3
+ import inspect
4
+ import os
5
+ from tqdm import tqdm
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ from transformers import (
12
+ CLIPImageProcessor,
13
+ CLIPTextModel,
14
+ CLIPTextModelWithProjection,
15
+ CLIPTokenizer,
16
+ CLIPVisionModelWithProjection,
17
+ )
18
+
19
+ from diffusers.schedulers import KarrasDiffusionSchedulers
20
+ from diffusers.image_processor import PipelineImageInput
21
+ from diffusers import (
22
+ AutoencoderKL,
23
+ UNet2DConditionModel,
24
+ StableDiffusionXLPipeline,
25
+ DDIMScheduler,
26
+ EulerDiscreteScheduler,
27
+ )
28
+ from diffusers.utils import BaseOutput
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from pytorch_wavelets import DWTForward, DWTInverse
31
+ from torchvision.transforms import GaussianBlur
32
+
33
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
34
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
35
+ """
36
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
37
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
38
+ """
39
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
40
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
41
+ # rescale the results from guidance (fixes overexposure)
42
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
43
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
44
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
45
+ return noise_cfg
46
+
47
+
48
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
49
+ def retrieve_timesteps(
50
+ scheduler,
51
+ num_inference_steps: Optional[int] = None,
52
+ device: Optional[Union[str, torch.device]] = None,
53
+ timesteps: Optional[List[int]] = None,
54
+ **kwargs,
55
+ ):
56
+ """
57
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
58
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
59
+
60
+ Args:
61
+ scheduler (`SchedulerMixin`):
62
+ The scheduler to get timesteps from.
63
+ num_inference_steps (`int`):
64
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
65
+ `timesteps` must be `None`.
66
+ device (`str` or `torch.device`, *optional*):
67
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
68
+ timesteps (`List[int]`, *optional*):
69
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
70
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
71
+ must be `None`.
72
+
73
+ Returns:
74
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
75
+ second element is the number of inference steps.
76
+ """
77
+ if timesteps is not None:
78
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
79
+ if not accepts_timesteps:
80
+ raise ValueError(
81
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
82
+ f" timestep schedules. Please check whether you are using the correct scheduler."
83
+ )
84
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
85
+ timesteps = scheduler.timesteps
86
+ num_inference_steps = len(timesteps)
87
+ else:
88
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
89
+ timesteps = scheduler.timesteps
90
+ return timesteps, num_inference_steps
91
+
92
+
93
+ def gaussian_blur_image_sharpening(image, kernel_size=3, sigma=(0.1, 2.0), alpha=1):
94
+ gaussian_blur = GaussianBlur(kernel_size=kernel_size, sigma=sigma)
95
+ image_blurred = gaussian_blur(image)
96
+ image_sharpened = (alpha + 1) * image - alpha * image_blurred
97
+
98
+ return image_sharpened
99
+
100
+
101
+ class DiffuseHighSDXLPipelineOutput(BaseOutput):
102
+ """
103
+ Output class for Stable Diffusion pipelines.
104
+
105
+ Args:
106
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
107
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
108
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
109
+ """
110
+
111
+ images: Union[List[PIL.Image.Image], np.ndarray]
112
+ guidance_images: Union[List[PIL.Image.Image], np.ndarray]
113
+
114
+
115
+ class DiffuseHighSDXLPipeline(StableDiffusionXLPipeline):
116
+ def __init__(
117
+ self,
118
+ vae: AutoencoderKL,
119
+ text_encoder: CLIPTextModel,
120
+ text_encoder_2: CLIPTextModelWithProjection,
121
+ tokenizer: CLIPTokenizer,
122
+ tokenizer_2: CLIPTokenizer,
123
+ unet: UNet2DConditionModel,
124
+ scheduler: KarrasDiffusionSchedulers,
125
+ image_encoder: CLIPVisionModelWithProjection = None,
126
+ feature_extractor: CLIPImageProcessor = None,
127
+ force_zeros_for_empty_prompt: bool = True,
128
+ add_watermarker: Optional[bool] = None,
129
+ ):
130
+ super().__init__(
131
+ vae=vae,
132
+ text_encoder=text_encoder,
133
+ text_encoder_2=text_encoder_2,
134
+ tokenizer=tokenizer,
135
+ tokenizer_2=tokenizer_2,
136
+ unet=unet,
137
+ scheduler=scheduler,
138
+ image_encoder=image_encoder,
139
+ feature_extractor=feature_extractor,
140
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
141
+ add_watermarker=add_watermarker
142
+ )
143
+
144
+ def _encode_vae_image(
145
+ self,
146
+ image: torch.Tensor,
147
+ normalize: bool = True,
148
+ ):
149
+ if normalize:
150
+ image = image * 2 - 1
151
+
152
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
153
+
154
+ if needs_upcasting:
155
+ self.upcast_vae()
156
+
157
+ image = image.to(self.device)
158
+ latents = self.vae.encode(image).latent_dist.mode() * self.vae.config.scaling_factor
159
+
160
+ if needs_upcasting:
161
+ self.vae.to(dtype=torch.float16)
162
+
163
+ return latents.to(self.dtype)
164
+
165
+ def _decode_vae_latent(
166
+ self,
167
+ latents: torch.Tensor,
168
+ output_type: Optional[str] = 'pt',
169
+ ):
170
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
171
+
172
+ if needs_upcasting:
173
+ self.upcast_vae()
174
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
175
+
176
+ latents = latents.to(self.device)
177
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
178
+ image = self.image_processor.postprocess(image, output_type=output_type)
179
+
180
+ if needs_upcasting:
181
+ self.vae.to(dtype=torch.float16)
182
+
183
+ return image
184
+
185
+ def edm_scheduler_step(
186
+ self,
187
+ model_output: torch.FloatTensor,
188
+ timestep: Union[float, torch.FloatTensor],
189
+ sample: torch.FloatTensor,
190
+ s_churn: float = 0.0,
191
+ s_tmin: float = 0.0,
192
+ s_tmax: float = 0.0,
193
+ s_noise: float = 1.0,
194
+ LL_guidance: Optional[torch.FloatTensor] = None,
195
+ generator: Optional[torch.Generator] = None,
196
+ return_pred_original_sample: bool = False,
197
+ ):
198
+ assert isinstance(self.scheduler, EulerDiscreteScheduler)
199
+ config = self.scheduler.config
200
+
201
+ if self.scheduler.step_index is None:
202
+ self.scheduler._init_step_index(timestep)
203
+
204
+ step_index = self.scheduler.step_index
205
+
206
+ sigma = self.scheduler.sigmas[step_index]
207
+
208
+ gamma = min(s_churn / (len(self.scheduler.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
209
+
210
+ noise = randn_tensor(
211
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
212
+ )
213
+
214
+ eps = noise * s_noise
215
+ sigma_hat = sigma * (gamma + 1)
216
+
217
+ if gamma > 0:
218
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
219
+
220
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
221
+ if config.prediction_type == "original_sample" or config.prediction_type == "sample":
222
+ pred_original_sample = model_output
223
+ elif config.prediction_type == "epsilon":
224
+ pred_original_sample = sample - sigma_hat * model_output
225
+ elif config.prediction_type == "v_prediction":
226
+ # denoised = model_output * c_out + input * c_skip
227
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
228
+ else:
229
+ raise ValueError(
230
+ f"prediction_type given as {config.prediction_type} must be one of `epsilon`, or `v_prediction`"
231
+ )
232
+
233
+ # 2. If gudiance LL component is given, perform structural guidance
234
+ if LL_guidance is not None:
235
+ pred_original_image = self._decode_vae_latent(pred_original_sample, output_type='pt')
236
+
237
+ _, HH = self.DWT(pred_original_image)
238
+ coeffs = (LL_guidance, HH)
239
+ pred_original_image = self.iDWT(coeffs)
240
+
241
+ pred_original_sample = self._encode_vae_image(pred_original_image)
242
+
243
+ # 3. Convert to an ODE derivative
244
+ derivative = (sample - pred_original_sample) / sigma_hat
245
+
246
+ dt = self.scheduler.sigmas[self.scheduler.step_index + 1] - sigma_hat
247
+
248
+ prev_sample = sample + derivative * dt
249
+
250
+ self.scheduler._step_index += 1
251
+
252
+ if return_pred_original_sample:
253
+ return (prev_sample, pred_original_sample)
254
+
255
+ return (prev_sample, )
256
+
257
+
258
+ @torch.no_grad()
259
+ def __call__(
260
+ self,
261
+ prompt: Union[str, List[str]] = None,
262
+ prompt_2: Optional[Union[str, List[str]]] = None,
263
+ num_inference_steps: int = 50,
264
+ timesteps: List[int] = None,
265
+ denoising_end: Optional[float] = None,
266
+ guidance_scale: float = 5,
267
+ negative_prompt: Optional[Union[str, List[str]]] = None,
268
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
269
+ num_images_per_prompt: Optional[int] = 1,
270
+ eta: float = 0.0,
271
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
272
+ latents: Optional[torch.FloatTensor] = None,
273
+ prompt_embeds: Optional[torch.FloatTensor] = None,
274
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
275
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
276
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
277
+ ip_adapter_image: Optional[PipelineImageInput] = None,
278
+ output_type: Optional[str] = "pil",
279
+ return_dict: bool = True,
280
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
281
+ guidance_rescale: float = 0.0,
282
+ original_size: Optional[Tuple[int, int]] = None,
283
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
284
+ target_size: Optional[Tuple[int, int]] = None,
285
+ negative_original_size: Optional[Tuple[int, int]] = None,
286
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
287
+ negative_target_size: Optional[Tuple[int, int]] = None,
288
+ clip_skip: Optional[int] = None,
289
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
290
+ callback_steps: Optional[int] = 1,
291
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
292
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
293
+ ### DiffuseHigh parameters ###
294
+ target_height: Union[int, List[int]] = [2048, 3072, 4096],
295
+ target_width: Union[int, List[int]] = [2048, 3072, 4096],
296
+ guidance_image: Optional[Union[torch.FloatTensor, PIL.Image.Image, np.ndarray]] = None,
297
+ noising_steps: int = 15,
298
+ diffusehigh_guidance_scale: float = 10.0,
299
+ # >>> DWT parameters
300
+ enable_dwt: bool = True,
301
+ dwt_level: Optional[int] = 1,
302
+ dwt_wave: Optional[str] = "db4",
303
+ dwt_mode: Optional[str] = "symmetric",
304
+ dwt_steps: Optional[int] = 5,
305
+ # >>> Sharpening parameters
306
+ enable_sharpening: bool = True,
307
+ sharpening_kernel_size: int = 3,
308
+ sharpening_sigma: Optional[Union[Tuple[float, float], float]] = (0.1, 2.0),
309
+ sharpening_alpha: float = 1.0,
310
+ **kwargs,
311
+ ):
312
+ r"""
313
+ Function invoked when calling the pipeline for generation.
314
+
315
+ Args:
316
+ prompt (`str` or `List[str]`, *optional*):
317
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
318
+ instead.
319
+ prompt_2 (`str` or `List[str]`, *optional*):
320
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
321
+ used in both text-encoders
322
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
323
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
324
+ Anything below 512 pixels won't work well for
325
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
326
+ and checkpoints that are not specifically fine-tuned on low resolutions.
327
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
328
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
329
+ Anything below 512 pixels won't work well for
330
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
331
+ and checkpoints that are not specifically fine-tuned on low resolutions.
332
+ num_inference_steps (`int`, *optional*, defaults to 50):
333
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
334
+ expense of slower inference.
335
+ timesteps (`List[int]`, *optional*):
336
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
337
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
338
+ passed will be used. Must be in descending order.
339
+ denoising_end (`float`, *optional*):
340
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
341
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
342
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
343
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
344
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
345
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
346
+ guidance_scale (`float`, *optional*, defaults to 5.0):
347
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
348
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
349
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
350
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
351
+ usually at the expense of lower image quality.
352
+ negative_prompt (`str` or `List[str]`, *optional*):
353
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
354
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
355
+ less than `1`).
356
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
357
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
358
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
359
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
360
+ The number of images to generate per prompt.
361
+ eta (`float`, *optional*, defaults to 0.0):
362
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
363
+ [`schedulers.DDIMScheduler`], will be ignored for others.
364
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
365
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
366
+ to make generation deterministic.
367
+ latents (`torch.FloatTensor`, *optional*):
368
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
369
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
370
+ tensor will ge generated by sampling using the supplied random `generator`.
371
+ prompt_embeds (`torch.FloatTensor`, *optional*):
372
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
373
+ provided, text embeddings will be generated from `prompt` input argument.
374
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
375
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
376
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
377
+ argument.
378
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
379
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
380
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
381
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
382
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
383
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
384
+ input argument.
385
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
386
+ output_type (`str`, *optional*, defaults to `"pil"`):
387
+ The output format of the generate image. Choose between
388
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
389
+ return_dict (`bool`, *optional*, defaults to `True`):
390
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
391
+ of a plain tuple.
392
+ cross_attention_kwargs (`dict`, *optional*):
393
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
394
+ `self.processor` in
395
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
396
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
397
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
398
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
399
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
400
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
401
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
402
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
403
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
404
+ explained in section 2.2 of
405
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
406
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
407
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
408
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
409
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
410
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
411
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
412
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
413
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
414
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
415
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
416
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
417
+ micro-conditioning as explained in section 2.2 of
418
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
419
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
420
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
421
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
422
+ micro-conditioning as explained in section 2.2 of
423
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
424
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
425
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
426
+ To negatively condition the generation process based on a target image resolution. It should be as same
427
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
428
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
429
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
430
+ callback_on_step_end (`Callable`, *optional*):
431
+ A function that calls at the end of each denoising steps during the inference. The function is called
432
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
433
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
434
+ `callback_on_step_end_tensor_inputs`.
435
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
436
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
437
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
438
+ `._callback_tensor_inputs` attribute of your pipeline class.
439
+ target_height ('List[int]' or int):
440
+ The height of the image being generated. If list is given, the pipeline generates corresponding intermediate
441
+ resolution images in a progressive manner.
442
+ target_width ('List[int]' or int):
443
+ The width of the image being generated. If list is given, the pipeline generates corresponding intermediate
444
+ resolution images in a progressive manner.
445
+
446
+
447
+
448
+ Examples:
449
+
450
+ Returns:
451
+ [`DiffuseHighSDXLPipelineOutput`] or `tuple`:
452
+ [`DiffuseHighSDXLPipelineOutput`] if `return_dict` is True, otherwise a
453
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
454
+ """
455
+ # 0. Default height and width to unet
456
+ height = self.default_sample_size * self.vae_scale_factor
457
+ width = self.default_sample_size * self.vae_scale_factor
458
+
459
+ original_size = original_size or (height, width)
460
+ target_size = target_size or (height, width)
461
+
462
+ # 1. Check inputs. Raise error if not correct
463
+ self.check_inputs(
464
+ prompt,
465
+ prompt_2,
466
+ height,
467
+ width,
468
+ callback_steps,
469
+ negative_prompt,
470
+ negative_prompt_2,
471
+ prompt_embeds,
472
+ negative_prompt_embeds,
473
+ pooled_prompt_embeds,
474
+ negative_pooled_prompt_embeds,
475
+ callback_on_step_end_tensor_inputs,
476
+ )
477
+
478
+ self._guidance_scale = guidance_scale
479
+ self._guidance_rescale = guidance_rescale
480
+ self._clip_skip = clip_skip
481
+ self._cross_attention_kwargs = cross_attention_kwargs
482
+ self._denoising_end = denoising_end
483
+
484
+ # 2. Define call parameters
485
+ if prompt is not None and isinstance(prompt, str):
486
+ batch_size = 1
487
+ elif prompt is not None and isinstance(prompt, list):
488
+ batch_size = len(prompt)
489
+ else:
490
+ batch_size = prompt_embeds.shape[0]
491
+
492
+ device = self._execution_device
493
+
494
+ # 3. Encode input prompt
495
+ lora_scale = (
496
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
497
+ )
498
+
499
+ (
500
+ prompt_embeds,
501
+ negative_prompt_embeds,
502
+ pooled_prompt_embeds,
503
+ negative_pooled_prompt_embeds,
504
+ ) = self.encode_prompt(
505
+ prompt=prompt,
506
+ prompt_2=prompt_2,
507
+ device=device,
508
+ num_images_per_prompt=num_images_per_prompt,
509
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
510
+ negative_prompt=negative_prompt,
511
+ negative_prompt_2=negative_prompt_2,
512
+ prompt_embeds=prompt_embeds,
513
+ negative_prompt_embeds=negative_prompt_embeds,
514
+ pooled_prompt_embeds=pooled_prompt_embeds,
515
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
516
+ lora_scale=lora_scale,
517
+ clip_skip=self.clip_skip,
518
+ )
519
+
520
+ # 4. Prepare timesteps
521
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
522
+
523
+ # 5. Prepare latent variables
524
+ num_channels_latents = self.unet.config.in_channels
525
+ latents = self.prepare_latents(
526
+ batch_size * num_images_per_prompt,
527
+ num_channels_latents,
528
+ height,
529
+ width,
530
+ prompt_embeds.dtype,
531
+ device,
532
+ generator,
533
+ latents,
534
+ )
535
+
536
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
537
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
538
+
539
+ # 7. Prepare added time ids & embeddings
540
+ add_text_embeds = pooled_prompt_embeds
541
+ if self.text_encoder_2 is None:
542
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
543
+ else:
544
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
545
+
546
+ add_time_ids = self._get_add_time_ids(
547
+ original_size,
548
+ crops_coords_top_left,
549
+ target_size,
550
+ dtype=prompt_embeds.dtype,
551
+ text_encoder_projection_dim=text_encoder_projection_dim,
552
+ )
553
+ if negative_original_size is not None and negative_target_size is not None:
554
+ negative_add_time_ids = self._get_add_time_ids(
555
+ negative_original_size,
556
+ negative_crops_coords_top_left,
557
+ negative_target_size,
558
+ dtype=prompt_embeds.dtype,
559
+ text_encoder_projection_dim=text_encoder_projection_dim,
560
+ )
561
+ else:
562
+ negative_add_time_ids = add_time_ids
563
+
564
+ if self.do_classifier_free_guidance:
565
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
566
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
567
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
568
+
569
+ prompt_embeds = prompt_embeds.to(device)
570
+ add_text_embeds = add_text_embeds.to(device)
571
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
572
+
573
+ if ip_adapter_image is not None:
574
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
575
+ if self.do_classifier_free_guidance:
576
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
577
+ image_embeds = image_embeds.to(device)
578
+
579
+ # 8. Denoising loop
580
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
581
+
582
+ # 8.1 Apply denoising_end
583
+ if (
584
+ self.denoising_end is not None
585
+ and isinstance(self.denoising_end, float)
586
+ and self.denoising_end > 0
587
+ and self.denoising_end < 1
588
+ ):
589
+ discrete_timestep_cutoff = int(
590
+ round(
591
+ self.scheduler.config.num_train_timesteps
592
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
593
+ )
594
+ )
595
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
596
+ timesteps = timesteps[:num_inference_steps]
597
+
598
+ # 9. Optionally get Guidance Scale Embedding
599
+ timestep_cond = None
600
+ if self.unet.config.time_cond_proj_dim is not None:
601
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
602
+ timestep_cond = self.get_guidance_scale_embedding(
603
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
604
+ ).to(device=device, dtype=latents.dtype)
605
+
606
+ # 10. Obtain clean image for structral guidance (can be given by user or generated)
607
+ if guidance_image is None:
608
+ self._num_timesteps = len(timesteps)
609
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
610
+ for i, t in enumerate(timesteps):
611
+ # expand the latents if we are doing classifier free guidance
612
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
613
+
614
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
615
+
616
+ # predict the noise residual
617
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
618
+ if ip_adapter_image is not None:
619
+ added_cond_kwargs["image_embeds"] = image_embeds
620
+ noise_pred = self.unet(
621
+ latent_model_input,
622
+ t,
623
+ encoder_hidden_states=prompt_embeds,
624
+ timestep_cond=timestep_cond,
625
+ cross_attention_kwargs=self.cross_attention_kwargs,
626
+ added_cond_kwargs=added_cond_kwargs,
627
+ return_dict=False,
628
+ )[0]
629
+
630
+ # perform guidance
631
+ if self.do_classifier_free_guidance:
632
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
633
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
634
+
635
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
636
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
637
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
638
+
639
+ # compute the previous noisy sample x_t -> x_t-1
640
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
641
+
642
+ if callback_on_step_end is not None:
643
+ callback_kwargs = {}
644
+ for k in callback_on_step_end_tensor_inputs:
645
+ callback_kwargs[k] = locals()[k]
646
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
647
+
648
+ latents = callback_outputs.pop("latents", latents)
649
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
650
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
651
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
652
+ negative_pooled_prompt_embeds = callback_outputs.pop(
653
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
654
+ )
655
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
656
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
657
+
658
+ # call the callback, if provided
659
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
660
+ progress_bar.update()
661
+ if callback is not None and i % callback_steps == 0:
662
+ step_idx = i // getattr(self.scheduler, "order", 1)
663
+ callback(step_idx, t, latents)
664
+
665
+ image = self._decode_vae_latent(latents, output_type='pt')
666
+ else:
667
+ image = self.image_processor.preprocess(guidance_image, height, width)
668
+ if self.image_processor.config.do_normalize:
669
+ image = (image + 1.) * 0.5
670
+
671
+ image = image.to(self.device)
672
+
673
+ original_guidance_image = image
674
+
675
+ # |-------------------------------- DiffuseHigh process --------------------------------|
676
+ # DWT & inverse DWT works on torch.float32
677
+ if enable_dwt:
678
+ self.DWT = DWTForward(J=dwt_level, wave=dwt_wave, mode=dwt_mode).to(self.device)
679
+ self.iDWT = DWTInverse(wave=dwt_wave, mode=dwt_mode).to(self.device)
680
+
681
+ # 11. Prepare progressive DiffuseHigh pipeline
682
+ self.scheduler.set_timesteps(num_inference_steps)
683
+ diffusehigh_timesteps = self.scheduler.timesteps[-noising_steps:]
684
+ self.enable_vae_tiling() # Vae tiling mode in order to prevent OOM issues
685
+
686
+ if isinstance(target_width, int):
687
+ target_width = [target_width]
688
+ if isinstance(target_height, int):
689
+ target_height = [target_height]
690
+
691
+ assert len(target_width) == len(target_height)
692
+
693
+ #12. Progressive DiffuseHigh Pipeline
694
+ for h, w in zip(target_height, target_width):
695
+ # interpolate the image to the desired resolution
696
+ guidance_image = F.interpolate(image, (h, w), mode="bicubic", align_corners=False)
697
+
698
+ # apply sharpening operation to the image
699
+ if enable_sharpening:
700
+ guidance_image = gaussian_blur_image_sharpening(
701
+ guidance_image,
702
+ kernel_size=sharpening_kernel_size,
703
+ sigma=sharpening_sigma,
704
+ alpha=sharpening_alpha,
705
+ )
706
+
707
+ # extract low-frequency component (structural guidance) from the guidance image
708
+ if enable_dwt:
709
+ LL, _ = self.DWT(guidance_image)
710
+
711
+ # obtain latent of the interpolated image and noise it
712
+ latents = self._encode_vae_image(guidance_image)
713
+ noise = randn_tensor(latents.shape, generator, device=latents.device, dtype=latents.dtype)
714
+ latents = self.scheduler.add_noise(latents, noise, diffusehigh_timesteps[None, 0])
715
+
716
+ for i, t in tqdm(enumerate(diffusehigh_timesteps), total=diffusehigh_timesteps.shape[0]):
717
+ # expand the latents if we are doing classifier free guidance
718
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
719
+
720
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
721
+
722
+ # predict the noise residual
723
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
724
+
725
+ noise_pred = self.unet(
726
+ latent_model_input,
727
+ t,
728
+ encoder_hidden_states=prompt_embeds,
729
+ timestep_cond=timestep_cond,
730
+ cross_attention_kwargs=self.cross_attention_kwargs, # None
731
+ added_cond_kwargs=added_cond_kwargs, # None
732
+ return_dict=False,
733
+ )[0]
734
+
735
+ # perform guidance
736
+ if self.do_classifier_free_guidance:
737
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
738
+ noise_pred = noise_pred_uncond + diffusehigh_guidance_scale * (noise_pred_text - noise_pred_uncond)
739
+
740
+ # EDM sampler step
741
+ latents = self.edm_scheduler_step(
742
+ noise_pred,
743
+ t,
744
+ latents,
745
+ **extra_step_kwargs,
746
+ LL_guidance=LL if (enable_dwt and i < dwt_steps) else None,
747
+ )[0]
748
+
749
+ image = self._decode_vae_latent(latents)
750
+
751
+ if isinstance(self.scheduler, EulerDiscreteScheduler):
752
+ self.scheduler._step_index = None
753
+
754
+ # Offload all models
755
+ self.maybe_free_model_hooks()
756
+
757
+ if output_type != 'pt':
758
+ image = self.image_processor.postprocess(image * 2 - 1, output_type=output_type)
759
+ guidance_image = self.image_processor.postprocess(original_guidance_image * 2 -1 , output_type=output_type)
760
+
761
+ if not return_dict:
762
+ return (image, guidance_image)
763
+
764
+ return DiffuseHighSDXLPipelineOutput(images=image, guidance_image=guidance_image)
765
+
766
+
767
+ def set_seeds(seed):
768
+ os.environ["PYTHONHASHSEED"] = str(seed)
769
+ np.random.seed(seed)
770
+ torch.manual_seed(seed)
771
+ torch.cuda.manual_seed(seed)
772
+ torch.backends.cudnn.deterministic = True
773
+ torch.backends.cudnn.benchmark = True
774
+
775
+ # DEBUGGING
776
+ if __name__ == "__main__":
777
+ set_seeds(23)
778
+
779
+ model = DiffuseHighSDXLPipeline.from_pretrained(
780
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, #scheduler=scheduler
781
+ ).to("cuda")
782
+
783
+ prompt = "Cinematic photo of delicious chocolate icecream."
784
+
785
+ negative_prompt = "blurry, ugly, duplicate, poorly drawn, deformed, mosaic"
786
+
787
+ image = model(
788
+ prompt,
789
+ negative_prompt=negative_prompt,
790
+ target_height=[2048, 3072, 4096],
791
+ target_width=[2048, 3072, 4096],
792
+ enable_dwt=True,
793
+ dwt_steps=5,
794
+ enable_sharpening=True,
795
+ sharpness_factor=1.0,
796
+ ).images[0]
797
+
798
+ image.save("sample.png")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ diffusers==0.24.0
2
+ accelerate
3
+ transformers
4
+ pywavelets
5
+ pytorch-wavelets
utils/utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+
5
+ def set_seeds(seed):
6
+ os.environ["PYTHONHASHSEED"] = str(seed)
7
+ np.random.seed(seed)
8
+ torch.manual_seed(seed)
9
+ torch.cuda.manual_seed(seed)
10
+ torch.backends.cudnn.deterministic = True
11
+ torch.backends.cudnn.benchmark = True