diff --git a/internal/server/connect.go b/internal/server/connect.go index ff42023..dd34085 100644 --- a/internal/server/connect.go +++ b/internal/server/connect.go @@ -100,11 +100,13 @@ func serveHijack(w http.ResponseWriter, targetConn net.Conn, clientAddr string, if bufReader != nil { // snippet borrowed from `proxy` plugin if n := bufReader.Reader.Buffered(); n > 0 { - rbuf, err := bufReader.Reader.Peek(n) - if err != nil { - return http.StatusBadGateway, err + rbuf, peekErr := bufReader.Reader.Peek(n) + if peekErr != nil { + return http.StatusBadGateway, peekErr + } + if _, writeErr := targetConn.Write(rbuf); writeErr != nil { + return http.StatusBadGateway, writeErr } - targetConn.Write(rbuf) } } // Since we hijacked the connection, we lost the ability to write and flush headers via w. @@ -134,7 +136,11 @@ var bufferPool = &sync.Pool{New: func() interface{} { func dualStream(target net.Conn, clientReader io.ReadCloser, clientWriter io.Writer, clientAddr string, hostPort string, username string) error { stream := func(w io.Writer, r io.Reader) error { // copy bytes from r to w - buf := bufferPool.Get().([]byte) + buf, ok := bufferPool.Get().([]byte) + if !ok { + return errors.New("failed to get buffer from pool") + } + // nolint:staticcheck defer bufferPool.Put(buf) buf = buf[0:cap(buf)] nw, _err := flushingIoCopy(w, r, buf) @@ -149,7 +155,11 @@ func dualStream(target net.Conn, clientReader io.ReadCloser, clientWriter io.Wri return _err } - go stream(target, clientReader) + go func() { + if err := stream(target, clientReader); err != nil { + log.Println("Error in stream:", err) + } + }() return stream(clientWriter, target) } diff --git a/internal/server/handleFunc.go b/internal/server/handleFunc.go index ee81032..bb2e7fb 100644 --- a/internal/server/handleFunc.go +++ b/internal/server/handleFunc.go @@ -8,7 +8,7 @@ import ( "strings" ) -func writeIp(w http.ResponseWriter, r *http.Request) { +func writeIP(w http.ResponseWriter, r *http.Request) { remoteAddr := r.RemoteAddr index := strings.LastIndex(remoteAddr, ":") @@ -25,8 +25,8 @@ func fileHandlerFunc() http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if config.GlobalConfig.Refer != "" && r.Header.Get("referer") != "" && !strings.Contains(r.Header.Get("referer"), config.GlobalConfig.Refer) && (strings.HasSuffix(r.URL.Path, ".html") || strings.HasSuffix(r.URL.Path, "/")) { - HttpRequst.WithLabelValues(r.Header.Get("referer"), r.URL.Path).Inc() - HttpRequst.WithLabelValues("all", "all").Inc() + ReqCount.WithLabelValues(r.Header.Get("referer"), r.URL.Path).Inc() + ReqCount.WithLabelValues("all", "all").Inc() } if containsDotDot(r.URL.Path) { diff --git a/internal/server/server.go b/internal/server/server.go index 3d249e0..d5b5f0a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -13,13 +13,13 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) -var ssl_cert *tls.Certificate = nil -var ssl_last_cert_update time.Time = time.Now() +var sslCert *tls.Certificate = nil +var sslLastCertUpdateTime time.Time = time.Now() -const ssl_cert_update_interval = 5 * time.Hour +const sslCertUpdateInterval = 5 * time.Hour var ( - HttpRequst = promauto.NewCounterVec(prometheus.CounterOpts{ + ReqCount = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "req_from_out_total", Help: "Number of HTTP requests received", }, []string{"referer", "path"}) @@ -30,7 +30,7 @@ var ( ) func Serve() error { - http.HandleFunc("/ip", writeIp) + http.HandleFunc("/ip", writeIP) http.Handle("/metrics", promhttp.Handler()) http.HandleFunc("/", fileHandlerFunc()) @@ -48,7 +48,7 @@ func Serve() error { WriteTimeout: 31 * time.Second, // Set idle timeout TLSConfig: &tls.Config{ GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { - return load_new_cert_if_need(globalConfig.Cert, globalConfig.PrivKey) + return loadNewCertIfNeed(globalConfig.Cert, globalConfig.PrivKey) }, }, } @@ -63,23 +63,23 @@ func Serve() error { return <-errors } -func load_new_cert_if_need(cert_file, privkey string) (*tls.Certificate, error) { +func loadNewCertIfNeed(certFile, privkey string) (*tls.Certificate, error) { now := time.Now() - if ssl_cert == nil || now.Sub(ssl_last_cert_update) > ssl_cert_update_interval { - cert, err := tls.LoadX509KeyPair(cert_file, privkey) + if sslCert == nil || now.Sub(sslLastCertUpdateTime) > sslCertUpdateInterval { + cert, err := tls.LoadX509KeyPair(certFile, privkey) if err != nil { log.Println("Error loading certificate", err) - if ssl_cert != nil { - return ssl_cert, nil + if sslCert != nil { + return sslCert, nil } return nil, err } else { - log.Println("Loaded certificate", cert_file, privkey) + log.Println("Loaded certificate", certFile, privkey) } - ssl_cert = &cert - ssl_last_cert_update = now + sslCert = &cert + sslLastCertUpdateTime = now return &cert, nil } else { - return ssl_cert, nil + return sslCert, nil } }