diff --git a/mind_renderer/cli.py b/mind_renderer/cli.py index 59d5e8b..02bd5d9 100644 --- a/mind_renderer/cli.py +++ b/mind_renderer/cli.py @@ -5,6 +5,7 @@ """ import argparse +import os from mind_renderer.core.story_generators import OneStepStoryGenerator from mind_renderer.utils.logger import Logger @@ -22,7 +23,9 @@ def main(): story = story_generator.generate() logger.info("Generated story:") - print(story) + save_folder = story_generator.config.get("output_path", "outputs") + markdown_path = os.path.join(save_folder, "story.md") + story.to_markdown(markdown_path) if __name__ == "__main__": diff --git a/mind_renderer/core/entities.py b/mind_renderer/core/entities.py index d78b4c5..115cb28 100644 --- a/mind_renderer/core/entities.py +++ b/mind_renderer/core/entities.py @@ -61,3 +61,18 @@ def add_piece(self, piece: StoryPiece): def __str__(self): return "\n".join([f"Episode {idx}:\n{str(piece)}\n" for idx, piece in enumerate(self.pieces)]) + + def to_markdown(self, file_path: str): + with open(file_path, "w") as file: + file.write(f"# {self.genres}\n") + for idx, piece in enumerate(self.pieces): + file.write(f"## Episode {idx}\n") + file.write(f"{piece.text}\n") + if piece.image_uri: + file.write(f"![image](./{piece.image_uri})\n") + if piece.audio_uri: + file.write(f"![audio](./{piece.audio_uri})\n") + if piece.video_uri: + file.write(f"![video](./{piece.video_uri})\n") + file.write("\n") + return file_path diff --git a/mind_renderer/core/story_generators.py b/mind_renderer/core/story_generators.py index 40fb76f..f6be231 100644 --- a/mind_renderer/core/story_generators.py +++ b/mind_renderer/core/story_generators.py @@ -8,26 +8,15 @@ import dspy from dotenv import load_dotenv +from openai import OpenAI from mind_renderer.core.entities import Story, StoryPiece from mind_renderer.utils.config_loader import ConfigLoader -from mind_renderer.utils.deepseek_lm import DeepSeek +from mind_renderer.utils.custom_lm import init_lm +from mind_renderer.utils.image_processing import download_and_save_image from mind_renderer.utils.logger import Logger -def init_lm(text_model_config: dict[str, str]) -> dspy.LM: - """Initialize the language model based on the configuration.""" - provider = text_model_config["provider"] - lm_name = text_model_config["lm_name"] - section_length = text_model_config.get("section_length", 1000) - if provider == "DeepSeek": - api_key = os.getenv("DEEPSEEK_API_KEY") - return DeepSeek(base_url="https://api.deepseek.com", model=lm_name, api_key=api_key, max_tokens=section_length) - elif provider == "GROQ": - api_key = os.getenv("GROQ_API_KEY") - return dspy.__dict__[provider](model=lm_name, max_tokens=section_length, api_key=api_key) - - class StorySketchGenerator(dspy.Module): """StorySketchGenerator is used to expand a simple prompt into detailed piece-level prompts.""" @@ -49,11 +38,16 @@ class SketchGenerateSignature(dspy.Signature): instruction: str = dspy.InputField(desc="The instruction for the generation.") simple_prompt: str = dspy.InputField(desc="The user instruction on how the story would look like.") - num_sections: str = dspy.InputField(desc="The number of episodes the story should have.") - story_worldview: str = dspy.OutputField(desc="The world view of the story.") + num_sections: str = dspy.InputField( + desc=""" + The number of episodes the story should have. Each episode describe one scene. + Consequtive episodes should be coherent and complementary. + """ + ) + story_worldview: str = dspy.OutputField(desc="The world view of the story. Used to ensure consistency.") prompts: str = dspy.OutputField( desc=""" - The json blob that contains a list of prompts for each piece of the story, they should be coherent, e.g. + The blob that contains a list of prompts for each piece of the story, they should be coherent, e.g. { "prompts": ["prompt1", "prompt2", "prompt3"] } @@ -99,6 +93,8 @@ class TextGenerateSignature(dspy.Signature): should_gen_thumbnail_prompt: bool = dspy.InputField( desc="Whether to generate the thumbnail image gneeration prompt." ) + prev_piece: str = dspy.InputField(desc="The previous piece of the story. For continuity.") + thumbnail_generation_prompt: str = dspy.OutputField( desc="The cohensive prompt for generating the thumbnail image in English only. Use descriptive keywords." ) @@ -115,15 +111,24 @@ def __init__(self, config: dict[str, any], text_model_config: dict[str, str]): self.writing_style = config["writing_style"] self.gen_thumbnail_prompt = config["gen_thumbnail"] - def forward(self, prompt: str, story_worldview: str, story_piece: StoryPiece, **kwargs) -> None: + def forward( + self, prompt: str, story_worldview: str, story_piece: StoryPiece, prev_piece: StoryPiece = None, **kwargs + ) -> None: """Generate the element based on the prompt and populate the story piece with the generated element.""" - self.generate(story_description=prompt, story_worldview=story_worldview, story_piece=story_piece, **kwargs) + self.generate( + story_description=prompt, + story_worldview=story_worldview, + story_piece=story_piece, + prev_piece=prev_piece, + **kwargs, + ) def generate( self, story_description: str, story_worldview: str, story_piece: StoryPiece, + prev_piece: StoryPiece = None, **kwargs, ) -> None: """Generate the element based on the story_description and populate the story piece with the generated element. @@ -133,6 +138,7 @@ def generate( story_worldview (str): The story_worldview of the story. gen_thumbnail_prompt (bool): Whether to generate the thumbnail prompt. story_piece (StoryPiece): The story piece to populate with the generated element. + prev_piece (StoryPiece): The previous piece of the story. Returns: Any: The generated element. @@ -146,11 +152,54 @@ def generate( story_genres=self.genres, story_writing_style=self.writing_style, should_gen_thumbnail_prompt=str(self.gen_thumbnail_prompt), + prev_piece=prev_piece.text if prev_piece else "", ) story_piece.text = response.story_text story_piece.thumbnail_gen_prompt = response.thumbnail_generation_prompt +class ThumbnailGenerator(dspy.Module): + """ThumbnailGenerator is used to generate thumbnail images for story pieces.""" + + def __init__(self, config: dict[str, any]): + self.logger = Logger(__name__) + self.config = config + self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + self.root_save_folder = config.get("root_save_folder", "outputs") + self.save_folder = config.get("thumbnail_save_folder", "thumbnails") + + def forward(self, story_piece: StoryPiece) -> None: + """Generate the thumbnail for the given story piece.""" + self.generate(story_piece) + + def generate(self, story_piece: StoryPiece) -> None: + """Generate the thumbnail based on the prompt in the story piece.""" + + genres = self.config.get("genres", "") + enhanced_prompt = f""" + Generate a thumbnail image for the story piece. The story piece is part of a whole story. + The genres of the story are: {genres}. + The description of the thumbnail: {story_piece.thumbnail_gen_prompt}. + """ + + response = self.client.images.generate( + model="dall-e-2", + prompt=enhanced_prompt, + size="256x256", + quality="standard", + n=1, + ) + + image_url = response.data[0].url + thumbnail_path = os.path.join(self.save_folder, f"thumbnail_{story_piece.idx}.png") + save_path = os.path.join(self.root_save_folder, thumbnail_path) + + downloaded_path = download_and_save_image(image_url, save_path) + story_piece.image_uri = thumbnail_path + + self.logger.info(f"Generated thumbnail for story piece {story_piece.idx}, saved at: {downloaded_path}") + + class StoryGenerator(dspy.Module): """StoryGenerator is the main interface for the user to interact with the system. It generates the entire story based on the prompt and the previous version of the story. @@ -203,19 +252,20 @@ class OneStepStoryGenerator(StoryGenerator): def __init__(self, config_path: str): self.logger = Logger(__name__) self.config_loader = ConfigLoader(config_path) - config = self.config_loader.get_config() + self.config = self.config_loader.get_config() text_model_config = self.config_loader.get_text_model_config() super().__init__( - genres=config.get("genres", ""), - writing_style=config.get("writing_style", ""), - gen_thumbnail_prompt=config.get("gen_thumbnail", False), + genres=self.config.get("genres", ""), + writing_style=self.config.get("writing_style", ""), + gen_thumbnail_prompt=self.config.get("gen_thumbnail", False), provider=text_model_config.get("provider"), lm_name=text_model_config.get("lm_name"), ) - self.story_sketch_generator = StorySketchGenerator(config=config, text_model_config=text_model_config) - self.text_generator = TextGenerator(config=config, text_model_config=text_model_config) + self.story_sketch_generator = StorySketchGenerator(config=self.config, text_model_config=text_model_config) + self.text_generator = TextGenerator(config=self.config, text_model_config=text_model_config) + self.thumbnail_generator = ThumbnailGenerator(config=self.config) def generate(self, prev_version: Story = None, **kwargs) -> Story: config = self.config_loader.get_config() @@ -225,18 +275,29 @@ def generate(self, prev_version: Story = None, **kwargs) -> Story: self.logger.tool_log(f"Generating a {num_sections} pieces story with prompt: {prompt}...") # Generate story sketch - self.logger.info("Generating story sketch...") + self.logger.tool_log("Generating story sketch...") sketch_response = self.story_sketch_generator(simple_prompt=prompt, num_sections=num_sections) story_worldview = sketch_response.story_worldview - prompts = json.loads(sketch_response.prompts)["prompts"] + try: + prompts = json.loads(sketch_response.prompts)["prompts"] + except json.JSONDecodeError: + raise ValueError(f"Failed to decode the prompts: {sketch_response.prompts}") # Generate the story pieces story_pieces = [] for i, prompt in enumerate(prompts): self.logger.info(f"Generating story piece {i+1} with prompt: {prompt}...") story_piece = StoryPiece(idx=i) - self.text_generator(prompt=prompt, story_worldview=story_worldview, story_piece=story_piece) + self.text_generator( + prompt=prompt, + story_worldview=story_worldview, + story_piece=story_piece, + prev_piece=story_pieces[i - 1] if i > 0 else None, + ) + # Generate thumbnail if enabled + if self.gen_thumbnail_prompt: + self.thumbnail_generator(story_piece) story_pieces.append(story_piece) return Story(pieces=story_pieces, genres=self.genres) diff --git a/mind_renderer/utils/deepseek_lm.py b/mind_renderer/utils/custom_lm.py similarity index 89% rename from mind_renderer/utils/deepseek_lm.py rename to mind_renderer/utils/custom_lm.py index 6ffd825..8cba2e2 100644 --- a/mind_renderer/utils/deepseek_lm.py +++ b/mind_renderer/utils/custom_lm.py @@ -3,6 +3,7 @@ import json import logging +import os from typing import Any, Literal, Optional import backoff @@ -12,6 +13,19 @@ from mind_renderer.utils.logger import Logger +def init_lm(text_model_config: dict[str, str]) -> dspy.LM: + """Initialize the language model based on the configuration.""" + provider = text_model_config["provider"] + lm_name = text_model_config["lm_name"] + section_length = text_model_config.get("section_length", 1000) + if provider == "DeepSeek": + api_key = os.getenv("DEEPSEEK_API_KEY") + return DeepSeek(base_url="https://api.deepseek.com", model=lm_name, api_key=api_key, max_tokens=section_length) + elif provider == "GROQ": + api_key = os.getenv("GROQ_API_KEY") + return dspy.__dict__[provider](model=lm_name, max_tokens=section_length, api_key=api_key) + + class DeepSeek(dspy.LM): """Wrapper around DeepSeek's API. diff --git a/mind_renderer/utils/image_processing.py b/mind_renderer/utils/image_processing.py new file mode 100644 index 0000000..3325b6d --- /dev/null +++ b/mind_renderer/utils/image_processing.py @@ -0,0 +1,19 @@ +import os + +import requests + + +def download_and_save_image(url: str, save_path: str) -> str: + """Download an image from a URL and save it to the specified folder path.""" + try: + response = requests.get(url) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + with open(save_path, "wb") as file: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + file.write(chunk) + except requests.exceptions.RequestException as e: + print(f"Error occurred while downloading image: {e}") + return None + + return save_path diff --git a/story_gen_config.yaml b/story_gen_config.yaml index abb73d6..55fecf1 100644 --- a/story_gen_config.yaml +++ b/story_gen_config.yaml @@ -1,8 +1,12 @@ story: > 写一个儿童故事,用于正面影响小朋友注意卫生。并且喜欢吃蔬菜。 +# Save folder +root_save_folder: "outputs" +thumbnail_save_folder: "thumbnails" + # The number of sections or episodes in the story -num_sections: 2 +num_sections: 3 section_length: 1000 # Comma-separated list of genres for the story diff --git a/tests/utils/test_deepseek_lm.py b/tests/utils/test_custom_lm.py similarity index 98% rename from tests/utils/test_deepseek_lm.py rename to tests/utils/test_custom_lm.py index 0c17d89..a6d62ab 100644 --- a/tests/utils/test_deepseek_lm.py +++ b/tests/utils/test_custom_lm.py @@ -7,7 +7,7 @@ import pytest import requests -from mind_renderer.utils.deepseek_lm import DeepSeek +from mind_renderer.utils.custom_lm import DeepSeek class TestDeepSeek: