Skip to content

Commit

Permalink
most for grpc
Browse files Browse the repository at this point in the history
  • Loading branch information
isayme committed May 17, 2024
1 parent 5a56fee commit a841f5a
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 48 deletions.
7 changes: 1 addition & 6 deletions cmd/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cmd

import (
"context"
"io"
"net"
"sync"
"time"
Expand Down Expand Up @@ -86,11 +85,7 @@ func handleConnection(conn net.Conn, tc tunnel.Client) {
var n int64
n, err = util.CopyBuffer(remote, conn)
logger.Debugw("copy from client end", "n", n, "err", err)
if err != nil && err != io.EOF {
once.Do(func() {
remote.Close()
})
}
remote.CloseWrite()
}()

go func() {
Expand Down
9 changes: 5 additions & 4 deletions cmd/server.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package cmd

import (
"io"

"github.com/isayme/go-logger"
"github.com/isayme/tox/conf"
"github.com/isayme/tox/socks5"
Expand Down Expand Up @@ -38,8 +36,11 @@ func startServer() {
}
}

func handler(rw io.ReadWriter) {
request := socks5.NewRequest(rw)
/**
* return when server will not send data anymore
*/
func handler(conn util.ServerConn) {
request := socks5.NewRequest(conn)
if err := request.Handle(); err != nil {
logger.Errorw("socks5 fail", "err", err)
}
Expand Down
21 changes: 11 additions & 10 deletions socks5/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ import (
)

type Request struct {
rw io.ReadWriter
rw util.ServerConn

cmd byte
atyp byte
addr string
}

func NewRequest(rw io.ReadWriter) *Request {
func NewRequest(conn util.ServerConn) *Request {
return &Request{
rw: rw,
rw: conn,
}
}

Expand Down Expand Up @@ -145,17 +145,17 @@ func (r *Request) negotiate() error {
}

func (r *Request) handleRequest() error {
conn, err := net.DialTimeout("tcp", r.addr, time.Second*5)
remote, err := net.DialTimeout("tcp", r.addr, time.Second*5)
if err != nil {
logger.Infow("net.Dial fail", "err", err, "addr", r.addr)
return err
}
defer conn.Close()
defer remote.Close()

config := conf.Get()

tcpConn, _ := conn.(*net.TCPConn)
conn = util.NewTimeoutConn(conn, time.Duration(config.Timeout)*time.Second)
remoteTcpConn, _ := remote.(*net.TCPConn)
remote = util.NewTimeoutConn(remote, time.Duration(config.Timeout)*time.Second)

logger.Infow("connect ok", "addr", r.addr)

Expand All @@ -167,7 +167,8 @@ func (r *Request) handleRequest() error {

var err error
var n int64
n, err = util.CopyBuffer(r.rw, conn)
n, err = util.CopyBuffer(r.rw, remote)
r.rw.CloseWrite()
logger.Debugw("copy from remote end", "n", n, "err", err)
}()

Expand All @@ -176,9 +177,9 @@ func (r *Request) handleRequest() error {

var err error
var n int64
n, err = util.CopyBuffer(conn, r.rw)
n, err = util.CopyBuffer(remote, r.rw)
logger.Debugw("copy from client end", "n", n, "err", err)
tcpConn.CloseWrite()
remoteTcpConn.CloseWrite()
}()

wg.Wait()
Expand Down
11 changes: 8 additions & 3 deletions tunnel/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import (
"bytes"
"context"
"crypto/tls"
"io"
"net/url"

pool "github.com/isayme/go-grpcpool"
"github.com/isayme/go-logger"
"github.com/isayme/tox/proto"
"github.com/isayme/tox/util"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
Expand Down Expand Up @@ -82,7 +82,7 @@ func NewClient(tunnel string, password string) (*Client, error) {
}, nil
}

func (t *Client) Connect(ctx context.Context) (io.ReadWriteCloser, error) {
func (t *Client) Connect(ctx context.Context) (util.LocalConn, error) {
conn, err := t.p.Get()
if err != nil {
return nil, err
Expand Down Expand Up @@ -141,5 +141,10 @@ func (rw *clientReadWriter) Write(p []byte) (int, error) {

func (rw *clientReadWriter) Close() error {
rw.c.CloseSend()
return rw.conn.Close()
rw.conn.Close()
return rw.conn.Value().Close()
}

func (rw *clientReadWriter) CloseWrite() error {
return rw.c.CloseSend()
}
12 changes: 8 additions & 4 deletions tunnel/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ package grpc
import (
"bytes"
"fmt"
"io"
"net"
"net/url"

"github.com/isayme/tox/proto"
"github.com/isayme/tox/util"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

type Server struct {
proto.UnimplementedTunnelServer

handler func(io.ReadWriter)
handler func(util.ServerConn)
tunnel string
key string
}
Expand All @@ -27,7 +27,7 @@ func NewServer(tunnel string, password string) (*Server, error) {
}, nil
}

func (s *Server) ListenAndServe(handler func(io.ReadWriter)) error {
func (s *Server) ListenAndServe(handler func(util.ServerConn)) error {
URL, err := url.Parse(s.tunnel)
if err != nil {
return err
Expand Down Expand Up @@ -56,7 +56,7 @@ func (s *Server) ListenAndServe(handler func(io.ReadWriter)) error {
return grpcs.Serve(l)
}

func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(io.ReadWriter)) error {
func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(util.ServerConn)) error {
URL, err := url.Parse(s.tunnel)
if err != nil {
return err
Expand Down Expand Up @@ -140,3 +140,7 @@ func (rw *serverReadWriter) Write(p []byte) (int, error) {
}
return len(p), nil
}

func (rw *serverReadWriter) CloseWrite() error {
return nil
}
16 changes: 13 additions & 3 deletions tunnel/h2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"

"github.com/isayme/tox/util"
"github.com/posener/h2conn"
)

Expand Down Expand Up @@ -56,7 +56,7 @@ func NewClient(tunnel string, password string) (*Client, error) {
}, nil
}

func (t *Client) Connect(ctx context.Context) (io.ReadWriteCloser, error) {
func (t *Client) Connect(ctx context.Context) (util.LocalConn, error) {
remote, resp, err := t.h2Client.Connect(ctx, t.tunnel)
if err != nil {
return nil, err
Expand All @@ -67,5 +67,15 @@ func (t *Client) Connect(ctx context.Context) (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("h2: bad status code: %d", resp.StatusCode)
}

return remote, nil
return &h2LocalConn{
Conn: remote,
}, nil
}

type h2LocalConn struct {
*h2conn.Conn
}

func (conn *h2LocalConn) CloseWrite() error {
return conn.Close()
}
10 changes: 5 additions & 5 deletions tunnel/h2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package h2

import (
"fmt"
"io"
"net/http"
"net/url"

"github.com/isayme/go-logger"
"github.com/isayme/tox/util"
"github.com/posener/h2conn"
)

Expand All @@ -22,11 +22,11 @@ func NewServer(tunnel string, password string) (*Server, error) {
}, nil
}

func (s *Server) ListenAndServe(handler func(io.ReadWriter)) error {
func (s *Server) ListenAndServe(handler func(util.ServerConn)) error {
return fmt.Errorf("tls required for http2 protocol")
}

func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(io.ReadWriter)) error {
func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(util.ServerConn)) error {
URL, err := url.Parse(s.tunnel)
if err != nil {
return err
Expand All @@ -44,9 +44,9 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(io.Rea
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
defer conn.Close()
// defer conn.Close()

handler(conn)
handler(&h2LocalConn{Conn: conn})
})

addr := fmt.Sprintf(":%s", URL.Port())
Expand Down
7 changes: 3 additions & 4 deletions tunnel/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package tunnel
import (
"context"
"fmt"
"io"
"net/url"

"github.com/isayme/go-logger"
Expand All @@ -14,12 +13,12 @@ import (
)

type Client interface {
Connect(context.Context) (io.ReadWriteCloser, error)
Connect(context.Context) (util.LocalConn, error)
}

type Server interface {
ListenAndServe(handler func(io.ReadWriter)) error
ListenAndServeTLS(certFile, keyFile string, handler func(io.ReadWriter)) error
ListenAndServe(handler func(util.ServerConn)) error
ListenAndServeTLS(certFile, keyFile string, handler func(util.ServerConn)) error
}

func NewClient(tunnel string, password string) (Client, error) {
Expand Down
14 changes: 11 additions & 3 deletions tunnel/websocket/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package websocket
import (
"context"
"crypto/tls"
"io"
"net/url"

"github.com/isayme/tox/util"
"golang.org/x/net/websocket"
)

Expand Down Expand Up @@ -41,10 +41,18 @@ func NewClient(tunnel string, password string) (*Client, error) {
}, nil
}

func (t *Client) Connect(ctx context.Context) (io.ReadWriteCloser, error) {
func (t *Client) Connect(ctx context.Context) (util.LocalConn, error) {
ws, err := websocket.DialConfig(t.config)
if err != nil {
return nil, err
}
return ws, nil
return &wsLocalConn{Conn: ws}, nil
}

type wsLocalConn struct {
*websocket.Conn
}

func (conn *wsLocalConn) CloseWrite() error {
return conn.Conn.Close()
}
12 changes: 6 additions & 6 deletions tunnel/websocket/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@ package websocket

import (
"fmt"
"io"
"net/http"
"net/url"

"github.com/isayme/tox/util"
"golang.org/x/net/websocket"
)

type Server struct {
tunnel string
password string
handler func(io.ReadWriter)
handler func(util.ServerConn)
}

func NewServer(tunnel string, password string) (*Server, error) {
Expand All @@ -22,7 +22,7 @@ func NewServer(tunnel string, password string) (*Server, error) {
}, nil
}

func (s *Server) ListenAndServe(handler func(io.ReadWriter)) error {
func (s *Server) ListenAndServe(handler func(util.ServerConn)) error {
URL, err := url.Parse(s.tunnel)
if err != nil {
return err
Expand All @@ -40,7 +40,7 @@ func (s *Server) ListenAndServe(handler func(io.ReadWriter)) error {
return http.ListenAndServe(addr, nil)
}

func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(io.ReadWriter)) error {
func (s *Server) ListenAndServeTLS(certFile, keyFile string, handler func(util.ServerConn)) error {
URL, err := url.Parse(s.tunnel)
if err != nil {
return err
Expand Down Expand Up @@ -68,12 +68,12 @@ func (s *Server) handshakeWebsocket(config *websocket.Config, req *http.Request)
}

func (s *Server) handleWebsocket(ws *websocket.Conn) {
defer ws.Close()
// defer ws.Close()

token := ws.Request().Header.Get("token")
if token != s.password {
return
}

s.handler(ws)
s.handler(&wsLocalConn{Conn: ws})
}
16 changes: 16 additions & 0 deletions util/conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package util

import "io"

type LocalConn interface {
io.Reader
io.Writer
io.Closer
CloseWrite() error
}

type ServerConn interface {
io.Reader
io.Writer
CloseWrite() error
}

0 comments on commit a841f5a

Please sign in to comment.