Skip to content

Commit

Permalink
feat: contract flattener (#107)
Browse files Browse the repository at this point in the history
* feat: contract flattener

* fix: add missing dep vyper

* fix: try lowering vyper min version for Python 3.8 support

* fix(test): install necessary vyper versions in test

* fix: compiler version handling in compiler.compile_code()

* style(lint): unused import

* fix(test): explicitly install compiler in test_pc_map

* feat: adds `ape vyper flatten` command

* test: add CLI test for flattener

* chore: bump minimum eth-ape version to 0.7.12

* docs: adds Contract Flattening section to README

* chore: bump eth-ape minimum ersion to 0.7.13

* fix(docs): fix warning directive in README

* docs: update return value docstring

Co-authored-by: antazoey <[email protected]>

* docs: speeling

Co-authored-by: antazoey <[email protected]>

* refactor: splitlines()

* refactor: limit `vyper flatten` command to vyper only

Co-authored-by: antazoey <[email protected]>

* fix(docs): bad bug report link

Co-authored-by: antazoey <[email protected]>

* style(docs): period

Co-authored-by: antazoey <[email protected]>

* fix(docs): comment spelling

Co-authored-by: antazoey <[email protected]>

* fix(docs): comment spelling

Co-authored-by: antazoey <[email protected]>

* refactor: splitlines

Co-authored-by: antazoey <[email protected]>

* fix: missing type hint

Co-authored-by: antazoey <[email protected]>

* docs: how to format returns

Co-authored-by: antazoey <[email protected]>

* refactor: not is None

Co-authored-by: antazoey <[email protected]>

* refactor: check with installed versions when compiling before installing during no-pragma fallback

---------

Co-authored-by: antazoey <[email protected]>
  • Loading branch information
mikeshultz and antazoey authored Mar 22, 2024
1 parent 3555cb4 commit 15b7598
Show file tree
Hide file tree
Showing 10 changed files with 530 additions and 10 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ ape compile

The `.vy` files in your project will compile into `ContractTypes` that you can deploy and interact with in Ape.

### Contract Flattening

For ease of publishing, validation, and some other cases it's sometimes useful to "flatten" your contract into a single file.
This combines your contract and any imported interfaces together in a way the compiler can understand.
You can do so with a command like this:

```bash
ape vyper flatten contracts/MyContract.vy build/MyContractFlattened.vy
```

> \[!WARNING\]
> This feature is experimental. Please [report any bugs](https://github.com/ApeWorX/ape-solidity/issues/new?assignees=&labels=bug&projects=&template=bug.md) you find when trying it out.
### Compiler Version

By default, the `ape-vyper` plugin uses version pragma for version specification.
Expand Down
25 changes: 25 additions & 0 deletions ape_vyper/_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathlib import Path

import ape
import click
from ape.cli import ape_cli_context


@click.group
def cli():
"""`vyper` command group"""


@cli.command(short_help="Flatten select contract source files")
@ape_cli_context()
@click.argument("CONTRACT", type=click.Path(exists=True, resolve_path=True))
@click.argument("OUTFILE", type=click.Path(exists=False, resolve_path=True, writable=True))
def flatten(cli_ctx, contract: Path, outfile: Path):
"""
Flatten a contract into a single file
"""
with Path(outfile).open("w") as fout:
content = ape.compilers.vyper.flatten_contract(
Path(contract), base_path=ape.project.contracts_folder
)
fout.write(str(content))
101 changes: 101 additions & 0 deletions ape_vyper/ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Utilities for dealing with Vyper AST"""

from typing import List

from ethpm_types import ABI, MethodABI
from ethpm_types.abi import ABIType
from vyper.ast import parse_to_ast # type: ignore
from vyper.ast.nodes import FunctionDef, Module, Name, Subscript # type: ignore

DEFAULT_VYPER_MUTABILITY = "nonpayable"
DECORATOR_MUTABILITY = {
"pure", # Function does not read contract state or environment variables
"view", # Function does not alter contract state
"payable", # Function is able to receive Ether and may alter state
"nonpayable", # Function may alter sate
}


def funcdef_decorators(funcdef: FunctionDef) -> List[str]:
return [d.id for d in funcdef.get("decorator_list") or []]


def funcdef_inputs(funcdef: FunctionDef) -> List[ABIType]:
"""Get a FunctionDef's defined input args"""
args = funcdef.get("args")
# TODO: Does Vyper allow complex input types, like structs and arrays?
return (
[ABIType.model_validate({"name": arg.arg, "type": arg.annotation.id}) for arg in args.args]
if args
else []
)


def funcdef_outputs(funcdef: FunctionDef) -> List[ABIType]:
"""Get a FunctionDef's outputs, or return values"""
returns = funcdef.get("returns")

if not returns:
return []

if isinstance(returns, Name):
# TODO: Structs fall in here. I think they're supposed to be a tuple of types in the ABI.
# Need to dig into that more.
return [ABIType.model_validate({"type": returns.id})]

elif isinstance(returns, Subscript):
# An array type
length = returns.slice.value.value
array_type = returns.value.id
# TOOD: Is this an acurrate way to define a fixed length array for ABI?
return [ABIType.model_validate({"type": f"{array_type}[{length}]"})]

raise NotImplementedError(f"Unhandled return type {type(returns)}")


def funcdef_state_mutability(funcdef: FunctionDef) -> str:
"""Get a FunctionDef's declared state mutability"""
for decorator in funcdef_decorators(funcdef):
if decorator in DECORATOR_MUTABILITY:
return decorator
return DEFAULT_VYPER_MUTABILITY


def funcdef_is_external(funcdef: FunctionDef) -> bool:
"""Check if a FunctionDef is declared external"""
for decorator in funcdef_decorators(funcdef):
if decorator == "external":
return True
return False


def funcdef_to_abi(func: FunctionDef) -> ABI:
"""Return a MethodABI instance for a Vyper FunctionDef"""
return MethodABI.model_validate(
{
"name": func.get("name"),
"inputs": funcdef_inputs(func),
"outputs": funcdef_outputs(func),
"stateMutability": funcdef_state_mutability(func),
}
)


def module_to_abi(module: Module) -> List[ABI]:
"""
Create a list of MethodABIs from a Vyper AST Module instance.
"""
abi = []
for child in module.get_children():
if isinstance(child, FunctionDef):
abi.append(funcdef_to_abi(child))
return abi


def source_to_abi(source: str) -> List[ABI]:
"""
Given Vyper source code, return a list of Ape ABI elements needed for an external interface.
This currently does not include complex types or events.
"""
module = parse_to_ast(source)
return module_to_abi(module)
150 changes: 143 additions & 7 deletions ape_vyper/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fnmatch import fnmatch
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast

import vvm # type: ignore
from ape.api import PluginConfig
Expand All @@ -20,21 +20,29 @@
from ethpm_types import ASTNode, PackageManifest, PCMap, SourceMapItem
from ethpm_types.ast import ASTClassification
from ethpm_types.contract_type import SourceMap
from ethpm_types.source import Compiler, ContractSource, Function, SourceLocation
from ethpm_types.source import Compiler, Content, ContractSource, Function, SourceLocation
from evm_trace.enums import CALL_OPCODES
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from packaging.version import Version
from pydantic import field_serializer, field_validator
from vvm import compile_standard as vvm_compile_standard
from vvm.exceptions import VyperError # type: ignore

from ape_vyper.ast import source_to_abi
from ape_vyper.exceptions import (
RUNTIME_ERROR_MAP,
IntegerBoundsCheck,
RuntimeErrorType,
VyperCompileError,
VyperInstallError,
)
from ape_vyper.interface import (
extract_import_aliases,
extract_imports,
extract_meta,
generate_interface,
iface_name_from_file,
)

DEV_MSG_PATTERN = re.compile(r".*\s*#\s*(dev:.+)")
_RETURN_OPCODES = ("RETURN", "REVERT", "STOP")
Expand Down Expand Up @@ -333,12 +341,11 @@ def config_version_pragma(self) -> Optional[SpecifierSet]:
return None

@property
def import_remapping(self) -> Dict[str, Dict]:
def remapped_manifests(self) -> Dict[str, PackageManifest]:
"""
Configured interface imports from dependencies.
Interface import manifests.
"""

interfaces = {}
dependencies: Dict[str, PackageManifest] = {}

for remapping in self.settings.import_remapping:
Expand Down Expand Up @@ -366,7 +373,19 @@ def import_remapping(self) -> Dict[str, Dict]:
dependency = dependency_versions[version].compile()
dependencies[remapping] = dependency

for name, ct in (dependency.contract_types or {}).items():
return dependencies

@property
def import_remapping(self) -> Dict[str, Dict]:
"""
Configured interface imports from dependencies.
"""

interfaces = {}

for remapping in self.settings.import_remapping:
key, _ = remapping.split("=")
for name, ct in (self.remapped_manifests[remapping].contract_types or {}).items():
interfaces[f"{key}/{name}.json"] = {
"abi": [x.model_dump(mode="json", by_alias=True) for x in ct.abi]
}
Expand Down Expand Up @@ -518,8 +537,14 @@ def compile(

def compile_code(self, code: str, base_path: Optional[Path] = None, **kwargs) -> ContractType:
base_path = base_path or self.project_manager.contracts_folder

# Figure out what compiler version we need for this contract...
version = self._source_vyper_version(code)
# ...and install it if necessary
_install_vyper(version)

try:
result = vvm.compile_source(code, base_path=base_path)
result = vvm.compile_source(code, base_path=base_path, vyper_version=version)
except Exception as err:
raise VyperCompileError(str(err)) from err

Expand All @@ -531,6 +556,117 @@ def compile_code(self, code: str, base_path: Optional[Path] = None, **kwargs) ->
**kwargs,
)

def _source_vyper_version(self, code: str) -> Version:
"""Given source code, figure out which Vyper version to use"""
version_spec = get_version_pragma_spec(code)

def first_full_release(versions: Iterable[Version]) -> Optional[Version]:
for vers in versions:
if not vers.is_devrelease and not vers.is_postrelease and not vers.is_prerelease:
return vers
return None

if version_spec is None:
if version := first_full_release(self.installed_versions + self.available_versions):
return version
raise VyperInstallError("No available version.")

return next(version_spec.filter(self.available_versions))

def _flatten_source(
self, path: Path, base_path: Optional[Path] = None, raw_import_name: Optional[str] = None
) -> str:
base_path = base_path or self.config_manager.contracts_folder

# Get the non stdlib import paths for our contracts
imports = list(
filter(
lambda x: not x.startswith("vyper/"),
[y for x in self.get_imports([path], base_path).values() for y in x],
)
)

dependencies: Dict[str, PackageManifest] = {}
for key, manifest in self.remapped_manifests.items():
package = key.split("=")[0]

if manifest.sources is None:
continue

for source_id in manifest.sources.keys():
import_match = f"{package}/{source_id}"
dependencies[import_match] = manifest

flattened_source = ""
interfaces_source = ""
og_source = (base_path / path).read_text()

# Get info about imports and source meta
aliases = extract_import_aliases(og_source)
pragma, source_without_meta = extract_meta(og_source)
stdlib_imports, _, source_without_imports = extract_imports(source_without_meta)

for import_path in sorted(imports):
import_file = base_path / import_path

# Vyper imported interface names come from their file names
file_name = iface_name_from_file(import_file)
# If we have a known alias, ("import X as Y"), use the alias as interface name
iface_name = aliases[file_name] if file_name in aliases else file_name

# We need to compare without extensions because sometimes they're made up for some
# reason. TODO: Cleaner way to deal with this?
def _match_source(import_path: str) -> Optional[PackageManifest]:
import_path_name = ".".join(import_path.split(".")[:-1])
for source_path in dependencies.keys():
if source_path.startswith(import_path_name):
return dependencies[source_path]
return None

if matched_source := _match_source(import_path):
if not matched_source.contract_types:
continue

abis = [
el
for k in matched_source.contract_types.keys()
for el in matched_source.contract_types[k].abi
]
interfaces_source += generate_interface(abis, iface_name)
continue

# Vyper imported interface names come from their file names
file_name = iface_name_from_file(import_file)
# Generate an ABI from the source code
abis = source_to_abi(import_file.read_text())
interfaces_source += generate_interface(abis, iface_name)

def no_nones(it: Iterable[Optional[str]]) -> Iterable[str]:
# Type guard like generator to remove Nones and make mypy happy
for el in it:
if el is not None:
yield el

# Join all the OG and generated parts back together
flattened_source = "\n\n".join(
no_nones((pragma, stdlib_imports, interfaces_source, source_without_imports))
)

# TODO: Replace this nonsense with a real code formatter
def format_source(source: str) -> str:
while "\n\n\n\n" in source:
source = source.replace("\n\n\n\n", "\n\n\n")
return source

return format_source(flattened_source)

def flatten_contract(self, path: Path, base_path: Optional[Path] = None) -> Content:
"""
Returns the flattened contract suitable for compilation or verification as a single file
"""
source = self._flatten_source(path, base_path, path.name)
return Content({i: ln for i, ln in enumerate(source.splitlines())})

def get_version_map(
self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None
) -> Dict[Version, Set[Path]]:
Expand Down
Loading

0 comments on commit 15b7598

Please sign in to comment.