Skip to content

Commit

Permalink
Allow nested structures inside schedules
Browse files Browse the repository at this point in the history
  • Loading branch information
guaneec authored and AUTOMATIC1111 committed Oct 4, 2022
1 parent 6c6ae28 commit 2f1b61d
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 66 deletions.
119 changes: 53 additions & 66 deletions modules/prompt_parser.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
import re
from collections import namedtuple
import torch
from lark import Lark, Transformer, Visitor
import functools

import modules.shared as shared

re_prompt = re.compile(r'''
(.*?)
\[
([^]:]+):
(?:([^]:]*):)?
([0-9]*\.?[0-9]+)
]
|
(.+)
''', re.X)

# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
Expand All @@ -25,61 +16,57 @@


def get_learned_conditioning_prompt_schedules(prompts, steps):
res = []
cache = {}

for prompt in prompts:
prompt_schedule: list[list[str | int]] = [[steps, ""]]

cached = cache.get(prompt, None)
if cached is not None:
res.append(cached)
continue

for m in re_prompt.finditer(prompt):
plaintext = m.group(1) if m.group(5) is None else m.group(5)
concept_from = m.group(2)
concept_to = m.group(3)
if concept_to is None:
concept_to = concept_from
concept_from = ""
swap_position = float(m.group(4)) if m.group(4) is not None else None

if swap_position is not None:
if swap_position < 1:
swap_position = swap_position * steps
swap_position = int(min(swap_position, steps))

swap_index = None
found_exact_index = False
for i in range(len(prompt_schedule)):
end_step = prompt_schedule[i][0]
prompt_schedule[i][1] += plaintext

if swap_position is not None and swap_index is None:
if swap_position == end_step:
swap_index = i
found_exact_index = True

if swap_position < end_step:
swap_index = i

if swap_index is not None:
if not found_exact_index:
prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])

for i in range(len(prompt_schedule)):
end_step = prompt_schedule[i][0]
must_replace = swap_position < end_step

prompt_schedule[i][1] += concept_to if must_replace else concept_from

res.append(prompt_schedule)
cache[prompt] = prompt_schedule
#for t in prompt_schedule:
# print(t)

return res
grammar = r"""
start: prompt
prompt: (emphasized | scheduled | weighted | plain)*
!emphasized: "(" prompt ")"
| "(" prompt ":" prompt ")"
| "[" prompt "]"
scheduled: "[" (prompt ":")? prompt ":" NUMBER "]"
!weighted: "{" weighted_item ("|" weighted_item)* "}"
!weighted_item: prompt (":" prompt)?
plain: /([^\\\[\](){}:|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
"""
parser = Lark(grammar, parser='lalr')
def collect_steps(steps, tree):
l = [steps]
class CollectSteps(Visitor):
def scheduled(self, tree):
tree.children[-1] = float(tree.children[-1])
if tree.children[-1] < 1:
tree.children[-1] *= steps
tree.children[-1] = min(steps, int(tree.children[-1]))
l.append(tree.children[-1])
CollectSteps().visit(tree)
return sorted(set(l))
def at_step(step, tree):
class AtStep(Transformer):
def scheduled(self, args):
if len(args) == 2:
before, after, when = (), *args
else:
before, after, when = args
yield before if step <= when else after
def start(self, args):
def flatten(x):
if type(x) == str:
yield x
else:
for gen in x:
yield from flatten(gen)
return ''.join(flatten(args[0]))
def plain(self, args):
yield args[0].value
def __default__(self, data, children, meta):
for child in children:
yield from child
return AtStep().transform(tree)
@functools.cache
def get_schedule(prompt):
tree = parser.parse(prompt)
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
return [get_schedule(prompt) for prompt in prompts]


ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ clean-fid
resize-right
torchdiffeq
kornia
lark
1 change: 1 addition & 0 deletions requirements_versions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ clean-fid==0.1.29
resize-right==0.0.2
torchdiffeq==0.2.3
kornia==0.6.7
lark==1.1.2

0 comments on commit 2f1b61d

Please sign in to comment.