{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import torch\n", "import numpy as np\n", "import PIL\n", "from PIL import Image\n", "from IPython.display import HTML\n", "from pyramid_dit import PyramidDiTForVideoGeneration\n", "from IPython.display import Image as ipython_image\n", "from diffusers.utils import load_image, export_to_video, export_to_gif" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "variant='diffusion_transformer_768p' # For high resolution\n", "# variant='diffusion_transformer_384p' # For low resolution\n", "\n", "model_path = \"/home/jinyang06/models/pyramid-flow\" # The downloaded checkpoint dir\n", "model_dtype = 'bf16'\n", "\n", "device_id = 0\n", "torch.cuda.set_device(device_id)\n", "\n", "model = PyramidDiTForVideoGeneration(\n", " model_path,\n", " model_dtype,\n", " model_variant=variant,\n", ")\n", "\n", "model.vae.to(\"cuda\")\n", "model.dit.to(\"cuda\")\n", "model.text_encoder.to(\"cuda\")\\\n", "\n", "if model_dtype == \"bf16\":\n", " torch_dtype = torch.bfloat16 \n", "elif model_dtype == \"fp16\":\n", " torch_dtype = torch.float16\n", "else:\n", " torch_dtype = torch.float32\n", "\n", "\n", "def show_video(ori_path, rec_path, width=\"100%\"):\n", " html = ''\n", " if ori_path is not None:\n", " html += f\"\"\"\n", " \"\"\"\n", " \n", " html += f\"\"\"\n", " \"\"\"\n", " return HTML(html)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Text-to-Video" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "prompt = \"A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors\"\n", "\n", "# used for 384p model variant\n", "# width = 640\n", "# height = 384\n", "\n", "# used for 768p model variant\n", "width = 1280\n", "height = 768\n", "\n", "temp = 16 # temp in [1, 31] <=> frame in [1, 241] <=> duration in [0, 10s]\n", "\n", "model.vae.enable_tiling()\n", "\n", "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n", " frames = model.generate(\n", " prompt=prompt,\n", " num_inference_steps=[20, 20, 20],\n", " video_num_inference_steps=[10, 10, 10],\n", " height=height,\n", " width=width,\n", " temp=temp,\n", " guidance_scale=9.0, # The guidance for the first frame\n", " video_guidance_scale=5.0, # The guidance for the other video latent\n", " output_type=\"pil\",\n", " )\n", "\n", "export_to_video(frames, \"./text_to_video_sample.mp4\", fps=24)\n", "show_video(None, \"./text_to_video_sample.mp4\", \"70%\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Image-to-Video" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_path = 'assets/the_great_wall.jpg'\n", "image = Image.open(image_path).convert(\"RGB\")\n", "\n", "width = 1280\n", "height = 768\n", "temp = 16\n", "\n", "image = image.resize((width, height))\n", "\n", "display(image)\n", "\n", "prompt = \"FPV flying over the Great Wall\"\n", "\n", "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n", " frames = model.generate_i2v(\n", " prompt=prompt,\n", " input_image=image,\n", " num_inference_steps=[10, 10, 10],\n", " temp=temp,\n", " guidance_scale=7.0,\n", " video_guidance_scale=4.0,\n", " output_type=\"pil\",\n", " )\n", "\n", "export_to_video(frames, \"./image_to_video_sample.mp4\", fps=24)\n", "show_video(None, \"./image_to_video_sample.mp4\", \"70%\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }