diff --git a/TODO.txt b/TODO.txt new file mode 100644 index 0000000..ecb8774 --- /dev/null +++ b/TODO.txt @@ -0,0 +1,6 @@ +- CNAME resolution +- Ensure NS records don't point inward (will require changes to unrelated tests) +- Alias values via `!alias [index, default 0]` + - e.g. `www: !alias example.com A` + - Maybe ALIAS "records" too, to reduce complexity vs + - Either way, resolved at validate time, not at query time, and only searches within loaded zones, not external diff --git a/testdata/bad_cname_with_other.yaml b/testdata/bad_cname_with_other.yaml new file mode 100644 index 0000000..753e494 --- /dev/null +++ b/testdata/bad_cname_with_other.yaml @@ -0,0 +1,16 @@ +org: + example: + "@": + - type: SOA + value: ns1.example.com. admin.example.com. 1 1 1 1 1 + - type: A + value: 192.0.2.100 + - type: AAAA + value: 2001:db8::100 + - type: NS + value: ns1.example.com + "ext": + - type: CNAME + value: example.com + - type: A + value: 1.1.1.1 diff --git a/testdata/bad_glue_with_other.yaml b/testdata/bad_glue_with_other.yaml new file mode 100644 index 0000000..cdd5f9a --- /dev/null +++ b/testdata/bad_glue_with_other.yaml @@ -0,0 +1,16 @@ +org: + example: + "@": + - type: SOA + value: ns1.example.com. admin.example.com. 1 1 1 1 1 + - type: NS + value: ns1.example.com + delegated: + "@": + - type: NS + value: ns1.delegated.example.org + "ns1": + - type: A + value: 1.2.3.4 + - type: TXT + value: foobar diff --git a/testdata/bad_ns_with_other.yaml b/testdata/bad_ns_with_other.yaml new file mode 100644 index 0000000..cc045a1 --- /dev/null +++ b/testdata/bad_ns_with_other.yaml @@ -0,0 +1,16 @@ +org: + example: + "@": + - type: SOA + value: ns1.example.com. admin.example.com. 1 1 1 1 1 + - type: A + value: 192.0.2.100 + - type: AAAA + value: 2001:db8::100 + - type: NS + value: ns1.example.com + "ext": + - type: NS + value: ns1.example.com + - type: A + value: 1.1.1.1 diff --git a/testdata/bad_ns_with_subzone.yaml b/testdata/bad_ns_with_subzone.yaml new file mode 100644 index 0000000..65bf056 --- /dev/null +++ b/testdata/bad_ns_with_subzone.yaml @@ -0,0 +1,14 @@ +org: + example: + "@": + - type: SOA + value: ns1.example.com. admin.example.com. 1 1 1 1 1 + - type: NS + value: ns1.example.com + delegated: + "@": + - type: NS + value: ns1.example.com + subzone: + - type: A + value: 1.2.3.4 diff --git a/testdata/zones.yaml b/testdata/zones.yaml index b45ea3b..2f3ebe4 100644 --- a/testdata/zones.yaml +++ b/testdata/zones.yaml @@ -59,6 +59,15 @@ com: # Comment - type: NS ttl: 3600 value: ns2.example.org + client: + "@": + - type: NS + ttl: 3600 + value: ns1.client.example.com + ns1: + - type: A + ttl: 3600 + value: 192.0.2.3 folders: !include nested/zone.yaml org: example: !include example.org.yaml diff --git a/yamlplugin.go b/yamlplugin.go index 7f42f4a..8f9b06b 100644 --- a/yamlplugin.go +++ b/yamlplugin.go @@ -26,26 +26,6 @@ type YamlPluginConfig struct { var log = clog.NewWithPlugin("yaml") -func (y YamlPlugin) lookupRRs(qname string, qtype string) ([]dns.RR, bool) { - records, ok := y.Zone.LookupType(qname, qtype) - if !ok { - return nil, false - } - rrs := []dns.RR{} - for _, record := range records { - ttl := record.Ttl - if ttl == 0 { - ttl = y.Config.DefaultTtl - } - rr, err := dns.NewRR(fmt.Sprintf("%s %d %s %s", qname, ttl, record.Type, record.Value)) - if err != nil { - return nil, false - } - rrs = append(rrs, rr) - } - return rrs, true -} - func (y YamlPlugin) Name() string { return "yaml" } func (y YamlPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { @@ -57,17 +37,16 @@ func (y YamlPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M reply.SetReply(r) reply.Authoritative = true - rrs, ok := y.lookupRRs(qname, qtype) - if !ok { - return dns.RcodeNameError, nil + rcode, err := y.lookupRRs(qname, qtype, reply) + if err != nil { + return dns.RcodeServerFailure, fmt.Errorf("failed to lookup RRs for %s %s: %w", qname, qtype, err) } - reply.Answer = rrs w.WriteMsg(reply) - return dns.RcodeSuccess, nil + return rcode, nil } func setup(c *caddy.Controller) error { - c.Next() // yaml + c.Next() // skip "yaml" filename := "zones.yaml" if c.NextArg() { filename = c.Val() @@ -97,3 +76,37 @@ func setup(c *caddy.Controller) error { } func init() { plugin.Register("yaml", setup) } + +func (y YamlPlugin) lookupRRs(qname string, qtype string, reply *dns.Msg) (int, error) { + var rcode int + res, ok := y.Zone.LookupType(qname, qtype) + if !ok { + // NXDOMAIN + rcode = dns.RcodeNameError + } + if res.IsReferral { + reply.Authoritative = false + } + sections := []struct { + source []NamedRecord + dest *[]dns.RR + }{ + {res.Answer, &reply.Answer}, + {res.Ns, &reply.Ns}, + {res.Extra, &reply.Extra}, + } + for _, section := range sections { + for _, record := range section.source { + ttl := record.Record.Ttl + if ttl == 0 { + ttl = y.Config.DefaultTtl + } + rr, err := dns.NewRR(fmt.Sprintf("%s %d %s %s", record.Name, ttl, record.Record.Type, record.Record.Value)) + if err != nil { + return rcode, fmt.Errorf("failed to generate RR for %s %s %s: %w", record.Name, record.Record.Type, record.Record.Value, err) + } + *section.dest = append(*section.dest, rr) + } + } + return rcode, nil +} diff --git a/yamlzone.go b/yamlzone.go index 5438c42..b9e2ff0 100644 --- a/yamlzone.go +++ b/yamlzone.go @@ -18,11 +18,24 @@ type Record struct { Value string } +type NamedRecord struct { + Name string + Record Record +} + type Zone struct { Subzones map[string]*Zone Records []Record - GlueRecords []Record + SOA NamedRecord IsDelegationPoint bool + GlueRecords []NamedRecord // Only set (optionally) for delegation points +} + +type LookupResult struct { + Answer []NamedRecord + Ns []NamedRecord + Extra []NamedRecord + IsReferral bool } type contextKey string @@ -57,31 +70,40 @@ func LoadZoneBytes(data []byte, filename string) (*Zone, error) { if err != nil { return nil, err } - if err := z.Validate(".", false); err != nil { + if err := z.Validate(".", nil, nil); err != nil { return nil, err } return z, nil } -func (z *Zone) Validate(name string, zoneApexPresent bool) error { - nameservers := map[string]bool{} - otherRecords := []Record{} +func (z *Zone) Validate(name string, soa *Record, ns *[]Record) error { + nameservers, otherRecords := []Record{}, []Record{} + nameserversMap := map[string]bool{} isZoneApex := false + cnameCount := 0 for _, record := range z.Records { switch record.Type { case "SOA": isZoneApex = true + soa = &record + z.SOA = NamedRecord{Name: name, Record: record} case "NS": - nameservers[record.Value] = true + nameservers = append(nameservers, record) + nameserversMap[record.Value] = true + case "CNAME": + cnameCount++ default: otherRecords = append(otherRecords, record) } } - zoneApexPresent = zoneApexPresent || isZoneApex + if cnameCount > 1 || (cnameCount > 0 && len(otherRecords) > 0) { + return fmt.Errorf("%s: extraneous records found next to CNAME", name) + } + z.IsDelegationPoint = !isZoneApex && len(nameservers) > 0 - if !zoneApexPresent { + if soa == nil { // Outside zone if z.IsDelegationPoint { return fmt.Errorf("%s: delegation point found outside zone", name) @@ -93,51 +115,51 @@ func (z *Zone) Validate(name string, zoneApexPresent bool) error { if len(nameservers) == 0 { return fmt.Errorf("%s: zone apex missing NS records", name) } + ns = &nameservers } else if len(nameservers) > 0 { // Delegation point (does not fall through to subzone validation) if len(otherRecords) > 0 { return fmt.Errorf("%s: non-glue, non-NS records found at delegation point: %v", name, otherRecords) } for subname, subzone := range z.Subzones { - glueRecords, err := subzone.GetGlueRecords(concatName(name, subname), nameservers) + // This populates z.GlueRecords directly + err := subzone.GetGlueRecords(concatName(name, subname), nameserversMap) if err != nil { return err } - z.GlueRecords = append(z.GlueRecords, glueRecords...) } return nil } // Subzone validation // Either we're outside a zone, at a zone apex, or at a non-delegated subzone for subname, subzone := range z.Subzones { - if err := subzone.Validate(concatName(name, subname), zoneApexPresent); err != nil { + if err := subzone.Validate(concatName(name, subname), soa, ns); err != nil { return err } } return nil } -func (z *Zone) GetGlueRecords(name string, nameservers map[string]bool) ([]Record, error) { +func (z *Zone) GetGlueRecords(name string, nameserversMap map[string]bool) error { // If the domain is not a nameserver, it must have no records - if _, ok := nameservers[name]; !ok { + if _, ok := nameserversMap[strings.TrimSuffix(name, ".")]; !ok { if len(z.Records) > 0 { - return nil, fmt.Errorf("%s: non-glue records found under delegation point: %v", name, z.Records) + return fmt.Errorf("%s: non-glue records found under delegation point: %v", name, z.Records) } } // Any records under a delegation point must be glue records for _, record := range z.Records { if !(record.Type == "A" || record.Type == "AAAA") { - return nil, fmt.Errorf("%s: non-glue record found under delegation point: %v", name, record) + return fmt.Errorf("%s: non-glue record found under delegation point: %v", name, record) } + z.GlueRecords = append(z.GlueRecords, NamedRecord{Name: name, Record: record}) } for subname, subzone := range z.Subzones { - if glueRecords, err := subzone.GetGlueRecords(concatName(name, subname), nameservers); err != nil { - return nil, err - } else { - z.GlueRecords = append(z.GlueRecords, glueRecords...) + if err := subzone.GetGlueRecords(concatName(name, subname), nameserversMap); err != nil { + return err } } - return z.GlueRecords, nil + return nil } func (r *Record) UnmarshalYAML(ctx context.Context, data []byte) error { @@ -252,38 +274,67 @@ func nameToPath(name string) []string { return path } -func (z *Zone) Lookup(name string) ([]Record, bool) { +func (z *Zone) Lookup(name string) (LookupResult, bool) { + res := LookupResult{} path := nameToPath(name) for _, label := range path { // Support empty name and trailing dot if label == "" { continue } - if sz, ok := z.Subzones[label]; ok { - z = sz + if z.IsDelegationPoint { + res.IsReferral = true + // Capture NS records + for _, record := range z.Records { + if record.Type == "NS" { + res.Ns = append(res.Ns, NamedRecord{Name: name, Record: record}) + } + } + // Retrieve glue records from cache + res.Extra = append(res.Extra, z.GlueRecords...) + return res, true + } + if sz, ok := z.Subzones[label]; !ok { + // NXDOMAIN + res.Ns = []NamedRecord{z.SOA} + return res, false } else { - return nil, false + z = sz } } - return z.Records, true + res.Answer = []NamedRecord{} + for _, record := range z.Records { + res.Answer = append(res.Answer, NamedRecord{Name: name, Record: record}) + } + if len(res.Answer) == 0 { + // NODATA + res.Ns = []NamedRecord{z.SOA} + return res, true + } + return res, true } -func (z *Zone) FilterRecords(records []Record, recordType string) []Record { - filtered := []Record{} - for _, record := range records { - if record.Type == recordType { - filtered = append(filtered, record) +func (z *Zone) FilterRecords(res LookupResult, recordType string) LookupResult { + filtered := []NamedRecord{} + for _, nr := range res.Answer { + if nr.Record.Type == recordType { + filtered = append(filtered, nr) } } - return filtered + res.Answer = filtered + if len(res.Answer) == 0 && !res.IsReferral { + // This ends up as NODATA even if it had data before filtering + res.Ns = []NamedRecord{z.SOA} + } + return res } -func (z *Zone) LookupType(name string, recordType string) ([]Record, bool) { - records, ok := z.Lookup(name) +func (z *Zone) LookupType(name string, recordType string) (LookupResult, bool) { + res, ok := z.Lookup(name) if !ok { - return nil, false + return LookupResult{}, false } - return z.FilterRecords(records, recordType), true + return z.FilterRecords(res, recordType), true } func concatName(name string, subname string) string { diff --git a/yamlzone_test.go b/yamlzone_test.go index 66b4491..c641fb2 100644 --- a/yamlzone_test.go +++ b/yamlzone_test.go @@ -14,15 +14,15 @@ func assertOk(t *testing.T, ok bool) { } } -func assertNotOk(t *testing.T, ok bool, records []Record) { +func assertNotOk(t *testing.T, ok bool, res LookupResult) { if ok { - t.Fatalf("Expected not ok, got ok (%v)", records) + t.Fatalf("Expected not ok, got ok (%v)", res) } } -func assertRecordCount(t *testing.T, records []Record, expected int) { - if len(records) != expected { - t.Fatalf("Expected %d records, got %d (%v)", expected, len(records), records) +func assertRecordCount(t *testing.T, res LookupResult, expected int) { + if len(res.Answer) != expected { + t.Fatalf("Expected %d records, got %d (%v)", expected, len(res.Answer), res.Answer) } } @@ -65,25 +65,25 @@ func TestEmptyZone(t *testing.T) { } t.Run("Lookup empty string", func(t *testing.T) { - records, ok := zEmpty.Lookup("") + res, ok := zEmpty.Lookup("") assertOk(t, ok) - assertRecordCount(t, records, 0) + assertRecordCount(t, res, 0) }) t.Run("Lookup single dot", func(t *testing.T) { - records, ok := zEmpty.Lookup(".") + res, ok := zEmpty.Lookup(".") assertOk(t, ok) - assertRecordCount(t, records, 0) + assertRecordCount(t, res, 0) }) t.Run("Lookup example.com", func(t *testing.T) { - records, ok := zEmpty.Lookup("example.com") - assertNotOk(t, ok, records) + res, ok := zEmpty.Lookup("example.com") + assertNotOk(t, ok, res) }) t.Run("LookupType example.com A", func(t *testing.T) { - records, ok := zEmpty.LookupType("example.com", "A") - assertNotOk(t, ok, records) + res, ok := zEmpty.LookupType("example.com", "A") + assertNotOk(t, ok, res) }) } @@ -104,23 +104,23 @@ func TestSimpleZone(t *testing.T) { } t.Run("Lookup", func(t *testing.T) { - records, ok := zSimple.Lookup("") + res, ok := zSimple.Lookup("") assertOk(t, ok) - assertRecordCount(t, records, 4) - assertRecord(t, records[0], "SOA", 0, "ns1.example.com. admin.example.com. 1 1 1 1 1") - assertRecord(t, records[1], "A", 0, "192.0.2.100") - assertRecord(t, records[2], "AAAA", 0, "2001:db8::100") - assertRecord(t, records[3], "NS", 0, "ns1.example.com") + assertRecordCount(t, res, 4) + assertRecord(t, res.Answer[0].Record, "SOA", 0, "ns1.example.com. admin.example.com. 1 1 1 1 1") + assertRecord(t, res.Answer[1].Record, "A", 0, "192.0.2.100") + assertRecord(t, res.Answer[2].Record, "AAAA", 0, "2001:db8::100") + assertRecord(t, res.Answer[3].Record, "NS", 0, "ns1.example.com") }) t.Run("LookupType", func(t *testing.T) { - records, ok := zSimple.LookupType("", "A") + res, ok := zSimple.LookupType("", "A") if !ok { t.Fatalf("Expected ok, got false") } - assertRecordCount(t, records, 1) - assertRecord(t, records[0], "A", 0, "192.0.2.100") + assertRecordCount(t, res, 1) + assertRecord(t, res.Answer[0].Record, "A", 0, "192.0.2.100") }) } @@ -141,93 +141,93 @@ func TestFullZone(t *testing.T) { } t.Run("Lookup example.com", func(t *testing.T) { - records, ok := zFull.Lookup("example.com") + res, ok := zFull.Lookup("example.com") if !ok { t.Fatalf("Expected ok, got false") } - assertRecordCount(t, records, 8) - assertRecord(t, records[0], "SOA", 0, "ns1.example.com. admin.example.com. 1 1 1 1 1") - assertRecord(t, records[1], "A", 0, "192.0.2.1") - assertRecord(t, records[2], "AAAA", 0, "2001:db8::1") - assertRecord(t, records[3], "MX", 3600, "10 mail.example.com") // Default TTL - assertRecord(t, records[4], "TXT", 300, "v=spf1 a mx include:mail.example.com ~all") - assertRecord(t, records[5], "CAA", 86400, "0 issue \"letsencrypt.org\"") - assertRecord(t, records[6], "TXT", 3600, "foo=bar") - assertRecord(t, records[7], "NS", 0, "ns1.example.com") + assertRecordCount(t, res, 8) + assertRecord(t, res.Answer[0].Record, "SOA", 0, "ns1.example.com. admin.example.com. 1 1 1 1 1") + assertRecord(t, res.Answer[1].Record, "A", 0, "192.0.2.1") + assertRecord(t, res.Answer[2].Record, "AAAA", 0, "2001:db8::1") + assertRecord(t, res.Answer[3].Record, "MX", 3600, "10 mail.example.com") // Default TTL + assertRecord(t, res.Answer[4].Record, "TXT", 300, "v=spf1 a mx include:mail.example.com ~all") + assertRecord(t, res.Answer[5].Record, "CAA", 86400, "0 issue \"letsencrypt.org\"") + assertRecord(t, res.Answer[6].Record, "TXT", 3600, "foo=bar") + assertRecord(t, res.Answer[7].Record, "NS", 0, "ns1.example.com") }) t.Run("LookupType example.com TXT", func(t *testing.T) { - records, ok := zFull.LookupType("example.com", "TXT") + res, ok := zFull.LookupType("example.com", "TXT") assertOk(t, ok) - assertRecordCount(t, records, 2) - assertRecord(t, records[0], "TXT", 300, "v=spf1 a mx include:mail.example.com ~all") - assertRecord(t, records[1], "TXT", 3600, "foo=bar") + assertRecordCount(t, res, 2) + assertRecord(t, res.Answer[0].Record, "TXT", 300, "v=spf1 a mx include:mail.example.com ~all") + assertRecord(t, res.Answer[1].Record, "TXT", 3600, "foo=bar") }) t.Run("Lookup www.example.com", func(t *testing.T) { - records, ok := zFull.Lookup("www.example.com") + res, ok := zFull.Lookup("www.example.com") assertOk(t, ok) - assertRecordCount(t, records, 1) - assertRecord(t, records[0], "CNAME", 3600, "example.com") + assertRecordCount(t, res, 1) + assertRecord(t, res.Answer[0].Record, "CNAME", 3600, "example.com") }) t.Run("LookupType www.example.com CNAME", func(t *testing.T) { - records, ok := zFull.LookupType("www.example.com", "CNAME") + res, ok := zFull.LookupType("www.example.com", "CNAME") assertOk(t, ok) - assertRecordCount(t, records, 1) - assertRecord(t, records[0], "CNAME", 3600, "example.com") + assertRecordCount(t, res, 1) + assertRecord(t, res.Answer[0].Record, "CNAME", 3600, "example.com") }) t.Run("Lookup www.example.com TXT", func(t *testing.T) { - records, ok := zFull.LookupType("www.example.com", "TXT") + res, ok := zFull.LookupType("www.example.com", "TXT") assertOk(t, ok) - assertRecordCount(t, records, 0) + assertRecordCount(t, res, 0) }) t.Run("Lookup status.example.com", func(t *testing.T) { - records, ok := zFull.Lookup("status.example.com") + res, ok := zFull.Lookup("status.example.com") assertOk(t, ok) - assertRecordCount(t, records, 2) - assertRecord(t, records[0], "A", 3600, "198.51.100.24") - assertRecord(t, records[1], "A", 3600, "203.0.113.24") + assertRecordCount(t, res, 2) + assertRecord(t, res.Answer[0].Record, "A", 3600, "198.51.100.24") + assertRecord(t, res.Answer[1].Record, "A", 3600, "203.0.113.24") }) t.Run("Lookup partner.example.com", func(t *testing.T) { - records, ok := zFull.Lookup("partner.example.com") + res, ok := zFull.Lookup("partner.example.com") assertOk(t, ok) - assertRecordCount(t, records, 2) - assertRecord(t, records[0], "NS", 3600, "ns1.example.org") - assertRecord(t, records[1], "NS", 3600, "ns2.example.org") + assertRecordCount(t, res, 2) + assertRecord(t, res.Answer[0].Record, "NS", 3600, "ns1.example.org") + assertRecord(t, res.Answer[1].Record, "NS", 3600, "ns2.example.org") }) t.Run("Lookup unused.example.com", func(t *testing.T) { - records, ok := zFull.Lookup("unused.example.com") + res, ok := zFull.Lookup("unused.example.com") assertOk(t, ok) - assertRecordCount(t, records, 0) + assertRecordCount(t, res, 0) }) t.Run("Lookup ftp.internal.example.com", func(t *testing.T) { - records, ok := zFull.Lookup("ftp.internal.example.com") + res, ok := zFull.Lookup("ftp.internal.example.com") assertOk(t, ok) - assertRecordCount(t, records, 1) - assertRecord(t, records[0], "A", 3600, "10.0.0.2") + assertRecordCount(t, res, 1) + assertRecord(t, res.Answer[0].Record, "A", 3600, "10.0.0.2") }) t.Run("Lookup _xmpp-server._tcp.example.com", func(t *testing.T) { - records, ok := zFull.Lookup("_xmpp-server._tcp.example.com") + res, ok := zFull.Lookup("_xmpp-server._tcp.example.com") assertOk(t, ok) - assertRecordCount(t, records, 1) - assertRecord(t, records[0], "SRV", 3600, "10 0 5269 example.com") + assertRecordCount(t, res, 1) + assertRecord(t, res.Answer[0].Record, "SRV", 3600, "10 0 5269 example.com") }) t.Run("Lookup multilayer.nested.folders.example.com", func(t *testing.T) { - records, ok := zFull.Lookup("multilayer.nested.folders.example.com") + res, ok := zFull.Lookup("multilayer.nested.folders.example.com") assertOk(t, ok) - assertRecordCount(t, records, 1) - assertRecord(t, records[0], "A", 3600, "192.0.2.1") + assertRecordCount(t, res, 1) + assertRecord(t, res.Answer[0].Record, "A", 3600, "192.0.2.1") }) } @@ -238,6 +238,11 @@ func TestBadZones(t *testing.T) { errorSubstring string } var badZones = []badZone{ + { + name: "CnameWithOther", + filename: "testdata/bad_cname_with_other.yaml", + errorSubstring: "extraneous records found next to CNAME", + }, { name: "NonexistentFile", filename: "testdata/bad_nonexistent.yaml", @@ -303,6 +308,26 @@ func TestBadZones(t *testing.T) { filename: "testdata/bad_missing_ns.yaml", errorSubstring: "zone apex missing NS records", }, + { + name: "NsWithOther", + filename: "testdata/bad_ns_with_other.yaml", + errorSubstring: "non-glue, non-NS records found at delegation point", + }, + { + name: "NsWithSubzone", + filename: "testdata/bad_ns_with_subzone.yaml", + errorSubstring: "non-glue records found under delegation point", + }, + { + name: "GlueWithOther", + filename: "testdata/bad_glue_with_other.yaml", + errorSubstring: "non-glue record found under delegation point", + }, + { + name: "CnameWithOther", + filename: "testdata/bad_cname_with_other.yaml", + errorSubstring: "extraneous records found next to CNAME", + }, } for _, badZone := range badZones {