Skip to content

Commit

Permalink
Merge pull request #179 from relab/connection
Browse files Browse the repository at this point in the history
feat: Create config regardless of connection failures
  • Loading branch information
meling authored Mar 26, 2024
2 parents ef33eed + b10575b commit 8dd87ab
Show file tree
Hide file tree
Showing 8 changed files with 579 additions and 36 deletions.
152 changes: 129 additions & 23 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gorums

import (
"context"
"fmt"
"math"
"math/rand"
"sync"
Expand All @@ -16,6 +17,8 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
)

var streamDownErr = status.Error(codes.Unavailable, "stream is down")

type request struct {
ctx context.Context
msg *Message
Expand All @@ -40,7 +43,7 @@ type responseRouter struct {

type channel struct {
sendQ chan request
nodeID uint32
node *RawNode
mu sync.Mutex
lastError error
latency time.Duration
Expand All @@ -50,44 +53,83 @@ type channel struct {
gorumsStream ordering.Gorums_NodeStreamClient
streamMut sync.RWMutex
streamBroken atomicFlag
connEstablished atomicFlag
parentCtx context.Context
streamCtx context.Context
cancelStream context.CancelFunc
responseRouters map[uint64]responseRouter
responseMut sync.Mutex
}

// newChannel creates a new channel for the given node and starts the sending goroutine.
//
// Note that we start the sending goroutine even though the
// connection has not yet been established. This is to prevent
// deadlock when invoking a call type, as the goroutine will
// block on the sendQ until a connection has been established.
func newChannel(n *RawNode) *channel {
return &channel{
c := &channel{
sendQ: make(chan request, n.mgr.opts.sendBuffer),
backoffCfg: n.mgr.opts.backoff,
nodeID: n.ID(),
node: n,
latency: -1 * time.Second,
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
responseRouters: make(map[uint64]responseRouter),
}
// parentCtx controls the channel and is used to shut it down
c.parentCtx = n.newContext()
go c.sender()
return c
}

func (c *channel) connect(ctx context.Context, conn *grpc.ClientConn) error {
// newNodeStream creates a stream and starts the receiving goroutine.
//
// Note that the stream could fail even though conn != nil due
// to the non-blocking dial. Hence, we need to try to connect
// to the node before starting the receiving goroutine.
func (c *channel) newNodeStream(conn *grpc.ClientConn) error {
if conn == nil {
// no need to proceed if dial failed
return fmt.Errorf("connection is nil")
}
c.streamMut.Lock()
var err error
c.parentCtx = ctx
c.streamCtx, c.cancelStream = context.WithCancel(c.parentCtx)
c.gorumsClient = ordering.NewGorumsClient(conn)
c.gorumsStream, err = c.gorumsClient.NodeStream(c.streamCtx)
c.streamMut.Unlock()
if err != nil {
return err
}
go c.sendMsgs()
go c.recvMsgs()
c.streamBroken.clear()
// guard against creating multiple receiver goroutines
if !c.connEstablished.get() {
// connEstablished indicates dial was successful
// and that receiver have started
c.connEstablished.set()
go c.receiver()
}
return nil
}

func (c *channel) cancelPendingMsgs() {
c.responseMut.Lock()
defer c.responseMut.Unlock()
for msgID, router := range c.responseRouters {
router.c <- response{nid: c.node.ID(), err: streamDownErr}
// delete the router if we are only expecting a single reply message
if !router.streaming {
delete(c.responseRouters, msgID)
}
}
}

func (c *channel) routeResponse(msgID uint64, resp response) {
c.responseMut.Lock()
defer c.responseMut.Unlock()
if router, ok := c.responseRouters[msgID]; ok {
router.c <- resp
// delete the router if we are only expecting a single message
// delete the router if we are only expecting a single reply message
if !router.streaming {
delete(c.responseRouters, msgID)
}
Expand All @@ -100,10 +142,19 @@ func (c *channel) enqueue(req request, responseChan chan<- response, streaming b
c.responseRouters[req.msg.Metadata.MessageID] = responseRouter{responseChan, streaming}
c.responseMut.Unlock()
}
c.sendQ <- req
// either enqueue the request on the sendQ or respond
// with error if the node is closed
select {
case <-c.parentCtx.Done():
c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.node.ID(), err: fmt.Errorf("channel closed")})
return
case c.sendQ <- req:
}
}

func (c *channel) deleteRouter(msgID uint64) {
c.responseMut.Lock()
defer c.responseMut.Unlock()
delete(c.responseRouters, msgID)
}

Expand Down Expand Up @@ -141,12 +192,12 @@ func (c *channel) sendMsg(req request) (err error) {
case <-done:
// all is good
case <-req.ctx.Done():
// Both channels could be ready at the same time, so we should check 'done' again.
// Both channels could be ready at the same time, so we must check 'done' again.
select {
case <-done:
// false alarm
default:
// cause reconnect
// trigger reconnect
c.cancelStream()
}
}
Expand All @@ -163,30 +214,35 @@ func (c *channel) sendMsg(req request) (err error) {
return err
}

func (c *channel) sendMsgs() {
func (c *channel) sender() {
var req request
for {
select {
case <-c.parentCtx.Done():
return
case req = <-c.sendQ:
}
// try to connect to the node if previous attempts
// have failed or if the node has disconnected
if !c.isConnected() {
// streamBroken will be set if the reconnection fails
c.connect()
}
// return error if stream is broken
if c.streamBroken.get() {
err := status.Errorf(codes.Unavailable, "stream is down")
c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.nodeID, msg: nil, err: err})
c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.node.ID(), err: streamDownErr})
continue
}
// else try to send message
err := c.sendMsg(req)
if err != nil {
// return the error
c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.nodeID, msg: nil, err: err})
c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.node.ID(), err: err})
}
}
}

