diff --git a/CHANGES.rst b/CHANGES.rst index a98afb7cb..620e9429a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -22,6 +22,13 @@ Release history 0.5.0 (unreleased) ================== +**Changed** + +- PES learning in Nengo Loihi more closely matches learning in core Nengo. + (`#139 `__) +- Learning in the emulator more closely matches learning on hardware. + (`#139 `__) + **Fixed** - We integrate current (U) and voltage (V) more accurately now by accounting diff --git a/nengo_loihi/allocators.py b/nengo_loihi/allocators.py index 0ab0b919a..40c7bbbd1 100644 --- a/nengo_loihi/allocators.py +++ b/nengo_loihi/allocators.py @@ -63,7 +63,7 @@ def core_stdp_pre_cfgs(core): profile_idxs = {} for synapses in core.synapses: if synapses.tracing: - mag_int, mag_frac = tracing_mag_int_frac(synapses) + mag_int, mag_frac = tracing_mag_int_frac(synapses.tracing_mag) tracecfg = TraceCfg( tau=synapses.tracing_tau, spikeLevelInt=mag_int, diff --git a/nengo_loihi/builder.py b/nengo_loihi/builder.py index 00ac601b2..b011bf469 100644 --- a/nengo_loihi/builder.py +++ b/nengo_loihi/builder.py @@ -10,7 +10,7 @@ from nengo.dists import Distribution, get_samples from nengo.connection import LearningRule from nengo.ensemble import Neurons -from nengo.exceptions import BuildError +from nengo.exceptions import BuildError, ValidationError from nengo.solvers import NoSolver, Solver from nengo.utils.builder import default_n_eval_points import nengo.utils.numpy as npext @@ -110,6 +110,13 @@ def __init__(self, dt=0.001, label=None, builder=None): # limit for clipping intercepts, to avoid neurons with high gains self.intercept_limit = 0.95 + # scaling for PES errors, before rounding and clipping to -127..127 + self.pes_error_scale = 100. + + # learning weight exponent for PES (controls the maximum weight + # magnitude/weight resolution) + self.pes_wgt_exp = 4 + # Will be provided by Simulator self.chip2host_params = {} @@ -657,18 +664,41 @@ def build_connection(model, conn): model.objs[conn]['decode_axons'] = dec_ax0 if conn.learning_rule_type is not None: - if isinstance(conn.learning_rule_type, nengo.PES): - pes_learn_rate = conn.learning_rule_type.learning_rate - # scale learning rates to roughly match Nengo - # 1e-4 is the Nengo core default learning rate - pes_learn_rate *= 4 / 1e-4 - assert isinstance(conn.learning_rule_type.pre_synapse, - nengo.synapses.Lowpass) - pes_pre_syn = conn.learning_rule_type.pre_synapse.tau - # scale pre_syn.tau from s to ms - pes_pre_syn *= 1e3 - dec_syn.set_learning(tracing_tau=pes_pre_syn, - tracing_mag=pes_learn_rate) + rule_type = conn.learning_rule_type + if isinstance(rule_type, nengo.PES): + if not isinstance(rule_type.pre_synapse, + nengo.synapses.Lowpass): + raise ValidationError( + "Loihi only supports `Lowpass` pre-synapses for " + "learning rules", attr='pre_synapse', obj=rule_type) + + tracing_tau = rule_type.pre_synapse.tau / model.dt + + # Nengo builder scales PES learning rate by `dt / n_neurons` + n_neurons = (conn.pre_obj.n_neurons + if isinstance(conn.pre_obj, Ensemble) + else conn.pre_obj.size_in) + learning_rate = rule_type.learning_rate * model.dt / n_neurons + + # Account for scaling to put integer error in range [-127, 127] + learning_rate /= model.pes_error_scale + + # Tracing mag set so that the magnitude of the pre trace + # is independent of the pre tau. `dt` factor accounts for + # Nengo's `dt` spike scaling. Where is the second `dt` from? + # Maybe the fact that post interneurons have `vth = 1/dt`? + tracing_mag = -np.expm1(-1. / tracing_tau) / model.dt**2 + + # learning weight exponent controls the maximum weight + # magnitude/weight resolution + wgt_exp = model.pes_wgt_exp + + dec_syn.set_learning( + learning_rate=learning_rate, + tracing_mag=tracing_mag, + tracing_tau=tracing_tau, + wgt_exp=wgt_exp, + ) else: raise NotImplementedError() diff --git a/nengo_loihi/config.py b/nengo_loihi/config.py index 1b246d5df..6e15fa798 100644 --- a/nengo_loihi/config.py +++ b/nengo_loihi/config.py @@ -25,10 +25,10 @@ def add_params(network): """ config = network.config - cfg = config[nengo.Ensemble] - if 'on_chip' not in cfg._extra_params: - cfg.set_param("on_chip", - Parameter('on_chip', default=None, optional=True)) + ens_cfg = config[nengo.Ensemble] + if 'on_chip' not in ens_cfg._extra_params: + ens_cfg.set_param("on_chip", + Parameter('on_chip', default=None, optional=True)) def set_defaults(): diff --git a/nengo_loihi/loihi_api.py b/nengo_loihi/loihi_api.py index d9027249e..494b1fca1 100644 --- a/nengo_loihi/loihi_api.py +++ b/nengo_loihi/loihi_api.py @@ -22,6 +22,22 @@ Q_BITS = 21 # number of bits for synapse accumulator U_BITS = 23 # number of bits for cx input (u) +LEARN_BITS = 15 # number of bits in learning accumulator (not incl. sign) +LEARN_FRAC = 7 # extra least-significant bits added to weights for learning + + +def learn_overflow_bits(n_factors): + """Compute number of bits with which learning will overflow. + + Parameters + ---------- + n_factors : int + The number of learning factors (pre/post terms in the learning rule). + """ + factor_bits = 7 # number of bits per factor + mantissa_bits = 3 # number of bits for learning rate mantissa + return factor_bits*n_factors + mantissa_bits - LEARN_BITS + def overflow_signed(x, bits=7, out=None): """Compute overflow on an array of signed integers. @@ -86,16 +102,10 @@ def bias_to_manexp(bias): return man, exp -def tracing_mag_int_frac(synapses): - mag = synapses.tracing_mag - mag = mag / (synapses.size() / 100) - +def tracing_mag_int_frac(mag): + """Split trace magnitude into integer and fractional components for chip""" mag_int = int(mag) - # TODO: how does mag_frac actually work??? - # It's the x in x/128, I believe mag_frac = int(128 * (mag - mag_int)) - # mag_frac = min(int(round(1./mag_frac)), 128) - return mag_int, mag_frac @@ -157,6 +167,23 @@ def shift(x, s, **kwargs): return np.left_shift(x, s, **kwargs) +def scale_pes_errors(error, scale=1.): + """Scale PES errors based on a scaling factor, round and clip.""" + error = scale * error + error = np.round(error).astype(np.int32) + q = error > 127 + if np.any(q): + warnings.warn("Max PES error (%0.2e) greater than chip max (%0.2e). " + "Clipping." % (error.max() / scale, 127. / scale)) + error[q] = 127 + q = error < -127 + if np.any(q): + warnings.warn("Min PES error (%0.2e) less than chip min (%0.2e). " + "Clipping." % (error.min() / scale, -127. / scale)) + error[q] = -127 + return error + + class CxSlice(object): def __init__(self, board_idx, chip_idx, core_idx, cx_i0, cx_i1): self.board_idx = board_idx @@ -556,6 +583,11 @@ def realIdxBits(self): def isMixed(self): return self.fanoutType == 1 + @property + def shift_bits(self): + """Number of bits the -256..255 weight is right-shifted by.""" + return 8 - self.realWgtBits + self.isMixed + def bits_per_axon(self, n_weights): """For an axon with n weights, compute the weight memory bits used""" bits_per_weight = self.realWgtBits + self.dlyBits + self.tagBits @@ -594,24 +626,48 @@ def validate(self, core=None): assert 0 <= self.idxBits < 8 assert 1 <= self.fanoutType < 4 - def discretize_weights(self, w, dtype=np.int32): - s = 8 - self.realWgtBits + self.isMixed + def discretize_weights( + self, w, dtype=np.int32, lossy_shift=True, check_result=True): + """Takes weights and returns their quantized values with wgtExp. + + The actual weight to be put on the chip is this returned value + divided by the ``scale`` attribute. + + Parameters + ---------- + w : float ndarray + Weights to be discretized, in the range -255 to 255. + dtype : np.dtype, optional (Default: np.int32) + Data type for discretized weights. + lossy_shift : bool, optional (Default: True) + Whether to mimic the two-part weight shift that currently happens + on the chip, which can lose information for small wgtExp. + check_results : bool, optional (Default: True) + Whether to check that the discretized weights fall in + the valid range for weights on the chip (-256 to 255). + """ + s = self.shift_bits m = 2**(8 - s) - 1 w = np.round(w / 2.**s).clip(-m, m).astype(dtype) s2 = s + self.wgtExp - if s2 < 0: - warnings.warn("Lost %d extra bits in weight rounding" % (-s2,)) - # round before `s2` right shift, since just shifting would floor - # everything resulting in weights biased towards being smaller - w = (np.round(w * 2.**s2) / 2**s2).clip(-m, m).astype(dtype) + if lossy_shift: + if s2 < 0: + warnings.warn("Lost %d extra bits in weight rounding" % (-s2,)) + + # Round before `s2` right shift. Just shifting would floor + # everything resulting in weights biased towards being smaller. + w = (np.round(w * 2.**s2) / 2**s2).clip(-m, m).astype(dtype) - shift(w, s2, out=w) - np.left_shift(w, 6, out=w) + shift(w, s2, out=w) + np.left_shift(w, 6, out=w) + else: + shift(w, 6 + s2, out=w) - ws = w // self.scale - assert np.all(ws <= 255) and np.all(ws >= -256) + if check_result: + ws = w // self.scale + assert np.all(ws <= 255) and np.all(ws >= -256) return w diff --git a/nengo_loihi/loihi_cx.py b/nengo_loihi/loihi_cx.py index 4875c7825..087db2a73 100644 --- a/nengo_loihi/loihi_cx.py +++ b/nengo_loihi/loihi_cx.py @@ -14,7 +14,11 @@ bias_to_manexp, decay_int, decay_magnitude, + LEARN_FRAC, + learn_overflow_bits, overflow_signed, + scale_pes_errors, + shift, SynapseFmt, tracing_mag_int_frac, Q_BITS, U_BITS, @@ -265,7 +269,7 @@ def discretize(target, value): w_maxs = [s.max_abs_weight() for s in self.synapses] w_max = max(w_maxs) if len(w_maxs) > 0 else 0 b_max = np.abs(self.bias).max() - wgtExp = -7 + wgtExp = 0 if w_max > 1e-8: w_scale = (255. / w_max) @@ -311,7 +315,10 @@ def discretize(target, value): discretize(self.bias, bias_man * 2**bias_exp) for i, synapse in enumerate(self.synapses): - if w_maxs[i] > 1e-16: + if synapse.tracing: + wgtExp2 = synapse.learning_wgt_exp + dWgtExp = wgtExp - wgtExp2 + elif w_maxs[i] > 1e-16: dWgtExp = int(np.floor(np.log2(w_max / w_maxs[i]))) assert dWgtExp >= 0 wgtExp2 = max(wgtExp - dWgtExp, -6) @@ -323,9 +330,46 @@ def discretize(target, value): ws = w_scale[idxs] if is_iterable(w_scale) else w_scale discretize(w, synapse.synapse_fmt.discretize_weights( w * ws * 2**dWgtExp)) - # TODO: scale this properly, hardcoded for now + + # discretize learning if synapse.tracing: - synapse.synapse_fmt.wgtExp = 4 + synapse.tracing_tau = int(np.round(synapse.tracing_tau)) + + if is_iterable(w_scale): + assert np.all(w_scale == w_scale[0]) + w_scale_i = w_scale[0] if is_iterable(w_scale) else w_scale + + # incorporate weight scale and difference in weight exponents + # to learning rate, since these affect speed at which we learn + ws = w_scale_i * 2**dWgtExp + synapse.learning_rate *= ws + + # Loihi down-scales learning factors based on the number of + # overflow bits. Increasing learning rate maintains true rate. + synapse.learning_rate *= 2**learn_overflow_bits(2) + + # TODO: Currently, Loihi learning rate fixed at 2**-7. + # We should explore adjusting it for better performance. + lscale = 2**-7 / synapse.learning_rate + synapse.learning_rate *= lscale + synapse.tracing_mag /= lscale + + # discretize learning rate into mantissa and exponent + lr_exp = int(np.floor(np.log2(synapse.learning_rate))) + lr_int = int(np.round(synapse.learning_rate * 2**(-lr_exp))) + synapse.learning_rate = lr_int * 2**lr_exp + synapse._lr_int = lr_int + synapse._lr_exp = lr_exp + assert lr_exp >= -7 + + # discretize tracing mag into integer and fractional components + mag_int, mag_frac = tracing_mag_int_frac(synapse.tracing_mag) + if mag_int > 127: + warnings.warn("Trace increment exceeds upper limit " + "(learning rate may be too large)") + mag_int = 127 + mag_frac = 127 + synapse.tracing_mag = mag_int + mag_frac / 128. # --- noise assert (v_scale[0] == v_scale).all() @@ -395,12 +439,16 @@ class CxSynapses(object): can have a different number of target compartments. indices : (population, axon, compartment) ndarray The synapse indices. + learning_rate : float + The learning rate. + learning_wgt_exp : int + The weight exponent used on this connection if learning is enabled. tracing : bool Whether synaptic tracing is enabled for these synapses. - tracing_tau : float - The tracing time constant. + tracing_tau : int + Decay time constant for the learning trace, in timesteps (not seconds). tracing_mag : float - The tracing increment magnitude. + Magnitude by which the learning trace is increased for each spike. """ def __init__(self, n_axons, label=None): self.n_axons = n_axons @@ -411,6 +459,9 @@ def __init__(self, n_axons, label=None): self.indices = None self.axon_cx_bases = None self.axon_to_weight_map = None + + self.learning_rate = 1. + self.learning_wgt_exp = None self.tracing = False self.tracing_tau = None self.tracing_mag = None @@ -541,16 +592,22 @@ def set_population_weights( numSynapses=63, wgtBits=7) - def set_learning(self, tracing_tau=2, tracing_mag=1.0): + def set_learning( + self, learning_rate=1., tracing_tau=2, tracing_mag=1.0, wgt_exp=4): assert tracing_tau == int(tracing_tau), "tracing_tau must be integer" + self.tracing = True self.tracing_tau = int(tracing_tau) self.tracing_mag = tracing_mag self.format(learningCfg=1, stdpProfile=0) # ^ stdpProfile hard-coded for now (see loihi_interface) - mag_int, _ = tracing_mag_int_frac(self) - assert int(mag_int) < 2**7 + self.train_epoch = 2 + self.learn_epoch_k = 1 + self.learn_epoch = self.train_epoch * 2**self.learn_epoch_k + + self.learning_rate = learning_rate * self.learn_epoch + self.learning_wgt_exp = wgt_exp def format(self, **kwargs): if self.synapse_fmt is None: @@ -882,9 +939,64 @@ def overflow(x, bits, name=None): # --- allocate synapse memory self.axons_in = {synapses: [] for group in self.groups for synapses in group.synapses} - self.z = {synapses: np.zeros(synapses.n_axons, dtype=np.float64) - for group in self.groups for synapses in group.synapses - if synapses.tracing} # synapse traces + + learning_synapses = [ + synapses for group in self.groups + for synapses in group.synapses if synapses.tracing] + self.z = {synapses: np.zeros(synapses.n_axons, dtype=group_dtype) + for synapses in learning_synapses} # synapse traces + self.z_spikes = {synapses: set() for synapses in learning_synapses} + # Currently, PES learning only happens on Nodes, where we have + # pairs of on/off neurons. Therefore, the number of error dimensions + # is half the number of neurons. + self.pes_errors = {synapses: np.zeros(group.n//2, dtype=group_dtype) + for synapses in learning_synapses} + self.pes_error_scale = getattr(model, 'pes_error_scale', 1.) + + if group_dtype == np.int32: + def stochastic_round(x, dtype=group_dtype, rng=self.rng, + clip=None, name="values"): + x_sign = np.sign(x).astype(dtype) + x_frac, x_int = np.modf(np.abs(x)) + p = rng.rand(*x.shape) + y = x_int.astype(dtype) + (x_frac > p) + if clip is not None: + q = y > clip + if np.any(q): + warnings.warn("Clipping %s" % name) + y[q] = clip + return x_sign * y + + def trace_round(x, dtype=group_dtype, rng=self.rng): + return stochastic_round( + x, dtype=dtype, rng=rng, clip=127, name="synapse trace") + + def weight_update(synapses, delta_ws): + synapse_fmt = synapses.synapse_fmt + wgt_exp = synapse_fmt.realWgtExp + shift_bits = synapse_fmt.shift_bits + overflow = learn_overflow_bits(n_factors=2) + for w, delta_w in zip(synapses.weights, delta_ws): + product = shift( + delta_w * synapses._lr_int, + LEARN_FRAC + synapses._lr_exp - overflow) + learn_w = shift(w, LEARN_FRAC - wgt_exp) + product + learn_w[:] = stochastic_round( + learn_w * 2**(-LEARN_FRAC - shift_bits), + clip=2**(8 - shift_bits) - 1, + name="learning weights") + w[:] = np.left_shift(learn_w, wgt_exp + shift_bits) + + elif group_dtype == np.float32: + def trace_round(x, dtype=group_dtype): + return x # no rounding + + def weight_update(synapses, delta_ws): + for w, delta_w in zip(synapses.weights, delta_ws): + w += synapses.learning_rate * delta_w + + self.trace_round = trace_round + self.weight_update = weight_update # --- noise enableNoise = np.hstack([ @@ -966,15 +1078,15 @@ def host2chip(self, spikes, errors): for cx_spike_input, t, spike_idxs in spikes: cx_spike_input.add_spikes(t, spike_idxs) - learning_rate = 50 # This is set to match hardware - for synapses, t, e in errors: - z = self.z[synapses] - x = np.hstack([-e, e]) + # TODO: these are sent every timestep, but learning only happens every + # `tepoch * 2**learn_k` timesteps (see CxSynapses). Need to average. + for pes_errors in self.pes_errors.values(): + pes_errors[:] = 0 - delta_w = np.outer(z, x) * learning_rate - - for i, w in enumerate(synapses.weights): - w += delta_w[i].astype('int32') + for synapses, t, e in errors: + pes_errors = self.pes_errors[synapses] + assert pes_errors.shape == e.shape + pes_errors += scale_pes_errors(e, scale=self.pes_error_scale) def step(self): # noqa: C901 """Advance the simulation by 1 step (``dt`` seconds).""" @@ -1020,16 +1132,30 @@ def step(self): # noqa: C901 spike.axon_id, atom=spike.atom) qb[0, cx_base + indices] += weights - if synapses.tracing: - z = self.z[synapses] - tau = synapses.tracing_tau - mag = synapses.tracing_mag - - decay = np.exp(-1.0 / tau) - z *= decay - + # --- learning trace + z_spikes = self.z_spikes.get(synapses, None) + if z_spikes is not None: for spike in self.axons_in[synapses]: - z[spike.axon_id] += mag + if spike.axon_id in z_spikes: + self.error("Synaptic trace spikes lost") + z_spikes.add(spike.axon_id) + + z = self.z.get(synapses, None) + if z is not None and self.t % synapses.train_epoch == 0: + tau = synapses.tracing_tau + decay = np.exp(-synapses.train_epoch / tau) + zi = decay*z + zi[list(z_spikes)] += synapses.tracing_mag + z[:] = self.trace_round(zi) + z_spikes.clear() + + # --- learning update + pes_e = self.pes_errors.get(synapses, None) + if pes_e is not None and self.t % synapses.learn_epoch == 0: + assert z is not None + x = np.hstack([-pes_e, pes_e]) + delta_w = np.outer(z, x) + self.weight_update(synapses, delta_w) # --- updates q0 = self.q[0, :] diff --git a/nengo_loihi/loihi_interface.py b/nengo_loihi/loihi_interface.py index d3a8dafd0..0d6c867c2 100644 --- a/nengo_loihi/loihi_interface.py +++ b/nengo_loihi/loihi_interface.py @@ -32,7 +32,12 @@ def no_nxsdk(*args, **kwargs): import nengo_loihi.loihi_cx as loihi_cx from nengo_loihi.allocators import one_to_one_allocator from nengo_loihi.loihi_api import ( - bias_to_manexp, CX_PROFILES_MAX, SpikeInput, VTH_PROFILES_MAX) + bias_to_manexp, + CX_PROFILES_MAX, + scale_pes_errors, + SpikeInput, + VTH_PROFILES_MAX, +) from nengo_loihi.loihi_cx import CxGroup logger = logging.getLogger(__name__) @@ -593,6 +598,7 @@ def _iter_probes(self): def build(self, cx_model, seed=None): cx_model.validate() self.model = cx_model + self.pes_error_scale = getattr(cx_model, 'pes_error_scale', 1.) if self.use_snips: # tag all probes as being snip-based, @@ -729,8 +735,6 @@ def host2chip(self, spikes, errors): loihi_errors = [] for synapses, t, e in errors: - x = (100 * e).astype(int) - x = np.clip(x, -100, 100, out=x) cx_group = synapses.group coreid = None for core in self.board.chips[0].cores: @@ -744,7 +748,8 @@ def host2chip(self, spikes, errors): break assert coreid is not None - loihi_errors.append([coreid, len(x)] + x.tolist()) + e = scale_pes_errors(e, scale=self.pes_error_scale) + loihi_errors.append([coreid, len(e)] + e.tolist()) if self.use_snips: return self._host2chip_snips(loihi_spikes, loihi_errors) diff --git a/nengo_loihi/tests/test_learning.py b/nengo_loihi/tests/test_learning.py index a4df2f324..b9de7a8b8 100644 --- a/nengo_loihi/tests/test_learning.py +++ b/nengo_loihi/tests/test_learning.py @@ -1,14 +1,29 @@ import nengo +from nengo.exceptions import ValidationError, SimulationError +from nengo.utils.numpy import rms import numpy as np import pytest +import nengo_loihi.builder -@pytest.mark.parametrize('n_per_dim', [120, 200]) -@pytest.mark.parametrize('dims', [1, 3]) -def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims): - scale = np.linspace(1, 0, dims + 1)[:-1] - input_fn = lambda t: np.sin(t * 2 * np.pi) * scale +def pes_network( + n_per_dim, + dims, + seed, + learning_rule_type=nengo.PES(learning_rate=1e-3), + input_scale=None, + error_scale=1., + learn_synapse=0.005, + probe_synapse=0.02, +): + if input_scale is None: + input_scale = np.linspace(1, 0, dims + 1)[:-1] + assert input_scale.size == dims + + input_fn = lambda t: np.sin(t * 2 * np.pi) * input_scale + + probes = {} with nengo.Network(seed=seed) as model: stim = nengo.Node(input_fn) @@ -19,44 +34,150 @@ def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims): conn = nengo.Connection( pre, post, function=lambda x: np.zeros(dims), - synapse=0.01, - learning_rule_type=nengo.PES(learning_rate=1e-3)) + synapse=learn_synapse, + learning_rule_type=learning_rule_type) - nengo.Connection(post, conn.learning_rule) - nengo.Connection(stim, conn.learning_rule, transform=-1) + nengo.Connection(post, conn.learning_rule, transform=error_scale) + nengo.Connection(stim, conn.learning_rule, transform=-error_scale) - p_stim = nengo.Probe(stim, synapse=0.02) - p_pre = nengo.Probe(pre, synapse=0.02) - p_post = nengo.Probe(post, synapse=0.02) + probes['stim'] = nengo.Probe(stim, synapse=probe_synapse) + probes['pre'] = nengo.Probe(pre, synapse=probe_synapse) + probes['post'] = nengo.Probe(post, synapse=probe_synapse) - with Simulator(model) as sim: - sim.run(5.0) + return model, probes + + +@pytest.mark.parametrize('n_per_dim', [120, 200]) +@pytest.mark.parametrize('dims', [1, 3]) +def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims): + tau = 0.01 + model, probes = pes_network(n_per_dim, dims, seed, learn_synapse=tau) + + simtime = 5.0 + with nengo.Simulator(model) as nengo_sim: + nengo_sim.run(simtime) + + with Simulator(model) as loihi_sim: + loihi_sim.run(simtime) + + with Simulator(model, target='simreal') as real_sim: + real_sim.run(simtime) + + t = nengo_sim.trange() + pre_tmask = t > 0.1 + post_tmask = t > simtime - 1.0 + + inter_tau = loihi_sim.model.inter_tau + y = nengo_sim.data[probes['stim']] + y_dpre = nengo.Lowpass(inter_tau).filt(y) + y_dpost = nengo.Lowpass(tau).combine(nengo.Lowpass(inter_tau)).filt(y_dpre) + y_nengo = nengo_sim.data[probes['post']] + y_loihi = loihi_sim.data[probes['post']] + y_real = real_sim.data[probes['post']] - t = sim.trange() plt.subplot(211) - plt.plot(t, sim.data[p_stim]) - plt.plot(t, sim.data[p_pre]) - plt.plot(t, sim.data[p_post]) - - # --- fit input_fn to output, determine magnitude - # The larger the magnitude, the closer the output is to the input - x = np.array([input_fn(tt)[0] for tt in t[t > 4]]) - y = sim.data[p_post][t > 4][:, 0] - m = np.linspace(0, 1, 21) - errors = np.abs(y - m[:, None]*x).mean(axis=1) - m_best = m[np.argmin(errors)] + plt.plot(t, y_dpost, 'k', label='target') + plt.plot(t, y_nengo, 'b', label='nengo') + plt.plot(t, y_loihi, 'g', label='loihi') + plt.plot(t, y_real, 'r:', label='real') + plt.legend() plt.subplot(212) - plt.plot(t[t > 4], x) - plt.plot(t[t > 4], y) - plt.plot(t[t > 4], m_best * x, ':') + plt.plot(t[post_tmask], y_loihi[post_tmask] - y_dpost[post_tmask], 'k') + plt.plot(t[post_tmask], y_loihi[post_tmask] - y_nengo[post_tmask], 'b') + + x_loihi = loihi_sim.data[probes['pre']] + assert allclose(x_loihi[pre_tmask], y_dpre[pre_tmask], + atol=0.1, rtol=0.05) + + assert allclose(y_loihi[post_tmask], y_dpost[post_tmask], + atol=0.1, rtol=0.05) + assert allclose(y_loihi, y_nengo, atol=0.2, rtol=0.2) + + assert allclose(y_real[post_tmask], y_dpost[post_tmask], + atol=0.1, rtol=0.05) + assert allclose(y_real, y_nengo, atol=0.2, rtol=0.2) + + +def test_pes_overflow(allclose, plt, seed, Simulator): + dims = 3 + n_per_dim = 120 + tau = 0.01 + model, probes = pes_network(n_per_dim, dims, seed, learn_synapse=tau, + input_scale=np.linspace(1, 0.7, dims)) + + simtime = 3.0 + loihi_model = nengo_loihi.builder.Model() + # set learning_wgt_exp low to create overflow in weight values + loihi_model.pes_wgt_exp = -1 + + with Simulator(model, model=loihi_model) as loihi_sim: + loihi_sim.run(simtime) - assert allclose(sim.data[p_pre][t > 0.1], - sim.data[p_stim][t > 0.1], - atol=0.15, - rtol=0.15) - assert np.min(errors) < 0.3, "Not able to fit correctly" - assert m_best > (0.3 if n_per_dim < 150 else 0.6) + t = loihi_sim.trange() + post_tmask = t > simtime - 1.0 + + inter_tau = loihi_sim.model.inter_tau + y = loihi_sim.data[probes['stim']] + y_dpre = nengo.Lowpass(inter_tau).filt(y) + y_dpost = nengo.Lowpass(tau).combine(nengo.Lowpass(inter_tau)).filt(y_dpre) + y_loihi = loihi_sim.data[probes['post']] + + plt.plot(t, y_dpost, 'k', label='target') + plt.plot(t, y_loihi, 'g', label='loihi') + + # --- fit output to scaled version of target output + z_ref0 = y_dpost[post_tmask][:, 0] + z_loihi = y_loihi[post_tmask] + scale = np.linspace(0, 1, 50) + E = np.abs(z_loihi - scale[:, None, None]*z_ref0[:, None]) + errors = E.mean(axis=1) # average over time (errors is: scales x dims) + for j in range(dims): + errors_j = errors[:, j] + i = np.argmin(errors_j) + assert errors_j[i] < 0.1, ("Learning output for dim %d did not match " + "any scaled version of the target output" + % j) + assert scale[i] > 0.4, "Learning output for dim %d is too small" % j + assert scale[i] < 0.7, ("Learning output for dim %d is too large " + "(weights or traces not clipping as expected)" + % j) + + +def test_pes_error_clip(allclose, plt, seed, Simulator): + dims = 2 + n_per_dim = 120 + tau = 0.01 + error_scale = 5. # scale up error signal so it clips + model, probes = pes_network( + n_per_dim, dims, seed, learn_synapse=tau, + learning_rule_type=nengo.PES(learning_rate=1e-3 / error_scale), + input_scale=np.array([1., -1.]), + error_scale=error_scale) + + simtime = 3.0 + with pytest.warns(UserWarning, match=r'.*PES error.*Clipping.'): + with Simulator(model) as loihi_sim: + loihi_sim.run(simtime) + + t = loihi_sim.trange() + post_tmask = t > simtime - 1.0 + + inter_tau = loihi_sim.model.inter_tau + y = loihi_sim.data[probes['stim']] + y_dpre = nengo.Lowpass(inter_tau).filt(y) + y_dpost = nengo.Lowpass(tau).combine(nengo.Lowpass(inter_tau)).filt(y_dpre) + y_loihi = loihi_sim.data[probes['post']] + + plt.plot(t, y_dpost, 'k', label='target') + plt.plot(t, y_loihi, 'g', label='loihi') + + # --- assert that we've learned something, but not everything + error = (rms(y_loihi[post_tmask] - y_dpost[post_tmask]) + / rms(y_dpost[post_tmask])) + assert error < 0.5 + assert error > 0.05 + # ^ error on emulator vs chip is quite different, hence large tolerances @pytest.mark.parametrize('init_function', [None, lambda x: 0]) @@ -74,19 +195,65 @@ def test_multiple_pes(init_function, allclose, plt, seed, Simulator): pre_ea.ea_ensembles[i], output[i], function=init_function, - learning_rule_type=nengo.PES(learning_rate=5e-4), + learning_rule_type=nengo.PES(learning_rate=3e-3), ) nengo.Connection(target[i], conn.learning_rule, transform=-1) nengo.Connection(output[i], conn.learning_rule) probe = nengo.Probe(output, synapse=0.1) + + simtime = 2.5 with Simulator(model) as sim: - sim.run(1.0) + sim.run(simtime) + t = sim.trange() + tmask = t > simtime * 0.85 plt.plot(t, sim.data[probe]) for target, style in zip(targets, plt.rcParams["axes.prop_cycle"]): plt.axhline(target, **style) for i, target in enumerate(targets): - assert allclose(sim.data[probe][t > 0.8, i], target, atol=0.1) + assert allclose(sim.data[probe][tmask, i], target, + atol=0.05, rtol=0.05), "Target %d not close" % i + + +def test_pes_pre_synapse_type_error(Simulator): + with nengo.Network() as model: + pre = nengo.Ensemble(10, 1) + post = nengo.Node(size_in=1) + rule_type = nengo.PES(pre_synapse=nengo.Alpha(0.005)) + conn = nengo.Connection(pre, post, learning_rule_type=rule_type) + nengo.Connection(post, conn.learning_rule) + + with pytest.raises(ValidationError): + with Simulator(model): + pass + + +def test_pes_trace_increment_clip_warning(seed, Simulator): + dims = 2 + n_per_dim = 120 + model, _ = pes_network( + n_per_dim, dims, seed, + learning_rule_type=nengo.PES(learning_rate=1e-1)) + + with pytest.warns(UserWarning, match="Trace increment exceeds upper"): + with Simulator(model): + pass + + +def test_drop_trace_spikes(Simulator, seed): + with nengo.Network(seed=seed) as net: + a = nengo.Ensemble(10, 1, gain=nengo.dists.Choice([1]), + bias=nengo.dists.Choice([2000]), + neuron_type=nengo.SpikingRectifiedLinear()) + b = nengo.Node(size_in=1) + + conn = nengo.Connection(a, b, learning_rule_type=nengo.PES(1)) + + nengo.Connection(b, conn.learning_rule) + + with Simulator(net, target="sim") as sim: + with pytest.raises(SimulationError): + sim.run(1.0) diff --git a/nengo_loihi/tests/test_loihi_api.py b/nengo_loihi/tests/test_loihi_api.py index 86065cc05..712f16126 100644 --- a/nengo_loihi/tests/test_loihi_api.py +++ b/nengo_loihi/tests/test_loihi_api.py @@ -2,8 +2,8 @@ import numpy as np import pytest -from nengo_loihi.loihi_api import overflow_signed -from nengo_loihi.loihi_api import decay_int, decay_magnitude +from nengo_loihi.loihi_api import ( + overflow_signed, decay_int, decay_magnitude, SynapseFmt) @pytest.mark.parametrize("b", (8, 16, 17, 23)) @@ -87,3 +87,17 @@ def empirical_decay_magnitude(decay, x0): plt.plot(relative_diff.clip(0, None)) assert np.all(relative_diff < 1e-6) + + +@pytest.mark.parametrize("lossy_shift", (True, False)) +def test_lossy_shift(lossy_shift, rng): + wgt_bits = 6 + w = rng.uniform(-100, 100, size=(10, 10)) + fmt = SynapseFmt(wgtBits=wgt_bits, wgtExp=0, fanoutType=0) + + w2 = fmt.discretize_weights(w, lossy_shift=lossy_shift) + + clipped = np.round(w / 4).clip(-2 ** wgt_bits, 2 ** wgt_bits).astype( + np.int32) + + assert np.allclose(w2, np.left_shift(clipped, 8))