Skip to content

Commit

Permalink
Merge pull request #177 from randovania/feature/formatting
Browse files Browse the repository at this point in the history
Use ruff's formatter
  • Loading branch information
henriquegemignani authored Aug 22, 2024
2 parents 6943fc8 + fec15d1 commit afb176d
Show file tree
Hide file tree
Showing 102 changed files with 5,136 additions and 4,662 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ repos:
rev: v0.6.1
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
args: [ --fix, --exit-non-zero-on-fix ]
- id: ruff-format
29 changes: 17 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,22 @@


def generate_property_templates():
subprocess.run([
sys.executable,
os.fspath(Path(__file__).parent.joinpath("tools", "create_class_definitions.py")),
"dread",
], check=True)
subprocess.run([
sys.executable,
os.fspath(Path(__file__).parent.joinpath("tools", "create_class_definitions.py")),
"sr",
], check=True)

subprocess.run(
[
sys.executable,
os.fspath(Path(__file__).parent.joinpath("tools", "create_class_definitions.py")),
"dread",
],
check=True,
)
subprocess.run(
[
sys.executable,
os.fspath(Path(__file__).parent.joinpath("tools", "create_class_definitions.py")),
"sr",
],
check=True,
)


class BuildPyCommand(build_py):
Expand All @@ -33,6 +38,6 @@ def run(self):

setup(
cmdclass={
'build_py': BuildPyCommand,
"build_py": BuildPyCommand,
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
# This function returns a list containing only the path to this
# directory, which is the location of these hooks.


def get_hook_dirs():
return [os.path.dirname(__file__)]
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

# https://pyinstaller.readthedocs.io/en/stable/hooks.html#provide-hooks-with-package

datas = collect_data_files('mercury_engine_data_structures', excludes=['__pyinstaller'])
datas = collect_data_files("mercury_engine_data_structures", excludes=["__pyinstaller"])
3 changes: 2 additions & 1 deletion src/mercury_engine_data_structures/_dread_data_construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class CompressedZSTD(construct.Tunnel):
def __init__(self, subcon, level: int = 3):
super().__init__(subcon)
import zstd

self.lib = zstd
self.level = level

Expand All @@ -26,7 +27,7 @@ def __init__(self):
construct.Sequence(
construct.PascalString(construct.Int16un, "ascii"), # key
construct.Int64un, # hash
)
),
)

def _parse(self, stream, context, path) -> typing.Dict[str, int]:
Expand Down
2 changes: 1 addition & 1 deletion src/mercury_engine_data_structures/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def decode_encode_compare_file(file_path: Path, game: Game, file_format: str):
resource = resource_class.parse(raw, target_game=game)
encoded = resource.build()

if raw != encoded and raw.rstrip(b"\xFF") != encoded:
if raw != encoded and raw.rstrip(b"\xff") != encoded:
return f"{file_path}: Results differ (len(raw): {len(raw)}; len(encoded): {len(encoded)})"
return None

Expand Down
112 changes: 70 additions & 42 deletions src/mercury_engine_data_structures/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,23 @@
CVector3D = construct.Array(3, Float)
CVector4D = construct.Array(4, Float)


class VersionAdapter(Adapter):
def __init__(self, value: int | str | tuple[int, int, int] | None = None):
if isinstance(value, str):
value = tuple([int(i) for i in value.split(".")])
elif isinstance(value, int):
value = struct.pack("<I", value)
value = struct.unpack("<HBB", value)
value = struct.unpack("<HBB", value)

if value is None:
subcon = construct.Struct(major=construct.Int16ul, minor=construct.Int8ul, patch=construct.Int8ul)
else:
major, minor, patch = value
subcon = construct.Struct(
major = construct.Const(major, construct.Int16ul),
minor = construct.Const(minor, construct.Int8ul),
patch = construct.Const(patch, construct.Int8ul)
major=construct.Const(major, construct.Int16ul),
minor=construct.Const(minor, construct.Int8ul),
patch=construct.Const(patch, construct.Int8ul),
)
super().__init__(subcon)

Expand All @@ -44,11 +45,8 @@ def _decode(self, obj, context, path):

def _encode(self, obj, context, path):
lst = [int(i) for i in obj.split(".")]
return {
"major": lst[0],
"minor": lst[1],
"patch": lst[2]
}
return {"major": lst[0], "minor": lst[1], "patch": lst[2]}


def _cvector_emitparse(length: int, code: construct.CodeGen) -> str:
"""Specialized construct compile for CVector2/3/4D"""
Expand Down Expand Up @@ -130,8 +128,12 @@ def __init__(self, subcon, *, allow_duplicates: bool = False):
super().__init__(subcon)
self.allow_duplicates = allow_duplicates

