From 64d0ce1f4061839aa274680362f599cef70262dd Mon Sep 17 00:00:00 2001 From: Eric Hunsberger Date: Fri, 28 Sep 2018 10:22:16 -0300 Subject: [PATCH] Match learning between Nengo/emulator/chip This commit ensures that both the emulator and chip match Nengo in terms of basic PES learning. The emulator now does learning much more accurately, which allows us to better map PES learning parameters (i.e., learning rate, synapse decay) to the chip. This commit also makes several changes to learning tests: - The PES learning test has been simplified and explicitly compares PES learning on Nengo Loihi with core Nengo. - `test_multiple_pes` was fixed. It was not passing when using Nengo as the simulator, so it had been designed with the old Loihi learning rates in mind and needed fixing now that Loihi learning is closer to Nengo learning. - Several test tolerances have been adjusted. Note that some tolerances were changed for the chip; in general, the chip seems to learn slightly faster than the emulator (I'm not quite sure why), but this difference seems to be less apparent for more dimensions/neurons. And adds additional tests: - Add test for PES error clipping. - Add test for learning overflow. - Add test for PES learning with 'simreal' emulator. - Add test for trace increment clip warning. - Add test for discretize_weights.lossy_shift. - Add test for learning trace dropping Co-authored-by: Daniel Rasmussen --- CHANGES.rst | 7 + nengo_loihi/allocators.py | 2 +- nengo_loihi/builder.py | 56 +++++-- nengo_loihi/config.py | 8 +- nengo_loihi/loihi_api.py | 94 ++++++++--- nengo_loihi/loihi_cx.py | 186 +++++++++++++++++---- nengo_loihi/loihi_interface.py | 13 +- nengo_loihi/tests/test_learning.py | 243 +++++++++++++++++++++++----- nengo_loihi/tests/test_loihi_api.py | 18 ++- 9 files changed, 516 insertions(+), 111 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index a98afb7cb..620e9429a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -22,6 +22,13 @@ Release history 0.5.0 (unreleased) ================== +**Changed** + +- PES learning in Nengo Loihi more closely matches learning in core Nengo. + (`#139 `__) +- Learning in the emulator more closely matches learning on hardware. + (`#139 `__) + **Fixed** - We integrate current (U) and voltage (V) more accurately now by accounting diff --git a/nengo_loihi/allocators.py b/nengo_loihi/allocators.py index 0ab0b919a..40c7bbbd1 100644 --- a/nengo_loihi/allocators.py +++ b/nengo_loihi/allocators.py @@ -63,7 +63,7 @@ def core_stdp_pre_cfgs(core): profile_idxs = {} for synapses in core.synapses: if synapses.tracing: - mag_int, mag_frac = tracing_mag_int_frac(synapses) + mag_int, mag_frac = tracing_mag_int_frac(synapses.tracing_mag) tracecfg = TraceCfg( tau=synapses.tracing_tau, spikeLevelInt=mag_int, diff --git a/nengo_loihi/builder.py b/nengo_loihi/builder.py index 00ac601b2..b011bf469 100644 --- a/nengo_loihi/builder.py +++ b/nengo_loihi/builder.py @@ -10,7 +10,7 @@ from nengo.dists import Distribution, get_samples from nengo.connection import LearningRule from nengo.ensemble import Neurons -from nengo.exceptions import BuildError +from nengo.exceptions import BuildError, ValidationError from nengo.solvers import NoSolver, Solver from nengo.utils.builder import default_n_eval_points import nengo.utils.numpy as npext @@ -110,6 +110,13 @@ def __init__(self, dt=0.001, label=None, builder=None): # limit for clipping intercepts, to avoid neurons with high gains self.intercept_limit = 0.95 + # scaling for PES errors, before rounding and clipping to -127..127 + self.pes_error_scale = 100. + + # learning weight exponent for PES (controls the maximum weight + # magnitude/weight resolution) + self.pes_wgt_exp = 4 + # Will be provided by Simulator self.chip2host_params = {} @@ -657,18 +664,41 @@ def build_connection(model, conn): model.objs[conn]['decode_axons'] = dec_ax0 if conn.learning_rule_type is not None: - if isinstance(conn.learning_rule_type, nengo.PES): - pes_learn_rate = conn.learning_rule_type.learning_rate - # scale learning rates to roughly match Nengo - # 1e-4 is the Nengo core default learning rate - pes_learn_rate *= 4 / 1e-4 - assert isinstance(conn.learning_rule_type.pre_synapse, - nengo.synapses.Lowpass) - pes_pre_syn = conn.learning_rule_type.pre_synapse.tau - # scale pre_syn.tau from s to ms - pes_pre_syn *= 1e3 - dec_syn.set_learning(tracing_tau=pes_pre_syn, - tracing_mag=pes_learn_rate) + rule_type = conn.learning_rule_type + if isinstance(rule_type, nengo.PES): + if not isinstance(rule_type.pre_synapse, + nengo.synapses.Lowpass): + raise ValidationError( + "Loihi only supports `Lowpass` pre-synapses for " + "learning rules", attr='pre_synapse', obj=rule_type) + + tracing_tau = rule_type.pre_synapse.tau / model.dt + + # Nengo builder scales PES learning rate by `dt / n_neurons` + n_neurons = (conn.pre_obj.n_neurons + if isinstance(conn.pre_obj, Ensemble) + else conn.pre_obj.size_in) + learning_rate = rule_type.learning_rate * model.dt / n_neurons + + # Account for scaling to put integer error in range [-127, 127] + learning_rate /= model.pes_error_scale + + # Tracing mag set so that the magnitude of the pre trace + # is independent of the pre tau. `dt` factor accounts for + # Nengo's `dt` spike scaling. Where is the second `dt` from? + # Maybe the fact that post interneurons have `vth = 1/dt`? + tracing_mag = -np.expm1(-1. / tracing_tau) / model.dt**2 + + # learning weight exponent controls the maximum weight + # magnitude/weight resolution + wgt_exp = model.pes_wgt_exp + + dec_syn.set_learning( + learning_rate=learning_rate, + tracing_mag=tracing_mag, + tracing_tau=tracing_tau, + wgt_exp=wgt_exp, + ) else: raise NotImplementedError() diff --git a/nengo_loihi/config.py b/nengo_loihi/config.py index 1b246d5df..6e15fa798 100644 --- a/nengo_loihi/config.py +++ b/nengo_loihi/config.py @@ -25,10 +25,10 @@ def add_params(network): """ config = network.config - cfg = config[nengo.Ensemble] - if 'on_chip' not in cfg._extra_params: - cfg.set_param("on_chip", - Parameter('on_chip', default=None, optional=True)) + ens_cfg = config[nengo.Ensemble] + if 'on_chip' not in ens_cfg._extra_params: + ens_cfg.set_param("on_chip", + Parameter('on_chip', default=None, optional=True)) def set_defaults(): diff --git a/nengo_loihi/loihi_api.py b/nengo_loihi/loihi_api.py index d9027249e..494b1fca1 100644 --- a/nengo_loihi/loihi_api.py +++ b/nengo_loihi/loihi_api.py @@ -22,6 +22,22 @@ Q_BITS = 21 # number of bits for synapse accumulator U_BITS = 23 # number of bits for cx input (u) +LEARN_BITS = 15 # number of bits in learning accumulator (not incl. sign) +LEARN_FRAC = 7 # extra least-significant bits added to weights for learning + + +def learn_overflow_bits(n_factors): + """Compute number of bits with which learning will overflow. + + Parameters + ---------- + n_factors : int + The number of learning factors (pre/post terms in the learning rule). + """ + factor_bits = 7 # number of bits per factor + mantissa_bits = 3 # number of bits for learning rate mantissa + return factor_bits*n_factors + mantissa_bits - LEARN_BITS + def overflow_signed(x, bits=7, out=None): """Compute overflow on an array of signed integers. @@ -86,16 +102,10 @@ def bias_to_manexp(bias): return man, exp -def tracing_mag_int_frac(synapses): - mag = synapses.tracing_mag - mag = mag / (synapses.size() / 100) - +def tracing_mag_int_frac(mag): + """Split trace magnitude into integer and fractional components for chip""" mag_int = int(mag) - # TODO: how does mag_frac actually work??? - # It's the x in x/128, I believe mag_frac = int(128 * (mag - mag_int)) - # mag_frac = min(int(round(1./mag_frac)), 128) - return mag_int, mag_frac @@ -157,6 +167,23 @@ def shift(x, s, **kwargs): return np.left_shift(x, s, **kwargs) +def scale_pes_errors(error, scale=1.): + """Scale PES errors based on a scaling factor, round and clip.""" + error = scale * error + error = np.round(error).astype(np.int32) + q = error > 127 + if np.any(q): + warnings.warn("Max PES error (%0.2e) greater than chip max (%0.2e). " + "Clipping." % (error.max() / scale, 127. / scale)) + error[q] = 127 + q = error < -127 + if np.any(q): + warnings.warn("Min PES error (%0.2e) less than chip min (%0.2e). " + "Clipping." % (error.min() / scale, -127. / scale)) + error[q] = -127 + return error + + class CxSlice(object): def __init__(self, board_idx, chip_idx, core_idx, cx_i0, cx_i1): self.board_idx = board_idx @@ -556,6 +583,11 @@ def realIdxBits(self): def isMixed(self): return self.fanoutType == 1 + @property + def shift_bits(self): + """Number of bits the -256..255 weight is right-shifted by.""" + return 8 - self.realWgtBits + self.isMixed + def bits_per_axon(self, n_weights): """For an axon with n weights, compute the weight memory bits used""" bits_per_weight = self.realWgtBits + self.dlyBits + self.tagBits @@ -594,24 +626,48 @@ def validate(self, core=None): assert 0 <= self.idxBits < 8 assert 1 <= self.fanoutType < 4 - def discretize_weights(self, w, dtype=np.int32): - s = 8 - self.realWgtBits + self.isMixed + def discretize_weights( + self, w, dtype=np.int32, lossy_shift=True, check_result=True): + """Takes weights and returns their quantized values with wgtExp. + + The actual weight to be put on the chip is this returned value + divided by the ``scale`` attribute. + + Parameters + ---------- + w : float ndarray + Weights to be discretized, in the range -255 to 255. + dtype : np.dtype, optional (Default: np.int32) + Data type for discretized weights. + lossy_shift : bool, optional (Default: True) + Whether to mimic the two-part weight shift that currently happens + on the chip, which can lose information for small wgtExp. + check_results : bool, optional (Default: True) + Whether to check that the discretized weights fall in + the valid range for weights on the chip (-256 to 255). + """ + s = self.shift_bits m = 2**(8 - s) - 1 w = np.round(w / 2.**s).clip(-m, m).astype(dtype) s2 = s + self.wgtExp - if s2 < 0: - warnings.warn("Lost %d extra bits in weight rounding" % (-s2,)) - # round before `s2` right shift, since just shifting would floor - # everything resulting in weights biased towards being smaller - w = (np.round(w * 2.**s2) / 2**s2).clip(-m, m).astype(dtype) + if lossy_shift: + if s2 < 0: + warnings.warn("Lost %d extra bits in weight rounding" % (-s2,)) + + # Round before `s2` right shift. Just shifting would floor + # everything resulting in weights biased towards being smaller. + w = (np.round(w * 2.**s2) / 2**s2).clip(-m, m).astype(dtype) - shift(w, s2, out=w) - np.left_shift(w, 6, out=w) + shift(w, s2, out=w) + np.left_shift(w, 6, out=w) + else: + shift(w, 6 + s2, out=w) - ws = w // self.scale - assert np.all(ws <= 255) and np.all(ws >= -256) + if check_result: + ws = w // self.scale + assert np.all(ws <= 255) and np.all(ws >= -256) return w diff --git a/nengo_loihi/loihi_cx.py b/nengo_loihi/loihi_cx.py index 4875c7825..087db2a73 100644 --- a/nengo_loihi/loihi_cx.py +++ b/nengo_loihi/loihi_cx.py @@ -14,7 +14,11 @@ bias_to_manexp, decay_int, decay_magnitude, + LEARN_FRAC, + learn_overflow_bits, overflow_signed, + scale_pes_errors, + shift, SynapseFmt, tracing_mag_int_frac, Q_BITS, U_BITS, @@ -265,7 +269,7 @@ def discretize(target, value): w_maxs = [s.max_abs_weight() for s in self.synapses] w_max = max(w_maxs) if len(w_maxs) > 0 else 0 b_max = np.abs(self.bias).max() - wgtExp = -7 + wgtExp = 0 if w_max > 1e-8: w_scale = (255. / w_max) @@ -311,7 +315,10 @@ def discretize(target, value): discretize(self.bias, bias_man * 2**bias_exp) for i, synapse in enumerate(self.synapses): - if w_maxs[i] > 1e-16: + if synapse.tracing: + wgtExp2 = synapse.learning_wgt_exp + dWgtExp = wgtExp - wgtExp2 + elif w_maxs[i] > 1e-16: dWgtExp = int(np.floor(np.log2(w_max / w_maxs[i]))) assert dWgtExp >= 0 wgtExp2 = max(wgtExp - dWgtExp, -6) @@ -323,9 +330,46 @@ def discretize(target, value): ws = w_scale[idxs] if is_iterable(w_scale) else w_scale discretize(w, synapse.synapse_fmt.discretize_weights( w * ws * 2**dWgtExp)) - # TODO: scale this properly, hardcoded for now + + # discretize learning if synapse.tracing: - synapse.synapse_fmt.wgtExp = 4 + synapse.tracing_tau = int(np.round(synapse.tracing_tau)) + + if is_iterable(w_scale): + assert np.all(w_scale == w_scale[0]) + w_scale_i = w_scale[0] if is_iterable(w_scale) else w_scale + + # incorporate weight scale and difference in weight exponents + # to learning rate, since these affect speed at which we learn + ws = w_scale_i * 2**dWgtExp + synapse.learning_rate *= ws + + # Loihi down-scales learning factors based on the number of + # overflow bits. Increasing learning rate maintains true rate. + synapse.learning_rate *= 2**learn_overflow_bits(2) + + # TODO: Currently, Loihi learning rate fixed at 2**-7. + # We should explore adjusting it for better performance. + lscale = 2**-7 / synapse.learning_rate + synapse.learning_rate *= lscale + synapse.tracing_mag /= lscale + + # discretize learning rate into mantissa and exponent + lr_exp = int(np.floor(np.log2(synapse.learning_rate))) + lr_int = int(np.round(synapse.learning_rate * 2**(-lr_exp))) + synapse.learning_rate = lr_int * 2**lr_exp + synapse._lr_int = lr_int + synapse._lr_exp = lr_exp + assert lr_exp >= -7 + + # discretize tracing mag into integer and fractional components + mag_int, mag_frac = tracing_mag_int_frac(synapse.tracing_mag) + if mag_int > 127: + warnings.warn("Trace increment exceeds upper limit " + "(learning rate may be too large)") + mag_int = 127 + mag_frac = 127 + synapse.tracing_mag = mag_int + mag_frac / 128. # --- noise assert (v_scale[0] == v_scale).all() @@ -395,12 +439,16 @@ class CxSynapses(object): can have a different number of target compartments. indices : (population, axon, compartment) ndarray The synapse indices. + learning_rate : float + The learning rate. + learning_wgt_exp : int + The weight exponent used on this connection if learning is enabled. tracing : bool Whether synaptic tracing is enabled for these synapses. - tracing_tau : float - The tracing time constant. + tracing_tau : int + Decay time constant for the learning trace, in timesteps (not seconds). tracing_mag : float - The tracing increment magnitude. + Magnitude by which the learning trace is increased for each spike. """ def __init__(self, n_axons, label=None): self.n_axons = n_axons @@ -411,6 +459,9 @@ def __init__(self, n_axons, label=None): self.indices = None self.axon_cx_bases = None self.axon_to_weight_map = None + + self.learning_rate = 1. + self.learning_wgt_exp = None self.tracing = False self.tracing_tau = None self.tracing_mag = None @@ -541,16 +592,22 @@ def set_population_weights( numSynapses=63, wgtBits=7) - def set_learning(self, tracing_tau=2, tracing_mag=1.0): + def set_learning( + self, learning_rate=1., tracing_tau=2, tracing_mag=1.0, wgt_exp=4): assert tracing_tau == int(tracing_tau), "tracing_tau must be integer" + self.tracing = True self.tracing_tau = int(tracing_tau) self.tracing_mag = tracing_mag self.format(learningCfg=1, stdpProfile=0) # ^ stdpProfile hard-coded for now (see loihi_interface) - mag_int, _ = tracing_mag_int_frac(self) - assert int(mag_int) < 2**7 + self.train_epoch = 2 + self.learn_epoch_k = 1 + self.learn_epoch = self.train_epoch * 2**self.learn_epoch_k + + self.learning_rate = learning_rate * self.learn_epoch + self.learning_wgt_exp = wgt_exp def format(self, **kwargs): if self.synapse_fmt is None: @@ -882,9 +939,64 @@ def overflow(x, bits, name=None): # --- allocate synapse memory self.axons_in = {synapses: [] for group in self.groups for synapses in group.synapses} - self.z = {synapses: np.zeros(synapses.n_axons, dtype=np.float64) - for group in self.groups for synapses in group.synapses - if synapses.tracing} # synapse traces + + learning_synapses = [ + synapses for group in self.groups + for synapses in group.synapses if synapses.tracing] + self.z = {synapses: np.zeros(synapses.n_axons, dtype=group_dtype) + for synapses in learning_synapses} # synapse traces + self.z_spikes = {synapses: set() for synapses in learning_synapses} + # Currently, PES learning only happens on Nodes, where we have + # pairs of on/off neurons. Therefore, the number of error dimensions + # is half the number of neurons. + self.pes_errors = {synapses: np.zeros(group.n//2, dtype=group_dtype) + for synapses in learning_synapses} + self.pes_error_scale = getattr(model, 'pes_error_scale', 1.) + + if group_dtype == np.int32: + def stochastic_round(x, dtype=group_dtype, rng=self.rng, + clip=None, name="values"): + x_sign = np.sign(x).astype(dtype) + x_frac, x_int = np.modf(np.abs(x)) + p = rng.rand(*x.shape) + y = x_int.astype(dtype) + (x_frac > p) + if clip is not None: + q = y > clip + if np.any(q): + warnings.warn("Clipping %s" % name) + y[q] = clip + return x_sign * y + + def trace_round(x, dtype=group_dtype, rng=self.rng): + return stochastic_round( + x, dtype=dtype, rng=rng, clip=127, name="synapse trace") + + def weight_update(synapses, delta_ws): + synapse_fmt = synapses.synapse_fmt + wgt_exp = synapse_fmt.realWgtExp + shift_bits = synapse_fmt.shift_bits + overflow = learn_overflow_bits(n_factors=2) + for w, delta_w in zip(synapses.weights, delta_ws): + product = shift( + delta_w * synapses._lr_int, + LEARN_FRAC + synapses._lr_exp - overflow) + learn_w = shift(w, LEARN_FRAC - wgt_exp) + product + learn_w[:] = stochastic_round( + learn_w * 2**(-LEARN_FRAC - shift_bits), + clip=2**(8 - shift_bits) - 1, + name="learning weights") + w[:] = np.left_shift(learn_w, wgt_exp + shift_bits) + + elif group_dtype == np.float32: + def trace_round(x, dtype=group_dtype): + return x # no rounding + + def weight_update(synapses, delta_ws): + for w, delta_w in zip(synapses.weights, delta_ws): + w += synapses.learning_rate * delta_w + + self.trace_round = trace_round + self.weight_update = weight_update # --- noise enableNoise = np.hstack([ @@ -966,15 +1078,15 @@ def host2chip(self, spikes, errors): for cx_spike_input, t, spike_idxs in spikes: cx_spike_input.add_spikes(t, spike_idxs) - learning_rate = 50 # This is set to match hardware - for synapses, t, e in errors: - z = self.z[synapses] - x = np.hstack([-e, e]) + # TODO: these are sent every timestep, but learning only happens every + # `tepoch * 2**learn_k` timesteps (see CxSynapses). Need to average. + for pes_errors in self.pes_errors.values(): + pes_errors[:] = 0 - delta_w = np.outer(z, x) * learning_rate - - for i, w in enumerate(synapses.weights): - w += delta_w[i].astype('int32') + for synapses, t, e in errors: + pes_errors = self.pes_errors[synapses] + assert pes_errors.shape == e.shape + pes_errors += scale_pes_errors(e, scale=self.pes_error_scale) def step(self): # noqa: C901 """Advance the simulation by 1 step (``dt`` seconds).""" @@ -1020,16 +1132,30 @@ def step(self): # noqa: C901 spike.axon_id, atom=spike.atom) qb[0, cx_base + indices] += weights - if synapses.tracing: - z = self.z[synapses] - tau = synapses.tracing_tau - mag = synapses.tracing_mag - - decay = np.exp(-1.0 / tau) - z *= decay - + # --- learning trace + z_spikes = self.z_spikes.get(synapses, None) + if z_spikes is not None: for spike in self.axons_in[synapses]: - z[spike.axon_id] += mag + if spike.axon_id in z_spikes: + self.error("Synaptic trace spikes lost") + z_spikes.add(spike.axon_id) + + z = self.z.get(synapses, None) + if z is not None and self.t % synapses.train_epoch == 0: + tau = synapses.tracing_tau + decay = np.exp(-synapses.train_epoch / tau) + zi = decay*z + zi[list(z_spikes)] += synapses.tracing_mag + z[:] = self.trace_round(zi) + z_spikes.clear() + + # --- learning update + pes_e = self.pes_errors.get(synapses, None) + if pes_e is not None and self.t % synapses.learn_epoch == 0: + assert z is not None + x = np.hstack([-pes_e, pes_e]) + delta_w = np.outer(z, x) + self.weight_update(synapses, delta_w) # --- updates q0 = self.q[0, :] diff --git a/nengo_loihi/loihi_interface.py b/nengo_loihi/loihi_interface.py index d3a8dafd0..0d6c867c2 100644 --- a/nengo_loihi/loihi_interface.py +++ b/nengo_loihi/loihi_interface.py @@ -32,7 +32,12 @@ def no_nxsdk(*args, **kwargs): import nengo_loihi.loihi_cx as loihi_cx from nengo_loihi.allocators import one_to_one_allocator from nengo_loihi.loihi_api import ( - bias_to_manexp, CX_PROFILES_MAX, SpikeInput, VTH_PROFILES_MAX) + bias_to_manexp, + CX_PROFILES_MAX, + scale_pes_errors, + SpikeInput, + VTH_PROFILES_MAX, +) from nengo_loihi.loihi_cx import CxGroup logger = logging.getLogger(__name__) @@ -593,6 +598,7 @@ def _iter_probes(self): def build(self, cx_model, seed=None): cx_model.validate() self.model = cx_model + self.pes_error_scale = getattr(cx_model, 'pes_error_scale', 1.) if self.use_snips: # tag all probes as being snip-based, @@ -729,8 +735,6 @@ def host2chip(self, spikes, errors): loihi_errors = [] for synapses, t, e in errors: - x = (100 * e).astype(int) - x = np.clip(x, -100, 100, out=x) cx_group = synapses.group coreid = None for core in self.board.chips[0].cores: @@ -744,7 +748,8 @@ def host2chip(self, spikes, errors): break assert coreid is not None - loihi_errors.append([coreid, len(x)] + x.tolist()) + e = scale_pes_errors(e, scale=self.pes_error_scale) + loihi_errors.append([coreid, len(e)] + e.tolist()) if self.use_snips: return self._host2chip_snips(loihi_spikes, loihi_errors) diff --git a/nengo_loihi/tests/test_learning.py b/nengo_loihi/tests/test_learning.py index a4df2f324..b9de7a8b8 100644 --- a/nengo_loihi/tests/test_learning.py +++ b/nengo_loihi/tests/test_learning.py @@ -1,14 +1,29 @@ import nengo +from nengo.exceptions import ValidationError, SimulationError +from nengo.utils.numpy import rms import numpy as np import pytest +import nengo_loihi.builder -@pytest.mark.parametrize('n_per_dim', [120, 200]) -@pytest.mark.parametrize('dims', [1, 3]) -def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims): - scale = np.linspace(1, 0, dims + 1)[:-1] - input_fn = lambda t: np.sin(t * 2 * np.pi) * scale +def pes_network( + n_per_dim, + dims, + seed, + learning_rule_type=nengo.PES(learning_rate=1e-3), + input_scale=None, + error_scale=1., + learn_synapse=0.005, + probe_synapse=0.02, +): + if input_scale is None: + input_scale = np.linspace(1, 0, dims + 1)[:-1] + assert input_scale.size == dims + + input_fn = lambda t: np.sin(t * 2 * np.pi) * input_scale + + probes = {} with nengo.Network(seed=seed) as model: stim = nengo.Node(input_fn) @@ -19,44 +34,150 @@ def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims): conn = nengo.Connection( pre, post, function=lambda x: np.zeros(dims), - synapse=0.01, - learning_rule_type=nengo.PES(learning_rate=1e-3)) + synapse=learn_synapse, + learning_rule_type=learning_rule_type) - nengo.Connection(post, conn.learning_rule) - nengo.Connection(stim, conn.learning_rule, transform=-1) + nengo.Connection(post, conn.learning_rule, transform=error_scale) + nengo.Connection(stim, conn.learning_rule, transform=-error_scale) - p_stim = nengo.Probe(stim, synapse=0.02) - p_pre = nengo.Probe(pre, synapse=0.02) - p_post = nengo.Probe(post, synapse=0.02) + probes['stim'] = nengo.Probe(stim, synapse=probe_synapse) + probes['pre'] = nengo.Probe(pre, synapse=probe_synapse) + probes['post'] = nengo.Probe(post, synapse=probe_synapse) - with Simulator(model) as sim: - sim.run(5.0) + return model, probes + + +@pytest.mark.parametrize('n_per_dim', [120, 200]) +@pytest.mark.parametrize('dims', [1, 3]) +def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims): + tau = 0.01 + model, probes = pes_network(n_per_dim, dims, seed, learn_synapse=tau) + + simtime = 5.0 + with nengo.Simulator(model) as nengo_sim: + nengo_sim.run(simtime) + + with Simulator(model) as loihi_sim: + loihi_sim.run(simtime) + + with Simulator(model, target='simreal') as real_sim: + real_sim.run(simtime) + + t = nengo_sim.trange() + pre_tmask = t > 0.1 + post_tmask = t > simtime - 1.0 + + inter_tau = loihi_sim.model.inter_tau + y = nengo_sim.data[probes['stim']] + y_dpre = nengo.Lowpass(inter_tau).filt(y) + y_dpost = nengo.Lowpass(tau).combine(nengo.Lowpass(inter_tau)).filt(y_dpre) + y_nengo = nengo_sim.data[probes['post']] + y_loihi = loihi_sim.data[probes['post']] + y_real = real_sim.data[probes['post']] - t = sim.trange() plt.subplot(211) - plt.plot(t, sim.data[p_stim]) - plt.plot(t, sim.data[p_pre]) - plt.plot(t, sim.data[p_post]) - - # --- fit input_fn to output, determine magnitude - # The larger the magnitude, the closer the output is to the input - x = np.array([input_fn(tt)[0] for tt in t[t > 4]]) - y = sim.data[p_post][t > 4][:, 0] - m = np.linspace(0, 1, 21) - errors = np.abs(y - m[:, None]*x).mean(axis=1) - m_best = m[np.argmin(errors)] + plt.plot(t, y_dpost, 'k', label='target') + plt.plot(t, y_nengo, 'b', label='nengo') + plt.plot(t, y_loihi, 'g', label='loihi') + plt.plot(t, y_real, 'r:', label='real') + plt.legend() plt.subplot(212) - plt.plot(t[t > 4], x) - plt.plot(t[t > 4], y) - plt.plot(t[t > 4], m_best * x, ':') + plt.plot(t[post_tmask], y_loihi[post_tmask] - y_dpost[post_tmask], 'k') + plt.plot(t[post_tmask], y_loihi[post_tmask] - y_nengo[post_tmask], 'b') + + x_loihi = loihi_sim.data[probes['pre']] + assert allclose(x_loihi[pre_tmask], y_dpre[pre_tmask], + atol=0.1, rtol=0.05) + + assert allclose(y_loihi[post_tmask], y_dpost[post_tmask], + atol=0.1, rtol=0.05) + assert allclose(y_loihi, y_nengo, atol=0.2, rtol=0.2) + + assert allclose(y_real[post_tmask], y_dpost[post_tmask], + atol=0.1, rtol=0.05) + assert allclose(y_real, y_nengo, atol=0.2, rtol=0.2) + + +def test_pes_overflow(allclose, plt, seed, Simulator): + dims = 3 + n_per_dim = 120 + tau = 0.01 + model, probes = pes_network(n_per_dim, dims, seed, learn_synapse=tau, + input_scale=np.linspace(1, 0.7, dims)) + + simtime = 3.0 + loihi_model = nengo_loihi.builder.Model() + # set learning_wgt_exp low to create overflow in weight values + loihi_model.pes_wgt_exp = -1 + + with Simulator(model, model=loihi_model) as loihi_sim: + loihi_sim.run(simtime) - assert allclose(sim.data[p_pre][t > 0.1], - sim.data[p_stim][t > 0.1], - atol=0.15, - rtol=0.15) - assert np.min(errors) < 0.3, "Not able to fit correctly" - assert m_best > (0.3 if n_per_dim < 150 else 0.6) + t = loihi_sim.trange() + post_tmask = t > simtime - 1.0 + + inter_tau = loihi_sim.model.inter_tau + y = loihi_sim.data[probes['stim']] + y_dpre = nengo.Lowpass(inter_tau).filt(y) + y_dpost = nengo.Lowpass(tau).combine(nengo.Lowpass(inter_tau)).filt(y_dpre) + y_loihi = loihi_sim.data[probes['post']] + + plt.plot(t, y_dpost, 'k', label='target') + plt.plot(t, y_loihi, 'g', label='loihi') + + # --- fit output to scaled version of target output + z_ref0 = y_dpost[post_tmask][:, 0] + z_loihi = y_loihi[post_tmask] + scale = np.linspace(0, 1, 50) + E = np.abs(z_loihi - scale[:, None, None]*z_ref0[:, None]) + errors = E.mean(axis=1) # average over time (errors is: scales x dims) + for j in range(dims): + errors_j = errors[:, j] + i = np.argmin(errors_j) + assert errors_j[i] < 0.1, ("Learning output for dim %d did not match " + "any scaled version of the target output" + % j) + assert scale[i] > 0.4, "Learning output for dim %d is too small" % j + assert scale[i] < 0.7, ("Learning output for dim %d is too large " + "(weights or traces not clipping as expected)" + % j) + + +def test_pes_error_clip(allclose, plt, seed, Simulator): + dims = 2 + n_per_dim = 120 + tau = 0.01 + error_scale = 5. # scale up error signal so it clips + model, probes = pes_network( + n_per_dim, dims, seed, learn_synapse=tau, + learning_rule_type=nengo.PES(learning_rate=1e-3 / error_scale), + input_scale=np.array([1., -1.]), + error_scale=error_scale) + + simtime = 3.0 + with pytest.warns(UserWarning, match=r'.*PES error.*Clipping.'): + with Simulator(model) as loihi_sim: + loihi_sim.run(simtime) + + t = loihi_sim.trange() + post_tmask = t > simtime - 1.0 + + inter_tau = loihi_sim.model.inter_tau + y = loihi_sim.data[probes['stim']] + y_dpre = nengo.Lowpass(inter_tau).filt(y) + y_dpost = nengo.Lowpass(tau).combine(nengo.Lowpass(inter_tau)).filt(y_dpre) + y_loihi = loihi_sim.data[probes['post']] + + plt.plot(t, y_dpost, 'k', label='target') + plt.plot(t, y_loihi, 'g', label='loihi') + + # --- assert that we've learned something, but not everything + error = (rms(y_loihi[post_tmask] - y_dpost[post_tmask]) + / rms(y_dpost[post_tmask])) + assert error < 0.5 + assert error > 0.05 + # ^ error on emulator vs chip is quite different, hence large tolerances @pytest.mark.parametrize('init_function', [None, lambda x: 0]) @@ -74,19 +195,65 @@ def test_multiple_pes(init_function, allclose, plt, seed, Simulator): pre_ea.ea_ensembles[i], output[i], function=init_function, - learning_rule_type=nengo.PES(learning_rate=5e-4), + learning_rule_type=nengo.PES(learning_rate=3e-3), ) nengo.Connection(target[i], conn.learning_rule, transform=-1) nengo.Connection(output[i], conn.learning_rule) probe = nengo.Probe(output, synapse=0.1) + + simtime = 2.5 with Simulator(model) as sim: - sim.run(1.0) + sim.run(simtime) + t = sim.trange() + tmask = t > simtime * 0.85 plt.plot(t, sim.data[probe]) for target, style in zip(targets, plt.rcParams["axes.prop_cycle"]): plt.axhline(target, **style) for i, target in enumerate(targets): - assert allclose(sim.data[probe][t > 0.8, i], target, atol=0.1) + assert allclose(sim.data[probe][tmask, i], target, + atol=0.05, rtol=0.05), "Target %d not close" % i + + +def test_pes_pre_synapse_type_error(Simulator): + with nengo.Network() as model: + pre = nengo.Ensemble(10, 1) + post = nengo.Node(size_in=1) + rule_type = nengo.PES(pre_synapse=nengo.Alpha(0.005)) + conn = nengo.Connection(pre, post, learning_rule_type=rule_type) + nengo.Connection(post, conn.learning_rule) + + with pytest.raises(ValidationError): + with Simulator(model): + pass + + +def test_pes_trace_increment_clip_warning(seed, Simulator): + dims = 2 + n_per_dim = 120 + model, _ = pes_network( + n_per_dim, dims, seed, + learning_rule_type=nengo.PES(learning_rate=1e-1)) + + with pytest.warns(UserWarning, match="Trace increment exceeds upper"): + with Simulator(model): + pass + + +def test_drop_trace_spikes(Simulator, seed): + with nengo.Network(seed=seed) as net: + a = nengo.Ensemble(10, 1, gain=nengo.dists.Choice([1]), + bias=nengo.dists.Choice([2000]), + neuron_type=nengo.SpikingRectifiedLinear()) + b = nengo.Node(size_in=1) + + conn = nengo.Connection(a, b, learning_rule_type=nengo.PES(1)) + + nengo.Connection(b, conn.learning_rule) + + with Simulator(net, target="sim") as sim: + with pytest.raises(SimulationError): + sim.run(1.0) diff --git a/nengo_loihi/tests/test_loihi_api.py b/nengo_loihi/tests/test_loihi_api.py index 86065cc05..712f16126 100644 --- a/nengo_loihi/tests/test_loihi_api.py +++ b/nengo_loihi/tests/test_loihi_api.py @@ -2,8 +2,8 @@ import numpy as np import pytest -from nengo_loihi.loihi_api import overflow_signed -from nengo_loihi.loihi_api import decay_int, decay_magnitude +from nengo_loihi.loihi_api import ( + overflow_signed, decay_int, decay_magnitude, SynapseFmt) @pytest.mark.parametrize("b", (8, 16, 17, 23)) @@ -87,3 +87,17 @@ def empirical_decay_magnitude(decay, x0): plt.plot(relative_diff.clip(0, None)) assert np.all(relative_diff < 1e-6) + + +@pytest.mark.parametrize("lossy_shift", (True, False)) +def test_lossy_shift(lossy_shift, rng): + wgt_bits = 6 + w = rng.uniform(-100, 100, size=(10, 10)) + fmt = SynapseFmt(wgtBits=wgt_bits, wgtExp=0, fanoutType=0) + + w2 = fmt.discretize_weights(w, lossy_shift=lossy_shift) + + clipped = np.round(w / 4).clip(-2 ** wgt_bits, 2 ** wgt_bits).astype( + np.int32) + + assert np.allclose(w2, np.left_shift(clipped, 8))