From c736f4c1e96f5c92501d3f3d74f118e36ffbeeba Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 18 Jun 2025 16:06:33 +0700 Subject: [PATCH] test: improve DNS resolver tests reliability and thread safety - Add timeouts and proper cleanup in Test_osResolver_Singleflight: * Implement context timeout * Add proper PacketConn cleanup * Fix race conditions in error handling * Improve atomic value reporting - Enhance Test_osResolver_HotCache: * Add proper timeout context * Implement more reliable cache verification * Fix potential resource leaks * Add deterministic polling intervals - Add thread safety to Test_Edns0_CacheReply: * Implement proper timeout context * Add proper resource cleanup * Fix concurrent operations handling The changes improve overall test suite reliability by addressing resource management, timeout handling, and thread safety concerns across multiple DNS resolver test cases. --- resolver_test.go | 112 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 85 insertions(+), 27 deletions(-) diff --git a/resolver_test.go b/resolver_test.go index d5a76d6..1606529 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -143,6 +143,8 @@ func Test_osResolver_Singleflight(t *testing.T) { if err != nil { t.Fatalf("failed to listen on LAN address: %v", err) } + defer lanPC.Close() + call := &atomic.Int64{} lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) if err != nil { @@ -153,7 +155,13 @@ func Test_osResolver_Singleflight(t *testing.T) { or := newResolverWithNameserver([]string{lanAddr}) domain := "controld.com" n := 10 + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var wg sync.WaitGroup + errs := make(chan error, n) + wg.Add(n) for i := 0; i < n; i++ { go func() { @@ -161,25 +169,40 @@ func Test_osResolver_Singleflight(t *testing.T) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeA) m.RecursionDesired = true - _, err := or.Resolve(context.Background(), m) + _, err := or.Resolve(ctx, m) if err != nil { - t.Error(err) + errs <- err } }() } wg.Wait() + close(errs) + + // Collect any errors that occurred + for err := range errs { + t.Errorf("resolver error: %v", err) + } // All above queries should only make 1 call to server. - if call.Load() != 1 { - t.Fatalf("expected 1 result from singleflight lookup, got %d", call) + if got := call.Load(); got != 1 { + t.Fatalf("expected 1 result from singleflight lookup, got %d", got) } } func Test_osResolver_HotCache(t *testing.T) { + const ( + testIterations = 2 + cacheCheckTimeout = 5 * time.Second + pollInterval = 10 * time.Millisecond + ) + + // Setup test server lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to listen on LAN address: %v", err) } + defer lanPC.Close() + call := &atomic.Int64{} lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) if err != nil { @@ -187,58 +210,81 @@ func Test_osResolver_HotCache(t *testing.T) { } defer lanServer.Shutdown() + // Initialize resolver or := newResolverWithNameserver([]string{lanAddr}) domain := "controld.com" m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeA) m.RecursionDesired = true - // Make 2 repeated queries to server, should hit hot cache. - for i := 0; i < 2; i++ { - if _, err := or.Resolve(context.Background(), m.Copy()); err != nil { + // Setup context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Make repeated queries to server, should hit hot cache + for i := 0; i < testIterations; i++ { + resp, err := or.Resolve(ctx, m.Copy()) + if err != nil { t.Fatal(err) } + // Verify response content + if resp.Rcode != dns.RcodeSuccess { + t.Errorf("expected success response, got %v", resp.Rcode) + } } + if call.Load() != 1 { t.Fatalf("cache not hit, server was called: %d", call.Load()) } + // Wait for cache to be cleaned timeoutChan := make(chan struct{}) - time.AfterFunc(5*time.Second, func() { + time.AfterFunc(cacheCheckTimeout, func() { close(timeoutChan) }) + // Check cache with proper polling interval +waitLoop: for { select { case <-timeoutChan: t.Fatal("timed out waiting for cache cleaned") - default: + case <-time.After(pollInterval): count := 0 or.cache.Range(func(key, value interface{}) bool { count++ return true }) - if count != 0 { - t.Logf("hot cache is not empty: %d elements", count) - continue + if count == 0 { + break waitLoop } + t.Logf("hot cache is not empty: %d elements", count) } - break } - if _, err := or.Resolve(context.Background(), m.Copy()); err != nil { + // Verify cache miss after cleanup + resp, err := or.Resolve(ctx, m.Copy()) + if err != nil { t.Fatal(err) } + if resp.Rcode != dns.RcodeSuccess { + t.Errorf("expected success response after cache cleanup, got %v", resp.Rcode) + } if call.Load() != 2 { t.Fatal("cache hit unexpectedly") } } func Test_Edns0_CacheReply(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to listen on LAN address: %v", err) } + defer lanPC.Close() + call := &atomic.Int64{} lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) if err != nil { @@ -252,33 +298,45 @@ func Test_Edns0_CacheReply(t *testing.T) { m.SetQuestion(dns.Fqdn(domain), dns.TypeA) m.RecursionDesired = true - do := func() *dns.Msg { + do := func() (*dns.Msg, error) { msg := m.Copy() msg.SetEdns0(4096, true) cookieOption := new(dns.EDNS0_COOKIE) cookieOption.Code = dns.EDNS0COOKIE cookieOption.Cookie = generateEdns0ClientCookie() msg.IsEdns0().Option = append(msg.IsEdns0().Option, cookieOption) + return or.Resolve(ctx, msg) + } - answer, err := or.Resolve(context.Background(), msg) - if err != nil { - t.Fatal(err) - } - return answer + answer1, err := do() + if err != nil { + t.Fatalf("first resolve failed: %v", err) } - answer1 := do() - answer2 := do() - // Ensure the cache was hit, so we can check that edns0 cookie must be modified. - if call.Load() != 1 { - t.Fatalf("cache not hit, server was called: %d", call.Load()) + + answer2, err := do() + if err != nil { + t.Fatalf("second resolve failed: %v", err) } + + // Ensure the cache was hit + if got := call.Load(); got != 1 { + t.Fatalf("expected 1 server call, got: %d", got) + } + cookie1 := getEdns0Cookie(answer1.IsEdns0()) cookie2 := getEdns0Cookie(answer2.IsEdns0()) + if cookie1 == nil || cookie2 == nil { - t.Fatalf("unexpected nil cookie value (cookie1: %v, cookie2: %v)", cookie1, cookie2) + t.Fatalf("unexpected nil cookie (cookie1: %v, cookie2: %v)", cookie1, cookie2) } + if cookie1.Cookie == cookie2.Cookie { - t.Fatalf("edns0 cookie is not modified: %v", cookie1) + t.Fatalf("edns0 cookie was not modified (cookie: %v)", cookie1.Cookie) + } + + // Validate response code + if answer1.Rcode != dns.RcodeSuccess || answer2.Rcode != dns.RcodeSuccess { + t.Errorf("expected success response code, got: %v, %v", answer1.Rcode, answer2.Rcode) } }