diff --git a/.env.template b/.env.template new file mode 100644 index 0000000..e31322c --- /dev/null +++ b/.env.template @@ -0,0 +1,8 @@ +# Settings +SKETCH_INSPECT_LENGTH=0 # For debugging, set to 1 to see the generated code + +# LLM +OPENAI_API_KEY=YOUR_OPENAI_API +GROQ_API_KEY=YOUR_GROQ_API +ANTHROPIC_API_KEY=YOUR_ANTHROPIC_API +DEEPSEEK_API_KEY=YOUR_DEEPSEEK_API \ No newline at end of file diff --git a/mind_renderer/cli.py b/mind_renderer/cli.py index 02bd5d9..36049cb 100644 --- a/mind_renderer/cli.py +++ b/mind_renderer/cli.py @@ -5,7 +5,6 @@ """ import argparse -import os from mind_renderer.core.story_generators import OneStepStoryGenerator from mind_renderer.utils.logger import Logger @@ -23,9 +22,7 @@ def main(): story = story_generator.generate() logger.info("Generated story:") - save_folder = story_generator.config.get("output_path", "outputs") - markdown_path = os.path.join(save_folder, "story.md") - story.to_markdown(markdown_path) + story.to_markdown() if __name__ == "__main__": diff --git a/mind_renderer/core/entities.py b/mind_renderer/core/entities.py index 115cb28..86283a7 100644 --- a/mind_renderer/core/entities.py +++ b/mind_renderer/core/entities.py @@ -1,6 +1,8 @@ """Defines the entities of a story. """ +import os +import time from typing import List @@ -42,19 +44,35 @@ def __str__(self): class Story: """A story is a collection of story pieces.""" - def __init__(self, genres: str, visual_style: str = None, audio_style: str = None, pieces: List[StoryPiece] = None): + def __init__( + self, + config: dict[str, any], + title: str, + genres: str, + visual_style: str = None, + audio_style: str = None, + pieces: List[StoryPiece] = None, + ): """ Args: + title (str): The title of the story. genres (str): The genres of the story. visual_style (str, optional): The visual style of the story, e.g. cartoon, simple color. Defaults to None. audio_style (str, optional): The audio style of the story, story telling. Defaults to None. pieces (List[StoryPiece], optional): The pieces of the story. Defaults to None. """ + # Created timestamp + self.created_timestamp = str(int(time.time())) + self.title = title self.genres = genres self.visual_style = visual_style self.audio_style = audio_style - self.pieces = pieces + self.pieces = pieces or [] + story_folder_name = os.path.join(f"{self.created_timestamp}-{self.title}") + self.story_folder_path = os.path.join(config.get("output_path", "outputs"), story_folder_name) + # Create the story folder + os.makedirs(self.story_folder_path, exist_ok=True) def add_piece(self, piece: StoryPiece): self.pieces.append(piece) @@ -62,9 +80,11 @@ def add_piece(self, piece: StoryPiece): def __str__(self): return "\n".join([f"Episode {idx}:\n{str(piece)}\n" for idx, piece in enumerate(self.pieces)]) - def to_markdown(self, file_path: str): + def to_markdown(self): + file_path = os.path.join(self.story_folder_path, "story.md") with open(file_path, "w") as file: - file.write(f"# {self.genres}\n") + file.write(f"# {self.title}\n") + file.write(f"### {self.genres}\n") for idx, piece in enumerate(self.pieces): file.write(f"## Episode {idx}\n") file.write(f"{piece.text}\n") diff --git a/mind_renderer/core/story_generators.py b/mind_renderer/core/story_generators.py index f6be231..59ee808 100644 --- a/mind_renderer/core/story_generators.py +++ b/mind_renderer/core/story_generators.py @@ -7,6 +7,7 @@ from abc import abstractmethod import dspy +import retrying from dotenv import load_dotenv from openai import OpenAI @@ -23,14 +24,15 @@ class StorySketchGenerator(dspy.Module): instruction: str = """ You are the story sketching assistant. Please follow the below instruction: + 0. Generate a sketch for a multi-section story according to the specified number of sections. 1. Generate the sketch in the SAME language as the user input, e.g. generate Chinese sketch if user input is Chinese. 2. Generate the story_worldview and the detailed piece-level prompts based on the simple prompt. The story_worldview is the high-level description of the story, used to make each piece consistent. Each of the detailed piece-level prompts will be used to generate a piece of the story. - 3. Each episode should be coherent and complementary. + 3. Each section should be coherent and complementary of the others. 4. The story in combination should be complete. So the story pieces can be logically connected. - 5. Generate the prompt in the same language as the input prompt. + 5. Generate the prompt per section in the same language as the input prompt. """ class SketchGenerateSignature(dspy.Signature): @@ -40,16 +42,19 @@ class SketchGenerateSignature(dspy.Signature): simple_prompt: str = dspy.InputField(desc="The user instruction on how the story would look like.") num_sections: str = dspy.InputField( desc=""" - The number of episodes the story should have. Each episode describe one scene. - Consequtive episodes should be coherent and complementary. + The number of sections the story should have. Each section describe one scene. + Consequtive sections should be coherent and complementary. """ ) + story_title: str = dspy.OutputField( + desc="The title of the story. Should be short and catchy. In the same language as the input prompt." + ) story_worldview: str = dspy.OutputField(desc="The world view of the story. Used to ensure consistency.") prompts: str = dspy.OutputField( desc=""" The blob that contains a list of prompts for each piece of the story, they should be coherent, e.g. { - "prompts": ["prompt1", "prompt2", "prompt3"] + "prompts": ["prompt1", "prompt2", ...] } """ ) @@ -95,12 +100,12 @@ class TextGenerateSignature(dspy.Signature): ) prev_piece: str = dspy.InputField(desc="The previous piece of the story. For continuity.") + story_text: str = dspy.OutputField( + desc="The actual story content. Use the same language as the story description." + ) # The output text of the story. thumbnail_generation_prompt: str = dspy.OutputField( desc="The cohensive prompt for generating the thumbnail image in English only. Use descriptive keywords." ) - story_text: str = dspy.OutputField( - desc="The actual story section. Use the same language as the story description." - ) # The output text of the story. def __init__(self, config: dict[str, any], text_model_config: dict[str, str]): self.logger = Logger(__name__) @@ -123,6 +128,10 @@ def forward( **kwargs, ) + @retrying.retry( + wait_fixed=5000, + stop_max_attempt_number=3, + ) def generate( self, story_description: str, @@ -131,31 +140,30 @@ def generate( prev_piece: StoryPiece = None, **kwargs, ) -> None: - """Generate the element based on the story_description and populate the story piece with the generated element. - - Args: - story_description (str): The description of the story at the high level. - story_worldview (str): The story_worldview of the story. - gen_thumbnail_prompt (bool): Whether to generate the thumbnail prompt. - story_piece (StoryPiece): The story piece to populate with the generated element. - prev_piece (StoryPiece): The previous piece of the story. - - Returns: - Any: The generated element. - """ - story_gen = dspy.ChainOfThought(self.TextGenerateSignature) - with dspy.context(lm=self.lm): - response = story_gen( - instruction=self.instructions, - story_description=story_description, - story_worldview=story_worldview, - story_genres=self.genres, - story_writing_style=self.writing_style, - should_gen_thumbnail_prompt=str(self.gen_thumbnail_prompt), - prev_piece=prev_piece.text if prev_piece else "", - ) - story_piece.text = response.story_text - story_piece.thumbnail_gen_prompt = response.thumbnail_generation_prompt + try: + self.logger.info("Attempting to generate story piece...") + story_gen = dspy.ChainOfThoughtWithHint(self.TextGenerateSignature) + + with dspy.context(lm=self.lm): + response = story_gen( + instruction=self.instructions, + story_description=story_description, + story_worldview=story_worldview, + story_genres=self.genres, + story_writing_style=self.writing_style, + should_gen_thumbnail_prompt=str(self.gen_thumbnail_prompt), + prev_piece=prev_piece.text if prev_piece else "", + ) + story_piece.text = response.story_text + if not story_piece.text: + self.logger.error("Failed to generate the text for the story piece.") + raise ValueError(f"Failed to generate the text for the story piece, {response}") + else: + story_piece.thumbnail_gen_prompt = response.thumbnail_generation_prompt + self.logger.info("Successfully generated story piece.") + except Exception as e: + self.logger.error(f"An error occurred: {str(e)}") + raise class ThumbnailGenerator(dspy.Module): @@ -168,11 +176,11 @@ def __init__(self, config: dict[str, any]): self.root_save_folder = config.get("root_save_folder", "outputs") self.save_folder = config.get("thumbnail_save_folder", "thumbnails") - def forward(self, story_piece: StoryPiece) -> None: + def forward(self, story_piece: StoryPiece, story: Story) -> None: """Generate the thumbnail for the given story piece.""" - self.generate(story_piece) + self.generate(story_piece=story_piece, story=story) - def generate(self, story_piece: StoryPiece) -> None: + def generate(self, story_piece: StoryPiece, story: Story) -> None: """Generate the thumbnail based on the prompt in the story piece.""" genres = self.config.get("genres", "") @@ -185,17 +193,17 @@ def generate(self, story_piece: StoryPiece) -> None: response = self.client.images.generate( model="dall-e-2", prompt=enhanced_prompt, - size="256x256", + size="512x512", quality="standard", n=1, ) image_url = response.data[0].url - thumbnail_path = os.path.join(self.save_folder, f"thumbnail_{story_piece.idx}.png") - save_path = os.path.join(self.root_save_folder, thumbnail_path) + thumbnail_name = f"thumbnail_{story_piece.idx}.png" + save_path = os.path.join(story.story_folder_path, thumbnail_name) downloaded_path = download_and_save_image(image_url, save_path) - story_piece.image_uri = thumbnail_path + story_piece.image_uri = thumbnail_name self.logger.info(f"Generated thumbnail for story piece {story_piece.idx}, saved at: {downloaded_path}") @@ -278,6 +286,7 @@ def generate(self, prev_version: Story = None, **kwargs) -> Story: self.logger.tool_log("Generating story sketch...") sketch_response = self.story_sketch_generator(simple_prompt=prompt, num_sections=num_sections) + story_title = sketch_response.story_title.replace("Story Title:", "").strip() story_worldview = sketch_response.story_worldview try: prompts = json.loads(sketch_response.prompts)["prompts"] @@ -285,7 +294,8 @@ def generate(self, prev_version: Story = None, **kwargs) -> Story: raise ValueError(f"Failed to decode the prompts: {sketch_response.prompts}") # Generate the story pieces - story_pieces = [] + story = Story(config=self.config, title=story_title, genres=self.genres) + prev_story_piece = None for i, prompt in enumerate(prompts): self.logger.info(f"Generating story piece {i+1} with prompt: {prompt}...") story_piece = StoryPiece(idx=i) @@ -293,14 +303,14 @@ def generate(self, prev_version: Story = None, **kwargs) -> Story: prompt=prompt, story_worldview=story_worldview, story_piece=story_piece, - prev_piece=story_pieces[i - 1] if i > 0 else None, + prev_piece=prev_story_piece, ) # Generate thumbnail if enabled if self.gen_thumbnail_prompt: - self.thumbnail_generator(story_piece) - story_pieces.append(story_piece) + self.thumbnail_generator(story_piece=story_piece, story=story) + story.add_piece(story_piece) - return Story(pieces=story_pieces, genres=self.genres) + return story def _stop_condition_met(self) -> bool: return True diff --git a/mind_renderer/utils/custom_lm.py b/mind_renderer/utils/custom_lm.py index 8cba2e2..7ecf902 100644 --- a/mind_renderer/utils/custom_lm.py +++ b/mind_renderer/utils/custom_lm.py @@ -23,6 +23,8 @@ def init_lm(text_model_config: dict[str, str]) -> dspy.LM: return DeepSeek(base_url="https://api.deepseek.com", model=lm_name, api_key=api_key, max_tokens=section_length) elif provider == "GROQ": api_key = os.getenv("GROQ_API_KEY") + elif provider == "OpenAI": + api_key = os.getenv("OPENAI_API_KEY") return dspy.__dict__[provider](model=lm_name, max_tokens=section_length, api_key=api_key) diff --git a/mind_renderer/utils/image_processing.py b/mind_renderer/utils/image_processing.py index 3325b6d..4844a30 100644 --- a/mind_renderer/utils/image_processing.py +++ b/mind_renderer/utils/image_processing.py @@ -1,8 +1,10 @@ import os import requests +import retrying +@retrying.retry(wait_fixed=5000, stop_max_attempt_number=3) def download_and_save_image(url: str, save_path: str) -> str: """Download an image from a URL and save it to the specified folder path.""" try: diff --git a/poetry.lock b/poetry.lock index 5d7a9b9..3863e0c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2257,6 +2257,20 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "retrying" +version = "1.3.4" +description = "Retrying" +optional = false +python-versions = "*" +files = [ + {file = "retrying-1.3.4-py3-none-any.whl", hash = "sha256:8cc4d43cb8e1125e0ff3344e9de678fefd85db3b750b81b2240dc0183af37b35"}, + {file = "retrying-1.3.4.tar.gz", hash = "sha256:345da8c5765bd982b1d1915deb9102fd3d1f7ad16bd84a9700b85f64d24e8f3e"}, +] + +[package.dependencies] +six = ">=1.7.0" + [[package]] name = "setuptools" version = "71.0.0" @@ -2846,4 +2860,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "77efa2615db711c868f824dc49ada4a4578fed42ec611345d750923bae04dcd4" +content-hash = "155c528a020365e68320086394c5497182f09c341ca639d7ece198e2821b21f0" diff --git a/pyproject.toml b/pyproject.toml index bb69bf5..70493bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ notion-client = "^2.2.1" dspy-ai = "^2.4.12" load-dotenv = "^0.1.0" groq = "^0.9.0" +retrying = "^1.3.4" [tool.poetry.group.dev.dependencies] black = "^24.4.2" diff --git a/story_gen_config.yaml b/story_gen_config.yaml index 55fecf1..3b403fd 100644 --- a/story_gen_config.yaml +++ b/story_gen_config.yaml @@ -1,12 +1,8 @@ story: > 写一个儿童故事,用于正面影响小朋友注意卫生。并且喜欢吃蔬菜。 -# Save folder -root_save_folder: "outputs" -thumbnail_save_folder: "thumbnails" - # The number of sections or episodes in the story -num_sections: 3 +num_sections: 4 section_length: 1000 # Comma-separated list of genres for the story @@ -21,12 +17,16 @@ gen_thumbnail: true text_models: # The AI provider to use for generating the story, supported: DeepSeek, OpenAI, GROQ, Claude - provider: "DeepSeek" + provider: "OpenAI" # The specific language model to use, supported: # OpenAI: gpt-3.5-turbo-instruct, gpt-4o # DeepSeek: deep-seek-chat # GROQ: llama3-8b-8192, llama3-70b-8192, gemma2-9b-it, whisper-large-v3, mixtral-8x7b-32768 # Claude: claude-3-5-sonnet-20240620, claude-3-opus-20240229, claude-3-haiku-20240307 - lm_name: "deep-seek-chat" + lm_name: "gpt-4o-mini" # lm_name: "mixtral-8x7b-32768" + + +# Save folder, don't change until you know what you are doing +root_save_folder: "outputs" \ No newline at end of file