Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
WeetHet committed Sep 8, 2024
1 parent 114b115 commit f2f11bf
Show file tree
Hide file tree
Showing 13 changed files with 515 additions and 384 deletions.
475 changes: 246 additions & 229 deletions poetry.lock

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions tests/test_dafny_generate copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from textwrap import dedent
from verified_cogen.runners.languages import LanguageDatabase, init_basic_languages

init_basic_languages()


def test_dafny_generate():
dafny_lang = LanguageDatabase().get("dafny")
code = dedent(
"""\
method main(value: int) returns (result: int)
requires value >= 10
ensures result >= 20
{
assert value * 2 >= 20; // assert-line
result := value * 2;
}"""
)
assert dafny_lang.generate_validators(code) == dedent(
"""\
method main_valid(value: int) returns (result: int)
requires value >= 10
ensures result >= 20
{ var ret := main(value); return ret; }
"""
)
26 changes: 26 additions & 0 deletions tests/test_nagini_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# from textwrap import dedent
# from verified_cogen.runners.languages import LanguageDatabase, init_basic_languages

# init_basic_languages()


# def test_nagini_generate():
# nagini_lang = LanguageDatabase().get("nagini")
# code = dedent(
# """\
# def main(value: int) -> int:
# Requires(value >= 10)
# Ensures(Result() >= 20)
# Assert(value * 2 >= 20) // assert-line
# return value * 2"""
# )
# assert nagini_lang.generate_validators(code) == dedent(
# """\
# def main_valid(value: int) -> int:
# Requires(value >= 10)
# Ensures(Result() >= 20)
# ret = main(value)
# return ret"""
# )

