Skip to content

Commit

Permalink
Merge pull request #5 from small-thinking/add-title-and-file-name
Browse files Browse the repository at this point in the history
Add thumbnail generation
  • Loading branch information
yxjiang authored Jul 19, 2024
2 parents 1de2594 + b289018 commit fa7cfef
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 61 deletions.
8 changes: 8 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Settings
SKETCH_INSPECT_LENGTH=0 # For debugging, set to 1 to see the generated code

# LLM
OPENAI_API_KEY=YOUR_OPENAI_API
GROQ_API_KEY=YOUR_GROQ_API
ANTHROPIC_API_KEY=YOUR_ANTHROPIC_API
DEEPSEEK_API_KEY=YOUR_DEEPSEEK_API
5 changes: 1 addition & 4 deletions mind_renderer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""

import argparse
import os

from mind_renderer.core.story_generators import OneStepStoryGenerator
from mind_renderer.utils.logger import Logger
Expand All @@ -23,9 +22,7 @@ def main():
story = story_generator.generate()

logger.info("Generated story:")
save_folder = story_generator.config.get("output_path", "outputs")
markdown_path = os.path.join(save_folder, "story.md")
story.to_markdown(markdown_path)
story.to_markdown()


if __name__ == "__main__":
Expand Down
28 changes: 24 additions & 4 deletions mind_renderer/core/entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Defines the entities of a story.
"""

import os
import time
from typing import List


Expand Down Expand Up @@ -42,29 +44,47 @@ def __str__(self):
class Story:
"""A story is a collection of story pieces."""

def __init__(self, genres: str, visual_style: str = None, audio_style: str = None, pieces: List[StoryPiece] = None):
def __init__(
self,
config: dict[str, any],
title: str,
genres: str,
visual_style: str = None,
audio_style: str = None,
pieces: List[StoryPiece] = None,
):
"""
Args:
title (str): The title of the story.
genres (str): The genres of the story.
visual_style (str, optional): The visual style of the story, e.g. cartoon, simple color. Defaults to None.
audio_style (str, optional): The audio style of the story, story telling. Defaults to None.
pieces (List[StoryPiece], optional): The pieces of the story. Defaults to None.
"""
# Created timestamp
self.created_timestamp = str(int(time.time()))
self.title = title
self.genres = genres
self.visual_style = visual_style
self.audio_style = audio_style
self.pieces = pieces
self.pieces = pieces or []
story_folder_name = os.path.join(f"{self.created_timestamp}-{self.title}")
self.story_folder_path = os.path.join(config.get("output_path", "outputs"), story_folder_name)
# Create the story folder
os.makedirs(self.story_folder_path, exist_ok=True)

def add_piece(self, piece: StoryPiece):
self.pieces.append(piece)

def __str__(self):
return "\n".join([f"Episode {idx}:\n{str(piece)}\n" for idx, piece in enumerate(self.pieces)])

def to_markdown(self, file_path: str):
def to_markdown(self):
file_path = os.path.join(self.story_folder_path, "story.md")
with open(file_path, "w") as file:
file.write(f"# {self.genres}\n")
file.write(f"# {self.title}\n")
file.write(f"### {self.genres}\n")
for idx, piece in enumerate(self.pieces):
file.write(f"## Episode {idx}\n")
file.write(f"{piece.text}\n")
Expand Down
100 changes: 55 additions & 45 deletions mind_renderer/core/story_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from abc import abstractmethod

import dspy
import retrying
from dotenv import load_dotenv
from openai import OpenAI

Expand All @@ -23,14 +24,15 @@ class StorySketchGenerator(dspy.Module):
instruction: str = """
You are the story sketching assistant. Please follow the below instruction:
0. Generate a sketch for a multi-section story according to the specified number of sections.
1. Generate the sketch in the SAME language as the user input,
e.g. generate Chinese sketch if user input is Chinese.
2. Generate the story_worldview and the detailed piece-level prompts based on the simple prompt.
The story_worldview is the high-level description of the story, used to make each piece consistent.
Each of the detailed piece-level prompts will be used to generate a piece of the story.
3. Each episode should be coherent and complementary.
3. Each section should be coherent and complementary of the others.
4. The story in combination should be complete. So the story pieces can be logically connected.
5. Generate the prompt in the same language as the input prompt.
5. Generate the prompt per section in the same language as the input prompt.
"""

class SketchGenerateSignature(dspy.Signature):
Expand All @@ -40,16 +42,19 @@ class SketchGenerateSignature(dspy.Signature):
simple_prompt: str = dspy.InputField(desc="The user instruction on how the story would look like.")
num_sections: str = dspy.InputField(
desc="""
The number of episodes the story should have. Each episode describe one scene.
Consequtive episodes should be coherent and complementary.
The number of sections the story should have. Each section describe one scene.
Consequtive sections should be coherent and complementary.
"""
)
story_title: str = dspy.OutputField(
desc="The title of the story. Should be short and catchy. In the same language as the input prompt."
)
story_worldview: str = dspy.OutputField(desc="The world view of the story. Used to ensure consistency.")
prompts: str = dspy.OutputField(
desc="""
The blob that contains a list of prompts for each piece of the story, they should be coherent, e.g.
{
"prompts": ["prompt1", "prompt2", "prompt3"]
"prompts": ["prompt1", "prompt2", ...]
}
"""
)
Expand Down Expand Up @@ -95,12 +100,12 @@ class TextGenerateSignature(dspy.Signature):
)
prev_piece: str = dspy.InputField(desc="The previous piece of the story. For continuity.")

story_text: str = dspy.OutputField(
desc="The actual story content. Use the same language as the story description."
) # The output text of the story.
thumbnail_generation_prompt: str = dspy.OutputField(
desc="The cohensive prompt for generating the thumbnail image in English only. Use descriptive keywords."
)
story_text: str = dspy.OutputField(
desc="The actual story section. Use the same language as the story description."
) # The output text of the story.

def __init__(self, config: dict[str, any], text_model_config: dict[str, str]):
self.logger = Logger(__name__)
Expand All @@ -123,6 +128,10 @@ def forward(
**kwargs,
)

@retrying.retry(
wait_fixed=5000,
stop_max_attempt_number=3,
)
def generate(
self,
story_description: str,
Expand All @@ -131,31 +140,30 @@ def generate(
prev_piece: StoryPiece = None,
**kwargs,
) -> None:
"""Generate the element based on the story_description and populate the story piece with the generated element.
Args:
story_description (str): The description of the story at the high level.
story_worldview (str): The story_worldview of the story.
gen_thumbnail_prompt (bool): Whether to generate the thumbnail prompt.
story_piece (StoryPiece): The story piece to populate with the generated element.
prev_piece (StoryPiece): The previous piece of the story.
Returns:
Any: The generated element.
"""
story_gen = dspy.ChainOfThought(self.TextGenerateSignature)
with dspy.context(lm=self.lm):
response = story_gen(
instruction=self.instructions,
story_description=story_description,
story_worldview=story_worldview,
story_genres=self.genres,
story_writing_style=self.writing_style,
should_gen_thumbnail_prompt=str(self.gen_thumbnail_prompt),
prev_piece=prev_piece.text if prev_piece else "",
)
story_piece.text = response.story_text
story_piece.thumbnail_gen_prompt = response.thumbnail_generation_prompt
try:
self.logger.info("Attempting to generate story piece...")
story_gen = dspy.ChainOfThoughtWithHint(self.TextGenerateSignature)

with dspy.context(lm=self.lm):
response = story_gen(
instruction=self.instructions,
story_description=story_description,
story_worldview=story_worldview,
story_genres=self.genres,
story_writing_style=self.writing_style,
should_gen_thumbnail_prompt=str(self.gen_thumbnail_prompt),
prev_piece=prev_piece.text if prev_piece else "",
)
story_piece.text = response.story_text
if not story_piece.text:
self.logger.error("Failed to generate the text for the story piece.")
raise ValueError(f"Failed to generate the text for the story piece, {response}")
else:
story_piece.thumbnail_gen_prompt = response.thumbnail_generation_prompt
self.logger.info("Successfully generated story piece.")
except Exception as e:
self.logger.error(f"An error occurred: {str(e)}")
raise


class ThumbnailGenerator(dspy.Module):
Expand All @@ -168,11 +176,11 @@ def __init__(self, config: dict[str, any]):
self.root_save_folder = config.get("root_save_folder", "outputs")
self.save_folder = config.get("thumbnail_save_folder", "thumbnails")

def forward(self, story_piece: StoryPiece) -> None:
def forward(self, story_piece: StoryPiece, story: Story) -> None:
"""Generate the thumbnail for the given story piece."""
self.generate(story_piece)
self.generate(story_piece=story_piece, story=story)

def generate(self, story_piece: StoryPiece) -> None:
def generate(self, story_piece: StoryPiece, story: Story) -> None:
"""Generate the thumbnail based on the prompt in the story piece."""

genres = self.config.get("genres", "")
Expand All @@ -185,17 +193,17 @@ def generate(self, story_piece: StoryPiece) -> None:
response = self.client.images.generate(
model="dall-e-2",
prompt=enhanced_prompt,
size="256x256",
size="512x512",
quality="standard",
n=1,
)

image_url = response.data[0].url
thumbnail_path = os.path.join(self.save_folder, f"thumbnail_{story_piece.idx}.png")
save_path = os.path.join(self.root_save_folder, thumbnail_path)
thumbnail_name = f"thumbnail_{story_piece.idx}.png"
save_path = os.path.join(story.story_folder_path, thumbnail_name)

downloaded_path = download_and_save_image(image_url, save_path)
story_piece.image_uri = thumbnail_path
story_piece.image_uri = thumbnail_name

self.logger.info(f"Generated thumbnail for story piece {story_piece.idx}, saved at: {downloaded_path}")

Expand Down Expand Up @@ -278,29 +286,31 @@ def generate(self, prev_version: Story = None, **kwargs) -> Story:
self.logger.tool_log("Generating story sketch...")
sketch_response = self.story_sketch_generator(simple_prompt=prompt, num_sections=num_sections)

story_title = sketch_response.story_title.replace("Story Title:", "").strip()
story_worldview = sketch_response.story_worldview
try:
prompts = json.loads(sketch_response.prompts)["prompts"]
except json.JSONDecodeError:
raise ValueError(f"Failed to decode the prompts: {sketch_response.prompts}")

# Generate the story pieces
story_pieces = []
story = Story(config=self.config, title=story_title, genres=self.genres)
prev_story_piece = None
for i, prompt in enumerate(prompts):
self.logger.info(f"Generating story piece {i+1} with prompt: {prompt}...")
story_piece = StoryPiece(idx=i)
self.text_generator(
prompt=prompt,
story_worldview=story_worldview,
story_piece=story_piece,
prev_piece=story_pieces[i - 1] if i > 0 else None,
prev_piece=prev_story_piece,
)
# Generate thumbnail if enabled
if self.gen_thumbnail_prompt:
self.thumbnail_generator(story_piece)
story_pieces.append(story_piece)
self.thumbnail_generator(story_piece=story_piece, story=story)
story.add_piece(story_piece)

return Story(pieces=story_pieces, genres=self.genres)
return story

def _stop_condition_met(self) -> bool:
return True
2 changes: 2 additions & 0 deletions mind_renderer/utils/custom_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def init_lm(text_model_config: dict[str, str]) -> dspy.LM:
return DeepSeek(base_url="https://api.deepseek.com", model=lm_name, api_key=api_key, max_tokens=section_length)
elif provider == "GROQ":
api_key = os.getenv("GROQ_API_KEY")
elif provider == "OpenAI":
api_key = os.getenv("OPENAI_API_KEY")
return dspy.__dict__[provider](model=lm_name, max_tokens=section_length, api_key=api_key)


Expand Down
2 changes: 2 additions & 0 deletions mind_renderer/utils/image_processing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os

import requests
import retrying


@retrying.retry(wait_fixed=5000, stop_max_attempt_number=3)
def download_and_save_image(url: str, save_path: str) -> str:
"""Download an image from a URL and save it to the specified folder path."""
try:
Expand Down
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ notion-client = "^2.2.1"
dspy-ai = "^2.4.12"
load-dotenv = "^0.1.0"
groq = "^0.9.0"
retrying = "^1.3.4"

[tool.poetry.group.dev.dependencies]
black = "^24.4.2"
Expand Down
14 changes: 7 additions & 7 deletions story_gen_config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
story: >
写一个儿童故事,用于正面影响小朋友注意卫生。并且喜欢吃蔬菜。
# Save folder
root_save_folder: "outputs"
thumbnail_save_folder: "thumbnails"

# The number of sections or episodes in the story
num_sections: 3
num_sections: 4
section_length: 1000

# Comma-separated list of genres for the story
Expand All @@ -21,12 +17,16 @@ gen_thumbnail: true
text_models:

# The AI provider to use for generating the story, supported: DeepSeek, OpenAI, GROQ, Claude
provider: "DeepSeek"
provider: "OpenAI"

# The specific language model to use, supported:
# OpenAI: gpt-3.5-turbo-instruct, gpt-4o
# DeepSeek: deep-seek-chat
# GROQ: llama3-8b-8192, llama3-70b-8192, gemma2-9b-it, whisper-large-v3, mixtral-8x7b-32768
# Claude: claude-3-5-sonnet-20240620, claude-3-opus-20240229, claude-3-haiku-20240307
lm_name: "deep-seek-chat"
lm_name: "gpt-4o-mini"
# lm_name: "mixtral-8x7b-32768"


# Save folder, don't change until you know what you are doing
root_save_folder: "outputs"

0 comments on commit fa7cfef

Please sign in to comment.