bubbliiiing commited on
Commit
14d2973
1 Parent(s): 933a5a0

update to v1.1

Browse files
app.py CHANGED
@@ -19,11 +19,14 @@ if __name__ == "__main__":
19
  server_port = 7860
20
 
21
  # Params below is used when ui_mode = "modelscope"
22
- model_name = "models/Diffusion_Transformer/CogVideoX-Fun-5b-InP"
 
 
 
23
  savedir_sample = "samples"
24
 
25
  if ui_mode == "modelscope":
26
- demo, controller = ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
27
  elif ui_mode == "eas":
28
  demo, controller = ui_eas(model_name, savedir_sample)
29
  else:
 
19
  server_port = 7860
20
 
21
  # Params below is used when ui_mode = "modelscope"
22
+ model_name = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-5b-InP"
23
+ # "Inpaint" or "Control"
24
+ model_type = "Inpaint"
25
+ # Save dir of this model
26
  savedir_sample = "samples"
27
 
28
  if ui_mode == "modelscope":
29
+ demo, controller = ui_modelscope(model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype)
30
  elif ui_mode == "eas":
31
  demo, controller = ui_eas(model_name, savedir_sample)
32
  else:
cogvideox/api/api.py CHANGED
@@ -68,6 +68,20 @@ def save_base64_video(base64_string):
68
 
69
  return file_path
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
72
  @app.post("/cogvideox_fun/infer_forward")
