From 90e901206f1f216bee49cc969d760aea39116b9f Mon Sep 17 00:00:00 2001 From: param-stripe <54037899+param-stripe@users.noreply.github.com> Date: Tue, 29 Nov 2022 14:42:57 +0000 Subject: [PATCH] Add VerifyRequestHandler config option (#180) * Add VerifyRequestHandler config option * GetProxyConnectHeader -> ProxyConnectHeader * Remove some extraneous changes * rename to proxyConnectHeaders * call it CustomRequestHandler * pctx.Req -> req * remove the header in the handler and add assertion * remove unnecessary comment * fix test name * improve test names * create proxyClientWithConnectHeaders method * add comment explaining smokescreen behaviour when error is returned --- pkg/smokescreen/config.go | 5 ++ pkg/smokescreen/smokescreen.go | 19 ++++ pkg/smokescreen/smokescreen_test.go | 134 ++++++++++++++++++++++++++++ 3 files changed, 158 insertions(+) diff --git a/pkg/smokescreen/config.go b/pkg/smokescreen/config.go index cbbca07a..ec3f86a8 100644 --- a/pkg/smokescreen/config.go +++ b/pkg/smokescreen/config.go @@ -79,6 +79,11 @@ type Config struct { // ranges by default (exempting loopback and unicast ranges) // This setting can be used to configure Smokescreen with a blocklist, rather than an allowlist UnsafeAllowPrivateRanges bool + + // Custom handler for users to allow running code per requests, users can pass in custom methods to verify requests based + // on headers, code for metrics etc. + // If the handler returns an error, smokescreen will deny the request. + CustomRequestHandler func(*http.Request) error } type missingRoleError struct { diff --git a/pkg/smokescreen/smokescreen.go b/pkg/smokescreen/smokescreen.go index 1e72add8..8da0903a 100644 --- a/pkg/smokescreen/smokescreen.go +++ b/pkg/smokescreen/smokescreen.go @@ -475,6 +475,15 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { sctx.logger.WithField("url", req.RequestURI).Debug("received HTTP proxy request") + // Call the custom request handler if it exists + if config.CustomRequestHandler != nil { + err = config.CustomRequestHandler(req) + if err != nil { + pctx.Error = denyError{err} + return req, rejectResponse(pctx, pctx.Error) + } + } + sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, req, destination) // Returning any kind of response in this handler is goproxy's way of short circuiting @@ -609,6 +618,16 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) { pctx.Error = denyError{err} return "", pctx.Error } + + // Call the custom request handler if it exists + if config.CustomRequestHandler != nil { + err = config.CustomRequestHandler(pctx.Req) + if err != nil { + pctx.Error = denyError{err} + return "", pctx.Error + } + } + sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, pctx.Req, destination) if pctx.Error != nil { return "", denyError{pctx.Error} diff --git a/pkg/smokescreen/smokescreen_test.go b/pkg/smokescreen/smokescreen_test.go index ce74eff4..67ccd43f 100644 --- a/pkg/smokescreen/smokescreen_test.go +++ b/pkg/smokescreen/smokescreen_test.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" "log" "net" "net/http" @@ -1019,6 +1020,134 @@ func TestRejectResponseHandler(t *testing.T) { }) } +func TestCustomRequestHandler(t *testing.T) { + r := require.New(t) + testHeader := "X-Verify-Request-Header" + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get(testHeader) != "" { + w.Write([]byte("header not removed!")) + return + } + w.Write([]byte("OK")) + }) + customRequestHandler := func(r *http.Request) error { + header := r.Header.Get(testHeader) + r.Header.Del(testHeader) + if header == "" { + return errors.New("header doesn't exist") + } + if header != "valid" { + return errors.New("invalid header") + } + return nil + } + + t.Run("CustomRequestHandler works for HTTPS", func(t *testing.T) { + testCases := []struct { + header http.Header + expectedError bool + }{ + { + header: http.Header{testHeader: []string{"valid"}}, + expectedError: false, + }, + { + header: http.Header{testHeader: []string{"invalid"}}, + expectedError: true, + }, + } + cfg, err := testConfig("test-local-srv") + r.NoError(err) + err = cfg.SetAllowAddresses([]string{"127.0.0.1"}) + r.NoError(err) + cfg.CustomRequestHandler = customRequestHandler + + l, err := net.Listen("tcp", "localhost:0") + r.NoError(err) + cfg.Listener = l + + proxy := proxyServer(cfg) + remote := httptest.NewTLSServer(h) + defer proxy.Close() + for _, testCase := range testCases { + + client, err := proxyClientWithConnectHeaders(proxy.URL, testCase.header) + r.NoError(err) + + req, err := http.NewRequest("GET", remote.URL, nil) + r.NoError(err) + resp, err := client.Do(req) + if testCase.expectedError { + r.Nil(resp) + r.Contains(err.Error(), "Request rejected by proxy") + } else { + r.NoError(err) + r.Equal(200, resp.StatusCode) + body, err := ioutil.ReadAll(resp.Body) + r.NoError(err) + resp.Body.Close() + r.Equal([]byte("OK"), body) + } + } + }) + + t.Run("CustomRequestHandler works for HTTP", func(t *testing.T) { + testCases := []struct { + header string + expectedError bool + }{ + { + header: "valid", + expectedError: false, + }, + { + header: "invalid", + expectedError: true, + }, + } + cfg, err := testConfig("test-local-srv") + r.NoError(err) + err = cfg.SetAllowAddresses([]string{"127.0.0.1"}) + r.NoError(err) + cfg.CustomRequestHandler = customRequestHandler + + l, err := net.Listen("tcp", "localhost:0") + r.NoError(err) + cfg.Listener = l + + remote := httptest.NewServer(h) + + proxySrv := proxyServer(cfg) + r.NoError(err) + defer proxySrv.Close() + + // Create a http.Client that uses our proxy + client, err := proxyClient(proxySrv.URL) + r.NoError(err) + + for _, testCase := range testCases { + req, err := http.NewRequest("GET", remote.URL, nil) + r.NoError(err) + req.Header.Set(testHeader, testCase.header) + resp, err := client.Do(req) + if testCase.expectedError { + r.NoError(err) + errorMessage := resp.Header.Get("X-Smokescreen-Error") + r.Contains(errorMessage, "invalid header") + + } else { + r.NoError(err) + r.Equal(200, resp.StatusCode) + body, err := ioutil.ReadAll(resp.Body) + r.NoError(err) + resp.Body.Close() + r.Equal([]byte("OK"), body) + + } + } + }) +} + func findCanonicalProxyDecision(logs []*logrus.Entry) *logrus.Entry { for _, entry := range logs { if entry.Message == CanonicalProxyDecision { @@ -1070,6 +1199,10 @@ func proxyServer(conf *Config) *httptest.Server { } func proxyClient(proxy string) (*http.Client, error) { + return proxyClientWithConnectHeaders(proxy, nil) +} + +func proxyClientWithConnectHeaders(proxy string, proxyConnectHeaders http.Header) (*http.Client, error) { proxyUrl, err := url.Parse(proxy) if err != nil { return nil, err @@ -1081,6 +1214,7 @@ func proxyClient(proxy string) (*http.Client, error) { TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + ProxyConnectHeader: proxyConnectHeaders, }, }, nil }