Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yxjiang committed Jul 21, 2024
1 parent 60946b6 commit 109e8e4
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 103 deletions.
13 changes: 3 additions & 10 deletions mind_renderer/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""The cli to generate the story.
Run the following command to generate the story:
poetry run python mind_renderer/cli.py --config-path story_gen_config.yaml
poetry run python mind_renderer/cli.py
"""

import argparse
Expand All @@ -11,17 +11,10 @@


def main():
parser = argparse.ArgumentParser(description="One-step story generator CLI")
parser.add_argument("--config-path", type=str, default="config.yaml", help="Path to the config file")

args = parser.parse_args()

logger = Logger("StoryGeneratorCLI")

story_generator = OneStepStoryGenerator(config_path=args.config_path)
story_generator = OneStepStoryGenerator()
story = story_generator.generate()

logger.info("Generated story:")
print("Generated story:")
story.to_markdown()


Expand Down
120 changes: 63 additions & 57 deletions mind_renderer/core/story_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,20 @@ class StorySketchGenerator(dspy.Module):
"""StorySketchGenerator is used to expand a simple prompt into detailed piece-level prompts."""

instruction: str = """
You are the story sketching assistant. Please follow the below instruction:
0. Generate a sketch for a multi-section story according to the specified number of sections.
1. Generate the sketch in the SAME language as the user input,
e.g. generate Chinese sketch if user input is Chinese.
2. Generate the story_worldview and the detailed piece-level prompts based on the simple prompt.
The story_worldview is the high-level description of the story, used to make each piece consistent.
Each of the detailed piece-level prompts will be used to generate a piece of the story.
3. Each section should be coherent and complementary of the others.
4. The story in combination should be complete. So the story pieces can be logically connected.
5. Generate the prompt per section in the same language as the input prompt.
You are the story sketching assistant. Please follow the below instruction:
1. Story Sketch Generation: Create a sketch for a multi-section story according to the specified number of
sections.
2. Language Consistency: Ensure the sketch is generated in the SAME language as the user input. For example,
generate the sketch in Chinese if the user input is in Chinese.
3. Worldview and Detailed Prompts: Based on the simple prompt provided:
3.1. Generate the story worldview, which is a high-level description of the story to ensure consistency across
sections.
3.2. Create detailed piece-level prompts for each section, which will be used to generate individual parts of the
story.
4. Coherence and Complementarity: Each section should be coherent within itself and complement the other sections.
5. Logical Connectivity: Ensure the entire story, when combined, forms a complete and logically connected narrative.
6. Section-Specific Prompts: Generate prompts for each section in the same language as the input prompt.
"""

class SketchGenerateSignature(dspy.Signature):
Expand All @@ -59,10 +62,10 @@ class SketchGenerateSignature(dspy.Signature):
"""
)

def __init__(self, config_loader: ConfigLoader, text_model_config: dict[str, str]):
self.logger = Logger(__name__)
def __init__(self, config_loader: ConfigLoader, logger: Logger):
self.logger = logger
self.config_loader = config_loader
self.lm = init_lm(self.config_loader)
self.lm = init_lm(self.config_loader, logger=self.logger)
self.sketch_inspect_length = int(os.getenv("SKETCH_INSPECT_LENGTH", 0))

def forward(self, simple_prompt: str, num_sections: int, **kwargs):
Expand All @@ -83,15 +86,22 @@ class TextGenerator(dspy.Module):
"""TextGenerator is used to generate the text element of the story piece based on the prompt."""

instructions = """
You are the story writing assistant. Your goal is to generate a story piece a time based on the prompt.
Please generate the story piece in the SAME language as the input prompt.
Please put the clean generated text in the output field.
You are the story writing assistant that generates one story piece a time based on the prompt.
### General guidelines:
1. Please generate the story piece in the SAME language as the input prompt.
"""

post_process_instructions = """
You are the post-processing assistant that refines the generated story piece.
The raw story piece generated may contain some errors or uncleaned tags,
we only need the actual story content in the section of story text.
Please only return the cleaned story text.
"""

class TextGenerateSignature(dspy.Signature):
"""Signature for the generate method."""

