diff --git a/pkg/frontend/transport/handler.go b/pkg/frontend/transport/handler.go index 36263c2c5d..d323a138cd 100644 --- a/pkg/frontend/transport/handler.go +++ b/pkg/frontend/transport/handler.go @@ -234,20 +234,20 @@ func (f *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { f.reportQueryStats(r, userID, queryString, queryResponseTime, stats, err, statusCode, resp) } + hs := w.Header() + if f.cfg.QueryStatsEnabled { + writeServiceTimingHeader(queryResponseTime, hs, stats) + } + if err != nil { - writeError(w, err) + writeError(w, err, hs) return } - hs := w.Header() for h, vs := range resp.Header { hs[h] = vs } - if f.cfg.QueryStatsEnabled { - writeServiceTimingHeader(queryResponseTime, hs, stats) - } - w.WriteHeader(resp.StatusCode) // log copy response body error so that we will know even though success response code returned bytesCopied, err := io.Copy(w, resp.Body) @@ -422,7 +422,7 @@ func formatQueryString(queryString url.Values) (fields []interface{}) { return fields } -func writeError(w http.ResponseWriter, err error) { +func writeError(w http.ResponseWriter, err error, additionalHeaders http.Header) { switch err { case context.Canceled: err = errCanceled @@ -433,7 +433,22 @@ func writeError(w http.ResponseWriter, err error) { err = errRequestEntityTooLarge } } - server.WriteError(w, err) + + resp, ok := httpgrpc.HTTPResponseFromError(err) + if ok { + for k, values := range additionalHeaders { + resp.Headers = append(resp.Headers, &httpgrpc.Header{Key: k, Values: values}) + } + _ = server.WriteResponse(w, resp) + } else { + headers := w.Header() + for k, values := range additionalHeaders { + for _, value := range values { + headers.Set(k, value) + } + } + http.Error(w, err.Error(), http.StatusInternalServerError) + } } func writeServiceTimingHeader(queryResponseTime time.Duration, headers http.Header, stats *querier_stats.QueryStats) { diff --git a/pkg/frontend/transport/handler_test.go b/pkg/frontend/transport/handler_test.go index 7cf4a4bd34..955323fa99 100644 --- a/pkg/frontend/transport/handler_test.go +++ b/pkg/frontend/transport/handler_test.go @@ -31,19 +31,31 @@ func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { func TestWriteError(t *testing.T) { for _, test := range []struct { - status int - err error + status int + err error + additionalHeaders http.Header }{ - {http.StatusInternalServerError, errors.New("unknown")}, - {http.StatusGatewayTimeout, context.DeadlineExceeded}, - {StatusClientClosedRequest, context.Canceled}, - {http.StatusBadRequest, httpgrpc.Errorf(http.StatusBadRequest, "")}, - {http.StatusRequestEntityTooLarge, errors.New("http: request body too large")}, + {http.StatusInternalServerError, errors.New("unknown"), http.Header{"User-Agent": []string{"Golang"}}}, + {http.StatusInternalServerError, errors.New("unknown"), nil}, + {http.StatusGatewayTimeout, context.DeadlineExceeded, nil}, + {StatusClientClosedRequest, context.Canceled, nil}, + {StatusClientClosedRequest, context.Canceled, http.Header{"User-Agent": []string{"Golang"}}}, + {StatusClientClosedRequest, context.Canceled, http.Header{"User-Agent": []string{"Golang"}, "Content-Type": []string{"application/json"}}}, + {http.StatusBadRequest, httpgrpc.Errorf(http.StatusBadRequest, ""), http.Header{}}, + {http.StatusRequestEntityTooLarge, errors.New("http: request body too large"), http.Header{}}, } { t.Run(test.err.Error(), func(t *testing.T) { w := httptest.NewRecorder() - writeError(w, test.err) + writeError(w, test.err, test.additionalHeaders) require.Equal(t, test.status, w.Result().StatusCode) + expectedAdditionalHeaders := test.additionalHeaders + if expectedAdditionalHeaders != nil { + for key, value := range w.Header() { + if values, ok := expectedAdditionalHeaders[key]; ok { + require.Equal(t, values, value) + } + } + } }) } }