From 10e97ffc5239cf4634e3e7e5dc950be9df0bdf44 Mon Sep 17 00:00:00 2001 From: AlexShefY Date: Thu, 19 Sep 2024 09:32:40 +0200 Subject: [PATCH] fix parsing --- .gitmodules | 1 + tests/test_nagini.py | 16 +++++++++++++++- verified_cogen/runners/languages/language.py | 1 - verified_cogen/runners/languages/nagini.py | 2 +- 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index d0350e6..e049f86 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,6 +5,7 @@ [submodule "benches/HumanEval-Dafny"] path = benches/HumanEval-Dafny url = https://github.com/JetBrains-Research/HumanEval-Dafny + [submodule "benches/HumanEval-Nagini"] path = benches/HumanEval-Nagini url = https://github.com/JetBrains-Research/HumanEval-Nagini diff --git a/tests/test_nagini.py b/tests/test_nagini.py index 1e994a1..0679ac8 100644 --- a/tests/test_nagini.py +++ b/tests/test_nagini.py @@ -11,8 +11,10 @@ def test_nagini_generate(): def main(value: int) -> int: Requires(value >= 10) Ensures(Result() >= 20) + # impl-start Assert(value * 2 >= 20) # assert-line - return value * 2""" + return value * 2 + # impl-end""" ) assert nagini_lang.generate_validators(code) == dedent( """\ @@ -43,8 +45,12 @@ def main(value: int) -> int: assert nagini_lang.generate_validators(code) == dedent( """\ def main_valid(value: int) -> int: + # pre-conditions-start Requires(value >= 10) + # pre-conditions-end + # post-conditions-start Ensures(Result() >= 20) + # post-conditions-end ret = main(value) return ret""" ) @@ -247,17 +253,23 @@ def alpha_valid(c : int) -> bool : ret = alpha(c) return ret def flip__char_valid(c : int) -> int : + # pre-conditions-start Ensures(lower(c) == upper(Result())) Ensures(upper(c) == lower(Result())) + # pre-conditions-end ret = flip__char(c) return ret def flip__case_valid(s : List[int]) -> List[int] : + # pre-conditions-start Requires(Acc(list_pred(s))) + # pre-conditions-end + # post-conditions-start Ensures(Acc(list_pred(s))) Ensures(Acc(list_pred(Result()))) Ensures((len(Result())) == (len(s))) Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), lower((s)[d_0_i_]) == upper((Result())[d_0_i_]))))) Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), upper((s)[d_0_i_]) == lower((Result())[d_0_i_]))))) + # post-conditions-end ret = flip__case(s) return ret""" ) @@ -287,8 +299,10 @@ def flip__char(c : int) -> int : assert nagini_lang.generate_validators(code) == dedent( """\ def flip__char_valid(c : int) -> int : + # pre-conditions-start Ensures(lower(c) == upper(Result())) Ensures(upper(c) == lower(Result())) + # pre-conditions-end ret = flip__char(c) return ret""" ) diff --git a/verified_cogen/runners/languages/language.py b/verified_cogen/runners/languages/language.py index 2b9cb81..6e9d6db 100644 --- a/verified_cogen/runners/languages/language.py +++ b/verified_cogen/runners/languages/language.py @@ -43,7 +43,6 @@ def __init__( self.inline_assert_comment = inline_assert_comment def generate_validators(self, code: str) -> str: - code = re.sub(r"^ *#.*(\r\n|\r|\n)?", "", code, flags=re.MULTILINE) methods = self.method_regex.finditer(code) validators = [] diff --git a/verified_cogen/runners/languages/nagini.py b/verified_cogen/runners/languages/nagini.py index 559ec72..044d7e8 100644 --- a/verified_cogen/runners/languages/nagini.py +++ b/verified_cogen/runners/languages/nagini.py @@ -16,7 +16,7 @@ class NaginiLanguage(GenericLanguage): def __init__(self): super().__init__( re.compile( - r"def\s+(\w+)\s*\((.*?)\)\s*->\s*(.*?):(:?(?:\r\n|\r|\n)?( *(?:Requires|Ensures)\([^\r\n]*\)(?:\r\n|\r|\n)?)*)", + r"def\s+(\w+)\s*\((.*?)\)\s*->\s*(.*?):(.*?(\r\n|\r|\n))\s+# impl-start", re.DOTALL, ), NAGINI_VALIDATOR_TEMPLATE,