337 lines
8.7 KiB
Go
337 lines
8.7 KiB
Go
package yamlzone
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/goccy/go-yaml"
|
|
"github.com/goccy/go-yaml/ast"
|
|
"github.com/goccy/go-yaml/parser"
|
|
)
|
|
|
|
type Record struct {
|
|
Type string
|
|
Ttl uint64
|
|
Value string
|
|
}
|
|
|
|
type NamedRecord struct {
|
|
Name string
|
|
Record Record
|
|
}
|
|
|
|
type Zone struct {
|
|
Subzones map[string]*Zone
|
|
Records []Record
|
|
SOA NamedRecord
|
|
IsDelegationPoint bool
|
|
GlueRecords []NamedRecord // Only set (optionally) for delegation points
|
|
}
|
|
|
|
type LookupResult struct {
|
|
Answer []NamedRecord
|
|
Ns []NamedRecord
|
|
Extra []NamedRecord
|
|
IsReferral bool
|
|
}
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
ctxDirectory = contextKey("directory")
|
|
ctxFiles = contextKey("files")
|
|
)
|
|
|
|
func LoadZone(filename string) (*Zone, error) {
|
|
data, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return LoadZoneBytes(data, filename)
|
|
}
|
|
|
|
func LoadZoneBytes(data []byte, filename string) (*Zone, error) {
|
|
// Check for multiple documents before unmarshaling
|
|
file, err := parser.ParseBytes(data, 0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(file.Docs) != 1 {
|
|
return nil, fmt.Errorf("expected exactly one document, got %d", len(file.Docs))
|
|
}
|
|
|
|
z := &Zone{}
|
|
ctx := context.WithValue(context.TODO(), ctxDirectory, filepath.Dir(filename))
|
|
ctx = context.WithValue(ctx, ctxFiles, map[string]bool{filename: true})
|
|
err = yaml.UnmarshalContext(ctx, data, z)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := z.Validate(".", nil, nil); err != nil {
|
|
return nil, err
|
|
}
|
|
return z, nil
|
|
}
|
|
|
|
func (z *Zone) Validate(name string, soa *NamedRecord, ns *[]Record) error {
|
|
nameservers, otherRecords := []Record{}, []Record{}
|
|
nameserversMap := map[string]bool{}
|
|
isZoneApex := false
|
|
cnameCount := 0
|
|
for _, record := range z.Records {
|
|
switch record.Type {
|
|
case "SOA":
|
|
isZoneApex = true
|
|
soa = &NamedRecord{Name: name, Record: record}
|
|
case "NS":
|
|
nameservers = append(nameservers, record)
|
|
nameserversMap[record.Value] = true
|
|
case "CNAME":
|
|
cnameCount++
|
|
default:
|
|
otherRecords = append(otherRecords, record)
|
|
}
|
|
}
|
|
|
|
z.IsDelegationPoint = !isZoneApex && len(nameservers) > 0
|
|
|
|
if cnameCount > 1 || (cnameCount > 0 && len(otherRecords) > 0) {
|
|
return fmt.Errorf("%s: extraneous records found next to CNAME", name)
|
|
}
|
|
|
|
if soa == nil {
|
|
// Outside zone
|
|
if z.IsDelegationPoint {
|
|
return fmt.Errorf("%s: delegation point found outside zone", name)
|
|
} else if len(otherRecords) > 0 {
|
|
return fmt.Errorf("%s: records found outside zone: %v", name, otherRecords)
|
|
}
|
|
} else if isZoneApex {
|
|
// Zone apex
|
|
if len(nameservers) == 0 {
|
|
return fmt.Errorf("%s: zone apex missing NS records", name)
|
|
}
|
|
z.SOA = *soa
|
|
ns = &nameservers
|
|
} else if len(nameservers) > 0 {
|
|
// Delegation point (does not fall through)
|
|
if len(otherRecords) > 0 {
|
|
return fmt.Errorf("%s: non-glue, non-NS records found at delegation point: %v", name, otherRecords)
|
|
}
|
|
for subname, subzone := range z.Subzones {
|
|
// This populates z.GlueRecords directly
|
|
err := subzone.GetGlueRecords(concatName(name, subname), nameserversMap)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
} else {
|
|
// Normal node
|
|
// Cache SOA at each node in the zone tree
|
|
z.SOA = *soa
|
|
}
|
|
|
|
// Either we're outside a zone, at a zone apex, or at a non-delegated subzone
|
|
|
|
for subname, subzone := range z.Subzones {
|
|
if err := subzone.Validate(concatName(name, subname), soa, ns); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (z *Zone) GetGlueRecords(name string, nameserversMap map[string]bool) error {
|
|
// If the domain is not a nameserver, it must have no records
|
|
if _, ok := nameserversMap[strings.TrimSuffix(name, ".")]; !ok {
|
|
if len(z.Records) > 0 {
|
|
return fmt.Errorf("%s: non-glue records found under delegation point: %v", name, z.Records)
|
|
}
|
|
}
|
|
// Any records under a delegation point must be glue records
|
|
for _, record := range z.Records {
|
|
if !(record.Type == "A" || record.Type == "AAAA") {
|
|
return fmt.Errorf("%s: non-glue record found under delegation point: %v", name, record)
|
|
}
|
|
z.GlueRecords = append(z.GlueRecords, NamedRecord{Name: name, Record: record})
|
|
}
|
|
for subname, subzone := range z.Subzones {
|
|
if err := subzone.GetGlueRecords(concatName(name, subname), nameserversMap); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *Record) UnmarshalYAML(ctx context.Context, data []byte) error {
|
|
type rawRecord struct {
|
|
Type string `yaml:"type"`
|
|
Ttl uint64 `yaml:"ttl"`
|
|
Value string `yaml:"value"`
|
|
}
|
|
var rr rawRecord
|
|
if err := yaml.UnmarshalContext(ctx, data, &rr); err != nil {
|
|
return err
|
|
}
|
|
if rr.Type == "" {
|
|
return fmt.Errorf("record: missing type")
|
|
}
|
|
if rr.Value == "" {
|
|
return fmt.Errorf("record: missing value")
|
|
}
|
|
*r = Record(rr)
|
|
return nil
|
|
}
|
|
|
|
func (z *Zone) UnmarshalYAML(ctx context.Context, data []byte) error {
|
|
// Store the current directory so recursive includes can be resolved
|
|
dir, ok := ctx.Value(ctxDirectory).(string)
|
|
if !ok {
|
|
dir = "."
|
|
}
|
|
|
|
// Store the list of files that have been included in this chain to avoid
|
|
// infinite recursion
|
|
files, ok := ctx.Value(ctxFiles).(map[string]bool)
|
|
if !ok {
|
|
files = map[string]bool{}
|
|
}
|
|
|
|
// Parse the YAML data into an AST to check for !include tags and handle
|
|
// mapping vs sequence nodes
|
|
file, err := parser.ParseBytes(data, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
node := file.Docs[0].Body
|
|
|
|
switch node.Type() {
|
|
case ast.TagType:
|
|
tagNode := node.(*ast.TagNode)
|
|
if tagNode.Start.Value != "!include" {
|
|
return fmt.Errorf("encountered unexpected tag %s", tagNode.Start.Value)
|
|
}
|
|
includeFilenameNode, ok := tagNode.Value.(ast.ScalarNode)
|
|
if !ok {
|
|
return fmt.Errorf("include: expected scalar")
|
|
}
|
|
includeFilename := includeFilenameNode.GetValue().(string)
|
|
qualifiedIncludeFilename := filepath.Join(dir, includeFilename)
|
|
includedData, err := os.ReadFile(qualifiedIncludeFilename)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Check for multiple documents in included file before unmarshaling
|
|
includedFile, err := parser.ParseBytes(includedData, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(includedFile.Docs) != 1 {
|
|
return fmt.Errorf("%s: expected exactly one document, got %d", qualifiedIncludeFilename, len(includedFile.Docs))
|
|
}
|
|
|
|
if _, ok := files[qualifiedIncludeFilename]; ok {
|
|
return fmt.Errorf("infinite recursion detected in %s", qualifiedIncludeFilename)
|
|
}
|
|
files[qualifiedIncludeFilename] = true
|
|
subctx := context.WithValue(ctx, ctxFiles, files)
|
|
subctx = context.WithValue(subctx, ctxDirectory, filepath.Dir(qualifiedIncludeFilename))
|
|
err = yaml.UnmarshalContext(subctx, includedData, z)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case ast.SequenceType:
|
|
var records []Record
|
|
if err := yaml.UnmarshalContext(ctx, data, &records); err != nil {
|
|
return err
|
|
}
|
|
z.Records = records
|
|
case ast.MappingType:
|
|
var subzones map[string]*Zone
|
|
if err := yaml.UnmarshalContext(ctx, data, &subzones); err != nil {
|
|
return err
|
|
}
|
|
if _, ok := subzones["@"]; ok {
|
|
z.Records = subzones["@"].Records
|
|
delete(subzones, "@")
|
|
}
|
|
z.Subzones = subzones
|
|
default:
|
|
return fmt.Errorf("expected a sequence, mapping, or !include tag")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Converts a DNS name to the corresponding path in the zone tree
|
|
// "www.example.com" -> ["com", "example", "www"]
|
|
func nameToPath(name string) []string {
|
|
parts := strings.Split(name, ".")
|
|
path := []string{}
|
|
for i := len(parts) - 1; i >= 0; i-- {
|
|
path = append(path, parts[i])
|
|
}
|
|
return path
|
|
}
|
|
|
|
func (z *Zone) Lookup(name string) (LookupResult, bool) {
|
|
res, ok := z.LookupType(name, "")
|
|
return res, ok
|
|
}
|
|
|
|
func (z *Zone) LookupType(name string, recordType string) (LookupResult, bool) {
|
|
res := LookupResult{}
|
|
path := nameToPath(name)
|
|
for _, label := range path {
|
|
// Support empty name and trailing dot
|
|
if label == "" {
|
|
continue
|
|
}
|
|
if sz, ok := z.Subzones[label]; !ok {
|
|
// NXDOMAIN
|
|
res.Ns = []NamedRecord{z.SOA}
|
|
return res, false
|
|
} else {
|
|
z = sz
|
|
}
|
|
if z.IsDelegationPoint {
|
|
// Retrieve glue records from cache, whether referral or NS query
|
|
res.Extra = append(res.Extra, z.GlueRecords...)
|
|
if recordType != "NS" {
|
|
// Treat as referral unless client requested NS records
|
|
res.IsReferral = true
|
|
// Capture NS records
|
|
for _, record := range z.Records {
|
|
if record.Type == "NS" {
|
|
res.Ns = append(res.Ns, NamedRecord{Name: name, Record: record})
|
|
}
|
|
}
|
|
return res, true
|
|
}
|
|
}
|
|
}
|
|
res.Answer = []NamedRecord{}
|
|
for _, record := range z.Records {
|
|
if recordType == "" || record.Type == recordType {
|
|
res.Answer = append(res.Answer, NamedRecord{Name: name, Record: record})
|
|
}
|
|
}
|
|
if len(res.Answer) == 0 {
|
|
// NODATA
|
|
res.Ns = []NamedRecord{z.SOA}
|
|
return res, true
|
|
}
|
|
return res, true
|
|
}
|
|
|
|
func concatName(name string, subname string) string {
|
|
if name == "." {
|
|
return subname + "."
|
|
}
|
|
return subname + "." + name
|
|
}
|