Skip to content

Commit

Permalink
TUN-8415: Refactor capnp rpc into a single module
Browse files Browse the repository at this point in the history
Combines the tunnelrpc and quic/schema capnp files into the same module.

To help reduce future issues with capnp id generation, capnpids are
provided in the capnp files from the existing capnp struct ids generated
in the go files.

Reduces the overall interface of the Capnp methods to the rest of
the code by providing an interface that will handle the quic protocol
selection.

Introduces a new `rpc-timeout` config that will allow all of the
SessionManager and ConfigurationManager RPC requests to have a timeout.
The timeout for these values is set to 5 seconds as non of these operations
for the managers should take a long time to complete.

Removed the RPC-specific logger as it never provided good debugging value
as the RPC method names were not visible in the logs.
  • Loading branch information
DevinCarr committed May 17, 2024
1 parent 7d76ce2 commit eb2e434
Show file tree
Hide file tree
Showing 39 changed files with 1,121 additions and 1,028 deletions.
14 changes: 4 additions & 10 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -251,22 +251,16 @@ github-windows-upload:
python3 github_release.py --path built_artifacts/cloudflared-windows-386.exe --release-version $(VERSION) --name cloudflared-windows-386.exe
python3 github_release.py --path built_artifacts/cloudflared-windows-386.msi --release-version $(VERSION) --name cloudflared-windows-386.msi

.PHONY: tunnelrpc-deps
tunnelrpc-deps:
.PHONY: capnp
capnp:
which capnp # https://capnproto.org/install.html
which capnpc-go # go install zombiezen.com/go/capnproto2/capnpc-go@latest
capnp compile -ogo tunnelrpc/tunnelrpc.capnp

.PHONY: quic-deps
quic-deps:
which capnp
which capnpc-go
capnp compile -ogo quic/schema/quic_metadata_protocol.capnp
capnp compile -ogo tunnelrpc/proto/tunnelrpc.capnp tunnelrpc/proto/quic_metadata_protocol.capnp

.PHONY: vet
vet:
go vet -mod=vendor github.com/cloudflare/cloudflared/...

.PHONY: fmt
fmt:
goimports -l -w -local github.com/cloudflare/cloudflared $$(go list -mod=vendor -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc)
goimports -l -w -local github.com/cloudflare/cloudflared $$(go list -mod=vendor -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc/proto)
6 changes: 3 additions & 3 deletions cmd/cloudflared/tunnel/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ const (
// hostKeyPath is the path of the dir to save SSH host keys too
hostKeyPath = "host-key-path"

// udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge
udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout"
// rpcTimeout is how long to wait for a Capnp RPC request to the edge
rpcTimeout = "rpc-timeout"

// writeStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
writeStreamTimeout = "write-stream-timeout"
Expand Down Expand Up @@ -695,7 +695,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: udpUnregisterSessionTimeoutFlag,
Name: rpcTimeout,
Value: 5 * time.Second,
Hidden: true,
}),
Expand Down
2 changes: 1 addition & 1 deletion cmd/cloudflared/tunnel/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func prepareTunnelConfig(
EdgeTLSConfigs: edgeTLSConfigs,
FeatureSelector: featureSelector,
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag),
RPCTimeout: c.Duration(rpcTimeout),
WriteStreamTimeout: c.Duration(writeStreamTimeout),
DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
}
Expand Down
96 changes: 34 additions & 62 deletions connection/quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ import (
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management"
"github.com/cloudflare/cloudflared/packet"
quicpogs "github.com/cloudflare/cloudflared/quic"
cfdquic "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic"
)

