-
Notifications
You must be signed in to change notification settings - Fork 0
/
socket.go
593 lines (497 loc) · 17.6 KB
/
socket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
package engineio
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"slices"
"strconv"
"time"
)
// ErrInvalidURL is an error that is returned when a URI is invalid.
var ErrInvalidURL = errors.New("invalid URL")
// SocketState represents the state of a socket.
type SocketState string
const (
// SocketStateOpen represents an open socket.
SocketStateOpen SocketState = "open"
// SocketStateOpening represents a socket that is opening.
SocketStateOpening SocketState = "opening"
// SocketStateClosed represents a socket that is closed.
SocketStateClosed SocketState = "closed"
// SocketStateClosing represents a socket that is closing.
SocketStateClosing SocketState = "closing"
)
// SocketOpenHandler is a function that is called when a socket opens.
type SocketOpenHandler func()
// SocketCloseHandler is a function that is called when a socket closes.
type SocketCloseHandler func(reason string, cause error)
// SocketPacketHandler is a function that is called when a socket receives packets.
type SocketPacketHandler func(Packet)
// SocketMessageHandler is a function that is called when a socket receives a message packet.
type SocketMessageHandler func([]byte)
// SocketErrorHandler is a function that is called when a socket encounters an error.
type SocketErrorHandler func(error)
// SocketOptions is a struct that represents options for a socket connection.
type SocketOptions struct {
// Client is the http.Client to use for the transport.
// Default: a new http.Client
Client TransportClient
// Header is the headers to use for the transport.
// Default: an empty http.Header
Header *http.Header
// Upgrade determines whether the client should try to upgrade the transport from long-polling to something better.
// Default: true
Upgrade *bool
// RememberUpgrade determines whether the client should remember to upgrade to a better transport from the previous connection upgrade.
// Default: false
RememberUpgrade *bool
// Transports is a list of transports to try (in order).
// Engine.io always attempts to connect directly with the first one, provided the feature detection test for it passes.
// Default: ['polling', 'websocket', 'webtransport']
Transports *[]TransportType
// TryAllTransports determines whether the client should try all transports in the list before giving up.
// Default: false
TryAllTransports *bool
}
// Socket is a struct that represents a socket connection.
type Socket struct {
// url is the target to connect to.
url *url.URL
// client is the TransportClient to use for the transport.
client TransportClient
// header is the headers to use for the transport.
header http.Header
// upgrade determines whether the client should try to upgrade the transport from long-polling to something better.
upgrade bool
// rememberUpgrade determines whether the client should remember to upgrade to a better transport from the previous connection upgrade.
rememberUpgrade bool
// transports is a list of transports to try (in order).
transports []TransportType
// tryAllTransports determines whether the client should try all transports in the list before giving up.
tryAllTransports bool
// onOpenHandler is the handler for when the socket opens.
onOpenHandler SocketOpenHandler
// onCloseHandler is the handler for when the socket closes.
onCloseHandler SocketCloseHandler
// onPacketHandler is the handler for when the socket receives packets.
onPacketHandler SocketPacketHandler
// onErrorHandler is the handler for when the socket receives a message packet.
onMessageHandler SocketMessageHandler
// onErrorHandler is the handler for when the socket encounters an error.
onErrorHandler SocketErrorHandler
// sessionID is a unique session identifier.
sessionID string
// The ping interval, used in the heartbeat mechanism (in milliseconds).
// https://github.com/socketio/engine.io-protocol?tab=readme-ov-file#heartbeat
pingInterval time.Duration
// The ping timeout, used in the heartbeat mechanism (in milliseconds).
// https://github.com/socketio/engine.io-protocol?tab=readme-ov-file#heartbeat
pingTimeout time.Duration
// The maximum number of bytes per chunk, used by the client to aggregate packets into payloads.
// https://github.com/socketio/engine.io-protocol?tab=readme-ov-file#packet-encoding
maxPayload int
// state is the state of the socket.
state SocketState
// transport is the transport for the socket.
transport Transport
// priorUpgradeSuccess determines whether the prior upgrade was successful.
priorUpgradeSuccess bool
// pingTimeoutTimer is a timer that closes the transport if the server does not respond to pings.
pingTimeoutTimer *time.Timer
}
// NewSocket creates a new Socket.
func NewSocket(serverURL string, socketOptions ...SocketOptions) (*Socket, error) {
url, err := url.Parse(serverURL)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidURL, err)
}
var client TransportClient = &http.Client{}
if len(socketOptions) != 0 && socketOptions[0].Client != nil {
client = socketOptions[0].Client
}
var header = http.Header{}
if len(socketOptions) != 0 && socketOptions[0].Header != nil {
header = *socketOptions[0].Header
}
var upgrade = true
if len(socketOptions) != 0 && socketOptions[0].Upgrade != nil {
upgrade = *socketOptions[0].Upgrade
}
var tryAllTransports = false
if len(socketOptions) != 0 && socketOptions[0].TryAllTransports != nil {
tryAllTransports = *socketOptions[0].TryAllTransports
}
var rememberUpgrade = false
if len(socketOptions) != 0 && socketOptions[0].RememberUpgrade != nil {
rememberUpgrade = *socketOptions[0].RememberUpgrade
}
var transports = []TransportType{
TransportTypePolling,
TransportTypeWebSocket,
}
if len(socketOptions) != 0 && socketOptions[0].Transports != nil {
transports = *socketOptions[0].Transports
}
return &Socket{
url: url,
client: client,
header: header,
upgrade: upgrade,
rememberUpgrade: rememberUpgrade,
transports: transports,
tryAllTransports: tryAllTransports,
state: SocketStateClosed,
}, nil
}
// Open opens the socket.
func (s *Socket) Open(ctx context.Context) {
// The socket must be closed to begin opening.
if s.state != SocketStateClosed {
return
}
// Create a new transport.
var err error
s.transport, err = s.createTransport()
if err != nil {
// It may not be possible to create a transport if no transports are available.
if s.onErrorHandler != nil {
s.onErrorHandler(errors.New("no transports available"))
}
return
}
// Bind the transport handlers.
s.transport.OnPacket(s.onPacket)
s.transport.OnError(s.onError)
s.transport.OnClose(func(ctx context.Context) {
s.onClose(ctx, "transport closed", nil)
})
// Open the transport.
s.state = SocketStateOpening
s.transport.Open(ctx)
}
// createTransport creates a new transport.
func (s *Socket) createTransport() (Transport, error) {
if len(s.transports) == 0 {
return nil, errors.New("no transports available")
}
var transportType TransportType
switch {
// The WebSocket transport can be used if a prior upgrade was successful and the WebSocket transport is available.
case s.rememberUpgrade && s.priorUpgradeSuccess && slices.Contains(s.transports, TransportTypeWebSocket):
transportType = TransportTypeWebSocket
// Use the first transport in the list.
default:
transportType = s.transports[0]
}
// Resolve the URL for the transport.
url, err := s.resolveURL(transportType)
if err != nil {
return nil, fmt.Errorf("resolve URL: %w", err)
}
// Create a new transport.
return Transports[transportType](url, TransportOptions{
Client: s.client,
Header: s.header,
})
}
// resolveURL resolves the URL for the transport by applying query parameters to the base URL.
func (s *Socket) resolveURL(transportType TransportType) (*url.URL, error) {
// Make a copy of the URL to avoid modifying the original URL.
u, err := url.Parse(s.url.String())
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidURL, err)
}
// Set query parameters.
q := s.url.Query()
q.Set("EIO", strconv.Itoa(Protocol))
q.Set("transport", string(transportType))
if s.sessionID != "" {
q.Set("sid", s.sessionID)
}
u.RawQuery = q.Encode()
return u, nil
}
// Close closes the socket.
func (s *Socket) Close(ctx context.Context) {
switch s.state {
// The socket must be open or opening to send packets.
case SocketStateOpen, SocketStateOpening:
break
default:
return
}
s.state = SocketStateClosing
s.transport.Close(ctx)
}
// Send sends packets through the socket.
func (s *Socket) Send(ctx context.Context, packets []Packet) error {
switch s.state {
// The socket must be open to send packets.
case SocketStateOpen:
break
default:
return nil
}
// TODO: Enforce the maximum payload size.
// Loop over the packets building up the largest payload possible until s.maxPayload is reached.
// If more packets exist, begin adding those to the next payload. Continue until all packets are sent.
return s.transport.Send(ctx, packets)
}
func (s *Socket) onError(ctx context.Context, err error) {
if s.onErrorHandler != nil {
s.onErrorHandler(fmt.Errorf("transport error: %w", err))
}
// The transport could not upgrade because the connection was unsuccessful.
s.priorUpgradeSuccess = false
// If multiple transports exist, try the next one.
if s.tryAllTransports && len(s.transports) > 1 && s.state == SocketStateOpening {
s.transports = s.transports[1:] // The current transport does not work so it is removed from the list.
// If the transport is not closed, close it.
s.transport.OnOpen(nil)
s.transport.OnClose(nil)
s.transport.OnPacket(nil)
s.transport.OnError(nil)
s.transport.Close(ctx)
// Create a new transport.
s.Open(ctx)
return
}
s.onClose(ctx, "transport error", err)
}
// onClose is called when the transport closes.
func (s *Socket) onClose(ctx context.Context, reason string, cause error) {
switch s.state {
// These states are valid states for closing the socket.
case SocketStateOpen, SocketStateOpening, SocketStateClosing:
break
default:
return
}
// Stop the ping timeout timer.
if s.pingTimeoutTimer != nil {
s.pingTimeoutTimer.Stop()
}
// Ignore further communication from the transport.
if s.transport != nil {
s.transport.OnOpen(nil)
s.transport.OnClose(nil)
s.transport.OnPacket(nil)
s.transport.OnError(nil)
// Close the transport.
s.transport.Close(ctx)
s.transport = nil
}
// The state is now closed.
s.state = SocketStateClosed
// A session no longer exists.
s.sessionID = ""
// If the onCloseHandler is set, call it.
if s.onCloseHandler != nil {
s.onCloseHandler(reason, cause)
}
}
func (s *Socket) onPacket(ctx context.Context, p Packet) {
switch s.state {
// These states are valid states for receiving packets.
case SocketStateOpen, SocketStateOpening, SocketStateClosing:
break
default:
return
}
s.reschedulePingTimeout(ctx)
switch p.Type {
// If an open packet is received, the server is ready to receive messages.
case PacketOpen:
var openPacket OpenPacket
switch err := json.Unmarshal(p.Data, &openPacket); {
case err != nil:
if s.onErrorHandler != nil {
s.onErrorHandler(fmt.Errorf("failed to unmarshal open packet: %w", err))
}
default:
s.onOpen(ctx, openPacket)
}
// If a ping packet is received, the server is still alive.
case PacketPing:
// Reply with a pong packet.
if err := s.Send(ctx, []Packet{{Type: PacketPong}}); err != nil && s.onErrorHandler != nil {
s.onErrorHandler(fmt.Errorf("failed to send pong packet: %w", err))
}
case PacketMessage:
if s.onMessageHandler != nil {
s.onMessageHandler(p.Data)
}
}
// If the onPacketHandler is set, call it.
if s.onPacketHandler != nil {
s.onPacketHandler(p)
}
}
// reschedulePingTimeout stops and re-schedules the ping timeout timer. This closes the transport if the server does not send a message within the pingTimeout.
func (s *Socket) reschedulePingTimeout(ctx context.Context) {
if s.pingTimeoutTimer != nil {
s.pingTimeoutTimer.Stop()
}
// Schedule a timer that closes the transport if the server does not respond to a message.
var timer = time.NewTimer(s.pingInterval + s.pingTimeout)
go func() {
select {
// If the context is complete, the timer should be stopped.
case <-ctx.Done():
timer.Stop()
// When a message is received...
case _, closed := <-timer.C:
// The goroutine should exit if the timer is closed.
if closed {
return
}
// Otherwise, the socket should close.
s.onClose(ctx, "ping timeout", nil)
}
}()
// Store it so it can be stopped.
s.pingTimeoutTimer = timer
}
// onOpen is called when the server sends an open packet.
func (s *Socket) onOpen(ctx context.Context, p OpenPacket) {
// Store the session ID, ping interval, ping timeout, and maximum payload.
s.sessionID = p.SessionID
s.maxPayload = p.MaxPayload
s.pingTimeout = time.Duration(p.PingTimeout) * time.Millisecond
s.pingInterval = time.Duration(p.PingInterval) * time.Millisecond
// The state is now open.
s.state = SocketStateOpen
// Update the transport's URL with the session ID.
url, err := s.resolveURL(s.transport.Type())
if err != nil {
if s.onErrorHandler != nil {
s.onErrorHandler(fmt.Errorf("failed to resolve URL: %w", err))
}
return
}
s.transport.SetURL(url)
// If the transport is a WebSocket, remember that the upgrade was successful.
s.priorUpgradeSuccess = s.transport.Type() == TransportTypeWebSocket
// If the client supports upgrades, probe the server for upgrades.
if s.upgrade && len(p.Upgrades) != 0 {
for _, upgrade := range p.Upgrades {
// The transport upgrade must be in the list of supported transports.
if slices.Contains(s.transports, upgrade) {
// Probe the transport upgrade.
if err := s.probe(ctx, upgrade); err != nil {
// If an error handler is set, call it.
if s.onErrorHandler != nil {
s.onErrorHandler(fmt.Errorf("failed to probe for upgrade: %w", err))
}
// If an error occurred, the next upgrade should be attempted.
continue
}
// If no error occurred, the upgrade was successful.
break
}
}
}
// If the onOpenHandler is set, call it.
if s.onOpenHandler != nil {
s.onOpenHandler()
}
}
// probe probes the server for an upgrade to a new transport.
func (s *Socket) probe(ctx context.Context, upgradeTransportType TransportType) error {
// Reset upgrade success to false.
s.priorUpgradeSuccess = false
// Resolve the URL for the transport.
url, err := s.resolveURL(upgradeTransportType)
if err != nil {
return fmt.Errorf("failed to resolve URL: %w", err)
}
// Create a transport using the upgrade transport type.
transport, err := Transports[upgradeTransportType](url, TransportOptions{
Client: s.client,
Header: s.header,
})
if err != nil {
return fmt.Errorf("failed to create transport: %w", err)
}
// errChan is a channel that is used to signal when the upgrade process is complete.
errChan := make(chan error, 1)
// When the transport opens, send a probe packet.
transport.OnOpen(func(ctx context.Context) {
// Send a probe packet.
// https://github.com/socketio/engine.io-protocol?tab=readme-ov-file#upgrade
transport.Send(ctx, []Packet{{Type: PacketPing, Data: []byte("probe")}})
// The on open handler is no longer needed.
transport.OnOpen(nil)
})
// When the transport receives a packet, check if it is a pong packet with the data "probe".
transport.OnPacket(func(ctx context.Context, p Packet) {
// Wait for a pong packet with the data "probe".
if p.Type == PacketPong && string(p.Data) == "probe" {
// The upgrade success is true if the transport is a WebSocket.
s.priorUpgradeSuccess = transport.Type() == TransportTypeWebSocket
// Ignore further communication from the transport.
s.transport.OnOpen(nil)
s.transport.OnClose(nil)
s.transport.OnError(nil)
s.transport.OnPacket(nil)
s.transport.Pause(ctx)
// Set the current transport to the new upgraded transport and configure the handlers.
s.transport = transport
s.transport.OnPacket(s.onPacket)
s.transport.OnError(s.onError)
s.transport.OnClose(func(ctx context.Context) {
s.onClose(ctx, "transport closed", nil)
})
// Send the upgrade packet.
s.transport.Send(ctx, []Packet{{Type: PacketUpgrade}})
errChan <- nil
}
})
// If the transport is closed, send an error to the channel.
transport.OnClose(func(_ context.Context) {
errChan <- errors.New("transport closed")
transport.OnOpen(nil)
transport.OnClose(nil)
transport.OnError(nil)
transport.OnPacket(nil)
transport.Close(ctx)
})
// If an error occurs, close the transport and send the error to the error channel.
transport.OnError(func(ctx context.Context, err error) {
errChan <- fmt.Errorf("transport error: %v", err)
transport.OnOpen(nil)
transport.OnClose(nil)
transport.OnError(nil)
transport.OnPacket(nil)
transport.Close(ctx)
})
// Open the transport.
transport.Open(ctx)
if err := <-errChan; err != nil {
return fmt.Errorf("probing transport: %w", err)
}
return nil
}
// OnOpen sets the handler for when the socket opens.
func (s *Socket) OnOpen(handler SocketOpenHandler) {
s.onOpenHandler = handler
}
// OnClose sets the handler for when the socket closes.
func (s *Socket) OnClose(handler SocketCloseHandler) {
s.onCloseHandler = handler
}
// OnPacket sets the handler for when the socket receives packets.
func (s *Socket) OnPacket(handler SocketPacketHandler) {
s.onPacketHandler = handler
}
// OnMessage sets the handler for when the socket receives a message packet.
func (s *Socket) OnMessage(handler SocketMessageHandler) {
s.onMessageHandler = handler
}
// OnError sets the handler for when the socket encounters an error.
func (s *Socket) OnError(handler SocketErrorHandler) {
s.onErrorHandler = handler
}