Skip to content

Commit

Permalink
Add VerifyRequestHandler config option (stripe#180)
Browse files Browse the repository at this point in the history
* Add VerifyRequestHandler config option

* GetProxyConnectHeader -> ProxyConnectHeader

* Remove some extraneous changes

* rename to proxyConnectHeaders

* call it CustomRequestHandler

* pctx.Req -> req

* remove the header in the handler and add assertion

* remove unnecessary comment

* fix test name

* improve test names

* create proxyClientWithConnectHeaders method

* add comment explaining smokescreen behaviour when error is returned
  • Loading branch information
param-stripe authored Nov 29, 2022
1 parent 5b7c3b7 commit 90e9012
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pkg/smokescreen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ type Config struct {
// ranges by default (exempting loopback and unicast ranges)
// This setting can be used to configure Smokescreen with a blocklist, rather than an allowlist
UnsafeAllowPrivateRanges bool

// Custom handler for users to allow running code per requests, users can pass in custom methods to verify requests based
// on headers, code for metrics etc.
// If the handler returns an error, smokescreen will deny the request.
CustomRequestHandler func(*http.Request) error
}

type missingRoleError struct {
Expand Down
19 changes: 19 additions & 0 deletions pkg/smokescreen/smokescreen.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,15 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {

sctx.logger.WithField("url", req.RequestURI).Debug("received HTTP proxy request")

// Call the custom request handler if it exists
if config.CustomRequestHandler != nil {
err = config.CustomRequestHandler(req)
if err != nil {
pctx.Error = denyError{err}
return req, rejectResponse(pctx, pctx.Error)
}
}

sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, req, destination)

// Returning any kind of response in this handler is goproxy's way of short circuiting
Expand Down Expand Up @@ -609,6 +618,16 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) {
pctx.Error = denyError{err}
return "", pctx.Error
}

// Call the custom request handler if it exists
if config.CustomRequestHandler != nil {
err = config.CustomRequestHandler(pctx.Req)
if err != nil {
pctx.Error = denyError{err}
return "", pctx.Error
}
}

sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, pctx.Req, destination)
if pctx.Error != nil {
return "", denyError{pctx.Error}
Expand Down
134 changes: 134 additions & 0 deletions pkg/smokescreen/smokescreen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -1019,6 +1020,134 @@ func TestRejectResponseHandler(t *testing.T) {
})
}

func TestCustomRequestHandler(t *testing.T) {
r := require.New(t)
testHeader := "X-Verify-Request-Header"
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get(testHeader) != "" {
w.Write([]byte("header not removed!"))
return
}
w.Write([]byte("OK"))
})
customRequestHandler := func(r *http.Request) error {
header := r.Header.Get(testHeader)
r.Header.Del(testHeader)
if header == "" {
return errors.New("header doesn't exist")
}
if header != "valid" {
return errors.New("invalid header")
}
return nil
}

t.Run("CustomRequestHandler works for HTTPS", func(t *testing.T) {
testCases := []struct {
header http.Header
expectedError bool
}{
{
header: http.Header{testHeader: []string{"valid"}},
expectedError: false,
},
{
header: http.Header{testHeader: []string{"invalid"}},
expectedError: true,
},
}
cfg, err := testConfig("test-local-srv")
r.NoError(err)
err = cfg.SetAllowAddresses([]string{"127.0.0.1"})
r.NoError(err)
cfg.CustomRequestHandler = customRequestHandler

l, err := net.Listen("tcp", "localhost:0")
r.NoError(err)
cfg.Listener = l

proxy := proxyServer(cfg)
remote := httptest.NewTLSServer(h)
defer proxy.Close()
for _, testCase := range testCases {

client, err := proxyClientWithConnectHeaders(proxy.URL, testCase.header)
r.NoError(err)

req, err := http.NewRequest("GET", remote.URL, nil)
r.NoError(err)
resp, err := client.Do(req)
if testCase.expectedError {
r.Nil(resp)
r.Contains(err.Error(), "Request rejected by proxy")
} else {
r.NoError(err)
r.Equal(200, resp.StatusCode)
body, err := ioutil.ReadAll(resp.Body)
r.NoError(err)
resp.Body.Close()
r.Equal([]byte("OK"), body)
}
}
})

t.Run("CustomRequestHandler works for HTTP", func(t *testing.T) {
testCases := []struct {
header string
expectedError bool
}{
{
header: "valid",
expectedError: false,
},
{
header: "invalid",
expectedError: true,
},
}
cfg, err := testConfig("test-local-srv")
r.NoError(err)
err = cfg.SetAllowAddresses([]string{"127.0.0.1"})
r.NoError(err)
cfg.CustomRequestHandler = customRequestHandler

l, err := net.Listen("tcp", "localhost:0")
r.NoError(err)
cfg.Listener = l

remote := httptest.NewServer(h)

proxySrv := proxyServer(cfg)
r.NoError(err)
defer proxySrv.Close()

// Create a http.Client that uses our proxy
client, err := proxyClient(proxySrv.URL)
r.NoError(err)

for _, testCase := range testCases {
req, err := http.NewRequest("GET", remote.URL, nil)
r.NoError(err)
req.Header.Set(testHeader, testCase.header)
resp, err := client.Do(req)
if testCase.expectedError {
r.NoError(err)
errorMessage := resp.Header.Get("X-Smokescreen-Error")
r.Contains(errorMessage, "invalid header")

} else {
r.NoError(err)
r.Equal(200, resp.StatusCode)
body, err := ioutil.ReadAll(resp.Body)
r.NoError(err)
resp.Body.Close()
r.Equal([]byte("OK"), body)

}
}
})
}

func findCanonicalProxyDecision(logs []*logrus.Entry) *logrus.Entry {
for _, entry := range logs {
if entry.Message == CanonicalProxyDecision {
Expand Down Expand Up @@ -1070,6 +1199,10 @@ func proxyServer(conf *Config) *httptest.Server {
}

func proxyClient(proxy string) (*http.Client, error) {
return proxyClientWithConnectHeaders(proxy, nil)
}

func proxyClientWithConnectHeaders(proxy string, proxyConnectHeaders http.Header) (*http.Client, error) {
proxyUrl, err := url.Parse(proxy)
if err != nil {
return nil, err
Expand All @@ -1081,6 +1214,7 @@ func proxyClient(proxy string) (*http.Client, error) {
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
ProxyConnectHeader: proxyConnectHeaders,
},
}, nil
}

0 comments on commit 90e9012

Please sign in to comment.