diff --git a/cmd/run.go b/cmd/run.go index aeb3504c..286ec85f 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -89,6 +89,7 @@ func StopGracefully( pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), conf.Plugin.Timeout) defer cancel() + //nolint:contextcheck _, err := pluginRegistry.Run( pluginTimeoutCtx, map[string]interface{}{"signal": signal}, diff --git a/network/proxy.go b/network/proxy.go index 2cf608a8..56e9d799 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -23,8 +23,8 @@ import ( type IProxy interface { Connect(conn net.Conn) *gerr.GatewayDError Disconnect(conn net.Conn) *gerr.GatewayDError - PassThroughToServer(conn net.Conn) *gerr.GatewayDError - PassThroughToClient(conn net.Conn) *gerr.GatewayDError + PassThroughToServer(conn net.Conn, stack *Stack) *gerr.GatewayDError + PassThroughToClient(conn net.Conn, stack *Stack) *gerr.GatewayDError IsHealthy(cl *Client) (*Client, *gerr.GatewayDError) IsExhausted() bool Shutdown() @@ -260,7 +260,7 @@ func (pr *Proxy) Disconnect(conn net.Conn) *gerr.GatewayDError { } // PassThroughToServer sends the data from the client to the server. -func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { +func (pr *Proxy) PassThroughToServer(conn net.Conn, stack *Stack) *gerr.GatewayDError { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "PassThrough") defer span.End() @@ -317,6 +317,9 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { return gerr.ErrClientNotConnected.Wrap(origErr) } + // Push the client's request to the stack. + stack.Push(&Request{Data: request}) + // If the hook wants to terminate the connection, do it. if pr.shouldTerminate(result) { if modResponse, modReceived := pr.getPluginModifiedResponse(result); modResponse != nil { @@ -326,6 +329,10 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { metrics.TotalTrafficBytes.Observe(float64(modReceived)) span.AddEvent("Terminating connection") + + // Remove the request from the stack if the response is modified. + stack.PopLastRequest() + return pr.sendTrafficToClient(conn, modResponse, modReceived) } span.RecordError(gerr.ErrHookTerminatedConnection) @@ -337,6 +344,8 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { span.AddEvent("Plugin(s) modified the request") } + stack.UpdateLastRequest(&Request{Data: request}) + // Send the request to the server. _, err = pr.sendTrafficToServer(client, request) span.AddEvent("Sent traffic to server") @@ -370,7 +379,7 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { } // PassThroughToClient sends the data from the server to the client. -func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError { +func (pr *Proxy) PassThroughToClient(conn net.Conn, stack *Stack) *gerr.GatewayDError { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "PassThrough") defer span.End() @@ -410,12 +419,22 @@ func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError { pr.logger.Debug().Fields(fields).Msg("No data to send to client") span.AddEvent("No data to send to client") span.RecordError(err) + + stack.PopLastRequest() + return err } pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), pr.pluginTimeout) defer cancel() + // Get the last request from the stack. + lastRequest := stack.PopLastRequest() + request := make([]byte, 0) + if lastRequest != nil { + request = lastRequest.Data + } + // Run the OnTrafficFromServer hooks. result, err := pr.pluginRegistry.Run( pluginTimeoutCtx, @@ -423,6 +442,10 @@ func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError { conn, client, []Field{ + { + Name: "request", + Value: request, + }, { Name: "response", Value: response[:received], @@ -457,6 +480,10 @@ func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError { conn, client, []Field{ + { + Name: "request", + Value: request, + }, { Name: "response", Value: response[:received], diff --git a/network/proxy_test.go b/network/proxy_test.go index 8bd82452..379b6df0 100644 --- a/network/proxy_test.go +++ b/network/proxy_test.go @@ -310,10 +310,12 @@ func BenchmarkProxyPassThrough(b *testing.B) { proxy.Connect(conn.Conn) //nolint:errcheck defer proxy.Disconnect(&conn) //nolint:errcheck + stack := NewStack() + // Connect to the proxy for i := 0; i < b.N; i++ { - proxy.PassThroughToClient(&conn) //nolint:errcheck - proxy.PassThroughToServer(&conn) //nolint:errcheck + proxy.PassThroughToClient(&conn, stack) //nolint:errcheck + proxy.PassThroughToServer(&conn, stack) //nolint:errcheck } } diff --git a/network/server.go b/network/server.go index 3c9270be..71738b08 100644 --- a/network/server.go +++ b/network/server.go @@ -270,35 +270,39 @@ func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action { } span.AddEvent("Ran the OnTraffic hooks") + stack := NewStack() + // Pass the traffic from the client to server. // If there is an error, log it and close the connection. - go func(server *Server, conn net.Conn, stopConnection chan struct{}) { + go func(server *Server, conn net.Conn, stopConnection chan struct{}, stack *Stack) { for { server.logger.Trace().Msg("Passing through traffic from client to server") - if err := server.proxy.PassThroughToServer(conn); err != nil { + if err := server.proxy.PassThroughToServer(conn, stack); err != nil { server.logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) stopConnection <- struct{}{} break } } - }(s, conn, stopConnection) + }(s, conn, stopConnection, stack) // Pass the traffic from the server to client. // If there is an error, log it and close the connection. - go func(server *Server, conn net.Conn, stopConnection chan struct{}) { + go func(server *Server, conn net.Conn, stopConnection chan struct{}, stack *Stack) { for { server.logger.Debug().Msg("Passing through traffic from server to client") - if err := server.proxy.PassThroughToClient(conn); err != nil { + if err := server.proxy.PassThroughToClient(conn, stack); err != nil { server.logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) stopConnection <- struct{}{} break } } - }(s, conn, stopConnection) + }(s, conn, stopConnection, stack) <-stopConnection + stack.Clear() + return Close } diff --git a/network/stack.go b/network/stack.go new file mode 100644 index 00000000..a3dd6cfb --- /dev/null +++ b/network/stack.go @@ -0,0 +1,82 @@ +package network + +import "sync" + +type Request struct { + Data []byte +} + +type Stack struct { + items []*Request + mu sync.RWMutex +} + +func (s *Stack) Push(req *Request) { + s.mu.Lock() + defer s.mu.Unlock() + + s.items = append(s.items, req) +} + +func (s *Stack) GetLastRequest() *Request { + s.mu.RLock() + defer s.mu.RUnlock() + + if len(s.items) == 0 { + return nil + } + + //nolint:staticcheck + for i := len(s.items) - 1; i >= 0; i-- { + return s.items[i] + } + + return nil +} + +func (s *Stack) PopLastRequest() *Request { + s.mu.Lock() + defer s.mu.Unlock() + + if len(s.items) == 0 { + return nil + } + + //nolint:staticcheck + for i := len(s.items) - 1; i >= 0; i-- { + req := s.items[i] + s.items = append(s.items[:i], s.items[i+1:]...) + return req + } + + return nil +} + +func (s *Stack) UpdateLastRequest(req *Request) { + s.mu.Lock() + defer s.mu.Unlock() + + if len(s.items) == 0 { + return + } + + //nolint:staticcheck + for i := len(s.items) - 1; i >= 0; i-- { + s.items[i] = req + return + } +} + +func (s *Stack) Clear() { + s.mu.Lock() + defer s.mu.Unlock() + + s.items = make([]*Request, 0) +} + +func NewStack() *Stack { + return &Stack{ + items: make([]*Request, 0), + mu: sync.RWMutex{}, + } +}