def _decode(self, obj: construct.ListContainer, context: construct.Container, path: str,
) -> construct.ListContainer | construct.Container:
def _decode(
self,
obj: construct.ListContainer,
context: construct.Container,
path: str,
) -> construct.ListContainer | construct.Container:
result = construct.Container()
for item in obj:
key = item.key
Expand All @@ -142,14 +144,15 @@ def _decode(self, obj: construct.ListContainer, context: construct.Container, pa
result[key] = item.value
return result

def _encode(self, obj: construct.ListContainer | construct.Container, context: construct.Container, path: str,
) -> list:
def _encode(
self,
obj: construct.ListContainer | construct.Container,
context: construct.Container,
path: str,
) -> list:
if self.allow_duplicates and isinstance(obj, list):
return obj
return construct.ListContainer(
construct.Container(key=type_, value=item)
for type_, item in obj.items()
)
return construct.ListContainer(construct.Container(key=type_, value=item) for type_, item in obj.items())

def _emitparse(self, code):
fname = f"parse_dict_adapter_{code.allocateId()}"
Expand Down Expand Up @@ -200,9 +203,15 @@ def __init__(self, field, key=StrId):

def _parse(self, stream, context, path):
context = construct.Container(
_=context, _params=context._params, _root=None, _parsing=context._parsing,
_building=context._building, _sizing=context._sizing, _io=stream,
_index=context.get("_index", None))
_=context,
_params=context._params,
_root=None,
_parsing=context._parsing,
_building=context._building,
_sizing=context._sizing,
_io=stream,
_index=context.get("_index", None),
)
context._root = context._.get("_root", context)

key = self.key._parsereport(stream, context, path)
Expand All @@ -215,9 +224,15 @@ def _parse(self, stream, context, path):

def _build(self, obj, stream, context, path):
context = construct.Container(
_=context, _params=context._params, _root=None, _parsing=context._parsing,
_building=context._building, _sizing=context._sizing, _io=stream,
_index=context.get("_index", None))
_=context,
_params=context._params,
_root=None,
_parsing=context._parsing,
_building=context._building,
_sizing=context._sizing,
_io=stream,
_index=context.get("_index", None),
)
context._root = context._.get("_root", context)

key = self.key._build(obj.key, stream, context, path)
Expand All @@ -230,9 +245,15 @@ def _build(self, obj, stream, context, path):

def _sizeof(self, context, path):
context = construct.Container(
_=context, _params=context._params, _root=None, _parsing=context._parsing,
_building=context._building, _sizing=context._sizing, _io=None,
_index=context.get("_index", None))
_=context,
_params=context._params,
_root=None,
_parsing=context._parsing,
_building=context._building,
_sizing=context._sizing,
_io=None,
_index=context.get("_index", None),
)
context._root = context._.get("_root", context)

key = self.key._sizeof(context, path)
Expand Down Expand Up @@ -279,8 +300,12 @@ def {fname}(obj, io, this):


