Skip to content

Commit

Permalink
Use a stack data structure to push and restore the last request
Browse files Browse the repository at this point in the history
  • Loading branch information
mostafa committed Oct 15, 2023
1 parent b009dc1 commit 4399644
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 12 deletions.
1 change: 1 addition & 0 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func StopGracefully(
pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), conf.Plugin.Timeout)
defer cancel()

//nolint:contextcheck
_, err := pluginRegistry.Run(
pluginTimeoutCtx,
map[string]interface{}{"signal": signal},
Expand Down
35 changes: 31 additions & 4 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import (
type IProxy interface {
Connect(conn net.Conn) *gerr.GatewayDError
Disconnect(conn net.Conn) *gerr.GatewayDError
PassThroughToServer(conn net.Conn) *gerr.GatewayDError
PassThroughToClient(conn net.Conn) *gerr.GatewayDError
PassThroughToServer(conn net.Conn, stack *Stack) *gerr.GatewayDError
PassThroughToClient(conn net.Conn, stack *Stack) *gerr.GatewayDError
IsHealthy(cl *Client) (*Client, *gerr.GatewayDError)
IsExhausted() bool
Shutdown()
Expand Down Expand Up @@ -260,7 +260,7 @@ func (pr *Proxy) Disconnect(conn net.Conn) *gerr.GatewayDError {
}

// PassThroughToServer sends the data from the client to the server.
func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError {
func (pr *Proxy) PassThroughToServer(conn net.Conn, stack *Stack) *gerr.GatewayDError {
_, span := otel.Tracer(config.TracerName).Start(pr.ctx, "PassThrough")
defer span.End()

Expand Down Expand Up @@ -317,6 +317,9 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError {
return gerr.ErrClientNotConnected.Wrap(origErr)
}

// Push the client's request to the stack.
stack.Push(&Request{Data: request})

// If the hook wants to terminate the connection, do it.
if pr.shouldTerminate(result) {
if modResponse, modReceived := pr.getPluginModifiedResponse(result); modResponse != nil {
Expand All @@ -326,6 +329,10 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError {
metrics.TotalTrafficBytes.Observe(float64(modReceived))

span.AddEvent("Terminating connection")

// Remove the request from the stack if the response is modified.
stack.PopLastRequest()

return pr.sendTrafficToClient(conn, modResponse, modReceived)
}
span.RecordError(gerr.ErrHookTerminatedConnection)
Expand All @@ -337,6 +344,8 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError {
span.AddEvent("Plugin(s) modified the request")
}

stack.UpdateLastRequest(&Request{Data: request})

// Send the request to the server.
_, err = pr.sendTrafficToServer(client, request)
span.AddEvent("Sent traffic to server")
Expand Down Expand Up @@ -370,7 +379,7 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError {
}

// PassThroughToClient sends the data from the server to the client.
func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError {
func (pr *Proxy) PassThroughToClient(conn net.Conn, stack *Stack) *gerr.GatewayDError {
_, span := otel.Tracer(config.TracerName).Start(pr.ctx, "PassThrough")
defer span.End()

Expand Down Expand Up @@ -410,19 +419,33 @@ func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError {
pr.logger.Debug().Fields(fields).Msg("No data to send to client")
span.AddEvent("No data to send to client")
span.RecordError(err)

stack.PopLastRequest()

return err
}

pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), pr.pluginTimeout)
defer cancel()

// Get the last request from the stack.
lastRequest := stack.PopLastRequest()
request := make([]byte, 0)
if lastRequest != nil {
request = lastRequest.Data
}

// Run the OnTrafficFromServer hooks.
result, err := pr.pluginRegistry.Run(
pluginTimeoutCtx,
trafficData(
conn,
client,
[]Field{
{
Name: "request",
Value: request,
},
{
Name: "response",
Value: response[:received],
Expand Down Expand Up @@ -457,6 +480,10 @@ func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError {
conn,
client,
[]Field{
{
Name: "request",
Value: request,
},
{
Name: "response",
Value: response[:received],
Expand Down
6 changes: 4 additions & 2 deletions network/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,12 @@ func BenchmarkProxyPassThrough(b *testing.B) {
proxy.Connect(conn.Conn) //nolint:errcheck
defer proxy.Disconnect(&conn) //nolint:errcheck

stack := NewStack()

// Connect to the proxy
for i := 0; i < b.N; i++ {
proxy.PassThroughToClient(&conn) //nolint:errcheck
proxy.PassThroughToServer(&conn) //nolint:errcheck
proxy.PassThroughToClient(&conn, stack) //nolint:errcheck
proxy.PassThroughToServer(&conn, stack) //nolint:errcheck
}
}

Expand Down
16 changes: 10 additions & 6 deletions network/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,35 +270,39 @@ func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action {
}
span.AddEvent("Ran the OnTraffic hooks")

stack := NewStack()

// Pass the traffic from the client to server.
// If there is an error, log it and close the connection.
go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
go func(server *Server, conn net.Conn, stopConnection chan struct{}, stack *Stack) {
for {
server.logger.Trace().Msg("Passing through traffic from client to server")
if err := server.proxy.PassThroughToServer(conn); err != nil {
if err := server.proxy.PassThroughToServer(conn, stack); err != nil {
server.logger.Trace().Err(err).Msg("Failed to pass through traffic")
span.RecordError(err)
stopConnection <- struct{}{}
break
}
}
}(s, conn, stopConnection)
}(s, conn, stopConnection, stack)

// Pass the traffic from the server to client.
// If there is an error, log it and close the connection.
go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
go func(server *Server, conn net.Conn, stopConnection chan struct{}, stack *Stack) {
for {
server.logger.Debug().Msg("Passing through traffic from server to client")
if err := server.proxy.PassThroughToClient(conn); err != nil {
if err := server.proxy.PassThroughToClient(conn, stack); err != nil {
server.logger.Trace().Err(err).Msg("Failed to pass through traffic")
span.RecordError(err)
stopConnection <- struct{}{}
break
}
}
}(s, conn, stopConnection)
}(s, conn, stopConnection, stack)

<-stopConnection
stack.Clear()

return Close
}

Expand Down
82 changes: 82 additions & 0 deletions network/stack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package network

import "sync"

type Request struct {
Data []byte
}

type Stack struct {
items []*Request
mu sync.RWMutex
}

func (s *Stack) Push(req *Request) {
s.mu.Lock()
defer s.mu.Unlock()

s.items = append(s.items, req)
}

func (s *Stack) GetLastRequest() *Request {
s.mu.RLock()
defer s.mu.RUnlock()

if len(s.items) == 0 {
return nil
}

//nolint:staticcheck
for i := len(s.items) - 1; i >= 0; i-- {
return s.items[i]
}

return nil
}

func (s *Stack) PopLastRequest() *Request {
s.mu.Lock()
defer s.mu.Unlock()

if len(s.items) == 0 {
return nil
}

//nolint:staticcheck
for i := len(s.items) - 1; i >= 0; i-- {
req := s.items[i]
s.items = append(s.items[:i], s.items[i+1:]...)
return req
}

return nil
}

func (s *Stack) UpdateLastRequest(req *Request) {
s.mu.Lock()
defer s.mu.Unlock()

if len(s.items) == 0 {
return
}

//nolint:staticcheck
for i := len(s.items) - 1; i >= 0; i-- {
s.items[i] = req
return
}
}

func (s *Stack) Clear() {
s.mu.Lock()
defer s.mu.Unlock()

s.items = make([]*Request, 0)
}

func NewStack() *Stack {
return &Stack{
items: make([]*Request, 0),
mu: sync.RWMutex{},
}
}

0 comments on commit 4399644

Please sign in to comment.