Skip to content

Commit

Permalink
Merge pull request #12 from ehennenfent/prompt_manager
Browse files Browse the repository at this point in the history
Add prompt manager
  • Loading branch information
ehennenfent authored May 7, 2024
2 parents 951fb96 + 128c532 commit 653f53c
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ jobs:
- name: Lint formatting
run: black --check .
- name: Lint semantics
run: ruff .
run: ruff check .
- name: Lint types
run: mypy .
56 changes: 24 additions & 32 deletions live_illustrate/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,45 +66,27 @@ def get_args() -> argparse.Namespace:
choices=["1792x1024", "1024x1792", "1024x1024", "512x512", "256x256"],
)
parser.add_argument(
"--image_quality",
default="standard",
help="How fancy of an image to render",
choices=["standard", "hd"],
"--image_quality", default="standard", help="How fancy of an image to render", choices=["standard", "hd"]
)
parser.add_argument(
"--image_style",
default="vivid",
help="How stylized of an image to render",
choices=["vivid", "natural"],
)
parser.add_argument(
"--server_host",
default="0.0.0.0",
help="Address to bind web server",
)
parser.add_argument(
"--server_port",
default=8080,
type=int,
help="Port to serve HTML viewer on",
)
parser.add_argument(
"--open",
action="store_true",
help="Automatically open a browser tab for the rendered images",
"--image_style", default="vivid", help="How stylized of an image to render", choices=["vivid", "natural"]
)
parser.add_argument("--server_host", default="0.0.0.0", help="Address to bind web server")
parser.add_argument("--server_port", default=8080, type=int, help="Port to serve HTML viewer on")
parser.add_argument("--open", action="store_true", help="Automatically open a browser tab for the rendered images")
parser.add_argument(
"--persistence_of_memory",
default=0.2, # Expressed as a fraction of the total buffered transcription
type=float,
help="How much of the previous transcription to retain after generating each summary. 0 - 1.0",
)
parser.add_argument(
"-v",
"--verbose",
action="count",
default=0,
"--oneshot",
type=argparse.FileType("r"),
help="Read transcription lines from a text file and render. Useful for testing.",
)
parser.add_argument("--data_dir", type=str, default=str(DEFAULT_DATA_DIR), help="Directory to save session data")
parser.add_argument("-v", "--verbose", action="count", default=0)
return parser.parse_args()


Expand All @@ -117,8 +99,11 @@ def main() -> None:
logging.getLogger("requests").setLevel(logging.INFO if args.verbose > 0 else logging.WARNING)
logging.getLogger("werkzeug").setLevel(logging.INFO if args.verbose > 0 else logging.WARNING) # flask

# create each of our thread objects with the apppropriate command line args
transcriber = AudioTranscriber(model=args.audio_model, phrase_timeout=args.wait_minutes * args.phrase_timeout)
# We don't test transcription in oneshot mode
if not (is_oneshot := args.oneshot is not None):
transcriber = AudioTranscriber(model=args.audio_model, phrase_timeout=args.wait_minutes * args.phrase_timeout)

# Create each of our thread objects with the apppropriate command line args
buffer = TextBuffer(
wait_minutes=args.wait_minutes, max_context=args.max_context, persistence=args.persistence_of_memory
)
Expand All @@ -133,7 +118,7 @@ def main() -> None:
host=args.server_host, port=args.server_port, default_image=f"https://placehold.co/{args.image_size}/png"
)

with SessionData(DEFAULT_DATA_DIR, echo=True) as session_data:
with SessionData(Path(args.data_dir), echo=True) as session_data:
# wire up some callbacks to save the intermediate data and forward it along
def on_text_transcribed(transcription: Transcription) -> None:
if is_transcription_interesting(transcription):
Expand All @@ -151,7 +136,8 @@ def on_image_rendered(image: Image | None) -> None:
session_data.save_image(image)

# start each thread with the appropriate callback
Thread(target=transcriber.start, args=(on_text_transcribed,), daemon=True).start()
if not is_oneshot:
Thread(target=transcriber.start, args=(on_text_transcribed,), daemon=True).start()
Thread(target=summarizer.start, args=(on_summary_generated,), daemon=True).start()
Thread(target=renderer.start, args=(on_image_rendered,), daemon=True).start()

Expand All @@ -169,6 +155,12 @@ def open_browser() -> None:

Thread(target=lambda: open_browser).start()

if is_oneshot:
# Read all the lines from the file, pretend we transcribed them
for line in args.oneshot: # type: ignore
# This will still dump things in the data directory. No sense short circuiting the testing.
on_text_transcribed(Transcription(line.strip()))

