package yamlzone import ( "context" "fmt" "os" "path/filepath" "strings" "github.com/goccy/go-yaml" "github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/parser" ) type Record struct { Type string Ttl uint64 Value string } type Zone struct { Subzones map[string]*Zone Records []Record GlueRecords []Record IsDelegationPoint bool } type contextKey string const ( ctxDirectory = contextKey("directory") ctxFiles = contextKey("files") ) func LoadZone(filename string) (*Zone, error) { data, err := os.ReadFile(filename) if err != nil { return nil, err } return LoadZoneBytes(data, filename) } func LoadZoneBytes(data []byte, filename string) (*Zone, error) { // Check for multiple documents before unmarshaling file, err := parser.ParseBytes(data, 0) if err != nil { return nil, err } if len(file.Docs) != 1 { return nil, fmt.Errorf("expected exactly one document, got %d", len(file.Docs)) } z := &Zone{} ctx := context.WithValue(context.TODO(), ctxDirectory, filepath.Dir(filename)) ctx = context.WithValue(ctx, ctxFiles, map[string]bool{filename: true}) err = yaml.UnmarshalContext(ctx, data, z) if err != nil { return nil, err } if err := z.Validate(".", false); err != nil { return nil, err } return z, nil } func (z *Zone) Validate(name string, zoneApexPresent bool) error { nameservers := map[string]bool{} otherRecords := []Record{} isZoneApex := false for _, record := range z.Records { switch record.Type { case "SOA": isZoneApex = true case "NS": nameservers[record.Value] = true default: otherRecords = append(otherRecords, record) } } zoneApexPresent = zoneApexPresent || isZoneApex z.IsDelegationPoint = !isZoneApex && len(nameservers) > 0 if !zoneApexPresent { // Outside zone if z.IsDelegationPoint { return fmt.Errorf("%s: delegation point found outside zone", name) } else if len(otherRecords) > 0 { return fmt.Errorf("%s: records found outside zone: %v", name, otherRecords) } } else if isZoneApex { // Zone apex if len(nameservers) == 0 { return fmt.Errorf("%s: zone apex missing NS records", name) } } else if len(nameservers) > 0 { // Delegation point (does not fall through to subzone validation) if len(otherRecords) > 0 { return fmt.Errorf("%s: non-glue, non-NS records found at delegation point: %v", name, otherRecords) } for subname, subzone := range z.Subzones { glueRecords, err := subzone.GetGlueRecords(concatName(name, subname), nameservers) if err != nil { return err } z.GlueRecords = append(z.GlueRecords, glueRecords...) } return nil } // Subzone validation // Either we're outside a zone, at a zone apex, or at a non-delegated subzone for subname, subzone := range z.Subzones { if err := subzone.Validate(concatName(name, subname), zoneApexPresent); err != nil { return err } } return nil } func (z *Zone) GetGlueRecords(name string, nameservers map[string]bool) ([]Record, error) { // If the domain is not a nameserver, it must have no records if _, ok := nameservers[name]; !ok { if len(z.Records) > 0 { return nil, fmt.Errorf("%s: non-glue records found under delegation point: %v", name, z.Records) } } // Any records under a delegation point must be glue records for _, record := range z.Records { if !(record.Type == "A" || record.Type == "AAAA") { return nil, fmt.Errorf("%s: non-glue record found under delegation point: %v", name, record) } } for subname, subzone := range z.Subzones { if glueRecords, err := subzone.GetGlueRecords(concatName(name, subname), nameservers); err != nil { return nil, err } else { z.GlueRecords = append(z.GlueRecords, glueRecords...) } } return z.GlueRecords, nil } func (r *Record) UnmarshalYAML(ctx context.Context, data []byte) error { type rawRecord struct { Type string `yaml:"type"` Ttl uint64 `yaml:"ttl"` Value string `yaml:"value"` } var rr rawRecord if err := yaml.UnmarshalContext(ctx, data, &rr); err != nil { return err } if rr.Type == "" { return fmt.Errorf("record: missing type") } if rr.Value == "" { return fmt.Errorf("record: missing value") } *r = Record(rr) return nil } func (z *Zone) UnmarshalYAML(ctx context.Context, data []byte) error { // Store the current directory so recursive includes can be resolved dir, ok := ctx.Value(ctxDirectory).(string) if !ok { dir = "." } // Store the list of files that have been included in this chain to avoid // infinite recursion files, ok := ctx.Value(ctxFiles).(map[string]bool) if !ok { files = map[string]bool{} } // Parse the YAML data into an AST to check for !include tags and handle // mapping vs sequence nodes file, err := parser.ParseBytes(data, 0) if err != nil { return err } node := file.Docs[0].Body switch node.Type() { case ast.TagType: tagNode := node.(*ast.TagNode) if tagNode.Start.Value != "!include" { return fmt.Errorf("encountered unexpected tag %s", tagNode.Start.Value) } includeFilenameNode, ok := tagNode.Value.(ast.ScalarNode) if !ok { return fmt.Errorf("include: expected scalar") } includeFilename := includeFilenameNode.GetValue().(string) qualifiedIncludeFilename := filepath.Join(dir, includeFilename) includedData, err := os.ReadFile(qualifiedIncludeFilename) if err != nil { return err } // Check for multiple documents in included file before unmarshaling includedFile, err := parser.ParseBytes(includedData, 0) if err != nil { return err } if len(includedFile.Docs) != 1 { return fmt.Errorf("%s: expected exactly one document, got %d", qualifiedIncludeFilename, len(includedFile.Docs)) } if _, ok := files[qualifiedIncludeFilename]; ok { return fmt.Errorf("infinite recursion detected in %s", qualifiedIncludeFilename) } files[qualifiedIncludeFilename] = true subctx := context.WithValue(ctx, ctxFiles, files) subctx = context.WithValue(subctx, ctxDirectory, filepath.Dir(qualifiedIncludeFilename)) err = yaml.UnmarshalContext(subctx, includedData, z) if err != nil { return err } case ast.SequenceType: var records []Record if err := yaml.UnmarshalContext(ctx, data, &records); err != nil { return err } z.Records = records case ast.MappingType: var subzones map[string]*Zone if err := yaml.UnmarshalContext(ctx, data, &subzones); err != nil { return err } if _, ok := subzones["@"]; ok { z.Records = subzones["@"].Records delete(subzones, "@") } z.Subzones = subzones default: return fmt.Errorf("expected a sequence, mapping, or !include tag") } return nil } // Converts a DNS name to the corresponding path in the zone tree // "www.example.com" -> ["com", "example", "www"] func nameToPath(name string) []string { parts := strings.Split(name, ".") path := []string{} for i := len(parts) - 1; i >= 0; i-- { path = append(path, parts[i]) } return path } func (z *Zone) Lookup(name string) ([]Record, bool) { path := nameToPath(name) for _, label := range path { // Support empty name and trailing dot if label == "" { continue } if sz, ok := z.Subzones[label]; ok { z = sz } else { return nil, false } } return z.Records, true } func (z *Zone) FilterRecords(records []Record, recordType string) []Record { filtered := []Record{} for _, record := range records { if record.Type == recordType { filtered = append(filtered, record) } } return filtered } func (z *Zone) LookupType(name string, recordType string) ([]Record, bool) { records, ok := z.Lookup(name) if !ok { return nil, false } return z.FilterRecords(records, recordType), true } func concatName(name string, subname string) string { if name == "." { return subname + "." } return subname + "." + name }