Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate thumbnails #4

Merged
merged 5 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mind_renderer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import argparse
import os

from mind_renderer.core.story_generators import OneStepStoryGenerator
from mind_renderer.utils.logger import Logger
Expand All @@ -22,7 +23,9 @@ def main():
story = story_generator.generate()

logger.info("Generated story:")
print(story)
save_folder = story_generator.config.get("output_path", "outputs")
markdown_path = os.path.join(save_folder, "story.md")
story.to_markdown(markdown_path)


if __name__ == "__main__":
Expand Down
15 changes: 15 additions & 0 deletions mind_renderer/core/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,18 @@ def add_piece(self, piece: StoryPiece):

def __str__(self):
return "\n".join([f"Episode {idx}:\n{str(piece)}\n" for idx, piece in enumerate(self.pieces)])

def to_markdown(self, file_path: str):
with open(file_path, "w") as file:
file.write(f"# {self.genres}\n")
for idx, piece in enumerate(self.pieces):
file.write(f"## Episode {idx}\n")
file.write(f"{piece.text}\n")
if piece.image_uri:
file.write(f"![image](./{piece.image_uri})\n")
if piece.audio_uri:
file.write(f"![audio](./{piece.audio_uri})\n")
if piece.video_uri:
file.write(f"![video](./{piece.video_uri})\n")
file.write("\n")
return file_path
117 changes: 89 additions & 28 deletions mind_renderer/core/story_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,15 @@

import dspy
from dotenv import load_dotenv
from openai import OpenAI

from mind_renderer.core.entities import Story, StoryPiece
from mind_renderer.utils.config_loader import ConfigLoader
from mind_renderer.utils.deepseek_lm import DeepSeek
from mind_renderer.utils.custom_lm import init_lm
from mind_renderer.utils.image_processing import download_and_save_image
from mind_renderer.utils.logger import Logger


def init_lm(text_model_config: dict[str, str]) -> dspy.LM:
"""Initialize the language model based on the configuration."""
provider = text_model_config["provider"]
lm_name = text_model_config["lm_name"]
section_length = text_model_config.get("section_length", 1000)
if provider == "DeepSeek":
api_key = os.getenv("DEEPSEEK_API_KEY")
return DeepSeek(base_url="https://api.deepseek.com", model=lm_name, api_key=api_key, max_tokens=section_length)
elif provider == "GROQ":
api_key = os.getenv("GROQ_API_KEY")
return dspy.__dict__[provider](model=lm_name, max_tokens=section_length, api_key=api_key)


class StorySketchGenerator(dspy.Module):
"""StorySketchGenerator is used to expand a simple prompt into detailed piece-level prompts."""

Expand All @@ -49,11 +38,16 @@ class SketchGenerateSignature(dspy.Signature):

