initial commit

This commit is contained in:
Dessa Simpson 2025-11-30 19:25:34 -07:00
commit 772469d21b
27 changed files with 1121 additions and 0 deletions

344
yamlzone_test.go Normal file
View file

@ -0,0 +1,344 @@
package yamlzone
import (
"context"
"strings"
"testing"
"github.com/goccy/go-yaml"
)
func assertOk(t *testing.T, ok bool) {
if !ok {
t.Fatalf("Expected ok, got false")
}
}
func assertNotOk(t *testing.T, ok bool, records []Record) {
if ok {
t.Fatalf("Expected not ok, got ok (%v)", records)
}
}
func assertRecordCount(t *testing.T, records []Record, expected int) {
if len(records) != expected {
t.Fatalf("Expected %d records, got %d (%v)", expected, len(records), records)
}
}
func assertRecordType(t *testing.T, record Record, expected string) {
if record.Type != expected {
t.Fatalf("Expected %s record, got %s", expected, record.Type)
}
}
func assertRecordTtl(t *testing.T, record Record, expected uint64) {
if record.Ttl != expected {
t.Fatalf("Expected TTL %d, got %d", expected, record.Ttl)
}
}
func assertRecordValue(t *testing.T, record Record, expected string) {
if record.Value != expected {
t.Fatalf("Expected %s, got %s", expected, record.Value)
}
}
func assertRecord(t *testing.T, record Record, expectedType string, expectedTtl uint64, expectedValue string) {
assertRecordType(t, record, expectedType)
assertRecordTtl(t, record, expectedTtl)
assertRecordValue(t, record, expectedValue)
}
func TestEmptyZone(t *testing.T) {
var zEmpty Zone
t.Run("Unmarshal", func(t *testing.T) {
err := yaml.Unmarshal([]byte("{}"), &zEmpty)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if t.Failed() {
return
}
t.Run("Lookup empty string", func(t *testing.T) {
records, ok := zEmpty.Lookup("")
assertOk(t, ok)
assertRecordCount(t, records, 0)
})
t.Run("Lookup single dot", func(t *testing.T) {
records, ok := zEmpty.Lookup(".")
assertOk(t, ok)
assertRecordCount(t, records, 0)
})
t.Run("Lookup example.com", func(t *testing.T) {
records, ok := zEmpty.Lookup("example.com")
assertNotOk(t, ok, records)
})
t.Run("LookupType example.com A", func(t *testing.T) {
records, ok := zEmpty.LookupType("example.com", "A")
assertNotOk(t, ok, records)
})
}
func TestSimpleZone(t *testing.T) {
const zSimpleFile = "testdata/example.org.yaml"
var zSimple *Zone
t.Run("Load", func(t *testing.T) {
var err error
zSimple, err = LoadZone(zSimpleFile)
if err != nil {
t.Fatalf("Expected no error, got \"%v\"", err)
}
})
if t.Failed() {
return
}
t.Run("Lookup", func(t *testing.T) {
records, ok := zSimple.Lookup("")
assertOk(t, ok)
assertRecordCount(t, records, 4)
assertRecord(t, records[0], "SOA", 0, "ns1.example.com. admin.example.com. 1 1 1 1 1")
assertRecord(t, records[1], "A", 0, "192.0.2.100")
assertRecord(t, records[2], "AAAA", 0, "2001:db8::100")
assertRecord(t, records[3], "NS", 0, "ns1.example.com")
})
t.Run("LookupType", func(t *testing.T) {
records, ok := zSimple.LookupType("", "A")
if !ok {
t.Fatalf("Expected ok, got false")
}
assertRecordCount(t, records, 1)
assertRecord(t, records[0], "A", 0, "192.0.2.100")
})
}
func TestFullZone(t *testing.T) {
const zFullFile = "testdata/zones.yaml"
var zFull *Zone
t.Run("Load", func(t *testing.T) {
var err error
zFull, err = LoadZone(zFullFile)
if err != nil {
t.Fatalf("Expected no error, got \"%v\"", err)
}
})
if t.Failed() {
return
}
t.Run("Lookup example.com", func(t *testing.T) {
records, ok := zFull.Lookup("example.com")
if !ok {
t.Fatalf("Expected ok, got false")
}
assertRecordCount(t, records, 8)
assertRecord(t, records[0], "SOA", 0, "ns1.example.com. admin.example.com. 1 1 1 1 1")
assertRecord(t, records[1], "A", 0, "192.0.2.1")
assertRecord(t, records[2], "AAAA", 0, "2001:db8::1")
assertRecord(t, records[3], "MX", 3600, "10 mail.example.com") // Default TTL
assertRecord(t, records[4], "TXT", 300, "v=spf1 a mx include:mail.example.com ~all")
assertRecord(t, records[5], "CAA", 86400, "0 issue \"letsencrypt.org\"")
assertRecord(t, records[6], "TXT", 3600, "foo=bar")
assertRecord(t, records[7], "NS", 0, "ns1.example.com")
})
t.Run("LookupType example.com TXT", func(t *testing.T) {
records, ok := zFull.LookupType("example.com", "TXT")
assertOk(t, ok)
assertRecordCount(t, records, 2)
assertRecord(t, records[0], "TXT", 300, "v=spf1 a mx include:mail.example.com ~all")
assertRecord(t, records[1], "TXT", 3600, "foo=bar")
})
t.Run("Lookup www.example.com", func(t *testing.T) {
records, ok := zFull.Lookup("www.example.com")
assertOk(t, ok)
assertRecordCount(t, records, 1)
assertRecord(t, records[0], "CNAME", 3600, "example.com")
})
t.Run("LookupType www.example.com CNAME", func(t *testing.T) {
records, ok := zFull.LookupType("www.example.com", "CNAME")
assertOk(t, ok)
assertRecordCount(t, records, 1)
assertRecord(t, records[0], "CNAME", 3600, "example.com")
})
t.Run("Lookup www.example.com TXT", func(t *testing.T) {
records, ok := zFull.LookupType("www.example.com", "TXT")
assertOk(t, ok)
assertRecordCount(t, records, 0)
})
t.Run("Lookup status.example.com", func(t *testing.T) {
records, ok := zFull.Lookup("status.example.com")
assertOk(t, ok)
assertRecordCount(t, records, 2)
assertRecord(t, records[0], "A", 3600, "198.51.100.24")
assertRecord(t, records[1], "A", 3600, "203.0.113.24")
})
t.Run("Lookup partner.example.com", func(t *testing.T) {
records, ok := zFull.Lookup("partner.example.com")
assertOk(t, ok)
assertRecordCount(t, records, 2)
assertRecord(t, records[0], "NS", 3600, "ns1.example.org")
assertRecord(t, records[1], "NS", 3600, "ns2.example.org")
})
t.Run("Lookup unused.example.com", func(t *testing.T) {
records, ok := zFull.Lookup("unused.example.com")
assertOk(t, ok)
assertRecordCount(t, records, 0)
})
t.Run("Lookup ftp.internal.example.com", func(t *testing.T) {
records, ok := zFull.Lookup("ftp.internal.example.com")
assertOk(t, ok)
assertRecordCount(t, records, 1)
assertRecord(t, records[0], "A", 3600, "10.0.0.2")
})
t.Run("Lookup _xmpp-server._tcp.example.com", func(t *testing.T) {
records, ok := zFull.Lookup("_xmpp-server._tcp.example.com")
assertOk(t, ok)
assertRecordCount(t, records, 1)
assertRecord(t, records[0], "SRV", 3600, "10 0 5269 example.com")
})
t.Run("Lookup multilayer.nested.folders.example.com", func(t *testing.T) {
records, ok := zFull.Lookup("multilayer.nested.folders.example.com")
assertOk(t, ok)
assertRecordCount(t, records, 1)
assertRecord(t, records[0], "A", 3600, "192.0.2.1")
})
}
func TestBadZones(t *testing.T) {
type badZone struct {
name string
filename string
errorSubstring string
}
var badZones = []badZone{
{
name: "NonexistentFile",
filename: "testdata/bad_nonexistent.yaml",
errorSubstring: "open testdata/bad_nonexistent.yaml: no such file or directory",
},
{
name: "Directory",
filename: "testdata/bad_directory.yaml",
errorSubstring: "read testdata/bad_directory.yaml: is a directory",
},
{
name: "InvalidYaml",
filename: "testdata/bad_invalid_yaml.yaml",
errorSubstring: "[1:1]",
},
{
name: "InvalidTag",
filename: "testdata/bad_invalid_tag.yaml",
errorSubstring: "encountered unexpected tag !foo",
},
{
name: "InvalidType",
filename: "testdata/bad_invalid_type.yaml",
errorSubstring: "expected a sequence, mapping, or !include tag",
},
{
name: "MultipleDocuments",
filename: "testdata/bad_multiple_documents.yaml",
errorSubstring: "expected exactly one document, got 2",
},
{
name: "InvalidRecordMissingType",
filename: "testdata/bad_invalid_record_missing_type.yaml",
errorSubstring: "record: missing type",
},
{
name: "InvalidRecordMissingValue",
filename: "testdata/bad_invalid_record_missing_value.yaml",
errorSubstring: "record: missing value",
},
{
name: "InfiniteRecursion",
filename: "testdata/bad_recursion_1.yaml",
errorSubstring: "infinite recursion detected in testdata/bad_recursion_1.yaml",
},
{
name: "IncludeTagValueNotScalar",
filename: "testdata/bad_include_tag_value_not_scalar.yaml",
errorSubstring: "include: expected scalar",
},
{
name: "RecordTypeMismatch",
filename: "testdata/bad_record_type_mismatch.yaml",
errorSubstring: "cannot unmarshal string into Go struct field",
},
{
name: "MissingSoa",
filename: "testdata/bad_missing_soa.yaml",
errorSubstring: "records found outside zone",
},
{
name: "MissingNs",
filename: "testdata/bad_missing_ns.yaml",
errorSubstring: "zone apex missing NS records",
},
}
for _, badZone := range badZones {
t.Run(badZone.name, func(t *testing.T) {
z, err := LoadZone(badZone.filename)
if err == nil {
t.Errorf("Expected error, got nil (%v)", z)
} else if !strings.Contains(err.Error(), badZone.errorSubstring) {
t.Errorf("Expected error \"%s\", got \"%s\"", badZone.errorSubstring, err.Error())
}
})
t.Run("Include"+badZone.name, func(t *testing.T) {
z, err := LoadZoneBytes([]byte("!include "+badZone.filename), ".")
if err == nil {
t.Errorf("Expected error, got nil (%v)", z)
} else if !strings.Contains(err.Error(), badZone.errorSubstring) {
t.Errorf("Expected error \"%s\", got \"%s\"", badZone.errorSubstring, err.Error())
}
})
}
}
// This test exists solely to reach 100% test coverage. It triggers an error
// condition that should never normally be possible, but is guarded against for
// completeness.
func TestDirectUnmarshalWithInvalidBytes(t *testing.T) {
// Directly call Zone.UnmarshalYAML with invalid bytes to trigger
// yamlzone.go:103 (error in parser.ParseBytes)
z := &Zone{}
invalidBytes := []byte("[invalid yaml")
err := z.UnmarshalYAML(context.TODO(), invalidBytes)
if err == nil {
t.Fatalf("Expected error, got nil")
}
if !strings.Contains(err.Error(), "[1:") {
t.Errorf("Expected parser error, got %v", err)
}
}