73
  def _infer_forward_api(
@@ -77,7 +91,7 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
77
  lora_model_path = datas.get('lora_model_path', 'none')
78
  lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
79
  prompt_textbox = datas.get('prompt_textbox', None)
80
- negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. ')
81
  sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
82
  sample_step_slider = datas.get('sample_step_slider', 30)
83
  resize_method = datas.get('resize_method', "Generate by")
@@ -93,6 +107,8 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
93
  start_image = datas.get('start_image', None)
94
  end_image = datas.get('end_image', None)
95
  validation_video = datas.get('validation_video', None)
 
 
96
  denoise_strength = datas.get('denoise_strength', 0.70)
97
  seed_textbox = datas.get("seed_textbox", 43)
98
 
@@ -109,6 +125,12 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
109
  if validation_video is not None:
110
  validation_video = save_base64_video(validation_video)
111
 
 
 
 
 
 
 
112
  try:
113
  save_sample_path, comment = controller.generate(
114
  "",
@@ -131,6 +153,8 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
131
  start_image,
132
  end_image,
133
  validation_video,
 
 
134
  denoise_strength,
135
  seed_textbox,
136
  is_api = True,
 
68
 
69
  return file_path
70
 
71
+ def save_base64_image(base64_string):
72
+ video_data = base64.b64decode(base64_string)
73
+
74
+ md5_hash = hashlib.md5(video_data).hexdigest()
75
+ filename = f"{md5_hash}.jpg"
76
+
77
+ temp_dir = tempfile.gettempdir()
78
+ file_path = os.path.join(temp_dir, filename)
79
+
80
+ with open(file_path, 'wb') as video_file:
81
+ video_file.write(video_data)
82
+
83
+ return file_path
84
+
85
  def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
86
  @app.post("/cogvideox_fun/infer_forward")
87
  def _infer_forward_api(
 
91
  lora_model_path = datas.get('lora_model_path', 'none')
92
  lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
93
  prompt_textbox = datas.get('prompt_textbox', None)
94
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
95
  sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
96
  sample_step_slider = datas.get('sample_step_slider', 30)
97
  resize_method = datas.get('resize_method', "Generate by")
 
107
  start_image = datas.get('start_image', None)
108
  end_image = datas.get('end_image', None)
109
  validation_video = datas.get('validation_video', None)
110
+ validation_video_mask = datas.get('validation_video_mask', None)
111
+ control_video = datas.get('control_video', None)
112
  denoise_strength = datas.get('denoise_strength', 0.70)
113
  seed_textbox = datas.get("seed_textbox", 43)
114
 
 
125
  if validation_video is not None:
126
  validation_video = save_base64_video(validation_video)
127
 
128
+ if validation_video_mask is not None:
129
+ validation_video_mask = save_base64_image(validation_video_mask)
130
+
131
+ if control_video is not None:
132
+ control_video = save_base64_video(control_video)
133
+
134
  try:
135
  save_sample_path, comment = controller.generate(
136
  "",
 
153
  start_image,
154
  end_image,
155
  validation_video,
156
+ validation_video_mask,
157
+ control_video,
158
  denoise_strength,
159
  seed_textbox,
160
  is_api = True,
cogvideox/api/post_infer.py CHANGED
@@ -33,7 +33,7 @@ def post_infer(generation_method, length_slider, url='http://127.0.0.1:7860'):
33
  "lora_model_path": "none",
34
  "lora_alpha_slider": 0.55,
35
  "prompt_textbox": "A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
36
- "negative_prompt_textbox": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. ",
37
  "sampler_dropdown": "Euler",
38
  "sample_step_slider": 50,
39
  "width_slider": 672,
 
33
  "lora_model_path": "none",
34
  "lora_alpha_slider": 0.55,
35
  "prompt_textbox": "A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
36
+ "negative_prompt_textbox": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ",
37
  "sampler_dropdown": "Euler",
38
  "sample_step_slider": 50,
39
  "width_slider": 672,
cogvideox/data/dataset_image_video.py CHANGED
@@ -322,3 +322,225 @@ class ImageVideoDataset(Dataset):
322
 
323
  return sample
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  return sample
324
 
325
+
326
+ class ImageVideoControlDataset(Dataset):
327
+ def __init__(
328
+ self,
329
+ ann_path, data_root=None,
330
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
331
+ image_sample_size=512,
332
+ video_repeat=0,
333
+ text_drop_ratio=-1,
334
+ enable_bucket=False,
335
+ video_length_drop_start=0.1,
336
+ video_length_drop_end=0.9,
337
+ enable_inpaint=False,
338
+ ):
339
+ # Loading annotations from files
340
+ print(f"loading annotations from {ann_path} ...")
341
+ if ann_path.endswith('.csv'):
342
+ with open(ann_path, 'r') as csvfile:
343
+ dataset = list(csv.DictReader(csvfile))
344
+ elif ann_path.endswith('.json'):
345
+ dataset = json.load(open(ann_path))
346
+
347
+ self.data_root = data_root
348
+
349
+ # It's used to balance num of images and videos.
350
+ self.dataset = []
351
+ for data in dataset:
352
+ if data.get('type', 'image') != 'video':
353
+ self.dataset.append(data)
354
+ if video_repeat > 0:
355
+ for _ in range(video_repeat):
356
+ for data in dataset:
357
+ if data.get('type', 'image') == 'video':
358
+ self.dataset.append(data)
359
+ del dataset
360
+
361
+ self.length = len(self.dataset)
362
+ print(f"data scale: {self.length}")
363
+ # TODO: enable bucket training
364
+ self.enable_bucket = enable_bucket
365
+ self.text_drop_ratio = text_drop_ratio
366
+ self.enable_inpaint = enable_inpaint
367
+
368
+ self.video_length_drop_start = video_length_drop_start
369
+ self.video_length_drop_end = video_length_drop_end
370
+
371
+ # Video params
372
+ self.video_sample_stride = video_sample_stride
373
+ self.video_sample_n_frames = video_sample_n_frames
374
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
375
+ self.video_transforms = transforms.Compose(
376
+ [
377
+ transforms.Resize(min(self.video_sample_size)),
378
+ transforms.CenterCrop(self.video_sample_size),
379
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
380
+ ]
381
+ )
382
+
383
+ # Image params
384
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
385
+ self.image_transforms = transforms.Compose([
386
+ transforms.Resize(min(self.image_sample_size)),
387
+ transforms.CenterCrop(self.image_sample_size),
388
+ transforms.ToTensor(),
389
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
390
+ ])
391
+
392
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
393
+
394
+ def get_batch(self, idx):
395
+ data_info = self.dataset[idx % len(self.dataset)]
396
+ video_id, control_video_id, text = data_info['file_path'], data_info['control_file_path'], data_info['text']
397
+
398
+ if data_info.get('type', 'image')=='video':
399
+ if self.data_root is None:
400
+ video_dir = video_id
401
+ else:
402
+ video_dir = os.path.join(self.data_root, video_id)
403
+
404
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
405
+ min_sample_n_frames = min(
406
+ self.video_sample_n_frames,
407
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
408
+ )
409
+ if min_sample_n_frames == 0:
410
+ raise ValueError(f"No Frames in video.")
411
+
412
+ video_length = int(self.video_length_drop_end * len(video_reader))
413
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
414
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
415
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
416
+
417
+ try:
418
+ sample_args = (video_reader, batch_index)
419
+ pixel_values = func_timeout(
420
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
421
+ )
422
+ resized_frames = []
423
+ for i in range(len(pixel_values)):
424
+ frame = pixel_values[i]
425
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
426
+ resized_frames.append(resized_frame)
427
+ pixel_values = np.array(resized_frames)
428
+ except FunctionTimedOut:
429
+ raise ValueError(f"Read {idx} timeout.")
430
+ except Exception as e:
431
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
432
+
433
+ if not self.enable_bucket:
434
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
435
+ pixel_values = pixel_values / 255.
436
+ del video_reader
437
+ else:
438
+ pixel_values = pixel_values
439
+
440
+ if not self.enable_bucket:
441
+ pixel_values = self.video_transforms(pixel_values)
442
+
443
+ # Random use no text generation
444
+ if random.random() < self.text_drop_ratio:
445
+ text = ''
446
+
447
+ if self.data_root is None:
448
+ control_video_id = control_video_id
449
+ else:
450
+ control_video_id = os.path.join(self.data_root, control_video_id)
451
+
452
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
453
+ try:
454
+ sample_args = (control_video_reader, batch_index)
455
+ control_pixel_values = func_timeout(
456
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
457
+ )
458
+ resized_frames = []
459
+ for i in range(len(control_pixel_values)):
460
+ frame = control_pixel_values[i]
461
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
462
+ resized_frames.append(resized_frame)
463
+ control_pixel_values = np.array(resized_frames)
464
+ except FunctionTimedOut:
465
+ raise ValueError(f"Read {idx} timeout.")
466
+ except Exception as e:
467
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
468
+
469
+ if not self.enable_bucket:
470
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
471
+ control_pixel_values = control_pixel_values / 255.
472
+ del control_video_reader
473
+ else:
474
+ control_pixel_values = control_pixel_values
475
+
476
+ if not self.enable_bucket:
477
+ control_pixel_values = self.video_transforms(control_pixel_values)
478
+ return pixel_values, control_pixel_values, text, "video"
479
+ else:
480
+ image_path, text = data_info['file_path'], data_info['text']
481
+ if self.data_root is not None:
482
+ image_path = os.path.join(self.data_root, image_path)
483
+ image = Image.open(image_path).convert('RGB')
484
+ if not self.enable_bucket:
485
+ image = self.image_transforms(image).unsqueeze(0)
486
+ else:
487
+ image = np.expand_dims(np.array(image), 0)
488
+
489
+ if random.random() < self.text_drop_ratio:
490
+ text = ''
491
+
492
+ if self.data_root is None:
493
+ control_image_id = control_image_id
494
+ else:
495
+ control_image_id = os.path.join(self.data_root, control_image_id)
496
+
497
+ control_image = Image.open(control_image_id).convert('RGB')
498
+ if not self.enable_bucket:
499
+ control_image = self.image_transforms(control_image).unsqueeze(0)
500
+ else:
501
+ control_image = np.expand_dims(np.array(control_image), 0)
502
+ return image, control_image, text, 'image'
503
+
504
+ def __len__(self):
505
+ return self.length
506
+
507
+ def __getitem__(self, idx):
508
+ data_info = self.dataset[idx % len(self.dataset)]
509
+ data_type = data_info.get('type', 'image')
510
+ while True:
511
+ sample = {}
512
+ try:
513
+ data_info_local = self.dataset[idx % len(self.dataset)]
514
+ data_type_local = data_info_local.get('type', 'image')
515
+ if data_type_local != data_type:
516
+ raise ValueError("data_type_local != data_type")
517
+
518
+ pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
519
+ sample["pixel_values"] = pixel_values
520
+ sample["control_pixel_values"] = control_pixel_values
521
+ sample["text"] = name
522
+ sample["data_type"] = data_type
523
+ sample["idx"] = idx
524
+
525
+ if len(sample) > 0:
526
+ break
527
+ except Exception as e:
528
+ print(e, self.dataset[idx % len(self.dataset)])
529
+ idx = random.randint(0, self.length-1)
530
+
531
+ if self.enable_inpaint and not self.enable_bucket:
532
+ mask = get_random_mask(pixel_values.size())
533
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
534
+ sample["mask_pixel_values"] = mask_pixel_values
535
+ sample["mask"] = mask
536
+
537
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
538
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
539
+ sample["clip_pixel_values"] = clip_pixel_values
540
+
541
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
542
+ if (mask == 1).all():
543
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
544
+ sample["ref_pixel_values"] = ref_pixel_values
545
+
546
+ return sample
cogvideox/models/transformer3d.py CHANGED
@@ -27,7 +27,7 @@ from diffusers.utils import is_torch_version, logging
27
  from diffusers.utils.torch_utils import maybe_allow_in_graph
28
  from diffusers.models.attention import Attention, FeedForward
29
  from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
30
- from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
31
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
  from diffusers.models.modeling_utils import ModelMixin
33
  from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
@@ -35,6 +35,44 @@ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
35
 
36
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  @maybe_allow_in_graph
40
  class CogVideoXBlock(nn.Module):
@@ -239,6 +277,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
239
  spatial_interpolation_scale: float = 1.875,
240
  temporal_interpolation_scale: float = 1.0,
241
  use_rotary_positional_embeddings: bool = False,
 
242
  ):
243
  super().__init__()
244
  inner_dim = num_attention_heads * attention_head_dim
@@ -414,6 +453,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
414
  timestep: Union[int, float, torch.LongTensor],
415
  timestep_cond: Optional[torch.Tensor] = None,
416
  inpaint_latents: Optional[torch.Tensor] = None,
 
417
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
418
  return_dict: bool = True,
419
  ):
@@ -432,6 +472,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
432
  # 2. Patch embedding
433
  if inpaint_latents is not None:
434
  hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
 
 
435
  hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
436
 
437
  # 3. Position embedding
 
27
  from diffusers.utils.torch_utils import maybe_allow_in_graph
28
  from diffusers.models.attention import Attention, FeedForward
29
  from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
30
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
31
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
  from diffusers.models.modeling_utils import ModelMixin
33
  from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
 
35
 
36
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
 
38
+ class CogVideoXPatchEmbed(nn.Module):
39
+ def __init__(
40
+ self,
41
+ patch_size: int = 2,
42
+ in_channels: int = 16,
43
+ embed_dim: int = 1920,
44
+ text_embed_dim: int = 4096,
45
+ bias: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+ self.patch_size = patch_size
49
+
50
+ self.proj = nn.Conv2d(
51
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
52
+ )
53
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
54
+
55
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
56
+ r"""
57
+ Args:
58
+ text_embeds (`torch.Tensor`):
59
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
60
+ image_embeds (`torch.Tensor`):
61
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
62
+ """
63
+ text_embeds = self.text_proj(text_embeds)
64
+
65
+ batch, num_frames, channels, height, width = image_embeds.shape
66
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
67
+ image_embeds = self.proj(image_embeds)
68
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
69
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
70
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
71
+
72
+ embeds = torch.cat(
73
+ [text_embeds, image_embeds], dim=1
74
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
75
+ return embeds
76
 
77
  @maybe_allow_in_graph
78
  class CogVideoXBlock(nn.Module):
 
277
  spatial_interpolation_scale: float = 1.875,
278
  temporal_interpolation_scale: float = 1.0,
279
  use_rotary_positional_embeddings: bool = False,
280
+ add_noise_in_inpaint_model: bool = False,
281
  ):
282
  super().__init__()
283
  inner_dim = num_attention_heads * attention_head_dim
 
453
  timestep: Union[int, float, torch.LongTensor],
454
  timestep_cond: Optional[torch.Tensor] = None,
455
  inpaint_latents: Optional[torch.Tensor] = None,
456
+ control_latents: Optional[torch.Tensor] = None,
457
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
458
  return_dict: bool = True,
459
  ):
 
472
  # 2. Patch embedding
473
  if inpaint_latents is not None:
474
  hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
475
+ if control_latents is not None:
476
+ hidden_states = torch.concat([hidden_states, control_latents], 2)
477
  hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
478
 
479
  # 3. Position embedding
cogvideox/pipeline/pipeline_cogvideox_control.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from einops import rearrange
24
+ from transformers import T5EncoderModel, T5Tokenizer
25
+
26
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
27
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
28
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
29
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
31
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from diffusers.video_processor import VideoProcessor
34
+ from diffusers.image_processor import VaeImageProcessor
35
+ from einops import rearrange
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```python
44
+ >>> import torch
45
+ >>> from diffusers import CogVideoX_Fun_Pipeline
46
+ >>> from diffusers.utils import export_to_video
47
+
48
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
49
+ >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
50
+ >>> prompt = (
51
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
52
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
53
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
54
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
55
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
56
+ ... "atmosphere of this unique musical performance."
57
+ ... )
58
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
59
+ >>> export_to_video(video, "output.mp4", fps=8)
60
+ ```
61
+ """
62
+
63
+
64
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
65
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
66
+ tw = tgt_width
67
+ th = tgt_height
68
+ h, w = src
69
+ r = h / w
70
+ if r > (th / tw):
71
+ resize_height = th
72
+ resize_width = int(round(th / h * w))
73
+ else:
74
+ resize_width = tw
75
+ resize_height = int(round(tw / w * h))
76
+
77
+ crop_top = int(round((th - resize_height) / 2.0))
78
+ crop_left = int(round((tw - resize_width) / 2.0))
79
+
80
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
81
+
82
+
83
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
84
+ def retrieve_timesteps(
85
+ scheduler,
86
+ num_inference_steps: Optional[int] = None,
87
+ device: Optional[Union[str, torch.device]] = None,
88
+ timesteps: Optional[List[int]] = None,
89
+ sigmas: Optional[List[float]] = None,
90
+ **kwargs,
91
+ ):
92
+ """
93
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
94
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
95
+
96
+ Args:
97
+ scheduler (`SchedulerMixin`):
98
+ The scheduler to get timesteps from.
99
+ num_inference_steps (`int`):
100
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
101
+ must be `None`.
102
+ device (`str` or `torch.device`, *optional*):
103
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
104
+ timesteps (`List[int]`, *optional*):
105
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
106
+ `num_inference_steps` and `sigmas` must be `None`.
107
+ sigmas (`List[float]`, *optional*):
108
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
109
+ `num_inference_steps` and `timesteps` must be `None`.
110
+
111
+ Returns:
112
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
113
+ second element is the number of inference steps.
114
+ """
115
+ if timesteps is not None and sigmas is not None:
116
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
117
+ if timesteps is not None:
118
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
119
+ if not accepts_timesteps:
120
+ raise ValueError(
121
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
122
+ f" timestep schedules. Please check whether you are using the correct scheduler."
123
+ )
124
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
125
+ timesteps = scheduler.timesteps
126
+ num_inference_steps = len(timesteps)
127
+ elif sigmas is not None:
128
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
129
+ if not accept_sigmas:
130
+ raise ValueError(
131
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
132
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
133
+ )
134
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ num_inference_steps = len(timesteps)
137
+ else:
138
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ return timesteps, num_inference_steps
141
+
142
+
143
+ @dataclass
144
+ class CogVideoX_Fun_PipelineOutput(BaseOutput):
145
+ r"""
146
+ Output class for CogVideo pipelines.
147
+
148
+ Args:
149
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
150
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
151
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
152
+ `(batch_size, num_frames, channels, height, width)`.
153
+ """
154
+
155
+ videos: torch.Tensor
156
+
157
+
158
+ class CogVideoX_Fun_Pipeline_Control(DiffusionPipeline):
159
+ r"""
160
+ Pipeline for text-to-video generation using CogVideoX.
161
+
162
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
163
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
164
+
165
+ Args:
166
+ vae ([`AutoencoderKL`]):
167
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
168
+ text_encoder ([`T5EncoderModel`]):
169
+ Frozen text-encoder. CogVideoX_Fun uses
170
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
171
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
172
+ tokenizer (`T5Tokenizer`):
173
+ Tokenizer of class
174
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
175
+ transformer ([`CogVideoXTransformer3DModel`]):
176
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
177
+ scheduler ([`SchedulerMixin`]):
178
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
179
+ """
180
+
181
+ _optional_components = []
182
+ model_cpu_offload_seq = "text_encoder->vae->transformer->vae"
183
+
184
+ _callback_tensor_inputs = [
185
+ "latents",
186
+ "prompt_embeds",
187
+ "negative_prompt_embeds",
188
+ ]
189
+
190
+ def __init__(
191
+ self,
192
+ tokenizer: T5Tokenizer,
193
+ text_encoder: T5EncoderModel,
194
+ vae: AutoencoderKLCogVideoX,
195
+ transformer: CogVideoXTransformer3DModel,
196
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
197
+ ):
198
+ super().__init__()
199
+
200
+ self.register_modules(
201
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
202
+ )
203
+ self.vae_scale_factor_spatial = (
204
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
205
+ )
206
+ self.vae_scale_factor_temporal = (
207
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
208
+ )
209
+
210
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
211
+
212
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
213
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
214
+ self.mask_processor = VaeImageProcessor(
215
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
216
+ )
217
+
218
+ def _get_t5_prompt_embeds(
219
+ self,
220
+ prompt: Union[str, List[str]] = None,
221
+ num_videos_per_prompt: int = 1,
222
+ max_sequence_length: int = 226,
223
+ device: Optional[torch.device] = None,
224
+ dtype: Optional[torch.dtype] = None,
225
+ ):
226
+ device = device or self._execution_device
227
+ dtype = dtype or self.text_encoder.dtype
228
+
229
+ prompt = [prompt] if isinstance(prompt, str) else prompt
230
+ batch_size = len(prompt)
231
+
232
+ text_inputs = self.tokenizer(
233
+ prompt,
234
+ padding="max_length",
235
+ max_length=max_sequence_length,
236
+ truncation=True,
237
+ add_special_tokens=True,
238
+ return_tensors="pt",
239
+ )
240
+ text_input_ids = text_inputs.input_ids
241
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
242
+
243
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
244
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
245
+ logger.warning(
246
+ "The following part of your input was truncated because `max_sequence_length` is set to "
247
+ f" {max_sequence_length} tokens: {removed_text}"
248
+ )
249
+
250
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
251
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
252
+
253
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
254
+ _, seq_len, _ = prompt_embeds.shape
255
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
256
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
257
+
258
+ return prompt_embeds
259
+
260
+ def encode_prompt(
261
+ self,
262
+ prompt: Union[str, List[str]],
263
+ negative_prompt: Optional[Union[str, List[str]]] = None,
264
+ do_classifier_free_guidance: bool = True,
265
+ num_videos_per_prompt: int = 1,
266
+ prompt_embeds: Optional[torch.Tensor] = None,
267
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
268
+ max_sequence_length: int = 226,
269
+ device: Optional[torch.device] = None,
270
+ dtype: Optional[torch.dtype] = None,
271
+ ):
272
+ r"""
273
+ Encodes the prompt into text encoder hidden states.
274
+
275
+ Args:
276
+ prompt (`str` or `List[str]`, *optional*):
277
+ prompt to be encoded
278
+ negative_prompt (`str` or `List[str]`, *optional*):
279
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
280
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
281
+ less than `1`).
282
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
283
+ Whether to use classifier free guidance or not.
284
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
285
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
286
+ prompt_embeds (`torch.Tensor`, *optional*):
287
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
288
+ provided, text embeddings will be generated from `prompt` input argument.
289
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
290
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
291
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
292
+ argument.
293
+ device: (`torch.device`, *optional*):
294
+ torch device
295
+ dtype: (`torch.dtype`, *optional*):
296
+ torch dtype
297
+ """
298
+ device = device or self._execution_device
299
+
300
+ prompt = [prompt] if isinstance(prompt, str) else prompt
301
+ if prompt is not None:
302
+ batch_size = len(prompt)
303
+ else:
304
+ batch_size = prompt_embeds.shape[0]
305
+
306
+ if prompt_embeds is None:
307
+ prompt_embeds = self._get_t5_prompt_embeds(
308
+ prompt=prompt,
309
+ num_videos_per_prompt=num_videos_per_prompt,
310
+ max_sequence_length=max_sequence_length,
311
+ device=device,
312
+ dtype=dtype,
313
+ )
314
+
315
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
316
+ negative_prompt = negative_prompt or ""
317
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
318
+
319
+ if prompt is not None and type(prompt) is not type(negative_prompt):
320
+ raise TypeError(
321
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
322
+ f" {type(prompt)}."
323
+ )
324
+ elif batch_size != len(negative_prompt):
325
+ raise ValueError(
326
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
327
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
328
+ " the batch size of `prompt`."
329
+ )
330
+
331
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
332
+ prompt=negative_prompt,
333
+ num_videos_per_prompt=num_videos_per_prompt,
334
+ max_sequence_length=max_sequence_length,
335
+ device=device,
336
+ dtype=dtype,
337
+ )
338
+
339
+ return prompt_embeds, negative_prompt_embeds
340
+
341
+ def prepare_latents(
342
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
343
+ ):
344
+ shape = (
345
+ batch_size,
346
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
347
+ num_channels_latents,
348
+ height // self.vae_scale_factor_spatial,
349
+ width // self.vae_scale_factor_spatial,
350
+ )
351
+ if isinstance(generator, list) and len(generator) != batch_size:
352
+ raise ValueError(
353
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
354
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
355
+ )
356
+
357
+ if latents is None:
358
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
359
+ else:
360
+ latents = latents.to(device)
361
+
362
+ # scale the initial noise by the standard deviation required by the scheduler
363
+ latents = latents * self.scheduler.init_noise_sigma
364
+ return latents
365
+
366
+ def prepare_control_latents(
367
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
368
+ ):
369
+ # resize the mask to latents shape as we concatenate the mask to the latents
370
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
371
+ # and half precision
372
+
373
+ if mask is not None:
374
+ mask = mask.to(device=device, dtype=self.vae.dtype)
375
+ bs = 1
376
+ new_mask = []
377
+ for i in range(0, mask.shape[0], bs):
378
+ mask_bs = mask[i : i + bs]
379
+ mask_bs = self.vae.encode(mask_bs)[0]
380
+ mask_bs = mask_bs.mode()
381
+ new_mask.append(mask_bs)
382
+ mask = torch.cat(new_mask, dim = 0)
383
+ mask = mask * self.vae.config.scaling_factor
384
+
385
+ if masked_image is not None:
386
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
387
+ bs = 1
388
+ new_mask_pixel_values = []
389
+ for i in range(0, masked_image.shape[0], bs):
390
+ mask_pixel_values_bs = masked_image[i : i + bs]
391
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
392
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
393
+ new_mask_pixel_values.append(mask_pixel_values_bs)
394
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
395
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
396
+ else:
397
+ masked_image_latents = None
398
+
399
+ return mask, masked_image_latents
400
+
401
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
402
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
403
+ latents = 1 / self.vae.config.scaling_factor * latents
404
+
405
+ frames = self.vae.decode(latents).sample
406
+ frames = (frames / 2 + 0.5).clamp(0, 1)
407
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
408
+ frames = frames.cpu().float().numpy()
409
+ return frames
410
+
411
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
412
+ def prepare_extra_step_kwargs(self, generator, eta):
413
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
414
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
415
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
416
+ # and should be between [0, 1]
417
+
418
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
419
+ extra_step_kwargs = {}
420
+ if accepts_eta:
421
+ extra_step_kwargs["eta"] = eta
422
+
423
+ # check if the scheduler accepts generator
424
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
425
+ if accepts_generator:
426
+ extra_step_kwargs["generator"] = generator
427
+ return extra_step_kwargs
428
+
429
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
430
+ def check_inputs(
431
+ self,
432
+ prompt,
433
+ height,
434
+ width,
435
+ negative_prompt,
436
+ callback_on_step_end_tensor_inputs,
437
+ prompt_embeds=None,
438
+ negative_prompt_embeds=None,
439
+ ):
440
+ if height % 8 != 0 or width % 8 != 0:
441
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
442
+
443
+ if callback_on_step_end_tensor_inputs is not None and not all(
444
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
445
+ ):
446
+ raise ValueError(
447
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
448
+ )
449
+ if prompt is not None and prompt_embeds is not None:
450
+ raise ValueError(
451
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
452
+ " only forward one of the two."
453
+ )
454
+ elif prompt is None and prompt_embeds is None:
455
+ raise ValueError(
456
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
457
+ )
458
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
459
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
460
+
461
+ if prompt is not None and negative_prompt_embeds is not None:
462
+ raise ValueError(
463
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
464
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
465
+ )
466
+
467
+ if negative_prompt is not None and negative_prompt_embeds is not None:
468
+ raise ValueError(
469
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
470
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
471
+ )
472
+
473
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
474
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
475
+ raise ValueError(
476
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
477
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
478
+ f" {negative_prompt_embeds.shape}."
479
+ )
480
+
481
+ def fuse_qkv_projections(self) -> None:
482
+ r"""Enables fused QKV projections."""
483
+ self.fusing_transformer = True
484
+ self.transformer.fuse_qkv_projections()
485
+
486
+ def unfuse_qkv_projections(self) -> None:
487
+ r"""Disable QKV projection fusion if enabled."""
488
+ if not self.fusing_transformer:
489
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
490
+ else:
491
+ self.transformer.unfuse_qkv_projections()
492
+ self.fusing_transformer = False
493
+
494
+ def _prepare_rotary_positional_embeddings(
495
+ self,
496
+ height: int,
497
+ width: int,
498
+ num_frames: int,
499
+ device: torch.device,
500
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
501
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
502
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
503
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
504
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
505
+
506
+ grid_crops_coords = get_resize_crop_region_for_grid(
507
+ (grid_height, grid_width), base_size_width, base_size_height
508
+ )
509
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
510
+ embed_dim=self.transformer.config.attention_head_dim,
511
+ crops_coords=grid_crops_coords,
512
+ grid_size=(grid_height, grid_width),
513
+ temporal_size=num_frames,
514
+ use_real=True,
515
+ )
516
+
517
+ freqs_cos = freqs_cos.to(device=device)
518
+ freqs_sin = freqs_sin.to(device=device)
519
+ return freqs_cos, freqs_sin
520
+
521
+ @property
522
+ def guidance_scale(self):
523
+ return self._guidance_scale
524
+
525
+ @property
526
+ def num_timesteps(self):
527
+ return self._num_timesteps
528
+
529
+ @property
530
+ def interrupt(self):
531
+ return self._interrupt
532
+
533
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
534
+ def get_timesteps(self, num_inference_steps, strength, device):
535
+ # get the original timestep using init_timestep
536
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
537
+
538
+ t_start = max(num_inference_steps - init_timestep, 0)
539
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
540
+
541
+ return timesteps, num_inference_steps - t_start
542
+
543
+ @torch.no_grad()
544
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
545
+ def __call__(
546
+ self,
547
+ prompt: Optional[Union[str, List[str]]] = None,
548
+ negative_prompt: Optional[Union[str, List[str]]] = None,
549
+ height: int = 480,
550
+ width: int = 720,
551
+ video: Union[torch.FloatTensor] = None,
552
+ control_video: Union[torch.FloatTensor] = None,
553
+ num_frames: int = 49,
554
+ num_inference_steps: int = 50,
555
+ timesteps: Optional[List[int]] = None,
556
+ guidance_scale: float = 6,
557
+ use_dynamic_cfg: bool = False,
558
+ num_videos_per_prompt: int = 1,
559
+ eta: float = 0.0,
560
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
561
+ latents: Optional[torch.FloatTensor] = None,
562
+ prompt_embeds: Optional[torch.FloatTensor] = None,
563
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
564
+ output_type: str = "numpy",
565
+ return_dict: bool = False,
566
+ callback_on_step_end: Optional[
567
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
568
+ ] = None,
569
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
570
+ max_sequence_length: int = 226,
571
+ comfyui_progressbar: bool = False,
572
+ ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
573
+ """
574
+ Function invoked when calling the pipeline for generation.
575
+
576
+ Args:
577
+ prompt (`str` or `List[str]`, *optional*):
578
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
579
+ instead.
580
+ negative_prompt (`str` or `List[str]`, *optional*):
581
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
582
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
583
+ less than `1`).
584
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
585
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
586
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
587
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
588
+ num_frames (`int`, defaults to `48`):
589
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
590
+ contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
591
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
592
+ needs to be satisfied is that of divisibility mentioned above.
593
+ num_inference_steps (`int`, *optional*, defaults to 50):
594
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
595
+ expense of slower inference.
596
+ timesteps (`List[int]`, *optional*):
597
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
598
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
599
+ passed will be used. Must be in descending order.
600
+ guidance_scale (`float`, *optional*, defaults to 7.0):
601
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
602
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
603
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
604
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
605
+ usually at the expense of lower image quality.
606
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
607
+ The number of videos to generate per prompt.
608
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
609
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
610
+ to make generation deterministic.
611
+ latents (`torch.FloatTensor`, *optional*):
612
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
613
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
614
+ tensor will ge generated by sampling using the supplied random `generator`.
615
+ prompt_embeds (`torch.FloatTensor`, *optional*):
616
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
617
+ provided, text embeddings will be generated from `prompt` input argument.
618
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
619
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
620
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
621
+ argument.
622
+ output_type (`str`, *optional*, defaults to `"pil"`):
623
+ The output format of the generate image. Choose between
624
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
625
+ return_dict (`bool`, *optional*, defaults to `True`):
626
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
627
+ of a plain tuple.
628
+ callback_on_step_end (`Callable`, *optional*):
629
+ A function that calls at the end of each denoising steps during the inference. The function is called
630
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
631
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
632
+ `callback_on_step_end_tensor_inputs`.
633
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
634
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
635
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
636
+ `._callback_tensor_inputs` attribute of your pipeline class.
637
+ max_sequence_length (`int`, defaults to `226`):
638
+ Maximum sequence length in encoded prompt. Must be consistent with
639
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
640
+
641
+ Examples:
642
+
643
+ Returns:
644
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
645
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
646
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
647
+ """
648
+
649
+ if num_frames > 49:
650
+ raise ValueError(
651
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
652
+ )
653
+
654
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
655
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
656
+
657
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
658
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
659
+ num_videos_per_prompt = 1
660
+
661
+ # 1. Check inputs. Raise error if not correct
662
+ self.check_inputs(
663
+ prompt,
664
+ height,
665
+ width,
666
+ negative_prompt,
667
+ callback_on_step_end_tensor_inputs,
668
+ prompt_embeds,
669
+ negative_prompt_embeds,
670
+ )
671
+ self._guidance_scale = guidance_scale
672
+ self._interrupt = False
673
+
674
+ # 2. Default call parameters
675
+ if prompt is not None and isinstance(prompt, str):
676
+ batch_size = 1
677
+ elif prompt is not None and isinstance(prompt, list):
678
+ batch_size = len(prompt)
679
+ else:
680
+ batch_size = prompt_embeds.shape[0]
681
+
682
+ device = self._execution_device
683
+
684
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
685
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
686
+ # corresponds to doing no classifier free guidance.
687
+ do_classifier_free_guidance = guidance_scale > 1.0
688
+
689
+ # 3. Encode input prompt
690
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
691
+ prompt,
692
+ negative_prompt,
693
+ do_classifier_free_guidance,
694
+ num_videos_per_prompt=num_videos_per_prompt,
695
+ prompt_embeds=prompt_embeds,
696
+ negative_prompt_embeds=negative_prompt_embeds,
697
+ max_sequence_length=max_sequence_length,
698
+ device=device,
699
+ )
700
+ if do_classifier_free_guidance:
701
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
702
+
703
+ # 4. Prepare timesteps
704
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
705
+ self._num_timesteps = len(timesteps)
706
+ if comfyui_progressbar:
707
+ from comfy.utils import ProgressBar
708
+ pbar = ProgressBar(num_inference_steps + 2)
709
+
710
+ # 5. Prepare latents.
711
+ latent_channels = self.vae.config.latent_channels
712
+ latents = self.prepare_latents(
713
+ batch_size * num_videos_per_prompt,
714
+ latent_channels,
715
+ num_frames,
716
+ height,
717
+ width,
718
+ prompt_embeds.dtype,
719
+ device,
720
+ generator,
721
+ latents,
722
+ )
723
+ if comfyui_progressbar:
724
+ pbar.update(1)
725
+
726
+ if control_video is not None:
727
+ video_length = control_video.shape[2]
728
+ control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
729
+ control_video = control_video.to(dtype=torch.float32)
730
+ control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
731
+ else:
732
+ control_video = None
733
+ control_video_latents = self.prepare_control_latents(
734
+ None,
735
+ control_video,
736
+ batch_size,
737
+ height,
738
+ width,
739
+ prompt_embeds.dtype,
740
+ device,
741
+ generator,
742
+ do_classifier_free_guidance
743
+ )[1]
744
+ control_video_latents_input = (
745
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
746
+ )
747
+ control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")
748
+
749
+ if comfyui_progressbar:
750
+ pbar.update(1)
751
+
752
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
753
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
754
+
755
+ # 7. Create rotary embeds if required
756
+ image_rotary_emb = (
757
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
758
+ if self.transformer.config.use_rotary_positional_embeddings
759
+ else None
760
+ )
761
+
762
+ # 8. Denoising loop
763
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
764
+
765
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
766
+ # for DPM-solver++
767
+ old_pred_original_sample = None
768
+ for i, t in enumerate(timesteps):
769
+ if self.interrupt:
770
+ continue
771
+
772
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
773
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
774
+
775
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
776
+ timestep = t.expand(latent_model_input.shape[0])
777
+
778
+ # predict noise model_output
779
+ noise_pred = self.transformer(
780
+ hidden_states=latent_model_input,
781
+ encoder_hidden_states=prompt_embeds,
782
+ timestep=timestep,
783
+ image_rotary_emb=image_rotary_emb,
784
+ return_dict=False,
785
+ control_latents=control_latents,
786
+ )[0]
787
+ noise_pred = noise_pred.float()
788
+
789
+ # perform guidance
790
+ if use_dynamic_cfg:
791
+ self._guidance_scale = 1 + guidance_scale * (
792
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
793
+ )
794
+ if do_classifier_free_guidance:
795
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
796
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
797
+
798
+ # compute the previous noisy sample x_t -> x_t-1
799
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
800
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
801
+ else:
802
+ latents, old_pred_original_sample = self.scheduler.step(
803
+ noise_pred,
804
+ old_pred_original_sample,
805
+ t,
806
+ timesteps[i - 1] if i > 0 else None,
807
+ latents,
808
+ **extra_step_kwargs,
809
+ return_dict=False,
810
+ )
811
+ latents = latents.to(prompt_embeds.dtype)
812
+
813
+ # call the callback, if provided
814
+ if callback_on_step_end is not None:
815
+ callback_kwargs = {}
816
+ for k in callback_on_step_end_tensor_inputs:
817
+ callback_kwargs[k] = locals()[k]
818
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
819
+
820
+ latents = callback_outputs.pop("latents", latents)
821
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
822
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
823
+
824
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
825
+ progress_bar.update()
826
+ if comfyui_progressbar:
827
+ pbar.update(1)
828
+
829
+ if output_type == "numpy":
830
+ video = self.decode_latents(latents)
831
+ elif not output_type == "latent":
832
+ video = self.decode_latents(latents)
833
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
834
+ else:
835
+ video = latents
836
+
837
+ # Offload all models
838
+ self.maybe_free_model_hooks()
839
+
840
+ if not return_dict:
841
+ video = torch.from_numpy(video)
842
+
843
+ return CogVideoX_Fun_PipelineOutput(videos=video)
cogvideox/pipeline/pipeline_cogvideox_inpaint.py CHANGED
@@ -177,6 +177,19 @@ def resize_mask(mask, latent, process_first_frame_only=True):
177
  return resized_mask
178
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  @dataclass
181
  class CogVideoX_Fun_PipelineOutput(BaseOutput):
182
  r"""
