Skip to content

Commit

Permalink
Update rand package
Browse files Browse the repository at this point in the history
  • Loading branch information
itsubaki committed Apr 13, 2024
1 parent 764bd73 commit 65bed03
Show file tree
Hide file tree
Showing 32 changed files with 128 additions and 111 deletions.
6 changes: 3 additions & 3 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ type Agent struct {
}

func (a *Agent) GetAction() int {
rng := randv2.New(a.Source)
if a.Epsilon > rng.Float64() {
return rng.IntN(len(a.Qs))
g := randv2.New(a.Source)
if a.Epsilon > g.Float64() {
return g.IntN(len(a.Qs))
}

return vector.Argmax(a.Qs)
Expand Down
6 changes: 3 additions & 3 deletions agent/alpha.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ type AlphaAgent struct {
}

func (a *AlphaAgent) GetAction() int {
rng := randv2.New(a.Source)
if a.Epsilon > rng.Float64() {
return rng.IntN(len(a.Qs))
g := randv2.New(a.Source)
if a.Epsilon > g.Float64() {
return g.IntN(len(a.Qs))
}

return vector.Argmax(a.Qs)
Expand Down
6 changes: 3 additions & 3 deletions agent/dqn.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ func (a *DQNAgent) Sync() {
}

func (a *DQNAgent) GetAction(state []float64) int {
rng := randv2.New(a.Source)
if a.Epsilon > rng.Float64() {
return rng.IntN(a.ActionSize)
g := randv2.New(a.Source)
if a.Epsilon > g.Float64() {
return g.IntN(a.ActionSize)
}

qs := a.Q.Predict(matrix.New(state))
Expand Down
4 changes: 2 additions & 2 deletions agent/env/bandit.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ func NewBandit(arms int, s randv2.Source) *Bandit {
}

func (b *Bandit) Play(arm int) float64 {
rng := randv2.New(b.Source)
if b.Rates[arm] > rng.Float64() {
g := randv2.New(b.Source)
if b.Rates[arm] > g.Float64() {
return 1
}

Expand Down
10 changes: 5 additions & 5 deletions agent/qlearning.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package agent

import (
"fmt"
"math/rand"
randv2 "math/rand/v2"

"github.com/itsubaki/neu/math/vector"
)
Expand All @@ -13,13 +13,13 @@ type QLearningAgent struct {
Epsilon float64
ActionSize int
Q DefaultMap[float64]
Source rand.Source
Source randv2.Source
}

func (a *QLearningAgent) GetAction(state fmt.Stringer) int {
rng := rand.New(a.Source)
if a.Epsilon > rng.Float64() {
return rng.Intn(a.ActionSize)
g := randv2.New(a.Source)
if a.Epsilon > g.Float64() {
return g.IntN(a.ActionSize)
}

qs := qstate(a.Q, state.String(), a.ActionSize)
Expand Down
18 changes: 9 additions & 9 deletions agent/qlearning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package agent_test

import (
"fmt"
"math/rand"
"strconv"
"strings"

"github.com/itsubaki/neu/agent"
"github.com/itsubaki/neu/agent/env"
"github.com/itsubaki/neu/math/rand"
)

func ExampleQLearningAgent() {
Expand All @@ -18,7 +18,7 @@ func ExampleQLearningAgent() {
Epsilon: 0.1,
ActionSize: 4,
Q: make(map[string]float64),
Source: rand.NewSource(1),
Source: rand.Const(1),
}

episodes := 10000
Expand Down Expand Up @@ -64,24 +64,24 @@ func ExampleQLearningAgent() {
// (1, 2) UP : 0.9000
// (1, 2) DOWN : 0.7290
// (1, 2) LEFT : 0.8100
// (1, 2) RIGHT : -0.1000
// (1, 2) RIGHT : -0.1001
// (1, 3) UP : 1.0000
// (1, 3) DOWN : 0.0000
// (1, 3) LEFT : 0.6480
// (1, 3) RIGHT : 0.0000
// (1, 3) LEFT : 0.0000
// (1, 3) RIGHT : -0.0812
// (2, 0) UP : 0.6561
// (2, 0) DOWN : 0.5905
// (2, 0) LEFT : 0.5905
// (2, 0) RIGHT : 0.6561
// (2, 1) UP : 0.6561
// (2, 1) DOWN : 0.6557
// (2, 1) UP : 0.6559
// (2, 1) DOWN : 0.6559
// (2, 1) LEFT : 0.5905
// (2, 1) RIGHT : 0.7290
// (2, 2) UP : 0.8100
// (2, 2) DOWN : 0.7288
// (2, 2) DOWN : 0.7290
// (2, 2) LEFT : 0.6561
// (2, 2) RIGHT : 0.0000
// (2, 3) UP : -0.0992
// (2, 3) UP : -0.1000
// (2, 3) DOWN : 0.0000
// (2, 3) LEFT : 0.0000
// (2, 3) RIGHT : 0.0000
Expand Down
6 changes: 3 additions & 3 deletions agent/replay_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type ReplayBuffer struct {

func NewReplayBuffer(bufferSize, batchSize int, s ...randv2.Source) *ReplayBuffer {
if len(s) == 0 {
s = append(s, rand.MustNewSource())
s = append(s, rand.NewSource(rand.MustRead()))
}

return &ReplayBuffer{
Expand All @@ -47,11 +47,11 @@ func (b *ReplayBuffer) Len() int {
}

func (b *ReplayBuffer) Batch() ([][]float64, []int, []float64, [][]float64, []bool) {
rng := randv2.New(b.Source)
g := randv2.New(b.Source)

counter := make(map[int]bool)
for c := 0; c < b.BatchSize; {
n := rng.IntN(b.Len())
n := g.IntN(b.Len())
if _, ok := counter[n]; !ok {
counter[n] = true
c++
Expand Down
2 changes: 1 addition & 1 deletion cmd/dqn/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func main() {
Beta1: beta1,
Beta2: beta2,
},
Source: rand.MustNewSource(),
Source: rand.NewSource(rand.MustRead()),
}

for i := 0; i < episode; i++ {
Expand Down
2 changes: 1 addition & 1 deletion dataset/sequence/sequence.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ type Dataset struct {

func Load(dir, fileName string, s ...randv2.Source) (*Dataset, *Dataset, *Vocab, error) {
if len(s) == 0 {
s = append(s, rand.MustNewSource())
s = append(s, rand.NewSource(rand.MustRead()))
}

// read
Expand Down
4 changes: 2 additions & 2 deletions layer/negative_sampling_loss.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type NegativeSamplingLoss struct {

func NewNegativeSamplingLoss(W matrix.Matrix, corpus []int, power float64, sampleSize int, s ...randv2.Source) *NegativeSamplingLoss {
if len(s) == 0 {
s = append(s, rand.MustNewSource())
s = append(s, rand.NewSource(rand.MustRead()))
}

embed, loss := make([]EmbeddingDot, sampleSize+1), make([]SigmoidWithLoss, sampleSize+1)
Expand Down Expand Up @@ -128,7 +128,7 @@ func NewUnigramSampler(corpus []int, power float64, size int) *UnigramSampler {

func (s *UnigramSampler) NegativeSample(target []int, seed ...randv2.Source) [][]int {
if len(seed) == 0 {
seed = append(seed, rand.MustNewSource())
seed = append(seed, rand.NewSource(rand.MustRead()))
}

N := len(target)
Expand Down
2 changes: 1 addition & 1 deletion math/matrix/matrix.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func One(m, n int) Matrix {
// rnd returns a pseudo-random number generator.
func rnd(s ...randv2.Source) *randv2.Rand {
if len(s) == 0 {
s = append(s, rand.MustNewSource())
s = append(s, rand.NewSource(rand.MustRead()))
}

return randv2.New(s[0])
Expand Down
27 changes: 27 additions & 0 deletions math/rand/crypto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package rand

import (
"crypto/rand"
"fmt"
)

func Must[T any](a T, err error) T {
if err != nil {
panic(err)
}

return a
}

func MustRead() [32]byte {
return Must(Read())
}

func Read() ([32]byte, error) {
var p [32]byte
if _, err := rand.Read(p[:]); err != nil {
return [32]byte{}, fmt.Errorf("read: %v", err)
}

return p, nil
}
36 changes: 36 additions & 0 deletions math/rand/crypto_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package rand_test

import (
"fmt"
randv2 "math/rand/v2"
"testing"

"github.com/itsubaki/neu/math/rand"
)

func TestMustRead(t *testing.T) {
v := randv2.New(rand.NewSource(rand.MustRead())).Float64()
if v >= 0 && v < 1 {
return
}

t.Fail()
}

func TestMustPanic(t *testing.T) {
defer func() {
if rec := recover(); rec != nil {
err, ok := rec.(error)
if !ok {
t.Fail()
}

if err.Error() != "something went wrong" {
t.Fail()
}
}
}()

rand.Must(-1, fmt.Errorf("something went wrong"))
t.Fail()
}
23 changes: 2 additions & 21 deletions math/rand/source.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,12 @@
package rand

import (
crand "crypto/rand"
"fmt"
randv2 "math/rand/v2"
)

func Must[T any](a T, err error) T {
if err != nil {
panic(err)
}

return a
}

func MustNewSource() randv2.Source {
return Must(NewSource())
}

// NewSource returns a source of pseudo-random number generator
func NewSource() (randv2.Source, error) {
var p [32]byte
if _, err := crand.Read(p[:]); err != nil {
return nil, fmt.Errorf("read: %v", err)
}

return randv2.NewChaCha8(p), nil
func NewSource(seed [32]byte) randv2.Source {
return randv2.NewChaCha8(seed)
}

// Const returns a source of constant pseudo-random number generator
Expand Down
27 changes: 0 additions & 27 deletions math/rand/source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,6 @@ func ExampleConst() {
// 0.6764556596678
}

func TestMustNewSource(t *testing.T) {
v := randv2.New(rand.MustNewSource()).Float64()
if v >= 0 && v < 1 {
return
}

t.Fail()
}

func TestConst(t *testing.T) {
v := randv2.New(rand.Const()).Float64()
if v >= 0 && v < 1 {
Expand All @@ -51,21 +42,3 @@ func TestConst(t *testing.T) {

t.Fail()
}

func TestMustPanic(t *testing.T) {
defer func() {
if rec := recover(); rec != nil {
err, ok := rec.(error)
if !ok {
t.Fail()
}

if err.Error() != "something went wrong" {
t.Fail()
}
}
}()

rand.Must(-1, fmt.Errorf("something went wrong"))
t.Fail()
}
Loading

0 comments on commit 65bed03

Please sign in to comment.