diff --git a/CHANGES.rst b/CHANGES.rst index dbae88c95..f84266433 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -33,6 +33,11 @@ Release History vector generators. (`#284 `__, `#285 `__) +- ``supports_sidedness`` decorator that can be used in algebras to define what + ``ElementSidedness`` values are allowed for an operation. This information can + then be retrieved from the respective methods. + (`#281 `__, + `#291 `__) **Removed** diff --git a/docs/modules/nengo_spa.algebras.rst b/docs/modules/nengo_spa.algebras.rst index a30029e35..931893291 100644 --- a/docs/modules/nengo_spa.algebras.rst +++ b/docs/modules/nengo_spa.algebras.rst @@ -10,6 +10,7 @@ The following items are re-exported by :mod:`nengo_spa.algebras`: base.AbstractAlgebra base.CommonProperties base.ElementSidedness + base.supports_sidedness hrr_algebra.HrrAlgebra hrr_algebra.HrrProperties vtb_algebra.VtbAlgebra diff --git a/nengo_spa/algebras/__init__.py b/nengo_spa/algebras/__init__.py index adbb914f0..8acd6a107 100644 --- a/nengo_spa/algebras/__init__.py +++ b/nengo_spa/algebras/__init__.py @@ -1,6 +1,11 @@ """Algebras define the specific superposition and (un)binding operations.""" -from .base import AbstractAlgebra, CommonProperties, ElementSidedness +from .base import ( + AbstractAlgebra, + CommonProperties, + ElementSidedness, + supports_sidedness, +) from .hrr_algebra import HrrAlgebra, HrrProperties from .tvtb_algebra import TvtbAlgebra, TvtbProperties from .vtb_algebra import VtbAlgebra, VtbProperties diff --git a/nengo_spa/algebras/base.py b/nengo_spa/algebras/base.py index 0d6b8f2e0..58bdafa3e 100644 --- a/nengo_spa/algebras/base.py +++ b/nengo_spa/algebras/base.py @@ -11,6 +11,52 @@ class ElementSidedness(Enum): TWO_SIDED = "two-sided" +def supports_sidedness(sidedness): + """Declare supported sidedness values on an operation. + + This decorator can be used with methods in an algebra that take a + *sidedness* parameter. It declares which values of *sidedness* are + supported by the algebra. The valid values are added as a *frozenset* as + a *supported_sidedness* attribute on the method. + + When checking for supported sidedness, it must first be checked whether + the *supported_sidedness* attribute exists (for backwards compatibility). + If it does not exist, it should be assumed that all values for *sidedness* + are supported. + + Parameters + ---------- + sidedness: Iterable[ElementSidedness] + The sidedness values that are supported by the annotated method. + + Returns + ------- + function + The method itself with the *supported_sidedness* attribute added. + + Examples + ------- + + >>> from nengo_spa.algebras import AbstractAlgebra, supports_sidedness + >>> class MyAlgebra(AbstractAlgebra): + ... @supports_sidedness({ElementSidedness.LEFT}) + ... def invert(self, v, sidedness): + ... # ... + ... pass + ... + ... # ... + ... + >>> print(MyAlgebra.invert.supported_sidedness) + frozenset({}) + """ + + def decorator(fn): + setattr(fn, "supported_sidedness", frozenset(sidedness)) + return fn + + return decorator + + class _DuckTypedABCMeta(ABCMeta): def __instancecheck__(cls, instance): if super().__instancecheck__(instance): diff --git a/nengo_spa/algebras/tests/test_algebras.py b/nengo_spa/algebras/tests/test_algebras.py index eff762920..1b95731ea 100644 --- a/nengo_spa/algebras/tests/test_algebras.py +++ b/nengo_spa/algebras/tests/test_algebras.py @@ -5,6 +5,7 @@ from nengo_spa.algebras.base import AbstractAlgebra, CommonProperties, ElementSidedness from nengo_spa.algebras.vtb_algebra import VtbAlgebra +from nengo_spa.conftest import check_sidedness from nengo_spa.vector_generation import UnitLengthVectors @@ -35,39 +36,39 @@ def test_superpose(algebra, rng): @pytest.mark.parametrize("d", [25, 36]) @pytest.mark.parametrize("sidedness", ElementSidedness) def test_binding_and_invert(algebra, d, sidedness, rng): + check_sidedness(algebra, "invert", sidedness) + dissimilarity_passed = 0 unbinding_passed = 0 - try: - for i in range(10): - gen = UnitLengthVectors(d, rng=rng) - a = next(gen) - b = next(gen) - - binding_side = sidedness - if sidedness is ElementSidedness.TWO_SIDED: - binding_side = ( - ElementSidedness.LEFT if i % 1 == 0 else ElementSidedness.RIGHT - ) - - with warnings.catch_warnings(): - warnings.simplefilter("error", DeprecationWarning) - if binding_side is ElementSidedness.LEFT: - bound = algebra.bind(b, a) - r = algebra.bind(algebra.invert(b, sidedness=sidedness), bound) - elif binding_side is ElementSidedness.RIGHT: - bound = algebra.bind(a, b) - r = algebra.bind(bound, algebra.invert(b, sidedness=sidedness)) - else: - raise AssertionError("Invalid binding_side value.") - - for v in (a, b): - dissimilarity_passed += np.dot(v, bound / np.linalg.norm(b)) < 0.7 - unbinding_passed += np.dot(a, r / np.linalg.norm(r)) > 0.6 - - assert dissimilarity_passed >= 2 * 8 - assert unbinding_passed >= 7 - except (NotImplementedError, DeprecationWarning): - pass + + for i in range(10): + gen = UnitLengthVectors(d, rng=rng) + a = next(gen) + b = next(gen) + + binding_side = sidedness + if sidedness is ElementSidedness.TWO_SIDED: + binding_side = ( + ElementSidedness.LEFT if i % 1 == 0 else ElementSidedness.RIGHT + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + if binding_side is ElementSidedness.LEFT: + bound = algebra.bind(b, a) + r = algebra.bind(algebra.invert(b, sidedness=sidedness), bound) + elif binding_side is ElementSidedness.RIGHT: + bound = algebra.bind(a, b) + r = algebra.bind(bound, algebra.invert(b, sidedness=sidedness)) + else: + raise AssertionError("Invalid binding_side value.") + + for v in (a, b): + dissimilarity_passed += np.dot(v, bound / np.linalg.norm(b)) < 0.7 + unbinding_passed += np.dot(a, r / np.linalg.norm(r)) > 0.6 + + assert dissimilarity_passed >= 2 * 8 + assert unbinding_passed >= 7 @pytest.mark.parametrize("d", [25, 36]) @@ -86,19 +87,18 @@ def test_integer_binding_power(algebra, d, rng): @pytest.mark.parametrize("d", [25, 36]) def test_integer_binding_is_consistent_with_base_implementation(algebra, d, rng): + check_sidedness(algebra, "identity_element", ElementSidedness.LEFT) + v = algebra.create_vector(d, set(), rng=rng) - try: - for exponent in range(-2, 4): - assert np.allclose( - algebra.binding_power(v, exponent), - AbstractAlgebra.binding_power(algebra, v, exponent), - ) + for exponent in range(-2, 4): + assert np.allclose( + algebra.binding_power(v, exponent), + AbstractAlgebra.binding_power(algebra, v, exponent), + ) - with pytest.raises(ValueError, match="only supports integer binding powers"): - AbstractAlgebra.binding_power(algebra, v, 0.5) - except NotImplementedError: - pytest.skip() + with pytest.raises(ValueError, match="only supports integer binding powers"): + AbstractAlgebra.binding_power(algebra, v, 0.5) @pytest.mark.parametrize("d", [16, 25]) @@ -139,110 +139,109 @@ def test_get_binding_matrix(algebra, rng): @pytest.mark.filterwarnings("ignore:.*sidedness:DeprecationWarning") @pytest.mark.parametrize("sidedness", ElementSidedness) def test_get_inversion_matrix(algebra, sidedness, rng): + check_sidedness(algebra, "invert", sidedness) a = next(UnitLengthVectors(16, rng=rng)) - try: - m = algebra.get_inversion_matrix(16, sidedness=sidedness) - assert np.allclose(algebra.invert(a, sidedness=sidedness), np.dot(m, a)) - except NotImplementedError: - pass + m = algebra.get_inversion_matrix(16, sidedness=sidedness) + assert np.allclose(algebra.invert(a, sidedness=sidedness), np.dot(m, a)) @pytest.mark.parametrize("sidedness", ElementSidedness) def test_absorbing_element(algebra, sidedness, rng): + check_sidedness(algebra, "absorbing_element", sidedness) + a = next(UnitLengthVectors(16, rng=rng)) - try: - with warnings.catch_warnings(record=True) as caught_warnings: - warnings.simplefilter("always", DeprecationWarning) - p = algebra.absorbing_element(16) - except NotImplementedError: - pass - else: - is_deprecated = len(caught_warnings) > 0 and any( - issubclass(w.category, DeprecationWarning) for w in caught_warnings - ) - if ( - sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED) - and not is_deprecated - ): - r = algebra.bind(p, a) - r /= np.linalg.norm(r) - assert np.allclose(p, r) or np.allclose(p, -r) - if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED): - r = algebra.bind(a, p) - r /= np.linalg.norm(r) - assert np.allclose(p, r) or np.allclose(p, -r) + + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always", DeprecationWarning) + p = algebra.absorbing_element(16) + + is_deprecated = len(caught_warnings) > 0 and any( + issubclass(w.category, DeprecationWarning) for w in caught_warnings + ) + if ( + sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED) + and not is_deprecated + ): + r = algebra.bind(p, a) + r /= np.linalg.norm(r) + assert np.allclose(p, r) or np.allclose(p, -r) + if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED): + r = algebra.bind(a, p) + r /= np.linalg.norm(r) + assert np.allclose(p, r) or np.allclose(p, -r) @pytest.mark.parametrize("sidedness", ElementSidedness) def test_identity_element(algebra, sidedness, rng): + check_sidedness(algebra, "identity_element", sidedness) + a = next(UnitLengthVectors(16, rng=rng)) - try: - with warnings.catch_warnings(record=True) as caught_warnings: - warnings.simplefilter("always", DeprecationWarning) - p = algebra.identity_element(16) - except NotImplementedError: - pass - else: - is_deprecated = len(caught_warnings) > 0 and any( - issubclass(w.category, DeprecationWarning) for w in caught_warnings - ) - if ( - sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED) - and not is_deprecated - ): - assert np.allclose(algebra.bind(p, a), a) - if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED): - assert np.allclose(algebra.bind(a, p), a) + + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always", DeprecationWarning) + p = algebra.identity_element(16) + + is_deprecated = len(caught_warnings) > 0 and any( + issubclass(w.category, DeprecationWarning) for w in caught_warnings + ) + if ( + sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED) + and not is_deprecated + ): + assert np.allclose(algebra.bind(p, a), a) + if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED): + assert np.allclose(algebra.bind(a, p), a) @pytest.mark.parametrize("sidedness", ElementSidedness) def test_negative_identity_element(algebra, sidedness, rng): - try: - x = next(UnitLengthVectors(16, rng=rng)) - a = algebra.abs(x) - with warnings.catch_warnings(record=True) as caught_warnings: - warnings.simplefilter("always", DeprecationWarning) - p = algebra.negative_identity_element(16) - except NotImplementedError: - pass - else: - is_deprecated = len(caught_warnings) > 0 and any( - issubclass(w.category, DeprecationWarning) for w in caught_warnings - ) - if ( - sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED) - and not is_deprecated - ): - b = algebra.bind(p, a) - assert np.allclose(algebra.abs(b), a) - assert algebra.sign(b).is_negative() - if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED): - b = algebra.bind(a, p) - assert np.allclose(algebra.abs(b), a) - assert algebra.sign(b).is_negative() + x = next(UnitLengthVectors(16, rng=rng)) + if algebra.sign(x).is_indefinite(): + pytest.xfail("Generated vector has an indefinite sign.") + + a = algebra.abs(x) + + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always", DeprecationWarning) + p = algebra.negative_identity_element(16) + + is_deprecated = len(caught_warnings) > 0 and any( + issubclass(w.category, DeprecationWarning) for w in caught_warnings + ) + if ( + sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED) + and not is_deprecated + ): + b = algebra.bind(p, a) + assert np.allclose(algebra.abs(b), a) + assert algebra.sign(b).is_negative() + if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED): + b = algebra.bind(a, p) + assert np.allclose(algebra.abs(b), a) + assert algebra.sign(b).is_negative() @pytest.mark.parametrize("sidedness", ElementSidedness) def test_zero_element(algebra, sidedness, rng): + check_sidedness(algebra, "zero_element", sidedness) + a = next(UnitLengthVectors(16, rng=rng)) - try: - with warnings.catch_warnings(record=True) as caught_warnings: - warnings.simplefilter("always", DeprecationWarning) - p = algebra.zero_element(16) - except NotImplementedError: - pass - else: - assert np.all(p == 0.0) - is_deprecated = len(caught_warnings) > 0 and any( - issubclass(w.category, DeprecationWarning) for w in caught_warnings - ) - if ( - sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED) - and not is_deprecated - ): - assert np.allclose(algebra.bind(a, p), 0.0) - if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED): - assert np.allclose(algebra.bind(p, a), 0.0) + + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always", DeprecationWarning) + p = algebra.zero_element(16) + + assert np.all(p == 0.0) + is_deprecated = len(caught_warnings) > 0 and any( + issubclass(w.category, DeprecationWarning) for w in caught_warnings + ) + if ( + sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED) + and not is_deprecated + ): + assert np.allclose(algebra.bind(a, p), 0.0) + if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED): + assert np.allclose(algebra.bind(p, a), 0.0) def test_isinstance_check(algebra): @@ -318,22 +317,19 @@ def test_isinstance_ducktyping_check(): @pytest.mark.parametrize("sidedness", ElementSidedness) @pytest.mark.filterwarnings("ignore:.*sidedness:DeprecationWarning") def test_sign(algebra, element, check_property, sidedness): - try: - v = getattr(algebra, element + "_element")(16, sidedness) - assert getattr(algebra.sign(v), check_property)() - except NotImplementedError: - pass + method_name = f"{element}_element" + check_sidedness(algebra, method_name, sidedness) + v = getattr(algebra, method_name)(16, sidedness) + assert getattr(algebra.sign(v), check_property)() @pytest.mark.parametrize("d", [16, 25]) @pytest.mark.parametrize("sidedness", ElementSidedness) def test_abs(algebra, d, sidedness): - try: - neg_v = algebra.negative_identity_element(d, sidedness) - assert algebra.sign(neg_v).is_negative() - v = algebra.abs(neg_v) - assert algebra.sign(v).is_positive() - assert np.allclose(v, algebra.identity_element(d, sidedness)) - assert np.allclose(algebra.abs(v), v) # idempotency - except NotImplementedError: - pass + check_sidedness(algebra, "negative_identity_element", sidedness) + neg_v = algebra.negative_identity_element(d, sidedness) + assert algebra.sign(neg_v).is_negative() + v = algebra.abs(neg_v) + assert algebra.sign(v).is_positive() + assert np.allclose(v, algebra.identity_element(d, sidedness)) + assert np.allclose(algebra.abs(v), v) # idempotency diff --git a/nengo_spa/algebras/tvtb_algebra.py b/nengo_spa/algebras/tvtb_algebra.py index a50362808..8429647f8 100644 --- a/nengo_spa/algebras/tvtb_algebra.py +++ b/nengo_spa/algebras/tvtb_algebra.py @@ -8,6 +8,7 @@ CommonProperties, ElementSidedness, GenericSign, + supports_sidedness, ) from nengo_spa.networks.tvtb import TVTB @@ -321,6 +322,7 @@ def sign(self, v): else: return TvtbSign(None) + @supports_sidedness({}) def absorbing_element(self, d, sidedness=ElementSidedness.TWO_SIDED): """TVTB has no absorbing element except the zero vector. diff --git a/nengo_spa/algebras/vtb_algebra.py b/nengo_spa/algebras/vtb_algebra.py index 8cd75fdf5..f12f79dfb 100644 --- a/nengo_spa/algebras/vtb_algebra.py +++ b/nengo_spa/algebras/vtb_algebra.py @@ -8,6 +8,7 @@ CommonProperties, ElementSidedness, GenericSign, + supports_sidedness, ) from nengo_spa.networks.vtb import VTB @@ -184,6 +185,7 @@ def bind(self, a, b): m = self.get_binding_matrix(b) return np.dot(m, a) + @supports_sidedness({ElementSidedness.RIGHT}) def invert(self, v, sidedness=ElementSidedness.TWO_SIDED): """Invert vector *v*. @@ -392,6 +394,7 @@ def abs(self, v): # own inverse. return self.bind(v, self.sign(v).to_vector(len(v))) + @supports_sidedness({}) def absorbing_element(self, d, sidedness=ElementSidedness.TWO_SIDED): """VTB has no absorbing element except the zero vector. @@ -399,6 +402,7 @@ def absorbing_element(self, d, sidedness=ElementSidedness.TWO_SIDED): """ raise NotImplementedError("VtbAlgebra does not have any absorbing elements.") + @supports_sidedness({ElementSidedness.RIGHT}) def identity_element(self, d, sidedness=ElementSidedness.TWO_SIDED): """Return the identity element of dimensionality *d*. @@ -437,6 +441,7 @@ def identity_element(self, d, sidedness=ElementSidedness.TWO_SIDED): sub_d = self._get_sub_d(d) return (np.eye(sub_d) / d ** 0.25).flatten() + @supports_sidedness({ElementSidedness.RIGHT}) def negative_identity_element(self, d, sidedness=ElementSidedness.TWO_SIDED): r"""Return the negative identity element of dimensionality *d*. diff --git a/nengo_spa/ast/tests/test_dynamic.py b/nengo_spa/ast/tests/test_dynamic.py index 5a75b0009..a60827862 100644 --- a/nengo_spa/ast/tests/test_dynamic.py +++ b/nengo_spa/ast/tests/test_dynamic.py @@ -4,7 +4,9 @@ from numpy.testing import assert_allclose import nengo_spa as spa +from nengo_spa.algebras.base import ElementSidedness from nengo_spa.ast.symbolic import PointerSymbol +from nengo_spa.conftest import check_sidedness from nengo_spa.exceptions import SpaTypeError from nengo_spa.semantic_pointer import SemanticPointer from nengo_spa.testing import assert_sp_close @@ -49,26 +51,31 @@ def test_unary_operation_on_module(Simulator, algebra, op, suffix, rng): assert_sp_close(sim.trange(), sim.data[p], vocab.parse(op + "A"), skip=0.2) -@pytest.mark.parametrize("sidedness", ["l", "r"]) +@pytest.mark.parametrize("sidedness", [ElementSidedness.LEFT, ElementSidedness.RIGHT]) @pytest.mark.parametrize("suffix", ["", ".output"]) def test_inv_operation_on_module(Simulator, algebra, sidedness, suffix, rng): - try: - vocab = spa.Vocabulary(16, pointer_gen=rng, algebra=algebra) - vocab.populate("A") + check_sidedness(algebra, "invert", sidedness) + sidedness_prefix = {ElementSidedness.LEFT: "l", ElementSidedness.RIGHT: "r"}[ + sidedness + ] - with spa.Network() as model: - stimulus = spa.Transcode("A", output_vocab=vocab) # noqa: F841 - x = eval("stimulus" + suffix + "." + sidedness + "inv()") - p = nengo.Probe(x.construct(), synapse=0.03) + vocab = spa.Vocabulary(16, pointer_gen=rng, algebra=algebra) + vocab.populate("A") - with Simulator(model) as sim: - sim.run(0.3) + with spa.Network() as model: + stimulus = spa.Transcode("A", output_vocab=vocab) # noqa: F841 + x = eval("stimulus" + suffix + "." + sidedness_prefix + "inv()") + p = nengo.Probe(x.construct(), synapse=0.03) - assert_sp_close( - sim.trange(), sim.data[p], vocab.parse("A." + sidedness + "inv()"), skip=0.2 - ) - except NotImplementedError: - pass + with Simulator(model) as sim: + sim.run(0.3) + + assert_sp_close( + sim.trange(), + sim.data[p], + vocab.parse("A." + sidedness_prefix + "inv()"), + skip=0.2, + ) @pytest.mark.parametrize("op", ["+", "-", "*"]) @@ -230,28 +237,27 @@ def test_transformed_and_pointer_symbol(Simulator, algebra, seed, rng): def test_transformed_and_network(Simulator, algebra, seed, rng): - try: - vocab = spa.Vocabulary(16, pointer_gen=rng, algebra=algebra) - vocab.populate("A; B.unitary()") - - with spa.Network(seed=seed) as model: - a = spa.Transcode("A", output_vocab=vocab) - b = spa.Transcode("B", output_vocab=vocab) - x = (a * PointerSymbol("B.linv()")) * b - p = nengo.Probe(x.construct(), synapse=0.3) - - with Simulator(model) as sim: - sim.run(0.3) - - assert_sp_close( - sim.trange(), - sim.data[p], - vocab.parse("A * B.linv() * B"), - skip=0.2, - normalized=True, - ) - except NotImplementedError: - pass + check_sidedness(algebra, "invert", ElementSidedness.LEFT) + + vocab = spa.Vocabulary(16, pointer_gen=rng, algebra=algebra) + vocab.populate("A; B.unitary()") + + with spa.Network(seed=seed) as model: + a = spa.Transcode("A", output_vocab=vocab) + b = spa.Transcode("B", output_vocab=vocab) + x = (a * PointerSymbol("B.linv()")) * b + p = nengo.Probe(x.construct(), synapse=0.3) + + with Simulator(model) as sim: + sim.run(0.3) + + assert_sp_close( + sim.trange(), + sim.data[p], + vocab.parse("A * B.linv() * B"), + skip=0.2, + normalized=True, + ) def test_transformed_and_transformed(Simulator, algebra, seed, rng): diff --git a/nengo_spa/conftest.py b/nengo_spa/conftest.py index 1aeadc148..143565f86 100644 --- a/nengo_spa/conftest.py +++ b/nengo_spa/conftest.py @@ -38,3 +38,15 @@ def pytest_generate_tests(metafunc): "algebra", [pytest.param(a, id=a.__class__.__name__) for a in TestConfig.algebras], ) + + +def check_sidedness(algebra, method_name, sidedness): + method = getattr(algebra, method_name) + if ( + hasattr(method, "supported_sidedness") + and sidedness not in method.supported_sidedness + ): + pytest.xfail( + f"Algebra {algebra.__class__.__name__} does not have a " + f"{sidedness} {method_name}." + ) diff --git a/nengo_spa/tests/test_semantic_pointer.py b/nengo_spa/tests/test_semantic_pointer.py index 9cfb064cf..95e43092f 100644 --- a/nengo_spa/tests/test_semantic_pointer.py +++ b/nengo_spa/tests/test_semantic_pointer.py @@ -15,6 +15,7 @@ ) from nengo_spa.algebras.hrr_algebra import HrrAlgebra, HrrSign from nengo_spa.ast.symbolic import PointerSymbol +from nengo_spa.conftest import check_sidedness from nengo_spa.exceptions import SpaTypeError from nengo_spa.semantic_pointer import ( AbsorbingElement, @@ -123,9 +124,11 @@ def test_add_sub(algebra, rng): @pytest.mark.parametrize("d", [64, 65]) -def test_binding(algebra, d, rng): +@pytest.mark.parametrize("sidedness", ElementSidedness) +def test_binding(algebra, d, sidedness, rng): if not algebra.is_valid_dimensionality(d): - return + pytest.xfail("Invalid dimensionality for algebra.") + check_sidedness(algebra, "identity_element", sidedness) gen = UnitLengthVectors(d, rng=rng) @@ -140,16 +143,12 @@ def test_binding(algebra, d, rng): assert np.allclose((a * b).v, conv_ans) assert np.allclose(a.bind(b).v, conv_ans) assert np.allclose(c.v, conv_ans) - try: - identity = Identity(d, algebra=algebra, sidedness=ElementSidedness.RIGHT) - assert np.allclose((a * identity).v, a.v) - except NotImplementedError: - pass - try: - identity = Identity(d, algebra=algebra, sidedness=ElementSidedness.LEFT) + + identity = Identity(d, algebra=algebra, sidedness=sidedness) + if sidedness in {ElementSidedness.LEFT, ElementSidedness.TWO_SIDED}: assert np.allclose((identity * a).v, a.v) - except NotImplementedError: - pass + if sidedness in {ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED}: + assert np.allclose((a * identity).v, a.v) @pytest.mark.parametrize("d", [64, 65]) @@ -180,31 +179,29 @@ def test_fractional_binding_power(algebra, d, rng): assert np.allclose(a.v, (pow(a, 0.5) ** 2).v) -@pytest.mark.filterwarnings("ignore:.*sidedness:DeprecationWarning") -def test_inverse(algebra, rng): +def test_inverse_two_sided(algebra, rng): + check_sidedness(algebra, "invert", ElementSidedness.TWO_SIDED) gen = UnitLengthVectors(64, rng=rng) a = SemanticPointer(next(gen), algebra=algebra) + assert np.allclose( + (~a).v, algebra.invert(a.v, sidedness=ElementSidedness.TWO_SIDED) + ) - try: - assert np.allclose( - (~a).v, algebra.invert(a.v, sidedness=ElementSidedness.TWO_SIDED) - ) - except NotImplementedError: - pass - try: - assert np.allclose( - a.linv().v, algebra.invert(a.v, sidedness=ElementSidedness.LEFT) - ) - except NotImplementedError: - pass +def test_inverse_left(algebra, rng): + check_sidedness(algebra, "invert", ElementSidedness.LEFT) + gen = UnitLengthVectors(64, rng=rng) + a = SemanticPointer(next(gen), algebra=algebra) + assert np.allclose(a.linv().v, algebra.invert(a.v, sidedness=ElementSidedness.LEFT)) - try: - assert np.allclose( - a.rinv().v, algebra.invert(a.v, sidedness=ElementSidedness.RIGHT) - ) - except NotImplementedError: - pass + +def test_inverse_right(algebra, rng): + check_sidedness(algebra, "invert", ElementSidedness.RIGHT) + gen = UnitLengthVectors(64, rng=rng) + a = SemanticPointer(next(gen), algebra=algebra) + assert np.allclose( + a.rinv().v, algebra.invert(a.v, sidedness=ElementSidedness.RIGHT) + ) def test_multiply(rng): @@ -402,38 +399,31 @@ def test_invalid_algebra(): @pytest.mark.filterwarnings("ignore:.*:DeprecationWarning") @pytest.mark.parametrize("sidedness", ElementSidedness) def test_identity(algebra, sidedness): - try: - assert np.allclose( - Identity(64, algebra=algebra, sidedness=sidedness).v, - algebra.identity_element(64, sidedness=sidedness), - ) - except NotImplementedError: - pass + check_sidedness(algebra, "identity_element", sidedness) + assert np.allclose( + Identity(64, algebra=algebra, sidedness=sidedness).v, + algebra.identity_element(64, sidedness=sidedness), + ) @pytest.mark.filterwarnings("ignore:.*:DeprecationWarning") @pytest.mark.parametrize("sidedness", ElementSidedness) -def test_absorbing_element(algebra, sidedness, plt): - plt.plot([0, 1], [0, 1]) - try: - assert np.allclose( - AbsorbingElement(64, algebra=algebra, sidedness=sidedness).v, - algebra.absorbing_element(64, sidedness=sidedness), - ) - except NotImplementedError: - pass +def test_absorbing_element(algebra, sidedness): + check_sidedness(algebra, "absorbing_element", sidedness) + assert np.allclose( + AbsorbingElement(64, algebra=algebra, sidedness=sidedness).v, + algebra.absorbing_element(64, sidedness=sidedness), + ) @pytest.mark.filterwarnings("ignore:.*:DeprecationWarning") @pytest.mark.parametrize("sidedness", ElementSidedness) def test_zero(algebra, sidedness): - try: - assert np.allclose( - Zero(64, algebra=algebra, sidedness=sidedness).v, - algebra.zero_element(64, sidedness=sidedness), - ) - except NotImplementedError: - pass + check_sidedness(algebra, "zero_element", sidedness) + assert np.allclose( + Zero(64, algebra=algebra, sidedness=sidedness).v, + algebra.zero_element(64, sidedness=sidedness), + ) def test_name():