instruction: str = dspy.InputField(desc="The instruction for the generation.")
simple_prompt: str = dspy.InputField(desc="The user instruction on how the story would look like.")
num_sections: str = dspy.InputField(desc="The number of episodes the story should have.")
story_worldview: str = dspy.OutputField(desc="The world view of the story.")
num_sections: str = dspy.InputField(
desc="""
The number of episodes the story should have. Each episode describe one scene.
Consequtive episodes should be coherent and complementary.
"""
)
story_worldview: str = dspy.OutputField(desc="The world view of the story. Used to ensure consistency.")
prompts: str = dspy.OutputField(
desc="""
The json blob that contains a list of prompts for each piece of the story, they should be coherent, e.g.
The blob that contains a list of prompts for each piece of the story, they should be coherent, e.g.
{
"prompts": ["prompt1", "prompt2", "prompt3"]
}
Expand Down Expand Up @@ -99,6 +93,8 @@ class TextGenerateSignature(dspy.Signature):
should_gen_thumbnail_prompt: bool = dspy.InputField(
desc="Whether to generate the thumbnail image gneeration prompt."
)
prev_piece: str = dspy.InputField(desc="The previous piece of the story. For continuity.")

thumbnail_generation_prompt: str = dspy.OutputField(
desc="The cohensive prompt for generating the thumbnail image in English only. Use descriptive keywords."
)
Expand All @@ -115,15 +111,24 @@ def __init__(self, config: dict[str, any], text_model_config: dict[str, str]):
self.writing_style = config["writing_style"]
self.gen_thumbnail_prompt = config["gen_thumbnail"]

def forward(self, prompt: str, story_worldview: str, story_piece: StoryPiece, **kwargs) -> None:
def forward(
self, prompt: str, story_worldview: str, story_piece: StoryPiece, prev_piece: StoryPiece = None, **kwargs
) -> None:
"""Generate the element based on the prompt and populate the story piece with the generated element."""
self.generate(story_description=prompt, story_worldview=story_worldview, story_piece=story_piece, **kwargs)
self.generate(
story_description=prompt,
story_worldview=story_worldview,
story_piece=story_piece,
prev_piece=prev_piece,
**kwargs,
)

def generate(
self,
story_description: str,
story_worldview: str,
story_piece: StoryPiece,
prev_piece: StoryPiece = None,
**kwargs,
) -> None:
"""Generate the element based on the story_description and populate the story piece with the generated element.
Expand All @@ -133,6 +138,7 @@ def generate(
story_worldview (str): The story_worldview of the story.
gen_thumbnail_prompt (bool): Whether to generate the thumbnail prompt.
story_piece (StoryPiece): The story piece to populate with the generated element.
prev_piece (StoryPiece): The previous piece of the story.

Returns:
Any: The generated element.
Expand All @@ -146,11 +152,54 @@ def generate(
story_genres=self.genres,
story_writing_style=self.writing_style,
should_gen_thumbnail_prompt=str(self.gen_thumbnail_prompt),
prev_piece=prev_piece.text if prev_piece else "",
)
story_piece.text = response.story_text
story_piece.thumbnail_gen_prompt = response.thumbnail_generation_prompt


class ThumbnailGenerator(dspy.Module):
"""ThumbnailGenerator is used to generate thumbnail images for story pieces."""

def __init__(self, config: dict[str, any]):
self.logger = Logger(__name__)
self.config = config
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
self.root_save_folder = config.get("root_save_folder", "outputs")
self.save_folder = config.get("thumbnail_save_folder", "thumbnails")

def forward(self, story_piece: StoryPiece) -> None:
"""Generate the thumbnail for the given story piece."""
self.generate(story_piece)

def generate(self, story_piece: StoryPiece) -> None:
"""Generate the thumbnail based on the prompt in the story piece."""

genres = self.config.get("genres", "")
enhanced_prompt = f"""
Generate a thumbnail image for the story piece. The story piece is part of a whole story.
The genres of the story are: {genres}.
The description of the thumbnail: {story_piece.thumbnail_gen_prompt}.
"""

response = self.client.images.generate(
model="dall-e-2",
prompt=enhanced_prompt,
size="256x256",
quality="standard",
n=1,
)

image_url = response.data[0].url
thumbnail_path = os.path.join(self.save_folder, f"thumbnail_{story_piece.idx}.png")
save_path = os.path.join(self.root_save_folder, thumbnail_path)

downloaded_path = download_and_save_image(image_url, save_path)
story_piece.image_uri = thumbnail_path

self.logger.info(f"Generated thumbnail for story piece {story_piece.idx}, saved at: {downloaded_path}")


class StoryGenerator(dspy.Module):
"""StoryGenerator is the main interface for the user to interact with the system.
It generates the entire story based on the prompt and the previous version of the story.
Expand Down Expand Up @@ -203,19 +252,20 @@ class OneStepStoryGenerator(StoryGenerator):
def __init__(self, config_path: str):
self.logger = Logger(__name__)
self.config_loader = ConfigLoader(config_path)
config = self.config_loader.get_config()
self.config = self.config_loader.get_config()
text_model_config = self.config_loader.get_text_model_config()

super().__init__(
genres=config.get("genres", ""),
writing_style=config.get("writing_style", ""),
gen_thumbnail_prompt=config.get("gen_thumbnail", False),
genres=self.config.get("genres", ""),
writing_style=self.config.get("writing_style", ""),
gen_thumbnail_prompt=self.config.get("gen_thumbnail", False),
provider=text_model_config.get("provider"),
lm_name=text_model_config.get("lm_name"),
)

self.story_sketch_generator = StorySketchGenerator(config=config, text_model_config=text_model_config)
self.text_generator = TextGenerator(config=config, text_model_config=text_model_config)
self.story_sketch_generator = StorySketchGenerator(config=self.config, text_model_config=text_model_config)
self.text_generator = TextGenerator(config=self.config, text_model_config=text_model_config)
self.thumbnail_generator = ThumbnailGenerator(config=self.config)

def generate(self, prev_version: Story = None, **kwargs) -> Story:
config = self.config_loader.get_config()
Expand All @@ -225,18 +275,29 @@ def generate(self, prev_version: Story = None, **kwargs) -> Story:
self.logger.tool_log(f"Generating a {num_sections} pieces story with prompt: {prompt}...")

# Generate story sketch
self.logger.info("Generating story sketch...")
self.logger.tool_log("Generating story sketch...")
sketch_response = self.story_sketch_generator(simple_prompt=prompt, num_sections=num_sections)

story_worldview = sketch_response.story_worldview
prompts = json.loads(sketch_response.prompts)["prompts"]
try:
prompts = json.loads(sketch_response.prompts)["prompts"]
except json.JSONDecodeError:
raise ValueError(f"Failed to decode the prompts: {sketch_response.prompts}")

# Generate the story pieces
story_pieces = []
for i, prompt in enumerate(prompts):
self.logger.info(f"Generating story piece {i+1} with prompt: {prompt}...")
story_piece = StoryPiece(idx=i)
self.text_generator(prompt=prompt, story_worldview=story_worldview, story_piece=story_piece)
self.text_generator(
prompt=prompt,
story_worldview=story_worldview,
story_piece=story_piece,
prev_piece=story_pieces[i - 1] if i > 0 else None,
)
# Generate thumbnail if enabled
if self.gen_thumbnail_prompt:
self.thumbnail_generator(story_piece)
story_pieces.append(story_piece)

return Story(pieces=story_pieces, genres=self.genres)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
import logging
import os
from typing import Any, Literal, Optional

import backoff
Expand All @@ -12,6 +13,19 @@
from mind_renderer.utils.logger import Logger


def init_lm(text_model_config: dict[str, str]) -> dspy.LM:
"""Initialize the language model based on the configuration."""
provider = text_model_config["provider"]
lm_name = text_model_config["lm_name"]
section_length = text_model_config.get("section_length", 1000)
if provider == "DeepSeek":
api_key = os.getenv("DEEPSEEK_API_KEY")
return DeepSeek(base_url="https://api.deepseek.com", model=lm_name, api_key=api_key, max_tokens=section_length)
elif provider == "GROQ":
api_key = os.getenv("GROQ_API_KEY")
return dspy.__dict__[provider](model=lm_name, max_tokens=section_length, api_key=api_key)


class DeepSeek(dspy.LM):
"""Wrapper around DeepSeek's API.

Expand Down
19 changes: 19 additions & 0 deletions mind_renderer/utils/image_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os

import requests


def download_and_save_image(url: str, save_path: str) -> str:
"""Download an image from a URL and save it to the specified folder path."""
try:
response = requests.get(url)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
file.write(chunk)
except requests.exceptions.RequestException as e:
print(f"Error occurred while downloading image: {e}")
return None

return save_path
6 changes: 5 additions & 1 deletion story_gen_config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
story: >
写一个儿童故事,用于正面影响小朋友注意卫生。并且喜欢吃蔬菜。

# Save folder
root_save_folder: "outputs"
thumbnail_save_folder: "thumbnails"

# The number of sections or episodes in the story
num_sections: 2
num_sections: 3
section_length: 1000

# Comma-separated list of genres for the story
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import requests

from mind_renderer.utils.deepseek_lm import DeepSeek
from mind_renderer.utils.custom_lm import DeepSeek


class TestDeepSeek:
Expand Down
Loading