Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance video gen #8

Merged
merged 3 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 55 additions & 15 deletions mind_renderer/core/video.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
"""
Run video generation with command:
poetry run python mind_renderer/core/video.py submit --prompt the prompt

Fetch generated video with command:
poetry run python mind_renderer/core/video.py fetch --task_id runway_f2450ac24c1e
"""

import argparse
import asyncio
import json
import os
from typing import Any, Dict, Optional

import aiohttp
from dotenv import load_dotenv

from mind_renderer.utils.logger import Logger


class RunwayHTTPClient:
def __init__(self, session: Optional[aiohttp.ClientSession] = None, api_key: str = "") -> None:
Expand All @@ -16,6 +27,7 @@ def __init__(self, session: Optional[aiohttp.ClientSession] = None, api_key: str
session (Optional[aiohttp.ClientSession]): An optional aiohttp ClientSession to use for requests.
"""
self.base_url = "https://api.302.ai"
self.logger = Logger(__name__)
load_dotenv(override=True)
self.api_key = api_key if api_key else os.getenv("302AI_API_KEY", "test_api_key")

Expand Down Expand Up @@ -45,29 +57,56 @@ async def send_post_request(self, prompt: str, seconds: int = 10, seed: str = ""
async with session.post(url, headers=self.headers, data=data) as response:
return await response.json()

async def fetch_generated_image(self, task_id: str) -> str:
async def fetch_generated_video(self, task_id: str) -> str:
"""
Fetches the generated image for a given task ID and saves it as an MP4 file.
Fetches the generated video for a given task ID and saves it as an MP4 file.

Args:
task_id (str): The ID of the task to fetch the image for.
task_id (str): The ID of the task to fetch the video for.
session (Optional[aiohttp.ClientSession]): An optional aiohttp ClientSession to use for requests.

Returns:
str: The path to the saved MP4 file.
"""
url = f"{self.base_url}/runway/task/{task_id}/fetch"

if self.session:
async with self.session.get(url, headers=self.headers) as response:
generated_video_path = None
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=self.headers) as response:
data = await response.read()
else:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=self.headers) as response:
data = await response.read()
decoded_data = data.decode("utf-8")
response_json = json.loads(decoded_data)
if "task" not in response_json:
self.logger.error(f"Error fetching generated video: {response_json}")
return
task_blob = response_json["task"]
if "status" not in task_blob:
self.logger.error(f"Error fetching generated video: {task_blob}")
return
status = task_blob["status"]
if status == "PENDING":
self.logger.info("Task is still pending...")
return
elif status == "RUNNING":
self.logger.info("Task is still running...")
return
elif status == "SUCCEEDED":
generated_video_path = task_blob["artifacts"][0]["url"]
self.logger.info(f"Generated video path: {generated_video_path}")
else:
self.logger.error(f"Task failed with status: {status}, Response: {task_blob}")
return

file_path = f"{task_id}.mp4"
with open(file_path, "wb") as f:
f.write(data)
# Download mp4 from given url
if generated_video_path:
self.logger.info(f"Downloading generated video from: {generated_video_path}...")
async with aiohttp.ClientSession() as session:
async with session.get(generated_video_path) as response:
data = await response.read()
with open(file_path, "wb") as f:
f.write(data)
self.logger.info(f"Generated video saved as: {file_path}")

return file_path

Expand All @@ -87,8 +126,8 @@ def parse_arg() -> argparse.Namespace:
submit_parser.add_argument("--seconds", type=int, default=10, help="The duration in seconds.")
submit_parser.add_argument("--seed", type=str, default="", help="The seed value.")

fetch_parser = subparsers.add_parser("fetch", help="Fetch a generated image")
fetch_parser.add_argument("--task_id", type=str, required=True, help="The ID of the task to fetch the image for.")
fetch_parser = subparsers.add_parser("fetch", help="Fetch a generated video")
fetch_parser.add_argument("--task_id", type=str, required=True, help="The ID of the task to fetch the video for.")

return parser.parse_args()

Expand All @@ -98,12 +137,13 @@ def parse_arg() -> argparse.Namespace:
client = RunwayHTTPClient()

if args.command == "submit":
asyncio.run(
response = asyncio.run(
client.send_post_request(
prompt=args.prompt,
seconds=args.seconds,
seed=args.seed,
)
)
print(f"Response: {response}")
elif args.command == "fetch":
asyncio.run(client.fetch_generated_image(task_id=args.task_id))
asyncio.run(client.fetch_generated_video(task_id=args.task_id))
32 changes: 1 addition & 31 deletions tests/core/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import os
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
from unittest.mock import AsyncMock, mock_open, patch

import pytest
from aiohttp import ClientSession
Expand Down Expand Up @@ -38,33 +38,3 @@ async def test_send_post_request_success():
url, headers=expected_headers, data={"text_prompt": prompt, "seconds": str(seconds), "seed": seed}
)
assert response == expected_response


@pytest.mark.asyncio
async def test_fetch_generated_image():
api_key = "test_api_key"
task_id = "runway_1234"
expected_file_path = f"{task_id}.mp4"

mock_data = b"fake mp4 data"

with patch.dict(os.environ, {"302AI_API_KEY": api_key}):
mock_response = AsyncMock()
mock_response.read.return_value = mock_data

mock_session = AsyncMock(spec=ClientSession)
mock_session.get.return_value.__aenter__.return_value = mock_response

client = RunwayHTTPClient(session=mock_session, api_key=api_key)

with patch("builtins.open", mock_open()) as mock_file:
file_path = await client.fetch_generated_image(task_id)

expected_url = f"https://api.302.ai/runway/task/{task_id}/fetch"
expected_headers = {"Authorization": f"Bearer {api_key}", "User-Agent": "Apifox/1.0.0 (https://apifox.com)"}

mock_session.get.assert_called_once_with(expected_url, headers=expected_headers)
mock_file.assert_called_once_with(expected_file_path, "wb")
mock_file().write.assert_called_once_with(mock_data)

assert file_path == expected_file_path
Loading