From ead150a2f00136f6c72dee81f17b7851fa4286d9 Mon Sep 17 00:00:00 2001 From: Max Weinzierl Date: Thu, 18 Oct 2018 11:11:12 -0500 Subject: [PATCH] Fix for RuntimeError on torch tensor indexing This resolves #2 Fix for RuntimeError: tensors used as indices must be long or byte tensors --- pytorch_neat/recurrent_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_neat/recurrent_net.py b/pytorch_neat/recurrent_net.py index 1bf811b..413971e 100644 --- a/pytorch_neat/recurrent_net.py +++ b/pytorch_neat/recurrent_net.py @@ -33,7 +33,7 @@ def dense_from_coo(shape, conns, dtype=torch.float64): idxs, weights = conns if len(idxs) == 0: return mat - rows, cols = np.array(idxs).transpose() + rows, cols = np.array(idxs, dtype=np.int64).transpose() mat[torch.tensor(rows), torch.tensor(cols)] = torch.tensor( weights, dtype=dtype) return mat