diff --git a/server/service.go b/server/service.go index 78427a59e..123cbf74f 100644 --- a/server/service.go +++ b/server/service.go @@ -4,6 +4,7 @@ import ( "context" "net/http" + "github.com/gorilla/mux" "google.golang.org/grpc" ) @@ -105,3 +106,9 @@ func (h ContextHandlerFunc) ServeHTTPContext(ctx context.Context, rw http.Respon type ContextHandler interface { ServeHTTPContext(context.Context, http.ResponseWriter, *http.Request) } + +// GorillaService lets you define a gorilla configured +// Router as the main service for SimpleServer +type GorillaService interface { + Gorilla() *mux.Router +} diff --git a/server/simple_server.go b/server/simple_server.go index 7adef2878..c8a80a058 100644 --- a/server/simple_server.go +++ b/server/simple_server.go @@ -201,6 +201,7 @@ func (s *SimpleServer) Register(svcI Service) error { ss SimpleService cs ContextService mcs MixedContextService + gs GorillaService ) switch svc := svcI.(type) { @@ -216,6 +217,8 @@ func (s *SimpleServer) Register(svcI Service) error { cs = svc case ContextService: cs = svc + case GorillaService: + gs = svc default: return errors.New("services for SimpleServers must implement the SimpleService, JSONService or MixedService interfaces") } @@ -259,6 +262,11 @@ func (s *SimpleServer) Register(svcI Service) error { } } + if gs != nil { + s.mux = &GorillaRouter{gs.Gorilla()} + s.h = svcI.Middleware(s.mux) + } + RegisterProfiler(s.cfg, s.mux) return nil } diff --git a/server/simple_server_test.go b/server/simple_server_test.go index 9f3ec559f..7986e7c2c 100644 --- a/server/simple_server_test.go +++ b/server/simple_server_test.go @@ -9,6 +9,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/gorilla/mux" ) type benchmarkContextService struct { @@ -494,3 +496,35 @@ func TestNotFoundHandler(t *testing.T) { t.Errorf("expected response body to be \"\", got %q", gotBody) } } + +type gorillaService struct { + mux *mux.Router +} + +func (gs *gorillaService) Prefix() string { + return "" +} + +func (gs *gorillaService) Middleware(h http.Handler) http.Handler { + return h +} + +func (gs *gorillaService) Gorilla() *mux.Router { + return gs.mux +} + +func TestGorillaService(t *testing.T) { + r := mux.NewRouter() + var called bool + r.HandleFunc("/svc", func(w http.ResponseWriter, r *http.Request) { + called = true + }) + ss := NewSimpleServer(nil) + ss.Register(&gorillaService{mux: r}) + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/svc", nil) + ss.ServeHTTP(w, req) + if !called { + t.Fatalf("Expected gorilla router to be called: %v", w.Result().Status) + } +}