-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from small-thinking/add-runway-api
Add video gen api
- Loading branch information
Showing
10 changed files
with
517 additions
and
226 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
import time | ||
|
||
import openai | ||
from dotenv import load_dotenv | ||
from pydub.utils import mediainfo | ||
from tqdm import tqdm | ||
|
||
load_dotenv(override=True) | ||
|
||
# Replace with your OpenAI API key | ||
openai.api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
client = openai.OpenAI() | ||
|
||
|
||
def transcribe_audio(file_path: str) -> str: | ||
# Open the audio file | ||
with open(file_path, "rb") as audio_file: | ||
# Use the OpenAI Whisper API to transcribe the audio | ||
response = openai.audio.transcriptions.create( | ||
file=audio_file, model="whisper-1", response_format="verbose_json", timestamp_granularities=["word"] | ||
) | ||
return response.text | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
# Path to your audio file | ||
audio_file_path = "test.mp3" | ||
folder_path = os.path.expanduser("~/Downloads") | ||
audio_file_path = os.path.join(folder_path, audio_file_path) | ||
|
||
# Estimate the duration of the audio file | ||
audio_info = mediainfo(audio_file_path) | ||
duration = float(audio_info["duration"]) | ||
|
||
# Simulate progress updates (since the actual API call is synchronous) | ||
print("Transcribing audio...") | ||
for i in tqdm(range(int(duration)), desc="Progress", unit="s"): | ||
time.sleep(1) # Simulate time passing for each second of audio | ||
|
||
# Transcribe the audio file | ||
transcription = transcribe_audio(audio_file_path) | ||
|
||
# Save the transcription to a text file | ||
output_file_path = "output_transcription.txt" | ||
with open(folder_path, "w", encoding="utf-8") as f: | ||
f.write(transcription) | ||
|
||
print("Transcription completed and saved to output_transcription.txt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import argparse | ||
import asyncio | ||
import os | ||
from typing import Any, Dict, Optional | ||
|
||
import aiohttp | ||
from dotenv import load_dotenv | ||
|
||
|
||
class RunwayHTTPClient: | ||
def __init__(self, session: Optional[aiohttp.ClientSession] = None, api_key: str = "") -> None: | ||
""" | ||
Initializes the HTTPClient with the URL and headers. | ||
Args: | ||
session (Optional[aiohttp.ClientSession]): An optional aiohttp ClientSession to use for requests. | ||
""" | ||
self.base_url = "https://api.302.ai" | ||
load_dotenv(override=True) | ||
self.api_key = api_key if api_key else os.getenv("302AI_API_KEY", "test_api_key") | ||
|
||
self.headers = {"Authorization": f"Bearer {self.api_key}", "User-Agent": "Apifox/1.0.0 (https://apifox.com)"} | ||
self.session = session | ||
|
||
async def send_post_request(self, prompt: str, seconds: int = 10, seed: str = "") -> Dict[str, Any]: | ||
""" | ||
Sends an asynchronous HTTP POST request with the specified text prompt and other form data. | ||
Args: | ||
prompt (str): The text prompt to include in the request. | ||
seconds (int, optional): The duration in seconds. Defaults to 10. | ||
seed (str, optional): The seed value. Defaults to an empty string. | ||
Returns: | ||
Dict[str, Any]: The JSON response from the server. | ||
""" | ||
url = f"{self.base_url}/runway/submit" | ||
data = {"text_prompt": prompt, "seconds": str(seconds), "seed": seed} | ||
|
||
if self.session: | ||
async with self.session.post(url, headers=self.headers, data=data) as response: | ||
return await response.json() | ||
else: | ||
async with aiohttp.ClientSession() as session: | ||
async with session.post(url, headers=self.headers, data=data) as response: | ||
return await response.json() | ||
|
||
async def fetch_generated_image(self, task_id: str) -> str: | ||
""" | ||
Fetches the generated image for a given task ID and saves it as an MP4 file. | ||
Args: | ||
task_id (str): The ID of the task to fetch the image for. | ||
Returns: | ||
str: The path to the saved MP4 file. | ||
""" | ||
url = f"{self.base_url}/runway/task/{task_id}/fetch" | ||
|
||
if self.session: | ||
async with self.session.get(url, headers=self.headers) as response: | ||
data = await response.read() | ||
else: | ||
async with aiohttp.ClientSession() as session: | ||
async with session.get(url, headers=self.headers) as response: | ||
data = await response.read() | ||
|
||
file_path = f"{task_id}.mp4" | ||
with open(file_path, "wb") as f: | ||
f.write(data) | ||
|
||
return file_path | ||
|
||
|
||
def parse_arg() -> argparse.Namespace: | ||
""" | ||
Parses the command-line arguments. | ||
Returns: | ||
argparse.Namespace: The parsed arguments. | ||
""" | ||
parser = argparse.ArgumentParser(description="Runway API Client") | ||
subparsers = parser.add_subparsers(dest="command", required=True) | ||
|
||
submit_parser = subparsers.add_parser("submit", help="Submit a new task") | ||
submit_parser.add_argument("--prompt", type=str, required=True, help="The text prompt to include in the request.") | ||
submit_parser.add_argument("--seconds", type=int, default=10, help="The duration in seconds.") | ||
submit_parser.add_argument("--seed", type=str, default="", help="The seed value.") | ||
|
||
fetch_parser = subparsers.add_parser("fetch", help="Fetch a generated image") | ||
fetch_parser.add_argument("--task_id", type=str, required=True, help="The ID of the task to fetch the image for.") | ||
|
||
return parser.parse_args() | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_arg() | ||
client = RunwayHTTPClient() | ||
|
||
if args.command == "submit": | ||
asyncio.run( | ||
client.send_post_request( | ||
prompt=args.prompt, | ||
seconds=args.seconds, | ||
seed=args.seed, | ||
) | ||
) | ||
elif args.command == "fetch": | ||
asyncio.run(client.fetch_generated_image(task_id=args.task_id)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.