@@ -444,7 +457,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
444
  return outputs
445
 
446
  def prepare_mask_latents(
447
- self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
448
  ):
449
  # resize the mask to latents shape as we concatenate the mask to the latents
450
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
@@ -463,6 +476,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
463
  mask = mask * self.vae.config.scaling_factor
464
 
465
  if masked_image is not None:
 
 
466
  masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
467
  bs = 1
468
  new_mask_pixel_values = []
@@ -650,6 +665,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
650
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
651
  max_sequence_length: int = 226,
652
  strength: float = 1,
 
653
  comfyui_progressbar: bool = False,
654
  ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
655
  """
@@ -866,6 +882,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
866
  device,
867
  generator,
868
  do_classifier_free_guidance,
 
869
  )
870
  mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
871
  mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
 
177
  return resized_mask
178
 
179
 
180
+ def add_noise_to_reference_video(image, ratio=None):
181
+ if ratio is None:
182
+ sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
183
+ sigma = torch.exp(sigma).to(image.dtype)
184
+ else:
185
+ sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
186
+
187
+ image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
188
+ image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
189
+ image = image + image_noise
190
+ return image
191
+
192
+
193
  @dataclass
194
  class CogVideoX_Fun_PipelineOutput(BaseOutput):
195
  r"""
 