instruction: str = dspy.InputField(desc="The instruction for the generation.")
instruction: str = dspy.InputField(desc="The instruction for the generation to strictly follow.")
story_description: str = dspy.InputField(desc="The description of the story at the high level.")
story_worldview: str = dspy.InputField(desc="The story_worldview of the story.")
story_genres: str = dspy.InputField(desc="The genres of the story, e.g. horror, romance, etc.")
Expand All @@ -108,10 +118,17 @@ class TextGenerateSignature(dspy.Signature):
desc="The cohensive prompt for generating the thumbnail image in English only. Use descriptive keywords."
)

def __init__(self, config_loader: ConfigLoader, text_model_config: dict[str, str]):
self.logger = Logger(__name__)
class PostProcessSignature(dspy.Signature):
"""Signature for the post_process method."""

instruction: str = dspy.InputField(desc="The instruction for the post-processing.")
raw_text: str = dspy.InputField(desc="The story piece to be post-processed.")
story_text: str = dspy.OutputField(desc="The cleaned story text.")

def __init__(self, config_loader: ConfigLoader, logger: Logger):
self.logger = logger
self.config_loader = config_loader
self.lm = init_lm(self.config_loader)
self.lm = init_lm(self.config_loader, logger=self.logger)
self.genres = self.config_loader.get_value("genres", "")
self.writing_style = self.config_loader.get_value("writing_style", "")
self.gen_thumbnail_prompt = self.config_loader.get_value("image_model.gen_thumbnail", False)
Expand Down Expand Up @@ -142,7 +159,8 @@ def generate(
) -> None:
try:
self.logger.debug("Attempting to generate story piece...")
story_gen = dspy.ChainOfThoughtWithHint(self.TextGenerateSignature)
story_gen = dspy.Predict(self.TextGenerateSignature)
story_post_process = dspy.Predict(self.PostProcessSignature)

