Skip to content

Commit

Permalink
Add Must to rand package
Browse files Browse the repository at this point in the history
  • Loading branch information
itsubaki committed Apr 6, 2024
1 parent 2a5e604 commit 84080a9
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
2 changes: 1 addition & 1 deletion layer/linear.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type LinearOpts struct {
}

func Linear(outSize int, opts ...LinearOpts) *LinearT {
s := rand.NewSource()
s := rand.MustNewSource()
if len(opts) != 0 && opts[0].Source != nil {
s = opts[0].Source
}
Expand Down
2 changes: 1 addition & 1 deletion matrix/matrix.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func From(x [][]int) Matrix {
// rnd returns a pseudo-random number generator.
func rnd(s ...randv2.Source) *randv2.Rand {
if len(s) == 0 {
s = append(s, rand.NewSource())
s = append(s, rand.MustNewSource())
}

return randv2.New(s[0])
Expand Down
19 changes: 16 additions & 3 deletions rand/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,30 @@ package rand

import (
crand "crypto/rand"
"fmt"
randv2 "math/rand/v2"
)

func Must[T any](a T, err error) T {
if err != nil {
panic(err)
}

return a
}

func MustNewSource() randv2.Source {
return Must(NewSource())
}

// NewSource returns a source of pseudo-random number generator
func NewSource() randv2.Source {
func NewSource() (randv2.Source, error) {
var p [32]byte
if _, err := crand.Read(p[:]); err != nil {
panic(err)
return nil, fmt.Errorf("read: %v", err)
}

return randv2.NewChaCha8(p)
return randv2.NewChaCha8(p), nil
}

// Const returns a source of constant pseudo-random number generator
Expand Down
4 changes: 2 additions & 2 deletions rand/source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func ExampleConst() {
// 0.6764556596678
}

func TestNewSource(t *testing.T) {
v := randv2.New(rand.NewSource()).Float64()
func TestMustNewSource(t *testing.T) {
v := randv2.New(rand.MustNewSource()).Float64()
if v >= 0 && v < 1 {
return
}
Expand Down

0 comments on commit 84080a9

Please sign in to comment.