From 29e5ef281064dac1133605a0a7b4d7a86229d6e7 Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Fri, 14 Jun 2024 17:42:33 +0200 Subject: [PATCH] Compress: add server-side encoding priority --- middleware/compress/compress.go | 22 +++++-- middleware/compress/compress_test.go | 95 ++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 7 deletions(-) diff --git a/middleware/compress/compress.go b/middleware/compress/compress.go index dd74b9bf..3b63eb21 100644 --- a/middleware/compress/compress.go +++ b/middleware/compress/compress.go @@ -89,7 +89,10 @@ func (w *Gzip) NewWriter(wr io.Writer) io.WriteCloser { // and the value returned by the `Encoder`'s `Encoding()` method. Quality values in // the headers are taken into account. // -// If the header's value is `*`, the first element of the slice is used. +// In case of equal priority, the encoding that is the earliest in the slice is chosen. +// If the header's value is `*` and no encoding already matched, +// the first element of the slice is used. +// // If none of the accepted encodings are available in the `Encoders` slice, then the // response will not be compressed and the middleware immediately passes. // @@ -142,17 +145,22 @@ func (m *Middleware) getEncoder(response *goyave.Response, request *goyave.Reque if response.Hijacked() || request.Header().Get("Upgrade") != "" { return nil } - encodings := httputil.ParseMultiValuesHeader(request.Header().Get("Accept-Encoding")) - for _, h := range encodings { - if h.Value == "*" { - return m.Encoders[0] - } + acceptedEncodings := httputil.ParseMultiValuesHeader(request.Header().Get("Accept-Encoding")) + groupedByPriority := lo.PartitionBy(acceptedEncodings, func(h httputil.HeaderValue) float64 { + return h.Priority + }) + for _, h := range groupedByPriority { w, ok := lo.Find(m.Encoders, func(w Encoder) bool { - return w.Encoding() == h.Value + return lo.ContainsBy(h, func(h httputil.HeaderValue) bool { return h.Value == w.Encoding() }) }) if ok { return w } + + hasWildCard := lo.ContainsBy(h, func(h httputil.HeaderValue) bool { return h.Value == "*" }) + if hasWildCard { + return m.Encoders[0] + } } return nil diff --git a/middleware/compress/compress_test.go b/middleware/compress/compress_test.go index a7674e39..e2736f2f 100644 --- a/middleware/compress/compress_test.go +++ b/middleware/compress/compress_test.go @@ -214,3 +214,98 @@ func TestCompressWriter(t *testing.T) { require.NoError(t, writer.Close()) assert.True(t, closeableWriter.closed) } + +type testEncoder struct { + encoding string +} + +func (e *testEncoder) NewWriter(_ io.Writer) io.WriteCloser { + return nil +} + +func (e *testEncoder) Encoding() string { + return e.encoding +} + +func TestEncoderPriority(t *testing.T) { + + gzip := &testEncoder{encoding: "gzip"} + br := &testEncoder{encoding: "br"} + zstd := &testEncoder{encoding: "zstd"} + + cases := []struct { + want Encoder + acceptEncoding string + encoders []Encoder + }{ + { + encoders: []Encoder{br, zstd, gzip}, + acceptEncoding: "gzip, deflate, br, zstd", + want: br, + }, + { + encoders: []Encoder{br, zstd, gzip}, + acceptEncoding: "*", + want: br, + }, + { + encoders: []Encoder{br, zstd, gzip}, + acceptEncoding: "gzip, *", + want: gzip, + }, + { + encoders: []Encoder{br, zstd, gzip}, + acceptEncoding: "gzip, br, *", + want: br, + }, + { + encoders: []Encoder{br, zstd, gzip}, + acceptEncoding: "*, gzip, br", + want: br, + }, + { + encoders: []Encoder{br, zstd, gzip}, + acceptEncoding: "gzip, *;q=0.9", + want: gzip, + }, + { + encoders: []Encoder{br, zstd, gzip}, + acceptEncoding: "gzip", + want: gzip, + }, + { + encoders: []Encoder{br, zstd, gzip}, + acceptEncoding: "zstd;q=0.9, br;q=0.9", + want: br, + }, + { + encoders: []Encoder{br, zstd, gzip}, + acceptEncoding: "zstd;q=0.9, br;q=0.8", + want: zstd, + }, + { + encoders: []Encoder{br, zstd, gzip}, + acceptEncoding: "gzip;q=0.8, *;q=0.1", + want: gzip, + }, + { + encoders: []Encoder{br, zstd, gzip}, + acceptEncoding: "gzip;q=0.8, *;q=1.0", + want: br, + }, + } + + for _, c := range cases { + c := c + t.Run(c.acceptEncoding, func(t *testing.T) { + middleware := &Middleware{ + Encoders: c.encoders, + } + request := testutil.NewTestRequest(http.MethodGet, "/", nil) + request.Header().Set("Accept-Encoding", c.acceptEncoding) + response, _ := testutil.NewTestResponse(request) + e := middleware.getEncoder(response, request) + assert.Equal(t, c.want, e) + }) + } +}