diff --git a/internal/server/api.go b/internal/server/api.go index 0c33394..5567f3a 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -102,7 +102,7 @@ func (a *API) Complete() { } // ServeHTTP handles requests to the server -func (a *API) ServeHTTP(_ http.ResponseWriter, r *http.Request) { +func (a *API) ServeHTTP(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) if err != nil { err = fmt.Errorf("failed to read body: %w", err) @@ -122,6 +122,12 @@ func (a *API) ServeHTTP(_ http.ResponseWriter, r *http.Request) { a.pushError(err) } + if actual == nil { + // indicates the kind (endpoint) isn't implemented in decodeWrapper, so return a 501 + w.WriteHeader(http.StatusNotImplemented) + return + } + if kind == "increment_metric" { // Let's just output the metrics data and stop a.outputRequestData(kind, actual) diff --git a/internal/server/api_test.go b/internal/server/api_test.go index c3b2b27..9f4a331 100644 --- a/internal/server/api_test.go +++ b/internal/server/api_test.go @@ -1,6 +1,8 @@ package server import ( + "net/http" + "net/http/httptest" "testing" ) @@ -12,3 +14,17 @@ func Test_decodeWrapper(t *testing.T) { } }) } + +func TestAPI_ServeHTTP(t *testing.T) { + t.Run("doesn't crash when unknown endpoint is used", func(t *testing.T) { + request := httptest.NewRequest("POST", "/unexpected-endpoint", nil) + response := httptest.NewRecorder() + + api := NewAPI(nil, nil) + api.ServeHTTP(response, request) + + if response.Code != http.StatusNotImplemented { + t.Errorf("expected status code %d, got %d", http.StatusNotImplemented, response.Code) + } + }) +}