diff --git a/logRequests.go b/logRequests.go index ac0e847..d624032 100644 --- a/logRequests.go +++ b/logRequests.go @@ -1,6 +1,10 @@ package transport import ( + "bytes" + "encoding/json" + "fmt" + "io" "log" "net/http" "time" @@ -8,23 +12,60 @@ import ( "moul.io/http2curl/v2" ) -func LogRequests(next http.RoundTripper) http.RoundTripper { - return RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) { - r := CloneRequest(req) +type RequestLogger interface { + LogRequest(req *http.Request, curl *http2curl.CurlCommand) + LogResponse(r *http.Request, resp *http.Response, startTime time.Time) +} - curlCommand, _ := http2curl.GetCurlCommand(r) - log.Printf("%v", curlCommand) - log.Printf("request: %s %s", r.Method, r.URL) +func LogRequests(logger RequestLogger) func(next http.RoundTripper) http.RoundTripper { + return func(next http.RoundTripper) http.RoundTripper { + return RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) { + r := CloneRequest(req) - startTime := time.Now() - defer func() { - if resp != nil { - log.Printf("response (HTTP %v): %v %s", time.Since(startTime), resp.Status, r.URL) - } else { - log.Printf("response (): %v %s", time.Since(startTime), r.URL) - } - }() + curlCommand, _ := http2curl.GetCurlCommand(r) + + logger.LogRequest(req, curlCommand) + + startTime := time.Now() + defer func() { + logger.LogResponse(r, resp, startTime) + }() + + return next.RoundTrip(r) + }) + } +} + +type DefaultLogger struct { + PrintResponsePayload bool +} - return next.RoundTrip(r) - }) +func (d *DefaultLogger) LogRequest(r *http.Request, curl *http2curl.CurlCommand) { + log.Printf(curl.String()) + log.Printf("request: %s %s", r.Method, r.URL) +} + +func (d *DefaultLogger) LogResponse(r *http.Request, resp *http.Response, startTime time.Time) { + if resp == nil { + log.Printf(fmt.Sprintf("response (): %v %s", time.Since(startTime), r.URL)) + return + } + + log.Printf("response: %s %s", resp.Status, resp.Request.URL) + + if d.PrintResponsePayload && resp.Header.Get("Content-Type") == "application/json" { + var b bytes.Buffer + + tee := io.TeeReader(resp.Body, &b) + resp.Body = io.NopCloser(&b) + + payload, err := io.ReadAll(tee) + if err == nil { + // Pretty print the JSON payload + var prettyJSON bytes.Buffer + if err := json.Indent(&prettyJSON, payload, "", " "); err == nil { + log.Printf("%s", prettyJSON.String()) + } + } + } } diff --git a/setHeaderFunc_test.go b/setHeaderFunc_test.go index bf48d02..5c96a90 100644 --- a/setHeaderFunc_test.go +++ b/setHeaderFunc_test.go @@ -34,7 +34,7 @@ func TestSetHeaderFunc(t *testing.T) { Transport: transport.Chain( http.DefaultTransport, transport.SetHeaderFunc("Authorization", issueRandomAuthToken), - transport.LogRequests, + transport.LogRequests(&transport.DefaultLogger{PrintResponsePayload: true}), ), Timeout: 15 * time.Second, } diff --git a/setHeader_test.go b/setHeader_test.go index 8ea71bf..83e03c5 100644 --- a/setHeader_test.go +++ b/setHeader_test.go @@ -45,7 +45,7 @@ func TestSetHeader(t *testing.T) { transport.SetHeader("User-Agent", userAgent), transport.SetHeader("Authorization", authHeader), transport.SetHeader("x-extra", "value"), - transport.LogRequests, + transport.LogRequests(&transport.DefaultLogger{PrintResponsePayload: true}), ), Timeout: 15 * time.Second, } diff --git a/transport_test.go b/transport_test.go index e2edd3f..cc5ea65 100644 --- a/transport_test.go +++ b/transport_test.go @@ -24,7 +24,7 @@ func TestChain(t *testing.T) { Transport: Chain( nil, SetHeader("User-Agent", "transport-chain/v1.0.0"), - LogRequests, + LogRequests(&DefaultLogger{PrintResponsePayload: true}), ), } @@ -63,7 +63,7 @@ func TestChainWithRetries(t *testing.T) { Transport: Chain( http.DefaultTransport, Retry(http.DefaultTransport, 5), - LogRequests, + LogRequests(&DefaultLogger{PrintResponsePayload: true}), ), } @@ -103,7 +103,7 @@ func TestChainWithRetryAfter(t *testing.T) { Transport: Chain( http.DefaultTransport, Retry(http.DefaultTransport, 5), - LogRequests, + LogRequests(&DefaultLogger{PrintResponsePayload: true}), ), }