Skip to content

Commit

Permalink
Merge pull request #7 from small-thinking/add-runway-api
Browse files Browse the repository at this point in the history
Add video gen api
  • Loading branch information
yxjiang authored Jul 29, 2024
2 parents 109e8e4 + 2cdccc8 commit 5684730
Show file tree
Hide file tree
Showing 10 changed files with 517 additions and 226 deletions.
3 changes: 0 additions & 3 deletions mind_renderer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
poetry run python mind_renderer/cli.py
"""

import argparse

from mind_renderer.core.story_generators import OneStepStoryGenerator
from mind_renderer.utils.logger import Logger


def main():
Expand Down
51 changes: 51 additions & 0 deletions mind_renderer/core/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import time

import openai
from dotenv import load_dotenv
from pydub.utils import mediainfo
from tqdm import tqdm

load_dotenv(override=True)

# Replace with your OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")

client = openai.OpenAI()


def transcribe_audio(file_path: str) -> str:
# Open the audio file
with open(file_path, "rb") as audio_file:
# Use the OpenAI Whisper API to transcribe the audio
response = openai.audio.transcriptions.create(
file=audio_file, model="whisper-1", response_format="verbose_json", timestamp_granularities=["word"]
)
return response.text


if __name__ == "__main__":

# Path to your audio file
audio_file_path = "test.mp3"
folder_path = os.path.expanduser("~/Downloads")
audio_file_path = os.path.join(folder_path, audio_file_path)

# Estimate the duration of the audio file
audio_info = mediainfo(audio_file_path)
duration = float(audio_info["duration"])

# Simulate progress updates (since the actual API call is synchronous)
print("Transcribing audio...")
for i in tqdm(range(int(duration)), desc="Progress", unit="s"):
time.sleep(1) # Simulate time passing for each second of audio

# Transcribe the audio file
transcription = transcribe_audio(audio_file_path)

# Save the transcription to a text file
output_file_path = "output_transcription.txt"
with open(folder_path, "w", encoding="utf-8") as f:
f.write(transcription)

print("Transcription completed and saved to output_transcription.txt")
20 changes: 16 additions & 4 deletions mind_renderer/core/story_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ class StorySketchGenerator(dspy.Module):

instruction: str = """
You are the story sketching assistant. Please follow the below instruction:
1. Story Sketch Generation: Create a sketch for a multi-section story according to the specified number of
sections.
2. Language Consistency: Ensure the sketch is generated in the SAME language as the user input. For example,
generate the sketch in English if the user input is in English,
generate the sketch in Chinese if the user input is in Chinese.
3. Worldview and Detailed Prompts: Based on the simple prompt provided:
3.1. Generate the story worldview, which is a high-level description of the story to ensure consistency across
Expand Down Expand Up @@ -134,10 +135,17 @@ def __init__(self, config_loader: ConfigLoader, logger: Logger):
self.gen_thumbnail_prompt = self.config_loader.get_value("image_model.gen_thumbnail", False)

