Skip to content

Commit

Permalink
Add try-logging to runners (#6)
Browse files Browse the repository at this point in the history
* Add try-logging to runners

* Formatting fix

* fix bugs, add pyright check

* rename job

---------

Co-authored-by: WeetHet <[email protected]>
  • Loading branch information
gt22 and WeetHet authored Sep 16, 2024
1 parent 782b190 commit ef54ec3
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 19 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/test-pytest-ruff.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Run pytest and ruff
name: Run pytest, ruff, and pyright

on: [push]

Expand All @@ -8,7 +8,7 @@ env:
POETRY_URL: https://install.python-poetry.org

jobs:
run-pytest:
run-tests:
runs-on: ubuntu-latest
steps:
- name: Checkout
Expand Down Expand Up @@ -38,5 +38,7 @@ jobs:
run: poetry run ruff format verified_cogen --check
- name: Run ruff linter
run: poetry run ruff check verified_cogen
- name: Run pyright
run: poetry run pyright verified_cogen
- name: Run pytest
run: poetry run pytest
31 changes: 30 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ ruff = "^0.5.4"
pytest = "^8.3.1"
matplotlib = "^3.9.2"
ipykernel = "^6.29.5"
pyright = "^1.1.380"

[build-system]
requires = ["poetry-core"]
Expand Down
3 changes: 3 additions & 0 deletions verified_cogen/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def get_default_parser():
"-s", "--output-style", choices=["stats", "full"], default="full"
)
parser.add_argument("--filter-by-ext", help="filter by extension", default=None)
parser.add_argument(
"--log-tries", help="Save output of every try to given dir", default=None
)
return parser


Expand Down
3 changes: 2 additions & 1 deletion verified_cogen/experiments/incremental_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def main():
assert args.retries == 0

directory = pathlib.Path(args.dir)
log_tries = pathlib.Path(args.log_tries) if args.log_tries is not None else None
results_directory = pathlib.Path("results")
results_directory.mkdir(exist_ok=True)
json_results = pathlib.Path("results") / f"tries_{directory.name}.json"
Expand All @@ -51,7 +52,7 @@ def main():
args.temperature,
)
runner = ValidatingRunner(
wrapping=InvariantRunner(llm, logger, verifier),
wrapping=InvariantRunner(llm, logger, verifier, log_tries),
language=language,
)
display_name = rename_file(file)
Expand Down
23 changes: 13 additions & 10 deletions verified_cogen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from verified_cogen.tools.modes import Mode
from verified_cogen.tools.verifier import Verifier
from pathlib import Path
from typing import Callable
from typing import Callable, Optional
from verified_cogen.runners import Runner
from logging import Logger

Expand Down Expand Up @@ -85,19 +85,19 @@ def run_once(


def make_runner_cls(
bench_type: str, extension: str
bench_type: str, extension: str, log_tries: Optional[pathlib.Path]
) -> Callable[[LLM, Logger, Verifier], Runner]:
def runner_cls(llm: LLM, logger: Logger, verifier: Verifier):
match bench_type:
case "invariants":
return InvariantRunner(llm, logger, verifier)
return InvariantRunner(llm, logger, verifier, log_tries)
case "generic":
return GenericRunner(llm, logger, verifier)
return GenericRunner(llm, logger, verifier, log_tries)
case "generate":
return GenerateRunner(llm, logger, verifier)
return GenerateRunner(llm, logger, verifier, log_tries)
case "validating":
return ValidatingRunner(
InvariantRunner(llm, logger, verifier),
InvariantRunner(llm, logger, verifier, log_tries),
LanguageDatabase().get(extension),
)
case _:
Expand All @@ -120,11 +120,14 @@ def main():

if args.input is None and args.dir is None:
args.input = input("Input file: ").strip()
log_tries = pathlib.Path(args.log_tries) if args.log_tries is not None else None

verifier = Verifier(args.shell, args.verifier_command, args.verifier_timeout)
if args.dir is not None:
files = sorted(list(pathlib.Path(args.dir).glob(ext_glob(args.filter_by_ext))))
runner_cls = make_runner_cls(args.bench_type, extension_from_file_list(files))
runner_cls = make_runner_cls(
args.bench_type, extension_from_file_list(files), log_tries
)
runner = runner_cls(
LLM(
args.grazie_token,
Expand Down Expand Up @@ -166,9 +169,9 @@ def main():
args.prompts_directory,
args.temperature,
)
runner = make_runner_cls(args.bench_type, Path(args.input).suffix[1:])(
llm, logger, verifier
)
runner = make_runner_cls(
args.bench_type, Path(args.input).suffix[1:], log_tries
)(llm, logger, verifier)
tries = runner.run_on_file(mode, args.tries, args.input)
if tries == 0:
print("Verified without modification")
Expand Down
29 changes: 24 additions & 5 deletions verified_cogen/runners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,21 @@ class Runner:
llm: LLM
logger: Logger
verifier: Verifier
log_tries: Optional[pathlib.Path]

def __init__(self, llm: LLM, logger: Logger, verifier: Verifier):
def __init__(
self,
llm: LLM,
logger: Logger,
verifier: Verifier,
log_tries: Optional[pathlib.Path] = None,
):
self.llm = llm
self.logger = logger
self.verifier = verifier
self.log_tries = log_tries
if self.log_tries is not None:
self.log_tries.mkdir(exist_ok=True, parents=True)

def rewrite(self, prg: str) -> str:
"""Rewrite the program with additional checks in one step."""
Expand Down Expand Up @@ -58,9 +68,16 @@ def invoke(self, prg: str, mode: Mode) -> str:
self.logger.info("Invocation done")
return inv_prg

def verify_program(self, name: str, prg: str):
def _verification_file(self, name: str, try_n: int) -> pathlib.Path:
if self.log_tries is not None:
base, extension = name.rsplit(".", 1)
return self.log_tries / f"{base}.{try_n}.{extension}"
else:
return LLM_GENERATED_DIR / name

def verify_program(self, name: str, try_n: int, prg: str):
LLM_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
output = LLM_GENERATED_DIR / name
output = self._verification_file(name, try_n)
with open(output, "w") as f:
f.write(prg)
return self.verifier.verify(output)
Expand All @@ -73,7 +90,9 @@ def try_fixing(
) -> Optional[int]:
tries = total_tries
while tries > 0:
verification_result = self.verify_program(name, inv_prg)
verification_result = self.verify_program(
name, total_tries - tries + 1, inv_prg
)
if verification_result is None:
self.logger.info("Verification timed out")
tries -= 1
Expand Down Expand Up @@ -105,7 +124,7 @@ def run_on_file(
with open(file, "r") as f:
prg = self.preprocess(f.read(), mode)

verification_result = self.verify_program(name, prg)
verification_result = self.verify_program(name, 0, prg)
if verification_result is not None and verification_result[0]:
return 0
elif verification_result is None:
Expand Down

0 comments on commit ef54ec3

Please sign in to comment.