forked from temporalio/samples-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
swarm.go
152 lines (130 loc) · 3.55 KB
/
swarm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package pso
import (
"errors"
"fmt"
"go.temporal.io/sdk/workflow"
)
type ParticleResult struct {
Position
Step int
}
type Swarm struct {
Settings *SwarmSettings
Gbest *Position
Particles []*Particle
}
func NewSwarm(ctx workflow.Context, settings *SwarmSettings) (*Swarm, error) {
var swarm Swarm
// store settings
swarm.Settings = settings
// initialize gbest
swarm.Gbest = NewPosition(swarm.Settings.function.dim)
swarm.Gbest.Fitness = 1e20
// initialize particles in parallel
chunkResultChannel := workflow.NewChannel(ctx)
swarm.Particles = make([]*Particle, settings.Size)
for i := 0; i < swarm.Settings.Size; i++ {
particleIdx := i
workflow.Go(ctx, func(ctx workflow.Context) {
var particle Particle
err := workflow.ExecuteActivity(ctx, InitParticleActivityName, swarm).Get(ctx, &particle)
if err == nil {
swarm.Particles[particleIdx] = &particle
}
chunkResultChannel.Send(ctx, err)
})
}
// wait for all particles to be initialized
for i := 0; i < swarm.Settings.Size; i++ {
var v interface{}
chunkResultChannel.Receive(ctx, &v)
switch r := v.(type) {
case error:
if r != nil {
return &swarm, r
}
}
}
swarm.updateBest()
return &swarm, nil
}
func (swarm *Swarm) updateBest() {
for i := 0; i < swarm.Settings.Size; i++ {
if swarm.Particles[i].Pbest.IsBetterThan(swarm.Gbest) {
swarm.Gbest = swarm.Particles[i].Pbest.Copy()
}
}
}
func (swarm *Swarm) Run(ctx workflow.Context, step int) (ParticleResult, error) {
logger := workflow.GetLogger(ctx)
// Setup query handler for query type "iteration"
var iterationMessage string
err := workflow.SetQueryHandler(ctx, "iteration", func(input []byte) (string, error) {
return iterationMessage, nil
})
if err != nil {
logger.Info("SetQueryHandler failed: " + err.Error())
return ParticleResult{}, err
}
// the algorithm goes here
chunkResultChannel := workflow.NewChannel(ctx)
for step <= swarm.Settings.Steps {
logger.Info("Iteration ", "step", step)
// Update particles in parallel
for i := 0; i < swarm.Settings.Size; i++ {
particleIdx := i
workflow.Go(ctx, func(ctx workflow.Context) {
var particle Particle
err := workflow.ExecuteActivity(ctx, UpdateParticleActivityName, *swarm, particleIdx).Get(ctx, &particle)
if err == nil {
swarm.Particles[particleIdx] = &particle
}
chunkResultChannel.Send(ctx, err)
})
}
// Wait for all particles to be updated
for i := 0; i < swarm.Settings.Size; i++ {
var v interface{}
chunkResultChannel.Receive(ctx, &v)
switch r := v.(type) {
case error:
if r != nil {
return ParticleResult{
Position: *swarm.Gbest,
Step: step,
}, r
}
}
}
logger.Debug("Iteration Update Swarm Best", "step", step)
swarm.updateBest()
// Check if the goal has reached then stop early
if swarm.Gbest.Fitness < swarm.Settings.function.Goal {
logger.Debug("Iteration New Swarm Best", "step", step)
return ParticleResult{
Position: *swarm.Gbest,
Step: step,
}, nil
}
iterationMessage = fmt.Sprintf("Step %d :: min err=%.5e\n", step, swarm.Gbest.Fitness)
if step%swarm.Settings.PrintEvery == 0 {
logger.Info(iterationMessage)
}
// Finished all iterations
if step == swarm.Settings.Steps {
break
}
// Not finished yet, just continue as new to reduce history size
if step%swarm.Settings.ContinueAsNewEvery == 0 {
return ParticleResult{
Position: *swarm.Gbest,
Step: step,
}, errors.New(ContinueAsNewStr)
}
step++
}
return ParticleResult{
Position: *swarm.Gbest,
Step: step,
}, nil
}