package yamlzone import ( "context" "strings" "testing" "github.com/goccy/go-yaml" ) func assertOk(t *testing.T, ok bool) { if !ok { t.Fatalf("Expected ok, got false") } } func assertNotOk(t *testing.T, ok bool, res LookupResult) { if ok { t.Fatalf("Expected not ok, got ok (%v)", res) } } func assertRecordCount(t *testing.T, res LookupResult, expected int) { if len(res.Answer) != expected { t.Fatalf("Expected %d records, got %d (%v)", expected, len(res.Answer), res.Answer) } } func assertRecordType(t *testing.T, record Record, expected string) { if record.Type != expected { t.Fatalf("Expected %s record, got %s", expected, record.Type) } } func assertRecordTtl(t *testing.T, record Record, expected uint64) { if record.Ttl != expected { t.Fatalf("Expected TTL %d, got %d", expected, record.Ttl) } } func assertRecordValue(t *testing.T, record Record, expected string) { if record.Value != expected { t.Fatalf("Expected %s, got %s", expected, record.Value) } } func assertRecord(t *testing.T, record Record, expectedType string, expectedTtl uint64, expectedValue string) { assertRecordType(t, record, expectedType) assertRecordTtl(t, record, expectedTtl) assertRecordValue(t, record, expectedValue) } func TestEmptyZone(t *testing.T) { var zEmpty Zone t.Run("Unmarshal", func(t *testing.T) { err := yaml.Unmarshal([]byte("{}"), &zEmpty) if err != nil { t.Errorf("Expected no error, got %v", err) } }) if t.Failed() { return } t.Run("Lookup empty string", func(t *testing.T) { res, ok := zEmpty.Lookup("") assertOk(t, ok) assertRecordCount(t, res, 0) }) t.Run("Lookup single dot", func(t *testing.T) { res, ok := zEmpty.Lookup(".") assertOk(t, ok) assertRecordCount(t, res, 0) }) t.Run("Lookup example.com", func(t *testing.T) { res, ok := zEmpty.Lookup("example.com") assertNotOk(t, ok, res) }) t.Run("LookupType example.com A", func(t *testing.T) { res, ok := zEmpty.LookupType("example.com", "A") assertNotOk(t, ok, res) }) } func TestSimpleZone(t *testing.T) { const zSimpleFile = "testdata/example.org.yaml" var zSimple *Zone t.Run("Load", func(t *testing.T) { var err error zSimple, err = LoadZone(zSimpleFile) if err != nil { t.Fatalf("Expected no error, got \"%v\"", err) } }) if t.Failed() { return } t.Run("Lookup", func(t *testing.T) { res, ok := zSimple.Lookup("") assertOk(t, ok) assertRecordCount(t, res, 4) assertRecord(t, res.Answer[0].Record, "SOA", 0, "ns1.example.com. admin.example.com. 1 1 1 1 1") assertRecord(t, res.Answer[1].Record, "A", 0, "192.0.2.100") assertRecord(t, res.Answer[2].Record, "AAAA", 0, "2001:db8::100") assertRecord(t, res.Answer[3].Record, "NS", 0, "ns1.example.com") }) t.Run("LookupType", func(t *testing.T) { res, ok := zSimple.LookupType("", "A") if !ok { t.Fatalf("Expected ok, got false") } assertRecordCount(t, res, 1) assertRecord(t, res.Answer[0].Record, "A", 0, "192.0.2.100") }) } func TestFullZone(t *testing.T) { const zFullFile = "testdata/zones.yaml" var zFull *Zone t.Run("Load", func(t *testing.T) { var err error zFull, err = LoadZone(zFullFile) if err != nil { t.Fatalf("Expected no error, got \"%v\"", err) } }) if t.Failed() { return } t.Run("Lookup example.com", func(t *testing.T) { res, ok := zFull.Lookup("example.com") if !ok { t.Fatalf("Expected ok, got false") } assertRecordCount(t, res, 8) assertRecord(t, res.Answer[0].Record, "SOA", 0, "ns1.example.com. admin.example.com. 1 1 1 1 1") assertRecord(t, res.Answer[1].Record, "A", 0, "192.0.2.1") assertRecord(t, res.Answer[2].Record, "AAAA", 0, "2001:db8::1") assertRecord(t, res.Answer[3].Record, "MX", 3600, "10 mail.example.com") // Default TTL assertRecord(t, res.Answer[4].Record, "TXT", 300, "v=spf1 a mx include:mail.example.com ~all") assertRecord(t, res.Answer[5].Record, "CAA", 86400, "0 issue \"letsencrypt.org\"") assertRecord(t, res.Answer[6].Record, "TXT", 3600, "foo=bar") assertRecord(t, res.Answer[7].Record, "NS", 0, "ns1.example.com") }) t.Run("LookupType example.com TXT", func(t *testing.T) { res, ok := zFull.LookupType("example.com", "TXT") assertOk(t, ok) assertRecordCount(t, res, 2) assertRecord(t, res.Answer[0].Record, "TXT", 300, "v=spf1 a mx include:mail.example.com ~all") assertRecord(t, res.Answer[1].Record, "TXT", 3600, "foo=bar") }) t.Run("Lookup www.example.com", func(t *testing.T) { res, ok := zFull.Lookup("www.example.com") assertOk(t, ok) assertRecordCount(t, res, 1) assertRecord(t, res.Answer[0].Record, "CNAME", 3600, "example.com") }) t.Run("LookupType www.example.com CNAME", func(t *testing.T) { res, ok := zFull.LookupType("www.example.com", "CNAME") assertOk(t, ok) assertRecordCount(t, res, 1) assertRecord(t, res.Answer[0].Record, "CNAME", 3600, "example.com") }) t.Run("Lookup www.example.com TXT", func(t *testing.T) { res, ok := zFull.LookupType("www.example.com", "TXT") assertOk(t, ok) assertRecordCount(t, res, 0) }) t.Run("Lookup status.example.com", func(t *testing.T) { res, ok := zFull.Lookup("status.example.com") assertOk(t, ok) assertRecordCount(t, res, 2) assertRecord(t, res.Answer[0].Record, "A", 3600, "198.51.100.24") assertRecord(t, res.Answer[1].Record, "A", 3600, "203.0.113.24") }) t.Run("Lookup partner.example.com", func(t *testing.T) { res, ok := zFull.Lookup("partner.example.com") assertOk(t, ok) assertRecordCount(t, res, 2) assertRecord(t, res.Answer[0].Record, "NS", 3600, "ns1.example.org") assertRecord(t, res.Answer[1].Record, "NS", 3600, "ns2.example.org") }) t.Run("Lookup unused.example.com", func(t *testing.T) { res, ok := zFull.Lookup("unused.example.com") assertOk(t, ok) assertRecordCount(t, res, 0) }) t.Run("Lookup ftp.internal.example.com", func(t *testing.T) { res, ok := zFull.Lookup("ftp.internal.example.com") assertOk(t, ok) assertRecordCount(t, res, 1) assertRecord(t, res.Answer[0].Record, "A", 3600, "10.0.0.2") }) t.Run("Lookup _xmpp-server._tcp.example.com", func(t *testing.T) { res, ok := zFull.Lookup("_xmpp-server._tcp.example.com") assertOk(t, ok) assertRecordCount(t, res, 1) assertRecord(t, res.Answer[0].Record, "SRV", 3600, "10 0 5269 example.com") }) t.Run("Lookup multilayer.nested.folders.example.com", func(t *testing.T) { res, ok := zFull.Lookup("multilayer.nested.folders.example.com") assertOk(t, ok) assertRecordCount(t, res, 1) assertRecord(t, res.Answer[0].Record, "A", 3600, "192.0.2.1") }) } func TestBadZones(t *testing.T) { type badZone struct { name string filename string errorSubstring string } var badZones = []badZone{ { name: "CnameWithOther", filename: "testdata/bad_cname_with_other.yaml", errorSubstring: "extraneous records found next to CNAME", }, { name: "NonexistentFile", filename: "testdata/bad_nonexistent.yaml", errorSubstring: "open testdata/bad_nonexistent.yaml: no such file or directory", }, { name: "Directory", filename: "testdata/bad_directory.yaml", errorSubstring: "read testdata/bad_directory.yaml: is a directory", }, { name: "InvalidYaml", filename: "testdata/bad_invalid_yaml.yaml", errorSubstring: "[1:1]", }, { name: "InvalidTag", filename: "testdata/bad_invalid_tag.yaml", errorSubstring: "encountered unexpected tag !foo", }, { name: "InvalidType", filename: "testdata/bad_invalid_type.yaml", errorSubstring: "expected a sequence, mapping, or !include tag", }, { name: "MultipleDocuments", filename: "testdata/bad_multiple_documents.yaml", errorSubstring: "expected exactly one document, got 2", }, { name: "InvalidRecordMissingType", filename: "testdata/bad_invalid_record_missing_type.yaml", errorSubstring: "record: missing type", }, { name: "InvalidRecordMissingValue", filename: "testdata/bad_invalid_record_missing_value.yaml", errorSubstring: "record: missing value", }, { name: "InfiniteRecursion", filename: "testdata/bad_recursion_1.yaml", errorSubstring: "infinite recursion detected in testdata/bad_recursion_1.yaml", }, { name: "IncludeTagValueNotScalar", filename: "testdata/bad_include_tag_value_not_scalar.yaml", errorSubstring: "include: expected scalar", }, { name: "RecordTypeMismatch", filename: "testdata/bad_record_type_mismatch.yaml", errorSubstring: "cannot unmarshal string into Go struct field", }, { name: "MissingSoa", filename: "testdata/bad_missing_soa.yaml", errorSubstring: "records found outside zone", }, { name: "MissingNs", filename: "testdata/bad_missing_ns.yaml", errorSubstring: "zone apex missing NS records", }, { name: "NsWithOther", filename: "testdata/bad_ns_with_other.yaml", errorSubstring: "non-glue, non-NS records found at delegation point", }, { name: "NsWithSubzone", filename: "testdata/bad_ns_with_subzone.yaml", errorSubstring: "non-glue records found under delegation point", }, { name: "GlueWithOther", filename: "testdata/bad_glue_with_other.yaml", errorSubstring: "non-glue record found under delegation point", }, { name: "CnameWithOther", filename: "testdata/bad_cname_with_other.yaml", errorSubstring: "extraneous records found next to CNAME", }, } for _, badZone := range badZones { t.Run(badZone.name, func(t *testing.T) { z, err := LoadZone(badZone.filename) if err == nil { t.Errorf("Expected error, got nil (%v)", z) } else if !strings.Contains(err.Error(), badZone.errorSubstring) { t.Errorf("Expected error \"%s\", got \"%s\"", badZone.errorSubstring, err.Error()) } }) t.Run("Include"+badZone.name, func(t *testing.T) { z, err := LoadZoneBytes([]byte("!include "+badZone.filename), ".") if err == nil { t.Errorf("Expected error, got nil (%v)", z) } else if !strings.Contains(err.Error(), badZone.errorSubstring) { t.Errorf("Expected error \"%s\", got \"%s\"", badZone.errorSubstring, err.Error()) } }) } } // This test exists solely to reach 100% test coverage. It triggers an error // condition that should never normally be possible, but is guarded against for // completeness. func TestDirectUnmarshalWithInvalidBytes(t *testing.T) { // Directly call Zone.UnmarshalYAML with invalid bytes to trigger // yamlzone.go:103 (error in parser.ParseBytes) z := &Zone{} invalidBytes := []byte("[invalid yaml") err := z.UnmarshalYAML(context.TODO(), invalidBytes) if err == nil { t.Fatalf("Expected error, got nil") } if !strings.Contains(err.Error(), "[1:") { t.Errorf("Expected parser error, got %v", err) } }