From 7ca6f24e6caf516b3bc0b1823d70dd716476a845 Mon Sep 17 00:00:00 2001
From: Fabian Dill <fabian.dill@web.de>
Date: Fri, 4 Mar 2022 21:36:18 +0100
Subject: [PATCH] MultiServer: allow multiple, ordered operations MultiServer:
 rename "data" on Get, Retrieved and SetNotify to "keys" MultiServer: add some
 more operators SniClient: some pep8 cleanup

---
 FactorioClient.py |  9 ++---
 MultiServer.py    | 31 ++++++++++------
 SNIClient.py      | 92 +++++++++++++++++++++++++++--------------------
 3 files changed, 78 insertions(+), 54 deletions(-)

diff --git a/FactorioClient.py b/FactorioClient.py
index f1f6844b53e2..82f58cf806b7 100644
--- a/FactorioClient.py
+++ b/FactorioClient.py
@@ -169,8 +169,9 @@ async def game_watcher(ctx: FactorioContext):
                                 # attempt to refill
                                 ctx.last_deplete = time.time()
                                 asyncio.create_task(ctx.send_msgs([{
-                                    "cmd": "Set", "key": "EnergyLink", "operation": "deplete",
-                                    "value": -ctx.energy_link_increment * in_world_bridges,
+                                    "cmd": "Set", "key": "EnergyLink", "operations":
+                                        [{"operation": "add", "value": -ctx.energy_link_increment * in_world_bridges},
+                                         {"operation": "max", "value": 0}],
                                     "last_deplete": ctx.last_deplete
                                 }]))
                             # Above Capacity - (len(Bridges) * ENERGY_INCREMENT)
@@ -178,8 +179,8 @@ async def game_watcher(ctx: FactorioContext):
                                 ctx.energy_link_increment*in_world_bridges:
                                 value = ctx.energy_link_increment * in_world_bridges
                                 asyncio.create_task(ctx.send_msgs([{
-                                    "cmd": "Set", "key": "EnergyLink", "operation": "add",
-                                    "value": value
+                                    "cmd": "Set", "key": "EnergyLink", "operations":
+                                        [{"operation": "add", "value": value}]
                                 }]))
                                 ctx.rcon_client.send_command(
                                     f"/ap-energylink -{value}")
diff --git a/MultiServer.py b/MultiServer.py
index 7a2058214a2f..77e2ffbc42b4 100644
--- a/MultiServer.py
+++ b/MultiServer.py
@@ -41,12 +41,20 @@
 
 # functions callable on storable data on the server by clients
 modify_functions = {
-    "add": operator.add,
+    "add": operator.add,  # add together two objects, using python's "+" operator (works on strings and lists as append)
     "mul": operator.mul,
+    "mod": operator.mod,
     "max": max,
     "min": min,
     "replace": lambda old, new: new,
-    "deplete": lambda value, change: max(0, value + change)
+    "default": lambda old, new: old,
+    "pow": operator.pow,
+    # bitwise:
+    "xor": operator.xor,
+    "or": operator.or_,
+    "and": operator.and_,
+    "left_shift": operator.lshift,
+    "right_shift": operator.rshift,
 }
 
 
@@ -1544,26 +1552,27 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
                     await ctx.send_encoded_msgs(bounceclient, msg)
 
         elif cmd == "Get":
-            if "data" not in args or type(args["data"]) != list:
+            if "keys" not in args or type(args["keys"]) != list:
                 await ctx.send_msgs(client, [{'cmd': 'InvalidPacket', "type": "arguments",
                                               "text": 'Retrieve', "original_cmd": cmd}])
                 return
             args["cmd"] = "Retrieved"
-            keys = args["data"]
-            args["data"] = {key: ctx.stored_data.get(key, None) for key in keys}
+            keys = args["keys"]
+            args["keys"] = {key: ctx.stored_data.get(key, None) for key in keys}
             await ctx.send_msgs(client, [args])
 
         elif cmd == "Set":
