From 94650a02de62956eee8e7e41f61e8a41506b5842 Mon Sep 17 00:00:00 2001 From: Silvris <58583688+Silvris@users.noreply.github.com> Date: Tue, 19 Mar 2024 17:08:29 -0500 Subject: [PATCH] Core: implement APProcedurePatch and APTokenMixin (#2536) * initial work on procedure patch * more flexibility load default procedure for version 5 patches add args for procedure add default extension for tokens and bsdiff allow specifying additional required extensions for generation * pushing current changes to go fix tloz bug * move tokens into a separate inheritable class * forgot the commit to remove token from ProcedurePatch * further cleaning from bad commit * start on docstrings * further work on docstrings and typing * improve docstrings * fix incorrect docstring * cleanup * clean defaults and docstring * define interface that has only the bare minimum required for `Patch.create_rom_file` * change to dictionary.get * remove unnecessary if statement * update to explicitly check for procedure, restore compatible version and manual override * Update Files.py * remove struct uses * ensure returning bytes, add token type checking * Apply suggestions from code review Co-authored-by: Doug Hoskisson * pep8 --------- Co-authored-by: beauxq Co-authored-by: Doug Hoskisson --- worlds/Files.py | 284 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 248 insertions(+), 36 deletions(-) diff --git a/worlds/Files.py b/worlds/Files.py index b2ecb9afb880..6fee582c872d 100644 --- a/worlds/Files.py +++ b/worlds/Files.py @@ -3,10 +3,11 @@ import abc import json import zipfile +from enum import IntEnum import os import threading -from typing import ClassVar, Dict, List, Literal, Tuple, Any, Optional, Union, BinaryIO +from typing import ClassVar, Dict, List, Literal, Tuple, Any, Optional, Union, BinaryIO, overload import bsdiff4 @@ -38,6 +39,32 @@ def get_handler(file: str) -> Optional[AutoPatchRegister]: return None +class AutoPatchExtensionRegister(abc.ABCMeta): + extension_types: ClassVar[Dict[str, AutoPatchExtensionRegister]] = {} + required_extensions: List[str] = [] + + def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoPatchExtensionRegister: + # construct class + new_class = super().__new__(mcs, name, bases, dct) + if "game" in dct: + AutoPatchExtensionRegister.extension_types[dct["game"]] = new_class + return new_class + + @staticmethod + def get_handler(game: str) -> Union[AutoPatchExtensionRegister, List[AutoPatchExtensionRegister]]: + handler = AutoPatchExtensionRegister.extension_types.get(game, APPatchExtension) + if handler.required_extensions: + handlers = [handler] + for required in handler.required_extensions: + ext = AutoPatchExtensionRegister.extension_types.get(required) + if not ext: + raise NotImplementedError(f"No handler for {required}.") + handlers.append(ext) + return handlers + else: + return handler + + container_version: int = 6 @@ -157,27 +184,14 @@ def patch(self, target: str) -> None: """ create the output file with the file name `target` """ -class APDeltaPatch(APAutoPatchInterface): - """An implementation of `APAutoPatchInterface` that additionally - has delta.bsdiff4 containing a delta patch to get the desired file.""" - +class APProcedurePatch(APAutoPatchInterface): + """ + An APPatch that defines a procedure to produce the desired file. + """ hash: Optional[str] # base checksum of source file - patch_file_ending: str = "" - delta: Optional[bytes] = None source_data: bytes - procedure = None # delete this line when APPP is added - - def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any) -> None: - self.patched_path = patched_path - super(APDeltaPatch, self).__init__(*args, **kwargs) - - def get_manifest(self) -> Dict[str, Any]: - manifest = super(APDeltaPatch, self).get_manifest() - manifest["base_checksum"] = self.hash - manifest["result_file_ending"] = self.result_file_ending - manifest["patch_file_ending"] = self.patch_file_ending - manifest["compatible_version"] = 5 # delete this line when APPP is added - return manifest + patch_file_ending: str = "" + files: Dict[str, bytes] = {} @classmethod def get_source_data(cls) -> bytes: @@ -190,21 +204,219 @@ def get_source_data_with_cache(cls) -> bytes: cls.source_data = cls.get_source_data() return cls.source_data + def __init__(self, *args: Any, **kwargs: Any): + super(APProcedurePatch, self).__init__(*args, **kwargs) + + def get_manifest(self) -> Dict[str, Any]: + manifest = super(APProcedurePatch, self).get_manifest() + manifest["base_checksum"] = self.hash + manifest["result_file_ending"] = self.result_file_ending + manifest["patch_file_ending"] = self.patch_file_ending + manifest["procedure"] = self.procedure + if self.procedure == APDeltaPatch.procedure: + manifest["compatible_version"] = 5 + return manifest + + def read_contents(self, opened_zipfile: zipfile.ZipFile) -> None: + super(APProcedurePatch, self).read_contents(opened_zipfile) + with opened_zipfile.open("archipelago.json", "r") as f: + manifest = json.load(f) + if "procedure" not in manifest: + # support patching files made before moving to procedures + self.procedure = [("apply_bsdiff4", ["delta.bsdiff4"])] + else: + self.procedure = manifest["procedure"] + for file in opened_zipfile.namelist(): + if file not in ["archipelago.json"]: + self.files[file] = opened_zipfile.read(file) + + def write_contents(self, opened_zipfile: zipfile.ZipFile) -> None: + super(APProcedurePatch, self).write_contents(opened_zipfile) + for file in self.files: + opened_zipfile.writestr(file, self.files[file], + compress_type=zipfile.ZIP_STORED if file.endswith(".bsdiff4") else None) + + def get_file(self, file: str) -> bytes: + """ Retrieves a file from the patch container.""" + if file not in self.files: + self.read() + return self.files[file] + + def write_file(self, file_name: str, file: bytes) -> None: + """ Writes a file to the patch container, to be retrieved upon patching. """ + self.files[file_name] = file + + def patch(self, target: str) -> None: + self.read() + base_data = self.get_source_data_with_cache() + patch_extender = AutoPatchExtensionRegister.get_handler(self.game) + assert not isinstance(self.procedure, str), f"{type(self)} must define procedures" + for step, args in self.procedure: + if isinstance(patch_extender, list): + extension = next((item for item in [getattr(extender, step, None) for extender in patch_extender] + if item is not None), None) + else: + extension = getattr(patch_extender, step, None) + if extension is not None: + base_data = extension(self, base_data, *args) + else: + raise NotImplementedError(f"Unknown procedure {step} for {self.game}.") + with open(target, 'wb') as f: + f.write(base_data) + + +class APDeltaPatch(APProcedurePatch): + """An APProcedurePatch that additionally has delta.bsdiff4 + containing a delta patch to get the desired file, often a rom.""" + + procedure = [ + ("apply_bsdiff4", ["delta.bsdiff4"]) + ] + + def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any) -> None: + super(APDeltaPatch, self).__init__(*args, **kwargs) + self.patched_path = patched_path + def write_contents(self, opened_zipfile: zipfile.ZipFile): + self.write_file("delta.bsdiff4", + bsdiff4.diff(self.get_source_data_with_cache(), open(self.patched_path, "rb").read())) super(APDeltaPatch, self).write_contents(opened_zipfile) - # write Delta - opened_zipfile.writestr("delta.bsdiff4", - bsdiff4.diff(self.get_source_data_with_cache(), open(self.patched_path, "rb").read()), - compress_type=zipfile.ZIP_STORED) # bsdiff4 is a format with integrated compression - - def read_contents(self, opened_zipfile: zipfile.ZipFile): - super(APDeltaPatch, self).read_contents(opened_zipfile) - self.delta = opened_zipfile.read("delta.bsdiff4") - - def patch(self, target: str): - """Base + Delta -> Patched""" - if not self.delta: - self.read() - result = bsdiff4.patch(self.get_source_data_with_cache(), self.delta) - with open(target, "wb") as f: - f.write(result) + + +class APTokenTypes(IntEnum): + WRITE = 0 + COPY = 1 + RLE = 2 + AND_8 = 3 + OR_8 = 4 + XOR_8 = 5 + + +class APTokenMixin: + """ + A class that defines functions for generating a token binary, for use in patches. + """ + tokens: List[ + Tuple[APTokenTypes, int, Union[ + bytes, # WRITE + Tuple[int, int], # COPY, RLE + int # AND_8, OR_8, XOR_8 + ]]] = [] + + def get_token_binary(self) -> bytes: + """ + Returns the token binary created from stored tokens. + :return: A bytes object representing the token data. + """ + data = bytearray() + data.extend(len(self.tokens).to_bytes(4, "little")) + for token_type, offset, args in self.tokens: + data.append(token_type) + data.extend(offset.to_bytes(4, "little")) + if token_type in [APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8]: + assert isinstance(args, int), f"Arguments to AND/OR/XOR must be of type int, not {type(args)}" + data.extend(int.to_bytes(1, 4, "little")) + data.append(args) + elif token_type in [APTokenTypes.COPY, APTokenTypes.RLE]: + assert isinstance(args, tuple), f"Arguments to COPY/RLE must be of type tuple, not {type(args)}" + data.extend(int.to_bytes(4, 4, "little")) + data.extend(args[0].to_bytes(4, "little")) + data.extend(args[1].to_bytes(4, "little")) + elif token_type == APTokenTypes.WRITE: + assert isinstance(args, bytes), f"Arguments to WRITE must be of type bytes, not {type(args)}" + data.extend(len(args).to_bytes(4, "little")) + data.extend(args) + else: + raise ValueError(f"Unknown token type {token_type}") + return bytes(data) + + @overload + def write_token(self, + token_type: Literal[APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8], + offset: int, + data: int) -> None: + ... + + @overload + def write_token(self, + token_type: Literal[APTokenTypes.COPY, APTokenTypes.RLE], + offset: int, + data: Tuple[int, int]) -> None: + ... + + @overload + def write_token(self, + token_type: Literal[APTokenTypes.WRITE], + offset: int, + data: bytes) -> None: + ... + + def write_token(self, token_type: APTokenTypes, offset: int, data: Union[bytes, Tuple[int, int], int]): + """ + Stores a token to be used by patching. + """ + self.tokens.append((token_type, offset, data)) + + +class APPatchExtension(metaclass=AutoPatchExtensionRegister): + """Class that defines patch extension functions for a given game. + Patch extension functions must have the following two arguments in the following order: + + caller: APProcedurePatch (used to retrieve files from the patch container) + + rom: bytes (the data to patch) + + Further arguments are passed in from the procedure as defined. + + Patch extension functions must return the changed bytes. + """ + game: str + required_extensions: List[str] = [] + + @staticmethod + def apply_bsdiff4(caller: APProcedurePatch, rom: bytes, patch: str): + """Applies the given bsdiff4 from the patch onto the current file.""" + return bsdiff4.patch(rom, caller.get_file(patch)) + + @staticmethod + def apply_tokens(caller: APProcedurePatch, rom: bytes, token_file: str) -> bytes: + """Applies the given token file from the patch onto the current file.""" + token_data = caller.get_file(token_file) + rom_data = bytearray(rom) + token_count = int.from_bytes(token_data[0:4], "little") + bpr = 4 + for _ in range(token_count): + token_type = token_data[bpr:bpr + 1][0] + offset = int.from_bytes(token_data[bpr + 1:bpr + 5], "little") + size = int.from_bytes(token_data[bpr + 5:bpr + 9], "little") + data = token_data[bpr + 9:bpr + 9 + size] + if token_type in [APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8]: + arg = data[0] + if token_type == APTokenTypes.AND_8: + rom_data[offset] = rom_data[offset] & arg + elif token_type == APTokenTypes.OR_8: + rom_data[offset] = rom_data[offset] | arg + else: + rom_data[offset] = rom_data[offset] ^ arg + elif token_type in [APTokenTypes.COPY, APTokenTypes.RLE]: + length = int.from_bytes(data[:4], "little") + value = int.from_bytes(data[4:], "little") + if token_type == APTokenTypes.COPY: + rom_data[offset: offset + length] = rom_data[value: value + length] + else: + rom_data[offset: offset + length] = bytes([value] * length) + else: + rom_data[offset:offset + len(data)] = data + bpr += 9 + size + return bytes(rom_data) + + @staticmethod + def calc_snes_crc(caller: APProcedurePatch, rom: bytes): + """Calculates and applies a valid CRC for the SNES rom header.""" + rom_data = bytearray(rom) + if len(rom) < 0x8000: + raise Exception("Tried to calculate SNES CRC on file too small to be a SNES ROM.") + crc = (sum(rom_data[:0x7FDC] + rom_data[0x7FE0:]) + 0x01FE) & 0xFFFF + inv = crc ^ 0xFFFF + rom_data[0x7FDC:0x7FE0] = [inv & 0xFF, (inv >> 8) & 0xFF, crc & 0xFF, (crc >> 8) & 0xFF] + return bytes(rom_data)