344 lines
9.8 KiB
Go
344 lines
9.8 KiB
Go
package yamlzone
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/goccy/go-yaml"
|
|
)
|
|
|
|
func assertOk(t *testing.T, ok bool) {
|
|
if !ok {
|
|
t.Fatalf("Expected ok, got false")
|
|
}
|
|
}
|
|
|
|
func assertNotOk(t *testing.T, ok bool, records []Record) {
|
|
if ok {
|
|
t.Fatalf("Expected not ok, got ok (%v)", records)
|
|
}
|
|
}
|
|
|
|
func assertRecordCount(t *testing.T, records []Record, expected int) {
|
|
if len(records) != expected {
|
|
t.Fatalf("Expected %d records, got %d (%v)", expected, len(records), records)
|
|
}
|
|
}
|
|
|
|
func assertRecordType(t *testing.T, record Record, expected string) {
|
|
if record.Type != expected {
|
|
t.Fatalf("Expected %s record, got %s", expected, record.Type)
|
|
}
|
|
}
|
|
|
|
func assertRecordTtl(t *testing.T, record Record, expected uint64) {
|
|
if record.Ttl != expected {
|
|
t.Fatalf("Expected TTL %d, got %d", expected, record.Ttl)
|
|
}
|
|
}
|
|
|
|
func assertRecordValue(t *testing.T, record Record, expected string) {
|
|
if record.Value != expected {
|
|
t.Fatalf("Expected %s, got %s", expected, record.Value)
|
|
}
|
|
}
|
|
|
|
func assertRecord(t *testing.T, record Record, expectedType string, expectedTtl uint64, expectedValue string) {
|
|
assertRecordType(t, record, expectedType)
|
|
assertRecordTtl(t, record, expectedTtl)
|
|
assertRecordValue(t, record, expectedValue)
|
|
}
|
|
|
|
func TestEmptyZone(t *testing.T) {
|
|
var zEmpty Zone
|
|
|
|
t.Run("Unmarshal", func(t *testing.T) {
|
|
err := yaml.Unmarshal([]byte("{}"), &zEmpty)
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
})
|
|
|
|
if t.Failed() {
|
|
return
|
|
}
|
|
|
|
t.Run("Lookup empty string", func(t *testing.T) {
|
|
records, ok := zEmpty.Lookup("")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 0)
|
|
})
|
|
|
|
t.Run("Lookup single dot", func(t *testing.T) {
|
|
records, ok := zEmpty.Lookup(".")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 0)
|
|
})
|
|
|
|
t.Run("Lookup example.com", func(t *testing.T) {
|
|
records, ok := zEmpty.Lookup("example.com")
|
|
assertNotOk(t, ok, records)
|
|
})
|
|
|
|
t.Run("LookupType example.com A", func(t *testing.T) {
|
|
records, ok := zEmpty.LookupType("example.com", "A")
|
|
assertNotOk(t, ok, records)
|
|
})
|
|
}
|
|
|
|
func TestSimpleZone(t *testing.T) {
|
|
const zSimpleFile = "testdata/example.org.yaml"
|
|
var zSimple *Zone
|
|
|
|
t.Run("Load", func(t *testing.T) {
|
|
var err error
|
|
zSimple, err = LoadZone(zSimpleFile)
|
|
if err != nil {
|
|
t.Fatalf("Expected no error, got \"%v\"", err)
|
|
}
|
|
})
|
|
|
|
if t.Failed() {
|
|
return
|
|
}
|
|
|
|
t.Run("Lookup", func(t *testing.T) {
|
|
records, ok := zSimple.Lookup("")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 4)
|
|
assertRecord(t, records[0], "SOA", 0, "ns1.example.com. admin.example.com. 1 1 1 1 1")
|
|
assertRecord(t, records[1], "A", 0, "192.0.2.100")
|
|
assertRecord(t, records[2], "AAAA", 0, "2001:db8::100")
|
|
assertRecord(t, records[3], "NS", 0, "ns1.example.com")
|
|
})
|
|
|
|
t.Run("LookupType", func(t *testing.T) {
|
|
records, ok := zSimple.LookupType("", "A")
|
|
if !ok {
|
|
t.Fatalf("Expected ok, got false")
|
|
}
|
|
|
|
assertRecordCount(t, records, 1)
|
|
assertRecord(t, records[0], "A", 0, "192.0.2.100")
|
|
})
|
|
}
|
|
|
|
func TestFullZone(t *testing.T) {
|
|
const zFullFile = "testdata/zones.yaml"
|
|
var zFull *Zone
|
|
|
|
t.Run("Load", func(t *testing.T) {
|
|
var err error
|
|
zFull, err = LoadZone(zFullFile)
|
|
if err != nil {
|
|
t.Fatalf("Expected no error, got \"%v\"", err)
|
|
}
|
|
})
|
|
|
|
if t.Failed() {
|
|
return
|
|
}
|
|
|
|
t.Run("Lookup example.com", func(t *testing.T) {
|
|
records, ok := zFull.Lookup("example.com")
|
|
if !ok {
|
|
t.Fatalf("Expected ok, got false")
|
|
}
|
|
|
|
assertRecordCount(t, records, 8)
|
|
assertRecord(t, records[0], "SOA", 0, "ns1.example.com. admin.example.com. 1 1 1 1 1")
|
|
assertRecord(t, records[1], "A", 0, "192.0.2.1")
|
|
assertRecord(t, records[2], "AAAA", 0, "2001:db8::1")
|
|
assertRecord(t, records[3], "MX", 3600, "10 mail.example.com") // Default TTL
|
|
assertRecord(t, records[4], "TXT", 300, "v=spf1 a mx include:mail.example.com ~all")
|
|
assertRecord(t, records[5], "CAA", 86400, "0 issue \"letsencrypt.org\"")
|
|
assertRecord(t, records[6], "TXT", 3600, "foo=bar")
|
|
assertRecord(t, records[7], "NS", 0, "ns1.example.com")
|
|
|
|
})
|
|
|
|
t.Run("LookupType example.com TXT", func(t *testing.T) {
|
|
records, ok := zFull.LookupType("example.com", "TXT")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 2)
|
|
assertRecord(t, records[0], "TXT", 300, "v=spf1 a mx include:mail.example.com ~all")
|
|
assertRecord(t, records[1], "TXT", 3600, "foo=bar")
|
|
})
|
|
|
|
t.Run("Lookup www.example.com", func(t *testing.T) {
|
|
records, ok := zFull.Lookup("www.example.com")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 1)
|
|
assertRecord(t, records[0], "CNAME", 3600, "example.com")
|
|
|
|
})
|
|
|
|
t.Run("LookupType www.example.com CNAME", func(t *testing.T) {
|
|
records, ok := zFull.LookupType("www.example.com", "CNAME")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 1)
|
|
assertRecord(t, records[0], "CNAME", 3600, "example.com")
|
|
})
|
|
|
|
t.Run("Lookup www.example.com TXT", func(t *testing.T) {
|
|
records, ok := zFull.LookupType("www.example.com", "TXT")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 0)
|
|
})
|
|
|
|
t.Run("Lookup status.example.com", func(t *testing.T) {
|
|
records, ok := zFull.Lookup("status.example.com")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 2)
|
|
assertRecord(t, records[0], "A", 3600, "198.51.100.24")
|
|
assertRecord(t, records[1], "A", 3600, "203.0.113.24")
|
|
})
|
|
|
|
t.Run("Lookup partner.example.com", func(t *testing.T) {
|
|
records, ok := zFull.Lookup("partner.example.com")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 2)
|
|
assertRecord(t, records[0], "NS", 3600, "ns1.example.org")
|
|
assertRecord(t, records[1], "NS", 3600, "ns2.example.org")
|
|
})
|
|
|
|
t.Run("Lookup unused.example.com", func(t *testing.T) {
|
|
records, ok := zFull.Lookup("unused.example.com")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 0)
|
|
})
|
|
|
|
t.Run("Lookup ftp.internal.example.com", func(t *testing.T) {
|
|
records, ok := zFull.Lookup("ftp.internal.example.com")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 1)
|
|
assertRecord(t, records[0], "A", 3600, "10.0.0.2")
|
|
})
|
|
|
|
t.Run("Lookup _xmpp-server._tcp.example.com", func(t *testing.T) {
|
|
records, ok := zFull.Lookup("_xmpp-server._tcp.example.com")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 1)
|
|
assertRecord(t, records[0], "SRV", 3600, "10 0 5269 example.com")
|
|
})
|
|
|
|
t.Run("Lookup multilayer.nested.folders.example.com", func(t *testing.T) {
|
|
records, ok := zFull.Lookup("multilayer.nested.folders.example.com")
|
|
assertOk(t, ok)
|
|
assertRecordCount(t, records, 1)
|
|
assertRecord(t, records[0], "A", 3600, "192.0.2.1")
|
|
})
|
|
}
|
|
|
|
func TestBadZones(t *testing.T) {
|
|
type badZone struct {
|
|
name string
|
|
filename string
|
|
errorSubstring string
|
|
}
|
|
var badZones = []badZone{
|
|
{
|
|
name: "NonexistentFile",
|
|
filename: "testdata/bad_nonexistent.yaml",
|
|
errorSubstring: "open testdata/bad_nonexistent.yaml: no such file or directory",
|
|
},
|
|
{
|
|
name: "Directory",
|
|
filename: "testdata/bad_directory.yaml",
|
|
errorSubstring: "read testdata/bad_directory.yaml: is a directory",
|
|
},
|
|
{
|
|
name: "InvalidYaml",
|
|
filename: "testdata/bad_invalid_yaml.yaml",
|
|
errorSubstring: "[1:1]",
|
|
},
|
|
{
|
|
name: "InvalidTag",
|
|
filename: "testdata/bad_invalid_tag.yaml",
|
|
errorSubstring: "encountered unexpected tag !foo",
|
|
},
|
|
{
|
|
name: "InvalidType",
|
|
filename: "testdata/bad_invalid_type.yaml",
|
|
errorSubstring: "expected a sequence, mapping, or !include tag",
|
|
},
|
|
{
|
|
name: "MultipleDocuments",
|
|
filename: "testdata/bad_multiple_documents.yaml",
|
|
errorSubstring: "expected exactly one document, got 2",
|
|
},
|
|
{
|
|
name: "InvalidRecordMissingType",
|
|
filename: "testdata/bad_invalid_record_missing_type.yaml",
|
|
errorSubstring: "record: missing type",
|
|
},
|
|
{
|
|
name: "InvalidRecordMissingValue",
|
|
filename: "testdata/bad_invalid_record_missing_value.yaml",
|
|
errorSubstring: "record: missing value",
|
|
},
|
|
{
|
|
name: "InfiniteRecursion",
|
|
filename: "testdata/bad_recursion_1.yaml",
|
|
errorSubstring: "infinite recursion detected in testdata/bad_recursion_1.yaml",
|
|
},
|
|
{
|
|
name: "IncludeTagValueNotScalar",
|
|
filename: "testdata/bad_include_tag_value_not_scalar.yaml",
|
|
errorSubstring: "include: expected scalar",
|
|
},
|
|
{
|
|
name: "RecordTypeMismatch",
|
|
filename: "testdata/bad_record_type_mismatch.yaml",
|
|
errorSubstring: "cannot unmarshal string into Go struct field",
|
|
},
|
|
{
|
|
name: "MissingSoa",
|
|
filename: "testdata/bad_missing_soa.yaml",
|
|
errorSubstring: "records found outside zone",
|
|
},
|
|
{
|
|
name: "MissingNs",
|
|
filename: "testdata/bad_missing_ns.yaml",
|
|
errorSubstring: "zone apex missing NS records",
|
|
},
|
|
}
|
|
|
|
for _, badZone := range badZones {
|
|
t.Run(badZone.name, func(t *testing.T) {
|
|
z, err := LoadZone(badZone.filename)
|
|
if err == nil {
|
|
t.Errorf("Expected error, got nil (%v)", z)
|
|
} else if !strings.Contains(err.Error(), badZone.errorSubstring) {
|
|
t.Errorf("Expected error \"%s\", got \"%s\"", badZone.errorSubstring, err.Error())
|
|
}
|
|
})
|
|
t.Run("Include"+badZone.name, func(t *testing.T) {
|
|
z, err := LoadZoneBytes([]byte("!include "+badZone.filename), ".")
|
|
if err == nil {
|
|
t.Errorf("Expected error, got nil (%v)", z)
|
|
} else if !strings.Contains(err.Error(), badZone.errorSubstring) {
|
|
t.Errorf("Expected error \"%s\", got \"%s\"", badZone.errorSubstring, err.Error())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// This test exists solely to reach 100% test coverage. It triggers an error
|
|
// condition that should never normally be possible, but is guarded against for
|
|
// completeness.
|
|
func TestDirectUnmarshalWithInvalidBytes(t *testing.T) {
|
|
// Directly call Zone.UnmarshalYAML with invalid bytes to trigger
|
|
// yamlzone.go:103 (error in parser.ParseBytes)
|
|
z := &Zone{}
|
|
invalidBytes := []byte("[invalid yaml")
|
|
|
|
err := z.UnmarshalYAML(context.TODO(), invalidBytes)
|
|
if err == nil {
|
|
t.Fatalf("Expected error, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "[1:") {
|
|
t.Errorf("Expected parser error, got %v", err)
|
|
}
|
|
}
|