initial commit
This commit is contained in:
commit
772469d21b
27 changed files with 1121 additions and 0 deletions
294
yamlzone.go
Normal file
294
yamlzone.go
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
package yamlzone
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/goccy/go-yaml/ast"
|
||||
"github.com/goccy/go-yaml/parser"
|
||||
)
|
||||
|
||||
type Record struct {
|
||||
Type string
|
||||
Ttl uint64
|
||||
Value string
|
||||
}
|
||||
|
||||
type Zone struct {
|
||||
Subzones map[string]*Zone
|
||||
Records []Record
|
||||
GlueRecords []Record
|
||||
IsDelegationPoint bool
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
ctxDirectory = contextKey("directory")
|
||||
ctxFiles = contextKey("files")
|
||||
)
|
||||
|
||||
func LoadZone(filename string) (*Zone, error) {
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return LoadZoneBytes(data, filename)
|
||||
}
|
||||
|
||||
func LoadZoneBytes(data []byte, filename string) (*Zone, error) {
|
||||
// Check for multiple documents before unmarshaling
|
||||
file, err := parser.ParseBytes(data, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(file.Docs) != 1 {
|
||||
return nil, fmt.Errorf("expected exactly one document, got %d", len(file.Docs))
|
||||
}
|
||||
|
||||
z := &Zone{}
|
||||
ctx := context.WithValue(context.TODO(), ctxDirectory, filepath.Dir(filename))
|
||||
ctx = context.WithValue(ctx, ctxFiles, map[string]bool{filename: true})
|
||||
err = yaml.UnmarshalContext(ctx, data, z)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := z.Validate(".", false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return z, nil
|
||||
}
|
||||
|
||||
func (z *Zone) Validate(name string, zoneApexPresent bool) error {
|
||||
nameservers := map[string]bool{}
|
||||
otherRecords := []Record{}
|
||||
isZoneApex := false
|
||||
for _, record := range z.Records {
|
||||
switch record.Type {
|
||||
case "SOA":
|
||||
isZoneApex = true
|
||||
case "NS":
|
||||
nameservers[record.Value] = true
|
||||
default:
|
||||
otherRecords = append(otherRecords, record)
|
||||
}
|
||||
}
|
||||
|
||||
zoneApexPresent = zoneApexPresent || isZoneApex
|
||||
z.IsDelegationPoint = !isZoneApex && len(nameservers) > 0
|
||||
|
||||
if !zoneApexPresent {
|
||||
// Outside zone
|
||||
if z.IsDelegationPoint {
|
||||
return fmt.Errorf("%s: delegation point found outside zone", name)
|
||||
} else if len(otherRecords) > 0 {
|
||||
return fmt.Errorf("%s: records found outside zone: %v", name, otherRecords)
|
||||
}
|
||||
} else if isZoneApex {
|
||||
// Zone apex
|
||||
if len(nameservers) == 0 {
|
||||
return fmt.Errorf("%s: zone apex missing NS records", name)
|
||||
}
|
||||
} else if len(nameservers) > 0 {
|
||||
// Delegation point (does not fall through to subzone validation)
|
||||
if len(otherRecords) > 0 {
|
||||
return fmt.Errorf("%s: non-glue, non-NS records found at delegation point: %v", name, otherRecords)
|
||||
}
|
||||
for subname, subzone := range z.Subzones {
|
||||
glueRecords, err := subzone.GetGlueRecords(concatName(name, subname), nameservers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
z.GlueRecords = append(z.GlueRecords, glueRecords...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Subzone validation
|
||||
// Either we're outside a zone, at a zone apex, or at a non-delegated subzone
|
||||
for subname, subzone := range z.Subzones {
|
||||
if err := subzone.Validate(concatName(name, subname), zoneApexPresent); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *Zone) GetGlueRecords(name string, nameservers map[string]bool) ([]Record, error) {
|
||||
// If the domain is not a nameserver, it must have no records
|
||||
if _, ok := nameservers[name]; !ok {
|
||||
if len(z.Records) > 0 {
|
||||
return nil, fmt.Errorf("%s: non-glue records found under delegation point: %v", name, z.Records)
|
||||
}
|
||||
}
|
||||
// Any records under a delegation point must be glue records
|
||||
for _, record := range z.Records {
|
||||
if !(record.Type == "A" || record.Type == "AAAA") {
|
||||
return nil, fmt.Errorf("%s: non-glue record found under delegation point: %v", name, record)
|
||||
}
|
||||
}
|
||||
for subname, subzone := range z.Subzones {
|
||||
if glueRecords, err := subzone.GetGlueRecords(concatName(name, subname), nameservers); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
z.GlueRecords = append(z.GlueRecords, glueRecords...)
|
||||
}
|
||||
}
|
||||
return z.GlueRecords, nil
|
||||
}
|
||||
|
||||
func (r *Record) UnmarshalYAML(ctx context.Context, data []byte) error {
|
||||
type rawRecord struct {
|
||||
Type string `yaml:"type"`
|
||||
Ttl uint64 `yaml:"ttl"`
|
||||
Value string `yaml:"value"`
|
||||
}
|
||||
var rr rawRecord
|
||||
if err := yaml.UnmarshalContext(ctx, data, &rr); err != nil {
|
||||
return err
|
||||
}
|
||||
if rr.Type == "" {
|
||||
return fmt.Errorf("record: missing type")
|
||||
}
|
||||
if rr.Value == "" {
|
||||
return fmt.Errorf("record: missing value")
|
||||
}
|
||||
*r = Record(rr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *Zone) UnmarshalYAML(ctx context.Context, data []byte) error {
|
||||
// Store the current directory so recursive includes can be resolved
|
||||
dir, ok := ctx.Value(ctxDirectory).(string)
|
||||
if !ok {
|
||||
dir = "."
|
||||
}
|
||||
|
||||
// Store the list of files that have been included in this chain to avoid
|
||||
// infinite recursion
|
||||
files, ok := ctx.Value(ctxFiles).(map[string]bool)
|
||||
if !ok {
|
||||
files = map[string]bool{}
|
||||
}
|
||||
|
||||
// Parse the YAML data into an AST to check for !include tags and handle
|
||||
// mapping vs sequence nodes
|
||||
file, err := parser.ParseBytes(data, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
node := file.Docs[0].Body
|
||||
|
||||
switch node.Type() {
|
||||
case ast.TagType:
|
||||
tagNode := node.(*ast.TagNode)
|
||||
if tagNode.Start.Value != "!include" {
|
||||
return fmt.Errorf("encountered unexpected tag %s", tagNode.Start.Value)
|
||||
}
|
||||
includeFilenameNode, ok := tagNode.Value.(ast.ScalarNode)
|
||||
if !ok {
|
||||
return fmt.Errorf("include: expected scalar")
|
||||
}
|
||||
includeFilename := includeFilenameNode.GetValue().(string)
|
||||
qualifiedIncludeFilename := filepath.Join(dir, includeFilename)
|
||||
includedData, err := os.ReadFile(qualifiedIncludeFilename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for multiple documents in included file before unmarshaling
|
||||
includedFile, err := parser.ParseBytes(includedData, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(includedFile.Docs) != 1 {
|
||||
return fmt.Errorf("%s: expected exactly one document, got %d", qualifiedIncludeFilename, len(includedFile.Docs))
|
||||
}
|
||||
|
||||
if _, ok := files[qualifiedIncludeFilename]; ok {
|
||||
return fmt.Errorf("infinite recursion detected in %s", qualifiedIncludeFilename)
|
||||
}
|
||||
files[qualifiedIncludeFilename] = true
|
||||
subctx := context.WithValue(ctx, ctxFiles, files)
|
||||
subctx = context.WithValue(subctx, ctxDirectory, filepath.Dir(qualifiedIncludeFilename))
|
||||
err = yaml.UnmarshalContext(subctx, includedData, z)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case ast.SequenceType:
|
||||
var records []Record
|
||||
if err := yaml.UnmarshalContext(ctx, data, &records); err != nil {
|
||||
return err
|
||||
}
|
||||
z.Records = records
|
||||
case ast.MappingType:
|
||||
var subzones map[string]*Zone
|
||||
if err := yaml.UnmarshalContext(ctx, data, &subzones); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, ok := subzones["@"]; ok {
|
||||
z.Records = subzones["@"].Records
|
||||
delete(subzones, "@")
|
||||
}
|
||||
z.Subzones = subzones
|
||||
default:
|
||||
return fmt.Errorf("expected a sequence, mapping, or !include tag")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Converts a DNS name to the corresponding path in the zone tree
|
||||
// "www.example.com" -> ["com", "example", "www"]
|
||||
func nameToPath(name string) []string {
|
||||
parts := strings.Split(name, ".")
|
||||
path := []string{}
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
path = append(path, parts[i])
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (z *Zone) Lookup(name string) ([]Record, bool) {
|
||||
path := nameToPath(name)
|
||||
for _, label := range path {
|
||||
// Support empty name and trailing dot
|
||||
if label == "" {
|
||||
continue
|
||||
}
|
||||
if sz, ok := z.Subzones[label]; ok {
|
||||
z = sz
|
||||
} else {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return z.Records, true
|
||||
}
|
||||
|
||||
func (z *Zone) FilterRecords(records []Record, recordType string) []Record {
|
||||
filtered := []Record{}
|
||||
for _, record := range records {
|
||||
if record.Type == recordType {
|
||||
filtered = append(filtered, record)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (z *Zone) LookupType(name string, recordType string) ([]Record, bool) {
|
||||
records, ok := z.Lookup(name)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return z.FilterRecords(records, recordType), true
|
||||
}
|
||||
|
||||
func concatName(name string, subname string) string {
|
||||
if name == "." {
|
||||
return subname + "."
|
||||
}
|
||||
return subname + "." + name
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue