diff --git a/target.go b/target.go index 05c8414..25afa88 100644 --- a/target.go +++ b/target.go @@ -23,6 +23,7 @@ package main import ( + "bytes" "encoding/base64" "fmt" "io/ioutil" @@ -99,6 +100,10 @@ func (s *targetServer) resolveQueryWithResolver(q *dns.Msg, r resolver) ([]byte, start := time.Now() response, err := r.resolve(q) + if err != nil { + log.Println("Resolution failed: ", err) + return nil, err + } elapsed := time.Since(start) packedResponse, err := response.Pack() @@ -176,6 +181,9 @@ func (s *targetServer) parseObliviousQueryFromRequest(r *http.Request) (odoh.Obl func (s *targetServer) createObliviousResponseForQuery(context odoh.ResponseContext, dnsResponse []byte) (odoh.ObliviousDNSMessage, error) { response := odoh.CreateObliviousDNSResponse(dnsResponse, 0) odohResponse, err := context.EncryptResponse(response) + if err != nil { + return odoh.ObliviousDNSMessage{}, err + } if s.verbose { log.Printf("Encrypted response: %x", odohResponse) @@ -200,6 +208,13 @@ func (s *targetServer) odohQueryHandler(w http.ResponseWriter, r *http.Request) return } + keyID := s.odohKeyPair.Config.Contents.KeyID() + receivedKeyID := odohMessage.KeyID + if !bytes.Equal(keyID, receivedKeyID) { + log.Println("received keyID is different from expected key ID") + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + } + obliviousQuery, responseContext, err := s.odohKeyPair.DecryptQuery(odohMessage) if err != nil { log.Println("DecryptQuery failed:", err) diff --git a/target_test.go b/target_test.go index d9bdca5..94a2d6f 100644 --- a/target_test.go +++ b/target_test.go @@ -33,6 +33,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" odoh "github.com/cloudflare/odoh-go" "github.com/miekg/dns" @@ -359,8 +360,8 @@ func TestQueryHandlerODoHWithInvalidKey(t *testing.T) { rr := httptest.NewRecorder() handler.ServeHTTP(rr, request) - if status := rr.Result().StatusCode; status != http.StatusBadRequest { - t.Fatal(fmt.Errorf("Result did not yield %d, got %d instead", http.StatusBadRequest, status)) + if status := rr.Result().StatusCode; status != http.StatusUnauthorized { + t.Fatal(fmt.Errorf("Result did not yield %d, got %d instead", http.StatusUnauthorized, status)) } } @@ -392,3 +393,86 @@ func TestQueryHandlerODoHWithCorruptCiphertext(t *testing.T) { t.Fatal(fmt.Errorf("Result did not yield %d, got %d instead", http.StatusBadRequest, status)) } } + +func TestQueryHandlerODoHWithMalformedQuery(t *testing.T) { + r := createLocalResolver(t) + target := createTarget(t, r) + + handler := http.HandlerFunc(target.targetQueryHandler) + + // malformed odoh query + queryBytes := []byte{1, 2, 3} + request, err := http.NewRequest(http.MethodPost, queryEndpoint, bytes.NewReader(queryBytes)) + if err != nil { + t.Fatal(err) + } + request.Header.Add("Content-Type", odohMessageContentType) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, request) + + if status := rr.Result().StatusCode; status != http.StatusBadRequest { + t.Fatal(fmt.Errorf("Result did not yield %d, got %d instead", http.StatusBadRequest, status)) + } +} + +func TestODoHResolutionWithRealResolver(t *testing.T) { + r := &targetResolver{ + timeout: 2500 * time.Millisecond, + nameserver: "1.1.1.1:53", + } + target := createTarget(t, r) + + handler := http.HandlerFunc(target.targetQueryHandler) + + // malformed DNS query + obliviousQuery := odoh.CreateObliviousDNSQuery([]byte{1, 2, 3}, 0) + encryptedQuery, _, err := target.odohKeyPair.Config.Contents.EncryptQuery(obliviousQuery) + if err != nil { + t.Fatal(err) + } + + request, err := http.NewRequest(http.MethodPost, queryEndpoint, bytes.NewReader(encryptedQuery.Marshal())) + if err != nil { + t.Fatal(err) + } + request.Header.Add("Content-Type", odohMessageContentType) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, request) + + if status := rr.Result().StatusCode; status != http.StatusBadRequest { + t.Fatal(fmt.Errorf("Result did not yield %d, got %d instead", http.StatusBadRequest, status)) + } + + handler = http.HandlerFunc(target.targetQueryHandler) + + // valid dns query + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeA) + packedQuery, err := q.Pack() + if err != nil { + t.Fatal(err) + } + obliviousQuery = odoh.CreateObliviousDNSQuery([]byte(packedQuery), 0) + encryptedQuery, _, err = target.odohKeyPair.Config.Contents.EncryptQuery(obliviousQuery) + if err != nil { + t.Fatal(err) + } + + request, err = http.NewRequest(http.MethodPost, queryEndpoint, bytes.NewReader(encryptedQuery.Marshal())) + if err != nil { + t.Fatal(err) + } + request.Header.Add("Content-Type", odohMessageContentType) + + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, request) + + if status := rr.Result().StatusCode; status != http.StatusOK { + t.Fatal(fmt.Errorf("Result did not yield %d, got %d instead", http.StatusOK, status)) + } + if rr.Result().Header.Get("Content-Type") != odohMessageContentType { + t.Fatal("Invalid content type response") + } +}