Skip to content

Commit

Permalink
Add MPS as a possible computing backend
Browse files Browse the repository at this point in the history
These changes just add a quick test to know if the MPS device from the Apple Silicon chips is available. Was tested on tutorials with no problem.
  • Loading branch information
laurentperrinet committed Oct 10, 2023
1 parent 71836f8 commit 97f60f5
Show file tree
Hide file tree
Showing 24 changed files with 2,499 additions and 2,685 deletions.
2 changes: 1 addition & 1 deletion docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Define variables for dataloading.

batch_size = 128
data_path='/data/mnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

Load MNIST dataset.

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/tutorial_5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ training a fully-connected spiking neural net.
data_path='/data/mnist'
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

::

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/tutorial_6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ here. <https://snntorch.readthedocs.io/en/latest/snntorch.surrogate.html>`__
data_path='/data/mnist'
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

::

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/tutorial_7.rst
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ previous tutorial. The convolutional network architecture to be used is:

::

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
# neuron and simulation parameters
spike_grad = surrogate.atan()
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/tutorial_pop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Define variables for dataloading.

batch_size = 128
data_path='/data/fmnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

Load FashionMNIST dataset.

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/tutorial_regression_1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ Instantiate the network below:
::

hidden = 128
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model = Net(timesteps=num_steps, hidden=hidden).to(device)


Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/tutorial_regression_2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ Instantiate the network below:
::

hidden = 128
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model = Net(timesteps=num_steps, hidden=hidden, beta=0.9).to(device)

4. Construct Training Loop
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/tutorial_sae.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
data_path='/data/mnist'
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
.. container:: cell code

Expand Down
Loading

0 comments on commit 97f60f5

Please sign in to comment.