const (
Expand Down Expand Up @@ -59,14 +61,14 @@ type QUICConnection struct {
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
sessionManager datagramsession.Manager
// datagramMuxer mux/demux datagrams from quic connection
datagramMuxer *quicpogs.DatagramMuxerV2
datagramMuxer *cfdquic.DatagramMuxerV2
packetRouter *ingress.PacketRouter
controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions
connIndex uint8

udpUnregisterTimeout time.Duration
streamWriteTimeout time.Duration
rpcTimeout time.Duration
streamWriteTimeout time.Duration
}

// NewQUICConnection returns a new instance of QUICConnection.
Expand All @@ -82,7 +84,7 @@ func NewQUICConnection(
controlStreamHandler ControlStreamHandler,
logger *zerolog.Logger,
packetRouterConfig *ingress.GlobalRouterConfig,
udpUnregisterTimeout time.Duration,
rpcTimeout time.Duration,
streamWriteTimeout time.Duration,
) (*QUICConnection, error) {
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
Expand All @@ -104,7 +106,7 @@ func NewQUICConnection(
}

sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
datagramMuxer := quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan)
datagramMuxer := cfdquic.NewDatagramMuxerV2(session, logger, sessionDemuxChan)
sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan)
packetRouter := ingress.NewPacketRouter(packetRouterConfig, datagramMuxer, logger)

