Skip to content

Commit

Permalink
#215 - Response: add support for flushing (#218)
Browse files Browse the repository at this point in the history
Response: add support for flushing #215

- Response implement http.Flusher
- Add goyave.Flusher interface
- Call PreWrite only once on the first Write
- Add CommonWriter to reduce chained writers boilerplate
- Use CommonWriter for log and compress middleware
  • Loading branch information
System-Glitch authored Jul 23, 2024
1 parent 31b83c8 commit 21f3c8c
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 40 deletions.
26 changes: 7 additions & 19 deletions log/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ type Formatter func(ctx *Context) (message string, attributes []slog.Attr)
// Writer chained writer keeping response body in memory.
// Used for loggin in common format.
type Writer struct {
goyave.Component
goyave.CommonWriter
formatter Formatter
writer io.Writer
request *goyave.Request
response *goyave.Response
length int
Expand All @@ -47,28 +46,20 @@ var _ goyave.PreWriter = (*Writer)(nil)
// formatter.
func NewWriter(server *goyave.Server, response *goyave.Response, request *goyave.Request, formatter Formatter) *Writer {
writer := &Writer{
request: request,
writer: response.Writer(),
response: response,
formatter: formatter,
CommonWriter: goyave.NewCommonWriter(response.Writer()),
request: request,
response: response,
formatter: formatter,
}
writer.Init(server)
return writer
}

// PreWrite calls PreWrite on the
// child writer if it implements PreWriter.
func (w *Writer) PreWrite(b []byte) {
if pr, ok := w.writer.(goyave.PreWriter); ok {
pr.PreWrite(b)
}
}

// Write writes the data as a response and keeps its length in memory
// for later logging.
func (w *Writer) Write(b []byte) (int, error) {
w.length += len(b)
n, err := w.writer.Write(b)
n, err := w.CommonWriter.Write(b)
return n, errors.New(err)
}

Expand All @@ -90,10 +81,7 @@ func (w *Writer) Close() error {
w.Logger().Info(message, lo.Map(attrs, func(a slog.Attr, _ int) any { return a })...)
}

if wr, ok := w.writer.(io.Closer); ok {
return wr.Close()
}
return nil
return errors.New(w.CommonWriter.Close())
}

// AccessMiddleware captures response data and outputs it to the logger at the
Expand Down
28 changes: 18 additions & 10 deletions middleware/compress/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,37 @@ type Encoder interface {
}

type compressWriter struct {
io.WriteCloser
http.ResponseWriter
childWriter io.Writer
goyave.CommonWriter
responseWriter http.ResponseWriter
childWriter io.Writer
}

func (w *compressWriter) PreWrite(b []byte) {
if pr, ok := w.childWriter.(goyave.PreWriter); ok {
pr.PreWrite(b)
}
h := w.ResponseWriter.Header()
h := w.responseWriter.Header()
if h.Get("Content-Type") == "" {
h.Set("Content-Type", http.DetectContentType(b))
}
h.Del("Content-Length")
}

