diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b4b44d4..5bd4d27 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ jobs: fail-fast: false matrix: os: ["windows-latest", "ubuntu-latest", "macOS-latest"] - go: ["1.23.x"] + go: ["1.24.x"] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v3 @@ -21,6 +21,6 @@ jobs: - run: "go test -race ./..." - uses: dominikh/staticcheck-action@v1.3.1 with: - version: "2024.1.1" + version: "2025.1" install-go: false cache-key: ${{ matrix.go }} diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 33012fa..f1aa445 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -84,13 +84,7 @@ type upstreamForResult struct { srcAddr string } -func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { - // Start network monitoring - if err := p.monitorNetworkChanges(mainCtx); err != nil { - mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring") - // Don't return here as we still want DNS service to run - } - +func (p *prog) serveDNS(listenerNum string) error { listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { @@ -1187,7 +1181,7 @@ func FlushDNSCache() error { } // monitorNetworkChanges starts monitoring for network interface changes -func (p *prog) monitorNetworkChanges(ctx context.Context) error { +func (p *prog) monitorNetworkChanges() error { mon, err := netmon.New(func(format string, args ...any) { // Always fetch the latest logger (and inject the prefix) mainLog.Load().Printf("netmon: "+format, args...) @@ -1406,9 +1400,6 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro return err } - msg := new(dns.Msg) - msg.SetQuestion(".", dns.TypeNS) - timeout := 1000 * time.Millisecond if uc.Timeout > 0 { timeout = time.Millisecond * time.Duration(uc.Timeout) @@ -1422,6 +1413,7 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream) start := time.Now() + msg := uc.VerifyMsg() _, err = resolver.Resolve(ctx, msg) duration := time.Since(start) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 90f403d..76f7c36 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -530,6 +530,15 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { go p.watchLinkState(ctx) } + if !reload { + go func() { + // Start network monitoring + if err := p.monitorNetworkChanges(); err != nil { + mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring") + } + }() + } + for listenerNum := range p.cfg.Listener { p.cfg.Listener[listenerNum].Init() if !reload { @@ -541,7 +550,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) - if err := p.serveDNS(ctx, listenerNum); err != nil { + if err := p.serveDNS(listenerNum); err != nil { mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) } mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr) diff --git a/config.go b/config.go index 96f6686..73484d7 100644 --- a/config.go +++ b/config.go @@ -358,6 +358,15 @@ func (uc *UpstreamConfig) Init() { } } +// VerifyMsg creates and returns a new DNS message could be used for testing upstream health. +func (uc *UpstreamConfig) VerifyMsg() *dns.Msg { + msg := new(dns.Msg) + msg.RecursionDesired = true + msg.SetQuestion(".", dns.TypeNS) + msg.SetEdns0(4096, false) // ensure handling of large DNS response + return msg +} + // VerifyDomain returns the domain name that could be resolved by the upstream endpoint. // It returns empty for non-ControlD upstream endpoint. func (uc *UpstreamConfig) VerifyDomain() string { diff --git a/internal/clientinfo/dhcp_lease_files.go b/internal/clientinfo/dhcp_lease_files.go index 1b5d829..34aabf3 100644 --- a/internal/clientinfo/dhcp_lease_files.go +++ b/internal/clientinfo/dhcp_lease_files.go @@ -16,4 +16,5 @@ var clientInfoFiles = map[string]ctrld.LeaseFileFormat{ "/var/dhcpd/var/db/dhcpd.leases": ctrld.IscDhcpd, // Pfsense "/home/pi/.router/run/dhcp/dnsmasq.leases": ctrld.Dnsmasq, // Firewalla "/var/lib/kea/dhcp4.leases": ctrld.KeaDHCP4, // Pfsense + "/var/db/dnsmasq.leases": ctrld.Dnsmasq, // OPNsense } diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index e009e01..a09d729 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -74,7 +74,6 @@ func (m *mdns) lookupIPByHostname(name string, v6 bool) string { if value == name { if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { ip = addr.String() - //lint:ignore S1008 This is used for readable. if addr.IsLoopback() { // Continue searching if this is loopback address. return true } diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index 8e6b3f7..9a1d10c 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -104,7 +104,6 @@ func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string { if value == name { if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { ip = addr.String() - //lint:ignore S1008 This is used for readable. if addr.IsLoopback() { // Continue searching if this is loopback address. return true } @@ -120,8 +119,7 @@ func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string { // is reachable, set p.serverDown to false, so p.lookupHostname can continue working. func (p *ptrDiscover) checkServer() { bo := backoff.NewBackoff("ptrDiscover", func(format string, args ...any) {}, time.Minute*5) - m := new(dns.Msg) - m.SetQuestion(".", dns.TypeNS) + m := (&ctrld.UpstreamConfig{}).VerifyMsg() ping := func() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() diff --git a/nameservers_windows.go b/nameservers_windows.go index eb4f2b5..7b16e8e 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -23,20 +23,17 @@ import ( ) const ( - maxDNSAdapterRetries = 5 - retryDelayDNSAdapter = 1 * time.Second - defaultDNSAdapterTimeout = 10 * time.Second - minDNSServers = 1 // Minimum number of DNS servers we want to find - NetSetupUnknown uint32 = 0 - NetSetupWorkgroup uint32 = 1 - NetSetupDomain uint32 = 2 - NetSetupCloudDomain uint32 = 3 - DS_FORCE_REDISCOVERY = 0x00000001 - DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010 - DS_BACKGROUND_ONLY = 0x00000100 - DS_IP_REQUIRED = 0x00000200 - DS_IS_DNS_NAME = 0x00020000 - DS_RETURN_DNS_NAME = 0x40000000 + maxDNSAdapterRetries = 5 + retryDelayDNSAdapter = 1 * time.Second + defaultDNSAdapterTimeout = 10 * time.Second + minDNSServers = 1 // Minimum number of DNS servers we want to find + + DS_FORCE_REDISCOVERY = 0x00000001 + DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010 + DS_BACKGROUND_ONLY = 0x00000100 + DS_IP_REQUIRED = 0x00000200 + DS_IS_DNS_NAME = 0x00020000 + DS_RETURN_DNS_NAME = 0x40000000 ) type DomainControllerInfo struct { @@ -158,7 +155,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { 0, // DomainGuid - not needed 0, // SiteName - not needed uintptr(flags), // Flags - uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output + uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output if ret != 0 { switch ret { @@ -343,27 +340,28 @@ func checkDomainJoined() bool { var domain *uint16 var status uint32 - err := windows.NetGetJoinInformation(nil, &domain, &status) - if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get domain join status: %v", err) + if err := windows.NetGetJoinInformation(nil, &domain, &status); err != nil { + Log(context.Background(), logger.Debug(), "Failed to get domain join status: %v", err) return false } defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) + // NETSETUP_JOIN_STATUS constants from Microsoft Windows API + // See: https://learn.microsoft.com/en-us/windows/win32/api/lmjoin/ne-lmjoin-netsetup_join_status + // + // NetSetupUnknownStatus uint32 = 0 // The status is unknown + // NetSetupUnjoined uint32 = 1 // The computer is not joined to a domain or workgroup + // NetSetupWorkgroupName uint32 = 2 // The computer is joined to a workgroup + // NetSetupDomainName uint32 = 3 // The computer is joined to a domain + // + // We only care about NetSetupDomainName. domainName := windows.UTF16PtrToString(domain) Log(context.Background(), logger.Debug(), - "Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", + "Domain join status: domain=%s status=%d (UnknownStatus=0, Unjoined=1, WorkgroupName=2, DomainName=3)", domainName, status) - // Consider domain or cloud domain as domain-joined - isDomain := status == NetSetupDomain || status == NetSetupCloudDomain - Log(context.Background(), logger.Debug(), - "Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v", - status, - status == NetSetupDomain, - status == NetSetupCloudDomain, - isDomain) + isDomain := status == syscall.NetSetupDomainName + Log(context.Background(), logger.Debug(), "Is domain joined? status=%d, result=%v", status, isDomain) return isDomain } diff --git a/resolver_test.go b/resolver_test.go index ebcad16..f030739 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -282,6 +282,35 @@ func Test_Edns0_CacheReply(t *testing.T) { } } +// https://github.com/Control-D-Inc/ctrld/issues/255 +func Test_legacyResolverWithBigExtraSection(t *testing.T) { + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") // 127.0.0.1 is considered LAN (loopback) + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, bigExtraSectionHandler()) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + uc := &UpstreamConfig{ + Name: "Legacy", + Type: ResolverTypeLegacy, + Endpoint: lanAddr, + } + uc.Init() + r, err := NewResolver(uc) + if err != nil { + t.Fatal(err) + } + + _, err = r.Resolve(context.Background(), uc.VerifyMsg()) + if err != nil { + t.Fatal(err) + } +} + func Test_upstreamTypeFromEndpoint(t *testing.T) { tests := []struct { name string @@ -370,6 +399,68 @@ func countHandler(call *atomic.Int64) dns.HandlerFunc { } } +func mustRR(s string) dns.RR { + r, err := dns.NewRR(s) + if err != nil { + panic(err) + } + return r +} + +func bigExtraSectionHandler() dns.HandlerFunc { + return func(w dns.ResponseWriter, msg *dns.Msg) { + m := &dns.Msg{ + Answer: []dns.RR{ + mustRR(". 7149 IN NS m.root-servers.net."), + mustRR(". 7149 IN NS c.root-servers.net."), + mustRR(". 7149 IN NS e.root-servers.net."), + mustRR(". 7149 IN NS j.root-servers.net."), + mustRR(". 7149 IN NS g.root-servers.net."), + mustRR(". 7149 IN NS k.root-servers.net."), + mustRR(". 7149 IN NS l.root-servers.net."), + mustRR(". 7149 IN NS d.root-servers.net."), + mustRR(". 7149 IN NS h.root-servers.net."), + mustRR(". 7149 IN NS b.root-servers.net."), + mustRR(". 7149 IN NS a.root-servers.net."), + mustRR(". 7149 IN NS f.root-servers.net."), + mustRR(". 7149 IN NS i.root-servers.net."), + }, + Extra: []dns.RR{ + mustRR("m.root-servers.net. 656 IN A 202.12.27.33"), + mustRR("m.root-servers.net. 656 IN AAAA 2001:dc3::35"), + mustRR("c.root-servers.net. 656 IN A 192.33.4.12"), + mustRR("c.root-servers.net. 656 IN AAAA 2001:500:2::c"), + mustRR("e.root-servers.net. 656 IN A 192.203.230.10"), + mustRR("e.root-servers.net. 656 IN AAAA 2001:500:a8::e"), + mustRR("j.root-servers.net. 656 IN A 192.58.128.30"), + mustRR("j.root-servers.net. 656 IN AAAA 2001:503:c27::2:30"), + mustRR("g.root-servers.net. 656 IN A 192.112.36.4"), + mustRR("g.root-servers.net. 656 IN AAAA 2001:500:12::d0d"), + mustRR("k.root-servers.net. 656 IN A 193.0.14.129"), + mustRR("k.root-servers.net. 656 IN AAAA 2001:7fd::1"), + mustRR("l.root-servers.net. 656 IN A 199.7.83.42"), + mustRR("l.root-servers.net. 656 IN AAAA 2001:500:9f::42"), + mustRR("d.root-servers.net. 656 IN A 199.7.91.13"), + mustRR("d.root-servers.net. 656 IN AAAA 2001:500:2d::d"), + mustRR("h.root-servers.net. 656 IN A 198.97.190.53"), + mustRR("h.root-servers.net. 656 IN AAAA 2001:500:1::53"), + mustRR("b.root-servers.net. 656 IN A 170.247.170.2"), + mustRR("b.root-servers.net. 656 IN AAAA 2801:1b8:10::b"), + mustRR("a.root-servers.net. 656 IN A 198.41.0.4"), + mustRR("a.root-servers.net. 656 IN AAAA 2001:503:ba3e::2:30"), + mustRR("f.root-servers.net. 656 IN A 192.5.5.241"), + mustRR("f.root-servers.net. 656 IN AAAA 2001:500:2f::f"), + mustRR("i.root-servers.net. 656 IN A 192.36.148.17"), + mustRR("i.root-servers.net. 656 IN AAAA 2001:7fe::53"), + }, + } + + m.Compress = true + m.SetReply(msg) + w.WriteMsg(m) + } +} + func generateEdns0ClientCookie() string { cookie := make([]byte, 8) if _, err := rand.Read(cookie); err != nil {