Skip to content

Commit

Permalink
Router: fixes for conflicting subrouters and trailing slash
Browse files Browse the repository at this point in the history
- fix conflicting subrouters with a prefix starting with the same characters (e.g.: /test and /test-2)
- prevent turning back and exploring other branches if subrouter matches but none of its routes do
- fix "/" route at the root router being matched if a subrouter matches but none of its routes do and a trailing slash is present
  • Loading branch information
System-Glitch committed Apr 4, 2024
1 parent 6a47a8f commit ae89345
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 25 deletions.
2 changes: 2 additions & 0 deletions parameterizable.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ func (p *parameterizable) compileParameters(uri string, ends bool, regexCache ma

if ends {
builder.WriteString("$")
} else {
builder.WriteString(`/?$`)
}

pattern := builder.String()
Expand Down
40 changes: 35 additions & 5 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io/fs"
"net/http"
"regexp"
"strings"

"maps"
"slices"
Expand All @@ -29,7 +30,7 @@ var (
errMatchMethodNotAllowed = errors.New("Method not allowed for this route")
errMatchNotFound = errors.New("No match for this URI")

methodNotAllowedRoute = newRoute(func(response *Response, _ *Request) { // TODO document special route names
methodNotAllowedRoute = newRoute(func(response *Response, _ *Request) {
response.Status(http.StatusMethodNotAllowed)
}, RouteMethodNotAllowed)
notFoundRoute = newRoute(func(response *Response, _ *Request) {
Expand Down Expand Up @@ -66,7 +67,8 @@ func (rm *routeMatch) mergeParams(params map[string]string) {
}

func (rm *routeMatch) trimCurrentPath(fullMatch string) {
rm.currentPath = rm.currentPath[len(fullMatch):]
length := len(fullMatch)
rm.currentPath = rm.currentPath[length:]
}

// Router registers routes to be matched and executes a handler.
Expand All @@ -85,6 +87,8 @@ type Router struct {
prefix string
routes []*Route
subrouters []*Router

slashCount int
}

var _ http.Handler = (*Router)(nil) // implements http.Handler
Expand Down Expand Up @@ -282,7 +286,16 @@ func (r *Router) match(method string, match *routeMatch) bool {
// Check if router itself matches
var params []string
if r.parameterizable.regex != nil {
params = r.parameterizable.regex.FindStringSubmatch(match.currentPath)
i := -1
if len(match.currentPath) > 0 {
// Ignore slashes in router prefix
i = nthIndex(match.currentPath[1:], "/", r.slashCount) + 1
}
if i <= 0 {
i = len(match.currentPath)
}
currentPath := match.currentPath[:i]
params = r.parameterizable.regex.FindStringSubmatch(currentPath)
} else {
params = []string{""}
}
Expand Down Expand Up @@ -318,7 +331,21 @@ func (r *Router) match(method string, match *routeMatch) bool {
}

match.route = notFoundRoute
return false
// Return true if the subrouter matched so we don't turn back and check other subrouters
return params != nil && len(params[0]) > 0
}

func nthIndex(str, substr string, n int) int {
index := -1
for nth := 0; nth < n; nth++ {
i := strings.Index(str, substr)
if i == -1 || i == len(str) {
return -1
}
index += i + 1
str = str[i+1:]
}
return index
}

func (r *Router) makeParameters(match []string) map[string]string {
Expand Down Expand Up @@ -350,7 +377,10 @@ func (r *Router) Subrouter(prefix string) *Router {
globalMiddleware: r.globalMiddleware,
regexCache: r.regexCache,
}
router.compileParameters(router.prefix, false, r.regexCache)
if prefix != "" {
router.compileParameters(router.prefix, false, r.regexCache)
router.slashCount = strings.Count(prefix, "/")
}
r.subrouters = append(r.subrouters, router)
return router
}
Expand Down
56 changes: 36 additions & 20 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io/fs"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -480,48 +481,63 @@ func TestRouter(t *testing.T) {
viewers.Get("/", nil).Name("users.viewers.show")
users.Put("/", nil).Name("users.update")

// Conflicting subrouters
conflict := router.Subrouter("/conflict")
conflict.Get("/", nil).Name("conflict.root")
conflict.Get("/child", nil).Name("conflict.child")
conflict2 := router.Subrouter("/conflict-2")
conflict2.Get("/", nil).Name("conflict-2.root")
conflict2.Get("/child", nil).Name("conflict-2.child")

// Multiple segments in subrouter path
subrouter := router.Subrouter("/subrouter/{param}")
subrouter.Get("/", nil).Name("multiple-segments.subroute.index")
subrouter.Get("/subroute", nil).Name("multiple-segments.subroute.show")
subrouter.Get("/subroute/{name}", nil).Name("multiple-segments.subroute.name")

cases := []struct {
path string
method string
expectedRoute string
}{
{path: "/", method: http.MethodGet, expectedRoute: "root"},
{path: "/", method: http.MethodPost, expectedRoute: "method-not-allowed"},
{path: "/", method: http.MethodPost, expectedRoute: RouteMethodNotAllowed},
{path: "/first-level", method: http.MethodGet, expectedRoute: "first-level"},
{path: "/first-level/", method: http.MethodGet, expectedRoute: "not-found"}, // Trailing slash
{path: "/first-level", method: http.MethodPost, expectedRoute: "method-not-allowed"},
{path: "/first-level/", method: http.MethodGet, expectedRoute: RouteNotFound}, // Trailing slash
{path: "/first-level", method: http.MethodPost, expectedRoute: RouteMethodNotAllowed},
{path: "/categories", method: http.MethodGet, expectedRoute: "categories.index"},
{path: "/categories/", method: http.MethodGet, expectedRoute: RouteNotFound}, // Trailing slash
{path: "/categories/123", method: http.MethodGet, expectedRoute: "categories.show"},
{path: "/categories/123/inventory", method: http.MethodGet, expectedRoute: "categories.inventory"},
{path: "/categories/test", method: http.MethodGet, expectedRoute: "not-found"},
{path: "/categories/test", method: http.MethodGet, expectedRoute: RouteNotFound},
{path: "/categories/123/products", method: http.MethodGet, expectedRoute: "products.index"},
{path: "/categories/123/products", method: http.MethodPost, expectedRoute: "products.create"},
{path: "/categories/123/products/1234567890", method: http.MethodGet, expectedRoute: "products.show"},
{path: "/users/manage", method: http.MethodGet, expectedRoute: "users.admins.manage"},
{path: "/users/manage", method: http.MethodGet, expectedRoute: "users.admins.manage"},
{path: "/users/profile", method: http.MethodGet, expectedRoute: "users.viewers.profile"},
{path: "/users", method: http.MethodGet, expectedRoute: "users.viewers.show"}, // Method not allowed on users.admins.create
{path: "/users", method: http.MethodPut, expectedRoute: "users.update"},
{path: "/conflict", method: http.MethodGet, expectedRoute: "conflict.root"},
{path: "/conflict/", method: http.MethodGet, expectedRoute: RouteNotFound},
{path: "/conflict/child", method: http.MethodGet, expectedRoute: "conflict.child"},
{path: "/conflict-2", method: http.MethodGet, expectedRoute: "conflict-2.root"},
{path: "/conflict-2/", method: http.MethodGet, expectedRoute: RouteNotFound},
{path: "/conflict-2/child", method: http.MethodGet, expectedRoute: "conflict-2.child"},
{path: "/categories/123/not-a-route", method: http.MethodGet, expectedRoute: RouteNotFound},
{path: "/categories/123/not-a-route/", method: http.MethodGet, expectedRoute: RouteNotFound},
{path: "/subrouter/value", method: http.MethodGet, expectedRoute: "multiple-segments.subroute.index"},
{path: "/subrouter/value/", method: http.MethodGet, expectedRoute: RouteNotFound},
{path: "/subrouter/value/subroute", method: http.MethodGet, expectedRoute: "multiple-segments.subroute.show"},
{path: "/subrouter/value/subroute/", method: http.MethodGet, expectedRoute: RouteNotFound},
{path: "/subrouter/value/subroute/johndoe", method: http.MethodGet, expectedRoute: "multiple-segments.subroute.name"},
}

for _, c := range cases {
c := c
t.Run(fmt.Sprintf("%s_%s", c.method, c.path), func(t *testing.T) {
t.Run(fmt.Sprintf("%s_%s", c.method, strings.ReplaceAll(c.path, "/", "_")), func(t *testing.T) {
match := routeMatch{currentPath: c.path}
ok := router.match(c.method, &match)
switch c.expectedRoute {
case "":
assert.False(t, ok)
case "not-found":
assert.False(t, ok)
assert.Equal(t, notFoundRoute, match.route)
case "method-not-allowed":
assert.True(t, ok)
assert.Equal(t, methodNotAllowedRoute, match.route)
default:
assert.True(t, ok)
assert.Equal(t, c.expectedRoute, match.route.name)
}
router.match(c.method, &match)
assert.Equal(t, c.expectedRoute, match.route.name)
})
}

Expand Down
11 changes: 11 additions & 0 deletions static_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ func TestStaticHandler(t *testing.T) {
assert.Equal(t, "{\n \"custom-entry\": \"value\"\n}", string(body))
},
},
{
uri: "/lang/en-US/fields.json",
directory: "resources",
download: true,
expected: func(t *testing.T, response *Response, result *http.Response, body []byte) {
assert.Equal(t, http.StatusOK, response.GetStatus())
assert.Equal(t, "application/json", result.Header.Get("Content-Type"))
assert.Equal(t, "attachment; filename=\"fields.json\"", result.Header.Get("Content-Disposition"))
assert.Equal(t, "{\n \"email\": \"email address\"\n}", string(body))
},
},
}

for _, c := range cases {
Expand Down

0 comments on commit ae89345

Please sign in to comment.