def forward(
self, prompt: str, story_worldview: str, story_piece: StoryPiece, prev_piece: StoryPiece = None, **kwargs
self,
idx: str,
prompt: str,
story_worldview: str,
story_piece: StoryPiece,
prev_piece: StoryPiece = None,
**kwargs,
) -> None:
"""Generate the element based on the prompt and populate the story piece with the generated element."""
self.generate(
idx=idx,
story_description=prompt,
story_worldview=story_worldview,
story_piece=story_piece,
Expand All @@ -151,6 +159,7 @@ def forward(
)
def generate(
self,
idx: str,
story_description: str,
story_worldview: str,
story_piece: StoryPiece,
Expand Down Expand Up @@ -183,7 +192,9 @@ def generate(
"Thumbnail Generation Prompt:", ""
).strip()
self.logger.info(f"Generated story piece text:\n{story_piece.text}")
self.logger.info(f"Generated thumbnail prompt:\n{story_piece.thumbnail_gen_prompt}")
self.logger.info(
f"{idx}: Generated thumbnail prompt:\n{story_piece.thumbnail_gen_prompt}", write_to_file=True
)
except Exception as e:
self.logger.error(f"An error occurred: {str(e)}")
raise
Expand Down Expand Up @@ -245,7 +256,7 @@ def __init__(
load_dotenv(override=True)
self.state: dict[str, any]
self.config_loader = ConfigLoader()
self.logger = Logger(__name__, parent_folder=self.config_loader.get_value("root_save_folder", "outputs"))
self.logger = Logger("story", parent_folder=self.config_loader.get_value("root_save_folder", "outputs"))
self.genres = self.config_loader.get_value("genres", "")
self.writing_style = self.config_loader.get_value("writing_style", "")
self.gen_thumbnail_prompt = self.config_loader.get_value("image_model.gen_thumbnail", False)
Expand Down Expand Up @@ -308,6 +319,7 @@ def generate(self, prev_version: Story = None, **kwargs) -> Story:
self.logger.info(f"Generating story piece {i+1} with prompt: {prompt}...")
story_piece = StoryPiece(idx=i)
self.text_generator(
idx=f"Section {i+1}",
prompt=prompt,
story_worldview=story_worldview,
story_piece=story_piece,
Expand Down
109 changes: 109 additions & 0 deletions mind_renderer/core/video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import argparse
import asyncio
import os
from typing import Any, Dict, Optional

import aiohttp
from dotenv import load_dotenv


class RunwayHTTPClient:
def __init__(self, session: Optional[aiohttp.ClientSession] = None, api_key: str = "") -> None:
"""
Initializes the HTTPClient with the URL and headers.
Args:
session (Optional[aiohttp.ClientSession]): An optional aiohttp ClientSession to use for requests.
"""
self.base_url = "https://api.302.ai"
load_dotenv(override=True)
self.api_key = api_key if api_key else os.getenv("302AI_API_KEY", "test_api_key")

self.headers = {"Authorization": f"Bearer {self.api_key}", "User-Agent": "Apifox/1.0.0 (https://apifox.com)"}
self.session = session

async def send_post_request(self, prompt: str, seconds: int = 10, seed: str = "") -> Dict[str, Any]:
"""
Sends an asynchronous HTTP POST request with the specified text prompt and other form data.
Args:
prompt (str): The text prompt to include in the request.
seconds (int, optional): The duration in seconds. Defaults to 10.
seed (str, optional): The seed value. Defaults to an empty string.
Returns:
Dict[str, Any]: The JSON response from the server.
"""
url = f"{self.base_url}/runway/submit"
data = {"text_prompt": prompt, "seconds": str(seconds), "seed": seed}

if self.session:
async with self.session.post(url, headers=self.headers, data=data) as response:
return await response.json()
else:
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=self.headers, data=data) as response:
return await response.json()

async def fetch_generated_image(self, task_id: str) -> str:
"""
Fetches the generated image for a given task ID and saves it as an MP4 file.
Args:
task_id (str): The ID of the task to fetch the image for.
Returns:
str: The path to the saved MP4 file.
"""
url = f"{self.base_url}/runway/task/{task_id}/fetch"

if self.session:
async with self.session.get(url, headers=self.headers) as response:
data = await response.read()
else:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=self.headers) as response:
data = await response.read()

file_path = f"{task_id}.mp4"
with open(file_path, "wb") as f:
f.write(data)

return file_path


def parse_arg() -> argparse.Namespace:
"""
Parses the command-line arguments.
Returns:
argparse.Namespace: The parsed arguments.
"""
parser = argparse.ArgumentParser(description="Runway API Client")
subparsers = parser.add_subparsers(dest="command", required=True)

submit_parser = subparsers.add_parser("submit", help="Submit a new task")
submit_parser.add_argument("--prompt", type=str, required=True, help="The text prompt to include in the request.")
submit_parser.add_argument("--seconds", type=int, default=10, help="The duration in seconds.")
submit_parser.add_argument("--seed", type=str, default="", help="The seed value.")

fetch_parser = subparsers.add_parser("fetch", help="Fetch a generated image")
fetch_parser.add_argument("--task_id", type=str, required=True, help="The ID of the task to fetch the image for.")

return parser.parse_args()


if __name__ == "__main__":
args = parse_arg()
client = RunwayHTTPClient()

if args.command == "submit":
asyncio.run(
client.send_post_request(
prompt=args.prompt,
seconds=args.seconds,
seed=args.seed,
)
)
elif args.command == "fetch":
asyncio.run(client.fetch_generated_image(task_id=args.task_id))
56 changes: 41 additions & 15 deletions mind_renderer/utils/logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import logging
import os
import time
from enum import Enum
from typing import Optional

Expand Down Expand Up @@ -47,7 +48,7 @@ def __new__(cls, *args, **kwargs):
return cls._instance

def __init__(
self, logger_name: str, parent_folder: str, verbose: bool = True, level: Optional[LoggingLevel] = None
self, logger_name: str, parent_folder: str = "", verbose: bool = True, level: Optional[LoggingLevel] = None
):
if not hasattr(self, "logger"):
load_dotenv(override=True)
Expand All @@ -67,15 +68,17 @@ def __init__(
self.logger.addHandler(self.console_handler)

# File handler
log_folder_root = self.config.get_value("logger.folder_root", "logs")
log_folder = os.path.join(log_folder_root, parent_folder)
if not os.path.exists(log_folder):
os.makedirs(log_folder)
log_file = os.path.join(log_folder, f"{logger_name}.log")
self.file_handler = logging.FileHandler(log_file)
self.file_handler.setLevel(level=self.logging_level.value)
self.file_handler.setFormatter(self.formatter)
self.logger.addHandler(self.file_handler)
if parent_folder:
log_folder_root = self.config.get_value("logger.folder_root", "logs")
log_folder = os.path.join(parent_folder, log_folder_root)
if not os.path.exists(log_folder):
os.makedirs(log_folder)
timestamp = str(int(time.time()))
log_file = os.path.join(log_folder, f"{timestamp}-{logger_name}.log")
self.file_handler = logging.FileHandler(log_file)
self.file_handler.setLevel(level=self.logging_level.value)
self.file_handler.setFormatter(self.formatter)
self.logger.addHandler(self.file_handler)

def log(self, message: str, level: LoggingLevel, color: str = ansi.Fore.GREEN, write_to_file: bool = False) -> None:
if level >= self.logging_level:
Expand All @@ -85,12 +88,35 @@ def log(self, message: str, level: LoggingLevel, color: str = ansi.Fore.GREEN, w
caller_frame = inspect.stack()[2]
caller_name = caller_frame.function
caller_line = caller_frame.lineno
message = f"{caller_name}({caller_line}): {message}"
log_message = color + message + Fore.RESET
self.logger.info(log_message)
formatted_message = f"{caller_name}({caller_line}): {message}"

# Console logging
console_message = color + formatted_message + Fore.RESET
self.console_handler.handle(
logging.LogRecord(
name=self.logger.name,
level=level.value,
pathname=caller_frame.filename,
lineno=caller_line,
msg=console_message,
args=None,
exc_info=None,
)
)

# File logging
if write_to_file:
self.file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
self.logger.info(message)
self.file_handler.handle(
logging.LogRecord(
name=self.logger.name,
level=level.value,
pathname=caller_frame.filename,
lineno=caller_line,
msg=formatted_message,
args=None,
exc_info=None,
)
)

def debug(self, message: str, write_to_file: bool = False) -> None:
self.log(message, Logger.LoggingLevel.DEBUG, Fore.BLACK, write_to_file)
Expand Down
Loading

0 comments on commit 5684730

Please sign in to comment.