457
  return outputs
458
 
459
  def prepare_mask_latents(
460
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
461
  ):
462
  # resize the mask to latents shape as we concatenate the mask to the latents
463
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
 
476
  mask = mask * self.vae.config.scaling_factor
477
 
478
  if masked_image is not None:
479
+ if self.transformer.config.add_noise_in_inpaint_model:
480
+ masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
481
  masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
482
  bs = 1
483
  new_mask_pixel_values = []
 
665
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
666
  max_sequence_length: int = 226,
667
  strength: float = 1,
668
+ noise_aug_strength: float = 0.0563,
669
  comfyui_progressbar: bool = False,
670
  ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
671
  """
 
882
  device,
883
  generator,
884
  do_classifier_free_guidance,
885
+ noise_aug_strength=noise_aug_strength,
886
  )
887
  mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
888
  mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
cogvideox/ui/ui.py CHANGED
@@ -30,6 +30,8 @@ from cogvideox.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
30
  from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
31
  from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
32
  from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
 
 
33
  from cogvideox.pipeline.pipeline_cogvideox_inpaint import \
34
  CogVideoX_Fun_Pipeline_Inpaint
35
  from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
@@ -58,7 +60,7 @@ css = """
58
  }
59
  """
60
 
61
- class CogVideoX_I2VController:
62
  def __init__(self, low_gpu_memory_mode, weight_dtype):
63
  # config dirs
64
  self.basedir = os.getcwd()
@@ -68,6 +70,7 @@ class CogVideoX_I2VController:
68
  self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
69
  self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
70
  self.savedir_sample = os.path.join(self.savedir, "sample")
 
71
  os.makedirs(self.savedir, exist_ok=True)
72
 
73
  self.diffusion_transformer_list = []
@@ -102,6 +105,9 @@ class CogVideoX_I2VController:
102
  personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
103
  self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
104
 
 
 
 
105
  def update_diffusion_transformer(self, diffusion_transformer_dropdown):
106
  print("Update diffusion transformer")
107
  if diffusion_transformer_dropdown == "none":
@@ -118,16 +124,25 @@ class CogVideoX_I2VController:
118
  ).to(self.weight_dtype)
119
 
120
  # Get pipeline
121
- if self.transformer.config.in_channels != self.vae.config.latent_channels:
122
- self.pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
123
- diffusion_transformer_dropdown,
124
- vae=self.vae,
125
- transformer=self.transformer,
126
- scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
127
- torch_dtype=self.weight_dtype
128
- )
 
 
 
 
 
 
 
 
 
129
  else:
130
- self.pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
131
  diffusion_transformer_dropdown,
132
  vae=self.vae,
133
  transformer=self.transformer,
@@ -191,6 +206,8 @@ class CogVideoX_I2VController:
191
  start_image,
192
  end_image,
193
  validation_video,
 
 
194
  denoise_strength,
195
  seed_textbox,
196
  is_api = False,
@@ -208,20 +225,34 @@ class CogVideoX_I2VController:
208
  if self.lora_model_path != lora_model_dropdown:
209
  print("Update lora model")
210
  self.update_lora_model(lora_model_dropdown)
211
-
 
 
 
 
 
 
 
 
 
 
 
 
212
  if resize_method == "Resize according to Reference":
213
- if start_image is None and validation_video is None:
214
  if is_api:
215
  return "", f"Please upload an image when using \"Resize according to Reference\"."
216
  else:
217
  raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
218
 
219
  aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
220
-
221
- if validation_video is not None:
222
- original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
 
 
223
  else:
224
- original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
225
  closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
226
  height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
227
 
@@ -255,75 +286,91 @@ class CogVideoX_I2VController:
255
  generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
256
 
257
  try:
258
- if self.transformer.config.in_channels != self.vae.config.latent_channels:
259
- if generation_method == "Long Video Generation":
260
- if validation_video is not None:
261
- raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
262
- init_frames = 0
263
- last_frames = init_frames + partial_video_length
264
- while init_frames < length_slider:
265
- if last_frames >= length_slider:
266
- _partial_video_length = length_slider - init_frames
267
- _partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
- if _partial_video_length <= 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  break
271
- else:
272
- _partial_video_length = partial_video_length
273
 
274
- if last_frames >= length_slider:
275
- input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
276
- else:
277
- input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
278
-
279
- with torch.no_grad():
280
- sample = self.pipeline(
281
- prompt_textbox,
282
- negative_prompt = negative_prompt_textbox,
283
- num_inference_steps = sample_step_slider,
284
- guidance_scale = cfg_scale_slider,
285
- width = width_slider,
286
- height = height_slider,
287
- num_frames = _partial_video_length,
288
- generator = generator,
289
-
290
- video = input_video,
291
- mask_video = input_video_mask,
292
- strength = 1,
293
- ).videos
294
-
295
- if init_frames != 0:
296
- mix_ratio = torch.from_numpy(
297
- np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
298
- ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
299
-
300
- new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
301
- sample[:, :, :overlap_video_length] * mix_ratio
302
- new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
303
 
304
- sample = new_sample
 
 
 
 
 
305
  else:
306
- new_sample = sample
307
-
308
- if last_frames >= length_slider:
309
- break
310
-
311
- start_image = [
312
- Image.fromarray(
313
- (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
314
- ) for _index in range(-overlap_video_length, 0)
315
- ]
316
-
317
- init_frames = init_frames + _partial_video_length - overlap_video_length
318
- last_frames = init_frames + _partial_video_length
 
 
 
 
319
  else:
320
- if validation_video is not None:
321
- input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
322
- strength = denoise_strength
323
- else:
324
- input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
325
- strength = 1
326
-
327
  sample = self.pipeline(
328
  prompt_textbox,
329
  negative_prompt = negative_prompt_textbox,
@@ -332,13 +379,11 @@ class CogVideoX_I2VController:
332
  width = width_slider,
333
  height = height_slider,
334
  num_frames = length_slider if not is_image else 1,
335
- generator = generator,
336
-
337
- video = input_video,
338
- mask_video = input_video_mask,
339
- strength = strength,
340
  ).videos
341
  else:
 
 
342
  sample = self.pipeline(
343
  prompt_textbox,
344
  negative_prompt = negative_prompt_textbox,
@@ -347,7 +392,9 @@ class CogVideoX_I2VController:
347
  width = width_slider,
348
  height = height_slider,
349
  num_frames = length_slider if not is_image else 1,
350
- generator = generator
 
 
351
  ).videos
352
  except Exception as e:
353
  gc.collect()
@@ -422,7 +469,7 @@ class CogVideoX_I2VController:
422
 
423
 
424
  def ui(low_gpu_memory_mode, weight_dtype):
425
- controller = CogVideoX_I2VController(low_gpu_memory_mode, weight_dtype)
426
 
427
  with gr.Blocks(css=css) as demo:
428
  gr.Markdown(
@@ -437,7 +484,20 @@ def ui(low_gpu_memory_mode, weight_dtype):
437
  with gr.Column(variant="panel"):
438
  gr.Markdown(
439
  """
440
- ### 1. Model checkpoints (模型路径).
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  """
442
  )
443
  with gr.Row():
@@ -488,12 +548,12 @@ def ui(low_gpu_memory_mode, weight_dtype):
488
  with gr.Column(variant="panel"):
489
  gr.Markdown(
490
  """
491
- ### 2. Configs for Generation (生成参数配置).
492
  """
493
  )
494
 
495
  prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
496
- negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. " )
497
 
498
  with gr.Row():
499
  with gr.Column():
@@ -522,7 +582,7 @@ def ui(low_gpu_memory_mode, weight_dtype):
522
  partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False)
523
 
524
  source_method = gr.Radio(
525
- ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"],
526
  value="Text to Video (文本到视频)",
527
  show_label=False,
528
  )
@@ -535,7 +595,7 @@ def ui(low_gpu_memory_mode, weight_dtype):
535
  template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
536
  def select_template(evt: gr.SelectData):
537
  text = {
538
- "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
539
  "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
540
  "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
541
  "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
@@ -557,13 +617,36 @@ def ui(low_gpu_memory_mode, weight_dtype):
557
  end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
558
 
559
  with gr.Column(visible = False) as video_to_video_col:
560
- validation_video = gr.Video(
561
- label="The video to convert (视频转视频的参考视频)", show_label=True,
562
- elem_id="v2v", sources="upload",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
  )
564
- denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=0.95, step=0.01)
565
 
566
- cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=7.0, minimum=0, maximum=20)
567
 
568
  with gr.Row():
569
  seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
@@ -585,6 +668,12 @@ def ui(low_gpu_memory_mode, weight_dtype):
585
  interactive=False
586
  )
587
 
 
 
 
 
 
 
588
  def upload_generation_method(generation_method):
589
  if generation_method == "Video Generation":
590
  return [gr.update(visible=True, maximum=49, value=49), gr.update(visible=False), gr.update(visible=False)]
@@ -598,13 +687,18 @@ def ui(low_gpu_memory_mode, weight_dtype):
598
 
599
  def upload_source_method(source_method):
600
  if source_method == "Text to Video (文本到视频)":
601
- return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
602
  elif source_method == "Image to Video (图片到视频)":
603
- return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None)]
 
 
604
  else:
605
- return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update()]
606
  source_method.change(
607
- upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video]
 
 
 
608
  )
609
 
610
  def upload_resize_method(resize_method):
@@ -639,6 +733,8 @@ def ui(low_gpu_memory_mode, weight_dtype):
639
  start_image,
640
  end_image,
641
  validation_video,
 
 
642
  denoise_strength,
643
  seed_textbox,
644
  ],
@@ -647,8 +743,8 @@ def ui(low_gpu_memory_mode, weight_dtype):
647
  return demo, controller
648
 
649
 
650
- class CogVideoX_I2VController_Modelscope:
651
- def __init__(self, model_name, savedir_sample, low_gpu_memory_mode, weight_dtype):
652
  # Basic dir
653
  self.basedir = os.getcwd()
654
  self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
@@ -658,6 +754,7 @@ class CogVideoX_I2VController_Modelscope:
658
  os.makedirs(self.savedir_sample, exist_ok=True)
659
 
660
  # model path
 
661
  self.weight_dtype = weight_dtype
662
 
663
  self.vae = AutoencoderKLCogVideoX.from_pretrained(
@@ -672,16 +769,25 @@ class CogVideoX_I2VController_Modelscope:
672
  ).to(self.weight_dtype)
673
 
674
  # Get pipeline
675
- if self.transformer.config.in_channels != self.vae.config.latent_channels:
676
- self.pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
677
- model_name,
678
- vae=self.vae,
679
- transformer=self.transformer,
680
- scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
681
- torch_dtype=self.weight_dtype
682
- )
 
 
 
 
 
 
 
 
 
683
  else:
684
- self.pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
685
  model_name,
686
  vae=self.vae,
687
  transformer=self.transformer,
@@ -733,6 +839,8 @@ class CogVideoX_I2VController_Modelscope:
733
  start_image,
734
  end_image,
735
  validation_video,
 
 
736
  denoise_strength,
737
  seed_textbox,
738
  is_api = False,
@@ -747,25 +855,48 @@ class CogVideoX_I2VController_Modelscope:
747
  if self.lora_model_path != lora_model_dropdown:
748
  print("Update lora model")
749
  self.update_lora_model(lora_model_dropdown)
 
 
 
 
 
 
 
 
 
 
 
 
750
 
751
  if resize_method == "Resize according to Reference":
752
- if start_image is None and validation_video is None:
753
- raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
 
 
 
754
 
755
- aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
756
-
757
- if validation_video is not None:
758
- original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
 
 
759
  else:
760
- original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
761
  closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
762
  height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
763
 
764
  if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
765
- raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
766
-
 
 
 
767
  if start_image is None and end_image is not None:
768
- raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
 
 
 
769
 
770
  is_image = True if generation_method == "Image Generation" else False
771
 
@@ -779,13 +910,42 @@ class CogVideoX_I2VController_Modelscope:
779
  generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
780
 
781
  try:
782
- if self.transformer.config.in_channels != self.vae.config.latent_channels:
783
- if validation_video is not None:
784
- input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
785
- strength = denoise_strength
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786
  else:
787
- input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
788
- strength = 1
 
 
 
 
 
 
 
 
 
 
789
 
790
  sample = self.pipeline(
791
  prompt_textbox,
@@ -797,20 +957,7 @@ class CogVideoX_I2VController_Modelscope:
797
  num_frames = length_slider if not is_image else 1,
798
  generator = generator,
799
 
800
- video = input_video,
801
- mask_video = input_video_mask,
802
- strength = strength,
803
- ).videos
804
- else:
805
- sample = self.pipeline(
806
- prompt_textbox,
807
- negative_prompt = negative_prompt_textbox,
808
- num_inference_steps = sample_step_slider,
809
- guidance_scale = cfg_scale_slider,
810
- width = width_slider,
811
- height = height_slider,
812
- num_frames = length_slider if not is_image else 1,
813
- generator = generator
814
  ).videos
815
  except Exception as e:
816
  gc.collect()
@@ -866,8 +1013,8 @@ class CogVideoX_I2VController_Modelscope:
866
  return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
867
 
868
 
869
- def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype):
870
- controller = CogVideoX_I2VController_Modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
871
 
872
  with gr.Blocks(css=css) as demo:
873
  gr.Markdown(
@@ -882,7 +1029,20 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
882
  with gr.Column(variant="panel"):
883
  gr.Markdown(
884
  """
