Skip to content

Commit

Permalink
rename doh.Transport to doh.Resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
jyyi1 committed Feb 21, 2024
1 parent c3a6b5f commit 72ef2cd
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 134 deletions.
6 changes: 3 additions & 3 deletions Android/app/src/go/backend/doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (

// DoHServer represents a DNS-over-HTTPS server.
type DoHServer struct {
tspt doh.Transport
r doh.Resolver
}

// NewDoHServer creates a DoHServer that connects to the specified DoH server.
Expand All @@ -47,7 +47,7 @@ func NewDoHServer(
ips = strings.Split(ipsStr, ",")
}
dialer := protect.MakeDialer(protector)
t, err := doh.NewTransport(url, ips, dialer, nil, makeInternalDoHListener(listener))
t, err := doh.NewResolver(url, ips, dialer, nil, makeInternalDoHListener(listener))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -75,7 +75,7 @@ var dohQuery = []byte{
//
// If the server responds correctly, the function returns nil. Otherwise, the function returns an error.
func (s *DoHServer) Probe() error {
resp, err := s.tspt.Query(context.Background(), dohQuery)
resp, err := s.r.Query(context.Background(), dohQuery)
if err != nil {
return fmt.Errorf("failed to send query: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions Android/app/src/go/backend/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type Session struct {
*intra.Tunnel
}

func (s *Session) SetDoHServer(svr *DoHServer) { s.SetDNS(svr.tspt) }
func (s *Session) SetDoHServer(svr *DoHServer) { s.SetDNS(svr.r) }

// ConnectSession reads packets from a TUN device and applies the Intra routing
// rules. Currently, this only consists of redirecting DNS packets to a specified
Expand Down Expand Up @@ -64,7 +64,7 @@ func ConnectSession(
if dohdns == nil {
return nil, errors.New("dohdns must not be nil")
}
t, err := intra.NewTunnel(fakedns, dohdns.tspt, tun, protector, listener)
t, err := intra.NewTunnel(fakedns, dohdns.r, tun, protector, listener)
if err != nil {
return nil, err
}
Expand Down
93 changes: 46 additions & 47 deletions Android/app/src/go/doh/doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,22 @@ type Listener interface {
OnResponse(Token, *Summary)
}

// Transport represents a DNS query transport.
type Transport interface {
// Query sends a DNS query represented by q (including ID) to this DoH server
// Resolver represents a DNS-over-HTTPS (DoH) resolver.
type Resolver interface {
// Query sends a DNS query represented by q (including ID) to this DoH resolver
// (located at GetURL) using the provided context, and returns the correponding
//
// A non-nil error will be returned if no response was received from the DoH server,
// A non-nil error will be returned if no response was received from the DoH resolver,
// the error may also be accompanied by a SERVFAIL response if appropriate.
Query(ctx context.Context, q []byte) ([]byte, error)

// Return the server URL used to initialize this transport.
// Return the server URL used to initialize this DoH resolver.
GetURL() string
}

// TODO: Keep a context here so that queries can be canceled.
type transport struct {
Transport
type resolver struct {
Resolver
url string
hostname string
port int
Expand All @@ -108,7 +108,7 @@ type transport struct {
// Wait up to three seconds for the TCP handshake to complete.
const tcpTimeout time.Duration = 3 * time.Second

func (t *transport) dial(ctx context.Context, network, addr string) (net.Conn, error) {
func (r *resolver) dial(ctx context.Context, network, addr string) (net.Conn, error) {
logging.Debug.Printf("Dialing %s\n", addr)
domain, portStr, err := net.SplitHostPort(addr)
if err != nil {
Expand All @@ -125,11 +125,11 @@ func (t *transport) dial(ctx context.Context, network, addr string) (net.Conn, e

// TODO: Improve IP fallback strategy with parallelism and Happy Eyeballs.
var conn net.Conn
ips := t.ips.Get(domain)
ips := r.ips.Get(domain)
confirmed := ips.Confirmed()
if confirmed != nil {
logging.Debug.Printf("Trying confirmed IP %s for addr %s\n", confirmed.String(), addr)
if conn, err = split.DialWithSplitRetry(ctx, t.dialer, tcpaddr(confirmed), nil); err == nil {
if conn, err = split.DialWithSplitRetry(ctx, r.dialer, tcpaddr(confirmed), nil); err == nil {
logging.Info.Printf("Confirmed IP %s worked\n", confirmed.String())
return conn, nil
}
Expand All @@ -143,29 +143,29 @@ func (t *transport) dial(ctx context.Context, network, addr string) (net.Conn, e
// Don't try this IP twice.
continue
}
if conn, err = split.DialWithSplitRetry(ctx, t.dialer, tcpaddr(ip), nil); err == nil {
if conn, err = split.DialWithSplitRetry(ctx, r.dialer, tcpaddr(ip), nil); err == nil {
logging.Info.Printf("Found working IP: %s\n", ip.String())
return conn, nil
}
}
return nil, err
}

// NewTransport returns a DoH DNSTransport, ready for use.
// NewResolver returns a DoH [Resolver], ready for use.
// This is a POST-only DoH implementation, so the DoH template should be a URL.
//
// `rawurl` is the DoH template in string form.
//
// `addrs` is a list of domains or IP addresses to use as fallback, if the hostname lookup fails or
// returns non-working addresses.
//
// `dialer` is the dialer that the transport will use. The transport will modify the dialer's
// `dialer` is the dialer that the [Resolver] will use. The [Resolver] will modify the dialer's
// timeout but will not mutate it otherwise.
//
// `auth` will provide a client certificate if required by the TLS server.
//
// `listener` will receive the status of each DNS query when it is complete.
func NewTransport(rawurl string, addrs []string, dialer *net.Dialer, auth ClientAuth, listener Listener) (Transport, error) {
func NewResolver(rawurl string, addrs []string, dialer *net.Dialer, auth ClientAuth, listener Listener) (Resolver, error) {
if dialer == nil {
dialer = &net.Dialer{}
}
Expand All @@ -188,7 +188,7 @@ func NewTransport(rawurl string, addrs []string, dialer *net.Dialer, auth Client
port = 443
}

t := &transport{
t := &resolver{
url: rawurl,
hostname: parsedurl.Hostname(),
port: port,
Expand Down Expand Up @@ -251,15 +251,15 @@ func (e *httpError) Error() string {
// Independent of the query's success or failure, this function also returns the
// address of the server on a best-effort basis, or nil if the address could not
// be determined.
func (t *transport) doQuery(ctx context.Context, q []byte) (response []byte, server *net.TCPAddr, qerr *queryError) {
func (r *resolver) doQuery(ctx context.Context, q []byte) (response []byte, server *net.TCPAddr, qerr *queryError) {
if len(q) < 2 {
qerr = &queryError{BadQuery, fmt.Errorf("Query length is %d", len(q))}
return
}

t.hangoverLock.RLock()
inHangover := time.Now().Before(t.hangoverExpiration)
t.hangoverLock.RUnlock()
r.hangoverLock.RLock()
inHangover := time.Now().Before(r.hangoverExpiration)
r.hangoverLock.RUnlock()
if inHangover {
response = tryServfail(q)
qerr = &queryError{HTTPError, errors.New("Forwarder is in servfail hangover")}
Expand All @@ -276,14 +276,14 @@ func (t *transport) doQuery(ctx context.Context, q []byte) (response []byte, ser
// Zero out the query ID.
id := binary.BigEndian.Uint16(q)
binary.BigEndian.PutUint16(q, 0)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, t.url, bytes.NewBuffer(q))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, r.url, bytes.NewBuffer(q))
if err != nil {
qerr = &queryError{InternalError, err}
return
}

var hostname string
response, hostname, server, qerr = t.sendRequest(id, req)
response, hostname, server, qerr = r.sendRequest(id, req)

// Restore the query ID.
binary.BigEndian.PutUint16(q, id)
Expand All @@ -301,21 +301,21 @@ func (t *transport) doQuery(ctx context.Context, q []byte) (response []byte, ser

if qerr != nil {
if qerr.status != SendFailed {
t.hangoverLock.Lock()
t.hangoverExpiration = time.Now().Add(hangoverDuration)
t.hangoverLock.Unlock()
r.hangoverLock.Lock()
r.hangoverExpiration = time.Now().Add(hangoverDuration)
r.hangoverLock.Unlock()
}

response = tryServfail(q)
} else if server != nil {
// Record a working IP address for this server iff qerr is nil
t.ips.Get(hostname).Confirm(server.IP)
r.ips.Get(hostname).Confirm(server.IP)
}
return
}

func (t *transport) sendRequest(id uint16, req *http.Request) (response []byte, hostname string, server *net.TCPAddr, qerr *queryError) {
hostname = t.hostname
func (r *resolver) sendRequest(id uint16, req *http.Request) (response []byte, hostname string, server *net.TCPAddr, qerr *queryError) {
hostname = r.hostname

// The connection used for this request. If the request fails, we will close
// this socket, in case it is no longer functioning.
Expand All @@ -331,7 +331,7 @@ func (t *transport) sendRequest(id uint16, req *http.Request) (response []byte,
logging.Info.Printf("%d Query failed: %v\n", id, qerr)
if server != nil {
logging.Debug.Printf("%d Disconfirming %s\n", id, server.IP.String())
t.ips.Get(hostname).Disconfirm(server.IP)
r.ips.Get(hostname).Disconfirm(server.IP)
}
if conn != nil {
logging.Info.Printf("%d Closing failing DoH socket\n", id)
Expand Down Expand Up @@ -401,7 +401,7 @@ func (t *transport) sendRequest(id uint16, req *http.Request) (response []byte,
req.Header.Set("Accept", mimetype)
req.Header.Set("User-Agent", "Intra")
logging.Debug.Printf("%d Sending query\n", id)
httpResponse, err := t.client.Do(req)
httpResponse, err := r.client.Do(req)
if err != nil {
qerr = &queryError{SendFailed, err}
return
Expand Down Expand Up @@ -432,14 +432,14 @@ func (t *transport) sendRequest(id uint16, req *http.Request) (response []byte,
return
}

func (t *transport) Query(ctx context.Context, q []byte) ([]byte, error) {
func (r *resolver) Query(ctx context.Context, q []byte) ([]byte, error) {
var token Token
if t.listener != nil {
token = t.listener.OnQuery(t.url)
if r.listener != nil {
token = r.listener.OnQuery(r.url)
}

before := time.Now()
response, server, qerr := t.doQuery(ctx, q)
response, server, qerr := r.doQuery(ctx, q)
after := time.Now()

errIsCancel := false
Expand Down Expand Up @@ -469,14 +469,14 @@ func (t *transport) Query(ctx context.Context, q []byte) ([]byte, error) {
// Deadlock happens (both Step 1 and Step 5 are marked as synchronized)!
//
// TODO: make stop() an asynchronized function
if t.listener != nil && !errIsCancel {
if r.listener != nil && !errIsCancel {
latency := after.Sub(before)
var ip string
if server != nil {
ip = server.IP.String()
}

t.listener.OnResponse(token, &Summary{
r.listener.OnResponse(token, &Summary{
Latency: latency.Seconds(),
Query: q,
Response: response,
Expand All @@ -488,13 +488,13 @@ func (t *transport) Query(ctx context.Context, q []byte) ([]byte, error) {
return response, err
}

func (t *transport) GetURL() string {
return t.url
func (r *resolver) GetURL() string {
return r.url
}

// Perform a query using the transport, and send the response to the writer.
func forwardQuery(t Transport, q []byte, c io.Writer) error {
resp, qerr := t.Query(context.Background(), q)
// Perform a query using the Resolver, and send the response to the writer.
func forwardQuery(r Resolver, q []byte, c io.Writer) error {
resp, qerr := r.Query(context.Background(), q)
if resp == nil && qerr != nil {
return qerr
}
Expand All @@ -517,18 +517,17 @@ func forwardQuery(t Transport, q []byte, c io.Writer) error {
return qerr
}

// Perform a query using the transport, send the response to the writer,
// Perform a query using the Resolver, send the response to the writer,
// and close the writer if there was an error.
func forwardQueryAndCheck(t Transport, q []byte, c io.WriteCloser) {
if err := forwardQuery(t, q, c); err != nil {
func forwardQueryAndCheck(r Resolver, q []byte, c io.WriteCloser) {
if err := forwardQuery(r, q, c); err != nil {
logging.Warn.Printf("Query forwarding failed: %v\n", err)
c.Close()
}
}

// Accept a DNS-over-TCP socket from a stub resolver, and connect the socket
// to this DNSTransport.
func Accept(t Transport, c io.ReadWriteCloser) {
// Accept a DNS-over-TCP socket, and connect the socket to a DoH Resolver.
func Accept(r Resolver, c io.ReadWriteCloser) {
qlbuf := make([]byte, 2)
for {
n, err := c.Read(qlbuf)
Expand All @@ -555,7 +554,7 @@ func Accept(t Transport, c io.ReadWriteCloser) {
logging.Warn.Printf("Incomplete query: %d < %d\n", n, qlen)
break
}
go forwardQueryAndCheck(t, q, c)
go forwardQueryAndCheck(r, q, c)
}
// TODO: Cancel outstanding queries at this point.
c.Close()
Expand Down
Loading

0 comments on commit 72ef2cd

Please sign in to comment.