diff --git a/cmd/local.go b/cmd/local.go index 929f038..ebbbec3 100644 --- a/cmd/local.go +++ b/cmd/local.go @@ -2,7 +2,6 @@ package cmd import ( "context" - "io" "net" "sync" "time" @@ -86,11 +85,7 @@ func handleConnection(conn net.Conn, tc tunnel.Client) { var n int64 n, err = util.CopyBuffer(remote, conn) logger.Debugw("copy from client end", "n", n, "err", err) - if err != nil && err != io.EOF { - once.Do(func() { - remote.Close() - }) - } + remote.CloseWrite() }() go func() { diff --git a/cmd/server.go b/cmd/server.go index 0391cf9..3b64df9 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -1,8 +1,6 @@ package cmd import ( - "io" - "github.com/isayme/go-logger" "github.com/isayme/tox/conf" "github.com/isayme/tox/socks5" @@ -38,8 +36,11 @@ func startServer() { } } -func handler(rw io.ReadWriter) { - request := socks5.NewRequest(rw) +/** + * return when server will not send data anymore + */ +func handler(conn util.ServerConn) { + request := socks5.NewRequest(conn) if err := request.Handle(); err != nil { logger.Errorw("socks5 fail", "err", err) } diff --git a/socks5/request.go b/socks5/request.go index 1f9b690..f5eb3f8 100644 --- a/socks5/request.go +++ b/socks5/request.go @@ -16,16 +16,16 @@ import ( ) type Request struct { - rw io.ReadWriter + rw util.ServerConn cmd byte atyp byte addr string } -func NewRequest(rw io.ReadWriter) *Request { +func NewRequest(conn util.ServerConn) *Request { return &Request{ - rw: rw, + rw: conn, } } @@ -145,17 +145,17 @@ func (r *Request) negotiate() error { } func (r *Request) handleRequest() error { - conn, err := net.DialTimeout("tcp", r.addr, time.Second*5) + remote, err := net.DialTimeout("tcp", r.addr, time.Second*5) if err != nil { logger.Infow("net.Dial fail", "err", err, "addr", r.addr) return err } - defer conn.Close() + defer remote.Close() config := conf.Get() - tcpConn, _ := conn.(*net.TCPConn) - conn = util.NewTimeoutConn(conn, time.Duration(config.Timeout)*time.Second) + remoteTcpConn, _ := remote.(*net.TCPConn) + remote = util.NewTimeoutConn(remote, time.Duration(config.Timeout)*time.Second) logger.Infow("connect ok", "addr", r.addr) @@ -167,7 +167,8 @@ func (r *Request) handleRequest() error { var err error var n int64 - n, err = util.CopyBuffer(r.rw, conn) + n, err = util.CopyBuffer(r.rw, remote) + r.rw.CloseWrite() logger.Debugw("copy from remote end", "n", n, "err", err) }() @@ -176,9 +177,9 @@ func (r *Request) handleRequest() error { var err error var n int64 - n, err = util.CopyBuffer(conn, r.rw) + n, err = util.CopyBuffer(remote, r.rw) logger.Debugw("copy from client end", "n", n, "err", err) - tcpConn.CloseWrite() + remoteTcpConn.CloseWrite() }() wg.Wait() diff --git a/tunnel/grpc/client.go b/tunnel/grpc/client.go index 9c6abfc..38134fe 100644 --- a/tunnel/grpc/client.go +++ b/tunnel/grpc/client.go @@ -4,12 +4,12 @@ import ( "bytes" "context" "crypto/tls" - "io" "net/url" pool "github.com/isayme/go-grpcpool" "github.com/isayme/go-logger" "github.com/isayme/tox/proto" + "github.com/isayme/tox/util" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -82,7 +82,7 @@ func NewClient(tunnel string, password string) (*Client, error) { }, nil } -func (t *Client) Connect(ctx context.Context) (io.ReadWriteCloser, error) { +func (t *Client) Connect(ctx context.Context) (util.LocalConn, error) { conn, err := t.p.Get() if err != nil { return nil, err @@ -141,5 +141,10 @@ func (rw *clientReadWriter) Write(p []byte) (int, error) { func (rw *clientReadWriter) Close() error { rw.c.CloseSend() - return rw.conn.Close() + rw.conn.Close() + return rw.conn.Value().Close() +} + +func (rw *clientReadWriter) CloseWrite() error { + return rw.c.CloseSend() } diff --git a/tunnel/grpc/server.go b/tunnel/grpc/server.go index 4348f3d..c3d894a 100644 --- a/tunnel/grpc/server.go +++ b/tunnel/grpc/server.go @@ -3,11 +3,11 @@ package grpc import ( "bytes" "fmt" - "io" "net" "net/url" "github.com/isayme/tox/proto" + "github.com/isayme/tox/util" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) @@ -15,7 +15,7 @@ import ( type Server struct { proto.UnimplementedTunnelServer - handler func(io.ReadWriter) + handler func(util.ServerConn) tunnel string key string } @@ -27,7 +27,7 @@ func NewServer(tunnel string, password string) (*Server, error) { }, nil } -func (s *Server) ListenAndServe(handler func(io.ReadWriter)) error { +func (s *Server) ListenAndServe(handler func(util.ServerConn)) error { URL, err := url.Parse(s.tunnel) if err != nil { return err @@ -56,7 +56,7 @@ func (s *Server) ListenAndServe(handler func(io.ReadWriter)) error { return grpcs.Serve(l) } -func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(io.ReadWriter)) error { +func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(util.ServerConn)) error { URL, err := url.Parse(s.tunnel) if err != nil { return err @@ -140,3 +140,7 @@ func (rw *serverReadWriter) Write(p []byte) (int, error) { } return len(p), nil } + +func (rw *serverReadWriter) CloseWrite() error { + return nil +} diff --git a/tunnel/h2/client.go b/tunnel/h2/client.go index d4450d5..7cdcd3b 100644 --- a/tunnel/h2/client.go +++ b/tunnel/h2/client.go @@ -4,12 +4,12 @@ import ( "context" "crypto/tls" "fmt" - "io" "net" "net/http" "net/url" "time" + "github.com/isayme/tox/util" "github.com/posener/h2conn" ) @@ -56,7 +56,7 @@ func NewClient(tunnel string, password string) (*Client, error) { }, nil } -func (t *Client) Connect(ctx context.Context) (io.ReadWriteCloser, error) { +func (t *Client) Connect(ctx context.Context) (util.LocalConn, error) { remote, resp, err := t.h2Client.Connect(ctx, t.tunnel) if err != nil { return nil, err @@ -67,5 +67,15 @@ func (t *Client) Connect(ctx context.Context) (io.ReadWriteCloser, error) { return nil, fmt.Errorf("h2: bad status code: %d", resp.StatusCode) } - return remote, nil + return &h2LocalConn{ + Conn: remote, + }, nil +} + +type h2LocalConn struct { + *h2conn.Conn +} + +func (conn *h2LocalConn) CloseWrite() error { + return conn.Close() } diff --git a/tunnel/h2/server.go b/tunnel/h2/server.go index 784d073..79455fe 100644 --- a/tunnel/h2/server.go +++ b/tunnel/h2/server.go @@ -2,11 +2,11 @@ package h2 import ( "fmt" - "io" "net/http" "net/url" "github.com/isayme/go-logger" + "github.com/isayme/tox/util" "github.com/posener/h2conn" ) @@ -22,11 +22,11 @@ func NewServer(tunnel string, password string) (*Server, error) { }, nil } -func (s *Server) ListenAndServe(handler func(io.ReadWriter)) error { +func (s *Server) ListenAndServe(handler func(util.ServerConn)) error { return fmt.Errorf("tls required for http2 protocol") } -func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(io.ReadWriter)) error { +func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(util.ServerConn)) error { URL, err := url.Parse(s.tunnel) if err != nil { return err @@ -44,9 +44,9 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(io.Rea http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - defer conn.Close() + // defer conn.Close() - handler(conn) + handler(&h2LocalConn{Conn: conn}) }) addr := fmt.Sprintf(":%s", URL.Port()) diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 25ad968..e24973e 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -3,7 +3,6 @@ package tunnel import ( "context" "fmt" - "io" "net/url" "github.com/isayme/go-logger" @@ -14,12 +13,12 @@ import ( ) type Client interface { - Connect(context.Context) (io.ReadWriteCloser, error) + Connect(context.Context) (util.LocalConn, error) } type Server interface { - ListenAndServe(handler func(io.ReadWriter)) error - ListenAndServeTLS(certFile, keyFile string, handler func(io.ReadWriter)) error + ListenAndServe(handler func(util.ServerConn)) error + ListenAndServeTLS(certFile, keyFile string, handler func(util.ServerConn)) error } func NewClient(tunnel string, password string) (Client, error) { diff --git a/tunnel/websocket/client.go b/tunnel/websocket/client.go index eb0b231..d37dc28 100644 --- a/tunnel/websocket/client.go +++ b/tunnel/websocket/client.go @@ -3,9 +3,9 @@ package websocket import ( "context" "crypto/tls" - "io" "net/url" + "github.com/isayme/tox/util" "golang.org/x/net/websocket" ) @@ -41,10 +41,18 @@ func NewClient(tunnel string, password string) (*Client, error) { }, nil } -func (t *Client) Connect(ctx context.Context) (io.ReadWriteCloser, error) { +func (t *Client) Connect(ctx context.Context) (util.LocalConn, error) { ws, err := websocket.DialConfig(t.config) if err != nil { return nil, err } - return ws, nil + return &wsLocalConn{Conn: ws}, nil +} + +type wsLocalConn struct { + *websocket.Conn +} + +func (conn *wsLocalConn) CloseWrite() error { + return conn.Conn.Close() } diff --git a/tunnel/websocket/server.go b/tunnel/websocket/server.go index d12d42e..6d449ad 100644 --- a/tunnel/websocket/server.go +++ b/tunnel/websocket/server.go @@ -2,17 +2,17 @@ package websocket import ( "fmt" - "io" "net/http" "net/url" + "github.com/isayme/tox/util" "golang.org/x/net/websocket" ) type Server struct { tunnel string password string - handler func(io.ReadWriter) + handler func(util.ServerConn) } func NewServer(tunnel string, password string) (*Server, error) { @@ -22,7 +22,7 @@ func NewServer(tunnel string, password string) (*Server, error) { }, nil } -func (s *Server) ListenAndServe(handler func(io.ReadWriter)) error { +func (s *Server) ListenAndServe(handler func(util.ServerConn)) error { URL, err := url.Parse(s.tunnel) if err != nil { return err @@ -40,7 +40,7 @@ func (s *Server) ListenAndServe(handler func(io.ReadWriter)) error { return http.ListenAndServe(addr, nil) } -func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(io.ReadWriter)) error { +func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(util.ServerConn)) error { URL, err := url.Parse(s.tunnel) if err != nil { return err @@ -68,12 +68,12 @@ func (s *Server) handshakeWebsocket(config *websocket.Config, req *http.Request) } func (s *Server) handleWebsocket(ws *websocket.Conn) { - defer ws.Close() + // defer ws.Close() token := ws.Request().Header.Get("token") if token != s.password { return } - s.handler(ws) + s.handler(&wsLocalConn{Conn: ws}) } diff --git a/util/conn.go b/util/conn.go new file mode 100644 index 0000000..2119214 --- /dev/null +++ b/util/conn.go @@ -0,0 +1,16 @@ +package util + +import "io" + +type LocalConn interface { + io.Reader + io.Writer + io.Closer + CloseWrite() error +} + +type ServerConn interface { + io.Reader + io.Writer + CloseWrite() error +}