From ddc0219f5691581c73a2107e8d63920d8b1ac20d Mon Sep 17 00:00:00 2001 From: WeetHet Date: Tue, 17 Sep 2024 10:11:05 +0300 Subject: [PATCH] fix nagini regex --- tests/test_nagini.py | 158 ++++++++++++++++++- verified_cogen/runners/languages/language.py | 8 +- verified_cogen/runners/languages/nagini.py | 2 +- 3 files changed, 163 insertions(+), 5 deletions(-) diff --git a/tests/test_nagini.py b/tests/test_nagini.py index 013a611..1e994a1 100644 --- a/tests/test_nagini.py +++ b/tests/test_nagini.py @@ -1,4 +1,3 @@ -import re from textwrap import dedent from verified_cogen.runners.languages import LanguageDatabase, register_basic_languages @@ -25,6 +24,32 @@ def main_valid(value: int) -> int: ) +def test_nagini_with_comments(): + nagini_lang = LanguageDatabase().get("nagini") + code = dedent( + """\ + def main(value: int) -> int: + # pre-conditions-start + Requires(value >= 10) + # pre-conditions-end + # post-conditions-start + Ensures(Result() >= 20) + # post-conditions-end + # impl-start + Assert(value * 2 >= 20) # assert-line + return value * 2 + # impl-end""" + ) + assert nagini_lang.generate_validators(code) == dedent( + """\ + def main_valid(value: int) -> int: + Requires(value >= 10) + Ensures(Result() >= 20) + ret = main(value) + return ret""" + ) + + def test_remove_line(): nagini_lang = LanguageDatabase().get("nagini") code = dedent( @@ -136,3 +161,134 @@ def is_prime(k : int) -> bool: d_2_i_ = (d_2_i_) + (1) return result""" ) + + +def test_nagini_large(): + nagini_lang = LanguageDatabase().get("nagini") + + code = dedent( + """\ + from typing import cast, List, Dict, Set, Optional, Union + from nagini_contracts.contracts import * + + @Pure + def lower(c : int) -> bool : + # impl-start + return ((0) <= (c)) and ((c) <= (25)) + # impl-end + + @Pure + def upper(c : int) -> bool : + # impl-start + return ((26) <= (c)) and ((c) <= (51)) + # impl-end + + @Pure + def alpha(c : int) -> bool : + # impl-start + return (lower(c)) or (upper(c)) + # impl-end + + @Pure + def flip__char(c : int) -> int : + # pre-conditions-start + Ensures(lower(c) == upper(Result())) + Ensures(upper(c) == lower(Result())) + # pre-conditions-end + + # impl-start + if lower(c): + return ((c) - (0)) + (26) + elif upper(c): + return ((c) + (0)) - (26) + elif True: + return c + # impl-end + + def flip__case(s : List[int]) -> List[int] : + # pre-conditions-start + Requires(Acc(list_pred(s))) + # pre-conditions-end + # post-conditions-start + Ensures(Acc(list_pred(s))) + Ensures(Acc(list_pred(Result()))) + Ensures((len(Result())) == (len(s))) + Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), lower((s)[d_0_i_]) == upper((Result())[d_0_i_]))))) + Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), upper((s)[d_0_i_]) == lower((Result())[d_0_i_]))))) + # post-conditions-end + + # impl-start + res = list([int(0)] * len(s)) # type : List[int] + i = int(0) # type : int + while i < len(s): + # invariants-start + Invariant(Acc(list_pred(s))) + Invariant(Acc(list_pred(res))) + Invariant(((0) <= (i)) and ((i) <= (len(s)))) + Invariant((len(res)) == (len(s))) + Invariant(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (i)), lower((s)[d_0_i_]) == upper((res)[d_0_i_]))))) + Invariant(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (i)), upper((s)[d_0_i_]) == lower((res)[d_0_i_]))))) + # invariants-end + res[i] = flip__char(s[i]) + i = i + 1 + return res + # impl-end""" + ) + # print(nagini_lang.generate_validators(code)) + assert nagini_lang.generate_validators(code) == dedent( + """\ + def lower_valid(c : int) -> bool : + ret = lower(c) + return ret + def upper_valid(c : int) -> bool : + ret = upper(c) + return ret + def alpha_valid(c : int) -> bool : + ret = alpha(c) + return ret + def flip__char_valid(c : int) -> int : + Ensures(lower(c) == upper(Result())) + Ensures(upper(c) == lower(Result())) + ret = flip__char(c) + return ret + def flip__case_valid(s : List[int]) -> List[int] : + Requires(Acc(list_pred(s))) + Ensures(Acc(list_pred(s))) + Ensures(Acc(list_pred(Result()))) + Ensures((len(Result())) == (len(s))) + Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), lower((s)[d_0_i_]) == upper((Result())[d_0_i_]))))) + Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), upper((s)[d_0_i_]) == lower((Result())[d_0_i_]))))) + ret = flip__case(s) + return ret""" + ) + + +def test_nagini_small(): + nagini_lang = LanguageDatabase().get("nagini") + + code = dedent( + """\ + @Pure + def flip__char(c : int) -> int : + # pre-conditions-start + Ensures(lower(c) == upper(Result())) + Ensures(upper(c) == lower(Result())) + # pre-conditions-end + + # impl-start + if lower(c): + return ((c) - (0)) + (26) + elif upper(c): + return ((c) + (0)) - (26) + elif True: + return c + # impl-end""" + ) + assert nagini_lang.generate_validators(code) == dedent( + """\ + def flip__char_valid(c : int) -> int : + Ensures(lower(c) == upper(Result())) + Ensures(upper(c) == lower(Result())) + ret = flip__char(c) + return ret""" + ) diff --git a/verified_cogen/runners/languages/language.py b/verified_cogen/runners/languages/language.py index def935a..72be824 100644 --- a/verified_cogen/runners/languages/language.py +++ b/verified_cogen/runners/languages/language.py @@ -1,5 +1,6 @@ from abc import abstractmethod from typing import Pattern +import re class Language: @@ -39,6 +40,7 @@ def __init__( self.inline_assert_comment = inline_assert_comment def generate_validators(self, code: str) -> str: + code = re.sub(r"^ *#.*(\r\n|\r|\n)?", "", code, flags=re.MULTILINE) methods = self.method_regex.finditer(code) validators = [] @@ -53,9 +55,9 @@ def generate_validators(self, code: str) -> str: validators.append( self.validator_template.replace("{method_name}", method_name) - .replace("{parameters}", parameters) - .replace("{returns}", returns) - .replace("{specs}", specs) + .replace("{parameters}", parameters or "") + .replace("{returns}", returns or "") + .replace("{specs}", specs or "\n") .replace( "{param_names}", ", ".join( diff --git a/verified_cogen/runners/languages/nagini.py b/verified_cogen/runners/languages/nagini.py index 76d3f03..4cb1fe5 100644 --- a/verified_cogen/runners/languages/nagini.py +++ b/verified_cogen/runners/languages/nagini.py @@ -16,7 +16,7 @@ class NaginiLanguage(GenericLanguage): def __init__(self): super().__init__( re.compile( - r"def\s+(\w+)\s*\((.*?)\)\s*->\s*(.*?):((?:\r\n|\r|\n) *(?:Requires|Ensures)\(.*\)(?:\r\n|\r|\n))*", + r"def\s+(\w+)\s*\((.*?)\)\s*->\s*(.*?):(:?\n?( *(?:Requires|Ensures)\([^\n]*\)\n?)*)", re.DOTALL, ), NAGINI_VALIDATOR_TEMPLATE,