diff --git a/Makefile b/Makefile index 5e6c6dd1e..60567ea6a 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ generate: go build -buildmode=plugin -o ./transport/http/server/plugin/tests/lura-server-example.so ./transport/http/server/plugin/tests go build -buildmode=plugin -o ./proxy/plugin/tests/lura-request-modifier-example.so ./proxy/plugin/tests/logger go build -buildmode=plugin -o ./proxy/plugin/tests/lura-error-example.so ./proxy/plugin/tests/error + go build -buildmode=plugin -o ./proxy/plugin/tests/lura-shortcircuit-example.so ./proxy/plugin/tests/shortcircuit test: generate go test -cover -race ./... diff --git a/proxy/plugin.go b/proxy/plugin.go index 0767b2251..60fdece63 100644 --- a/proxy/plugin.go +++ b/proxy/plugin.go @@ -107,58 +107,74 @@ func newPluginMiddleware(logger logging.Logger, tag, pattern string, cfg map[str if totRespModifiers == 0 { return func(ctx context.Context, r *Request) (*Response, error) { - var err error - r, err = executeRequestModifiers(ctx, reqModifiers, r) + req, resp, err := executeRequestModifiers(ctx, reqModifiers, r) if err != nil { return nil, err } - return next[0](ctx, r) + if resp != nil { + return resp, nil + } + + return next[0](ctx, req) } } return func(ctx context.Context, r *Request) (*Response, error) { - var err error - r, err = executeRequestModifiers(ctx, reqModifiers, r) + req, resp, err := executeRequestModifiers(ctx, reqModifiers, r) if err != nil { return nil, err } - resp, err := next[0](ctx, r) - if err != nil { - return resp, err + if resp == nil { + var err error + resp, err = next[0](ctx, req) + if err != nil { + return resp, err + } } - return executeResponseModifiers(ctx, respModifiers, resp, newRequestWrapper(ctx, r)) + return executeResponseModifiers(ctx, respModifiers, resp, newRequestWrapper(ctx, req)) } } } -func executeRequestModifiers(ctx context.Context, reqModifiers []func(interface{}) (interface{}, error), r *Request) (*Request, error) { +func executeRequestModifiers(ctx context.Context, reqModifiers []func(interface{}) (interface{}, error), req *Request) (*Request, *Response, error) { var tmp RequestWrapper - tmp = newRequestWrapper(ctx, r) + tmp = newRequestWrapper(ctx, req) + var resp *Response for _, f := range reqModifiers { res, err := f(tmp) if err != nil { - return nil, err + return nil, nil, err } - t, ok := res.(RequestWrapper) - if !ok { + switch t := res.(type) { + case RequestWrapper: + tmp = t + case ResponseWrapper: + resp = new(Response) + resp.Data = t.Data() + resp.IsComplete = t.IsComplete() + resp.Io = t.Io() + resp.Metadata = Metadata{} + resp.Metadata.Headers = t.Headers() + resp.Metadata.StatusCode = t.StatusCode() + break + default: continue } - tmp = t } - r.Method = tmp.Method() - r.URL = tmp.URL() - r.Query = tmp.Query() - r.Path = tmp.Path() - r.Body = tmp.Body() - r.Params = tmp.Params() - r.Headers = tmp.Headers() + req.Method = tmp.Method() + req.URL = tmp.URL() + req.Query = tmp.Query() + req.Path = tmp.Path() + req.Body = tmp.Body() + req.Params = tmp.Params() + req.Headers = tmp.Headers() - return r, nil + return req, resp, nil } func executeResponseModifiers(ctx context.Context, respModifiers []func(interface{}) (interface{}, error), r *Response, req RequestWrapper) (*Response, error) { diff --git a/proxy/plugin/modifier_test.go b/proxy/plugin/modifier_test.go index 60bf6b598..663187b92 100644 --- a/proxy/plugin/modifier_test.go +++ b/proxy/plugin/modifier_test.go @@ -31,7 +31,7 @@ func ExampleLoadWithLoggerAndContext() { fmt.Println(err.Error()) return } - if total != 2 { + if total != 3 { fmt.Printf("unexpected number of loaded plugins!. have %d, want 2\n", total) return } @@ -92,7 +92,7 @@ func TestLoad(t *testing.T) { t.Error(err.Error()) t.Fail() } - if total != 2 { + if total != 3 { t.Errorf("unexpected number of loaded plugins!. have %d, want 2", total) } diff --git a/proxy/plugin/tests/shortcircuit/main.go b/proxy/plugin/tests/shortcircuit/main.go new file mode 100644 index 000000000..5a133dabf --- /dev/null +++ b/proxy/plugin/tests/shortcircuit/main.go @@ -0,0 +1,99 @@ +package main + +import ( + "context" + "errors" + "io" + "net/http" + "net/url" + "strings" +) + +func main() {} + +var ModifierRegisterer = registerer("lura-shortcircuit-example") + +type registerer string + +func (r registerer) RegisterModifiers(f func( + name string, + modifierFactory func(map[string]interface{}) func(interface{}) (interface{}, error), + appliesToRequest bool, + appliesToResponse bool, +), +) { + f(string(r)+"-request", r.requestModifierFactory, true, false) + f(string(r)+"-response", r.reqsponseModifierFactory, false, true) +} + +func (r registerer) requestModifierFactory(_ map[string]interface{}) func(interface{}) (interface{}, error) { + return func(input interface{}) (interface{}, error) { + req, ok := input.(RequestWrapper) + if !ok { + return nil, unknownTypeErr + } + + header := make(http.Header) + header.Add("X-Plugin-Request", "shortcircuit") + return responseWrapper{ + request: req, + io: strings.NewReader("shortcircuit"), + headers: header, + statusCode: http.StatusTeapot, + }, nil + } +} + +func (r registerer) reqsponseModifierFactory(_ map[string]interface{}) func(interface{}) (interface{}, error) { + return func(input interface{}) (interface{}, error) { + resp, ok := input.(ResponseWrapper) + if !ok { + return nil, unknownTypeErr + } + + header := http.Header(resp.Headers()) + header.Add("X-Plugin-Response", "shortcircuit") + return resp, nil + } +} + +type responseWrapper struct { + ctx context.Context + request interface{} + data map[string]interface{} + isComplete bool + headers map[string][]string + statusCode int + io io.Reader +} + +func (r responseWrapper) Context() context.Context { return r.ctx } +func (r responseWrapper) Request() interface{} { return r.request } +func (r responseWrapper) Data() map[string]interface{} { return r.data } +func (r responseWrapper) IsComplete() bool { return r.isComplete } +func (r responseWrapper) Io() io.Reader { return r.io } +func (r responseWrapper) Headers() map[string][]string { return r.headers } +func (r responseWrapper) StatusCode() int { return r.statusCode } + +var unknownTypeErr = errors.New("unknown request type") + +type RequestWrapper interface { + Context() context.Context + Params() map[string]string + Headers() map[string][]string + Body() io.ReadCloser + Method() string + URL() *url.URL + Query() url.Values + Path() string +} + +type ResponseWrapper interface { + Context() context.Context + Request() interface{} + Data() map[string]interface{} + IsComplete() bool + Io() io.Reader + Headers() map[string][]string + StatusCode() int +} diff --git a/proxy/plugin_test.go b/proxy/plugin_test.go index 892c03ac9..daa9b8587 100644 --- a/proxy/plugin_test.go +++ b/proxy/plugin_test.go @@ -193,6 +193,103 @@ func TestNewPluginMiddleware_error_response(t *testing.T) { } } +func TestNewPluginMiddleware_shortcircuit_request(t *testing.T) { + plugin.LoadWithLogger("./plugin/tests", ".so", plugin.RegisterModifier, logging.NoOp) + + validator := func(ctx context.Context, r *Request) (*Response, error) { + t.Helper() + t.Error("the backend should not be called") + return nil, nil + } + + bknd := NewBackendPluginMiddleware( + logging.NoOp, + &config.Backend{}, + )(validator) + + p := NewPluginMiddleware( + logging.NoOp, + &config.EndpointConfig{ + ExtraConfig: map[string]interface{}{ + plugin.Namespace: map[string]interface{}{ + "name": []interface{}{ + "lura-shortcircuit-example-request", + }, + }, + }, + }, + )(bknd) + + resp, err := p(context.Background(), &Request{Path: "/bar"}) + if err != nil { + t.Error(err.Error()) + } + + if resp == nil { + t.Errorf("unexpected response: %v", resp) + return + } + + if sc := resp.Metadata.StatusCode; sc != http.StatusTeapot { + t.Errorf("unexpected status code: %d", sc) + } + + header := http.Header(resp.Metadata.Headers) + if h := header.Get("X-Plugin-Request"); h != "shortcircuit" { + t.Errorf("unexpected header: %s", h) + } +} + +func TestNewPluginMiddleware_shortcircuit_request_response(t *testing.T) { + plugin.LoadWithLogger("./plugin/tests", ".so", plugin.RegisterModifier, logging.NoOp) + + validator := func(ctx context.Context, r *Request) (*Response, error) { + t.Error("the backend should not be called") + return nil, nil + } + + bknd := NewBackendPluginMiddleware( + logging.NoOp, + &config.Backend{}, + )(validator) + + p := NewPluginMiddleware( + logging.NoOp, + &config.EndpointConfig{ + ExtraConfig: map[string]interface{}{ + plugin.Namespace: map[string]interface{}{ + "name": []interface{}{ + "lura-shortcircuit-example-request", + "lura-shortcircuit-example-response", + }, + }, + }, + }, + )(bknd) + + resp, err := p(context.Background(), &Request{Path: "/bar"}) + if err != nil { + t.Error(err.Error()) + } + + if resp == nil { + t.Errorf("unexpected response: %v", resp) + return + } + + if sc := resp.Metadata.StatusCode; sc != http.StatusTeapot { + t.Errorf("unexpected status code: %d", sc) + } + + header := http.Header(resp.Metadata.Headers) + if h := header.Get("X-Plugin-Request"); h != "shortcircuit" { + t.Errorf("unexpected header: %s", h) + } + if h := header.Get("X-Plugin-Response"); h != "shortcircuit" { + t.Errorf("unexpected header: %s", h) + } +} + func TestNewPluginMiddleware_PoisonedPlugin(t *testing.T) { plugin.RegisterModifier("poisoned", func(map[string]interface{}) func(interface{}) (interface{}, error) { return nil