From 72e73ecde8b2fd58fb1e828c0bc3ac47d178c1f3 Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Tue, 5 Dec 2023 15:17:14 +0100 Subject: [PATCH] Zabirauf kwargs (#288) Co-authored-by: Zohaib Rauf --- src/lmql/language/compiler.py | 10 +++++++++- src/lmql/ops/ops.py | 9 +++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/lmql/language/compiler.py b/src/lmql/language/compiler.py index 4d26ec91..6959d024 100644 --- a/src/lmql/language/compiler.py +++ b/src/lmql/language/compiler.py @@ -623,10 +623,18 @@ def transform_node(self, expr, snf): return self.default_transform_node(expr, snf).strip() elif type(expr) is ast.Call: constraint_ref = get_builtin_name(expr.func) or get_inner_constraint_ref(expr.func, self.scope) + if constraint_ref is not None: args = [self.transform_node(a, snf) for a in expr.args] args_list = ", ".join(args) - return f"{constraint_ref}([{args_list}])" + + keywords = {key.arg: self.transform_node(key.value, snf) for key in expr.keywords} + keywords_list = ", ".join([f"{k}={v}" for k,v in keywords.items()]) + if len(keywords_list) > 0: + keywords_list = ", " + keywords_list + + return f"{constraint_ref}([{args_list}]{keywords_list})" + if is_allowed_builtin_python_call(expr.func): return self.default_transform_node(expr, snf).strip() diff --git a/src/lmql/ops/ops.py b/src/lmql/ops/ops.py index 603c77f6..fdf2af67 100644 --- a/src/lmql/ops/ops.py +++ b/src/lmql/ops/ops.py @@ -731,6 +731,9 @@ def remainder(seq: str, phrase: str): @LMQLOp("REGEX") class RegexOp(Node): + def __init__(self, predecessors, **kwargs): + super().__init__(predecessors) + self.kwargs = kwargs def forward(self, *args, **kwargs): if any([a is None for a in args]): return None @@ -744,6 +747,8 @@ def follow(self, *args, **kwargs): x = args[0] ex = args[1] assert isinstance(ex, str) + + verbose = self.kwargs.get("verbose", False) if x == strip_next_token(x): return fmap( @@ -751,8 +756,8 @@ def follow(self, *args, **kwargs): ) r = Regex(ex) - rd = r.d(strip_next_token(x), verbose=False) # take derivative - print(f"r={r.pattern} x={strip_next_token(x)} --> {rd.pattern if rd is not None else '[no drivative]'}") + rd = r.d(strip_next_token(x), verbose=verbose) # take derivative + if verbose: print(f"r={r.pattern} x={strip_next_token(x)} --> {rd.pattern if rd is not None else '[no drivative]'}") if rd is None: return False elif rd.is_empty(): # derivative is empty -> full match; therefore we must end