-            if "key" not in args or "value" not in args:
+            if "key" not in args or "value" not in args or \
+                    "operations" not in args or not type(args["operations"]) == list:
                 await ctx.send_msgs(client, [{'cmd': 'InvalidPacket', "type": "arguments",
                                               "text": 'Set', "original_cmd": cmd}])
                 return
             args["cmd"] = "SetReply"
             value = ctx.stored_data.get(args["key"], args.get("default", 0))
             args["original_value"] = value
-            operation = args.get("operation", "replace")
-            func = modify_functions[operation]
-            value = func(value, args.get("value"))
+            for operation in args["operations"]:
+                func = modify_functions[operation["operation"]]
+                value = func(value, operation["value"])
             ctx.stored_data[args["key"]] = args["value"] = value
             targets = set(ctx.stored_data_notification_clients[args["key"]])
             if args.get("want_reply", True):
@@ -1572,11 +1581,11 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
                 ctx.broadcast(targets, [args])
 
         elif cmd == "SetNotify":
-            if "data" not in args or type(args["data"]) != list:
+            if "keys" not in args or type(args["keys"]) != list:
                 await ctx.send_msgs(client, [{'cmd': 'InvalidPacket', "type": "arguments",
                                               "text": 'SetNotify', "original_cmd": cmd}])
                 return
-            for key in args["data"]:
+            for key in args["keys"]:
                 ctx.stored_data_notification_clients[key].add(client)
 
 
diff --git a/SNIClient.py b/SNIClient.py
index 2fd9bd030211..73e998d51e34 100644
--- a/SNIClient.py
+++ b/SNIClient.py
@@ -11,9 +11,8 @@
 import logging
 import asyncio
 from json import loads, dumps
-from tkinter import font
 
-from Utils import get_item_name_from_id, init_logging
+from Utils import init_logging
 
 if __name__ == "__main__":
     init_logging("SNIClient", exception_logger="Client")
@@ -22,14 +21,12 @@
 
 from NetUtils import *
 from worlds.alttp import Regions, Shops
-from worlds.alttp import Items
 from worlds.alttp.Rom import ROM_PLAYER_LIMIT
 from worlds.sm.Rom import ROM_PLAYER_LIMIT as SM_ROM_PLAYER_LIMIT
 import Utils
 from CommonClient import CommonContext, server_loop, console_loop, ClientCommandProcessor, gui_enabled, get_base_parser
 from Patch import GAME_ALTTP, GAME_SM
 
-
 snes_logger = logging.getLogger("SNES")
 
 from MultiServer import mark_raw
@@ -41,7 +38,7 @@ class DeathState(enum.IntEnum):
     dead = 3
 
 
-class LttPCommandProcessor(ClientCommandProcessor):
+class SNIClientCommandProcessor(ClientCommandProcessor):
     ctx: Context
 
     def _cmd_slow_mode(self, toggle: str = ""):
@@ -72,7 +69,6 @@ def _cmd_snes(self, snes_options: str = "") -> bool:
             snes_address = options[0]
             snes_device_number = int(options[1])
 
-
         self.ctx.snes_reconnect_address = None
         asyncio.create_task(snes_connect(self.ctx, snes_address, snes_device_number), name="SNES Connect")
         return True
@@ -92,15 +88,23 @@ def _cmd_snes_close(self) -> bool:
     #     if self.ctx.snes_state != SNESState.SNES_ATTACHED:
     #         self.output("No attached SNES Device.")
     #         return False
-    #
     #     snes_buffered_write(self.ctx, int(address, 16), bytes([int(data)]))
     #     asyncio.create_task(snes_flush_writes(self.ctx))
     #     self.output("Data Sent")
     #     return True
 
+    # def _cmd_snes_read(self, address, size=1):
+    #     """Read the SNES' memory address (base16)."""
+    #     if self.ctx.snes_state != SNESState.SNES_ATTACHED:
+    #         self.output("No attached SNES Device.")
+    #         return False
+    #     data = await snes_read(self.ctx, int(address, 16), size)
+    #     self.output(f"Data Read: {data}")
+    #     return True
+
 
 class Context(CommonContext):
