coredns-yaml/yamlzone.go

345 lines
8.9 KiB
Go

package yamlzone
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/goccy/go-yaml"
"github.com/goccy/go-yaml/ast"
"github.com/goccy/go-yaml/parser"
)
type Record struct {
Type string
Ttl uint64
Value string
}
type NamedRecord struct {
Name string
Record Record
}
type Zone struct {
Subzones map[string]*Zone
Records []Record
SOA NamedRecord
IsDelegationPoint bool
GlueRecords []NamedRecord // Only set (optionally) for delegation points
}
type LookupResult struct {
Answer []NamedRecord
Ns []NamedRecord
Extra []NamedRecord
IsReferral bool
}
type contextKey string
const (
ctxDirectory = contextKey("directory")
ctxFiles = contextKey("files")
)
func LoadZone(filename string) (*Zone, error) {
data, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
return LoadZoneBytes(data, filename)
}
func LoadZoneBytes(data []byte, filename string) (*Zone, error) {
// Check for multiple documents before unmarshaling
file, err := parser.ParseBytes(data, 0)
if err != nil {
return nil, err
}
if len(file.Docs) != 1 {
return nil, fmt.Errorf("expected exactly one document, got %d", len(file.Docs))
}
z := &Zone{}
ctx := context.WithValue(context.TODO(), ctxDirectory, filepath.Dir(filename))
ctx = context.WithValue(ctx, ctxFiles, map[string]bool{filename: true})
err = yaml.UnmarshalContext(ctx, data, z)
if err != nil {
return nil, err
}
if err := z.Validate(".", nil, nil); err != nil {
return nil, err
}
return z, nil
}
func (z *Zone) Validate(name string, soa *Record, ns *[]Record) error {
nameservers, otherRecords := []Record{}, []Record{}
nameserversMap := map[string]bool{}
isZoneApex := false
cnameCount := 0
for _, record := range z.Records {
switch record.Type {
case "SOA":
isZoneApex = true
soa = &record
z.SOA = NamedRecord{Name: name, Record: record}
case "NS":
nameservers = append(nameservers, record)
nameserversMap[record.Value] = true
case "CNAME":
cnameCount++
default:
otherRecords = append(otherRecords, record)
}
}
if cnameCount > 1 || (cnameCount > 0 && len(otherRecords) > 0) {
return fmt.Errorf("%s: extraneous records found next to CNAME", name)
}
z.IsDelegationPoint = !isZoneApex && len(nameservers) > 0
if soa == nil {
// Outside zone
if z.IsDelegationPoint {
return fmt.Errorf("%s: delegation point found outside zone", name)
} else if len(otherRecords) > 0 {
return fmt.Errorf("%s: records found outside zone: %v", name, otherRecords)
}
} else if isZoneApex {
// Zone apex
if len(nameservers) == 0 {
return fmt.Errorf("%s: zone apex missing NS records", name)
}
ns = &nameservers
} else if len(nameservers) > 0 {
// Delegation point (does not fall through to subzone validation)
if len(otherRecords) > 0 {
return fmt.Errorf("%s: non-glue, non-NS records found at delegation point: %v", name, otherRecords)
}
for subname, subzone := range z.Subzones {
// This populates z.GlueRecords directly
err := subzone.GetGlueRecords(concatName(name, subname), nameserversMap)
if err != nil {
return err
}
}
return nil
}
// Subzone validation
// Either we're outside a zone, at a zone apex, or at a non-delegated subzone
for subname, subzone := range z.Subzones {
if err := subzone.Validate(concatName(name, subname), soa, ns); err != nil {
return err
}
}
return nil
}
func (z *Zone) GetGlueRecords(name string, nameserversMap map[string]bool) error {
// If the domain is not a nameserver, it must have no records
if _, ok := nameserversMap[strings.TrimSuffix(name, ".")]; !ok {
if len(z.Records) > 0 {
return fmt.Errorf("%s: non-glue records found under delegation point: %v", name, z.Records)
}
}
// Any records under a delegation point must be glue records
for _, record := range z.Records {
if !(record.Type == "A" || record.Type == "AAAA") {
return fmt.Errorf("%s: non-glue record found under delegation point: %v", name, record)
}
z.GlueRecords = append(z.GlueRecords, NamedRecord{Name: name, Record: record})
}
for subname, subzone := range z.Subzones {
if err := subzone.GetGlueRecords(concatName(name, subname), nameserversMap); err != nil {
return err
}
}
return nil
}
func (r *Record) UnmarshalYAML(ctx context.Context, data []byte) error {
type rawRecord struct {
Type string `yaml:"type"`
Ttl uint64 `yaml:"ttl"`
Value string `yaml:"value"`
}
var rr rawRecord
if err := yaml.UnmarshalContext(ctx, data, &rr); err != nil {
return err
}
if rr.Type == "" {
return fmt.Errorf("record: missing type")
}
if rr.Value == "" {
return fmt.Errorf("record: missing value")
}
*r = Record(rr)
return nil
}
func (z *Zone) UnmarshalYAML(ctx context.Context, data []byte) error {
// Store the current directory so recursive includes can be resolved
dir, ok := ctx.Value(ctxDirectory).(string)
if !ok {
dir = "."
}
// Store the list of files that have been included in this chain to avoid
// infinite recursion
files, ok := ctx.Value(ctxFiles).(map[string]bool)
if !ok {
files = map[string]bool{}
}
// Parse the YAML data into an AST to check for !include tags and handle
// mapping vs sequence nodes
file, err := parser.ParseBytes(data, 0)
if err != nil {
return err
}
node := file.Docs[0].Body
switch node.Type() {
case ast.TagType:
tagNode := node.(*ast.TagNode)
if tagNode.Start.Value != "!include" {
return fmt.Errorf("encountered unexpected tag %s", tagNode.Start.Value)
}
includeFilenameNode, ok := tagNode.Value.(ast.ScalarNode)
if !ok {
return fmt.Errorf("include: expected scalar")
}
includeFilename := includeFilenameNode.GetValue().(string)
qualifiedIncludeFilename := filepath.Join(dir, includeFilename)
includedData, err := os.ReadFile(qualifiedIncludeFilename)
if err != nil {
return err
}
// Check for multiple documents in included file before unmarshaling
includedFile, err := parser.ParseBytes(includedData, 0)
if err != nil {
return err
}
if len(includedFile.Docs) != 1 {
return fmt.Errorf("%s: expected exactly one document, got %d", qualifiedIncludeFilename, len(includedFile.Docs))
}
if _, ok := files[qualifiedIncludeFilename]; ok {
return fmt.Errorf("infinite recursion detected in %s", qualifiedIncludeFilename)
}
files[qualifiedIncludeFilename] = true
subctx := context.WithValue(ctx, ctxFiles, files)
subctx = context.WithValue(subctx, ctxDirectory, filepath.Dir(qualifiedIncludeFilename))
err = yaml.UnmarshalContext(subctx, includedData, z)
if err != nil {
return err
}
case ast.SequenceType:
var records []Record
if err := yaml.UnmarshalContext(ctx, data, &records); err != nil {
return err
}
z.Records = records
case ast.MappingType:
var subzones map[string]*Zone
if err := yaml.UnmarshalContext(ctx, data, &subzones); err != nil {
return err
}
if _, ok := subzones["@"]; ok {
z.Records = subzones["@"].Records
delete(subzones, "@")
}
z.Subzones = subzones
default:
return fmt.Errorf("expected a sequence, mapping, or !include tag")
}
return nil
}
// Converts a DNS name to the corresponding path in the zone tree
// "www.example.com" -> ["com", "example", "www"]
func nameToPath(name string) []string {
parts := strings.Split(name, ".")
path := []string{}
for i := len(parts) - 1; i >= 0; i-- {
path = append(path, parts[i])
}
return path
}
func (z *Zone) Lookup(name string) (LookupResult, bool) {
res := LookupResult{}
path := nameToPath(name)
for _, label := range path {
// Support empty name and trailing dot
if label == "" {
continue
}
if z.IsDelegationPoint {
res.IsReferral = true
// Capture NS records
for _, record := range z.Records {
if record.Type == "NS" {
res.Ns = append(res.Ns, NamedRecord{Name: name, Record: record})
}
}
// Retrieve glue records from cache
res.Extra = append(res.Extra, z.GlueRecords...)
return res, true
}
if sz, ok := z.Subzones[label]; !ok {
// NXDOMAIN
res.Ns = []NamedRecord{z.SOA}
return res, false
} else {
z = sz
}
}
res.Answer = []NamedRecord{}
for _, record := range z.Records {
res.Answer = append(res.Answer, NamedRecord{Name: name, Record: record})
}
if len(res.Answer) == 0 {
// NODATA
res.Ns = []NamedRecord{z.SOA}
return res, true
}
return res, true
}
func (z *Zone) FilterRecords(res LookupResult, recordType string) LookupResult {
filtered := []NamedRecord{}
for _, nr := range res.Answer {
if nr.Record.Type == recordType {
filtered = append(filtered, nr)
}
}
res.Answer = filtered
if len(res.Answer) == 0 && !res.IsReferral {
// This ends up as NODATA even if it had data before filtering
res.Ns = []NamedRecord{z.SOA}
}
return res
}
func (z *Zone) LookupType(name string, recordType string) (LookupResult, bool) {
res, ok := z.Lookup(name)
if !ok {
return LookupResult{}, false
}
return z.FilterRecords(res, recordType), true
}
func concatName(name string, subname string) string {
if name == "." {
return subname + "."
}
return subname + "." + name
}