diff --git a/server.go b/server.go index 83452db..0286243 100644 --- a/server.go +++ b/server.go @@ -63,6 +63,11 @@ type normalServerHandler struct { } func (s *normalServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !websocket.IsWebSocketUpgrade(r) { + closeTcpHandle(w, r) + return + } + log.Println("Incoming --> ", r.RemoteAddr, r.Header, s.DestAddress) ch := make(chan net.Conn) @@ -84,9 +89,6 @@ func (s *normalServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) go func() { defer close(ch) - if !websocket.IsWebSocketUpgrade(r) { - return - } tcp, err := net.Dial("tcp", s.DestAddress) if err != nil { log.Println(err) @@ -125,6 +127,11 @@ type internalServerHandler struct { } func (s *internalServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !websocket.IsWebSocketUpgrade(r) { + closeTcpHandle(w, r) + return + } + log.Println("Incoming --> ", r.RemoteAddr, r.Header, " --> ( [Client]", s.DestAddress, ") --> ", s.Client.Target()) ch := make(chan io.Closer) @@ -144,9 +151,6 @@ func (s *internalServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request go func() { defer close(ch) - if !websocket.IsWebSocketUpgrade(r) { - return - } // send inHeader to client for Xray's 0rtt ws ws2, err := s.Client.Dial(r.Header) if err != nil { @@ -174,8 +178,21 @@ func (s *internalServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request TunnelTcpTcp(target, source) } +func closeTcpHandle(writer http.ResponseWriter, request *http.Request) { + h, ok := writer.(http.Hijacker) + if !ok { + return + } + netConn, _, err := h.Hijack() + if err != nil { + return + } + _ = netConn.Close() +} + func BuildServer(config ServerConfig) { mux := http.NewServeMux() + hadRoot := false for _, target := range config.Target { if len(target.WSPath) == 0 { target.WSPath = "/" @@ -202,8 +219,14 @@ func BuildServer(config ServerConfig) { } } + if target.WSPath == "/" { + hadRoot = true + } mux.Handle(target.WSPath, sh) } + if !hadRoot { + mux.HandleFunc("/", closeTcpHandle) + } var s Server s = &server{ bindAddress: config.BindAddress,