diff --git a/models/server.go b/models/server.go index fd50e05..7671ca1 100644 --- a/models/server.go +++ b/models/server.go @@ -89,8 +89,13 @@ func (p *Proxy) setDialer(requestContext RequestContext, isClearText bool) (Exit network = "tcp6" format = `[%s]:0` } + backendAddress := backend + if strings.Contains(backendAddress, "@") == true { + parsedURL, _ := url.Parse(backend) + backendAddress = fmt.Sprintf(`%s:%s`, parsedURL.Host, parsedURL.Port()) + } - addr, err := net.ResolveTCPAddr(network, fmt.Sprintf(format, backend)) + addr, err := net.ResolveTCPAddr(network, fmt.Sprintf(format, backendAddress)) if err != nil { log.Trace().Err(err).Str("backend", backend).Msg("Resolve") } @@ -118,7 +123,7 @@ func (p *Proxy) isInWhitelist(requestAddress string) bool { func (p *Proxy) handleRequest(responseWriter http.ResponseWriter, request *http.Request) { defer func() { - // Delete hop by hop headers + //Delete hop by hop headers for _, v := range []string{ "Proxy-Connection", "Proxy-Authorization", @@ -175,15 +180,28 @@ func (p *Proxy) handleHTTP(responseWriter http.ResponseWriter, request *http.Req DialContext(context context.Context, network, address string) (net.Conn, error) }).DialContext, } - if p.IsUpstream { u, err := url.Parse("http://" + exitNode.Upstream) if err != nil { log.Err(err).Str("upstream", exitNode.Upstream).Msg("error parsing upstream") return } + if credentials := u.User.String(); credentials != "" { + request.Header.Set("Proxy-Authorization", fmt.Sprintf("Basic %v", b64.StdEncoding.EncodeToString([]byte(credentials)))) + } else { + for _, v := range []string{ + "Proxy-Connection", + "Proxy-Authorization", + "Proxy-Authenticate", + "Te", + "Trailers", + } { + request.Header.Del(v) + } + } transport.Proxy = http.ProxyURL(u) } + response, err := transport.RoundTrip(request) if err != nil { return @@ -191,7 +209,6 @@ func (p *Proxy) handleHTTP(responseWriter http.ResponseWriter, request *http.Req defer response.Body.Close() copyHeader(responseWriter.Header(), response.Header) responseWriter.WriteHeader(response.StatusCode) - bytesTransferred, _ := io.Copy(responseWriter, response.Body) go func() { p.LogPayload(MetricPayload{ @@ -272,11 +289,9 @@ func (p *Proxy) handleTunnel(responseWriter http.ResponseWriter, request *http.R sourceConnection, _, err := hijacker.Hijack() if err != nil { - return - } - - if err != nil { - _ = sourceConnection.Close() + if sourceConnection != nil { + _ = sourceConnection.Close() + } return } _, _ = sourceConnection.Write([]byte(HTTP200))