diff --git a/.gitignore b/.gitignore index 53b8998..0711e04 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ preprocess results results_img sample_videos - +outputs __pycache__/ *.py[cod] diff --git a/environment.yml b/environment.yml index f4a4ad2..84987a8 100644 --- a/environment.yml +++ b/environment.yml @@ -14,7 +14,10 @@ dependencies: - tensorboard - einops - transformers + - bitsandbytes + - wandb - av + - opencv-python - scikit-image - decord - pandas @@ -23,3 +26,6 @@ dependencies: - beautifulsoup4 - ftfy - omegaconf + - gradio + - spaces + - uuid diff --git a/experiments.ipynb b/experiments.ipynb new file mode 100644 index 0000000..3515766 --- /dev/null +++ b/experiments.ipynb @@ -0,0 +1,261 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# !git clone https://github.com/AppimateSA/Latte.git\n", + "# %cd Latte\n", + "# !git checkout luthando-contribution" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Install Correct Modules" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install git+https://github.com/huggingface/diffusers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Import" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/luthando/miniconda3/envs/latte/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import imageio\n", + "import torch\n", + "from torchvision.utils import save_image\n", + "from diffusers import LattePipeline\n", + "from diffusers.models import AutoencoderKLTemporalDecoder\n", + "\n", + "\n", + "torch.manual_seed(0)\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# !python -m pip uninstall diffusers -y && conda uninstall diffusers -y\n", + "# !conda clean -ay\n", + "# !python -m pip cache purge" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Normal Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# video_length = 16 # 1 (text-to-image) or 16 (text-to-video)\n", + "# pipe = LattePipeline.from_pretrained(\"maxin-cn/Latte-1\", torch_dtype=torch.float16).to(device)\n", + "\n", + "# # Using temporal decoder of VAE\n", + "# vae = AutoencoderKLTemporalDecoder.from_pretrained(\"maxin-cn/Latte-1\", subfolder=\"vae_temporal_decoder\", torch_dtype=torch.float16).to(device)\n", + "# pipe.vae = vae" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# prompt = \"a cat wearing sunglasses and working as a lifeguard at pool.\"\n", + "# videos = pipe(prompt, video_length=video_length, output_type='pt').frames.cpu()\n", + "\n", + "# if video_length > 1:\n", + "# videos = (videos.clamp(0, 1) * 255).to(dtype=torch.uint8) # convert to uint8\n", + "# imageio.mimwrite('./latte_output.mp4', videos[0].permute(0, 2, 3, 1), fps=8, quality=5) # highest quality is 10, lowest is 0\n", + "# else:\n", + "# save_image(videos[0], './latte_output.png')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Inference with 4/8-bit quantization" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading shards: 100%|██████████| 4/4 [00:00<00:00, 17119.61it/s]\n", + "Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00, 1.82it/s]\n", + "Loading pipeline components...: 25%|██▌ | 1/4 [00:00<00:00, 7.50it/s]Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "Loading pipeline components...: 100%|██████████| 4/4 [00:00<00:00, 21.44it/s]\n" + ] + } + ], + "source": [ + "import gc\n", + "from transformers import T5EncoderModel, BitsAndBytesConfig\n", + "\n", + "\n", + "torch.manual_seed(0)\n", + "\n", + "def flush():\n", + " gc.collect()\n", + " torch.cuda.empty_cache()\n", + "\n", + "def bytes_to_giga_bytes(bytes):\n", + " return bytes / 1024 / 1024 / 1024\n", + "\n", + "video_length = 16\n", + "model_id = \"maxin-cn/Latte-1\"\n", + "\n", + "text_encoder = T5EncoderModel.from_pretrained(\n", + " model_id,\n", + " subfolder=\"text_encoder\",\n", + " quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),\n", + " device_map=\"auto\",\n", + ")\n", + "pipe = LattePipeline.from_pretrained(\n", + " model_id, \n", + " text_encoder=text_encoder,\n", + " transformer=None,\n", + " device_map=\"balanced\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading pipeline components...: 0%| | 0/4 [00:00 1:\n", + " videos_uint8 = (videos.clamp(0, 1) * 255).to(dtype=torch.uint8) # convert to uint8\n", + " imageio.mimwrite(f\"{outputs_folder}latte_output_3.mp4\", videos_uint8[0].permute(0, 2, 3, 1), fps=8, quality=5) # highest quality is 10, lowest is 0\n", + "else:\n", + " save_image(videos[0], f\"{outputs_folder}latte_output_3.png\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "latte", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/gradio/app.py b/gradio/app.py new file mode 100644 index 0000000..41d59d1 --- /dev/null +++ b/gradio/app.py @@ -0,0 +1,111 @@ +import os +import sys +from types import SimpleNamespace +from huggingface_hub import snapshot_download +import gradio as gr +import spaces + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../scripts'))) +from inference import inference_function, video_to_base64 + + + + + + +@spaces.GPU(duration=200) +def run_inference(prompt_text, visual_prompt=None, is_running_in_api=None): + model_id = "maxin-cn/Latte-1" + + # negative_prompt = "watermark+++, text, shutterstock text, shutterstock++, blurry, ugly, username, url, low resolution, low quality" + negative_prompt = None + args = { + "model": model_id, + "prompt": prompt_text, + "negative_prompt": negative_prompt, + "num_frames": 16, + "num_steps": 50, + # "width": 256, + # "height": 256, + "visual_prompt": visual_prompt, + "device": 'cuda', + "quantize": True, + "fps": 4, + "output_dir": "./outputs", + } + + print("is_running_in_api: ", is_running_in_api) + responseFile = inference_function(SimpleNamespace(**args)) + print(model_id, "Produces -> ", responseFile) + if is_running_in_api == "True": + base64_file = video_to_base64(src_path=responseFile, delete_src=True) + filename = responseFile.split("/")[-1] + return {"base64": base64_file, "format": "video/mp4", "filename": filename} + else: + return responseFile + + +def main(): + with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(): + gr.HTML( + """ +
+