885
- ### 1. Model checkpoints (模型路径).
 
 
 
 
 
 
 
 
 
 
 
 
 
886
  """
887
  )
888
  with gr.Row():
@@ -919,12 +1079,12 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
919
  with gr.Column(variant="panel"):
920
  gr.Markdown(
921
  """
922
- ### 2. Configs for Generation (生成参数配置).
923
  """
924
  )
925
 
926
  prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
927
- negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. " )
928
 
929
  with gr.Row():
930
  with gr.Column():
@@ -953,7 +1113,7 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
953
  partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False)
954
 
955
  source_method = gr.Radio(
956
- ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"],
957
  value="Text to Video (文本到视频)",
958
  show_label=False,
959
  )
@@ -964,7 +1124,7 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
964
  template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
965
  def select_template(evt: gr.SelectData):
966
  text = {
967
- "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
968
  "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
969
  "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
970
  "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
@@ -986,13 +1146,36 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
986
  end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
987
 
988
  with gr.Column(visible = False) as video_to_video_col:
989
- validation_video = gr.Video(
990
- label="The video to convert (视频转视频的参考视频)", show_label=True,
991
- elem_id="v2v", sources="upload",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
992
  )
993
- denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=0.95, step=0.01)
994
 
995
- cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=7.0, minimum=0, maximum=20)
996
 
997
  with gr.Row():
998
  seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
@@ -1025,13 +1208,18 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
1025
 
1026
  def upload_source_method(source_method):
1027
  if source_method == "Text to Video (文本到视频)":
1028
- return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
1029
  elif source_method == "Image to Video (图片到视频)":
1030
- return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None)]
 
 
1031
  else:
1032
- return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update()]
1033
  source_method.change(
1034
- upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video]
 
 
 
1035
  )
1036
 
1037
  def upload_resize_method(resize_method):
@@ -1066,6 +1254,8 @@ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
1066
  start_image,
1067
  end_image,
1068
  validation_video,
 
 
1069
  denoise_strength,
1070
  seed_textbox,
1071
  ],
@@ -1080,7 +1270,7 @@ def post_eas(
1080
  prompt_textbox, negative_prompt_textbox,
1081
  sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
1082
  base_resolution, generation_method, length_slider, cfg_scale_slider,
1083
- start_image, end_image, validation_video, denoise_strength, seed_textbox,
1084
  ):
1085
  if start_image is not None:
1086
  with open(start_image, 'rb') as file:
@@ -1100,6 +1290,12 @@ def post_eas(
1100
  validation_video_encoded_content = base64.b64encode(file_content)
1101
  validation_video = validation_video_encoded_content.decode('utf-8')
1102
 
 
 
 
 
 
 
1103
  datas = {
1104
  "base_model_path": base_model_dropdown,
1105
  "lora_model_path": lora_model_dropdown,
@@ -1118,6 +1314,7 @@ def post_eas(
1118
  "start_image": start_image,
1119
  "end_image": end_image,
1120
  "validation_video": validation_video,
 
1121
  "denoise_strength": denoise_strength,
1122
  "seed_textbox": seed_textbox,
1123
  }
@@ -1131,7 +1328,7 @@ def post_eas(
1131
  return outputs
1132
 
1133
 
1134
- class CogVideoX_I2VController_EAS:
1135
  def __init__(self, model_name, savedir_sample):
1136
  self.savedir_sample = savedir_sample
1137
  os.makedirs(self.savedir_sample, exist_ok=True)
@@ -1156,6 +1353,7 @@ class CogVideoX_I2VController_EAS:
1156
  start_image,
1157
  end_image,
1158
  validation_video,
 
1159
  denoise_strength,
1160
  seed_textbox
1161
  ):
@@ -1167,7 +1365,7 @@ class CogVideoX_I2VController_EAS:
1167
  prompt_textbox, negative_prompt_textbox,
1168
  sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
1169
  base_resolution, generation_method, length_slider, cfg_scale_slider,
1170
- start_image, end_image, validation_video, denoise_strength,
1171
  seed_textbox
1172
  )
1173
  try:
@@ -1201,7 +1399,7 @@ class CogVideoX_I2VController_EAS:
1201
 
1202
 
1203
  def ui_eas(model_name, savedir_sample):
1204
- controller = CogVideoX_I2VController_EAS(model_name, savedir_sample)
1205
 
1206
  with gr.Blocks(css=css) as demo:
1207
  gr.Markdown(
@@ -1216,7 +1414,7 @@ def ui_eas(model_name, savedir_sample):
1216
  with gr.Column(variant="panel"):
1217
  gr.Markdown(
1218
  """
1219
- ### 1. Model checkpoints.
1220
  """
1221
  )
1222
  with gr.Row():
@@ -1258,7 +1456,7 @@ def ui_eas(model_name, savedir_sample):
1258
  )
1259
 
1260
  prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
1261
- negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. " )
1262
 
1263
  with gr.Row():
1264
  with gr.Column():
@@ -1295,7 +1493,7 @@ def ui_eas(model_name, savedir_sample):
1295
  template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
1296
  def select_template(evt: gr.SelectData):
1297
  text = {
1298
- "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1299
  "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1300
  "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1301
  "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
@@ -1317,13 +1515,25 @@ def ui_eas(model_name, savedir_sample):
1317
  end_image = gr.Image(label="The image at the ending of the video (Optional)", show_label=True, elem_id="i2v_end", sources="upload", type="filepath")
1318
 
1319
  with gr.Column(visible = False) as video_to_video_col:
1320
- validation_video = gr.Video(
1321
- label="The video to convert (视频转视频的参考视频)", show_label=True,
1322
- elem_id="v2v", sources="upload",
1323
- )
1324
- denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=0.95, step=0.01)
1325
-
1326
- cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=7.0, minimum=0, maximum=20)
 
 
 
 
 
 
 
 
 
 
 
 
1327
 
1328
  with gr.Row():
1329
  seed_textbox = gr.Textbox(label="Seed", value=43)
@@ -1347,7 +1557,7 @@ def ui_eas(model_name, savedir_sample):
1347
 
1348
  def upload_generation_method(generation_method):
1349
  if generation_method == "Video Generation":
1350
- return gr.update(visible=True, minimum=5, maximum=25, value=25, interactive=True)
1351
  elif generation_method == "Image Generation":
1352
  return gr.update(minimum=1, maximum=1, value=1, interactive=False)
1353
  generation_method.change(
@@ -1356,13 +1566,13 @@ def ui_eas(model_name, savedir_sample):
1356
 
1357
  def upload_source_method(source_method):
1358
  if source_method == "Text to Video (文本到视频)":
1359
- return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
1360
  elif source_method == "Image to Video (图片到视频)":
1361
- return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None)]
1362
  else:
1363
- return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update()]
1364
  source_method.change(
1365
- upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video]
1366
  )
1367
 
1368
  def upload_resize_method(resize_method):
@@ -1395,6 +1605,7 @@ def ui_eas(model_name, savedir_sample):
1395
  start_image,
1396
  end_image,
1397
  validation_video,
 
1398
  denoise_strength,
1399
  seed_textbox,
1400
  ],
 
30
  from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
31
  from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
32
  from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
33
+ from cogvideox.pipeline.pipeline_cogvideox_control import \
34
+ CogVideoX_Fun_Pipeline_Control
35
  from cogvideox.pipeline.pipeline_cogvideox_inpaint import \
36
  CogVideoX_Fun_Pipeline_Inpaint
37
  from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
 
60
  }
61
  """
62
 
63
+ class CogVideoX_Fun_Controller:
64
  def __init__(self, low_gpu_memory_mode, weight_dtype):
65
  # config dirs
66
  self.basedir = os.getcwd()
 
70
  self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
71
  self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
72
  self.savedir_sample = os.path.join(self.savedir, "sample")
73
+ self.model_type = "Inpaint"
74
  os.makedirs(self.savedir, exist_ok=True)
75
 
76
  self.diffusion_transformer_list = []
 
105
  personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
106
  self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
107
 
108
+ def update_model_type(self, model_type):
109
+ self.model_type = model_type
110
+
111
  def update_diffusion_transformer(self, diffusion_transformer_dropdown):
112
  print("Update diffusion transformer")
113
  if diffusion_transformer_dropdown == "none":
 
124
  ).to(self.weight_dtype)
125
 
126
  # Get pipeline
127
+ if self.model_type == "Inpaint":
128
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
129
+ self.pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
130
+ diffusion_transformer_dropdown,
131
+ vae=self.vae,
132
+ transformer=self.transformer,
133
+ scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
134
+ torch_dtype=self.weight_dtype
135
+ )
136
+ else:
137
+ self.pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
138
+ diffusion_transformer_dropdown,
139
+ vae=self.vae,
140
+ transformer=self.transformer,
141
+ scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
142
+ torch_dtype=self.weight_dtype
143
+ )
144
  else:
