Skip to content

Commit

Permalink
refactor: Use custom header to find target cluster
Browse files Browse the repository at this point in the history
* Instead of using path parameter, which can confuse users, use a custom header that will be read by LB and then redirect the query to correct backend. This should work for most of the datasources as we can always add custom headers in Grafana. This will also avoid having to manipulate path manually after stripping path parameter.

* Update tests and docs accordingly.

* Add note in docs that it is possible to find target cluster using query labels and show how to do it.

Signed-off-by: Mahendra Paipuri <[email protected]>
  • Loading branch information
mahendrapaipuri committed Oct 25, 2024
1 parent eb89dc2 commit 220a18e
Show file tree
Hide file tree
Showing 11 changed files with 227 additions and 130 deletions.
5 changes: 3 additions & 2 deletions pkg/lb/cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ var mockCEEMSLBApp = *kingpin.New(
"Mock Load Balancer App.",
)

func queryLB(address string) error {
func queryLB(address, clusterID string) error {
req, err := http.NewRequest(http.MethodGet, "http://"+address, nil) //nolint:noctx
if err != nil {
return err
}

req.Header.Add("X-Grafana-User", "usr1")
req.Header.Add("X-Ceems-Cluster-Id", clusterID)

client := &http.Client{Timeout: 10 * time.Second}

Expand Down Expand Up @@ -96,7 +97,7 @@ ceems_lb:

// Query LB
for i := range 10 {
if err := queryLB("localhost:9030/default"); err == nil {
if err := queryLB("localhost:9030", "default"); err == nil {
break
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/lb/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type QueryParamsContextKey struct{}

// QueryParams is the context value.
type QueryParams struct {
id string
clusterID string
uuids []string
queryPeriod time.Duration
}
Expand Down Expand Up @@ -315,7 +315,7 @@ func (lb *loadBalancer) Serve(w http.ResponseWriter, r *http.Request) {

if v, ok := queryParams.(*QueryParams); ok {
queryPeriod = v.queryPeriod
id = v.id
id = v.clusterID
} else {
http.Error(w, "Invalid query parameters", http.StatusBadRequest)

Expand Down
8 changes: 4 additions & 4 deletions pkg/lb/frontend/frontend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func TestNewFrontendSingleGroup(t *testing.T) {
newReq = request.WithContext(
context.WithValue(
request.Context(), QueryParamsContextKey{},
&QueryParams{queryPeriod: period, id: clusterID},
&QueryParams{queryPeriod: period, clusterID: clusterID},
),
)
} else {
Expand All @@ -164,7 +164,7 @@ func TestNewFrontendSingleGroup(t *testing.T) {
newReq := request.WithContext(
context.WithValue(
request.Context(), QueryParamsContextKey{},
&QueryParams{id: "default"},
&QueryParams{clusterID: "default"},
),
)
responseRecorder := httptest.NewRecorder()
Expand Down Expand Up @@ -260,7 +260,7 @@ func TestNewFrontendTwoGroups(t *testing.T) {
newReq = request.WithContext(
context.WithValue(
request.Context(), QueryParamsContextKey{},
&QueryParams{queryPeriod: period, id: test.clusterID},
&QueryParams{queryPeriod: period, clusterID: test.clusterID},
),
)
} else {
Expand All @@ -284,7 +284,7 @@ func TestNewFrontendTwoGroups(t *testing.T) {
newReq := request.WithContext(
context.WithValue(
request.Context(), QueryParamsContextKey{},
&QueryParams{id: "rm-0"},
&QueryParams{clusterID: "rm-0"},
),
)
responseRecorder := httptest.NewRecorder()
Expand Down
73 changes: 38 additions & 35 deletions pkg/lb/frontend/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,59 +70,62 @@ func setQueryParams(r *http.Request, queryParams *QueryParams) *http.Request {
}

// Parse query in the request after cloning it and add query params to context.
func parseQueryParams(r *http.Request, rmIDs []string, logger log.Logger) *http.Request {
func parseQueryParams(r *http.Request, logger log.Logger) *http.Request {
var body []byte

var id string
var clusterID string

var uuids []string

var queryPeriod time.Duration

var err error

// Get id from path parameter.
// Requested paths will be of form /{id}/<rest of path>. Here will strip `id`
// part and proxy the rest to backend
var pathParts []string

for _, p := range strings.Split(r.URL.Path, "/") {
if strings.TrimSpace(p) == "" {
continue
}

pathParts = append(pathParts, p)
}

// First path part must be resource manager ID and check if it is in the valid IDs
if len(pathParts) > 0 {
if slices.Contains(rmIDs, pathParts[0]) {
id = pathParts[0]

// If there is more than 1 pathParts, make URL or set / as URL
if len(pathParts) > 1 {
r.URL.Path = "/" + strings.Join(pathParts[1:], "/")
r.RequestURI = r.URL.Path
} else {
r.URL.Path = "/"
r.RequestURI = "/"
}
}
}
// Get cluster id from X-Ceems-Cluster-Id header
clusterID = r.Header.Get(ceemsClusterIDHeader)

// // Get id from path parameter.
// // Requested paths will be of form /{id}/<rest of path>. Here will strip `id`
// // part and proxy the rest to backend
// var pathParts []string

// for _, p := range strings.Split(r.URL.Path, "/") {
// if strings.TrimSpace(p) == "" {
// continue
// }

// pathParts = append(pathParts, p)
// }

// // First path part must be resource manager ID and check if it is in the valid IDs
// if len(pathParts) > 0 {
// if slices.Contains(rmIDs, pathParts[0]) {
// id = pathParts[0]

// // If there is more than 1 pathParts, make URL or set / as URL
// if len(pathParts) > 1 {
// r.URL.Path = "/" + strings.Join(pathParts[1:], "/")
// r.RequestURI = r.URL.Path
// } else {
// r.URL.Path = "/"
// r.RequestURI = "/"
// }
// }
// }

// Make a new request and add newReader to that request body
clonedReq := r.Clone(r.Context())

// If request has no body go to proxy directly
if r.Body == nil {
return setQueryParams(r, &QueryParams{id, uuids, queryPeriod})
return setQueryParams(r, &QueryParams{clusterID, uuids, queryPeriod})
}

// If failed to read body, skip verification and go to request proxy
if body, err = io.ReadAll(r.Body); err != nil {
level.Error(logger).Log("msg", "Failed to read request body", "err", err)

return setQueryParams(r, &QueryParams{id, uuids, queryPeriod})
return setQueryParams(r, &QueryParams{clusterID, uuids, queryPeriod})
}

// clone body to existing request and new request
Expand All @@ -133,7 +136,7 @@ func parseQueryParams(r *http.Request, rmIDs []string, logger log.Logger) *http.
if err = clonedReq.ParseForm(); err != nil {
level.Error(logger).Log("msg", "Could not parse request body", "err", err)

return setQueryParams(r, &QueryParams{id, uuids, queryPeriod})
return setQueryParams(r, &QueryParams{clusterID, uuids, queryPeriod})
}

// Parse TSDB's query in request query params
Expand All @@ -159,7 +162,7 @@ func parseQueryParams(r *http.Request, rmIDs []string, logger log.Logger) *http.
for _, idMatch := range strings.Split(match[1], "|") {
// Ignore empty strings
if strings.TrimSpace(idMatch) != "" {
id = strings.TrimSpace(idMatch)
clusterID = strings.TrimSpace(idMatch)
}
}
}
Expand All @@ -184,7 +187,7 @@ func parseQueryParams(r *http.Request, rmIDs []string, logger log.Logger) *http.
}

// Set query params to request's context
return setQueryParams(r, &QueryParams{id, uuids, queryPeriod})
return setQueryParams(r, &QueryParams{clusterID, uuids, queryPeriod})
}

// Parse time parameter in request.
Expand Down
4 changes: 2 additions & 2 deletions pkg/lb/frontend/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ func TestParseQueryParams(t *testing.T) {
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
}

newReq := parseQueryParams(req, test.rmIDs, log.NewNopLogger())
newReq := parseQueryParams(req, log.NewNopLogger())
queryParams := newReq.Context().Value(QueryParamsContextKey{}).(*QueryParams) //nolint:forcetypeassert
assert.Equal(t, queryParams.uuids, test.uuids)
assert.Equal(t, queryParams.id, test.rmID)
assert.Equal(t, queryParams.clusterID, test.rmID)

if test.method == "POST" {
// Check the new request body can still be parsed
Expand Down
46 changes: 35 additions & 11 deletions pkg/lb/frontend/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/url"
"regexp"
"slices"
"strings"

"github.com/go-kit/log"
Expand All @@ -16,11 +17,12 @@ import (

// Headers.
const (
grafanaUserHeader = "X-Grafana-User"
dashboardUserHeader = "X-Dashboard-User"
loggedUserHeader = "X-Logged-User"
adminUserHeader = "X-Admin-User"
ceemsUserHeader = "X-Ceems-User"
grafanaUserHeader = "X-Grafana-User"
dashboardUserHeader = "X-Dashboard-User"
loggedUserHeader = "X-Logged-User"
adminUserHeader = "X-Admin-User"
ceemsUserHeader = "X-Ceems-User"
ceemsClusterIDHeader = "X-Ceems-Cluster-Id"
)

var (
Expand All @@ -30,7 +32,7 @@ var (
// Playground: https://goplay.tools/snippet/kq_r_1SOgnG
regexpUUID = regexp.MustCompile("(?:.+?)[^gpu]uuid=[~]{0,1}\"(?P<uuid>[a-zA-Z0-9-|]+)\"(?:.*)")

// Regex that will match unit's ID.
// Regex that will match cluster's ID.
regexID = regexp.MustCompile("(?:.+?)ceems_id=[~]{0,1}\"(?P<id>[a-zA-Z0-9-|_]+)\"(?:.*)")
)

Expand Down Expand Up @@ -124,23 +126,24 @@ func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var loggedUser string

var id string
var clusterID string

var uuids []string

var queryParams interface{}

// Clone request, parse query params and set them in request context
// This will ensure we set query params in request's context always
r = parseQueryParams(r, amw.clusterIDs, amw.logger)
r = parseQueryParams(r, amw.logger)

// Apply middleware only for following endpoints:
// - query
// - query_range
// - labels
// - labels values
// - series
if !strings.HasSuffix(r.URL.Path, "query") && !strings.HasSuffix(r.URL.Path, "query_range") &&
if !strings.HasSuffix(r.URL.Path, "query") &&
!strings.HasSuffix(r.URL.Path, "query_range") &&
!strings.HasSuffix(r.URL.Path, "values") &&
!strings.HasSuffix(r.URL.Path, "labels") &&
!strings.HasSuffix(r.URL.Path, "series") {
Expand Down Expand Up @@ -206,9 +209,30 @@ func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler

// Check type assertions
if v, ok := queryParams.(*QueryParams); ok {
id = v.id
clusterID = v.clusterID
uuids = v.uuids

// Verify clusterID is in list of valid cluster IDs
if !slices.Contains(amw.clusterIDs, clusterID) {
// Write an error and stop the handler chain
w.WriteHeader(http.StatusBadRequest)

response := ceems_api.Response[any]{
Status: "error",
ErrorType: "bad_request",
Error: "invalid cluster ID",
}
if err := json.NewEncoder(w).Encode(&response); err != nil {
level.Error(amw.logger).Log("msg", "Failed to encode response", "err", err)
w.Write([]byte("KO"))
}

return
}
} else {
// Write an error and stop the handler chain
w.WriteHeader(http.StatusBadRequest)

response := ceems_api.Response[any]{
Status: "error",
ErrorType: "bad_request",
Expand All @@ -223,7 +247,7 @@ func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler
}

// Check if user is querying for his/her own compute units by looking to DB
if !amw.isUserUnit(r.Context(), loggedUser, []string{id}, uuids) { //nolint:contextcheck // False positive
if !amw.isUserUnit(r.Context(), loggedUser, []string{clusterID}, uuids) { //nolint:contextcheck // False positive
// Write an error and stop the handler chain
w.WriteHeader(http.StatusForbidden)

Expand Down
Loading

0 comments on commit 220a18e

Please sign in to comment.