diff --git a/nengo_loihi/simulator.py b/nengo_loihi/simulator.py index b9de67c39..cebbbf038 100644 --- a/nengo_loihi/simulator.py +++ b/nengo_loihi/simulator.py @@ -83,6 +83,7 @@ class Simulator(object): def __init__(self, network, dt=0.001, seed=None, model=None, # noqa: C901 precompute=True, target=None): self.closed = True # Start closed in case constructor raises exception + self.running = False if model is None: # Call the builder to make a model @@ -229,8 +230,21 @@ def close(self): `.Simulator.step`, and `.Simulator.reset` on a closed simulator raises a `.SimulatorClosed` exception. """ + if self.closed: + return self.closed = True self.signals = None # signals may no longer exist on some backends + if self.loihi is not None: + if not self.precompute and self.running: + print('send a stop signal to the snip') + self.loihi.nengo_io_h2c.write(1, [-1]) + import time + time.sleep(0.5) # need to wait for stop signal to arrive + print('closing loihi...') + # this takes a while, since loihi seems to need to finish + # all of the steps its been told to do + self.loihi.close() + print('...closed') def _probe(self): """Copy all probed signals to buffers.""" @@ -316,10 +330,11 @@ def step(self): self.run_steps(1) - def run_steps(self, steps): + def run_steps(self, steps): # noqa: C901 if self.closed: raise SimulatorClosed("Simulator cannot run because it is closed.") + self.running = True if self.simulator is not None: if self.precompute: self.host_pre_sim.run_steps(steps) @@ -327,14 +342,19 @@ def run_steps(self, steps): self.simulator.run_steps(steps) self.handle_chip2host_communications() self.host_post_sim.run_steps(steps) + self._n_steps += steps elif self.host_sim is None: self.simulator.run_steps(steps) + self._n_steps += steps else: for i in range(steps): + if self.closed: + break self.host_sim.step() self.handle_host2chip_communications() self.simulator.step() self.handle_chip2host_communications() + self._n_steps += 1 elif self.loihi is not None: if self.precompute: self.host_pre_sim.run_steps(steps) @@ -342,22 +362,35 @@ def run_steps(self, steps): self.loihi.run_steps(steps) self.handle_chip2host_communications() self.host_post_sim.run_steps(steps) + self._n_steps += steps elif self.host_sim is not None: self.loihi.create_io_snip() - self.loihi.run_steps(steps, async=True) - for i in range(steps): - self.host_sim.run_steps(1) - self.handle_host2chip_communications() - self.handle_chip2host_communications() - - print('Waiting for completion') - self.loihi.nengo_io_h2c.write(1, [0]) - self.loihi.wait_for_completion() - print("done") + try: + self.loihi.run_steps(steps, async=True) + for i in range(steps): + if self.closed: + break + self.host_sim.run_steps(1) + self.handle_host2chip_communications() + self.handle_chip2host_communications() + self._n_steps += 1 + except KeyboardInterrupt: + # tell the snip to shut down + self.loihi.nengo_io_h2c.write(1, [-1]) + raise + finally: + try: + # tell the snip to shut down + self.loihi.nengo_io_h2c.write(1, [-1]) + self.loihi.wait_for_completion() + except EOFError: + # it has already been shut down + pass else: self.loihi.run_steps(steps) + self._n_steps += steps - self._n_steps += steps + self.running = False self._probe() def handle_host2chip_communications(self): # noqa: C901 diff --git a/nengo_loihi/snips/nengo_io.c.template b/nengo_loihi/snips/nengo_io.c.template index 4a28593fd..1809c71ec 100644 --- a/nengo_loihi/snips/nengo_io.c.template +++ b/nengo_loihi/snips/nengo_io.c.template @@ -5,9 +5,20 @@ #define N_OUTPUTS %d void nengo_io(runState *s) { + // use the last element in userData as a quit flag + if (s->userData[1023] != 0) { + return; + } + int count[1]; readChannel("nengo_io_h2c", count, 1); //printf("count %%d\n", count[0]); + if (count[0] < 0) { + // a negative value indicates we should stop + s->userData[1023] = 1; + printf("stopping\n"); + return; + } int spike[2]; for (int i=0; i 0.05: + raise Exception('Stopping') + output = nengo.Node(output_func, size_in=1) + nengo.Connection(ens, output) + + with Simulator(model, precompute=False) as sim: + with pytest.raises(Exception): + sim.run(0.1) + + +def test_closing(Simulator, seed): + model = nengo.Network(seed=seed) + with model: + stim = nengo.Node(0.5) + ens = nengo.Ensemble(n_neurons=100, dimensions=1) + nengo.Connection(stim, ens) + + def output_func(t, x): + if t > 0.05: + sim.close() + output = nengo.Node(output_func, size_in=1) + nengo.Connection(ens, output) + + sim = Simulator(model, precompute=False) + + sim.run(0.1) + assert sim.n_steps == 51