Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
arloor committed Sep 27, 2024
1 parent 177082b commit 7c33fdd
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
22 changes: 16 additions & 6 deletions internal/server/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,13 @@ func serveHijack(w http.ResponseWriter, targetConn net.Conn, clientAddr string,
if bufReader != nil {
// snippet borrowed from `proxy` plugin
if n := bufReader.Reader.Buffered(); n > 0 {
rbuf, err := bufReader.Reader.Peek(n)
if err != nil {
return http.StatusBadGateway, err
rbuf, peekErr := bufReader.Reader.Peek(n)
if peekErr != nil {
return http.StatusBadGateway, peekErr
}
if _, writeErr := targetConn.Write(rbuf); writeErr != nil {
return http.StatusBadGateway, writeErr
}
targetConn.Write(rbuf)
}
}
// Since we hijacked the connection, we lost the ability to write and flush headers via w.
Expand Down Expand Up @@ -134,7 +136,11 @@ var bufferPool = &sync.Pool{New: func() interface{} {
func dualStream(target net.Conn, clientReader io.ReadCloser, clientWriter io.Writer, clientAddr string, hostPort string, username string) error {
stream := func(w io.Writer, r io.Reader) error {
// copy bytes from r to w
buf := bufferPool.Get().([]byte)
buf, ok := bufferPool.Get().([]byte)
if !ok {
return errors.New("failed to get buffer from pool")
}
// nolint:staticcheck
defer bufferPool.Put(buf)
buf = buf[0:cap(buf)]
nw, _err := flushingIoCopy(w, r, buf)
Expand All @@ -149,7 +155,11 @@ func dualStream(target net.Conn, clientReader io.ReadCloser, clientWriter io.Wri
return _err
}

go stream(target, clientReader)
go func() {
if err := stream(target, clientReader); err != nil {
log.Println("Error in stream:", err)
}
}()
return stream(clientWriter, target)
}

Expand Down
6 changes: 3 additions & 3 deletions internal/server/handleFunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"strings"
)

func writeIp(w http.ResponseWriter, r *http.Request) {
func writeIP(w http.ResponseWriter, r *http.Request) {

remoteAddr := r.RemoteAddr
index := strings.LastIndex(remoteAddr, ":")
Expand All @@ -25,8 +25,8 @@ func fileHandlerFunc() http.HandlerFunc {

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if config.GlobalConfig.Refer != "" && r.Header.Get("referer") != "" && !strings.Contains(r.Header.Get("referer"), config.GlobalConfig.Refer) && (strings.HasSuffix(r.URL.Path, ".html") || strings.HasSuffix(r.URL.Path, "/")) {
HttpRequst.WithLabelValues(r.Header.Get("referer"), r.URL.Path).Inc()
HttpRequst.WithLabelValues("all", "all").Inc()
ReqCount.WithLabelValues(r.Header.Get("referer"), r.URL.Path).Inc()
ReqCount.WithLabelValues("all", "all").Inc()
}

if containsDotDot(r.URL.Path) {
Expand Down
30 changes: 15 additions & 15 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
)

var ssl_cert *tls.Certificate = nil
var ssl_last_cert_update time.Time = time.Now()
var sslCert *tls.Certificate = nil
var sslLastCertUpdateTime time.Time = time.Now()

const ssl_cert_update_interval = 5 * time.Hour
const sslCertUpdateInterval = 5 * time.Hour

var (
HttpRequst = promauto.NewCounterVec(prometheus.CounterOpts{
ReqCount = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "req_from_out_total",
Help: "Number of HTTP requests received",
}, []string{"referer", "path"})
Expand All @@ -30,7 +30,7 @@ var (
)

func Serve() error {
http.HandleFunc("/ip", writeIp)
http.HandleFunc("/ip", writeIP)
http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/", fileHandlerFunc())

Expand All @@ -48,7 +48,7 @@ func Serve() error {
WriteTimeout: 31 * time.Second, // Set idle timeout
TLSConfig: &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return load_new_cert_if_need(globalConfig.Cert, globalConfig.PrivKey)
return loadNewCertIfNeed(globalConfig.Cert, globalConfig.PrivKey)
},
},
}
Expand All @@ -63,23 +63,23 @@ func Serve() error {
return <-errors
}

func load_new_cert_if_need(cert_file, privkey string) (*tls.Certificate, error) {
func loadNewCertIfNeed(certFile, privkey string) (*tls.Certificate, error) {
now := time.Now()
if ssl_cert == nil || now.Sub(ssl_last_cert_update) > ssl_cert_update_interval {
cert, err := tls.LoadX509KeyPair(cert_file, privkey)
if sslCert == nil || now.Sub(sslLastCertUpdateTime) > sslCertUpdateInterval {
cert, err := tls.LoadX509KeyPair(certFile, privkey)
if err != nil {
log.Println("Error loading certificate", err)
if ssl_cert != nil {
return ssl_cert, nil
if sslCert != nil {
return sslCert, nil
}
return nil, err
} else {
log.Println("Loaded certificate", cert_file, privkey)
log.Println("Loaded certificate", certFile, privkey)
}
ssl_cert = &cert
ssl_last_cert_update = now
sslCert = &cert
sslLastCertUpdateTime = now
return &cert, nil
} else {
return ssl_cert, nil
return sslCert, nil
}
}

0 comments on commit 7c33fdd

Please sign in to comment.