From ba48ff5965938879202dcbd9678ab8519ac9cf8f Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 30 Mar 2023 02:43:37 +0700 Subject: [PATCH] all: fix os resolver hangs when all server failed For os resolver, ctrld queries against all servers concurrently, and get the first success result back. However, if all server failed, the result channel is not closed, causing ctrld hang. Fixing this by closing the result channel once getting back all response from servers. While at it, also shorten the backoff time when waiting for network up, ctrld should serve as fast as possible after network is available. Updates #34 --- cmd/ctrld/cli.go | 3 +-- internal/net/net.go | 2 +- resolver.go | 8 ++++++++ resolver_test.go | 29 +++++++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 3 deletions(-) create mode 100644 resolver_test.go diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index 4816908..938860a 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -160,8 +160,7 @@ func initCLI() { log.Fatalf("failed to unmarshal config: %v", err) } - log.Println("starting ctrld ...") - log.Printf("version: %s\n", curVersion()) + log.Printf("starting ctrld %s\n", curVersion()) oi := osinfo.New() log.Printf("os: %s\n", oi.String()) diff --git a/internal/net/net.go b/internal/net/net.go index 888a2d6..4e71206 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -66,7 +66,7 @@ func supportListenIPv6Local() bool { } func probeStack() { - b := backoff.NewBackoff("probeStack", func(format string, args ...any) {}, time.Minute) + b := backoff.NewBackoff("probeStack", func(format string, args ...any) {}, 5*time.Second) for { if _, err := probeStackDialer.Dial("udp", bootstrapDNS); err == nil { hasNetworkUp = true diff --git a/resolver.go b/resolver.go index a12c700..45537fa 100644 --- a/resolver.go +++ b/resolver.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "sync" "github.com/miekg/dns" ) @@ -69,8 +70,15 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error dnsClient := &dns.Client{Net: "udp"} ch := make(chan *osResolverResult, numServers) + var wg sync.WaitGroup + wg.Add(len(o.nameservers)) + go func() { + wg.Wait() + close(ch) + }() for _, server := range o.nameservers { go func(server string) { + defer wg.Done() answer, _, err := dnsClient.ExchangeContext(ctx, msg, server) ch <- &osResolverResult{answer: answer, err: err} }(server) diff --git a/resolver_test.go b/resolver_test.go new file mode 100644 index 0000000..a5a93c4 --- /dev/null +++ b/resolver_test.go @@ -0,0 +1,29 @@ +package ctrld + +import ( + "context" + "testing" + "time" + + "github.com/miekg/dns" +) + +func Test_osResolver_Resolve(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + defer cancel() + resolver := &osResolver{nameservers: []string{"127.0.0.127:5353"}} + m := new(dns.Msg) + m.SetQuestion("controld.com.", dns.TypeA) + m.RecursionDesired = true + _, _ = resolver.Resolve(context.Background(), m) + }() + + select { + case <-time.After(10 * time.Second): + t.Error("os resolver hangs") + case <-ctx.Done(): + } +}