-    command_processor = LttPCommandProcessor
+    command_processor = SNIClientCommandProcessor
     game = "A Link to the Past"
     items_handling = None  # set in game_watcher
 
@@ -183,7 +187,8 @@ async def deathlink_kill_player(ctx: Context):
                 continue
             if not invincible[0] and last_health[0] == health[0]:
                 snes_buffered_write(ctx, WRAM_START + 0xF36D, bytes([0]))  # set current health to 0
-                snes_buffered_write(ctx, WRAM_START + 0x0373, bytes([8]))  # deal 1 full heart of damage at next opportunity
+                snes_buffered_write(ctx, WRAM_START + 0x0373,
+                                    bytes([8]))  # deal 1 full heart of damage at next opportunity
         elif ctx.game == GAME_SM:
             snes_buffered_write(ctx, WRAM_START + 0x09C2, bytes([0, 0]))  # set current health to 0
             if not ctx.death_link_allow_survive:
@@ -200,7 +205,8 @@ async def deathlink_kill_player(ctx: Context):
             health = await snes_read(ctx, WRAM_START + 0x09C2, 2)
             if health is not None:
                 health = health[0] | (health[1] << 8)
-            if not gamemode or gamemode[0] in SM_DEATH_MODES or (ctx.death_link_allow_survive and health is not None and health > 0):
+            if not gamemode or gamemode[0] in SM_DEATH_MODES or (
+                    ctx.death_link_allow_survive and health is not None and health > 0):
                 ctx.death_state = DeathState.dead
         ctx.last_death_link = time.time()
 
@@ -914,7 +920,7 @@ async def game_watcher(ctx: Context):
 
             ctx.rom = rom
             death_link = await snes_read(ctx, DEATH_LINK_ACTIVE_ADDR if ctx.game == GAME_ALTTP else
-                                         SM_DEATH_LINK_ACTIVE_ADDR, 1)
+            SM_DEATH_LINK_ACTIVE_ADDR, 1)
             if death_link:
                 ctx.death_link_allow_survive = bool(death_link[0] & 0b10)
                 await ctx.update_death_link(bool(death_link[0] & 0b1))
