diff --git a/parameterizable.go b/parameterizable.go index bdf0720a..c6d9ecb5 100644 --- a/parameterizable.go +++ b/parameterizable.go @@ -62,6 +62,8 @@ func (p *parameterizable) compileParameters(uri string, ends bool, regexCache ma if ends { builder.WriteString("$") + } else { + builder.WriteString(`/?$`) } pattern := builder.String() diff --git a/router.go b/router.go index 4feb5129..0d8fdce5 100644 --- a/router.go +++ b/router.go @@ -5,6 +5,7 @@ import ( "io/fs" "net/http" "regexp" + "strings" "maps" "slices" @@ -29,7 +30,7 @@ var ( errMatchMethodNotAllowed = errors.New("Method not allowed for this route") errMatchNotFound = errors.New("No match for this URI") - methodNotAllowedRoute = newRoute(func(response *Response, _ *Request) { // TODO document special route names + methodNotAllowedRoute = newRoute(func(response *Response, _ *Request) { response.Status(http.StatusMethodNotAllowed) }, RouteMethodNotAllowed) notFoundRoute = newRoute(func(response *Response, _ *Request) { @@ -66,7 +67,8 @@ func (rm *routeMatch) mergeParams(params map[string]string) { } func (rm *routeMatch) trimCurrentPath(fullMatch string) { - rm.currentPath = rm.currentPath[len(fullMatch):] + length := len(fullMatch) + rm.currentPath = rm.currentPath[length:] } // Router registers routes to be matched and executes a handler. @@ -85,6 +87,8 @@ type Router struct { prefix string routes []*Route subrouters []*Router + + slashCount int } var _ http.Handler = (*Router)(nil) // implements http.Handler @@ -282,7 +286,16 @@ func (r *Router) match(method string, match *routeMatch) bool { // Check if router itself matches var params []string if r.parameterizable.regex != nil { - params = r.parameterizable.regex.FindStringSubmatch(match.currentPath) + i := -1 + if len(match.currentPath) > 0 { + // Ignore slashes in router prefix + i = nthIndex(match.currentPath[1:], "/", r.slashCount) + 1 + } + if i <= 0 { + i = len(match.currentPath) + } + currentPath := match.currentPath[:i] + params = r.parameterizable.regex.FindStringSubmatch(currentPath) } else { params = []string{""} } @@ -318,7 +331,21 @@ func (r *Router) match(method string, match *routeMatch) bool { } match.route = notFoundRoute - return false + // Return true if the subrouter matched so we don't turn back and check other subrouters + return params != nil && len(params[0]) > 0 +} + +func nthIndex(str, substr string, n int) int { + index := -1 + for nth := 0; nth < n; nth++ { + i := strings.Index(str, substr) + if i == -1 || i == len(str) { + return -1 + } + index += i + 1 + str = str[i+1:] + } + return index } func (r *Router) makeParameters(match []string) map[string]string { @@ -350,7 +377,10 @@ func (r *Router) Subrouter(prefix string) *Router { globalMiddleware: r.globalMiddleware, regexCache: r.regexCache, } - router.compileParameters(router.prefix, false, r.regexCache) + if prefix != "" { + router.compileParameters(router.prefix, false, r.regexCache) + router.slashCount = strings.Count(prefix, "/") + } r.subrouters = append(r.subrouters, router) return router } diff --git a/router_test.go b/router_test.go index c4c98753..bdd1e5ee 100644 --- a/router_test.go +++ b/router_test.go @@ -6,6 +6,7 @@ import ( "io/fs" "net/http" "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -480,48 +481,63 @@ func TestRouter(t *testing.T) { viewers.Get("/", nil).Name("users.viewers.show") users.Put("/", nil).Name("users.update") + // Conflicting subrouters + conflict := router.Subrouter("/conflict") + conflict.Get("/", nil).Name("conflict.root") + conflict.Get("/child", nil).Name("conflict.child") + conflict2 := router.Subrouter("/conflict-2") + conflict2.Get("/", nil).Name("conflict-2.root") + conflict2.Get("/child", nil).Name("conflict-2.child") + + // Multiple segments in subrouter path + subrouter := router.Subrouter("/subrouter/{param}") + subrouter.Get("/", nil).Name("multiple-segments.subroute.index") + subrouter.Get("/subroute", nil).Name("multiple-segments.subroute.show") + subrouter.Get("/subroute/{name}", nil).Name("multiple-segments.subroute.name") + cases := []struct { path string method string expectedRoute string }{ {path: "/", method: http.MethodGet, expectedRoute: "root"}, - {path: "/", method: http.MethodPost, expectedRoute: "method-not-allowed"}, + {path: "/", method: http.MethodPost, expectedRoute: RouteMethodNotAllowed}, {path: "/first-level", method: http.MethodGet, expectedRoute: "first-level"}, - {path: "/first-level/", method: http.MethodGet, expectedRoute: "not-found"}, // Trailing slash - {path: "/first-level", method: http.MethodPost, expectedRoute: "method-not-allowed"}, + {path: "/first-level/", method: http.MethodGet, expectedRoute: RouteNotFound}, // Trailing slash + {path: "/first-level", method: http.MethodPost, expectedRoute: RouteMethodNotAllowed}, {path: "/categories", method: http.MethodGet, expectedRoute: "categories.index"}, + {path: "/categories/", method: http.MethodGet, expectedRoute: RouteNotFound}, // Trailing slash {path: "/categories/123", method: http.MethodGet, expectedRoute: "categories.show"}, {path: "/categories/123/inventory", method: http.MethodGet, expectedRoute: "categories.inventory"}, - {path: "/categories/test", method: http.MethodGet, expectedRoute: "not-found"}, + {path: "/categories/test", method: http.MethodGet, expectedRoute: RouteNotFound}, {path: "/categories/123/products", method: http.MethodGet, expectedRoute: "products.index"}, {path: "/categories/123/products", method: http.MethodPost, expectedRoute: "products.create"}, {path: "/categories/123/products/1234567890", method: http.MethodGet, expectedRoute: "products.show"}, {path: "/users/manage", method: http.MethodGet, expectedRoute: "users.admins.manage"}, - {path: "/users/manage", method: http.MethodGet, expectedRoute: "users.admins.manage"}, {path: "/users/profile", method: http.MethodGet, expectedRoute: "users.viewers.profile"}, {path: "/users", method: http.MethodGet, expectedRoute: "users.viewers.show"}, // Method not allowed on users.admins.create {path: "/users", method: http.MethodPut, expectedRoute: "users.update"}, + {path: "/conflict", method: http.MethodGet, expectedRoute: "conflict.root"}, + {path: "/conflict/", method: http.MethodGet, expectedRoute: RouteNotFound}, + {path: "/conflict/child", method: http.MethodGet, expectedRoute: "conflict.child"}, + {path: "/conflict-2", method: http.MethodGet, expectedRoute: "conflict-2.root"}, + {path: "/conflict-2/", method: http.MethodGet, expectedRoute: RouteNotFound}, + {path: "/conflict-2/child", method: http.MethodGet, expectedRoute: "conflict-2.child"}, + {path: "/categories/123/not-a-route", method: http.MethodGet, expectedRoute: RouteNotFound}, + {path: "/categories/123/not-a-route/", method: http.MethodGet, expectedRoute: RouteNotFound}, + {path: "/subrouter/value", method: http.MethodGet, expectedRoute: "multiple-segments.subroute.index"}, + {path: "/subrouter/value/", method: http.MethodGet, expectedRoute: RouteNotFound}, + {path: "/subrouter/value/subroute", method: http.MethodGet, expectedRoute: "multiple-segments.subroute.show"}, + {path: "/subrouter/value/subroute/", method: http.MethodGet, expectedRoute: RouteNotFound}, + {path: "/subrouter/value/subroute/johndoe", method: http.MethodGet, expectedRoute: "multiple-segments.subroute.name"}, } for _, c := range cases { c := c - t.Run(fmt.Sprintf("%s_%s", c.method, c.path), func(t *testing.T) { + t.Run(fmt.Sprintf("%s_%s", c.method, strings.ReplaceAll(c.path, "/", "_")), func(t *testing.T) { match := routeMatch{currentPath: c.path} - ok := router.match(c.method, &match) - switch c.expectedRoute { - case "": - assert.False(t, ok) - case "not-found": - assert.False(t, ok) - assert.Equal(t, notFoundRoute, match.route) - case "method-not-allowed": - assert.True(t, ok) - assert.Equal(t, methodNotAllowedRoute, match.route) - default: - assert.True(t, ok) - assert.Equal(t, c.expectedRoute, match.route.name) - } + router.match(c.method, &match) + assert.Equal(t, c.expectedRoute, match.route.name) }) } diff --git a/static_test.go b/static_test.go index b8ce6a64..c6ba1dbd 100644 --- a/static_test.go +++ b/static_test.go @@ -84,6 +84,17 @@ func TestStaticHandler(t *testing.T) { assert.Equal(t, "{\n \"custom-entry\": \"value\"\n}", string(body)) }, }, + { + uri: "/lang/en-US/fields.json", + directory: "resources", + download: true, + expected: func(t *testing.T, response *Response, result *http.Response, body []byte) { + assert.Equal(t, http.StatusOK, response.GetStatus()) + assert.Equal(t, "application/json", result.Header.Get("Content-Type")) + assert.Equal(t, "attachment; filename=\"fields.json\"", result.Header.Get("Content-Disposition")) + assert.Equal(t, "{\n \"email\": \"email address\"\n}", string(body)) + }, + }, } for _, c := range cases {