Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support shortcircuit in Request Modifiers #731

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ generate:
go build -buildmode=plugin -o ./transport/http/server/plugin/tests/lura-server-example.so ./transport/http/server/plugin/tests
go build -buildmode=plugin -o ./proxy/plugin/tests/lura-request-modifier-example.so ./proxy/plugin/tests/logger
go build -buildmode=plugin -o ./proxy/plugin/tests/lura-error-example.so ./proxy/plugin/tests/error
go build -buildmode=plugin -o ./proxy/plugin/tests/lura-shortcircuit-example.so ./proxy/plugin/tests/shortcircuit

test: generate
go test -cover -race ./...
Expand Down
62 changes: 39 additions & 23 deletions proxy/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,58 +107,74 @@ func newPluginMiddleware(logger logging.Logger, tag, pattern string, cfg map[str

if totRespModifiers == 0 {
return func(ctx context.Context, r *Request) (*Response, error) {
var err error
r, err = executeRequestModifiers(ctx, reqModifiers, r)
req, resp, err := executeRequestModifiers(ctx, reqModifiers, r)
if err != nil {
return nil, err
}

return next[0](ctx, r)
if resp != nil {
return resp, nil
}

return next[0](ctx, req)
}
}

return func(ctx context.Context, r *Request) (*Response, error) {
var err error
r, err = executeRequestModifiers(ctx, reqModifiers, r)
req, resp, err := executeRequestModifiers(ctx, reqModifiers, r)
if err != nil {
return nil, err
}

resp, err := next[0](ctx, r)
if err != nil {
return resp, err
if resp == nil {
var err error
resp, err = next[0](ctx, req)
if err != nil {
return resp, err
}
}

return executeResponseModifiers(ctx, respModifiers, resp, newRequestWrapper(ctx, r))
return executeResponseModifiers(ctx, respModifiers, resp, newRequestWrapper(ctx, req))
}
}
}

func executeRequestModifiers(ctx context.Context, reqModifiers []func(interface{}) (interface{}, error), r *Request) (*Request, error) {
func executeRequestModifiers(ctx context.Context, reqModifiers []func(interface{}) (interface{}, error), req *Request) (*Request, *Response, error) {
var tmp RequestWrapper
tmp = newRequestWrapper(ctx, r)
tmp = newRequestWrapper(ctx, req)
var resp *Response

for _, f := range reqModifiers {
res, err := f(tmp)
if err != nil {
return nil, err
return nil, nil, err
}
t, ok := res.(RequestWrapper)
if !ok {
switch t := res.(type) {
case RequestWrapper:
tmp = t
case ResponseWrapper:
resp = new(Response)
resp.Data = t.Data()
resp.IsComplete = t.IsComplete()
resp.Io = t.Io()
resp.Metadata = Metadata{}
resp.Metadata.Headers = t.Headers()
resp.Metadata.StatusCode = t.StatusCode()
break
default:
continue
}
tmp = t
}

r.Method = tmp.Method()
r.URL = tmp.URL()
r.Query = tmp.Query()
r.Path = tmp.Path()
r.Body = tmp.Body()
r.Params = tmp.Params()
r.Headers = tmp.Headers()
req.Method = tmp.Method()
req.URL = tmp.URL()
req.Query = tmp.Query()
req.Path = tmp.Path()
req.Body = tmp.Body()
req.Params = tmp.Params()
req.Headers = tmp.Headers()

return r, nil
return req, resp, nil
}

func executeResponseModifiers(ctx context.Context, respModifiers []func(interface{}) (interface{}, error), r *Response, req RequestWrapper) (*Response, error) {
Expand Down
4 changes: 2 additions & 2 deletions proxy/plugin/modifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func ExampleLoadWithLoggerAndContext() {
fmt.Println(err.Error())
return
}
if total != 2 {
if total != 3 {
fmt.Printf("unexpected number of loaded plugins!. have %d, want 2\n", total)
return
}
Expand Down Expand Up @@ -92,7 +92,7 @@ func TestLoad(t *testing.T) {
t.Error(err.Error())
t.Fail()
}
if total != 2 {
if total != 3 {
t.Errorf("unexpected number of loaded plugins!. have %d, want 2", total)
}

Expand Down
99 changes: 99 additions & 0 deletions proxy/plugin/tests/shortcircuit/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package main

import (
"context"
"errors"
"io"
"net/http"
"net/url"
"strings"
)

func main() {}

var ModifierRegisterer = registerer("lura-shortcircuit-example")

type registerer string

func (r registerer) RegisterModifiers(f func(
name string,
modifierFactory func(map[string]interface{}) func(interface{}) (interface{}, error),
appliesToRequest bool,
appliesToResponse bool,
),
) {
f(string(r)+"-request", r.requestModifierFactory, true, false)
f(string(r)+"-response", r.reqsponseModifierFactory, false, true)
}

func (r registerer) requestModifierFactory(_ map[string]interface{}) func(interface{}) (interface{}, error) {
return func(input interface{}) (interface{}, error) {
req, ok := input.(RequestWrapper)
if !ok {
return nil, unknownTypeErr
}

header := make(http.Header)
header.Add("X-Plugin-Request", "shortcircuit")
return responseWrapper{
request: req,
io: strings.NewReader("shortcircuit"),
headers: header,
statusCode: http.StatusTeapot,
}, nil
}
}

func (r registerer) reqsponseModifierFactory(_ map[string]interface{}) func(interface{}) (interface{}, error) {
return func(input interface{}) (interface{}, error) {
resp, ok := input.(ResponseWrapper)
if !ok {
return nil, unknownTypeErr
}

header := http.Header(resp.Headers())
header.Add("X-Plugin-Response", "shortcircuit")
return resp, nil
}
}

type responseWrapper struct {
ctx context.Context
request interface{}
data map[string]interface{}
isComplete bool
headers map[string][]string
statusCode int
io io.Reader
}

func (r responseWrapper) Context() context.Context { return r.ctx }
func (r responseWrapper) Request() interface{} { return r.request }
func (r responseWrapper) Data() map[string]interface{} { return r.data }
func (r responseWrapper) IsComplete() bool { return r.isComplete }
func (r responseWrapper) Io() io.Reader { return r.io }
func (r responseWrapper) Headers() map[string][]string { return r.headers }
func (r responseWrapper) StatusCode() int { return r.statusCode }

var unknownTypeErr = errors.New("unknown request type")

type RequestWrapper interface {
Context() context.Context
Params() map[string]string
Headers() map[string][]string
Body() io.ReadCloser
Method() string
URL() *url.URL
Query() url.Values
Path() string
}

type ResponseWrapper interface {
Context() context.Context
Request() interface{}
Data() map[string]interface{}
IsComplete() bool
Io() io.Reader
Headers() map[string][]string
StatusCode() int
}
97 changes: 97 additions & 0 deletions proxy/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,103 @@ func TestNewPluginMiddleware_error_response(t *testing.T) {
}
}

func TestNewPluginMiddleware_shortcircuit_request(t *testing.T) {
plugin.LoadWithLogger("./plugin/tests", ".so", plugin.RegisterModifier, logging.NoOp)

validator := func(ctx context.Context, r *Request) (*Response, error) {
t.Helper()
t.Error("the backend should not be called")
return nil, nil
}

bknd := NewBackendPluginMiddleware(
logging.NoOp,
&config.Backend{},
)(validator)

p := NewPluginMiddleware(
logging.NoOp,
&config.EndpointConfig{
ExtraConfig: map[string]interface{}{
plugin.Namespace: map[string]interface{}{
"name": []interface{}{
"lura-shortcircuit-example-request",
},
},
},
},
)(bknd)

resp, err := p(context.Background(), &Request{Path: "/bar"})
if err != nil {
t.Error(err.Error())
}

if resp == nil {
t.Errorf("unexpected response: %v", resp)
return
}

if sc := resp.Metadata.StatusCode; sc != http.StatusTeapot {
t.Errorf("unexpected status code: %d", sc)
}

header := http.Header(resp.Metadata.Headers)
if h := header.Get("X-Plugin-Request"); h != "shortcircuit" {
t.Errorf("unexpected header: %s", h)
}
}

func TestNewPluginMiddleware_shortcircuit_request_response(t *testing.T) {
plugin.LoadWithLogger("./plugin/tests", ".so", plugin.RegisterModifier, logging.NoOp)

validator := func(ctx context.Context, r *Request) (*Response, error) {
t.Error("the backend should not be called")
return nil, nil
}

bknd := NewBackendPluginMiddleware(
logging.NoOp,
&config.Backend{},
)(validator)

p := NewPluginMiddleware(
logging.NoOp,
&config.EndpointConfig{
ExtraConfig: map[string]interface{}{
plugin.Namespace: map[string]interface{}{
"name": []interface{}{
"lura-shortcircuit-example-request",
"lura-shortcircuit-example-response",
},
},
},
},
)(bknd)

resp, err := p(context.Background(), &Request{Path: "/bar"})
if err != nil {
t.Error(err.Error())
}

if resp == nil {
t.Errorf("unexpected response: %v", resp)
return
}

if sc := resp.Metadata.StatusCode; sc != http.StatusTeapot {
t.Errorf("unexpected status code: %d", sc)
}

header := http.Header(resp.Metadata.Headers)
if h := header.Get("X-Plugin-Request"); h != "shortcircuit" {
t.Errorf("unexpected header: %s", h)
}
if h := header.Get("X-Plugin-Response"); h != "shortcircuit" {
t.Errorf("unexpected header: %s", h)
}
}

func TestNewPluginMiddleware_PoisonedPlugin(t *testing.T) {
plugin.RegisterModifier("poisoned", func(map[string]interface{}) func(interface{}) (interface{}, error) {
return nil
Expand Down