diff --git a/pkg/dnssec/chain.go b/pkg/dnssec/chain.go new file mode 100644 index 0000000..2a58953 --- /dev/null +++ b/pkg/dnssec/chain.go @@ -0,0 +1,201 @@ +package dnssec + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/miekg/dns" + "github.com/qdm12/dns/v2/internal/server" +) + +// delegationChain is the DNSSEC chain of trust from the +// queried zone to the root (.) zone. +// The first signed zone is the queried zone, and the last +// signed zone is the root zone. +// See https://www.ietf.org/rfc/rfc4033.txt +type delegationChain []*signedZone + +// buildDelegationChain queries the RRs required for the zone validation. +// It begins the queries at the desired zone and then go +// up the delegation tree until it reaches the root zone. +// It returns a new delegation chain of signed zones where the +// first signed zone (index 0) is the child zone and the last signed +// zone is the root zone. +func buildDelegationChain(ctx context.Context, exchange server.Exchange, + zone string, qClass uint16) (chain delegationChain, err error) { + zoneParts := strings.Split(zone, ".") + + type result struct { + i int + signedZone *signedZone + err error + } + results := make(chan result) + + for i := range zoneParts { + go func(i int, results chan<- result) { + // the following zone names are queried: + // 'example.com.', 'com.', '.' + zoneName := dns.Fqdn(strings.Join(zoneParts[i:], ".")) + signedZone, err := queryDelegation(ctx, exchange, zoneName, qClass) + if err != nil { + err = fmt.Errorf("querying delegation for zone %s: %w", + zone, err) + } + results <- result{i: i, signedZone: signedZone, err: err} + }(i, results) + } + + chain = make(delegationChain, len(zoneParts)) + for range zoneParts { + result := <-results + if result.err != nil { + if err == nil { + err = result.err + } + continue + } + chain[result.i] = result.signedZone + } + close(results) + + if err != nil { + return nil, err + } + + return chain, nil +} + +// queryDelegation obtains the DNSKEY records and the DS +// records for a given zone, and creates a signed zone with +// this information. It does not query the (non existent) +// DS record for the root zone. +func queryDelegation(ctx context.Context, exchange server.Exchange, + zone string, qClass uint16) (sz *signedZone, err error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + type result struct { + qType uint16 // TODO use rrsig type covered instead + signedRRSet signedRRSet + err error + } + results := make(chan result) + + go func() { + signedRRSet, err := queryDNSKey(ctx, exchange, zone, qClass) + results <- result{qType: dns.TypeDNSKEY, signedRRSet: signedRRSet, err: err} + }() + + go func() { + // Note for the root zone ".", the DS record is not queried since + // it does not exist. + signedRRSet, err := queryDS(ctx, exchange, zone, qClass) + results <- result{qType: dns.TypeDS, signedRRSet: signedRRSet, err: err} + }() + + sz = &signedZone{ + zone: zone, + } + + const parallelQueries = 2 + for i := 0; i < parallelQueries; i++ { + result := <-results + if result.err != nil { + if err == nil { // first error encountered + err = result.err + cancel() + } + continue + } + if result.qType == dns.TypeDS { + // For the root zone ".", both dsRRSig and dsRRSet are nil. + sz.dsRRSet = result.signedRRSet + } else { + sz.dnsKeyRRSet = result.signedRRSet + sz.keyTagToDNSKey = dnsKeyRRSetToMap(result.signedRRSet.rrset) + } + } + close(results) + + if err != nil { + return nil, err + } + + return sz, nil +} + +var ( + ErrSignedRecordNotFound = errors.New("signed record not found") + ErrNotSignedRRSetReceived = errors.New("not signed RRSet received") +) + +func queryDNSKey(ctx context.Context, exchange server.Exchange, + zone string, qClass uint16) (signedrrset signedRRSet, err error) { + notSignedRRSet, signedrrset, err := fetchSingleRRSigRRSet(ctx, + exchange, zone, qClass, dns.TypeDNSKEY) + switch { + case err != nil: + return signedRRSet{}, err + case len(notSignedRRSet) > 0: + for _, rr := range signedrrset.rrset { + fmt.Println("\n\n===> Received signed RR:", rr) + } + for _, rr := range notSignedRRSet { + fmt.Println("\n\n===> Received not signed RR:", rr) + } + + return signedRRSet{}, fmt.Errorf("for %s: %w", + queryParamsToString(zone, qClass, dns.TypeDNSKEY), + ErrNotSignedRRSetReceived) + case len(signedrrset.rrset) == 0: + return signedRRSet{}, fmt.Errorf("for %s: %w", + queryParamsToString(zone, qClass, dns.TypeDNSKEY), + ErrSignedRecordNotFound) + } + return signedrrset, nil +} + +func queryDS(ctx context.Context, exchange server.Exchange, + zone string, qClass uint16) (signedrrset signedRRSet, + err error) { + if zone == "." { + // The root zone has no RRSIG + DS record. + // The root zone DS record is the root anchor, + // and it is not signed by an RRSIG. + return signedRRSet{}, nil + } + notSignedRRSet, signedrrset, err := fetchSingleRRSigRRSet(ctx, + exchange, zone, qClass, dns.TypeDS) + switch { + case err != nil: + return signedRRSet{}, err + case len(notSignedRRSet) > 0: + return signedRRSet{}, fmt.Errorf("for %s: %w", + queryParamsToString(zone, qClass, dns.TypeDNSKEY), + ErrNotSignedRRSetReceived) + case len(signedrrset.rrset) == 0: + return signedRRSet{}, fmt.Errorf("for %s: %w", + queryParamsToString(zone, qClass, dns.TypeDNSKEY), + ErrSignedRecordNotFound) + } + return signedrrset, nil +} + +func dnsKeyRRSetToMap(rrset []dns.RR) (keyTagToDNSKey map[uint16]*dns.DNSKEY) { + keyTagToDNSKey = make(map[uint16]*dns.DNSKEY, len(rrset)) + for _, rr := range rrset { + if rr.Header().Rrtype != dns.TypeDNSKEY { + continue + } + dnsKey, ok := rr.(*dns.DNSKEY) + if !ok { + panic(fmt.Sprintf("RR is of type %T and not of type *dns.DNSKEY", rr)) + } + + keyTagToDNSKey[dnsKey.KeyTag()] = dnsKey + } + return keyTagToDNSKey +} diff --git a/pkg/dnssec/client.go b/pkg/dnssec/client.go new file mode 100644 index 0000000..bbf05e4 --- /dev/null +++ b/pkg/dnssec/client.go @@ -0,0 +1,68 @@ +package dnssec + +import ( + "context" + "fmt" + + "github.com/miekg/dns" + "github.com/qdm12/dns/v2/internal/server" +) + +type Client struct { + exchange server.Exchange +} + +func New(exchange server.Exchange) *Client { + return &Client{ + exchange: exchange, + } +} + +func (c *Client) Exchange(ctx context.Context, request *dns.Msg) ( + response *dns.Msg, err error) { + response = new(dns.Msg) + + for _, question := range request.Question { + rrset, err := fetchAndValidateZone(ctx, c.exchange, + question.Name, question.Qclass, question.Qtype) + if err != nil { + return nil, fmt.Errorf("validating %s %s %s: %w", + question.Name, dns.ClassToString[question.Qclass], + dns.TypeToString[question.Qtype], err) + } + + response.Answer = append(response.Answer, rrset...) + } + + return response, nil +} + +func fetchAndValidateZone(ctx context.Context, exchange server.Exchange, + zone string, qClass, qType uint16) (rrset []dns.RR, err error) { + signedRRSets, notSignedRRSet, err := fetchSignedRRSets(ctx, exchange, zone, qClass, qType) + if err != nil { + return nil, fmt.Errorf("fetching desired zone RRSet with RRSig: %w", err) + } + for _, rr := range notSignedRRSet { + fmt.Println("\n\n===> Received not signed RR:", rr) + } + + delegationChain, err := buildDelegationChain(ctx, exchange, zone, qClass) + if err != nil { + return nil, fmt.Errorf("building delegation chain: %w", err) + } + + minRRSetSize := len(signedRRSets) // 1 RR per RRSig + rrset = make([]dns.RR, 0, minRRSetSize) + for _, signedRRSet := range signedRRSets { + err = verifyWithChain(signedRRSet, delegationChain) + if err != nil { + return nil, fmt.Errorf("verifying RRSet of type %s with RRSig "+ + "and delegation chain: %w", + signedRRSet.Type(), err) + } + rrset = append(rrset, signedRRSet.rrset...) + } + + return rrset, nil +} diff --git a/pkg/dnssec/fetch.go b/pkg/dnssec/fetch.go new file mode 100644 index 0000000..c85b10b --- /dev/null +++ b/pkg/dnssec/fetch.go @@ -0,0 +1,179 @@ +package dnssec + +import ( + "context" + "errors" + "fmt" + "sort" + "strings" + + "github.com/miekg/dns" + "github.com/qdm12/dns/v2/internal/server" + "golang.org/x/exp/maps" +) + +func newRequestWithRRSig(zone string, qClass, qType uint16) (request *dns.Msg) { + request = new(dns.Msg).SetQuestion(zone, qType) + request.Question[0].Qclass = qClass + request.RecursionDesired = true + const maxUDPSize = 4096 + const doEdns0 = true + request.SetEdns0(maxUDPSize, doEdns0) + return request +} + +var ( + ErrMultipleSignedRRSets = errors.New("multiple signed RRSets") +) + +func fetchSingleRRSigRRSet(ctx context.Context, exchange server.Exchange, zone string, + qClass, qType uint16) (notSignedRRSet []dns.RR, signedrrset signedRRSet, err error) { + signedRRSets, notSignedRRSet, err := fetchSignedRRSets(ctx, exchange, zone, qClass, qType) + if err != nil { + return nil, signedRRSet{}, err + } + + switch len(signedRRSets) { + case 0: + return notSignedRRSet, signedRRSet{}, nil + case 1: + return notSignedRRSet, signedRRSets[0], nil + default: + rrsigTypesCovered := make([]string, 0, len(signedRRSets)) + for _, signedRRSet := range signedRRSets { + rrsigTypesCovered = append(rrsigTypesCovered, + dns.TypeToString[signedRRSet.rrsig.TypeCovered]) + } + return nil, signedRRSet{}, fmt.Errorf("for %s: %w: for types %s", + queryParamsToString(zone, qClass, qType), + ErrMultipleSignedRRSets, strings.Join(rrsigTypesCovered, ", ")) + } +} + +var ( + ErrResponseNotAuthenticated = errors.New("response is not authenticated") + ErrRcodeServerFailure = errors.New("upstream return code is server failure") +) + +func fetchSignedRRSets(ctx context.Context, exchange server.Exchange, zone string, + qClass, qType uint16) (signedRRSets []signedRRSet, + notSignedRRSet []dns.RR, err error) { + request := newRequestWithRRSig(zone, qClass, qType) + + response, err := exchange(ctx, request) + if err != nil { + return nil, nil, err + } + + fmt.Println("\n\n===> Request", + queryParamsToString(zone, qClass, qType), + "\n===> Response", response.String()) + + if response.Rcode == dns.RcodeServerFailure { + // this may mean DNSSEC validation failed on the upstream server. + // https://www.ietf.org/rfc/rfc4033.txt + // This specification only defines how security-aware name servers can + // signal non-validating stub resolvers that data was found to be bogus + // (using RCODE=2, "Server Failure"; see [RFC4035]). + return nil, nil, fmt.Errorf("for %s: %w"+ + " (DNSSEC validation may have failed upstream)", + queryParamsToString(zone, qClass, qType), ErrRcodeServerFailure) + } + + if len(response.Answer) == 0 { + return nil, nil, nil + } + + signedRRSets, notSignedRRSet, err = answersToRRSigRRSet(response.Answer) + if err != nil { + return nil, nil, fmt.Errorf("for %s: %w", + queryParamsToString(zone, qClass, qType), err) + } + + return signedRRSets, notSignedRRSet, nil +} + +var ( + ErrRRSigMultipleForType = errors.New("multiple RRSIGs for the same record type") + ErrRRSigForNoRRSet = errors.New("RRSIG for no RRSet") +) + +func answersToRRSigRRSet(answers []dns.RR) ( + signedRRSets []signedRRSet, + notSignedRRSet []dns.RR, err error) { + // For well formed DNSSEC DNS answers, there should + // be at most N/2 signed record types where N is + // the number of total answers. + maxTypes := len(answers) / 2 //nolint:gomnd + typeToRRSig := make(map[uint16]*dns.RRSIG, maxTypes) + typeToRRSet := make(map[uint16][]dns.RR, maxTypes) + + // used to set the capacity of notSignedRRSet + notSignedRRFound := 0 + + for _, rr := range answers { + rrType := rr.Header().Rrtype + if rrType != dns.TypeRRSIG { + rrIsSigned := typeToRRSig[rrType] != nil + if !rrIsSigned { + notSignedRRFound++ + } + rrset := typeToRRSet[rrType] + rrset = append(rrset, rr) + typeToRRSet[rrType] = rrset + continue + } + + rrsig, ok := rr.(*dns.RRSIG) + if !ok { + panic(fmt.Sprintf("RR is of type %T and not of type *dns.RRSIG", rr)) + } + typeCovered := rrsig.TypeCovered + if typeToRRSig[typeCovered] != nil { + return nil, nil, fmt.Errorf("%w: %s", + ErrRRSigMultipleForType, + dns.TypeToString[typeCovered]) + } + typeToRRSig[typeCovered] = rrsig + + rrset, exists := typeToRRSet[typeCovered] + if exists { // it is now signed + notSignedRRFound -= len(rrset) + } + } + + signedRRSets = make([]signedRRSet, 0, len(typeToRRSig)) + + for typeCovered, rrsig := range typeToRRSig { + rrset := typeToRRSet[typeCovered] + if len(rrset) == 0 { + return nil, nil, fmt.Errorf("%w: for type %s", + ErrRRSigForNoRRSet, + dns.TypeToString[typeCovered]) + } + signedRRSets = append(signedRRSets, signedRRSet{ + rrsig: rrsig, + rrset: rrset, + }) + delete(typeToRRSet, typeCovered) + } + + // Predictable order for tests + sort.Slice(signedRRSets, func(i, j int) bool { + return signedRRSets[i].rrsig.TypeCovered < signedRRSets[j].rrsig.TypeCovered + }) + + // Remaining not signed RRSets + // Sort by type for predictable order + notSignedTypes := maps.Keys(typeToRRSet) + sort.Slice(notSignedTypes, func(i, j int) bool { + return notSignedTypes[i] < notSignedTypes[j] + }) + notSignedRRSet = make([]dns.RR, 0, notSignedRRFound) + for _, notSignedType := range notSignedTypes { + rrset := typeToRRSet[notSignedType] + notSignedRRSet = append(notSignedRRSet, rrset...) + } + + return signedRRSets, notSignedRRSet, nil +} diff --git a/pkg/dnssec/fetch_test.go b/pkg/dnssec/fetch_test.go new file mode 100644 index 0000000..b34766d --- /dev/null +++ b/pkg/dnssec/fetch_test.go @@ -0,0 +1,203 @@ +package dnssec + +import ( + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func Test_answersToRRSigRRSet(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + answers []dns.RR + signedRRSets []signedRRSet + notSignedRRSet []dns.RR + errWrapped error + errMessage string + }{ + "no_answer": { + signedRRSets: []signedRRSet{}, + notSignedRRSet: []dns.RR{}, + }, + "bad_single_rrsig_answer": { + answers: []dns.RR{ + &dns.RRSIG{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeRRSIG, + }, + TypeCovered: dns.TypeA, + }, + }, + errWrapped: ErrRRSigForNoRRSet, + errMessage: "RRSIG for no RRSet: for type A", + }, + "bad_rrsig_for_no_rrset": { + answers: []dns.RR{ + &dns.AAAA{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeAAAA, + }, + }, + &dns.RRSIG{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeRRSIG, + }, + TypeCovered: dns.TypeAAAA, + }, + &dns.RRSIG{ // bad one + Hdr: dns.RR_Header{ + Rrtype: dns.TypeRRSIG, + }, + TypeCovered: dns.TypeA, + }, + }, + errWrapped: ErrRRSigForNoRRSet, + errMessage: "RRSIG for no RRSet: for type A", + }, + "bad_multiple_rrsig_for_same_type": { + answers: []dns.RR{ + &dns.RRSIG{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeRRSIG, + }, + TypeCovered: dns.TypeA, + }, + &dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeA, + }, + }, + &dns.RRSIG{ // bad one + Hdr: dns.RR_Header{ + Rrtype: dns.TypeRRSIG, + }, + TypeCovered: dns.TypeA, + }, + }, + errWrapped: ErrRRSigMultipleForType, + errMessage: "multiple RRSIGs for the same record type: A", + }, + "only_signed_rrsets": { + answers: []dns.RR{ + &dns.RRSIG{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeRRSIG, + }, + TypeCovered: dns.TypeA, + }, + &dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeA, + }, + }, + &dns.AAAA{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeAAAA, + }, + }, + &dns.RRSIG{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeRRSIG, + }, + TypeCovered: dns.TypeAAAA, + }, + }, + signedRRSets: []signedRRSet{ + { + rrsig: &dns.RRSIG{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeRRSIG, + }, + TypeCovered: dns.TypeA, + }, + rrset: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeA, + }, + }, + }, + }, + { + rrsig: &dns.RRSIG{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeRRSIG, + }, + TypeCovered: dns.TypeAAAA, + }, + rrset: []dns.RR{ + &dns.AAAA{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeAAAA, + }, + }, + }, + }, + }, + notSignedRRSet: []dns.RR{}, + }, + "signed_and_not_signed_rrsets": { + answers: []dns.RR{ + &dns.RRSIG{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeRRSIG, + }, + TypeCovered: dns.TypeA, + }, + &dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeA, + }, + }, + &dns.AAAA{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeAAAA, + }, + }, + }, + signedRRSets: []signedRRSet{ + { + rrsig: &dns.RRSIG{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeRRSIG, + }, + TypeCovered: dns.TypeA, + }, + rrset: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeA, + }, + }, + }, + }, + }, + notSignedRRSet: []dns.RR{ + &dns.AAAA{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeAAAA, + }, + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + rrSigToRRSet, notSignedRRSet, err := + answersToRRSigRRSet(testCase.answers) + + assert.Equal(t, testCase.signedRRSets, rrSigToRRSet) + assert.Equal(t, testCase.notSignedRRSet, notSignedRRSet) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} diff --git a/pkg/dnssec/helpers.go b/pkg/dnssec/helpers.go new file mode 100644 index 0000000..1061300 --- /dev/null +++ b/pkg/dnssec/helpers.go @@ -0,0 +1,7 @@ +package dnssec + +import "github.com/miekg/dns" + +func queryParamsToString(zone string, qClass, qType uint16) string { + return zone + " " + dns.ClassToString[qClass] + " " + dns.TypeToString[qType] +} diff --git a/pkg/dnssec/integration_test.go b/pkg/dnssec/integration_test.go new file mode 100644 index 0000000..7504196 --- /dev/null +++ b/pkg/dnssec/integration_test.go @@ -0,0 +1,138 @@ +//go:build integration +// +build integration + +package dnssec + +import ( + "context" + "net" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/qdm12/dns/v2/internal/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func getRRSetWithoutValidation(t *testing.T, zone string, + qType, qClass uint16) (rrset []dns.RR) { + t.Helper() + + request := new(dns.Msg) + request.SetQuestion(zone, qType) + request.Question[0].Qclass = qClass + + response, _, err := new(dns.Client).Exchange(request, "1.1.1.1:53") + require.NoError(t, err) + + // Clear TTL since they are not predicatable + for _, rr := range response.Answer { + rr.Header().Ttl = 0 + } + + return response.Answer +} + +func testExchange() server.Exchange { + client := &dns.Client{} + dialer := &net.Dialer{} + return func(ctx context.Context, request *dns.Msg) (response *dns.Msg, err error) { + netConn, err := dialer.DialContext(ctx, "udp", "1.1.1.1:53") + if err != nil { + return nil, err + } + + dnsConn := &dns.Conn{Conn: netConn} + response, _, err = client.ExchangeWithConn(request, dnsConn) + + _ = dnsConn.Close() + + return response, err + } +} + +func Test_fetchAndValidateZone(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + zone string + dnsType uint16 + exchange server.Exchange + rrset []dns.RR + errWrapped error + errMessage string + }{ + // "valid_dnssec": { + // zone: "qqq.ninja.", + // dnsType: dns.TypeA, + // rrset: getRRSetWithoutValidation(t, "qqq.ninja.", dns.TypeA, dns.ClassINET), + // exchange: testExchange(), + // }, + // "no_rrset_no_rsig": { + // zone: "vip.icann.org.", + // dnsType: dns.TypeA, + // exchange: testExchange(), + // }, + "no_ds": { + zone: "textsecure-service.whispersystems.org.", + dnsType: dns.TypeA, + rrset: getRRSetWithoutValidation(t, "textsecure-service.whispersystems.org.", dns.TypeA, dns.ClassINET), + exchange: testExchange(), + }, + // "no DNSSEC": { + // zone: "github.com.", + // dnsType: dns.TypeA, + // rrset: getRRSetWithoutValidation(t, "github.com.", dns.TypeA, dns.ClassINET), + // exchange: testExchange(), + // }, + // "bad DNSSEC already failed by upstream": { + // zone: "dnssec-failed.org.", + // dnsType: dns.TypeA, + // exchange: testExchange(), + // errWrapped: ErrRcodeServerFailure, + // errMessage: "cannot fetch desired RRSet and RRSig: " + + // "for dnssec-failed.org. IN A: " + + // "DNSSEC validation might had failed upstream", + // }, + } + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + deadline, ok := t.Deadline() + if !ok { + deadline = time.Now().Add(5 * time.Second) + } + + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + rrset, err := fetchAndValidateZone(ctx, testCase.exchange, + testCase.zone, dns.ClassINET, testCase.dnsType) + + // Remove TTL fields from rrset + for i := range rrset { + rrset[i].Header().Ttl = 0 + } + + assert.Equal(t, testCase.rrset, rrset) + require.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Benchmark_fetchAndValidateZone(b *testing.B) { + ctx := context.Background() + const zone = "qqq.ninja." + const dnsType = dns.TypeA + exchange := testExchange() + + for i := 0; i < b.N; i++ { + _, _ = fetchAndValidateZone(ctx, exchange, zone, dns.ClassINET, dnsType) + } +} diff --git a/pkg/dnssec/signedzone.go b/pkg/dnssec/signedzone.go new file mode 100644 index 0000000..4da25e2 --- /dev/null +++ b/pkg/dnssec/signedzone.go @@ -0,0 +1,50 @@ +package dnssec + +import ( + "errors" + "fmt" + "time" + + "github.com/miekg/dns" +) + +type signedZone struct { + zone string + dnsKeyRRSet signedRRSet + dsRRSet signedRRSet + keyTagToDNSKey map[uint16]*dns.DNSKEY +} + +type signedRRSet struct { + rrsig *dns.RRSIG + rrset []dns.RR +} + +func (s signedRRSet) Type() string { + return dns.TypeToString[s.rrsig.TypeCovered] +} + +var ( + ErrRRSigExpired = errors.New("RRSIG has expired") +) + +func (s signedRRSet) verify(keyTagToDNSKey map[uint16]*dns.DNSKEY) ( + err error) { + if !s.rrsig.ValidityPeriod(time.Now()) { + return fmt.Errorf("%w", ErrRRSigExpired) + } + + keyTag := s.rrsig.KeyTag + dnsKey, ok := keyTagToDNSKey[keyTag] + if !ok { + return fmt.Errorf("for RRSIG key tag %d: %w", + keyTag, ErrDNSKeyNotFound) + } + + err = s.rrsig.Verify(dnsKey, s.rrset) + if err != nil { + return fmt.Errorf("verification failed: %w", err) + } + + return nil +} diff --git a/pkg/dnssec/verify.go b/pkg/dnssec/verify.go new file mode 100644 index 0000000..f638b59 --- /dev/null +++ b/pkg/dnssec/verify.go @@ -0,0 +1,138 @@ +package dnssec + +import ( + "errors" + "fmt" + "strings" + + "github.com/miekg/dns" +) + +// verify uses the zone data in the signed zone and its parent signed zones +// to validate the DNSSEC chain of trust. +// It starts the verification with the RRSet given as argument, and, +// assuming a signature is valid, it walks through the slice of signed +// zones checking the RRSIGs on the DNSKEY and DS resource record sets. +func verifyWithChain(signedRRSet signedRRSet, chain []*signedZone) error { + childDesiredZone := chain[0] // child desired zone + + // Verify desired RRSet with DNSKEY of desired zone matching the + // RRSIG key tag. + err := signedRRSet.verify( + childDesiredZone.keyTagToDNSKey) + if err != nil { + return fmt.Errorf("verifying desired RRSet: %w", err) + } + + chainWithoutRoot := chain[:len(chain)-1] + for i, signedZone := range chainWithoutRoot { + // Iterate in this order: "example.com.", "com." + + // Verify DNSKEY RRSet with its RRSIG and the DNSKEY matching + // the RRSIG key tag. + err := signedZone.dnsKeyRRSet.verify( + signedZone.keyTagToDNSKey) + if err != nil { + return fmt.Errorf("verifying DNSKEY records for zone %s: %w", + signedZone.zone, err) + } + + // Verify DS RRSet with its RRSIG and the DNSKEY of its parent zone + // matching the RRSIG key tag. + parentSignedZone := chain[i+1] + err = signedZone.dsRRSet.verify( + parentSignedZone.keyTagToDNSKey) + if err != nil { + return fmt.Errorf("verifying DS records for zone %s: %w", + signedZone.zone, err) + } + + // Verify DS RRSet digests with their corresponding DNSKEYs. + err = verifyDSRRSet(signedZone.dsRRSet.rrset, + signedZone.keyTagToDNSKey) + if err != nil { + return fmt.Errorf("verifying DS RRSet for zone %s: %w", + signedZone.zone, err) + } + } + + // Verify the root zone "." + rootZone := chain[len(chain)-1] + + // Verify DNSKEY RRSet with its RRSIG and the DNSKEY matching + // the RRSIG key tag. + err = rootZone.dnsKeyRRSet.verify( + rootZone.keyTagToDNSKey) + if err != nil { + return fmt.Errorf("verifying DNSKEY records for the root zone: %w", + err) + } + + // Verify the root anchor digest against the digest of the DS + // calculated from the DNSKEY of the root zone matching the + // root anchor key tag. + const ( + rootAnchorKeyTag = 20326 + rootAnchorDigest = "E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D" + ) + rootAnchor := &dns.DS{ + Algorithm: dns.RSASHA256, + DigestType: dns.SHA256, + KeyTag: rootAnchorKeyTag, + Digest: rootAnchorDigest, + } + err = verifyDS(rootAnchor, rootZone.keyTagToDNSKey) + if err != nil { + return fmt.Errorf("verifying the root anchor: %w", err) + } + + return nil +} + +// verifyDSRRSet verifies the digest of each received DS +// is equal to the digest of the calculated DS obtained +// from the DNSKEY (KSK) matching the received DS key tag. +func verifyDSRRSet(dsRRSet []dns.RR, + keyTagToDNSKey map[uint16]*dns.DNSKEY) (err error) { + for _, rr := range dsRRSet { + ds, ok := rr.(*dns.DS) + if !ok { + panic(fmt.Sprintf("RR is of type %T and not of type *dns.DS", rr)) + } + err = verifyDS(ds, keyTagToDNSKey) + if err != nil { + return fmt.Errorf("verifying DS record: %w", err) + } + } + return nil +} + +var ( + ErrDNSKeyNotFound = errors.New("DNS Key record not found") + ErrDNSKeyToDS = errors.New("failed to calculate DS from DNSKEY") + ErrDNSKeyDSMismatch = errors.New("DS does not match DNS key") +) + +func verifyDS(receivedDS *dns.DS, + keyTagToDNSKey map[uint16]*dns.DNSKEY) error { + dnsKey, ok := keyTagToDNSKey[receivedDS.KeyTag] + if !ok { + return fmt.Errorf("%w: for key tag %d", + ErrDNSKeyNotFound, receivedDS.KeyTag) + } + + calculatedDS := dnsKey.ToDS(receivedDS.DigestType) + if calculatedDS == nil { + return fmt.Errorf("%w: for DNSKEY name %s and digest type %d", + ErrDNSKeyToDS, dnsKey.Header().Name, receivedDS.DigestType) + } + + if !strings.EqualFold(receivedDS.Digest, calculatedDS.Digest) { + return fmt.Errorf("%w: DS record has digest %s "+ + "but DNSKEY calculated DS has digest %s", + ErrDNSKeyDSMismatch, receivedDS.Digest, calculatedDS.Digest) + } + // TODO NSEC + + return nil +} diff --git a/pkg/doh/handler.go b/pkg/doh/handler.go index a2147d5..b011d63 100644 --- a/pkg/doh/handler.go +++ b/pkg/doh/handler.go @@ -4,6 +4,7 @@ import ( "context" "github.com/qdm12/dns/v2/internal/server" + "github.com/qdm12/dns/v2/pkg/dnssec" ) func newDNSHandler(ctx context.Context, settings ServerSettings) ( @@ -11,6 +12,10 @@ func newDNSHandler(ctx context.Context, settings ServerSettings) ( dial := newDoHDial(settings.Resolver) exchange := server.NewExchange("DoH", dial, settings.Logger) + if *settings.DNSSEC { + dnssecClient := dnssec.New(exchange) + exchange = dnssecClient.Exchange + } return server.New(ctx, exchange, settings.Logger) } diff --git a/pkg/doh/integration_test.go b/pkg/doh/integration_test.go index 391cfd4..0811384 100644 --- a/pkg/doh/integration_test.go +++ b/pkg/doh/integration_test.go @@ -266,6 +266,7 @@ func Test_Server_Mocks(t *testing.T) { Metrics: metrics, }, ListeningAddress: ptrTo(""), + DNSSEC: ptrTo(false), }) require.NoError(t, err) diff --git a/pkg/doh/settings.go b/pkg/doh/settings.go index 231f384..567084b 100644 --- a/pkg/doh/settings.go +++ b/pkg/doh/settings.go @@ -19,6 +19,7 @@ import ( type ServerSettings struct { Resolver ResolverSettings ListeningAddress *string + DNSSEC *bool // Middlewares is a list of middlewares to use. // The first one is the first wrapper, and the last one // is the last wrapper of the handlers in the chain. @@ -47,6 +48,7 @@ func (s *ServerSettings) SetDefaults() { s.Resolver.SetDefaults() s.ListeningAddress = gosettings.DefaultPointer(s.ListeningAddress, ":53") s.Logger = gosettings.DefaultComparable[Logger](s.Logger, lognoop.New()) + s.DNSSEC = gosettings.DefaultPointer(s.DNSSEC, false) } func (s *ResolverSettings) SetDefaults() { @@ -113,6 +115,7 @@ func (s *ResolverSettings) String() string { func (s *ServerSettings) ToLinesNode() (node *gotree.Node) { node = gotree.New("DoH server settings:") node.Appendf("Listening address: %s", *s.ListeningAddress) + node.Appendf("DNSSEC: %s", gosettings.BoolToYesNo(s.DNSSEC)) node.AppendNode(s.Resolver.ToLinesNode()) return node } diff --git a/pkg/doh/settings_test.go b/pkg/doh/settings_test.go index ab63a20..6af9ec8 100644 --- a/pkg/doh/settings_test.go +++ b/pkg/doh/settings_test.go @@ -40,6 +40,7 @@ func Test_ServerSettings_SetDefaults(t *testing.T) { Metrics: metrics, }, ListeningAddress: ptrTo(":53"), + DNSSEC: ptrTo(true), } assert.Equal(t, expectedSettings, s) } diff --git a/pkg/dot/handler.go b/pkg/dot/handler.go index 89f6875..a6fdf7f 100644 --- a/pkg/dot/handler.go +++ b/pkg/dot/handler.go @@ -4,11 +4,17 @@ import ( "context" "github.com/qdm12/dns/v2/internal/server" + "github.com/qdm12/dns/v2/pkg/dnssec" ) func newDNSHandler(ctx context.Context, settings ServerSettings) ( handler *server.Handler) { dial := newDoTDial(settings.Resolver) exchange := server.NewExchange("DoT", dial, settings.Logger) + if *settings.DNSSEC { + dnssecClient := dnssec.New(exchange) + exchange = dnssecClient.Exchange + } + return server.New(ctx, exchange, settings.Logger) } diff --git a/pkg/dot/integration_test.go b/pkg/dot/integration_test.go index d98c710..63f65d3 100644 --- a/pkg/dot/integration_test.go +++ b/pkg/dot/integration_test.go @@ -258,6 +258,7 @@ func Test_Server_Mocks(t *testing.T) { require.NoError(t, err) server, err := NewServer(ServerSettings{ + DNSSEC: ptrTo(false), Logger: logger, Middlewares: []Middleware{metricsMiddleware, cacheMiddleware, filterMiddleware}, Resolver: ResolverSettings{ diff --git a/pkg/dot/settings.go b/pkg/dot/settings.go index 39a12c8..4bcdfe6 100644 --- a/pkg/dot/settings.go +++ b/pkg/dot/settings.go @@ -19,6 +19,7 @@ import ( type ServerSettings struct { Resolver ResolverSettings ListeningAddress *string + DNSSEC *bool // Middlewares is a list of middlewares to use. // The first one is the first wrapper, and the last one // is the last wrapper of the handlers in the chain. @@ -55,6 +56,7 @@ func (s *ServerSettings) SetDefaults() { s.Resolver.SetDefaults() s.ListeningAddress = gosettings.DefaultPointer(s.ListeningAddress, ":53") s.Logger = gosettings.DefaultComparable[Logger](s.Logger, lognoop.New()) + s.DNSSEC = gosettings.DefaultPointer(s.DNSSEC, false) } func (s *ResolverSettings) SetDefaults() { @@ -124,6 +126,7 @@ func (s *ResolverSettings) String() string { func (s *ServerSettings) ToLinesNode() (node *gotree.Node) { node = gotree.New("DoT server settings:") node.Appendf("Listening address: %s", *s.ListeningAddress) + node.Appendf("DNSSEC: %s", gosettings.BoolToYesNo(s.DNSSEC)) node.AppendNode(s.Resolver.ToLinesNode()) return node }