diff --git a/.gitignore b/.gitignore index 60fbd5b..79421d2 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,5 @@ Releases # Dependency directories (remove the comment below to include it) # vendor/ + +SNIProxy \ No newline at end of file diff --git a/go.mod b/go.mod index 5ed2ca4..9d6db06 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module github.com/XIU2/SNIProxy go 1.18 require ( - golang.org/x/net v0.0.0-20220812174116-3211cb980234 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect + golang.org/x/net v0.0.0-20220812174116-3211cb980234 + gopkg.in/yaml.v2 v2.4.0 ) diff --git a/go.sum b/go.sum index b7adba5..53a9dff 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ golang.org/x/net v0.0.0-20220812174116-3211cb980234 h1:RDqmgfe7SvlMWoqC3xwQ2blLO3fcWcxMa3eBLRdRW7E= golang.org/x/net v0.0.0-20220812174116-3211cb980234/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/main.go b/main.go index 71050a1..4ed468d 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,9 @@ import ( "syscall" "time" + // "net/http" + // _ "net/http/pprof" + "gopkg.in/yaml.v2" ) @@ -66,6 +69,12 @@ https://github.com/XIU2/SNIProxy } } +// func webPprof() { +// if err := http.ListenAndServe(":6060", nil); err != nil { +// serviceLogger(fmt.Sprintf("启动 pprof 服务失败: %v", err), 31, false) +// } +// } + func main() { data, err := os.ReadFile(ConfigFilePath) // 读取配置文件 if err != nil { @@ -87,6 +96,7 @@ func main() { serviceLogger(fmt.Sprintf("前置代理: %v", cfg.EnableSocks), 32, false) serviceLogger(fmt.Sprintf("任意域名: %v", cfg.AllowAllHosts), 32, false) + // go webPprof() // 启动 pprof 服务 startSniProxy() // 启动 SNI Proxy } @@ -238,35 +248,52 @@ func forward(conn net.Conn, data []byte, dst string, raddr string) { serviceLogger(fmt.Sprintf("无法传输到后端, %v", err), 31, false) return } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - conChk := make(chan int) - go ioReflector(backend, conn, false, conChk, raddr, dst) - go ioReflector(conn, backend, true, conChk, raddr, dst) + conChk := make(chan struct{}) + go ioReflector(ctx, backend, conn, false, conChk, raddr, dst) + go ioReflector(ctx, conn, backend, true, conChk, raddr, dst) <-conChk + // 取消上下文,通知另一个 ioReflector 退出 + cancel() } -// ioReflector 函数接收一个 io.WriteCloser 类型的写入对象 dst、一个 io.Reader 类型的读取对象 src、一个 bool 类型的 isToClient、一个 chan int 类型的 conChk,以及两个字符串类型的 raddr 和 dsts +// ioReflector 函数接收一个 io.WriteCloser 类型的写入对象 dst、一个 io.Reader 类型的读取对象 src、一个 bool 类型的 isToClient、一个 chan struct{} 类型的 conChk,以及两个字符串类型的 raddr 和 dsts // 该函数使用 io.Copy 函数将 src 中读取到的数据流复制到 dst 中,然后将转发的字节数写入日志 // 最后,该函数关闭 dst 连接,并向 conChk 通道发送一个信号以表示连接已关闭。 -func ioReflector(dst io.WriteCloser, src io.Reader, isToClient bool, conChk chan int, raddr string, dsts string) { +func ioReflector(ctx context.Context, dst io.WriteCloser, src io.Reader, isToClient bool, conChk chan struct{}, raddr string, dsts string) { // 将 IO 流反映到另一个 defer onDisconnect(dst, conChk) - written, _ := io.Copy(dst, src) - if isToClient { - serviceLogger(fmt.Sprintf("[%v] -> [%v] %d bytes", dsts, raddr, written), 33, true) - } else { - serviceLogger(fmt.Sprintf("[%v] -> [%v] %d bytes", raddr, dsts, written), 33, true) + + done := make(chan struct{}) + go func() { + written, _ := io.Copy(dst, src) + if isToClient { + serviceLogger(fmt.Sprintf("[%v] -> [%v] %d bytes", dsts, raddr, written), 33, true) + } else { + serviceLogger(fmt.Sprintf("[%v] -> [%v] %d bytes", raddr, dsts, written), 33, true) + } + close(done) + }() + + select { + case <-ctx.Done(): + // 上下文取消,退出 + case <-done: + // 复制完成,退出 } - dst.Close() - conChk <- 1 } -// onDisconnect 函数接收一个 io.WriteCloser 类型的写入对象 dst 和一个 chan int 类型的 conChk +// onDisconnect 函数接收一个 io.WriteCloser 类型的写入对象 dst 和一个 chan struct{} 类型的 conChk // 该函数在 dst 连接关闭时被调用,并向 conChk 通道发送一个信号以表示连接已关闭 -func onDisconnect(dst io.WriteCloser, conChk chan int) { +func onDisconnect(dst io.WriteCloser, conChk chan struct{}) { // 关闭时 -> 强制断开另一对连接 dst.Close() - conChk <- 1 + select { + case conChk <- struct{}{}: + default: + } } // 解析 Client Hello 消息