diff --git a/.gitignore b/.gitignore
index 53b8998..0711e04 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,7 +3,7 @@ preprocess
results
results_img
sample_videos
-
+outputs
__pycache__/
*.py[cod]
diff --git a/environment.yml b/environment.yml
index f4a4ad2..84987a8 100644
--- a/environment.yml
+++ b/environment.yml
@@ -14,7 +14,10 @@ dependencies:
- tensorboard
- einops
- transformers
+ - bitsandbytes
+ - wandb
- av
+ - opencv-python
- scikit-image
- decord
- pandas
@@ -23,3 +26,6 @@ dependencies:
- beautifulsoup4
- ftfy
- omegaconf
+ - gradio
+ - spaces
+ - uuid
diff --git a/experiments.ipynb b/experiments.ipynb
new file mode 100644
index 0000000..3515766
--- /dev/null
+++ b/experiments.ipynb
@@ -0,0 +1,261 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !git clone https://github.com/AppimateSA/Latte.git\n",
+ "# %cd Latte\n",
+ "# !git checkout luthando-contribution"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Install Correct Modules"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !pip install git+https://github.com/huggingface/diffusers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Import"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/luthando/miniconda3/envs/latte/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
+ "source": [
+ "import imageio\n",
+ "import torch\n",
+ "from torchvision.utils import save_image\n",
+ "from diffusers import LattePipeline\n",
+ "from diffusers.models import AutoencoderKLTemporalDecoder\n",
+ "\n",
+ "\n",
+ "torch.manual_seed(0)\n",
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !python -m pip uninstall diffusers -y && conda uninstall diffusers -y\n",
+ "# !conda clean -ay\n",
+ "# !python -m pip cache purge"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Normal Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# video_length = 16 # 1 (text-to-image) or 16 (text-to-video)\n",
+ "# pipe = LattePipeline.from_pretrained(\"maxin-cn/Latte-1\", torch_dtype=torch.float16).to(device)\n",
+ "\n",
+ "# # Using temporal decoder of VAE\n",
+ "# vae = AutoencoderKLTemporalDecoder.from_pretrained(\"maxin-cn/Latte-1\", subfolder=\"vae_temporal_decoder\", torch_dtype=torch.float16).to(device)\n",
+ "# pipe.vae = vae"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# prompt = \"a cat wearing sunglasses and working as a lifeguard at pool.\"\n",
+ "# videos = pipe(prompt, video_length=video_length, output_type='pt').frames.cpu()\n",
+ "\n",
+ "# if video_length > 1:\n",
+ "# videos = (videos.clamp(0, 1) * 255).to(dtype=torch.uint8) # convert to uint8\n",
+ "# imageio.mimwrite('./latte_output.mp4', videos[0].permute(0, 2, 3, 1), fps=8, quality=5) # highest quality is 10, lowest is 0\n",
+ "# else:\n",
+ "# save_image(videos[0], './latte_output.png')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Inference with 4/8-bit quantization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading shards: 100%|██████████| 4/4 [00:00<00:00, 17119.61it/s]\n",
+ "Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00, 1.82it/s]\n",
+ "Loading pipeline components...: 25%|██▌ | 1/4 [00:00<00:00, 7.50it/s]Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
+ "Loading pipeline components...: 100%|██████████| 4/4 [00:00<00:00, 21.44it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "import gc\n",
+ "from transformers import T5EncoderModel, BitsAndBytesConfig\n",
+ "\n",
+ "\n",
+ "torch.manual_seed(0)\n",
+ "\n",
+ "def flush():\n",
+ " gc.collect()\n",
+ " torch.cuda.empty_cache()\n",
+ "\n",
+ "def bytes_to_giga_bytes(bytes):\n",
+ " return bytes / 1024 / 1024 / 1024\n",
+ "\n",
+ "video_length = 16\n",
+ "model_id = \"maxin-cn/Latte-1\"\n",
+ "\n",
+ "text_encoder = T5EncoderModel.from_pretrained(\n",
+ " model_id,\n",
+ " subfolder=\"text_encoder\",\n",
+ " quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),\n",
+ " device_map=\"auto\",\n",
+ ")\n",
+ "pipe = LattePipeline.from_pretrained(\n",
+ " model_id, \n",
+ " text_encoder=text_encoder,\n",
+ " transformer=None,\n",
+ " device_map=\"balanced\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Loading pipeline components...: 0%| | 0/4 [00:00, ?it/s]Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
+ "The config attributes {'attention_type': 'default', 'double_self_attention': False, 'norm_num_groups': 32, 'num_vector_embeds': None, 'only_cross_attention': False, 'upcast_attention': False, 'use_linear_projection': False} were passed to LatteTransformer3DModel, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
+ "Some weights of the model checkpoint were not used when initializing LatteTransformer3DModel: \n",
+ " ['caption_projection.y_embedding']\n",
+ "Loading pipeline components...: 100%|██████████| 4/4 [00:00<00:00, 11.42it/s]\n",
+ "100%|██████████| 50/50 [01:01<00:00, 1.24s/it]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Max memory allocated: 9.150381565093994 GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "outputs_folder = \"./outputs/\"\n",
+ "with torch.no_grad():\n",
+ " prompt = \"a cat wearing sunglasses and working as a lifeguard at pool, with hot girls.\"\n",
+ " negative_prompt = \"\"\n",
+ " prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt, negative_prompt=negative_prompt)\n",
+ "\n",
+ "del text_encoder\n",
+ "del pipe\n",
+ "flush()\n",
+ "\n",
+ "pipe = LattePipeline.from_pretrained(\n",
+ " model_id,\n",
+ " text_encoder=None,\n",
+ " torch_dtype=torch.float16,\n",
+ ").to(\"cuda\")\n",
+ "# pipe.enable_vae_tiling()\n",
+ "# pipe.enable_vae_slicing()\n",
+ "videos = pipe(\n",
+ " video_length=video_length,\n",
+ " num_inference_steps=50,\n",
+ " negative_prompt=None, \n",
+ " prompt_embeds=prompt_embeds,\n",
+ " negative_prompt_embeds=negative_prompt_embeds,\n",
+ " output_type=\"pt\",\n",
+ ").frames.cpu()\n",
+ "print(f\"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if video_length > 1:\n",
+ " videos_uint8 = (videos.clamp(0, 1) * 255).to(dtype=torch.uint8) # convert to uint8\n",
+ " imageio.mimwrite(f\"{outputs_folder}latte_output_3.mp4\", videos_uint8[0].permute(0, 2, 3, 1), fps=8, quality=5) # highest quality is 10, lowest is 0\n",
+ "else:\n",
+ " save_image(videos[0], f\"{outputs_folder}latte_output_3.png\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "latte",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/gradio/app.py b/gradio/app.py
new file mode 100644
index 0000000..41d59d1
--- /dev/null
+++ b/gradio/app.py
@@ -0,0 +1,111 @@
+import os
+import sys
+from types import SimpleNamespace
+from huggingface_hub import snapshot_download
+import gradio as gr
+import spaces
+
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../scripts')))
+from inference import inference_function, video_to_base64
+
+
+
+
+
+
+@spaces.GPU(duration=200)
+def run_inference(prompt_text, visual_prompt=None, is_running_in_api=None):
+ model_id = "maxin-cn/Latte-1"
+
+ # negative_prompt = "watermark+++, text, shutterstock text, shutterstock++, blurry, ugly, username, url, low resolution, low quality"
+ negative_prompt = None
+ args = {
+ "model": model_id,
+ "prompt": prompt_text,
+ "negative_prompt": negative_prompt,
+ "num_frames": 16,
+ "num_steps": 50,
+ # "width": 256,
+ # "height": 256,
+ "visual_prompt": visual_prompt,
+ "device": 'cuda',
+ "quantize": True,
+ "fps": 4,
+ "output_dir": "./outputs",
+ }
+
+ print("is_running_in_api: ", is_running_in_api)
+ responseFile = inference_function(SimpleNamespace(**args))
+ print(model_id, "Produces -> ", responseFile)
+ if is_running_in_api == "True":
+ base64_file = video_to_base64(src_path=responseFile, delete_src=True)
+ filename = responseFile.split("/")[-1]
+ return {"base64": base64_file, "format": "video/mp4", "filename": filename}
+ else:
+ return responseFile
+
+
+def main():
+ with gr.Blocks() as demo:
+ with gr.Row():
+ with gr.Column():
+ gr.HTML(
+ """
+
+