@@ -976,7 +982,8 @@ async def game_watcher(ctx: Context):
                 item = ctx.items_received[recv_index]
                 recv_index += 1
                 logging.info('Received %s from %s (%s) (%d/%d in list)' % (
-                    color(ctx.item_name_getter(item.item), 'red', 'bold'), color(ctx.player_names[item.player], 'yellow'),
+                    color(ctx.item_name_getter(item.item), 'red', 'bold'),
+                    color(ctx.player_names[item.player], 'yellow'),
                     ctx.location_name_getter(item.location), recv_index, len(ctx.items_received)))
 
                 snes_buffered_write(ctx, RECV_PROGRESS_ADDR,
@@ -998,7 +1005,7 @@ async def game_watcher(ctx: Context):
             if scout_location > 0 and scout_location not in ctx.locations_scouted:
                 ctx.locations_scouted.add(scout_location)
                 await ctx.send_msgs([{"cmd": "LocationScouts", "locations": [scout_location]}])
-            await track_locations(ctx, roomid, roomdata)        
+            await track_locations(ctx, roomid, roomdata)
         elif ctx.game == GAME_SM:
             gamemode = await snes_read(ctx, WRAM_START + 0x0998, 1)
             if "DeathLink" in ctx.tags and gamemode and ctx.last_death_link + 1 < time.time():
@@ -1025,14 +1032,16 @@ async def game_watcher(ctx: Context):
                 itemIndex = (message[4] | (message[5] << 8)) >> 3
 
                 recv_index += 1
-                snes_buffered_write(ctx, SM_RECV_PROGRESS_ADDR + 0x680, bytes([recv_index & 0xFF, (recv_index >> 8) & 0xFF]))
+                snes_buffered_write(ctx, SM_RECV_PROGRESS_ADDR + 0x680,
+                                    bytes([recv_index & 0xFF, (recv_index >> 8) & 0xFF]))
 
                 from worlds.sm.Locations import locations_start_id
                 location_id = locations_start_id + itemIndex
 
                 ctx.locations_checked.add(location_id)
                 location = ctx.location_name_getter(location_id)
-                snes_logger.info(f'New Check: {location} ({len(ctx.locations_checked)}/{len(ctx.missing_locations) + len(ctx.checked_locations)})')
+                snes_logger.info(
+                    f'New Check: {location} ({len(ctx.locations_checked)}/{len(ctx.missing_locations) + len(ctx.checked_locations)})')
                 await ctx.send_msgs([{"cmd": 'LocationChecks', "locations": [location_id]}])
 
             data = await snes_read(ctx, SM_RECV_PROGRESS_ADDR + 0x600, 4)
@@ -1048,11 +1057,14 @@ async def game_watcher(ctx: Context):
                 itemId = item.item - items_start_id
 
                 playerID = item.player if item.player <= SM_ROM_PLAYER_LIMIT else 0
-                snes_buffered_write(ctx, SM_RECV_PROGRESS_ADDR + itemOutPtr * 4, bytes([playerID & 0xFF, (playerID >> 8) & 0xFF, itemId & 0xFF, (itemId >> 8) & 0xFF]))
+                snes_buffered_write(ctx, SM_RECV_PROGRESS_ADDR + itemOutPtr * 4, bytes(
+                    [playerID & 0xFF, (playerID >> 8) & 0xFF, itemId & 0xFF, (itemId >> 8) & 0xFF]))
                 itemOutPtr += 1
-                snes_buffered_write(ctx, SM_RECV_PROGRESS_ADDR + 0x602, bytes([itemOutPtr & 0xFF, (itemOutPtr >> 8) & 0xFF]))
+                snes_buffered_write(ctx, SM_RECV_PROGRESS_ADDR + 0x602,
+                                    bytes([itemOutPtr & 0xFF, (itemOutPtr >> 8) & 0xFF]))
                 logging.info('Received %s from %s (%s) (%d/%d in list)' % (
-                    color(ctx.item_name_getter(item.item), 'red', 'bold'), color(ctx.player_names[item.player], 'yellow'),
+                    color(ctx.item_name_getter(item.item), 'red', 'bold'),
+                    color(ctx.player_names[item.player], 'yellow'),
                     ctx.location_name_getter(item.location), itemOutPtr, len(ctx.items_received)))
             await snes_flush_writes(ctx)
 
@@ -1130,6 +1142,7 @@ async def main():
     if input_task:
         input_task.cancel()
 
+
 def get_alttp_settings(romfile: str):
     lastSettings = Utils.get_adjuster_settings(GAME_ALTTP)
     adjusted = False
@@ -1139,8 +1152,8 @@ def get_alttp_settings(romfile: str):
         if not hasattr(lastSettings, 'auto_apply') or 'ask' in lastSettings.auto_apply:
 
             whitelist = {"music", "menuspeed", "heartbeep", "heartcolor", "ow_palettes", "quickswap",
-                        "uw_palettes", "sprite", "sword_palettes", "shield_palettes", "hud_palettes",
-                        "reduceflashing", "deathlink"}
+                         "uw_palettes", "sprite", "sword_palettes", "shield_palettes", "hud_palettes",
+                         "reduceflashing", "deathlink"}
             printed_options = {name: value for name, value in vars(lastSettings).items() if name in whitelist}
             if hasattr(lastSettings, "sprite_pool"):
                 sprite_pool = {}
@@ -1154,40 +1167,41 @@ def get_alttp_settings(romfile: str):
             import pprint
 
             if gui_enabled:
