Skip to content

Commit

Permalink
feat: add support for enable_decimals for Vyper 0.4 (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Sep 28, 2024
1 parent a554daf commit f26db86
Show file tree
Hide file tree
Showing 13 changed files with 175 additions and 96 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ Import the voting contract types like this:
import voting.ballot as ballot
```

### Decimals

To use decimals on Vyper 0.4, use the following config:

```yaml
vyper:
enable_decimals: true
```
### Pragmas
Ape-Vyper supports Vyper 0.3.10's [new pragma formats](https://github.com/vyperlang/vyper/pull/3493)
Expand Down
7 changes: 5 additions & 2 deletions tests/ape-config.yaml → ape-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Allows compiling to work from the project-level.
contracts_folder: contracts/passing_contracts
contracts_folder: tests/contracts/passing_contracts

# Specify a dependency to use in Vyper imports.
dependencies:
- name: exampledependency
local: ./ExampleDependency
local: ./tests/ExampleDependency

# NOTE: Snekmate does not need to be listed here since
# it is installed in site-packages. However, we include it
# to show it doesn't cause problems when included.
- python: snekmate
config_override:
contracts_folder: .

vyper:
enable_decimals: true
3 changes: 1 addition & 2 deletions ape_vyper/compiler/_versions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ def get_settings(
optimization = False

selection_dict = self._get_selection_dictionary(selection, project=pm)
search_paths = [*getsitepackages()]
search_paths.append(".")
search_paths = [*getsitepackages(), "."]

version_settings[settings_key] = {
"optimize": optimization,
Expand Down
19 changes: 19 additions & 0 deletions ape_vyper/compiler/_versions/vyper_04.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,25 @@ def get_import_remapping(self, project: Optional[ProjectManager] = None) -> dict
# You always import via module or package name.
return {}

def get_settings(
self,
version: Version,
source_paths: Iterable[Path],
compiler_data: dict,
project: Optional[ProjectManager] = None,
) -> dict:
pm = project or self.local_project

enable_decimals = self.api.get_config(project=pm).enable_decimals
if enable_decimals is None:
enable_decimals = False

settings = super().get_settings(version, source_paths, compiler_data, project=pm)
for settings_set in settings.values():
settings_set["enable_decimals"] = enable_decimals

return settings

def _get_sources_dictionary(
self, source_ids: Iterable[str], project: Optional[ProjectManager] = None, **kwargs
) -> dict[str, dict]:
Expand Down
16 changes: 12 additions & 4 deletions ape_vyper/compiler/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,17 @@ def compile(
settings: Optional[dict] = None,
) -> Iterator[ContractType]:
pm = project or self.local_project

original_settings = self.compiler_settings
self.compiler_settings = {**self.compiler_settings, **(settings or {})}
try:
yield from self._compile(contract_filepaths, project=pm)
finally:
self.compiler_settings = original_settings

def _compile(
self, contract_filepaths: Iterable[Path], project: Optional[ProjectManager] = None
):
pm = project or self.local_project
contract_types: list[ContractType] = []
import_map = self._import_resolver.get_imports(pm, contract_filepaths)
config = self.get_config(pm)
Expand Down Expand Up @@ -514,12 +523,11 @@ def init_coverage_profile(
def enrich_error(self, err: ContractLogicError) -> ContractLogicError:
return enrich_error(err)

# TODO: In 0.9, make sure project is a kwarg here.
def trace_source(
self, contract_source: ContractSource, trace: TraceAPI, calldata: HexBytes
) -> SourceTraceback:
frames = trace.get_raw_frames()
tracer = SourceTracer(contract_source, frames, calldata)
return tracer.trace()
return SourceTracer.trace(trace.get_raw_frames(), contract_source, calldata)

def _get_compiler_arguments(
self,
Expand Down
7 changes: 7 additions & 0 deletions ape_vyper/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class VyperConfig(PluginConfig):
"""

enable_decimals: Optional[bool] = None
"""
On Vyper 0.4, to use decimal types, you must enable it.
Defaults to ``None`` to avoid misleading that ``False``
means you cannot use decimals on a lower version.
"""

@field_validator("version", mode="before")
def validate_version(cls, value):
return pragma_str_to_specifier_set(value) if isinstance(value, str) else value
Expand Down
7 changes: 5 additions & 2 deletions ape_vyper/flattener.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ape.logging import logger
from ape.managers import ProjectManager
from ape.utils import ManagerAccessMixin
from ape.utils import ManagerAccessMixin, get_relative_path
from ethpm_types.source import Content

from ape_vyper._utils import get_version_pragma_spec
Expand Down Expand Up @@ -65,7 +65,10 @@ def _flatten_source(
flattened_modules = ""
modules_prefixes: set[str] = set()

for import_path in sorted(imports):
# Source by source ID for greater consistency..
for import_path in sorted(
imports, key=lambda p: f"{get_relative_path(p.absolute(), pm.path)}"
):
import_info = imports[import_path]

# Vyper imported interface names come from their file names
Expand Down
26 changes: 6 additions & 20 deletions ape_vyper/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def __init__(self, project: ProjectManager, paths: list[Path]):
# Even though we build up mappings of all sources, as may be referenced
# later on and that prevents re-calculating over again, we only
# "show" the items requested.
self._request_view: list[Path] = paths
self.paths: list[Path] = paths

def __getitem__(self, item: Union[str, Path], *args, **kwargs) -> list[Import]:
if isinstance(item, str) or not item.is_absolute():
Expand Down Expand Up @@ -294,7 +294,7 @@ def keys(self) -> list[Path]: # type: ignore
result = []
keys = sorted(list(super().keys()))
for path in keys:
if path not in self._request_view:
if path not in self.paths:
continue

result.append(path)
Expand All @@ -311,7 +311,7 @@ def values(self) -> list[list[Import]]: # type: ignore
def items(self) -> list[tuple[Path, list[Import]]]: # type: ignore
result = []
for path in self.keys(): # sorted
if path not in self._request_view:
if path not in self.paths:
continue

result.append((path, self[path]))
Expand All @@ -328,30 +328,16 @@ class ImportResolver(ManagerAccessMixin):
_projects: dict[str, ImportMap] = {}
_dependency_attempted_compile: set[str] = set()

def get_imports(
self,
project: ProjectManager,
contract_filepaths: Iterable[Path],
) -> ImportMap:
def get_imports(self, project: ProjectManager, contract_filepaths: Iterable[Path]) -> ImportMap:
paths = list(contract_filepaths)
reset_view = None
if project.project_id not in self._projects:
self._projects[project.project_id] = ImportMap(project, paths)
else:
# Change the items we "view". Some (or all) may need to be added as well.
reset_view = self._projects[project.project_id]._request_view
self._projects[project.project_id]._request_view = paths

try:
import_map = self._get_imports(paths, project)
finally:
if reset_view is not None:
self._projects[project.project_id]._request_view = reset_view

return import_map
return self._get_imports(paths, project)

def _get_imports(self, paths: list[Path], project: ProjectManager) -> ImportMap:
import_map = self._projects[project.project_id]
import_map.paths = list({*import_map.paths, *paths})
for path in paths:
if path in import_map:
# Already handled.
Expand Down
Loading

0 comments on commit f26db86

Please sign in to comment.