Skip to content

Commit

Permalink
Core: implement APProcedurePatch and APTokenMixin (ArchipelagoMW#2536)
Browse files Browse the repository at this point in the history
* initial work on procedure patch

* more flexibility

load default procedure for version 5 patches
add args for procedure
add default extension for tokens and bsdiff
allow specifying additional required extensions for generation

* pushing current changes to go fix tloz bug

* move tokens into a separate inheritable class

* forgot the commit to remove token from ProcedurePatch

* further cleaning from bad commit

* start on docstrings

* further work on docstrings and typing

* improve docstrings

* fix incorrect docstring

* cleanup

* clean defaults and docstring

* define interface that has only the bare minimum required
for `Patch.create_rom_file`

* change to dictionary.get

* remove unnecessary if statement

* update to explicitly check for procedure, restore compatible version and manual override

* Update Files.py

* remove struct uses

* ensure returning bytes, add token type checking

* Apply suggestions from code review

Co-authored-by: Doug Hoskisson <[email protected]>

* pep8

---------

Co-authored-by: beauxq <[email protected]>
Co-authored-by: Doug Hoskisson <[email protected]>
  • Loading branch information
3 people authored Mar 19, 2024
1 parent 8a8263f commit 94650a0
Showing 1 changed file with 248 additions and 36 deletions.
284 changes: 248 additions & 36 deletions worlds/Files.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import abc
import json
import zipfile
from enum import IntEnum
import os
import threading

from typing import ClassVar, Dict, List, Literal, Tuple, Any, Optional, Union, BinaryIO
from typing import ClassVar, Dict, List, Literal, Tuple, Any, Optional, Union, BinaryIO, overload

import bsdiff4

Expand Down Expand Up @@ -38,6 +39,32 @@ def get_handler(file: str) -> Optional[AutoPatchRegister]:
return None


class AutoPatchExtensionRegister(abc.ABCMeta):
extension_types: ClassVar[Dict[str, AutoPatchExtensionRegister]] = {}
required_extensions: List[str] = []

def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoPatchExtensionRegister:
# construct class
new_class = super().__new__(mcs, name, bases, dct)
if "game" in dct:
AutoPatchExtensionRegister.extension_types[dct["game"]] = new_class
return new_class

@staticmethod
def get_handler(game: str) -> Union[AutoPatchExtensionRegister, List[AutoPatchExtensionRegister]]:
handler = AutoPatchExtensionRegister.extension_types.get(game, APPatchExtension)
if handler.required_extensions:
handlers = [handler]
for required in handler.required_extensions:
ext = AutoPatchExtensionRegister.extension_types.get(required)
if not ext:
raise NotImplementedError(f"No handler for {required}.")
handlers.append(ext)
return handlers
else:
return handler


container_version: int = 6


Expand Down Expand Up @@ -157,27 +184,14 @@ def patch(self, target: str) -> None:
""" create the output file with the file name `target` """


class APDeltaPatch(APAutoPatchInterface):
"""An implementation of `APAutoPatchInterface` that additionally
has delta.bsdiff4 containing a delta patch to get the desired file."""

class APProcedurePatch(APAutoPatchInterface):
"""
An APPatch that defines a procedure to produce the desired file.
"""
hash: Optional[str] # base checksum of source file
patch_file_ending: str = ""
delta: Optional[bytes] = None
source_data: bytes
procedure = None # delete this line when APPP is added

def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any) -> None:
self.patched_path = patched_path
super(APDeltaPatch, self).__init__(*args, **kwargs)

def get_manifest(self) -> Dict[str, Any]:
manifest = super(APDeltaPatch, self).get_manifest()
manifest["base_checksum"] = self.hash
manifest["result_file_ending"] = self.result_file_ending
manifest["patch_file_ending"] = self.patch_file_ending
manifest["compatible_version"] = 5 # delete this line when APPP is added
return manifest
patch_file_ending: str = ""
files: Dict[str, bytes] = {}

@classmethod
def get_source_data(cls) -> bytes:
Expand All @@ -190,21 +204,219 @@ def get_source_data_with_cache(cls) -> bytes:
cls.source_data = cls.get_source_data()
return cls.source_data

def __init__(self, *args: Any, **kwargs: Any):
super(APProcedurePatch, self).__init__(*args, **kwargs)

