Skip to content

Commit

Permalink
Merge pull request #8 from small-thinking/runway-test-5
Browse files Browse the repository at this point in the history
Enhance video gen
  • Loading branch information
yxjiang authored Jul 29, 2024
2 parents 5684730 + 1e51b50 commit 8c43510
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 46 deletions.
70 changes: 55 additions & 15 deletions mind_renderer/core/video.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
"""
Run video generation with command:
poetry run python mind_renderer/core/video.py submit --prompt the prompt
Fetch generated video with command:
poetry run python mind_renderer/core/video.py fetch --task_id runway_f2450ac24c1e
"""

import argparse
import asyncio
import json
import os
from typing import Any, Dict, Optional

import aiohttp
from dotenv import load_dotenv

from mind_renderer.utils.logger import Logger


class RunwayHTTPClient:
def __init__(self, session: Optional[aiohttp.ClientSession] = None, api_key: str = "") -> None:
Expand All @@ -16,6 +27,7 @@ def __init__(self, session: Optional[aiohttp.ClientSession] = None, api_key: str
session (Optional[aiohttp.ClientSession]): An optional aiohttp ClientSession to use for requests.
"""
self.base_url = "https://api.302.ai"
self.logger = Logger(__name__)
load_dotenv(override=True)
self.api_key = api_key if api_key else os.getenv("302AI_API_KEY", "test_api_key")

Expand Down Expand Up @@ -45,29 +57,56 @@ async def send_post_request(self, prompt: str, seconds: int = 10, seed: str = ""
async with session.post(url, headers=self.headers, data=data) as response:
return await response.json()

async def fetch_generated_image(self, task_id: str) -> str:
async def fetch_generated_video(self, task_id: str) -> str:
"""
Fetches the generated image for a given task ID and saves it as an MP4 file.
Fetches the generated video for a given task ID and saves it as an MP4 file.
Args:
task_id (str): The ID of the task to fetch the image for.
task_id (str): The ID of the task to fetch the video for.
session (Optional[aiohttp.ClientSession]): An optional aiohttp ClientSession to use for requests.
Returns:
str: The path to the saved MP4 file.
"""
url = f"{self.base_url}/runway/task/{task_id}/fetch"

if self.session:
async with self.session.get(url, headers=self.headers) as response:
generated_video_path = None
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=self.headers) as response:
data = await response.read()
else:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=self.headers) as response:
data = await response.read()
decoded_data = data.decode("utf-8")
response_json = json.loads(decoded_data)
if "task" not in response_json:
self.logger.error(f"Error fetching generated video: {response_json}")
return
task_blob = response_json["task"]
if "status" not in task_blob:
self.logger.error(f"Error fetching generated video: {task_blob}")
return
status = task_blob["status"]
if status == "PENDING":
self.logger.info("Task is still pending...")
return
elif status == "RUNNING":
self.logger.info("Task is still running...")
return
elif status == "SUCCEEDED":
generated_video_path = task_blob["artifacts"][0]["url"]
self.logger.info(f"Generated video path: {generated_video_path}")
else:
self.logger.error(f"Task failed with status: {status}, Response: {task_blob}")
return

file_path = f"{task_id}.mp4"
with open(file_path, "wb") as f:
f.write(data)
# Download mp4 from given url
if generated_video_path:
self.logger.info(f"Downloading generated video from: {generated_video_path}...")
async with aiohttp.ClientSession() as session:
async with session.get(generated_video_path) as response:
data = await response.read()
with open(file_path, "wb") as f:
f.write(data)
self.logger.info(f"Generated video saved as: {file_path}")

return file_path

Expand All @@ -87,8 +126,8 @@ def parse_arg() -> argparse.Namespace:
submit_parser.add_argument("--seconds", type=int, default=10, help="The duration in seconds.")
submit_parser.add_argument("--seed", type=str, default="", help="The seed value.")

fetch_parser = subparsers.add_parser("fetch", help="Fetch a generated image")
fetch_parser.add_argument("--task_id", type=str, required=True, help="The ID of the task to fetch the image for.")
fetch_parser = subparsers.add_parser("fetch", help="Fetch a generated video")
fetch_parser.add_argument("--task_id", type=str, required=True, help="The ID of the task to fetch the video for.")

return parser.parse_args()

Expand All @@ -98,12 +137,13 @@ def parse_arg() -> argparse.Namespace:
client = RunwayHTTPClient()

if args.command == "submit":
asyncio.run(
response = asyncio.run(
client.send_post_request(
prompt=args.prompt,
seconds=args.seconds,
seed=args.seed,
)
)
print(f"Response: {response}")
elif args.command == "fetch":
asyncio.run(client.fetch_generated_image(task_id=args.task_id))
asyncio.run(client.fetch_generated_video(task_id=args.task_id))
32 changes: 1 addition & 31 deletions tests/core/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import os
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
from unittest.mock import AsyncMock, mock_open, patch

import pytest
from aiohttp import ClientSession
Expand Down Expand Up @@ -38,33 +38,3 @@ async def test_send_post_request_success():
url, headers=expected_headers, data={"text_prompt": prompt, "seconds": str(seconds), "seed": seed}
)
assert response == expected_response


@pytest.mark.asyncio
async def test_fetch_generated_image():
api_key = "test_api_key"
task_id = "runway_1234"
expected_file_path = f"{task_id}.mp4"

mock_data = b"fake mp4 data"

with patch.dict(os.environ, {"302AI_API_KEY": api_key}):
mock_response = AsyncMock()
mock_response.read.return_value = mock_data

mock_session = AsyncMock(spec=ClientSession)
mock_session.get.return_value.__aenter__.return_value = mock_response

client = RunwayHTTPClient(session=mock_session, api_key=api_key)

with patch("builtins.open", mock_open()) as mock_file:
file_path = await client.fetch_generated_image(task_id)

expected_url = f"https://api.302.ai/runway/task/{task_id}/fetch"
expected_headers = {"Authorization": f"Bearer {api_key}", "User-Agent": "Apifox/1.0.0 (https://apifox.com)"}

mock_session.get.assert_called_once_with(expected_url, headers=expected_headers)
mock_file.assert_called_once_with(expected_file_path, "wb")
mock_file().write.assert_called_once_with(mock_data)

assert file_path == expected_file_path

0 comments on commit 8c43510

Please sign in to comment.