Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
alex28sh committed Nov 28, 2024
1 parent 6508efc commit 41cf12d
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 35 deletions.
14 changes: 10 additions & 4 deletions verified_cogen/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,16 @@ def add(self, prg: str, checks: str, function: Optional[str] = None) -> str:
self.add_user_prompt(prompt, False)
return self.make_request()

def rewrite(self, prg: str, additional_prompt: str = "") -> str:
self.add_user_prompt(
prompts.rewrite_prompt(self.prompt_dir).replace("{program}", prg)
)
def rewrite(
self,
prg: str,
text_description: Optional[str] = None,
additional_prompt: str = "",
) -> str:
result = prompts.rewrite_prompt(self.prompt_dir).replace("{program}", prg)
if text_description is not None and "{text_description}" in result:
result = result.replace("{text_description}", text_description)
self.add_user_prompt(result)
if additional_prompt:
self.add_user_prompt(additional_prompt)
return self.make_request()
Expand Down
7 changes: 6 additions & 1 deletion verified_cogen/runners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ def __init__(
if self.config.log_tries is not None:
self.config.log_tries.mkdir(exist_ok=True, parents=True)

def rewrite(self, prg: str, text_description: Optional[str] = None, additional_prompt: str = "") -> str:
def rewrite(
self,
prg: str,
text_description: Optional[str] = None,
additional_prompt: str = "",
) -> str:
"""Rewrite the program with additional checks in one step."""
...

Expand Down
14 changes: 11 additions & 3 deletions verified_cogen/runners/flush.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ class FlushRunner(Runner):

def __init__(self, wrapped_runner: Runner):
super().__init__(
wrapped_runner.llm, wrapped_runner.logger, wrapped_runner.verifier, wrapped_runner.config
wrapped_runner.llm,
wrapped_runner.logger,
wrapped_runner.verifier,
wrapped_runner.config,
)
self.wrapped_runner = wrapped_runner

Expand Down Expand Up @@ -44,7 +47,12 @@ def preprocess(self, prg: str, mode: Mode) -> str:
def postprocess(self, inv_prg: str) -> str:
return self.wrapped_runner.postprocess(inv_prg)

def rewrite(self, prg: str, text_description: Optional[str] = None, additional_prompt: str = "") -> str:
def rewrite(
self,
prg: str,
text_description: Optional[str] = None,
additional_prompt: str = "",
) -> str:
return self.wrapped_runner.rewrite(prg, text_description, additional_prompt)

def produce(self, prg: str) -> str:
Expand All @@ -54,4 +62,4 @@ def insert(self, prg: str, checks: str, mode: Mode) -> str:
return self.wrapped_runner.insert(prg, checks, mode)

def precheck(self, prg: str, mode: Mode):
return self.wrapped_runner.precheck(prg, mode)
return self.wrapped_runner.precheck(prg, mode)
9 changes: 7 additions & 2 deletions verified_cogen/runners/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@


class GenerateRunner(Runner):
def rewrite(self, prg: str, text_description: Optional[str] = None, additional_prompt: str = "") -> str:
return self.llm.rewrite(prg, additional_prompt)
def rewrite(
self,
prg: str,
text_description: Optional[str] = None,
additional_prompt: str = "",
) -> str:
return self.llm.rewrite(prg, text_description, additional_prompt)

def produce(self, prg: str) -> str:
raise ValueError("Produce not supported for generate")
Expand Down
9 changes: 7 additions & 2 deletions verified_cogen/runners/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@


class GenericRunner(Runner):
def rewrite(self, prg: str, text_description: Optional[str] = None, additional_prompt: str = "") -> str:
return self.llm.rewrite(prg, additional_prompt)
def rewrite(
self,
prg: str,
text_description: Optional[str] = None,
additional_prompt: str = "",
) -> str:
return self.llm.rewrite(prg, text_description, additional_prompt)

def produce(self, prg: str) -> str:
return self.llm.produce(prg)
Expand Down
9 changes: 7 additions & 2 deletions verified_cogen/runners/invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ def insert_invariants(llm: LLM, prg: str, inv: str, mode: Mode):


class InvariantRunner(Runner):
def rewrite(self, prg: str, text_description: Optional[str] = None, additional_prompt: str = "") -> str:
return self.llm.rewrite(prg, additional_prompt)
def rewrite(
self,
prg: str,
text_description: Optional[str] = None,
additional_prompt: str = "",
) -> str:
return self.llm.rewrite(prg, text_description, additional_prompt)

def produce(self, prg: str) -> str:
return self.llm.produce(prg)
Expand Down
9 changes: 7 additions & 2 deletions verified_cogen/runners/languages/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def remove_conditions(self, code: str) -> str: ...
def separate_validator_errors(self, errors: str) -> tuple[str, str]: ...

@abstractmethod
def check_helpers(self, code: str, pure_non_helpers: [str]) -> Tuple[List[str], str]: ...
def check_helpers(
self, code: str, pure_non_helpers: [str]
) -> Tuple[List[str], str]: ...

@abstractmethod
def find_pure_non_helpers(self, code: str) -> [str]: ...
Expand Down Expand Up @@ -105,12 +107,15 @@ def remove_conditions(self, code: str) -> str:
lines = [line for line in lines if self.inline_assert_comment not in line]
return "\n".join(lines).strip()

def check_helpers(self, code: str, pure_non_helpers: [str]) -> Tuple[List[str], str]:
def check_helpers(
self, code: str, pure_non_helpers: [str]
) -> Tuple[List[str], str]:
return [], code

def find_pure_non_helpers(self, code: str) -> [str]:
return []


class LanguageDatabase:
_instance = None
languages: dict[str, Language] = dict()
Expand Down
6 changes: 4 additions & 2 deletions verified_cogen/runners/languages/nagini.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def separate_validator_errors(self, errors: str) -> tuple[str, str]:
]
return "\n".join(lines), ""