func (w *compressWriter) Write(b []byte) (int, error) {
n, err := w.WriteCloser.Write(b)
return n, errors.New(err)
func (w *compressWriter) Flush() error {
if err := w.CommonWriter.Flush(); err != nil {
return errors.New(err)
}
switch flusher := w.childWriter.(type) {
case goyave.Flusher:
return errors.New(flusher.Flush())
case http.Flusher:
flusher.Flush()
}
return nil
}

func (w *compressWriter) Close() error {
err := errors.New(w.WriteCloser.Close())
err := errors.New(w.CommonWriter.Close())

if wr, ok := w.childWriter.(io.Closer); ok {
return errors.New(wr.Close())
Expand Down Expand Up @@ -106,8 +114,8 @@ func (m *Middleware) Handle(next goyave.Handler) goyave.Handler {

respWriter := response.Writer()
compressWriter := &compressWriter{
WriteCloser: encoder.NewWriter(respWriter),
ResponseWriter: response,
CommonWriter: goyave.NewCommonWriter(encoder.NewWriter(respWriter)),
responseWriter: response,
childWriter: respWriter,
}
response.SetWriter(compressWriter)
Expand Down
67 changes: 62 additions & 5 deletions middleware/compress/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package compress
import (
"bytes"
"compress/gzip"
"fmt"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -16,10 +17,20 @@ import (
"goyave.dev/goyave/v5/util/testutil"
)

type testBuf struct {
*bytes.Buffer
flushErr error
}

func (w *testBuf) Flush() error {
return w.flushErr
}

type closeableChildWriter struct {
io.Writer
closed bool
preWritten bool
flushed bool
}

func (w *closeableChildWriter) PreWrite(b []byte) {
Expand All @@ -29,6 +40,17 @@ func (w *closeableChildWriter) PreWrite(b []byte) {
}
}

func (w *closeableChildWriter) Flush() error {
w.flushed = true
switch flusher := w.Writer.(type) {
case goyave.Flusher:
return flusher.Flush()
case http.Flusher:
flusher.Flush()
}
return nil
}

func (w *closeableChildWriter) Close() error {
w.closed = true
if wr, ok := w.Writer.(io.Closer); ok {
Expand All @@ -37,6 +59,20 @@ func (w *closeableChildWriter) Close() error {
return nil
}

type closeableChildWriterHTTPFlusher struct {
*closeableChildWriter
}

func (w *closeableChildWriterHTTPFlusher) Flush() {
w.flushed = true
switch flusher := w.Writer.(type) {
case goyave.Flusher:
_ = flusher.Flush()
case http.Flusher:
flusher.Flush()
}
}

func TestCompressMiddleware(t *testing.T) {
server := testutil.NewTestServerWithOptions(t, goyave.Options{Config: config.LoadDefault()})

Expand Down Expand Up @@ -149,26 +185,47 @@ func TestCompressWriter(t *testing.T) {
Level: gzip.BestCompression,
}

buf := bytes.NewBuffer([]byte{})
buf := &testBuf{Buffer: bytes.NewBuffer([]byte{})}
closeableWriter := &closeableChildWriter{
Writer: buf,
closed: false,
}

response := httptest.NewRecorder()

writer := &compressWriter{
WriteCloser: encoder.NewWriter(closeableWriter),
ResponseWriter: response,
CommonWriter: goyave.NewCommonWriter(encoder.NewWriter(closeableWriter)),
responseWriter: response,
childWriter: closeableWriter,
}

writer.PreWrite([]byte("hello world"))

assert.True(t, closeableWriter.preWritten)

assert.NoError(t, writer.Flush())
assert.True(t, closeableWriter.flushed)

buf.flushErr = fmt.Errorf("test error")
assert.ErrorIs(t, writer.Flush(), buf.flushErr)

require.NoError(t, writer.Close())
assert.True(t, closeableWriter.closed)

t.Run("http_flusher", func(t *testing.T) {
buf := &testBuf{Buffer: bytes.NewBuffer([]byte{})}
closeableWriter := &closeableChildWriterHTTPFlusher{
closeableChildWriter: &closeableChildWriter{
Writer: buf,
},
}
response := httptest.NewRecorder()
writer := &compressWriter{
CommonWriter: goyave.NewCommonWriter(encoder.NewWriter(closeableWriter)),
responseWriter: response,
childWriter: closeableWriter,
}
assert.NoError(t, writer.Flush())
assert.True(t, closeableWriter.flushed)
})
}

type testEncoder struct {
Expand Down
89 changes: 83 additions & 6 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,68 @@ var (

// PreWriter is a writter that needs to alter the response headers or status
// before they are written.
// If implemented, PreWrite will be called right before the Write operation.
// If implemented, PreWrite will be called right before the first `Write` operation.
type PreWriter interface {
PreWrite(b []byte)
}

// The Flusher interface is implemented by writers that allow
// handlers to flush buffered data to the client.
//
// Note that even for writers that support flushing, if the client
// is connected through an HTTP proxy, the buffered data may not reach
// the client until the response completes.
type Flusher interface {
Flush() error
}

// CommonWriter is a component meant to be used with composition
// to avoid having to implement the base behavior of the common interfaces
// a chained writer has to implement (`PreWrite()`, `Write()`, `Close()`, `Flush()`)
type CommonWriter struct {
Component
wr io.Writer
}

// NewCommonWriter create a new common writer that will output to the given `io.Writer`.
func NewCommonWriter(wr io.Writer) CommonWriter {
return CommonWriter{
wr: wr,
}
}

// PreWrite calls PreWrite on the
// child writer if it implements PreWriter.
func (w CommonWriter) PreWrite(b []byte) {
if pr, ok := w.wr.(PreWriter); ok {
pr.PreWrite(b)
}
}

func (w CommonWriter) Write(b []byte) (int, error) {
n, err := w.wr.Write(b)
return n, errorutil.New(err)
}

// Close the underlying writer if it implements `io.Closer`.
func (w CommonWriter) Close() error {
if wr, ok := w.wr.(io.Closer); ok {
return errorutil.New(wr.Close())
}
return nil
}

// Flush the underlying writer if it implements `goyave.Flusher` or `http.Flusher`.
func (w *CommonWriter) Flush() error {
switch flusher := w.wr.(type) {
case Flusher:
return errorutil.New(flusher.Flush())
case http.Flusher:
flusher.Flush()
}
return nil
}

// Response implementation wrapping `http.ResponseWriter`. Writing an HTTP response without
// using it is incorrect. This acts as a proxy to one or many `io.Writer` chained, with the original
// `http.ResponseWriter` always last.
Expand Down Expand Up @@ -81,10 +138,12 @@ func (r *Response) reset(server *Server, request *Request, writer http.ResponseW
// PreWrite writes the response header after calling PreWrite on the
// child writer if it implements PreWriter.
func (r *Response) PreWrite(b []byte) {
r.empty = false
if pr, ok := r.writer.(PreWriter); ok {
pr.PreWrite(b)
if r.empty {
if pr, ok := r.writer.(PreWriter); ok {
pr.PreWrite(b)
}
}
r.empty = false
if !r.wroteHeader {
if r.status == 0 {
r.status = http.StatusOK
Expand All @@ -97,7 +156,7 @@ func (r *Response) PreWrite(b []byte) {
// http.ResponseWriter implementation

// Write writes the data as a response.
// See http.ResponseWriter.Write
// See `http.ResponseWriter.Write`.
func (r *Response) Write(data []byte) (int, error) {
r.PreWrite(data)
n, err := r.writer.Write(data)
Expand Down Expand Up @@ -128,6 +187,25 @@ func (r *Response) Cookie(cookie *http.Cookie) {
http.SetCookie(r.responseWriter, cookie)
}

// Flush sends any buffered data to the client if the underlying
// writer implements `goyave.Flusher`.
//
// If the response headers have not been written already, `PreWrite()` will
// be called with an empty byte slice.
func (r *Response) Flush() {
if !r.wroteHeader {
r.PreWrite([]byte{})
}
switch flusher := r.writer.(type) {
case Flusher:
if err := flusher.Flush(); err != nil {
r.server.Logger.Error(errorutil.New(err))
}
case http.Flusher:
flusher.Flush()
}
}

// --------------------------------------
// http.Hijacker implementation

Expand Down Expand Up @@ -254,7 +332,6 @@ func (r *Response) writeFile(fs fs.StatFS, file string, disposition string) {
r.Status(http.StatusNotFound)
return
}
r.empty = false
r.status = http.StatusOK
mime, size, err := fsutil.GetMIMEType(fs, file)
if err != nil {
Expand Down
Loading

0 comments on commit 21f3c8c

Please sign in to comment.