Skip to content

Commit

Permalink
[kit] fix proxy in client
Browse files Browse the repository at this point in the history
  • Loading branch information
ehsannm committed Dec 15, 2023
1 parent 4a72ead commit b135843
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 52 deletions.
36 changes: 15 additions & 21 deletions kit/stub/option.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package stub

import (
"fmt"
"io"
"strings"
"time"

"github.com/clubpay/ronykit/kit"
Expand All @@ -25,7 +23,7 @@ type config struct {
tp kit.TracePropagator

readTimeout, writeTimeout, dialTimeout time.Duration
httpProxyConfig *httpproxy.Config
proxy *httpproxy.Config
dialFunc fasthttp.DialFunc
}

Expand Down Expand Up @@ -87,29 +85,25 @@ func WithTracePropagator(tp kit.TracePropagator) Option {
// WithHTTPProxy returns an Option that sets the dialer to the provided HTTP proxy.
// example formats:
//
// http://localhost:9050
// http://username:password@localhost:9050
// https://localhost:9050
func WithHTTPProxy(url string, timeout time.Duration) Option {
// localhost:9050
// username:password@localhost:9050
// localhost:9050
func WithHTTPProxy(proxyURL string, timeout time.Duration) Option {
return func(cfg *config) {
cfg.httpProxyConfig = httpproxy.FromEnvironment()
switch {
default:
panic(fmt.Errorf("unsupported proxy scheme: %s", url))
case strings.HasPrefix(url, "https://"):
cfg.httpProxyConfig.HTTPSProxy = url
case strings.HasPrefix(url, "http://"):
cfg.httpProxyConfig.HTTPProxy = url
}

cfg.dialFunc = fasthttpproxy.FasthttpHTTPDialerTimeout(url, timeout)
cfg.proxy = httpproxy.FromEnvironment()
cfg.proxy.HTTPProxy = proxyURL
cfg.proxy.HTTPSProxy = proxyURL
cfg.dialFunc = fasthttpproxy.FasthttpHTTPDialerTimeout(proxyURL, timeout)
}
}

// WithSocksProxy returns an Option that sets the dialer to the provided SOCKS5 proxy.
// example format: socks5://localhost:9050
func WithSocksProxy(url string) Option {
// example format: localhost:9050
func WithSocksProxy(proxyURL string) Option {
return func(cfg *config) {
cfg.dialFunc = fasthttpproxy.FasthttpSocksDialer(url)
cfg.proxy = httpproxy.FromEnvironment()
cfg.proxy.HTTPProxy = proxyURL
cfg.proxy.HTTPSProxy = proxyURL
cfg.dialFunc = fasthttpproxy.FasthttpSocksDialer(proxyURL)
}
}
68 changes: 37 additions & 31 deletions kit/stub/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type Stub struct {
cfg config
r *reflector.Reflector

httpC fasthttp.Client
httpC *fasthttp.Client
}

func New(hostPort string, opts ...Option) *Stub {
Expand All @@ -37,19 +37,24 @@ func New(hostPort string, opts ...Option) *Stub {
opt(&cfg)
}

return &Stub{
cfg: cfg,
r: reflector.New(),
httpC: fasthttp.Client{
Name: cfg.name,
ReadTimeout: cfg.readTimeout,
WriteTimeout: cfg.writeTimeout,
Dial: cfg.dialFunc,
TLSConfig: &tls.Config{
InsecureSkipVerify: cfg.skipVerifyTLS, //nolint:gosec
},
httpC := &fasthttp.Client{
Name: cfg.name,
ReadTimeout: cfg.readTimeout,
WriteTimeout: cfg.writeTimeout,
TLSConfig: &tls.Config{
InsecureSkipVerify: cfg.skipVerifyTLS, //nolint:gosec
},
}

if cfg.dialFunc != nil {
httpC.Dial = cfg.dialFunc
}

return &Stub{
cfg: cfg,
r: reflector.New(),
httpC: httpC,
}
}

func HTTP(rawURL string, opts ...Option) (*RESTCtx, error) {
Expand Down Expand Up @@ -79,7 +84,7 @@ func HTTP(rawURL string, opts ...Option) (*RESTCtx, error) {

func (s *Stub) REST(opt ...RESTOption) *RESTCtx {
ctx := &RESTCtx{
c: &s.httpC,
c: s.httpC,
r: s.r,
handlers: map[int]RESTResponseHandler{},
uri: fasthttp.AcquireURI(),
Expand Down Expand Up @@ -108,28 +113,29 @@ func (s *Stub) REST(opt ...RESTOption) *RESTCtx {
}

func (s *Stub) Websocket(opts ...WebsocketOption) *WebsocketCtx {
var proxyFunc func(req *http.Request) (*url.URL, error)
if s.cfg.httpProxyConfig != nil {
fn := s.cfg.httpProxyConfig.ProxyFunc()
proxyFunc = func(req *http.Request) (*url.URL, error) {
return fn(req.URL)
defaultProxy := http.ProxyFromEnvironment
if s.cfg.proxy != nil {
defaultProxy = func(req *http.Request) (*url.URL, error) {
return s.cfg.proxy.ProxyFunc()(req.URL)
}
}

defaultDialerBuilder := func() *websocket.Dialer {
return &websocket.Dialer{
Proxy: defaultProxy,
HandshakeTimeout: s.cfg.dialTimeout,
}
}
ctx := &WebsocketCtx{
cfg: wsConfig{
autoReconnect: true,
pingTime: time.Second * 30,
dialTimeout: s.cfg.dialTimeout,
writeTimeout: s.cfg.writeTimeout,
ratelimitChan: make(chan struct{}, defaultConcurrency),
rpcInFactory: common.SimpleIncomingJSONRPC,
rpcOutFactory: common.SimpleOutgoingJSONRPC,
dialerBuilder: func() *websocket.Dialer {
return &websocket.Dialer{
Proxy: proxyFunc,
HandshakeTimeout: s.cfg.dialTimeout,
}
},
autoReconnect: true,
pingTime: time.Second * 30,
dialTimeout: s.cfg.dialTimeout,
writeTimeout: s.cfg.writeTimeout,
ratelimitChan: make(chan struct{}, defaultConcurrency),
rpcInFactory: common.SimpleIncomingJSONRPC,
rpcOutFactory: common.SimpleOutgoingJSONRPC,
dialerBuilder: defaultDialerBuilder,
tracePropagator: s.cfg.tp,
},
r: s.r,
Expand Down

0 comments on commit b135843

Please sign in to comment.