func (c *channel) recvMsgs() {
func (c *channel) receiver() {
for {
resp := newMessage(responseType)
c.streamMut.RLock()
Expand All @@ -195,12 +251,17 @@ func (c *channel) recvMsgs() {
c.streamBroken.set()
c.streamMut.RUnlock()
c.setLastErr(err)
// attempt to reconnect
c.reconnect()
// we only reach this point when the stream failed AFTER a message
// was sent and we are waiting for a reply. We thus need to respond
// with a stream is down error on all pending messages.
c.cancelPendingMsgs()
// attempt to reconnect indefinitely until the node is closed.
// This is necessary when streaming is enabled.
c.reconnect(-1)
} else {
c.streamMut.RUnlock()
err := status.FromProto(resp.Metadata.GetStatus()).Err()
c.routeResponse(resp.Metadata.MessageID, response{nid: c.nodeID, msg: resp.Message, err: err})
c.routeResponse(resp.Metadata.MessageID, response{nid: c.node.ID(), msg: resp.Message, err: err})
}

select {
Expand All @@ -211,23 +272,60 @@ func (c *channel) recvMsgs() {
}
}

func (c *channel) reconnect() {
c.streamMut.Lock()
defer c.streamMut.Unlock()
func (c *channel) connect() error {
if !c.connEstablished.get() {
// a connection has not yet been established; i.e.,
// a previous dial attempt could have failed.
// try dialing again.
err := c.node.dial()
if err != nil {
c.streamBroken.set()
return err
}
err = c.newNodeStream(c.node.conn)
if err != nil {
c.streamBroken.set()
return err
}
}
// the node was previously connected but is now disconnected
if c.streamBroken.get() {
// try to reconnect only once.
// Maybe add this as a user option?
c.reconnect(1)
}
return nil
}

// reconnect tries to reconnect to the node using an exponential backoff strategy.
// maxRetries = -1 represents infinite retries.
func (c *channel) reconnect(maxRetries float64) {
backoffCfg := c.backoffCfg

var retries float64
for {
var err error

c.streamMut.Lock()
// check if stream is already up
if !c.streamBroken.get() {
// do nothing because stream is up
c.streamMut.Unlock()
return
}
c.streamCtx, c.cancelStream = context.WithCancel(c.parentCtx)
c.gorumsStream, err = c.gorumsClient.NodeStream(c.streamCtx)
if err == nil {
c.streamBroken.clear()
c.streamMut.Unlock()
return
}
c.cancelStream()
c.streamMut.Unlock()
c.setLastErr(err)
if retries >= maxRetries && maxRetries > 0 {
c.streamBroken.set()
return
}
delay := float64(backoffCfg.BaseDelay)
max := float64(backoffCfg.MaxDelay)
for r := retries; delay < max && r > 0; r-- {
Expand Down Expand Up @@ -264,6 +362,14 @@ func (c *channel) channelLatency() time.Duration {
return c.latency
}

// isConnected returns true if the channel has an active connection to the node.
func (c *channel) isConnected() bool {
// streamBroken.get() is initially false and NodeStream could be down
// even though node.conn is not nil. Hence, we need connEstablished
// to make sure a proper connection has been made.
return c.connEstablished.get() && !c.streamBroken.get()
}

type atomicFlag struct {
flag int32
}
Expand Down
Loading

0 comments on commit 8dd87ab

Please sign in to comment.