package yamlzone import ( "context" "fmt" "os" "path/filepath" "strings" "github.com/goccy/go-yaml" "github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/parser" ) type Record struct { Type string Ttl uint64 Value string } type NamedRecord struct { Name string Record Record } type Zone struct { Subzones map[string]*Zone Records []Record SOA NamedRecord IsDelegationPoint bool GlueRecords []NamedRecord // Only set (optionally) for delegation points } type LookupResult struct { Answer []NamedRecord Ns []NamedRecord Extra []NamedRecord IsReferral bool } type contextKey string const ( ctxDirectory = contextKey("directory") ctxFiles = contextKey("files") ) func LoadZone(filename string) (*Zone, error) { data, err := os.ReadFile(filename) if err != nil { return nil, err } return LoadZoneBytes(data, filename) } func LoadZoneBytes(data []byte, filename string) (*Zone, error) { // Check for multiple documents before unmarshaling file, err := parser.ParseBytes(data, 0) if err != nil { return nil, err } if len(file.Docs) != 1 { return nil, fmt.Errorf("expected exactly one document, got %d", len(file.Docs)) } z := &Zone{} ctx := context.WithValue(context.TODO(), ctxDirectory, filepath.Dir(filename)) ctx = context.WithValue(ctx, ctxFiles, map[string]bool{filename: true}) err = yaml.UnmarshalContext(ctx, data, z) if err != nil { return nil, err } if err := z.Validate(".", nil, nil); err != nil { return nil, err } return z, nil } func (z *Zone) Validate(name string, soa *Record, ns *[]Record) error { nameservers, otherRecords := []Record{}, []Record{} nameserversMap := map[string]bool{} isZoneApex := false cnameCount := 0 for _, record := range z.Records { switch record.Type { case "SOA": isZoneApex = true soa = &record z.SOA = NamedRecord{Name: name, Record: record} case "NS": nameservers = append(nameservers, record) nameserversMap[record.Value] = true case "CNAME": cnameCount++ default: otherRecords = append(otherRecords, record) } } if cnameCount > 1 || (cnameCount > 0 && len(otherRecords) > 0) { return fmt.Errorf("%s: extraneous records found next to CNAME", name) } z.IsDelegationPoint = !isZoneApex && len(nameservers) > 0 if soa == nil { // Outside zone if z.IsDelegationPoint { return fmt.Errorf("%s: delegation point found outside zone", name) } else if len(otherRecords) > 0 { return fmt.Errorf("%s: records found outside zone: %v", name, otherRecords) } } else if isZoneApex { // Zone apex if len(nameservers) == 0 { return fmt.Errorf("%s: zone apex missing NS records", name) } ns = &nameservers } else if len(nameservers) > 0 { // Delegation point (does not fall through to subzone validation) if len(otherRecords) > 0 { return fmt.Errorf("%s: non-glue, non-NS records found at delegation point: %v", name, otherRecords) } for subname, subzone := range z.Subzones { // This populates z.GlueRecords directly err := subzone.GetGlueRecords(concatName(name, subname), nameserversMap) if err != nil { return err } } return nil } // Subzone validation // Either we're outside a zone, at a zone apex, or at a non-delegated subzone for subname, subzone := range z.Subzones { if err := subzone.Validate(concatName(name, subname), soa, ns); err != nil { return err } } return nil } func (z *Zone) GetGlueRecords(name string, nameserversMap map[string]bool) error { // If the domain is not a nameserver, it must have no records if _, ok := nameserversMap[strings.TrimSuffix(name, ".")]; !ok { if len(z.Records) > 0 { return fmt.Errorf("%s: non-glue records found under delegation point: %v", name, z.Records) } } // Any records under a delegation point must be glue records for _, record := range z.Records { if !(record.Type == "A" || record.Type == "AAAA") { return fmt.Errorf("%s: non-glue record found under delegation point: %v", name, record) } z.GlueRecords = append(z.GlueRecords, NamedRecord{Name: name, Record: record}) } for subname, subzone := range z.Subzones { if err := subzone.GetGlueRecords(concatName(name, subname), nameserversMap); err != nil { return err } } return nil } func (r *Record) UnmarshalYAML(ctx context.Context, data []byte) error { type rawRecord struct { Type string `yaml:"type"` Ttl uint64 `yaml:"ttl"` Value string `yaml:"value"` } var rr rawRecord if err := yaml.UnmarshalContext(ctx, data, &rr); err != nil { return err } if rr.Type == "" { return fmt.Errorf("record: missing type") } if rr.Value == "" { return fmt.Errorf("record: missing value") } *r = Record(rr) return nil } func (z *Zone) UnmarshalYAML(ctx context.Context, data []byte) error { // Store the current directory so recursive includes can be resolved dir, ok := ctx.Value(ctxDirectory).(string) if !ok { dir = "." } // Store the list of files that have been included in this chain to avoid // infinite recursion files, ok := ctx.Value(ctxFiles).(map[string]bool) if !ok { files = map[string]bool{} } // Parse the YAML data into an AST to check for !include tags and handle // mapping vs sequence nodes file, err := parser.ParseBytes(data, 0) if err != nil { return err } node := file.Docs[0].Body switch node.Type() { case ast.TagType: tagNode := node.(*ast.TagNode) if tagNode.Start.Value != "!include" { return fmt.Errorf("encountered unexpected tag %s", tagNode.Start.Value) } includeFilenameNode, ok := tagNode.Value.(ast.ScalarNode) if !ok { return fmt.Errorf("include: expected scalar") } includeFilename := includeFilenameNode.GetValue().(string) qualifiedIncludeFilename := filepath.Join(dir, includeFilename) includedData, err := os.ReadFile(qualifiedIncludeFilename) if err != nil { return err } // Check for multiple documents in included file before unmarshaling includedFile, err := parser.ParseBytes(includedData, 0) if err != nil { return err } if len(includedFile.Docs) != 1 { return fmt.Errorf("%s: expected exactly one document, got %d", qualifiedIncludeFilename, len(includedFile.Docs)) } if _, ok := files[qualifiedIncludeFilename]; ok { return fmt.Errorf("infinite recursion detected in %s", qualifiedIncludeFilename) } files[qualifiedIncludeFilename] = true subctx := context.WithValue(ctx, ctxFiles, files) subctx = context.WithValue(subctx, ctxDirectory, filepath.Dir(qualifiedIncludeFilename)) err = yaml.UnmarshalContext(subctx, includedData, z) if err != nil { return err } case ast.SequenceType: var records []Record if err := yaml.UnmarshalContext(ctx, data, &records); err != nil { return err } z.Records = records case ast.MappingType: var subzones map[string]*Zone if err := yaml.UnmarshalContext(ctx, data, &subzones); err != nil { return err } if _, ok := subzones["@"]; ok { z.Records = subzones["@"].Records delete(subzones, "@") } z.Subzones = subzones default: return fmt.Errorf("expected a sequence, mapping, or !include tag") } return nil } // Converts a DNS name to the corresponding path in the zone tree // "www.example.com" -> ["com", "example", "www"] func nameToPath(name string) []string { parts := strings.Split(name, ".") path := []string{} for i := len(parts) - 1; i >= 0; i-- { path = append(path, parts[i]) } return path } func (z *Zone) Lookup(name string) (LookupResult, bool) { res := LookupResult{} path := nameToPath(name) for _, label := range path { // Support empty name and trailing dot if label == "" { continue } if z.IsDelegationPoint { res.IsReferral = true // Capture NS records for _, record := range z.Records { if record.Type == "NS" { res.Ns = append(res.Ns, NamedRecord{Name: name, Record: record}) } } // Retrieve glue records from cache res.Extra = append(res.Extra, z.GlueRecords...) return res, true } if sz, ok := z.Subzones[label]; !ok { // NXDOMAIN res.Ns = []NamedRecord{z.SOA} return res, false } else { z = sz } } res.Answer = []NamedRecord{} for _, record := range z.Records { res.Answer = append(res.Answer, NamedRecord{Name: name, Record: record}) } if len(res.Answer) == 0 { // NODATA res.Ns = []NamedRecord{z.SOA} return res, true } return res, true } func (z *Zone) FilterRecords(res LookupResult, recordType string) LookupResult { filtered := []NamedRecord{} for _, nr := range res.Answer { if nr.Record.Type == recordType { filtered = append(filtered, nr) } } res.Answer = filtered if len(res.Answer) == 0 && !res.IsReferral { // This ends up as NODATA even if it had data before filtering res.Ns = []NamedRecord{z.SOA} } return res } func (z *Zone) LookupType(name string, recordType string) (LookupResult, bool) { res, ok := z.Lookup(name) if !ok { return LookupResult{}, false } return z.FilterRecords(res, recordType), true } func concatName(name string, subname string) string { if name == "." { return subname + "." } return subname + "." + name }