diff --git a/proxy.go b/proxy.go index 0058333..10d5872 100644 --- a/proxy.go +++ b/proxy.go @@ -326,7 +326,12 @@ loop: protocolError() return } - sess, ok = value.(map[string]interface{})["sessionId"].(string) + valueMap, ok := value.(map[string]interface{}) + if !ok { + protocolError() + return + } + sess, ok = valueMap["sessionId"].(string) if !ok { protocolError() return @@ -512,7 +517,7 @@ func withCloseNotifier(handler http.HandlerFunc) http.HandlerFunc { cancel() }() select { - case <-w.(http.CloseNotifier).CloseNotify(): + case <-r.Context().Done(): cancel() case <-ctx.Done(): } diff --git a/proxy_test.go b/proxy_test.go index 4266136..5f7ebe8 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -1207,6 +1207,36 @@ func TestStartSessionJSONWireProtocol(t *testing.T) { AssertThat(t, value["value"].(map[string]interface{})["sessionId"], EqualTo{fmt.Sprintf("%s123", node.Sum())}) } +func TestPanicRouteProtocolError(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/wd/hub/session", postOnly(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`{"value":[]}`)) + })) + selenium := httptest.NewServer(mux) + defer selenium.Close() + + host, port := hostportnum(selenium.URL) + node := Host{Name: host, Port: port, Count: 1} + + test.Lock() + defer test.Unlock() + + browsers := Browsers{Browsers: []Browser{ + {Name: "browser", DefaultVersion: "1.0", Versions: []Version{ + {Number: "1.0", Regions: []Region{ + {Hosts: Hosts{ + node, + }}, + }}, + }}}} + updateQuota(user, browsers) + + rsp, err := createSession(`{"desiredCapabilities":{"browserName":"browser", "version":"1.0"}}`) + + AssertThat(t, err, Is{nil}) + AssertThat(t, rsp.StatusCode, Is{http.StatusBadGateway}) +} + func TestDeleteSession(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/wd/hub/session/", func(w http.ResponseWriter, r *http.Request) {