294 lines
7.6 KiB
Go
294 lines
7.6 KiB
Go
package yamlzone
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/goccy/go-yaml"
|
|
"github.com/goccy/go-yaml/ast"
|
|
"github.com/goccy/go-yaml/parser"
|
|
)
|
|
|
|
type Record struct {
|
|
Type string
|
|
Ttl uint64
|
|
Value string
|
|
}
|
|
|
|
type Zone struct {
|
|
Subzones map[string]*Zone
|
|
Records []Record
|
|
GlueRecords []Record
|
|
IsDelegationPoint bool
|
|
}
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
ctxDirectory = contextKey("directory")
|
|
ctxFiles = contextKey("files")
|
|
)
|
|
|
|
func LoadZone(filename string) (*Zone, error) {
|
|
data, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return LoadZoneBytes(data, filename)
|
|
}
|
|
|
|
func LoadZoneBytes(data []byte, filename string) (*Zone, error) {
|
|
// Check for multiple documents before unmarshaling
|
|
file, err := parser.ParseBytes(data, 0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(file.Docs) != 1 {
|
|
return nil, fmt.Errorf("expected exactly one document, got %d", len(file.Docs))
|
|
}
|
|
|
|
z := &Zone{}
|
|
ctx := context.WithValue(context.TODO(), ctxDirectory, filepath.Dir(filename))
|
|
ctx = context.WithValue(ctx, ctxFiles, map[string]bool{filename: true})
|
|
err = yaml.UnmarshalContext(ctx, data, z)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := z.Validate(".", false); err != nil {
|
|
return nil, err
|
|
}
|
|
return z, nil
|
|
}
|
|
|
|
func (z *Zone) Validate(name string, zoneApexPresent bool) error {
|
|
nameservers := map[string]bool{}
|
|
otherRecords := []Record{}
|
|
isZoneApex := false
|
|
for _, record := range z.Records {
|
|
switch record.Type {
|
|
case "SOA":
|
|
isZoneApex = true
|
|
case "NS":
|
|
nameservers[record.Value] = true
|
|
default:
|
|
otherRecords = append(otherRecords, record)
|
|
}
|
|
}
|
|
|
|
zoneApexPresent = zoneApexPresent || isZoneApex
|
|
z.IsDelegationPoint = !isZoneApex && len(nameservers) > 0
|
|
|
|
if !zoneApexPresent {
|
|
// Outside zone
|
|
if z.IsDelegationPoint {
|
|
return fmt.Errorf("%s: delegation point found outside zone", name)
|
|
} else if len(otherRecords) > 0 {
|
|
return fmt.Errorf("%s: records found outside zone: %v", name, otherRecords)
|
|
}
|
|
} else if isZoneApex {
|
|
// Zone apex
|
|
if len(nameservers) == 0 {
|
|
return fmt.Errorf("%s: zone apex missing NS records", name)
|
|
}
|
|
} else if len(nameservers) > 0 {
|
|
// Delegation point (does not fall through to subzone validation)
|
|
if len(otherRecords) > 0 {
|
|
return fmt.Errorf("%s: non-glue, non-NS records found at delegation point: %v", name, otherRecords)
|
|
}
|
|
for subname, subzone := range z.Subzones {
|
|
glueRecords, err := subzone.GetGlueRecords(concatName(name, subname), nameservers)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
z.GlueRecords = append(z.GlueRecords, glueRecords...)
|
|
}
|
|
return nil
|
|
}
|
|
// Subzone validation
|
|
// Either we're outside a zone, at a zone apex, or at a non-delegated subzone
|
|
for subname, subzone := range z.Subzones {
|
|
if err := subzone.Validate(concatName(name, subname), zoneApexPresent); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (z *Zone) GetGlueRecords(name string, nameservers map[string]bool) ([]Record, error) {
|
|
// If the domain is not a nameserver, it must have no records
|
|
if _, ok := nameservers[name]; !ok {
|
|
if len(z.Records) > 0 {
|
|
return nil, fmt.Errorf("%s: non-glue records found under delegation point: %v", name, z.Records)
|
|
}
|
|
}
|
|
// Any records under a delegation point must be glue records
|
|
for _, record := range z.Records {
|
|
if !(record.Type == "A" || record.Type == "AAAA") {
|
|
return nil, fmt.Errorf("%s: non-glue record found under delegation point: %v", name, record)
|
|
}
|
|
}
|
|
for subname, subzone := range z.Subzones {
|
|
if glueRecords, err := subzone.GetGlueRecords(concatName(name, subname), nameservers); err != nil {
|
|
return nil, err
|
|
} else {
|
|
z.GlueRecords = append(z.GlueRecords, glueRecords...)
|
|
}
|
|
}
|
|
return z.GlueRecords, nil
|
|
}
|
|
|
|
func (r *Record) UnmarshalYAML(ctx context.Context, data []byte) error {
|
|
type rawRecord struct {
|
|
Type string `yaml:"type"`
|
|
Ttl uint64 `yaml:"ttl"`
|
|
Value string `yaml:"value"`
|
|
}
|
|
var rr rawRecord
|
|
if err := yaml.UnmarshalContext(ctx, data, &rr); err != nil {
|
|
return err
|
|
}
|
|
if rr.Type == "" {
|
|
return fmt.Errorf("record: missing type")
|
|
}
|
|
if rr.Value == "" {
|
|
return fmt.Errorf("record: missing value")
|
|
}
|
|
*r = Record(rr)
|
|
return nil
|
|
}
|
|
|
|
func (z *Zone) UnmarshalYAML(ctx context.Context, data []byte) error {
|
|
// Store the current directory so recursive includes can be resolved
|
|
dir, ok := ctx.Value(ctxDirectory).(string)
|
|
if !ok {
|
|
dir = "."
|
|
}
|
|
|
|
// Store the list of files that have been included in this chain to avoid
|
|
// infinite recursion
|
|
files, ok := ctx.Value(ctxFiles).(map[string]bool)
|
|
if !ok {
|
|
files = map[string]bool{}
|
|
}
|
|
|
|
// Parse the YAML data into an AST to check for !include tags and handle
|
|
// mapping vs sequence nodes
|
|
file, err := parser.ParseBytes(data, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
node := file.Docs[0].Body
|
|
|
|
switch node.Type() {
|
|
case ast.TagType:
|
|
tagNode := node.(*ast.TagNode)
|
|
if tagNode.Start.Value != "!include" {
|
|
return fmt.Errorf("encountered unexpected tag %s", tagNode.Start.Value)
|
|
}
|
|
includeFilenameNode, ok := tagNode.Value.(ast.ScalarNode)
|
|
if !ok {
|
|
return fmt.Errorf("include: expected scalar")
|
|
}
|
|
includeFilename := includeFilenameNode.GetValue().(string)
|
|
qualifiedIncludeFilename := filepath.Join(dir, includeFilename)
|
|
includedData, err := os.ReadFile(qualifiedIncludeFilename)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Check for multiple documents in included file before unmarshaling
|
|
includedFile, err := parser.ParseBytes(includedData, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(includedFile.Docs) != 1 {
|
|
return fmt.Errorf("%s: expected exactly one document, got %d", qualifiedIncludeFilename, len(includedFile.Docs))
|
|
}
|
|
|
|
if _, ok := files[qualifiedIncludeFilename]; ok {
|
|
return fmt.Errorf("infinite recursion detected in %s", qualifiedIncludeFilename)
|
|
}
|
|
files[qualifiedIncludeFilename] = true
|
|
subctx := context.WithValue(ctx, ctxFiles, files)
|
|
subctx = context.WithValue(subctx, ctxDirectory, filepath.Dir(qualifiedIncludeFilename))
|
|
err = yaml.UnmarshalContext(subctx, includedData, z)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case ast.SequenceType:
|
|
var records []Record
|
|
if err := yaml.UnmarshalContext(ctx, data, &records); err != nil {
|
|
return err
|
|
}
|
|
z.Records = records
|
|
case ast.MappingType:
|
|
var subzones map[string]*Zone
|
|
if err := yaml.UnmarshalContext(ctx, data, &subzones); err != nil {
|
|
return err
|
|
}
|
|
if _, ok := subzones["@"]; ok {
|
|
z.Records = subzones["@"].Records
|
|
delete(subzones, "@")
|
|
}
|
|
z.Subzones = subzones
|
|
default:
|
|
return fmt.Errorf("expected a sequence, mapping, or !include tag")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Converts a DNS name to the corresponding path in the zone tree
|
|
// "www.example.com" -> ["com", "example", "www"]
|
|
func nameToPath(name string) []string {
|
|
parts := strings.Split(name, ".")
|
|
path := []string{}
|
|
for i := len(parts) - 1; i >= 0; i-- {
|
|
path = append(path, parts[i])
|
|
}
|
|
return path
|
|
}
|
|
|
|
func (z *Zone) Lookup(name string) ([]Record, bool) {
|
|
path := nameToPath(name)
|
|
for _, label := range path {
|
|
// Support empty name and trailing dot
|
|
if label == "" {
|
|
continue
|
|
}
|
|
if sz, ok := z.Subzones[label]; ok {
|
|
z = sz
|
|
} else {
|
|
return nil, false
|
|
}
|
|
}
|
|
return z.Records, true
|
|
}
|
|
|
|
func (z *Zone) FilterRecords(records []Record, recordType string) []Record {
|
|
filtered := []Record{}
|
|
for _, record := range records {
|
|
if record.Type == recordType {
|
|
filtered = append(filtered, record)
|
|
}
|
|
}
|
|
return filtered
|
|
}
|
|
|
|
func (z *Zone) LookupType(name string, recordType string) ([]Record, bool) {
|
|
records, ok := z.Lookup(name)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
return z.FilterRecords(records, recordType), true
|
|
}
|
|
|
|
func concatName(name string, subname string) string {
|
|
if name == "." {
|
|
return subname + "."
|
|
}
|
|
return subname + "." + name
|
|
}
|