Expand All @@ -118,7 +120,7 @@ func NewQUICConnection(
controlStreamHandler: controlStreamHandler,
connOptions: connOptions,
connIndex: connIndex,
udpUnregisterTimeout: udpUnregisterTimeout,
rpcTimeout: rpcTimeout,
streamWriteTimeout: streamWriteTimeout,
}, nil
}
Expand Down Expand Up @@ -198,15 +200,16 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {

func (q *QUICConnection) runStream(quicStream quic.Stream) {
ctx := quicStream.Context()
stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close()

// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
// code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream.
// So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream.
// A call to close will simulate a close to the read-side, which will fail subsequent reads.
noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream}
if err := q.handleStream(ctx, noCloseStream); err != nil {
ss := rpcquic.NewCloudflaredServer(q.handleDataStream, q, q, q.rpcTimeout)
if err := ss.Serve(ctx, noCloseStream); err != nil {
q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream")

// if we received an error at this level, then close write side of stream with an error, which will result in
Expand All @@ -215,30 +218,7 @@ func (q *QUICConnection) runStream(quicStream quic.Stream) {
}
}

func (q *QUICConnection) handleStream(ctx context.Context, stream io.ReadWriteCloser) error {
signature, err := quicpogs.DetermineProtocol(stream)
if err != nil {
return err
}
switch signature {
case quicpogs.DataStreamProtocolSignature:
reqServerStream, err := quicpogs.NewRequestServerStream(stream, signature)
if err != nil {
return err
}
return q.handleDataStream(ctx, reqServerStream)
case quicpogs.RPCStreamProtocolSignature:
rpcStream, err := quicpogs.NewRPCServerStream(stream, signature)
if err != nil {
return err
}
return q.handleRPCStream(rpcStream)
default:
return fmt.Errorf("unknown protocol %v", signature)
}
}

func (q *QUICConnection) handleDataStream(ctx context.Context, stream *quicpogs.RequestServerStream) error {
func (q *QUICConnection) handleDataStream(ctx context.Context, stream *rpcquic.RequestServerStream) error {
request, err := stream.ReadConnectRequestData()
if err != nil {
return err
Expand All @@ -264,22 +244,22 @@ func (q *QUICConnection) handleDataStream(ctx context.Context, stream *quicpogs.
// dispatchRequest will dispatch the request depending on the type and returns an error if it occurs.
// More importantly, it also tells if the during processing of the request the ConnectResponse metadata was sent downstream.
// This is important since it informs
func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) (error, bool) {
func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *rpcquic.RequestServerStream, err error, request *pogs.ConnectRequest) (error, bool) {
originProxy, err := q.orchestrator.GetOriginProxy()
if err != nil {
return err, false
}

switch request.Type {
case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket:
case pogs.ConnectionTypeHTTP, pogs.ConnectionTypeWebsocket:
tracedReq, err := buildHTTPRequest(ctx, request, stream, q.connIndex, q.logger)
if err != nil {
return err, false
}
w := newHTTPResponseAdapter(stream)
return originProxy.ProxyHTTP(&w, tracedReq, request.Type == quicpogs.ConnectionTypeWebsocket), w.connectResponseSent
return originProxy.ProxyHTTP(&w, tracedReq, request.Type == pogs.ConnectionTypeWebsocket), w.connectResponseSent

case quicpogs.ConnectionTypeTCP:
case pogs.ConnectionTypeTCP:
rwa := &streamReadWriteAcker{RequestServerStream: stream}
metadata := request.MetadataMap()
return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{
Expand All @@ -293,14 +273,6 @@ func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.R
}
}

func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error {
if err := rpcStream.Serve(q, q, q.logger); err != nil {
q.logger.Err(err).Msg("failed handling RPC stream")
}

return nil
}

// RegisterUdpSession is the RPC method invoked by edge to register and run a session
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*tunnelpogs.RegisterUdpSessionResponse, error) {
traceCtx := tracing.NewTracedContext(ctx, traceContext, q.logger)
Expand Down Expand Up @@ -377,9 +349,9 @@ func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUI
return
}

stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close()
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger)
rpcClientStream, err := rpcquic.NewSessionClient(ctx, stream, q.rpcTimeout)
if err != nil {
// Log this at debug because this is not an error if session was closed due to lost connection
// with edge
Expand Down Expand Up @@ -408,16 +380,16 @@ func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32,
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
// the client.
type streamReadWriteAcker struct {
*quicpogs.RequestServerStream
*rpcquic.RequestServerStream
connectResponseSent bool
}

// AckConnection acks response back to the proxy.
func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error {
metadata := []quicpogs.Metadata{}
metadata := []pogs.Metadata{}
// Only add tracing if provided by origintunneld
if tracePropagation != "" {
metadata = append(metadata, quicpogs.Metadata{
metadata = append(metadata, pogs.Metadata{
Key: tracing.CanonicalCloudflaredTracingHeader,
Val: tracePropagation,
})
Expand All @@ -428,12 +400,12 @@ func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error {

// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
type httpResponseAdapter struct {
*quicpogs.RequestServerStream
*rpcquic.RequestServerStream
headers http.Header
connectResponseSent bool
}

func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter {
func newHTTPResponseAdapter(s *rpcquic.RequestServerStream) httpResponseAdapter {
return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header)}
}

Expand All @@ -442,12 +414,12 @@ func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) {
}

func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
metadata := make([]quicpogs.Metadata, 0)
metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
metadata := make([]pogs.Metadata, 0)
metadata = append(metadata, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
for k, vv := range header {
for _, v := range vv {
httpHeaderKey := fmt.Sprintf("%s:%s", HTTPHeaderKey, k)
metadata = append(metadata, quicpogs.Metadata{Key: httpHeaderKey, Val: v})
metadata = append(metadata, pogs.Metadata{Key: httpHeaderKey, Val: v})
}
}

Expand Down Expand Up @@ -483,17 +455,17 @@ func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
}

func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
}

func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...quicpogs.Metadata) error {
func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error {
hrw.connectResponseSent = true
return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...)
}

func buildHTTPRequest(
ctx context.Context,
connectRequest *quicpogs.ConnectRequest,
connectRequest *pogs.ConnectRequest,
body io.ReadCloser,
connIndex uint8,
log *zerolog.Logger,
Expand All @@ -502,7 +474,7 @@ func buildHTTPRequest(
dest := connectRequest.Dest
method := metadata[HTTPMethodKey]
host := metadata[HTTPHostKey]
isWebsocket := connectRequest.Type == quicpogs.ConnectionTypeWebsocket
isWebsocket := connectRequest.Type == pogs.ConnectionTypeWebsocket

req, err := http.NewRequestWithContext(ctx, method, dest, body)
if err != nil {
Expand Down Expand Up @@ -597,19 +569,19 @@ func (np *nopCloserReadWriter) Close() error {

// muxerWrapper wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface
type muxerWrapper struct {
muxer *quicpogs.DatagramMuxerV2
muxer *cfdquic.DatagramMuxerV2
}

func (rp *muxerWrapper) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
return rp.muxer.SendPacket(quicpogs.RawPacket(pk))
return rp.muxer.SendPacket(cfdquic.RawPacket(pk))
}

func (rp *muxerWrapper) ReceivePacket(ctx context.Context) (packet.RawPacket, error) {
pk, err := rp.muxer.ReceivePacket(ctx)
if err != nil {
return packet.RawPacket{}, err
}
rawPacket, ok := pk.(quicpogs.RawPacket)
rawPacket, ok := pk.(cfdquic.RawPacket)
if ok {
return packet.RawPacket(rawPacket), nil
}
Expand Down
Loading

0 comments on commit eb2e434

Please sign in to comment.