# print(nagini_lang.generate_validators(code))
3 changes: 3 additions & 0 deletions verified_cogen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from verified_cogen.runners.generate import GenerateRunner
from verified_cogen.runners.generic import GenericRunner
from verified_cogen.runners.invariants import InvariantRunner
from verified_cogen.runners.languages import init_basic_languages
from verified_cogen.runners.validating import ValidatingRunner
from verified_cogen.tools import pprint_stat, rename_file, tabulate_list
from verified_cogen.tools.modes import Mode
Expand Down Expand Up @@ -66,6 +67,8 @@ def run_once(files, args, runner, verifier, mode, is_once) -> tuple[int, int, in


def main():
init_basic_languages()

args = get_args()
mode = Mode(args.insert_conditions_mode)
if mode == Mode.REGEX:
Expand Down
85 changes: 38 additions & 47 deletions verified_cogen/runners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from abc import ABC
from logging import Logger
from typing import Optional
import pathlib
Expand All @@ -11,103 +10,95 @@
LLM_GENERATED_DIR = pathlib.Path(get_cache_dir()) / "llm-generated"


class Runner(ABC):
@classmethod
def rewrite(cls, llm: LLM, prg: str) -> str:
class Runner:
llm: LLM
logger: Logger
verifier: Verifier

def __init__(self, llm: LLM, logger: Logger, verifier: Verifier):
self.llm = llm
self.logger = logger
self.verifier = verifier

def rewrite(self, prg: str) -> str:
"""Rewrite the program with additional checks in one step."""
...

@classmethod
def produce(cls, llm: LLM, prg: str) -> str:
def produce(self, prg: str) -> str:
"""Produce the additional checks for the program."""
...

@classmethod
def insert(cls, llm: LLM, prg: str, checks: str, mode: Mode) -> str:
def insert(self, prg: str, checks: str, mode: Mode) -> str:
"""Insert the additional checks into the program."""
...

@classmethod
def precheck(cls, prg: str, mode: Mode):
def precheck(self, prg: str, mode: Mode):
pass

@classmethod
def preprocess(cls, prg: str, mode: Mode) -> str:
def preprocess(self, prg: str, mode: Mode) -> str:
return prg

@classmethod
def invoke(cls, logger: Logger, llm: LLM, prg: str, mode: Mode) -> str:
logger.info("Invoking LLM")
def invoke(self, prg: str, mode: Mode) -> str:
self.logger.info("Invoking LLM")
if mode.is_singlestep:
inv_prg = cls.rewrite(llm, prg)
inv_prg = self.rewrite(prg)
else:
raise ValueError(f"Unexpected mode: {mode}")
logger.info("Invocation done")
self.logger.info("Invocation done")
return inv_prg

@classmethod
def verify_program(cls, verifier: Verifier, name: str, prg: str):
def verify_program(self, name: str, prg: str):
LLM_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
output = LLM_GENERATED_DIR / name
with open(output, "w") as f:
f.write(prg)
return verifier.verify(output)
return self.verifier.verify(output)

@classmethod
def try_fixing(
cls,
logger: Logger,
verifier: Verifier,
llm: LLM,
self,
total_tries: int,
inv_prg: str,
name: str,
) -> Optional[int]:
tries = total_tries
while tries > 0:
verification_result = cls.verify_program(verifier, name, inv_prg)
verification_result = self.verify_program(name, inv_prg)
if verification_result is None:
logger.info("Verification timed out")
self.logger.info("Verification timed out")
tries -= 1
if tries > 0:
inv_prg = llm.ask_for_timeout()
inv_prg = self.llm.ask_for_timeout()
else:
verified_inv, out_inv, err_inv = verification_result
if verified_inv:
return total_tries - tries + 1
else:
logger.info("Verification failed:")
logger.info(out_inv)
logger.info(err_inv)
logger.info("Retrying...")
self.logger.info("Verification failed:")
self.logger.info(out_inv)
self.logger.info(err_inv)
self.logger.info("Retrying...")
tries -= 1
if tries > 0:
inv_prg = llm.ask_for_fixed(out_inv + err_inv)
inv_prg = self.llm.ask_for_fixed(out_inv + err_inv)
return None

@classmethod
def run_on_file(
cls,
logger: Logger,
verifier: Verifier,
self,
mode: Mode,
llm: LLM,
total_tries: int,
file: str,
) -> Optional[int]:
name = basename(file)
logger.info(f"Running on {file}")
self.logger.info(f"Running on {file}")

with open(file, "r") as f:
prg = cls.preprocess(f.read(), mode)
prg = self.preprocess(f.read(), mode)

verification_result = cls.verify_program(verifier, name, prg)
verification_result = self.verify_program(name, prg)
if verification_result is not None and verification_result[0]:
return 0
elif verification_result is None:
logger.info("Verification timed out")
cls.precheck(prg, mode)
inv_prg = cls.invoke(logger, llm, prg, mode)
return cls.try_fixing(
logger, verifier, llm, total_tries, inv_prg, name
)
self.logger.info("Verification timed out")
self.precheck(prg, mode)
inv_prg = self.invoke(prg, mode)
return self.try_fixing(total_tries, inv_prg, name)
39 changes: 12 additions & 27 deletions verified_cogen/runners/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

from verified_cogen.tools.modes import Mode

from verified_cogen.llm import LLM
from verified_cogen.runners import Runner
from verified_cogen.tools.verifier import Verifier
from verified_cogen.tools import basename, get_cache_dir


Expand All @@ -16,24 +14,17 @@


class GenerateRunner(Runner):
@classmethod
def rewrite(cls, llm: LLM, prg: str) -> str:
return llm.rewrite(prg)
def rewrite(self, prg: str) -> str:
return self.llm.rewrite(prg)

@classmethod
def produce(cls, llm: LLM, prg: str) -> str:
def produce(self, prg: str) -> str:
raise ValueError("Produce not supported for generate")

@classmethod
def insert(cls, llm: LLM, prg: str, checks: str, mode: Mode) -> str:
def insert(self, prg: str, checks: str, mode: Mode) -> str:
raise ValueError("Insert not supported for generate")

@classmethod
def try_fixing(
cls,
logger: logging.Logger,
verifier: Verifier,
llm: LLM,
self,
total_tries: int,
inv_prg: str,
name: str,
Expand All @@ -44,12 +35,12 @@ def try_fixing(
output = LLM_GENERATED_DIR / f"{name[:-7]}.dfy"
with open(output, "w") as f:
f.write(inv_prg)
verification_result = verifier.verify(output)
verification_result = self.verifier.verify(output)
if verification_result is None:
logger.info("Verification timed out")
tries -= 1
if tries > 0:
inv_prg = llm.ask_for_timeout()
inv_prg = self.llm.ask_for_timeout()
else:
verified_inv, out_inv, err_inv = verification_result
if verified_inv:
Expand All @@ -61,25 +52,19 @@ def try_fixing(
logger.info("Retrying...")
tries -= 1
if tries > 0:
inv_prg = llm.ask_for_fixed(out_inv + err_inv)
inv_prg = self.llm.ask_for_fixed(out_inv + err_inv)
return None

@classmethod
def run_on_file(
cls,
logger: logging.Logger,
verifier: Verifier,
self,
mode: Mode,
llm: LLM,
total_tries: int,
file: str,
) -> Optional[int]:
logger.info(f"Running on {file}")

with open(file, "r") as f:
prg = f.read()
cls.precheck(prg, mode)
inv_prg = cls.invoke(logger, llm, prg, mode)
return cls.try_fixing(
logger, verifier, llm, total_tries, inv_prg, basename(file)
)
self.precheck(prg, mode)
inv_prg = self.invoke(prg, mode)
return self.try_fixing(total_tries, inv_prg, basename(file))
15 changes: 6 additions & 9 deletions verified_cogen/runners/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@


class GenericRunner(Runner):
@classmethod
def rewrite(cls, llm: LLM, prg: str) -> str:
return llm.rewrite(prg)
def rewrite(self, prg: str) -> str:
return self.llm.rewrite(prg)

@classmethod
def produce(cls, llm: LLM, prg: str) -> str:
return llm.produce(prg)
def produce(self, prg: str) -> str:
return self.llm.produce(prg)

@classmethod
def insert(cls, llm: LLM, prg: str, checks: str, mode: Mode) -> str:
def insert(self, prg: str, checks: str, mode: Mode) -> str:
if mode == Mode.REGEX:
raise ValueError("Regex mode not supported for generic")
return llm.add(prg, checks)
return self.llm.add(prg, checks)
18 changes: 7 additions & 11 deletions verified_cogen/runners/invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,16 @@ def insert_invariants(llm: LLM, prg: str, inv: str, mode: Mode):


class InvariantRunner(Runner):
@classmethod
def rewrite(cls, llm: LLM, prg: str) -> str:
return llm.rewrite(prg)
def rewrite(self, prg: str) -> str:
return self.llm.rewrite(prg)

@classmethod
def produce(cls, llm: LLM, prg: str) -> str:
return llm.produce(prg)
def produce(self, prg: str) -> str:
return self.llm.produce(prg)

@classmethod
def insert(cls, llm: LLM, prg: str, checks: str, mode: Mode) -> str:
return insert_invariants(llm, prg, checks, mode)
def insert(self, prg: str, checks: str, mode: Mode) -> str:
return insert_invariants(self.llm, prg, checks, mode)

@classmethod
def precheck(cls, prg: str, mode: Mode):
def precheck(self, prg: str, mode: Mode):
if mode == Mode.REGEX:
while_count = prg.count("while")
if while_count == 0:
Expand Down
8 changes: 8 additions & 0 deletions verified_cogen/runners/languages/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from verified_cogen.runners.languages.dafny import DafnyLanguage
from verified_cogen.runners.languages.language import LanguageDatabase
from verified_cogen.runners.languages.nagini import NaginiLanguage


def init_basic_languages():
LanguageDatabase().add("dafny", ["dfy"], DafnyLanguage())
LanguageDatabase().add("nagini", ["py", "python"], NaginiLanguage())
21 changes: 21 additions & 0 deletions verified_cogen/runners/languages/dafny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Pattern
from verified_cogen.runners.languages.language import GenericLanguage
import re

DAFNY_VALIDATOR_TEMPLATE = """\
def {method_name}_valid({parameters}) -> ({returns}):{specs}\
ret = {method_name}({param_names})
return ret
"""


class DafnyLanguage(GenericLanguage):
method_regex: Pattern[str]

def __init__(self):
super().__init__(
re.compile(
r"method\s+(\w+)\s*\((.*?)\)\s*returns\s*\((.*?)\)(.*?)\{", re.DOTALL
),
DAFNY_VALIDATOR_TEMPLATE,
)
Loading

0 comments on commit f2f11bf

Please sign in to comment.