diff --git a/NetUtils.py b/NetUtils.py index b30316ca6d7b..99c37238c35a 100644 --- a/NetUtils.py +++ b/NetUtils.py @@ -347,6 +347,18 @@ def local(self): class _LocationStore(dict, typing.MutableMapping[int, typing.Dict[int, typing.Tuple[int, int, int]]]): + def __init__(self, values: typing.MutableMapping[int, typing.Dict[int, typing.Tuple[int, int, int]]]): + super().__init__(values) + + if not self: + raise ValueError(f"Rejecting game with 0 players") + + if len(self) != max(self): + raise ValueError("Player IDs not continuous") + + if len(self.get(0, {})): + raise ValueError("Invalid player id 0 for location") + def find_item(self, slots: typing.Set[int], seeked_item_id: int ) -> typing.Generator[typing.Tuple[int, int, int, int, int], None, None]: for finding_player, check_data in self.items(): diff --git a/_speedups.pyx b/_speedups.pyx index 95e837d1bba6..fc2413ceb5d1 100644 --- a/_speedups.pyx +++ b/_speedups.pyx @@ -8,6 +8,7 @@ This is deliberately .pyx because using a non-compiled "pure python" may be slow # pip install cython cymem import cython +import warnings from cpython cimport PyObject from typing import Any, Dict, Iterable, Iterator, Generator, Sequence, Tuple, TypeVar, Union, Set, List, TYPE_CHECKING from cymem.cymem cimport Pool @@ -107,13 +108,16 @@ cdef class LocationStore: count += 1 sender_count += 1 - if not count: - raise ValueError("No locations") + if not sender_count: + raise ValueError(f"Rejecting game with 0 players") if sender_count != max_sender: # we assume player 0 will never have locations raise ValueError("Player IDs not continuous") + if not count: + warnings.warn("Game has no locations") + # allocate the arrays and invalidate index (0xff...) self.entries = self._mem.alloc(count, sizeof(LocationEntry)) self.sender_index = self._mem.alloc(max_sender + 1, sizeof(IndexEntry)) @@ -140,9 +144,9 @@ cdef class LocationStore: self._proxies.append(None) # player 0 assert self.sender_index[0].count == 0 for i in range(1, max_sender + 1): - if self.sender_index[i].count == 0 and self.sender_index[i].start >= count: - self.sender_index[i].start = 0 # do not point outside valid entries - assert self.sender_index[i].start < count + assert self.sender_index[i].count == 0 or ( + self.sender_index[i].start < count and + self.sender_index[i].start + self.sender_index[i].count <= count) key = i # allocate python integer proxy = PlayerLocationProxy(self, i) self._keys.append(key) diff --git a/test/netutils/TestLocationStore.py b/test/netutils/TestLocationStore.py index 5c98437a031e..9fe904f68a16 100644 --- a/test/netutils/TestLocationStore.py +++ b/test/netutils/TestLocationStore.py @@ -1,10 +1,13 @@ # Tests for _speedups.LocationStore and NetUtils._LocationStore import typing import unittest +import warnings from NetUtils import LocationStore, _LocationStore +State = typing.Dict[typing.Tuple[int, int], typing.Set[int]] +RawLocations = typing.Dict[int, typing.Dict[int, typing.Tuple[int, int, int]]] -sample_data = { +sample_data: RawLocations = { 1: { 11: (21, 2, 7), 12: (22, 2, 0), @@ -23,28 +26,29 @@ }, } -empty_state = { +empty_state: State = { (0, slot): set() for slot in sample_data } -full_state = { +full_state: State = { (0, slot): set(locations) for (slot, locations) in sample_data.items() } -one_state = { +one_state: State = { (0, 1): {12} } class Base: class TestLocationStore(unittest.TestCase): + """Test method calls on a loaded store.""" store: typing.Union[LocationStore, _LocationStore] - def test_len(self): + def test_len(self) -> None: self.assertEqual(len(self.store), 4) self.assertEqual(len(self.store[1]), 3) - def test_key_error(self): + def test_key_error(self) -> None: with self.assertRaises(KeyError): _ = self.store[0] with self.assertRaises(KeyError): @@ -54,25 +58,25 @@ def test_key_error(self): _ = locations[7] _ = locations[11] # no Exception - def test_getitem(self): + def test_getitem(self) -> None: self.assertEqual(self.store[1][11], (21, 2, 7)) self.assertEqual(self.store[1][13], (13, 1, 0)) self.assertEqual(self.store[2][22], (12, 1, 0)) self.assertEqual(self.store[4][9], (99, 3, 0)) - def test_get(self): + def test_get(self) -> None: self.assertEqual(self.store.get(1, None), self.store[1]) self.assertEqual(self.store.get(0, None), None) self.assertEqual(self.store[1].get(11, (None, None, None)), self.store[1][11]) self.assertEqual(self.store[1].get(10, (None, None, None)), (None, None, None)) - def test_iter(self): + def test_iter(self) -> None: self.assertEqual(sorted(self.store), [1, 2, 3, 4]) self.assertEqual(len(self.store), len(sample_data)) self.assertEqual(list(self.store[1]), [11, 12, 13]) self.assertEqual(len(self.store[1]), len(sample_data[1])) - def test_items(self): + def test_items(self) -> None: self.assertEqual(sorted(p for p, _ in self.store.items()), sorted(self.store)) self.assertEqual(sorted(p for p, _ in self.store[1].items()), sorted(self.store[1])) self.assertEqual(sorted(self.store.items())[0][0], 1) @@ -80,7 +84,7 @@ def test_items(self): self.assertEqual(sorted(self.store[1].items())[0][0], 11) self.assertEqual(sorted(self.store[1].items())[0][1], self.store[1][11]) - def test_find_item(self): + def test_find_item(self) -> None: self.assertEqual(sorted(self.store.find_item(set(), 99)), []) self.assertEqual(sorted(self.store.find_item({3}, 1)), []) self.assertEqual(sorted(self.store.find_item({5}, 99)), []) @@ -89,129 +93,141 @@ def test_find_item(self): self.assertEqual(sorted(self.store.find_item({3, 4}, 99)), [(3, 9, 99, 4, 0), (4, 9, 99, 3, 0)]) - def test_get_for_player(self): + def test_get_for_player(self) -> None: self.assertEqual(self.store.get_for_player(3), {4: {9}}) self.assertEqual(self.store.get_for_player(1), {1: {13}, 2: {22, 23}}) - def get_checked(self): + def get_checked(self) -> None: self.assertEqual(self.store.get_checked(full_state, 0, 1), [11, 12, 13]) self.assertEqual(self.store.get_checked(one_state, 0, 1), [12]) self.assertEqual(self.store.get_checked(empty_state, 0, 1), []) self.assertEqual(self.store.get_checked(full_state, 0, 3), [9]) - def get_missing(self): + def get_missing(self) -> None: self.assertEqual(self.store.get_missing(full_state, 0, 1), []) self.assertEqual(self.store.get_missing(one_state, 0, 1), [11, 13]) self.assertEqual(self.store.get_missing(empty_state, 0, 1), [11, 12, 13]) self.assertEqual(self.store.get_missing(empty_state, 0, 3), [9]) - def get_remaining(self): + def get_remaining(self) -> None: self.assertEqual(self.store.get_remaining(full_state, 0, 1), []) self.assertEqual(self.store.get_remaining(one_state, 0, 1), [13, 21]) self.assertEqual(self.store.get_remaining(empty_state, 0, 1), [13, 21, 22]) self.assertEqual(self.store.get_remaining(empty_state, 0, 3), [99]) + class TestLocationStoreConstructor(unittest.TestCase): + """Test constructors for a given store type.""" + type: type + + def test_hole(self) -> None: + with self.assertRaises(Exception): + self.type({ + 1: {1: (1, 1, 1)}, + 3: {1: (1, 1, 1)}, + }) + + def test_no_slot1(self) -> None: + with self.assertRaises(Exception): + self.type({ + 2: {1: (1, 1, 1)}, + 3: {1: (1, 1, 1)}, + }) + + def test_slot0(self) -> None: + with self.assertRaises(ValueError): + self.type({ + 0: {1: (1, 1, 1)}, + 1: {1: (1, 1, 1)}, + }) + with self.assertRaises(ValueError): + self.type({ + 0: {1: (1, 1, 1)}, + 2: {1: (1, 1, 1)}, + }) + + def test_no_players(self) -> None: + with self.assertRaises(Exception): + _ = self.type({}) + + def test_no_locations(self) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + store = self.type({ + 1: {}, + }) + self.assertEqual(len(store), 1) + self.assertEqual(len(store[1]), 0) + + def test_no_locations_for_1(self) -> None: + store = self.type({ + 1: {}, + 2: {1: (1, 2, 3)}, + }) + self.assertEqual(len(store), 2) + self.assertEqual(len(store[1]), 0) + self.assertEqual(len(store[2]), 1) + + def test_no_locations_for_last(self) -> None: + store = self.type({ + 1: {1: (1, 2, 3)}, + 2: {}, + }) + self.assertEqual(len(store), 2) + self.assertEqual(len(store[1]), 1) + self.assertEqual(len(store[2]), 0) + class TestPurePythonLocationStore(Base.TestLocationStore): + """Run base method tests for pure python implementation.""" def setUp(self) -> None: self.store = _LocationStore(sample_data) super().setUp() +class TestPurePythonLocationStoreConstructor(Base.TestLocationStoreConstructor): + """Run base constructor tests for the pure python implementation.""" + def setUp(self) -> None: + self.type = _LocationStore + super().setUp() + + @unittest.skipIf(LocationStore is _LocationStore, "_speedups not available") class TestSpeedupsLocationStore(Base.TestLocationStore): + """Run base method tests for cython implementation.""" def setUp(self) -> None: self.store = LocationStore(sample_data) super().setUp() @unittest.skipIf(LocationStore is _LocationStore, "_speedups not available") -class TestSpeedupsLocationStoreConstructor(unittest.TestCase): - def test_float_key(self): +class TestSpeedupsLocationStoreConstructor(Base.TestLocationStoreConstructor): + """Run base constructor tests and tests the additional constraints for cython implementation.""" + def setUp(self) -> None: + self.type = LocationStore + super().setUp() + + def test_float_key(self) -> None: with self.assertRaises(Exception): - LocationStore({ + self.type({ 1: {1: (1, 1, 1)}, 1.1: {1: (1, 1, 1)}, 3: {1: (1, 1, 1)} }) - def test_string_key(self): + def test_string_key(self) -> None: with self.assertRaises(Exception): - LocationStore({ + self.type({ "1": {1: (1, 1, 1)}, }) - def test_hole(self): + def test_high_player_number(self) -> None: with self.assertRaises(Exception): - LocationStore({ - 1: {1: (1, 1, 1)}, - 3: {1: (1, 1, 1)}, - }) - - def test_no_slot1(self): - with self.assertRaises(Exception): - LocationStore({ - 2: {1: (1, 1, 1)}, - 3: {1: (1, 1, 1)}, - }) - - def test_slot0(self): - with self.assertRaises(Exception): - LocationStore({ - 0: {1: (1, 1, 1)}, - 1: {1: (1, 1, 1)}, - }) - with self.assertRaises(Exception): - LocationStore({ - 0: {1: (1, 1, 1)}, - 2: {1: (1, 1, 1)}, - }) - - def test_high_player_number(self): - with self.assertRaises(Exception): - LocationStore({ + self.type({ 1 << 32: {1: (1, 1, 1)}, }) - def test_no_players(self): - try: # either is fine: raise during init, or behave like {} - store = LocationStore({}) - self.assertEqual(len(store), 0) - with self.assertRaises(KeyError): - _ = store[1] - except ValueError: - pass - - def test_no_locations(self): - try: # either is fine: raise during init, or behave like {1: {}} - store = LocationStore({ - 1: {}, - }) - self.assertEqual(len(store), 1) - self.assertEqual(len(store[1]), 0) - except ValueError: - pass - - def test_no_locations_for_1(self): - store = LocationStore({ - 1: {}, - 2: {1: (1, 2, 3)}, - }) - self.assertEqual(len(store), 2) - self.assertEqual(len(store[1]), 0) - self.assertEqual(len(store[2]), 1) - - def test_no_locations_for_last(self): - store = LocationStore({ - 1: {1: (1, 2, 3)}, - 2: {}, - }) - self.assertEqual(len(store), 2) - self.assertEqual(len(store[1]), 1) - self.assertEqual(len(store[2]), 0) - - def test_not_a_tuple(self): + def test_not_a_tuple(self) -> None: with self.assertRaises(Exception): - LocationStore({ + self.type({ 1: {1: None}, })