145
+ self.pipeline = CogVideoX_Fun_Pipeline_Control.from_pretrained(
146
  diffusion_transformer_dropdown,
147
  vae=self.vae,
148
  transformer=self.transformer,
 
206
  start_image,
207
  end_image,
208
  validation_video,
209
+ validation_video_mask,
210
+ control_video,
211
  denoise_strength,
212
  seed_textbox,
213
  is_api = False,
 
225
  if self.lora_model_path != lora_model_dropdown:
226
  print("Update lora model")
227
  self.update_lora_model(lora_model_dropdown)
228
+
229
+ if control_video is not None and self.model_type == "Inpaint":
230
+ if is_api:
231
+ return "", f"If specifying the control video, please set the model_type == \"Control\". "
232
+ else:
233
+ raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ")
234
+
235
+ if control_video is None and self.model_type == "Control":
236
+ if is_api:
237
+ return "", f"If set the model_type == \"Control\", please specifying the control video. "
238
+ else:
239
+ raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ")
240
+
241
  if resize_method == "Resize according to Reference":
242
+ if start_image is None and validation_video is None and control_video is None:
243
  if is_api:
244
  return "", f"Please upload an image when using \"Resize according to Reference\"."
245
  else:
246
  raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
247
 
248
  aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
249
+ if self.model_type == "Inpaint":
250
+ if validation_video is not None:
251
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
252
+ else:
253
+ original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
254
  else:
255
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
256
  closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
257
  height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
258
 
 
286
  generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
287
 
288
  try:
289
+ if self.model_type == "Inpaint":
290
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
291
+ if generation_method == "Long Video Generation":
292
+ if validation_video is not None:
293
+ raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
294
+ init_frames = 0
295
+ last_frames = init_frames + partial_video_length
296
+ while init_frames < length_slider:
297
+ if last_frames >= length_slider:
298
+ _partial_video_length = length_slider - init_frames
299
+ _partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
300
+
301
+ if _partial_video_length <= 0:
302
+ break
303
+ else:
304
+ _partial_video_length = partial_video_length
305
+
306
+ if last_frames >= length_slider:
307
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
308
+ else:
309
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
310
+
311
+ with torch.no_grad():
312
+ sample = self.pipeline(
313
+ prompt_textbox,
314
+ negative_prompt = negative_prompt_textbox,
315
+ num_inference_steps = sample_step_slider,
316
+ guidance_scale = cfg_scale_slider,
317
+ width = width_slider,
318
+ height = height_slider,
319
+ num_frames = _partial_video_length,
320
+ generator = generator,
321
+
322
+ video = input_video,
323
+ mask_video = input_video_mask,
324
+ strength = 1,
325
+ ).videos
326
 
327
+ if init_frames != 0:
328
+ mix_ratio = torch.from_numpy(
329
+ np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
330
+ ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
331
+
332
+ new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
333
+ sample[:, :, :overlap_video_length] * mix_ratio
334
+ new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
335
+
336
+ sample = new_sample
337
+ else:
338
+ new_sample = sample
339
+
340
+ if last_frames >= length_slider:
341
  break
 
 
342
 
343
+ start_image = [
344
+ Image.fromarray(
345
+ (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
346
+ ) for _index in range(-overlap_video_length, 0)
347
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
+ init_frames = init_frames + _partial_video_length - overlap_video_length
350
+ last_frames = init_frames + _partial_video_length
351
+ else:
352
+ if validation_video is not None:
353
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=8)
354
+ strength = denoise_strength
355
  else:
356
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
357
+ strength = 1
358
+
359
+ sample = self.pipeline(
360
+ prompt_textbox,
361
+ negative_prompt = negative_prompt_textbox,
362
+ num_inference_steps = sample_step_slider,
363
+ guidance_scale = cfg_scale_slider,
364
+ width = width_slider,
365
+ height = height_slider,
366
+ num_frames = length_slider if not is_image else 1,
367
+ generator = generator,
368
+
369
+ video = input_video,
370
+ mask_video = input_video_mask,
371
+ strength = strength,
372
+ ).videos
373
  else:
 
 
 
 
 
 
 
374
  sample = self.pipeline(
375
  prompt_textbox,
376
  negative_prompt = negative_prompt_textbox,
 
379
  width = width_slider,
380
  height = height_slider,
381
  num_frames = length_slider if not is_image else 1,
382
+ generator = generator
 
 
 
 
383
  ).videos
384
  else:
385
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=8)
386
+
387
  sample = self.pipeline(
388
  prompt_textbox,
389
  negative_prompt = negative_prompt_textbox,
 
392
  width = width_slider,
393
  height = height_slider,
394
  num_frames = length_slider if not is_image else 1,
395
+ generator = generator,
396
+
397
+ control_video = input_video,
398
  ).videos
399
  except Exception as e:
400
  gc.collect()
 
469
 
470
 
471
  def ui(low_gpu_memory_mode, weight_dtype):
472
+ controller = CogVideoX_Fun_Controller(low_gpu_memory_mode, weight_dtype)
473
 
474
  with gr.Blocks(css=css) as demo:
475
  gr.Markdown(
 
484
  with gr.Column(variant="panel"):
485
  gr.Markdown(
486
  """
487
+ ### 1. CogVideoX-Fun Model Type (CogVideoX-Fun模型的种类,正常模型还是控制模型).
488
+ """
489
+ )
490
+ with gr.Row():
491
+ model_type = gr.Dropdown(
492
+ label="The model type of CogVideoX-Fun (CogVideoX-Fun模型的种类,正常模型还是控制模型)",
493
+ choices=["Inpaint", "Control"],
494
+ value="Inpaint",
495
+ interactive=True,
496
+ )
497
+
498
+ gr.Markdown(
499
+ """
500
+ ### 2. Model checkpoints (模型路径).
501
  """
502
  )
503
  with gr.Row():
 
548
  with gr.Column(variant="panel"):
549
  gr.Markdown(
550
  """
551
+ ### 3. Configs for Generation (生成参数配置).
552
  """
553
  )
554
 
555
  prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
556
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " )
557
 
558
  with gr.Row():
559
  with gr.Column():
 
582
  partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False)
583
 
584
  source_method = gr.Radio(
585
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"],
586
  value="Text to Video (文本到视频)",
587
  show_label=False,
588
  )
 
595
  template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
596
  def select_template(evt: gr.SelectData):
597
  text = {
598
+ "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
599
  "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
600
  "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
601
  "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
 
617
  end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
618
 
619
  with gr.Column(visible = False) as video_to_video_col:
620
+ with gr.Row():
621
+ validation_video = gr.Video(
622
+ label="The video to convert (视频转视频的参考视频)", show_label=True,
623
+ elem_id="v2v", sources="upload",
624
+ )
625
+ with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
626
+ gr.Markdown(
627
+ """
628
+ - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
629
+ - (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
630
+ """
631
+ )
632
+ validation_video_mask = gr.Image(
633
+ label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
634
+ show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
635
+ )
636
+ denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
637
+
638
+ with gr.Column(visible = False) as control_video_col:
639
+ gr.Markdown(
640
+ """
641
+ Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
642
+ """
643
+ )
644
+ control_video = gr.Video(
645
+ label="The control video (用于提供控制信号的video)", show_label=True,
646
+ elem_id="v2v_control", sources="upload",
647
  )
 
648
 
649
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
650
 
651
  with gr.Row():
652
  seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
 
668
  interactive=False
669
  )
670
 
671
+ model_type.change(
672
+ fn=controller.update_model_type,
673
+ inputs=[model_type],
674
+ outputs=[]
675
+ )
676
+
677
  def upload_generation_method(generation_method):
678
  if generation_method == "Video Generation":
679
  return [gr.update(visible=True, maximum=49, value=49), gr.update(visible=False), gr.update(visible=False)]
 
687
 
688
  def upload_source_method(source_method):
689
  if source_method == "Text to Video (文本到视频)":
690
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
691
  elif source_method == "Image to Video (图片到视频)":
692
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
693
+ elif source_method == "Video to Video (视频到视频)":
694
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
695
  else:
696
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
697
  source_method.change(
698
+ upload_source_method, source_method, [
699
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
700
+ validation_video, validation_video_mask, control_video
701
+ ]
702
  )
703
 
704
  def upload_resize_method(resize_method):
 
733
  start_image,
734
  end_image,
735
  validation_video,
736
+ validation_video_mask,
737
+ control_video,
738
  denoise_strength,
739
  seed_textbox,
740
  ],
 
743
  return demo, controller
744
 
745
 
746
+ class CogVideoX_Fun_Controller_Modelscope:
747
+ def __init__(self, model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype):
748
  # Basic dir
749
  self.basedir = os.getcwd()
750
  self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
 
754
  os.makedirs(self.savedir_sample, exist_ok=True)
755
 
756
  # model path
757
+ self.model_type = model_type
758
  self.weight_dtype = weight_dtype
759
 
760
  self.vae = AutoencoderKLCogVideoX.from_pretrained(
 
769
  ).to(self.weight_dtype)
770
 
771
  # Get pipeline
772
+ if model_type == "Inpaint":
773
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
774
+ self.pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
775
+ model_name,
776
+ vae=self.vae,
777
+ transformer=self.transformer,
778
+ scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
779
+ torch_dtype=self.weight_dtype
780
+ )
781
+ else:
782
+ self.pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
783
+ model_name,
784
+ vae=self.vae,
785
+ transformer=self.transformer,
786
+ scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
787
+ torch_dtype=self.weight_dtype
788
+ )
789
  else:
790
+ self.pipeline = CogVideoX_Fun_Pipeline_Control.from_pretrained(
791
  model_name,
792
  vae=self.vae,
793
  transformer=self.transformer,
 
839
  start_image,
840
  end_image,
841
  validation_video,
842
+ validation_video_mask,
843
+ control_video,
844
  denoise_strength,
845
  seed_textbox,
846
  is_api = False,
 
855
  if self.lora_model_path != lora_model_dropdown:
856
  print("Update lora model")
857
  self.update_lora_model(lora_model_dropdown)
858
+
859
+ if control_video is not None and self.model_type == "Inpaint":
860
+ if is_api:
861
+ return "", f"If specifying the control video, please set the model_type == \"Control\". "
862
+ else:
863
+ raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ")
864
+
865
+ if control_video is None and self.model_type == "Control":
866
+ if is_api:
867
+ return "", f"If set the model_type == \"Control\", please specifying the control video. "
868
+ else:
869
+ raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ")
870
 
871
  if resize_method == "Resize according to Reference":
872
+ if start_image is None and validation_video is None and control_video is None:
873
+ if is_api:
874
+ return "", f"Please upload an image when using \"Resize according to Reference\"."
875
+ else:
876
+ raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
877
 
878
+ aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
879
+ if self.model_type == "Inpaint":
880
+ if validation_video is not None:
881
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
882
+ else:
883
+ original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
884
  else:
885
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
886
  closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
887
  height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
888
 
889
  if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
890
+ if is_api:
891
+ return "", f"Please select an image to video pretrained model while using image to video."
892
+ else:
893
+ raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
894
+
895
  if start_image is None and end_image is not None:
896
+ if is_api:
897
+ return "", f"If specifying the ending image of the video, please specify a starting image of the video."
898
+ else:
899
+ raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
900
 
901
  is_image = True if generation_method == "Image Generation" else False
902
 
 
910
  generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
911
 
912
  try:
913
+ if self.model_type == "Inpaint":
914
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
915
+ if validation_video is not None:
916
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=8)
917
+ strength = denoise_strength
918
+ else:
919
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
920
+ strength = 1
921
+
922
+ sample = self.pipeline(
923
+ prompt_textbox,
924
+ negative_prompt = negative_prompt_textbox,
925
+ num_inference_steps = sample_step_slider,
926
+ guidance_scale = cfg_scale_slider,
927
+ width = width_slider,
928
+ height = height_slider,
929
+ num_frames = length_slider if not is_image else 1,
930
+ generator = generator,
931
+
932
+ video = input_video,
933
+ mask_video = input_video_mask,
934
+ strength = strength,
935
+ ).videos
936
  else:
