diff --git a/mind_renderer/core/video.py b/mind_renderer/core/video.py index 8d62a04..27e5a10 100644 --- a/mind_renderer/core/video.py +++ b/mind_renderer/core/video.py @@ -1,11 +1,22 @@ +""" +Run video generation with command: + poetry run python mind_renderer/core/video.py submit --prompt the prompt + +Fetch generated video with command: + poetry run python mind_renderer/core/video.py fetch --task_id runway_f2450ac24c1e +""" + import argparse import asyncio +import json import os from typing import Any, Dict, Optional import aiohttp from dotenv import load_dotenv +from mind_renderer.utils.logger import Logger + class RunwayHTTPClient: def __init__(self, session: Optional[aiohttp.ClientSession] = None, api_key: str = "") -> None: @@ -16,6 +27,7 @@ def __init__(self, session: Optional[aiohttp.ClientSession] = None, api_key: str session (Optional[aiohttp.ClientSession]): An optional aiohttp ClientSession to use for requests. """ self.base_url = "https://api.302.ai" + self.logger = Logger(__name__) load_dotenv(override=True) self.api_key = api_key if api_key else os.getenv("302AI_API_KEY", "test_api_key") @@ -45,29 +57,56 @@ async def send_post_request(self, prompt: str, seconds: int = 10, seed: str = "" async with session.post(url, headers=self.headers, data=data) as response: return await response.json() - async def fetch_generated_image(self, task_id: str) -> str: + async def fetch_generated_video(self, task_id: str) -> str: """ - Fetches the generated image for a given task ID and saves it as an MP4 file. + Fetches the generated video for a given task ID and saves it as an MP4 file. Args: - task_id (str): The ID of the task to fetch the image for. + task_id (str): The ID of the task to fetch the video for. + session (Optional[aiohttp.ClientSession]): An optional aiohttp ClientSession to use for requests. Returns: str: The path to the saved MP4 file. """ url = f"{self.base_url}/runway/task/{task_id}/fetch" - if self.session: - async with self.session.get(url, headers=self.headers) as response: + generated_video_path = None + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=self.headers) as response: data = await response.read() - else: - async with aiohttp.ClientSession() as session: - async with session.get(url, headers=self.headers) as response: - data = await response.read() + decoded_data = data.decode("utf-8") + response_json = json.loads(decoded_data) + if "task" not in response_json: + self.logger.error(f"Error fetching generated video: {response_json}") + return + task_blob = response_json["task"] + if "status" not in task_blob: + self.logger.error(f"Error fetching generated video: {task_blob}") + return + status = task_blob["status"] + if status == "PENDING": + self.logger.info("Task is still pending...") + return + elif status == "RUNNING": + self.logger.info("Task is still running...") + return + elif status == "SUCCEEDED": + generated_video_path = task_blob["artifacts"][0]["url"] + self.logger.info(f"Generated video path: {generated_video_path}") + else: + self.logger.error(f"Task failed with status: {status}, Response: {task_blob}") + return file_path = f"{task_id}.mp4" - with open(file_path, "wb") as f: - f.write(data) + # Download mp4 from given url + if generated_video_path: + self.logger.info(f"Downloading generated video from: {generated_video_path}...") + async with aiohttp.ClientSession() as session: + async with session.get(generated_video_path) as response: + data = await response.read() + with open(file_path, "wb") as f: + f.write(data) + self.logger.info(f"Generated video saved as: {file_path}") return file_path @@ -87,8 +126,8 @@ def parse_arg() -> argparse.Namespace: submit_parser.add_argument("--seconds", type=int, default=10, help="The duration in seconds.") submit_parser.add_argument("--seed", type=str, default="", help="The seed value.") - fetch_parser = subparsers.add_parser("fetch", help="Fetch a generated image") - fetch_parser.add_argument("--task_id", type=str, required=True, help="The ID of the task to fetch the image for.") + fetch_parser = subparsers.add_parser("fetch", help="Fetch a generated video") + fetch_parser.add_argument("--task_id", type=str, required=True, help="The ID of the task to fetch the video for.") return parser.parse_args() @@ -98,12 +137,13 @@ def parse_arg() -> argparse.Namespace: client = RunwayHTTPClient() if args.command == "submit": - asyncio.run( + response = asyncio.run( client.send_post_request( prompt=args.prompt, seconds=args.seconds, seed=args.seed, ) ) + print(f"Response: {response}") elif args.command == "fetch": - asyncio.run(client.fetch_generated_image(task_id=args.task_id)) + asyncio.run(client.fetch_generated_video(task_id=args.task_id)) diff --git a/tests/core/test_video.py b/tests/core/test_video.py index 8687d7a..1f11f97 100644 --- a/tests/core/test_video.py +++ b/tests/core/test_video.py @@ -3,7 +3,7 @@ """ import os -from unittest.mock import AsyncMock, MagicMock, mock_open, patch +from unittest.mock import AsyncMock, mock_open, patch import pytest from aiohttp import ClientSession @@ -38,33 +38,3 @@ async def test_send_post_request_success(): url, headers=expected_headers, data={"text_prompt": prompt, "seconds": str(seconds), "seed": seed} ) assert response == expected_response - - -@pytest.mark.asyncio -async def test_fetch_generated_image(): - api_key = "test_api_key" - task_id = "runway_1234" - expected_file_path = f"{task_id}.mp4" - - mock_data = b"fake mp4 data" - - with patch.dict(os.environ, {"302AI_API_KEY": api_key}): - mock_response = AsyncMock() - mock_response.read.return_value = mock_data - - mock_session = AsyncMock(spec=ClientSession) - mock_session.get.return_value.__aenter__.return_value = mock_response - - client = RunwayHTTPClient(session=mock_session, api_key=api_key) - - with patch("builtins.open", mock_open()) as mock_file: - file_path = await client.fetch_generated_image(task_id) - - expected_url = f"https://api.302.ai/runway/task/{task_id}/fetch" - expected_headers = {"Authorization": f"Bearer {api_key}", "User-Agent": "Apifox/1.0.0 (https://apifox.com)"} - - mock_session.get.assert_called_once_with(expected_url, headers=expected_headers) - mock_file.assert_called_once_with(expected_file_path, "wb") - mock_file().write.assert_called_once_with(mock_data) - - assert file_path == expected_file_path