+ Latte: Efficient High Quality Video Production. +

+
+ """ + ) + + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(scale=0.999): + prompt_source = gr.Dropdown(label="From API", choices=["False", "True"], value="False", interactive=False) + gr.Markdown("## Text Prompt (T2V)") + prompt_text = gr.Textbox(label="Text", placeholder="Enter prompt text here", lines=1) + + gr.Markdown("## Visual Prompt (I2V)") + visual_prompt = gr.Image(label="Image (optional)", show_download_button=False) + submit_button = gr.Button("Run Inference") + + with gr.Column(): + output_video = gr.Video(label="Output Video", height="100%", autoplay=True, show_download_button=True) + + submit_button.click( + fn=run_inference, + inputs=[ + prompt_text, + visual_prompt, + prompt_source, + ], + outputs=output_video + ) + gr.Examples( + examples=[ + [ "A cat wearing sunglasses and working as a lifeguard at pool.", None ], + [ "A cat, made of wooden blocks, wearing sunglasses and working as a lifeguard at pool.", "gradio/examples/1/wooden_blocks.jpg" ], + [ "A car driving fast on the Eastern beach in East London.", None ], + [ "A car, made of flowers, driving fast on the Eastern beach in East London.", "gradio/examples/1/rose.jpg" ], + ], + fn=run_inference, + inputs=[ + prompt_text, + visual_prompt, + prompt_source, + ], + outputs=[output_video], + cache_examples=True, + ) + + + # launch + demo.queue(max_size=5, default_concurrency_limit=1) + demo.launch(debug=True, share=True, max_threads=1) + +if __name__ == "__main__": + main() + +# python gradio/app.py \ No newline at end of file diff --git a/gradio/examples/1/rose.jpg b/gradio/examples/1/rose.jpg new file mode 100644 index 0000000..f99cbf1 Binary files /dev/null and b/gradio/examples/1/rose.jpg differ diff --git a/gradio/examples/1/wooden_blocks.jpg b/gradio/examples/1/wooden_blocks.jpg new file mode 100644 index 0000000..6e0380a Binary files /dev/null and b/gradio/examples/1/wooden_blocks.jpg differ diff --git a/gradio_cached_examples/18/Output Video/682ac195c30a217a19b8/latte_output_b421.mp4 b/gradio_cached_examples/18/Output Video/682ac195c30a217a19b8/latte_output_b421.mp4 new file mode 100644 index 0000000..b9e4a75 Binary files /dev/null and b/gradio_cached_examples/18/Output Video/682ac195c30a217a19b8/latte_output_b421.mp4 differ diff --git a/gradio_cached_examples/18/log.csv b/gradio_cached_examples/18/log.csv new file mode 100644 index 0000000..a22f48d --- /dev/null +++ b/gradio_cached_examples/18/log.csv @@ -0,0 +1,2 @@ +Output Video,flag,username,timestamp +"{""video"": {""path"": ""gradio_cached_examples/18/Output Video/682ac195c30a217a19b8/latte_output_b421.mp4"", ""url"": ""/file=/tmp/gradio/9905b30c9a779f4c4cf2b6a2ce137dbfded8e32a88f609b5205cb420a69edaa9/latte_output_b421.mp4"", ""size"": null, ""orig_name"": ""latte_output_$b421.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-08-18 16:34:17.664024 diff --git a/scripts/inference.py b/scripts/inference.py new file mode 100644 index 0000000..dec21f2 --- /dev/null +++ b/scripts/inference.py @@ -0,0 +1,205 @@ +import os +import io +import platform +import re +import gc +import argparse +import warnings +from pathlib import Path +from typing import List, Optional +from uuid import uuid4 +import av +import cv2 +import uuid +import imageio +import base64 +import shutil + +import numpy as np +import torch +import torchvision +from PIL import Image +import decord +from transformers import T5EncoderModel, BitsAndBytesConfig +from diffusers import LattePipeline + + +def flush(): + gc.collect() + torch.cuda.empty_cache() + +def bytes_to_giga_bytes(bytes): + return bytes / 1024 / 1024 / 1024 + +def delete_file(file_path): + if os.path.exists(file_path): + os.remove(file_path) + print("File deleted successfully") + else: + print("File not found.") + +def base64_to_video(base64_string=""): + """Converts a Base64 image string to a PyTorch tensor""" + img_data = base64.b64decode(base64_string) + img = Image.open(io.BytesIO(img_data)) + img_array = np.array(img) + + # Convert to tensor - PyTorch generally expects channels first (e.g., (C, H, W)) + file_output_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float() + return file_output_tensor + +def video_to_base64(src_path="", delete_src=False): + try: + with open(src_path, "rb") as video_file: + video_data = video_file.read() + base64_encoded_data = base64.b64encode(video_data) + if delete_src: + delete_file(src_path) + return base64_encoded_data.decode('utf-8') + except: + return None + +def video_to_tensor(video_path="", output_format="TCHW"): + video_tensor = None + video_tensor, _, info = torchvision.io.read_video(video_path, pts_unit = "sec", output_format=output_format) + return video_tensor + +def tensor_to_video(video_tensor, output_filename="mp_video.mp4", fps=1): + if video_tensor.shape[-1] > 3: + video_tensor = video_tensor.permute(0, 2, 3, 1) + + height, width = video_tensor.shape[1], video_tensor.shape[2] + container = av.open(output_filename, mode='w', format='mp4') + stream = container.add_stream('libx264', rate=fps) # Common codec (H.264) + stream.width = width + stream.height = height + stream.pix_fmt = 'yuv420p' # Set a suitable pixel format + for i in range(video_tensor.shape[0]): + # frame = video_tensor[i].permute(1, 2, 0).numpy() # PyTorch => OpenCV compatible + frame = video_tensor[i].numpy().astype('uint8') # PyTorch => OpenCV compatible + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # OpenCV usually expects BGR + + av_frame = av.VideoFrame.from_ndarray(frame, format='bgr24') + for packet in stream.encode(av_frame): + container.mux(packet) + + # Flush stream + for packet in stream.encode(): + container.mux(packet) + container.close() + return output_filename + +# def to_video(fn: str, frames: list[np.ndarray], fps: int): +def to_video(fn: str, frames: any, fps: int) -> str: + writer = imageio.get_writer(fn, format='FFMPEG', fps=fps) + for frame in frames: + writer.append_data(frame) + writer.close() + return fn + +def initialize_pipeline( + model_id: str, + device: str = "cuda", + load_in_4bit: int = True, +): + text_encoder = None + if load_in_4bit: + text_encoder = T5EncoderModel.from_pretrained( + model_id, + subfolder="text_encoder", + quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16), + device_map="auto", + ) + pipe = LattePipeline.from_pretrained( + model_id, + text_encoder=text_encoder, + transformer=None, + device_map="balanced", + ) + else: + pipe = LattePipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device) + + # Using temporal decoder of VAE + vae = AutoencoderKLTemporalDecoder.from_pretrained(model_id, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device) + pipe.vae = vae + + return pipe, text_encoder + + + +def inference_function( args ): + decord.bridge.set_bridge("torch") + flush() + + args.model = args.model if hasattr(args, 'model') else "maxin-cn/Latte-1" + args.prompt = args.prompt if hasattr(args, 'prompt') else None + args.negative_prompt = args.negative_prompt if hasattr(args, 'negative_prompt') else None + args.num_steps = args.num_steps if hasattr(args, 'num_steps') else 16 + args.num_frames = args.num_frames if hasattr(args, 'num_frames') else 16 + args.fps = args.fps if hasattr(args, 'fps') else 4 + args.quantize = args.quantize if hasattr(args, 'quantize') else True + args.device = args.device if hasattr(args, 'device') else "cuda" + args.seed = args.seed if hasattr(args, 'seed') else 0 + args.output_dir = args.output_dir if hasattr(args, 'output_dir') else 0 + + + # ========================================= + # ====== validate and prepare inputs ====== + # ========================================= + + torch.manual_seed(args.seed) + pipe, text_encoder = initialize_pipeline(model_id=args.model, device=args.device, load_in_4bit=args.quantize) + if args.quantize is not True: + videos = pipe(args.prompt, video_length=args.num_frames, output_type='pt').frames.cpu() + else: + with torch.no_grad(): + neg_prompt = args.negative_prompt if args.negative_prompt is not None else "" + prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(args.prompt, negative_prompt=neg_prompt ) + + del text_encoder + del pipe + flush() + + pipe = LattePipeline.from_pretrained( + args.model, + text_encoder=None, + torch_dtype=torch.float16, + ).to(args.device) + + videos = pipe( + video_length=args.num_frames, + negative_prompt=args.negative_prompt, + num_inference_steps=args.num_steps, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + output_type="pt", + ).frames.cpu() + print(f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB") + + + output_path = f"{args.output_dir}/latte_output_${(uuid.uuid4().hex)[:4]}" + if args.num_frames > 1: + output_path = f"{output_path}.mp4" + videos_uint8 = (videos.clamp(0, 1) * 255).to(dtype=torch.uint8) # convert to uint8 + imageio.mimwrite(output_path, videos_uint8[0].permute(0, 2, 3, 1), fps=8, quality=5) # highest quality is 10, lowest is 0 + else: + output_path = f"{output_path}.png" + save_image(videos[0], output_path) + + return output_path + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-m", "--model", type=str, required=True, help="The Model name or Model id from repository.") + parser.add_argument("-p", "--prompt", type=str, required=True, help="Text prompt to condition on.") + parser.add_argument("-n", "--negative-prompt", type=str, default=None, help="Text prompt to condition against.") + parser.add_argument("-s", "--num-steps", type=int, default=25, help="Number of diffusion steps to run per frame.") + parser.add_argument("-T", "--num-frames", type=int, default=16, help="Total number of frames to generate.") + parser.add_argument("-f", "--fps", type=int, default=4, help="FPS of output video.") + parser.add_argument("-d", "--device", type=str, default="cuda", help="Device to run inference on (defaults to cuda).") + parser.add_argument("-q", "--quantize", type=bool, default=True, help="Whether to run the quantized version of the Model.") + parser.add_argument("-r", "--seed", type=int, default=0, help="Random seed to make generations reproducible.") + parser.add_argument("-o", "--output-dir", type=str, default="./outputs", help="Directory to save output video to.") + args = parser.parse_args() + + inference_function(args) \ No newline at end of file