diff --git a/tests/functional/test_compiler.py b/tests/functional/test_compiler.py index f088be8..07fd984 100644 --- a/tests/functional/test_compiler.py +++ b/tests/functional/test_compiler.py @@ -203,37 +203,35 @@ def test_compile_project(project): assert "older_version" in actual -# -# -# @pytest.mark.parametrize("contract_name", PASSING_CONTRACT_NAMES) -# def test_compile_individual_contracts(project, contract_name, compiler): -# path = project.contracts_folder / contract_name -# assert list(compiler.compile((path,), project=project)) -# -# -# @pytest.mark.parametrize( -# "contract_name", [n for n in FAILING_CONTRACT_NAMES if n != "contract_unknown_pragma.vy"] -# ) -# def test_compile_failures(contract_name, compiler): -# failing_project = ape.Project(FAILING_BASE) -# path = FAILING_BASE / contract_name -# with pytest.raises(VyperCompileError, match=EXPECTED_FAIL_PATTERNS[path.stem]) as err: -# list(compiler.compile((path,), project=failing_project)) -# -# assert isinstance(err.value.base_err, VyperError) -# -# -# def test_compile_zero_four(compiler, project): -# """ -# An easy way to test only Vyper 0.4 changes. -# """ -# paths = ( -# project.contracts_folder / "subdir" / "zero_four_in_subdir.vy", -# project.contracts_folder / "zero_four.vy", -# ) -# result = [x.name for x in compiler.compile(paths, project=project)] -# assert "zero_four" in result -# assert "zero_four_in_subdir" in result +@pytest.mark.parametrize("contract_name", PASSING_CONTRACT_NAMES) +def test_compile_individual_contracts(project, contract_name, compiler): + path = project.contracts_folder / contract_name + assert list(compiler.compile((path,), project=project)) + + +@pytest.mark.parametrize( + "contract_name", [n for n in FAILING_CONTRACT_NAMES if n != "contract_unknown_pragma.vy"] +) +def test_compile_failures(contract_name, compiler): + failing_project = ape.Project(FAILING_BASE) + path = FAILING_BASE / contract_name + with pytest.raises(VyperCompileError, match=EXPECTED_FAIL_PATTERNS[path.stem]) as err: + list(compiler.compile((path,), project=failing_project)) + + assert isinstance(err.value.base_err, VyperError) + + +def test_compile_zero_four(compiler, project): + """ + An easy way to test only Vyper 0.4 changes. + """ + paths = ( + project.contracts_folder / "subdir" / "zero_four_in_subdir.vy", + project.contracts_folder / "zero_four.vy", + ) + result = [x.name for x in compiler.compile(paths, project=project)] + assert "zero_four" in result + assert "zero_four_in_subdir" in result def test_install_failure(compiler): @@ -312,323 +310,322 @@ def test_get_version_map(project, compiler, all_versions): assert actual4 == expected4 -# -# def test_compiler_data_in_manifest(project): -# def run_test(manifest): -# assert len(manifest.compilers) >= 3, manifest.compilers -# -# all_latest_03 = [ -# c for c in manifest.compilers if str(c.version) == str(VERSION_FROM_PRAGMA) -# ] -# evm_opt = [c for c in all_latest_03 if c.settings.get("evmVersion") == "paris"][0] -# gas_opt = [c for c in all_latest_03 if c.settings["optimize"] == "gas"][0] -# true_opt = [ -# c -# for c in manifest.compilers -# if c.settings["optimize"] is True and "non_payable_default" in c.contractTypes -# ][0] -# codesize_opt = [ -# c -# for c in all_latest_03 -# if c.settings["optimize"] == "codesize" and c.settings.get("evmVersion") != "paris" -# ][0] -# vyper_028 = [ -# c for c in manifest.compilers if str(c.version) == str(OLDER_VERSION_FROM_PRAGMA) -# ][0] -# -# for compiler in (vyper_028, codesize_opt): -# assert compiler.name == "vyper" -# -# assert vyper_028.settings["evmVersion"] == "berlin" -# assert codesize_opt.settings["evmVersion"] == "shanghai" -# -# # There is only one contract with evm-version pragma. -# assert evm_opt.contractTypes == ["evm_pragma"] -# assert evm_opt.settings.get("evmVersion") == "paris" -# -# assert "optimize_codesize" in codesize_opt.contractTypes -# assert "older_version" in vyper_028.contractTypes -# assert len(gas_opt.contractTypes) >= 1 -# assert "non_payable_default" in true_opt.contractTypes -# -# project.update_manifest(compilers=[]) -# project.load_contracts(use_cache=False) -# run_test(project.manifest) -# man = project.extract_manifest() -# run_test(man) -# -# -# def test_compile_parse_dev_messages(compiler, dev_revert_source, project): -# """ -# Test parsing of dev messages in a contract. These follow the form of "#dev: ...". -# -# The compiler will output a map that maps dev messages to line numbers. -# See contract_with_dev_messages.vy for more information. -# """ -# result = list(compiler.compile((dev_revert_source,), project=project)) -# -# assert len(result) == 1 -# -# contract = result[0] -# -# assert contract.dev_messages is not None -# assert len(contract.dev_messages) == 4 -# assert contract.dev_messages[6] == "dev: foo" -# assert contract.dev_messages[9] == "dev: bar" -# assert contract.dev_messages[16] == "dev: baz" -# assert contract.dev_messages[20] == "dev: 你好,猿" -# assert 23 not in contract.dev_messages -# -# -# def test_get_imports(compiler, project): -# # Ensure the dependency starts off un-compiled so we can show this -# # is the point at which it will be compiled. We make sure to only -# # compile when we know it is a JSON interface based dependency -# # and not a site-package or relative-path based dependency. -# dependency = project.dependencies["exampledependency"]["local"] -# dependency.manifest.contract_types = {} -# -# vyper_files = [ -# x for x in project.contracts_folder.iterdir() if x.is_file() and x.suffix == ".vy" -# ] -# actual = compiler.get_imports(vyper_files, project=project) -# -# prefix = "contracts/passing_contracts" -# builtin_import = "vyper/interfaces/ERC20.json" -# local_import = "IFace.vy" -# local_from_import = "IFace2.vy" -# local_nested_import = "IFaceNested.vy" -# dependency_import = "Dependency.vy" -# -# # The source IDs end up as absolute paths because they are in tempdir -# # (not direct local project) and because of Vyper 0.4 reasons, we need -# # this to be the case. And we don't know the version map yet at this point. -# contract_37_key = [k for k in actual if f"{prefix}/contract_037.vy" in k][0] -# use_iface_key = [k for k in actual if f"{prefix}/use_iface.vy" in k][0] -# use_iface2_key = [k for k in actual if f"{prefix}/use_iface2.vy" in k][0] -# -# assert set(actual[contract_37_key]) == {builtin_import} -# -# actual_iface_use = actual[use_iface_key] -# for expected in (local_import, local_from_import, dependency_import, local_nested_import): -# assert any(k for k in actual_iface_use if expected in k), f"{expected} not found" -# -# assert actual[use_iface2_key][0].endswith(local_import) -# -# -# @pytest.mark.parametrize("src,vers", [("contract_039", "0.3.9"), ("contract_037", "0.3.7")]) -# def test_pc_map(compiler, project, src, vers): -# """ -# Ensure we de-compress the source map correctly by comparing to the results -# from `compile_src()` which includes the uncompressed source map data. -# """ -# -# path = project.sources.lookup(src) -# result = list(compiler.compile((path,), project=project))[0] -# actual = result.pcmap.root -# code = path.read_text(encoding="utf8") -# vvm.install_vyper(vers) -# cfg = compiler.get_config(project=project) -# evm_version = cfg.evm_version -# compile_result = vvm.compile_source(code, vyper_version=vers, evm_version=evm_version) -# std_result = compile_result[""] -# src_map = std_result["source_map"] -# lines = code.splitlines() -# -# # Use the old-fashioned way of gathering PCMap to ensure our creative way works -# expected = {pc: {"location": ln} for pc, ln in src_map["pc_pos_map"].items()} -# missing_pcs = [] -# empty_locs = [] -# wrong_locs = [] -# for expected_pc, item_dict in expected.items(): -# expected_loc = item_dict["location"] -# -# # Collect matching locations. -# matching_locs = [] -# for mpc, loc in actual.items(): -# if loc["location"] == expected_loc: -# matching_locs.append(mpc) -# -# if expected_pc not in actual: -# missing_pcs.append((expected_pc, expected_loc, matching_locs)) -# continue -# -# if actual[expected_pc]["location"] is None: -# empty_locs.append((expected_pc, expected_loc, matching_locs)) -# continue -# -# if actual[expected_pc]["location"] != expected_loc: -# wrong_locs.append((expected_pc, expected_loc, matching_locs)) -# -# limit = 10 # Only show first ten failures of each category. -# -# def make_failure(title, ls): -# fail_format = "PC={pc}, Expected={ex} (actual matches={match})" -# suffix = ", ".join([fail_format.format(pc=m, ex=e, match=mat) for m, e, mat in ls[:limit]]) -# return f"{title}: {suffix}" -# -# failures = [] -# if len(missing_pcs) != 0: -# failures.append((missing_pcs[0][0], make_failure("Missing PCs", missing_pcs))) -# if len(empty_locs) != 0: -# failures.append((empty_locs[0][0], make_failure("Empty locations", empty_locs))) -# if len(wrong_locs) != 0: -# failures.append((wrong_locs[0][0], make_failure("Wrong locations", wrong_locs))) -# -# # Show first failures to occur first. -# failures.sort(key=lambda x: x[0]) -# -# assert len(failures) == 0, "\n".join([x[1] for x in failures]) -# -# # Test helper methods. -# def _all(check): -# return [x for x in actual.values() if x.get("dev") == f"dev: {check.value}"] -# -# def line(cont: str) -> int: -# # A helper for getting expected line numbers -# return [i + 1 for i, x in enumerate(lines) if cont in x][0] -# -# # Verify non-payable checks. -# nonpayable_checks = _all(RuntimeErrorType.NONPAYABLE_CHECK) -# if nonpayable_checks: -# assert len(nonpayable_checks) >= 1 -# else: -# # NOTE: Vyper 0.3.10 doesn't have these anymore. -# # But they do have a new error type instead. -# checks = _all(RuntimeErrorType.INVALID_CALLDATA_OR_VALUE) -# assert len(checks) >= 1 -# -# # Verify integer overflow checks -# overflows = _all(RuntimeErrorType.INTEGER_OVERFLOW) -# overflow_no = line("return (2**127-1) + i") -# expected_overflow_loc = [overflow_no, 12, overflow_no, 20] -# assert len(overflows) >= 2 -# -# if vers == "0.3.7": -# assert expected_overflow_loc in [o["location"] for o in overflows if o["location"]] -# # else: 0.3.9 registers as IntegerBoundsCheck -# -# # Verify integer underflow checks -# underflows = _all(RuntimeErrorType.INTEGER_UNDERFLOW) -# underflow_no = line("return i - (2**127-1)") -# expected_underflow_loc = [underflow_no, 11, underflow_no, 25] -# assert len(underflows) >= 2 -# -# if vers == "0.3.7": -# assert expected_underflow_loc in [u["location"] for u in underflows if u["location"]] -# # else: 0.3.9 registers as IntegerBoundsCheck -# -# # Verify division by zero checks -# div_zeros = _all(RuntimeErrorType.DIVISION_BY_ZERO) -# div_no = line("return 4 / i") -# expected_div_0 = [div_no, 11, div_no, 16] -# -# if vers == "0.3.7": -# assert len(div_zeros) >= 1 -# assert expected_div_0 in [d["location"] for d in div_zeros if d["location"]] -# # TODO: figure out how to detect these on 0.3.9 -# -# # Verify modulo by zero checks -# mod_zeros = _all(RuntimeErrorType.MODULO_BY_ZERO) -# mod_no = line("return 4 % i") -# expected_mod_0_loc = [mod_no, 11, mod_no, 16] -# assert len(mod_zeros) >= 1 -# assert expected_mod_0_loc in [m["location"] for m in mod_zeros if m["location"]] -# -# # Verify index out of range checks -# range_checks = _all(RuntimeErrorType.INDEX_OUT_OF_RANGE) -# range_no = line("return self.dynArray[idx]") -# expected_range_check = [range_no, 11, range_no, 24] -# if vers == "0.3.7": -# assert len(range_checks) >= 1 -# assert expected_range_check in [r["location"] for r in range_checks] -# # TODO: figure out how to detect these on 0.3.9 -# -# -# def test_enrich_error_int_overflow(geth_provider, traceback_contract, account): -# int_max = 2**256 - 1 -# with pytest.raises(IntegerOverflowError): -# traceback_contract.addBalance(int_max, sender=account) -# -# -# def test_enrich_error_non_payable_check(geth_provider, traceback_contract, account): -# if traceback_contract.contract_type.name.endswith("0310"): -# # NOTE: Nonpayable error is combined with calldata check now. -# with pytest.raises(InvalidCalldataOrValueError): -# traceback_contract.addBalance(123, sender=account, value=1) -# -# else: -# with pytest.raises(NonPayableError): -# traceback_contract.addBalance(123, sender=account, value=1) -# -# -# def test_enrich_error_fallback(geth_provider, traceback_contract, account): -# """ -# Show that when attempting to call a contract's fallback method when there is -# no fallback defined results in a custom contract logic error. -# """ -# with pytest.raises(FallbackNotDefinedError): -# traceback_contract(sender=account) -# -# -# def test_enrich_error_handle_when_name(compiler, geth_provider, mocker): -# """ -# Sometimes, a provider may use the name of the enum instead of the value, -# which we are still able to enrich. -# """ -# -# class TB(SourceTraceback): -# @property -# def revert_type(self) -> Optional[str]: -# return "NONPAYABLE_CHECK" -# -# tb = TB([{"statements": [], "closure": {"name": "fn"}, "depth": 0}]) # type: ignore -# error = ContractLogicError(None, source_traceback=tb) -# new_error = compiler.enrich_error(error) -# assert isinstance(new_error, NonPayableError) -# -# -# @pytest.mark.parametrize("arguments", [(), (123,), (123, 321)]) -# def test_trace_source(account, geth_provider, project, traceback_contract, arguments): -# receipt = traceback_contract.addBalance(*arguments, sender=account) -# actual = receipt.source_traceback -# base_folder = Path(__file__).parent.parent / "contracts" / "passing_contracts" -# contract_name = traceback_contract.contract_type.name -# expected = rf""" -# Traceback (most recent call last) -# File {base_folder}/{contract_name}.vy, in addBalance -# 32 if i != num: -# 33 continue -# 34 -# --> 35 return self._balance -# """.strip() -# assert str(actual) == expected -# -# -# def test_trace_source_content_from_kwarg_default_parametrization( -# account, geth_provider, project, traceback_contract -# ): -# """ -# This test is for verifying stuff around Vyper auto-generated methods from kwarg defaults. -# Mostly, need to make sure the correct content is discoverable in the source traceback -# so that coverage works properly. -# """ -# no_args_tx = traceback_contract.addBalance(sender=account) -# no_args_tb = no_args_tx.source_traceback -# -# def check(name: str, tb): -# items = [x.closure.full_name for x in tb if x.closure.full_name == name] -# assert len(items) >= 1 -# -# check("addBalance()", no_args_tb) -# -# single_arg_tx = traceback_contract.addBalance(442, sender=account) -# single_arg_tb = single_arg_tx.source_traceback -# check("addBalance(uint256)", single_arg_tb) -# -# both_args_tx = traceback_contract.addBalance(4, 5, sender=account) -# both_args_tb = both_args_tx.source_traceback -# check("addBalance(uint256,uint256)", both_args_tb) +def test_compiler_data_in_manifest(project): + def run_test(manifest): + assert len(manifest.compilers) >= 3, manifest.compilers + + all_latest_03 = [ + c for c in manifest.compilers if str(c.version) == str(VERSION_FROM_PRAGMA) + ] + evm_opt = [c for c in all_latest_03 if c.settings.get("evmVersion") == "paris"][0] + gas_opt = [c for c in all_latest_03 if c.settings["optimize"] == "gas"][0] + true_opt = [ + c + for c in manifest.compilers + if c.settings["optimize"] is True and "non_payable_default" in c.contractTypes + ][0] + codesize_opt = [ + c + for c in all_latest_03 + if c.settings["optimize"] == "codesize" and c.settings.get("evmVersion") != "paris" + ][0] + vyper_028 = [ + c for c in manifest.compilers if str(c.version) == str(OLDER_VERSION_FROM_PRAGMA) + ][0] + + for compiler in (vyper_028, codesize_opt): + assert compiler.name == "vyper" + + assert vyper_028.settings["evmVersion"] == "berlin" + assert codesize_opt.settings["evmVersion"] == "shanghai" + + # There is only one contract with evm-version pragma. + assert evm_opt.contractTypes == ["evm_pragma"] + assert evm_opt.settings.get("evmVersion") == "paris" + + assert "optimize_codesize" in codesize_opt.contractTypes + assert "older_version" in vyper_028.contractTypes + assert len(gas_opt.contractTypes) >= 1 + assert "non_payable_default" in true_opt.contractTypes + + project.update_manifest(compilers=[]) + project.load_contracts(use_cache=False) + run_test(project.manifest) + man = project.extract_manifest() + run_test(man) + + +def test_compile_parse_dev_messages(compiler, dev_revert_source, project): + """ + Test parsing of dev messages in a contract. These follow the form of "#dev: ...". + + The compiler will output a map that maps dev messages to line numbers. + See contract_with_dev_messages.vy for more information. + """ + result = list(compiler.compile((dev_revert_source,), project=project)) + + assert len(result) == 1 + + contract = result[0] + + assert contract.dev_messages is not None + assert len(contract.dev_messages) == 4 + assert contract.dev_messages[6] == "dev: foo" + assert contract.dev_messages[9] == "dev: bar" + assert contract.dev_messages[16] == "dev: baz" + assert contract.dev_messages[20] == "dev: 你好,猿" + assert 23 not in contract.dev_messages + + +def test_get_imports(compiler, project): + # Ensure the dependency starts off un-compiled so we can show this + # is the point at which it will be compiled. We make sure to only + # compile when we know it is a JSON interface based dependency + # and not a site-package or relative-path based dependency. + dependency = project.dependencies["exampledependency"]["local"] + dependency.manifest.contract_types = {} + + vyper_files = [ + x for x in project.contracts_folder.iterdir() if x.is_file() and x.suffix == ".vy" + ] + actual = compiler.get_imports(vyper_files, project=project) + + prefix = "contracts/passing_contracts" + builtin_import = "vyper/interfaces/ERC20.json" + local_import = "IFace.vy" + local_from_import = "IFace2.vy" + local_nested_import = "IFaceNested.vy" + dependency_import = "Dependency.vy" + + # The source IDs end up as absolute paths because they are in tempdir + # (not direct local project) and because of Vyper 0.4 reasons, we need + # this to be the case. And we don't know the version map yet at this point. + contract_37_key = [k for k in actual if f"{prefix}/contract_037.vy" in k][0] + use_iface_key = [k for k in actual if f"{prefix}/use_iface.vy" in k][0] + use_iface2_key = [k for k in actual if f"{prefix}/use_iface2.vy" in k][0] + + assert set(actual[contract_37_key]) == {builtin_import} + + actual_iface_use = actual[use_iface_key] + for expected in (local_import, local_from_import, dependency_import, local_nested_import): + assert any(k for k in actual_iface_use if expected in k), f"{expected} not found" + + assert actual[use_iface2_key][0].endswith(local_import) + + +@pytest.mark.parametrize("src,vers", [("contract_039", "0.3.9"), ("contract_037", "0.3.7")]) +def test_pc_map(compiler, project, src, vers): + """ + Ensure we de-compress the source map correctly by comparing to the results + from `compile_src()` which includes the uncompressed source map data. + """ + + path = project.sources.lookup(src) + result = list(compiler.compile((path,), project=project))[0] + actual = result.pcmap.root + code = path.read_text(encoding="utf8") + vvm.install_vyper(vers) + cfg = compiler.get_config(project=project) + evm_version = cfg.evm_version + compile_result = vvm.compile_source(code, vyper_version=vers, evm_version=evm_version) + std_result = compile_result[""] + src_map = std_result["source_map"] + lines = code.splitlines() + + # Use the old-fashioned way of gathering PCMap to ensure our creative way works + expected = {pc: {"location": ln} for pc, ln in src_map["pc_pos_map"].items()} + missing_pcs = [] + empty_locs = [] + wrong_locs = [] + for expected_pc, item_dict in expected.items(): + expected_loc = item_dict["location"] + + # Collect matching locations. + matching_locs = [] + for mpc, loc in actual.items(): + if loc["location"] == expected_loc: + matching_locs.append(mpc) + + if expected_pc not in actual: + missing_pcs.append((expected_pc, expected_loc, matching_locs)) + continue + + if actual[expected_pc]["location"] is None: + empty_locs.append((expected_pc, expected_loc, matching_locs)) + continue + + if actual[expected_pc]["location"] != expected_loc: + wrong_locs.append((expected_pc, expected_loc, matching_locs)) + + limit = 10 # Only show first ten failures of each category. + + def make_failure(title, ls): + fail_format = "PC={pc}, Expected={ex} (actual matches={match})" + suffix = ", ".join([fail_format.format(pc=m, ex=e, match=mat) for m, e, mat in ls[:limit]]) + return f"{title}: {suffix}" + + failures = [] + if len(missing_pcs) != 0: + failures.append((missing_pcs[0][0], make_failure("Missing PCs", missing_pcs))) + if len(empty_locs) != 0: + failures.append((empty_locs[0][0], make_failure("Empty locations", empty_locs))) + if len(wrong_locs) != 0: + failures.append((wrong_locs[0][0], make_failure("Wrong locations", wrong_locs))) + + # Show first failures to occur first. + failures.sort(key=lambda x: x[0]) + + assert len(failures) == 0, "\n".join([x[1] for x in failures]) + + # Test helper methods. + def _all(check): + return [x for x in actual.values() if x.get("dev") == f"dev: {check.value}"] + + def line(cont: str) -> int: + # A helper for getting expected line numbers + return [i + 1 for i, x in enumerate(lines) if cont in x][0] + + # Verify non-payable checks. + nonpayable_checks = _all(RuntimeErrorType.NONPAYABLE_CHECK) + if nonpayable_checks: + assert len(nonpayable_checks) >= 1 + else: + # NOTE: Vyper 0.3.10 doesn't have these anymore. + # But they do have a new error type instead. + checks = _all(RuntimeErrorType.INVALID_CALLDATA_OR_VALUE) + assert len(checks) >= 1 + + # Verify integer overflow checks + overflows = _all(RuntimeErrorType.INTEGER_OVERFLOW) + overflow_no = line("return (2**127-1) + i") + expected_overflow_loc = [overflow_no, 12, overflow_no, 20] + assert len(overflows) >= 2 + + if vers == "0.3.7": + assert expected_overflow_loc in [o["location"] for o in overflows if o["location"]] + # else: 0.3.9 registers as IntegerBoundsCheck + + # Verify integer underflow checks + underflows = _all(RuntimeErrorType.INTEGER_UNDERFLOW) + underflow_no = line("return i - (2**127-1)") + expected_underflow_loc = [underflow_no, 11, underflow_no, 25] + assert len(underflows) >= 2 + + if vers == "0.3.7": + assert expected_underflow_loc in [u["location"] for u in underflows if u["location"]] + # else: 0.3.9 registers as IntegerBoundsCheck + + # Verify division by zero checks + div_zeros = _all(RuntimeErrorType.DIVISION_BY_ZERO) + div_no = line("return 4 / i") + expected_div_0 = [div_no, 11, div_no, 16] + + if vers == "0.3.7": + assert len(div_zeros) >= 1 + assert expected_div_0 in [d["location"] for d in div_zeros if d["location"]] + # TODO: figure out how to detect these on 0.3.9 + + # Verify modulo by zero checks + mod_zeros = _all(RuntimeErrorType.MODULO_BY_ZERO) + mod_no = line("return 4 % i") + expected_mod_0_loc = [mod_no, 11, mod_no, 16] + assert len(mod_zeros) >= 1 + assert expected_mod_0_loc in [m["location"] for m in mod_zeros if m["location"]] + + # Verify index out of range checks + range_checks = _all(RuntimeErrorType.INDEX_OUT_OF_RANGE) + range_no = line("return self.dynArray[idx]") + expected_range_check = [range_no, 11, range_no, 24] + if vers == "0.3.7": + assert len(range_checks) >= 1 + assert expected_range_check in [r["location"] for r in range_checks] + # TODO: figure out how to detect these on 0.3.9 + + +def test_enrich_error_int_overflow(geth_provider, traceback_contract, account): + int_max = 2**256 - 1 + with pytest.raises(IntegerOverflowError): + traceback_contract.addBalance(int_max, sender=account) + + +def test_enrich_error_non_payable_check(geth_provider, traceback_contract, account): + if traceback_contract.contract_type.name.endswith("0310"): + # NOTE: Nonpayable error is combined with calldata check now. + with pytest.raises(InvalidCalldataOrValueError): + traceback_contract.addBalance(123, sender=account, value=1) + + else: + with pytest.raises(NonPayableError): + traceback_contract.addBalance(123, sender=account, value=1) + + +def test_enrich_error_fallback(geth_provider, traceback_contract, account): + """ + Show that when attempting to call a contract's fallback method when there is + no fallback defined results in a custom contract logic error. + """ + with pytest.raises(FallbackNotDefinedError): + traceback_contract(sender=account) + + +def test_enrich_error_handle_when_name(compiler, geth_provider, mocker): + """ + Sometimes, a provider may use the name of the enum instead of the value, + which we are still able to enrich. + """ + + class TB(SourceTraceback): + @property + def revert_type(self) -> Optional[str]: + return "NONPAYABLE_CHECK" + + tb = TB([{"statements": [], "closure": {"name": "fn"}, "depth": 0}]) # type: ignore + error = ContractLogicError(None, source_traceback=tb) + new_error = compiler.enrich_error(error) + assert isinstance(new_error, NonPayableError) + + +@pytest.mark.parametrize("arguments", [(), (123,), (123, 321)]) +def test_trace_source(account, geth_provider, project, traceback_contract, arguments): + receipt = traceback_contract.addBalance(*arguments, sender=account) + actual = receipt.source_traceback + base_folder = Path(__file__).parent.parent / "contracts" / "passing_contracts" + contract_name = traceback_contract.contract_type.name + expected = rf""" +Traceback (most recent call last) + File {base_folder}/{contract_name}.vy, in addBalance + 32 if i != num: + 33 continue + 34 + --> 35 return self._balance +""".strip() + assert str(actual) == expected + + +def test_trace_source_content_from_kwarg_default_parametrization( + account, geth_provider, project, traceback_contract +): + """ + This test is for verifying stuff around Vyper auto-generated methods from kwarg defaults. + Mostly, need to make sure the correct content is discoverable in the source traceback + so that coverage works properly. + """ + no_args_tx = traceback_contract.addBalance(sender=account) + no_args_tb = no_args_tx.source_traceback + + def check(name: str, tb): + items = [x.closure.full_name for x in tb if x.closure.full_name == name] + assert len(items) >= 1 + + check("addBalance()", no_args_tb) + + single_arg_tx = traceback_contract.addBalance(442, sender=account) + single_arg_tb = single_arg_tx.source_traceback + check("addBalance(uint256)", single_arg_tb) + + both_args_tx = traceback_contract.addBalance(4, 5, sender=account) + both_args_tb = both_args_tx.source_traceback + check("addBalance(uint256,uint256)", both_args_tb) def test_trace_err_source(account, geth_provider, project, traceback_contract): diff --git a/tests/functional/test_coverage.py b/tests/functional/test_coverage.py index 55acd9e..478c7a3 100644 --- a/tests/functional/test_coverage.py +++ b/tests/functional/test_coverage.py @@ -5,9 +5,8 @@ from typing import Optional import pytest - -from ape.utils import create_tempdir from ape import Project +from ape.utils import create_tempdir LINES_VALID = 8 MISSES = 0