with dspy.context(lm=self.lm):
response = story_gen(
Expand All @@ -154,13 +172,18 @@ def generate(
should_gen_thumbnail_prompt=str(self.gen_thumbnail_prompt),
prev_piece=prev_piece.text if prev_piece else "",
)
story_piece.text = response.story_text
post_process_response = story_post_process(
instruction=self.post_process_instructions, raw_text=response.story_text
)
story_piece.text = post_process_response.story_text.replace("Story Text:", "").strip()
if not story_piece.text:
self.logger.error("Failed to generate the text for the story piece.")
raise ValueError(f"Failed to generate the text for the story piece, {response}")
else:
story_piece.thumbnail_gen_prompt = response.thumbnail_generation_prompt
self.logger.debug("Successfully generated story piece.")
story_piece.thumbnail_gen_prompt = response.thumbnail_generation_prompt.replace(
"Thumbnail Generation Prompt:", ""
).strip()
self.logger.info(f"Generated story piece text:\n{story_piece.text}")
self.logger.info(f"Generated thumbnail prompt:\n{story_piece.thumbnail_gen_prompt}")
except Exception as e:
self.logger.error(f"An error occurred: {str(e)}")
raise
Expand All @@ -169,8 +192,8 @@ def generate(
class ThumbnailGenerator(dspy.Module):
"""ThumbnailGenerator is used to generate thumbnail images for story pieces."""

def __init__(self, config_loader: ConfigLoader):
self.logger = Logger(__name__)
def __init__(self, config_loader: ConfigLoader, logger: Logger):
self.logger = logger
self.config_loader = config_loader
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
self.root_save_folder = self.config_loader.get_value("root_save_folder", "outputs")
Expand Down Expand Up @@ -217,21 +240,17 @@ class StoryGenerator(dspy.Module):

def __init__(
self,
genres: str,
writing_style: str,
gen_thumbnail_prompt: bool,
provider: str = None,
lm_name: str = None,
**kwargs,
):
load_dotenv(override=True)
self.state: dict[str, any]
self.logger = Logger(__name__)
self.genres = genres
self.writing_style = writing_style
self.gen_thumbnail_prompt = gen_thumbnail_prompt
self.provider = provider
self.lm_name = lm_name
self.config_loader = ConfigLoader()
self.logger = Logger(__name__, parent_folder=self.config_loader.get_value("root_save_folder", "outputs"))
self.genres = self.config_loader.get_value("genres", "")
self.writing_style = self.config_loader.get_value("writing_style", "")
self.gen_thumbnail_prompt = self.config_loader.get_value("image_model.gen_thumbnail", False)
self.provider = self.config_loader.get_value("text_model.provider")
self.lm_name = self.config_loader.get_value("text_model.lm_name")
self.temperature = kwargs.get("temperature", 0.7)

@abstractmethod
Expand Down Expand Up @@ -259,24 +278,11 @@ def forward(self, **kwargs):


class OneStepStoryGenerator(StoryGenerator):
def __init__(self, config_path: str):
self.logger = Logger(__name__)
self.config_loader = ConfigLoader(config_path)
text_model_config = self.config_loader.get_text_model_config()

super().__init__(
genres=self.config_loader.get_value("genres", ""),
writing_style=self.config_loader.get_value("writing_style", ""),
gen_thumbnail_prompt=self.config_loader.get_value("image_model.gen_thumbnail", False),
provider=self.config_loader.get_value("text_model.provider"),
lm_name=self.config_loader.get_value("text_model.lm_name"),
)

self.story_sketch_generator = StorySketchGenerator(
config_loader=self.config_loader, text_model_config=text_model_config
)
self.text_generator = TextGenerator(config_loader=self.config_loader, text_model_config=text_model_config)
self.thumbnail_generator = ThumbnailGenerator(config_loader=self.config_loader)
def __init__(self):
super().__init__()
self.story_sketch_generator = StorySketchGenerator(config_loader=self.config_loader, logger=self.logger)
self.text_generator = TextGenerator(config_loader=self.config_loader, logger=self.logger)
self.thumbnail_generator = ThumbnailGenerator(config_loader=self.config_loader, logger=self.logger)

def generate(self, prev_version: Story = None, **kwargs) -> Story:
prompt = self.config_loader.get_value("story", "")
Expand Down
15 changes: 8 additions & 7 deletions mind_renderer/utils/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@


class ConfigLoader:
def __init__(self, config_path: str = "config.yaml"):
self.config_path = config_path
self.config = self._load_config()
def __init__(self, config_path=None):
config_path = config_path or "./story_gen_config.yaml"
self.config = self._load_config(config_path=config_path)

def _load_config(self) -> Dict[str, Any]:
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"Config file not found: {self.config_path}")
def _load_config(self, config_path: str) -> Dict[str, Any]:

with open(self.config_path, "r") as config_file:
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")

with open(config_path, "r") as config_file:
return yaml.safe_load(config_file)

def get_config(self) -> Dict[str, Any]:
Expand Down
15 changes: 11 additions & 4 deletions mind_renderer/utils/custom_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,20 @@
from mind_renderer.utils.logger import Logger


def init_lm(config_loader: ConfigLoader) -> dspy.LM:
def init_lm(config_loader: ConfigLoader, logger: Logger) -> dspy.LM:
"""Initialize the language model based on the configuration."""
provider = config_loader.get_value("text_model.provider")
lm_name = config_loader.get_value("text_model.lm_name")
section_length = config_loader.get_value("text_model.section_length")
section_length = config_loader.get_value("text_model.section_length", 100)
if provider == "DeepSeek":
api_key = os.getenv("DEEPSEEK_API_KEY")
return DeepSeek(base_url="https://api.deepseek.com", model=lm_name, api_key=api_key, max_tokens=section_length)
return DeepSeek(
base_url="https://api.deepseek.com",
logger=logger,
model=lm_name,
api_key=api_key,
max_tokens=section_length,
)
elif provider == "GROQ":
api_key = os.getenv("GROQ_API_KEY")
elif provider == "OpenAI":
Expand All @@ -42,6 +48,7 @@ class DeepSeek(dspy.LM):

def __init__(
self,
logger: Logger,
model: str = "deepseek-chat",
api_key: Optional[str] = None,
api_base: Optional[str] = "https://api.deepseek.com",
Expand All @@ -50,7 +57,7 @@ def __init__(
**kwargs,
):
super().__init__(model)
self.logger = Logger(__name__)
self.logger = logger
self.provider = "openai"
self.api_base = api_base
self.system_prompt = system_prompt
Expand Down
66 changes: 43 additions & 23 deletions mind_renderer/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from colorama import Fore, ansi
from dotenv import load_dotenv

from mind_renderer.utils.config_loader import ConfigLoader


class Logger:
_instance = None
Expand Down Expand Up @@ -44,10 +46,13 @@ def __new__(cls, *args, **kwargs):
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self, logger_name: str, verbose: bool = True, level: Optional[LoggingLevel] = None):
def __init__(
self, logger_name: str, parent_folder: str, verbose: bool = True, level: Optional[LoggingLevel] = None
):
if not hasattr(self, "logger"):
load_dotenv(override=True)
self.logging_level = level if level else Logger.LoggingLevel[os.getenv("LOGGING_LEVEL", "INFO")]
self.config = ConfigLoader()
self.logging_level = level if level else Logger.LoggingLevel[self.config.get_value("log_level", "INFO")]
self.logger = logging.getLogger(logger_name)
self.logger.setLevel(level=self.logging_level.value)

Expand All @@ -61,37 +66,52 @@ def __init__(self, logger_name: str, verbose: bool = True, level: Optional[Loggi
self.console_handler.setFormatter(self.formatter)
self.logger.addHandler(self.console_handler)

def log(self, message: str, level: LoggingLevel, color: str = ansi.Fore.GREEN) -> None:
# File handler
log_folder_root = self.config.get_value("logger.folder_root", "logs")
log_folder = os.path.join(log_folder_root, parent_folder)
if not os.path.exists(log_folder):
os.makedirs(log_folder)
log_file = os.path.join(log_folder, f"{logger_name}.log")
self.file_handler = logging.FileHandler(log_file)
self.file_handler.setLevel(level=self.logging_level.value)
self.file_handler.setFormatter(self.formatter)
self.logger.addHandler(self.file_handler)

def log(self, message: str, level: LoggingLevel, color: str = ansi.Fore.GREEN, write_to_file: bool = False) -> None:
if level >= self.logging_level:
if len(inspect.stack()) >= 4:
caller_frame = inspect.stack()[2]
caller_frame = inspect.stack()[3]
else:
caller_frame = inspect.stack()[2]
caller_name = caller_frame[3]
caller_line = caller_frame[2]
caller_name = caller_frame.function
caller_line = caller_frame.lineno
message = f"{caller_name}({caller_line}): {message}"
self.logger.info(color + message + Fore.RESET)
log_message = color + message + Fore.RESET
self.logger.info(log_message)
if write_to_file:
self.file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
self.logger.info(message)

def debug(self, message: str) -> None:
self.log(message, Logger.LoggingLevel.DEBUG, Fore.BLACK)
def debug(self, message: str, write_to_file: bool = False) -> None:
self.log(message, Logger.LoggingLevel.DEBUG, Fore.BLACK, write_to_file)

def info(self, message: str) -> None:
self.log(message, Logger.LoggingLevel.INFO, Fore.WHITE)
def info(self, message: str, write_to_file: bool = False) -> None:
self.log(message, Logger.LoggingLevel.INFO, Fore.WHITE, write_to_file)

def tool_log(self, message: str) -> None:
self.log(message, Logger.LoggingLevel.TOOL, Fore.YELLOW)
def tool_log(self, message: str, write_to_file: bool = False) -> None:
self.log(message, Logger.LoggingLevel.TOOL, Fore.YELLOW, write_to_file)

def task_log(self, message: str) -> None:
self.log(message, Logger.LoggingLevel.TASK, Fore.BLUE)
def task_log(self, message: str, write_to_file: bool = False) -> None:
self.log(message, Logger.LoggingLevel.TASK, Fore.BLUE, write_to_file)

def thought_process_log(self, message: str) -> None:
self.log(message, Logger.LoggingLevel.THOUGHT_PROCESS, Fore.GREEN)
def thought_process_log(self, message: str, write_to_file: bool = False) -> None:
self.log(message, Logger.LoggingLevel.THOUGHT_PROCESS, Fore.GREEN, write_to_file)

def warning(self, message: str) -> None:
self.log(message, Logger.LoggingLevel.WARNING, Fore.YELLOW)
def warning(self, message: str, write_to_file: bool = False) -> None:
self.log(message, Logger.LoggingLevel.WARNING, Fore.YELLOW, write_to_file)

def error(self, message: str) -> None:
self.log(message, Logger.LoggingLevel.ERROR, Fore.RED)
def error(self, message: str, write_to_file: bool = False) -> None:
self.log(message, Logger.LoggingLevel.ERROR, Fore.RED, write_to_file)

def critical(self, message: str) -> None:
self.log(message, Logger.LoggingLevel.CRITICAL, Fore.MAGENTA)
def critical(self, message: str, write_to_file: bool = False) -> None:
self.log(message, Logger.LoggingLevel.CRITICAL, Fore.MAGENTA, write_to_file)
Loading

0 comments on commit 109e8e4

Please sign in to comment.