aroraaman's picture
Add all of `fourm`
3424266
raw
history blame
No virus
6.18 kB
# Copyright 2024 EPFL and Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
from diffusers import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from tqdm import tqdm
from fourm.utils import to_2tuple
def rescale_noise_cfg(noise_cfg, noise_pred_conditional, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_conditional.std(dim=list(range(1, noise_pred_conditional.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class PipelineCond(DiffusionPipeline):
"""Pipeline for conditional image generation.
This model inherits from `DiffusionPipeline`. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
model: The conditional diffusion model.
scheduler: A diffusion scheduler, e.g. see scheduling_ddpm.py
"""
def __init__(self, model: torch.nn.Module, scheduler: SchedulerMixin):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
@torch.no_grad()
def __call__(self,
cond: torch.Tensor,
generator: Optional[torch.Generator] = None,
timesteps: Optional[int] = None,
guidance_scale: float = 0.0,
guidance_rescale: float = 0.0,
image_size: Optional[Union[Tuple[int, int], int]] = None,
verbose: bool = True,
scheduler_timesteps_mode: str = 'trailing',
orig_res: Optional[Union[torch.LongTensor, Tuple[int, int]]] = None,
**kwargs) -> torch.Tensor:
"""The call function to the pipeline for conditional image generation.
Args:
cond: The conditional input to the model.
generator: A torch.Generator to make generation deterministic.
timesteps: The number of denoising steps. More denoising steps usually lead to a higher
quality image at the expense of slower inference. Defaults to the number of training
timesteps if not given.
guidance_scale: The scale of the classifier-free guidance. If set to 0.0, no guidance is used.
guidance_rescale: Rescaling factor to fix the variance when using guidance scaling.
image_size: The size of the image to generate. If not given, the default training size
of the model is used.
verbose: Whether to show a progress bar.
scheduler_timesteps_mode: The mode to use for DDIMScheduler. One of `trailing`, `linspace`,
`leading`. See https://arxiv.org/abs/2305.08891 for more details.
orig_res: The original resolution of the image to condition the diffusion on. Ignored if None.
See SDXL https://arxiv.org/abs/2307.01952 for more details.
Returns:
The generated image.
"""
timesteps = self.scheduler.config.num_train_timesteps if timesteps is None else timesteps
batch_size, _, _, _ = cond.shape
# Sample gaussian noise to begin loop
image_size = self.model.sample_size if image_size is None else image_size
image_size = to_2tuple(image_size)
image = torch.randn(
(batch_size, self.model.in_channels, image_size[0], image_size[1]),
generator=generator,
)
image = image.to(self.model.device)
do_cfg = callable(guidance_scale) or guidance_scale > 1.0
# Set step values
self.scheduler.set_timesteps(timesteps, mode=scheduler_timesteps_mode)
if verbose:
pbar = tqdm(total=len(self.scheduler.timesteps))
for t in self.scheduler.timesteps:
# 1. Predict noise model_output
model_output = self.model(image, t, cond, orig_res=orig_res, **kwargs)
if do_cfg:
model_output_uncond = self.model(image, t, cond, unconditional=True, **kwargs) # TODO: is there a better way to get unconditional output?
if callable(guidance_scale):
guidance_scale_value = guidance_scale(t/self.scheduler.config.num_train_timesteps)
else:
guidance_scale_value = guidance_scale
model_output_cfg = model_output_uncond + guidance_scale_value * (model_output - model_output_uncond)
if guidance_rescale > 0.0:
model_output = rescale_noise_cfg(model_output_cfg, model_output, guidance_rescale=guidance_rescale)
else:
model_output = model_output_cfg
# 2. Compute previous image: x_t -> t_t-1
with torch.cuda.amp.autocast(enabled=False):
image = self.scheduler.step(model_output.float(), t, image, generator=generator).prev_sample
if verbose:
pbar.update()
if verbose:
pbar.close()
return image