def check_helpers(self, code: str, pure_non_helpers: [str]) -> Tuple[List[str], str]:
def check_helpers(
self, code: str, pure_non_helpers: [str]
) -> Tuple[List[str], str]:
return detect_and_replace_pure_calls_nagini(code, pure_non_helpers)

def find_pure_non_helpers(self, code: str) -> [str]:
Expand All @@ -57,4 +59,4 @@ def find_pure_non_helpers(self, code: str) -> [str]:
non_helpers: list[str] = []
for match in methods:
non_helpers.append(match.group(3))
return non_helpers
return non_helpers
7 changes: 6 additions & 1 deletion verified_cogen/runners/step_by_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ def __init__(self, wrapping: Runner, config: Optional[StepByStepConfig] = None):
def preprocess(self, prg: str, mode: Mode) -> str:
return self.wrapped_runner.preprocess(prg, mode)

def rewrite(self, prg: str, text_description: Optional[str] = None, additional_prompt: str = "") -> str:
def rewrite(
self,
prg: str,
text_description: Optional[str] = None,
additional_prompt: str = "",
) -> str:
return (
self.rewrite_full_examples(prg, text_description)
if self._config.full_examples
Expand Down
29 changes: 22 additions & 7 deletions verified_cogen/runners/validating.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def _add_validators(self, prg: str, inv_prg: str):
def preprocess(self, prg: str, mode: Mode) -> str:
if self.config.remove_implementations:
self.pure_non_helpers = self.language.find_pure_non_helpers(prg)
self.logger.info("found pure_non_helpers: " + ",".join(self.pure_non_helpers))
self.logger.info(
"found pure_non_helpers: " + ",".join(self.pure_non_helpers)
)
res_prg = self.language.remove_conditions(prg)
self.wrapped_runner.starting_prg = res_prg
return res_prg
Expand All @@ -51,21 +53,34 @@ def postprocess(self, inv_prg: str) -> str:
assert self.starting_prg is not None
invalid_helpers: [str] = []
try:
invalid_helpers, inv_prg = self.language.check_helpers(inv_prg, self.pure_non_helpers)
invalid_helpers, inv_prg = self.language.check_helpers(
inv_prg, self.pure_non_helpers
)
self.logger.info("invalid_helpers: " + ",".join(invalid_helpers))
except:
except Exception:
self.logger.info("pass")
pass
if invalid_helpers:
self.llm.add_user_prompt(prompts.invalid_helpers_prompt(self.llm.prompt_dir).replace("{invalid_helpers}",
",".join(invalid_helpers)).replace("{program}", inv_prg).replace("{helpers}", ",".join(self.pure_non_helpers)))
self.llm.add_user_prompt(
prompts.invalid_helpers_prompt(self.llm.prompt_dir)
.replace("{invalid_helpers}", ",".join(invalid_helpers))
.replace("{program}", inv_prg)
.replace("{helpers}", ",".join(self.pure_non_helpers))
)
return self._add_validators(
self.starting_prg, self.wrapped_runner.postprocess(inv_prg)
)

def rewrite(self, prg: str, text_description: Optional[str] = None, additional_prompt: str = "") -> str:
def rewrite(
self,
prg: str,
text_description: Optional[str] = None,
additional_prompt: str = "",
) -> str:
if self.config.remove_implementations and self.pure_non_helpers:
additional_prompt += prompts.helpers_prompt(self.llm.prompt_dir).replace("{helpers}",",".join(self.pure_non_helpers))
additional_prompt += prompts.helpers_prompt(self.llm.prompt_dir).replace(
"{helpers}", ",".join(self.pure_non_helpers)
)
return self.wrapped_runner.rewrite(prg, text_description, additional_prompt)

def produce(self, prg: str) -> str:
Expand Down
22 changes: 13 additions & 9 deletions verified_cogen/tools/pureCallsDetectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, pure_non_helpers: [str]):
self.pure_non_helpers = pure_non_helpers