class DictConstruct(construct.Construct):
def __init__(self, key_type: construct.Construct, value_type: construct.Construct,
count_type: construct.Construct = construct.Int32ul):
def __init__(
self,
key_type: construct.Construct,
value_type: construct.Construct,
count_type: construct.Construct = construct.Int32ul,
):
super().__init__()
self.key_type = key_type
self.value_type = value_type
Expand Down Expand Up @@ -326,9 +351,11 @@ def {fname}(key, value, io, this):
{self.value_type._compilebuild(code)}
"""
code.append(block)
return (f"(reuse(len(obj), "
f"lambda obj: {self.count_type._compilebuild(code)}), "
f"list({fname}(key, value, io, this) for key, value in obj.items()), obj)[2]")
return (
f"(reuse(len(obj), "
f"lambda obj: {self.count_type._compilebuild(code)}), "
f"list({fname}(key, value, io, this) for key, value in obj.items()), obj)[2]"
)


def make_dict(value: construct.Construct, key=StrId):
Expand All @@ -355,30 +382,31 @@ def make_vector(value: construct.Construct):
if hasattr(value, "_emitparse_vector"):
_emitparse = value._emitparse_vector
else:

def _emitparse(code: construct.CodeGen) -> str:
return (f"ListContainer(({value._compileparse(code)}) "
f"for i in range({construct.Int32ul._compileparse(code)}))")
return (
f"ListContainer(({value._compileparse(code)}) "
f"for i in range({construct.Int32ul._compileparse(code)}))"
)

result._emitparse = _emitparse

def _emitbuild(code):
return (f"(reuse(len(obj), lambda obj: {construct.Int32ul._compilebuild(code)}),"
f" list({value._compilebuild(code)} for obj in obj), obj)[2]")
return (
f"(reuse(len(obj), lambda obj: {construct.Int32ul._compilebuild(code)}),"
f" list({value._compilebuild(code)} for obj in obj), obj)[2]"
)

result._emitbuild = _emitbuild

return result


def make_enum(values: typing.Union[typing.List[str], typing.Dict[str, int]], *,
add_invalid: bool = True):
def make_enum(values: typing.Union[typing.List[str], typing.Dict[str, int]], *, add_invalid: bool = True):
if isinstance(values, dict):
mapping = copy.copy(values)
else:
mapping = {
name: i
for i, name in enumerate(values)
}
mapping = {name: i for i, name in enumerate(values)}
if add_invalid:
mapping["Invalid"] = 0x7fffffff
mapping["Invalid"] = 0x7FFFFFFF
return construct.Enum(construct.Int32ul, **mapping)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class AlignTo(Construct):
def __init__(self, modulus, pattern = b"\x00"):
def __init__(self, modulus, pattern=b"\x00"):
super().__init__()
self.modulus = modulus
self.pattern = pattern
Expand All @@ -29,7 +29,7 @@ def _build(self, obj, stream, context, path):


class AlignedPrefixed(Subconstruct):
def __init__(self, length_field, subcon, modulus, length_size, pad_byte=b"\xFF"):
def __init__(self, length_field, subcon, modulus, length_size, pad_byte=b"\xff"):
super().__init__(subcon)
self.length_field = length_field
self.modulus = modulus
Expand Down Expand Up @@ -112,13 +112,14 @@ def PrefixedAllowZeroLen(lengthfield, subcon, includelengthfield=False):
return FocusedSeq(
"prefixed",
"len" / Peek(lengthfield),
"prefixed" / Prefixed(
"prefixed"
/ Prefixed(
lengthfield,
IfThenElse(
construct.this._parsing,
If(construct.this.len > 0, subcon),
If(construct.this.prefixed, subcon),
),
includelengthfield
)
includelengthfield,
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ def _encode(self, obj: typing.Union[str, enum.IntEnum, int], context, path) -> i
def _emitbuild(self, code: construct.CodeGen):
i = code.allocateId()

mapping = ", ".join(
f"{repr(enum_entry.name)}: {enum_entry.value}"
for enum_entry in self.enum_class
)
mapping = ", ".join(f"{repr(enum_entry.name)}: {enum_entry.value}" for enum_entry in self.enum_class)

code.append(f"""
_enum_name_to_value_{i} = {{{mapping}}}
Expand All @@ -45,5 +42,5 @@ def _encode_enum_{i}(obj, io, this):
def BitMaskEnum(enum_type: typing.Type[enum.IntEnum]):
flags = {}
for enumentry in enum_type:
flags[enumentry.name] = 2 ** enumentry.value
flags[enumentry.name] = 2**enumentry.value
return construct.FlagsEnum(construct.Int32ul, **flags)
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def _resolve_id(type_class: Union[construct.Construct, Type[construct.Construct]


def emit_switch_cases_parse(
code: construct.CodeGen,
fields: Dict[Union[str, int], Union[construct.Construct, Type[construct.Construct]]],
custom_table_name: Optional[str] = None,
code: construct.CodeGen,
fields: Dict[Union[str, int], Union[construct.Construct, Type[construct.Construct]]],
custom_table_name: Optional[str] = None,
) -> str:
"""Construct codegen helper for handling the switch cases dict in _emitparse."""
table_name = custom_table_name
Expand Down Expand Up @@ -47,9 +47,9 @@ def {code_name}(io, this):


def emit_switch_cases_build(
code: construct.CodeGen,
fields: Dict[Union[str, int], Union[construct.Construct, Type[construct.Construct]]],
custom_table_name: Optional[str] = None,
code: construct.CodeGen,
fields: Dict[Union[str, int], Union[construct.Construct, Type[construct.Construct]]],
custom_table_name: Optional[str] = None,
) -> str:
"""Construct codegen helper for handling the switch cases dict in _emitbuild."""
table_name = custom_table_name
Expand Down Expand Up @@ -113,8 +113,10 @@ def _emitparse(self, code):
)

def _emitbuild(self, code):
return (f"(({self.thensubcon._compilebuild(code)}) if ("
f"{self._insert_cond(code)}) else ({self.elsesubcon._compilebuild(code)}))")
return (
f"(({self.thensubcon._compilebuild(code)}) if ("
f"{self._insert_cond(code)}) else ({self.elsesubcon._compilebuild(code)}))"
)


def ComplexIf(condfunc, subcon):
Expand Down
Loading

0 comments on commit afb176d

Please sign in to comment.