eggarsway's picture
main
2fc816b
raw
history blame contribute delete
No virus
5 kB
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from dataclasses import dataclass
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet3DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
deprecate,
logging,
replace_example_docstring,
BaseOutput,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import (
tensor2vid,
)
from ..CrossAttn.InjecterProc import InjecterProcessor
from ..Misc import Logger as log
from ..Misc import Const
def use_dd_temporal(unet, use=True):
""" To determine using the temporal attention editing at a step
"""
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "Attention" and "attn2" in name:
module.processor.use_dd_temporal = use
def use_dd(unet, use=True):
""" To determine using the spatial attention editing at a step
"""
for name, module in unet.named_modules():
module_name = type(module).__name__
# if module_name == "CrossAttention" and "attn2" in name:
if module_name == "Attention" and "attn2" in name:
module.processor.use_dd = use
def initiailization(unet, bundle, bbox_per_frame):
log.info("Intialization")
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "Attention" and "attn2" in name:
if "temp_attentions" in name:
processor = InjecterProcessor(
bundle=bundle,
bbox_per_frame=bbox_per_frame,
strengthen_scale=bundle["temp_strengthen_scale"],
weaken_scale=bundle["temp_weaken_scale"],
is_text2vidzero=False,
name=name,
)
else:
processor = InjecterProcessor(
bundle=bundle,
bbox_per_frame=bbox_per_frame,
strengthen_scale=bundle["spatial_strengthen_scale"],
weaken_scale=bundle["spatial_weaken_scale"],
is_text2vidzero=False,
name=name,
)
module.processor = processor
# print(name)
log.info("Initialized")
def keyframed_prompt_embeds(bundle, encode_prompt_func, device):
num_frames = bundle["keyframe"][-1]["frame"] + 1
keyframe = bundle["keyframe"]
f = lambda start, end, index: (1 - index) * start + index * end
n = len(keyframe)
keyed_prompt_embeds = []
for i in range(n - 1):
if i == 0:
start_fr = keyframe[i]["frame"]
else:
start_fr = keyframe[i]["frame"] + 1
end_fr = keyframe[i + 1]["frame"]
start_prompt = keyframe[i]["prompt"] + Const.POSITIVE_PROMPT
end_prompt = keyframe[i + 1]["prompt"] + Const.POSITIVE_PROMPT
clip_length = end_fr - start_fr + 1
start_prompt_embeds, _ = encode_prompt_func(
start_prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=Const.NEGATIVE_PROMPT,
)
end_prompt_embeds, negative_prompt_embeds = encode_prompt_func(
end_prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=Const.NEGATIVE_PROMPT,
)
for fr in range(clip_length):
index = float(fr) / (clip_length - 1)
keyed_prompt_embeds.append(f(start_prompt_embeds, end_prompt_embeds, index))
assert len(keyed_prompt_embeds) == num_frames
return torch.cat(keyed_prompt_embeds), negative_prompt_embeds.repeat_interleave(
num_frames, dim=0
)
def keyframed_bbox(bundle):
keyframe = bundle["keyframe"]
bbox_per_frame = []
f = lambda start, end, index: (1 - index) * start + index * end
n = len(keyframe)
for i in range(n - 1):
if i == 0:
start_fr = keyframe[i]["frame"]
else:
start_fr = keyframe[i]["frame"] + 1
end_fr = keyframe[i + 1]["frame"]
start_bbox = keyframe[i]["bbox_ratios"]
end_bbox = keyframe[i + 1]["bbox_ratios"]
clip_length = end_fr - start_fr + 1
for fr in range(clip_length):
index = float(fr) / (clip_length - 1)
bbox = []
for j in range(4):
bbox.append(f(start_bbox[j], end_bbox[j], index))
bbox_per_frame.append(bbox)
return bbox_per_frame