Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning matches Nengo learning more closely #139

Merged
merged 5 commits into from
Jan 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ Release history
0.5.0 (unreleased)
==================

**Changed**

- PES learning in Nengo Loihi more closely matches learning in core Nengo.
(`#139 <https://github.com/nengo/nengo-loihi/pull/139>`__)
- Learning in the emulator more closely matches learning on hardware.
(`#139 <https://github.com/nengo/nengo-loihi/pull/139>`__)

**Fixed**

- We integrate current (U) and voltage (V) more accurately now by accounting
Expand Down
18 changes: 10 additions & 8 deletions docs/examples/adaptive_motor_control.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,9 @@
"source": [
"## Running the network with Nengo Loihi\n",
"\n",
"Loihi has some implicit gains included in it\n",
"so we use a learning rate of 1e-6 in Nengo Loihi,\n",
"instead of 1e-5, to match the performance of the standard Nengo backend."
"Loihi has limits on the magnitude of the \n",
"error signal, so we need to adjust the\n",
"`pes_error_scale` parameter to avoid too much overflow."
]
},
{
Expand All @@ -438,10 +438,12 @@
"metadata": {},
"outputs": [],
"source": [
"adapt_loihi_model = add_adaptation(\n",
" add_gravity(build_baseline_model()), learning_rate=1e-6)\n",
"arm_sim.reset()\n",
"with nengo_loihi.Simulator(adapt_loihi_model) as sim:\n",
"\n",
"model = nengo_loihi.builder.Model()\n",
"model.pes_error_scale = 10\n",
"\n",
"with nengo_loihi.Simulator(adapt_model, model=model) as sim:\n",
" sim.run(runtime)\n",
"adapt_loihi_t = sim.trange()\n",
"adapt_loihi_data = sim.data"
Expand All @@ -453,10 +455,10 @@
"metadata": {},
"outputs": [],
"source": [
"adapt_loihi_error = calculate_error(adapt_loihi_model, adapt_loihi_data)\n",
"adapt_loihi_error = calculate_error(adapt_model, adapt_loihi_data)\n",
"plot_xy(\n",
" [adapt_loihi_t],\n",
" [adapt_loihi_data[adapt_loihi_model.probe_hand]],\n",
" [adapt_loihi_data[adapt_model.probe_hand]],\n",
" ['Adapt'])\n",
"plot_data(\n",
" [baseline_t, gravity_t, adapt_t, adapt_loihi_t],\n",
Expand Down
2 changes: 1 addition & 1 deletion nengo_loihi/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def core_stdp_pre_cfgs(core):
profile_idxs = {}
for synapses in core.synapses:
if synapses.tracing:
mag_int, mag_frac = tracing_mag_int_frac(synapses)
mag_int, mag_frac = tracing_mag_int_frac(synapses.tracing_mag)
tracecfg = TraceCfg(
tau=synapses.tracing_tau,
spikeLevelInt=mag_int,
Expand Down
56 changes: 43 additions & 13 deletions nengo_loihi/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from nengo.dists import Distribution, get_samples
from nengo.connection import LearningRule
from nengo.ensemble import Neurons
from nengo.exceptions import BuildError
from nengo.exceptions import BuildError, ValidationError
from nengo.solvers import NoSolver, Solver
from nengo.utils.builder import default_n_eval_points
import nengo.utils.numpy as npext
Expand Down Expand Up @@ -110,6 +110,13 @@ def __init__(self, dt=0.001, label=None, builder=None):
# limit for clipping intercepts, to avoid neurons with high gains
self.intercept_limit = 0.95

# scaling for PES errors, before rounding and clipping to -127..127
self.pes_error_scale = 100.

# learning weight exponent for PES (controls the maximum weight
# magnitude/weight resolution)
self.pes_wgt_exp = 4

# Will be provided by Simulator
self.chip2host_params = {}

Expand Down Expand Up @@ -657,18 +664,41 @@ def build_connection(model, conn):
model.objs[conn]['decode_axons'] = dec_ax0

if conn.learning_rule_type is not None:
if isinstance(conn.learning_rule_type, nengo.PES):
pes_learn_rate = conn.learning_rule_type.learning_rate
# scale learning rates to roughly match Nengo
# 1e-4 is the Nengo core default learning rate
pes_learn_rate *= 4 / 1e-4
assert isinstance(conn.learning_rule_type.pre_synapse,
nengo.synapses.Lowpass)
pes_pre_syn = conn.learning_rule_type.pre_synapse.tau
# scale pre_syn.tau from s to ms
pes_pre_syn *= 1e3
dec_syn.set_learning(tracing_tau=pes_pre_syn,
tracing_mag=pes_learn_rate)
rule_type = conn.learning_rule_type
if isinstance(rule_type, nengo.PES):
if not isinstance(rule_type.pre_synapse,
nengo.synapses.Lowpass):
raise ValidationError(
"Loihi only supports `Lowpass` pre-synapses for "
"learning rules", attr='pre_synapse', obj=rule_type)

tracing_tau = rule_type.pre_synapse.tau / model.dt

# Nengo builder scales PES learning rate by `dt / n_neurons`
n_neurons = (conn.pre_obj.n_neurons
if isinstance(conn.pre_obj, Ensemble)
else conn.pre_obj.size_in)
learning_rate = rule_type.learning_rate * model.dt / n_neurons

# Account for scaling to put integer error in range [-127, 127]
learning_rate /= model.pes_error_scale

# Tracing mag set so that the magnitude of the pre trace
# is independent of the pre tau. `dt` factor accounts for
# Nengo's `dt` spike scaling. Where is the second `dt` from?
# Maybe the fact that post interneurons have `vth = 1/dt`?
tracing_mag = -np.expm1(-1. / tracing_tau) / model.dt**2

# learning weight exponent controls the maximum weight
# magnitude/weight resolution
wgt_exp = model.pes_wgt_exp

dec_syn.set_learning(
learning_rate=learning_rate,
tracing_mag=tracing_mag,
tracing_tau=tracing_tau,
wgt_exp=wgt_exp,
)
else:
raise NotImplementedError()

Expand Down
8 changes: 4 additions & 4 deletions nengo_loihi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def add_params(network):
"""
config = network.config

cfg = config[nengo.Ensemble]
if 'on_chip' not in cfg._extra_params:
cfg.set_param("on_chip",
Parameter('on_chip', default=None, optional=True))
ens_cfg = config[nengo.Ensemble]
if 'on_chip' not in ens_cfg._extra_params:
ens_cfg.set_param("on_chip",
Parameter('on_chip', default=None, optional=True))


def set_defaults():
Expand Down
94 changes: 75 additions & 19 deletions nengo_loihi/loihi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@
Q_BITS = 21 # number of bits for synapse accumulator
U_BITS = 23 # number of bits for cx input (u)

LEARN_BITS = 15 # number of bits in learning accumulator (not incl. sign)
LEARN_FRAC = 7 # extra least-significant bits added to weights for learning


def learn_overflow_bits(n_factors):
drasmuss marked this conversation as resolved.
Show resolved Hide resolved
"""Compute number of bits with which learning will overflow.

Parameters
----------
n_factors : int
The number of learning factors (pre/post terms in the learning rule).
"""
factor_bits = 7 # number of bits per factor
mantissa_bits = 3 # number of bits for learning rate mantissa
return factor_bits*n_factors + mantissa_bits - LEARN_BITS


def overflow_signed(x, bits=7, out=None):
"""Compute overflow on an array of signed integers.
Expand Down Expand Up @@ -86,16 +102,10 @@ def bias_to_manexp(bias):
return man, exp


def tracing_mag_int_frac(synapses):
mag = synapses.tracing_mag
mag = mag / (synapses.size() / 100)

def tracing_mag_int_frac(mag):
"""Split trace magnitude into integer and fractional components for chip"""
mag_int = int(mag)
# TODO: how does mag_frac actually work???
# It's the x in x/128, I believe
mag_frac = int(128 * (mag - mag_int))
# mag_frac = min(int(round(1./mag_frac)), 128)

return mag_int, mag_frac


Expand Down Expand Up @@ -157,6 +167,23 @@ def shift(x, s, **kwargs):
return np.left_shift(x, s, **kwargs)


def scale_pes_errors(error, scale=1.):
drasmuss marked this conversation as resolved.
Show resolved Hide resolved
"""Scale PES errors based on a scaling factor, round and clip."""
error = scale * error
error = np.round(error).astype(np.int32)
q = error > 127
if np.any(q):
warnings.warn("Max PES error (%0.2e) greater than chip max (%0.2e). "
"Clipping." % (error.max() / scale, 127. / scale))
error[q] = 127
q = error < -127
if np.any(q):
warnings.warn("Min PES error (%0.2e) less than chip min (%0.2e). "
"Clipping." % (error.min() / scale, -127. / scale))
error[q] = -127
return error


class CxSlice(object):
def __init__(self, board_idx, chip_idx, core_idx, cx_i0, cx_i1):
self.board_idx = board_idx
Expand Down Expand Up @@ -556,6 +583,11 @@ def realIdxBits(self):
def isMixed(self):
return self.fanoutType == 1

@property
def shift_bits(self):
"""Number of bits the -256..255 weight is right-shifted by."""
return 8 - self.realWgtBits + self.isMixed

def bits_per_axon(self, n_weights):
"""For an axon with n weights, compute the weight memory bits used"""
bits_per_weight = self.realWgtBits + self.dlyBits + self.tagBits
Expand Down Expand Up @@ -594,24 +626,48 @@ def validate(self, core=None):
assert 0 <= self.idxBits < 8
assert 1 <= self.fanoutType < 4

def discretize_weights(self, w, dtype=np.int32):
s = 8 - self.realWgtBits + self.isMixed
def discretize_weights(
self, w, dtype=np.int32, lossy_shift=True, check_result=True):
"""Takes weights and returns their quantized values with wgtExp.

The actual weight to be put on the chip is this returned value
divided by the ``scale`` attribute.

Parameters
----------
w : float ndarray
Weights to be discretized, in the range -255 to 255.
dtype : np.dtype, optional (Default: np.int32)
Data type for discretized weights.
lossy_shift : bool, optional (Default: True)
Whether to mimic the two-part weight shift that currently happens
on the chip, which can lose information for small wgtExp.
check_results : bool, optional (Default: True)
Whether to check that the discretized weights fall in
the valid range for weights on the chip (-256 to 255).
"""
s = self.shift_bits
m = 2**(8 - s) - 1

w = np.round(w / 2.**s).clip(-m, m).astype(dtype)
s2 = s + self.wgtExp
if s2 < 0:
warnings.warn("Lost %d extra bits in weight rounding" % (-s2,))

# round before `s2` right shift, since just shifting would floor
# everything resulting in weights biased towards being smaller
w = (np.round(w * 2.**s2) / 2**s2).clip(-m, m).astype(dtype)
if lossy_shift:
if s2 < 0:
warnings.warn("Lost %d extra bits in weight rounding" % (-s2,))

# Round before `s2` right shift. Just shifting would floor
# everything resulting in weights biased towards being smaller.
w = (np.round(w * 2.**s2) / 2**s2).clip(-m, m).astype(dtype)

shift(w, s2, out=w)
np.left_shift(w, 6, out=w)
shift(w, s2, out=w)
np.left_shift(w, 6, out=w)
else:
shift(w, 6 + s2, out=w)

ws = w // self.scale
assert np.all(ws <= 255) and np.all(ws >= -256)
if check_result:
ws = w // self.scale
assert np.all(ws <= 255) and np.all(ws >= -256)

return w

Expand Down
Loading