Skip to content

Commit

Permalink
Add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
shendiaomo committed Nov 4, 2020
1 parent 789e461 commit 784a3b7
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions nn/parallel/parallel_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package parallel

import (
"fmt"
"github.com/stretchr/testify/assert"
torch "github.com/wangkuiyi/gotorch"
"github.com/wangkuiyi/gotorch/nn"
"testing"
)

type myModelModule struct {
nn.Module // Every model must derive from Module
}

// Forward executes the calculation
func (m *myModelModule) Forward(x torch.Tensor) torch.Tensor {
fmt.Println("Forward")
return torch.Tensor{nil}
}

func myModel() *myModelModule {
m := &myModelModule{}
m.Init(m)
return m
}

func TestDataParallel(t *testing.T) {
m := myModel()
// panic: Parallel API needs -DWITH_CUDA on building libcgotorch.so
assert.Panics(t, func() {
DataParallel(m, torch.Tensor{nil}, []torch.Device{}, torch.Device{}, 0)
})
}

0 comments on commit 784a3b7

Please sign in to comment.