From 89a14eb3598eb23ba79147e029aaec08e4e27f99 Mon Sep 17 00:00:00 2001 From: crissi98 Date: Wed, 1 Sep 2021 16:13:22 +0200 Subject: [PATCH] fix routing error when two routes use same path and different methods --- router.go | 8 ++++---- router_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/router.go b/router.go index 8932499..ab615ce 100644 --- a/router.go +++ b/router.go @@ -16,7 +16,7 @@ import ( type Router struct { baseRouter routers.Router errMapper *errorMapper - implementations map[*routers.Route]requestHandler + implementations map[routers.Route]requestHandler } // NewRouter creates a new Router with the path of a OpenAPI specification file in YAML or JSON format. @@ -32,7 +32,7 @@ func NewRouter(swaggerPath string) (*Router, error) { return &Router{ baseRouter: router, errMapper: &errorMapper{errorMapping: make(map[reflect.Type]*HTTPError)}, - implementations: make(map[*routers.Route]requestHandler), + implementations: make(map[routers.Route]requestHandler), }, nil } @@ -50,7 +50,7 @@ func (router *Router) ServeHTTP(writer http.ResponseWriter, request *http.Reques response.write(writer) return } - handler, ok := router.implementations[route] + handler, ok := router.implementations[*route] if ok { validationInput := &openapi3filter.RequestValidationInput{ Request: request, @@ -113,7 +113,7 @@ func (router *Router) AddRequestHandlerWithAuthFunc(method string, path string, options.AuthenticationFunc = authFunc } - router.implementations[route] = requestHandler{ + router.implementations[*route] = requestHandler{ errMapper: router.errMapper, handlerFunction: handleFunc, options: options, diff --git a/router_test.go b/router_test.go index 700f050..babb330 100644 --- a/router_test.go +++ b/router_test.go @@ -239,6 +239,50 @@ func TestRouter_POSTInvalidData(t *testing.T) { } } +func TestRouter_POSTandGETSamePath(t *testing.T) { + //given + router, server := getRouterAndServer() + defer server.Close() + postCalled := false + getCalled := false + sendData := TestData{ + Data: "test", + } + dataBytes, _ := json.Marshal(&sendData) + router.AddRequestHandler(http.MethodPost, "/test", func(_ *http.Request, _ map[string]string) (*Response, error) { + postCalled = true + return &Response{ + StatusCode: http.StatusNoContent, + }, nil + }) + router.AddRequestHandler(http.MethodGet, "/test", func(_ *http.Request, _ map[string]string) (*Response, error) { + getCalled = true + return &Response{ + StatusCode: http.StatusOK, + Body: &TestData{ + Data: "test", + }, + }, nil + }) + + // when + postResponse, postErr := server.Client().Post(server.URL+"/test", "application/json", bytes.NewReader(dataBytes)) + getResponse, getErr := server.Client().Get(server.URL + "/test") + + // then + assert.Nil(t, postErr) + if assert.NotNil(t, postResponse) { + assert.True(t, postCalled) + assert.Equal(t, http.StatusNoContent, postResponse.StatusCode) + } + + assert.Nil(t, getErr) + if assert.NotNil(t, getResponse) { + assert.True(t, getCalled) + assert.Equal(t, http.StatusOK, getResponse.StatusCode) + } +} + func TestRouter_InvalidPath(t *testing.T) { // given _, server := getRouterAndServer()