akhaliq HF staff commited on
Commit
6731de0
1 Parent(s): f32e7d8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from pyramid_dit import PyramidDiTForVideoGeneration
5
+ from diffusers.utils import load_image, export_to_video
6
+ from huggingface_hub import snapshot_download
7
+ import os
8
+
9
+ # Download and load the model
10
+ model_path = 'pyramid_flow_model'
11
+ if not os.path.exists(model_path):
12
+ snapshot_download("rain1011/pyramid-flow-sd3", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
13
+
14
+ torch.cuda.set_device(0)
15
+ model_dtype, torch_dtype = 'bf16', torch.bfloat16
16
+
17
+ model = PyramidDiTForVideoGeneration(
18
+ model_path,
19
+ model_dtype,
20
+ model_variant='diffusion_transformer_768p',
21
+ )
22
+
23
+ model.vae.to("cuda")
24
+ model.dit.to("cuda")
25
+ model.text_encoder.to("cuda")
26
+ model.vae.enable_tiling()
27
+
28
+ def generate_video(prompt, height, width, duration, guidance_scale, video_guidance_scale):
29
+ temp = 16 if duration == "5s" else 31
30
+
31
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
32
+ frames = model.generate(
33
+ prompt=prompt,
34
+ num_inference_steps=[20, 20, 20],
35
+ video_num_inference_steps=[10, 10, 10],
36
+ height=height,
37
+ width=width,
38
+ temp=temp,
39
+ guidance_scale=guidance_scale,
40
+ video_guidance_scale=video_guidance_scale,
41
+ output_type="pil",
42
+ )
43
+
44
+ output_path = "generated_video.mp4"
45
+ export_to_video(frames, output_path, fps=24)
46
+ return output_path
47
+
48
+ def generate_video_from_image(image, prompt, video_guidance_scale):
49
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
50
+ frames = model.generate_i2v(
51
+ prompt=prompt,
52
+ input_image=image,
53
+ num_inference_steps=[10, 10, 10],
54
+ temp=16,
55
+ video_guidance_scale=video_guidance_scale,
56
+ output_type="pil",
57
+ )
58
+
59
+ output_path = "generated_video_from_image.mp4"
60
+ export_to_video(frames, output_path, fps=24)
61
+ return output_path
62
+
63
+ # Gradio interface
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("# Pyramid Flow Video Generation Demo")
66
+
67
+ with gr.Tab("Text-to-Video"):
68
+ with gr.Row():
69
+ with gr.Column():
70
+ txt_prompt = gr.Textbox(label="Prompt")
71
+ txt_height = gr.Slider(384, 768, value=768, step=384, label="Height")
72
+ txt_width = gr.Slider(640, 1280, value=1280, step=640, label="Width")
73
+ txt_duration = gr.Radio(["5s", "10s"], value="5s", label="Duration")
74
+ txt_guidance_scale = gr.Slider(1, 15, value=9, step=0.1, label="Guidance Scale")
75
+ txt_video_guidance_scale = gr.Slider(1, 15, value=5, step=0.1, label="Video Guidance Scale")
76
+ txt_generate = gr.Button("Generate Video")
77
+ with gr.Column():
78
+ txt_output = gr.Video(label="Generated Video")
79
+
80
+ with gr.Tab("Image-to-Video"):
81
+ with gr.Row():
82
+ with gr.Column():
83
+ img_input = gr.Image(type="pil", label="Input Image")
84
+ img_prompt = gr.Textbox(label="Prompt (optional)")
85
+ img_video_guidance_scale = gr.Slider(1, 15, value=4, step=0.1, label="Video Guidance Scale")
86
+ img_generate = gr.Button("Generate Video")
87
+ with gr.Column():
88
+ img_output = gr.Video(label="Generated Video")
89
+
90
+ txt_generate.click(generate_video,
91
+ inputs=[txt_prompt, txt_height, txt_width, txt_duration, txt_guidance_scale, txt_video_guidance_scale],
92
+ outputs=txt_output)
93
+
94
+ img_generate.click(generate_video_from_image,
95
+ inputs=[img_input, img_prompt, img_video_guidance_scale],
96
+ outputs=img_output)
97
+
98
+ demo.launch()