diff --git a/prompts/humaneval-nagini-few-shot/produce.txt b/prompts/humaneval-nagini-few-shot/produce.txt index 3ee0661..409bf14 100644 --- a/prompts/humaneval-nagini-few-shot/produce.txt +++ b/prompts/humaneval-nagini-few-shot/produce.txt @@ -4,11 +4,43 @@ Even if you think some invariant is not totally necessary, better add it than no Don't add any additional text comments, your response must contain only program with invariants. Do not provide ANY explanations. Don't include markdown backticks. Respond only in Python code, nothing else. -Pay attention to important details of Nagini syntax: -1. You should use only single inequalities of type `a <= b` or equalities `a == b` -2. Use `Acc(list_pred(x))` invariants for variables of type List[] (for example, if x is of type List[int], add `Acc(list_pred(x))` invariants to all cycles in the scope of variable x) -3. Try to reuse the invariants, encountered as preconditions and postconditions, if these invariants hold during cycle execution -4. Don't use any built-in functions +You remember the following aspects of Nagini syntax: + +1. Nagini DOES NOT SUPPORT some Python features as list comprehensions (k + 1 for k in range(5)), as double inequalities (a <= b <= c). +Instead of double inequalities it's customary to use two separate inequalities (a <= b and b <= c). + +2. In Nagini method preconditions (Requires) and postconditions (Ensures) placed right after method signature, like here: +" +def Sum(a : List[int], s : int, t : int) -> int : + Requires(Acc(list_pred(a))) + Requires(((0) <= (s)) and ((s) <= (t)) and ((t) <= (len(a)))) + Ensures(Acc(list_pred(a))) + ... +" + +3. Invariant are placed right after `while` statement and before the code of `while` body: +" + while i < len(numbers): + Invariant(Acc(list_pred(numbers))) + Invariant(0 <= i and i <= len(numbers)) + s = s + numbers[i] +" +Invariants CANNOT be placed in any other position. +You remember that each invariant (and each expression) should contain equal number of opening and closing brackets, so that it is valid. +You should sustain balanced parentheses. + +4. Nagini requires special annotations for working with lists `Acc(list_pred(..))`. You can use these constructs only inside `Invariant`, +anywhere else you should not use `Acc()` or `list_pred()`: +" + while i < len(numbers): + Invariant(Acc(list_pred(numbers))) +" + +5. Nagini contains `Forall` and `Exists` constructs that can be used in invariants. First argument of Forall/Exists is typically a type (i.e `int`), +second argument is a lambda. `Forall(type, lambda x : a)` denotes that assertion `a` is true for every element `x` of type `type`. + +6. In Nagini `Implies(e1, a2)` plays role of implication. `Implies(e1, a2)` denotes that assertion a2 holds if boolean expression e1 is true. +You can use it inside invariants and asserts. Here are some examples of verified Python cycles with Nagini invariants: ``` diff --git a/prompts/humaneval-nagini-few-shot/rewrite.txt b/prompts/humaneval-nagini-few-shot/rewrite.txt index 7c11f6b..5a1d09d 100644 --- a/prompts/humaneval-nagini-few-shot/rewrite.txt +++ b/prompts/humaneval-nagini-few-shot/rewrite.txt @@ -5,11 +5,43 @@ Even if you think some invariant is not totally necessary, better add it than no Don't add any additional text comments, your response must contain only program with invariants. Do not provide ANY explanations. Don't include markdown backticks. Respond only in Python code, nothing else. -Pay attention to important details of Nagini syntax: -1. You should use only single inequalities of type `a <= b` or equalities `a == b` -2. Use `Acc(list_pred(x))` invariants for variables of type List[] (for example, if x is of type List[int], add `Acc(list_pred(x))` invariants to all cycles in the scope of variable x) -3. Try to reuse the invariants, encountered as preconditions and postconditions, if these invariants hold during cycle execution -4. Don't use any built-in functions +You remember the following aspects of Nagini syntax: + +1. Nagini DOES NOT SUPPORT some Python features as list comprehensions (k + 1 for k in range(5)), as double inequalities (a <= b <= c). +Instead of double inequalities it's customary to use two separate inequalities (a <= b and b <= c). + +2. In Nagini method preconditions (Requires) and postconditions (Ensures) placed right after method signature, like here: +" +def Sum(a : List[int], s : int, t : int) -> int : + Requires(Acc(list_pred(a))) + Requires(((0) <= (s)) and ((s) <= (t)) and ((t) <= (len(a)))) + Ensures(Acc(list_pred(a))) + ... +" + +3. Invariant are placed right after `while` statement and before the code of `while` body: +" + while i < len(numbers): + Invariant(Acc(list_pred(numbers))) + Invariant(0 <= i and i <= len(numbers)) + s = s + numbers[i] +" +Invariants CANNOT be placed in any other position. +You remember that each invariant (and each expression) should contain equal number of opening and closing brackets, so that it is valid. +You should sustain balanced parentheses. + +4. Nagini requires special annotations for working with lists `Acc(list_pred(..))`. You can use these constructs only inside `Invariant`, +anywhere else you should not use `Acc()` or `list_pred()`: +" + while i < len(numbers): + Invariant(Acc(list_pred(numbers))) +" + +5. Nagini contains `Forall` and `Exists` constructs that can be used in invariants. First argument of Forall/Exists is typically a type (i.e `int`), +second argument is a lambda. `Forall(type, lambda x : a)` denotes that assertion `a` is true for every element `x` of type `type`. + +6. In Nagini `Implies(e1, a2)` plays role of implication. `Implies(e1, a2)` denotes that assertion a2 holds if boolean expression e1 is true. +You can use it inside invariants and asserts. You might need to work with accumulating functions, such as sum, so here's an example of how to do that: ``` diff --git a/prompts/humaneval-nagini-without-impls-few-shot/ask_for_fixed.txt b/prompts/humaneval-nagini-without-impls-few-shot/ask_for_fixed.txt new file mode 100644 index 0000000..1631341 --- /dev/null +++ b/prompts/humaneval-nagini-without-impls-few-shot/ask_for_fixed.txt @@ -0,0 +1,6 @@ +The following errors occurred during verification: +{error} + +Please fix the error by adding, removing or modifying the implementation, invariants or assertions and return the fixed program. +Don't add any additional text comments, your response must contain only program with invariants. +Do not provide ANY explanations. Don't include markdown backticks. Respond only in Python code, nothing else. \ No newline at end of file diff --git a/prompts/humaneval-nagini-without-impls-few-shot/ask_for_fixed_had_errors.txt b/prompts/humaneval-nagini-without-impls-few-shot/ask_for_fixed_had_errors.txt new file mode 100644 index 0000000..9e6f1ab --- /dev/null +++ b/prompts/humaneval-nagini-without-impls-few-shot/ask_for_fixed_had_errors.txt @@ -0,0 +1,6 @@ +There are still some errors: +{error} + +Could you please fix them? +Don't add any additional text comments, your response must contain only program with invariants. +Do not provide ANY explanations. Don't include markdown backticks. Respond only in Python code, nothing else. \ No newline at end of file diff --git a/prompts/humaneval-nagini-without-impls-few-shot/helpers.txt b/prompts/humaneval-nagini-without-impls-few-shot/helpers.txt new file mode 100644 index 0000000..f090878 --- /dev/null +++ b/prompts/humaneval-nagini-without-impls-few-shot/helpers.txt @@ -0,0 +1,3 @@ +Generally, you should use helper functions (marked with @Pure annotation) only in invariants, asserts and conditions (in `if` or `while` conditions), not in the plain code. +But, the following helper functions you can use anywhere: {helpers}. +Do not change helper functions. \ No newline at end of file diff --git a/prompts/humaneval-nagini-without-impls-few-shot/invalid_helpers.txt b/prompts/humaneval-nagini-without-impls-few-shot/invalid_helpers.txt new file mode 100644 index 0000000..2a397ac --- /dev/null +++ b/prompts/humaneval-nagini-without-impls-few-shot/invalid_helpers.txt @@ -0,0 +1,9 @@ +We detected an improper usage of helper functions. Here is the list of helper functions used in a wrong way: +{invalid_helpers} +You should use helper functions only in invariants, asserts and conditions (in `if` or `while` conditions), not in the plain code. +The following helper functions you can use anywhere: {helpers}. +We replaced all improper usages with `invalid_call()` and got the following program: +{program} +You should rewrite this program without changing pre/postconditions and helper functions (denoted with @Pure). +After rewriting your code should verify. +Your code should not contain any `invalid_call()` invocations. \ No newline at end of file diff --git a/prompts/humaneval-nagini-without-impls-few-shot/rewrite.txt b/prompts/humaneval-nagini-without-impls-few-shot/rewrite.txt new file mode 100644 index 0000000..97d51eb --- /dev/null +++ b/prompts/humaneval-nagini-without-impls-few-shot/rewrite.txt @@ -0,0 +1,82 @@ +Rewrite the following Nagini code with implementations of some functions missing. While rewriting it, ensure that it verifies. Include invariants and assertions. Don't remove any helper functions (they are marked with @Pure annotation), they are there to help you. Prefer loops to recursion. +Use helper functions only in invariants, asserts and conditions (in `if` or `while` conditions). Don't use helpers in the plain code. +Do not change helper functions. +Add code and invariants to other functions. +Ensure that the invariants are as comprehensive as they can be. +Even if you think some invariant is not totally necessary, better add it than not. +Don't add any additional text comments, your response must contain only program with invariants. +Do not provide ANY explanations. Don't include markdown backticks. Respond only in Python code, nothing else. + + +You remember the following aspects of Nagini syntax: + +1. Nagini DOES NOT SUPPORT some Python features as list comprehensions (k + 1 for k in range(5)), as double inequalities (a <= b <= c). +Instead of double inequalities it's customary to use two separate inequalities (a <= b and b <= c). + +2. In Nagini method preconditions (Requires) and postconditions (Ensures) placed right after method signature, like here: +" +def Sum(a : List[int], s : int, t : int) -> int : + Requires(Acc(list_pred(a))) + Requires(((0) <= (s)) and ((s) <= (t)) and ((t) <= (len(a)))) + Ensures(Acc(list_pred(a))) + ... +" + +3. Invariant are placed right after `while` statement and before the code of `while` body: +" + while i < len(numbers): + Invariant(Acc(list_pred(numbers))) + Invariant(0 <= i and i <= len(numbers)) + s = s + numbers[i] +" +Invariants CANNOT be placed in any other position. +You remember that each invariant (and each expression) should contain equal number of opening and closing brackets, so that it is valid. +You should sustain balanced parentheses. + +4. Nagini requires special annotations for working with lists `Acc(list_pred(..))`. You can use these constructs only inside `Invariant`, +anywhere else you should not use `Acc()` or `list_pred()`: +" + while i < len(numbers): + Invariant(Acc(list_pred(numbers))) +" + +5. Nagini contains `Forall` and `Exists` constructs that can be used in invariants. First argument of Forall/Exists is typically a type (i.e `int`), +second argument is a lambda. `Forall(type, lambda x : a)` denotes that assertion `a` is true for every element `x` of type `type`. + +6. In Nagini `Implies(e1, a2)` plays role of implication. `Implies(e1, a2)` denotes that assertion a2 holds if boolean expression e1 is true. +You can use it inside invariants and asserts. + +You might need to work with accumulating functions, such as sum, so here's an example of how to do that: +``` +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +@Pure +def Sum(a : List[int], s : int, t : int) -> int : + Requires(Acc(list_pred(a))) + Requires(((0) <= (s)) and ((s) <= (t)) and ((t) <= (len(a)))) + + if s == t: + return 0 + else: + return (a)[t - 1] + (Sum(a, s, t - 1)) + +def sum_loop(numbers: List[int]) -> int: + Requires(Acc(list_pred(numbers))) + Ensures(Acc(list_pred(numbers))) + Ensures(Result() == Sum(numbers, 0, len(numbers))) + s = int(0) + i = int(0) + while (i) < (len(numbers)): + Invariant(Acc(list_pred(numbers))) + Invariant(0 <= i and i <= len(numbers)) + Invariant(Forall(int, lambda d_1_p_: + (Implies(0 <= d_1_p_ and d_1_p_ < len(numbers), Sum(numbers, 0, d_1_p_ + 1) == Sum(numbers, 0, d_1_p_) + numbers[d_1_p_]), [[Sum(numbers, 0, d_1_p_ + 1)]]))) + Invariant(s == Sum(numbers, 0, i)) + Assert(Sum(numbers, 0, i + 1) == Sum(numbers, 0, i) + numbers[i]) + s = s + (numbers)[i] + i = i + 1 + return s +``` +The program: +{program} \ No newline at end of file diff --git a/prompts/humaneval-nagini-without-impls-few-shot/sys.txt b/prompts/humaneval-nagini-without-impls-few-shot/sys.txt new file mode 100644 index 0000000..b462d86 --- /dev/null +++ b/prompts/humaneval-nagini-without-impls-few-shot/sys.txt @@ -0,0 +1,4 @@ +You are an expert in a Python verification framework Nagini. +You will be given tasks dealing with Python programs including precise annotations. +Do not provide ANY explanations. Don't include markdown backticks. Respond only in Python code, nothing else. +You respond only with code blocks. \ No newline at end of file diff --git a/prompts/humaneval-nagini-without-impls-few-shot/timeout.txt b/prompts/humaneval-nagini-without-impls-few-shot/timeout.txt new file mode 100644 index 0000000..f69fd69 --- /dev/null +++ b/prompts/humaneval-nagini-without-impls-few-shot/timeout.txt @@ -0,0 +1,3 @@ +The verifier timed out during the verification. +This usually means that the provided invariants were too broad or were difficult to check. +Could you please try to improve the invariants and try again? diff --git a/tests/test_pure_calls.py b/tests/test_pure_calls.py new file mode 100644 index 0000000..6a133ad --- /dev/null +++ b/tests/test_pure_calls.py @@ -0,0 +1,276 @@ +from textwrap import dedent +from typing import List + +from verified_cogen.runners.languages import LanguageDatabase, register_basic_languages, AnnotationType +from verified_cogen.tools.pureCallsDetectors import detect_and_replace_pure_calls_nagini + +register_basic_languages( + with_removed=[ + AnnotationType.INVARIANTS, + AnnotationType.ASSERTS, + AnnotationType.IMPLS, + ] +) + +def test_simple(): + code = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +@Pure +def factorial__spec(n : int) -> int : + Requires(n >= -1) + Ensures(Result() >= 0) + if n == -1: + return 1 + else: + return (n + 1) * factorial__spec(n - 1) + +@Pure +def sum__spec(n : int) -> int : + Requires(n >= -1) + Ensures(Result() >= 0) + if 0 > n: + return 0 + else: + return n + 1 + sum__spec(n - 1) + +def f(n : int) -> List[int]: + Requires((n) >= (1)) + Ensures(Acc(list_pred(Result()))) + Ensures((len(Result())) == (n)) + Ensures(Forall(int, lambda d_2_i_: + not ((((d_2_i_) >= (0)) and ((d_2_i_) < (len(Result())))) and (((d_2_i_ % 2)) == (0))) or (((Result())[d_2_i_]) == (factorial__spec(d_2_i_ - 1))))) + Ensures(Forall(int, lambda d_3_i_: + not ((((d_3_i_) >= (0)) and ((d_3_i_) < (len(Result())))) and (((d_3_i_ % 2)) != (0))) or (((Result())[d_3_i_]) == (sum__spec(d_3_i_ - 1))))) + + result = [0] * n + i = 0 + while sum__spec(i - 1) < n: + Invariant(Acc(list_pred(result))) + Invariant(0 <= i and i <= n) + Invariant(len(result) == n) + Invariant(sum__spec(i - 1) >= 0) + Invariant(Forall(int, lambda j: Implies(0 <= j and j < i, result[j] == (factorial__spec(j - 1) if j % 2 == 0 else sum__spec(j - 1))))) + Invariant(Forall(int, lambda j: Implies(i <= j and j < n, result[j] == 0))) + if i % 2 == 0: + result[i] = factorial__spec(i - 1) + else: + result[i] = sum__spec(i - 1) + i += 1 + return result +""" + ) + + calls, new_code = detect_and_replace_pure_calls_nagini(code, []) + + assert calls == ["factorial__spec", "sum__spec"] + compare_code = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +@Pure +def factorial__spec(n: int) -> int: + Requires(n >= -1) + Ensures(Result() >= 0) + if n == -1: + return 1 + else: + return (n + 1) * factorial__spec(n - 1) + +@Pure +def sum__spec(n: int) -> int: + Requires(n >= -1) + Ensures(Result() >= 0) + if 0 > n: + return 0 + else: + return n + 1 + sum__spec(n - 1) + +def f(n: int) -> List[int]: + Requires(n >= 1) + Ensures(Acc(list_pred(Result()))) + Ensures(len(Result()) == n) + Ensures(Forall(int, lambda d_2_i_: not ((d_2_i_ >= 0 and d_2_i_ < len(Result())) and d_2_i_ % 2 == 0) or Result()[d_2_i_] == factorial__spec(d_2_i_ - 1))) + Ensures(Forall(int, lambda d_3_i_: not ((d_3_i_ >= 0 and d_3_i_ < len(Result())) and d_3_i_ % 2 != 0) or Result()[d_3_i_] == sum__spec(d_3_i_ - 1))) + result = [0] * n + i = 0 + while sum__spec(i - 1) < n: + Invariant(Acc(list_pred(result))) + Invariant(0 <= i and i <= n) + Invariant(len(result) == n) + Invariant(sum__spec(i - 1) >= 0) + Invariant(Forall(int, lambda j: Implies(0 <= j and j < i, result[j] == (factorial__spec(j - 1) if j % 2 == 0 else sum__spec(j - 1))))) + Invariant(Forall(int, lambda j: Implies(i <= j and j < n, result[j] == 0))) + if i % 2 == 0: + result[i] = invalid_call() + else: + result[i] = invalid_call() + i += 1 + return result""" + ) + + assert new_code == compare_code + + +def test_simple1(): + code = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +@Pure +def factorial__spec(n : int) -> int : + Requires(n >= -1) + Ensures(Result() >= 0) + if n == -1: + return 1 + else: + return (n + 1) * factorial__spec(n - 1) + +@Pure +def sum__spec(n : int) -> int : + Requires(n >= -1) + Ensures(Result() >= 0) + if 0 > n: + return 0 + else: + return n + 1 + sum__spec(n - 1) + +def f(n : int) -> List[int]: + Requires((n) >= (1)) + Ensures(Acc(list_pred(Result()))) + Ensures((len(Result())) == (n)) + Ensures(Forall(int, lambda d_2_i_: + not ((((d_2_i_) >= (0)) and ((d_2_i_) < (len(Result())))) and (((d_2_i_ % 2)) == (0))) or (((Result())[d_2_i_]) == (factorial__spec(d_2_i_ - 1))))) + Ensures(Forall(int, lambda d_3_i_: + not ((((d_3_i_) >= (0)) and ((d_3_i_) < (len(Result())))) and (((d_3_i_ % 2)) != (0))) or (((Result())[d_3_i_]) == (sum__spec(d_3_i_ - 1))))) + + result = [0] * n + i = 0 + while sum__spec(i - 1) < n: + Invariant(Acc(list_pred(result))) + Invariant(0 <= i and i <= n) + Invariant(len(result) == n) + Invariant(sum__spec(i - 1) >= 0) + Invariant(Forall(int, lambda j: Implies(0 <= j and j < i, result[j] == (factorial__spec(j - 1) if j % 2 == 0 else sum__spec(j - 1))))) + Invariant(Forall(int, lambda j: Implies(i <= j and j < n, result[j] == 0))) + if i % 2 == 0: + result[i] = factorial__spec(i - 1) + else: + result[i] = sum__spec(i - 1) + i += 1 + return result +""" + ) + + calls, new_code = detect_and_replace_pure_calls_nagini(code, ["sum__spec"]) + + assert calls == ["factorial__spec"] + compare_code = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +@Pure +def factorial__spec(n: int) -> int: + Requires(n >= -1) + Ensures(Result() >= 0) + if n == -1: + return 1 + else: + return (n + 1) * factorial__spec(n - 1) + +@Pure +def sum__spec(n: int) -> int: + Requires(n >= -1) + Ensures(Result() >= 0) + if 0 > n: + return 0 + else: + return n + 1 + sum__spec(n - 1) + +def f(n: int) -> List[int]: + Requires(n >= 1) + Ensures(Acc(list_pred(Result()))) + Ensures(len(Result()) == n) + Ensures(Forall(int, lambda d_2_i_: not ((d_2_i_ >= 0 and d_2_i_ < len(Result())) and d_2_i_ % 2 == 0) or Result()[d_2_i_] == factorial__spec(d_2_i_ - 1))) + Ensures(Forall(int, lambda d_3_i_: not ((d_3_i_ >= 0 and d_3_i_ < len(Result())) and d_3_i_ % 2 != 0) or Result()[d_3_i_] == sum__spec(d_3_i_ - 1))) + result = [0] * n + i = 0 + while sum__spec(i - 1) < n: + Invariant(Acc(list_pred(result))) + Invariant(0 <= i and i <= n) + Invariant(len(result) == n) + Invariant(sum__spec(i - 1) >= 0) + Invariant(Forall(int, lambda j: Implies(0 <= j and j < i, result[j] == (factorial__spec(j - 1) if j % 2 == 0 else sum__spec(j - 1))))) + Invariant(Forall(int, lambda j: Implies(i <= j and j < n, result[j] == 0))) + if i % 2 == 0: + result[i] = invalid_call() + else: + result[i] = sum__spec(i - 1) + i += 1 + return result""" + ) + + assert new_code == compare_code + + +def test_find_pure_non_helpers(): + code = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +#use-as-unpure +@Pure +def factorial__spec(n : int) -> int : + Requires(n >= -1) + Ensures(Result() >= 0) + if n == -1: + return 1 + else: + return (n + 1) * factorial__spec(n - 1) + +@Pure +def sum__spec(n : int) -> int : + Requires(n >= -1) + Ensures(Result() >= 0) + if 0 > n: + return 0 + else: + return n + 1 + sum__spec(n - 1) + +def f(n : int) -> List[int]: + Requires((n) >= (1)) + Ensures(Acc(list_pred(Result()))) + Ensures((len(Result())) == (n)) + Ensures(Forall(int, lambda d_2_i_: + not ((((d_2_i_) >= (0)) and ((d_2_i_) < (len(Result())))) and (((d_2_i_ % 2)) == (0))) or (((Result())[d_2_i_]) == (factorial__spec(d_2_i_ - 1))))) + Ensures(Forall(int, lambda d_3_i_: + not ((((d_3_i_) >= (0)) and ((d_3_i_) < (len(Result())))) and (((d_3_i_ % 2)) != (0))) or (((Result())[d_3_i_]) == (sum__spec(d_3_i_ - 1))))) + + result = [0] * n + i = 0 + while sum__spec(i - 1) < n: + Invariant(Acc(list_pred(result))) + Invariant(0 <= i and i <= n) + Invariant(len(result) == n) + Invariant(sum__spec(i - 1) >= 0) + Invariant(Forall(int, lambda j: Implies(0 <= j and j < i, result[j] == (factorial__spec(j - 1) if j % 2 == 0 else sum__spec(j - 1))))) + Invariant(Forall(int, lambda j: Implies(i <= j and j < n, result[j] == 0))) + if i % 2 == 0: + result[i] = factorial__spec(i - 1) + else: + result[i] = sum__spec(i - 1) + i += 1 + return result""" + ) + + nagini_lang = LanguageDatabase().get("nagini") + + result: List[str] = ["factorial__spec"] + + assert result == nagini_lang.find_pure_non_helpers(code) \ No newline at end of file diff --git a/verified_cogen/experiments/incremental_run.py b/verified_cogen/experiments/incremental_run.py index 12a3149..0b1926a 100644 --- a/verified_cogen/experiments/incremental_run.py +++ b/verified_cogen/experiments/incremental_run.py @@ -74,7 +74,9 @@ def main(): verifier = Verifier(args.verifier_command) config = RunnerConfig( - log_tries=log_tries, include_text_descriptions=args.include_text_descriptions + log_tries=log_tries, + include_text_descriptions=args.include_text_descriptions, + remove_implementations=args.remove_implementations, ) for file in files: llm = LLM( diff --git a/verified_cogen/llm/llm.py b/verified_cogen/llm/llm.py index 9a872cf..8686a23 100644 --- a/verified_cogen/llm/llm.py +++ b/verified_cogen/llm/llm.py @@ -127,10 +127,18 @@ def add(self, prg: str, checks: str, function: Optional[str] = None) -> str: self.add_user_prompt(prompt, False) return self.make_request() - def rewrite(self, prg: str) -> str: - self.add_user_prompt( - prompts.rewrite_prompt(self.prompt_dir).replace("{program}", prg) - ) + def rewrite( + self, + prg: str, + text_description: Optional[str] = None, + additional_prompt: str = "", + ) -> str: + result = prompts.rewrite_prompt(self.prompt_dir).replace("{program}", prg) + if text_description is not None and "{text_description}" in result: + result = result.replace("{text_description}", text_description) + self.add_user_prompt(result) + if additional_prompt: + self.add_user_prompt(additional_prompt) return self.make_request() def ask_for_fixed(self, err: str) -> str: diff --git a/verified_cogen/llm/prompts.py b/verified_cogen/llm/prompts.py index 19650e1..bee4251 100644 --- a/verified_cogen/llm/prompts.py +++ b/verified_cogen/llm/prompts.py @@ -62,3 +62,11 @@ def ask_for_fixed_had_errors_prompt(prompt_dir: str) -> str: def ask_for_timeout_prompt(prompt_dir: str) -> str: return read_prompt(f"{prompt_dir}/timeout.txt") + + +def invalid_helpers_prompt(prompt_dir: str) -> str: + return read_prompt(f"{prompt_dir}/invalid_helpers.txt") + + +def helpers_prompt(prompt_dir: str) -> str: + return read_prompt(f"{prompt_dir}/helpers.txt") diff --git a/verified_cogen/runners/__init__.py b/verified_cogen/runners/__init__.py index 41a01f2..aee390a 100644 --- a/verified_cogen/runners/__init__.py +++ b/verified_cogen/runners/__init__.py @@ -14,14 +14,17 @@ class RunnerConfig: log_tries: Optional[pathlib.Path] = None include_text_descriptions: bool = False + remove_implementations: bool = False def __init__( self, log_tries: Optional[pathlib.Path] = None, include_text_descriptions: bool = False, + remove_implementations: bool = False, ): self.log_tries = log_tries self.include_text_descriptions = include_text_descriptions + self.remove_implementations = remove_implementations class Runner: @@ -41,7 +44,12 @@ def __init__( if self.config.log_tries is not None: self.config.log_tries.mkdir(exist_ok=True, parents=True) - def rewrite(self, prg: str, text_description: Optional[str] = None) -> str: + def rewrite( + self, + prg: str, + text_description: Optional[str] = None, + additional_prompt: str = "", + ) -> str: """Rewrite the program with additional checks in one step.""" ... diff --git a/verified_cogen/runners/flush.py b/verified_cogen/runners/flush.py new file mode 100644 index 0000000..b22c486 --- /dev/null +++ b/verified_cogen/runners/flush.py @@ -0,0 +1,65 @@ +from typing import Optional + +from verified_cogen.runners import Runner +from verified_cogen.tools import compare_errors +from verified_cogen.tools.modes import Mode + + +class FlushRunner(Runner): + previous_error: str = "" + timeout: str = "Verification timed out" + + wrapped_runner: Runner + + def __init__(self, wrapped_runner: Runner): + super().__init__( + wrapped_runner.llm, + wrapped_runner.logger, + wrapped_runner.verifier, + wrapped_runner.config, + ) + self.wrapped_runner = wrapped_runner + + def flush_and_rewrite(self) -> str: + assert self.starting_prg is not None + self.llm.wipe_all() + self.previous_error = "" + self.logger.info("Encountered same error. Rewrite") + return self.rewrite(self.starting_prg) + + def ask_for_timeout(self) -> str: + if compare_errors(self.previous_error, self.timeout): + return self.flush_and_rewrite() + else: + self.previous_error = self.timeout + return self.wrapped_runner.ask_for_timeout() + + def ask_for_fixed(self, err: str) -> str: + if compare_errors(self.previous_error, err): + return self.flush_and_rewrite() + else: + self.previous_error = err + return self.wrapped_runner.ask_for_fixed(err) + + def preprocess(self, prg: str, mode: Mode) -> str: + return self.wrapped_runner.preprocess(prg, mode) + + def postprocess(self, inv_prg: str) -> str: + return self.wrapped_runner.postprocess(inv_prg) + + def rewrite( + self, + prg: str, + text_description: Optional[str] = None, + additional_prompt: str = "", + ) -> str: + return self.wrapped_runner.rewrite(prg, text_description, additional_prompt) + + def produce(self, prg: str) -> str: + return self.wrapped_runner.produce(prg) + + def insert(self, prg: str, checks: str, mode: Mode) -> str: + return self.wrapped_runner.insert(prg, checks, mode) + + def precheck(self, prg: str, mode: Mode): + return self.wrapped_runner.precheck(prg, mode) diff --git a/verified_cogen/runners/generate.py b/verified_cogen/runners/generate.py index 5d4128b..27ad046 100644 --- a/verified_cogen/runners/generate.py +++ b/verified_cogen/runners/generate.py @@ -12,8 +12,13 @@ class GenerateRunner(Runner): - def rewrite(self, prg: str, text_description: Optional[str] = None) -> str: - return self.llm.rewrite(prg) + def rewrite( + self, + prg: str, + text_description: Optional[str] = None, + additional_prompt: str = "", + ) -> str: + return self.llm.rewrite(prg, text_description, additional_prompt) def produce(self, prg: str) -> str: raise ValueError("Produce not supported for generate") diff --git a/verified_cogen/runners/generic.py b/verified_cogen/runners/generic.py index 39055ac..85e14c9 100644 --- a/verified_cogen/runners/generic.py +++ b/verified_cogen/runners/generic.py @@ -9,8 +9,13 @@ class GenericRunner(Runner): - def rewrite(self, prg: str, text_description: Optional[str] = None) -> str: - return self.llm.rewrite(prg) + def rewrite( + self, + prg: str, + text_description: Optional[str] = None, + additional_prompt: str = "", + ) -> str: + return self.llm.rewrite(prg, text_description, additional_prompt) def produce(self, prg: str) -> str: return self.llm.produce(prg) diff --git a/verified_cogen/runners/invariants.py b/verified_cogen/runners/invariants.py index 88ec2f4..19bcb4c 100644 --- a/verified_cogen/runners/invariants.py +++ b/verified_cogen/runners/invariants.py @@ -44,8 +44,13 @@ def insert_invariants(llm: LLM, prg: str, inv: str, mode: Mode): class InvariantRunner(Runner): - def rewrite(self, prg: str, text_description: Optional[str] = None) -> str: - return self.llm.rewrite(prg) + def rewrite( + self, + prg: str, + text_description: Optional[str] = None, + additional_prompt: str = "", + ) -> str: + return self.llm.rewrite(prg, text_description, additional_prompt) def produce(self, prg: str) -> str: return self.llm.produce(prg) diff --git a/verified_cogen/runners/languages/language.py b/verified_cogen/runners/languages/language.py index 3d1d0de..5612c90 100644 --- a/verified_cogen/runners/languages/language.py +++ b/verified_cogen/runners/languages/language.py @@ -1,6 +1,6 @@ from abc import abstractmethod from enum import Enum -from typing import Any, Pattern +from typing import Any, Pattern, List, Tuple class AnnotationType(Enum): @@ -26,6 +26,14 @@ def remove_conditions(self, code: str) -> str: ... @abstractmethod def separate_validator_errors(self, errors: str) -> tuple[str, str]: ... + @abstractmethod + def check_helpers( + self, code: str, pure_non_helpers: [str] + ) -> Tuple[List[str], str]: ... + + @abstractmethod + def find_pure_non_helpers(self, code: str) -> [str]: ... + class GenericLanguage(Language): method_regex: Pattern[str] @@ -99,6 +107,14 @@ def remove_conditions(self, code: str) -> str: lines = [line for line in lines if self.inline_assert_comment not in line] return "\n".join(lines).strip() + def check_helpers( + self, code: str, pure_non_helpers: [str] + ) -> Tuple[List[str], str]: + return [], code + + def find_pure_non_helpers(self, code: str) -> [str]: + return [] + class LanguageDatabase: _instance = None diff --git a/verified_cogen/runners/languages/nagini.py b/verified_cogen/runners/languages/nagini.py index 06ac472..605d525 100644 --- a/verified_cogen/runners/languages/nagini.py +++ b/verified_cogen/runners/languages/nagini.py @@ -1,7 +1,8 @@ import re -from typing import Pattern +from typing import Pattern, List, Tuple from verified_cogen.runners.languages.language import AnnotationType, GenericLanguage +from verified_cogen.tools.pureCallsDetectors import detect_and_replace_pure_calls_nagini NAGINI_VALIDATOR_TEMPLATE = """\ def {method_name}_valid({parameters}) -> {returns}:{specs}\ @@ -23,7 +24,7 @@ def __init__(self, remove_annotations: list[AnnotationType]): # type: ignore } super().__init__( re.compile( - r"def\s+(\w+)\s*\((.*?)\)\s*->\s*(.*?):(.*?(\r\n|\r|\n))\s+# impl-start", + r"def\s+(\w+)\s*\((.*?)\)\s*->\s*(.*?):(.*?(\r\n|\r|\n))\s+# (impl-start|pure-start)", re.DOTALL, ), NAGINI_VALIDATOR_TEMPLATE, @@ -43,3 +44,19 @@ def separate_validator_errors(self, errors: str) -> tuple[str, str]: if "Verification successful" not in line and "Verification took" not in line ] return "\n".join(lines), "" + + def check_helpers( + self, code: str, pure_non_helpers: [str] + ) -> Tuple[List[str], str]: + return detect_and_replace_pure_calls_nagini(code, pure_non_helpers) + + def find_pure_non_helpers(self, code: str) -> [str]: + pattern: Pattern[str] = re.compile( + r"#use-as-unpure(\r\n|\r|\n)@Pure(\r\n|\r|\n)def\s+(\w+)\s*\((.*?)\)\s*->\s*(.*?):", + re.DOTALL, + ) + methods = list(pattern.finditer(code)) + non_helpers: list[str] = [] + for match in methods: + non_helpers.append(match.group(3)) + return non_helpers diff --git a/verified_cogen/runners/step_by_step.py b/verified_cogen/runners/step_by_step.py index 0bd3d82..ae37d22 100644 --- a/verified_cogen/runners/step_by_step.py +++ b/verified_cogen/runners/step_by_step.py @@ -30,7 +30,12 @@ def __init__(self, wrapping: Runner, config: Optional[StepByStepConfig] = None): def preprocess(self, prg: str, mode: Mode) -> str: return self.wrapped_runner.preprocess(prg, mode) - def rewrite(self, prg: str, text_description: Optional[str] = None) -> str: + def rewrite( + self, + prg: str, + text_description: Optional[str] = None, + additional_prompt: str = "", + ) -> str: return ( self.rewrite_full_examples(prg, text_description) if self._config.full_examples diff --git a/verified_cogen/runners/step_by_step_flush.py b/verified_cogen/runners/step_by_step_flush.py deleted file mode 100644 index 2444741..0000000 --- a/verified_cogen/runners/step_by_step_flush.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional -from verified_cogen.runners import Runner -from verified_cogen.runners.step_by_step import StepByStepRunner, StepByStepConfig -from verified_cogen.tools import compare_errors - - -class StepByStepFlushRunner(StepByStepRunner): - previous_error: str = "" - timeout: str = "Verification timed out" - - def __init__(self, wrapping: Runner, config: Optional[StepByStepConfig] = None): - super().__init__(wrapping, config) - - def flush_and_rewrite(self) -> str: - assert self.starting_prg is not None - self.llm.wipe_all() - self.previous_error = "" - self.logger.info("Encountered same error. Rewrite") - return self.rewrite(self.starting_prg) - - def ask_for_timeout(self) -> str: - if compare_errors(self.previous_error, self.timeout): - return self.flush_and_rewrite() - else: - self.previous_error = self.timeout - return self.wrapped_runner.ask_for_timeout() - - def ask_for_fixed(self, err: str) -> str: - if compare_errors(self.previous_error, err): - return self.flush_and_rewrite() - else: - self.previous_error = err - return self.wrapped_runner.ask_for_fixed(err) diff --git a/verified_cogen/runners/validating.py b/verified_cogen/runners/validating.py index a1c1e51..d2bd8d1 100644 --- a/verified_cogen/runners/validating.py +++ b/verified_cogen/runners/validating.py @@ -1,5 +1,6 @@ from typing import Optional +from verified_cogen.llm import prompts from verified_cogen.llm.llm import LLM from verified_cogen.runners import Runner from verified_cogen.runners.languages.language import Language @@ -10,6 +11,7 @@ class ValidatingRunner(Runner): wrapped_runner: Runner language: Language summarizer_llm: LLM + pure_non_helpers: [str] = [] def __init__( self, @@ -38,18 +40,48 @@ def _add_validators(self, prg: str, inv_prg: str): return val_prg def preprocess(self, prg: str, mode: Mode) -> str: + if self.config.remove_implementations: + self.pure_non_helpers = self.language.find_pure_non_helpers(prg) + self.logger.info( + "found pure_non_helpers: " + ",".join(self.pure_non_helpers) + ) res_prg = self.language.remove_conditions(prg) self.wrapped_runner.starting_prg = res_prg return res_prg def postprocess(self, inv_prg: str) -> str: assert self.starting_prg is not None + invalid_helpers: [str] = [] + try: + invalid_helpers, inv_prg = self.language.check_helpers( + inv_prg, self.pure_non_helpers + ) + self.logger.info("invalid_helpers: " + ",".join(invalid_helpers)) + except Exception: + self.logger.info("pass") + pass + if invalid_helpers: + self.llm.add_user_prompt( + prompts.invalid_helpers_prompt(self.llm.prompt_dir) + .replace("{invalid_helpers}", ",".join(invalid_helpers)) + .replace("{program}", inv_prg) + .replace("{helpers}", ",".join(self.pure_non_helpers)) + ) return self._add_validators( self.starting_prg, self.wrapped_runner.postprocess(inv_prg) ) - def rewrite(self, prg: str, text_description: Optional[str] = None) -> str: - return self.wrapped_runner.rewrite(prg, text_description) + def rewrite( + self, + prg: str, + text_description: Optional[str] = None, + additional_prompt: str = "", + ) -> str: + if self.config.remove_implementations and self.pure_non_helpers: + additional_prompt += prompts.helpers_prompt(self.llm.prompt_dir).replace( + "{helpers}", ",".join(self.pure_non_helpers) + ) + return self.wrapped_runner.rewrite(prg, text_description, additional_prompt) def produce(self, prg: str) -> str: return self.wrapped_runner.produce(prg) @@ -74,7 +106,7 @@ def ask_for_fixed(self, err: str) -> str: self.summarizer_llm.user_prompts = [] self.summarizer_llm.responses = [] result += ( - "Also, hidden validation errors occured, here is the summary:\n" + "Also, hidden validation errors occurred, here is the summary:\n" + validator_summary ) return self.wrapped_runner.ask_for_fixed(result) diff --git a/verified_cogen/tools/pureCallsDetectors.py b/verified_cogen/tools/pureCallsDetectors.py new file mode 100644 index 0000000..15d6869 --- /dev/null +++ b/verified_cogen/tools/pureCallsDetectors.py @@ -0,0 +1,88 @@ +import ast +from typing import Tuple, List + + +class PureFunctionCallReplacer(ast.NodeTransformer): + def __init__(self, pure_non_helpers: [str]): + self.pure_functions = list() + self.detected_calls = list() + self.current_function = None + self.in_pure_function = False + self.in_condition = False + self.pure_non_helpers = pure_non_helpers + + def visit_FunctionDef(self, node: ast.FunctionDef): + is_pure = any(decorator.id == "Pure" for decorator in node.decorator_list) + if is_pure and node.name not in self.pure_non_helpers: + self.pure_functions.append(node.name) + + prev_function = self.current_function + prev_in_pure = self.in_pure_function + self.current_function = node.name + self.in_pure_function = is_pure + + if not is_pure: + self.generic_visit(node) + + self.current_function = prev_function + self.in_pure_function = prev_in_pure + return node + + def visit_Call(self, node: ast.Call): + # if node.func: + # print(node.func.id) + if isinstance(node.func, ast.Name): + # print("A") + # print("A " + node.func.id) + if ( + node.func.id in self.pure_functions + and not self.in_pure_function + and not self.in_condition + and self.current_function is not None + ): + self.detected_calls.append(node.func.id) + return ast.Call( + func=ast.Name(id="invalid_call", ctx=ast.Load()), + args=[], + keywords=[], + ) + if node.func.id not in ["Invariant", "Assert", "Requires", "Ensures"]: + return self.generic_visit(node) + # print("B") + # print(node.func) + # print("B " + node.func.id) + return node + + def visit_If(self, node: ast.If): + prev_in_condition = self.in_condition + self.in_condition = True + node.test = self.visit(node.test) + self.in_condition = prev_in_condition + node.body = [self.visit(stmt) for stmt in node.body] + node.orelse = [self.visit(stmt) for stmt in node.orelse] + return node + + def visit_While(self, node: ast.While): + prev_in_condition = self.in_condition + self.in_condition = True + node.test = self.visit(node.test) + self.in_condition = prev_in_condition + node.body = [self.visit(stmt) for stmt in node.body] + return node + + def visit_Assert(self, node: ast.Assert): + prev_in_condition = self.in_condition + self.in_condition = True + node = self.generic_visit(node) + self.in_condition = prev_in_condition + return node + + +def detect_and_replace_pure_calls_nagini( + code: str, pure_non_helpers: [str] +) -> Tuple[List[str], str]: + tree = ast.parse(code) + replacer = PureFunctionCallReplacer(pure_non_helpers) + modified_tree = replacer.visit(tree) + new_code = ast.unparse(modified_tree) + return replacer.detected_calls, new_code