-            
+
                 from tkinter import Tk, PhotoImage, Label, LabelFrame, Frame, Button
                 applyPromptWindow = Tk()
                 applyPromptWindow.resizable(False, False)
-                applyPromptWindow.protocol('WM_DELETE_WINDOW',lambda: onButtonClick())
+                applyPromptWindow.protocol('WM_DELETE_WINDOW', lambda: onButtonClick())
                 logo = PhotoImage(file=Utils.local_path('data', 'icon.png'))
                 applyPromptWindow.tk.call('wm', 'iconphoto', applyPromptWindow._w, logo)
                 applyPromptWindow.wm_title("Last adjuster settings LttP")
 
                 label = LabelFrame(applyPromptWindow,
-                                text='Last used adjuster settings were found. Would you like to apply these?')
-                label.grid(column=0,row=0, padx=5, pady=5, ipadx=5, ipady=5)
-                label.grid_columnconfigure (0, weight=1) 
-                label.grid_columnconfigure (1, weight=1) 
-                label.grid_columnconfigure (2, weight=1) 
-                label.grid_columnconfigure (3, weight=1) 
-                def onButtonClick(answer: str='no'):
+                                   text='Last used adjuster settings were found. Would you like to apply these?')
+                label.grid(column=0, row=0, padx=5, pady=5, ipadx=5, ipady=5)
+                label.grid_columnconfigure(0, weight=1)
+                label.grid_columnconfigure(1, weight=1)
+                label.grid_columnconfigure(2, weight=1)
+                label.grid_columnconfigure(3, weight=1)
+
+                def onButtonClick(answer: str = 'no'):
                     setattr(onButtonClick, 'choice', answer)
                     applyPromptWindow.destroy()
 
                 framedOptions = Frame(label)
-                framedOptions.grid(column=0, columnspan=4,row=0)
+                framedOptions.grid(column=0, columnspan=4, row=0)
                 framedOptions.grid_columnconfigure(0, weight=1)
                 framedOptions.grid_columnconfigure(1, weight=1)
                 framedOptions.grid_columnconfigure(2, weight=1)
                 curRow = 0
                 curCol = 0
                 for name, value in printed_options.items():
-                    Label(framedOptions, text=name+": "+str(value)).grid(column=curCol, row=curRow, padx=5)
-                    if(curCol==2):
-                        curRow+=1
-                        curCol=0
+                    Label(framedOptions, text=name + ": " + str(value)).grid(column=curCol, row=curRow, padx=5)
+                    if (curCol == 2):
+                        curRow += 1
+                        curCol = 0
                     else:
-                        curCol+=1
+                        curCol += 1
 
                 yesButton = Button(label, text='Yes', command=lambda: onButtonClick('yes'), width=10)
                 yesButton.grid(column=0, row=1)
@@ -1203,8 +1217,8 @@ def onButtonClick(answer: str='no'):
                 choice = getattr(onButtonClick, 'choice')
             else:
                 choice = input(f"Last used adjuster settings were found. Would you like to apply these? \n"
-                                    f"{pprint.pformat(printed_options)}\n"
-                                    f"Enter yes, no, always or never: ")
+                               f"{pprint.pformat(printed_options)}\n"
+                               f"Enter yes, no, always or never: ")
             if choice and choice.startswith("y"):
                 choice = 'yes'
             elif choice and "never" in choice:
@@ -1221,7 +1235,7 @@ def onButtonClick(answer: str='no'):
             choice = 'no'
         elif 'always' in lastSettings.auto_apply:
             choice = 'yes'
-                    
+
         if 'yes' in choice:
             from worlds.alttp.Rom import get_base_rom_path
             lastSettings.rom = romfile
@@ -1247,10 +1261,10 @@ def onButtonClick(answer: str='no'):
             except Exception as e:
                 logging.exception(e)
     else:
-        
         adjusted = False
     return adjustedromfile, adjusted
 
+
 if __name__ == '__main__':
     colorama.init()
     loop = asyncio.get_event_loop()