def get_manifest(self) -> Dict[str, Any]:
manifest = super(APProcedurePatch, self).get_manifest()
manifest["base_checksum"] = self.hash
manifest["result_file_ending"] = self.result_file_ending
manifest["patch_file_ending"] = self.patch_file_ending
manifest["procedure"] = self.procedure
if self.procedure == APDeltaPatch.procedure:
manifest["compatible_version"] = 5
return manifest

def read_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
super(APProcedurePatch, self).read_contents(opened_zipfile)
with opened_zipfile.open("archipelago.json", "r") as f:
manifest = json.load(f)
if "procedure" not in manifest:
# support patching files made before moving to procedures
self.procedure = [("apply_bsdiff4", ["delta.bsdiff4"])]
else:
self.procedure = manifest["procedure"]
for file in opened_zipfile.namelist():
if file not in ["archipelago.json"]:
self.files[file] = opened_zipfile.read(file)

def write_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
super(APProcedurePatch, self).write_contents(opened_zipfile)
for file in self.files:
opened_zipfile.writestr(file, self.files[file],
compress_type=zipfile.ZIP_STORED if file.endswith(".bsdiff4") else None)

def get_file(self, file: str) -> bytes:
""" Retrieves a file from the patch container."""
if file not in self.files:
self.read()
return self.files[file]

def write_file(self, file_name: str, file: bytes) -> None:
""" Writes a file to the patch container, to be retrieved upon patching. """
self.files[file_name] = file

def patch(self, target: str) -> None:
self.read()
base_data = self.get_source_data_with_cache()
patch_extender = AutoPatchExtensionRegister.get_handler(self.game)
assert not isinstance(self.procedure, str), f"{type(self)} must define procedures"
for step, args in self.procedure:
if isinstance(patch_extender, list):
extension = next((item for item in [getattr(extender, step, None) for extender in patch_extender]
if item is not None), None)
else:
extension = getattr(patch_extender, step, None)
if extension is not None:
base_data = extension(self, base_data, *args)
else:
raise NotImplementedError(f"Unknown procedure {step} for {self.game}.")
with open(target, 'wb') as f:
f.write(base_data)


class APDeltaPatch(APProcedurePatch):
"""An APProcedurePatch that additionally has delta.bsdiff4
containing a delta patch to get the desired file, often a rom."""

procedure = [
("apply_bsdiff4", ["delta.bsdiff4"])
]

def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any) -> None:
super(APDeltaPatch, self).__init__(*args, **kwargs)
self.patched_path = patched_path

def write_contents(self, opened_zipfile: zipfile.ZipFile):
self.write_file("delta.bsdiff4",
bsdiff4.diff(self.get_source_data_with_cache(), open(self.patched_path, "rb").read()))
super(APDeltaPatch, self).write_contents(opened_zipfile)
# write Delta
opened_zipfile.writestr("delta.bsdiff4",
bsdiff4.diff(self.get_source_data_with_cache(), open(self.patched_path, "rb").read()),
compress_type=zipfile.ZIP_STORED) # bsdiff4 is a format with integrated compression

def read_contents(self, opened_zipfile: zipfile.ZipFile):
super(APDeltaPatch, self).read_contents(opened_zipfile)
self.delta = opened_zipfile.read("delta.bsdiff4")

def patch(self, target: str):
"""Base + Delta -> Patched"""
if not self.delta:
self.read()
result = bsdiff4.patch(self.get_source_data_with_cache(), self.delta)
with open(target, "wb") as f:
f.write(result)


class APTokenTypes(IntEnum):
WRITE = 0
COPY = 1
RLE = 2
AND_8 = 3
OR_8 = 4
XOR_8 = 5


class APTokenMixin:
"""
A class that defines functions for generating a token binary, for use in patches.
"""
tokens: List[
Tuple[APTokenTypes, int, Union[
bytes, # WRITE
Tuple[int, int], # COPY, RLE
int # AND_8, OR_8, XOR_8
]]] = []