937
+ sample = self.pipeline(
938
+ prompt_textbox,
939
+ negative_prompt = negative_prompt_textbox,
940
+ num_inference_steps = sample_step_slider,
941
+ guidance_scale = cfg_scale_slider,
942
+ width = width_slider,
943
+ height = height_slider,
944
+ num_frames = length_slider if not is_image else 1,
945
+ generator = generator
946
+ ).videos
947
+ else:
948
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=8)
949
 
950
  sample = self.pipeline(
951
  prompt_textbox,
 
957
  num_frames = length_slider if not is_image else 1,
958
  generator = generator,
959
 
960
+ control_video = input_video,
 
 
 
 
 
 
 
 
 
 
 
 
 
961
  ).videos
962
  except Exception as e:
963
  gc.collect()
 
1013
  return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
1014
 
1015
 
1016
+ def ui_modelscope(model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype):
1017
+ controller = CogVideoX_Fun_Controller_Modelscope(model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype)
1018
 
1019
  with gr.Blocks(css=css) as demo:
1020
  gr.Markdown(
 
1029
  with gr.Column(variant="panel"):
1030
  gr.Markdown(
1031
  """
1032
+ ### 1. CogVideoX-Fun Model Type (CogVideoX-Fun模型的种类,正常模型还是控制模型).
1033
+ """
1034
+ )
1035
+ with gr.Row():
1036
+ model_type = gr.Dropdown(
1037
+ label="The model type of CogVideoX-Fun (CogVideoX-Fun模型的种类,正常模型还是控制模型)",
1038
+ choices=[model_type],
1039
+ value=model_type,
1040
+ interactive=False,
1041
+ )
1042
+
1043
+ gr.Markdown(
1044
+ """
1045
+ ### 2. Model checkpoints (模型路径).
1046
  """
1047
  )
1048
  with gr.Row():
 
1079
  with gr.Column(variant="panel"):
1080
  gr.Markdown(
1081
  """
1082
+ ### 3. Configs for Generation (生成参数配置).
1083
  """
1084
  )
1085
 
1086
  prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
1087
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " )
1088
 
1089
  with gr.Row():
1090
  with gr.Column():
 
1113
  partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False)
1114
 
1115
  source_method = gr.Radio(
1116
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"],
1117
  value="Text to Video (文本到视频)",
1118
  show_label=False,
1119
  )
 
1124
  template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
1125
  def select_template(evt: gr.SelectData):
1126
  text = {
1127
+ "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1128
  "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1129
  "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1130
  "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
 
1146
  end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
1147
 
1148
  with gr.Column(visible = False) as video_to_video_col:
1149
+ with gr.Row():
1150
+ validation_video = gr.Video(
1151
+ label="The video to convert (视频转视频的参考视频)", show_label=True,
1152
+ elem_id="v2v", sources="upload",
1153
+ )
1154
+ with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
1155
+ gr.Markdown(
1156
+ """
1157
+ - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
1158
+ - (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
1159
+ """
1160
+ )
1161
+ validation_video_mask = gr.Image(
1162
+ label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
1163
+ show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
1164
+ )
1165
+ denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
1166
+
1167
+ with gr.Column(visible = False) as control_video_col:
1168
+ gr.Markdown(
1169
+ """
1170
+ Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
1171
+ """
1172
+ )
1173
+ control_video = gr.Video(
1174
+ label="The control video (用于提供控制信号的video)", show_label=True,
1175
+ elem_id="v2v_control", sources="upload",
1176
  )
 
1177
 
1178
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
1179
 
1180
  with gr.Row():
1181
  seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
 
1208
 
1209
  def upload_source_method(source_method):
1210
  if source_method == "Text to Video (文本到视频)":
1211
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
1212
  elif source_method == "Image to Video (图片到视频)":
1213
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
1214
+ elif source_method == "Video to Video (视频到视频)":
1215
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
1216
  else:
1217
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
1218
  source_method.change(
1219
+ upload_source_method, source_method, [
1220
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
1221
+ validation_video, validation_video_mask, control_video
1222
+ ]
1223
  )
1224
 
1225
  def upload_resize_method(resize_method):
 
1254
  start_image,
1255
  end_image,
1256
  validation_video,
1257
+ validation_video_mask,
1258
+ control_video,
1259
  denoise_strength,
1260
  seed_textbox,
1261
  ],
 
1270
  prompt_textbox, negative_prompt_textbox,
1271
  sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
1272
  base_resolution, generation_method, length_slider, cfg_scale_slider,
1273
+ start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox,
1274
  ):
1275
  if start_image is not None:
1276
  with open(start_image, 'rb') as file:
 
1290
  validation_video_encoded_content = base64.b64encode(file_content)
1291
  validation_video = validation_video_encoded_content.decode('utf-8')
1292
 
1293
+ if validation_video_mask is not None:
1294
+ with open(validation_video_mask, 'rb') as file:
1295
+ file_content = file.read()
1296
+ validation_video_mask_encoded_content = base64.b64encode(file_content)
1297
+ validation_video_mask = validation_video_mask_encoded_content.decode('utf-8')
1298
+
1299
  datas = {
1300
  "base_model_path": base_model_dropdown,
1301
  "lora_model_path": lora_model_dropdown,
 
1314
  "start_image": start_image,
1315
  "end_image": end_image,
1316
  "validation_video": validation_video,
1317
+ "validation_video_mask": validation_video_mask,
1318
  "denoise_strength": denoise_strength,
1319
  "seed_textbox": seed_textbox,
1320
  }
 
1328
  return outputs
1329
 
1330
 
1331
+ class CogVideoX_Fun_Controller_EAS:
1332
  def __init__(self, model_name, savedir_sample):
1333
  self.savedir_sample = savedir_sample
1334
  os.makedirs(self.savedir_sample, exist_ok=True)
 
1353
  start_image,
1354
  end_image,
1355
  validation_video,
1356
+ validation_video_mask,
1357
  denoise_strength,
1358
  seed_textbox
1359
  ):
 
1365
  prompt_textbox, negative_prompt_textbox,
1366
  sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
1367
  base_resolution, generation_method, length_slider, cfg_scale_slider,
1368
+ start_image, end_image, validation_video, validation_video_mask, denoise_strength,
1369
  seed_textbox
1370
  )
1371
  try:
 
1399
 
1400
 
1401
  def ui_eas(model_name, savedir_sample):
1402
+ controller = CogVideoX_Fun_Controller_EAS(model_name, savedir_sample)
1403
 
1404
  with gr.Blocks(css=css) as demo:
1405
  gr.Markdown(
 
1414
  with gr.Column(variant="panel"):
1415
  gr.Markdown(
1416
  """
1417
+ ### 1. Model checkpoints (模型路径).
1418
  """
1419
  )
1420
  with gr.Row():
 
1456
  )
1457
 
1458
  prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
1459
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " )
1460
 
1461
  with gr.Row():
1462
  with gr.Column():
 
1493
  template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
1494
  def select_template(evt: gr.SelectData):
1495
  text = {
1496
+ "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1497
  "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1498
  "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1499
  "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
 
1515
  end_image = gr.Image(label="The image at the ending of the video (Optional)", show_label=True, elem_id="i2v_end", sources="upload", type="filepath")
1516
 
1517
  with gr.Column(visible = False) as video_to_video_col:
1518
+ with gr.Row():
1519
+ validation_video = gr.Video(
1520
+ label="The video to convert (视频转视频的参考视频)", show_label=True,
1521
+ elem_id="v2v", sources="upload",
1522
+ )
1523
+ with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
1524
+ gr.Markdown(
1525
+ """
1526
+ - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
1527
+ - (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
1528
+ """
1529
+ )
1530
+ validation_video_mask = gr.Image(
1531
+ label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
1532
+ show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
1533
+ )
1534
+ denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
1535
+
1536
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
1537
 
1538
  with gr.Row():
1539
  seed_textbox = gr.Textbox(label="Seed", value=43)
 
1557
 
1558
  def upload_generation_method(generation_method):
1559
  if generation_method == "Video Generation":
1560
+ return gr.update(visible=True, minimum=5, maximum=49, value=49, interactive=True)
1561
  elif generation_method == "Image Generation":
1562
  return gr.update(minimum=1, maximum=1, value=1, interactive=False)
1563
  generation_method.change(
 
1566
 
1567
  def upload_source_method(source_method):
1568
  if source_method == "Text to Video (文本到视频)":
1569
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
1570
  elif source_method == "Image to Video (图片到视频)":
1571
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
1572
  else:
1573
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
1574
  source_method.change(
1575
+ upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
1576
  )
1577
 
1578
  def upload_resize_method(resize_method):
 
1605
  start_image,
1606
  end_image,
1607
  validation_video,
1608
+ validation_video_mask,
1609
  denoise_strength,
1610
  seed_textbox,
1611
  ],
cogvideox/utils/utils.py CHANGED
@@ -166,16 +166,27 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide
166
 
167
  return input_video, input_video_mask, clip_image
168
 
169
- def get_video_to_video_latent(input_video_path, video_length, sample_size):
170
- if type(input_video_path) is str:
171
  cap = cv2.VideoCapture(input_video_path)
172
  input_video = []
 
 
 
 
 
 
173
  while True:
174
  ret, frame = cap.read()
175
  if not ret:
176
  break
177
- frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
178
- input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
 
 
 
 
 
179
  cap.release()
180
  else:
181
  input_video = input_video_path
@@ -183,7 +194,15 @@ def get_video_to_video_latent(input_video_path, video_length, sample_size):
183
  input_video = torch.from_numpy(np.array(input_video))[:video_length]
184
  input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
185
 
186
- input_video_mask = torch.zeros_like(input_video[:, :1])
187
- input_video_mask[:, :, :] = 255
 
 
 
 
 
 
 
 
188
 
189
  return input_video, input_video_mask, None
 
166
 
167
  return input_video, input_video_mask, clip_image
168
 
169
+ def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None):
170
+ if isinstance(input_video_path, str):
171
  cap = cv2.VideoCapture(input_video_path)
172
  input_video = []
173
+
174
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
175
+ frame_skip = 1 if fps is None else int(original_fps // fps)
176
+
177
+ frame_count = 0
178
+
179
  while True:
180
  ret, frame = cap.read()
181
  if not ret:
182
  break
183
+
184
+ if frame_count % frame_skip == 0:
185
+ frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
186
+ input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
187
+
188
+ frame_count += 1
189
+
190
  cap.release()
191
  else:
192
  input_video = input_video_path
 
194
  input_video = torch.from_numpy(np.array(input_video))[:video_length]
195
  input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
196
 
197
+ if validation_video_mask is not None:
198
+ validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
199
+ input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
200
+
201
+ input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
202
+ input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
203
+ input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
204
+ else:
205
+ input_video_mask = torch.zeros_like(input_video[:, :1])
206
+ input_video_mask[:, :, :] = 255
207
 
208
  return input_video, input_video_mask, None