Skip to content

Commit

Permalink
support ssh fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Oct 21, 2022
1 parent 165aea6 commit 8fcf198
Show file tree
Hide file tree
Showing 14 changed files with 633 additions and 372 deletions.
223 changes: 114 additions & 109 deletions client.go → client/client.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package client

import (
"crypto/tls"
Expand All @@ -11,22 +11,16 @@ import (
"strings"
"time"

"github.com/wwqgtxx/wstunnel/common"
"github.com/wwqgtxx/wstunnel/config"
"github.com/wwqgtxx/wstunnel/tunnel"
"github.com/wwqgtxx/wstunnel/utils"

"github.com/gorilla/websocket"
)

var PortToClient = make(map[string]Client)

type Client interface {
ClientImpl
Start()
Addr() string
GetClientImpl() ClientImpl
SetClientImpl(impl ClientImpl)
GetServerWSPath() string
}

type client struct {
ClientImpl
common.ClientImpl
bindAddress string
serverWSPath string
}
Expand Down Expand Up @@ -55,27 +49,18 @@ func (c *client) Addr() string {
return c.bindAddress
}

func (c *client) GetClientImpl() ClientImpl {
func (c *client) GetClientImpl() common.ClientImpl {
return c.ClientImpl
}

func (c *client) SetClientImpl(impl ClientImpl) {
func (c *client) SetClientImpl(impl common.ClientImpl) {
c.ClientImpl = impl
}

func (c *client) GetServerWSPath() string {
return c.serverWSPath
}

type ClientImpl interface {
Target() string
Proxy() string
Handle(tcp net.Conn)
Dial(edBuf []byte, inHeader http.Header) (io.Closer, error)
ToRawConn(conn io.Closer) net.Conn
Tunnel(tcp net.Conn, conn io.Closer)
}

type wsClientImpl struct {
header http.Header
wsUrl string
Expand All @@ -86,7 +71,7 @@ type wsClientImpl struct {

type tcpClientImpl struct {
targetAddress string
netDial netDialerFunc
netDial NetDialerFunc
proxy string
}

Expand All @@ -101,7 +86,7 @@ func (c *wsClientImpl) Proxy() string {
func (c *wsClientImpl) Handle(tcp net.Conn) {
defer tcp.Close()
log.Println("Incoming --> ", tcp.RemoteAddr(), " --> ", c.Target(), c.Proxy())
header, edBuf, err := encodeXray0rtt(tcp, c)
header, edBuf, err := utils.EncodeXray0rtt(tcp, c.ed)
if err != nil {
log.Println(err)
return
Expand Down Expand Up @@ -155,7 +140,7 @@ func (c *wsClientImpl) ToRawConn(conn io.Closer) net.Conn {
}

func (c *wsClientImpl) Tunnel(tcp net.Conn, conn io.Closer) {
TunnelTcpWs(tcp, conn.(*websocket.Conn))
tunnel.TunnelTcpWs(tcp, conn.(*websocket.Conn))
}

func (c *tcpClientImpl) Target() string {
Expand Down Expand Up @@ -191,15 +176,29 @@ func (c *tcpClientImpl) ToRawConn(conn io.Closer) net.Conn {
}

func (c *tcpClientImpl) Tunnel(tcp net.Conn, conn io.Closer) {
TunnelTcpTcp(tcp, conn.(net.Conn))
tunnel.TunnelTcpTcp(tcp, conn.(net.Conn))
}

func BuildClient(config ClientConfig) {
var cImpl ClientImpl
var proxyUrl *url.URL
var proxyStr string
if len(config.Proxy) > 0 {
u, err := url.Parse(config.Proxy)
func BuildClient(config config.ClientConfig) {
_, port, err := net.SplitHostPort(config.BindAddress)
if err != nil {
log.Println(err)
}

serverWSPath := strings.ReplaceAll(config.ServerWSPath, "{port}", port)

c := &client{
ClientImpl: NewClientImpl(config),
bindAddress: config.BindAddress,
serverWSPath: serverWSPath,
}

common.PortToClient[port] = c
}

func parseProxy(proxyString string) (proxyUrl *url.URL, proxyStr string) {
if len(proxyString) > 0 {
u, err := url.Parse(proxyString)
if err != nil {
log.Println(err)
}
Expand All @@ -209,110 +208,112 @@ func BuildClient(config ClientConfig) {
ru.User = nil
proxyStr = ru.String()
}
return
}

func NewClientImpl(config config.ClientConfig) common.ClientImpl {
if len(config.TargetAddress) > 0 {
var netDial netDialerFunc
tcpDialer := &net.Dialer{
Timeout: 45 * time.Second,
}
netDial = tcpDialer.Dial
return NewTcpClientImpl(config)
} else {
return NewWsClientImpl(config)
}
}

proxyDialer := proxy_FromEnvironment()
if proxyUrl != nil {
dialer, err := proxy_FromURL(proxyUrl, netDial)
if err != nil {
log.Println(err)
} else {
proxyDialer = dialer
}
}
if proxyDialer != proxy_Direct {
netDial = proxyDialer.Dial
}
func NewTcpClientImpl(config config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(config.Proxy)

cImpl = &tcpClientImpl{
targetAddress: config.TargetAddress,
netDial: netDial,
proxy: proxyStr,
}
} else {
proxy := http.ProxyFromEnvironment
if proxyUrl != nil {
proxy = http.ProxyURL(proxyUrl)
}
var netDial NetDialerFunc
tcpDialer := &net.Dialer{
Timeout: 45 * time.Second,
}
netDial = tcpDialer.Dial

header := http.Header{}
if len(config.WSHeaders) != 0 {
for key, value := range config.WSHeaders {
header.Add(key, value)
}
}
wsDialer := &websocket.Dialer{
Proxy: proxy,
HandshakeTimeout: 45 * time.Second,
ReadBufferSize: BufSize,
WriteBufferSize: BufSize,
WriteBufferPool: WriteBufferPool,
}
wsDialer.TLSClientConfig = &tls.Config{
ServerName: config.ServerName,
InsecureSkipVerify: config.SkipCertVerify,
}
var ed uint32
if u, err := url.Parse(config.WSUrl); err == nil {
if q := u.Query(); q.Get("ed") != "" {
Ed, _ := strconv.Atoi(q.Get("ed"))
ed = uint32(Ed)
q.Del("ed")
u.RawQuery = q.Encode()
config.WSUrl = u.String()
}
}
cImpl = &wsClientImpl{
header: header,
wsUrl: config.WSUrl,
wsDialer: wsDialer,
ed: ed,
proxy: proxyStr,
proxyDialer := proxy_FromEnvironment()
if proxyUrl != nil {
dialer, err := proxy_FromURL(proxyUrl, netDial)
if err != nil {
log.Println(err)
} else {
proxyDialer = dialer
}
}
_, port, err := net.SplitHostPort(config.BindAddress)
if err != nil {
log.Println(err)
if proxyDialer != proxy_Direct {
netDial = proxyDialer.Dial
}

serverWSPath := strings.ReplaceAll(config.ServerWSPath, "{port}", port)
return &tcpClientImpl{
targetAddress: config.TargetAddress,
netDial: netDial,
proxy: proxyStr,
}
}

c := &client{
ClientImpl: cImpl,
bindAddress: config.BindAddress,
serverWSPath: serverWSPath,
func NewWsClientImpl(config config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(config.Proxy)

proxy := http.ProxyFromEnvironment
if proxyUrl != nil {
proxy = http.ProxyURL(proxyUrl)
}

PortToClient[port] = c
header := http.Header{}
if len(config.WSHeaders) != 0 {
for key, value := range config.WSHeaders {
header.Add(key, value)
}
}
wsDialer := &websocket.Dialer{
Proxy: proxy,
HandshakeTimeout: 45 * time.Second,
ReadBufferSize: tunnel.BufSize,
WriteBufferSize: tunnel.BufSize,
WriteBufferPool: tunnel.WriteBufferPool,
}
wsDialer.TLSClientConfig = &tls.Config{
ServerName: config.ServerName,
InsecureSkipVerify: config.SkipCertVerify,
}
var ed uint32
if u, err := url.Parse(config.WSUrl); err == nil {
if q := u.Query(); q.Get("ed") != "" {
Ed, _ := strconv.Atoi(q.Get("ed"))
ed = uint32(Ed)
q.Del("ed")
u.RawQuery = q.Encode()
config.WSUrl = u.String()
}
}
return &wsClientImpl{
header: header,
wsUrl: config.WSUrl,
wsDialer: wsDialer,
ed: ed,
proxy: proxyStr,
}
}

func StartClients() {
for clientPort, client := range PortToClient {
for clientPort, client := range common.PortToClient {
if !strings.HasPrefix(client.Target(), "ws") {
host, port, err := net.SplitHostPort(client.Target())
if err != nil {
log.Println(err)
}

if host == "127.0.0.1" || host == "localhost" {
if _server, ok := PortToServer[port]; ok {
if _server, ok := common.PortToServer[port]; ok {
log.Println("Short circuit replace (",
client.Addr(), "<->", client.Target(),
") to ( [Server]",
client.Addr(),
")")
server := _server.CloneWithNewAddress(client.Addr())
PortToServer[clientPort] = server
delete(PortToClient, clientPort) //It is safe in Golang!!!
newServer := _server.CloneWithNewAddress(client.Addr())
common.PortToServer[clientPort] = newServer
delete(common.PortToClient, clientPort) //It is safe in Golang!!!
continue
}

if _client, ok := PortToClient[port]; ok {
if _client, ok := common.PortToClient[port]; ok {
log.Println("Short circuit replace (",
client.Addr(), "<->", client.Target(),
") to (",
Expand All @@ -325,3 +326,7 @@ func StartClients() {
client.Start()
}
}

func init() {
common.NewClientImpl = NewClientImpl
}
14 changes: 7 additions & 7 deletions proxy.go → client/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main
package client

import (
"bufio"
Expand All @@ -14,18 +14,18 @@ import (
"strings"
)

type netDialerFunc func(network, addr string) (net.Conn, error)

func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr)
}

func init() {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
})
}

type NetDialerFunc func(network, addr string) (net.Conn, error)

func (fn NetDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr)
}

func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host
hostNoPort = u.Host
Expand Down
2 changes: 1 addition & 1 deletion x_net_proxy.go → client/x_net_proxy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 8fcf198

Please sign in to comment.