From 6bc69bce30fef95bddf8b1394d741b19915f63a3 Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Thu, 5 Oct 2023 08:54:09 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- __init__.py | 1 - asm.py | 21 ++++++++++----------- mem.py | 43 ++++++++++++++++++++----------------------- mmio.py | 9 ++++----- pci.py | 17 +++++++++++++---- segments.py | 28 +++++++++++----------------- utils.py | 6 +++--- xhci.py | 48 ++++++++++++++++++++---------------------------- 8 files changed, 81 insertions(+), 92 deletions(-) diff --git a/__init__.py b/__init__.py index d17efa7..f2f3246 100644 --- a/__init__.py +++ b/__init__.py @@ -8,7 +8,6 @@ try: t.halt() - pass except: pass diff --git a/asm.py b/asm.py index 746d202..94ec23a 100644 --- a/asm.py +++ b/asm.py @@ -4,16 +4,13 @@ from proc import * def get_registers(thread): - registers = [] register_list = ["eax", "ebx", "ecx", "edx", "esi", "edi", "ebp", "esp", "eip"] - for reg in register_list: - registers.append((reg, thread.arch_register(reg))) - return registers + return [(reg, thread.arch_register(reg)) for reg in register_list] def print_registers(thread=None): if thread is None: thread = t - + was_running = False # Halt the thread if needed if thread.isrunning(): @@ -22,7 +19,7 @@ def print_registers(thread=None): registers = get_registers(thread) print ("Registers : ") for (reg, val) in registers: - print("%s: %s" % (reg, val.ToHex())) + print(f"{reg}: {val.ToHex()}") if was_running: thread.go() @@ -83,7 +80,7 @@ def v3_resume(): def pop(): ss = reg("ss") - ret = t.mem(ss.ToHex() + ":" + reg("esp").ToHex(), 4) + ret = t.mem(f"{ss.ToHex()}:" + reg("esp").ToHex(), 4) reg("esp", reg("esp") + 4) return ret @@ -166,16 +163,18 @@ def printStackContent(): table = ss[2] idx = ss[3:15] base = t.arch_register("ldtbas" if table else "gdtbas") - segment = GDTEntry(t.memblock(str(base.ToUInt32() + 8 * idx.ToUInt32()) + "L", 8, 1)) + segment = GDTEntry( + t.memblock(f"{str(base.ToUInt32() + 8 * idx.ToUInt32())}L", 8, 1) + ) limit = segment.limit - print("ESP : %s" % esp.ToHex()) + print(f"ESP : {esp.ToHex()}") esp = esp & ~0xF - t.memdump(ss.ToHex() + ":" + esp.ToHex(), limit - esp, 1) + t.memdump(f"{ss.ToHex()}:{esp.ToHex()}", limit - esp, 1) def peek(register, offset=0, size=4, value=None): ds = reg("ds") reg_value = reg(register) - return t.mem(ds.ToHex() + ":" + hex(reg_value + offset), size, value) + return t.mem(f"{ds.ToHex()}:{hex(reg_value + offset)}", size, value) def poke(register, offset=0, value=None, size=4): return peek(register, offset, size, value) diff --git a/mem.py b/mem.py index 2084a0d..9f60d29 100644 --- a/mem.py +++ b/mem.py @@ -128,14 +128,14 @@ def print_pages(): def linear_to_pages(addr): addr_bits = ipc.BitData(32, addr) directory = addr_bits[22:31].ToUInt32() - offset = addr_bits[0:21].ToUInt32() + offset = addr_bits[:21].ToUInt32() pd = reg("cr3") pde = t.memblock(phys(pd + directory*4), 4, 1) pde = PDE(directory, pde) if pde.present: if pde.size == 0: table = addr_bits[12:21].ToUInt32() - offset = addr_bits[0:11].ToUInt32() + offset = addr_bits[:11].ToUInt32() pt = pde.base_addr << 12 pte = t.memblock(phys(pt + table*4), 4, 1) pte = PTE(pde, table, pte) @@ -172,19 +172,15 @@ def virt_to_phys(addr, selector="ds"): def linear_to_phys(addr): (pde, pte, offset) = linear_to_pages(addr) if pte: - if pte.present: - return pte.base_addr.ToUInt32() << 12 | offset - return None - if pde.present: - return pde.base_addr.ToUInt32() << 12 | offset - return None + return pte.base_addr.ToUInt32() << 12 | offset if pte.present else None + return pde.base_addr.ToUInt32() << 12 | offset if pde.present else None def dump_pages(filename): save_to_file(filename, print_pages) def memdump_ds(addr, size=0x10): ds = reg("ds") - return t.memdump(ds.ToHex() + ":" + hex(addr), size, 1) + return t.memdump(f"{ds.ToHex()}:{hex(addr)}", size, 1) def memset(addr, value, size): t.memblock(phys(addr), int(size), 4, value) @@ -199,25 +195,26 @@ def phys(addr): def malloc(size): malloc_func = proc_get_address(t, "SYSLIB:MALLOC") - execute_asm(t, - "push %s" % hex(size), - # Need to call using register because asm uses near call and if I - # do a far call with 'cs:addr', it pushes cs to the stack so it always - # allocates 0x1bc bytes - "mov eax, %s" % hex(malloc_func).replace("L", ""), - "call eax") + execute_asm( + t, + f"push {hex(size)}", + f'mov eax, {hex(malloc_func).replace("L", "")}', + "call eax", + ) wait_until_infinite_loop(t, False) return reg("eax") def malign(alignment, size): malign_func = proc_get_address(t, "SYSLIB:MALIGN") - execute_asm(t, - "push 0", - "push %s" % hex(size), - "push %s" % hex(alignment), - "push 0", - "mov eax, %s" % hex(malign_func).replace("L", ""), - "call eax") + execute_asm( + t, + "push 0", + f"push {hex(size)}", + f"push {hex(alignment)}", + "push 0", + f'mov eax, {hex(malign_func).replace("L", "")}', + "call eax", + ) wait_until_infinite_loop(t, False) return reg("eax") diff --git a/mmio.py b/mmio.py index 1d9ae3a..8083423 100644 --- a/mmio.py +++ b/mmio.py @@ -98,7 +98,7 @@ def save_mmios(pwd, mmios, prefix="MMIO_"): # Sort by size mmios.sort(lambda a, b: cmp(a[1], b[1]) if a[1] != b[1] else cmp(a[0], b[0])) for (addr, size) in mmios: - print("Addr: %s, size: %s" % (hex(addr), hex(size))) + print(f"Addr: {hex(addr)}, size: {hex(size)}") path = os.path.join(pwd, prefix + hex(addr)[2:].replace("L", "") + ".bin") if os.path.exists(path): statinfo = os.stat(path) @@ -110,8 +110,7 @@ def save_mmios(pwd, mmios, prefix="MMIO_"): with open(path, "ab") as f: while size > 0: chunk = 4 * 1024 - if chunk > size: - chunk = size + chunk = min(chunk, size) f.write(memtostr(phys(addr), chunk)) addr += chunk size -= chunk @@ -124,7 +123,7 @@ def save_mmios(pwd, mmios, prefix="MMIO_"): def bruteforce_sideband(pwd, group=0, start=0, end=0x100, size=0x8000, rs=1, fid=0): for i in xrange(start, end): channel = (group << 8) + i - print("Dumping Sideband : %s" % hex(channel)) + print(f"Dumping Sideband : {hex(channel)}") dump_sideband_channel(pwd, channel, size=size, rs=rs, fid=fid) def bruteforce_sideband_port(pwd, port, start=0, end=0x100, size=0x1000): @@ -153,7 +152,7 @@ def dump_sideband_channel(pwd, channel, size=0x8000, rs=1, fid=0): sb_channel_port_addr = proc_get_address(t, "SB_CHANNEL") sb_mmio, _ = setup_sideband_channel(channel, rs, fid) t.memdump(phys(sb_mmio), 0x10, 1) - save_mmios(pwd, [(sb_mmio, size)], "SB_" + hex(channel) + "_") + save_mmios(pwd, [(sb_mmio, size)], f"SB_{hex(channel)}_") try: a = t.mem(phys(sb_channel_port_addr + 0x18), 4) diff --git a/pci.py b/pci.py index 36fbd7c..135a626 100644 --- a/pci.py +++ b/pci.py @@ -30,15 +30,24 @@ def list_pci_devices(base_addr=0xE0000000, alt="", bars=True): for func in range (8): device = PCIDevice(bus, dev, func, t, base_addr) vid = device.getVID() - if vid != 0xFFFFFFFF and vid != 0x0: + if vid not in [0xFFFFFFFF, 0x0]: print("PCI %d.%d.%d : %s" % (bus, dev, func, vid.ToHex())) - mmio.save_mmios(pwd, [(device.getIOAddress(), 0x1000)], "PCI_" + alt + "%d.%d.%d_" % (bus, dev, func) ) + mmio.save_mmios( + pwd, + [(device.getIOAddress(), 0x1000)], + f"PCI_{alt}" + "%d.%d.%d_" % (bus, dev, func), + ) if bars: for offset in range(0x10, 0x28, 4): bar = device.readWord(offset) if bar != 0: - bar[0:7] = 0 - mmio.save_mmios(pwd, [(bar, 0x1000)], "BAR_" + alt + "%d.%d.%d_" % (bus, dev, func)) + bar[:7] = 0 + mmio.save_mmios( + pwd, + [(bar, 0x1000)], + f"BAR_{alt}" + + "%d.%d.%d_" % (bus, dev, func), + ) elif dev == 0 and func == 0: break else: diff --git a/segments.py b/segments.py index bb0a324..dcd69a4 100644 --- a/segments.py +++ b/segments.py @@ -6,7 +6,7 @@ class Selector: def __init__(self, bits): self.bits = bits - self.rpl = bits[0:1] + self.rpl = bits[:1] self.table = bits[2] self.idx = bits[3:15] @@ -20,7 +20,7 @@ def __init__(self, bits): self.base_addr = bits[16:31] self.base_addr.Append(bits[32:39]) self.base_addr.Append(bits[56:63]) - self.limit = bits[0:15] + self.limit = bits[:15] self.limit.Append(bits[48:51]) self.access = bits[40:47] self.flags = bits[52:55] @@ -75,12 +75,12 @@ class IDTEntry: def __init__(self, bits): # https://wiki.osdev.org/Interrupt_Descriptor_Table self.bits = bits - self.offset = bits[0:15] + self.offset = bits[:15] self.offset.Append(bits[48:63]) self.selector = bits[16:31] self.zero = bits[32:39] self.type_attr = bits[40:47] - self.gate_type = self.type_attr[0:3] + self.gate_type = self.type_attr[:3] self.s = self.type_attr[4] self.privl = self.type_attr[5:6] self.pr = self.type_attr[7] @@ -115,14 +115,11 @@ def print_segment(name, base, limit): entries = (limit.ToUInt32() + 1) // 8 print("%s (%s, %s) has %d entries" % (name, base, limit, entries)) for i in xrange(entries): - segment = t.memblock(str(base.ToUInt32() + 8 * i) + "L", 8, 1) - if name == "IDT": - entry = IDTEntry(segment) - else: - entry = GDTEntry(segment) + segment = t.memblock(f"{str(base.ToUInt32() + 8 * i)}L", 8, 1) + entry = IDTEntry(segment) if name == "IDT" else GDTEntry(segment) if entry.pr: print("**** %s Entry %d ****" % (name, i)) - print("%s" % str(entry)) + print(f"{str(entry)}") def print_segments(): gdtbas = t.arch_register("gdtbas") @@ -153,12 +150,9 @@ def segment_addr_to_linear(selector, addr): table = selector[2] idx = selector[3:15].ToUInt32() base = t.arch_register("ldtbas" if table else "gdtbas") - segment = t.memblock(str(base.ToUInt32() + 8 * idx) + "L", 8, 1) + segment = t.memblock(f"{str(base.ToUInt32() + 8 * idx)}L", 8, 1) entry = GDTEntry(segment) - if addr < entry.limit: - return entry.base_addr + addr - else: - return None + return entry.base_addr + addr if addr < entry.limit else None def table_to_mmio(base, limit): entries = (limit.ToUInt32() + 1) // 8 @@ -202,8 +196,8 @@ def dump_ldts(): limit = t.arch_register("ldtlim") entries = (limit.ToUInt32() + 1) // 8 for i in xrange(entries): - segment = t.memblock(str(base.ToUInt32() + 8 * i) + "L", 8, 1) + segment = t.memblock(f"{str(base.ToUInt32() + 8 * i)}L", 8, 1) entry = GDTEntry(segment) if entry.pr: - t.memsave("LDT-" + i + ".bin", str(entry.base.ToUInt32()) + "L") + t.memsave(f"LDT-{i}.bin", f"{str(entry.base.ToUInt32())}L") diff --git a/utils.py b/utils.py index 3d35b87..fdd6ff3 100644 --- a/utils.py +++ b/utils.py @@ -32,7 +32,7 @@ def debug(str): def genTaps(max, depth=0, max_depth=1, parent="SPT_TAP"): res = "" for i in xrange(0, max, 2): - name = "%s_%s" % (parent, i) + name = f"{parent}_{i}" res += (' ' * depth + '\n' % (name, i, max, i, max)) if depth + 1 < max_depth: res += genTaps(max, depth + 1, max_depth, name) @@ -48,8 +48,8 @@ def displayValidIdcodes(prefix=""): idcode = d.idcode() proc_id = d.irdrscan(0x2, 32) if proc_id != 0: - idcode += " (" + proc_id.ToHex() + ")" - print("%s : %s" % (d.name, idcode)) + idcode += f" ({proc_id.ToHex()})" + print(f"{d.name} : {idcode}") ipc = connect() print(ipc.devicelist) diff --git a/xhci.py b/xhci.py index 2a037fb..ac85991 100644 --- a/xhci.py +++ b/xhci.py @@ -49,16 +49,7 @@ def __init__(self, data=0): Data.__init__(self, 128, data) def __repr__(self): - return "TRB : {}\n" \ - " PTR Low : {}\n" \ - " PTR High: {}\n" \ - " Status: {}\n" \ - " Control: {}\n" \ - .format(self.data, - self.get(self.PTR_LOW), - self.get(self.PTR_HIGH), - self.get(self.STATUS), - self.get(self.CONTROL)) + return f"TRB : {self.data}\n PTR Low : {self.get(self.PTR_LOW)}\n PTR High: {self.get(self.PTR_HIGH)}\n Status: {self.get(self.STATUS)}\n Control: {self.get(self.CONTROL)}\n" class TRBPtrBits: @@ -200,7 +191,7 @@ def next_command_trb(self, cmd=0): def post_command(self): trb = self.TRB() - xhci_debug("Posting command %s" % TRBType.name(trb.get(TRBControlBits.TT))) + xhci_debug(f"Posting command {TRBType.name(trb.get(TRBControlBits.TT))}") # Set Cycle bit trb.set(TRBControlBits.C, self.pcs) trb.write(self.current) @@ -236,8 +227,7 @@ def enable_slot(self): cc = self.wait_for_command(cmd, False) if cc == TRBCompletionCode.SUCCESS: trb = xhci.er.TRB() - slot_id = trb.get(TRBControlBits.ID) - return slot_id + return trb.get(TRBControlBits.ID) return None def address_device(self, slot_id, ic): @@ -420,8 +410,10 @@ def __init__(self, slot_id, addr=None): self.slot = SlotContext(self.ctx) self.ep0 = EPContext(self.ctx + 0x20) self.eps = [] - for i in range(0x40, 0x40 + 0x20 * (self.NUM_EPS - 2), 0x20): - self.eps.append(EPContext(self.ctx + i)) + self.eps.extend( + EPContext(self.ctx + i) + for i in range(0x40, 0x40 + 0x20 * (self.NUM_EPS - 2), 0x20) + ) def __getitem__(self, idx): if int(idx) > 30: @@ -452,7 +444,7 @@ def __init__(self, thread): def dump_pci_config(self): sb_mmio, _ = setup_sideband_channel(0x050400 | self.port, 0, self.fid << 3) t.memdump(phys(sb_mmio), 0x100, 1) - save_mmios(pwd, [(sb_mmio, 0x1000)], "PCI_" + str(self.fid) + ".0_") + save_mmios(pwd, [(sb_mmio, 0x1000)], f"PCI_{str(self.fid)}.0_") def check_pci_from_ME(self): sb_mmio, _ = setup_sideband_channel(0x050400 | self.port, 0, self.fid << 3) @@ -599,15 +591,15 @@ def init(self): self.page_size = self.bar_read16(0x88).ToUInt32() << 12 self.max_slots = self.bar_read32(0x4).ToUInt32() & 0xff self.max_ports = (self.bar_read32(0x4).ToUInt32() & 0xff000000) >> 24 - xhci_debug("caplen: %s" % hex(self.bar_read32(0))) - xhci_debug("rtsoff: %s" % hex(self.bar_read32(0x18))) - xhci_debug("dboff: %s" % hex(self.bar_read32(0x14))) + xhci_debug(f"caplen: {hex(self.bar_read32(0))}") + xhci_debug(f"rtsoff: {hex(self.bar_read32(24))}") + xhci_debug(f"dboff: {hex(self.bar_read32(20))}") xhci_debug("hciversion: %d.%d" % (self.bar_read8(0x3), self.bar_read8(0x2))) xhci_debug("Max Slots: %d" % self.max_slots) xhci_debug("Max Ports: %d" % self.max_ports) xhci_debug("Page Size: %d" % self.page_size) - + # Allocate resources self.dcbaa = dma_align(64, (self.max_slots + 1 ) * 8, memset_value=0) max_sp_hi = (self.bar_read32(0x8) & 0x03E00000) >> 21 @@ -622,11 +614,11 @@ def init(self): self.set(self.dcbaa, self.sp_ptrs) self.dma_buffer = dma_align(64 * 1024, 64 * 1024) self.cr = XHCICommandRing(4) - xhci_debug("command ring %s" % hex(self.cr.ring)) + xhci_debug(f"command ring {hex(self.cr.ring)}") self.er = XHCIEventRing(64) - xhci_debug("event ring %s" % hex(self.er.ring)) + xhci_debug(f"event ring {hex(self.er.ring)}") self.ev_ring_table = dma_align(64, 0x10, memset_value=0) - xhci_debug("event ring table %s" % hex(self.ev_ring_table)) + xhci_debug(f"event ring table {hex(self.ev_ring_table)}") # Setup hardware self.wait_ready() @@ -636,10 +628,10 @@ def init(self): self.bar_write32(0x98, self.cr.ring | 0x1) self.bar_write32(0x9c, 0) - + self.set(self.ev_ring_table, self.er.ring) self.set(self.ev_ring_table + 8, 64) - + self.bar_write32(0x2028, 1) # Size of evet ring table self.bar_write32(0x2030, self.ev_ring_table) self.bar_write32(0x2034, 0) @@ -652,8 +644,8 @@ def init(self): cc = self.cr.noop() running = self.bar_read32(0x98) & 8 - xhci_debug("NOOP result : %s" % TRBCompletionCode.name(cc)) - xhci_debug("Command ring is %s" % ("running" if running else "not running")) + xhci_debug(f"NOOP result : {TRBCompletionCode.name(cc)}") + xhci_debug(f'Command ring is {"running" if running else "not running"}') self.devs = [None]* self.max_ports self.transfer_rings = [None]* self.max_ports @@ -679,7 +671,7 @@ def hub_reset(self, port): portsc[4] = 1 self.bar_write32(0x480 + 0x10 *port, portsc) else: - xhci_debug("Unknown port state %s" % pls) + xhci_debug(f"Unknown port state {pls}") timeout = 100 while True: portsc = self.bar_read32(0x480 + 0x10 * port)