Skip to content

Commit

Permalink
do not lookup cname after looking up the txt for mta-sts, and follow …
Browse files Browse the repository at this point in the history
…cnames for mocks

because the txt would already follow cnames.
the additional cname lookup didn't hurt, it just didn't do anything.
i probably didn't realize that before looking deeper into dns.
  • Loading branch information
mjl- committed Oct 14, 2023
1 parent 8ca1988 commit 101c270
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 76 deletions.
12 changes: 6 additions & 6 deletions dns/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,20 +184,20 @@ func (r MockResolver) LookupHost(ctx context.Context, host string) ([]string, ad

func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net.IP, adns.Result, error) {
mr := mockReq{"ip", host}
_, result, err := r.result(ctx, mr)
name, result, err := r.result(ctx, mr)
if err != nil {
return nil, result, err
}
var ips []net.IP
switch network {
case "ip", "ip4":
for _, ip := range r.A[host] {
for _, ip := range r.A[name] {
ips = append(ips, net.ParseIP(ip))
}
}
switch network {
case "ip", "ip6":
for _, ip := range r.AAAA[host] {
for _, ip := range r.AAAA[name] {
ips = append(ips, net.ParseIP(ip))
}
}
Expand All @@ -209,7 +209,7 @@ func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net

func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, adns.Result, error) {
mr := mockReq{"mx", name}
_, result, err := r.result(ctx, mr)
name, result, err := r.result(ctx, mr)
if err != nil {
return nil, result, err
}
Expand All @@ -222,7 +222,7 @@ func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, adn

func (r MockResolver) LookupTXT(ctx context.Context, name string) ([]string, adns.Result, error) {
mr := mockReq{"txt", name}
_, result, err := r.result(ctx, mr)
name, result, err := r.result(ctx, mr)
if err != nil {
return nil, result, err
}
Expand All @@ -241,7 +241,7 @@ func (r MockResolver) LookupTLSA(ctx context.Context, port int, protocol string,
name = fmt.Sprintf("_%d._%s.%s", port, protocol, host)
}
mr := mockReq{"tlsa", name}
_, result, err := r.result(ctx, mr)
name, result, err := r.result(ctx, mr)
if err != nil {
return nil, result, err
}
Expand Down
51 changes: 16 additions & 35 deletions mtasts/mtasts.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,45 +162,26 @@ var (
)

// LookupRecord looks up the MTA-STS TXT DNS record at "_mta-sts.<domain>",
// following CNAME records, and returns the parsed MTA-STS record, the DNS TXT
// record and any CNAMEs that were followed.
func LookupRecord(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (rrecord *Record, rtxt string, rcnames []string, rerr error) {
// following CNAME records, and returns the parsed MTA-STS record and the DNS TXT
// record.
func LookupRecord(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (rrecord *Record, rtxt string, rerr error) {
log := xlog.WithContext(ctx)
start := time.Now()
defer func() {
log.Debugx("mtasts lookup result", rerr, mlog.Field("domain", domain), mlog.Field("record", rrecord), mlog.Field("cnames", rcnames), mlog.Field("duration", time.Since(start)))
log.Debugx("mtasts lookup result", rerr, mlog.Field("domain", domain), mlog.Field("record", rrecord), mlog.Field("duration", time.Since(start)))
}()

// ../rfc/8461:289
// ../rfc/8461:351
// We lookup the txt record, but must follow CNAME records when the TXT does not exist.
var cnames []string
// We lookup the txt record, but must follow CNAME records when the TXT does not
// exist. LookupTXT follows CNAMEs.
name := "_mta-sts." + domain.ASCII + "."
var txts []string
for {
var err error
txts, _, err = dns.WithPackage(resolver, "mtasts").LookupTXT(ctx, name)
if dns.IsNotFound(err) {
// DNS has no specified limit on how many CNAMEs to follow. Chains of 10 CNAMEs
// have been seen on the internet.
if len(cnames) > 16 {
return nil, "", cnames, fmt.Errorf("too many cnames")
}
cname, _, err := dns.WithPackage(resolver, "mtasts").LookupCNAME(ctx, name)
if dns.IsNotFound(err) {
return nil, "", cnames, ErrNoRecord
}
if err != nil {
return nil, "", cnames, fmt.Errorf("%w: %s", ErrDNS, err)
}
cnames = append(cnames, cname)
name = cname
continue
} else if err != nil {
return nil, "", cnames, fmt.Errorf("%w: %s", ErrDNS, err)
} else {
break
}
txts, _, err := dns.WithPackage(resolver, "mtasts").LookupTXT(ctx, name)
if dns.IsNotFound(err) {
return nil, "", ErrNoRecord
} else if err != nil {
return nil, "", fmt.Errorf("%w: %s", ErrDNS, err)
}

var text string
Expand All @@ -215,18 +196,18 @@ func LookupRecord(ctx context.Context, resolver dns.Resolver, domain dns.Domain)
continue
}
if err != nil {
return nil, "", cnames, err
return nil, "", err
}
if record != nil {
return nil, "", cnames, ErrMultipleRecords
return nil, "", ErrMultipleRecords
}
record = r
text = txt
}
if record == nil {
return nil, "", cnames, ErrNoRecord
return nil, "", ErrNoRecord
}
return record, text, cnames, nil
return record, text, nil
}

// Policy fetch errors.
Expand Down Expand Up @@ -330,7 +311,7 @@ func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (record
log.Debugx("mtasts get result", err, mlog.Field("domain", domain), mlog.Field("record", record), mlog.Field("policy", policy), mlog.Field("duration", time.Since(start)))
}()

record, _, _, err = LookupRecord(ctx, resolver, domain)
record, _, err = LookupRecord(ctx, resolver, domain)
if err != nil {
return nil, nil, err
}
Expand Down
33 changes: 17 additions & 16 deletions mtasts/mtasts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ import (
"github.com/mjl-/adns"

"github.com/mjl-/mox/dns"
"github.com/mjl-/mox/mlog"
)

func TestLookup(t *testing.T) {
mlog.SetConfig(map[string]mlog.Level{"": mlog.LevelDebug})

resolver := dns.MockResolver{
TXT: map[string][]string{
"_mta-sts.a.example.": {"v=STSv1; id=1"},
Expand All @@ -37,39 +40,37 @@ func TestLookup(t *testing.T) {
CNAME: map[string]string{
"_mta-sts.a.cnames.example.": "_mta-sts.b.cnames.example.",
"_mta-sts.b.cnames.example.": "_mta-sts.c.cnames.example.",
"_mta-sts.followtemperror.example.": "_mta-sts.cnametemperror.example.",
"_mta-sts.followtemperror.example.": "_mta-sts.temperror.example.",
},
Fail: []string{
"txt _mta-sts.temperror.example.",
"cname _mta-sts.cnametemperror.example.",
},
}

test := func(host string, expRecord *Record, expCNAMEs []string, expErr error) {
test := func(host string, expRecord *Record, expErr error) {
t.Helper()

record, _, cnames, err := LookupRecord(context.Background(), resolver, dns.Domain{ASCII: host})
record, _, err := LookupRecord(context.Background(), resolver, dns.Domain{ASCII: host})
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
t.Fatalf("lookup: got err %#v, expected %#v", err, expErr)
}
if err != nil {
return
}
if !reflect.DeepEqual(record, expRecord) || !reflect.DeepEqual(cnames, expCNAMEs) {
t.Fatalf("lookup: got record %#v, cnames %#v, expected %#v %#v", record, cnames, expRecord, expCNAMEs)
if !reflect.DeepEqual(record, expRecord) {
t.Fatalf("lookup: got record %#v, expected %#v", record, expRecord)
}
}

test("absent.example", nil, nil, ErrNoRecord)
test("other.example", nil, nil, ErrNoRecord)
test("a.example", &Record{Version: "STSv1", ID: "1"}, nil, nil)
test("one.example", &Record{Version: "STSv1", ID: "1"}, nil, nil)
test("bad.example", nil, nil, ErrRecordSyntax)
test("multiple.example", nil, nil, ErrMultipleRecords)
test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, []string{"_mta-sts.b.cnames.example.", "_mta-sts.c.cnames.example."}, nil)
test("temperror.example", nil, nil, ErrDNS)
test("cnametemperror.example", nil, nil, ErrDNS)
test("followtemperror.example", nil, nil, ErrDNS)
test("absent.example", nil, ErrNoRecord)
test("other.example", nil, ErrNoRecord)
test("a.example", &Record{Version: "STSv1", ID: "1"}, nil)
test("one.example", &Record{Version: "STSv1", ID: "1"}, nil)
test("bad.example", nil, ErrRecordSyntax)
test("multiple.example", nil, ErrMultipleRecords)
test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, nil)
test("temperror.example", nil, ErrDNS)
test("followtemperror.example", nil, ErrDNS)
}

func TestMatches(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion mtastsdb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (policy
policy = &cachedPolicy.Policy
nctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
record, _, _, err := mtasts.LookupRecord(nctx, resolver, domain)
record, _, err := mtasts.LookupRecord(nctx, resolver, domain)
if err != nil {
if !errors.Is(err, mtasts.ErrNoRecord) {
// Could be a temporary DNS or configuration error.
Expand Down
2 changes: 1 addition & 1 deletion mtastsdb/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func refreshDomain(ctx context.Context, db *bstore.DB, resolver dns.Resolver, pr
return
}
log.Debug("refreshing mta-sts policy for domain", mlog.Field("domain", d))
record, _, _, err := mtasts.LookupRecord(ctx, resolver, d)
record, _, err := mtasts.LookupRecord(ctx, resolver, d)
if err == nil && record.ID == pr.RecordID {
qup := bstore.QueryDB[PolicyRecord](ctx, db)
qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
Expand Down
8 changes: 1 addition & 7 deletions webadmin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ type MTASTSRecord struct {
mtasts.Record
}
type MTASTSCheckResult struct {
CNAMEs []string
TXT string
Record *MTASTSRecord
PolicyText string
Expand Down Expand Up @@ -1180,15 +1179,10 @@ Ensure a DNS TXT record like the following exists:
defer logPanic(ctx)
defer wg.Done()

record, txt, cnames, err := mtasts.LookupRecord(ctx, resolver, domain)
record, txt, err := mtasts.LookupRecord(ctx, resolver, domain)
if err != nil {
addf(&r.MTASTS.Errors, "Looking up MTA-STS record: %s", err)
}
if cnames != nil {
r.MTASTS.CNAMEs = cnames
} else {
r.MTASTS.CNAMEs = []string{}
}
r.MTASTS.TXT = txt
if record != nil {
r.MTASTS.Record = &MTASTSRecord{*record}
Expand Down
3 changes: 1 addition & 2 deletions webadmin/admin.html
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,7 @@
const detailsTLSRPT = !checks.TLSRPT.TXT ? [] : [
dom.div('TXT record: ' + checks.TLSRPT.TXT),
]
const detailsMTASTS = empty(checks.MTASTS.CNAMEs) && !checks.MTASTS.TXT && !checks.MTASTS.PolicyText ? [] : [
dom.div('CNAMEs followed: ' + (checks.MTASTS.CNAMEs.join(', ') || '(none)')),
const detailsMTASTS = !checks.MTASTS.TXT && !checks.MTASTS.PolicyText ? [] : [
!checks.MTASTS.TXT ? [] : dom.div('MTA-STS record: ' + checks.MTASTS.TXT),
!checks.MTASTS.PolicyText ? [] : dom.div('MTA-STS policy: ', dom('pre.literal', style({maxWidth: '60em'}), checks.MTASTS.PolicyText)),
]
Expand Down
8 changes: 0 additions & 8 deletions webadmin/adminapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1626,14 +1626,6 @@
"Name": "MTASTSCheckResult",
"Docs": "",
"Fields": [
{
"Name": "CNAMEs",
"Docs": "",
"Typewords": [
"[]",
"string"
]
},
{
"Name": "TXT",
"Docs": "",
Expand Down

0 comments on commit 101c270

Please sign in to comment.