def get_token_binary(self) -> bytes:
"""
Returns the token binary created from stored tokens.
:return: A bytes object representing the token data.
"""
data = bytearray()
data.extend(len(self.tokens).to_bytes(4, "little"))
for token_type, offset, args in self.tokens:
data.append(token_type)
data.extend(offset.to_bytes(4, "little"))
if token_type in [APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8]:
assert isinstance(args, int), f"Arguments to AND/OR/XOR must be of type int, not {type(args)}"
data.extend(int.to_bytes(1, 4, "little"))
data.append(args)
elif token_type in [APTokenTypes.COPY, APTokenTypes.RLE]:
assert isinstance(args, tuple), f"Arguments to COPY/RLE must be of type tuple, not {type(args)}"
data.extend(int.to_bytes(4, 4, "little"))
data.extend(args[0].to_bytes(4, "little"))
data.extend(args[1].to_bytes(4, "little"))
elif token_type == APTokenTypes.WRITE:
assert isinstance(args, bytes), f"Arguments to WRITE must be of type bytes, not {type(args)}"
data.extend(len(args).to_bytes(4, "little"))
data.extend(args)
else:
raise ValueError(f"Unknown token type {token_type}")
return bytes(data)

@overload
def write_token(self,
token_type: Literal[APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8],
offset: int,
data: int) -> None:
...

@overload
def write_token(self,
token_type: Literal[APTokenTypes.COPY, APTokenTypes.RLE],
offset: int,
data: Tuple[int, int]) -> None:
...

@overload
def write_token(self,
token_type: Literal[APTokenTypes.WRITE],
offset: int,
data: bytes) -> None:
...

def write_token(self, token_type: APTokenTypes, offset: int, data: Union[bytes, Tuple[int, int], int]):
"""
Stores a token to be used by patching.
"""
self.tokens.append((token_type, offset, data))


class APPatchExtension(metaclass=AutoPatchExtensionRegister):
"""Class that defines patch extension functions for a given game.
Patch extension functions must have the following two arguments in the following order:
caller: APProcedurePatch (used to retrieve files from the patch container)
rom: bytes (the data to patch)
Further arguments are passed in from the procedure as defined.
Patch extension functions must return the changed bytes.
"""
game: str
required_extensions: List[str] = []

@staticmethod
def apply_bsdiff4(caller: APProcedurePatch, rom: bytes, patch: str):
"""Applies the given bsdiff4 from the patch onto the current file."""
return bsdiff4.patch(rom, caller.get_file(patch))

@staticmethod
def apply_tokens(caller: APProcedurePatch, rom: bytes, token_file: str) -> bytes:
"""Applies the given token file from the patch onto the current file."""
token_data = caller.get_file(token_file)
rom_data = bytearray(rom)
token_count = int.from_bytes(token_data[0:4], "little")
bpr = 4
for _ in range(token_count):
token_type = token_data[bpr:bpr + 1][0]
offset = int.from_bytes(token_data[bpr + 1:bpr + 5], "little")
size = int.from_bytes(token_data[bpr + 5:bpr + 9], "little")
data = token_data[bpr + 9:bpr + 9 + size]
if token_type in [APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8]:
arg = data[0]
if token_type == APTokenTypes.AND_8:
rom_data[offset] = rom_data[offset] & arg
elif token_type == APTokenTypes.OR_8:
rom_data[offset] = rom_data[offset] | arg
else:
rom_data[offset] = rom_data[offset] ^ arg
elif token_type in [APTokenTypes.COPY, APTokenTypes.RLE]:
length = int.from_bytes(data[:4], "little")
value = int.from_bytes(data[4:], "little")
if token_type == APTokenTypes.COPY:
rom_data[offset: offset + length] = rom_data[value: value + length]
else:
rom_data[offset: offset + length] = bytes([value] * length)
else:
rom_data[offset:offset + len(data)] = data
bpr += 9 + size
return bytes(rom_data)

@staticmethod
def calc_snes_crc(caller: APProcedurePatch, rom: bytes):
"""Calculates and applies a valid CRC for the SNES rom header."""
rom_data = bytearray(rom)
if len(rom) < 0x8000:
raise Exception("Tried to calculate SNES CRC on file too small to be a SNES ROM.")
crc = (sum(rom_data[:0x7FDC] + rom_data[0x7FE0:]) + 0x01FE) & 0xFFFF
inv = crc ^ 0xFFFF
rom_data[0x7FDC:0x7FE0] = [inv & 0xFF, (inv >> 8) & 0xFF, crc & 0xFF, (crc >> 8) & 0xFF]
return bytes(rom_data)

0 comments on commit 94650a0

Please sign in to comment.