From 72ef2cd26fa8afb2401dee6c867c0e7e749f0ed2 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Wed, 21 Feb 2024 14:16:46 -0500 Subject: [PATCH] rename doh.Transport to doh.Resolver --- Android/app/src/go/backend/doh.go | 6 +- Android/app/src/go/backend/tunnel.go | 4 +- Android/app/src/go/doh/doh.go | 93 +++++++------- Android/app/src/go/doh/doh_test.go | 119 +++++++++--------- Android/app/src/go/intra/packet_proxy.go | 6 +- Android/app/src/go/intra/sni_reporter.go | 4 +- Android/app/src/go/intra/sni_reporter_test.go | 12 +- Android/app/src/go/intra/stream_dialer.go | 6 +- Android/app/src/go/intra/tunnel.go | 11 +- 9 files changed, 127 insertions(+), 134 deletions(-) diff --git a/Android/app/src/go/backend/doh.go b/Android/app/src/go/backend/doh.go index 83a9f46a..831719de 100644 --- a/Android/app/src/go/backend/doh.go +++ b/Android/app/src/go/backend/doh.go @@ -26,7 +26,7 @@ import ( // DoHServer represents a DNS-over-HTTPS server. type DoHServer struct { - tspt doh.Transport + r doh.Resolver } // NewDoHServer creates a DoHServer that connects to the specified DoH server. @@ -47,7 +47,7 @@ func NewDoHServer( ips = strings.Split(ipsStr, ",") } dialer := protect.MakeDialer(protector) - t, err := doh.NewTransport(url, ips, dialer, nil, makeInternalDoHListener(listener)) + t, err := doh.NewResolver(url, ips, dialer, nil, makeInternalDoHListener(listener)) if err != nil { return nil, err } @@ -75,7 +75,7 @@ var dohQuery = []byte{ // // If the server responds correctly, the function returns nil. Otherwise, the function returns an error. func (s *DoHServer) Probe() error { - resp, err := s.tspt.Query(context.Background(), dohQuery) + resp, err := s.r.Query(context.Background(), dohQuery) if err != nil { return fmt.Errorf("failed to send query: %w", err) } diff --git a/Android/app/src/go/backend/tunnel.go b/Android/app/src/go/backend/tunnel.go index f9814a2e..5c9058af 100644 --- a/Android/app/src/go/backend/tunnel.go +++ b/Android/app/src/go/backend/tunnel.go @@ -35,7 +35,7 @@ type Session struct { *intra.Tunnel } -func (s *Session) SetDoHServer(svr *DoHServer) { s.SetDNS(svr.tspt) } +func (s *Session) SetDoHServer(svr *DoHServer) { s.SetDNS(svr.r) } // ConnectSession reads packets from a TUN device and applies the Intra routing // rules. Currently, this only consists of redirecting DNS packets to a specified @@ -64,7 +64,7 @@ func ConnectSession( if dohdns == nil { return nil, errors.New("dohdns must not be nil") } - t, err := intra.NewTunnel(fakedns, dohdns.tspt, tun, protector, listener) + t, err := intra.NewTunnel(fakedns, dohdns.r, tun, protector, listener) if err != nil { return nil, err } diff --git a/Android/app/src/go/doh/doh.go b/Android/app/src/go/doh/doh.go index 1a13e270..71686267 100644 --- a/Android/app/src/go/doh/doh.go +++ b/Android/app/src/go/doh/doh.go @@ -78,22 +78,22 @@ type Listener interface { OnResponse(Token, *Summary) } -// Transport represents a DNS query transport. -type Transport interface { - // Query sends a DNS query represented by q (including ID) to this DoH server +// Resolver represents a DNS-over-HTTPS (DoH) resolver. +type Resolver interface { + // Query sends a DNS query represented by q (including ID) to this DoH resolver // (located at GetURL) using the provided context, and returns the correponding // - // A non-nil error will be returned if no response was received from the DoH server, + // A non-nil error will be returned if no response was received from the DoH resolver, // the error may also be accompanied by a SERVFAIL response if appropriate. Query(ctx context.Context, q []byte) ([]byte, error) - // Return the server URL used to initialize this transport. + // Return the server URL used to initialize this DoH resolver. GetURL() string } // TODO: Keep a context here so that queries can be canceled. -type transport struct { - Transport +type resolver struct { + Resolver url string hostname string port int @@ -108,7 +108,7 @@ type transport struct { // Wait up to three seconds for the TCP handshake to complete. const tcpTimeout time.Duration = 3 * time.Second -func (t *transport) dial(ctx context.Context, network, addr string) (net.Conn, error) { +func (r *resolver) dial(ctx context.Context, network, addr string) (net.Conn, error) { logging.Debug.Printf("Dialing %s\n", addr) domain, portStr, err := net.SplitHostPort(addr) if err != nil { @@ -125,11 +125,11 @@ func (t *transport) dial(ctx context.Context, network, addr string) (net.Conn, e // TODO: Improve IP fallback strategy with parallelism and Happy Eyeballs. var conn net.Conn - ips := t.ips.Get(domain) + ips := r.ips.Get(domain) confirmed := ips.Confirmed() if confirmed != nil { logging.Debug.Printf("Trying confirmed IP %s for addr %s\n", confirmed.String(), addr) - if conn, err = split.DialWithSplitRetry(ctx, t.dialer, tcpaddr(confirmed), nil); err == nil { + if conn, err = split.DialWithSplitRetry(ctx, r.dialer, tcpaddr(confirmed), nil); err == nil { logging.Info.Printf("Confirmed IP %s worked\n", confirmed.String()) return conn, nil } @@ -143,7 +143,7 @@ func (t *transport) dial(ctx context.Context, network, addr string) (net.Conn, e // Don't try this IP twice. continue } - if conn, err = split.DialWithSplitRetry(ctx, t.dialer, tcpaddr(ip), nil); err == nil { + if conn, err = split.DialWithSplitRetry(ctx, r.dialer, tcpaddr(ip), nil); err == nil { logging.Info.Printf("Found working IP: %s\n", ip.String()) return conn, nil } @@ -151,7 +151,7 @@ func (t *transport) dial(ctx context.Context, network, addr string) (net.Conn, e return nil, err } -// NewTransport returns a DoH DNSTransport, ready for use. +// NewResolver returns a DoH [Resolver], ready for use. // This is a POST-only DoH implementation, so the DoH template should be a URL. // // `rawurl` is the DoH template in string form. @@ -159,13 +159,13 @@ func (t *transport) dial(ctx context.Context, network, addr string) (net.Conn, e // `addrs` is a list of domains or IP addresses to use as fallback, if the hostname lookup fails or // returns non-working addresses. // -// `dialer` is the dialer that the transport will use. The transport will modify the dialer's +// `dialer` is the dialer that the [Resolver] will use. The [Resolver] will modify the dialer's // timeout but will not mutate it otherwise. // // `auth` will provide a client certificate if required by the TLS server. // // `listener` will receive the status of each DNS query when it is complete. -func NewTransport(rawurl string, addrs []string, dialer *net.Dialer, auth ClientAuth, listener Listener) (Transport, error) { +func NewResolver(rawurl string, addrs []string, dialer *net.Dialer, auth ClientAuth, listener Listener) (Resolver, error) { if dialer == nil { dialer = &net.Dialer{} } @@ -188,7 +188,7 @@ func NewTransport(rawurl string, addrs []string, dialer *net.Dialer, auth Client port = 443 } - t := &transport{ + t := &resolver{ url: rawurl, hostname: parsedurl.Hostname(), port: port, @@ -251,15 +251,15 @@ func (e *httpError) Error() string { // Independent of the query's success or failure, this function also returns the // address of the server on a best-effort basis, or nil if the address could not // be determined. -func (t *transport) doQuery(ctx context.Context, q []byte) (response []byte, server *net.TCPAddr, qerr *queryError) { +func (r *resolver) doQuery(ctx context.Context, q []byte) (response []byte, server *net.TCPAddr, qerr *queryError) { if len(q) < 2 { qerr = &queryError{BadQuery, fmt.Errorf("Query length is %d", len(q))} return } - t.hangoverLock.RLock() - inHangover := time.Now().Before(t.hangoverExpiration) - t.hangoverLock.RUnlock() + r.hangoverLock.RLock() + inHangover := time.Now().Before(r.hangoverExpiration) + r.hangoverLock.RUnlock() if inHangover { response = tryServfail(q) qerr = &queryError{HTTPError, errors.New("Forwarder is in servfail hangover")} @@ -276,14 +276,14 @@ func (t *transport) doQuery(ctx context.Context, q []byte) (response []byte, ser // Zero out the query ID. id := binary.BigEndian.Uint16(q) binary.BigEndian.PutUint16(q, 0) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, t.url, bytes.NewBuffer(q)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, r.url, bytes.NewBuffer(q)) if err != nil { qerr = &queryError{InternalError, err} return } var hostname string - response, hostname, server, qerr = t.sendRequest(id, req) + response, hostname, server, qerr = r.sendRequest(id, req) // Restore the query ID. binary.BigEndian.PutUint16(q, id) @@ -301,21 +301,21 @@ func (t *transport) doQuery(ctx context.Context, q []byte) (response []byte, ser if qerr != nil { if qerr.status != SendFailed { - t.hangoverLock.Lock() - t.hangoverExpiration = time.Now().Add(hangoverDuration) - t.hangoverLock.Unlock() + r.hangoverLock.Lock() + r.hangoverExpiration = time.Now().Add(hangoverDuration) + r.hangoverLock.Unlock() } response = tryServfail(q) } else if server != nil { // Record a working IP address for this server iff qerr is nil - t.ips.Get(hostname).Confirm(server.IP) + r.ips.Get(hostname).Confirm(server.IP) } return } -func (t *transport) sendRequest(id uint16, req *http.Request) (response []byte, hostname string, server *net.TCPAddr, qerr *queryError) { - hostname = t.hostname +func (r *resolver) sendRequest(id uint16, req *http.Request) (response []byte, hostname string, server *net.TCPAddr, qerr *queryError) { + hostname = r.hostname // The connection used for this request. If the request fails, we will close // this socket, in case it is no longer functioning. @@ -331,7 +331,7 @@ func (t *transport) sendRequest(id uint16, req *http.Request) (response []byte, logging.Info.Printf("%d Query failed: %v\n", id, qerr) if server != nil { logging.Debug.Printf("%d Disconfirming %s\n", id, server.IP.String()) - t.ips.Get(hostname).Disconfirm(server.IP) + r.ips.Get(hostname).Disconfirm(server.IP) } if conn != nil { logging.Info.Printf("%d Closing failing DoH socket\n", id) @@ -401,7 +401,7 @@ func (t *transport) sendRequest(id uint16, req *http.Request) (response []byte, req.Header.Set("Accept", mimetype) req.Header.Set("User-Agent", "Intra") logging.Debug.Printf("%d Sending query\n", id) - httpResponse, err := t.client.Do(req) + httpResponse, err := r.client.Do(req) if err != nil { qerr = &queryError{SendFailed, err} return @@ -432,14 +432,14 @@ func (t *transport) sendRequest(id uint16, req *http.Request) (response []byte, return } -func (t *transport) Query(ctx context.Context, q []byte) ([]byte, error) { +func (r *resolver) Query(ctx context.Context, q []byte) ([]byte, error) { var token Token - if t.listener != nil { - token = t.listener.OnQuery(t.url) + if r.listener != nil { + token = r.listener.OnQuery(r.url) } before := time.Now() - response, server, qerr := t.doQuery(ctx, q) + response, server, qerr := r.doQuery(ctx, q) after := time.Now() errIsCancel := false @@ -469,14 +469,14 @@ func (t *transport) Query(ctx context.Context, q []byte) ([]byte, error) { // Deadlock happens (both Step 1 and Step 5 are marked as synchronized)! // // TODO: make stop() an asynchronized function - if t.listener != nil && !errIsCancel { + if r.listener != nil && !errIsCancel { latency := after.Sub(before) var ip string if server != nil { ip = server.IP.String() } - t.listener.OnResponse(token, &Summary{ + r.listener.OnResponse(token, &Summary{ Latency: latency.Seconds(), Query: q, Response: response, @@ -488,13 +488,13 @@ func (t *transport) Query(ctx context.Context, q []byte) ([]byte, error) { return response, err } -func (t *transport) GetURL() string { - return t.url +func (r *resolver) GetURL() string { + return r.url } -// Perform a query using the transport, and send the response to the writer. -func forwardQuery(t Transport, q []byte, c io.Writer) error { - resp, qerr := t.Query(context.Background(), q) +// Perform a query using the Resolver, and send the response to the writer. +func forwardQuery(r Resolver, q []byte, c io.Writer) error { + resp, qerr := r.Query(context.Background(), q) if resp == nil && qerr != nil { return qerr } @@ -517,18 +517,17 @@ func forwardQuery(t Transport, q []byte, c io.Writer) error { return qerr } -// Perform a query using the transport, send the response to the writer, +// Perform a query using the Resolver, send the response to the writer, // and close the writer if there was an error. -func forwardQueryAndCheck(t Transport, q []byte, c io.WriteCloser) { - if err := forwardQuery(t, q, c); err != nil { +func forwardQueryAndCheck(r Resolver, q []byte, c io.WriteCloser) { + if err := forwardQuery(r, q, c); err != nil { logging.Warn.Printf("Query forwarding failed: %v\n", err) c.Close() } } -// Accept a DNS-over-TCP socket from a stub resolver, and connect the socket -// to this DNSTransport. -func Accept(t Transport, c io.ReadWriteCloser) { +// Accept a DNS-over-TCP socket, and connect the socket to a DoH Resolver. +func Accept(r Resolver, c io.ReadWriteCloser) { qlbuf := make([]byte, 2) for { n, err := c.Read(qlbuf) @@ -555,7 +554,7 @@ func Accept(t Transport, c io.ReadWriteCloser) { logging.Warn.Printf("Incomplete query: %d < %d\n", n, qlen) break } - go forwardQueryAndCheck(t, q, c) + go forwardQueryAndCheck(r, q, c) } // TODO: Cancel outstanding queries at this point. c.Close() diff --git a/Android/app/src/go/doh/doh_test.go b/Android/app/src/go/doh/doh_test.go index b1bf4a94..5d5164c5 100644 --- a/Android/app/src/go/doh/doh_test.go +++ b/Android/app/src/go/doh/doh_test.go @@ -122,17 +122,17 @@ func init() { } // Check that the constructor works. -func TestNewTransport(t *testing.T) { - makeTestDoHTransport(t, googleDoH) +func TestNewResolver(t *testing.T) { + newTestDoHResolver(t, googleDoH) } // Check that the constructor rejects unsupported URLs. func TestBadUrl(t *testing.T) { - _, err := NewTransport("ftp://www.example.com", nil, nil, nil, nil) + _, err := NewResolver("ftp://www.example.com", nil, nil, nil, nil) if err == nil { t.Error("Expected error") } - _, err = NewTransport("https://www.example", nil, nil, nil, nil) + _, err = NewResolver("https://www.example", nil, nil, nil, nil) if err == nil { t.Error("Expected error") } @@ -141,7 +141,7 @@ func TestBadUrl(t *testing.T) { // Check for failure when the query is too short to be valid. func TestShortQuery(t *testing.T) { var qerr *queryError - doh := makeTestDoHTransport(t, googleDoH) + doh := newTestDoHResolver(t, googleDoH) _, err := doh.Query(context.Background(), []byte{}) if err == nil { t.Error("Empty query should fail") @@ -179,7 +179,7 @@ func TestQueryIntegration(t *testing.T) { } testQuery := func(queryData []byte) { - doh := makeTestDoHTransport(t, googleDoH) + doh := newTestDoHResolver(t, googleDoH) resp, err2 := doh.Query(context.Background(), queryData) if err2 != nil { t.Fatal(err2) @@ -226,11 +226,10 @@ func (r *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) // Check that a DNS query is converted correctly into an HTTP query. func TestRequest(t *testing.T) { - doh := makeTestDoHTransport(t, googleDoH) - transport := doh.(*transport) + resolver := newTestDoHResolver(t, googleDoH) rt := makeTestRoundTripper() - transport.client.Transport = rt - go doh.Query(context.Background(), simpleQueryBytes) + resolver.client.Transport = rt + go resolver.Query(context.Background(), simpleQueryBytes) req := <-rt.req if req.URL.String() != googleDoH.url { t.Errorf("URL mismatch: %s != %s", req.URL.String(), googleDoH.url) @@ -275,10 +274,9 @@ func queriesMostlyEqual(m1 dnsmessage.Message, m2 dnsmessage.Message) bool { // Check that a DOH response is returned correctly. func TestResponse(t *testing.T) { - doh := makeTestDoHTransport(t, googleDoH) - transport := doh.(*transport) + resolver := newTestDoHResolver(t, googleDoH) rt := makeTestRoundTripper() - transport.client.Transport = rt + resolver.client.Transport = rt // Fake server. go func() { @@ -296,7 +294,7 @@ func TestResponse(t *testing.T) { w.Close() }() - resp, err := doh.Query(context.Background(), simpleQueryBytes) + resp, err := resolver.Query(context.Background(), simpleQueryBytes) if err != nil { t.Error(err) } @@ -314,10 +312,9 @@ func TestResponse(t *testing.T) { // Simulate an empty response. (This is not a compliant server // behavior.) func TestEmptyResponse(t *testing.T) { - doh := makeTestDoHTransport(t, googleDoH) - transport := doh.(*transport) + resolver := newTestDoHResolver(t, googleDoH) rt := makeTestRoundTripper() - transport.client.Transport = rt + resolver.client.Transport = rt // Fake server. go func() { @@ -332,7 +329,7 @@ func TestEmptyResponse(t *testing.T) { } }() - _, err := doh.Query(context.Background(), simpleQueryBytes) + _, err := resolver.Query(context.Background(), simpleQueryBytes) var qerr *queryError if err == nil { t.Error("Empty body should cause an error") @@ -345,10 +342,9 @@ func TestEmptyResponse(t *testing.T) { // Simulate a non-200 HTTP response code. func TestHTTPError(t *testing.T) { - doh := makeTestDoHTransport(t, googleDoH) - transport := doh.(*transport) + resolver := newTestDoHResolver(t, googleDoH) rt := makeTestRoundTripper() - transport.client.Transport = rt + resolver.client.Transport = rt go func() { <-rt.req @@ -362,7 +358,7 @@ func TestHTTPError(t *testing.T) { w.Close() }() - _, err := doh.Query(context.Background(), simpleQueryBytes) + _, err := resolver.Query(context.Background(), simpleQueryBytes) var qerr *queryError if err == nil { t.Error("Empty body should cause an error") @@ -375,13 +371,12 @@ func TestHTTPError(t *testing.T) { // Simulate an HTTP query error. func TestSendFailed(t *testing.T) { - doh := makeTestDoHTransport(t, googleDoH) - transport := doh.(*transport) + resolver := newTestDoHResolver(t, googleDoH) rt := makeTestRoundTripper() - transport.client.Transport = rt + resolver.client.Transport = rt rt.err = errors.New("test") - _, err := doh.Query(context.Background(), simpleQueryBytes) + _, err := resolver.Query(context.Background(), simpleQueryBytes) var qerr *queryError if err == nil { t.Error("Send failure should be reported") @@ -398,13 +393,12 @@ func TestSendFailed(t *testing.T) { // when queries suceeded and fail, respectively. func TestDohIPConfirmDisconfirm(t *testing.T) { u, _ := url.Parse(googleDoH.url) - doh := makeTestDoHTransport(t, googleDoH) - transport := doh.(*transport) + resolver := newTestDoHResolver(t, googleDoH) hostname := u.Hostname() - ipmap := transport.ips.Get(hostname) + ipmap := resolver.ips.Get(hostname) // send a valid request to first have confirmed-ip set - res, _ := doh.Query(context.Background(), simpleQueryBytes) + res, _ := resolver.Query(context.Background(), simpleQueryBytes) mustUnpack(res) ip1 := ipmap.Confirmed() @@ -414,7 +408,7 @@ func TestDohIPConfirmDisconfirm(t *testing.T) { // simulate http-fail with doh server-ip set to previously confirmed-ip rt := makeTestRoundTripper() - transport.client.Transport = rt + resolver.client.Transport = rt go func() { req := <-rt.req trace := httptrace.ContextClientTrace(req.Context()) @@ -430,7 +424,7 @@ func TestDohIPConfirmDisconfirm(t *testing.T) { Request: &http.Request{URL: u}, } }() - doh.Query(context.Background(), simpleQueryBytes) + resolver.Query(context.Background(), simpleQueryBytes) ip2 := ipmap.Confirmed() if ip2 != nil { @@ -449,10 +443,9 @@ func (c *fakeConn) RemoteAddr() net.Addr { // Check that the DNSListener is called with a correct summary. func TestListener(t *testing.T) { - doh, listener := makeTestDoHTransportWithListener(t, googleDoH) - transport := doh.(*transport) + resolver, listener := newTestDoHResolverWithListener(t, googleDoH) rt := makeTestRoundTripper() - transport.client.Transport = rt + resolver.client.Transport = rt go func() { req := <-rt.req @@ -474,7 +467,7 @@ func TestListener(t *testing.T) { w.Close() }() - doh.Query(context.Background(), simpleQueryBytes) + resolver.Query(context.Background(), simpleQueryBytes) s := listener.summary if s.Latency < 0 { t.Errorf("Negative latency: %f", s.Latency) @@ -521,14 +514,14 @@ func makePair() (io.ReadWriteCloser, io.ReadWriteCloser) { return &socket{r1, w2}, &socket{r2, w1} } -type fakeTransport struct { - Transport +type fakeResolver struct { + Resolver query chan []byte response chan []byte err error } -func (t *fakeTransport) Query(ctx context.Context, q []byte) ([]byte, error) { +func (t *fakeResolver) Query(ctx context.Context, q []byte) ([]byte, error) { t.query <- q if t.err != nil { return nil, t.err @@ -536,18 +529,18 @@ func (t *fakeTransport) Query(ctx context.Context, q []byte) ([]byte, error) { return <-t.response, nil } -func (t *fakeTransport) GetURL() string { +func (t *fakeResolver) GetURL() string { return "fake" } -func (t *fakeTransport) Close() { +func (t *fakeResolver) Close() { t.err = errors.New("closed") close(t.query) close(t.response) } -func newFakeTransport() *fakeTransport { - return &fakeTransport{ +func newFakeResolver() *fakeResolver { + return &fakeResolver{ query: make(chan []byte), response: make(chan []byte), } @@ -555,7 +548,7 @@ func newFakeTransport() *fakeTransport { // Test a successful query over TCP func TestAccept(t *testing.T) { - doh := newFakeTransport() + doh := newFakeResolver() client, server := makePair() // Start the forwarder running. @@ -614,7 +607,7 @@ func TestAccept(t *testing.T) { // Sends a TCP query that results in failure. When a query fails, // Accept should close the TCP socket. func TestAcceptFail(t *testing.T) { - doh := newFakeTransport() + doh := newFakeResolver() client, server := makePair() // Start the forwarder running. @@ -646,7 +639,7 @@ func TestAcceptFail(t *testing.T) { // Sends a TCP query, and closes the socket before the response is sent. // This tests for crashes when a response cannot be delivered. func TestAcceptClose(t *testing.T) { - doh := newFakeTransport() + doh := newFakeResolver() client, server := makePair() // Start the forwarder running. @@ -676,7 +669,7 @@ func TestAcceptClose(t *testing.T) { // Test failure due to a response that is larger than the // maximum message size for DNS over TCP (65535). func TestAcceptOversize(t *testing.T) { - doh := newFakeTransport() + doh := newFakeResolver() client, server := makePair() // Start the forwarder running. @@ -865,16 +858,16 @@ func TestServfail(t *testing.T) { } func TestQueryCanBeCancelled(t *testing.T) { - expectDoHTimeout := func(server testingDoHServer, msg string) { - doh := makeTestDoHTransport(t, server) + expectDoHTimeout := func(config testingDoHConfig, msg string) { + doh := newTestDoHResolver(t, config) st := time.Now() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() doh.Query(ctx, simpleQueryBytes) - require.WithinRange(t, time.Now(), st.Add(1700*time.Millisecond), st.Add(2300*time.Millisecond), msg) + require.WithinRange(t, time.Now(), st.Add(1500*time.Millisecond), st.Add(2500*time.Millisecond), msg) } - expectDoHTimeout(unreachableDoH, "unreachable server should timeout within deadline of ctx") + expectDoHTimeout(unreachableDoH, "unreachable resolver should timeout within deadline of ctx") // Intentionally create a local DoH server that does not accept any requests addr, err := url.Parse(localDoH.url) @@ -885,21 +878,21 @@ func TestQueryCanBeCancelled(t *testing.T) { require.NoError(t, err) defer svr.Close() - expectDoHTimeout(localDoH, "unresponsive server should timeout within deadline of ctx") + expectDoHTimeout(localDoH, "unresponsive resolver should timeout within deadline of ctx") } /******** Test DoH servers ********/ -type testingDoHServer struct { +type testingDoHConfig struct { url string ips []string } -var unreachableDoH = testingDoHServer{ +var unreachableDoH = testingDoHConfig{ url: "https://1.2.3.4:443", ips: []string{"1.2.3.4"}, } -var googleDoH = testingDoHServer{ +var googleDoH = testingDoHConfig{ url: "https://dns.google/dns-query", ips: []string{ "8.8.8.8", @@ -909,12 +902,12 @@ var googleDoH = testingDoHServer{ }, } -var localDoH = testingDoHServer{ +var localDoH = testingDoHConfig{ url: "https://localhost:34443", ips: []string{"127.0.0.1"}, } -/********** DoH Transport Test Helpers **********/ +/********** DoH Resolver Test Helpers **********/ type testingDoHListener struct { Listener summary *Summary @@ -923,17 +916,17 @@ type testingDoHListener struct { func (l *testingDoHListener) OnQuery(url string) Token { return nil } func (l *testingDoHListener) OnResponse(tok Token, s *Summary) { l.summary = s } -func newTestDOHResolver(t *testing.T, config dohConfig) DOHResolver { - doh, err := NewDOHResolver(target.url, target.ips, nil, nil, nil) +func newTestDoHResolver(t *testing.T, config testingDoHConfig) *resolver { + doh, err := NewResolver(config.url, config.ips, nil, nil, nil) require.NoError(t, err) require.NotNil(t, doh) - return doh + return doh.(*resolver) } -func makeTestDoHTransportWithListener(t *testing.T, target testingDoHServer) (Transport, *testingDoHListener) { +func newTestDoHResolverWithListener(t *testing.T, config testingDoHConfig) (*resolver, *testingDoHListener) { listener := &testingDoHListener{} - doh, err := NewTransport(target.url, target.ips, nil, nil, listener) + doh, err := NewResolver(config.url, config.ips, nil, nil, listener) require.NoError(t, err) require.NotNil(t, doh) - return doh, listener + return doh.(*resolver), listener } diff --git a/Android/app/src/go/intra/packet_proxy.go b/Android/app/src/go/intra/packet_proxy.go index 1bd28481..fecb8f48 100644 --- a/Android/app/src/go/intra/packet_proxy.go +++ b/Android/app/src/go/intra/packet_proxy.go @@ -32,7 +32,7 @@ import ( type intraPacketProxy struct { fakeDNSAddr netip.AddrPort - dns atomic.Pointer[doh.Transport] + dns atomic.Pointer[doh.Resolver] proxy network.PacketProxy listener UDPListener ctx context.Context @@ -41,7 +41,7 @@ type intraPacketProxy struct { var _ network.PacketProxy = (*intraPacketProxy)(nil) func newIntraPacketProxy( - ctx context.Context, fakeDNS netip.AddrPort, dns doh.Transport, protector protect.Protector, listener UDPListener, + ctx context.Context, fakeDNS netip.AddrPort, dns doh.Resolver, protector protect.Protector, listener UDPListener, ) (*intraPacketProxy, error) { if dns == nil { return nil, errors.New("dns is required") @@ -88,7 +88,7 @@ func (p *intraPacketProxy) NewSession(resp network.PacketResponseReceiver) (netw }, nil } -func (p *intraPacketProxy) SetDNS(dns doh.Transport) error { +func (p *intraPacketProxy) SetDNS(dns doh.Resolver) error { if dns == nil { return errors.New("dns is required") } diff --git a/Android/app/src/go/intra/sni_reporter.go b/Android/app/src/go/intra/sni_reporter.go index 9472fc34..07353fcf 100644 --- a/Android/app/src/go/intra/sni_reporter.go +++ b/Android/app/src/go/intra/sni_reporter.go @@ -42,13 +42,13 @@ const burst = 10 * time.Second // tcpSNIReporter is a thread-safe wrapper around choir.Reporter type tcpSNIReporter struct { mu sync.RWMutex // Protects dns, suffix, and r. - dns doh.Transport + dns doh.Resolver suffix string r choir.Reporter } // SetDNS changes the DNS transport used for uploading reports. -func (r *tcpSNIReporter) SetDNS(dns doh.Transport) { +func (r *tcpSNIReporter) SetDNS(dns doh.Resolver) { r.mu.Lock() r.dns = dns r.mu.Unlock() diff --git a/Android/app/src/go/intra/sni_reporter_test.go b/Android/app/src/go/intra/sni_reporter_test.go index 0e878a7a..c89122f1 100644 --- a/Android/app/src/go/intra/sni_reporter_test.go +++ b/Android/app/src/go/intra/sni_reporter_test.go @@ -29,17 +29,17 @@ import ( type qfunc func(q []byte) ([]byte, error) -type fakeTransport struct { - doh.Transport +type fakeResolver struct { + doh.Resolver query qfunc } -func (t *fakeTransport) Query(ctx context.Context, q []byte) ([]byte, error) { - return t.query(q) +func (r *fakeResolver) Query(ctx context.Context, q []byte) ([]byte, error) { + return r.query(q) } -func newFakeTransport(query qfunc) *fakeTransport { - return &fakeTransport{query: query} +func newFakeTransport(query qfunc) *fakeResolver { + return &fakeResolver{query: query} } func sendReport(t *testing.T, r *tcpSNIReporter, summary TCPSocketSummary, response []byte, responseErr error) string { diff --git a/Android/app/src/go/intra/stream_dialer.go b/Android/app/src/go/intra/stream_dialer.go index 68d8c7a3..0af65b8a 100644 --- a/Android/app/src/go/intra/stream_dialer.go +++ b/Android/app/src/go/intra/stream_dialer.go @@ -32,7 +32,7 @@ import ( type intraStreamDialer struct { fakeDNSAddr netip.AddrPort - dns atomic.Pointer[doh.Transport] + dns atomic.Pointer[doh.Resolver] dialer *net.Dialer alwaysSplitHTTPS atomic.Bool listener TCPListener @@ -43,7 +43,7 @@ var _ transport.StreamDialer = (*intraStreamDialer)(nil) func newIntraStreamDialer( fakeDNS netip.AddrPort, - dns doh.Transport, + dns doh.Resolver, protector protect.Protector, listener TCPListener, sniReporter *tcpSNIReporter, @@ -86,7 +86,7 @@ func (sd *intraStreamDialer) Dial(ctx context.Context, raddr string) (transport. return makeTCPWrapConn(conn, stats, sd.listener, sd.sniReporter), nil } -func (sd *intraStreamDialer) SetDNS(dns doh.Transport) error { +func (sd *intraStreamDialer) SetDNS(dns doh.Resolver) error { if dns == nil { return errors.New("dns is required") } diff --git a/Android/app/src/go/intra/tunnel.go b/Android/app/src/go/intra/tunnel.go index eccb279e..a8098dec 100644 --- a/Android/app/src/go/intra/tunnel.go +++ b/Android/app/src/go/intra/tunnel.go @@ -59,10 +59,11 @@ type Tunnel struct { // // These will normally be localhost with a high-numbered port. // -// `dohdns` is the initial DOH transport. +// `dohdns` is the initial [Resolver]. +// // `eventListener` will be notified at the completion of every tunneled socket. func NewTunnel( - fakedns string, dohdns doh.Transport, tun io.Closer, protector protect.Protector, eventListener Listener, + fakedns string, dohdns doh.Resolver, tun io.Closer, protector protect.Protector, eventListener Listener, ) (t *Tunnel, err error) { if eventListener == nil { return nil, errors.New("eventListener is required") @@ -99,10 +100,10 @@ func NewTunnel( return } -// Set the DNSTransport. This method must be called before connecting the transport -// to the TUN device. The transport can be changed at any time during operation, but +// Set the DNS Resolver. This method must be called before connecting the transport +// to the TUN device. The transport can be changed at any time during operation, but // must not be nil. -func (t *Tunnel) SetDNS(dns doh.Transport) { +func (t *Tunnel) SetDNS(dns doh.Resolver) { t.sd.SetDNS(dns) t.pp.SetDNS(dns) t.sni.SetDNS(dns)