diff --git a/resolver.go b/resolver.go index 6379203..b61aef0 100644 --- a/resolver.go +++ b/resolver.go @@ -290,7 +290,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error } } - Log(ctx, ProxyLogger.Load().Debug(), "os resolver query with nameservers: %v", nss) + Log(ctx, ProxyLogger.Load().Debug(), "os resolver query with nameservers: %v public: %v", nss, publicServers) // New check: If no resolvers are available, return an error. if numServers == 0 { @@ -343,11 +343,15 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error } // try local nameservers - do(nss, true) + if len(nss) > 0 { + do(nss, true) + } // we must always try the public servers too, since DCHP may have only public servers // this is okay to do since we always prefer LAN nameserver responses - do(publicServers, false) + if len(publicServers) > 0 { + do(publicServers, false) + } var ( nonSuccessAnswer *dns.Msg @@ -369,33 +373,49 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error case res.server == controldPublicDnsWithPort: controldSuccessAnswer = res.answer case !res.lan: + // if there are no LAN nameservers, we should not wait + // just use the first response + if len(nss) == 0 { + Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", res.server) + cancel() + logAnswer(res.server) + return res.answer, nil + } publicResponses = append(publicResponses, publicResponse{ answer: res.answer, server: res.server, }) } case res.answer != nil: - nonSuccessAnswer = res.answer - nonSuccessServer = res.server Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d", res.server, res.answer.Rcode) + // When there are no LAN nameservers, we should not wait + // for other nameservers to respond. + if len(nss) == 0 { + Log(ctx, ProxyLogger.Load().Debug(), "no lan nameservers using public non success answer") + cancel() + logAnswer(res.server) + return res.answer, nil + } + nonSuccessAnswer = res.answer + nonSuccessServer = res.server } errs = append(errs, res.err) } if len(publicResponses) > 0 { resp := publicResponses[0] - Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", resp.server) + Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", resp.server) logAnswer(resp.server) return resp.answer, nil } if controldSuccessAnswer != nil { - Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort) + Log(ctx, ProxyLogger.Load().Debug(), "using ControlD answer from: %s", controldPublicDnsWithPort) logAnswer(controldPublicDnsWithPort) return controldSuccessAnswer, nil } if nonSuccessAnswer != nil { - Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s", nonSuccessServer) + Log(ctx, ProxyLogger.Load().Debug(), "using non-success answer from: %s", nonSuccessServer) logAnswer(nonSuccessServer) return nonSuccessAnswer, nil } diff --git a/resolver_test.go b/resolver_test.go index e96e875..fb6831b 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -70,41 +70,59 @@ func Test_osResolver_ResolveLanHostname(t *testing.T) { } func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { - ns := make([]string, 0, 2) - servers := make([]*dns.Server, 0, 2) - handlers := []dns.Handler{ + // Set up a LAN nameserver that returns a success response. + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") // 127.0.0.1 is considered LAN (loopback) + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, successHandler()) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + // Set up two public nameservers that return non-success responses. + publicHandlers := []dns.Handler{ nonSuccessHandlerWithRcode(dns.RcodeRefused), nonSuccessHandlerWithRcode(dns.RcodeNameError), - successHandler(), } - for i := range handlers { + var publicNS []string + var publicServers []*dns.Server + for _, handler := range publicHandlers { pc, err := net.ListenPacket("udp", ":0") if err != nil { - t.Fatalf("unexpected error: %v", err) + t.Fatalf("failed to listen on public address: %v", err) } - - s, addr, err := runLocalPacketConnTestServer(t, pc, handlers[i]) + s, addr, err := runLocalPacketConnTestServer(t, pc, handler) if err != nil { - t.Fatalf("unexpected error: %v", err) + t.Fatalf("failed to run public test server: %v", err) } - ns = append(ns, addr) - servers = append(servers, s) + publicNS = append(publicNS, addr) + publicServers = append(publicServers, s) } defer func() { - for _, server := range servers { - server.Shutdown() + for _, s := range publicServers { + s.Shutdown() } }() + + // We now create an osResolver which has both a LAN and public nameserver. resolver := &osResolver{} - resolver.publicServers.Store(&ns) + // Explicitly store the LAN nameserver. + resolver.lanServers.Store(&[]string{lanAddr}) + // And store the public nameservers. + resolver.publicServers.Store(&publicNS) + msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) answer, err := resolver.Resolve(context.Background(), msg) if err != nil { t.Fatal(err) } + + // Since a LAN nameserver is available and returns a success answer, we expect RcodeSuccess. if answer.Rcode != dns.RcodeSuccess { - t.Errorf("unexpected return code: %s", dns.RcodeToString[answer.Rcode]) + t.Errorf("expected a success answer from LAN nameserver (RcodeSuccess) but got: %s", dns.RcodeToString[answer.Rcode]) } }