def visit_FunctionDef(self, node: ast.FunctionDef):
is_pure = any(decorator.id == 'Pure' for decorator in node.decorator_list)
is_pure = any(decorator.id == "Pure" for decorator in node.decorator_list)
if is_pure and node.name not in self.pure_non_helpers:
self.pure_functions.append(node.name)

Expand All @@ -34,15 +34,17 @@ def visit_Call(self, node: ast.Call):
if isinstance(node.func, ast.Name):
# print("A")
# print("A " + node.func.id)
if (node.func.id in self.pure_functions and
not self.in_pure_function and
not self.in_condition and
self.current_function is not None):
if (
node.func.id in self.pure_functions
and not self.in_pure_function
and not self.in_condition
and self.current_function is not None
):
self.detected_calls.append(node.func.id)
return ast.Call(
func=ast.Name(id='invalid_call', ctx=ast.Load()),
func=ast.Name(id="invalid_call", ctx=ast.Load()),
args=[],
keywords=[]
keywords=[],
)
if node.func.id not in ["Invariant", "Assert", "Requires", "Ensures"]:
return self.generic_visit(node)
Expand Down Expand Up @@ -76,9 +78,11 @@ def visit_Assert(self, node: ast.Assert):
return node


def detect_and_replace_pure_calls_nagini(code: str, pure_non_helpers: [str]) -> Tuple[List[str], str]:
def detect_and_replace_pure_calls_nagini(
code: str, pure_non_helpers: [str]
) -> Tuple[List[str], str]:
tree = ast.parse(code)
replacer = PureFunctionCallReplacer(pure_non_helpers)
modified_tree = replacer.visit(tree)
new_code = ast.unparse(modified_tree)
return replacer.detected_calls, new_code
return replacer.detected_calls, new_code

0 comments on commit 41cf12d

Please sign in to comment.