Skip to content

Commit

Permalink
Zabirauf kwargs (#288)
Browse files Browse the repository at this point in the history
Co-authored-by: Zohaib Rauf <[email protected]>
  • Loading branch information
lbeurerkellner and zabirauf authored Dec 5, 2023
1 parent 76abda8 commit 72e73ec
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
10 changes: 9 additions & 1 deletion src/lmql/language/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,10 +623,18 @@ def transform_node(self, expr, snf):
return self.default_transform_node(expr, snf).strip()
elif type(expr) is ast.Call:
constraint_ref = get_builtin_name(expr.func) or get_inner_constraint_ref(expr.func, self.scope)

if constraint_ref is not None:
args = [self.transform_node(a, snf) for a in expr.args]
args_list = ", ".join(args)
return f"{constraint_ref}([{args_list}])"

keywords = {key.arg: self.transform_node(key.value, snf) for key in expr.keywords}
keywords_list = ", ".join([f"{k}={v}" for k,v in keywords.items()])
if len(keywords_list) > 0:
keywords_list = ", " + keywords_list

return f"{constraint_ref}([{args_list}]{keywords_list})"

if is_allowed_builtin_python_call(expr.func):
return self.default_transform_node(expr, snf).strip()

Expand Down
9 changes: 7 additions & 2 deletions src/lmql/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,9 @@ def remainder(seq: str, phrase: str):

@LMQLOp("REGEX")
class RegexOp(Node):
def __init__(self, predecessors, **kwargs):
super().__init__(predecessors)
self.kwargs = kwargs

def forward(self, *args, **kwargs):
if any([a is None for a in args]): return None
Expand All @@ -744,15 +747,17 @@ def follow(self, *args, **kwargs):
x = args[0]
ex = args[1]
assert isinstance(ex, str)

verbose = self.kwargs.get("verbose", False)

if x == strip_next_token(x):
return fmap(
("*", Regex(ex).fullmatch(strip_next_token(x)))
)

r = Regex(ex)
rd = r.d(strip_next_token(x), verbose=False) # take derivative
print(f"r={r.pattern} x={strip_next_token(x)} --> {rd.pattern if rd is not None else '[no drivative]'}")
rd = r.d(strip_next_token(x), verbose=verbose) # take derivative
if verbose: print(f"r={r.pattern} x={strip_next_token(x)} --> {rd.pattern if rd is not None else '[no drivative]'}")
if rd is None:
return False
elif rd.is_empty(): # derivative is empty -> full match; therefore we must end
Expand Down

0 comments on commit 72e73ec

Please sign in to comment.