# flask feels like it probably has a good ctrl+c handler, so we'll make this one the main thread
server.start()

Expand Down
20 changes: 20 additions & 0 deletions live_illustrate/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import typing as t
from pathlib import Path

PROMPTS_FOLDER = Path(__file__).parent.joinpath("prompts")
IMAGE_EXTENSION = PROMPTS_FOLDER.joinpath("image_extra.txt")
SUMMARY = PROMPTS_FOLDER.joinpath("summary.txt")


class PromptManager:
def __init__(self):
self.cached: t.Dict[Path, str] = {}
self.last_modified = {}

def get_prompt(self, path: Path) -> str:
last_modified = path.stat().st_mtime
if self.last_modified.get(path) != last_modified:
with open(path, "r") as f:
self.cached[path] = f.read()
self.last_modified[path] = last_modified
return self.cached[path]
1 change: 1 addition & 0 deletions live_illustrate/prompts/image_extra.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
digital painting, fantasy art
19 changes: 19 additions & 0 deletions live_illustrate/prompts/summary.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
You are a skilled illustrator who draws pictures from a tabletop role playing game.
You will receive lines of dialogue that will include (in part) details about the physical surroundings and appearance of the characters.
In one to two sentences, describe an illustration of the current setting.

For example, given the following dialog (between quotes):
"Kyran, what are you doing right now?
I'm still exploring this dungeon.
Great. You come around a corner and enter a wide room. It's too dark to see what's inside.
Shouldn't elves have darkvision?
Yes... something's making it still be too dark.
Can I sneak in?
Uh your armor's too loud.
Okay, I'll light a torch and then roll to investigate. And that's a four.
You see a rune painted in a dark red liquid on the wall, but don't seem to recognize it."

You might say: "An armor-clad elf holding a torch peers into a dark dungeon room. A strange red rune painted on the wall catches his eye."

If there is more than one scene described by the dialog, try to focus on the most recent one.
Remember to use clear language and to only include details that can be seen.
8 changes: 3 additions & 5 deletions live_illustrate/render.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import typing as t
from datetime import datetime

from openai import OpenAI

from .prompts import IMAGE_EXTENSION, PromptManager
from .util import AsyncThread, Image, Summary

# Prompt engineering level 1,000,000
EXTRA: t.List[str] = ["digital painting, fantasy art"]


class ImageRenderer(AsyncThread):
def __init__(self, model: str, image_size: str, image_quality: str, image_style: str) -> None:
Expand All @@ -17,13 +14,14 @@ def __init__(self, model: str, image_size: str, image_quality: str, image_style:
self.size: str = image_size
self.image_quality: str = image_quality
self.image_style: str = image_style
self.prompt_manager = PromptManager()

def work(self, summary: Summary) -> Image | None:
"""Sends the text to Dall-e, spits out an image URL"""
start = datetime.now()
rendered = self.openai_client.images.generate(
model=self.model,
prompt="\n".join((summary.summary, *EXTRA)),
prompt=summary.summary + "\n" + self.prompt_manager.get_prompt(IMAGE_EXTENSION),
size=self.size, # type: ignore[arg-type]
quality=self.image_quality, # type: ignore[arg-type]
style=self.image_style, # type: ignore[arg-type]
Expand Down
10 changes: 3 additions & 7 deletions live_illustrate/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@

from openai import OpenAI

from .prompts import SUMMARY, PromptManager
from .util import AsyncThread, Summary, Transcription, num_tokens_from_string

SYSTEM_PROMPT = "You are a helpful assistant that describes scenes to an artist who wants to draw them. \
You will be given several lines of dialogue that contain details about the physical surroundings of the characters. \
Your job is to summarize the details of the scene in a bulleted list containing 4-7 bullet points. \
If there is more than one scene described by the dialog, summarize only the most recent one. \
Remember to be concise and not include details that cannot be seen." # Not so good about this last bit, eh?


class TextSummarizer(AsyncThread):
def __init__(self, model: str) -> None:
super().__init__("TextSummarizer")
self.openai_client: OpenAI = OpenAI()
self.model: str = model
self.prompt_manager = PromptManager()

def work(self, transcription: Transcription) -> Summary | None:
"""Sends the big buffer of provided text to ChatGPT, returns bullets describing the setting"""
Expand All @@ -28,7 +24,7 @@ def work(self, transcription: Transcription) -> Summary | None:
response = self.openai_client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "system", "content": self.prompt_manager.get_prompt(SUMMARY)},
{"role": "user", "content": text},
],
)
Expand Down

0 comments on commit 653f53c

Please sign in to comment.