Skip to content

Commit

Permalink
Merge pull request #247 from SpikeAI/master
Browse files Browse the repository at this point in the history
Add MPS as a possible computing backend
  • Loading branch information
jeshraghian authored Oct 10, 2023
2 parents 71836f8 + 6d1c48e commit 5f2d292
Show file tree
Hide file tree
Showing 31 changed files with 2,523 additions and 2,709 deletions.
4 changes: 2 additions & 2 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ Define variables for dataloading.
::

batch_size = 128
data_path='/data/mnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
data_path='/tmp/data/mnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

Load MNIST dataset.

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/legacy/tutorial_1_old.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Let's define a few variables:

# Training Parameters
batch_size=128
data_path='/data/mnist'
data_path='/tmp/data/mnist'
num_classes = 10

# Temporal Dynamics
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/legacy/tutorial_3_old.rst
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Much of the following code has already been explained in the first two tutorials

# Training Parameters
batch_size=128
data_path='/data/mnist'
data_path='/tmp/data/mnist'

# Temporal Dynamics
num_steps = 25
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/tutorial_1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ Install the latest PyPi distribution of snnTorch:

# Training Parameters
batch_size=128
data_path='/data/mnist'
data_path='/tmp/data/mnist'
num_classes = 10 # MNIST has 10 output classes
# Torch Variables
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/tutorial_5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,10 @@ training a fully-connected spiking neural net.

# dataloader arguments
batch_size = 128
data_path='/data/mnist'
data_path='/tmp/data/mnist'
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

::

Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/tutorial_6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ here. <https://snntorch.readthedocs.io/en/latest/snntorch.surrogate.html>`__

# dataloader arguments
batch_size = 128
data_path='/data/mnist'
data_path='/tmp/data/mnist'
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

::

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/tutorial_7.rst
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ previous tutorial. The convolutional network architecture to be used is:

::

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
# neuron and simulation parameters
spike_grad = surrogate.atan()
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/tutorial_ipu_1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Load in the MNIST dataset.
from torchvision import datasets, transforms

batch_size = 128
data_path='/data/mnist'
data_path='/tmp/data/mnist'
# Define a transform
transform = transforms.Compose([
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/tutorial_pop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ Define variables for dataloading.
::

batch_size = 128
data_path='/data/fmnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
data_path='/tmp/data/fmnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

Load FashionMNIST dataset.

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/tutorial_regression_1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ Instantiate the network below:
::

hidden = 128
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model = Net(timesteps=num_steps, hidden=hidden).to(device)


Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/tutorial_regression_2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ Instantiate the network below:
::

hidden = 128
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model = Net(timesteps=num_steps, hidden=hidden, beta=0.9).to(device)

4. Construct Training Loop
Expand All @@ -332,7 +332,7 @@ temporal data is an exercise left to the reader/coder.
::

batch_size = 128
data_path='/data/mnist'
data_path='/tmp/data/mnist'
# Define a transform
transform = transforms.Compose([
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/tutorial_sae.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,10 @@
# dataloader arguments
batch_size = 250
data_path='/data/mnist'
data_path='/tmp/data/mnist'
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
.. container:: cell code

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/zh-cn/tutorial_1_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ snnTorch 教程系列基于以下论文。如果您发现这些资源或代码

# Training Parameters
batch_size=128
data_path='/data/mnist'
data_path='/tmp/data/mnist'
num_classes = 10 # MNIST has 10 output classes
# Torch Variables
Expand Down
Loading

0 comments on commit 5f2d292

Please sign in to comment.