Skip to content

Commit

Permalink
Match learning between Nengo/emulator/chip
Browse files Browse the repository at this point in the history
This commit ensures that both the emulator and chip match Nengo
in terms of basic PES learning. The emulator now does learning
much more accurately, which allows us to better map PES learning
parameters (i.e., learning rate, synapse decay) to the chip.

This commit also makes several changes to learning tests:

- The PES learning test has been simplified and explicitly
  compares PES learning on Nengo Loihi with core Nengo.
- `test_multiple_pes` was fixed. It was not passing when using
  Nengo as the simulator, so it had been designed with the old
  Loihi learning rates in mind and needed fixing now that Loihi
  learning is closer to Nengo learning.
- Several test tolerances have been adjusted.
  Note that some tolerances were changed for the chip; in general,
  the chip seems to learn slightly faster than the emulator
  (I'm not quite sure why), but this difference seems to be less
  apparent for more dimensions/neurons.

And adds additional tests:

- Add test for PES error clipping.
- Add test for learning overflow.
- Add test for PES learning with 'simreal' emulator.
- Add test for trace increment clip warning.
- Add test for discretize_weights.lossy_shift.
- Add test for learning trace dropping

Co-authored-by: Daniel Rasmussen <[email protected]>
  • Loading branch information
2 people authored and tbekolay committed Jan 16, 2019
1 parent 60f0392 commit 64d0ce1
Show file tree
Hide file tree
Showing 9 changed files with 516 additions and 111 deletions.
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ Release history
0.5.0 (unreleased)
==================

**Changed**

- PES learning in Nengo Loihi more closely matches learning in core Nengo.
(`#139 <https://github.com/nengo/nengo-loihi/pull/139>`__)
- Learning in the emulator more closely matches learning on hardware.
(`#139 <https://github.com/nengo/nengo-loihi/pull/139>`__)

**Fixed**

- We integrate current (U) and voltage (V) more accurately now by accounting
Expand Down
2 changes: 1 addition & 1 deletion nengo_loihi/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def core_stdp_pre_cfgs(core):
profile_idxs = {}
for synapses in core.synapses:
if synapses.tracing:
mag_int, mag_frac = tracing_mag_int_frac(synapses)
mag_int, mag_frac = tracing_mag_int_frac(synapses.tracing_mag)
tracecfg = TraceCfg(
tau=synapses.tracing_tau,
spikeLevelInt=mag_int,
Expand Down
56 changes: 43 additions & 13 deletions nengo_loihi/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from nengo.dists import Distribution, get_samples
from nengo.connection import LearningRule
from nengo.ensemble import Neurons
from nengo.exceptions import BuildError
from nengo.exceptions import BuildError, ValidationError
from nengo.solvers import NoSolver, Solver
from nengo.utils.builder import default_n_eval_points
import nengo.utils.numpy as npext
Expand Down Expand Up @@ -110,6 +110,13 @@ def __init__(self, dt=0.001, label=None, builder=None):
# limit for clipping intercepts, to avoid neurons with high gains
self.intercept_limit = 0.95

# scaling for PES errors, before rounding and clipping to -127..127
self.pes_error_scale = 100.

# learning weight exponent for PES (controls the maximum weight
# magnitude/weight resolution)
self.pes_wgt_exp = 4

# Will be provided by Simulator
self.chip2host_params = {}

Expand Down Expand Up @@ -657,18 +664,41 @@ def build_connection(model, conn):
model.objs[conn]['decode_axons'] = dec_ax0

if conn.learning_rule_type is not None:
if isinstance(conn.learning_rule_type, nengo.PES):
pes_learn_rate = conn.learning_rule_type.learning_rate
# scale learning rates to roughly match Nengo
# 1e-4 is the Nengo core default learning rate
pes_learn_rate *= 4 / 1e-4
assert isinstance(conn.learning_rule_type.pre_synapse,
nengo.synapses.Lowpass)
pes_pre_syn = conn.learning_rule_type.pre_synapse.tau
# scale pre_syn.tau from s to ms
pes_pre_syn *= 1e3
dec_syn.set_learning(tracing_tau=pes_pre_syn,
tracing_mag=pes_learn_rate)
rule_type = conn.learning_rule_type
if isinstance(rule_type, nengo.PES):
if not isinstance(rule_type.pre_synapse,
nengo.synapses.Lowpass):
raise ValidationError(
"Loihi only supports `Lowpass` pre-synapses for "
"learning rules", attr='pre_synapse', obj=rule_type)

tracing_tau = rule_type.pre_synapse.tau / model.dt

# Nengo builder scales PES learning rate by `dt / n_neurons`
n_neurons = (conn.pre_obj.n_neurons
if isinstance(conn.pre_obj, Ensemble)
else conn.pre_obj.size_in)
learning_rate = rule_type.learning_rate * model.dt / n_neurons

# Account for scaling to put integer error in range [-127, 127]
learning_rate /= model.pes_error_scale

# Tracing mag set so that the magnitude of the pre trace
# is independent of the pre tau. `dt` factor accounts for
# Nengo's `dt` spike scaling. Where is the second `dt` from?
# Maybe the fact that post interneurons have `vth = 1/dt`?
tracing_mag = -np.expm1(-1. / tracing_tau) / model.dt**2

# learning weight exponent controls the maximum weight
# magnitude/weight resolution
wgt_exp = model.pes_wgt_exp

dec_syn.set_learning(
learning_rate=learning_rate,
tracing_mag=tracing_mag,
tracing_tau=tracing_tau,
wgt_exp=wgt_exp,
)
else:
raise NotImplementedError()

Expand Down
8 changes: 4 additions & 4 deletions nengo_loihi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def add_params(network):
"""
config = network.config

cfg = config[nengo.Ensemble]
if 'on_chip' not in cfg._extra_params:
cfg.set_param("on_chip",
Parameter('on_chip', default=None, optional=True))
ens_cfg = config[nengo.Ensemble]
if 'on_chip' not in ens_cfg._extra_params:
ens_cfg.set_param("on_chip",
Parameter('on_chip', default=None, optional=True))


def set_defaults():
Expand Down
94 changes: 75 additions & 19 deletions nengo_loihi/loihi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@
Q_BITS = 21 # number of bits for synapse accumulator
U_BITS = 23 # number of bits for cx input (u)

LEARN_BITS = 15 # number of bits in learning accumulator (not incl. sign)
LEARN_FRAC = 7 # extra least-significant bits added to weights for learning


def learn_overflow_bits(n_factors):
"""Compute number of bits with which learning will overflow.
Parameters
----------
n_factors : int
The number of learning factors (pre/post terms in the learning rule).
"""
factor_bits = 7 # number of bits per factor
mantissa_bits = 3 # number of bits for learning rate mantissa
return factor_bits*n_factors + mantissa_bits - LEARN_BITS


def overflow_signed(x, bits=7, out=None):
"""Compute overflow on an array of signed integers.
Expand Down Expand Up @@ -86,16 +102,10 @@ def bias_to_manexp(bias):
return man, exp


def tracing_mag_int_frac(synapses):
mag = synapses.tracing_mag
mag = mag / (synapses.size() / 100)

def tracing_mag_int_frac(mag):
"""Split trace magnitude into integer and fractional components for chip"""
mag_int = int(mag)
# TODO: how does mag_frac actually work???
# It's the x in x/128, I believe
mag_frac = int(128 * (mag - mag_int))
# mag_frac = min(int(round(1./mag_frac)), 128)

return mag_int, mag_frac


Expand Down Expand Up @@ -157,6 +167,23 @@ def shift(x, s, **kwargs):
return np.left_shift(x, s, **kwargs)


def scale_pes_errors(error, scale=1.):
"""Scale PES errors based on a scaling factor, round and clip."""
error = scale * error
error = np.round(error).astype(np.int32)
q = error > 127
if np.any(q):
warnings.warn("Max PES error (%0.2e) greater than chip max (%0.2e). "
"Clipping." % (error.max() / scale, 127. / scale))
error[q] = 127
q = error < -127
if np.any(q):
warnings.warn("Min PES error (%0.2e) less than chip min (%0.2e). "
"Clipping." % (error.min() / scale, -127. / scale))
error[q] = -127
return error


class CxSlice(object):
def __init__(self, board_idx, chip_idx, core_idx, cx_i0, cx_i1):
self.board_idx = board_idx
Expand Down Expand Up @@ -556,6 +583,11 @@ def realIdxBits(self):
def isMixed(self):
return self.fanoutType == 1

@property
def shift_bits(self):
"""Number of bits the -256..255 weight is right-shifted by."""
return 8 - self.realWgtBits + self.isMixed

def bits_per_axon(self, n_weights):
"""For an axon with n weights, compute the weight memory bits used"""
bits_per_weight = self.realWgtBits + self.dlyBits + self.tagBits
Expand Down Expand Up @@ -594,24 +626,48 @@ def validate(self, core=None):
assert 0 <= self.idxBits < 8
assert 1 <= self.fanoutType < 4

def discretize_weights(self, w, dtype=np.int32):
s = 8 - self.realWgtBits + self.isMixed
def discretize_weights(
self, w, dtype=np.int32, lossy_shift=True, check_result=True):
"""Takes weights and returns their quantized values with wgtExp.
The actual weight to be put on the chip is this returned value
divided by the ``scale`` attribute.
Parameters
----------
w : float ndarray
Weights to be discretized, in the range -255 to 255.
dtype : np.dtype, optional (Default: np.int32)
Data type for discretized weights.
lossy_shift : bool, optional (Default: True)
Whether to mimic the two-part weight shift that currently happens
on the chip, which can lose information for small wgtExp.
check_results : bool, optional (Default: True)
Whether to check that the discretized weights fall in
the valid range for weights on the chip (-256 to 255).
"""
s = self.shift_bits
m = 2**(8 - s) - 1

w = np.round(w / 2.**s).clip(-m, m).astype(dtype)
s2 = s + self.wgtExp
if s2 < 0:
warnings.warn("Lost %d extra bits in weight rounding" % (-s2,))

# round before `s2` right shift, since just shifting would floor
# everything resulting in weights biased towards being smaller
w = (np.round(w * 2.**s2) / 2**s2).clip(-m, m).astype(dtype)
if lossy_shift:
if s2 < 0:
warnings.warn("Lost %d extra bits in weight rounding" % (-s2,))

# Round before `s2` right shift. Just shifting would floor
# everything resulting in weights biased towards being smaller.
w = (np.round(w * 2.**s2) / 2**s2).clip(-m, m).astype(dtype)

shift(w, s2, out=w)
np.left_shift(w, 6, out=w)
shift(w, s2, out=w)
np.left_shift(w, 6, out=w)
else:
shift(w, 6 + s2, out=w)

ws = w // self.scale
assert np.all(ws <= 255) and np.all(ws >= -256)
if check_result:
ws = w // self.scale
assert np.all(ws <= 255) and np.all(ws >= -256)

return w

Expand Down
Loading

0 comments on commit 64d0ce1

Please sign in to comment.