Skip to content

Commit

Permalink
Move manual decoders onchip
Browse files Browse the repository at this point in the history
With this change, decoders that are manually specified through
a connection from `ens.neurons` with a `transform` are treated the
same as decoders specified through a connection from `ens` with
a `NoSolver` solver.

Co-authored-by: Trevor Bekolay <[email protected]>
  • Loading branch information
tcstewar and tbekolay committed Sep 26, 2018
1 parent 9ca0d0c commit 0811e2a
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
34 changes: 25 additions & 9 deletions nengo_loihi/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def split(model, inter_rate, inter_n): # noqa: C901
intercepts=[-1] * dim + [-1] * dim)

# scale the input spikes based on the radius of the
# target ensemble
# target ensemble
if isinstance(c.post_obj, nengo.Ensemble):
scaling = 1.0 / c.post_obj.radius
else:
Expand All @@ -216,13 +216,29 @@ def split(model, inter_rate, inter_n): # noqa: C901
nengo.Connection(receive, c.post, synapse=c.synapse)
with chip:
logger.debug("Creating Probe for %s", c)
probe = nengo.Probe(c.pre, synapse=None, solver=c.solver)
chip2host_params[probe] = dict(
learning_rule_type=c.learning_rule_type,
function=c.function,
eval_points=c.eval_points,
scale_eval_points=c.scale_eval_points,
transform=c.transform)
if (isinstance(c.pre, nengo.ensemble.Neurons) and
c.transform.ndim == 2):
# decoders manually specified in the transform
# should be handled like a normal decoder
probe = nengo.Probe(
c.pre.ensemble,
synapse=None,
solver=nengo.solvers.NoSolver(c.transform.T))
dims = c.transform.shape[0]
chip2host_params[probe] = dict(
learning_rule_type=c.learning_rule_type,
function=lambda x, dims=dims: np.zeros(dims),
transform=np.array(1),
)
else:
probe = nengo.Probe(c.pre, synapse=None, solver=c.solver)
chip2host_params[probe] = dict(
learning_rule_type=c.learning_rule_type,
function=c.function,
eval_points=c.eval_points,
scale_eval_points=c.scale_eval_points,
transform=c.transform
)
chip2host_receivers[probe] = receive
if c.learning_rule_type is not None:
modulated_conns[c] = probe
Expand All @@ -249,7 +265,7 @@ def base_obj(obj):
return obj


def split_pre_from_host(host_model): # noqa: C901
def split_pre_from_host(host_model): # noqa: C901
assert len(host_model.networks) == 0
logger.info("Splitting pre model from host")

Expand Down
44 changes: 44 additions & 0 deletions nengo_loihi/tests/test_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import nengo
import numpy as np


@pytest.mark.parametrize("pre_dims", [1, 3])
@pytest.mark.parametrize("post_dims", [1])
@pytest.mark.parametrize("learn", [True, False])
@pytest.mark.parametrize("use_solver", [True, False])
def test_manual_decoders(
seed, Simulator, pre_dims, post_dims, learn, use_solver):

with nengo.Network(seed=seed) as model:
pre = nengo.Ensemble(50, dimensions=pre_dims,
gain=np.ones(50),
bias=np.ones(50) * 5)
post = nengo.Node(None, size_in=post_dims)

learning_rule_type = nengo.PES() if learn else None
weights = np.zeros((post_dims, 50))
if use_solver:
conn = nengo.Connection(pre, post,
function=lambda x: np.zeros(post_dims),
learning_rule_type=learning_rule_type,
solver=nengo.solvers.NoSolver(weights.T))
else:
conn = nengo.Connection(pre.neurons, post,
learning_rule_type=learning_rule_type,
transform=weights)

if learn:
error = nengo.Node(np.zeros(post_dims))
nengo.Connection(error, conn.learning_rule)

pre_probe = nengo.Probe(pre.neurons, synapse=None)
post_probe = nengo.Probe(post, synapse=None)

with Simulator(model, precompute=False) as sim:
sim.run(0.1)

# Ensure pre population has a lot of activity
assert np.mean(sim.data[pre_probe]) > 100
# But that post has no activity due to the zero weights
assert np.all(sim.data[post_probe] == 0)

0 comments on commit 0811e2a

Please sign in to comment.